├── .gitignore ├── LICENSE ├── README.md ├── demo_images ├── t1_img.png └── t2_img.png ├── setup.py └── torchange ├── __init__.py ├── configs └── changen2 │ ├── s1c1_cstar_vitb_1x256.py │ ├── s1c1_cstar_vitl_1x256.py │ ├── s1c5_cstar_vitb_1x256.py │ └── s9c1_cstar_vitb_1x256.py ├── data ├── __init__.py ├── bitemporal.py ├── hf_builder.py ├── levircd.py ├── s2looking.py ├── second.py └── xView2.py ├── metrics ├── __init__.py ├── bcd.py ├── second.py └── xview2.py ├── models ├── __init__.py ├── changemask.py ├── changen2 │ ├── README.md │ ├── __init__.py │ ├── _changestar_1x256.py │ ├── change_event_simulation.py │ └── rsdit.py ├── changesparse.py ├── changestar2.py ├── changestar_1xd.py └── segment_any_change │ ├── README.md │ ├── __init__.py │ ├── anychange.py │ ├── base.py │ ├── segment_anything │ ├── __init__.py │ ├── automatic_mask_generator.py │ ├── build_sam.py │ ├── modeling │ │ ├── __init__.py │ │ ├── common.py │ │ ├── image_encoder.py │ │ ├── mask_decoder.py │ │ ├── prompt_encoder.py │ │ ├── sam.py │ │ └── transformer.py │ ├── predictor.py │ └── utils │ │ ├── __init__.py │ │ ├── amg.py │ │ ├── onnx.py │ │ └── transforms.py │ ├── simple_maskgen.py │ └── viz.py └── module ├── __init__.py ├── _sam_vit.py └── farseg.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .idea -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## torchange - A Unified Change Representation Learning Benchmark Library 2 | 3 | torchange aims to provide out-of-box contemporary spatiotemporal change model implementations, standard metrics, and datasets, in pursuit of benchmarking and reproducibility. 4 | 5 | >This project is still under development. Other repositories would be gradually merged into ```torchange```. 6 | 7 | > The ```torchange``` API is in beta and may change in the near future. 8 | 9 | > Note: ```torchange``` is designed to provide straightforward implementations, thus we will adopt a single file for each algorithm without any modular encapsulation. 10 | Algorithms released before 2024 will be transferred here from our internal codebase. 11 | If you encounter any bugs, please report them in the issue section. Please be patient with new releases and bug fixes, as this is a significant burden for a single maintainer. 12 | Technical consultations are only accepted via email inquiry. 13 | 14 | > Our default training engine is [ever](https://github.com/Z-Zheng/ever/). 15 | 16 | ### News 17 | 18 | - 2024/06, we launch the project of ``torchange``. 19 | 20 | ### Features 21 | 22 | - Out-of-box and straightforward model implementations 23 | - Highly-optimized implementations, e.g., multi-gpu sync dice loss. 24 | - Multi-gpu metric computation and score tracker, supporting wandb. 25 | - Including the latest research advancements in ``Change``, not just architecture games. 26 | 27 | ### Installation 28 | 29 | 30 | #### nightly version (master) 31 | ```bash 32 | pip install -U --no-deps --force-reinstall git+https://github.com/Z-Zheng/pytorch-change-models 33 | ``` 34 | 35 | ### Model zoo (in progress) 36 | 37 | This is also a tutorial for junior researchers interested in contemporary change detection. 38 | 39 | 40 | #### 0. change modeling principle 41 | - (PCM) Unifying Remote Sensing Change Detection via Deep Probabilistic Change Models: from Principles, Models to Applications, ISPRS P&RS 2024. [[`Paper`](https://www.sciencedirect.com/science/article/pii/S0924271624002624)], [[`Code`](https://github.com/Z-Zheng/pytorch-change-models/blob/main/torchange/models/changesparse.py)] 42 | - (GPCM) Scalable Multi-Temporal Remote Sensing Change Data Generation via Simulating Stochastic Change Process, ICCV 2023 [[`Paper`](https://arxiv.org/pdf/2309.17031)], [[`Code`](https://github.com/Z-Zheng/Changen)] 43 | 44 | 45 | #### 1.0 unified architecture 46 | - (ChangeStar) Change is Everywhere: Single-Temporal Supervised Object Change Detection in Remote Sensing Imagery, ICCV 2021. [[`Paper`](https://arxiv.org/abs/2108.07002)], [[`Project`](https://zhuozheng.top/changestar/)], [[`Code`](https://github.com/Z-Zheng/ChangeStar)] 47 | - (ChangeStar2) Single-Temporal Supervised Learning for Universal Remote Sensing Change Detection, IJCV 2024. [[`Paper`](https://link.springer.com/article/10.1007/s11263-024-02141-4)], [[`Code`](https://github.com/Z-Zheng/pytorch-change-models/blob/main/torchange/models/changestar2.py)] 48 | - (ChangeSparse) Unifying Remote Sensing Change Detection via Deep Probabilistic Change Models: from Principles, Models to Applications, ISPRS P&RS 2024. [[`Paper`](https://www.sciencedirect.com/science/article/pii/S0924271624002624)], [[`Code`](https://github.com/Z-Zheng/pytorch-change-models/blob/main/torchange/models/changesparse.py)] 49 | 50 | #### 1.1 one-to-many semantic change detection 51 | - (ChangeOS) Building damage assessment for rapid disaster response with a deep object-based semantic change detection framework: from natural disasters to man-made disasters, RSE 2021. [[`Paper`](https://www.sciencedirect.com/science/article/pii/S0034425721003564)], [[`Code`](https://github.com/Z-Zheng/ChangeOS)] 52 | 53 | #### 1.2 many-to-many semantic change detection 54 | - (ChangeMask) ChangeMask: Deep Multi-task Encoder-Transformer-Decoder Architecture for Semantic Change Detection, ISPRS P&RS 2022. [[`Paper`](https://www.sciencedirect.com/science/article/pii/S0924271621002835)], [[`Code`](https://github.com/Z-Zheng/pytorch-change-models/blob/main/torchange/models/changemask.py)] 55 | 56 | 57 | #### 2.0 learning change representation via single-temporal supervision 58 | - (STAR) Change is Everywhere: Single-Temporal Supervised Object Change Detection in Remote Sensing Imagery, ICCV 2021. [[`Paper`](https://arxiv.org/abs/2108.07002)], [[`Project`](https://zhuozheng.top/changestar/)], [[`Code`](https://github.com/Z-Zheng/ChangeStar)] 59 | - (G-STAR) Single-Temporal Supervised Learning for Universal Remote Sensing Change Detection, IJCV 2024. [[`Paper`](https://link.springer.com/article/10.1007/s11263-024-02141-4)], [[`Code`](https://github.com/Z-Zheng/pytorch-change-models/blob/main/torchange/models/changestar2.py)] 60 | - (Changen) Scalable Multi-Temporal Remote Sensing Change Data Generation via Simulating Stochastic Change Process, ICCV 2023 [[`Paper`](https://arxiv.org/pdf/2309.17031)], [[`Code`](https://github.com/Z-Zheng/Changen)] 61 | - (Changen2) Changen2: Multi-Temporal Remote Sensing Generative Change Foundation Model, IEEE TPAMI 2024 [[`Paper`](https://arxiv.org/abs/2406.17998)],[[`Code`](https://github.com/Z-Zheng/pytorch-change-models/tree/main/torchange/models/changen2)] 62 | 63 | 64 | #### 2.1 change data synthesis from single-temporal data 65 | - (Changen) Scalable Multi-Temporal Remote Sensing Change Data Generation via Simulating Stochastic Change Process, ICCV 2023 [[`Paper`](https://arxiv.org/pdf/2309.17031)], [[`Code`](https://github.com/Z-Zheng/Changen)] 66 | - (Changen2) Changen2: Multi-Temporal Remote Sensing Generative Change Foundation Model, IEEE TPAMI 2024 [[`Paper`](https://arxiv.org/abs/2406.17998)],[[`Code`](https://github.com/Z-Zheng/pytorch-change-models/tree/main/torchange/models/changen2)] 67 | 68 | 69 | #### 3.0 zero-shot change detection 70 | - (AnyChange) Segment Any Change, NeurIPS 2024 [[`Paper`](https://arxiv.org/abs/2402.01188)], [[`Code`](https://github.com/Z-Zheng/pytorch-change-models/blob/main/torchange/models/segment_any_change)] 71 | - (Changen2) Changen2: Multi-Temporal Remote Sensing Generative Change Foundation Model, IEEE TPAMI 2024 [[`Paper`](https://arxiv.org/abs/2406.17998)],[[`Code`](https://github.com/Z-Zheng/pytorch-change-models/tree/main/torchange/models/changen2)] 72 | 73 | 74 | ### License 75 | This project is under the Apache 2.0 License. See [LICENSE](https://github.com/Z-Zheng/pytorch-change-models/blob/main/LICENSE) for details. 76 | 77 | If you find it useful in your work — whether in research, demos, products, or educational materials — we’d love to hear from you! 78 | 79 | Sharing your use case helps us: 80 | 81 | 📌 Understand real-world impact 82 | 83 | 📣 Highlight your work in our future talks or papers 84 | 85 | 🚀 Improve future versions of this project 86 | 87 | 📬 Please drop us a short message at: 88 | zhuozheng@cs.stanford.edu 89 | Feel free to include a brief description, your institution/company, a link to your work (if available), or any suggestions. 90 | 91 | 92 | ![visitors](https://visitor-badge.laobi.icu/badge?page_id=Z-Zheng/pytorch-change-models) -------------------------------------------------------------------------------- /demo_images/t1_img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Z-Zheng/pytorch-change-models/9c564feecfb577069b9a9bc5859264f768581046/demo_images/t1_img.png -------------------------------------------------------------------------------- /demo_images/t2_img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Z-Zheng/pytorch-change-models/9c564feecfb577069b9a9bc5859264f768581046/demo_images/t2_img.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | from os import path 3 | 4 | 5 | def get_version(): 6 | init_py_path = path.join(path.abspath(path.dirname(__file__)), "torchange", 7 | "__init__.py") 8 | init_py = open(init_py_path, "r").readlines() 9 | version_line = [l.strip() for l in init_py if l.startswith("__version__")][0] 10 | version = version_line.split("=")[-1].strip().strip("'\"") 11 | return version 12 | 13 | 14 | install_requires = [ 15 | 'numpy', 16 | 'albumentations>=0.4.2', 17 | 'tifffile', 18 | 'scikit-image', 19 | 'tqdm', 20 | 'einops', 21 | 'timm', 22 | 'datasets[vision]', 23 | ] 24 | setup( 25 | name='torchange', 26 | version=get_version(), 27 | description='pytorch-change-models', 28 | keywords='Remote Sensing, ' 29 | 'Earth Vision, ' 30 | 'Deep Learning, ' 31 | 'Change Detection, ' 32 | 'Change Data Generation, ', 33 | packages=find_packages(exclude=['projects', 'tools']), 34 | classifiers=[ 35 | 'Development Status :: 4 - Beta', 36 | 'Operating System :: OS Independent', 37 | 'Programming Language :: Python :: 3.8', 38 | 'Programming Language :: Python :: 3.9', 39 | 'Programming Language :: Python :: 3.10', 40 | 'Programming Language :: Python :: 3.11', 41 | 'Topic :: Utilities', 42 | ], 43 | url='https://github.com/Z-Zheng/pytorch-change-models', 44 | author='Zhuo Zheng', 45 | author_email='zhuozheng@cs.stanford.edu', 46 | license='CC-BY-NC 4.0', 47 | setup_requires=[], 48 | tests_require=[], 49 | install_requires=install_requires, 50 | zip_safe=False) 51 | -------------------------------------------------------------------------------- /torchange/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Zhuo Zheng and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | __version__ = "0.0.1" 7 | 8 | import importlib 9 | import pkgutil 10 | from pathlib import Path 11 | 12 | 13 | def _import_dataclass(): 14 | data_pkg = __name__ + ".data" 15 | package_dir = Path(__file__).parent / "data" 16 | for module_info in pkgutil.iter_modules([str(package_dir)]): 17 | if not module_info.name.startswith("_"): 18 | full_module_name = f"{data_pkg}.{module_info.name}" 19 | importlib.import_module(full_module_name) 20 | 21 | 22 | _import_dataclass() 23 | -------------------------------------------------------------------------------- /torchange/configs/changen2/s1c1_cstar_vitb_1x256.py: -------------------------------------------------------------------------------- 1 | D = 256 2 | config = dict( 3 | model=dict( 4 | type='ChangeStar1xd', 5 | params=dict( 6 | GLOBAL=dict(weight=dict(path=None)), 7 | encoder=dict( 8 | bitemporal_forward=True, 9 | type='SAMEncoderFarSeg', 10 | params=dict( 11 | checkpoint=None, 12 | vit_type='vit_b', 13 | fpn_channels=D, 14 | out_channels=D, 15 | freeze_vit=False, 16 | ), 17 | ), 18 | head=dict(num_semantic_classes=1), 19 | ) 20 | ), 21 | ) 22 | -------------------------------------------------------------------------------- /torchange/configs/changen2/s1c1_cstar_vitl_1x256.py: -------------------------------------------------------------------------------- 1 | D = 256 2 | config = dict( 3 | model=dict( 4 | type='ChangeStar1xd', 5 | params=dict( 6 | GLOBAL=dict(weight=dict(path=None)), 7 | encoder=dict( 8 | bitemporal_forward=True, 9 | type='SAMEncoderFarSeg', 10 | params=dict( 11 | checkpoint=None, 12 | vit_type='vit_l', 13 | fpn_channels=D, 14 | out_channels=D, 15 | freeze_vit=False, 16 | ), 17 | ), 18 | head=dict(num_semantic_classes=1), 19 | ) 20 | ), 21 | ) 22 | -------------------------------------------------------------------------------- /torchange/configs/changen2/s1c5_cstar_vitb_1x256.py: -------------------------------------------------------------------------------- 1 | D = 256 2 | 3 | config = dict( 4 | model=dict( 5 | type='ChangeStar1xd', 6 | params=dict( 7 | GLOBAL=dict(weight=dict(path=None)), 8 | encoder=dict( 9 | bitemporal_forward=True, 10 | type='SAMEncoderFarSeg', 11 | params=dict( 12 | checkpoint=None, 13 | vit_type='vit_b', 14 | fpn_channels=D, 15 | out_channels=D, 16 | freeze_vit=False, 17 | ), 18 | 19 | ), 20 | head=dict(num_semantic_classes=1, num_change_classes=5), 21 | ) 22 | ), 23 | ) 24 | -------------------------------------------------------------------------------- /torchange/configs/changen2/s9c1_cstar_vitb_1x256.py: -------------------------------------------------------------------------------- 1 | D = 256 2 | 3 | config = dict( 4 | model=dict( 5 | type='ChangeStar1xd', 6 | params=dict( 7 | GLOBAL=dict(weight=dict(path=None)), 8 | encoder=dict( 9 | bitemporal_forward=True, 10 | type='SAMEncoderFarSeg', 11 | params=dict( 12 | checkpoint=None, 13 | vit_type='vit_b', 14 | fpn_channels=D, 15 | out_channels=D, 16 | freeze_vit=False, 17 | ), 18 | ), 19 | head=dict(num_semantic_classes=9), 20 | ) 21 | ), 22 | ) 23 | -------------------------------------------------------------------------------- /torchange/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Zhuo Zheng and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from .hf_builder import build_dataset -------------------------------------------------------------------------------- /torchange/data/bitemporal.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Zhuo Zheng and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from pathlib import Path 8 | 9 | import albumentations as A 10 | import ever as er 11 | import numpy as np 12 | import torch 13 | from skimage.io import imread 14 | from torch.utils.data import Dataset 15 | from datasets import load_dataset, concatenate_datasets 16 | from typing import Dict 17 | 18 | PRE = 'pre' 19 | TWISE = 'temporalwise' 20 | POST = 'post' 21 | 22 | 23 | class BinarizeMask(A.DualTransform): 24 | def __init__(self): 25 | super().__init__(True, 1.0) 26 | 27 | def apply(self, img, *args, **params): 28 | return img 29 | 30 | def apply_to_mask(self, mask, *args, **params): 31 | return (mask > 0).astype(np.float32) 32 | 33 | def get_transform_init_args_names(self): 34 | return () 35 | 36 | 37 | def check_name(t1_image_name, other): 38 | assert t1_image_name == other, f'expected {t1_image_name}, but {other}' 39 | 40 | 41 | def to_bitemporal_compose(compose): 42 | assert isinstance(compose, A.Compose) 43 | 44 | return A.Compose(compose.transforms, additional_targets={ 45 | 't2_image': 'image', 46 | 't2_mask': 'mask', 47 | 'change': 'mask' 48 | }) 49 | 50 | 51 | def data_transform(T, data) -> Dict: 52 | if isinstance(T, dict): 53 | data = T[PRE](**data) 54 | 55 | if not isinstance(T[TWISE], A.NoOp): 56 | if 'mask' in data: 57 | t1_data = {'image': data['image'], 'mask': data['mask']} 58 | else: 59 | t1_data = {'image': data['image']} 60 | t1_data = T[TWISE](**t1_data) 61 | data.update(t1_data) 62 | 63 | if 't2_mask' in data: 64 | t2_data = {'image': data['t2_image'], 'mask': data['t2_mask']} 65 | else: 66 | t2_data = {'image': data['t2_image']} 67 | t2_data = T[TWISE](**t2_data) 68 | t2_data = {f't2_{k}': v for k, v in t2_data.items()} 69 | data.update(t2_data) 70 | 71 | data = T[POST](**data) 72 | else: 73 | data = T(data) 74 | return data 75 | 76 | 77 | class BitemporalDataset(Dataset): 78 | def __init__( 79 | self, 80 | t1_image_fps, 81 | t2_image_fps, 82 | t1_mask_fps=None, 83 | t2_mask_fps=None, 84 | change_fps=None, 85 | transform=None, 86 | name_checker=check_name 87 | ): 88 | self.t1_image_fps = t1_image_fps 89 | self.t2_image_fps = t2_image_fps 90 | self.t1_mask_fps = t1_mask_fps 91 | self.t2_mask_fps = t2_mask_fps 92 | self.change_fps = change_fps 93 | 94 | self.name_checker = name_checker 95 | 96 | self.t = { 97 | PRE: A.NoOp(), 98 | TWISE: A.NoOp(), 99 | POST: A.NoOp(), 100 | } 101 | if isinstance(transform, A.Compose): 102 | self.t[POST] = to_bitemporal_compose(transform) 103 | elif isinstance(transform, dict): 104 | for k, v in transform.items(): 105 | assert k in [PRE, TWISE, POST] 106 | if isinstance(v, A.Compose): 107 | v = to_bitemporal_compose(v) 108 | self.t[k] = v 109 | else: 110 | self.t = transform 111 | 112 | def __getitem__(self, idx): 113 | base_name = Path(self.t1_image_fps[idx]).name 114 | self.name_checker(base_name, Path(self.t2_image_fps[idx]).name) 115 | 116 | img1 = imread(self.t1_image_fps[idx]) 117 | img2 = imread(self.t2_image_fps[idx]) 118 | 119 | data = { 120 | 'image': img1, 121 | 't2_image': img2 122 | } 123 | 124 | if self.t1_mask_fps: 125 | self.name_checker(base_name, Path(self.t1_mask_fps[idx]).name) 126 | msk1 = imread(self.t1_mask_fps[idx]) 127 | data['mask'] = msk1 128 | 129 | if self.t2_mask_fps: 130 | self.name_checker(base_name, Path(self.t2_mask_fps[idx]).name) 131 | msk2 = imread(self.t2_mask_fps[idx]) 132 | data['t2_mask'] = msk2 133 | 134 | if self.change_fps: 135 | self.name_checker(base_name, Path(self.change_fps[idx]).name) 136 | cmask = imread(self.change_fps[idx]) 137 | data['change'] = cmask 138 | 139 | data = data_transform(self.t, data) 140 | 141 | img = torch.cat([data['image'], data['t2_image']], dim=0) 142 | 143 | masks = [] 144 | if 'mask' in data: 145 | masks.append(data['mask']) 146 | 147 | if 't2_mask' in data: 148 | masks.append(data['t2_mask']) 149 | 150 | if 'change' in data: 151 | masks.append(data['change']) 152 | 153 | ann = dict( 154 | masks=masks, 155 | image_filename=str(Path(self.t1_image_fps[idx]).name) 156 | ) 157 | 158 | return img, ann 159 | 160 | def __len__(self): 161 | return len(self.t1_image_fps) 162 | 163 | 164 | @er.registry.DATASET.register() 165 | class HFBitemporalDataset(er.ERDataset): 166 | def __init__(self, config): 167 | super().__init__(config) 168 | ds = [] 169 | for s in self.cfg.splits: 170 | d = load_dataset(self.cfg.hf_repo_name, split=s) 171 | ds.append(d) 172 | hfd = concatenate_datasets(ds) if len(ds) > 1 else ds[0] 173 | self.hfd = hfd.with_format('numpy') 174 | 175 | transform = self.cfg.transform 176 | self.t = { 177 | PRE: A.NoOp(), 178 | TWISE: A.NoOp(), 179 | POST: A.NoOp(), 180 | } 181 | if isinstance(transform, A.Compose): 182 | self.t[POST] = to_bitemporal_compose(transform) 183 | elif isinstance(transform, dict): 184 | for k, v in transform.items(): 185 | assert k in [PRE, TWISE, POST] 186 | if isinstance(v, A.Compose): 187 | v = to_bitemporal_compose(v) 188 | self.t[k] = v 189 | else: 190 | self.t = transform 191 | 192 | def _slice_data(self, data, tile_slice): 193 | if tile_slice is None: 194 | return data 195 | 196 | x1, y1, x2, y2 = tile_slice 197 | return data[y1:y2, x1:x2] 198 | 199 | def compute_tile_slice(self, idx): 200 | return idx, None 201 | 202 | def __getitem__(self, idx): 203 | idx, tile_slice = self.compute_tile_slice(idx) 204 | 205 | example = self.hfd[idx] 206 | img1 = self._slice_data(example['t1_image'], tile_slice) 207 | img2 = self._slice_data(example['t2_image'], tile_slice) 208 | 209 | data = { 210 | 'image': img1, 211 | 't2_image': img2 212 | } 213 | 214 | if 't1_mask' in example: 215 | data['mask'] = self._slice_data(example['t1_mask'], tile_slice) 216 | 217 | if 't2_mask' in example: 218 | data['t2_mask'] = self._slice_data(example['t2_mask'], tile_slice) 219 | 220 | if 'change_mask' in example: 221 | data['change'] = self._slice_data(example['change_mask'], tile_slice) 222 | 223 | data = data_transform(self.t, data) 224 | 225 | img = torch.cat([data['image'], data['t2_image']], dim=0) 226 | 227 | masks = [] 228 | if 'mask' in data: 229 | masks.append(data['mask']) 230 | 231 | if 't2_mask' in data: 232 | masks.append(data['t2_mask']) 233 | 234 | if 'change' in data: 235 | masks.append(data['change']) 236 | 237 | ann = dict( 238 | masks=masks, 239 | image_filename=str(example['image_name']) 240 | ) 241 | 242 | return img, ann 243 | 244 | def __len__(self): 245 | return len(self.hfd) 246 | 247 | def set_default_config(self): 248 | self.cfg.update(dict( 249 | hf_repo_name=None, 250 | splits=[], 251 | transform=None, 252 | )) 253 | -------------------------------------------------------------------------------- /torchange/data/hf_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Zhuo Zheng and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from torchange.data.bitemporal import HFBitemporalDataset 7 | 8 | HF_DATASETS = { 9 | 'levircd': 'EVER-Z/torchange_levircd', 10 | 's2looking': 'EVER-Z/torchange_s2looking', 11 | 'second': 'EVER-Z/torchange_second', 12 | 'xView2': 'EVER-Z/torchange_xView2', 13 | 'Changen2-S1-15k': 'EVER-Z/torchange_Changen2-S1-15k', 14 | 'Changen2-S9-27k': 'EVER-Z/torchange_Changen2-S9-27k', 15 | } 16 | 17 | 18 | def build_dataset(dataset_name, splits, transform, **kwargs): 19 | assert dataset_name in HF_DATASETS 20 | 21 | if dataset_name == 'xView2': 22 | from torchange.data.xView2 import HFxView2 23 | assert 'crop_size' in kwargs 24 | assert 'stride' in kwargs 25 | assert 'training' in kwargs 26 | 27 | return HFxView2(dict( 28 | hf_repo_name=HF_DATASETS[dataset_name], 29 | splits=splits, 30 | transform=transform, 31 | crop_size=kwargs['crop_size'], 32 | stride=kwargs['stride'], 33 | training=kwargs['training'], 34 | )) 35 | 36 | dataset = HFBitemporalDataset(dict( 37 | hf_repo_name=HF_DATASETS[dataset_name], 38 | splits=splits, 39 | transform=transform, 40 | )) 41 | 42 | return dataset 43 | -------------------------------------------------------------------------------- /torchange/data/levircd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Zhuo Zheng and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import ever as er 8 | import glob 9 | import math 10 | import os 11 | from torchange.data.bitemporal import BitemporalDataset 12 | 13 | 14 | @er.registry.DATASET.register() 15 | class LEVIRCD(BitemporalDataset, er.ERDataset): 16 | def __init__(self, cfg): 17 | er.ERDataset.__init__(self, cfg) 18 | 19 | A_image_fps = sorted(glob.glob(os.path.join(self.cfg.dataset_dir, 'A', '*.png'))) 20 | N = len(A_image_fps) 21 | if self.cfg.subsample_ratio < 1.0: 22 | split = math.floor(len(A_image_fps) * self.cfg.subsample_ratio) 23 | A_image_fps = A_image_fps[:split] 24 | 25 | A_image_fps = math.ceil(N / len(A_image_fps)) * A_image_fps 26 | A_image_fps = A_image_fps[:N] 27 | er.info(f'use subsample ratio of {self.cfg.subsample_ratio}, {split} training samples') 28 | 29 | B_image_fps = [fp.replace('/A/', '/B/') for fp in A_image_fps] 30 | gt_fps = [fp.replace('/A', '/label') for fp in A_image_fps] 31 | 32 | super().__init__( 33 | t1_image_fps=A_image_fps, 34 | t2_image_fps=B_image_fps, 35 | change_fps=gt_fps, 36 | transform=self.cfg.transforms 37 | ) 38 | 39 | def set_default_config(self): 40 | self.cfg.update(dict( 41 | dataset_dir=None, 42 | transforms=None, 43 | subsample_ratio=1.0 44 | )) 45 | -------------------------------------------------------------------------------- /torchange/data/s2looking.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Zhuo Zheng and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import ever as er 8 | from torchange.data.bitemporal import BitemporalDataset 9 | import glob 10 | import math 11 | import os 12 | 13 | 14 | @er.registry.DATASET.register() 15 | class S2Looking(BitemporalDataset, er.ERDataset): 16 | def __init__(self, cfg): 17 | er.ERDataset.__init__(self, cfg) 18 | A_image_fps = sorted(glob.glob(os.path.join(self.cfg.dataset_dir, 'Image1', '*.png'))) 19 | N = len(A_image_fps) 20 | if self.cfg.subsample_ratio < 1.0: 21 | split = math.floor(len(A_image_fps) * self.cfg.subsample_ratio) 22 | A_image_fps = A_image_fps[:split] 23 | 24 | A_image_fps = math.ceil(N / len(A_image_fps)) * A_image_fps 25 | A_image_fps = A_image_fps[:N] 26 | er.info(f'use subsample ratio of {self.cfg.subsample_ratio}, {split} training samples') 27 | 28 | B_image_fps = [fp.replace('Image1', 'Image2') for fp in A_image_fps] 29 | gt_fps = [fp.replace('Image1', 'label') for fp in A_image_fps] 30 | 31 | super().__init__( 32 | t1_image_fps=A_image_fps, 33 | t2_image_fps=B_image_fps, 34 | change_fps=gt_fps, 35 | transform=self.cfg.transforms 36 | ) 37 | 38 | def set_default_config(self): 39 | self.cfg.update(dict( 40 | dataset_dir=None, 41 | transforms=None, 42 | subsample_ratio=1.0 43 | )) 44 | -------------------------------------------------------------------------------- /torchange/data/second.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Zhuo Zheng and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import ever as er 8 | from torchange.data.bitemporal import BitemporalDataset 9 | from pathlib import Path 10 | import torch 11 | 12 | 13 | @er.registry.DATASET.register() 14 | class BinarySECOND(BitemporalDataset, er.ERDataset): 15 | def __init__(self, cfg): 16 | er.ERDataset.__init__(self, cfg) 17 | root_dir = Path(self.cfg.dataset_dir) 18 | img1_fps = [str(fp) for fp in (root_dir / 'im1').glob('*.png')] 19 | img2_fps = [fp.replace('im1', 'im2') for fp in img1_fps] 20 | label1_fps = [fp.replace('im1', 'label1_wo_cm') for fp in img1_fps] 21 | 22 | super().__init__( 23 | t1_image_fps=img1_fps, 24 | t2_image_fps=img2_fps, 25 | change_fps=label1_fps, 26 | transform=self.cfg.transforms 27 | ) 28 | 29 | def __getitem__(self, idx): 30 | img, ann = super().__getitem__(idx) 31 | # to binary change mask 32 | ann['masks'][-1] = (ann['masks'][-1] > 0).float() 33 | return img, ann 34 | 35 | def set_default_config(self): 36 | self.config.update(dict( 37 | dataset_dir=None, 38 | transforms=None, 39 | )) 40 | 41 | 42 | @er.registry.DATASET.register() 43 | class SECOND(BitemporalDataset, er.ERDataset): 44 | def __init__(self, cfg): 45 | er.ERDataset.__init__(self, cfg) 46 | root_dir = Path(self.cfg.dataset_dir) 47 | img1_fps = [str(fp) for fp in (root_dir / 'im1').glob('*.png')] 48 | img2_fps = [fp.replace('im1', 'im2') for fp in img1_fps] 49 | label1_fps = [fp.replace('im1', 'label1_wo_cm') for fp in img1_fps] 50 | label2_fps = [fp.replace('im1', 'label2_wo_cm') for fp in img1_fps] 51 | 52 | super().__init__( 53 | t1_image_fps=img1_fps, 54 | t2_image_fps=img2_fps, 55 | t1_mask_fps=label1_fps, 56 | t2_mask_fps=label2_fps, 57 | transform=self.cfg.transforms 58 | ) 59 | 60 | def __getitem__(self, idx): 61 | img, ann = super().__getitem__(idx) 62 | # append binary change mask 63 | ann['masks'].append((ann['masks'][0] > 0).float()) 64 | # convert 0-6 to 255, 0-5, where 255 will be ignored. 65 | ann['masks'][0] = ann['masks'][0].to(torch.uint8) - 1 66 | ann['masks'][1] = ann['masks'][1].to(torch.uint8) - 1 67 | 68 | return img, ann 69 | 70 | def set_default_config(self): 71 | self.config.update(dict( 72 | dataset_dir=None, 73 | transforms=None, 74 | )) 75 | -------------------------------------------------------------------------------- /torchange/data/xView2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Zhuo Zheng and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | 9 | import numpy as np 10 | from pathlib import Path 11 | import ever as er 12 | 13 | from torchange.data.bitemporal import HFBitemporalDataset, BitemporalDataset, data_transform 14 | from skimage.io import imread 15 | import torch 16 | from tqdm import tqdm 17 | 18 | 19 | @er.registry.DATASET.register() 20 | class xView2(BitemporalDataset, er.ERDataset): 21 | def __init__(self, cfg): 22 | er.ERDataset.__init__(self, cfg) 23 | dataset_dir = self.cfg.dataset_dir 24 | pre_img_fps, post_img_fps, pre_gt_fps, post_gt_fps = [], [], [], [] 25 | splits = [] 26 | if isinstance(dataset_dir, str): 27 | splits.append(Path(dataset_dir).name) 28 | pre_img_fps, post_img_fps, pre_gt_fps, post_gt_fps = self.parse_dataset_dir(self.cfg.dataset_dir) 29 | elif isinstance(dataset_dir, (tuple, list)): 30 | for dd in dataset_dir: 31 | splits.append(Path(dd).name) 32 | a, b, c, d = self.parse_dataset_dir(dd) 33 | pre_img_fps += a 34 | post_img_fps += b 35 | pre_gt_fps += c 36 | post_gt_fps += d 37 | else: 38 | raise ValueError 39 | 40 | if self.cfg.training: 41 | self.tiles = er.sliding_window((1024, 1024), self.cfg.crop_size, self.cfg.stride) 42 | else: 43 | self.tiles = er.sliding_window((1024, 1024), 1024, 1024) 44 | 45 | if self.cfg.training: 46 | split_name = '_'.join(splits) 47 | indices_file = Path(os.curdir) / f'xView2_{split_name}_valid_indices_p{self.cfg.crop_size}_s{self.cfg.stride}.npy' 48 | if indices_file.exists(): 49 | self.valid_patch_indices = np.load(str(indices_file)) 50 | else: 51 | valid = np.ones([len(post_gt_fps) * self.tiles.shape[0]], dtype=np.uint8) 52 | for img_idx in tqdm(range(len(post_gt_fps)), disable=not er.dist.is_main_process()): 53 | t2_mask = imread(post_gt_fps[img_idx]).astype(np.float32) 54 | for tile_idx in range(self.tiles.shape[0]): 55 | x1, y1, x2, y2 = self.tiles[tile_idx] 56 | sub = t2_mask[y1:y2, x1:x2] 57 | sub[sub == 255] = 0 58 | if np.sum(sub) == 0: 59 | valid[img_idx * self.tiles.shape[0] + tile_idx] = 0 60 | self.valid_patch_indices = np.nonzero(valid.astype(bool))[0] 61 | np.save(str(indices_file), self.valid_patch_indices) 62 | else: 63 | self.valid_patch_indices = np.arange(len(post_gt_fps)) 64 | 65 | super().__init__( 66 | t1_image_fps=pre_img_fps, 67 | t2_image_fps=post_img_fps, 68 | t1_mask_fps=pre_gt_fps, 69 | t2_mask_fps=post_gt_fps, 70 | transform=self.cfg.transforms, 71 | name_checker=lambda x, y: True, 72 | ) 73 | 74 | def parse_dataset_dir(self, dataset_dir): 75 | dataset_dir = Path(dataset_dir) 76 | img_dir = dataset_dir / 'images' 77 | tgt_dir = dataset_dir / 'targets' 78 | 79 | pre_gt_fps = list(tgt_dir.glob('*_pre_*.png')) 80 | post_gt_fps = [tgt_dir / fp.name.replace('pre', 'post') for fp in pre_gt_fps] 81 | 82 | pre_img_fps = [img_dir / fp.name.replace('_target.png', '.png') for fp in pre_gt_fps] 83 | post_img_fps = [img_dir / fp.name.replace('_target.png', '.png') for fp in post_gt_fps] 84 | 85 | return pre_img_fps, post_img_fps, pre_gt_fps, post_gt_fps 86 | 87 | def __getitem__(self, idx): 88 | idx = self.valid_patch_indices[idx] 89 | img_idx = idx // self.tiles.shape[0] 90 | tile_idx = idx % self.tiles.shape[0] 91 | x1, y1, x2, y2 = self.tiles[tile_idx] 92 | 93 | img1 = imread(self.t1_image_fps[img_idx])[y1:y2, x1:x2] 94 | img2 = imread(self.t2_image_fps[img_idx])[y1:y2, x1:x2] 95 | 96 | data = { 97 | 'image': img1, 98 | 't2_image': img2 99 | } 100 | 101 | if self.t1_mask_fps: 102 | msk1 = imread(self.t1_mask_fps[img_idx])[y1:y2, x1:x2] 103 | data['mask'] = msk1 104 | 105 | if self.t2_mask_fps: 106 | msk2 = imread(self.t2_mask_fps[img_idx])[y1:y2, x1:x2] 107 | if self.cfg.ignore_t2_bg: 108 | msk2[msk2 == 0] = 255 109 | data['t2_mask'] = msk2 110 | 111 | data = data_transform(self.t, data) 112 | 113 | img = torch.cat([data['image'], data['t2_image']], dim=0) 114 | 115 | masks = [] 116 | if 'mask' in data: 117 | masks.append(data['mask']) 118 | 119 | if 't2_mask' in data: 120 | masks.append(data['t2_mask']) 121 | 122 | ann = dict( 123 | masks=masks, 124 | image_filename=str(Path(self.t1_image_fps[img_idx]).name) 125 | ) 126 | return img, ann 127 | 128 | def __len__(self): 129 | return self.valid_patch_indices.shape[0] 130 | 131 | def set_default_config(self): 132 | self.config.update(dict( 133 | dataset_dir=None, 134 | crop_size=512, 135 | stride=256, 136 | training=True, 137 | ignore_t2_bg=False, 138 | )) 139 | 140 | 141 | @er.registry.DATASET.register() 142 | class HFxView2(HFBitemporalDataset): 143 | def __init__(self, cfg): 144 | super().__init__(cfg) 145 | if self.cfg.training: 146 | self.tiles = er.sliding_window((1024, 1024), self.cfg.crop_size, self.cfg.stride) 147 | else: 148 | self.tiles = er.sliding_window((1024, 1024), 1024, 1024) 149 | 150 | self.build_index() 151 | 152 | def build_index(self): 153 | if self.cfg.training: 154 | split_name = '_'.join(self.cfg.splits) 155 | basename = f'HFxView2_{split_name}_valid_indices_p{self.cfg.crop_size}_s{self.cfg.stride}.npy' 156 | indices_file = Path(os.curdir) / basename 157 | if indices_file.exists(): 158 | self.valid_patch_indices = np.load(str(indices_file)) 159 | else: 160 | valid = np.ones([len(self.hfd) * self.tiles.shape[0]], dtype=np.uint8) 161 | t2_masks = self.hfd['t2_mask'] 162 | for img_idx in tqdm(range(len(self.hfd)), disable=not er.dist.is_main_process()): 163 | t2_mask = np.array(t2_masks[img_idx]).astype(np.float32) 164 | for tile_idx in range(self.tiles.shape[0]): 165 | x1, y1, x2, y2 = self.tiles[tile_idx] 166 | sub = t2_mask[y1:y2, x1:x2] 167 | sub[sub == 255] = 0 168 | if np.sum(sub) == 0: 169 | valid[img_idx * self.tiles.shape[0] + tile_idx] = 0 170 | self.valid_patch_indices = np.nonzero(valid.astype(bool))[0] 171 | np.save(str(indices_file), self.valid_patch_indices) 172 | else: 173 | self.valid_patch_indices = np.arange(len(self.hfd)) 174 | 175 | def compute_tile_slice(self, idx): 176 | idx = self.valid_patch_indices[idx] 177 | img_idx = idx // self.tiles.shape[0] 178 | tile_idx = idx % self.tiles.shape[0] 179 | return int(img_idx), self.tiles[tile_idx] 180 | 181 | def __len__(self): 182 | return self.valid_patch_indices.shape[0] 183 | 184 | def set_default_config(self): 185 | super().set_default_config() 186 | self.cfg.update(dict( 187 | hf_repo_name='EVER-Z/torchange_xView2', 188 | crop_size=512, 189 | stride=256, 190 | training=True, 191 | )) 192 | -------------------------------------------------------------------------------- /torchange/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Zhuo Zheng and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /torchange/metrics/bcd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Zhuo Zheng and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os.path 8 | 9 | import gc 10 | import torch 11 | import ever as er 12 | import numpy as np 13 | from tqdm import tqdm 14 | 15 | 16 | @torch.no_grad() 17 | def binary_change_detection_evaluate(model, dataloader, log_dir=None, logger=None, class_names=None): 18 | model.eval() 19 | pm = er.metric.PixelMetric(2, log_dir, logger=logger, class_names=class_names) 20 | 21 | for img, gt in tqdm(dataloader, disable=not er.dist.is_main_process()): 22 | img = img.to(er.auto_device()) 23 | predictions = model(img) 24 | 25 | pr_change = (predictions['change_prediction'] > 0.5).cpu() 26 | pr_change = pr_change.numpy().astype(np.uint8) 27 | gt_change = gt['masks'][-1] 28 | gt_change = gt_change.numpy() 29 | y_true = gt_change.ravel() 30 | y_pred = pr_change.ravel() 31 | 32 | y_true = np.where(y_true > 0, np.ones_like(y_true), np.zeros_like(y_true)) 33 | 34 | pm.forward(y_true, y_pred) 35 | 36 | results = pm.summary_all() 37 | 38 | torch.cuda.empty_cache() 39 | gc.collect() 40 | 41 | return { 42 | 'eval/iou': results.iou(1), 43 | 'eval/f1': results.f1(1), 44 | 'eval/prec': results.precision(1), 45 | 'eval/rec': results.recall(1), 46 | } 47 | 48 | 49 | @er.registry.CALLBACK.register() 50 | class BinaryChangeDetectionPixelEval(er.Callback): 51 | def __init__(self, data_cfg, epoch_interval, prior=101): 52 | super().__init__( 53 | epoch_interval=epoch_interval, 54 | only_master=False, 55 | prior=prior, 56 | before_train=False, 57 | after_train=True, 58 | ) 59 | dataloader = er.builder.make_dataloader(data_cfg) 60 | self.dataloader = er.data.as_ddp_inference_loader(dataloader) 61 | self.score_tracker = er.metric.ScoreTracker() 62 | 63 | self.score_table_name = data_cfg.type 64 | 65 | def func(self): 66 | score = binary_change_detection_evaluate(self.unwrapped_model, self.dataloader, self.model_dir, self.logger) 67 | 68 | best_score = self.score_tracker.highest_score('eval/f1') 69 | if score['eval/f1'] > best_score['eval/f1']: 70 | self.save_model('model-best.pth') 71 | 72 | self.score_tracker.append(score, self.global_step) 73 | self.score_tracker.to_csv(os.path.join(self.model_dir, f'{self.score_table_name}_scores.csv')) 74 | 75 | best_score = self.score_tracker.highest_score('eval/f1') 76 | self.logger.info(f"best F1: {best_score['eval/f1']}, at step {best_score['step']}") 77 | -------------------------------------------------------------------------------- /torchange/metrics/second.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Zhuo Zheng and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import math 9 | import torch 10 | import ever as er 11 | from tqdm import tqdm 12 | import os 13 | 14 | palette = np.array([ 15 | [255, 255, 255], 16 | [0, 0, 255], # water 17 | [128, 128, 128], # ground 18 | [0, 128, 0], # low veg 19 | [0, 255, 0], # tree 20 | [128, 0, 0], # building 21 | [255, 0, 0], # playground 22 | ], dtype=np.uint8) 23 | 24 | land_use = ['W', 'G', 'L', 'T', 'B', 'P'] 25 | change_types = ['unchanged'] + [''] * 36 26 | 27 | for i in range(6): 28 | for j in range(6): 29 | change_types[i * 6 + j + 1] = f'{land_use[i]}2{land_use[j]}' 30 | 31 | 32 | def fast_hist(a, b, n): 33 | k = (a >= 0) & (a < n) 34 | return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n) 35 | 36 | 37 | def get_hist(image, label, num_class): 38 | hist = np.zeros((num_class, num_class)) 39 | hist += fast_hist(image.flatten(), label.flatten(), num_class) 40 | return hist 41 | 42 | 43 | def cal_kappa(hist): 44 | if hist.sum() == 0: 45 | po = 0 46 | pe = 1 47 | kappa = 0 48 | else: 49 | po = np.diag(hist).sum() / hist.sum() 50 | pe = np.matmul(hist.sum(1), hist.sum(0).T) / hist.sum() ** 2 51 | if pe == 1: 52 | kappa = 0 53 | else: 54 | kappa = (po - pe) / (1 - pe) 55 | return kappa 56 | 57 | 58 | def score_summary(hist): 59 | hist_fg = hist[1:, 1:] 60 | c2hist = np.zeros((2, 2)) 61 | c2hist[0][0] = hist[0][0] 62 | c2hist[0][1] = hist.sum(1)[0] - hist[0][0] 63 | c2hist[1][0] = hist.sum(0)[0] - hist[0][0] 64 | c2hist[1][1] = hist_fg.sum() 65 | hist_n0 = hist.copy() 66 | hist_n0[0][0] = 0 67 | kappa_n0 = cal_kappa(hist_n0) 68 | iu = np.diag(c2hist) / (c2hist.sum(1) + c2hist.sum(0) - np.diag(c2hist)) 69 | IoU_fg = iu[1] 70 | IoU_mean = (iu[0] + iu[1]) / 2 71 | Sek = (kappa_n0 * math.exp(IoU_fg)) / math.e 72 | Score = 0.3 * IoU_mean + 0.7 * Sek 73 | 74 | return { 75 | 'SECOND/kappa': kappa_n0, 76 | 'SECOND/mIoU': IoU_mean, 77 | 'SECOND/Sek': Sek, 78 | 'SECOND/Score': Score, 79 | 'SECOND/IoU_1': IoU_fg 80 | } 81 | 82 | 83 | @er.registry.CALLBACK.register() 84 | class SemanticChangeDetectionEval(er.Callback): 85 | def __init__(self, data_cfg, epoch_interval, prior=101): 86 | super().__init__( 87 | epoch_interval=epoch_interval, 88 | only_master=False, 89 | prior=prior, 90 | before_train=False, 91 | after_train=True, 92 | ) 93 | dataloader = er.builder.make_dataloader(data_cfg) 94 | self.dataloader = er.data.as_ddp_inference_loader(dataloader) 95 | self.score_tracker = er.metric.ScoreTracker() 96 | self.score_table_name = data_cfg.type 97 | 98 | def func(self): 99 | second_score = self.evaluate_sek() 100 | self.info(second_score) 101 | 102 | score = self.evaluate_mIoU() 103 | self.info(score) 104 | score.update(second_score) 105 | 106 | best_key = 'eval/mIoU_scd' 107 | best_score = self.score_tracker.highest_score(best_key) 108 | if score[best_key] > best_score[best_key]: 109 | self.save_model('model-best.pth') 110 | 111 | self.score_tracker.append(score, self.global_step) 112 | 113 | self.score_tracker.to_csv(os.path.join(self.model_dir, f'{self.score_table_name}_scores.csv')) 114 | 115 | best_score = self.score_tracker.highest_score(best_key) 116 | self.launcher.logger.info(f"best mIoU_scd: {best_score[best_key]}, at step {best_score['step']}") 117 | 118 | @torch.no_grad() 119 | def evaluate_sek(self): 120 | self.model.eval() 121 | num_class = 37 122 | hist = np.zeros((num_class, num_class)) 123 | 124 | for img, gt in tqdm(self.dataloader, disable=not er.dist.is_main_process()): 125 | img = img.to(er.auto_device()) 126 | predictions = self.model(img) 127 | CLASS = predictions['t1_semantic_prediction'].size(1) 128 | s1 = predictions['t1_semantic_prediction'].argmax(dim=1) 129 | s2 = predictions['t2_semantic_prediction'].argmax(dim=1) 130 | c = predictions['change_prediction'] > 0.5 131 | 132 | pr_sc = torch.where(c, s1 * CLASS + s2 + 1, torch.zeros_like(s1)) 133 | 134 | gt_s1 = gt['masks'][0].to(torch.int64) 135 | gt_s2 = gt['masks'][1].to(torch.int64) 136 | gt_sc = torch.where(gt['masks'][-1] > 0, gt_s1 * CLASS + gt_s2 + 1, 137 | torch.zeros_like(gt['masks'][0])) 138 | 139 | hist += get_hist(pr_sc.cpu().numpy(), gt_sc.cpu().numpy(), num_class) 140 | 141 | return score_summary(hist) 142 | 143 | @torch.no_grad() 144 | def evaluate_mIoU(self): 145 | self.model.eval() 146 | bcd = er.metric.PixelMetric(2, self.model_dir, logger=self.logger) 147 | scd = er.metric.PixelMetric(6 * 6 + 1, self.model_dir, logger=self.logger, class_names=change_types) 148 | class_freq = torch.zeros([6 * 6 + 1, ], dtype=torch.int64) 149 | 150 | for img, gt in tqdm(self.dataloader, disable=not er.dist.is_main_process()): 151 | img = img.to(er.auto_device()) 152 | predictions = self.model(img) 153 | CLASS = predictions['t1_semantic_prediction'].size(1) 154 | 155 | s1 = predictions['t1_semantic_prediction'].argmax(dim=1) 156 | s2 = predictions['t2_semantic_prediction'].argmax(dim=1) 157 | c = predictions['change_prediction'] > 0.5 158 | 159 | pr_sc = torch.where(c, s1 * CLASS + s2 + 1, torch.zeros_like(s1)) 160 | 161 | gt_s1 = gt['masks'][0].to(torch.int64) 162 | gt_s2 = gt['masks'][1].to(torch.int64) 163 | gt_sc = torch.where(gt['masks'][-1] > 0, 164 | gt_s1 * CLASS + gt_s2 + 1, 165 | torch.zeros_like(gt['masks'][0])) 166 | 167 | bcd.forward(gt['masks'][-1], c) 168 | scd.forward(gt_sc, pr_sc) 169 | 170 | idx, cnt = torch.unique(gt_sc, return_counts=True) 171 | class_freq.scatter_add_(dim=0, index=idx.to(torch.int64), src=cnt) 172 | 173 | valid_cls_indices = class_freq.nonzero(as_tuple=True)[0].numpy() 174 | er.info(f'effective number of change types: {valid_cls_indices.shape[0]}') 175 | 176 | er.dist.synchronize() 177 | bcd_results = bcd.summary_all() 178 | scd_results = scd.summary_all() 179 | 180 | ious = [] 181 | for i in valid_cls_indices: 182 | ious.append(scd_results.iou(int(i))) 183 | mIoU = sum(ious) / len(ious) 184 | 185 | return { 186 | 'eval/mIoU_scd': mIoU, 187 | 'eval/IoU_bcd': bcd_results.iou(1), 188 | 'eval/f1_bcd': bcd_results.f1(1), 189 | 'eval/prec_bcd': bcd_results.precision(1), 190 | 'eval/rec_bcd': bcd_results.recall(1), 191 | } 192 | -------------------------------------------------------------------------------- /torchange/metrics/xview2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Zhuo Zheng and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | from pathlib import Path 9 | 10 | import albumentations as A 11 | import albumentations.pytorch 12 | import ever as er 13 | import numpy as np 14 | import torch 15 | from tqdm import tqdm 16 | 17 | 18 | def mixed_score(loc_tb, dam_tb): 19 | loc_f1 = loc_tb.f1(1) 20 | 21 | nodamage_f1 = dam_tb.f1(1) 22 | minor_f1 = dam_tb.f1(2) 23 | major_f1 = dam_tb.f1(3) 24 | destroyed_f1 = dam_tb.f1(4) 25 | 26 | # https://github.com/DIUx-xView/xView2_scoring/blob/ea0793da6f66a71236f2c4a34536d51beff483ab/xview2_metrics.py#L247 27 | 28 | harmonic_mean = lambda xs: len(xs) / sum((x + 1e-6) ** -1 for x in xs) 29 | dam_f1 = harmonic_mean([nodamage_f1, minor_f1, major_f1, destroyed_f1]) 30 | 31 | final_f1 = 0.3 * loc_f1 + 0.7 * dam_f1 32 | return loc_f1, dam_f1, final_f1, [nodamage_f1, minor_f1, major_f1, destroyed_f1] 33 | 34 | 35 | def _accumulate_loc(op, gt, pr): 36 | gt = gt.numpy().ravel() 37 | gt = np.where(gt > 0, np.ones_like(gt), np.zeros_like(gt)) 38 | op.forward(gt, pr) 39 | 40 | 41 | def _accumulate_dam(op, gt, pr): 42 | IGNORE_INDEX = 255 43 | gt_dam = gt.numpy().ravel() 44 | # https://github.com/DIUx-xView/xView2_scoring/blob/ea0793da6f66a71236f2c4a34536d51beff483ab/xview2_metrics.py#L100 45 | valid_inds = np.where((gt_dam != IGNORE_INDEX) & (gt_dam != 0))[0] 46 | gt_dam = gt_dam[valid_inds] 47 | dam_pred = pr.cpu().numpy().ravel()[valid_inds] 48 | op.forward(gt_dam, dam_pred) 49 | 50 | 51 | def parse_prediction_v1(pred): 52 | loc_pred = pred['t1_semantic_prediction'] > 0.5 53 | dam_pred = pred['change_prediction'].argmax(dim=1) 54 | return loc_pred, dam_pred 55 | 56 | 57 | @torch.no_grad() 58 | def evaluate(model, test_dataloader, logger, model_dir, split): 59 | dataloder = test_dataloader 60 | torch.cuda.empty_cache() 61 | model.eval() 62 | 63 | loc_metric_op = er.metric.PixelMetric(2, model_dir, logger=logger) 64 | damage_metric_op = er.metric.PixelMetric(5, model_dir, 65 | logger=logger) 66 | 67 | for x, y in tqdm(dataloder, disable=not er.dist.is_main_process()): 68 | x = x.to(er.auto_device()) 69 | gt_loc, gt_dam = y['masks'] 70 | pred = model(x) 71 | loc_pred, dam_pred = parse_prediction_v1(pred) 72 | 73 | # single-map constrain 74 | # https://github.com/DIUx-xView/xView2_scoring/blob/ea0793da6f66a71236f2c4a34536d51beff483ab/xview2_metrics.py#L99 75 | dam_pred = loc_pred * dam_pred 76 | 77 | _accumulate_loc(loc_metric_op, gt_loc, loc_pred) 78 | 79 | _accumulate_dam(damage_metric_op, gt_dam, dam_pred) 80 | 81 | er.dist.synchronize() 82 | loc_tb = loc_metric_op.summary_all() 83 | dam_tb = damage_metric_op.summary_all() 84 | 85 | loc_f1, dam_f1, final_f1, dam_f1s = mixed_score(loc_tb, dam_tb) 86 | 87 | logger.info(f'\nOverall F1, Localization F1, Damage F1\n{final_f1:.4f}, {loc_f1:.4f}, {dam_f1:.4f}') 88 | logger.info(f'dam f1 per class\n{dam_f1s[0]:.4f}, {dam_f1s[1]:.4f}, {dam_f1s[2]:.4f}, {dam_f1s[3]:.4f}') 89 | 90 | torch.cuda.empty_cache() 91 | 92 | return { 93 | f'{split}/loc_f1': loc_f1, 94 | f'{split}/dam_f1': dam_f1, 95 | f'{split}/final_f1': final_f1, 96 | f'{split}/non': dam_f1s[0], 97 | f'{split}/minor': dam_f1s[1], 98 | f'{split}/major': dam_f1s[2], 99 | f'{split}/destroyed': dam_f1s[3], 100 | } 101 | 102 | 103 | class _xView2StandardEval(er.Callback): 104 | def __init__( 105 | self, 106 | epoch_interval=10, 107 | only_master=False, 108 | prior=101, 109 | after_train=True 110 | ): 111 | super().__init__(epoch_interval=epoch_interval, only_master=only_master, prior=prior, 112 | after_train=after_train) 113 | self.tracked_scores = er.metric.ScoreTracker() 114 | self.best_final_f1 = 0. 115 | self.best_step = 0 116 | self.split = None 117 | 118 | def func(self): 119 | self.logger.info(f'Split: {self.split}') 120 | 121 | scores = evaluate(self.unwrapped_model, self.dataloader, self.logger, self.model_dir, self.split) 122 | self.tracked_scores.append(scores, self.global_step) 123 | 124 | if er.dist.is_main_process(): 125 | self.tracked_scores.to_csv(os.path.join(self.model_dir, f'{self.split}_tracked_scores.csv')) 126 | 127 | if scores[f'{self.split}/final_f1'] > self.best_final_f1: 128 | self.save_model('model-best.pth') 129 | self.best_final_f1 = scores[f'{self.split}/final_f1'] 130 | self.best_step = self.global_step 131 | 132 | self.logger.info(f'best scores: {self.best_final_f1}, at step: {self.best_step}') 133 | 134 | 135 | @er.registry.CALLBACK.register() 136 | class xView2StandardEval(_xView2StandardEval): 137 | def __init__( 138 | self, 139 | dataset_dir, 140 | epoch_interval=10, 141 | only_master=False, 142 | prior=101, 143 | after_train=True 144 | ): 145 | super().__init__(epoch_interval=epoch_interval, only_master=only_master, prior=prior, 146 | after_train=after_train) 147 | split = Path(dataset_dir).name 148 | assert split in ['test', 'hold'] 149 | self.split = split 150 | 151 | dataloader = er.builder.make_dataloader(dict( 152 | type='xView2', 153 | params=dict( 154 | dataset_dir=dataset_dir, 155 | training=False, 156 | transforms=A.Compose([ 157 | A.Normalize(), 158 | A.pytorch.ToTensorV2(), 159 | ]), 160 | batch_size=1, 161 | num_workers=2, 162 | ), 163 | )) 164 | self.dataloader = er.data.as_ddp_inference_loader(dataloader) 165 | 166 | 167 | @er.registry.CALLBACK.register() 168 | class HFxView2StandardEval(_xView2StandardEval): 169 | def __init__( 170 | self, 171 | split, 172 | epoch_interval=10, 173 | only_master=False, 174 | prior=101, 175 | after_train=True 176 | ): 177 | super().__init__(epoch_interval=epoch_interval, only_master=only_master, prior=prior, 178 | after_train=after_train) 179 | assert split in ['test', 'hold'] 180 | self.split = split 181 | 182 | dataloader = er.builder.make_dataloader(dict( 183 | type='HFxView2', 184 | params=dict( 185 | hf_repo_name='EVER-Z/torchange_xView2', 186 | splits=[split], 187 | training=False, 188 | transform=A.Compose([ 189 | A.Normalize(), 190 | A.pytorch.ToTensorV2(), 191 | ]), 192 | batch_size=1, 193 | num_workers=2, 194 | ), 195 | )) 196 | self.dataloader = er.data.as_ddp_inference_loader(dataloader) 197 | -------------------------------------------------------------------------------- /torchange/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Zhuo Zheng and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchange/models/changemask.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Zhuo Zheng and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import ever as er 9 | import ever.module.loss as L 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from einops import rearrange 14 | 15 | try: 16 | import segmentation_models_pytorch as smp 17 | from segmentation_models_pytorch.decoders.unet.decoder import UnetDecoder 18 | except ImportError: 19 | print(f"segmentation_models_pytorch not found. please `pip install segmentation_models_pytorch`") 20 | 21 | 22 | 23 | CHANGE = 'change_prediction' 24 | T1SEM = 't1_semantic_prediction' 25 | T2SEM = 't2_semantic_prediction' 26 | 27 | 28 | def bitemporal_forward(module, x): 29 | x = rearrange(x, 'b (t c) h w -> (b t) c h w', t=2) 30 | features = module(x) 31 | if isinstance(features, list) or isinstance(features, tuple): 32 | t1_features, t2_features = [], [] 33 | for feat in features: 34 | t1_feat, t2_feat = rearrange(feat, '(b t) c h w -> t b c h w', t=2) 35 | t1_features.append(t1_feat) 36 | t2_features.append(t2_feat) 37 | else: 38 | t1_features, t2_features = rearrange(features, '(b t) c h w -> t b c h w', t=2) 39 | return t1_features, t2_features 40 | 41 | 42 | @torch.cuda.amp.autocast(dtype=torch.float32) 43 | def mse_loss(s1_logit, s2_logit, gt_masks): 44 | c_gt = gt_masks[-1].to(torch.float32).unsqueeze(1) 45 | 46 | s1_p = s1_logit.log_softmax(dim=1).exp() 47 | s2_p = s2_logit.log_softmax(dim=1).exp() 48 | 49 | diff = (s1_p - s2_p) ** 2 50 | losses = (1 - c_gt) * diff + c_gt * (1 - diff) 51 | 52 | return losses.mean() 53 | 54 | 55 | @torch.cuda.amp.autocast(dtype=torch.float32) 56 | def loss( 57 | s1_logit, s2_logit, c_logit, 58 | gt_masks, 59 | ): 60 | s1_gt = gt_masks[0].to(torch.int64) 61 | s2_gt = gt_masks[1].to(torch.int64) 62 | 63 | s1_ce = F.cross_entropy(s1_logit, s1_gt, ignore_index=255) 64 | s1_dice = L.dice_loss_with_logits(s1_logit, s1_gt) 65 | 66 | s2_ce = F.cross_entropy(s2_logit, s2_gt, ignore_index=255) 67 | s2_dice = L.dice_loss_with_logits(s2_logit, s2_gt) 68 | 69 | c_gt = gt_masks[-1].to(torch.float32) 70 | c_dice = L.dice_loss_with_logits(c_logit, c_gt) 71 | c_bce = L.binary_cross_entropy_with_logits(c_logit, c_gt) 72 | 73 | sim_loss = mse_loss(s1_logit, s2_logit, gt_masks) 74 | return { 75 | 's1_ce_loss': s1_ce, 76 | 's1_dice_loss': s1_dice, 77 | 's2_ce_loss': s2_ce, 78 | 's2_dice_loss': s2_dice, 79 | 'c_dice_loss': c_dice, 80 | 'c_bce_loss': c_bce, 81 | # to improve semantic-change consistency, this is a well-known issue in ChangeMask-like SCD methods. 82 | # original implementation doesn't have this objective. 83 | 'sim_loss': sim_loss 84 | } 85 | 86 | 87 | class Squeeze(nn.Module): 88 | def __init__(self, dim): 89 | super(Squeeze, self).__init__() 90 | self.dim = dim 91 | 92 | def forward(self, x: torch.Tensor): 93 | return x.squeeze(dim=self.dim) 94 | 95 | 96 | class SpatioTemporalInteraction(nn.Sequential): 97 | def __init__(self, 98 | in_channels, 99 | out_channels, 100 | kernel_size, 101 | dilation=1, 102 | type='conv3d'): 103 | if type == 'conv3d': 104 | padding = dilation * (kernel_size - 1) // 2 105 | super(SpatioTemporalInteraction, self).__init__( 106 | nn.Conv3d(in_channels, out_channels, [2, kernel_size, kernel_size], stride=1, 107 | dilation=(1, dilation, dilation), 108 | padding=(0, padding, padding), 109 | bias=False), 110 | Squeeze(dim=2), 111 | nn.BatchNorm2d(out_channels), 112 | nn.ReLU(True) 113 | ) 114 | elif type == 'conv1plus2d': 115 | super(SpatioTemporalInteraction, self).__init__( 116 | nn.Conv3d(in_channels, out_channels, (2, 1, 1), stride=1, 117 | padding=(0, 0, 0), 118 | bias=False), 119 | Squeeze(dim=2), 120 | nn.BatchNorm2d(out_channels), 121 | nn.ReLU(True), 122 | nn.Conv2d(out_channels, out_channels, kernel_size, 1, 123 | kernel_size // 2) if kernel_size > 1 else nn.Identity(), 124 | nn.BatchNorm2d(out_channels) if kernel_size > 1 else nn.Identity(), 125 | nn.ReLU(True) if kernel_size > 1 else nn.Identity(), 126 | ) 127 | 128 | 129 | class TemporalSymmetricTransformer(nn.Module): 130 | def __init__(self, 131 | in_channels, 132 | out_channels, 133 | kernel_size, 134 | dilation=1, 135 | interaction_type='conv3d', 136 | symmetric_fusion='add'): 137 | super(TemporalSymmetricTransformer, self).__init__() 138 | 139 | if isinstance(in_channels, list) or isinstance(in_channels, tuple): 140 | self.t = nn.ModuleList([ 141 | SpatioTemporalInteraction(inc, outc, kernel_size, dilation=dilation, type=interaction_type) 142 | for inc, outc in zip(in_channels, out_channels) 143 | ]) 144 | else: 145 | self.t = SpatioTemporalInteraction(in_channels, out_channels, kernel_size, dilation=dilation, 146 | type=interaction_type) 147 | 148 | if symmetric_fusion == 'add': 149 | self.symmetric_fusion = lambda x, y: x + y 150 | elif symmetric_fusion == 'mul': 151 | self.symmetric_fusion = lambda x, y: x * y 152 | elif symmetric_fusion == None: 153 | self.symmetric_fusion = None 154 | 155 | def forward(self, features1, features2): 156 | if isinstance(features1, list): 157 | d12_features = [op(torch.stack([f1, f2], dim=2)) for op, f1, f2 in 158 | zip(self.t, features1, features2)] 159 | if self.symmetric_fusion: 160 | d21_features = [op(torch.stack([f2, f1], dim=2)) for op, f1, f2 in 161 | zip(self.t, features1, features2)] 162 | change_features = [self.symmetric_fusion(d12, d21) for d12, d21 in zip(d12_features, d21_features)] 163 | else: 164 | change_features = d12_features 165 | else: 166 | if self.symmetric_fusion: 167 | change_features = self.symmetric_fusion(self.t(torch.stack([features1, features2], dim=2)), 168 | self.t(torch.stack([features2, features1], dim=2))) 169 | else: 170 | change_features = self.t(torch.stack([features1, features2], dim=2)) 171 | change_features = change_features.squeeze(dim=2) 172 | return change_features 173 | 174 | 175 | @er.registry.MODEL.register() 176 | class ChangeMask(er.ERModule): 177 | def __init__(self, cfg): 178 | super().__init__(cfg) 179 | self.encoder = smp.encoders.get_encoder('efficientnet-b0', weights='imagenet') 180 | out_channels = self.encoder.out_channels 181 | self.semantic_decoder = UnetDecoder( 182 | encoder_channels=out_channels, 183 | decoder_channels=[256, 128, 64, 32, 16], 184 | ) 185 | 186 | self.change_decoder = UnetDecoder( 187 | encoder_channels=out_channels, 188 | decoder_channels=[256, 128, 64, 32, 16], 189 | ) 190 | 191 | self.temporal_transformer = TemporalSymmetricTransformer( 192 | out_channels, out_channels, 193 | 3, interaction_type='conv3d', symmetric_fusion='add', 194 | ) 195 | self.s = nn.Conv2d(16, self.cfg.num_semantic_classes, 1) 196 | self.c = nn.Conv2d(16, 1, 1) 197 | 198 | def forward(self, x, y=None): 199 | t1_features, t2_features = bitemporal_forward(self.encoder, x) 200 | 201 | s1_logit = self.s(self.semantic_decoder(*t1_features)) 202 | s2_logit = self.s(self.semantic_decoder(*t2_features)) 203 | 204 | temporal_features = self.temporal_transformer(t1_features, t2_features) 205 | c_logit = self.c(self.change_decoder(*temporal_features)) 206 | 207 | if self.training: 208 | return loss(s1_logit, s2_logit, c_logit, y['masks']) 209 | 210 | return { 211 | T1SEM: s1_logit.softmax(dim=1), 212 | T2SEM: s2_logit.softmax(dim=1), 213 | CHANGE: c_logit.sigmoid(), 214 | } 215 | 216 | def set_default_config(self): 217 | self.cfg.update(dict( 218 | num_semantic_classes=6 219 | )) 220 | -------------------------------------------------------------------------------- /torchange/models/changen2/README.md: -------------------------------------------------------------------------------- 1 | # Changen2 (TPAMI 2024) 2 | 3 | This is the official repository for IEEE TPAMI 2024 paper 4 | "_Changen2: Multi-Temporal Remote Sensing Generative Change Foundation Model_". 5 | 6 | Authors: 7 | [Zhuo Zheng](https://zhuozheng.top/) 8 | [Stefano Ermon](https://cs.stanford.edu/~ermon/) 9 | [Dongjun Kim](https://sites.google.com/view/dongjun-kim) 10 | [Liangpei Zhang](http://www.lmars.whu.edu.cn/prof_web/zhangliangpei/rs/index.html) 11 | [Yanfei Zhong](http://rsidea.whu.edu.cn/) 12 | 13 | Abstract: Our understanding of the temporal dynamics of the Earth's surface has been significantly advanced by deep vision models, which often require a massive amount of labeled multi-temporal images for training. 14 | However, collecting, preprocessing, and annotating multi-temporal remote sensing images at scale is non-trivial since it is expensive and knowledge-intensive. 15 | In this paper, we present scalable multi-temporal change data generators based on generative models, which are cheap and automatic, alleviating these data problems. 16 | Our main idea is to simulate a stochastic change process over time. 17 | We describe the stochastic change process as a probabilistic graphical model, namely the generative probabilistic change model (GPCM), which factorizes the complex simulation problem into two more tractable sub-problems, i.e., condition-level change event simulation and image-level semantic change synthesis. 18 | To solve these two problems, we present Changen2, a GPCM implemented with a resolution-scalable diffusion transformer which can generate time series of remote sensing images and corresponding semantic and change labels from labeled and even unlabeled single-temporal images. 19 | Changen2 is a generative change foundation model that can be trained at scale via self-supervision, and is capable of producing change supervisory signals from unlabeled single-temporal images. 20 | Unlike existing foundation models, our generative change foundation model synthesizes change data to train task-specific foundation models for change detection. 21 | The resulting model possesses inherent zero-shot change detection capabilities and excellent transferability. 22 | Comprehensive experiments suggest Changen2 has superior spatiotemporal scalability in data generation, e.g., Changen2 model trained on 256$^2$ pixel single-temporal images can yield time series of any length and resolutions of 1,024^2 pixels. 23 | Changen2 pre-trained models exhibit superior zero-shot performance (narrowing the performance gap to 3% on LEVIR-CD and approximately 10% on both S2Looking and SECOND, compared to fully supervised counterpart) and transferability across multiple types of change tasks, including ordinary and off-nadir building change, land-use/land-cover change, and disaster assessment. 24 | 25 | ## Get Started (TBD) 26 | 27 | ### Change Event Simulation 28 | ```python 29 | from torchange.models.changen2 import change_event_simulation as ces 30 | 31 | # Changen2, Sec. 3.2, Change Event Simulation 32 | ces.add_object # Object Creation 33 | ces.remove_object # Object Removal 34 | ces.random_transition # Attribute Edit 35 | # Changen2, Sec. 3.5, Fig.7 36 | ces.next_time_contour_gen 37 | ``` 38 | 39 | ### Resolution-Scalable DiT models 40 | 41 | ```python 42 | from torchange.models.changen2 import RSDiT_models 43 | ``` 44 | 45 | ### Changen2 pre-trained ChangeStar (1x256) models 46 | 47 | ```python 48 | from torchange.models.changen2 import changestar_1x256 49 | ``` 50 | 51 | ### Synthetic Change Datasets 52 | [Changen2-S1-15k](https://huggingface.co/datasets/EVER-Z/Changen2-S1-15k), a building change dataset with 15k pairs and 2 change types), 0.3-1m spatial resolution, RGB bands 53 | 54 | [Changen2-S9-27k](https://huggingface.co/datasets/EVER-Z/Changen2-S9-27k), an urban land-use/landcover change dataset with 27k pairs and 38 change types), 0.25-0.5m spatial resolution, RGB bands 55 | 56 | 57 | ## Citation 58 | If you find our project helpful, we would greatly appreciate it if you could kindly cite our papers: 59 | ``` 60 | @article{changen2, 61 | author={Zheng, Zhuo and Ermon, Stefano and Kim, Dongjun and Zhang, Liangpei and Zhong, Yanfei}, 62 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 63 | title={Changen2: Multi-Temporal Remote Sensing Generative Change Foundation Model}, 64 | year={2025}, 65 | volume={47}, 66 | number={2}, 67 | pages={725-741}, 68 | doi={10.1109/TPAMI.2024.3475824} 69 | } 70 | @inproceedings{changen, 71 | title={Scalable multi-temporal remote sensing change data generation via simulating stochastic change process}, 72 | author={Zheng, Zhuo and Tian, Shiqi and Ma, Ailong and Zhang, Liangpei and Zhong, Yanfei}, 73 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 74 | pages={21818--21827}, 75 | year={2023} 76 | } 77 | ``` -------------------------------------------------------------------------------- /torchange/models/changen2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Zhuo Zheng and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from ._changestar_1x256 import * 7 | from .rsdit import * 8 | -------------------------------------------------------------------------------- /torchange/models/changen2/_changestar_1x256.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Zhuo Zheng and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | import torch 8 | import ever as er 9 | from torchange.models.changestar_1xd import ChangeStar1xd 10 | 11 | __all__ = [ 12 | 'changestar_1x256', 13 | 's1_init_s1c1_changestar_vitb_1x256', 14 | 's1_init_s1c1_changestar_vitl_1x256', 15 | 's9_init_s9c1_changestar_vitb_1x256', 16 | 's0_init_s1c1_changestar_vitb_1x256', 17 | 's0_init_s1c5_changestar_vitb_1x256', 18 | 's0_init_s9c1_changestar_vitb_1x256', 19 | ] 20 | 21 | 22 | def changestar_1x256(backbone_type, modeling_type, changen2_pretrained=None) -> ChangeStar1xd: 23 | import json 24 | from huggingface_hub import hf_hub_download 25 | from torchange.module.farseg import SAMEncoderFarSeg 26 | assert modeling_type in ['s1c1', 's9c1', 's1c5', ] 27 | assert backbone_type in ['vitb', 'vitl'] 28 | assert changen2_pretrained in [None, 's0', 's1', 's9'] 29 | pretrain_data = { 30 | None: None, 31 | 's0': 'Changen2-S0-1.2M', 32 | 's1': 'Changen2-S1-15k', 33 | 's9': 'Changen2-S9-27k' 34 | } 35 | 36 | model_name = f'{modeling_type}_cstar_{backbone_type}_1x256' 37 | # build model 38 | package_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 39 | cfg = er.config.import_config(os.path.join(package_root, 'configs', 'changen2', model_name)) 40 | model: ChangeStar1xd = er.builder.make_model(cfg.model) 41 | 42 | # load Changen2 pre-trained weight 43 | if changen2_pretrained: 44 | available_config = hf_hub_download('EVER-Z/Changen2-ChangeStar1x256', 'config.json') 45 | with open(available_config, "r", encoding="utf-8") as reader: 46 | text = reader.read() 47 | available_config = json.loads(text) 48 | weight_name = f'{changen2_pretrained}_changestar_{backbone_type}_1x256' 49 | assert weight_name in available_config, f'{weight_name} is not available' 50 | weights = hf_hub_download('EVER-Z/Changen2-ChangeStar1x256', available_config[weight_name]) 51 | 52 | model.load_state_dict(torch.load(weights, map_location=torch.device('cpu')), strict=changen2_pretrained != 's0') 53 | er.info(f'Load Changen2 pre-trained weight from EVER-Z/Changen2-ChangeStar1x256/{available_config[weight_name]}') 54 | 55 | er.info( 56 | f'architecture: changestar_1x256 | backbone: {backbone_type} | pre-trained data: {pretrain_data[changen2_pretrained]}') 57 | return model 58 | 59 | 60 | def s1_init_s1c1_changestar_vitb_1x256(): return changestar_1x256('vitb', 's1c1', 's1') 61 | 62 | 63 | def s1_init_s1c1_changestar_vitl_1x256(): return changestar_1x256('vitl', 's1c1', 's1') 64 | 65 | 66 | def s9_init_s9c1_changestar_vitb_1x256(): return changestar_1x256('vitb', 's9c1', 's9') 67 | 68 | 69 | def s0_init_s1c1_changestar_vitb_1x256(): return changestar_1x256('vitb', 's1c1', 's0') 70 | 71 | 72 | def s0_init_s9c1_changestar_vitb_1x256(): return changestar_1x256('vitb', 's9c1', 's0') 73 | 74 | 75 | def s0_init_s1c5_changestar_vitb_1x256(): return changestar_1x256('vitb', 's1c5', 's0') 76 | -------------------------------------------------------------------------------- /torchange/models/changen2/change_event_simulation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Zhuo Zheng and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from skimage.measure import label, regionprops 8 | from skimage.segmentation import find_boundaries 9 | from skimage.morphology import binary_opening, dilation, square 10 | import numpy as np 11 | import random 12 | 13 | MAXIMUM_TRY = 50 14 | 15 | 16 | class LC: 17 | Bareland = 1 18 | Rangeland = 2 19 | DevelopedSpace = 3 20 | Road = 4 21 | Tree = 5 22 | Water = 6 23 | AgricultureLand = 7 24 | Building = 8 25 | 26 | 27 | # OEM 28 | OEM_Transition = [ 29 | [i + 1 for i in range(8)], 30 | [LC.Rangeland, LC.DevelopedSpace, LC.Tree, LC.Water, LC.AgricultureLand], 31 | [LC.Bareland, LC.DevelopedSpace, LC.Tree, LC.Water, LC.AgricultureLand], 32 | [LC.Bareland, LC.Rangeland, LC.Tree, LC.Water, LC.AgricultureLand], 33 | [LC.Bareland, LC.DevelopedSpace, LC.Tree, LC.Water], 34 | [LC.Bareland, LC.Rangeland, LC.DevelopedSpace, LC.Water, LC.AgricultureLand], 35 | [LC.Bareland, LC.Rangeland, LC.DevelopedSpace, LC.Tree, LC.AgricultureLand], 36 | [LC.Bareland, LC.Rangeland, LC.DevelopedSpace, LC.Tree, LC.Water, LC.AgricultureLand], 37 | [LC.DevelopedSpace, LC.Tree, LC.Water] 38 | ] 39 | 40 | 41 | def object_proposal(mask): 42 | mask = (mask > 0).astype(np.uint8, copy=False) 43 | props = regionprops(label(mask)) 44 | return props 45 | 46 | 47 | def add_object(obj_mask, max_add_num_per_frame, min_add_num_per_frame=1): 48 | h, w = obj_mask.shape 49 | props = object_proposal(obj_mask) 50 | props = [p for p in props if p.area > 8 * 8] 51 | num_objs = random.randint(min_add_num_per_frame, max_add_num_per_frame) 52 | 53 | random.shuffle(props) 54 | props = props[:num_objs] 55 | 56 | new_obj_mask = (obj_mask > 0).astype(np.uint8) 57 | 58 | for obj in props: 59 | rr, cc = obj.coords.T 60 | 61 | ymin, xmin, ymax, xmax = obj.bbox 62 | 63 | for _ in range(MAXIMUM_TRY): 64 | # [-ymin, h - ymax) 65 | yscale = (h - ymax) + ymin 66 | yshift = -ymin 67 | yoffset = int(np.random.rand() * yscale + yshift) 68 | # [-xmin, w - xmax) 69 | xscale = (w - xmax) + xmin 70 | xshift = -xmin 71 | xoffset = int(np.random.rand() * xscale + xshift) 72 | 73 | candidate = new_obj_mask[rr + yoffset, cc + xoffset] 74 | if np.sum(candidate) == 0: 75 | new_obj_mask[rr + yoffset, cc + xoffset] = 1 76 | break 77 | 78 | return new_obj_mask 79 | 80 | 81 | def remove_object(obj_mask, max_rm_num_per_frame, min_rm_num_per_frame=1): 82 | props = object_proposal(obj_mask) 83 | 84 | props = [p for p in props if p.area > 8 * 8] 85 | 86 | num_objs = random.randint(min_rm_num_per_frame, max_rm_num_per_frame) 87 | num_objs = min(num_objs, len(props)) 88 | 89 | random.shuffle(props) 90 | props = props[:num_objs] 91 | 92 | obj_mask = obj_mask.copy() 93 | 94 | for obj in props: 95 | rr, cc = obj.coords.T 96 | obj_mask[rr, cc] = 0 97 | 98 | return obj_mask 99 | 100 | 101 | def remove_add_object(obj_mask, max_change_num_per_frame): 102 | obj_mask = remove_object(obj_mask, max_change_num_per_frame) 103 | obj_mask = add_object(obj_mask, max_change_num_per_frame) 104 | return obj_mask 105 | 106 | 107 | def random_transition(mask, num_classes, transition_kernel=None, p=0.3): 108 | if transition_kernel is None: 109 | transition_kernel = OEM_Transition 110 | eye = np.eye(num_classes) 111 | bin_masks = eye[mask] 112 | 113 | canvas = np.zeros_like(mask, dtype=np.int64) 114 | for i in range(num_classes): 115 | mask = bin_masks[:, :, i] 116 | if (mask == 0).all(): 117 | continue 118 | props = object_proposal(mask) 119 | props = [obj for obj in props if obj.area > 8 * 8] 120 | for obj in props: 121 | rr, cc = obj.coords.T 122 | if random.random() < p: 123 | canvas[rr, cc] = random.choice(transition_kernel[i]) 124 | else: 125 | canvas[rr, cc] = i 126 | return canvas 127 | 128 | 129 | # mainly for SAM masks 130 | def remove_instance(ins_mask, p=0.1): 131 | ins_mask = np.copy(ins_mask) 132 | for i in np.unique(ins_mask): 133 | if i == 0: 134 | continue 135 | if random.random() < p: 136 | ins_mask[ins_mask == i] = 0 137 | return ins_mask 138 | 139 | 140 | # Changen2, Sec 3.5, Fig.7 141 | def next_time_contour_gen(t1_mask, t2_mask): 142 | # compute change mask 143 | cmsk = ((t1_mask > 0) != (t2_mask > 0)).astype(np.uint8) 144 | cmsk = binary_opening(cmsk).astype(np.uint8) 145 | # compute t2 boundary 146 | bd1 = find_boundaries(t1_mask).astype(np.uint8) 147 | _cmsk = dilation(cmsk, square(3)) 148 | bd2 = bd1 * (1 - _cmsk) 149 | return bd2 150 | 151 | 152 | def generate_mask_seq(mask, seq_len=6, max_change_num_per_frame=5, mode='remove', seed=None, min_change_num_per_frame=1): 153 | random.seed(seed) 154 | if mode == 'remove': 155 | ds = [mask] 156 | for _ in range(seq_len - 1): 157 | ds.append(remove_object(ds[-1], max_change_num_per_frame, min_rm_num_per_frame=min_change_num_per_frame)) 158 | elif mode == 'add': 159 | ds = [mask] 160 | for _ in range(seq_len - 1): 161 | ds.append(add_object(ds[-1], max_change_num_per_frame, min_add_num_per_frame=min_change_num_per_frame)) 162 | elif mode == 'mix': 163 | ds = [mask] 164 | for _ in range(seq_len - 1): 165 | ds.append(remove_add_object(ds[-1], max_change_num_per_frame)) 166 | else: 167 | raise NotImplementedError 168 | return ds 169 | -------------------------------------------------------------------------------- /torchange/models/changestar_1xd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Zhuo Zheng and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | import ever as er 8 | from einops import rearrange 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import ever.module as M 13 | import ever.module.loss as L 14 | 15 | CHANGE = 'change_prediction' 16 | T1SEM = 't1_semantic_prediction' 17 | T2SEM = 't2_semantic_prediction' 18 | 19 | 20 | def bitemporal_forward(module, x): 21 | x = rearrange(x, 'b (t c) h w -> (b t) c h w', t=2) 22 | features = module(x) 23 | if isinstance(features, list) or isinstance(features, tuple): 24 | t1_features, t2_features = [], [] 25 | for feat in features: 26 | t1_feat, t2_feat = rearrange(feat, '(b t) c h w -> t b c h w', t=2) 27 | t1_features.append(t1_feat) 28 | t2_features.append(t2_feat) 29 | else: 30 | t1_features, t2_features = rearrange(features, '(b t) c h w -> t b c h w', t=2) 31 | 32 | return t1_features, t2_features 33 | 34 | 35 | @torch.amp.autocast('cuda', dtype=torch.float32) 36 | def sc_mse_loss(s1_logit, s2_logit, gt_masks): 37 | c_gt = gt_masks[-1].to(torch.float32).unsqueeze(1) 38 | 39 | s1_p = s1_logit.log_softmax(dim=1).exp() 40 | s2_p = s2_logit.log_softmax(dim=1).exp() 41 | 42 | diff = (s1_p - s2_p) ** 2 43 | losses = (1 - c_gt) * diff + c_gt * (1 - diff) 44 | 45 | return losses.mean() 46 | 47 | 48 | @er.registry.MODEL.register() 49 | class ChangeStar1xd(er.ERModule): 50 | def __init__(self, config): 51 | super().__init__(config) 52 | self.encoder = er.registry.MODEL[self.config.encoder.type](self.config.encoder.params) 53 | 54 | self.cfg.head.in_channels = 2 * self.config.encoder.params.out_channels 55 | self.cfg.head.out_channels = self.config.encoder.params.out_channels 56 | 57 | self.head = ChangeMixinBiSupN1(**self.cfg.head) 58 | self.init_from_weight_file() 59 | 60 | def forward(self, x, y=None): 61 | if self.cfg.encoder.bitemporal_forward: 62 | bitemporal_features = bitemporal_forward(self.encoder, x) 63 | else: 64 | bitemporal_features = self.encoder(x) 65 | 66 | preds = self.head(*bitemporal_features) 67 | 68 | if self.training: 69 | return self.loss(preds, y) 70 | 71 | return preds 72 | 73 | @torch.amp.autocast('cuda', dtype=torch.float32) 74 | def loss(self, preds, y): 75 | # masks[0] - cls, masks[1] - cls, masks[2] - change 76 | # masks[0] - cls, masks[1] - change 77 | # masks[0] - change 78 | gt_change = y['masks'][-1].to(torch.float32) 79 | change_logit = preds[CHANGE].to(torch.float32) 80 | 81 | loss_dict = dict() 82 | if hasattr(self.cfg.loss, 'change'): 83 | if ('bce' in self.cfg.loss.change) or ('ce' in self.cfg.loss.change): 84 | if change_logit.size(1) == 1: 85 | ls = self.cfg.loss.change.bce.get('ls', 0.0) 86 | loss = L.label_smoothing_binary_cross_entropy(change_logit, gt_change, eps=ls) 87 | loss_dict.update( 88 | c_bce_loss=loss 89 | ) 90 | else: 91 | ls = self.cfg.loss.change.ce.get('ls', 0.0) 92 | loss = F.cross_entropy(change_logit, gt_change.to(torch.int64), ignore_index=255, label_smoothing=ls) 93 | loss_dict.update( 94 | c_ce_loss=loss 95 | ) 96 | 97 | if 'dice' in self.cfg.loss.change: 98 | gamma = self.cfg.loss.change.dice.get('gamma', 1.0) 99 | if change_logit.size(1) == 1: 100 | loss_dict.update( 101 | c_dice_loss=L.tversky_loss_with_logits(change_logit, gt_change, alpha=0.5, beta=0.5, gamma=gamma), 102 | ) 103 | else: 104 | loss_dict.update( 105 | c_dice_loss=L.dice_loss_with_logits(change_logit, gt_change), 106 | ) 107 | 108 | if preds[T1SEM] is not None and 't1' in self.cfg.loss: 109 | gt_t1 = y['masks'][0] 110 | if preds[T1SEM].size(1) > 1: 111 | loss_dict.update(dict( 112 | t1_ce_loss=F.cross_entropy(preds[T1SEM], gt_t1.to(torch.int64), reduction='mean', ignore_index=255), 113 | t1_dice_loss=L.dice_loss_with_logits(preds[T1SEM], gt_t1.to(torch.int64)) 114 | )) 115 | else: 116 | gt_t1 = gt_t1.to(torch.float32) 117 | loss_dict.update(dict( 118 | t1_bce_loss=L.binary_cross_entropy_with_logits( 119 | preds[T1SEM], gt_t1.reshape_as(preds[T1SEM]), reduction='mean'), 120 | t1_dice_loss=L.dice_loss_with_logits(preds[T1SEM], gt_t1), 121 | )) 122 | 123 | if preds[T2SEM] is not None and 't2' in self.cfg.loss: 124 | gt_t2 = y['masks'][1] 125 | if preds[T2SEM].size(1) > 1: 126 | loss_dict.update(dict( 127 | t2_ce_loss=F.cross_entropy(preds[T2SEM], gt_t2.to(torch.int64), reduction='mean', ignore_index=255), 128 | t2_dice_loss=L.dice_loss_with_logits(preds[T2SEM], gt_t2.to(torch.int64)), 129 | )) 130 | else: 131 | gt_t2 = gt_t2.to(torch.float32) 132 | loss_dict.update(dict( 133 | t2_bce_loss=F.binary_cross_entropy_with_logits( 134 | preds[T2SEM], gt_t2.reshape_as(preds[T2SEM]), reduction='mean'), 135 | t2_dice_loss=L.dice_loss_with_logits(preds[T2SEM], gt_t2), 136 | )) 137 | 138 | if 'sc' in self.cfg.loss: 139 | loss_dict.update(dict( 140 | sc_mse_loss=sc_mse_loss(preds[T1SEM], preds[T2SEM], y['masks']) 141 | )) 142 | 143 | return loss_dict 144 | 145 | def set_default_config(self): 146 | self.config.update(dict( 147 | encoder=dict(type=None, params=dict(), bitemporal_forward=False), 148 | head=dict( 149 | in_channels=-1, 150 | out_channels=-1, 151 | temporal_symmetric=True, 152 | num_semantic_classes=None, 153 | num_change_classes=None 154 | ), 155 | loss=dict( 156 | ) 157 | )) 158 | 159 | def log_info(self): 160 | return dict( 161 | encoder=self.encoder, 162 | head=self.head 163 | ) 164 | 165 | def custom_param_groups(self): 166 | param_groups = [] 167 | 168 | if isinstance(self.encoder, er.ERModule): 169 | param_groups += self.encoder.custom_param_groups() 170 | else: 171 | param_groups += [{'params': self.encoder.parameters()}] 172 | 173 | if isinstance(self.head, er.ERModule): 174 | param_groups += self.head.custom_param_groups() 175 | else: 176 | param_groups += [{'params': self.head.parameters()}] 177 | 178 | return param_groups 179 | 180 | 181 | class ChangeMixinBiSupN1(nn.Module): 182 | def __init__(self, in_channels, out_channels, temporal_symmetric=True, 183 | num_semantic_classes=None, num_change_classes=None): 184 | super().__init__() 185 | self.conv = nn.Sequential( 186 | nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), 187 | M.LayerNorm2d(out_channels), 188 | nn.GELU() 189 | ) 190 | if num_change_classes is None: 191 | num_change_classes = 1 192 | 193 | self.temporal_symmetric = temporal_symmetric 194 | self.change_conv = M.ConvUpsampling(out_channels, num_change_classes, scale_factor=4, kernel_size=1) 195 | self.num_semantic_classes = num_semantic_classes 196 | if isinstance(num_semantic_classes, int): 197 | self.semantic_conv = M.ConvUpsampling(out_channels, num_semantic_classes, scale_factor=4, kernel_size=1) 198 | elif isinstance(num_semantic_classes, (tuple, list)): 199 | self.semantic_conv = nn.ModuleList([ 200 | M.ConvUpsampling(out_channels, nc, scale_factor=4, kernel_size=1) 201 | for nc in num_semantic_classes 202 | ]) 203 | else: 204 | self.semantic_conv = nn.Identity() 205 | 206 | def forward(self, t1_feature, t2_feature): 207 | pre_logit = self.conv(torch.cat([t1_feature, t2_feature], dim=1)) 208 | if self.temporal_symmetric: 209 | pre_logit = pre_logit + self.conv(torch.cat([t2_feature, t1_feature], dim=1)) 210 | 211 | change_logit = self.change_conv(pre_logit) 212 | if isinstance(self.num_semantic_classes, int) or self.num_semantic_classes is None: 213 | t1_semantic_logit = self.semantic_conv(t1_feature) 214 | t2_semantic_logit = self.semantic_conv(t2_feature) 215 | else: 216 | t1_semantic_logit = self.semantic_conv[0](t1_feature) 217 | t2_semantic_logit = self.semantic_conv[1](t2_feature) 218 | 219 | if self.training: 220 | return { 221 | CHANGE: change_logit, 222 | T1SEM: t1_semantic_logit if self.num_semantic_classes else None, 223 | T2SEM: t2_semantic_logit if self.num_semantic_classes else None, 224 | } 225 | else: 226 | def _act(logit): 227 | if logit.size(1) > 1: 228 | return logit.softmax(dim=1) 229 | else: 230 | return logit.sigmoid() 231 | 232 | return { 233 | CHANGE: change_logit.sigmoid(), 234 | T1SEM: _act(t1_semantic_logit) if self.num_semantic_classes else None, 235 | T2SEM: _act(t2_semantic_logit) if self.num_semantic_classes else None, 236 | } 237 | -------------------------------------------------------------------------------- /torchange/models/segment_any_change/README.md: -------------------------------------------------------------------------------- 1 | # Segment Any Change (NeurIPS 2024) 2 | 3 | This is the official repository for the NeurIPS 2024 paper 4 | "_Segment Any Change_". 5 | 6 | Authors: 7 | [Zhuo Zheng](https://zhuozheng.top/) 8 | [Yanfei Zhong](http://rsidea.whu.edu.cn/) 9 | [Liangpei Zhang](http://www.lmars.whu.edu.cn/prof_web/zhangliangpei/rs/index.html) 10 | [Stefano Ermon](https://cs.stanford.edu/~ermon/). 11 | 12 | Abstract: Visual foundation models have achieved remarkable results in zero-shot image classification and segmentation, but zero-shot change detection remains an open problem. 13 | In this paper, we propose the segment any change models (AnyChange), a new type of change detection model that supports zero-shot prediction and generalization on unseen change types and data distributions. 14 | AnyChange is built on the segment anything model (SAM) via our training-free adaptation method, bitemporal latent matching. 15 | By revealing and exploiting intra-image and inter-image semantic similarities in SAM's latent space, bitemporal latent matching endows SAM with zero-shot change detection capabilities in a training-free way. 16 | We also propose a point query mechanism to enable AnyChange's zero-shot object-centric change detection capability. 17 | 18 | ## Get Started 19 | ### Case 1: automatic mode (segment any change) 20 | ```python 21 | import matplotlib.pyplot as plt 22 | from skimage.io import imread 23 | from torchange.models.segment_any_change import AnyChange, show_change_masks 24 | 25 | # initialize AnyChange 26 | m = AnyChange('vit_h', sam_checkpoint='./sam_vit_h_4b8939.pth') 27 | # customize the hyperparameters of SAM's mask generator 28 | m.make_mask_generator( 29 | points_per_side=32, 30 | stability_score_thresh=0.95, 31 | ) 32 | # customize your AnyChange's hyperparameters 33 | m.set_hyperparameters( 34 | change_confidence_threshold=145, 35 | use_normalized_feature=True, 36 | bitemporal_match=True, 37 | ) 38 | 39 | img1 = imread('https://github.com/Z-Zheng/pytorch-change-models/blob/main/demo_images/t1_img.png') 40 | img2 = imread('https://github.com/Z-Zheng/pytorch-change-models/blob/main/demo_images/t2_img.png') 41 | 42 | changemasks, _, _ = m.forward(img1, img2) # automatic mode 43 | fig, axes = show_change_masks(img1, img2, changemasks) 44 | 45 | plt.show() 46 | ``` 47 | 48 | ### Case 2: point query mode (segment change of interest) 49 | ```python 50 | import matplotlib.pyplot as plt 51 | from skimage.io import imread 52 | from torchange.models.segment_any_change import AnyChange, show_change_masks 53 | 54 | # initialize AnyChange 55 | m = AnyChange('vit_h', sam_checkpoint='./sam_vit_h_4b8939.pth') 56 | # customize the hyperparameters of SAM's mask generator 57 | m.make_mask_generator( 58 | points_per_side=32, 59 | stability_score_thresh=0.95, 60 | ) 61 | # customize your AnyChange's hyperparameters 62 | m.set_hyperparameters( 63 | change_confidence_threshold=145, 64 | use_normalized_feature=True, 65 | bitemporal_match=True, 66 | object_sim_thresh=60, # for point query 67 | ) 68 | 69 | img1 = imread('https://github.com/Z-Zheng/pytorch-change-models/blob/main/demo_images/t1_img.png') 70 | img2 = imread('https://github.com/Z-Zheng/pytorch-change-models/blob/main/demo_images/t2_img.png') 71 | 72 | # parameter description: 73 | # xy: an absolute image coordinate. 74 | # temporal: indicate which time the point belongs to 75 | changemasks = m.single_point_match(xy=[926, 44], temporal=2, img1=img1, img2=img2) 76 | fig, axes = show_change_masks(img1, img2, changemasks) 77 | 78 | plt.show() 79 | ``` 80 | 81 | 82 | 83 | ## Citation 84 | If you find our project helpful, please cite our paper: 85 | ``` 86 | @inproceedings{ 87 | zheng2024anychange, 88 | title={Segment Any Change}, 89 | author={Zhuo Zheng and Yanfei Zhong and Liangpei Zhang and Stefano Ermon}, 90 | booktitle={Advances in Neural Information Processing Systems}, 91 | year={2024}, 92 | } 93 | ``` -------------------------------------------------------------------------------- /torchange/models/segment_any_change/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Zhuo Zheng and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from torchange.models.segment_any_change.anychange import AnyChange 7 | from torchange.models.segment_any_change.viz import show_change_masks -------------------------------------------------------------------------------- /torchange/models/segment_any_change/anychange.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Zhuo Zheng and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | 10 | import torch.nn.functional as F 11 | from torchange.models.segment_any_change.segment_anything.utils.amg import rle_to_mask, MaskData 12 | import copy 13 | from skimage.filters.thresholding import threshold_otsu 14 | from torchange.models.segment_any_change.base import SegmentAnyChange 15 | import math 16 | from torchvision.ops.boxes import batched_nms 17 | 18 | 19 | def angle2cosine(a): 20 | assert 0 <= a <= 180 21 | return math.cos(a / 180 * math.pi) 22 | 23 | 24 | def cosine2angle(c): 25 | assert -1 <= c <= 1 26 | return math.acos(c) * 180 / math.pi 27 | 28 | 29 | class AnyChange(SegmentAnyChange): 30 | def __init__(self, model_type='vit_b', sam_checkpoint='./sam_weights/sam_vit_b_01ec64.pth'): 31 | super().__init__(model_type, sam_checkpoint) 32 | 33 | layernorm = self.sam.image_encoder.neck[3] 34 | w = layernorm.weight.data 35 | b = layernorm.bias.data 36 | w = w.reshape(w.size(0), 1, 1) 37 | b = b.reshape(b.size(0), 1, 1) 38 | 39 | self.inv_transform = lambda e: (e - b) / w 40 | 41 | def set_hyperparameters( 42 | self, 43 | change_confidence_threshold=155, 44 | auto_threshold=False, 45 | use_normalized_feature=True, 46 | area_thresh=0.8, 47 | match_hist=False, 48 | object_sim_thresh=60, 49 | bitemporal_match=True, 50 | ): 51 | self.area_thresh = area_thresh 52 | self.match_hist = match_hist 53 | 54 | self.change_confidence_threshold = change_confidence_threshold 55 | self.auto_threshold = auto_threshold 56 | self.use_normalized_feature = use_normalized_feature 57 | self.object_sim_thresh = object_sim_thresh 58 | self.use_bitemporal_match = bitemporal_match 59 | 60 | def bitemporal_match(self, t1_mask_data, t1_image_embedding, t2_mask_data, t2_image_embedding) -> MaskData: 61 | t1_img_embed = t1_image_embedding 62 | t2_img_embed = t2_image_embedding 63 | h, w = self.embed_data1['original_size'] 64 | 65 | seq_img_embed = [t1_img_embed, t2_img_embed] 66 | seq_img_embed_data = [{'image_embedding': img_embed, 67 | 'original_size': self.embed_data1['original_size']} 68 | for img_embed in seq_img_embed] 69 | 70 | seq_mask_data = [t1_mask_data, ] 71 | for img_embed_data in seq_img_embed_data[1:-1]: 72 | mask_data = self.maskgen.generate_with_image_embedding(**img_embed_data) 73 | mask_data.filter((mask_data['areas'] / (h * w)) < self.area_thresh) 74 | seq_mask_data.append(mask_data) 75 | 76 | seq_mask_data.append(t2_mask_data) 77 | 78 | if self.use_normalized_feature: 79 | t1_img_embed = self.inv_transform(t1_img_embed) 80 | t2_img_embed = self.inv_transform(t2_img_embed) 81 | 82 | t1_img_embed = F.interpolate(t1_img_embed, size=(h, w), mode='bilinear', align_corners=True) 83 | t2_img_embed = F.interpolate(t2_img_embed, size=(h, w), mode='bilinear', align_corners=True) 84 | t1_img_embed = t1_img_embed.squeeze_(0) 85 | t2_img_embed = t2_img_embed.squeeze_(0) 86 | 87 | if self.auto_threshold: 88 | cosv = -F.cosine_similarity(t1_img_embed, t2_img_embed, dim=0) 89 | cosv = cosv.reshape(-1).cpu().numpy() 90 | threshold = threshold_otsu(cosv, cosv.shape[0]) 91 | self.change_confidence_threshold = cosine2angle(threshold) 92 | 93 | def _latent_match(mask_data, t1_img_embed, t2_img_embed): 94 | change_confidence = torch.zeros(len(mask_data['rles']), dtype=torch.float32, device=self.device) 95 | for i, rle in enumerate(mask_data['rles']): 96 | bmask = torch.from_numpy(rle_to_mask(rle)).to(self.device) 97 | t1_mask_embed = torch.mean(t1_img_embed[:, bmask], dim=-1) 98 | t2_mask_embed = torch.mean(t2_img_embed[:, bmask], dim=-1) 99 | score = -F.cosine_similarity(t1_mask_embed, t2_mask_embed, dim=0) 100 | change_confidence[i] += score 101 | 102 | keep = change_confidence > angle2cosine(self.change_confidence_threshold) 103 | 104 | mask_data = copy.deepcopy(mask_data) 105 | mask_data['change_confidence'] = change_confidence 106 | mask_data.filter(keep) 107 | return mask_data 108 | 109 | changemasks = MaskData() 110 | if self.use_bitemporal_match: 111 | for i in range(2): 112 | cmasks = _latent_match(seq_mask_data[i], t1_img_embed, t2_img_embed) 113 | changemasks.cat(cmasks) 114 | else: 115 | cmasks = _latent_match(seq_mask_data[1], t1_img_embed, t2_img_embed) 116 | changemasks.cat(cmasks) 117 | del cmasks 118 | 119 | return changemasks 120 | 121 | def single_point_q_mask(self, xy, img): 122 | point = np.array(xy).reshape(1, 2) 123 | 124 | embed_data = self.maskgen.image_encoder(img) 125 | 126 | embed_data.update(dict(points=point)) 127 | mask_data = self.maskgen.embedding_point_to_mask(**embed_data) 128 | 129 | if len(mask_data['rles']) > 0: 130 | q_mask = torch.from_numpy(rle_to_mask(mask_data['rles'][0])) 131 | else: 132 | q_mask = torch.zeros(img.shape[0], img.shape[1]) 133 | return q_mask 134 | 135 | def single_point_match(self, xy, temporal, img1, img2): 136 | h, w = img1.shape[:2] 137 | point = np.array(xy).reshape(1, 2) 138 | 139 | embed_data1 = self.maskgen.image_encoder(img1) 140 | embed_data2 = self.maskgen.image_encoder(img2) 141 | self.embed_data1 = embed_data1 142 | self.embed_data2 = embed_data2 143 | 144 | mask_data1 = self.maskgen.generate_with_image_embedding(**embed_data1) 145 | mask_data2 = self.maskgen.generate_with_image_embedding(**embed_data2) 146 | mask_data1.filter((mask_data1['areas'] / (h * w)) < self.area_thresh) 147 | mask_data2.filter((mask_data2['areas'] / (h * w)) < self.area_thresh) 148 | 149 | if temporal == 1: 150 | embed_data1.update(dict(points=point)) 151 | mask_data = self.maskgen.embedding_point_to_mask(**embed_data1) 152 | elif temporal == 2: 153 | embed_data2.update(dict(points=point)) 154 | mask_data = self.maskgen.embedding_point_to_mask(**embed_data2) 155 | else: 156 | raise ValueError 157 | 158 | q_area = mask_data['areas'][0] 159 | q_mask = torch.from_numpy(rle_to_mask(mask_data['rles'][0])) 160 | 161 | image_embedding1 = F.interpolate(embed_data1['image_embedding'], (h, w), mode='bilinear', 162 | align_corners=True).squeeze_(0) 163 | image_embedding2 = F.interpolate(embed_data2['image_embedding'], (h, w), mode='bilinear', 164 | align_corners=True).squeeze_(0) 165 | 166 | if temporal == 1: 167 | q_mask_features = torch.mean(image_embedding1[:, q_mask], dim=-1) 168 | elif temporal == 2: 169 | q_mask_features = torch.mean(image_embedding2[:, q_mask], dim=-1) 170 | else: 171 | raise ValueError 172 | 173 | cosmap1 = torch.cosine_similarity(q_mask_features.reshape(-1, 1, 1), image_embedding1, dim=0) 174 | cosmap2 = torch.cosine_similarity(q_mask_features.reshape(-1, 1, 1), image_embedding2, dim=0) 175 | 176 | obj_map1 = cosmap1 > angle2cosine(self.object_sim_thresh) 177 | obj_map2 = cosmap2 > angle2cosine(self.object_sim_thresh) 178 | 179 | def _filter_obj(obj_map, mask_data): 180 | mask_data = copy.deepcopy(mask_data) 181 | keep = (q_area * 0.25 < mask_data['areas']) & (mask_data['areas'] < q_area * 4) 182 | mask_data.filter(keep) 183 | keep = [] 184 | for i, rle in enumerate(mask_data['rles']): 185 | mask = rle_to_mask(rle) 186 | keep.append(np.mean(obj_map[mask]) > 0.5) 187 | keep = torch.from_numpy(np.array(keep)).to(torch.bool) 188 | mask_data.filter(keep) 189 | return mask_data 190 | 191 | mask_data1 = _filter_obj(obj_map1.cpu().numpy(), mask_data1) 192 | mask_data2 = _filter_obj(obj_map2.cpu().numpy(), mask_data2) 193 | 194 | data = { 195 | 't1_mask_data': mask_data1, 196 | 't1_image_embedding': embed_data1['image_embedding'], 197 | 't2_mask_data': mask_data2, 198 | 't2_image_embedding': embed_data2['image_embedding'], 199 | } 200 | cmasks = self.bitemporal_match(**data) 201 | 202 | keep = batched_nms( 203 | cmasks["boxes"].float(), 204 | cmasks["iou_preds"], 205 | torch.zeros_like(cmasks["boxes"][:, 0]), 206 | iou_threshold=self.maskgen.box_nms_thresh, 207 | ) 208 | cmasks.filter(keep) 209 | if len(cmasks['rles']) > 1000: 210 | scores = cmasks['change_confidence'] 211 | sorted_scores, _ = torch.sort(scores, descending=True, stable=True) 212 | keep = scores > sorted_scores[1000] 213 | cmasks.filter(keep) 214 | 215 | return cmasks 216 | 217 | def multi_points_match(self, xyts, img1, img2): 218 | h, w = img1.shape[:2] 219 | 220 | embed_data1 = self.maskgen.image_encoder(img1) 221 | embed_data2 = self.maskgen.image_encoder(img2) 222 | self.embed_data1 = embed_data1 223 | self.embed_data2 = embed_data2 224 | 225 | mask_data1 = self.maskgen.generate_with_image_embedding(**embed_data1) 226 | mask_data2 = self.maskgen.generate_with_image_embedding(**embed_data2) 227 | mask_data1.filter((mask_data1['areas'] / (h * w)) < self.area_thresh) 228 | mask_data2.filter((mask_data2['areas'] / (h * w)) < self.area_thresh) 229 | 230 | image_embedding1 = F.interpolate(embed_data1['image_embedding'], (h, w), mode='bilinear', 231 | align_corners=True).squeeze_(0) 232 | image_embedding2 = F.interpolate(embed_data2['image_embedding'], (h, w), mode='bilinear', 233 | align_corners=True).squeeze_(0) 234 | 235 | q_areas = [] 236 | q_features = [] 237 | for xyt in xyts: 238 | t = xyt[-1] 239 | point = xyt[:2].reshape(1, 2) 240 | 241 | if t == 1: 242 | embed_data1.update(dict(points=point)) 243 | mask_data = self.maskgen.embedding_point_to_mask(**embed_data1) 244 | elif t == 2: 245 | embed_data2.update(dict(points=point)) 246 | mask_data = self.maskgen.embedding_point_to_mask(**embed_data2) 247 | else: 248 | raise ValueError 249 | 250 | q_area = mask_data['areas'][0] 251 | q_mask = torch.from_numpy(rle_to_mask(mask_data['rles'][0])) 252 | 253 | q_areas.append(q_area) 254 | 255 | if t == 1: 256 | q_mask_features = torch.mean(image_embedding1[:, q_mask], dim=-1) 257 | elif t == 2: 258 | q_mask_features = torch.mean(image_embedding2[:, q_mask], dim=-1) 259 | else: 260 | raise ValueError 261 | q_features.append(q_mask_features) 262 | 263 | q_area = sum(q_areas) / len(q_areas) 264 | q_mask_features = sum(q_features) / len(q_features) 265 | 266 | cosmap1 = torch.cosine_similarity(q_mask_features.reshape(-1, 1, 1), image_embedding1, dim=0) 267 | cosmap2 = torch.cosine_similarity(q_mask_features.reshape(-1, 1, 1), image_embedding2, dim=0) 268 | 269 | obj_map1 = cosmap1 > angle2cosine(self.object_sim_thresh) 270 | obj_map2 = cosmap2 > angle2cosine(self.object_sim_thresh) 271 | 272 | def _filter_obj(obj_map, mask_data): 273 | mask_data = copy.deepcopy(mask_data) 274 | keep = (q_area * 0.25 < mask_data['areas']) & (mask_data['areas'] < q_area * 4) 275 | mask_data.filter(keep) 276 | keep = [] 277 | for i, rle in enumerate(mask_data['rles']): 278 | mask = rle_to_mask(rle) 279 | keep.append(np.mean(obj_map[mask]) > 0.5) 280 | keep = torch.from_numpy(np.array(keep)).to(torch.bool) 281 | mask_data.filter(keep) 282 | return mask_data 283 | 284 | mask_data1 = _filter_obj(obj_map1.cpu().numpy(), mask_data1) 285 | mask_data2 = _filter_obj(obj_map2.cpu().numpy(), mask_data2) 286 | 287 | data = { 288 | 't1_mask_data': mask_data1, 289 | 't1_image_embedding': embed_data1['image_embedding'], 290 | 't2_mask_data': mask_data2, 291 | 't2_image_embedding': embed_data2['image_embedding'], 292 | } 293 | cmasks = self.bitemporal_match(**data) 294 | 295 | keep = batched_nms( 296 | cmasks["boxes"].float(), 297 | cmasks["iou_preds"], 298 | torch.zeros_like(cmasks["boxes"][:, 0]), 299 | iou_threshold=self.maskgen.box_nms_thresh, 300 | ) 301 | cmasks.filter(keep) 302 | if len(cmasks['rles']) > 1000: 303 | scores = cmasks['change_confidence'] 304 | sorted_scores, _ = torch.sort(scores, descending=True, stable=True) 305 | keep = scores > sorted_scores[1000] 306 | cmasks.filter(keep) 307 | 308 | return cmasks 309 | -------------------------------------------------------------------------------- /torchange/models/segment_any_change/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Zhuo Zheng and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from skimage.exposure import match_histograms 11 | 12 | from torchvision.ops.boxes import batched_nms 13 | from torchange.models.segment_any_change.simple_maskgen import SimpleMaskGenerator 14 | from torchange.models.segment_any_change.segment_anything import sam_model_registry 15 | from torchange.models.segment_any_change.segment_anything.utils.amg import MaskData 16 | from safetensors.torch import load_file 17 | 18 | 19 | class SegmentAnyChange: 20 | def __init__(self, model_type='vit_b', sam_checkpoint='./sam_weights/sam_vit_b_01ec64.pth'): 21 | sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) 22 | self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 23 | self.sam = sam.to(self.device) 24 | self.maskgen = SimpleMaskGenerator(self.sam) 25 | 26 | self.set_hyperparameters() 27 | 28 | self.embed_data1 = None 29 | self.embed_data2 = None 30 | 31 | def set_hyperparameters(self, **kwargs): 32 | self.match_hist = kwargs.get('match_hist', False) 33 | self.area_thresh = kwargs.get('area_thresh', 0.8) 34 | 35 | def make_mask_generator(self, **kwargs): 36 | self.maskgen = SimpleMaskGenerator(self.sam, **kwargs) 37 | 38 | def extract_image_embedding(self, img1, img2): 39 | self.embed_data1 = self.maskgen.image_encoder(img1) 40 | self.embed_data2 = self.maskgen.image_encoder(img2) 41 | return self.embed_data1, self.embed_data2 42 | 43 | def set_cached_embedding(self, embedding): 44 | data = embedding 45 | oh, ow = data['original_size'].numpy() 46 | h, w = data['input_size'] 47 | self.embed_data1 = { 48 | 'image_embedding': data['t1'].to(self.device), 49 | 'original_size': (oh, ow), 50 | } 51 | 52 | self.embed_data2 = { 53 | 'image_embedding': data['t2'].to(self.device), 54 | 'original_size': (oh, ow), 55 | } 56 | self.maskgen.predictor.input_size = (h, w) 57 | self.maskgen.predictor.original_size = (oh, ow) 58 | 59 | def load_cached_embedding(self, filepath): 60 | data = load_file(filepath, device='cpu') 61 | self.set_cached_embedding(data) 62 | 63 | def clear_cached_embedding(self): 64 | self.embed_data1 = None 65 | self.embed_data2 = None 66 | self.maskgen.predictor.input_size = None 67 | self.maskgen.predictor.original_size = None 68 | 69 | def proposal(self, img1, img2): 70 | h, w = img1.shape[:2] 71 | if self.embed_data1 is None: 72 | self.extract_image_embedding(img1, img2) 73 | 74 | mask_data1 = self.maskgen.generate_with_image_embedding(**self.embed_data1) 75 | mask_data2 = self.maskgen.generate_with_image_embedding(**self.embed_data2) 76 | mask_data1.filter((mask_data1['areas'] / (h * w)) < self.area_thresh) 77 | mask_data2.filter((mask_data2['areas'] / (h * w)) < self.area_thresh) 78 | 79 | return { 80 | 't1_mask_data': mask_data1, 81 | 't1_image_embedding': self.embed_data1['image_embedding'], 82 | 't2_mask_data': mask_data2, 83 | 't2_image_embedding': self.embed_data2['image_embedding'], 84 | } 85 | 86 | def bitemporal_match(self, t1_mask_data, t1_image_embedding, t2_mask_data, t2_image_embedding) -> MaskData: 87 | return NotImplementedError 88 | 89 | def forward(self, img1, img2): 90 | h, w = img1.shape[:2] 91 | 92 | if self.match_hist: 93 | img2 = match_histograms(image=img2, reference=img1, channel_axis=-1).astype(np.uint8) 94 | 95 | data = self.proposal(img1, img2) 96 | 97 | changemasks = self.bitemporal_match(**data) 98 | 99 | keep = batched_nms( 100 | changemasks["boxes"].float(), 101 | changemasks["iou_preds"], 102 | torch.zeros_like(changemasks["boxes"][:, 0]), 103 | iou_threshold=self.maskgen.box_nms_thresh, 104 | ) 105 | changemasks.filter(keep) 106 | 107 | if len(changemasks['rles']) > 1000: 108 | scores = changemasks['change_confidence'] 109 | sorted_scores, _ = torch.sort(scores, descending=True, stable=True) 110 | keep = scores > sorted_scores[1000] 111 | changemasks.filter(keep) 112 | 113 | return changemasks, data['t1_mask_data'], data['t2_mask_data'] 114 | 115 | def to_eval_format_predictions(self, cmasks): 116 | boxes = cmasks['boxes'] 117 | rle_masks = cmasks['rles'] 118 | labels = torch.ones(boxes.size(0), dtype=torch.int64) 119 | scores = cmasks['change_confidence'] 120 | predictions = { 121 | 'boxes': boxes.to(torch.float32).cpu(), 122 | 'scores': scores.cpu(), 123 | 'labels': labels.cpu(), 124 | 'masks': rle_masks 125 | } 126 | return predictions 127 | 128 | def __call__(self, img1, img2): 129 | cmasks, t1_masks, t2_masks = self.forward(img1, img2) 130 | predictions = self.to_eval_format_predictions(cmasks) 131 | self.clear_cached_embedding() 132 | return predictions 133 | -------------------------------------------------------------------------------- /torchange/models/segment_any_change/segment_anything/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam, 9 | build_sam_vit_h, 10 | build_sam_vit_l, 11 | build_sam_vit_b, 12 | sam_model_registry, 13 | ) 14 | from .predictor import SamPredictor 15 | from .automatic_mask_generator import SamAutomaticMaskGenerator 16 | -------------------------------------------------------------------------------- /torchange/models/segment_any_change/segment_anything/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from functools import partial 10 | 11 | from ..segment_anything.modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer 12 | 13 | 14 | def build_sam_vit_h(checkpoint=None): 15 | return _build_sam( 16 | encoder_embed_dim=1280, 17 | encoder_depth=32, 18 | encoder_num_heads=16, 19 | encoder_global_attn_indexes=[7, 15, 23, 31], 20 | checkpoint=checkpoint, 21 | ) 22 | 23 | 24 | build_sam = build_sam_vit_h 25 | 26 | 27 | def build_sam_vit_l(checkpoint=None): 28 | return _build_sam( 29 | encoder_embed_dim=1024, 30 | encoder_depth=24, 31 | encoder_num_heads=16, 32 | encoder_global_attn_indexes=[5, 11, 17, 23], 33 | checkpoint=checkpoint, 34 | ) 35 | 36 | 37 | def build_sam_vit_b(checkpoint=None): 38 | return _build_sam( 39 | encoder_embed_dim=768, 40 | encoder_depth=12, 41 | encoder_num_heads=12, 42 | encoder_global_attn_indexes=[2, 5, 8, 11], 43 | checkpoint=checkpoint, 44 | ) 45 | 46 | 47 | sam_model_registry = { 48 | "default": build_sam_vit_h, 49 | "vit_h": build_sam_vit_h, 50 | "vit_l": build_sam_vit_l, 51 | "vit_b": build_sam_vit_b, 52 | } 53 | 54 | 55 | def _build_sam( 56 | encoder_embed_dim, 57 | encoder_depth, 58 | encoder_num_heads, 59 | encoder_global_attn_indexes, 60 | checkpoint=None, 61 | ): 62 | prompt_embed_dim = 256 63 | image_size = 1024 64 | vit_patch_size = 16 65 | image_embedding_size = image_size // vit_patch_size 66 | sam = Sam( 67 | image_encoder=ImageEncoderViT( 68 | depth=encoder_depth, 69 | embed_dim=encoder_embed_dim, 70 | img_size=image_size, 71 | mlp_ratio=4, 72 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 73 | num_heads=encoder_num_heads, 74 | patch_size=vit_patch_size, 75 | qkv_bias=True, 76 | use_rel_pos=True, 77 | global_attn_indexes=encoder_global_attn_indexes, 78 | window_size=14, 79 | out_chans=prompt_embed_dim, 80 | ), 81 | prompt_encoder=PromptEncoder( 82 | embed_dim=prompt_embed_dim, 83 | image_embedding_size=(image_embedding_size, image_embedding_size), 84 | input_image_size=(image_size, image_size), 85 | mask_in_chans=16, 86 | ), 87 | mask_decoder=MaskDecoder( 88 | num_multimask_outputs=3, 89 | transformer=TwoWayTransformer( 90 | depth=2, 91 | embedding_dim=prompt_embed_dim, 92 | mlp_dim=2048, 93 | num_heads=8, 94 | ), 95 | transformer_dim=prompt_embed_dim, 96 | iou_head_depth=3, 97 | iou_head_hidden_dim=256, 98 | ), 99 | pixel_mean=[123.675, 116.28, 103.53], 100 | pixel_std=[58.395, 57.12, 57.375], 101 | ) 102 | sam.eval() 103 | if checkpoint is not None: 104 | with open(checkpoint, "rb") as f: 105 | state_dict = torch.load(f) 106 | sam.load_state_dict(state_dict) 107 | return sam 108 | -------------------------------------------------------------------------------- /torchange/models/segment_any_change/segment_anything/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | from .transformer import TwoWayTransformer 12 | -------------------------------------------------------------------------------- /torchange/models/segment_any_change/segment_anything/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /torchange/models/segment_any_change/segment_anything/modeling/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import List, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class MaskDecoder(nn.Module): 17 | def __init__( 18 | self, 19 | *, 20 | transformer_dim: int, 21 | transformer: nn.Module, 22 | num_multimask_outputs: int = 3, 23 | activation: Type[nn.Module] = nn.GELU, 24 | iou_head_depth: int = 3, 25 | iou_head_hidden_dim: int = 256, 26 | ) -> None: 27 | """ 28 | Predicts masks given an image and prompt embeddings, using a 29 | transformer architecture. 30 | 31 | Arguments: 32 | transformer_dim (int): the channel dimension of the transformer 33 | transformer (nn.Module): the transformer used to predict masks 34 | num_multimask_outputs (int): the number of masks to predict 35 | when disambiguating masks 36 | activation (nn.Module): the type of activation to use when 37 | upscaling masks 38 | iou_head_depth (int): the depth of the MLP used to predict 39 | mask quality 40 | iou_head_hidden_dim (int): the hidden dimension of the MLP 41 | used to predict mask quality 42 | """ 43 | super().__init__() 44 | self.transformer_dim = transformer_dim 45 | self.transformer = transformer 46 | 47 | self.num_multimask_outputs = num_multimask_outputs 48 | 49 | self.iou_token = nn.Embedding(1, transformer_dim) 50 | self.num_mask_tokens = num_multimask_outputs + 1 51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 52 | 53 | self.output_upscaling = nn.Sequential( 54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 55 | LayerNorm2d(transformer_dim // 4), 56 | activation(), 57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 58 | activation(), 59 | ) 60 | self.output_hypernetworks_mlps = nn.ModuleList( 61 | [ 62 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 63 | for i in range(self.num_mask_tokens) 64 | ] 65 | ) 66 | 67 | self.iou_prediction_head = MLP( 68 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 69 | ) 70 | 71 | def forward( 72 | self, 73 | image_embeddings: torch.Tensor, 74 | image_pe: torch.Tensor, 75 | sparse_prompt_embeddings: torch.Tensor, 76 | dense_prompt_embeddings: torch.Tensor, 77 | multimask_output: bool, 78 | ) -> Tuple[torch.Tensor, torch.Tensor]: 79 | """ 80 | Predict masks given image and prompt embeddings. 81 | 82 | Arguments: 83 | image_embeddings (torch.Tensor): the embeddings from the image encoder 84 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 85 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 86 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 87 | multimask_output (bool): Whether to return multiple masks or a single 88 | mask. 89 | 90 | Returns: 91 | torch.Tensor: batched predicted masks 92 | torch.Tensor: batched predictions of mask quality 93 | """ 94 | masks, iou_pred = self.predict_masks( 95 | image_embeddings=image_embeddings, 96 | image_pe=image_pe, 97 | sparse_prompt_embeddings=sparse_prompt_embeddings, 98 | dense_prompt_embeddings=dense_prompt_embeddings, 99 | ) 100 | 101 | # Select the correct mask or masks for output 102 | if multimask_output: 103 | mask_slice = slice(1, None) 104 | else: 105 | mask_slice = slice(0, 1) 106 | masks = masks[:, mask_slice, :, :] 107 | iou_pred = iou_pred[:, mask_slice] 108 | 109 | # Prepare output 110 | return masks, iou_pred 111 | 112 | def predict_masks( 113 | self, 114 | image_embeddings: torch.Tensor, 115 | image_pe: torch.Tensor, 116 | sparse_prompt_embeddings: torch.Tensor, 117 | dense_prompt_embeddings: torch.Tensor, 118 | ) -> Tuple[torch.Tensor, torch.Tensor]: 119 | """Predicts masks. See 'forward' for more details.""" 120 | # Concatenate output tokens 121 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) 122 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) 123 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 124 | 125 | # Expand per-image data in batch direction to be per-mask 126 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 127 | src = src + dense_prompt_embeddings 128 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 129 | b, c, h, w = src.shape 130 | 131 | # Run the transformer 132 | hs, src = self.transformer(src, pos_src, tokens) 133 | iou_token_out = hs[:, 0, :] 134 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 135 | 136 | # Upscale mask embeddings and predict masks using the mask tokens 137 | src = src.transpose(1, 2).view(b, c, h, w) 138 | upscaled_embedding = self.output_upscaling(src) 139 | hyper_in_list: List[torch.Tensor] = [] 140 | for i in range(self.num_mask_tokens): 141 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) 142 | hyper_in = torch.stack(hyper_in_list, dim=1) 143 | b, c, h, w = upscaled_embedding.shape 144 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 145 | 146 | # Generate mask quality predictions 147 | iou_pred = self.iou_prediction_head(iou_token_out) 148 | 149 | return masks, iou_pred 150 | 151 | 152 | # Lightly adapted from 153 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 154 | class MLP(nn.Module): 155 | def __init__( 156 | self, 157 | input_dim: int, 158 | hidden_dim: int, 159 | output_dim: int, 160 | num_layers: int, 161 | sigmoid_output: bool = False, 162 | ) -> None: 163 | super().__init__() 164 | self.num_layers = num_layers 165 | h = [hidden_dim] * (num_layers - 1) 166 | self.layers = nn.ModuleList( 167 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 168 | ) 169 | self.sigmoid_output = sigmoid_output 170 | 171 | def forward(self, x): 172 | for i, layer in enumerate(self.layers): 173 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 174 | if self.sigmoid_output: 175 | x = F.sigmoid(x) 176 | return x 177 | -------------------------------------------------------------------------------- /torchange/models/segment_any_change/segment_anything/modeling/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | 11 | from typing import Any, Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class PromptEncoder(nn.Module): 17 | def __init__( 18 | self, 19 | embed_dim: int, 20 | image_embedding_size: Tuple[int, int], 21 | input_image_size: Tuple[int, int], 22 | mask_in_chans: int, 23 | activation: Type[nn.Module] = nn.GELU, 24 | ) -> None: 25 | """ 26 | Encodes prompts for input to SAM's mask decoder. 27 | 28 | Arguments: 29 | embed_dim (int): The prompts' embedding dimension 30 | image_embedding_size (tuple(int, int)): The spatial size of the 31 | image embedding, as (H, W). 32 | input_image_size (int): The padded size of the image as input 33 | to the image encoder, as (H, W). 34 | mask_in_chans (int): The number of hidden channels used for 35 | encoding input masks. 36 | activation (nn.Module): The activation to use when encoding 37 | input masks. 38 | """ 39 | super().__init__() 40 | self.embed_dim = embed_dim 41 | self.input_image_size = input_image_size 42 | self.image_embedding_size = image_embedding_size 43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 44 | 45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 46 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] 47 | self.point_embeddings = nn.ModuleList(point_embeddings) 48 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 49 | 50 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) 51 | self.mask_downscaling = nn.Sequential( 52 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 53 | LayerNorm2d(mask_in_chans // 4), 54 | activation(), 55 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 56 | LayerNorm2d(mask_in_chans), 57 | activation(), 58 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 59 | ) 60 | self.no_mask_embed = nn.Embedding(1, embed_dim) 61 | 62 | def get_dense_pe(self) -> torch.Tensor: 63 | """ 64 | Returns the positional encoding used to encode point prompts, 65 | applied to a dense set of points the shape of the image encoding. 66 | 67 | Returns: 68 | torch.Tensor: Positional encoding with shape 69 | 1x(embed_dim)x(embedding_h)x(embedding_w) 70 | """ 71 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 72 | 73 | def _embed_points( 74 | self, 75 | points: torch.Tensor, 76 | labels: torch.Tensor, 77 | pad: bool, 78 | ) -> torch.Tensor: 79 | """Embeds point prompts.""" 80 | points = points + 0.5 # Shift to center of pixel 81 | if pad: 82 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 83 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 84 | points = torch.cat([points, padding_point], dim=1) 85 | labels = torch.cat([labels, padding_label], dim=1) 86 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) 87 | point_embedding[labels == -1] = 0.0 88 | point_embedding[labels == -1] += self.not_a_point_embed.weight 89 | point_embedding[labels == 0] += self.point_embeddings[0].weight 90 | point_embedding[labels == 1] += self.point_embeddings[1].weight 91 | return point_embedding 92 | 93 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 94 | """Embeds box prompts.""" 95 | boxes = boxes + 0.5 # Shift to center of pixel 96 | coords = boxes.reshape(-1, 2, 2) 97 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) 98 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 99 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 100 | return corner_embedding 101 | 102 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 103 | """Embeds mask inputs.""" 104 | mask_embedding = self.mask_downscaling(masks) 105 | return mask_embedding 106 | 107 | def _get_batch_size( 108 | self, 109 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 110 | boxes: Optional[torch.Tensor], 111 | masks: Optional[torch.Tensor], 112 | ) -> int: 113 | """ 114 | Gets the batch size of the output given the batch size of the input prompts. 115 | """ 116 | if points is not None: 117 | return points[0].shape[0] 118 | elif boxes is not None: 119 | return boxes.shape[0] 120 | elif masks is not None: 121 | return masks.shape[0] 122 | else: 123 | return 1 124 | 125 | def _get_device(self) -> torch.device: 126 | return self.point_embeddings[0].weight.device 127 | 128 | def forward( 129 | self, 130 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 131 | boxes: Optional[torch.Tensor], 132 | masks: Optional[torch.Tensor], 133 | ) -> Tuple[torch.Tensor, torch.Tensor]: 134 | """ 135 | Embeds different types of prompts, returning both sparse and dense 136 | embeddings. 137 | 138 | Arguments: 139 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 140 | and labels to embed. 141 | boxes (torch.Tensor or none): boxes to embed 142 | masks (torch.Tensor or none): masks to embed 143 | 144 | Returns: 145 | torch.Tensor: sparse embeddings for the points and boxes, with shape 146 | BxNx(embed_dim), where N is determined by the number of input points 147 | and boxes. 148 | torch.Tensor: dense embeddings for the masks, in the shape 149 | Bx(embed_dim)x(embed_H)x(embed_W) 150 | """ 151 | bs = self._get_batch_size(points, boxes, masks) 152 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) 153 | if points is not None: 154 | coords, labels = points 155 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 156 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 157 | if boxes is not None: 158 | box_embeddings = self._embed_boxes(boxes) 159 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 160 | 161 | if masks is not None: 162 | dense_embeddings = self._embed_masks(masks) 163 | else: 164 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 165 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 166 | ) 167 | 168 | return sparse_embeddings, dense_embeddings 169 | 170 | 171 | class PositionEmbeddingRandom(nn.Module): 172 | """ 173 | Positional encoding using random spatial frequencies. 174 | """ 175 | 176 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 177 | super().__init__() 178 | if scale is None or scale <= 0.0: 179 | scale = 1.0 180 | self.register_buffer( 181 | "positional_encoding_gaussian_matrix", 182 | scale * torch.randn((2, num_pos_feats)), 183 | ) 184 | 185 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 186 | """Positionally encode points that are normalized to [0,1].""" 187 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 188 | coords = 2 * coords - 1 189 | coords = coords @ self.positional_encoding_gaussian_matrix 190 | coords = 2 * np.pi * coords 191 | # outputs d_1 x ... x d_n x C shape 192 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 193 | 194 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 195 | """Generate positional encoding for a grid of the specified size.""" 196 | h, w = size 197 | device: Any = self.positional_encoding_gaussian_matrix.device 198 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 199 | y_embed = grid.cumsum(dim=0) - 0.5 200 | x_embed = grid.cumsum(dim=1) - 0.5 201 | y_embed = y_embed / h 202 | x_embed = x_embed / w 203 | 204 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 205 | return pe.permute(2, 0, 1) # C x H x W 206 | 207 | def forward_with_coords( 208 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 209 | ) -> torch.Tensor: 210 | """Positionally encode points that are not normalized to [0,1].""" 211 | coords = coords_input.clone() 212 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 213 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 214 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 215 | -------------------------------------------------------------------------------- /torchange/models/segment_any_change/segment_anything/modeling/sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Any, Dict, List, Tuple 12 | 13 | from .image_encoder import ImageEncoderViT 14 | from .mask_decoder import MaskDecoder 15 | from .prompt_encoder import PromptEncoder 16 | 17 | 18 | class Sam(nn.Module): 19 | mask_threshold: float = 0.0 20 | image_format: str = "RGB" 21 | 22 | def __init__( 23 | self, 24 | image_encoder: ImageEncoderViT, 25 | prompt_encoder: PromptEncoder, 26 | mask_decoder: MaskDecoder, 27 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 28 | pixel_std: List[float] = [58.395, 57.12, 57.375], 29 | ) -> None: 30 | """ 31 | SAM predicts object masks from an image and input prompts. 32 | 33 | Arguments: 34 | image_encoder (ImageEncoderViT): The backbone used to encode the 35 | image into image embeddings that allow for efficient mask prediction. 36 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 37 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 38 | and encoded prompts. 39 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image. 40 | pixel_std (list(float)): Std values for normalizing pixels in the input image. 41 | """ 42 | super().__init__() 43 | self.image_encoder = image_encoder 44 | self.prompt_encoder = prompt_encoder 45 | self.mask_decoder = mask_decoder 46 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) 47 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 48 | 49 | @property 50 | def device(self) -> Any: 51 | return self.pixel_mean.device 52 | 53 | @torch.no_grad() 54 | def forward( 55 | self, 56 | batched_input: List[Dict[str, Any]], 57 | multimask_output: bool, 58 | ) -> List[Dict[str, torch.Tensor]]: 59 | """ 60 | Predicts masks end-to-end from provided images and prompts. 61 | If prompts are not known in advance, using SamPredictor is 62 | recommended over calling the model directly. 63 | 64 | Arguments: 65 | batched_input (list(dict)): A list over input images, each a 66 | dictionary with the following keys. A prompt key can be 67 | excluded if it is not present. 68 | 'image': The image as a torch tensor in 3xHxW format, 69 | already transformed for input to the model. 70 | 'original_size': (tuple(int, int)) The original size of 71 | the image before transformation, as (H, W). 72 | 'point_coords': (torch.Tensor) Batched point prompts for 73 | this image, with shape BxNx2. Already transformed to the 74 | input frame of the model. 75 | 'point_labels': (torch.Tensor) Batched labels for point prompts, 76 | with shape BxN. 77 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. 78 | Already transformed to the input frame of the model. 79 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, 80 | in the form Bx1xHxW. 81 | multimask_output (bool): Whether the model should predict multiple 82 | disambiguating masks, or return a single mask. 83 | 84 | Returns: 85 | (list(dict)): A list over input images, where each element is 86 | as dictionary with the following keys. 87 | 'masks': (torch.Tensor) Batched binary mask predictions, 88 | with shape BxCxHxW, where B is the number of input prompts, 89 | C is determined by multimask_output, and (H, W) is the 90 | original size of the image. 91 | 'iou_predictions': (torch.Tensor) The model's predictions 92 | of mask quality, in shape BxC. 93 | 'low_res_logits': (torch.Tensor) Low resolution logits with 94 | shape BxCxHxW, where H=W=256. Can be passed as mask input 95 | to subsequent iterations of prediction. 96 | """ 97 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) 98 | image_embeddings = self.image_encoder(input_images) 99 | 100 | outputs = [] 101 | for image_record, curr_embedding in zip(batched_input, image_embeddings): 102 | if "point_coords" in image_record: 103 | points = (image_record["point_coords"], image_record["point_labels"]) 104 | else: 105 | points = None 106 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 107 | points=points, 108 | boxes=image_record.get("boxes", None), 109 | masks=image_record.get("mask_inputs", None), 110 | ) 111 | low_res_masks, iou_predictions = self.mask_decoder( 112 | image_embeddings=curr_embedding.unsqueeze(0), 113 | image_pe=self.prompt_encoder.get_dense_pe(), 114 | sparse_prompt_embeddings=sparse_embeddings, 115 | dense_prompt_embeddings=dense_embeddings, 116 | multimask_output=multimask_output, 117 | ) 118 | masks = self.postprocess_masks( 119 | low_res_masks, 120 | input_size=image_record["image"].shape[-2:], 121 | original_size=image_record["original_size"], 122 | ) 123 | masks = masks > self.mask_threshold 124 | outputs.append( 125 | { 126 | "masks": masks, 127 | "iou_predictions": iou_predictions, 128 | "low_res_logits": low_res_masks, 129 | } 130 | ) 131 | return outputs 132 | 133 | def postprocess_masks( 134 | self, 135 | masks: torch.Tensor, 136 | input_size: Tuple[int, ...], 137 | original_size: Tuple[int, ...], 138 | ) -> torch.Tensor: 139 | """ 140 | Remove padding and upscale masks to the original image size. 141 | 142 | Arguments: 143 | masks (torch.Tensor): Batched masks from the mask_decoder, 144 | in BxCxHxW format. 145 | input_size (tuple(int, int)): The size of the image input to the 146 | model, in (H, W) format. Used to remove padding. 147 | original_size (tuple(int, int)): The original size of the image 148 | before resizing for input to the model, in (H, W) format. 149 | 150 | Returns: 151 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 152 | is given by original_size. 153 | """ 154 | masks = F.interpolate( 155 | masks, 156 | (self.image_encoder.img_size, self.image_encoder.img_size), 157 | mode="bilinear", 158 | align_corners=False, 159 | ) 160 | masks = masks[..., : input_size[0], : input_size[1]] 161 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 162 | return masks 163 | 164 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 165 | """Normalize pixel values and pad to a square input.""" 166 | # Normalize colors 167 | x = (x - self.pixel_mean) / self.pixel_std 168 | 169 | # Pad 170 | h, w = x.shape[-2:] 171 | padh = self.image_encoder.img_size - h 172 | padw = self.image_encoder.img_size - w 173 | x = F.pad(x, (0, padw, 0, padh)) 174 | return x 175 | -------------------------------------------------------------------------------- /torchange/models/segment_any_change/segment_anything/modeling/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import Tensor, nn 9 | 10 | import math 11 | from typing import Tuple, Type 12 | 13 | from .common import MLPBlock 14 | 15 | 16 | class TwoWayTransformer(nn.Module): 17 | def __init__( 18 | self, 19 | depth: int, 20 | embedding_dim: int, 21 | num_heads: int, 22 | mlp_dim: int, 23 | activation: Type[nn.Module] = nn.ReLU, 24 | attention_downsample_rate: int = 2, 25 | ) -> None: 26 | """ 27 | A transformer decoder that attends to an input image using 28 | queries whose positional embedding is supplied. 29 | 30 | Args: 31 | depth (int): number of layers in the transformer 32 | embedding_dim (int): the channel dimension for the input embeddings 33 | num_heads (int): the number of heads for multihead attention. Must 34 | divide embedding_dim 35 | mlp_dim (int): the channel dimension internal to the MLP block 36 | activation (nn.Module): the activation to use in the MLP block 37 | """ 38 | super().__init__() 39 | self.depth = depth 40 | self.embedding_dim = embedding_dim 41 | self.num_heads = num_heads 42 | self.mlp_dim = mlp_dim 43 | self.layers = nn.ModuleList() 44 | 45 | for i in range(depth): 46 | self.layers.append( 47 | TwoWayAttentionBlock( 48 | embedding_dim=embedding_dim, 49 | num_heads=num_heads, 50 | mlp_dim=mlp_dim, 51 | activation=activation, 52 | attention_downsample_rate=attention_downsample_rate, 53 | skip_first_layer_pe=(i == 0), 54 | ) 55 | ) 56 | 57 | self.final_attn_token_to_image = Attention( 58 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 59 | ) 60 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 61 | 62 | def forward( 63 | self, 64 | image_embedding: Tensor, 65 | image_pe: Tensor, 66 | point_embedding: Tensor, 67 | ) -> Tuple[Tensor, Tensor]: 68 | """ 69 | Args: 70 | image_embedding (torch.Tensor): image to attend to. Should be shape 71 | B x embedding_dim x h x w for any h and w. 72 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 73 | have the same shape as image_embedding. 74 | point_embedding (torch.Tensor): the embedding to add to the query points. 75 | Must have shape B x N_points x embedding_dim for any N_points. 76 | 77 | Returns: 78 | torch.Tensor: the processed point_embedding 79 | torch.Tensor: the processed image_embedding 80 | """ 81 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 82 | bs, c, h, w = image_embedding.shape 83 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 84 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 85 | 86 | # Prepare queries 87 | queries = point_embedding 88 | keys = image_embedding 89 | 90 | # Apply transformer blocks and final layernorm 91 | for layer in self.layers: 92 | queries, keys = layer( 93 | queries=queries, 94 | keys=keys, 95 | query_pe=point_embedding, 96 | key_pe=image_pe, 97 | ) 98 | 99 | # Apply the final attention layer from the points to the image 100 | q = queries + point_embedding 101 | k = keys + image_pe 102 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 103 | queries = queries + attn_out 104 | queries = self.norm_final_attn(queries) 105 | 106 | return queries, keys 107 | 108 | 109 | class TwoWayAttentionBlock(nn.Module): 110 | def __init__( 111 | self, 112 | embedding_dim: int, 113 | num_heads: int, 114 | mlp_dim: int = 2048, 115 | activation: Type[nn.Module] = nn.ReLU, 116 | attention_downsample_rate: int = 2, 117 | skip_first_layer_pe: bool = False, 118 | ) -> None: 119 | """ 120 | A transformer block with four layers: (1) self-attention of sparse 121 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 122 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 123 | inputs. 124 | 125 | Arguments: 126 | embedding_dim (int): the channel dimension of the embeddings 127 | num_heads (int): the number of heads in the attention layers 128 | mlp_dim (int): the hidden dimension of the mlp block 129 | activation (nn.Module): the activation of the mlp block 130 | skip_first_layer_pe (bool): skip the PE on the first layer 131 | """ 132 | super().__init__() 133 | self.self_attn = Attention(embedding_dim, num_heads) 134 | self.norm1 = nn.LayerNorm(embedding_dim) 135 | 136 | self.cross_attn_token_to_image = Attention( 137 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 138 | ) 139 | self.norm2 = nn.LayerNorm(embedding_dim) 140 | 141 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) 142 | self.norm3 = nn.LayerNorm(embedding_dim) 143 | 144 | self.norm4 = nn.LayerNorm(embedding_dim) 145 | self.cross_attn_image_to_token = Attention( 146 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 147 | ) 148 | 149 | self.skip_first_layer_pe = skip_first_layer_pe 150 | 151 | def forward( 152 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor 153 | ) -> Tuple[Tensor, Tensor]: 154 | # Self attention block 155 | if self.skip_first_layer_pe: 156 | queries = self.self_attn(q=queries, k=queries, v=queries) 157 | else: 158 | q = queries + query_pe 159 | attn_out = self.self_attn(q=q, k=q, v=queries) 160 | queries = queries + attn_out 161 | queries = self.norm1(queries) 162 | 163 | # Cross attention block, tokens attending to image embedding 164 | q = queries + query_pe 165 | k = keys + key_pe 166 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 167 | queries = queries + attn_out 168 | queries = self.norm2(queries) 169 | 170 | # MLP block 171 | mlp_out = self.mlp(queries) 172 | queries = queries + mlp_out 173 | queries = self.norm3(queries) 174 | 175 | # Cross attention block, image embedding attending to tokens 176 | q = queries + query_pe 177 | k = keys + key_pe 178 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 179 | keys = keys + attn_out 180 | keys = self.norm4(keys) 181 | 182 | return queries, keys 183 | 184 | 185 | class Attention(nn.Module): 186 | """ 187 | An attention layer that allows for downscaling the size of the embedding 188 | after projection to queries, keys, and values. 189 | """ 190 | 191 | def __init__( 192 | self, 193 | embedding_dim: int, 194 | num_heads: int, 195 | downsample_rate: int = 1, 196 | ) -> None: 197 | super().__init__() 198 | self.embedding_dim = embedding_dim 199 | self.internal_dim = embedding_dim // downsample_rate 200 | self.num_heads = num_heads 201 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." 202 | 203 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 204 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim) 205 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim) 206 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 207 | 208 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 209 | b, n, c = x.shape 210 | x = x.reshape(b, n, num_heads, c // num_heads) 211 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 212 | 213 | def _recombine_heads(self, x: Tensor) -> Tensor: 214 | b, n_heads, n_tokens, c_per_head = x.shape 215 | x = x.transpose(1, 2) 216 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 217 | 218 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 219 | # Input projections 220 | q = self.q_proj(q) 221 | k = self.k_proj(k) 222 | v = self.v_proj(v) 223 | 224 | # Separate into heads 225 | q = self._separate_heads(q, self.num_heads) 226 | k = self._separate_heads(k, self.num_heads) 227 | v = self._separate_heads(v, self.num_heads) 228 | 229 | # Attention 230 | _, _, _, c_per_head = q.shape 231 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens 232 | attn = attn / math.sqrt(c_per_head) 233 | attn = torch.softmax(attn, dim=-1) 234 | 235 | # Get output 236 | out = attn @ v 237 | out = self._recombine_heads(out) 238 | out = self.out_proj(out) 239 | 240 | return out 241 | -------------------------------------------------------------------------------- /torchange/models/segment_any_change/segment_anything/predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from ..segment_anything.modeling import Sam 11 | 12 | from typing import Optional, Tuple 13 | 14 | from .utils.transforms import ResizeLongestSide 15 | 16 | 17 | class SamPredictor: 18 | def __init__( 19 | self, 20 | sam_model: Sam, 21 | ) -> None: 22 | """ 23 | Uses SAM to calculate the image embedding for an image, and then 24 | allow repeated, efficient mask prediction given prompts. 25 | 26 | Arguments: 27 | sam_model (Sam): The model to use for mask prediction. 28 | """ 29 | super().__init__() 30 | self.model = sam_model 31 | self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) 32 | self.reset_image() 33 | 34 | def set_image( 35 | self, 36 | image: np.ndarray, 37 | image_format: str = "RGB", 38 | ) -> None: 39 | """ 40 | Calculates the image embeddings for the provided image, allowing 41 | masks to be predicted with the 'predict' method. 42 | 43 | Arguments: 44 | image (np.ndarray): The image for calculating masks. Expects an 45 | image in HWC uint8 format, with pixel values in [0, 255]. 46 | image_format (str): The color format of the image, in ['RGB', 'BGR']. 47 | """ 48 | assert image_format in [ 49 | "RGB", 50 | "BGR", 51 | ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." 52 | if image_format != self.model.image_format: 53 | image = image[..., ::-1] 54 | 55 | # Transform the image to the form expected by the model 56 | input_image = self.transform.apply_image(image) 57 | input_image_torch = torch.as_tensor(input_image, device=self.device) 58 | input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] 59 | 60 | self.set_torch_image(input_image_torch, image.shape[:2]) 61 | 62 | @torch.no_grad() 63 | def set_torch_image( 64 | self, 65 | transformed_image: torch.Tensor, 66 | original_image_size: Tuple[int, ...], 67 | ) -> None: 68 | """ 69 | Calculates the image embeddings for the provided image, allowing 70 | masks to be predicted with the 'predict' method. Expects the input 71 | image to be already transformed to the format expected by the model. 72 | 73 | Arguments: 74 | transformed_image (torch.Tensor): The input image, with shape 75 | 1x3xHxW, which has been transformed with ResizeLongestSide. 76 | original_image_size (tuple(int, int)): The size of the image 77 | before transformation, in (H, W) format. 78 | """ 79 | assert ( 80 | len(transformed_image.shape) == 4 81 | and transformed_image.shape[1] == 3 82 | and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size 83 | ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." 84 | self.reset_image() 85 | 86 | self.original_size = original_image_size 87 | self.input_size = tuple(transformed_image.shape[-2:]) 88 | input_image = self.model.preprocess(transformed_image) 89 | self.features = self.model.image_encoder(input_image) 90 | self.is_image_set = True 91 | 92 | def predict( 93 | self, 94 | point_coords: Optional[np.ndarray] = None, 95 | point_labels: Optional[np.ndarray] = None, 96 | box: Optional[np.ndarray] = None, 97 | mask_input: Optional[np.ndarray] = None, 98 | multimask_output: bool = True, 99 | return_logits: bool = False, 100 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 101 | """ 102 | Predict masks for the given input prompts, using the currently set image. 103 | 104 | Arguments: 105 | point_coords (np.ndarray or None): A Nx2 array of point prompts to the 106 | model. Each point is in (X,Y) in pixels. 107 | point_labels (np.ndarray or None): A length N array of labels for the 108 | point prompts. 1 indicates a foreground point and 0 indicates a 109 | background point. 110 | box (np.ndarray or None): A length 4 array given a box prompt to the 111 | model, in XYXY format. 112 | mask_input (np.ndarray): A low resolution mask input to the model, typically 113 | coming from a previous prediction iteration. Has form 1xHxW, where 114 | for SAM, H=W=256. 115 | multimask_output (bool): If true, the model will return three masks. 116 | For ambiguous input prompts (such as a single click), this will often 117 | produce better masks than a single prediction. If only a single 118 | mask is needed, the model's predicted quality score can be used 119 | to select the best mask. For non-ambiguous prompts, such as multiple 120 | input prompts, multimask_output=False can give better results. 121 | return_logits (bool): If true, returns un-thresholded masks logits 122 | instead of a binary mask. 123 | 124 | Returns: 125 | (np.ndarray): The output masks in CxHxW format, where C is the 126 | number of masks, and (H, W) is the original image size. 127 | (np.ndarray): An array of length C containing the model's 128 | predictions for the quality of each mask. 129 | (np.ndarray): An array of shape CxHxW, where C is the number 130 | of masks and H=W=256. These low resolution logits can be passed to 131 | a subsequent iteration as mask input. 132 | """ 133 | if not self.is_image_set: 134 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 135 | 136 | # Transform input prompts 137 | coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None 138 | if point_coords is not None: 139 | assert ( 140 | point_labels is not None 141 | ), "point_labels must be supplied if point_coords is supplied." 142 | point_coords = self.transform.apply_coords(point_coords, self.original_size) 143 | coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) 144 | labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) 145 | coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] 146 | if box is not None: 147 | box = self.transform.apply_boxes(box, self.original_size) 148 | box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) 149 | box_torch = box_torch[None, :] 150 | if mask_input is not None: 151 | mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) 152 | mask_input_torch = mask_input_torch[None, :, :, :] 153 | 154 | masks, iou_predictions, low_res_masks = self.predict_torch( 155 | coords_torch, 156 | labels_torch, 157 | box_torch, 158 | mask_input_torch, 159 | multimask_output, 160 | return_logits=return_logits, 161 | ) 162 | 163 | masks_np = masks[0].detach().cpu().numpy() 164 | iou_predictions_np = iou_predictions[0].detach().cpu().numpy() 165 | low_res_masks_np = low_res_masks[0].detach().cpu().numpy() 166 | return masks_np, iou_predictions_np, low_res_masks_np 167 | 168 | @torch.no_grad() 169 | def predict_torch( 170 | self, 171 | point_coords: Optional[torch.Tensor], 172 | point_labels: Optional[torch.Tensor], 173 | boxes: Optional[torch.Tensor] = None, 174 | mask_input: Optional[torch.Tensor] = None, 175 | multimask_output: bool = True, 176 | return_logits: bool = False, 177 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 178 | """ 179 | Predict masks for the given input prompts, using the currently set image. 180 | Input prompts are batched torch tensors and are expected to already be 181 | transformed to the input frame using ResizeLongestSide. 182 | 183 | Arguments: 184 | point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the 185 | model. Each point is in (X,Y) in pixels. 186 | point_labels (torch.Tensor or None): A BxN array of labels for the 187 | point prompts. 1 indicates a foreground point and 0 indicates a 188 | background point. 189 | boxes (np.ndarray or None): A Bx4 array given a box prompt to the 190 | model, in XYXY format. 191 | mask_input (np.ndarray): A low resolution mask input to the model, typically 192 | coming from a previous prediction iteration. Has form Bx1xHxW, where 193 | for SAM, H=W=256. Masks returned by a previous iteration of the 194 | predict method do not need further transformation. 195 | multimask_output (bool): If true, the model will return three masks. 196 | For ambiguous input prompts (such as a single click), this will often 197 | produce better masks than a single prediction. If only a single 198 | mask is needed, the model's predicted quality score can be used 199 | to select the best mask. For non-ambiguous prompts, such as multiple 200 | input prompts, multimask_output=False can give better results. 201 | return_logits (bool): If true, returns un-thresholded masks logits 202 | instead of a binary mask. 203 | 204 | Returns: 205 | (torch.Tensor): The output masks in BxCxHxW format, where C is the 206 | number of masks, and (H, W) is the original image size. 207 | (torch.Tensor): An array of shape BxC containing the model's 208 | predictions for the quality of each mask. 209 | (torch.Tensor): An array of shape BxCxHxW, where C is the number 210 | of masks and H=W=256. These low res logits can be passed to 211 | a subsequent iteration as mask input. 212 | """ 213 | if not self.is_image_set: 214 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 215 | 216 | if point_coords is not None: 217 | points = (point_coords, point_labels) 218 | else: 219 | points = None 220 | 221 | # Embed prompts 222 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder( 223 | points=points, 224 | boxes=boxes, 225 | masks=mask_input, 226 | ) 227 | 228 | # Predict masks 229 | low_res_masks, iou_predictions = self.model.mask_decoder( 230 | image_embeddings=self.features, 231 | image_pe=self.model.prompt_encoder.get_dense_pe(), 232 | sparse_prompt_embeddings=sparse_embeddings, 233 | dense_prompt_embeddings=dense_embeddings, 234 | multimask_output=multimask_output, 235 | ) 236 | 237 | # Upscale the masks to the original image resolution 238 | masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) 239 | 240 | if not return_logits: 241 | masks = masks > self.model.mask_threshold 242 | 243 | return masks, iou_predictions, low_res_masks 244 | 245 | def get_image_embedding(self) -> torch.Tensor: 246 | """ 247 | Returns the image embeddings for the currently set image, with 248 | shape 1xCxHxW, where C is the embedding dimension and (H,W) are 249 | the embedding spatial dimension of SAM (typically C=256, H=W=64). 250 | """ 251 | if not self.is_image_set: 252 | raise RuntimeError( 253 | "An image must be set with .set_image(...) to generate an embedding." 254 | ) 255 | assert self.features is not None, "Features must exist if an image has been set." 256 | return self.features 257 | 258 | @property 259 | def device(self) -> torch.device: 260 | return self.model.device 261 | 262 | def reset_image(self) -> None: 263 | """Resets the currently set image.""" 264 | self.is_image_set = False 265 | self.features = None 266 | self.orig_h = None 267 | self.orig_w = None 268 | self.input_h = None 269 | self.input_w = None 270 | -------------------------------------------------------------------------------- /torchange/models/segment_any_change/segment_anything/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchange/models/segment_any_change/segment_anything/utils/amg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | 10 | import math 11 | from copy import deepcopy 12 | from itertools import product 13 | from typing import Any, Dict, Generator, ItemsView, List, Tuple 14 | 15 | 16 | class MaskData: 17 | """ 18 | A structure for storing masks and their related data in batched format. 19 | Implements basic filtering and concatenation. 20 | """ 21 | 22 | def __init__(self, **kwargs) -> None: 23 | for v in kwargs.values(): 24 | assert isinstance( 25 | v, (list, np.ndarray, torch.Tensor) 26 | ), "MaskData only supports list, numpy arrays, and torch tensors." 27 | self._stats = dict(**kwargs) 28 | 29 | def __setitem__(self, key: str, item: Any) -> None: 30 | assert isinstance( 31 | item, (list, np.ndarray, torch.Tensor) 32 | ), "MaskData only supports list, numpy arrays, and torch tensors." 33 | self._stats[key] = item 34 | 35 | def __delitem__(self, key: str) -> None: 36 | del self._stats[key] 37 | 38 | def __getitem__(self, key: str) -> Any: 39 | return self._stats[key] 40 | 41 | def items(self) -> ItemsView[str, Any]: 42 | return self._stats.items() 43 | 44 | def filter(self, keep: torch.Tensor) -> None: 45 | for k, v in self._stats.items(): 46 | if v is None: 47 | self._stats[k] = None 48 | elif isinstance(v, torch.Tensor): 49 | self._stats[k] = v[torch.as_tensor(keep, device=v.device)] 50 | elif isinstance(v, np.ndarray): 51 | self._stats[k] = v[keep.detach().cpu().numpy()] 52 | elif isinstance(v, list) and keep.dtype == torch.bool: 53 | self._stats[k] = [a for i, a in enumerate(v) if keep[i]] 54 | elif isinstance(v, list): 55 | self._stats[k] = [v[i] for i in keep] 56 | else: 57 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 58 | 59 | def cat(self, new_stats: "MaskData") -> None: 60 | for k, v in new_stats.items(): 61 | if k not in self._stats or self._stats[k] is None: 62 | self._stats[k] = deepcopy(v) 63 | elif isinstance(v, torch.Tensor): 64 | self._stats[k] = torch.cat([self._stats[k], v], dim=0) 65 | elif isinstance(v, np.ndarray): 66 | self._stats[k] = np.concatenate([self._stats[k], v], axis=0) 67 | elif isinstance(v, list): 68 | self._stats[k] = self._stats[k] + deepcopy(v) 69 | else: 70 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 71 | 72 | def to_numpy(self) -> None: 73 | for k, v in self._stats.items(): 74 | if isinstance(v, torch.Tensor): 75 | self._stats[k] = v.detach().cpu().numpy() 76 | 77 | 78 | def is_box_near_crop_edge( 79 | boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 80 | ) -> torch.Tensor: 81 | """Filter masks at the edge of a crop, but not at the edge of the original image.""" 82 | crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) 83 | orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) 84 | boxes = uncrop_boxes_xyxy(boxes, crop_box).float() 85 | near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) 86 | near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) 87 | near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) 88 | return torch.any(near_crop_edge, dim=1) 89 | 90 | 91 | def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: 92 | box_xywh = deepcopy(box_xyxy) 93 | box_xywh[2] = box_xywh[2] - box_xywh[0] 94 | box_xywh[3] = box_xywh[3] - box_xywh[1] 95 | return box_xywh 96 | 97 | 98 | def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: 99 | assert len(args) > 0 and all( 100 | len(a) == len(args[0]) for a in args 101 | ), "Batched iteration must have inputs of all the same size." 102 | n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) 103 | for b in range(n_batches): 104 | yield [arg[b * batch_size: (b + 1) * batch_size] for arg in args] 105 | 106 | 107 | def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: 108 | """ 109 | Encodes masks to an uncompressed RLE, in the format expected by 110 | pycoco tools. 111 | """ 112 | # Put in fortran order and flatten h,w 113 | b, h, w = tensor.shape 114 | tensor = tensor.permute(0, 2, 1).flatten(1) 115 | 116 | # Compute change indices 117 | diff = tensor[:, 1:] ^ tensor[:, :-1] 118 | change_indices = diff.nonzero() 119 | 120 | # Encode run length 121 | out = [] 122 | for i in range(b): 123 | cur_idxs = change_indices[change_indices[:, 0] == i, 1] 124 | cur_idxs = torch.cat( 125 | [ 126 | torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), 127 | cur_idxs + 1, 128 | torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), 129 | ] 130 | ) 131 | btw_idxs = cur_idxs[1:] - cur_idxs[:-1] 132 | counts = [] if tensor[i, 0] == 0 else [0] 133 | counts.extend(btw_idxs.detach().cpu().tolist()) 134 | out.append({"size": [h, w], "counts": counts}) 135 | return out 136 | 137 | 138 | def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: 139 | """Compute a binary mask from an uncompressed RLE.""" 140 | h, w = rle["size"] 141 | mask = np.empty(h * w, dtype=bool) 142 | idx = 0 143 | parity = False 144 | for count in rle["counts"]: 145 | mask[idx: idx + count] = parity 146 | idx += count 147 | parity ^= True 148 | mask = mask.reshape(w, h) 149 | return mask.transpose() # Put in C order 150 | 151 | 152 | def area_from_rle(rle: Dict[str, Any]) -> int: 153 | return sum(rle["counts"][1::2]) 154 | 155 | 156 | def calculate_stability_score( 157 | masks: torch.Tensor, mask_threshold: float, threshold_offset: float 158 | ) -> torch.Tensor: 159 | """ 160 | Computes the stability score for a batch of masks. The stability 161 | score is the IoU between the binary masks obtained by thresholding 162 | the predicted mask logits at high and low values. 163 | """ 164 | # One mask is always contained inside the other. 165 | # Save memory by preventing unnecessary cast to torch.int64 166 | # import matplotlib.pyplot as plt 167 | # fig, axes = plt.subplots(4, 4, figsize=(12, 12)) 168 | # axes = axes.reshape(-1) 169 | # _mask = masks.numpy() 170 | # for i, ax in enumerate(axes): 171 | # ax.imshow(_mask[i]) 172 | # ax.axis('off') 173 | # plt.tight_layout() 174 | # plt.show() 175 | # exit(0) 176 | intersections = ( 177 | (masks > (mask_threshold + threshold_offset)) 178 | .sum(-1, dtype=torch.int16) 179 | .sum(-1, dtype=torch.int32) 180 | ) 181 | unions = ( 182 | (masks > (mask_threshold - threshold_offset)) 183 | .sum(-1, dtype=torch.int16) 184 | .sum(-1, dtype=torch.int32) 185 | ) 186 | return intersections / unions 187 | 188 | 189 | def build_point_grid(n_per_side: int) -> np.ndarray: 190 | """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" 191 | offset = 1 / (2 * n_per_side) 192 | points_one_side = np.linspace(offset, 1 - offset, n_per_side) 193 | points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) 194 | points_y = np.tile(points_one_side[:, None], (1, n_per_side)) 195 | points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) 196 | return points 197 | 198 | 199 | def build_all_layer_point_grids( 200 | n_per_side: int, n_layers: int, scale_per_layer: int 201 | ) -> List[np.ndarray]: 202 | """Generates point grids for all crop layers.""" 203 | points_by_layer = [] 204 | for i in range(n_layers + 1): 205 | n_points = int(n_per_side / (scale_per_layer ** i)) 206 | points_by_layer.append(build_point_grid(n_points)) 207 | return points_by_layer 208 | 209 | 210 | def generate_crop_boxes( 211 | im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float 212 | ) -> Tuple[List[List[int]], List[int]]: 213 | """ 214 | Generates a list of crop boxes of different sizes. Each layer 215 | has (2**i)**2 boxes for the ith layer. 216 | """ 217 | crop_boxes, layer_idxs = [], [] 218 | im_h, im_w = im_size 219 | short_side = min(im_h, im_w) 220 | 221 | # Original image 222 | crop_boxes.append([0, 0, im_w, im_h]) 223 | layer_idxs.append(0) 224 | 225 | def crop_len(orig_len, n_crops, overlap): 226 | return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) 227 | 228 | for i_layer in range(n_layers): 229 | n_crops_per_side = 2 ** (i_layer + 1) 230 | overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) 231 | 232 | crop_w = crop_len(im_w, n_crops_per_side, overlap) 233 | crop_h = crop_len(im_h, n_crops_per_side, overlap) 234 | 235 | crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] 236 | crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] 237 | 238 | # Crops in XYWH format 239 | for x0, y0 in product(crop_box_x0, crop_box_y0): 240 | box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] 241 | crop_boxes.append(box) 242 | layer_idxs.append(i_layer + 1) 243 | 244 | return crop_boxes, layer_idxs 245 | 246 | 247 | def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 248 | x0, y0, _, _ = crop_box 249 | offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) 250 | # Check if boxes has a channel dimension 251 | if len(boxes.shape) == 3: 252 | offset = offset.unsqueeze(1) 253 | return boxes + offset 254 | 255 | 256 | def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 257 | x0, y0, _, _ = crop_box 258 | offset = torch.tensor([[x0, y0]], device=points.device) 259 | # Check if points has a channel dimension 260 | if len(points.shape) == 3: 261 | offset = offset.unsqueeze(1) 262 | return points + offset 263 | 264 | 265 | def uncrop_masks( 266 | masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int 267 | ) -> torch.Tensor: 268 | x0, y0, x1, y1 = crop_box 269 | if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: 270 | return masks 271 | # Coordinate transform masks 272 | pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) 273 | pad = (x0, pad_x - x0, y0, pad_y - y0) 274 | return torch.nn.functional.pad(masks, pad, value=0) 275 | 276 | 277 | def remove_small_regions( 278 | mask: np.ndarray, area_thresh: float, mode: str 279 | ) -> Tuple[np.ndarray, bool]: 280 | """ 281 | Removes small disconnected regions and holes in a mask. Returns the 282 | mask and an indicator of if the mask has been modified. 283 | """ 284 | import cv2 # type: ignore 285 | 286 | assert mode in ["holes", "islands"] 287 | correct_holes = mode == "holes" 288 | working_mask = (correct_holes ^ mask).astype(np.uint8) 289 | n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) 290 | sizes = stats[:, -1][1:] # Row 0 is background label 291 | small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] 292 | if len(small_regions) == 0: 293 | return mask, False 294 | fill_labels = [0] + small_regions 295 | if not correct_holes: 296 | fill_labels = [i for i in range(n_labels) if i not in fill_labels] 297 | # If every region is below threshold, keep largest 298 | if len(fill_labels) == 0: 299 | fill_labels = [int(np.argmax(sizes)) + 1] 300 | mask = np.isin(regions, fill_labels) 301 | return mask, True 302 | 303 | 304 | def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: 305 | from pycocotools import mask as mask_utils # type: ignore 306 | 307 | h, w = uncompressed_rle["size"] 308 | rle = mask_utils.frPyObjects(uncompressed_rle, h, w) 309 | rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json 310 | return rle 311 | 312 | 313 | def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: 314 | """ 315 | Calculates boxes in XYXY format around masks. Return [0,0,0,0] for 316 | an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. 317 | """ 318 | # torch.max below raises an error on empty inputs, just skip in this case 319 | if torch.numel(masks) == 0: 320 | return torch.zeros(*masks.shape[:-2], 4, device=masks.device) 321 | 322 | # Normalize shape to CxHxW 323 | shape = masks.shape 324 | h, w = shape[-2:] 325 | if len(shape) > 2: 326 | masks = masks.flatten(0, -3) 327 | else: 328 | masks = masks.unsqueeze(0) 329 | 330 | # Get top and bottom edges 331 | in_height, _ = torch.max(masks, dim=-1) 332 | in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] 333 | bottom_edges, _ = torch.max(in_height_coords, dim=-1) 334 | in_height_coords = in_height_coords + h * (~in_height) 335 | top_edges, _ = torch.min(in_height_coords, dim=-1) 336 | 337 | # Get left and right edges 338 | in_width, _ = torch.max(masks, dim=-2) 339 | in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] 340 | right_edges, _ = torch.max(in_width_coords, dim=-1) 341 | in_width_coords = in_width_coords + w * (~in_width) 342 | left_edges, _ = torch.min(in_width_coords, dim=-1) 343 | 344 | # If the mask is empty the right edge will be to the left of the left edge. 345 | # Replace these boxes with [0, 0, 0, 0] 346 | empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) 347 | out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) 348 | out = out * (~empty_filter).unsqueeze(-1) 349 | 350 | # Return to original shape 351 | if len(shape) > 2: 352 | out = out.reshape(*shape[:-2], 4) 353 | else: 354 | out = out[0] 355 | 356 | return out 357 | -------------------------------------------------------------------------------- /torchange/models/segment_any_change/segment_anything/utils/onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Tuple 12 | 13 | from ..modeling import Sam 14 | from .amg import calculate_stability_score 15 | 16 | 17 | class SamOnnxModel(nn.Module): 18 | """ 19 | This model should not be called directly, but is used in ONNX export. 20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, 21 | with some functions modified to enable model tracing. Also supports extra 22 | options controlling what information. See the ONNX export script for details. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model: Sam, 28 | return_single_mask: bool, 29 | use_stability_score: bool = False, 30 | return_extra_metrics: bool = False, 31 | ) -> None: 32 | super().__init__() 33 | self.mask_decoder = model.mask_decoder 34 | self.model = model 35 | self.img_size = model.image_encoder.img_size 36 | self.return_single_mask = return_single_mask 37 | self.use_stability_score = use_stability_score 38 | self.stability_score_offset = 1.0 39 | self.return_extra_metrics = return_extra_metrics 40 | 41 | @staticmethod 42 | def resize_longest_image_size( 43 | input_image_size: torch.Tensor, longest_side: int 44 | ) -> torch.Tensor: 45 | input_image_size = input_image_size.to(torch.float32) 46 | scale = longest_side / torch.max(input_image_size) 47 | transformed_size = scale * input_image_size 48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) 49 | return transformed_size 50 | 51 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: 52 | point_coords = point_coords + 0.5 53 | point_coords = point_coords / self.img_size 54 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) 55 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) 56 | 57 | point_embedding = point_embedding * (point_labels != -1) 58 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( 59 | point_labels == -1 60 | ) 61 | 62 | for i in range(self.model.prompt_encoder.num_point_embeddings): 63 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ 64 | i 65 | ].weight * (point_labels == i) 66 | 67 | return point_embedding 68 | 69 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: 70 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) 71 | mask_embedding = mask_embedding + ( 72 | 1 - has_mask_input 73 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) 74 | return mask_embedding 75 | 76 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: 77 | masks = F.interpolate( 78 | masks, 79 | size=(self.img_size, self.img_size), 80 | mode="bilinear", 81 | align_corners=False, 82 | ) 83 | 84 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) 85 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore 86 | 87 | orig_im_size = orig_im_size.to(torch.int64) 88 | h, w = orig_im_size[0], orig_im_size[1] 89 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 90 | return masks 91 | 92 | def select_masks( 93 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int 94 | ) -> Tuple[torch.Tensor, torch.Tensor]: 95 | # Determine if we should return the multiclick mask or not from the number of points. 96 | # The reweighting is used to avoid control flow. 97 | score_reweight = torch.tensor( 98 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] 99 | ).to(iou_preds.device) 100 | score = iou_preds + (num_points - 2.5) * score_reweight 101 | best_idx = torch.argmax(score, dim=1) 102 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) 103 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) 104 | 105 | return masks, iou_preds 106 | 107 | @torch.no_grad() 108 | def forward( 109 | self, 110 | image_embeddings: torch.Tensor, 111 | point_coords: torch.Tensor, 112 | point_labels: torch.Tensor, 113 | mask_input: torch.Tensor, 114 | has_mask_input: torch.Tensor, 115 | orig_im_size: torch.Tensor, 116 | ): 117 | sparse_embedding = self._embed_points(point_coords, point_labels) 118 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 119 | 120 | masks, scores = self.model.mask_decoder.predict_masks( 121 | image_embeddings=image_embeddings, 122 | image_pe=self.model.prompt_encoder.get_dense_pe(), 123 | sparse_prompt_embeddings=sparse_embedding, 124 | dense_prompt_embeddings=dense_embedding, 125 | ) 126 | 127 | if self.use_stability_score: 128 | scores = calculate_stability_score( 129 | masks, self.model.mask_threshold, self.stability_score_offset 130 | ) 131 | 132 | if self.return_single_mask: 133 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) 134 | 135 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size) 136 | 137 | if self.return_extra_metrics: 138 | stability_scores = calculate_stability_score( 139 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset 140 | ) 141 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 142 | return upscaled_masks, scores, stability_scores, areas, masks 143 | 144 | return upscaled_masks, scores, masks 145 | -------------------------------------------------------------------------------- /torchange/models/segment_any_change/segment_anything/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to the longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | -------------------------------------------------------------------------------- /torchange/models/segment_any_change/simple_maskgen.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Zhuo Zheng and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | import numpy as np 10 | from .segment_anything import SamPredictor 11 | from .segment_anything.utils.amg import build_point_grid 12 | from .segment_anything.utils.amg import ( 13 | MaskData, 14 | area_from_rle, 15 | batch_iterator, 16 | batched_mask_to_box, 17 | calculate_stability_score, 18 | is_box_near_crop_edge, 19 | mask_to_rle_pytorch, 20 | remove_small_regions, 21 | rle_to_mask, 22 | uncrop_boxes_xyxy, 23 | uncrop_masks, 24 | uncrop_points, 25 | ) 26 | from torchvision.ops.boxes import batched_nms 27 | 28 | __all__ = [ 29 | 'SimpleMaskGenerator', 30 | ] 31 | 32 | 33 | class SimpleMaskGenerator: 34 | def __init__( 35 | self, 36 | model, 37 | points_per_side=32, 38 | points_per_batch: int = 64, 39 | pred_iou_thresh: float = 0.5, 40 | stability_score_thresh: float = 0.95, 41 | stability_score_offset: float = 1.0, 42 | box_nms_thresh: float = 0.7, 43 | point_grids=None, 44 | min_mask_region_area: int = 0, 45 | ): 46 | self.predictor = SamPredictor(model) 47 | self.points_per_batch = points_per_batch 48 | self.pred_iou_thresh = pred_iou_thresh 49 | self.stability_score_thresh = stability_score_thresh 50 | self.stability_score_offset = stability_score_offset 51 | self.box_nms_thresh = box_nms_thresh 52 | self.min_mask_region_area = min_mask_region_area 53 | 54 | assert (points_per_side is None) != ( 55 | point_grids is None 56 | ), "Exactly one of points_per_side or point_grid must be provided." 57 | if points_per_side is not None: 58 | self.point_grids = build_point_grid(points_per_side) 59 | elif point_grids is not None: 60 | self.point_grids = point_grids 61 | else: 62 | raise ValueError("Can't have both points_per_side and point_grid be None.") 63 | 64 | @torch.no_grad() 65 | def image_encoder(self, image): 66 | orig_size = image.shape[:2] 67 | self.predictor.set_image(image) 68 | return { 69 | 'image_embedding': self.predictor.get_image_embedding(), 70 | 'original_size': orig_size, 71 | } 72 | 73 | @torch.no_grad() 74 | def generate_with_image_embedding(self, image_embedding, original_size): 75 | im_h, im_w = original_size 76 | # Get points for this crop 77 | points_scale = np.array(original_size)[None, ::-1] 78 | points_for_image = self.point_grids * points_scale 79 | 80 | data = MaskData() 81 | crop_box = [0, 0, im_w, im_h] 82 | for (points,) in batch_iterator(self.points_per_batch, points_for_image): 83 | batch_data = self._process_batch(image_embedding, points, original_size, crop_box, original_size) 84 | 85 | data.cat(batch_data) 86 | del batch_data 87 | 88 | keep_by_nms = batched_nms( 89 | data["boxes"].float(), 90 | data["iou_preds"], 91 | torch.zeros_like(data["boxes"][:, 0]), # categories 92 | iou_threshold=self.box_nms_thresh, 93 | ) 94 | data.filter(keep_by_nms) 95 | 96 | # Return to the original image frame 97 | data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) 98 | data["points"] = uncrop_points(data["points"], crop_box) 99 | data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) 100 | 101 | # Filter small disconnected regions and holes in masks 102 | if self.min_mask_region_area > 0: 103 | data.to_numpy() 104 | data = self.postprocess_small_regions( 105 | data, 106 | self.min_mask_region_area, 107 | max(self.box_nms_thresh, self.crop_nms_thresh), 108 | ) 109 | 110 | data['areas'] = np.asarray([area_from_rle(rle) for rle in data['rles']]) 111 | if isinstance(data['boxes'], torch.Tensor): 112 | data['areas'] = torch.from_numpy(data['areas']) 113 | 114 | return data 115 | 116 | @torch.no_grad() 117 | def generate_with_points(self, image, points): 118 | image_embedding_data = self.image_encoder(image) 119 | image_embedding_data.update(dict(points=points)) 120 | return self.embedding_point_to_mask(**image_embedding_data) 121 | 122 | @torch.no_grad() 123 | def embedding_point_to_mask(self, image_embedding, original_size, points): 124 | h, w = original_size 125 | crop_box = [0, 0, w, h] 126 | data = self._process_batch(image_embedding, points, (h, w), crop_box, (h, w)) 127 | 128 | # Return to the original image frame 129 | data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) 130 | data["points"] = uncrop_points(data["points"], crop_box) 131 | data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) 132 | 133 | data['areas'] = np.asarray([area_from_rle(rle) for rle in data['rles']]) 134 | if isinstance(data['boxes'], torch.Tensor): 135 | data['areas'] = torch.from_numpy(data['areas']) 136 | return data 137 | 138 | @torch.no_grad() 139 | def generate(self, image, mask_output_mode='rle'): 140 | image_embedding_data = self.image_encoder(image) 141 | data = self.generate_with_image_embedding(**image_embedding_data) 142 | 143 | if mask_output_mode == 'rle': 144 | data["segmentations"] = data["rles"] 145 | elif mask_output_mode == 'binary_mask': 146 | data["segmentations"] = np.stack([rle_to_mask(rle) for rle in data["rles"]], axis=0) 147 | else: 148 | raise ValueError 149 | 150 | if isinstance(data['boxes'], torch.Tensor): 151 | if mask_output_mode == 'binary_mask': 152 | data["segmentations"] = torch.from_numpy(data["segmentations"]) 153 | 154 | orig_size = image_embedding_data['original_size'] 155 | image_embedding = image_embedding_data['image_embedding'] 156 | image_embedding = F.interpolate(image_embedding, size=orig_size, mode='bilinear', align_corners=True) 157 | image_embedding = image_embedding.squeeze_(0) 158 | return { 159 | 'mask_data': data, 160 | 'image_embedding': image_embedding 161 | } 162 | 163 | def _process_batch( 164 | self, 165 | image_embedding, 166 | points: np.ndarray, 167 | im_size, 168 | crop_box, 169 | orig_size, 170 | ) -> MaskData: 171 | orig_h, orig_w = orig_size 172 | 173 | # Run model on this batch 174 | transformed_points = self.predictor.transform.apply_coords(points, im_size) 175 | in_points = torch.as_tensor(transformed_points, device=self.predictor.device) 176 | in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) 177 | masks, iou_preds, _ = self.predict_torch( 178 | self.predictor, 179 | image_embedding, 180 | in_points[:, None, :], 181 | in_labels[:, None], 182 | multimask_output=True, 183 | return_logits=True, 184 | ) 185 | 186 | # Serialize predictions and store in MaskData 187 | data = MaskData( 188 | masks=masks.flatten(0, 1), 189 | iou_preds=iou_preds.flatten(0, 1), 190 | points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), 191 | ) 192 | del masks 193 | 194 | # Filter by predicted IoU 195 | if self.pred_iou_thresh > 0.0: 196 | keep_mask = data["iou_preds"] > self.pred_iou_thresh 197 | data.filter(keep_mask) 198 | 199 | # Calculate stability score 200 | data["stability_score"] = calculate_stability_score( 201 | data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset 202 | ) 203 | if self.stability_score_thresh > 0.0: 204 | keep_mask = data["stability_score"] >= self.stability_score_thresh 205 | data.filter(keep_mask) 206 | 207 | # Threshold masks and calculate boxes 208 | data["masks"] = data["masks"] > self.predictor.model.mask_threshold 209 | data["boxes"] = batched_mask_to_box(data["masks"]) 210 | 211 | # Filter boxes that touch crop boundaries 212 | keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) 213 | if not torch.all(keep_mask): 214 | data.filter(keep_mask) 215 | 216 | # Compress to RLE 217 | data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) 218 | data["rles"] = mask_to_rle_pytorch(data["masks"]) 219 | del data["masks"] 220 | 221 | return data 222 | 223 | @staticmethod 224 | def postprocess_small_regions( 225 | mask_data: MaskData, min_area: int, nms_thresh: float 226 | ) -> MaskData: 227 | """ 228 | Removes small disconnected regions and holes in masks, then reruns 229 | box NMS to remove any new duplicates. 230 | 231 | Edits mask_data in place. 232 | 233 | Requires open-cv as a dependency. 234 | """ 235 | if len(mask_data["rles"]) == 0: 236 | return mask_data 237 | 238 | # Filter small disconnected regions and holes 239 | new_masks = [] 240 | scores = [] 241 | for rle in mask_data["rles"]: 242 | mask = rle_to_mask(rle) 243 | 244 | mask, changed = remove_small_regions(mask, min_area, mode="holes") 245 | unchanged = not changed 246 | mask, changed = remove_small_regions(mask, min_area, mode="islands") 247 | unchanged = unchanged and not changed 248 | 249 | new_masks.append(torch.as_tensor(mask).unsqueeze(0)) 250 | # Give score=0 to changed masks and score=1 to unchanged masks 251 | # so NMS will prefer ones that didn't need postprocessing 252 | scores.append(float(unchanged)) 253 | 254 | # Recalculate boxes and remove any new duplicates 255 | masks = torch.cat(new_masks, dim=0) 256 | boxes = batched_mask_to_box(masks) 257 | keep_by_nms = batched_nms( 258 | boxes.float(), 259 | torch.as_tensor(scores), 260 | torch.zeros_like(boxes[:, 0]), # categories 261 | iou_threshold=nms_thresh, 262 | ) 263 | 264 | # Only recalculate RLEs for masks that have changed 265 | for i_mask in keep_by_nms: 266 | if scores[i_mask] == 0.0: 267 | mask_torch = masks[i_mask].unsqueeze(0) 268 | mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] 269 | mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly 270 | mask_data.filter(keep_by_nms) 271 | 272 | return mask_data 273 | 274 | def predict_torch( 275 | self, predictor, 276 | image_embedding, 277 | point_coords, 278 | point_labels, 279 | boxes=None, 280 | mask_input=None, 281 | multimask_output: bool = True, 282 | return_logits: bool = False, 283 | ): 284 | if point_coords is not None: 285 | points = (point_coords, point_labels) 286 | else: 287 | points = None 288 | 289 | # Embed prompts 290 | sparse_embeddings, dense_embeddings = predictor.model.prompt_encoder( 291 | points=points, 292 | boxes=boxes, 293 | masks=mask_input, 294 | ) 295 | 296 | # Predict masks 297 | low_res_masks, iou_predictions = predictor.model.mask_decoder( 298 | image_embeddings=image_embedding, 299 | image_pe=predictor.model.prompt_encoder.get_dense_pe(), 300 | sparse_prompt_embeddings=sparse_embeddings, 301 | dense_prompt_embeddings=dense_embeddings, 302 | multimask_output=multimask_output, 303 | ) 304 | 305 | # Upscale the masks to the original image resolution 306 | masks = predictor.model.postprocess_masks(low_res_masks, predictor.input_size, predictor.original_size) 307 | 308 | if not return_logits: 309 | masks = masks > predictor.model.mask_threshold 310 | 311 | return masks, iou_predictions, low_res_masks 312 | -------------------------------------------------------------------------------- /torchange/models/segment_any_change/viz.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Zhuo Zheng and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from skimage.segmentation import find_boundaries 8 | from .segment_anything.utils.amg import ( 9 | area_from_rle, 10 | box_xyxy_to_xywh, 11 | rle_to_mask, 12 | MaskData 13 | ) 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | 17 | 18 | def show_mask_data(mask_data, ax=None): 19 | assert isinstance(mask_data, MaskData) 20 | anns = [] 21 | for idx in range(len(mask_data["rles"])): 22 | ann_i = { 23 | "segmentation": rle_to_mask(mask_data["rles"][idx]), 24 | "area": area_from_rle(mask_data["rles"][idx]), 25 | } 26 | if 'boxes' in mask_data._stats: 27 | ann_i['bbox'] = box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist() 28 | anns.append(ann_i) 29 | 30 | if len(anns) == 0: 31 | return 32 | sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) 33 | if ax is None: 34 | ax = plt.gca() 35 | ax.set_autoscale_on(False) 36 | 37 | img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4)) 38 | img[:, :, 3] = 0 39 | for ann in sorted_anns: 40 | m = ann['segmentation'] 41 | boundary = find_boundaries(m) 42 | color_mask = np.concatenate([np.random.random(3), [0.35]]) 43 | color_boundary = np.array([0., 1., 1., 0.8]) 44 | img[m] = color_mask 45 | img[boundary] = color_boundary 46 | 47 | if 'label' in ann: 48 | x, y, w, h = ann['bbox'] 49 | ax.text( 50 | x + w / 2, 51 | y + h / 2, 52 | ann['label'], 53 | bbox={ 54 | 'facecolor': 'black', 55 | 'alpha': 0.8, 56 | 'pad': 0.7, 57 | 'edgecolor': 'none' 58 | }, 59 | color='red', 60 | fontsize=4, 61 | verticalalignment='top', 62 | horizontalalignment='left' 63 | ) 64 | ax.imshow(img) 65 | 66 | 67 | def show_change_masks(img1, img2, change_masks): 68 | fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharex=True, sharey=True) 69 | axes[0].imshow(img1) 70 | show_mask_data(change_masks, axes[0]) 71 | 72 | axes[1].imshow(img2) 73 | show_mask_data(change_masks, axes[1]) 74 | 75 | axes[2].imshow(255 * np.ones_like(img1)) 76 | show_mask_data(change_masks, axes[2]) 77 | for ax in axes: 78 | ax.axis('off') 79 | 80 | return fig, axes 81 | -------------------------------------------------------------------------------- /torchange/module/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Zhuo Zheng and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchange/module/farseg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Zhuo Zheng and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import ever as er 8 | import ever.module as M 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torchange.module._sam_vit import SAMEncoder 14 | 15 | 16 | @er.registry.MODEL.register() 17 | class FarSegEncoder(M.ResNetEncoder): 18 | def __init__(self, config): 19 | super().__init__(config) 20 | if self.config.resnet_type in ['resnet18', 'resnet34']: 21 | max_channels = 512 22 | else: 23 | max_channels = 2048 24 | self.fpn = M.FPN([max_channels // (2 ** (3 - i)) for i in range(4)], 256) 25 | self.fsr = M.FSRelation(max_channels, 26 | [256 for _ in range(4)], 27 | 256, 28 | True) 29 | self.dec = M.AssymetricDecoder(256, 30 | self.config.out_channels) 31 | 32 | def forward(self, inputs): 33 | features = super().forward(inputs) 34 | coarsest_features = features[-1] 35 | scene_embedding = F.adaptive_avg_pool2d(coarsest_features, 1) 36 | features = self.fpn(features) 37 | features = self.fsr(scene_embedding, features) 38 | features = self.dec(features) 39 | 40 | return features 41 | 42 | def set_default_config(self): 43 | super().set_default_config() 44 | self.config.update(dict( 45 | out_channels=96, 46 | )) 47 | 48 | 49 | class FSRelationV3(nn.Module): 50 | def __init__( 51 | self, 52 | scene_embedding_dim, 53 | in_channels_list, 54 | out_channels, 55 | scale_aware_proj=False, 56 | ): 57 | super().__init__() 58 | self.scale_aware_proj = scale_aware_proj 59 | 60 | if scale_aware_proj: 61 | self.scene_encoder = nn.ModuleList( 62 | [nn.Sequential( 63 | nn.Conv2d(scene_embedding_dim, out_channels, 1), 64 | M.LayerNorm2d(out_channels), 65 | nn.GELU(), 66 | nn.Conv2d(out_channels, out_channels, 1), 67 | M.LayerNorm2d(out_channels), 68 | nn.GELU(), 69 | ) for _ in range(len(in_channels_list))] 70 | ) 71 | self.project = nn.ModuleList( 72 | [nn.Sequential( 73 | nn.Conv2d(out_channels * 2, out_channels, 1, bias=False), 74 | M.LayerNorm2d(out_channels), 75 | nn.GELU(), 76 | nn.Dropout2d(p=0.1) 77 | ) for _ in range(len(in_channels_list))] 78 | ) 79 | else: 80 | # 2mlp 81 | self.scene_encoder = nn.Sequential( 82 | nn.Conv2d(scene_embedding_dim, out_channels, 1), 83 | M.LayerNorm2d(out_channels), 84 | nn.GELU(), 85 | nn.Conv2d(out_channels, out_channels, 1), 86 | M.LayerNorm2d(out_channels), 87 | nn.GELU(), 88 | ) 89 | self.project = nn.Sequential( 90 | nn.Conv2d(out_channels * 2, out_channels, 1, bias=False), 91 | M.LayerNorm2d(out_channels), 92 | nn.GELU(), 93 | nn.Dropout2d(p=0.1) 94 | ) 95 | 96 | self.content_encoders = nn.ModuleList() 97 | self.feature_reencoders = nn.ModuleList() 98 | for c in in_channels_list: 99 | self.content_encoders.append( 100 | nn.Sequential( 101 | nn.Conv2d(c, out_channels, 1), 102 | M.LayerNorm2d(out_channels), 103 | nn.GELU(), 104 | ) 105 | ) 106 | self.feature_reencoders.append( 107 | nn.Sequential( 108 | nn.Conv2d(c, out_channels, 1), 109 | M.LayerNorm2d(out_channels), 110 | nn.GELU(), 111 | ) 112 | ) 113 | 114 | self.normalizer = nn.Sigmoid() 115 | 116 | def forward(self, scene_feature, features: list): 117 | # [N, C, H, W] 118 | content_feats = [c_en(p_feat) for c_en, p_feat in zip(self.content_encoders, features)] 119 | if self.scale_aware_proj: 120 | scene_feats = [op(scene_feature) for op in self.scene_encoder] 121 | relations = [self.normalizer((sf * cf).sum(dim=1, keepdim=True)) for sf, cf in 122 | zip(scene_feats, content_feats)] 123 | else: 124 | # [N, C, 1, 1] 125 | scene_feat = self.scene_encoder(scene_feature) 126 | relations = [self.normalizer((scene_feat * cf).sum(dim=1, keepdim=True)) for cf in content_feats] 127 | 128 | p_feats = [op(p_feat) for op, p_feat in zip(self.feature_reencoders, features)] 129 | 130 | refined_feats = [torch.cat([r * p, o], dim=1) for r, p, o in zip(relations, p_feats, features)] 131 | 132 | if self.scale_aware_proj: 133 | ffeats = [op(x) for op, x in zip(self.project, refined_feats)] 134 | else: 135 | ffeats = [self.project(x) for x in refined_feats] 136 | 137 | return ffeats 138 | 139 | 140 | class FarSegMixin(nn.Module): 141 | def __init__(self, in_channels, fpn_channels, out_channels): 142 | super().__init__() 143 | self.fpn = M.FPN(in_channels, fpn_channels) 144 | self.fsr = FSRelationV3( 145 | in_channels[-1], 146 | [fpn_channels for _ in range(4)], 147 | fpn_channels, 148 | scale_aware_proj=True 149 | ) 150 | self.dec = M.AssymetricDecoder( 151 | fpn_channels, 152 | out_channels, 153 | norm_fn=M.LayerNorm2d 154 | ) 155 | 156 | def forward(self, x): 157 | scene_embedding = F.adaptive_avg_pool2d(x[-1], 1) 158 | features = self.fpn(x) 159 | features = self.fsr(scene_embedding, features) 160 | features = self.dec(features) 161 | return features 162 | 163 | 164 | @er.registry.MODEL.register() 165 | class SAMEncoderFarSeg(SAMEncoder): 166 | def __init__(self, cfg): 167 | super().__init__(cfg) 168 | in_channels = [self.out_channels for _ in range(4)] 169 | 170 | self.farseg = FarSegMixin( 171 | in_channels=in_channels, 172 | fpn_channels=self.cfg.fpn_channels, 173 | out_channels=self.cfg.out_channels, 174 | ) 175 | 176 | def forward(self, x): 177 | features = super().forward(x) 178 | features = self.farseg(features) 179 | 180 | return features 181 | 182 | def set_default_config(self): 183 | super().set_default_config() 184 | self.config.update(dict( 185 | fpn_channels=256, 186 | )) 187 | --------------------------------------------------------------------------------