├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── fig_0_1_2.png ├── inference.py ├── models ├── backbone.py ├── dpt_head.py └── regressor.py ├── notebooks ├── README.md ├── chm_demo.yml ├── data │ └── example_polygon.geojson ├── download_chm.ipynb └── run_chm_model.ipynb ├── pl_modules └── normnet_module.py └── src ├── raster_utils.py └── transforms.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to HighResCanopyHeight 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Meta's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to DINOv2, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | 3 | Apache License 4 | Version 2.0, January 2004 5 | http://www.apache.org/licenses/ 6 | 7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 8 | 9 | 1. Definitions. 10 | 11 | "License" shall mean the terms and conditions for use, reproduction, 12 | and distribution as defined by Sections 1 through 9 of this document. 13 | 14 | "Licensor" shall mean the copyright owner or entity authorized by 15 | the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all 18 | other entities that control, are controlled by, or are under common 19 | control with that entity. For the purposes of this definition, 20 | "control" means (i) the power, direct or indirect, to cause the 21 | direction or management of such entity, whether by contract or 22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 23 | outstanding shares, or (iii) beneficial ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity 26 | exercising permissions granted by this License. 27 | 28 | "Source" form shall mean the preferred form for making modifications, 29 | including but not limited to software source code, documentation 30 | source, and configuration files. 31 | 32 | "Object" form shall mean any form resulting from mechanical 33 | transformation or translation of a Source form, including but 34 | not limited to compiled object code, generated documentation, 35 | and conversions to other media types. 36 | 37 | "Work" shall mean the work of authorship, whether in Source or 38 | Object form, made available under the License, as indicated by a 39 | copyright notice that is included in or attached to the work 40 | (an example is provided in the Appendix below). 41 | 42 | "Derivative Works" shall mean any work, whether in Source or Object 43 | form, that is based on (or derived from) the Work and for which the 44 | editorial revisions, annotations, elaborations, or other modifications 45 | represent, as a whole, an original work of authorship. For the purposes 46 | of this License, Derivative Works shall not include works that remain 47 | separable from, or merely link (or bind by name) to the interfaces of, 48 | the Work and Derivative Works thereof. 49 | 50 | "Contribution" shall mean any work of authorship, including 51 | the original version of the Work and any modifications or additions 52 | to that Work or Derivative Works thereof, that is intentionally 53 | submitted to Licensor for inclusion in the Work by the copyright owner 54 | or by an individual or Legal Entity authorized to submit on behalf of 55 | the copyright owner. For the purposes of this definition, "submitted" 56 | means any form of electronic, verbal, or written communication sent 57 | to the Licensor or its representatives, including but not limited to 58 | communication on electronic mailing lists, source code control systems, 59 | and issue tracking systems that are managed by, or on behalf of, the 60 | Licensor for the purpose of discussing and improving the Work, but 61 | excluding communication that is conspicuously marked or otherwise 62 | designated in writing by the copyright owner as "Not a Contribution." 63 | 64 | "Contributor" shall mean Licensor and any individual or Legal Entity 65 | on behalf of whom a Contribution has been received by Licensor and 66 | subsequently incorporated within the Work. 67 | 68 | 2. Grant of Copyright License. Subject to the terms and conditions of 69 | this License, each Contributor hereby grants to You a perpetual, 70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 71 | copyright license to reproduce, prepare Derivative Works of, 72 | publicly display, publicly perform, sublicense, and distribute the 73 | Work and such Derivative Works in Source or Object form. 74 | 75 | 3. Grant of Patent License. Subject to the terms and conditions of 76 | this License, each Contributor hereby grants to You a perpetual, 77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 78 | (except as stated in this section) patent license to make, have made, 79 | use, offer to sell, sell, import, and otherwise transfer the Work, 80 | where such license applies only to those patent claims licensable 81 | by such Contributor that are necessarily infringed by their 82 | Contribution(s) alone or by combination of their Contribution(s) 83 | with the Work to which such Contribution(s) was submitted. If You 84 | institute patent litigation against any entity (including a 85 | cross-claim or counterclaim in a lawsuit) alleging that the Work 86 | or a Contribution incorporated within the Work constitutes direct 87 | or contributory patent infringement, then any patent licenses 88 | granted to You under this License for that Work shall terminate 89 | as of the date such litigation is filed. 90 | 91 | 4. Redistribution. You may reproduce and distribute copies of the 92 | Work or Derivative Works thereof in any medium, with or without 93 | modifications, and in Source or Object form, provided that You 94 | meet the following conditions: 95 | 96 | (a) You must give any other recipients of the Work or 97 | Derivative Works a copy of this License; and 98 | 99 | (b) You must cause any modified files to carry prominent notices 100 | stating that You changed the files; and 101 | 102 | (c) You must retain, in the Source form of any Derivative Works 103 | that You distribute, all copyright, patent, trademark, and 104 | attribution notices from the Source form of the Work, 105 | excluding those notices that do not pertain to any part of 106 | the Derivative Works; and 107 | 108 | (d) If the Work includes a "NOTICE" text file as part of its 109 | distribution, then any Derivative Works that You distribute must 110 | include a readable copy of the attribution notices contained 111 | within such NOTICE file, excluding those notices that do not 112 | pertain to any part of the Derivative Works, in at least one 113 | of the following places: within a NOTICE text file distributed 114 | as part of the Derivative Works; within the Source form or 115 | documentation, if provided along with the Derivative Works; or, 116 | within a display generated by the Derivative Works, if and 117 | wherever such third-party notices normally appear. The contents 118 | of the NOTICE file are for informational purposes only and 119 | do not modify the License. You may add Your own attribution 120 | notices within Derivative Works that You distribute, alongside 121 | or as an addendum to the NOTICE text from the Work, provided 122 | that such additional attribution notices cannot be construed 123 | as modifying the License. 124 | 125 | You may add Your own copyright statement to Your modifications and 126 | may provide additional or different license terms and conditions 127 | for use, reproduction, or distribution of Your modifications, or 128 | for any such Derivative Works as a whole, provided Your use, 129 | reproduction, and distribution of the Work otherwise complies with 130 | the conditions stated in this License. 131 | 132 | 5. Submission of Contributions. Unless You explicitly state otherwise, 133 | any Contribution intentionally submitted for inclusion in the Work 134 | by You to the Licensor shall be under the terms and conditions of 135 | this License, without any additional terms or conditions. 136 | Notwithstanding the above, nothing herein shall supersede or modify 137 | the terms of any separate license agreement you may have executed 138 | with Licensor regarding such Contributions. 139 | 140 | 6. Trademarks. This License does not grant permission to use the trade 141 | names, trademarks, service marks, or product names of the Licensor, 142 | except as required for reasonable and customary use in describing the 143 | origin of the Work and reproducing the content of the NOTICE file. 144 | 145 | 7. Disclaimer of Warranty. Unless required by applicable law or 146 | agreed to in writing, Licensor provides the Work (and each 147 | Contributor provides its Contributions) on an "AS IS" BASIS, 148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 149 | implied, including, without limitation, any warranties or conditions 150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 151 | PARTICULAR PURPOSE. You are solely responsible for determining the 152 | appropriateness of using or redistributing the Work and assume any 153 | risks associated with Your exercise of permissions under this License. 154 | 155 | 8. Limitation of Liability. In no event and under no legal theory, 156 | whether in tort (including negligence), contract, or otherwise, 157 | unless required by applicable law (such as deliberate and grossly 158 | negligent acts) or agreed to in writing, shall any Contributor be 159 | liable to You for damages, including any direct, indirect, special, 160 | incidental, or consequential damages of any character arising as a 161 | result of this License or out of the use or inability to use the 162 | Work (including but not limited to damages for loss of goodwill, 163 | work stoppage, computer failure or malfunction, or any and all 164 | other commercial damages or losses), even if such Contributor 165 | has been advised of the possibility of such damages. 166 | 167 | 9. Accepting Warranty or Additional Liability. While redistributing 168 | the Work or Derivative Works thereof, You may choose to offer, 169 | and charge a fee for, acceptance of support, warranty, indemnity, 170 | or other liability obligations and/or rights consistent with this 171 | License. However, in accepting such obligations, You may act only 172 | on Your own behalf and on Your sole responsibility, not on behalf 173 | of any other Contributor, and only if You agree to indemnify, 174 | defend, and hold each Contributor harmless for any liability 175 | incurred by, or claims asserted against, such Contributor by reason 176 | of your accepting any such warranty or additional liability. 177 | 178 | END OF TERMS AND CONDITIONS 179 | 180 | APPENDIX: How to apply the Apache License to your work. 181 | 182 | To apply the Apache License to your work, attach the following 183 | boilerplate notice, with the fields enclosed by brackets "[]" 184 | replaced with your own identifying information. (Don't include 185 | the brackets!) The text should be enclosed in the appropriate 186 | comment syntax for the file format. We also recommend that a 187 | file or class name and description of purpose be included on the 188 | same "printed page" as the copyright notice for easier 189 | identification within third-party archives. 190 | 191 | Copyright [yyyy] [name of copyright owner] 192 | 193 | Licensed under the Apache License, Version 2.0 (the "License"); 194 | you may not use this file except in compliance with the License. 195 | You may obtain a copy of the License at 196 | 197 | http://www.apache.org/licenses/LICENSE-2.0 198 | 199 | Unless required by applicable law or agreed to in writing, software 200 | distributed under the License is distributed on an "AS IS" BASIS, 201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 202 | See the License for the specific language governing permissions and 203 | limitations under the License. 204 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # High Resolution Canopy Height Maps 2 | 3 | **[Meta AI Research, FAIR](https://ai.facebook.com/research/)** 4 | 5 | Jamie Tolan, 6 | Hung-I Yang, 7 | Benjamin Nosarzewski, 8 | Guillaume Couairon, 9 | Huy V. Vo, 10 | John Brandt, 11 | Justine Spore, 12 | Sayantan Majumdar, 13 | Daniel Haziza, 14 | Janaki Vamaraju, 15 | Théo Moutakanni, 16 | Piotr Bojanowski, 17 | Tracy Johns, 18 | Brian White, 19 | Tobias Tiecke, 20 | Camille Couprie 21 | 22 | [[`Paper`](https://doi.org/10.1016/j.rse.2023.113888)][[`ArxiV [same content]`](https://arxiv.org/abs/2304.07213)] [[`Blog`](https://research.facebook.com/blog/2023/4/every-tree-counts-large-scale-mapping-of-canopy-height-at-the-resolution-of-individual-trees/)] [[`BibTeX`](#citing-HighResCanopyHeight)] 23 | 24 | 25 | 26 | PyTorch implementation and pretrained models for High resolution Canopy Height Prediction inference. For details, see the paper: 27 | **[Very high resolution canopy height maps from RGB imagery using self-supervised vision transformer and convolutional decoder trained on Aerial Lidar](https://arxiv.org/abs/2304.07213)**. 28 | 29 | In collaboration with the Physical Modeling and Sustainability teams at Meta, and the World Resource Institute, we applied [DINOv2: Learning Robust Visual Features without Supervision](https://arxiv.org/abs/2304.07193) to the canopy height map (CHM) prediction problem. We used this technique to pretrain a backbone on about 18 millions satellite images around the globe. We then trained a CHM predictor on a modest sized training dataset covering a few thousand square kilometers of forests in the United States, with this backbone as feature extractor. 30 | We demonstrate in our paper quantitatively and qualitatively the advantages of large-scale self-supervised learning, the versatility of obtained representations allowing generalization to different geographic regions and input imagery. 31 | 32 | The maps obtained with this model are available at https://wri-datalab.earthengine.app/view/submeter-canopyheight. 33 | 34 | 35 | ![alt text](https://github.com/facebookresearch/HighResCanopyHeight/blob/main/fig_0_1_2.png) 36 | 37 | ## Requirements 38 | 39 | pytorch, 40 | pytorch lightning, 41 | pandas 42 | 43 | Example of successful environment creation for inference 44 | 45 | ``` 46 | conda create -n hrch python=3.9 -y 47 | conda activate hrch 48 | conda install pytorch==2.0.1 torchvision==0.15.2 pytorch-cuda=11.7 -c pytorch -c nvidia 49 | pip install pytorch_lightning==1.7 50 | pip install pandas 51 | pip install matplotlib 52 | pip install torchmetrics==0.11.4 53 | ``` 54 | 55 | 56 | ## Data and pretrained models 57 | 58 | You can download the data and saved checkpoints from 59 | ``` 60 | s3://dataforgood-fb-data/forests/v1/models/ 61 | ``` 62 | 63 | ### Data 64 | 65 | To prepare the data, in the cloned repository, run these commands: 66 | ``` 67 | aws s3 --no-sign-request cp --recursive s3://dataforgood-fb-data/forests/v1/models/ . 68 | unzip data.zip 69 | rm data.zip 70 | ``` 71 | 72 | Although our method is designed to work from satellite images, it can also estimate canopy height from aerial images. 73 | 74 | We share aerial images for the Neon test set we created for the paper in data.zip. 75 | 76 | To automate the color balancing without the need of Maxar images, we trained a network from aerial images (Neon train) to predict the 95th and 5th percentiles of the corresponding maxar images : saved_checkpoints/aerial_normalization_quantiles_predictor.ckpt 77 | 78 | ### SSL Pretrained models 79 | 80 | In the saved_checkpoints directory, there are: 81 | 82 | SSLhuge_satellite.pth (2.9G): encoder trained on satellite images, decoder trained on satellite images. Use this model for inference on GPUs. Best results using RGB satellite images in input. 83 | 84 | compressed_SSLhuge.pth (749M): SSLhuge_satellite.pth quantized. Model used in the evaluations of the paper. 85 | 86 | compressed_SSLhuge_aerial.pth (749M): encoder trained on satellite images, decoder trained on aerial images. 87 | 88 | compressed_SSLlarge.pth (400M): ablation using a large model. 89 | 90 | ## Evaluation 91 | 92 | ``` 93 | python inference.py --checkpoint saved_checkpoints/SSLhuge_satellite.pth 94 | ``` 95 | ``` 96 | mae 3.15 97 | r2_block 0.51 98 | Bias: -1.60 99 | ``` 100 | 101 | Here are the performance on aerial images to expect with the different models released. Please note that the 3 first models in this table are trained exclusively on satellite data and are evaluated here in an out of domain context. 102 | 103 | | | SSL large | SSL huge | compressed SSL huge | SSL aerial| 104 | | --- | ---| --- | --- | ---| 105 | | MAE| 3.31 | 3.15 | 3.08 | 2.5 | 106 | | R2 block | 0.37 | 0.51 | 0.54 | 0.7 | 107 | | Bias | -1.4| -1.6 | -1.6 | -2.1 | 108 | 109 | ## Notes 110 | 111 | We do not include the GEDI correction step in this code release. 112 | 113 | The folder "models" contains code borrowed from the Dinov2 team, we thank all contributors. 114 | 115 | The inference using compressed models has not been tested using GPUs (CPU only). 116 | 117 | The backbone weights are the same for all SSL models. The backbone has been trained on images filtered to contain mainly vegetation. 118 | 119 | ## License 120 | 121 | HighResCanopyHeight code and model weights are released under the Apache License 2.0. See [LICENSE](LICENSE) for additional details. 122 | 123 | ## Contributing 124 | 125 | See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md). 126 | 127 | ## Citing HighResCanopyHeight 128 | 129 | If you find this repository useful, please consider giving a star :star: and citation :t-rex:: 130 | 131 | ``` 132 | @article{tolan2024very, 133 | title={Very high resolution canopy height maps from RGB imagery using self-supervised vision transformer and convolutional decoder trained on aerial lidar}, 134 | author={Tolan, Jamie and Yang, Hung-I and Nosarzewski, Benjamin and Couairon, Guillaume and Vo, Huy V and Brandt, John and Spore, Justine and Majumdar, Sayantan and Haziza, Daniel and Vamaraju, Janaki and others}, 135 | journal={Remote Sensing of Environment}, 136 | volume={300}, 137 | pages={113888}, 138 | year={2024} 139 | } 140 | ``` 141 | 142 | -------------------------------------------------------------------------------- /fig_0_1_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/HighResCanopyHeight/84971392fb3a017e53d19e28ee0409d0de5d5d61/fig_0_1_2.png -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import argparse 7 | import os 8 | import torch 9 | import pandas as pd 10 | import numpy as np 11 | import torchvision.transforms as T 12 | import matplotlib.pyplot as plt 13 | import torchmetrics 14 | from pathlib import Path 15 | import torch.nn as nn 16 | from tqdm import tqdm 17 | from PIL import Image 18 | import math 19 | import torchvision.transforms.functional as TF 20 | import torchvision 21 | from torchvision.utils import save_image 22 | 23 | from models.backbone import SSLVisionTransformer 24 | from models.dpt_head import DPTHead 25 | import pytorch_lightning as pl 26 | from models.regressor import RNet 27 | 28 | class SSLAE(nn.Module): 29 | def __init__(self, pretrained=None, classify=True, n_bins=256, huge=False): 30 | super().__init__() 31 | if huge == True: 32 | self.backbone = SSLVisionTransformer( 33 | embed_dim=1280, 34 | num_heads=20, 35 | out_indices=(9, 16, 22, 29), 36 | depth=32, 37 | pretrained=pretrained 38 | ) 39 | self.decode_head = DPTHead( 40 | classify=classify, 41 | in_channels=(1280, 1280, 1280, 1280), 42 | embed_dims=1280, 43 | post_process_channels=[160, 320, 640, 1280], 44 | ) 45 | else: 46 | self.backbone = SSLVisionTransformer(pretrained=pretrained) 47 | self.decode_head = DPTHead(classify=classify,n_bins=256) 48 | 49 | def forward(self, x): 50 | x = self.backbone(x) 51 | x = self.decode_head(x) 52 | return x 53 | 54 | class SSLModule(pl.LightningModule): 55 | def __init__(self, 56 | ssl_path="compressed_SSLbaseline.pth"): 57 | super().__init__() 58 | 59 | if 'huge' in ssl_path: 60 | self.chm_module_ = SSLAE(classify=True, huge=True).eval() 61 | else: 62 | self.chm_module_ = SSLAE(classify=True, huge=False).eval() 63 | 64 | if 'compressed' in ssl_path: 65 | ckpt = torch.load(ssl_path, map_location='cpu') 66 | self.chm_module_ = torch.quantization.quantize_dynamic( 67 | self.chm_module_, 68 | {torch.nn.Linear,torch.nn.Conv2d, torch.nn.ConvTranspose2d}, 69 | dtype=torch.qint8) 70 | self.chm_module_.load_state_dict(ckpt, strict=False) 71 | else: 72 | ckpt = torch.load(ssl_path) 73 | state_dict = ckpt['state_dict'] 74 | self.chm_module_.load_state_dict(state_dict) 75 | 76 | self.chm_module = lambda x: 10*self.chm_module_(x) 77 | def forward(self, x): 78 | x = self.chm_module(x) 79 | return x 80 | 81 | class NeonDataset(torch.utils.data.Dataset): 82 | path = './data/images/' 83 | root_dir = Path(path) 84 | df_path = './data/neon_test_data.csv' 85 | 86 | def __init__(self, model_norm, new_norm, src_img='maxar', 87 | trained_rgb= False, no_norm = False, 88 | **kwargs): 89 | 90 | self.no_norm = no_norm 91 | self.model_norm = model_norm 92 | self.new_norm = new_norm 93 | self.trained_rgb = trained_rgb 94 | self.size = 256 95 | self.df = pd.read_csv(self.df_path, index_col=0) 96 | self.src_img = src_img 97 | 98 | # number of times crops can be used horizontally 99 | self.size_multiplier = 6 100 | 101 | def __len__(self): 102 | if self.src_img == 'neon': 103 | return 30 * len(self.df) 104 | return len(self.df) 105 | 106 | 107 | def __getitem__(self, i): 108 | n = self.size_multiplier 109 | ix, jx, jy = i//(n**2), (i%(n**2))// n, (i% (n**2)) % n 110 | if self.src_img == 'neon': 111 | l = self.df.iloc[ix] 112 | x = list(range(l.bord_x, l.imsize-l.bord_x-self.size, self.size))[jx] 113 | y = list(range(l.bord_y, l.imsize-l.bord_y-self.size, self.size))[jy] 114 | img = TF.to_tensor(Image.open(self.root_dir / l[self.src_img]).crop((x, y, x+self.size, y+self.size))) 115 | chm = TF.to_tensor(Image.open(self.root_dir / l.chm).crop((x, y, x+self.size, y+self.size))) 116 | chm[chm<0] = 0 117 | 118 | if not self.trained_rgb: 119 | if self.src_img == 'neon': 120 | if self.no_norm: 121 | normIn = img 122 | else: 123 | if self.new_norm: 124 | # image image normalization using learned quantiles of pairs of Maxar/Neon images 125 | x = torch.unsqueeze(img, dim=0) 126 | norm_img = self.model_norm(x).detach() 127 | p5I = [norm_img[0][0].item(), norm_img[0][1].item(), norm_img[0][2].item()] 128 | p95I = [norm_img[0][3].item(), norm_img[0][4].item(), norm_img[0][5].item()] 129 | else: 130 | # apply image normalization to aerial images, matching color intensity of maxar images 131 | I = TF.to_tensor(Image.open(self.root_dir / l['maxar']).crop((x, y, x+s, y+s))) 132 | p5I = [np.percentile(I[i,:,:].flatten(),5) for i in range(3)] 133 | p95I = [np.percentile(I[i,:,:].flatten(),95) for i in range(3)] 134 | p5In = [np.percentile(img[i,:,:].flatten(),5) for i in range(3)] 135 | 136 | p95In = [np.percentile(img[i,:,:].flatten(),95) for i in range(3)] 137 | normIn = img.clone() 138 | for i in range(3): 139 | normIn[i,:,:] = (img[i,:,:]-p5In[i]) * ((p95I[i]-p5I[i])/(p95In[i]-p5In[i])) + p5I[i] 140 | 141 | return {'img': normIn, 142 | 'img_no_norm': img, 143 | 'chm': chm, 144 | 'lat':torch.Tensor([l.lat]).nan_to_num(0), 145 | 'lon':torch.Tensor([l.lon]).nan_to_num(0), 146 | } 147 | 148 | def evaluate(model, 149 | norm, 150 | model_norm, 151 | name, 152 | bs=32, 153 | trained_rgb=False, 154 | normtype=2, 155 | device = 'cuda:0', 156 | no_norm = False, 157 | display = False): 158 | 159 | dataset_key = 'neon_aerial' 160 | 161 | print("normtype", normtype) 162 | 163 | # choice of the normalization of aerial images. 164 | # i- For inference on satellite images args.normtype should be set to 0; 165 | # ii- For inference on aerial images, if corresponding Maxar quantiles at the 166 | # same coordinates are known, args.normtype should be set to 1; 167 | # iii- For inference on aerial images, an automatic normalization using a pretrained 168 | # network on aerial and satellite images on Neon can be used: args.normtype should be set to 2 (default); 169 | 170 | new_norm=True 171 | no_norm=False 172 | if normtype == 0: 173 | no_norm=True 174 | elif normtype == 1: 175 | new_norm=False 176 | elif normtype == 2: 177 | new_norm=True 178 | 179 | ds = NeonDataset( model_norm, new_norm, domain='test', src_img='neon', trained_rgb=trained_rgb, no_norm=no_norm) 180 | dataloader = torch.utils.data.DataLoader(ds, batch_size=bs, shuffle=True, num_workers=10) 181 | 182 | Path('../reports').joinpath(name).mkdir(parents=True, exist_ok=True) 183 | Path('../reports/'+name).joinpath('results_for_fig_'+dataset_key).mkdir(parents=True, exist_ok=True) 184 | metrics = {} 185 | 186 | # canopy height metrics 187 | metric_classes = dict( 188 | mae = torchmetrics.MeanAbsoluteError(), 189 | rmse = torchmetrics.MeanSquaredError(squared= False), 190 | r2 = torchmetrics.R2Score(), 191 | r2_block = torchmetrics.R2Score()) 192 | 193 | downsampler = nn.AvgPool2d(50) 194 | bd = 3 195 | 196 | preds, chms = [], [] 197 | chm_block_means, pred_block_means = [], [] 198 | 199 | fig_batch_ind = 0 200 | 201 | for batch in tqdm(dataloader): 202 | chm = batch['chm'].detach() 203 | batch = {k:v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)} 204 | pred = model(norm(batch['img'])) 205 | pred = pred.cpu().detach().relu() 206 | 207 | if display == True: 208 | # display Predicted CHM 209 | for ind in range(pred.shape[0]): 210 | fig, ax = plt.subplots(nrows=1, ncols=4, figsize=(20, 5)) 211 | plt.subplots_adjust(hspace=0.5) 212 | img_no_norm = batch['img_no_norm'][ind].cpu() 213 | Inn = np.moveaxis(img_no_norm.numpy(), 0, 2) 214 | img = batch['img'][ind].cpu() 215 | I = np.moveaxis(img.numpy(), 0, 2) 216 | gt = batch['chm'][ind].cpu() 217 | GT = np.moveaxis(gt.numpy(), 0, 2) 218 | ax[0].imshow(Inn) 219 | ax[0].set_title(f"Image",fontsize=12) 220 | ax[0].set_xlabel('meters') 221 | ax[1].imshow(I) 222 | ax[1].set_title(f"Normalized Image ",fontsize=12) 223 | ax[1].set_xlabel('meters') 224 | combined_data = np.concatenate((batch['chm'][ind].cpu().numpy(), pred[ind].detach().numpy()), axis=0) 225 | _min, _max = np.amin(combined_data), np.amax(combined_data) 226 | pltim = ax[2].imshow(pred[ind][0].detach().numpy(), vmin = _min, vmax = _max) 227 | ax[2].set_title(f"Pred CHM",fontsize=12) 228 | ax[2].set_xlabel('meters') 229 | pltim = ax[3].imshow(GT, vmin = _min, vmax = _max) 230 | ax[3].set_title(f"GT CHM",fontsize=12) 231 | ax[3].set_xlabel('meters') 232 | cax = fig.add_axes([0.95, 0.15, 0.02, 0.7]) 233 | fig.colorbar(pltim, cax=cax, orientation="vertical") 234 | cax.set_title("meters", fontsize=12) 235 | plt.savefig(f"{name}/fig_{fig_batch_ind}_{ind}_{normtype}.png", dpi=300) 236 | 237 | fig_batch_ind = fig_batch_ind + 1 238 | 239 | chm_block_mean = downsampler(chm[..., bd:, bd:]) 240 | pred_block_mean = downsampler(pred[..., bd:, bd:]) 241 | 242 | metric_classes['mae'].update(pred, chm) 243 | metric_classes['rmse'].update(pred, chm) 244 | metric_classes['r2'].update(pred.flatten(), chm.flatten()) 245 | metric_classes['r2_block'].update(pred_block_mean.flatten(), chm_block_mean.flatten()) 246 | 247 | preds.append(pred), chms.append(chm) 248 | chm_block_means.append(chm_block_mean) 249 | pred_block_means.append(pred_block_mean) 250 | if display: 251 | break 252 | preds, chms = torch.cat(preds), torch.cat(chms) 253 | 254 | metrics = {k:v.compute() for k, v in metric_classes.items()} 255 | torch.save(metrics, f'{name}/metrics.pt') 256 | 257 | #print metrics 258 | for k, v in metrics.items(): 259 | print(f'{k} {v.item():.2f}') 260 | print(f"Bias: {(preds.flatten()-chms.flatten()).numpy().mean():.2f}") 261 | 262 | 263 | def parse_args(): 264 | parser = argparse.ArgumentParser( 265 | description='test a model') 266 | parser.add_argument('--checkpoint', type=str, help='CHM pred checkpoint file', default='saved_checkpoints/compressed_SSLlarge.pth') 267 | parser.add_argument('--name', type=str, help='run name', default='output_inference') 268 | parser.add_argument('--trained_rgb', type=str, help='True if model was finetuned on aerial data') 269 | parser.add_argument('--normnet', type=str, help='path to a normalization network', default='saved_checkpoints/aerial_normalization_quantiles_predictor.ckpt') 270 | parser.add_argument('--normtype', type=int, help='0: no norm; 1: old norm, 2: new norm', default=2) 271 | parser.add_argument('--display', type=bool, help='saving outputs in images') 272 | args = parser.parse_args() 273 | return args 274 | 275 | 276 | 277 | def main(): 278 | # 0- read args 279 | args = parse_args() 280 | if 'compressed' in args.checkpoint: 281 | device='cpu' 282 | else: 283 | device='cuda:0' 284 | 285 | os.system("mkdir "+args.name) 286 | 287 | # 1- load network and its weight to normalize aerial images to match intensities from satellite images. 288 | norm_path = args.normnet 289 | ckpt = torch.load(norm_path, map_location='cpu') 290 | state_dict = ckpt['state_dict'] 291 | for k in list(state_dict.keys()): 292 | if 'backbone.' in k: 293 | new_k = k.replace('backbone.','') 294 | state_dict[new_k] = state_dict.pop(k) 295 | 296 | model_norm = RNet(n_classes=6) 297 | model_norm = model_norm.eval() 298 | model_norm.load_state_dict(state_dict) 299 | 300 | # 2- load SSL model 301 | model = SSLModule(ssl_path = args.checkpoint) 302 | model.to(device) 303 | model = model.eval() 304 | 305 | # 3- image normalization for each image going through the encoder 306 | norm = T.Normalize((0.420, 0.411, 0.296), (0.213, 0.156, 0.143)) 307 | norm = norm.to(device) 308 | 309 | # 4- evaluation 310 | evaluate(model, norm, model_norm, name=args.name, bs=16, trained_rgb=args.trained_rgb, normtype=args.normtype, device=device, display=args.display) 311 | 312 | if __name__ == '__main__': 313 | main() 314 | -------------------------------------------------------------------------------- /models/backbone.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | from torch import nn 8 | import torchvision 9 | from torch.nn.modules.batchnorm import _BatchNorm 10 | from torch.nn.modules.utils import _pair as to_2tuple 11 | import math 12 | import warnings 13 | from collections import OrderedDict 14 | from torch import Tensor 15 | 16 | import torch.nn.functional as F 17 | from typing import Callable, Optional, Tuple, Union 18 | from functools import partial 19 | import pdb 20 | 21 | class MaskingGenerator: 22 | def __init__( 23 | self, 24 | input_size, 25 | num_masking_patches=None, 26 | min_num_patches=4, 27 | max_num_patches=None, 28 | min_aspect=0.3, 29 | max_aspect=None, 30 | ): 31 | if not isinstance(input_size, tuple): 32 | input_size = (input_size,) * 2 33 | self.height, self.width = input_size 34 | 35 | self.num_patches = self.height * self.width 36 | self.num_masking_patches = num_masking_patches 37 | 38 | self.min_num_patches = min_num_patches 39 | self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches 40 | 41 | max_aspect = max_aspect or 1 / min_aspect 42 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) 43 | 44 | def __repr__(self): 45 | repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % ( 46 | self.height, 47 | self.width, 48 | self.min_num_patches, 49 | self.max_num_patches, 50 | self.num_masking_patches, 51 | self.log_aspect_ratio[0], 52 | self.log_aspect_ratio[1], 53 | ) 54 | return repr_str 55 | 56 | def get_shape(self): 57 | return self.height, self.width 58 | 59 | def _mask(self, mask, max_mask_patches): 60 | delta = 0 61 | for attempt in range(10): 62 | target_area = random.uniform(self.min_num_patches, max_mask_patches) 63 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 64 | h = int(round(math.sqrt(target_area * aspect_ratio))) 65 | w = int(round(math.sqrt(target_area / aspect_ratio))) 66 | if w < self.width and h < self.height: 67 | top = random.randint(0, self.height - h) 68 | left = random.randint(0, self.width - w) 69 | 70 | num_masked = mask[top : top + h, left : left + w].sum() 71 | # Overlap 72 | if 0 < h * w - num_masked <= max_mask_patches: 73 | for i in range(top, top + h): 74 | for j in range(left, left + w): 75 | if mask[i, j] == 0: 76 | mask[i, j] = 1 77 | delta += 1 78 | 79 | if delta > 0: 80 | break 81 | return delta 82 | 83 | def __call__(self, num_masking_patches=0): 84 | mask = np.zeros(shape=self.get_shape(), dtype=np.bool) 85 | mask_count = 0 86 | while mask_count < num_masking_patches: 87 | max_mask_patches = num_masking_patches - mask_count 88 | max_mask_patches = min(max_mask_patches, self.max_num_patches) 89 | 90 | delta = self._mask(mask, max_mask_patches) 91 | if delta == 0: 92 | break 93 | else: 94 | mask_count += delta 95 | 96 | return mask 97 | 98 | 99 | def resize(input, 100 | size=None, 101 | scale_factor=None, 102 | mode='nearest', 103 | align_corners=None, 104 | warning=False): 105 | if warning: 106 | if size is not None and align_corners: 107 | input_h, input_w = tuple(int(x) for x in input.shape[2:]) 108 | output_h, output_w = tuple(int(x) for x in size) 109 | if output_h > input_h or output_w > output_h: 110 | if ((output_h > 1 and output_w > 1 and input_h > 1 111 | and input_w > 1) and (output_h - 1) % (input_h - 1) 112 | and (output_w - 1) % (input_w - 1)): 113 | warnings.warn( 114 | f'When align_corners={align_corners}, ' 115 | 'the output would more aligned if ' 116 | f'input size {(input_h, input_w)} is `x+1` and ' 117 | f'out size {(output_h, output_w)} is `nx+1`') 118 | 119 | return F.interpolate(input, size, scale_factor, mode, align_corners) 120 | 121 | 122 | class Mlp(nn.Module): 123 | def __init__( 124 | self, 125 | in_features: int, 126 | hidden_features: Optional[int] = None, 127 | out_features: Optional[int] = None, 128 | act_layer: Callable[..., nn.Module] = nn.GELU(), 129 | drop: float = 0.0, 130 | ) -> None: 131 | super().__init__() 132 | out_features = out_features or in_features 133 | hidden_features = hidden_features or in_features 134 | self.fc1 = nn.Linear(in_features, hidden_features) 135 | self.act = act_layer() 136 | self.fc2 = nn.Linear(hidden_features, out_features) 137 | self.drop = nn.Dropout(drop) 138 | 139 | def forward(self, x: Tensor) -> Tensor: 140 | x = self.fc1(x) 141 | x = self.act(x) 142 | x = self.drop(x) 143 | x = self.fc2(x) 144 | x = self.drop(x) 145 | return x 146 | 147 | 148 | class Attention(nn.Module): 149 | def __init__( 150 | self, 151 | dim: int, 152 | num_heads: int = 8, 153 | qkv_bias: bool = False, 154 | attn_drop: float = 0.0, 155 | proj_drop: float = 0.0, 156 | ) -> None: 157 | super().__init__() 158 | self.num_heads = num_heads 159 | head_dim = dim // num_heads 160 | self.scale = head_dim**-0.5 161 | 162 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 163 | self.attn_drop = nn.Dropout(attn_drop) 164 | self.proj = nn.Linear(dim, dim) 165 | self.proj_drop = nn.Dropout(proj_drop) 166 | 167 | def forward(self, x: Tensor) -> Tensor: 168 | B, N, C = x.shape 169 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 170 | 171 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] 172 | attn = q @ k.transpose(-2, -1) 173 | 174 | attn = attn.softmax(dim=-1) 175 | attn = self.attn_drop(attn) 176 | 177 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 178 | x = self.proj(x) 179 | x = self.proj_drop(x) 180 | return x 181 | 182 | 183 | 184 | class LayerScale(nn.Module): 185 | def __init__( 186 | self, 187 | dim: int, 188 | init_values: Union[float, Tensor] = 1e-5, 189 | inplace: bool = False, 190 | ) -> None: 191 | super().__init__() 192 | self.inplace = inplace 193 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 194 | 195 | def forward(self, x: Tensor) -> Tensor: 196 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 197 | 198 | 199 | class Block(nn.Module): 200 | def __init__( 201 | self, 202 | dim: int, 203 | num_heads: int, 204 | mlp_ratio: float = 4.0, 205 | qkv_bias: bool = False, 206 | drop: float = 0.0, 207 | attn_drop: float = 0.0, 208 | init_values=None, 209 | drop_path: float = 0.0, 210 | act_layer: Callable[..., nn.Module] = nn.GELU(), 211 | norm_layer: Callable[..., nn.Module] = nn.LayerNorm, 212 | attn_class: Callable[..., nn.Module] = Attention, 213 | ffn_layer: Callable[..., nn.Module] = Mlp, 214 | ) -> None: 215 | super().__init__() 216 | self.norm1 = norm_layer(dim) 217 | self.attn = attn_class( 218 | dim, 219 | num_heads=num_heads, 220 | qkv_bias=qkv_bias, 221 | attn_drop=attn_drop, 222 | proj_drop=drop, 223 | ) 224 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 225 | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 226 | 227 | self.norm2 = norm_layer(dim) 228 | mlp_hidden_dim = int(dim * mlp_ratio) 229 | self.mlp = ffn_layer( 230 | in_features=dim, 231 | hidden_features=mlp_hidden_dim, 232 | act_layer=act_layer, 233 | drop=drop, 234 | ) 235 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 236 | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 237 | 238 | self.sample_drop_ratio = drop_path 239 | 240 | def forward(self, x: Tensor) -> Tensor: 241 | #pdb.set_trace() 242 | def attn_residual_func(x: Tensor) -> Tensor: 243 | return self.ls1(self.attn(self.norm1(x))) 244 | 245 | def ffn_residual_func(x: Tensor) -> Tensor: 246 | return self.ls2(self.mlp(self.norm2(x))) 247 | 248 | if self.training and self.sample_drop_ratio > 0.1: 249 | x = drop_add_residual_stochastic_depth( 250 | x, 251 | residual_func=attn_residual_func, 252 | sample_drop_ratio=self.sample_drop_ratio, 253 | ) 254 | x = drop_add_residual_stochastic_depth( 255 | x, 256 | residual_func=ffn_residual_func, 257 | sample_drop_ratio=self.sample_drop_ratio, 258 | ) 259 | elif self.training and self.sample_drop_ratio > 0.0: 260 | x = x + self.drop_path1(attn_residual_func(x)) 261 | x = x + self.drop_path1(ffn_residual_func(x)) 262 | else: 263 | x = x + attn_residual_func(x) 264 | x = x + ffn_residual_func(x) 265 | return x 266 | 267 | 268 | def make_2tuple(x): 269 | if isinstance(x, tuple): 270 | assert len(tuple) == 2 271 | return x 272 | 273 | assert isinstance(x, int) 274 | return (x, x) 275 | 276 | 277 | class PatchEmbed(nn.Module): 278 | """ 279 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D) 280 | 281 | Args: 282 | img_size: Image size. 283 | patch_size: Patch token size. 284 | in_chans: Number of input image channels. 285 | embed_dim: Number of linear projection output channels. 286 | norm_layer: Normalization layer. 287 | """ 288 | 289 | def __init__( 290 | self, 291 | img_size: Union[int, Tuple[int, int]] = 224, 292 | patch_size: Union[int, Tuple[int, int]] = 16, 293 | in_chans: int = 3, 294 | embed_dim: int = 768, 295 | norm_layer: Optional[Callable] = None, 296 | ) -> None: 297 | super().__init__() 298 | 299 | image_HW = make_2tuple(img_size) 300 | patch_HW = make_2tuple(patch_size) 301 | patch_grid_size = ( 302 | image_HW[0] // patch_HW[0], 303 | image_HW[1] // patch_HW[1], 304 | ) 305 | 306 | self.img_size = image_HW 307 | self.patch_size = patch_HW 308 | self.patches_resolution = patch_grid_size 309 | self.num_patches = patch_grid_size[0] * patch_grid_size[1] 310 | 311 | self.in_chans = in_chans 312 | self.embed_dim = embed_dim 313 | 314 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) 315 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 316 | 317 | 318 | def forward(self, x: Tensor) -> Tensor: 319 | _, _, H, W = x.shape 320 | patch_H, patch_W = self.patch_size 321 | 322 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" 323 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" 324 | 325 | x = self.proj(x) 326 | x = x.flatten(2).transpose(1, 2) 327 | x = self.norm(x) 328 | return x 329 | 330 | def flops(self) -> float: 331 | Ho, Wo = self.patches_resolution 332 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 333 | if self.norm is not None: 334 | flops += Ho * Wo * self.embed_dim 335 | return flops 336 | 337 | 338 | class DinoVisionTransformer(nn.Module): 339 | """Vision Transformer 340 | 341 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` 342 | - https://arxiv.org/abs/2010.11929 343 | """ 344 | 345 | def __init__( 346 | self, 347 | img_size=224, 348 | patch_size=16, 349 | in_chans=3, 350 | num_classes=0, 351 | global_pool="token", 352 | embed_dim=1024, 353 | depth=24, 354 | num_heads=16, 355 | mlp_ratio=4.0, 356 | qkv_bias=True, 357 | representation_size=None, 358 | drop_rate=0.0, 359 | attn_drop_rate=0.0, 360 | drop_path_rate=0.0, 361 | weight_init="", 362 | init_values=1., 363 | embed_layer=PatchEmbed, 364 | norm_layer=None, 365 | act_layer=None, 366 | block_fn=Block, 367 | ffn_layer="mlp", 368 | drop_path_uniform=False, 369 | patch_drop=0.0, 370 | sin_cos_embeddings=False, 371 | local_crops_size=96, 372 | multiple_pos_embeddings=False, 373 | ): 374 | """ 375 | Args: 376 | img_size (int, tuple): input image size 377 | patch_size (int, tuple): patch size 378 | in_chans (int): number of input channels 379 | num_classes (int): number of classes for classification head 380 | global_pool (str): type of global pooling for final sequence (default: 'token') 381 | embed_dim (int): embedding dimension 382 | depth (int): depth of transformer 383 | num_heads (int): number of attention heads 384 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 385 | qkv_bias (bool): enable bias for qkv if True 386 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 387 | drop_rate (float): dropout rate 388 | attn_drop_rate (float): attention dropout rate 389 | drop_path_rate (float): stochastic depth rate 390 | weight_init: (str): weight init scheme 391 | init_values: (float): layer-scale init values 392 | embed_layer (nn.Module): patch embedding layer 393 | norm_layer: (nn.Module): normalization layer 394 | act_layer: (nn.Module): MLP activation layer 395 | """ 396 | super().__init__() 397 | assert global_pool in ("", "avg", "token") 398 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 399 | act_layer = act_layer or nn.GELU 400 | 401 | self.num_classes = num_classes 402 | self.global_pool = global_pool 403 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 404 | self.num_tokens = 1 405 | self.grad_checkpointing = False 406 | self.sin_cos_embeddings = sin_cos_embeddings 407 | self.multiple_pos_embeddings = multiple_pos_embeddings 408 | 409 | self.patch_embed = embed_layer( 410 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim 411 | ) 412 | num_patches = self.patch_embed.num_patches 413 | 414 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 415 | if self.sin_cos_embeddings: 416 | self.pos_embed = torch.Tensor(()) 417 | logger.info("using sin-cos fixed embeddings") 418 | pass 419 | elif self.multiple_pos_embeddings: 420 | logger.info("using multiple position embeddings (one for global one for local)") 421 | self.pos_embeds = nn.ParameterDict() 422 | self.pos_embeds[str(img_size)] = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 423 | n_local_patches = (local_crops_size // patch_size) ** 2 424 | self.pos_embeds[str(local_crops_size)] = nn.Parameter(torch.zeros(1, n_local_patches, embed_dim)) 425 | self.pos_embed = None 426 | else: 427 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) 428 | self.pos_drop = nn.Dropout(p=drop_rate) 429 | 430 | if drop_path_uniform is True: 431 | dpr = [drop_path_rate] * depth 432 | else: 433 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 434 | 435 | if ffn_layer == "mlp": 436 | #print("using MLP layer as FFN") 437 | ffn_layer = Mlp 438 | elif ffn_layer == "swiglu": 439 | #print("using SwiGLU layer as FFN") 440 | ffn_layer = SwiGLUFFN 441 | elif ffn_layer == "identity": 442 | #print("using Identity layer as FFN") 443 | def f(*args, **kwargs): 444 | return nn.Identity() 445 | ffn_layer = f 446 | else: 447 | raise NotImplementedError 448 | 449 | self.blocks = nn.ModuleList( 450 | [ 451 | block_fn( 452 | dim=embed_dim, 453 | num_heads=num_heads, 454 | mlp_ratio=mlp_ratio, 455 | qkv_bias=qkv_bias, 456 | drop=drop_rate, 457 | attn_drop=attn_drop_rate, 458 | drop_path=dpr[i], 459 | norm_layer=norm_layer, 460 | act_layer=act_layer, 461 | ffn_layer=ffn_layer, 462 | init_values=init_values, 463 | ) 464 | for i in range(depth) 465 | ] 466 | ) 467 | 468 | use_fc_norm = self.global_pool == "avg" 469 | self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() 470 | 471 | # Representation layer. Used for original ViT models w/ in21k pretraining. 472 | self.representation_size = representation_size 473 | self.pre_logits = nn.Identity() 474 | if representation_size: 475 | self._reset_representation(representation_size) 476 | 477 | # Classifier Head 478 | self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() 479 | final_chs = self.representation_size if self.representation_size else self.embed_dim 480 | self.head = nn.Linear(final_chs, num_classes) if num_classes > 0 else nn.Identity() 481 | 482 | self.mask_generator = MaskingGenerator( 483 | input_size=(img_size // patch_size, img_size // patch_size), 484 | max_num_patches=0.5 * img_size // patch_size * img_size // patch_size, 485 | ) 486 | self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) 487 | 488 | # if weight_init != "skip": 489 | # self.init_weights(weight_init) 490 | 491 | def _reset_representation(self, representation_size): 492 | self.representation_size = representation_size 493 | if self.representation_size: 494 | self.pre_logits = nn.Sequential( 495 | OrderedDict([("fc", nn.Linear(self.embed_dim, self.representation_size)), ("act", nn.Tanh())]) 496 | ) 497 | else: 498 | self.pre_logits = nn.Identity() 499 | 500 | def init_weights(self, mode=""): 501 | assert mode in ("jax", "jax_nlhb", "moco", "") 502 | head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0 503 | if self.pos_embed is not None: 504 | trunc_normal_(self.pos_embed, std=0.02) 505 | elif self.pos_embeds: 506 | for v in self.pos_embeds.values(): 507 | trunc_normal_(v, std=0.02) 508 | nn.init.normal_(self.cls_token, std=1e-6) 509 | named_apply(get_init_weights_vit(mode, head_bias), self) 510 | 511 | def _init_weights(self, m): 512 | # this fn left here for compat with downstream users 513 | init_weights_vit_timm(m) 514 | 515 | @torch.jit.ignore() 516 | def load_pretrained(self, checkpoint_path, prefix=""): 517 | _load_weights(self, checkpoint_path, prefix) 518 | 519 | @torch.jit.ignore 520 | def no_weight_decay(self): 521 | return {"pos_embed", "cls_token", "dist_token"} 522 | 523 | @torch.jit.ignore 524 | def group_matcher(self, coarse=False): 525 | return dict( 526 | stem=r"^cls_token|pos_embed|patch_embed", # stem and embed 527 | blocks=[(r"^blocks\.(\d+)", None), (r"^norm", (99999,))], 528 | ) 529 | 530 | @torch.jit.ignore 531 | def set_grad_checkpointing(self, enable=True): 532 | self.grad_checkpointing = enable 533 | 534 | @torch.jit.ignore 535 | def get_classifier(self): 536 | return self.head 537 | 538 | def reset_classifier(self, num_classes: int, global_pool=None, representation_size=None): 539 | self.num_classes = num_classes 540 | if global_pool is not None: 541 | assert global_pool in ("", "avg", "token") 542 | self.global_pool = global_pool 543 | if representation_size is not None: 544 | self._reset_representation(representation_size) 545 | final_chs = self.representation_size if self.representation_size else self.embed_dim 546 | self.head = nn.Linear(final_chs, num_classes) if num_classes > 0 else nn.Identity() 547 | 548 | def forward_head(self, x, pre_logits: bool = False): 549 | if self.global_pool: 550 | x = x[:, 1:].mean(dim=1) if self.global_pool == "avg" else x[:, 0] 551 | x = self.fc_norm(x) 552 | x = self.pre_logits(x) 553 | return x if pre_logits else self.head(x) 554 | 555 | def interpolate_pos_encoding(self, x, w, h): 556 | if self.sin_cos_embeddings: 557 | 558 | w0 = w // self.patch_embed.patch_size[0] 559 | step_coef = (w0-1) / 3.14 560 | omega_coef = 10000 561 | sin_cos_embed = get_2d_sincos_pos_embed_cached_device( 562 | embed_dim=x.shape[-1], grid_size=w0, step_coef=step_coef, omega_coef=omega_coef, device=x.device, cls_token=True 563 | ) 564 | 565 | return sin_cos_embed 566 | elif self.multiple_pos_embeddings: 567 | 568 | _m = sum((v.mean() * 0 for v in self.pos_embeds.values())) 569 | pos_embed = self.pos_embeds[str(w)] + _m 570 | class_pos_embed = torch.zeros_like(pos_embed[:1,:1]) 571 | return torch.cat((class_pos_embed, pos_embed), dim=1) 572 | else: 573 | npatch = x.shape[1] - 1 574 | N = self.pos_embed.shape[1] - 1 575 | if npatch == N and w == h: 576 | return self.pos_embed 577 | class_pos_embed = self.pos_embed[:, 0] 578 | patch_pos_embed = self.pos_embed[:, 1:] 579 | dim = x.shape[-1] 580 | w0 = w // self.patch_embed.patch_size[0] 581 | h0 = h // self.patch_embed.patch_size[0] 582 | # we add a small number to avoid floating point error in the interpolation 583 | # see discussion at https://github.com/facebookresearch/dino/issues/8 584 | w0, h0 = w0 + 0.1, h0 + 0.1 585 | 586 | patch_pos_embed = nn.functional.interpolate( 587 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 588 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 589 | mode="bicubic", align_corners=True, recompute_scale_factor=True 590 | ) 591 | 592 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 593 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 594 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) 595 | 596 | def mask_patches_with_probability_p(self, x, mask_ratio_tuple, p): 597 | B, N, _ = x.shape 598 | n_samples_masked = int(B * p) 599 | mask_ratio_min, mask_ratio_max = mask_ratio_tuple 600 | masks = torch.stack( 601 | [ 602 | torch.BoolTensor(self.mask_generator(int(N * random.uniform(mask_ratio_min, mask_ratio_max)))) 603 | for _ in range(0, n_samples_masked) 604 | ] 605 | + [torch.BoolTensor(self.mask_generator(0)) for _ in range(n_samples_masked, B)] 606 | ).to( 607 | x.device 608 | ) 609 | masks = masks[torch.randperm(B, device=x.device)].flatten(1) 610 | x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) 611 | 612 | return x, masks 613 | 614 | def mask_patches_with_probability_p_upperbound(self, x, mask_ratio_tuple, p): 615 | B, N, _ = x.shape 616 | n_samples_masked = int(B * p) 617 | probs = torch.linspace(*mask_ratio_tuple, n_samples_masked + 1) 618 | upperbound = 0 619 | masks_list = [] 620 | for i in range(0, n_samples_masked): 621 | prob_min = probs[i] 622 | prob_max = probs[i+1] 623 | masks_list.append(torch.BoolTensor(self.mask_generator(int(N * random.uniform(prob_min, prob_max))))) 624 | upperbound += int(N * prob_max) 625 | for i in range(n_samples_masked, B): 626 | masks_list.append(torch.BoolTensor(self.mask_generator(0))) 627 | masks = torch.stack(masks_list).to(x.device) 628 | masks = masks[torch.randperm(B, device=x.device)].flatten(1) 629 | x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) 630 | 631 | return x, masks, upperbound 632 | 633 | def prepare_tokens(self, x, mask_ratio_tuple=(0.0, 0.0), mask_sample_probability=0.0, ibot_balanced_masking=False): 634 | B, nc, w, h = x.shape 635 | x = self.patch_embed(x) 636 | masks = None 637 | n_masked_patches_upperbound = None 638 | cls_token = self.cls_token 639 | do_ibot = max(mask_ratio_tuple) > 0.0 and mask_sample_probability > 0.0 640 | if do_ibot: 641 | if ibot_balanced_masking: 642 | logger.debug("using balanced masking") 643 | x, masks, n_masked_patches_upperbound = self.mask_patches_with_probability_p_upperbound( 644 | x, mask_ratio_tuple=mask_ratio_tuple, p=mask_sample_probability 645 | ) 646 | else: 647 | logger.debug("not using balanced masking") 648 | x, masks = self.mask_patches_with_probability_p( 649 | x, mask_ratio_tuple=mask_ratio_tuple, p=mask_sample_probability 650 | ) 651 | else: 652 | cls_token = cls_token + 0 * self.mask_token # hack to use the mask_token param to not crash ddp... 653 | 654 | x = torch.cat((cls_token.expand(x.shape[0], -1, -1), x), dim=1) 655 | x = self.pos_drop(x + self.interpolate_pos_encoding(x, w, h)) 656 | 657 | return x, masks, n_masked_patches_upperbound 658 | 659 | def forward_features(self, x, mask_ratio_tuple=(0.0, 0.0), mask_sample_probability=0.0, ibot_balanced_masking=False): 660 | x, masks, n_masked_patches_upperbound = self.prepare_tokens(x, mask_ratio_tuple, mask_sample_probability, ibot_balanced_masking) 661 | 662 | for blk in self.blocks: 663 | x = blk(x) 664 | 665 | x_norm = self.norm(x) 666 | return { 667 | "x_norm_clstoken": x_norm[:, 0], 668 | "x_norm_patchtokens": x_norm[:, 1:], 669 | "x_prenorm": x, 670 | "masks": masks, 671 | "n_masked_patches_upperbound": n_masked_patches_upperbound, 672 | } 673 | 674 | def get_intermediate_layers(self, x, n=1): 675 | x, _, _ = self.prepare_tokens(x) 676 | # we return the output tokens from the `n` last blocks 677 | output = [] 678 | for i, blk in enumerate(self.blocks): 679 | x = blk(x) 680 | if len(self.blocks) - i <= n: 681 | output.append(self.norm(x)) 682 | return output 683 | 684 | def forward(self, *args, is_training=False, **kwargs): 685 | ret = self.forward_features(*args, **kwargs) 686 | if is_training: 687 | return ret 688 | else: 689 | return ret["x_norm_clstoken"] 690 | 691 | 692 | 693 | class AdaptivePadding(nn.Module): 694 | """Applies padding to input (if needed) so that input can get fully covered 695 | by filter you specified. It support two modes "same" and "corner". The 696 | "same" mode is same with "SAME" padding mode in TensorFlow, pad zero around 697 | input. The "corner" mode would pad zero to bottom right. 698 | Args: 699 | kernel_size (int | tuple): Size of the kernel: 700 | stride (int | tuple): Stride of the filter. Default: 1: 701 | dilation (int | tuple): Spacing between kernel elements. 702 | Default: 1. 703 | padding (str): Support "same" and "corner", "corner" mode 704 | would pad zero to bottom right, and "same" mode would 705 | pad zero around input. Default: "corner". 706 | Example: 707 | >>> kernel_size = 16 708 | >>> stride = 16 709 | >>> dilation = 1 710 | >>> input = torch.rand(1, 1, 15, 17) 711 | >>> adap_pad = AdaptivePadding( 712 | >>> kernel_size=kernel_size, 713 | >>> stride=stride, 714 | >>> dilation=dilation, 715 | >>> padding="corner") 716 | >>> out = adap_pad(input) 717 | >>> assert (out.shape[2], out.shape[3]) == (16, 32) 718 | >>> input = torch.rand(1, 1, 16, 17) 719 | >>> out = adap_pad(input) 720 | >>> assert (out.shape[2], out.shape[3]) == (16, 32) 721 | """ 722 | 723 | def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'): 724 | 725 | super(AdaptivePadding, self).__init__() 726 | 727 | assert padding in ('same', 'corner') 728 | 729 | kernel_size = to_2tuple(kernel_size) 730 | stride = to_2tuple(stride) 731 | dilation = to_2tuple(dilation) 732 | 733 | self.padding = padding 734 | self.kernel_size = kernel_size 735 | self.stride = stride 736 | self.dilation = dilation 737 | 738 | def get_pad_shape(self, input_shape): 739 | input_h, input_w = input_shape 740 | kernel_h, kernel_w = self.kernel_size 741 | stride_h, stride_w = self.stride 742 | output_h = math.ceil(input_h / stride_h) 743 | output_w = math.ceil(input_w / stride_w) 744 | pad_h = max((output_h - 1) * stride_h + 745 | (kernel_h - 1) * self.dilation[0] + 1 - input_h, 0) 746 | pad_w = max((output_w - 1) * stride_w + 747 | (kernel_w - 1) * self.dilation[1] + 1 - input_w, 0) 748 | return pad_h, pad_w 749 | 750 | def forward(self, x): 751 | pad_h, pad_w = self.get_pad_shape(x.size()[-2:]) 752 | if pad_h > 0 or pad_w > 0: 753 | if self.padding == 'corner': 754 | x = F.pad(x, [0, pad_w, 0, pad_h]) 755 | elif self.padding == 'same': 756 | x = F.pad(x, [ 757 | pad_w // 2, pad_w - pad_w // 2, pad_h // 2, 758 | pad_h - pad_h // 2 759 | ]) 760 | return x 761 | 762 | 763 | 764 | class SSLVisionTransformer(DinoVisionTransformer): 765 | """Vision Transformer. 766 | """ 767 | 768 | def __init__(self, 769 | interpolate_mode='bicubic', 770 | init_cfg=None, 771 | pretrained=None, 772 | img_size=224, 773 | patch_size=16, 774 | #embed_dim=1024, 775 | #depth=24, 776 | #num_heads=16, 777 | mlp_ratio=4, 778 | qkv_bias=True, 779 | init_values=1., 780 | out_indices=(4, 11, 17, 23), 781 | final_norm=False, 782 | with_cls_token=True, 783 | output_cls_token=True, 784 | frozen_stages=100, 785 | *args, **kwargs): 786 | super(SSLVisionTransformer, self).__init__(*args, **kwargs) 787 | 788 | if output_cls_token: 789 | assert with_cls_token is True, f'with_cls_token must be True if' \ 790 | f'set output_cls_token to True, but got {with_cls_token}' 791 | 792 | assert not (init_cfg and pretrained), \ 793 | 'init_cfg and pretrained cannot be set at the same time' 794 | if isinstance(pretrained, str): 795 | warnings.warn('DeprecationWarning: pretrained is deprecated, ' 796 | 'please use "init_cfg" instead') 797 | self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) 798 | elif pretrained is not None: 799 | raise TypeError('pretrained must be a str or None') 800 | 801 | 802 | if len(self.blocks)==1: 803 | self.blocks = self.blocks[0] 804 | if isinstance(out_indices, int): 805 | if out_indices == -1: 806 | out_indices = len(self.blocks) - 1 807 | self.out_indices = [out_indices] 808 | elif isinstance(out_indices, list) or isinstance(out_indices, tuple): 809 | self.out_indices = out_indices 810 | else: 811 | raise TypeError('out_indices must be type of int, list or tuple') 812 | 813 | self.interpolate_mode = interpolate_mode 814 | self.pretrained = pretrained 815 | self.frozen_stages = frozen_stages 816 | self.detach = False 817 | self.with_cls_token = with_cls_token 818 | self.output_cls_token = output_cls_token 819 | self.final_norm = final_norm 820 | self.patch_size = self.patch_embed.patch_size 821 | self.adapad = AdaptivePadding(kernel_size=self.patch_size, stride=self.patch_size, padding='same') 822 | if pretrained: 823 | self.init_weights(pretrained) 824 | 825 | self._freeze_stages() 826 | 827 | @staticmethod 828 | def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode): 829 | """Resize pos_embed weights. 830 | Resize pos_embed using bicubic interpolate method. 831 | Args: 832 | pos_embed (torch.Tensor): Position embedding weights. 833 | input_shpae (tuple): Tuple for (downsampled input image height, 834 | downsampled input image width). 835 | pos_shape (tuple): The resolution of downsampled origin training 836 | image. 837 | mode (str): Algorithm used for upsampling: 838 | ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | 839 | ``'trilinear'``. Default: ``'nearest'`` 840 | Return: 841 | torch.Tensor: The resized pos_embed of shape [B, L_new, C] 842 | """ 843 | assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]' 844 | pos_h, pos_w = pos_shape 845 | cls_token_weight = pos_embed[:, 0] 846 | pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):] 847 | pos_embed_weight = pos_embed_weight.reshape( 848 | 1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2) 849 | pos_embed_weight = resize( 850 | pos_embed_weight, size=input_shpae, align_corners=False, mode=mode) 851 | cls_token_weight = cls_token_weight.unsqueeze(1) 852 | pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2) 853 | pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1) 854 | return pos_embed 855 | 856 | def init_weights(self, pretrained): 857 | print("init_weights", pretrained) 858 | if (isinstance(self.init_cfg, dict) 859 | and self.init_cfg.get('type') == 'Pretrained'): 860 | 861 | checkpoint = torch.load(pretrained, map_location='cpu') 862 | if 'state_dict' in checkpoint: 863 | # timm checkpoint 864 | state_dict = checkpoint['state_dict'] 865 | elif 'model' in checkpoint: 866 | # deit checkpoint 867 | state_dict = checkpoint['model'] 868 | elif 'teacher' in checkpoint: 869 | # dino eval checkpoint 870 | state_dict = checkpoint['teacher'] 871 | else: 872 | state_dict = checkpoint 873 | 874 | if len([k for k in state_dict.keys() if 'teacher.backbone.' in k]) > 0: 875 | state_dict = {k.replace('teacher.backbone.', ''):v for k,v in state_dict.items() if 'teacher.backbone' in k} 876 | if len([k for k in state_dict.keys() if 'backbone.' in k]) > 0: 877 | state_dict = {k.replace('backbone.', ''):v for k,v in state_dict.items()} 878 | 879 | if 'pos_embed' in state_dict.keys(): 880 | if self.pos_embed.shape != state_dict['pos_embed'].shape: 881 | print(f'Resize the pos_embed shape from ' 882 | f'{state_dict["pos_embed"].shape} to ' 883 | f'{self.pos_embed.shape}') 884 | h, w = (224, 224) # self.img_size 885 | pos_size = int( 886 | math.sqrt(state_dict['pos_embed'].shape[1] - 1)) 887 | state_dict['pos_embed'] = self.resize_pos_embed( 888 | state_dict['pos_embed'], 889 | (h // self.patch_size[0], w // self.patch_size[1]), 890 | (pos_size, pos_size), self.interpolate_mode) 891 | self.load_state_dict(state_dict) 892 | else: 893 | super(SSLVisionTransformer, self).init_weights() 894 | 895 | 896 | def forward(self, x): 897 | 898 | with torch.set_grad_enabled(not self.detach): 899 | _, _, old_w, old_h = x.shape 900 | xx = self.adapad(x) 901 | 902 | x = F.pad(x, (0, xx.shape[-1] - x.shape[-1], 0, xx.shape[-2] - x.shape[-2])) 903 | B, nc, w, h = x.shape 904 | 905 | x, _, _ = self.prepare_tokens(x) 906 | # we return the output tokens from the `n` last blocks 907 | outs = [] 908 | for i, blk in enumerate(self.blocks): 909 | x = blk(x) 910 | if i in self.out_indices: 911 | if self.with_cls_token: 912 | out = x[:, 1:] 913 | else: 914 | out = x 915 | B, _, C = out.shape 916 | out = out.reshape(B, w // self.patch_size[0], h // self.patch_size[1], 917 | C).permute(0, 3, 1, 2).contiguous() 918 | if self.output_cls_token: 919 | out = [out, x[:, 0]] 920 | else: 921 | out = [out] 922 | if self.final_norm: 923 | out = [self.norm(o) for o in out] 924 | if self.detach: 925 | out = [o.detach() for o in out] 926 | outs.append(out) 927 | return tuple(outs) 928 | 929 | def train(self, mode=True): 930 | super(SSLVisionTransformer, self).train(mode) 931 | self.detach = False 932 | self._freeze_stages() 933 | 934 | def _freeze_stages(self): 935 | """Freeze stages param and norm stats.""" 936 | if self.frozen_stages >= 0: 937 | self.patch_embed.eval() 938 | for m in [self.patch_embed]: 939 | for param in m.parameters(): 940 | param.requires_grad = False 941 | self.cls_token.requires_grad = False 942 | self.pos_embed.requires_grad = False 943 | self.mask_token.requires_grad = False 944 | 945 | if self.frozen_stages >= len(self.blocks) - 1: 946 | self.norm.eval() 947 | for param in self.norm.parameters(): 948 | param.requires_grad = False 949 | self.detach = True 950 | 951 | for i, layer in enumerate(self.blocks): 952 | if i <= self.frozen_stages: 953 | layer.eval() 954 | for param in layer.parameters(): 955 | param.requires_grad = False 956 | 957 | 958 | -------------------------------------------------------------------------------- /models/dpt_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import torch 8 | from torch import nn 9 | import torchvision 10 | 11 | from models.backbone import resize 12 | 13 | def kaiming_init(module: nn.Module, 14 | a: float = 0, 15 | mode: str = 'fan_out', 16 | nonlinearity: str = 'relu', 17 | bias: float = 0, 18 | distribution: str = 'normal') -> None: 19 | assert distribution in ['uniform', 'normal'] 20 | if hasattr(module, 'weight') and module.weight is not None: 21 | if distribution == 'uniform': 22 | nn.init.kaiming_uniform_( 23 | module.weight, a=a, mode=mode, nonlinearity=nonlinearity) 24 | else: 25 | nn.init.kaiming_normal_( 26 | module.weight, a=a, mode=mode, nonlinearity=nonlinearity) 27 | if hasattr(module, 'bias') and module.bias is not None: 28 | nn.init.constant_(module.bias, bias) 29 | 30 | class ConvModule(nn.Module): 31 | """A conv block that bundles conv/norm/activation layers. 32 | This block simplifies the usage of convolution layers, which are commonly 33 | used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU). 34 | It is based upon three build methods: `build_conv_layer()`, 35 | `build_norm_layer()` and `build_activation_layer()`. 36 | Besides, we add some additional features in this module. 37 | 1. Automatically set `bias` of the conv layer. 38 | 2. Spectral norm is supported. 39 | 3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only 40 | supports zero and circular padding, and we add "reflect" padding mode. 41 | Args: 42 | in_channels (int): Number of channels in the input feature map. 43 | Same as that in ``nn._ConvNd``. 44 | out_channels (int): Number of channels produced by the convolution. 45 | Same as that in ``nn._ConvNd``. 46 | kernel_size (int | tuple[int]): Size of the convolving kernel. 47 | Same as that in ``nn._ConvNd``. 48 | stride (int | tuple[int]): Stride of the convolution. 49 | Same as that in ``nn._ConvNd``. 50 | padding (int | tuple[int]): Zero-padding added to both sides of 51 | the input. Same as that in ``nn._ConvNd``. 52 | dilation (int | tuple[int]): Spacing between kernel elements. 53 | Same as that in ``nn._ConvNd``. 54 | groups (int): Number of blocked connections from input channels to 55 | output channels. Same as that in ``nn._ConvNd``. 56 | bias (bool | str): If specified as `auto`, it will be decided by the 57 | norm_cfg. Bias will be set as True if `norm_cfg` is None, otherwise 58 | False. Default: "auto". 59 | conv_cfg (dict): Config dict for convolution layer. Default: None, 60 | which means using conv2d. 61 | norm_cfg (dict): Config dict for normalization layer. Default: None. 62 | act_cfg (dict): Config dict for activation layer. 63 | Default: dict(type='ReLU'). 64 | inplace (bool): Whether to use inplace mode for activation. 65 | Default: True. 66 | with_spectral_norm (bool): Whether use spectral norm in conv module. 67 | Default: False. 68 | padding_mode (str): If the `padding_mode` has not been supported by 69 | current `Conv2d` in PyTorch, we will use our own padding layer 70 | instead. Currently, we support ['zeros', 'circular'] with official 71 | implementation and ['reflect'] with our own implementation. 72 | Default: 'zeros'. 73 | order (tuple[str]): The order of conv/norm/activation layers. It is a 74 | sequence of "conv", "norm" and "act". Common examples are 75 | ("conv", "norm", "act") and ("act", "conv", "norm"). 76 | Default: ('conv', 'norm', 'act'). 77 | """ 78 | 79 | _abbr_ = 'conv_block' 80 | 81 | def __init__(self, 82 | in_channels, 83 | out_channels, 84 | kernel_size, 85 | stride = 1, 86 | padding = 0, 87 | dilation = 1, 88 | groups = 1, 89 | bias = 'auto', 90 | conv_cfg = None, 91 | norm_cfg = None, 92 | act_cfg = dict(type='ReLU'), 93 | inplace= True, 94 | with_spectral_norm = False, 95 | padding_mode = 'zeros', 96 | order = ('conv', 'norm', 'act')): 97 | super().__init__() 98 | assert conv_cfg is None or isinstance(conv_cfg, dict) 99 | assert norm_cfg is None or isinstance(norm_cfg, dict) 100 | assert act_cfg is None or isinstance(act_cfg, dict) 101 | official_padding_mode = ['zeros', 'circular'] 102 | self.conv_cfg = conv_cfg 103 | self.norm_cfg = norm_cfg 104 | self.act_cfg = act_cfg 105 | self.inplace = inplace 106 | self.with_spectral_norm = with_spectral_norm 107 | self.with_explicit_padding = padding_mode not in official_padding_mode 108 | self.order = order 109 | assert isinstance(self.order, tuple) and len(self.order) == 3 110 | assert set(order) == {'conv', 'norm', 'act'} 111 | 112 | self.with_norm = norm_cfg is not None 113 | self.with_activation = act_cfg is not None 114 | # if the conv layer is before a norm layer, bias is unnecessary. 115 | if bias == 'auto': 116 | bias = not self.with_norm 117 | self.with_bias = bias 118 | 119 | if self.with_explicit_padding: 120 | pad_cfg = dict(type=padding_mode) 121 | self.padding_layer = build_padding_layer(pad_cfg, padding) 122 | # to do Camille put back 123 | 124 | # reset padding to 0 for conv module 125 | conv_padding = 0 if self.with_explicit_padding else padding 126 | # build convolution layer 127 | self.conv = nn.Conv2d( #build_conv_layer(#conv_cfg, 128 | in_channels, 129 | out_channels, 130 | kernel_size, 131 | stride=stride, 132 | padding=conv_padding, 133 | dilation=dilation, 134 | groups=groups, 135 | bias=bias) 136 | # export the attributes of self.conv to a higher level for convenience 137 | self.in_channels = self.conv.in_channels 138 | self.out_channels = self.conv.out_channels 139 | self.kernel_size = self.conv.kernel_size 140 | self.stride = self.conv.stride 141 | self.padding = padding 142 | self.dilation = self.conv.dilation 143 | self.transposed = self.conv.transposed 144 | self.output_padding = self.conv.output_padding 145 | self.groups = self.conv.groups 146 | 147 | if self.with_spectral_norm: 148 | self.conv = nn.utils.spectral_norm(self.conv) 149 | 150 | self.norm_name = None # type: ignore 151 | 152 | # build activation layer 153 | if self.with_activation: 154 | act_cfg_ = act_cfg.copy() # type: ignore 155 | # nn.Tanh has no 'inplace' argument 156 | if act_cfg_['type'] not in [ 157 | 'Tanh', 'PReLU', 'Sigmoid', 'HSigmoid', 'Swish', 'GELU' 158 | ]: 159 | act_cfg_.setdefault('inplace', inplace) 160 | self.activate = nn.ReLU() # build_activation_layer(act_cfg_) 161 | 162 | # Use msra init by default 163 | torch.manual_seed(1) 164 | self.init_weights() 165 | 166 | @property 167 | def norm(self): 168 | if self.norm_name: 169 | return getattr(self, self.norm_name) 170 | else: 171 | return None 172 | 173 | def init_weights(self): 174 | # 1. It is mainly for customized conv layers with their own 175 | # initialization manners by calling their own ``init_weights()``, 176 | # and we do not want ConvModule to override the initialization. 177 | # 2. For customized conv layers without their own initialization 178 | # manners (that is, they don't have their own ``init_weights()``) 179 | # and PyTorch's conv layers, they will be initialized by 180 | # this method with default ``kaiming_init``. 181 | # Note: For PyTorch's conv layers, they will be overwritten by our 182 | # initialization implementation using default ``kaiming_init``. 183 | if not hasattr(self.conv, 'init_weights'): 184 | if self.with_activation and self.act_cfg['type'] == 'LeakyReLU': 185 | nonlinearity = 'leaky_relu' 186 | a = self.act_cfg.get('negative_slope', 0.01) 187 | else: 188 | nonlinearity = 'relu' 189 | a = 0 190 | kaiming_init(self.conv, a=a, nonlinearity=nonlinearity) 191 | if self.with_norm: 192 | constant_init(self.norm, 1, bias=0) 193 | 194 | def forward(self, 195 | x: torch.Tensor, 196 | activate: bool = True, 197 | norm: bool = True, 198 | debug: bool = False) -> torch.Tensor: 199 | 200 | for layer in self.order: 201 | if debug==True: 202 | breakpoint() 203 | if layer == 'conv': 204 | if self.with_explicit_padding: 205 | x = self.padding_layer(x) 206 | x = self.conv(x) 207 | elif layer == 'norm' and norm and self.with_norm: 208 | x = self.norm(x) 209 | elif layer == 'act' and activate and self.with_activation: 210 | x = self.activate(x) 211 | return x 212 | 213 | 214 | class Interpolate(nn.Module): 215 | def __init__(self, scale_factor, mode, align_corners=False): 216 | super(Interpolate, self).__init__() 217 | self.interp = nn.functional.interpolate 218 | self.scale_factor = scale_factor 219 | self.mode = mode 220 | self.align_corners = align_corners 221 | 222 | def forward(self, x): 223 | x = self.interp( 224 | x, 225 | scale_factor=self.scale_factor, 226 | mode=self.mode, 227 | align_corners=self.align_corners) 228 | return x 229 | 230 | class HeadDepth(nn.Module): 231 | def __init__(self, features, classify=False, n_bins=256): 232 | super(HeadDepth, self).__init__() 233 | self.head = nn.Sequential( 234 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), 235 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 236 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 237 | nn.ReLU(), 238 | nn.Conv2d(32, 1 if not classify else n_bins, kernel_size=1, stride=1, padding=0), 239 | ) 240 | def forward(self, x): 241 | x = self.head(x) 242 | return x 243 | 244 | 245 | class ReassembleBlocks(nn.Module): 246 | """ViTPostProcessBlock, process cls_token in ViT backbone output and 247 | rearrange the feature vector to feature map. 248 | Args: 249 | in_channels (int): ViT feature channels. Default: 768. 250 | out_channels (List): output channels of each stage. 251 | Default: [96, 192, 384, 768]. 252 | readout_type (str): Type of readout operation. Default: 'ignore'. 253 | patch_size (int): The patch size. Default: 16. 254 | init_cfg (dict, optional): Initialization config dict. Default: None. 255 | """ 256 | def __init__(self, 257 | in_channels=1024, #768, 258 | out_channels=[128, 256, 512, 1024], #[96, 192, 384, 768], 259 | readout_type='project', # 'ignore', 260 | patch_size=16): 261 | super(ReassembleBlocks, self).__init__()#init_cfg) 262 | 263 | assert readout_type in ['ignore', 'add', 'project'] 264 | self.readout_type = readout_type 265 | self.patch_size = patch_size 266 | 267 | self.projects = nn.ModuleList([ 268 | ConvModule( 269 | in_channels=in_channels, 270 | out_channels=out_channel, 271 | kernel_size=1, 272 | act_cfg=None, 273 | ) for out_channel in out_channels 274 | ]) 275 | 276 | self.resize_layers = nn.ModuleList([ 277 | nn.ConvTranspose2d( 278 | in_channels=out_channels[0], 279 | out_channels=out_channels[0], 280 | kernel_size=4, 281 | stride=4, 282 | padding=0), 283 | nn.ConvTranspose2d( 284 | in_channels=out_channels[1], 285 | out_channels=out_channels[1], 286 | kernel_size=2, 287 | stride=2, 288 | padding=0), 289 | nn.Identity(), 290 | nn.Conv2d( 291 | in_channels=out_channels[3], 292 | out_channels=out_channels[3], 293 | kernel_size=3, 294 | stride=2, 295 | padding=1) 296 | ]) 297 | if self.readout_type == 'project': 298 | self.readout_projects = nn.ModuleList() 299 | for _ in range(len(self.projects)): 300 | self.readout_projects.append( 301 | nn.Sequential( 302 | nn.Linear(2 * in_channels, in_channels), 303 | nn.GELU())) 304 | #build_activation_layer(dict(type='GELU')))) 305 | 306 | def forward(self, inputs): 307 | assert isinstance(inputs, list) 308 | out = [] 309 | for i, x in enumerate(inputs): 310 | assert len(x) == 2 311 | x, cls_token = x[0], x[1] 312 | feature_shape = x.shape 313 | if self.readout_type == 'project': 314 | x = x.flatten(2).permute((0, 2, 1)) 315 | readout = cls_token.unsqueeze(1).expand_as(x) 316 | x = self.readout_projects[i](torch.cat((x, readout), -1)) 317 | x = x.permute(0, 2, 1).reshape(feature_shape) 318 | elif self.readout_type == 'add': 319 | x = x.flatten(2) + cls_token.unsqueeze(-1) 320 | x = x.reshape(feature_shape) 321 | else: 322 | pass 323 | x = self.projects[i](x) 324 | x = self.resize_layers[i](x) 325 | out.append(x) 326 | return out 327 | 328 | 329 | class PreActResidualConvUnit(nn.Module): 330 | """ResidualConvUnit, pre-activate residual unit. 331 | Args: 332 | in_channels (int): number of channels in the input feature map. 333 | act_cfg (dict): dictionary to construct and config activation layer. 334 | norm_cfg (dict): dictionary to construct and config norm layer. 335 | stride (int): stride of the first block. Default: 1 336 | dilation (int): dilation rate for convs layers. Default: 1. 337 | init_cfg (dict, optional): Initialization config dict. Default: None. 338 | """ 339 | 340 | def __init__(self, 341 | in_channels, 342 | act_cfg, 343 | norm_cfg, 344 | stride=1, 345 | dilation=1, 346 | init_cfg=None): 347 | super(PreActResidualConvUnit, self).__init__()#init_cfg) 348 | self.conv1 = ConvModule( 349 | in_channels, 350 | in_channels, 351 | 3, 352 | stride=stride, 353 | padding=dilation, 354 | dilation=dilation, 355 | norm_cfg=norm_cfg, 356 | act_cfg=act_cfg, 357 | bias=False, 358 | order=('act', 'conv', 'norm')) 359 | self.conv2 = ConvModule( 360 | in_channels, 361 | in_channels, 362 | 3, 363 | padding=1, 364 | norm_cfg=norm_cfg, 365 | act_cfg=act_cfg, 366 | bias=False, 367 | order=('act', 'conv', 'norm')) 368 | def forward(self, inputs): 369 | inputs_ = inputs.clone() 370 | x = self.conv1(inputs) 371 | x = self.conv2(x) 372 | return x + inputs_ 373 | 374 | 375 | class FeatureFusionBlock(nn.Module): 376 | """FeatureFusionBlock, merge feature map from different stages. 377 | Args: 378 | in_channels (int): Input channels. 379 | act_cfg (dict): The activation config for ResidualConvUnit. 380 | norm_cfg (dict): Config dict for normalization layer. 381 | expand (bool): Whether expand the channels in post process block. 382 | Default: False. 383 | align_corners (bool): align_corner setting for bilinear upsample. 384 | Default: True. 385 | init_cfg (dict, optional): Initialization config dict. Default: None. 386 | """ 387 | 388 | def __init__(self, 389 | in_channels, 390 | act_cfg, 391 | norm_cfg, 392 | expand=False, 393 | align_corners=True, 394 | init_cfg=None): 395 | super(FeatureFusionBlock, self).__init__()#init_cfg) 396 | self.in_channels = in_channels 397 | self.expand = expand 398 | self.align_corners = align_corners 399 | self.out_channels = in_channels 400 | if self.expand: 401 | self.out_channels = in_channels // 2 402 | self.project = ConvModule( 403 | self.in_channels, 404 | self.out_channels, 405 | kernel_size=1, 406 | act_cfg=None, 407 | bias=True) 408 | self.res_conv_unit1 = PreActResidualConvUnit( 409 | in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg) 410 | self.res_conv_unit2 = PreActResidualConvUnit( 411 | in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg) 412 | 413 | def forward(self, *inputs): 414 | x = inputs[0] 415 | 416 | if len(inputs) == 2: 417 | if x.shape != inputs[1].shape: 418 | res = resize( 419 | inputs[1], 420 | size=(x.shape[2], x.shape[3]), 421 | mode='bilinear', 422 | align_corners=False) 423 | else: 424 | res = inputs[1] 425 | x = x + self.res_conv_unit1(res) 426 | x = self.res_conv_unit2(x) 427 | x = resize( x, scale_factor=2, mode='bilinear', align_corners=self.align_corners) 428 | x = self.project(x) 429 | return x 430 | 431 | class DPTHead(nn.Module): 432 | """Vision Transformers for Dense Prediction. 433 | This head is implemented of `DPT `_. 434 | Args: 435 | embed_dims (int): The embed dimension of the ViT backbone. 436 | Default: 768. 437 | post_process_channels (List): Out channels of post process conv 438 | layers. Default: [96, 192, 384, 768]. 439 | readout_type (str): Type of readout operation. Default: 'ignore'. 440 | patch_size (int): The patch size. Default: 16. 441 | expand_channels (bool): Whether expand the channels in post process 442 | block. Default: False. 443 | """ 444 | 445 | def __init__(self, 446 | in_channels=(1024, 1024, 1024, 1024), 447 | channels=256, 448 | embed_dims=1024, 449 | post_process_channels=[128, 256, 512, 1024], 450 | readout_type='project', 451 | patch_size=16, 452 | expand_channels=False, 453 | min_depth = 0.001, 454 | classify=False, 455 | n_bins=256, 456 | **kwargs): 457 | super(DPTHead, self).__init__(**kwargs) 458 | torch.manual_seed(1) 459 | self.channels = channels 460 | self.norm_cfg = None 461 | self.min_depth = min_depth 462 | self.max_depth = 10 463 | self.n_bins = n_bins 464 | self.classify = classify 465 | self.in_channels = in_channels 466 | self.expand_channels = expand_channels 467 | self.reassemble_blocks = ReassembleBlocks(in_channels=embed_dims, # Camille 23-06-26 468 | out_channels=post_process_channels) # Camille 23-06-26 469 | 470 | self.post_process_channels = [ 471 | channel * math.pow(2, i) if expand_channels else channel 472 | for i, channel in enumerate(post_process_channels) 473 | ] 474 | self.convs = nn.ModuleList() 475 | for channel in self.post_process_channels: 476 | self.convs.append( 477 | ConvModule( 478 | channel, 479 | self.channels, 480 | kernel_size=3, 481 | padding=1, 482 | act_cfg=None, 483 | bias=False)) 484 | self.fusion_blocks = nn.ModuleList() 485 | self.act_cfg = {'type': 'ReLU'} 486 | for _ in range(len(self.convs)): 487 | self.fusion_blocks.append( 488 | FeatureFusionBlock(self.channels, self.act_cfg, self.norm_cfg)) 489 | self.fusion_blocks[0].res_conv_unit1 = None 490 | torch.manual_seed(1) 491 | self.project = ConvModule( 492 | self.channels, 493 | self.channels, 494 | kernel_size=3, 495 | padding=1, 496 | norm_cfg=None) 497 | self.num_fusion_blocks = len(self.fusion_blocks) 498 | self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers) 499 | self.num_post_process_channels = len(self.post_process_channels) 500 | assert self.num_fusion_blocks == self.num_reassemble_blocks 501 | assert self.num_reassemble_blocks == self.num_post_process_channels 502 | #self.conv_depth = HeadDepth(self.channels) 503 | self.conv_depth = HeadDepth(self.channels, self.classify, self.n_bins) 504 | self.relu = nn.ReLU() 505 | self.sigmoid = nn.Sigmoid() 506 | 507 | 508 | def forward(self, inputs): 509 | 510 | assert len(inputs) == self.num_reassemble_blocks 511 | x = [inp for inp in inputs] 512 | 513 | x = self.reassemble_blocks(x) 514 | x = [self.convs[i](feature) for i, feature in enumerate(x)] 515 | out = self.fusion_blocks[0](x[-1]) 516 | 517 | for i in range(1, len(self.fusion_blocks)): 518 | out = self.fusion_blocks[i](out, x[-(i + 1)]) 519 | 520 | out = self.project(out) 521 | if self.classify: 522 | logit = self.conv_depth(out) 523 | 524 | #if self.bins_strategy == 'UD': 525 | bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=inputs[0][0].device) 526 | #linear strategy 527 | logit = torch.relu(logit) 528 | eps = 0.1 529 | logit = logit + eps 530 | logit = logit / logit.sum(dim=1, keepdim=True) 531 | out = torch.einsum('ikmn,k->imn', [logit, bins]).unsqueeze(dim=1) #+ self.min_depth 532 | else: 533 | out = self.relu(self.conv_depth(out)) + self.min_depth 534 | 535 | return out 536 | 537 | -------------------------------------------------------------------------------- /models/regressor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torchvision 9 | 10 | class RNet(nn.Module): 11 | def __init__( 12 | self, 13 | n_channels=3, 14 | n_classes=13, 15 | n_pix=256, 16 | filters=(8, 16, 32, 64, 64, 128), 17 | pool=(2, 2), 18 | kernel_size=(3, 3), 19 | n_meta=0, 20 | ) -> None: 21 | super(RNet, self).__init__() 22 | 23 | def conv_block(in_filters, out_filters, kernel_size): 24 | layers = nn.Sequential( 25 | # first conv is across channels, size=1 26 | nn.Conv2d(in_filters, out_filters, kernel_size=(1, 1), padding="same"), 27 | nn.BatchNorm2d(out_filters), 28 | nn.ReLU(), 29 | nn.Conv2d( 30 | out_filters, out_filters, kernel_size=kernel_size, padding="same" 31 | ), 32 | ) 33 | return layers 34 | 35 | def fc_block(in_features, out_features): 36 | layers = nn.Sequential( 37 | nn.Linear(in_features=in_features, out_features=out_features), 38 | #nn.BatchNorm1d(out_features), 39 | #nn.InstanceNorm1d(out_features), 40 | nn.ReLU(), 41 | ) 42 | return layers 43 | 44 | self.pool = nn.MaxPool2d(2, 2) 45 | self.input_layer = conv_block(n_channels, filters[0], kernel_size) 46 | self.conv_block1 = conv_block(filters[0], filters[1], kernel_size) 47 | self.conv_block2 = conv_block(filters[1], filters[2], kernel_size) 48 | self.conv_block3 = conv_block(filters[2], filters[3], kernel_size) 49 | self.conv_block4 = conv_block(filters[3], filters[4], kernel_size) 50 | self.conv_block5 = conv_block(filters[4], filters[5], kernel_size) 51 | n_pool = 5 52 | self.fc1 = fc_block(in_features= int(filters[5] * (n_pix / 2**n_pool) ** 2), out_features=64) 53 | self.fc2 = fc_block(in_features=64 + n_meta, out_features=64) 54 | self.fc3 = fc_block(in_features=64, out_features=32) 55 | self.fc4 = nn.Linear(in_features=32, out_features=n_classes) 56 | 57 | def forward(self, x): 58 | x1 = self.pool(self.input_layer(x)) 59 | x2 = self.pool(self.conv_block1(x1)) 60 | x3 = self.pool(self.conv_block2(x2)) 61 | x4 = self.pool(self.conv_block3(x3)) 62 | x4b = self.pool(self.conv_block4(x4)) 63 | x5 = self.conv_block5(x4b) 64 | x6 = torch.flatten(x5, 1) # flatten all dimensions except batch 65 | x7 = self.fc1(x6) 66 | x9 = self.fc2(x7) 67 | x10 = self.fc3(x9) 68 | x11 = self.fc4(x10) 69 | return x11 -------------------------------------------------------------------------------- /notebooks/README.md: -------------------------------------------------------------------------------- 1 | # High Resolution Canopy Height Maps Notebooks 2 | 3 | **[Meta AI Research, FAIR](https://ai.facebook.com/research/)** 4 | 5 | ## Setup 6 | 7 | Install Conda (https://docs.conda.io/projects/conda/en/latest/index.html#) 8 | 9 | ## Example of successful environment creation for notebooks 10 | ``` 11 | conda create -n chm_demo python=3.9 12 | conda activate chm_demo 13 | conda install pytorch==2.0.1 -c pytorch 14 | conda install torchvision -c pytorch 15 | conda install conda-forge::pytorch-lightning==1.7 16 | conda install torchmetrics==0.11.4 17 | conda install geopandas jupyter rasterio boto3 scikit-image 18 | 19 | jupyter notebook 20 | ``` 21 | 22 | Alternatively, create conda env from yaml file: 23 | ``` 24 | conda env create -f chm_demo.yml 25 | ``` -------------------------------------------------------------------------------- /notebooks/chm_demo.yml: -------------------------------------------------------------------------------- 1 | name: chm_demo2 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - affine=2.3.0=pyhd3eb1b0_0 7 | - anyio=4.6.2=py39hca03da5_0 8 | - aom=3.6.0=h313beb8_0 9 | - appnope=0.1.2=py39hca03da5_1001 10 | - argon2-cffi=21.3.0=pyhd3eb1b0_0 11 | - argon2-cffi-bindings=21.2.0=py39h80987f9_1 12 | - asttokens=2.0.5=pyhd3eb1b0_0 13 | - async-lru=2.0.4=py39hca03da5_0 14 | - attrs=24.3.0=py39hca03da5_0 15 | - babel=2.11.0=py39hca03da5_0 16 | - backcall=0.2.0=pyhd3eb1b0_0 17 | - beautifulsoup4=4.12.3=py39hca03da5_0 18 | - blas=1.0=openblas 19 | - bleach=6.2.0=py39hca03da5_0 20 | - blosc=1.21.3=h313beb8_0 21 | - boost-cpp=1.82.0=h48ca7d4_2 22 | - boto3=1.34.154=py39hca03da5_0 23 | - botocore=1.34.154=py39hca03da5_0 24 | - bottleneck=1.4.2=py39hbda83bc_0 25 | - branca=0.6.0=py39hca03da5_0 26 | - brotli=1.0.9=h80987f9_9 27 | - brotli-bin=1.0.9=h80987f9_9 28 | - brotli-python=1.0.9=py39h313beb8_9 29 | - brunsli=0.1=hc377ac9_1 30 | - bzip2=1.0.8=h80987f9_6 31 | - c-ares=1.19.1=h80987f9_0 32 | - ca-certificates=2025.2.25=hca03da5_0 33 | - cairo=1.16.0=h302bd0f_5 34 | - certifi=2025.1.31=py39hca03da5_0 35 | - cffi=1.17.1=py39h3eb5a62_1 36 | - cfitsio=3.470=h7f6438f_7 37 | - charls=2.2.0=hc377ac9_0 38 | - charset-normalizer=3.3.2=pyhd3eb1b0_0 39 | - click=8.1.7=py39hca03da5_0 40 | - click-plugins=1.1.1=pyhd3eb1b0_0 41 | - cligj=0.7.2=pyhd3eb1b0_0 42 | - comm=0.2.1=py39hca03da5_0 43 | - contourpy=1.2.0=py39h48ca7d4_0 44 | - cryptography=43.0.3=py39h8026fc7_1 45 | - cycler=0.11.0=pyhd3eb1b0_0 46 | - cyrus-sasl=2.1.28=h9131b1a_1 47 | - dav1d=1.2.1=h80987f9_0 48 | - debugpy=1.8.11=py39h313beb8_0 49 | - decorator=5.1.1=pyhd3eb1b0_0 50 | - defusedxml=0.7.1=pyhd3eb1b0_0 51 | - exceptiongroup=1.2.0=py39hca03da5_0 52 | - executing=0.8.3=pyhd3eb1b0_0 53 | - expat=2.6.4=h313beb8_0 54 | - filelock=3.13.1=py39hca03da5_0 55 | - fiona=1.9.5=py39h46d7db6_0 56 | - folium=0.14.0=py39hca03da5_0 57 | - fontconfig=2.14.1=h6402c1e_3 58 | - fonttools=4.51.0=py39h80987f9_0 59 | - freetype=2.12.1=h1192e45_0 60 | - freexl=2.0.0=ha3de405_0 61 | - fsspec=2024.12.0=py39hca03da5_0 62 | - gdal=3.6.2=py39h8924233_9 63 | - geopandas=0.14.2=py39hca03da5_0 64 | - geopandas-base=0.14.2=py39hca03da5_0 65 | - geos=3.10.6=h313beb8_0 66 | - geotiff=1.7.0=h41f0982_3 67 | - gettext=0.21.0=hbdbcc25_2 68 | - giflib=5.2.2=h80987f9_0 69 | - glib=2.78.4=h313beb8_0 70 | - glib-tools=2.78.4=h313beb8_0 71 | - gmp=6.2.1=hc377ac9_3 72 | - gmpy2=2.1.2=py39h8c48613_0 73 | - gst-plugins-base=1.14.1=h313beb8_1 74 | - gstreamer=1.14.1=h80987f9_1 75 | - h11=0.14.0=py39hca03da5_0 76 | - hdf4=4.2.13=h5e329fb_3 77 | - hdf5=1.12.1=h05c076b_3 78 | - httpcore=1.0.2=py39hca03da5_0 79 | - httpx=0.27.0=py39hca03da5_0 80 | - icu=73.1=h313beb8_0 81 | - idna=3.7=py39hca03da5_0 82 | - imagecodecs=2024.9.22=py39hb4ce95a_0 83 | - imageio=2.37.0=py39hca03da5_0 84 | - importlib-metadata=8.5.0=py39hca03da5_0 85 | - importlib_metadata=8.5.0=hd3eb1b0_0 86 | - importlib_resources=6.4.0=py39hca03da5_0 87 | - ipykernel=6.29.5=py39hca03da5_0 88 | - ipython=8.15.0=py39hca03da5_0 89 | - ipywidgets=8.1.5=py39hca03da5_0 90 | - jedi=0.19.2=py39hca03da5_0 91 | - jinja2=3.1.4=py39hca03da5_1 92 | - jmespath=1.0.1=py39hca03da5_0 93 | - joblib=1.4.2=py39hca03da5_0 94 | - jpeg=9e=h80987f9_3 95 | - json-c=0.16=h1a28f6b_0 96 | - json5=0.9.25=py39hca03da5_0 97 | - jsonschema=4.23.0=py39hca03da5_0 98 | - jsonschema-specifications=2023.7.1=py39hca03da5_0 99 | - jupyter=1.0.0=py39hca03da5_9 100 | - jupyter-lsp=2.2.0=py39hca03da5_0 101 | - jupyter_client=8.6.0=py39hca03da5_0 102 | - jupyter_console=6.6.3=py39hca03da5_0 103 | - jupyter_core=5.7.2=py39hca03da5_0 104 | - jupyter_events=0.10.0=py39hca03da5_0 105 | - jupyter_server=2.14.1=py39hca03da5_0 106 | - jupyter_server_terminals=0.4.4=py39hca03da5_1 107 | - jupyterlab=4.2.5=py39hca03da5_0 108 | - jupyterlab_pygments=0.1.2=py_0 109 | - jupyterlab_server=2.27.3=py39hca03da5_0 110 | - jupyterlab_widgets=3.0.13=py39hca03da5_0 111 | - jxrlib=1.1=h1a28f6b_2 112 | - kealib=1.5.0=hba2eb73_1 113 | - kiwisolver=1.4.4=py39h313beb8_0 114 | - krb5=1.20.1=hf3e1bf2_1 115 | - lazy_loader=0.4=py39hca03da5_0 116 | - lcms2=2.16=he93ba84_0 117 | - lerc=4.0.0=h313beb8_0 118 | - libabseil=20240116.2=cxx17_h313beb8_0 119 | - libaec=1.1.3=h313beb8_0 120 | - libavif=1.1.1=h80987f9_0 121 | - libboost=1.82.0=h0bc93f9_2 122 | - libbrotlicommon=1.0.9=h80987f9_9 123 | - libbrotlidec=1.0.9=h80987f9_9 124 | - libbrotlienc=1.0.9=h80987f9_9 125 | - libclang=14.0.6=default_h1b80db6_2 126 | - libclang13=14.0.6=default_h24352ff_2 127 | - libcurl=8.11.1=hde089ae_0 128 | - libcxx=14.0.6=h848a8c0_0 129 | - libdeflate=1.22=h80987f9_0 130 | - libedit=3.1.20230828=h80987f9_0 131 | - libev=4.33=h1a28f6b_1 132 | - libffi=3.4.4=hca03da5_1 133 | - libgdal=3.6.2=h0e880fd_9 134 | - libgfortran=5.0.0=11_3_0_hca03da5_28 135 | - libgfortran5=11.3.0=h009349e_28 136 | - libglib=2.78.4=h0a96307_0 137 | - libiconv=1.16=h80987f9_3 138 | - libjpeg-turbo=2.0.0=h1a28f6b_0 139 | - libkml=1.3.0=hc4d7c42_7 140 | - libllvm14=14.0.6=h19fdd8a_4 141 | - libnetcdf=4.8.1=h0fce390_4 142 | - libnghttp2=1.57.0=h62f6fdd_0 143 | - libopenblas=0.3.21=h269037a_0 144 | - libpng=1.6.39=h80987f9_0 145 | - libpq=17.2=h02f6b3c_0 146 | - libprotobuf=3.20.3=h514c7bf_0 147 | - libsodium=1.0.18=h1a28f6b_0 148 | - libspatialindex=1.9.3=hc377ac9_0 149 | - libspatialite=5.1.0=h4e90699_2 150 | - libssh2=1.11.1=h3e2b118_0 151 | - libtiff=4.5.1=hc9ead59_1 152 | - libuv=1.48.0=h80987f9_0 153 | - libwebp=1.3.2=ha3663a8_0 154 | - libwebp-base=1.3.2=h80987f9_1 155 | - libxml2=2.13.5=h0b34f26_0 156 | - libzip=1.8.0=h62fee54_1 157 | - libzopfli=1.0.3=hc377ac9_0 158 | - lightning-utilities=0.11.9=py39hca03da5_0 159 | - llvm-openmp=14.0.6=hc6e5704_0 160 | - lz4-c=1.9.4=h313beb8_1 161 | - mapclassify=2.5.0=py39hca03da5_0 162 | - markupsafe=2.1.3=py39h80987f9_1 163 | - matplotlib-base=3.9.2=py39h7ef442a_1 164 | - matplotlib-inline=0.1.6=py39hca03da5_0 165 | - minizip=4.0.3=ha89c15f_0 166 | - mistune=2.0.4=py39hca03da5_0 167 | - mpc=1.1.0=h8c48613_1 168 | - mpfr=4.0.2=h695f6f0_1 169 | - mpmath=1.3.0=py39hca03da5_0 170 | - mysql=8.4.0=hbfabb4d_0 171 | - nbclient=0.8.0=py39hca03da5_0 172 | - nbconvert=7.16.4=py39hca03da5_0 173 | - nbformat=5.10.4=py39hca03da5_0 174 | - ncurses=6.4=h313beb8_0 175 | - nest-asyncio=1.6.0=py39hca03da5_0 176 | - networkx=3.2.1=py39hca03da5_0 177 | - ninja=1.12.1=hca03da5_0 178 | - ninja-base=1.12.1=h48ca7d4_0 179 | - notebook=7.2.2=py39hca03da5_1 180 | - notebook-shim=0.2.3=py39hca03da5_0 181 | - nspr=4.35=h313beb8_0 182 | - nss=3.89.1=h313beb8_0 183 | - numexpr=2.10.1=py39h5d9532f_0 184 | - numpy=1.26.4=py39h3b2db8e_0 185 | - numpy-base=1.26.4=py39ha9811e2_0 186 | - openjpeg=2.5.2=h54b8e55_0 187 | - openldap=2.6.4=he7ef289_0 188 | - openssl=3.0.16=h02f6b3c_0 189 | - overrides=7.4.0=py39hca03da5_0 190 | - packaging=24.2=py39hca03da5_0 191 | - pandas=2.2.3=py39hcf29cfe_0 192 | - pandocfilters=1.5.0=pyhd3eb1b0_0 193 | - parso=0.8.4=py39hca03da5_0 194 | - pcre2=10.42=hb066dcc_1 195 | - pexpect=4.8.0=pyhd3eb1b0_3 196 | - pickleshare=0.7.5=pyhd3eb1b0_1003 197 | - pillow=11.0.0=py39h84e58ab_1 198 | - pip=24.2=py39hca03da5_0 199 | - pixman=0.40.0=h1a28f6b_0 200 | - platformdirs=3.10.0=py39hca03da5_0 201 | - ply=3.11=py39hca03da5_0 202 | - poppler=24.09.0=h092caa6_1 203 | - poppler-data=0.4.11=hca03da5_1 204 | - proj=9.3.1=h805f6d4_0 205 | - prometheus_client=0.21.0=py39hca03da5_0 206 | - prompt-toolkit=3.0.43=py39hca03da5_0 207 | - prompt_toolkit=3.0.43=hd3eb1b0_0 208 | - psutil=5.9.0=py39h80987f9_1 209 | - ptyprocess=0.7.0=pyhd3eb1b0_2 210 | - pure_eval=0.2.2=pyhd3eb1b0_0 211 | - pybind11-abi=4=hd3eb1b0_1 212 | - pycparser=2.21=pyhd3eb1b0_0 213 | - pygments=2.15.1=py39hca03da5_1 214 | - pyopenssl=24.2.1=py39hca03da5_0 215 | - pyparsing=3.2.0=py39hca03da5_0 216 | - pyproj=3.6.1=py39h6070227_0 217 | - pyqt=5.15.10=py39h313beb8_0 218 | - pyqt5-sip=12.13.0=py39h80987f9_0 219 | - pysocks=1.7.1=py39hca03da5_0 220 | - python=3.9.21=hb885b13_1 221 | - python-dateutil=2.9.0post0=py39hca03da5_2 222 | - python-fastjsonschema=2.20.0=py39hca03da5_0 223 | - python-json-logger=3.2.1=py39hca03da5_0 224 | - python-tzdata=2023.3=pyhd3eb1b0_0 225 | - pytorch=2.0.1=gpu_mps_py39h20d1048_0 226 | - pytorch-lightning=2.3.3=pyhd8ed1ab_0 227 | - pytz=2024.1=py39hca03da5_0 228 | - pyyaml=6.0.2=py39h80987f9_0 229 | - pyzmq=26.2.0=py39h313beb8_0 230 | - qhull=2020.2=h48ca7d4_2 231 | - qt-main=5.15.2=h0917680_11 232 | - qtconsole=5.6.0=py39hca03da5_0 233 | - qtpy=2.4.1=py39hca03da5_0 234 | - rasterio=1.3.10=py39he3d1bc6_0 235 | - readline=8.2=h1a28f6b_0 236 | - referencing=0.30.2=py39hca03da5_0 237 | - requests=2.32.3=py39hca03da5_1 238 | - rfc3339-validator=0.1.4=py39hca03da5_0 239 | - rfc3986-validator=0.1.1=py39hca03da5_0 240 | - rpds-py=0.22.3=py39h2aea54e_0 241 | - rtree=1.0.1=py39hca03da5_0 242 | - s3transfer=0.10.1=py39hca03da5_0 243 | - scikit-image=0.24.0=py39h46d7db6_0 244 | - scikit-learn=1.5.2=py39h313beb8_0 245 | - scipy=1.13.1=py39hd336fd7_1 246 | - send2trash=1.8.2=py39hca03da5_1 247 | - setuptools=75.1.0=py39hca03da5_0 248 | - shapely=2.0.6=py39hc1b91b6_0 249 | - sip=6.7.12=py39h313beb8_1 250 | - six=1.16.0=pyhd3eb1b0_1 251 | - sleef=3.5.1=h80987f9_2 252 | - snappy=1.2.1=h313beb8_0 253 | - sniffio=1.3.0=py39hca03da5_0 254 | - snuggs=1.4.7=pyhd3eb1b0_0 255 | - soupsieve=2.5=py39hca03da5_0 256 | - sqlite=3.45.3=h80987f9_0 257 | - stack_data=0.2.0=pyhd3eb1b0_0 258 | - sympy=1.13.3=py39hca03da5_0 259 | - terminado=0.17.1=py39hca03da5_0 260 | - threadpoolctl=3.5.0=py39h33ce5c2_0 261 | - tifffile=2023.4.12=py39hca03da5_0 262 | - tiledb=2.3.3=hb4a6b97_3 263 | - tinycss2=1.2.1=py39hca03da5_0 264 | - tk=8.6.14=h6ba3021_0 265 | - tomli=2.0.1=py39hca03da5_0 266 | - torchmetrics=1.4.0.post0=py39hca03da5_0 267 | - torchvision=0.15.2=cpu_py39h31aa045_0 268 | - tornado=6.4.2=py39h80987f9_0 269 | - tqdm=4.66.5=py39h33ce5c2_0 270 | - traitlets=5.14.3=py39hca03da5_0 271 | - typing-extensions=4.12.2=py39hca03da5_0 272 | - typing_extensions=4.12.2=py39hca03da5_0 273 | - tzdata=2024b=h04d1e81_0 274 | - unicodedata2=15.1.0=py39h80987f9_1 275 | - uriparser=0.9.7=h80987f9_0 276 | - urllib3=1.26.19=py39hca03da5_0 277 | - wcwidth=0.2.5=pyhd3eb1b0_0 278 | - webencodings=0.5.1=py39hca03da5_1 279 | - websocket-client=1.8.0=py39hca03da5_0 280 | - wheel=0.44.0=py39hca03da5_0 281 | - widgetsnbextension=4.0.13=py39hca03da5_0 282 | - xerces-c=3.2.4=h313beb8_1 283 | - xyzservices=2022.9.0=py39hca03da5_1 284 | - xz=5.4.6=h80987f9_1 285 | - yaml=0.2.5=h1a28f6b_0 286 | - zeromq=4.3.5=h313beb8_0 287 | - zfp=1.0.0=h313beb8_0 288 | - zipp=3.21.0=py39hca03da5_0 289 | - zlib=1.2.13=h18a0788_1 290 | - zlib-ng=2.0.7=h80987f9_0 291 | - zstd=1.5.6=hfb09047_0 292 | prefix: /opt/homebrew/anaconda3/envs/chm_demo2 293 | -------------------------------------------------------------------------------- /notebooks/data/example_polygon.geojson: -------------------------------------------------------------------------------- 1 | {"type":"FeatureCollection","features":[{"type":"Feature","properties":{},"geometry":{"type":"Polygon","coordinates":[[[-120.38784027099608,39.403570770090795],[-120.34106254577637,39.403570770090795],[-120.34106254577637,39.42936481248126],[-120.38784027099608,39.42936481248126],[-120.38784027099608,39.403570770090795]]]}}]} -------------------------------------------------------------------------------- /pl_modules/normnet_module.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import pytorch_lightning as pl 7 | import torch.nn as nn 8 | from torchmetrics import R2Score, MeanAbsoluteError 9 | import torchvision 10 | import torch 11 | import torchvision.transforms as T 12 | 13 | import sys 14 | from pathlib import Path 15 | 16 | ROOT = str(Path(__file__).parent.parent) 17 | sys.path.append(ROOT) 18 | 19 | def l1_loss(x, y, f=lambda x:x): 20 | return ((f(x) - f(y)).abs()).mean() 21 | 22 | class NormNetModule(pl.LightningModule): 23 | def __init__(self, backbone_cls, 24 | opt = None, 25 | sched = None, 26 | loss = l1_loss, 27 | **kwargs): 28 | super().__init__() 29 | 30 | self.__dict__.update(**locals()) 31 | self.cls = self.__class__ 32 | self.save_hyperparameters(ignore=["self"]) 33 | 34 | self.backbone = backbone_cls(n_classes=6) 35 | 36 | self.l1 = nn.ModuleDict(dict(_train=MeanAbsoluteError(compute_on_step=False), 37 | _val=MeanAbsoluteError(compute_on_step=False))) 38 | 39 | def step(self, batch, batch_nb, domain='train'): 40 | pred = self.backbone(batch['img']) 41 | loss = self.loss(pred, batch['percs']) 42 | log_args = dict(sync_dist = (domain !='train')) 43 | self.log(f'{domain}_loss', loss, **log_args) 44 | return loss 45 | 46 | def epoch_end(self, outputs, domain='train'): 47 | l1 = self.l1['_'+domain].compute() 48 | self.log(f'{domain}_l1', l1, sync_dist=(domain !='train')) 49 | 50 | def training_step(self, batch, batch_nb): 51 | return self.step(batch, batch_nb, domain='train') 52 | 53 | def validation_step(self, batch, batch_nb): 54 | return self.step(batch, batch_nb, domain='val') 55 | 56 | def training_epoch_end(self, outputs): 57 | self.epoch_end(outputs, domain='train') 58 | 59 | def validation_epoch_end(self, outputs): 60 | self.epoch_end(outputs, domain='val') 61 | 62 | def configure_optimizers(self): 63 | opt = self.opt(self.parameters()) 64 | sched = self.sched(opt) 65 | return [opt], [sched] -------------------------------------------------------------------------------- /src/raster_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | from skimage.util import view_as_windows 8 | from scipy.ndimage import gaussian_filter 9 | 10 | 11 | def create_apodization(window_size, pad_int=None): 12 | if not pad_int: 13 | pad_int = int(np.mean(window_size) / 16) 14 | yr = window_size[0] - int(window_size[0] / pad_int) * 2 15 | xr = window_size[1] - int(window_size[1] / pad_int) * 2 16 | padx = int((window_size[0] - xr) / 2) 17 | pady = int((window_size[1] - yr) / 2) 18 | pad = ((padx, padx), (pady, pady)) 19 | 20 | weight = np.ones((xr, yr)) 21 | weight = np.pad(weight, pad, mode="constant") 22 | weight = gaussian_filter(weight, np.mean(pad) / 2) 23 | weight[weight < 1e-14] = 1e-14 24 | return weight 25 | 26 | 27 | def blocking(arr, block_shape, step): 28 | """ 29 | Create a stacked array, with thumbnail images of size block_shape, 30 | which have an overlap specified by step. (if step = block_shape 31 | there is no overlap) 32 | """ 33 | view = view_as_windows(arr, block_shape, step=step) 34 | vx = view.shape[0] 35 | vy = view.shape[1] 36 | flatten_view = view.reshape(-1, 37 | view.shape[-3], view.shape[-2], view.shape[-1]) 38 | return flatten_view, vx, vy 39 | 40 | 41 | def inverse_blocking(block_arr, out_size, winsize, weight, step, vx, vy): 42 | """ 43 | Given a stacked array of image thumbnails, expand back to unstacked 44 | image of size out_size, given a step size and winsize. 45 | Uses the inverse of weight to coadd the thumbnail images 46 | """ 47 | vview = block_arr.reshape((vx, vy) + block_arr.shape[1:]) 48 | out = np.zeros(out_size + (block_arr.shape[-1],)) 49 | weights = np.zeros(out_size) + 1e-14 50 | w2 = int(winsize / 2) 51 | for i in range(vview.shape[0]): 52 | if i == 0: 53 | ci = w2 54 | else: 55 | ci = w2 + i * step 56 | for j in range(vview.shape[1]): 57 | if j == 0: 58 | cj = w2 59 | else: 60 | cj = w2 + j * step 61 | exist = out[ci - w2: ci + w2, cj - w2: cj + w2, :] 62 | new = np.einsum("ijk,ij->ijk", vview[i, j], weight) 63 | out[ci - w2: ci + w2, cj - w2: cj + w2, :] = new + exist 64 | wexist = weights[ci - w2: ci + w2, cj - w2: cj + w2] 65 | weights[ci - w2: ci + w2, cj - w2: cj + w2] = weight + wexist 66 | final = np.einsum("ijk,ij->ijk", out, 1 / weights) 67 | return final, weights 68 | -------------------------------------------------------------------------------- /src/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torchvision.transforms as transforms 8 | 9 | import torchvision.transforms.functional as TF 10 | 11 | 12 | class Norm: 13 | def get_trans(self, nmean, nstd): 14 | self.nmean = nmean 15 | self.nstd = nstd 16 | self.Trans = transforms.Compose( 17 | [ 18 | transforms.ToTensor(), 19 | # type: ignore 20 | transforms.Normalize(mean=self.nmean, std=self.nstd), 21 | ] 22 | ) 23 | self.Normalize = transforms.Normalize(mean=self.nmean, std=self.nstd) 24 | self.invTrans = transforms.Compose( 25 | [ 26 | transforms.Normalize( 27 | mean=[0.0, 0.0, 0.0], 28 | std=[1 / self.nstd[0], 1 / self.nstd[1], 1 / self.nstd[2]], 29 | ), 30 | transforms.Normalize( 31 | mean=[-self.nmean[0], -self.nmean[1], -self.nmean[2]], 32 | std=[1.0, 1.0, 1.0], 33 | ), 34 | ] 35 | ) 36 | 37 | 38 | class SSLNorm(Norm): 39 | def __init__(self): 40 | super().__init__() 41 | nmean = [0.430, 0.411, 0.296] 42 | nstd = [0.213, 0.156, 0.143] 43 | self.get_trans(nmean, nstd) 44 | --------------------------------------------------------------------------------