├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── demo.ipynb ├── satlasnet.png ├── satlaspretrain_models ├── __init__.py ├── model.py ├── models │ ├── __init__.py │ ├── backbones.py │ ├── fpn.py │ └── heads.py └── utils.py ├── setup.py ├── tests ├── test_pretrained_models.py └── test_randomly_initialized_models.py └── torchgeo_demo.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__* 2 | *.swp 3 | *.pyc 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include README.md 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | SatlasPretrain Models: Foundation models for satellite and aerial imagery. 2 | -------------------------------------------------------------------------- 3 | 4 | **SatlasPretrain** is a large-scale pre-training dataset for remote sensing image understanding. This work 5 | was published at ICCV 2023. Details and download links for the dataset can be found 6 | [here](https://github.com/allenai/satlas/blob/main/SatlasPretrain.md). 7 | 8 | This repository contains `satlaspretrain_models`, a lightweight library to easily load pretrained SatlasPretrain models for: 9 | - Sentinel-2 10 | - Sentinel-1 11 | - Landsat 8/9 12 | - 0.5-2 m/pixel aerial imagery 13 | 14 | These models can be fine-tuned on downstream tasks that use these image sources, leading to faster training 15 | and improved performance compared to training from other initializations. 16 | 17 | Model Structure and Usage 18 | ------------------------- 19 | The SatlasPretrain models consist of three main components: backbone, feature pyramid network (FPN), and prediction head. 20 | 21 | ![SatlasPretrain model architecture diagram, described in the next paragraph.](satlasnet.png) 22 | 23 | For models trained on *multi-image* input, the backbone is applied on each individual image, and then max pooling is applied 24 | in the temporal dimension, i.e., across the multiple aligned images. *Single-image* models input an individual image. 25 | 26 | This package allows you to load the backbone or backbone+FPN using a model checkpoint ID from the tables below. 27 | 28 | ```python 29 | MODEL_CHECKPOINT_ID = "Sentinel2_SwinB_SI_RGB" 30 | model = weights_manager.get_pretrained_model(MODEL_CHECKPOINT_ID) 31 | model = weights_manager.get_pretrained_model(MODEL_CHECKPOINT_ID, fpn=True) 32 | ``` 33 | 34 | The output of the model is the multi-scale feature map (either from the backbone or from the FPN). 35 | 36 | For a complete fine-tuning example, [see our tutorial on fine-tuning the pre-trained model on EuroSAT](https://github.com/allenai/satlaspretrain_models/blob/main/demo.ipynb). 37 | 38 | Installation 39 | -------------- 40 | ``` 41 | conda create --name satlaspretrain python==3.9 42 | conda activate satlaspretrain 43 | pip install satlaspretrain-models 44 | ``` 45 | 46 | Available Pretrained Models 47 | --------------------------- 48 | 49 | The tables below list available model checkpoint IDs (like `Sentinel2_SwinB_SI_RGB`). 50 | Checkpoints are released under [ODC-BY](https://github.com/allenai/satlas/blob/main/DataLicense). 51 | This package will download model checkpoints automatically, but you can download them directly using the links below if desired. 52 | 53 | #### Sentinel-2 Pretrained Models 54 | | | Single-image, RGB | Multi-image, RGB | 55 | | ---------- | ------------ | ------------ | 56 | | **Swin-v2-Base** | [Sentinel2_SwinB_SI_RGB](https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel2_swinb_si_rgb.pth?download=true) | [Sentinel2_SwinB_MI_RGB](https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel2_swinb_mi_rgb.pth?download=true) | 57 | | **Swin-v2-Tiny** | [Sentinel2_SwinT_SI_RGB](https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel2_swint_si_rgb.pth?download=true) | [Sentinel2_SwinT_MI_RGB](https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel2_swint_mi_rgb.pth?download=true) | 58 | | **Resnet50** | [Sentinel2_Resnet50_SI_RGB](https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel2_resnet50_si_rgb.pth?download=true) | [Sentinel2_Resnet50_MI_RGB](https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel2_resnet50_mi_rgb.pth?download=true) | 59 | | **Resnet152** | [Sentinel2_Resnet152_SI_RGB](https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel2_resnet152_si_rgb.pth?download=true) | [Sentinel2_Resnet152_MI_RGB](https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel2_resnet152_mi_rgb.pth?download=true) | 60 | 61 | | | Single-image, MS | Multi-image, MS | 62 | | ---------- | ------------ | ------------ | 63 | | **Swin-v2-Base** | [Sentinel2_SwinB_SI_MS](https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel2_swinb_si_ms.pth?download=true) | [Sentinel2_SwinB_MI_MS](https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel2_swinb_mi_ms.pth?download=true) | 64 | | **Swin-v2-Tiny** | [Sentinel2_SwinT_SI_MS](https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel2_swint_si_ms.pth?download=true) | [Sentinel2_SwinT_MI_MS](https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel2_swint_mi_ms.pth?download=true) | 65 | | **Resnet50** | [Sentinel2_Resnet50_SI_MS](https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel2_swint_si_ms.pth?download=true) | [Sentinel2_Resnet50_MI_MS](https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel2_resnet50_mi_ms.pth?download=true) | 66 | | **Resnet152** | [Sentinel2_Resnet152_SI_MS](https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel2_resnet152_si_ms.pth?download=true) | [Sentinel2_Resnet152_MI_MS](https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel2_resnet152_mi_ms.pth?download=true) | 67 | 68 | #### Sentinel-1 Pretrained Models 69 | | | Single-image, VH+VV | Multi-image, VH+VV | 70 | | ---------- | ------------ | ------------ | 71 | | **Swin-v2-Base** | [Sentinel1_SwinB_SI](https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel1_swinb_si.pth?download=true) | [Sentinel1_SwinB_MI](https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel1_swinb_mi.pth?download=true) | 72 | 73 | #### Landsat 8/9 Pretrained Models 74 | | | Single-image, all bands | Multi-image, all bands | 75 | | ---------- | ------------ | ------------ | 76 | | **Swin-v2-Base** | [Landsat_SwinB_SI](https://huggingface.co/allenai/satlas-pretrain/resolve/main/landsat_swinb_si.pth?download=true) | [Landsat_SwinB_MI](https://huggingface.co/allenai/satlas-pretrain/resolve/main/landsat_swinb_mi.pth?download=true) | 77 | 78 | #### Aerial (0.5-2m/px high-res imagery) Pretrained Models 79 | | | Single-image, RGB | Multi-image, RGB | 80 | | ---------- | ------------ | ------------ | 81 | | **Swin-v2-Base** | [Aerial_SwinB_SI](https://huggingface.co/allenai/satlas-pretrain/resolve/main/aerial_swinb_si.pth?download=true) | [Aerial_SwinB_MI](https://huggingface.co/allenai/satlas-pretrain/resolve/main/aerial_swinb_mi.pth?download=true) | 82 | 83 | 84 | *Single-image* models learn strong representations for individual satellite or aerial images, while *multi-image* models use multiple image captures of the same location for added robustness when making predictions about static objects. In *multi-image* models, feature maps from the backbone are passed through temporal max pooling, so the backbone itself is still applied on individual images, but is trained to provide strong representations after the temporal max pooling step. See [ModelArchitecture.md](https://github.com/allenai/satlas/blob/main/ModelArchitecture.md) for more details. 85 | 86 | Sentinel-2 *RGB* models input the B4, B3, and B2 bands only, while the multi-spectral (*MS*) models input 9 bands. The aerial (0.5-2m/px high-res imagery) models input *RGB* NAIP and other high-res images, and we have found them to be effective on aerial imagery from a variety of sources. Landsat models input B1-B11 (*all bands*). Sentinel-1 models input *VV and VH* bands. 87 | See [Normalization.md](https://github.com/allenai/satlas/blob/main/Normalization.md) for details on how pixel values should be normalized for input to the pre-trained models. 88 | 89 | Usage Examples 90 | -------------- 91 | First initialize a `Weights` instance: 92 | 93 | ```python 94 | import satlaspretrain_models 95 | import torch 96 | weights_manager = satlaspretrain_models.Weights() 97 | ``` 98 | 99 | Then choose a **model_identifier** from the tables above to specify the pretrained model you want to load. 100 | Below are examples showing how to load in a few of the available models. 101 | 102 | #### Pretrained single-image Sentinel-2 RGB model, backbone only: 103 | ```python 104 | model = weights_manager.get_pretrained_model(model_identifier="Sentinel2_SwinB_SI_RGB") 105 | 106 | # Expected input is a portion of a Sentinel-2 L1C TCI image. 107 | # The 0-255 pixel values should be divided by 255 so they are 0-1. 108 | # tensor = tci_image[None, :, :, :] / 255 109 | tensor = torch.zeros((1, 3, 512, 512), dtype=torch.float32) 110 | 111 | # Since we only loaded the backbone, it outputs feature maps from the Swin-v2-Base backbone. 112 | output = model(tensor) 113 | print([feature_map.shape for feature_map in output]) 114 | # [torch.Size([1, 128, 128, 128]), torch.Size([1, 256, 64, 64]), torch.Size([1, 512, 32, 32]), torch.Size([1, 1024, 16, 16])] 115 | ``` 116 | 117 | #### Pretrained single-image Sentinel-1 model, backbone+FPN 118 | ```python 119 | model = weights_manager.get_pretrained_model("Sentinel1_SwinB_SI", fpn=True) 120 | 121 | # Expected input is a portion of a Sentinel-1 vh+vv image (in that order). 122 | # The 16-bit pixel values should be divided by 255 and clipped to 0-1 (any pixel values greater than 255 become 1). 123 | # tensor = torch.clip(torch.stack([vh_image, vv_image], dim=0)[None, :, :, :] / 255, 0, 1) 124 | tensor = torch.zeros((1, 2, 512, 512), dtype=torch.float32) 125 | 126 | # The model outputs feature maps from the FPN. 127 | output = model(tensor) 128 | print([feature_map.shape for feature_map in output]) 129 | # [torch.Size([1, 128, 128, 128]), torch.Size([1, 128, 64, 64]), torch.Size([1, 128, 32, 32]), torch.Size([1, 128, 16, 16])] 130 | ``` 131 | 132 | #### Prediction heads 133 | 134 | Although the checkpoints include prediction head parameters, these heads are task-specific, so loading the head parameters is not supported in this repository. 135 | Computing outputs from the pre-trained prediction heads is supported [in the dataset codebase](https://github.com/allenai/satlas/blob/main/SatlasPretrain.md#visualizing-outputs-on-new-images). 136 | 137 | For convenience when fine-tuning on certain types of tasks, though, `satlaspretrain_models` supports attaching certain heads (initialized randomly) to the pre-trained model: 138 | 139 | ```python 140 | # Backbone and FPN parameters initialized from checkpoint, head parameters initialized randomly. 141 | model = weights_manager.get_pretrained_model(MODEL_CHECKPOINT_ID, fpn=True, head=satlaspretrain_models.Head.CLASSIFY, num_categories=2) 142 | ``` 143 | 144 | The following head architectures are available: 145 | - *Segmentation*: U-Net Decoder w/ Cross Entropy loss 146 | - *Detection*: Faster R-CNN Decoder 147 | - *Instance Segmentation*: Mask R-CNN Decoder 148 | - *Regression*: U-Net Decoder w/ L1 loss 149 | - *Classification*: Pooling + Linear layers 150 | - *Multi-label Classification*: Pooling + Linear layers 151 | 152 | #### Pretrained multi-image aerial model, backbone + FPN + classification head: 153 | ```python 154 | # num_categories is the number of categories to predict. 155 | # All heads are randomly initialized and provided only for convenience for fine-tuning. 156 | model = weights_manager.get_pretrained_model("Aerial_SwinB_MI", fpn=True, head=satlaspretrain_models.Head.CLASSIFY, num_categories=2) 157 | 158 | # Expected input is 8-bit (0-255) aerial images at 0.5 - 2 m/pixel. 159 | # The 0-255 pixel values should be divided by 255 so they are 0-1. 160 | # This multi-image model is trained to input 4 images but should perform well with different numbers of images. 161 | # tensor = torch.stack([rgb_image1, rgb_image2], dim=0)[None, :, :, :, :] / 255 162 | tensor = torch.zeros((1, 4, 3, 512, 512), dtype=torch.float32) 163 | 164 | # The head needs to be fine-tuned on a downstream classification task. 165 | # It outputs classification probabilities. 166 | model.eval() 167 | output = model(tensor.reshape(1, 4*3, 512, 512)) 168 | print(output) 169 | # tensor([[0.0266, 0.9734]]) 170 | ``` 171 | 172 | #### Pretrained multi-image Landsat model, backbone + FPN + detection head 173 | ```python 174 | # num_categories is the number of bounding box detection categories. 175 | # All heads are randomly initialized and provided only for convenience for fine-tuning. 176 | model = weights_manager.get_pretrained_model("Landsat_SwinB_MI", fpn=True, head=satlaspretrain_models.Head.DETECT, num_categories=5) 177 | 178 | # Expected input is Landsat B1-B11 stacked in order. 179 | # This multi-image model is trained to input 8 images but should perform well with different numbers of images. 180 | # The 16-bit pixel values are normalized as follows: 181 | # landsat_images = torch.stack([landsat_image1, landsat_image2], dim=0) 182 | # tensor = torch.clip(landsat_images[None, :, :, :, :]-4000)/16320, 0, 1) 183 | tensor = torch.zeros((1, 8, 11, 512, 512), dtype=torch.float32) 184 | 185 | # The head needs to be fine-tuned on a downstream object detection task. 186 | # It outputs bounding box detections. 187 | model.eval() 188 | output = model(tensor.reshape(1, 8*11, 512, 512)) 189 | print(output) 190 | #[{'boxes': tensor([[ 67.0772, 239.2646, 95.6874, 16.3644], ...]), 191 | # 'labels': tensor([3, ...]), 192 | # 'scores': tensor([0.5443, ...])}] 193 | ``` 194 | 195 | Demos 196 | ----- 197 | We provide a [demo](https://github.com/allenai/satlaspretrain_models/blob/main/demo.ipynb) showing how to finetune a 198 | SatlasPretrain Sentinel-2 model on the EuroSAT classification task. 199 | 200 | We also provide a [torchgeo demo](https://github.com/allenai/satlaspretrain_models/blob/main/torchgeo_demo.ipynb), 201 | showing how to load SatlasPretrain weights into a model, download a dataset, initialize a trainer, and finetune the model on the UCMerced classification task. 202 | *Note*: a separate conda environment must be initialized to run this demo, see details in the notebook. 203 | 204 | 205 | Tests 206 | ----- 207 | There are tests to test loading pretrained models and one to test randomly initialized models. 208 | 209 | To run the tests, run the following command from the root directory: 210 | `pytest tests/` 211 | 212 | Contact 213 | ------- 214 | If you have any questions, please email `satlas@allenai.org` or open an issue [here](https://github.com/allenai/satlaspretrain_models/issues/new). 215 | -------------------------------------------------------------------------------- /satlasnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/satlaspretrain_models/7b5cd45adc3cad70b3834d65956974af6f6bffd0/satlasnet.png -------------------------------------------------------------------------------- /satlaspretrain_models/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import Weights, Model 2 | from .utils import Head, Backbone 3 | -------------------------------------------------------------------------------- /satlaspretrain_models/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | import requests 4 | from io import BytesIO 5 | 6 | from satlaspretrain_models.utils import Backbone, Head, SatlasPretrain_weights, adjust_state_dict_prefix 7 | from satlaspretrain_models.models import * 8 | 9 | class Weights: 10 | def __init__(self): 11 | """ 12 | Class to manage downloading weights and formatting them to be loaded into SatlasPretrain models. 13 | """ 14 | super(Weights, self).__init__() 15 | 16 | def get_pretrained_model(self, model_identifier, fpn=False, head=None, num_categories=None, device='cuda'): 17 | """ 18 | Find and load pretrained SatlasPretrain weights, based on the model_identifier argument. 19 | Option to load pretrained FPN and/or a randomly initialized head. 20 | 21 | Args: 22 | model_identifier: 23 | fpn (bool): Whether or not to load a pretrained FPN along with the Backbone. 24 | head (enum): If specified, a randomly initialized Head will be created along with the 25 | Backbone and [optionally] the FPN. 26 | num_categories (int): Number of categories to be included in output from prediction head. 27 | """ 28 | # Validate that the model identifier is supported. 29 | if not model_identifier in SatlasPretrain_weights.keys(): 30 | raise ValueError("Invalid model_identifier. See utils.SatlasPretrain_weights.") 31 | 32 | if head and (num_categories is None): 33 | raise ValueError("Must specify num_categories if head is desired.") 34 | 35 | model_info = SatlasPretrain_weights[model_identifier] 36 | 37 | # Use hardcoded huggingface url to download weights. 38 | weights_url = model_info['url'] 39 | response = requests.get(weights_url) 40 | if response.status_code == 200: 41 | weights_file = BytesIO(response.content) 42 | else: 43 | raise Exception(f"Failed to download weights from {url}") 44 | 45 | if device == 'cpu': 46 | weights = torch.load(weights_file, map_location=torch.device('cpu')) 47 | else: 48 | weights = torch.load(weights_file) 49 | 50 | # Initialize a pretrained model using the Model() class. 51 | model = Model(model_info['num_channels'], model_info['multi_image'], model_info['backbone'], fpn=fpn, head=head, 52 | num_categories=num_categories, weights=weights) 53 | return model 54 | 55 | 56 | class Model(torch.nn.Module): 57 | def __init__(self, num_channels=3, multi_image=False, backbone=Backbone.SWINB, fpn=False, head=None, num_categories=None, weights=None): 58 | """ 59 | Initializes a model, based on desired imagery source and model components. This class can be used directly to 60 | create a randomly initialized model (if weights=None) or can be called from the Weights class to initialize a 61 | SatlasPretrain pretrained foundation model. 62 | 63 | Args: 64 | num_channels (int): Number of input channels that the backbone model should expect. 65 | multi_image (bool): Whether or not the model should expect single-image or multi-image input. 66 | backbone (Backbone): The architecture of the pretrained backbone. All image sources support SwinTransformer. 67 | fpn (bool): Whether or not to feed imagery through the pretrained Feature Pyramid Network after the backbone. 68 | head (Head): If specified, a randomly initialized head will be included in the model. 69 | num_categories (int): If a Head is being returned as part of the model, must specify how many outputs are wanted. 70 | weights (torch weights): Weights to be loaded into the model. Defaults to None (random initialization) unless 71 | initialized using the Weights class. 72 | """ 73 | super(Model, self).__init__() 74 | 75 | # Validate user-provided arguments. 76 | if not isinstance(backbone, Backbone): 77 | raise ValueError("Invalid backbone.") 78 | if head and not isinstance(head, Head): 79 | raise ValueError("Invalid head.") 80 | if head and (num_categories is None): 81 | raise ValueError("Must specify num_categories if head is desired.") 82 | 83 | self.backbone = self._initialize_backbone(num_channels, backbone, multi_image, weights) 84 | 85 | if fpn: 86 | self.fpn = self._initialize_fpn(self.backbone.out_channels, weights) 87 | self.upsample = Upsample(self.fpn.out_channels) 88 | else: 89 | self.fpn = None 90 | 91 | if head: 92 | self.head = self._initialize_head(head, self.fpn.out_channels, num_categories) if fpn else self._initialize_head(head, self.backbone.out_channels, num_categories) 93 | else: 94 | self.head = None 95 | 96 | def _initialize_backbone(self, num_channels, backbone_arch, multi_image, weights): 97 | # Load backbone model according to specified architecture. 98 | if backbone_arch == Backbone.SWINB: 99 | backbone = SwinBackbone(num_channels, arch='swinb') 100 | elif backbone_arch == Backbone.SWINT: 101 | backbone = SwinBackbone(num_channels, arch='swint') 102 | elif backbone_arch == Backbone.RESNET50: 103 | backbone = ResnetBackbone(num_channels, arch='resnet50') 104 | elif backbone_arch == Backbone.RESNET152: 105 | backbone = ResnetBackbone(num_channels, arch='resnet152') 106 | else: 107 | raise ValueError("Unsupported backbone architecture.") 108 | 109 | # If using a model for multi-image, need the Aggretation to wrap underlying backbone model. 110 | prefix, prefix_allowed_count = None, None 111 | if backbone_arch in [Backbone.RESNET50, Backbone.RESNET152]: 112 | prefix_allowed_count = 0 113 | elif multi_image: 114 | backbone = AggregationBackbone(num_channels, backbone) 115 | prefix_allowed_count = 2 116 | else: 117 | prefix_allowed_count = 1 118 | 119 | # Load pretrained weights into the intialized backbone if weights were specified. 120 | if weights is not None: 121 | state_dict = adjust_state_dict_prefix(weights, 'backbone', 'backbone.', prefix_allowed_count) 122 | backbone.load_state_dict(state_dict) 123 | 124 | return backbone 125 | 126 | def _initialize_fpn(self, backbone_channels, weights): 127 | fpn = FPN(backbone_channels) 128 | 129 | # Load pretrained weights into the intialized FPN if weights were specified. 130 | if weights is not None: 131 | state_dict = adjust_state_dict_prefix(weights, 'fpn', 'intermediates.0.', 0) 132 | fpn.load_state_dict(state_dict) 133 | return fpn 134 | 135 | def _initialize_head(self, head, backbone_channels, num_categories): 136 | # Initialize the head (classification, detection, etc.) if specified 137 | if head == Head.CLASSIFY: 138 | return SimpleHead('classification', backbone_channels, num_categories) 139 | elif head == Head.MULTICLASSIFY: 140 | return SimpleHead('multi-label-classification', backbone_channels, num_categories) 141 | elif head == Head.SEGMENT: 142 | return SimpleHead('segment', backbone_channels, num_categories) 143 | elif head == Head.BINSEGMENT: 144 | return SimpleHead('bin_segment', backbone_channels, num_categories) 145 | elif head == Head.REGRESS: 146 | return SimpleHead('regress', backbone_channels, num_categories) 147 | elif head == Head.DETECT: 148 | return FRCNNHead('detect', backbone_channels, num_categories) 149 | elif head == Head.INSTANCE: 150 | return FRCNNHead('instance', backbone_channels, num_categories) 151 | return None 152 | 153 | def forward(self, imgs, targets=None): 154 | # Define forward pass 155 | x = self.backbone(imgs) 156 | if self.fpn: 157 | x = self.fpn(x) 158 | x = self.upsample(x) 159 | if self.head: 160 | x, loss = self.head(imgs, x, targets) 161 | return x, loss 162 | return x 163 | 164 | 165 | if __name__ == "__main__": 166 | weights_manager = Weights() 167 | 168 | # Test loading in all available pretrained backbone models, without FPN or Head. 169 | # Test feeding in a random tensor as input. 170 | for model_id in SatlasPretrain_weights.keys(): 171 | print("Attempting to load ...", model_id) 172 | model_info = SatlasPretrain_weights[model_id] 173 | model = weights_manager.get_pretrained_model(model_id) 174 | rand_img = torch.rand((8, model_info['num_channels'], 128, 128)) 175 | output = model(rand_img) 176 | print("Successfully initialized the pretrained model with ID:", model_id) 177 | 178 | # Test loading in all available pretrained backbone models, with FPN, without Head. 179 | # Test feeding in a random tensor as input. 180 | for model_id in SatlasPretrain_weights.keys(): 181 | print("Attempting to load ...", model_id, " with pretrained FPN.") 182 | model_info = SatlasPretrain_weights[model_id] 183 | model = weights_manager.get_pretrained_model(model_id, fpn=True) 184 | rand_img = torch.rand((8, model_info['num_channels'], 128, 128)) 185 | output = model(rand_img) 186 | print("Successfully initialized the pretrained model with ID:", model_id, " with FPN.") 187 | 188 | # Test loading in all available pretrained backbones, with FPN and with every possible Head. 189 | # Test feeding in a random tensor as input. Randomly generated targets are fed into detection/instance heads. 190 | for model_id in SatlasPretrain_weights.keys(): 191 | model_info = SatlasPretrain_weights[model_id] 192 | for head in Head: 193 | print("Attempting to load ...", model_id, " with pretrained FPN and randomly initialized ", head, " Head.") 194 | model = weights_manager.get_pretrained_model(model_id, fpn=True, head=head, num_categories=2) 195 | rand_img = torch.rand((1, model_info['num_channels'], 128, 128)) 196 | 197 | rand_targets = None 198 | if head == Head.DETECT: 199 | rand_targets = [{ 200 | 'boxes': torch.tensor([[100, 100, 110, 110], [30, 30, 40, 40]], dtype=torch.float32), 201 | 'labels': torch.tensor([0,1], dtype=torch.int64) 202 | }] 203 | elif head == Head.INSTANCE: 204 | rand_targets = [{ 205 | 'boxes': torch.tensor([[100, 100, 110, 110], [30, 30, 40, 40]], dtype=torch.float32), 206 | 'labels': torch.tensor([0,1], dtype=torch.int64), 207 | 'masks': torch.zeros_like(rand_img) 208 | }] 209 | elif head in [Head.SEGMENT, Head.BINSEGMENT, Head.REGRESS]: 210 | rand_targets = torch.zeros_like((rand_img)) 211 | 212 | output, loss = model(rand_img, rand_targets) 213 | print("Successfully initialized the pretrained model with ID:", model_id, " with FPN and randomly initialized ", head, " Head.") 214 | -------------------------------------------------------------------------------- /satlaspretrain_models/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .backbones import SwinBackbone, ResnetBackbone, AggregationBackbone 2 | from .fpn import FPN, Upsample 3 | from .heads import SimpleHead, FRCNNHead 4 | -------------------------------------------------------------------------------- /satlaspretrain_models/models/backbones.py: -------------------------------------------------------------------------------- 1 | import torch.nn 2 | import torchvision 3 | 4 | class SwinBackbone(torch.nn.Module): 5 | def __init__(self, num_channels, arch): 6 | super(SwinBackbone, self).__init__() 7 | 8 | if arch == 'swinb': 9 | self.backbone = torchvision.models.swin_v2_b() 10 | self.out_channels = [ 11 | [4, 128], 12 | [8, 256], 13 | [16, 512], 14 | [32, 1024], 15 | ] 16 | elif arch == 'swint': 17 | self.backbone = torchvision.models.swin_v2_t() 18 | self.out_channels = [ 19 | [4, 96], 20 | [8, 192], 21 | [16, 384], 22 | [32, 768], 23 | ] 24 | else: 25 | raise ValueError("Backbone architecture not supported.") 26 | 27 | self.backbone.features[0][0] = torch.nn.Conv2d(num_channels, self.backbone.features[0][0].out_channels, kernel_size=(4, 4), stride=(4, 4)) 28 | 29 | def forward(self, x): 30 | outputs = [] 31 | for layer in self.backbone.features: 32 | x = layer(x) 33 | outputs.append(x.permute(0, 3, 1, 2)) 34 | return [outputs[-7], outputs[-5], outputs[-3], outputs[-1]] 35 | 36 | 37 | class ResnetBackbone(torch.nn.Module): 38 | def __init__(self, num_channels, arch='resnet50'): 39 | super(ResnetBackbone, self).__init__() 40 | 41 | if arch == 'resnet50': 42 | self.resnet = torchvision.models.resnet.resnet50(weights=None) 43 | ch = [256, 512, 1024, 2048] 44 | elif arch == 'resnet152': 45 | self.resnet = torchvision.models.resnet.resnet152(weights=None) 46 | ch = [256, 512, 1024, 2048] 47 | else: 48 | raise ValueError("Backbone architecture not supported.") 49 | 50 | self.resnet.conv1 = torch.nn.Conv2d(num_channels, self.resnet.conv1.out_channels, kernel_size=7, stride=2, padding=3, bias=False) 51 | self.out_channels = [ 52 | [4, ch[0]], 53 | [8, ch[1]], 54 | [16, ch[2]], 55 | [32, ch[3]], 56 | ] 57 | 58 | def train(self, mode=True): 59 | super(ResnetBackbone, self).train(mode) 60 | 61 | def forward(self, x): 62 | x = self.resnet.conv1(x) 63 | x = self.resnet.bn1(x) 64 | x = self.resnet.relu(x) 65 | x = self.resnet.maxpool(x) 66 | 67 | layer1 = self.resnet.layer1(x) 68 | layer2 = self.resnet.layer2(layer1) 69 | layer3 = self.resnet.layer3(layer2) 70 | layer4 = self.resnet.layer4(layer3) 71 | 72 | return [layer1, layer2, layer3, layer4] 73 | 74 | 75 | class AggregationBackbone(torch.nn.Module): 76 | def __init__(self, num_channels, backbone): 77 | super(AggregationBackbone, self).__init__() 78 | 79 | # Number of channels to pass to underlying backbone. 80 | self.image_channels = num_channels 81 | 82 | # Prepare underlying backbone. 83 | self.backbone = backbone 84 | 85 | # Features from images within each group are aggregated separately. 86 | # Then the output is the concatenation across groups. 87 | # e.g. [[0], [1, 2]] to compare first image against the others 88 | self.groups = [[0, 1, 2, 3, 4, 5, 6, 7]] 89 | 90 | ngroups = len(self.groups) 91 | self.out_channels = [(depth, ngroups*count) for (depth, count) in self.backbone.out_channels] 92 | 93 | self.aggregation_op = 'max' 94 | 95 | def forward(self, x): 96 | # First get features of each image. 97 | all_features = [] 98 | for i in range(0, x.shape[1], self.image_channels): 99 | features = self.backbone(x[:, i:i+self.image_channels, :, :]) 100 | all_features.append(features) 101 | 102 | # Now compute aggregation over each group. 103 | # We handle each depth separately. 104 | l = [] 105 | for feature_idx in range(len(all_features[0])): 106 | aggregated_features = [] 107 | for group in self.groups: 108 | group_features = [] 109 | for image_idx in group: 110 | # We may input fewer than the maximum number of images. 111 | # So here we skip image indices in the group that aren't available. 112 | if image_idx >= len(all_features): 113 | continue 114 | 115 | group_features.append(all_features[image_idx][feature_idx]) 116 | # Resulting group features are (depth, batch, C, height, width). 117 | group_features = torch.stack(group_features, dim=0) 118 | 119 | if self.aggregation_op == 'max': 120 | group_features = torch.amax(group_features, dim=0) 121 | 122 | aggregated_features.append(group_features) 123 | 124 | # Finally we concatenate across groups. 125 | aggregated_features = torch.cat(aggregated_features, dim=1) 126 | 127 | l.append(aggregated_features) 128 | 129 | return l 130 | -------------------------------------------------------------------------------- /satlaspretrain_models/models/fpn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import collections 3 | import torchvision 4 | 5 | 6 | class FPN(torch.nn.Module): 7 | def __init__(self, backbone_channels): 8 | super(FPN, self).__init__() 9 | 10 | out_channels = 128 11 | in_channels_list = [ch[1] for ch in backbone_channels] 12 | self.fpn = torchvision.ops.FeaturePyramidNetwork(in_channels_list=in_channels_list, out_channels=out_channels) 13 | 14 | self.out_channels = [[ch[0], out_channels] for ch in backbone_channels] 15 | 16 | def forward(self, x): 17 | inp = collections.OrderedDict([('feat{}'.format(i), el) for i, el in enumerate(x)]) 18 | output = self.fpn(inp) 19 | output = list(output.values()) 20 | 21 | return output 22 | 23 | 24 | class Upsample(torch.nn.Module): 25 | # Computes an output feature map at 1x the input resolution. 26 | # It just applies a series of transpose convolution layers on the 27 | # highest resolution features from the backbone (FPN should be applied first). 28 | 29 | def __init__(self, backbone_channels): 30 | super(Upsample, self).__init__() 31 | self.in_channels = backbone_channels 32 | 33 | out_channels = backbone_channels[0][1] 34 | self.out_channels = [(1, out_channels)] + backbone_channels 35 | 36 | layers = [] 37 | depth, ch = backbone_channels[0] 38 | while depth > 1: 39 | next_ch = max(ch//2, out_channels) 40 | layer = torch.nn.Sequential( 41 | torch.nn.Conv2d(ch, ch, 3, padding=1), 42 | torch.nn.ReLU(inplace=True), 43 | torch.nn.ConvTranspose2d(ch, next_ch, 4, stride=2, padding=1), 44 | torch.nn.ReLU(inplace=True), 45 | ) 46 | layers.append(layer) 47 | ch = next_ch 48 | depth /= 2 49 | 50 | self.layers = torch.nn.Sequential(*layers) 51 | 52 | def forward(self, x): 53 | output = self.layers(x[0]) 54 | return [output] + x 55 | -------------------------------------------------------------------------------- /satlaspretrain_models/models/heads.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import math 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | import torchvision 7 | 8 | 9 | class NoopTransform(torch.nn.Module): 10 | def __init__(self): 11 | super(NoopTransform, self).__init__() 12 | 13 | self.transform = torchvision.models.detection.transform.GeneralizedRCNNTransform( 14 | min_size=800, 15 | max_size=800, 16 | image_mean=[], 17 | image_std=[], 18 | ) 19 | 20 | def forward(self, images, targets): 21 | images = self.transform.batch_images(images, size_divisible=32) 22 | image_sizes = [(image.shape[1], image.shape[2]) for image in images] 23 | image_list = torchvision.models.detection.image_list.ImageList(images, image_sizes) 24 | return image_list, targets 25 | 26 | def postprocess(self, detections, image_sizes, orig_sizes): 27 | return detections 28 | 29 | 30 | class FRCNNHead(torch.nn.Module): 31 | def __init__(self, task, backbone_channels, num_categories=2): 32 | super(FRCNNHead, self).__init__() 33 | 34 | self.task_type = task 35 | self.use_layers = list(range(len(backbone_channels))) 36 | num_channels = backbone_channels[self.use_layers[0]][1] 37 | featmap_names = ['feat{}'.format(i) for i in range(len(self.use_layers))] 38 | num_classes = num_categories 39 | 40 | self.noop_transform = NoopTransform() 41 | 42 | # RPN 43 | anchor_sizes = [[32], [64], [128], [256]] 44 | aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) 45 | rpn_anchor_generator = torchvision.models.detection.anchor_utils.AnchorGenerator(anchor_sizes, aspect_ratios) 46 | rpn_head = torchvision.models.detection.rpn.RPNHead(num_channels, rpn_anchor_generator.num_anchors_per_location()[0]) 47 | rpn_fg_iou_thresh = 0.7 48 | rpn_bg_iou_thresh = 0.3 49 | rpn_batch_size_per_image = 256 50 | rpn_positive_fraction = 0.5 51 | rpn_pre_nms_top_n = dict(training=2000, testing=2000) 52 | rpn_post_nms_top_n = dict(training=2000, testing=2000) 53 | rpn_nms_thresh = 0.7 54 | self.rpn = torchvision.models.detection.rpn.RegionProposalNetwork( 55 | rpn_anchor_generator, 56 | rpn_head, 57 | rpn_fg_iou_thresh, 58 | rpn_bg_iou_thresh, 59 | rpn_batch_size_per_image, 60 | rpn_positive_fraction, 61 | rpn_pre_nms_top_n, 62 | rpn_post_nms_top_n, 63 | rpn_nms_thresh, 64 | ) 65 | 66 | # ROI 67 | box_roi_pool = torchvision.ops.MultiScaleRoIAlign(featmap_names=featmap_names, output_size=7, sampling_ratio=2) 68 | box_head = torchvision.models.detection.faster_rcnn.TwoMLPHead(backbone_channels[0][1] * box_roi_pool.output_size[0] ** 2, 1024) 69 | box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(1024, num_classes) 70 | box_fg_iou_thresh = 0.5 71 | box_bg_iou_thresh = 0.5 72 | box_batch_size_per_image = 512 73 | box_positive_fraction = 0.25 74 | bbox_reg_weights = None 75 | box_score_thresh = 0.05 76 | box_nms_thresh = 0.5 77 | box_detections_per_img = 100 78 | self.roi_heads = torchvision.models.detection.roi_heads.RoIHeads( 79 | box_roi_pool, 80 | box_head, 81 | box_predictor, 82 | box_fg_iou_thresh, 83 | box_bg_iou_thresh, 84 | box_batch_size_per_image, 85 | box_positive_fraction, 86 | bbox_reg_weights, 87 | box_score_thresh, 88 | box_nms_thresh, 89 | box_detections_per_img, 90 | ) 91 | 92 | if self.task_type == 'instance': 93 | # Use Mask R-CNN stuff. 94 | self.roi_heads.mask_roi_pool = torchvision.ops.MultiScaleRoIAlign(featmap_names=featmap_names, output_size=14, sampling_ratio=2) 95 | 96 | mask_layers = (256, 256, 256, 256) 97 | mask_dilation = 1 98 | self.roi_heads.mask_head = torchvision.models.detection.mask_rcnn.MaskRCNNHeads(backbone_channels[0][1], mask_layers, mask_dilation) 99 | 100 | mask_predictor_in_channels = 256 101 | mask_dim_reduced = 256 102 | self.roi_heads.mask_predictor = torchvision.models.detection.mask_rcnn.MaskRCNNPredictor(mask_predictor_in_channels, mask_dim_reduced, num_classes) 103 | 104 | def forward(self, image_list, raw_features, targets=None): 105 | device = image_list[0].device 106 | images, targets = self.noop_transform(image_list, targets) 107 | 108 | features = collections.OrderedDict() 109 | for i, idx in enumerate(self.use_layers): 110 | features['feat{}'.format(i)] = raw_features[idx] 111 | 112 | proposals, proposal_losses = self.rpn(images, features, targets) 113 | detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets) 114 | 115 | losses = {'base': torch.tensor(0, device=device, dtype=torch.float32)} 116 | losses.update(proposal_losses) 117 | losses.update(detector_losses) 118 | 119 | loss = sum(x for x in losses.values()) 120 | return detections, loss 121 | 122 | 123 | class SimpleHead(torch.nn.Module): 124 | def __init__(self, task, backbone_channels, num_categories=2): 125 | super(SimpleHead, self).__init__() 126 | 127 | self.task_type = task 128 | 129 | use_channels = backbone_channels[0][1] 130 | num_layers = 2 131 | self.num_outputs = num_categories 132 | if self.num_outputs is None: 133 | if task_type == 'regress': 134 | self.num_outputs = 1 135 | else: 136 | self.num_outputs = 2 137 | 138 | layers = [] 139 | for _ in range(num_layers-1): 140 | layer = torch.nn.Sequential( 141 | torch.nn.Conv2d(use_channels, use_channels, 3, padding=1), 142 | torch.nn.ReLU(inplace=True), 143 | ) 144 | layers.append(layer) 145 | 146 | if self.task_type == 'segment': 147 | layers.append(torch.nn.Conv2d(use_channels, self.num_outputs, 3, padding=1)) 148 | self.loss_func = lambda logits, targets: torch.nn.functional.cross_entropy(logits, targets, reduction='none') 149 | 150 | elif self.task_type == 'bin_segment': 151 | layers.append(torch.nn.Conv2d(use_channels, self.num_outputs, 3, padding=1)) 152 | def loss_func(logits, targets): 153 | targets = targets.argmax(dim=1) 154 | return torch.nn.functional.cross_entropy(logits, targets, reduction='none')[:, None, :, :] 155 | self.loss_func = loss_func 156 | 157 | elif self.task_type == 'regress': 158 | layers.append(torch.nn.Conv2d(use_channels, self.num_outputs, 3, padding=1)) 159 | self.loss_func = lambda outputs, targets: torch.square(outputs - targets) 160 | 161 | elif self.task_type == 'classification': 162 | self.extra = torch.nn.Linear(use_channels, self.num_outputs) 163 | self.loss_func = lambda logits, targets: torch.nn.functional.cross_entropy(logits, targets, reduction='none') 164 | 165 | elif self.task_type == 'multi-label-classification': 166 | self.extra = torch.nn.Linear(use_channels, self.num_outputs) 167 | self.loss_func = lambda logits, targets: torch.nn.functional.binary_cross_entropy_with_logits(logits, targets, reduction='none') 168 | 169 | self.layers = torch.nn.Sequential(*layers) 170 | 171 | def forward(self, image_list, raw_features, targets=None): 172 | raw_outputs = self.layers(raw_features[0]) 173 | loss = None 174 | 175 | if self.task_type == 'segment': 176 | outputs = torch.nn.functional.softmax(raw_outputs, dim=1) 177 | 178 | if targets is not None: 179 | task_targets = torch.stack([target for target in targets], dim=0).long() 180 | loss = self.loss_func(raw_outputs, task_targets) 181 | loss = loss.mean() 182 | 183 | elif self.task_type == 'bin_segment': 184 | outputs = torch.nn.functional.softmax(raw_outputs, dim=1) 185 | 186 | if targets is not None: 187 | task_targets = torch.stack([target for target in targets], dim=0).long() 188 | loss = self.loss_func(raw_outputs, task_targets) 189 | loss = loss.mean() 190 | 191 | elif self.task_type == 'regress': 192 | raw_outputs = raw_outputs[:, 0, :, :] 193 | outputs = 255*raw_outputs 194 | 195 | if targets is not None: 196 | task_targets = torch.stack([target for target in targets], dim=0).long() 197 | loss = self.loss_func(raw_outputs, task_targets.float()/255) 198 | loss = loss.mean() 199 | 200 | elif self.task_type == 'classification': 201 | features = torch.amax(raw_outputs, dim=(2,3)) 202 | logits = self.extra(features) 203 | outputs = torch.nn.functional.softmax(logits, dim=1) 204 | 205 | if targets is not None: 206 | task_targets = torch.stack([target for target in targets], dim=0).long() 207 | loss = self.loss_func(logits, task_targets) 208 | loss = loss.mean() 209 | 210 | elif self.task_type == 'multi-label-classification': 211 | features = torch.amax(raw_outputs, dim=(2,3)) 212 | logits = self.extra(features) 213 | outputs = torch.sigmoid(logits) 214 | 215 | if targets is not None: 216 | task_targets = torch.stack([target for target in targets], dim=0).long() 217 | loss = self.loss_func(logits, task_targets) 218 | loss = loss.mean() 219 | 220 | return outputs, loss 221 | 222 | -------------------------------------------------------------------------------- /satlaspretrain_models/utils.py: -------------------------------------------------------------------------------- 1 | from enum import Enum, auto 2 | 3 | class Backbone(Enum): 4 | SWINB = auto() 5 | SWINT = auto() 6 | RESNET50 = auto() 7 | RESNET152 = auto() 8 | 9 | class Head(Enum): 10 | CLASSIFY = auto() 11 | MULTICLASSIFY = auto() 12 | DETECT = auto() 13 | INSTANCE = auto() 14 | SEGMENT = auto() 15 | BINSEGMENT = auto() 16 | REGRESS = auto() 17 | 18 | # Dictionary of arguments needed to load in each SatlasPretrain pretrained model. 19 | SatlasPretrain_weights = { 20 | 'Sentinel2_SwinB_SI_RGB': { 21 | 'url': 'https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel2_swinb_si_rgb.pth?download=true', 22 | 'backbone': Backbone.SWINB, 23 | 'num_channels': 3, 24 | 'multi_image': False 25 | }, 26 | 'Sentinel2_SwinB_MI_RGB': { 27 | 'url': 'https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel2_swinb_mi_rgb.pth?download=true', 28 | 'backbone': Backbone.SWINB, 29 | 'num_channels': 3, 30 | 'multi_image': True 31 | }, 32 | 'Sentinel2_SwinB_SI_MS': { 33 | 'url': 'https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel2_swinb_si_ms.pth?download=true', 34 | 'backbone': Backbone.SWINB, 35 | 'num_channels': 9, 36 | 'multi_image': False 37 | }, 38 | 'Sentinel2_SwinB_MI_MS': { 39 | 'url': 'https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel2_swinb_mi_ms.pth?download=true', 40 | 'backbone': Backbone.SWINB, 41 | 'num_channels': 9, 42 | 'multi_image': True 43 | }, 44 | 'Sentinel1_SwinB_SI': { 45 | 'url': 'https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel1_swinb_si.pth?download=true', 46 | 'backbone': Backbone.SWINB, 47 | 'num_channels': 2, 48 | 'multi_image': False 49 | }, 50 | 'Sentinel1_SwinB_MI': { 51 | 'url': 'https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel1_swinb_mi.pth?download=true', 52 | 'backbone': Backbone.SWINB, 53 | 'num_channels': 2, 54 | 'multi_image': True 55 | }, 56 | 'Landsat_SwinB_SI': { 57 | 'url': 'https://huggingface.co/allenai/satlas-pretrain/resolve/main/landsat_swinb_si.pth?download=true', 58 | 'backbone': Backbone.SWINB, 59 | 'num_channels': 11, 60 | 'multi_image': False 61 | }, 62 | 'Landsat_SwinB_MI': { 63 | 'url': 'https://huggingface.co/allenai/satlas-pretrain/resolve/main/landsat_swinb_mi.pth?download=true', 64 | 'backbone': Backbone.SWINB, 65 | 'num_channels': 11, 66 | 'multi_image': True 67 | }, 68 | 'Aerial_SwinB_SI': { 69 | 'url': 'https://huggingface.co/allenai/satlas-pretrain/resolve/main/aerial_swinb_si.pth?download=true', 70 | 'backbone': Backbone.SWINB, 71 | 'num_channels': 3, 72 | 'multi_image': False 73 | }, 74 | 'Aerial_SwinB_MI': { 75 | 'url':'https://huggingface.co/allenai/satlas-pretrain/resolve/main/aerial_swinb_mi.pth?download=true', 76 | 'backbone': Backbone.SWINB, 77 | 'num_channels': 3, 78 | 'multi_image': True 79 | }, 80 | 'Sentinel2_SwinT_SI_RGB': { 81 | 'url': 'https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel2_swint_si_rgb.pth?download=true', 82 | 'backbone': Backbone.SWINT, 83 | 'num_channels': 3, 84 | 'multi_image': False 85 | }, 86 | 'Sentinel2_SwinT_SI_MS': { 87 | 'url': 'https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel2_swint_si_ms.pth?download=true', 88 | 'backbone': Backbone.SWINT, 89 | 'num_channels': 9, 90 | 'multi_image': False 91 | }, 92 | 'Sentinel2_SwinT_MI_RGB': { 93 | 'url': 'https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel2_swint_mi_rgb.pth?download=true', 94 | 'backbone': Backbone.SWINT, 95 | 'num_channels': 3, 96 | 'multi_image': True 97 | }, 98 | 'Sentinel2_SwinT_MI_MS': { 99 | 'url': 'https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel2_swint_mi_ms.pth?download=true', 100 | 'backbone': Backbone.SWINT, 101 | 'num_channels': 9, 102 | 'multi_image': True 103 | }, 104 | 'Sentinel2_Resnet50_SI_RGB': { 105 | 'url': 'https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel2_resnet50_si_rgb.pth?download=true', 106 | 'backbone': Backbone.RESNET50, 107 | 'num_channels': 3, 108 | 'multi_image': False 109 | }, 110 | 'Sentinel2_Resnet50_SI_MS': { 111 | 'url': 'https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel2_resnet50_si_ms.pth?download=true', 112 | 'backbone': Backbone.RESNET50, 113 | 'num_channels': 9, 114 | 'multi_image': False 115 | }, 116 | 'Sentinel2_Resnet50_MI_RGB': { 117 | 'url': 'https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel2_resnet50_mi_rgb.pth?download=true', 118 | 'backbone': Backbone.RESNET50, 119 | 'num_channels': 3, 120 | 'multi_image': True 121 | }, 122 | 'Sentinel2_Resnet50_MI_MS': { 123 | 'url': 'https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel2_resnet50_mi_ms.pth?download=true', 124 | 'backbone': Backbone.RESNET50, 125 | 'num_channels': 9, 126 | 'multi_image': True 127 | }, 128 | 'Sentinel2_Resnet152_SI_RGB': { 129 | 'url': 'https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel2_resnet152_si_rgb.pth?download=true', 130 | 'backbone': Backbone.RESNET152, 131 | 'num_channels': 3, 132 | 'multi_image': False 133 | }, 134 | 'Sentinel2_Resnet152_SI_MS': { 135 | 'url': 'https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel2_resnet152_si_ms.pth?download=true', 136 | 'backbone': Backbone.RESNET152, 137 | 'num_channels': 9, 138 | 'multi_image': False 139 | }, 140 | 'Sentinel2_Resnet152_MI_RGB': { 141 | 'url': 'https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel2_resnet152_mi_rgb.pth?download=true', 142 | 'backbone': Backbone.RESNET152, 143 | 'num_channels': 3, 144 | 'multi_image': True 145 | }, 146 | 'Sentinel2_Resnet152_MI_MS': { 147 | 'url': 'https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel2_resnet152_mi_ms.pth?download=true', 148 | 'backbone': Backbone.RESNET152, 149 | 'num_channels': 9, 150 | 'multi_image': True 151 | }, 152 | } 153 | 154 | 155 | def adjust_state_dict_prefix(state_dict, needed, prefix=None, prefix_allowed_count=None): 156 | """ 157 | Adjusts the keys in the state dictionary by replacing 'backbone.backbone' prefix with 'backbone'. 158 | 159 | Args: 160 | state_dict (dict): Original state dictionary with 'backbone.backbone' prefixes. 161 | 162 | Returns: 163 | dict: Modified state dictionary with corrected prefixes. 164 | """ 165 | new_state_dict = {} 166 | for key, value in state_dict.items(): 167 | # Assure we're only keeping keys that we need for the current model component. 168 | if not needed in key: 169 | continue 170 | 171 | # Update the key prefixes to match what the model expects. 172 | if prefix is not None: 173 | while key.count(prefix) > prefix_allowed_count: 174 | key = key.replace(prefix, '', 1) 175 | 176 | new_state_dict[key] = value 177 | return new_state_dict 178 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="satlaspretrain_models", 5 | version="0.3.1", 6 | author="Satlas @ AI2", 7 | author_email="satlas@allenai.org", 8 | description="A simple package that makes it easy to load remote sensing foundation models for downstream use cases.", 9 | long_description=open('README.md').read(), 10 | long_description_content_type="text/markdown", 11 | url="https://github.com/allenai/satlaspretrain_models", 12 | packages=find_packages(), 13 | classifiers=[ 14 | "Programming Language :: Python :: 3", 15 | "License :: OSI Approved :: Apache Software License", 16 | "Operating System :: OS Independent", 17 | ], 18 | python_requires='>=3.9', 19 | install_requires=[ 20 | 'torch>=2.1.0', 21 | 'torchvision>=0.16.0', 22 | 'requests', 23 | 'matplotlib' 24 | ], 25 | ) 26 | 27 | -------------------------------------------------------------------------------- /tests/test_pretrained_models.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from satlaspretrain_models.model import Weights 5 | from satlaspretrain_models.utils import SatlasPretrain_weights, Head 6 | 7 | # Fixture for weights manager 8 | @pytest.fixture(scope="module") 9 | def weights_manager(): 10 | return Weights() 11 | 12 | # Test loading pretrained backbone models without FPN or Head 13 | @pytest.mark.parametrize("model_id", SatlasPretrain_weights.keys()) 14 | def test_pretrained_backbone(weights_manager, model_id): 15 | model_info = SatlasPretrain_weights[model_id] 16 | model = weights_manager.get_pretrained_model(model_id) 17 | rand_img = torch.rand((8, model_info['num_channels'], 128, 128)).float() 18 | output = model(rand_img) 19 | assert output is not None 20 | 21 | # Test loading pretrained backbone models with FPN, without Head 22 | @pytest.mark.parametrize("model_id", SatlasPretrain_weights.keys()) 23 | def test_pretrained_backbone_with_fpn(weights_manager, model_id): 24 | model_info = SatlasPretrain_weights[model_id] 25 | model = weights_manager.get_pretrained_model(model_id, fpn=True) 26 | rand_img = torch.rand((8, model_info['num_channels'], 128, 128)).float() 27 | output = model(rand_img) 28 | assert output is not None 29 | 30 | # Test loading pretrained backbones with FPN and every possible Head 31 | @pytest.mark.parametrize("model_id,head", [(model_id, head) for model_id in SatlasPretrain_weights.keys() for head in Head]) 32 | def test_pretrained_backbone_with_fpn_and_head(weights_manager, model_id, head): 33 | model_info = SatlasPretrain_weights[model_id] 34 | model = weights_manager.get_pretrained_model(model_id, fpn=True, head=head, num_categories=2) 35 | rand_img = torch.rand((1, model_info['num_channels'], 128, 128)).float() 36 | 37 | rand_targets = None 38 | if head == Head.DETECT: 39 | rand_targets = [{ 40 | 'boxes': torch.tensor([[100, 100, 110, 110], [30, 30, 40, 40]], dtype=torch.float32), 41 | 'labels': torch.tensor([0,1], dtype=torch.int64) 42 | }] 43 | elif head == Head.INSTANCE: 44 | rand_targets = [{ 45 | 'boxes': torch.tensor([[100, 100, 110, 110], [30, 30, 40, 40]], dtype=torch.float32), 46 | 'labels': torch.tensor([0,1], dtype=torch.int64), 47 | 'masks': torch.zeros_like(rand_img) 48 | }] 49 | elif head == Head.BINSEGMENT: 50 | rand_targets = torch.zeros((1, 2, 32, 32)) 51 | elif head == Head.REGRESS: 52 | rand_targets = torch.zeros((1, 2, 32, 32)).float() 53 | elif head == Head.CLASSIFY: 54 | rand_targets = torch.tensor([1]) 55 | 56 | # TODO: add rand_targets for SEGMENT and MULTICLASSIFY 57 | 58 | if rand_targets is not None: 59 | output, loss = model(rand_img, rand_targets) 60 | assert output is not None 61 | assert loss is not None 62 | else: 63 | output = model(rand_img) 64 | assert output is not None 65 | -------------------------------------------------------------------------------- /tests/test_randomly_initialized_models.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from satlaspretrain_models.model import Model 5 | from satlaspretrain_models.utils import Backbone, Head 6 | 7 | # Test loading randomly initialized backbone models without FPN or Head 8 | @pytest.mark.parametrize("backbone", [backbone for backbone in Backbone]) 9 | def test_random_backbone(backbone): 10 | model = Model(num_channels=3, multi_image=False, backbone=backbone, fpn=False, head=None, num_categories=None, weights=None) 11 | rand_img = torch.rand((8, 3, 128, 128)).float() 12 | output = model(rand_img) 13 | assert output is not None 14 | 15 | # Test loading randomly initialized backbone models with FPN, without Head 16 | @pytest.mark.parametrize("backbone", [backbone for backbone in Backbone]) 17 | def test_random_backbone_with_fpn(backbone): 18 | model = Model(num_channels=3, multi_image=False, backbone=backbone, fpn=True, head=None, num_categories=None, weights=None) 19 | rand_img = torch.rand((8, 3, 128, 128)).float() 20 | output = model(rand_img) 21 | assert output is not None 22 | 23 | # Test loading pretrained backbones with FPN and every possible Head 24 | @pytest.mark.parametrize("backbone,head", [(backbone, head) for backbone in Backbone for head in Head]) 25 | def test_random_backbone_with_fpn_and_head(backbone, head): 26 | model = Model(num_channels=3, multi_image=False, backbone=backbone, fpn=True, head=head, num_categories=2, weights=None) 27 | rand_img = torch.rand((1, 3, 128, 128)).float() 28 | 29 | rand_targets = None 30 | if head == Head.DETECT: 31 | rand_targets = [{ 32 | 'boxes': torch.tensor([[100, 100, 110, 110], [30, 30, 40, 40]], dtype=torch.float32), 33 | 'labels': torch.tensor([0,1], dtype=torch.int64) 34 | }] 35 | elif head == Head.INSTANCE: 36 | rand_targets = [{ 37 | 'boxes': torch.tensor([[100, 100, 110, 110], [30, 30, 40, 40]], dtype=torch.float32), 38 | 'labels': torch.tensor([0,1], dtype=torch.int64), 39 | 'masks': torch.zeros_like(rand_img) 40 | }] 41 | elif head == Head.BINSEGMENT: 42 | rand_targets = torch.zeros((1, 2, 32, 32)) 43 | elif head == Head.REGRESS: 44 | rand_targets = torch.zeros((1, 2, 32, 32)).float() 45 | elif head == Head.CLASSIFY: 46 | rand_targets = torch.tensor([1]) 47 | 48 | # TODO: add rand_targets for SEGMENT and MULTICLASSIFY 49 | 50 | if rand_targets is not None: 51 | output, loss = model(rand_img, rand_targets) 52 | assert output is not None 53 | assert loss is not None 54 | else: 55 | output = model(rand_img) 56 | assert output is not None 57 | -------------------------------------------------------------------------------- /torchgeo_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "2d0f533d-16af-4b24-91f0-31b338c91fcf", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "# Demo using the torchgeo package to initialize a SatlasPretrain model and finetune\n", 11 | "# on the UCMerced dataset.\n", 12 | "#\n", 13 | "# SETUP - this demo requires a DIFFERENT conda environment than the SatlasPretrain demo\n", 14 | "# conda create --name torchgeodemo python=3.12\n", 15 | "# conda activate torchgeodemo\n", 16 | "# NOTE: Satlas weights will be a part of the 0.6.0 release and the current version is 0.5.1, so install from git for now.\n", 17 | "# pip install git+https://github.com/microsoft/torchgeo " 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 1, 23 | "id": "291fd9c9-66cd-4643-a387-b6a8a7ec7754", 24 | "metadata": {}, 25 | "outputs": [ 26 | { 27 | "name": "stderr", 28 | "output_type": "stream", 29 | "text": [ 30 | "/Users/piperw/opt/anaconda3/envs/torchgeotest/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 31 | " from .autonotebook import tqdm as notebook_tqdm\n" 32 | ] 33 | } 34 | ], 35 | "source": [ 36 | "import os\n", 37 | "import torch\n", 38 | "import tempfile\n", 39 | "from typing import Optional\n", 40 | "from lightning.pytorch import Trainer\n", 41 | "\n", 42 | "from torchgeo.models import Swin_V2_B_Weights, swin_v2_b\n", 43 | "from torchgeo.datamodules import UCMercedDataModule\n", 44 | "from torchgeo.trainers import ClassificationTask" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 2, 50 | "id": "3ed5b26a-e461-4258-936a-d26c79d1e4f9", 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "# Experiment Arguments\n", 55 | "batch_size = 8\n", 56 | "num_workers = 2\n", 57 | "max_epochs = 10\n", 58 | "fast_dev_run = False" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 3, 64 | "id": "4bb46a02-8fa4-4e29-8f0d-1c0ce3de10f4", 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "# Torchgeo lightning datamodule to initialize dataset\n", 69 | "root = os.path.join(tempfile.gettempdir(), \"ucm\")\n", 70 | "datamodule = UCMercedDataModule(\n", 71 | " root=root, batch_size=batch_size, num_workers=num_workers, download=True\n", 72 | ")" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 11, 78 | "id": "aac6a7a3-684a-4f8a-aece-6d4eaa55772e", 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "# Custom ClassificationTask to load in the SatlasPretrain model\n", 83 | "class SatlasClassificationTask(ClassificationTask):\n", 84 | " def configure_models(self):\n", 85 | " weights = Swin_V2_B_Weights.SENTINEL2_RGB_SI_SATLAS\n", 86 | " self.model = swin_v2_b(weights)\n", 87 | "\n", 88 | " # Replace first layer's input channels with the task's number input channels.\n", 89 | " first_layer = self.model.features[0][0]\n", 90 | " self.model.features[0][0] = torch.nn.Conv2d(3,\n", 91 | " first_layer.out_channels,\n", 92 | " kernel_size=first_layer.kernel_size,\n", 93 | " stride=first_layer.stride,\n", 94 | " padding=first_layer.padding,\n", 95 | " bias=(first_layer.bias is not None))\n", 96 | "\n", 97 | " # Replace last layer's output features with the number classes.\n", 98 | " self.model.head = torch.nn.Linear(in_features=1024, out_features=self.hparams[\"num_classes\"], bias=True)\n", 99 | "\n", 100 | " def on_validation_epoch_end(self):\n", 101 | " # Accessing metrics logged during the current validation epoch\n", 102 | " val_loss = self.trainer.callback_metrics.get('val_loss', 'N/A')\n", 103 | " val_acc = self.trainer.callback_metrics.get('val_OverallAccuracy', 'N/A')\n", 104 | " print(f\"Epoch {self.current_epoch} Validation - Loss: {val_loss}, Accuracy: {val_acc}\")\n", 105 | "\n", 106 | " def on_validation_epoch_end(self):\n", 107 | " # Accessing metrics logged during the current validation epoch\n", 108 | " val_loss = self.trainer.callback_metrics.get('val_loss', 'N/A')\n", 109 | " val_acc = self.trainer.callback_metrics.get('val_OverallAccuracy', 'N/A')\n", 110 | " print(f\"Epoch {self.current_epoch} Validation - Loss: {val_loss}, Accuracy: {val_acc}\")\n" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 12, 116 | "id": "b4c09027-6c47-42f6-8d06-e3fbeb9d8cab", 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "# Initialize the Classifcation Task\n", 121 | "task = SatlasClassificationTask(num_classes=21)" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 13, 127 | "id": "f88e3e92-d30e-48d0-9c74-20efc91f4459", 128 | "metadata": {}, 129 | "outputs": [ 130 | { 131 | "name": "stderr", 132 | "output_type": "stream", 133 | "text": [ 134 | "GPU available: False, used: False\n", 135 | "TPU available: False, using: 0 TPU cores\n", 136 | "IPU available: False, using: 0 IPUs\n", 137 | "HPU available: False, using: 0 HPUs\n" 138 | ] 139 | } 140 | ], 141 | "source": [ 142 | "# Initialize the training code.\n", 143 | "accelerator = \"gpu\" if torch.cuda.is_available() else \"cpu\"\n", 144 | "default_root_dir = os.path.join(tempfile.gettempdir(), \"experiments\")\n", 145 | "\n", 146 | "trainer = Trainer(\n", 147 | " accelerator=accelerator,\n", 148 | " default_root_dir=default_root_dir,\n", 149 | " fast_dev_run=fast_dev_run,\n", 150 | " log_every_n_steps=1,\n", 151 | " min_epochs=1,\n", 152 | " max_epochs=max_epochs,\n", 153 | ")" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": 14, 159 | "id": "aa6a6684-cf14-4fa2-b801-f3d35b94519d", 160 | "metadata": {}, 161 | "outputs": [ 162 | { 163 | "name": "stderr", 164 | "output_type": "stream", 165 | "text": [ 166 | "\n", 167 | " | Name | Type | Params\n", 168 | "---------------------------------------------------\n", 169 | "0 | criterion | CrossEntropyLoss | 0 \n", 170 | "1 | train_metrics | MetricCollection | 0 \n", 171 | "2 | val_metrics | MetricCollection | 0 \n", 172 | "3 | test_metrics | MetricCollection | 0 \n", 173 | "4 | model | SwinTransformer | 86.9 M\n", 174 | "---------------------------------------------------\n", 175 | "86.9 M Trainable params\n", 176 | "0 Non-trainable params\n", 177 | "86.9 M Total params\n", 178 | "347.709 Total estimated model params size (MB)\n" 179 | ] 180 | }, 181 | { 182 | "name": "stdout", 183 | "output_type": "stream", 184 | "text": [ 185 | "Sanity Checking DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:03<00:00, 0.60it/s]Epoch 0 Validation - Loss: 3.1346988677978516, Accuracy: 0.0\n", 186 | " \r" 187 | ] 188 | }, 189 | { 190 | "name": "stderr", 191 | "output_type": "stream", 192 | "text": [ 193 | "/Users/piperw/opt/anaconda3/envs/torchgeotest/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:436: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.\n" 194 | ] 195 | }, 196 | { 197 | "name": "stdout", 198 | "output_type": "stream", 199 | "text": [ 200 | "Epoch 0: 78%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 124/158 [10:06<02:46, 0.20it/s, v_num=1]" 201 | ] 202 | }, 203 | { 204 | "name": "stderr", 205 | "output_type": "stream", 206 | "text": [ 207 | "/Users/piperw/opt/anaconda3/envs/torchgeotest/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...\n" 208 | ] 209 | } 210 | ], 211 | "source": [ 212 | "# Train\n", 213 | "trainer.fit(model=task, datamodule=datamodule)" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": null, 219 | "id": "3034caca-1f6b-434d-9e01-6cadf98fb1ac", 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [] 223 | } 224 | ], 225 | "metadata": { 226 | "kernelspec": { 227 | "display_name": "Python 3 (ipykernel)", 228 | "language": "python", 229 | "name": "python3" 230 | }, 231 | "language_info": { 232 | "codemirror_mode": { 233 | "name": "ipython", 234 | "version": 3 235 | }, 236 | "file_extension": ".py", 237 | "mimetype": "text/x-python", 238 | "name": "python", 239 | "nbconvert_exporter": "python", 240 | "pygments_lexer": "ipython3", 241 | "version": "3.12.0" 242 | } 243 | }, 244 | "nbformat": 4, 245 | "nbformat_minor": 5 246 | } 247 | --------------------------------------------------------------------------------