├── .binder └── environment.yml ├── .github └── workflows │ └── deploy-docs.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── cluster └── ml-cluster.yaml.template ├── configs ├── classify_eurosat.yaml ├── config.yaml ├── metadata.yaml ├── regression_biomasters.yaml └── segment_chesapeake.yaml ├── copy_data.sh ├── docs ├── _config.yml ├── _toc.yml ├── clay-v0 │ ├── changelog-v0.1.md │ ├── clay-v0-interpolation.ipynb │ ├── clay-v0-location-embeddings.ipynb │ ├── clay-v0-reconstruction.ipynb │ ├── data_datacube.md │ ├── data_labels.md │ ├── model_embeddings.md │ ├── model_finetuning.md │ ├── partial-inputs-flood-tutorial.ipynb │ ├── partial-inputs.ipynb │ ├── patch_level_cloud_cover.ipynb │ ├── run_region.md │ ├── specification-v0.md │ ├── tutorial_digital_earth_pacific_patch_level.ipynb │ └── worldcover-embeddings.md ├── finetune │ ├── classify.md │ ├── finetune-on-embeddings.ipynb │ ├── regression.md │ └── segment.md ├── getting-started │ ├── basic_use.md │ └── installation.md ├── index.md ├── logo.png ├── release-notes │ ├── changelog-v1.0.md │ ├── data_sampling.md │ └── specification.md └── tutorials │ ├── embeddings.ipynb │ ├── inference.ipynb │ ├── reconstruction.ipynb │ └── wall-to-wall.ipynb ├── environment.yml ├── finetune ├── classify │ ├── classify.py │ ├── eurosat_datamodule.py │ ├── eurosat_model.py │ └── factory.py ├── embedder │ ├── factory.py │ └── how-to-embed.ipynb ├── regression │ ├── biomasters_datamodule.py │ ├── biomasters_inference.ipynb │ ├── biomasters_model.py │ ├── factory.py │ ├── preprocess_data.py │ └── regression.py └── segment │ ├── chesapeake_datamodule.py │ ├── chesapeake_inference.ipynb │ ├── chesapeake_model.py │ ├── factory.py │ ├── preprocess_data.py │ └── segment.py ├── ruff.toml ├── src ├── __init__.py ├── backbone.py ├── callbacks.py ├── callbacks_wandb.py ├── datamodule.py ├── factory.py ├── model.py ├── module.py ├── mrl.py └── utils.py ├── train_clay.sh ├── trainer.py └── utils ├── check_data_sanity.py └── split_npz.py /.binder/environment.yml: -------------------------------------------------------------------------------- 1 | name: claymodel 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | - nodefaults 6 | dependencies: 7 | - conda-lock~=2.5.6 8 | - einops~=0.7.0 9 | - fiona~=1.9.5 10 | - geopandas-base~=0.14.1 11 | - h5netcdf~=1.3.0 12 | - jupyter-book~=1.0.0 13 | - jupyterlab~=4.0.7 14 | - jsonargparse~=4.27.0 15 | - lightning~=2.1.0 16 | - matplotlib-base~=3.8.2 17 | - planetary-computer~=1.0.0 18 | - python-box~=7.1.0 19 | - pytorch~=2.1.0 20 | - cpuonly 21 | - python~=3.11.0 22 | - pyarrow~=16.1.0 23 | - rioxarray~=0.15.0 24 | - rasterio~=1.3.10 25 | - s3fs~=2024.3.1 26 | - scikit-image~=0.22.0 27 | - scikit-learn~=1.4.0 28 | - stackstac~=0.5.0 29 | - timm~=0.9.16 30 | - torchdata~=0.7.1 31 | - torchvision~=0.16.1 32 | - transformers~=4.35.2 33 | - typeshed-client~=2.4.0 34 | - vit-pytorch~=1.6.4 35 | - wandb~=0.15.12 36 | - zarr~=2.16.1 37 | -------------------------------------------------------------------------------- /.github/workflows/deploy-docs.yml: -------------------------------------------------------------------------------- 1 | # https://jupyterbook.org/en/stable/publish/gh-pages.html#automatically-host-your-book-with-github-actions 2 | name: deploy-book 3 | 4 | # Only run this when the main branch changes 5 | on: 6 | # Uncomment the 'pull_request' line below to manually re-build Jupyter Book 7 | # pull_request: 8 | push: 9 | branches: 10 | - main 11 | # If your git repository has the Jupyter Book within some-subfolder next to 12 | # unrelated files, you can make this run only if a file within that specific 13 | # folder has been modified. 14 | paths: 15 | - '.github/workflows/deploy-book.yml' 16 | - 'docs/**' 17 | 18 | # Allow one concurrent deployment 19 | concurrency: 20 | group: "pages" 21 | cancel-in-progress: true 22 | 23 | jobs: 24 | # This job installs dependencies, builds the book, and pushes it to `gh-pages` 25 | build-book: 26 | runs-on: ubuntu-22.04 27 | defaults: 28 | run: 29 | shell: bash -l {0} 30 | 31 | steps: 32 | # Checkout current git repository 33 | - name: Checkout 34 | uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 35 | 36 | # Enable GitHub Pages and extract various metadata about a site 37 | - name: Setup Pages 38 | uses: actions/configure-pages@1f0c5cde4bc74cd7e1254d0cb4de8d49e9068c7d # v4.0.0 39 | 40 | # Install Micromamba with conda-forge dependencies 41 | - name: Setup Micromamba 42 | uses: mamba-org/setup-micromamba@e820223f89c8720d6c740ca154a7adf32fcd278a # v1.7.3 43 | with: 44 | environment-name: claymodel 45 | condarc: | 46 | channels: 47 | - conda-forge 48 | - nodefaults 49 | create-args: >- 50 | python=3.11 51 | jupyter-book=0.15.1 52 | 53 | # Build the Jupyter book 54 | - name: Build the book 55 | run: jupyter-book build docs/ 56 | 57 | # Upload the built HTML to GitHub Artifacts 58 | - name: Upload artifact 59 | uses: actions/upload-pages-artifact@v3 60 | with: 61 | path: docs/_build/html 62 | 63 | # This job downloads the built HTML artifacts and deploys the webpage 64 | deploy-pages: 65 | runs-on: ubuntu-22.04 66 | # Add a dependency to the build job 67 | needs: build-book 68 | # Grant GITHUB_TOKEN the permissions required to make a Pages deployment 69 | permissions: 70 | pages: write # to deploy to Pages 71 | id-token: write # to verify the deployment originates from an appropriate source 72 | # Deploy to the github-pages environment 73 | environment: 74 | name: github-pages 75 | url: ${{ steps.deployment.outputs.page_url }} 76 | 77 | steps: 78 | # Push the book's HTML to github-pages 79 | - name: Deploy to GitHub Pages 80 | id: deployment 81 | uses: actions/deploy-pages@v4 82 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Batch package 2 | batch-fetch-and-run.zip 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Checkpoints and logs 13 | checkpoints/**/*.ckpt 14 | lightning_logs/ 15 | wandb/ 16 | logs/ 17 | 18 | # Data files and folders 19 | data/** 20 | !data/**/ 21 | **/*/.gif 22 | **/*/.png 23 | 24 | # Distribution / packaging 25 | .Python 26 | env/ 27 | build/ 28 | develop-eggs/ 29 | dist/ 30 | downloads/ 31 | eggs/ 32 | .eggs/ 33 | lib/ 34 | lib64/ 35 | parts/ 36 | sdist/ 37 | var/ 38 | *.egg-info/ 39 | .installed.cfg 40 | *.egg 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | 52 | # Sphinx documentation 53 | docs/_build/ 54 | 55 | # DotEnv configuration 56 | .env 57 | 58 | # Jupyter Book 59 | /docs/_build/ 60 | 61 | # Jupyter Notebook 62 | .ipynb_checkpoints/ 63 | 64 | # Mac OS-specific storage files 65 | .DS_Store 66 | datadisk/ 67 | 68 | # vscode 69 | .vscode 70 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v5.0.0 6 | hooks: 7 | - id: check-added-large-files 8 | args: [ '--maxkb=512', '--enforce-all' ] 9 | exclude: '^docs/tutorials/.*\.ipynb$' 10 | - id: check-yaml 11 | - id: end-of-file-fixer 12 | - id: trailing-whitespace 13 | - repo: https://github.com/astral-sh/ruff-pre-commit 14 | rev: v0.6.9 15 | hooks: 16 | - id: ruff # Run the linter 17 | args: [ --fix ] 18 | types_or: [ python, pyi ] 19 | - id: ruff # Run the linter for Jupyter notebooks with the PLR0913 rule ignored 20 | args: [ --fix, --ignore=PLR0913 ] 21 | types: [ jupyter ] 22 | - id: ruff-format # Run the formatter 23 | types_or: [ python, pyi, jupyter ] 24 | 25 | # https://pre-commit.ci/#configuration 26 | ci: 27 | autofix_prs: true 28 | autoupdate_schedule: quarterly 29 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | 2 | # Contributor Covenant Code of Conduct 3 | 4 | ## Our Pledge 5 | 6 | We as members, contributors, and leaders pledge to make participation in our 7 | community a harassment-free experience for everyone, regardless of age, body 8 | size, visible or invisible disability, ethnicity, sex characteristics, gender 9 | identity and expression, level of experience, education, socio-economic status, 10 | nationality, personal appearance, race, caste, color, religion, or sexual 11 | identity and orientation. 12 | 13 | We pledge to act and interact in ways that contribute to an open, welcoming, 14 | diverse, inclusive, and healthy community. 15 | 16 | ## Our Standards 17 | 18 | Examples of behavior that contributes to a positive environment for our 19 | community include: 20 | 21 | * Demonstrating empathy and kindness toward other people 22 | * Being respectful of differing opinions, viewpoints, and experiences 23 | * Giving and gracefully accepting constructive feedback 24 | * Accepting responsibility and apologizing to those affected by our mistakes, 25 | and learning from the experience 26 | * Focusing on what is best not just for us as individuals, but for the overall 27 | community 28 | 29 | Examples of unacceptable behavior include: 30 | 31 | * The use of sexualized language or imagery, and sexual attention or advances of 32 | any kind 33 | * Trolling, insulting or derogatory comments, and personal or political attacks 34 | * Public or private harassment 35 | * Publishing others' private information, such as a physical or email address, 36 | without their explicit permission 37 | * Other conduct which could reasonably be considered inappropriate in a 38 | professional setting 39 | 40 | ## Enforcement Responsibilities 41 | 42 | Community leaders are responsible for clarifying and enforcing our standards of 43 | acceptable behavior and will take appropriate and fair corrective action in 44 | response to any behavior that they deem inappropriate, threatening, offensive, 45 | or harmful. 46 | 47 | Community leaders have the right and responsibility to remove, edit, or reject 48 | comments, commits, code, wiki edits, issues, and other contributions that are 49 | not aligned to this Code of Conduct, and will communicate reasons for moderation 50 | decisions when appropriate. 51 | 52 | ## Scope 53 | 54 | This Code of Conduct applies within all community spaces, and also applies when 55 | an individual is officially representing the community in public spaces. 56 | Examples of representing our community include using an official email address, 57 | posting via an official social media account, or acting as an appointed 58 | representative at an online or offline event. 59 | 60 | ## Enforcement 61 | 62 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 63 | reported to the community leaders responsible for enforcement at bruno@clay.foundation or dan@clay.foundation . 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series of 86 | actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or permanent 93 | ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within the 113 | community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.1, available at 119 | [https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1]. 120 | 121 | Community Impact Guidelines were inspired by 122 | [Mozilla's code of conduct enforcement ladder][Mozilla CoC]. 123 | 124 | For answers to common questions about this code of conduct, see the FAQ at 125 | [https://www.contributor-covenant.org/faq][FAQ]. Translations are available at 126 | [https://www.contributor-covenant.org/translations][translations]. 127 | 128 | [homepage]: https://www.contributor-covenant.org 129 | [v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html 130 | [Mozilla CoC]: https://github.com/mozilla/diversity 131 | [FAQ]: https://www.contributor-covenant.org/faq 132 | [translations]: https://www.contributor-covenant.org/translations 133 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Clay Foundation Model 2 | 3 | [![Jupyter Book Badge](https://jupyterbook.org/badge.svg)](https://clay-foundation.github.io/model) 4 | [![Deploy Book Status](https://github.com/Clay-foundation/model/actions/workflows/deploy-docs.yml/badge.svg)](https://github.com/Clay-foundation/model/actions/workflows/deploy-docs.yml) 5 | 6 | An open source AI model and interface for Earth. 7 | 8 | ## License 9 | 10 | Clay Model is licensed under the [Apache](LICENSE). This applies to the source code as well as the trained model weights. 11 | 12 | The Documentation is licensed under the [CC-BY-4.0](https://creativecommons.org/licenses/by/4.0/) license. 13 | 14 | ## Quickstart 15 | 16 | Launch into a [JupyterLab](https://jupyterlab.readthedocs.io) environment on 17 | 18 | | [Binder](https://mybinder.readthedocs.io/en/latest) | [SageMaker Studio Lab](https://studiolab.sagemaker.aws) | 19 | |:--:|:--:| 20 | | [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/Clay-foundation/model/main) | [![Open in SageMaker Studio Lab](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/Clay-foundation/model/blob/main/docs/tutorials/wall-to-wall.ipynb) | 21 | 22 | ## Installation 23 | 24 | ### Basic 25 | 26 | To help out with development, start by cloning this [repo-url](/../../) 27 | 28 | git clone 29 | cd model 30 | 31 | Then we recommend [using mamba](https://mamba.readthedocs.io/en/latest/installation/mamba-installation.html) 32 | to install the dependencies. A virtual environment will also be created with Python and 33 | [JupyterLab](https://github.com/jupyterlab/jupyterlab) installed. 34 | 35 | mamba env create --file environment.yml 36 | 37 | > [!NOTE] 38 | > The command above has been tested on Linux devices with CUDA GPUs. 39 | 40 | Activate the virtual environment first. 41 | 42 | mamba activate claymodel 43 | 44 | Finally, double-check that the libraries have been installed. 45 | 46 | mamba list 47 | 48 | 49 | ## Usage 50 | 51 | ### Running jupyter lab 52 | 53 | mamba activate claymodel 54 | python -m ipykernel install --user --name claymodel # to install virtual env properly 55 | jupyter kernelspec list --json # see if kernel is installed 56 | jupyter lab & 57 | 58 | 59 | ### Running the model 60 | 61 | The neural network model can be ran via 62 | [LightningCLI v2](https://pytorch-lightning.medium.com/introducing-lightningcli-v2supercharge-your-training-c070d43c7dd6). 63 | To check out the different options available, and look at the hyperparameter 64 | configurations, run: 65 | 66 | python trainer.py --help 67 | 68 | To quickly test the model on one batch in the validation set: 69 | 70 | python trainer.py fit --model ClayMAEModule --data ClayDataModule --config configs/config.yaml --trainer.fast_dev_run=True 71 | 72 | To train the model: 73 | 74 | python trainer.py fit --model ClayMAEModule --data ClayDataModule --config configs/config.yaml 75 | 76 | More options can be found using `python trainer.py fit --help`, or at the 77 | [LightningCLI docs](https://lightning.ai/docs/pytorch/2.1.0/cli/lightning_cli.html). 78 | 79 | ## Contributing 80 | 81 | ### Writing documentation 82 | 83 | Our Documentation uses [Jupyter Book](https://jupyterbook.org/intro.html). 84 | 85 | Install it with: 86 | ```bash 87 | pip install -U jupyter-book 88 | ``` 89 | 90 | Then build it with: 91 | ```bash 92 | jupyter-book build docs/ 93 | ``` 94 | 95 | You can preview the site locally with: 96 | ```bash 97 | python -m http.server --directory _build/html 98 | ``` 99 | 100 | There is a GitHub Action on `./github/workflows/deploy-docs.yml` that builds the site and pushes it to GitHub Pages. 101 | -------------------------------------------------------------------------------- /cluster/ml-cluster.yaml.template: -------------------------------------------------------------------------------- 1 | Region: us-east-2 2 | 3 | # DL AMI 4 | Image: 5 | Os: ubuntu2004 6 | CustomAmi: 7 | 8 | # FSx LUSTRE SHARED STORAGE 9 | SharedStorage: 10 | - MountDir: /fsx 11 | Name: fsx 12 | StorageType: FsxLustre 13 | FsxLustreSettings: 14 | FileSystemId: 15 | 16 | # HEAD NODE 17 | HeadNode: 18 | InstanceType: c5.12xlarge 19 | Networking: 20 | SubnetId: 21 | SecurityGroups: 22 | - # EFA enabled SG 23 | Ssh: 24 | KeyName: 25 | LocalStorage: 26 | RootVolume: 27 | Size: 200 28 | Iam: 29 | S3Access: 30 | - BucketName: 31 | EnableWriteAccess: false 32 | - BucketName: 33 | EnableWriteAccess: true 34 | 35 | 36 | # SCHEDULER 37 | Scheduling: 38 | Scheduler: slurm 39 | SlurmQueues: 40 | - Name: gpu-queue 41 | ComputeResources: 42 | - Name: 43 | Instances: 44 | - InstanceType: 45 | MinCount: 0 46 | MaxCount: 8 47 | Efa: 48 | Enabled: true 49 | Networking: 50 | SubnetIds: 51 | - 52 | SecurityGroups: 53 | - # EFA enabled SG 54 | PlacementGroup: 55 | Enabled: true 56 | Iam: 57 | S3Access: 58 | - BucketName: 59 | EnableWriteAccess: false 60 | - BucketName: 61 | EnableWriteAccess: true 62 | -------------------------------------------------------------------------------- /configs/classify_eurosat.yaml: -------------------------------------------------------------------------------- 1 | # lightning.pytorch==2.1.2 2 | seed_everything: 42 3 | data: 4 | metadata_path: configs/metadata.yaml 5 | batch_size: 128 6 | num_workers: 8 7 | model: 8 | num_classes: 10 9 | ckpt_path: checkpoints/clay-v1.5.ckpt 10 | lr: 5e-5 11 | wd: 0.05 12 | b1: 0.9 13 | b2: 0.95 14 | trainer: 15 | accelerator: auto 16 | strategy: ddp 17 | devices: auto 18 | num_nodes: 1 19 | precision: bf16-mixed 20 | log_every_n_steps: 5 21 | max_epochs: 20 22 | accumulate_grad_batches: 1 23 | default_root_dir: checkpoints/classify 24 | fast_dev_run: False 25 | num_sanity_val_steps: 0 26 | logger: 27 | - class_path: lightning.pytorch.loggers.WandbLogger 28 | init_args: 29 | entity: developmentseed 30 | project: clay-classify 31 | group: v1.5-test 32 | log_model: false 33 | callbacks: 34 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 35 | init_args: 36 | dirpath: checkpoints/classify 37 | auto_insert_metric_name: False 38 | filename: eurosat_epoch-{epoch:02d}_val-score-{val/score:.3f} 39 | monitor: val/score 40 | mode: max 41 | save_last: True 42 | save_top_k: 2 43 | save_weights_only: True 44 | verbose: True 45 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 46 | init_args: 47 | logging_interval: step 48 | plugins: 49 | - class_path: lightning.pytorch.plugins.io.AsyncCheckpointIO 50 | -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 108 2 | data: 3 | data_dir: /fsx 4 | size: 256 5 | metadata_path: configs/metadata.yaml 6 | platforms: 7 | - landsat-c2l1 8 | - landsat-c2l2-sr 9 | - linz 10 | - modis 11 | - naip 12 | - sentinel-1-rtc 13 | - sentinel-2-l2a 14 | batch_size: 1 15 | num_workers: 12 16 | model: 17 | model_size: large 18 | mask_ratio: 0.75 19 | norm_pix_loss: False 20 | patch_size: 8 21 | shuffle: True 22 | metadata_path: configs/metadata.yaml 23 | teacher: vit_large_patch14_reg4_dinov2.lvd142m 24 | dolls: [16, 32, 64, 128, 256, 768, 1024] 25 | doll_weights: [1, 1, 1, 1, 1, 1, 1] 26 | lr: 5e-6 27 | wd: 0.05 28 | b1: 0.9 29 | b2: 0.95 30 | embeddings_level: mean 31 | trainer: 32 | accelerator: gpu 33 | strategy: ddp 34 | devices: 8 35 | num_nodes: 48 36 | precision: bf16-mixed 37 | log_every_n_steps: 1 38 | max_epochs: 1000 39 | accumulate_grad_batches: 1 40 | default_root_dir: checkpoints/v1.5.0/ 41 | fast_dev_run: False 42 | num_sanity_val_steps: 0 43 | use_distributed_sampler: False 44 | limit_train_batches: 0.99 45 | limit_val_batches: 0.99 46 | logger: 47 | - class_path: lightning.pytorch.loggers.WandbLogger 48 | init_args: 49 | entity: developmentseed 50 | project: clay 51 | group: v1.5-nomrl-dinov2 52 | id: 0uy3in7l 53 | resume: must 54 | log_model: false 55 | callbacks: 56 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 57 | init_args: 58 | dirpath: checkpoints/v1.5.0/ 59 | auto_insert_metric_name: False 60 | filename: mae_v1.5.0_epoch-{epoch:02d}_val-loss-{val/loss:.4f} 61 | monitor: val/loss 62 | mode: min 63 | save_last: True 64 | save_top_k: 2 65 | save_weights_only: False 66 | verbose: True 67 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 68 | init_args: 69 | logging_interval: step 70 | - class_path: src.callbacks_wandb.LogIntermediatePredictions 71 | plugins: 72 | - class_path: lightning.pytorch.plugins.io.AsyncCheckpointIO 73 | ckpt_path: checkpoints/v1.5.0/last.ckpt 74 | -------------------------------------------------------------------------------- /configs/metadata.yaml: -------------------------------------------------------------------------------- 1 | sentinel-2-l2a: 2 | band_order: 3 | - blue 4 | - green 5 | - red 6 | - rededge1 7 | - rededge2 8 | - rededge3 9 | - nir 10 | - nir08 11 | - swir16 12 | - swir22 13 | rgb_indices: 14 | - 2 15 | - 1 16 | - 0 17 | gsd: 10 18 | bands: 19 | mean: 20 | blue: 1105. 21 | green: 1355. 22 | red: 1552. 23 | rededge1: 1887. 24 | rededge2: 2422. 25 | rededge3: 2630. 26 | nir: 2743. 27 | nir08: 2785. 28 | swir16: 2388. 29 | swir22: 1835. 30 | std: 31 | blue: 1809. 32 | green: 1757. 33 | red: 1888. 34 | rededge1: 1870. 35 | rededge2: 1732. 36 | rededge3: 1697. 37 | nir: 1742. 38 | nir08: 1648. 39 | swir16: 1470. 40 | swir22: 1379. 41 | wavelength: 42 | blue: 0.493 43 | green: 0.56 44 | red: 0.665 45 | rededge1: 0.704 46 | rededge2: 0.74 47 | rededge3: 0.783 48 | nir: 0.842 49 | nir08: 0.865 50 | swir16: 1.61 51 | swir22: 2.19 52 | planetscope-sr: 53 | band_order: 54 | - coastal_blue 55 | - blue 56 | - green_i 57 | - green 58 | - yellow 59 | - red 60 | - rededge 61 | - nir 62 | rgb_indices: 63 | - 5 64 | - 3 65 | - 1 66 | gsd: 5 67 | bands: 68 | mean: 69 | coastal_blue: 1720. 70 | blue: 1715. 71 | green_i: 1913. 72 | green: 2088. 73 | yellow: 2274. 74 | red: 2290. 75 | rededge: 2613. 76 | nir: 3970. 77 | std: 78 | coastal_blue: 747. 79 | blue: 698. 80 | green_i: 739. 81 | green: 768. 82 | yellow: 849. 83 | red: 868. 84 | rededge: 849. 85 | nir: 914. 86 | wavelength: 87 | coastal_blue: 0.443 88 | blue: 0.490 89 | green_i: 0.531 90 | green: 0.565 91 | yellow: 0.610 92 | red: 0.665 93 | rededge: 0.705 94 | nir: 0.865 95 | landsat-c2l1: 96 | band_order: 97 | - red 98 | - green 99 | - blue 100 | - nir08 101 | - swir16 102 | - swir22 103 | rgb_indices: 104 | - 0 105 | - 1 106 | - 2 107 | gsd: 30 108 | bands: 109 | mean: 110 | red: 10678. 111 | green: 10563. 112 | blue: 11083. 113 | nir08: 14792. 114 | swir16: 12276. 115 | swir22: 10114. 116 | std: 117 | red: 6025. 118 | green: 5411. 119 | blue: 5468. 120 | nir08: 6746. 121 | swir16: 5897. 122 | swir22: 4850. 123 | wavelength: 124 | red: 0.65 125 | green: 0.56 126 | blue: 0.48 127 | nir08: 0.86 128 | swir16: 1.6 129 | swir22: 2.2 130 | landsat-c2l2-sr: 131 | band_order: 132 | - red 133 | - green 134 | - blue 135 | - nir08 136 | - swir16 137 | - swir22 138 | rgb_indices: 139 | - 0 140 | - 1 141 | - 2 142 | gsd: 30 143 | bands: 144 | mean: 145 | red: 13705. 146 | green: 13310. 147 | blue: 12474. 148 | nir08: 17801. 149 | swir16: 14615. 150 | swir22: 12701. 151 | std: 152 | red: 9578. 153 | green: 9408. 154 | blue: 10144. 155 | nir08: 8277. 156 | swir16: 5300. 157 | swir22: 4522. 158 | wavelength: 159 | red: 0.65 160 | green: 0.56 161 | blue: 0.48 162 | nir08: 0.86 163 | swir16: 1.6 164 | swir22: 2.2 165 | naip: 166 | band_order: 167 | - red 168 | - green 169 | - blue 170 | - nir 171 | rgb_indices: 172 | - 0 173 | - 1 174 | - 2 175 | gsd: 1.0 176 | bands: 177 | mean: 178 | red: 110.16 179 | green: 115.41 180 | blue: 98.15 181 | nir: 139.04 182 | std: 183 | red: 47.23 184 | green: 39.82 185 | blue: 35.43 186 | nir: 49.86 187 | wavelength: 188 | red: 0.65 189 | green: 0.56 190 | blue: 0.48 191 | nir: 0.842 192 | linz: 193 | band_order: 194 | - red 195 | - green 196 | - blue 197 | rgb_indices: 198 | - 0 199 | - 1 200 | - 2 201 | gsd: 0.5 202 | bands: 203 | mean: 204 | red: 89.96 205 | green: 99.46 206 | blue: 89.51 207 | std: 208 | red: 41.83 209 | green: 36.96 210 | blue: 31.45 211 | wavelength: 212 | red: 0.635 213 | green: 0.555 214 | blue: 0.465 215 | sentinel-1-rtc: 216 | band_order: 217 | - vv 218 | - vh 219 | gsd: 10 220 | bands: 221 | mean: 222 | vv: -12.113 223 | vh: -18.673 224 | std: 225 | vv: 8.314 226 | vh: 8.017 227 | wavelength: 228 | vv: 3.5 229 | vh: 4.0 230 | modis: 231 | band_order: 232 | - sur_refl_b01 233 | - sur_refl_b02 234 | - sur_refl_b03 235 | - sur_refl_b04 236 | - sur_refl_b05 237 | - sur_refl_b06 238 | - sur_refl_b07 239 | rgb_indices: 240 | - 0 241 | - 3 242 | - 2 243 | gsd: 500 244 | bands: 245 | mean: 246 | sur_refl_b01: 1072. 247 | sur_refl_b02: 1624. 248 | sur_refl_b03: 931. 249 | sur_refl_b04: 1023. 250 | sur_refl_b05: 1599. 251 | sur_refl_b06: 1404. 252 | sur_refl_b07: 1051. 253 | std: 254 | sur_refl_b01: 1643. 255 | sur_refl_b02: 1878. 256 | sur_refl_b03: 1449. 257 | sur_refl_b04: 1538. 258 | sur_refl_b05: 1763. 259 | sur_refl_b06: 1618. 260 | sur_refl_b07: 1396. 261 | wavelength: 262 | sur_refl_b01: .645 263 | sur_refl_b02: .858 264 | sur_refl_b03: .469 265 | sur_refl_b04: .555 266 | sur_refl_b05: 1.240 267 | sur_refl_b06: 1.640 268 | sur_refl_b07: 2.130 269 | -------------------------------------------------------------------------------- /configs/regression_biomasters.yaml: -------------------------------------------------------------------------------- 1 | # lightning.pytorch==2.1.2 2 | seed_everything: 42 3 | data: 4 | metadata_path: configs/metadata.yaml 5 | batch_size: 25 6 | num_workers: 8 7 | train_chip_dir: data/biomasters/train_cube 8 | train_label_dir: data/biomasters/train_agbm 9 | val_chip_dir: data/biomasters/test_cube 10 | val_label_dir: data/biomasters/test_agbm 11 | model: 12 | ckpt_path: checkpoints/clay-v1.5.ckpt 13 | lr: 1e-2 14 | wd: 0.05 15 | b1: 0.9 16 | b2: 0.95 17 | trainer: 18 | accelerator: auto 19 | strategy: ddp 20 | devices: auto 21 | num_nodes: 1 22 | precision: bf16-mixed 23 | log_every_n_steps: 5 24 | max_epochs: 100 25 | default_root_dir: checkpoints/regression 26 | fast_dev_run: False 27 | num_sanity_val_steps: 0 28 | limit_train_batches: 0.25 29 | limit_val_batches: 0.25 30 | accumulate_grad_batches: 1 31 | logger: 32 | - class_path: lightning.pytorch.loggers.WandbLogger 33 | init_args: 34 | entity: developmentseed 35 | project: clay-regression 36 | log_model: false 37 | group: v1.5 38 | callbacks: 39 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 40 | init_args: 41 | dirpath: checkpoints/regression 42 | auto_insert_metric_name: False 43 | filename: biomasters_epoch-{epoch:02d}_val-score-{val/score:.3f} 44 | monitor: val/score 45 | mode: min 46 | save_last: False 47 | save_top_k: 2 48 | save_weights_only: True 49 | verbose: True 50 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 51 | init_args: 52 | logging_interval: step 53 | # - class_path: src.callbacks.LayerwiseFinetuning 54 | # init_args: 55 | # phase: 10 56 | # train_bn: True 57 | plugins: 58 | - class_path: lightning.pytorch.plugins.io.AsyncCheckpointIO 59 | -------------------------------------------------------------------------------- /configs/segment_chesapeake.yaml: -------------------------------------------------------------------------------- 1 | # lightning.pytorch==2.1.2 2 | seed_everything: 42 3 | data: 4 | train_chip_dir: data/cvpr/ny/train/chips/ 5 | train_label_dir: data/cvpr/ny/train/labels/ 6 | val_chip_dir: data/cvpr/ny/val/chips/ 7 | val_label_dir: data/cvpr/ny/val/labels/ 8 | metadata_path: configs/metadata.yaml 9 | batch_size: 16 10 | num_workers: 8 11 | platform: naip 12 | model: 13 | num_classes: 7 14 | ckpt_path: checkpoints/clay-v1.5.ckpt 15 | lr: 1e-5 16 | wd: 0.05 17 | b1: 0.9 18 | b2: 0.95 19 | trainer: 20 | accelerator: auto 21 | strategy: ddp 22 | devices: auto 23 | num_nodes: 1 24 | precision: bf16-mixed 25 | log_every_n_steps: 5 26 | max_epochs: 100 27 | accumulate_grad_batches: 1 28 | default_root_dir: checkpoints/segment 29 | fast_dev_run: False 30 | num_sanity_val_steps: 0 31 | logger: 32 | - class_path: lightning.pytorch.loggers.WandbLogger 33 | init_args: 34 | entity: developmentseed 35 | project: clay-segment 36 | group: v1.5 37 | log_model: false 38 | callbacks: 39 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 40 | init_args: 41 | dirpath: checkpoints/segment 42 | auto_insert_metric_name: False 43 | filename: chesapeake-7class-segment_epoch-{epoch:02d}_val-iou-{val/iou:.4f} 44 | monitor: val/iou 45 | mode: max 46 | save_last: True 47 | save_top_k: 2 48 | save_weights_only: True 49 | verbose: True 50 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 51 | init_args: 52 | logging_interval: step 53 | plugins: 54 | - class_path: lightning.pytorch.plugins.io.AsyncCheckpointIO 55 | -------------------------------------------------------------------------------- /copy_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Define source and destination directories 4 | src="/fsx" 5 | dest="data/pretrain" 6 | 7 | # Create the destination directory if it doesn't exist 8 | mkdir -p "$dest" 9 | 10 | # Find all directories in the source directory 11 | find "$src" -type d -print0 | while IFS= read -r -d '' dir; do 12 | # Create corresponding directory in the destination 13 | newdir="$dest${dir#$src}" 14 | mkdir -p "$newdir" 15 | 16 | # Copy the first 100 files from the source directory to the new directory 17 | find "$dir" -maxdepth 1 -type f -print0 | head -z -n 100 | xargs -0 -I{} cp {} "$newdir" 18 | done 19 | -------------------------------------------------------------------------------- /docs/_config.yml: -------------------------------------------------------------------------------- 1 | # Book settings 2 | # Learn more at https://jupyterbook.org/customize/config.html 3 | 4 | title: Clay Foundation Model 5 | author: Clay Foundation 6 | logo: logo.png 7 | only_build_toc_files: true 8 | 9 | # Only execution notebooks with no output cells on each build. 10 | # See https://jupyterbook.org/content/execute.html 11 | execute: 12 | execute_notebooks: cache 13 | exclude_patterns: 14 | - wall-to-wall.ipynb 15 | - embeddings.ipynb 16 | - reconstruction.ipynb 17 | 18 | # Define the name of the latex output file for PDF builds 19 | latex: 20 | latex_documents: 21 | targetname: clay_foundation_model.tex 22 | 23 | # Add a bibtex file so that we can create citations 24 | bibtex_bibfiles: 25 | - references.bib 26 | 27 | # Launch button settings 28 | launch_buttons: 29 | notebook_interface: jupyterlab 30 | binderhub_url: https://mybinder.org 31 | 32 | # Information about where the book exists on the web 33 | repository: 34 | url: https://github.com/Clay-foundation/model # Online location of your book 35 | path_to_book: docs # Optional path to your book, relative to the repository root 36 | branch: main # Which branch of the repository should be used when creating links (optional) 37 | 38 | # Add GitHub buttons to your book 39 | # See https://jupyterbook.org/customize/config.html#add-a-link-to-your-repository 40 | html: 41 | use_edit_page_button: true 42 | use_issues_button: true 43 | use_repository_button: true 44 | extra_footer: | 45 | Creative Commons License
46 | This content is licensed under a 47 | Creative Commons Attribution 4.0 International License. 48 | 49 | sphinx: 50 | config: 51 | html_show_copyright: false 52 | extra_extensions: 53 | - 'sphinx.ext.autodoc' 54 | -------------------------------------------------------------------------------- /docs/_toc.yml: -------------------------------------------------------------------------------- 1 | # Table of contents 2 | # Learn more at https://jupyterbook.org/customize/toc.html 3 | 4 | format: jb-book 5 | root: index 6 | parts: 7 | - caption: Release notes 8 | chapters: 9 | - title: Model release notes 10 | file: release-notes/specification 11 | - title: Data sampling strategy 12 | file: release-notes/data_sampling 13 | - title: Software release notes v1.0 14 | file: release-notes/changelog-v1.0 15 | - caption: Getting Started 16 | chapters: 17 | - title: Installation 18 | file: getting-started/installation 19 | - title: Basic Use 20 | file: getting-started/basic_use 21 | - caption: Tutorials 22 | chapters: 23 | - title: Clay wall-to-wall example 24 | file: tutorials/wall-to-wall 25 | - title: Explore embeddings from Clay Encoder 26 | file: tutorials/embeddings 27 | - title: Clay MAE reconstruction 28 | file: tutorials/reconstruction 29 | - caption: Finetune examples 30 | chapters: 31 | - title: Segmentation using Chesapeake 32 | file: finetune/segment 33 | - title: Classification using Eurosat 34 | file: finetune/classify 35 | - title: Regression using Biomasters 36 | file: finetune/regression 37 | - title: Finetuning on an embeddings database 38 | file: finetune/finetune-on-embeddings 39 | - caption: About Clay 40 | chapters: 41 | - title: GitHub 42 | url: https://github.com/Clay-foundation 43 | - title: LinkedIn 44 | url: https://www.linkedin.com/company/made-with-clay 45 | - title: Website 46 | url: https://madewithclay.org 47 | - caption: Older Versions 48 | chapters: 49 | - file: clay-v0/specification-v0 50 | sections: 51 | - title: Run over a region 52 | file: clay-v0/run_region 53 | - title: Generating embeddings 54 | file: clay-v0/model_embeddings 55 | - title: Finetuning 56 | file: clay-v0/model_finetuning 57 | - title: Embeddings for Contiguous US 58 | file: clay-v0/worldcover-embeddings 59 | - title: Interpolation example for v0 60 | file: clay-v0/clay-v0-interpolation 61 | - title: Creating datacubes 62 | file: clay-v0/data_datacube 63 | - title: Benchmark dataset labels 64 | file: clay-v0/data_labels 65 | - file: clay-v0-interpolation 66 | - title: Software release notes v0.1 67 | file: clay-v0/changelog-v1.0 68 | -------------------------------------------------------------------------------- /docs/clay-v0/changelog-v0.1.md: -------------------------------------------------------------------------------- 1 | (software_release_v0)= 2 | # Code Model release v0.0.1 3 | 4 | This changelog is a summary of the changes to the source code of the Clay model. 5 | Released on 2024/01/12. 6 | 7 | > For release notes for the trained model, see [](model_release) 8 | 9 | 10 | ### 💫 Highlights 11 | 12 | * 🎉 **First release of Clay Foundation Model** 🎉 13 | * Implement MAE with support for position, time, latlon & channel embeddings ([#47](https://github.com/Clay-foundation/model/pull/47)) 14 | 15 | ### 🚀 Model 16 | 17 | * Generate embeddings from CLAYModule trained with latlon/time encodings ([#96](https://github.com/Clay-foundation/model/pull/96)) 18 | * Add multigpu support & UI to test embeddings ([#109](https://github.com/Clay-foundation/model/pull/109)) 19 | * Refactor model for multi-device usage and easier disabling of masking ([#95](https://github.com/Clay-foundation/model/pull/95)) 20 | * Rename embeddings file to include MGRS code and store GeoTIFF source_url ([#86](https://github.com/Clay-foundation/model/pull/86)) 21 | * Callback function to log Masked Autoencoder reconstructions to WandB ([#88](https://github.com/Clay-foundation/model/pull/88)) 22 | * Adapt model to load 512x512 images from s3 bucket ([#85](https://github.com/Clay-foundation/model/pull/85)) 23 | * Save embeddings with spatiotemporal metadata to GeoParquet ([#73](https://github.com/Clay-foundation/model/pull/73)) 24 | * Generate embeddings via prediction loop ([#56](https://github.com/Clay-foundation/model/pull/56)) 25 | * Initial Vision Transformer architecture with MAE decoder ([#37](https://github.com/Clay-foundation/model/pull/37)) 26 | 27 | ### 🗃️ Data Pipeline 28 | 29 | * Adapted sampling strategy ([#81](https://github.com/Clay-foundation/model/pull/81)) 30 | * Allow ClayDataModule to load GeoTIFF files directly from s3 ([#92](https://github.com/Clay-foundation/model/pull/92)) 31 | * Let ClayDataModule return same spatiotemporal fields as GeoTIFFDataModule ([#91](https://github.com/Clay-foundation/model/pull/91)) 32 | * Improve date handling for data pipeline ([#76](https://github.com/Clay-foundation/model/pull/76)) 33 | * Let LightningDataModule return spatiotemporal metadata ([#66](https://github.com/Clay-foundation/model/pull/66)) 34 | * check for no data on a tile level in sentinel 1 vv and vh, sentinel 2 and DEM ([#60](https://github.com/Clay-foundation/model/pull/60)) 35 | * Batch setup ([#54](https://github.com/Clay-foundation/model/pull/54)) 36 | * LightningDataModule to load GeoTIFF files ([#52](https://github.com/Clay-foundation/model/pull/52)) 37 | * Ready for batch ([#44](https://github.com/Clay-foundation/model/pull/44)) 38 | * Tiler module ([#41](https://github.com/Clay-foundation/model/pull/41)) 39 | * Landcover based sampling strategy ([#29](https://github.com/Clay-foundation/model/pull/29)) 40 | * Datacube ([#27](https://github.com/Clay-foundation/model/pull/27)) 41 | 42 | ### 📖 Documentation 43 | 44 | * Document how the benchmark dataset labels were prepared ([#100](https://github.com/Clay-foundation/model/pull/100)) 45 | * Document how to finetune pretrained model on downstream task ([#99](https://github.com/Clay-foundation/model/pull/99)) 46 | * Document how to generate vector embeddings ([#98](https://github.com/Clay-foundation/model/pull/98)) 47 | * Document how to run the datacube pipeline with a batch job ([#97](https://github.com/Clay-foundation/model/pull/97)) 48 | * Initialize Jupyter Book documentation ([#89](https://github.com/Clay-foundation/model/pull/89)) 49 | * Setting the model license to OpenRail-M ([#63](https://github.com/Clay-foundation/model/pull/63)) 50 | * Create CODE_OF_CONDUCT.md ([#53](https://github.com/Clay-foundation/model/pull/53)) 51 | 52 | ### 🧰 Maintenance 53 | 54 | * Bump pytorch from 2.0.0 to 2.1.0, CUDA from 11.8 to 12.0 ([#51](https://github.com/Clay-foundation/model/pull/51)) 55 | * Add pre-commit hooks with ruff formatter/linter rules ([#26](https://github.com/Clay-foundation/model/pull/26)) 56 | * Setup GitHub Actions Continuous Integration tests ([#25](https://github.com/Clay-foundation/model/pull/25)) 57 | * Setup LightningCLI trainer script ([#24](https://github.com/Clay-foundation/model/pull/24)) 58 | * Initial conda environment and binder links ([#15](https://github.com/Clay-foundation/model/pull/15)) 59 | 60 | ### 🧑‍🤝‍🧑 Contributors 61 | 62 | * [@brunosan](https://github.com/brunosan) 63 | * [@lillythomas](https://github.com/lillythomas) 64 | * [@srmsoumya](https://github.com/srmsoumya) 65 | * [@yellowcap](https://github.com/yellowcap) 66 | * [@weiji14](https://github.com/weiji14) 67 | 68 | **Full Changelog**: https://github.com/Clay-foundation/model/compare/v0.0.0...v0.0.1 69 | -------------------------------------------------------------------------------- /docs/clay-v0/clay-v0-interpolation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "76ed0078-447f-4374-b6ba-a8b4a366188d", 6 | "metadata": {}, 7 | "source": [ 8 | "# CLAY v0 - Interpolation between images" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "ea0176a6-97a1-4af6-af75-b9e52e52fbaf", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import sys\n", 19 | "\n", 20 | "sys.path.append(\"../\")" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "id": "5ea314d0-176a-4ee3-b738-6152d27275d9", 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "import os\n", 31 | "from pathlib import Path\n", 32 | "\n", 33 | "import imageio\n", 34 | "import matplotlib.pyplot as plt\n", 35 | "import numpy as np\n", 36 | "import torch\n", 37 | "from einops import rearrange\n", 38 | "from PIL import Image\n", 39 | "\n", 40 | "from src.datamodule import ClayDataModule, ClayDataset\n", 41 | "from src.model_clay import CLAYModule" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "id": "37f4a735-18e6-48d7-9b58-e8d188e96b54", 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "# data directory for all chips\n", 52 | "DATA_DIR = \"../data/02\"\n", 53 | "# path of best model checkpoint for Clay v0\n", 54 | "CKPT_PATH = \"https://huggingface.co/made-with-clay/Clay/resolve/main/Clay_v0.1_epoch-24_val-loss-0.46.ckpt\"" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "id": "4c300730-b0b0-4c3d-8a0d-d5e3ac018641", 60 | "metadata": {}, 61 | "source": [ 62 | "## Load Model & DataModule" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "id": "4c5f2abf-5e9c-4def-88d9-38136307b420", 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "# Load the model & set in eval mode\n", 73 | "model = CLAYModule.load_from_checkpoint(\n", 74 | " CKPT_PATH, mask_ratio=0.0, shuffle=False\n", 75 | ") # No masking or shuffling of patches\n", 76 | "model.eval();" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "id": "348c0573-7670-47a6-9e13-c6de36493b58", 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "data_dir = Path(DATA_DIR)\n", 87 | "\n", 88 | "# Load the Clay DataModule\n", 89 | "ds = ClayDataset(chips_path=list(data_dir.glob(\"**/*.tif\")))\n", 90 | "dm = ClayDataModule(data_dir=str(data_dir), batch_size=2)\n", 91 | "dm.setup(stage=\"fit\")\n", 92 | "\n", 93 | "# Load the train DataLoader\n", 94 | "trn_dl = iter(dm.train_dataloader())" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "id": "af3b38f7-f193-4b7d-aada-b4c8abcce7ed", 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "# Load the first batch of chips\n", 105 | "batch = next(trn_dl)\n", 106 | "batch.keys()" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "id": "59f9a028-a789-40c9-8f23-eb2a0e1c66eb", 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "batch[\"pixels\"].shape, batch[\"latlon\"].shape, batch[\"timestep\"].shape" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "id": "5e4c53bf-43a0-4e19-93da-bc47bb446e74", 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "def show(sample, idx=None, save=False):\n", 127 | " Path(\"animate\").mkdir(exist_ok=True)\n", 128 | " sample = rearrange(sample, \"c h w -> h w c\")\n", 129 | " denorm_sample = sample * torch.as_tensor(dm.STD) + torch.as_tensor(dm.MEAN)\n", 130 | " rgb = denorm_sample[..., [2, 1, 0]]\n", 131 | " plt.imshow((rgb - rgb.min()) / (rgb.max() - rgb.min()))\n", 132 | " plt.axis(\"off\")\n", 133 | " if save:\n", 134 | " plt.savefig(f\"animate/chip_{idx}.png\")" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": null, 140 | "id": "12e22e6b-127e-4948-b3e1-12b81164a417", 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "sample1, sample2 = batch[\"pixels\"]" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "id": "9648f698-7970-4f42-a9d5-093661139994", 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "show(sample1)" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "id": "cbd75d9a-b42a-4e92-b81c-705a98a07d41", 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "show(sample2)" 165 | ] 166 | }, 167 | { 168 | "cell_type": "markdown", 169 | "id": "4d52f962-d20d-4966-af4d-fd5509520356", 170 | "metadata": {}, 171 | "source": [ 172 | "Each batch has chips of shape `13 x 512 x 512`, normalized `lat` & `lon` coords & normalized timestep information as `year`, `month` & `day`." 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "id": "f5e6dfc4-57fe-4296-9a57-9bd384ec61af", 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "# Save a copy of batch to visualize later\n", 183 | "_batch = batch[\"pixels\"].detach().clone().cpu().numpy()" 184 | ] 185 | }, 186 | { 187 | "cell_type": "markdown", 188 | "id": "e7c46bb7-3e25-454d-b345-1ca3ec1efb69", 189 | "metadata": {}, 190 | "source": [ 191 | "## Pass data through the CLAY model" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": null, 197 | "id": "092d1ded-427f-424f-82ec-63bf0bccfdcc", 198 | "metadata": {}, 199 | "outputs": [], 200 | "source": [ 201 | "# Pass the pixels through the encoder & decoder of CLAY\n", 202 | "with torch.no_grad():\n", 203 | " # Move data from to the device of model\n", 204 | " batch[\"pixels\"] = batch[\"pixels\"].to(model.device)\n", 205 | " batch[\"timestep\"] = batch[\"timestep\"].to(model.device)\n", 206 | " batch[\"latlon\"] = batch[\"latlon\"].to(model.device)\n", 207 | "\n", 208 | " # Pass pixels, latlon, timestep through the encoder to create encoded patches\n", 209 | " (\n", 210 | " unmasked_patches,\n", 211 | " unmasked_indices,\n", 212 | " masked_indices,\n", 213 | " masked_matrix,\n", 214 | " ) = model.model.encoder(batch)" 215 | ] 216 | }, 217 | { 218 | "cell_type": "markdown", 219 | "id": "5c14323c-dc02-4255-aa1a-3ec3c8799e3f", 220 | "metadata": {}, 221 | "source": [ 222 | "### Create an image based on interpolation of the embedding values between 2 images\n", 223 | "*Images are saved inside `./animate`*" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": null, 229 | "id": "6386a7a1-d225-4c1e-8228-23f5ce4b87e4", 230 | "metadata": {}, 231 | "outputs": [], 232 | "source": [ 233 | "for idx, alpha in enumerate(np.linspace(0, 1, 20)):\n", 234 | " patch_break = 128\n", 235 | " l1, l2 = unmasked_patches\n", 236 | " l3 = alpha * l1 + (1 - alpha) * l2\n", 237 | " l4 = torch.vstack((l1[:patch_break, :], l2[patch_break:, :]))\n", 238 | "\n", 239 | " # Pass the unmasked_patches through the decoder to reconstruct the pixel space\n", 240 | " with torch.no_grad():\n", 241 | " pixels = model.model.decoder(\n", 242 | " rearrange(l3, \"gl d -> 1 gl d\"), unmasked_indices[[0]], masked_indices[[0]]\n", 243 | " )\n", 244 | "\n", 245 | " image = rearrange(pixels, \"b c (h w) (p1 p2) -> b c (h p1) (w p2)\", h=16, p1=32)\n", 246 | " _image = image[0].detach().cpu()\n", 247 | " show(_image, idx, save=True)" 248 | ] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": null, 253 | "id": "ea7d8627-252d-42e2-af29-0fccb611121e", 254 | "metadata": {}, 255 | "outputs": [], 256 | "source": [ 257 | "fig, axs = plt.subplots(2, 10, figsize=(20, 4))\n", 258 | "for ax, idx in zip(axs.flatten(), range(20)):\n", 259 | " ax.imshow(Image.open(f\"./animate/chip_{idx}.png\"))\n", 260 | " ax.set_title(f\"Seq {idx}\")\n", 261 | " ax.set_axis_off()\n", 262 | "plt.tight_layout()" 263 | ] 264 | }, 265 | { 266 | "cell_type": "markdown", 267 | "id": "91606fa2-deec-485a-815c-d8c7ba07dec3", 268 | "metadata": {}, 269 | "source": [ 270 | "#### Create a GIF of the interpolation of images" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": null, 276 | "id": "91d9aeb7-edc3-492d-a6f1-fd5309e2ab40", 277 | "metadata": {}, 278 | "outputs": [], 279 | "source": [ 280 | "img_paths = [f\"./animate/chip_{idx}.png\" for idx in range(20)]\n", 281 | "\n", 282 | "with imageio.get_writer(\"animate/sample.gif\", mode=\"I\", duration=100) as writer:\n", 283 | " for img_path in img_paths:\n", 284 | " img = imageio.imread(img_path)\n", 285 | " writer.append_data(img)\n", 286 | "\n", 287 | "# Delete the images\n", 288 | "for img_path in img_paths:\n", 289 | " os.remove(img_path)" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": null, 295 | "id": "0e458cd3-d1a8-41b2-9e85-7612c153a7ea", 296 | "metadata": {}, 297 | "outputs": [], 298 | "source": [ 299 | "from IPython.display import Image, display\n", 300 | "\n", 301 | "display(Image(filename=\"./animate/sample.gif\"))" 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": null, 307 | "id": "730bbb55-344d-49e4-a81a-59a9da014770", 308 | "metadata": {}, 309 | "outputs": [], 310 | "source": [] 311 | } 312 | ], 313 | "metadata": { 314 | "kernelspec": { 315 | "display_name": "Python 3 (ipykernel)", 316 | "language": "python", 317 | "name": "python3" 318 | }, 319 | "language_info": { 320 | "codemirror_mode": { 321 | "name": "ipython", 322 | "version": 3 323 | }, 324 | "file_extension": ".py", 325 | "mimetype": "text/x-python", 326 | "name": "python", 327 | "nbconvert_exporter": "python", 328 | "pygments_lexer": "ipython3", 329 | "version": "3.11.7" 330 | } 331 | }, 332 | "nbformat": 4, 333 | "nbformat_minor": 5 334 | } 335 | -------------------------------------------------------------------------------- /docs/clay-v0/data_datacube.md: -------------------------------------------------------------------------------- 1 | # Creating datacubes 2 | 3 | ## How to create a datacube 4 | 5 | The `datacube.py` script collects Sentinel-2, Sentinel-1, and DEM data over individual MGRS tiles. The source list of the MGRS tiles to be processed is provided in an input file with MGRS geometries. Each run of the script will collect data for one of the MGRS tiles in the source file. The tile to be processed is based on the row index number provided as input. The MGRS tile ID is expected to be in the `name` property of the input file. 6 | 7 | For the target MGRS tile, the script loops through the years between 2017 and 2023 in random order. For each year, it will search for the least cloudy Sentinel-2 scene. Based on the date of the selected Sentinel-2 scene, it will search for the Sentinel-1 scenes that are the closest match to that date, with a maximum of +/- 3 days of difference. It will include multiple Sentinel-1 scenes until the full MGRS tile is covered. If no matching Sentinel-1 scenes can be found, the script moves to the next year. The script stops when 3 matching datasets have been collected for 3 different years. Finally, the script will also select the intersecting part of the Copernicus Digital Elevation Model (DEM). 8 | 9 | The script will then download the Sentinel-2 scene and match the data cube with the corresponding Sentinel-1 and DEM data. The scene-level data is then split into smaller chips of a fixed size of 512x512 pixels. The Sentinel-2, Sentinel-1 and DEM bands are then packed together in a single TIFF file for each chip. These are saved locally and synced to a S3 bucket at the end of the script. The bucket name can be specified as input. 10 | 11 | For testing and debugging, the data size can be reduced by specifying a pixel window using the `subset` parameter. Data will then be requested only for the specified pixel window. This will reduce the data size considerably which speeds up the processing during testing. 12 | 13 | The example run below will search for data for the geometry with row index 1 in a local MGRS sample file for a 1000x1000 pixel window. 14 | 15 | ```bash 16 | python datacube.py --sample /home/user/Desktop/mgrs_sample.fgb --bucket "my-bucket" --subset "1000,1000,2000,2000" --index 1 17 | ``` 18 | 19 | ## Running the datacube pipeline as a batch job 20 | 21 | This section describes how to containerize the data pipeline and run it on AWS Batch Spot instances using 22 | a [fetch-and-run](https://aws.amazon.com/blogs/compute/creating-a-simple-fetch-and-run-aws-batch-job/) 23 | approach. 24 | 25 | ### Prepare docker image in ECR 26 | 27 | Build the docker image and push it to a ecr repository. 28 | 29 | ```bash 30 | ecr_repo_id=12345 31 | cd scripts/pipeline/batch 32 | docker build -t $ecr_repo_id.dkr.ecr.us-east-1.amazonaws.com/fetch-and-run . 33 | 34 | aws ecr get-login-password --profile clay --region us-east-1 | docker login --username AWS --password-stdin $ecr_repo_id.dkr.ecr.us-east-1.amazonaws.com 35 | 36 | docker push $ecr_repo_id.dkr.ecr.us-east-1.amazonaws.com/fetch-and-run 37 | ``` 38 | 39 | ### Prepare AWS batch 40 | 41 | To prepare a batch, we need to create a compute environment, job queue, and job 42 | definition. 43 | 44 | Example configurations for the compute environment and the job definition are 45 | provided in the `batch` directory. 46 | 47 | The `submit.py` script contains a loop for submitting jobs to the queue. An 48 | alternative to these individual job submissions would be to use array jobs, but 49 | for now the individual submissions are simpler and failures are easier to track. 50 | 51 | ### Create ZIP file with the package to execute 52 | 53 | Package the model and the inference script into a zip file. The `datacube.py` 54 | script is the one that will be executed on the instances. 55 | 56 | Put the scripts in a zip file and upload the zip package into S3 so that 57 | the batch fetch-and-run can use it. 58 | 59 | ```bash 60 | zip -FSrj "batch-fetch-and-run.zip" ./scripts/pipeline* -x "scripts/pipeline*.pyc" 61 | 62 | aws s3api put-object --bucket clay-fetch-and-run-packages --key "batch-fetch-and-run.zip" --body "batch-fetch-and-run.zip" 63 | ``` 64 | 65 | ### Submit job 66 | 67 | We can now submit a batch job to run the pipeline. The `submit.py` file 68 | provides an example on how to submit jobs in python. 69 | -------------------------------------------------------------------------------- /docs/clay-v0/data_labels.md: -------------------------------------------------------------------------------- 1 | # Benchmark dataset labels 2 | 3 | A benchmark dataset is a collection of data used for evaluating the performance 4 | of algorithms, models, or systems in a specific field of study. These datasets 5 | are crucial for providing common ground for comparing different approaches, 6 | allowing researchers to assess the strengths and weaknesses of various methods. 7 | For Clay, we evaluate our model on benchmark datasets that have suitable downstream 8 | tasks. 9 | 10 | For our initial benchmark dataset, we've implemented the 11 | [Cloud to Street - Microsoft flood dataset](https://beta.source.coop/repositories/c2sms/c2smsfloods/description). 12 | It is what we will use in our initial linear probing experiments and 13 | evaluation of finetuning on a downstream task. The task itself is 14 | [segmentation](https://paperswithcode.com/task/semantic-segmentation) of water 15 | pixels associated with recorded flood events. 16 | 17 | The original dataset consists of two out of three of our Foundation model's datacube inputs 18 | (Sentinel-1 and Sentinel-2) along with raster water mask labels for both 19 | sensors. Each image is 512x512 pixels. The 20 | original Sentinel-2 images are L1C, which is Top-of-Atmosphere reflectance. We train 21 | Clay with surface reflectance, however, so we ultimately used the geospatial bounds 22 | from the GeoTIFF and image timestamp (from the granule name) to query 23 | [Microsoft Planetary Computer's STAC API for L2A (Bottom-of-Atmosphere a.k.a. "surface reflectance") Sentinel-2](https://planetarycomputer.microsoft.com/dataset/sentinel-2-l2a) 24 | scenes in the same time and space, with the same channels expected by Clay. We 25 | then followed the same `datacube` creation logic to generate datacubes with 26 | Sentinel-1 VV and VH and the Copernicus Digital Elevation Model (DEM). We also 27 | ensured that the Sentinel-1 data was within a +/- 3 day interval of each 28 | reference Sentinel-2 scene (same method used by the benchmark dataset authors) 29 | and that the Sentinel-1 data was indeed already included in the bechmark 30 | dataset's list of granules. The datacubes generated have all three inputs 31 | matching the exact specs of the Foundation model's training data, at 512x512 32 | pixels. 33 | 34 | Here is an example of a datacube we generated for the dataset: 35 | 36 | ![datacube](https://github.com/Clay-foundation/model/assets/23487320/94dffcf5-4075-4c17-ac96-01c11bcb299b) 37 | 38 | The images, left to right, show a true-color representation of the Sentinel-2 39 | scene, the Sentinel-1 VH polarization, and the Digital Elevation Model. 40 | 41 | ![gt](https://github.com/Clay-foundation/model/assets/23487320/4ac92af7-6931-4249-a920-7d29453b9b31) 42 | 43 | Here we have something similar, but this time just the Sentinel-1 and 44 | Sentinel-2 scenes with the Sentinel-1 water mask (ground truth) overlaid. 45 | 46 | Last note on this benchmark dataset that we've adapted for Clay: we made sure 47 | to preserve the metadata for timestamp and geospatial coordinates in the 48 | datacube such that we can embed information in the way that the Clay Foundation 49 | model expects. We also preserve the flood event information for analysis 50 | during finetuning. 51 | 52 | The script for generating these datacubes is at 53 | https://github.com/Clay-foundation/model/blob/c2smsfloods_benchmark_datapipeline/scripts/datacube_benchmark.py. 54 | You'll need an AWS account and Microsoft Planetary Computer API Key to run 55 | this. The data is queried from Microsoft Planetary Computer STAC APIs, read and 56 | processed in memory, and the datacubes are written directly to AWS S3. 57 | -------------------------------------------------------------------------------- /docs/clay-v0/model_embeddings.md: -------------------------------------------------------------------------------- 1 | # Generating vector embeddings 2 | 3 | Once you have a pretrained model, it is possible to pass some input images 4 | into the encoder part of the Vision Transformer and produce vector embeddings 5 | which contain a semantic representation of the image. 6 | 7 | ## Producing embeddings from the pretrained model 8 | 9 | Step-by-step instructions to create embeddings for a single MGRS tile location 10 | (e.g. 27WXN): 11 | 12 | 1. Ensure that you can access the 13-band GeoTIFF data files. 13 | 14 | ``` 15 | aws s3 ls s3://clay-tiles-02/02/27WXN/ 16 | ``` 17 | 18 | This should report a list of filepaths if you have the correct permissions. 19 | Otherwise, please set up authentication before continuing. 20 | 21 | 2. Download the pretrained model weights and put them in the `checkpoints/` 22 | folder: 23 | 24 | ```bash 25 | aws s3 cp s3://clay-model-ckpt/v0/clay-small-70MT-1100T-10E.ckpt checkpoints/ 26 | ``` 27 | 28 | ```{tip} 29 | For running model inference on a large scale (hundreds or thousands of MGRS 30 | tiles), it is recommended to have a cloud VM instance with: 31 | 32 | 1. A high bandwidth network (>25Gbps) to speed up data transfer from the S3 33 | bucket to the compute device. 34 | 2. An NVIDIA Ampere generation GPU (e.g. A10G) or newer, which would allow 35 | for efficient bfloat16 dtype calculations. 36 | 37 | For example, an AWS g5.4xlarge instance would be a cost effective option. 38 | ``` 39 | 40 | 3. Run model inference to generate the embeddings: 41 | 42 | ```bash 43 | python trainer.py predict --ckpt_path=checkpoints/clay-small-70MT-1100T-10E.ckpt \ 44 | --trainer.precision=bf16-mixed \ 45 | --data.data_dir=s3://clay-tiles-02/02/27WXN \ 46 | --data.batch_size=32 \ 47 | --data.num_workers=16 \ 48 | --model.embeddings_level=group 49 | ``` 50 | 51 | This should output a GeoParquet file containing the embeddings for MGRS tile 52 | 27WXN (recall that each 10000x10000 pixel MGRS tile contains hundreds of 53 | smaller 512x512 chips), saved to the `data/embeddings/` folder. See the next 54 | subsection for details about the embeddings file. 55 | 56 | The `embeddings_level` flag determines how the embeddings are calculated. 57 | The default is `mean`, resulting in one average embedding per MGRS tile of 58 | size 768. If set to `patch`, the embeddings will be kept at the patch level. 59 | The embedding array will be of size 16 * 16 * 768, representing one 60 | embedding per patch. The third option `group` will keep the full 61 | dimensionality of the encoder output, including the band group 62 | dimension. The array size of those embeddings is 6 * 16 * 16 * 768. 63 | 64 | The embeddings are flattened into one-dimensional arrays because pandas 65 | does not allow for multidimensional arrays. This makes it necessary to 66 | reshape the flattened arrays to access the patch-level embeddings. 67 | 68 | ```{note} 69 | For those interested in how the embeddings were computed, the predict step 70 | above does the following: 71 | 72 | 1. Pass the 13-band GeoTIFF input into the Vision Transformer's encoder, to 73 | produce raw embeddings of shape (B, 1538, 768), where B is the batch_size, 74 | 1538 is the patch dimension and 768 is the embedding length. The patch 75 | dimension itself is a concatenation of 1536 (6 band groups x 16x16 76 | spatial patches of size 32x32 pixels each in a 512x512 image) + 2 (latlon 77 | embedding and time embedding) = 1538. 78 | 2. By default, the mean or average is taken across the 1536 patch dimension, 79 | yielding an output embedding of shape (B, 768). If patch embeddings are 80 | requested, the shape is (B, 16 * 16 * 768), one embedding per patch. 81 | 82 | More details of how this is implemented can be found by inspecting the 83 | `predict_step` method in the `model_clay.py` file. 84 | ``` 85 | 86 | 87 | ## Format of the embeddings file 88 | 89 | The vector embeddings are stored in a single column within a 90 | [GeoParquet](https://geoparquet.org) file (*.gpq), with other columns 91 | containing spatiotemporal metadata. This file format is built on top of the 92 | popular Apache Parquet columnar storage format designed for fast analytics, 93 | and it is highly interoperable across different tools like QGIS, 94 | GeoPandas (Python), sfarrow (R), and more. 95 | 96 | ### Filename convention 97 | 98 | The embeddings file utilizes the following naming convention: 99 | 100 | ``` 101 | {MGRS:5}_{MINDATE:8}_{MAXDATE:8}_v{VERSION:3}.gpq 102 | ``` 103 | 104 | Example: `27WXN_20200101_20231231_v001.gpq` 105 | 106 | | Variable | Description | 107 | |--|--| 108 | | MGRS | The spatial location of the file's contents in the [Military Grid Reference System (MGRS)](https://en.wikipedia.org/wiki/Military_Grid_Reference_System), given as a 5-character string | 109 | | MINDATE | The minimum acquisition date of the Sentinel-2 images used to generate the embeddings, given in YYYYMMDD format | 110 | | MINDATE | The maximum acquisition date of the Sentinel-2 images used to generate the embeddings, given in YYYYMMDD format | 111 | | VERSION | Version of the generated embeddings, given as a 3-digit number | 112 | 113 | 114 | ### Table schema 115 | 116 | Each row within the GeoParquet table is generated from a 512x512 pixel image 117 | and contains a record of the embeddings, spatiotemporal metadata, and a link to 118 | the GeoTIFF file used as the source image for the embedding. The table looks 119 | something like this: 120 | 121 | Embedding size is 768 by default, 16 * 16 * 768 for patch level embeddings, and 122 | 6 * 16 * 16 * 768 for group level embeddings. 123 | 124 | 125 | | source_url | date | embeddings | geometry | 126 | |-----------------------------|------------|------------------|--------------| 127 | | s3://.../.../claytile_*.tif | 2021-01-01 | [0.1, 0.4, ... ] | POLYGON(...) | 128 | | s3://.../.../claytile_*.tif | 2021-06-30 | [0.2, 0.5, ... ] | POLYGON(...) | 129 | | s3://.../.../claytile_*.tif | 2021-12-31 | [0.3, 0.6, ... ] | POLYGON(...) | 130 | 131 | Details of each column are as follows: 132 | 133 | - `source_url` ([string](https://arrow.apache.org/docs/python/generated/pyarrow.string.html)) - The full URL to the 13-band GeoTIFF image the embeddings were derived from. 134 | - `date` ([date32](https://arrow.apache.org/docs/python/generated/pyarrow.date32.html)) - Acquisition date of the Sentinel-2 image used to generate the embeddings, in YYYY-MM-DD format. 135 | - `embeddings` ([FixedShapeTensorArray](https://arrow.apache.org/docs/python/generated/pyarrow.FixedShapeTensorArray.html)) - The vector embeddings given as a 1-D tensor or list with a length of 768. 136 | - `geometry` ([binary](https://arrow.apache.org/docs/python/generated/pyarrow.binary.html)) - The spatial bounding box of where the 13-band image, provided in a [WKB](https://en.wikipedia.org/wiki/Well-known_text_representation_of_geometry#Well-known_binary) Polygon representation. 137 | 138 | 139 | ```{note} 140 | Additional technical details of the GeoParquet file: 141 | - GeoParquet specification [v1.0.0](https://geoparquet.org/releases/v1.0.0) 142 | - Coordinate reference system of geometries are in `OGC:CRS84`. 143 | ``` 144 | 145 | ## Reading the embeddings 146 | 147 | Sample code to read the GeoParquet embeddings file using 148 | [`geopandas.read_parquet`](https://geopandas.org/en/stable/docs/reference/api/geopandas.read_parquet.html) 149 | 150 | ```{code} 151 | import geopandas as gpd 152 | 153 | gpq_file = "data/embeddings/27WXN_20200101_20231231_v001.gpq" 154 | geodataframe = gpd.read_parquet(path=gpq_file) 155 | print(geodataframe) 156 | ``` 157 | 158 | ```{seealso} 159 | Further reading: 160 | - https://guide.cloudnativegeo.org/geoparquet 161 | - https://cloudnativegeo.org/blog/2023/10/the-geoparquet-ecosystem-at-1.0.0 162 | ``` 163 | 164 | ## Converting to patch-level embeddings 165 | 166 | In the case where patch-level embeddings are requested, the resulting array 167 | will have all patch embeddings ravelled in one row. Each row represents a 168 | 512x512 pixel image, and contains 16x16 patch embeddings. 169 | 170 | To convert each row into patch level embeddings, the embedding array has to 171 | be unravelled into 256 patches like so 172 | 173 | ```{code} 174 | # This assumes embeddings levels set to "patch" 175 | ravelled_patch_embeddings = geodataframe.embeddings[0] 176 | patch_embeddings = ravelled_patch_embeddings.reshape(16, 16, 768) 177 | ``` 178 | -------------------------------------------------------------------------------- /docs/clay-v0/model_finetuning.md: -------------------------------------------------------------------------------- 1 | # Finetuning 2 | 3 | Fine-tuning refers to a process in machine learning where a pre-trained model 4 | is further trained on a specific dataset to adapt its parameters to a 5 | downstream task characterized by a relevent domain. It's distinct from training 6 | a model from scratch using the downstream task dataset exclusively. 7 | 8 | Related to finetuning in the field of training Foundation models is linear 9 | probing, which refers to a technique used to analyze or explore the 10 | representations learned by a Foundation model as it trains. When a large-scale 11 | model (like a vision transformer model) is pre-trained on a vast corpus of 12 | data, it learns rich and complex representations of patterns within the data. 13 | Linear probing involves examining or probing these learned representations by 14 | periodically (e.g. every few epochs of the Foundation model's training cycle) 15 | finetuning a small downstream task on top of the pre-trained model's layers or 16 | embeddings. 17 | 18 | We use full finetuning and linear probing in Clay to evaluate the usefulness of 19 | the Foundation model both during its pre-training and afterwards. 20 | 21 | Let's take a look at how we are finetuning on the benchmark datacube-adapted 22 | [Cloud to Street - Microsoft Flood Dataset](https://beta.source.coop/repositories/c2sms/c2smsfloods). 23 | As a reminder, that is a downstream 24 | segmentation task for identifiying water pixels in recorded flood events. It's 25 | a binary segmentation problem, specifically. 26 | 27 | We process the datacubes into batches formatted in the way the pretrained Clay 28 | model expects, with the addition of information for label images as well. 29 | Here's an example subset of a batch dictionary: 30 | 31 | ``` 32 | {'labels': tensor([[[[0., 0., 0., ..., 0., 0., 0.], 33 | [0., 0., 0., ..., 0., 0., 0.], 34 | [0., 0., 0., ..., 0., 0., 0.], 35 | ..., 36 | [0., 0., 0., ..., 0., 0., 0.], 37 | [0., 0., 0., ..., 0., 0., 0.], 38 | [0., 0., 0., ..., 0., 0., 0.]]]), 39 | 'pixels': tensor([[[[-0.5994, -0.6108, -0.6034, ..., -0.5610, -0.5590, -0.5614], 40 | [-0.5767, -0.5950, -0.6004, ..., -0.5619, -0.5536, -0.5610], 41 | [-0.5841, -0.5762, -0.5930, ..., -0.5491, -0.5304, -0.5373], 42 | ..., 43 | [-0.5087, -0.5447, -0.4351, ..., -0.6162, -0.6083, -0.6044], 44 | [-0.4184, -0.5432, -0.5003, ..., -0.6108, -0.6128, -0.6073], 45 | [-0.2496, -0.5348, -0.5225, ..., -0.6137, -0.6167, -0.6128]], 46 | 47 | [[-0.6371, -0.6435, -0.6425, ..., -0.5834, -0.5898, -0.5923], 48 | [-0.6296, -0.6410, -0.6385, ..., -0.5794, -0.5983, -0.5958], 49 | [-0.6167, -0.6177, -0.6182, ..., -0.5545, -0.5913, -0.5834], 50 | ..., 51 | [-0.4800, -0.5153, -0.4308, ..., -0.6525, -0.6410, -0.6331], 52 | [-0.4104, -0.5034, -0.4318, ..., -0.6331, -0.6226, -0.6087], 53 | [-0.2404, -0.5222, -0.4522, ..., -0.6231, -0.6241, -0.6177]], 54 | 55 | [[-0.7068, -0.7217, -0.7101, ..., -0.6118, -0.6178, -0.6290], 56 | [-0.7087, -0.7022, -0.6924, ..., -0.6141, -0.6146, -0.6234], 57 | [-0.7017, -0.6998, -0.6831, ..., -0.5927, -0.6085, -0.6104], 58 | ..., 59 | [-0.5563, -0.5480, -0.4571, ..., -0.7106, -0.7045, -0.6933], 60 | [-0.4725, -0.5526, -0.4781, ..., -0.6975, -0.6789, -0.6807], 61 | [-0.3117, -0.4995, -0.5000, ..., -0.6952, -0.6835, -0.6845]], 62 | 63 | ..., 64 | ]), 65 | 'bbox': tensor([[ 661415., 5369305., 666535., 5374425.]], dtype=torch.float64), 66 | 'epsg': tensor([32633], dtype=torch.int32), 67 | 'date': ['2020-10-20'], 68 | 'latlon': tensor([[-0.8192, -0.7854]]), 69 | 'timestep': tensor([[-1.2217, 2.7132, -2.4086]]), 70 | 'source_url': ['S2A_L2A_20201022T100051_N0209_R122_T33UXP_20201022T111023_06144-02560_S1B_IW_GRDH_1SDV_20201020T164222_20201020T164247_023899_02D6C4_rtc']} 71 | ``` 72 | 73 | Batches of dictionaries like this run through the Clay model's encoder to 74 | generate embeddings, such as this: 75 | 76 | ![embedding_ex](https://github.com/Clay-foundation/model/assets/23487320/375c9e83-d539-4730-b923-3b0b61ea689c) 77 | 78 | from batches with image bands such as: 79 | 80 | ![band_red_ex](https://github.com/Clay-foundation/model/assets/23487320/0c254dbf-9589-4fbf-ab32-e3774fbd2f1a) 81 | 82 | and labels: 83 | 84 | ![labels_ex](https://github.com/Clay-foundation/model/assets/23487320/a92eb8e7-9268-46e5-a254-132205cbc498) 85 | 86 | These embeddings are reshaped from shape 87 | `batch size * (band groups length * number of patches) * embedding size` to 88 | `batch size * (band groups length * embedding size) * patch height * patch width` 89 | before being passed to a series of 2D convolutional transpose and ReLU layers 90 | in a downstream decoder network. 91 | 92 | That decoder network is the core of the downstream task. In a forward pass, it 93 | ingests the embeddings, runs them through those layers and computes a loss 94 | value with respect to the labels. The loss is back-propagated and the decoder 95 | gradually finetunes itself to the downstream dataset. Here's a peek at the 96 | decoder layers: 97 | 98 | ``` 99 | Model( 100 | (decoder): Sequential( 101 | (0): Conv2d(4608, 64, kernel_size=(1, 1), stride=(1, 1)) 102 | (1): Upsample(scale_factor=2.0, mode='nearest') 103 | (2): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 104 | (3): ReLU(inplace=True) 105 | (4): Upsample(scale_factor=2.0, mode='nearest') 106 | (5): ConvTranspose2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 107 | (6): ReLU(inplace=True) 108 | (7): Upsample(scale_factor=2.0, mode='nearest') 109 | (8): ConvTranspose2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 110 | (9): ReLU(inplace=True) 111 | (10): Upsample(scale_factor=2.0, mode='nearest') 112 | (11): ConvTranspose2d(8, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 113 | (12): Upsample(scale_factor=2.0, mode='nearest') 114 | ) 115 | ) 116 | ``` 117 | 118 | Note the absence of an encoder. That is important as this is a finetuning 119 | architecture in which the encoder is replaced by the embeddings from the 120 | pre-trained Clay model. 121 | 122 | In comparison, the network we are using to train the downstream task from 123 | scratch looks notably different: 124 | 125 | ``` 126 | Model( 127 | (encoder): Sequential( 128 | (0): Conv2d(13, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 129 | (1): ReLU(inplace=True) 130 | (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 131 | (3): ReLU(inplace=True) 132 | (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) 133 | ) 134 | (decoder): Sequential( 135 | (0): ConvTranspose2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 136 | (1): ReLU(inplace=True) 137 | (2): ConvTranspose2d(128, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 138 | (3): ReLU(inplace=True) 139 | (4): Upsample(scale_factor=2.0, mode='nearest') 140 | (5): Conv2d(512, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 141 | ) 142 | ) 143 | ``` 144 | In this architecture, there is a defined encoder since the embeddings aren't 145 | doing the purpose of encoding latent information. 146 | 147 | For both the finetuning and "from scratch" architectures, we use a 148 | `binary_cross_entropy_with_logits` loss function as this is a binary 149 | segmentation problem, and on the predictions, we run sigmoid and max functions 150 | to obtain final segmentation results. 151 | 152 | The way we measure relative performance between the finetuned and 153 | "from scratch" model variants happens through calculation of evalution metrics 154 | common for segmentation, such as Dice coefficient, Intersection over Union, F1 155 | score, precision and recall. 156 | 157 | ### Linear probing 158 | 159 | For linear probing, we implement the finetuned architecture in a 160 | [PyTorch callback](https://lightning.ai/docs/pytorch/stable/extensions/callbacks.html) 161 | that will execute every `n` epochs during the Foundation model's training. 162 | -------------------------------------------------------------------------------- /docs/clay-v0/run_region.md: -------------------------------------------------------------------------------- 1 | # How to run clay over custom AOIs 2 | 3 | This section shows in a few simple steps how the clay model can be run for 4 | custom AOIs and over custom date ranges. 5 | 6 | ## Prepare folder strucutre for data 7 | 8 | ```bash 9 | # Move into the model repository 10 | cd /path/to/repository/model/ 11 | 12 | # Ensure data sub-directories exist 13 | mkdir data/mgrs 14 | mkdir data/chips 15 | mkdir data/embeddings 16 | ``` 17 | 18 | ## Download global list of MGRS tiles 19 | 20 | The first step is to download a list of globally available MGRS tiles. A full 21 | list of MGRS tiles has been created as part of the landcover based sampling 22 | strategy. The file is [sourced from a complete MGRS tile list](https://github.com/Clay-foundation/model/blob/main/scripts/landcover.sh#L7), 23 | and then [itersected with the WorldCover landcover](https://github.com/Clay-foundation/model/blob/main/scripts/landcover.py) 24 | layer, outputting the `mgrs_full.fgb` file that is used below. 25 | 26 | ```bash 27 | wget https://clay-mgrs-samples.s3.amazonaws.com/mgrs_full.fgb -O data/mgrs/mgrs_full.fgb 28 | ``` 29 | 30 | ## Create a Geopandas dataframe with MGRS tiles over the AOI 31 | 32 | This example uses a bounding box over the area around Puri, India, to 33 | filter the global list of MGRS tiles. The intersected MGRS tiles are 34 | then stored into a new dataset with the reduced list. The reduced list 35 | will be used by the `datacube.py` script for creating imagery chips. 36 | 37 | ```python 38 | import geopandas as gpd 39 | import pandas as pd 40 | from shapely import box 41 | 42 | 43 | mgrs = gpd.read_file("data/mgrs/mgrs_full.fgb") 44 | print(f"Loaded {len(mgrs)} MGRS grid cells.") 45 | 46 | aoi = gpd.GeoDataFrame( 47 | pd.DataFrame(["Puri"], columns=["Region"]), 48 | crs="EPSG:4326", 49 | geometry=[box(85.0503, 19.4949, 86.1042, 20.5642)], 50 | ) 51 | mgrs_aoi = mgrs.overlay(aoi) 52 | 53 | # Rename the name column to use lowercase letters for the datacube script to 54 | # pick upthe MGRS tile name. 55 | mgrs_aoi = mgrs_aoi.rename(columns={"Name": "name"}) 56 | 57 | print(f"Found {len(mgrs_aoi)} matching MGRS tiles over the AOI.") 58 | 59 | mgrs_aoi.to_file("data/mgrs/mgrs_aoi.fgb") 60 | ``` 61 | 62 | ## Use the datacube.py script to download imagery 63 | 64 | This will select the MGRS tiles that intersect with your AOI. The processing 65 | will then happen for each of the MGRS tiles. This will most likely provide 66 | slightly more data than the AOI itself, as the whole tile data will be downloaded 67 | for each matched MGRS tile. 68 | 69 | Each run of the datacube script will take an index as input, which is the index 70 | of the MGRS tile within the input file. This is why we need to download the 71 | data in a loop. 72 | 73 | A list of date ranges can be specified. The script will look for the least 74 | cloudy Sentinel-2 scene for each date range and match Sentinel-1 dates near 75 | the identified Sentinel-2 dates. 76 | 77 | The output folder can be specified as a local folder or a bucket can be 78 | specified if you want to upload the data to S3. 79 | 80 | Note that for the script to run, a Microsoft Planetary Computer token needs 81 | to be set up. Consult the [Planetary Computer SDK](https://github.com/microsoft/planetary-computer-sdk-for-python) 82 | documentation on how to set up the token. 83 | 84 | By default, the datacube script will download all the data available for each 85 | MGRS tile it processes, so the output might include imagery chips that are 86 | outside of the AOI specified. 87 | 88 | To speed up processing in the example below, we use the subset argument to 89 | reduce each MGRS tile to a small pixel window. When subsetting, the script 90 | will only download a fraction of each MGRS tile. This will lead to discontinous 91 | datasets and should not be used in a real use case. Remove the subset argument 92 | when using the script for a real world application, where all the data should 93 | be downloaded for each MGRS tile. 94 | 95 | ```bash 96 | for i in {0..5}; do 97 | 98 | python scripts/pipeline/datacube.py \ 99 | --sample data/mgrs/mgrs_aoi.fgb \ 100 | --localpath data/chips \ 101 | --index $i \ 102 | --dateranges 2020-01-01/2020-04-01,2021-06-01/2021-09-15 \ 103 | --subset 1500,1500,2524,2524; 104 | 105 | done 106 | ``` 107 | 108 | ## Create the embeddings for each training chip 109 | 110 | The checkpoints can be accessed directly from Hugging Face 111 | at https://huggingface.co/made-with-clay/Clay. 112 | 113 | The following command will run the model to create the embeddings 114 | and automatically download and cache the model weights. 115 | 116 | ```bash 117 | wandb disabled 118 | python trainer.py predict \ 119 | --ckpt_path=https://huggingface.co/made-with-clay/Clay/resolve/main/Clay_v0.1_epoch-24_val-loss-0.46.ckpt \ 120 | --trainer.precision=16-mixed \ 121 | --data.data_dir=/home/tam/Desktop/aoitiles \ 122 | --data.batch_size=2 \ 123 | --data.num_workers=8 124 | ``` 125 | -------------------------------------------------------------------------------- /docs/clay-v0/worldcover-embeddings.md: -------------------------------------------------------------------------------- 1 | # Running embeddings for Worldcover Sentinel-2 Composites 2 | This package is made to generate embeddings from the [ESA Worldcover](https://esa-worldcover.org/en/data-access) 3 | Sentinel-2 annual composites. The target region is all of the 4 | Contiguous United States. 5 | 6 | We ran this script for 2020 and 2021. 7 | 8 | ## The algorithm 9 | 10 | The `run.py` script will run through a column of image chips of 512x512 pixels. 11 | Each run is a column that spans the Contiguous United States from north to 12 | south. For each chip in that column, embeddings are generated and stored 13 | together in one geoparquet file. These files are then uploaded to the 14 | `clay-worldcover-embeddings` bucket on S3. 15 | 16 | There are 1359 such columns to process in order to cover all of the Conus US. 17 | 18 | The embeddings are stored alongside with the bbox of the data chip used for 19 | generating the embedding. To visualize the underlying data or an embedding 20 | the WMS and WMTS endpoints provided by the ESA Worldcover project can be used. 21 | 22 | So the geoparquet files only have the following two columns 23 | 24 | | embeddings | bbox | 25 | |------------------|--------------| 26 | | [0.1, 0.4, ... ] | POLYGON(...) | 27 | | [0.2, 0.5, ... ] | POLYGON(...) | 28 | | [0.3, 0.6, ... ] | POLYGON(...) | 29 | 30 | ## Exploring results 31 | 32 | The `embeddings_db.py` script provides a way to locally explore the embeddings. 33 | It will create a `lancedb` database and allow for search. The search results are 34 | visualizded by requesting the RGB image from the WMS endpoint for the bbox of 35 | each search result. 36 | 37 | ## Running on Batch 38 | 39 | ### Upload package to fetch and run bucket 40 | This snippet will create the zip package that is used for the fetch-and-run 41 | instance in our ECR registry. 42 | 43 | ```bash 44 | # Add clay src and scripts to zip file 45 | zip -FSr batch-fetch-and-run-wc.zip src scripts -x *.pyc -x scripts/worldcover/wandb/**\* 46 | 47 | # Add run to home dir, so that fetch-and-run can see it. 48 | zip -uj batch-fetch-and-run-wc.zip scripts/worldcover/run.py 49 | 50 | # Upload fetch-and-run package to S3 51 | aws s3api put-object --bucket clay-fetch-and-run-packages --key "batch-fetch-and-run-wc.zip" --body "batch-fetch-and-run-wc.zip" 52 | ``` 53 | 54 | ### Push array job 55 | This command will send the array job to AWS batch to run all of the 56 | 1359 jobs to cover the US. 57 | 58 | ```python 59 | import boto3 60 | 61 | batch = boto3.client("batch", region_name="us-east-1") 62 | year = 2020 63 | job = { 64 | "jobName": f"worldcover-conus-{year}", 65 | "jobQueue": "fetch-and-run", 66 | "jobDefinition": "fetch-and-run", 67 | "containerOverrides": { 68 | "command": ["run.py"], 69 | "environment": [ 70 | {"name": "BATCH_FILE_TYPE", "value": "zip"}, 71 | { 72 | "name": "BATCH_FILE_S3_URL", 73 | "value": "s3://clay-fetch-and-run-packages/batch-fetch-and-run-wc.zip", 74 | }, 75 | {"name": "YEAR", "value": f"{year}"} 76 | ], 77 | "resourceRequirements": [ 78 | {"type": "MEMORY", "value": "7500"}, 79 | {"type": "VCPU", "value": "4"}, 80 | # {"type": "GPU", "value": "1"}, 81 | ], 82 | }, 83 | "arrayProperties": { 84 | "size": int((125 - 67) * 12000 / 512) 85 | }, 86 | "retryStrategy": { 87 | "attempts": 5, 88 | "evaluateOnExit": [ 89 | {"onStatusReason": "Host EC2*", "action": "RETRY"}, 90 | {"onReason": "*", "action": "EXIT"} 91 | ] 92 | }, 93 | } 94 | 95 | print(batch.submit_job(**job)) 96 | ``` 97 | -------------------------------------------------------------------------------- /docs/finetune/classify.md: -------------------------------------------------------------------------------- 1 | # Classification head fine tuning 2 | 3 | We have built an example for training a classification head on top of 4 | the class token embeddings from the frozen Clay encoder. 5 | 6 | All the code for this example can be found in the 7 | [classify finetuning folder](https://github.com/Clay-foundation/model/blob/main/finetune/classify) 8 | of this repository. 9 | 10 | ## Classifier 11 | 12 | The `Classifier` class is designed for classification tasks, utilizing the Clay Encoder for feature extraction and adding a classification head on top of it. 13 | 14 | ### Parameters 15 | 16 | - `num_classes (int, optional)`: The number of classes for classification. Defaults to 10. 17 | - `ckpt_path (str, optional)`: Path to the Clay MAE pretrained model checkpoint. Defaults to None. 18 | 19 | ### Example 20 | 21 | In this example, we will use the `Classifier` class to classify images from the [EuroSAT MS dataset](https://github.com/phelber/EuroSAT). The implementation includes data preprocessing, data loading, and model training workflow using [PyTorch Lightning](https://lightning.ai/) & [TorchGeo](https://github.com/microsoft/torchgeo). 22 | 23 | In this example we freeze the Clay encoder and only train a very simple 2 layer MLP head for classification. The MLP head recieves as input the Clay class token embedding, which already contains the essence of the image as seen by Clay. The model for classification can then be kept very simple while still guaranteeing high quality results. 24 | 25 | Notice that the EuroSAT dataset comes without date stamps or location information. The Clay model requires encoded versions of a date stamp and a latitude and longitude information. These values can be set to zero if they are not available, which is what we are doing in the datamodule script. 26 | 27 | ## Dataset citation 28 | 29 | If you have used the EuroSAT dataset, please cite the following papers: 30 | 31 | [1] Eurosat: A novel dataset and deep learning benchmark for land use and land cover classification. Patrick Helber, Benjamin Bischke, Andreas Dengel, Damian Borth. IEEE Journal of Selected Topics in Applied Earth Observations and Remote Sensing, 2019. 32 | 33 | [2] Introducing EuroSAT: A Novel Dataset and Deep Learning Benchmark for Land Use and Land Cover Classification. Patrick Helber, Benjamin Bischke, Andreas Dengel. 2018 IEEE International Geoscience and Remote Sensing Symposium, 2018. 34 | 35 | Dataset URL: [EuroSAT Dataset](https://madm.dfki.de/files/sentinel/EuroSATallBands.zip) 36 | 37 | ## Setup 38 | 39 | Follow the instructions in the [README](../../README.md) to install the required dependencies. 40 | 41 | ```bash 42 | git clone 43 | cd model 44 | mamba env create --file environment.yml 45 | mamba activate claymodel 46 | ``` 47 | 48 | ## Usage 49 | 50 | ### Downloading the Dataset 51 | 52 | TorchGeo will automatically download the EuroSAT dataset when you run the `classify.py` script. 53 | 54 | Directory structure: 55 | ``` 56 | data/ds 57 | └── images 58 | └── remote_sensing 59 | └── otherDatasets 60 | └── sentinel_2 61 | └── tif 62 | ├── AnnualCrop 63 | ├── Forest 64 | ├── HerbaceousVegetation 65 | ├── Highway 66 | ├── Industrial 67 | ├── Pasture 68 | ├── PermanentCrop 69 | ├── Residential 70 | ├── River 71 | └── SeaLake 72 | ``` 73 | 74 | 75 | ### Training the Classifcation Head 76 | 77 | The model can be run via LightningCLI using configurations in `finetune/classify/configs/classify_eurosat.yaml`. 78 | 79 | 1. Download the Clay model checkpoint from [Huggingface model hub](https://huggingface.co/made-with-clay/Clay/resolve/main/v1.5/clay-v1.5.ckpt) and save it in the `checkpoints/` directory. 80 | 81 | 2. Modify the batch size, learning rate, and other hyperparameters in the configuration file as needed: 82 | ```yaml 83 | data: 84 | batch_size: 128 85 | num_workers: 8 86 | model: 87 | num_classes: 10 88 | ckpt_path: checkpoints/clay-v1.5.ckpt 89 | lr: 1e-4 90 | wd: 0.05 91 | b1: 0.9 92 | b2: 0.95 93 | ``` 94 | 95 | 3. Update the [WandB logger](https://lightning.ai/docs/pytorch/stable/extensions/generated/lightning.pytorch.loggers.WandbLogger.html#lightning.pytorch.loggers.WandbLogger) configuration in the configuration file with your WandB details or use [CSV Logger](https://lightning.ai/docs/pytorch/stable/extensions/generated/lightning.pytorch.loggers.CSVLogger.html#lightning.pytorch.loggers.CSVLogger) if you don't want to log to WandB: 96 | ```yaml 97 | logger: 98 | - class_path: lightning.pytorch.loggers.WandbLogger 99 | init_args: 100 | entity: 101 | project: 102 | group: 103 | log_model: false 104 | ``` 105 | 106 | 4. Train the classification head: 107 | 108 | To ensure that the imports work properly, ensure that the root of 109 | the repository is in the python path before running the script. 110 | 111 | ```bash 112 | python -m finetune.classify.classify fit --config configs/classify_eurosat.yaml 113 | ``` 114 | 115 | ## Acknowledgments 116 | 117 | This implementation uses the TorchGeo package for dataset handling and the EuroSAT dataset for training and evaluation. Special thanks to the contributors of [TorchGeo](https://github.com/microsoft/torchgeo) and [EuroSAT](https://github.com/phelber/EuroSAT). 118 | -------------------------------------------------------------------------------- /docs/finetune/regression.md: -------------------------------------------------------------------------------- 1 | # Regression head fine tuning 2 | 3 | We have built an example for training a regression head on top of 4 | feature maps extracted from the frozen Clay encoder. 5 | 6 | All the code for this example can be found in the 7 | [regression finetuning folder](https://github.com/Clay-foundation/model/blob/main/finetune/regression) 8 | of this repository. 9 | 10 | ## Download data 11 | 12 | The data comes as multifile zip, it can be downloaded from the 13 | [BioMassters](https://huggingface.co/datasets/nascetti-a/BioMassters/) 14 | huggingface repository. Grab a coffee, this is about 250GB in size. 15 | 16 | The next step is to unzip training data. The data comes in a multi-file 17 | zip archive. So it needs to be unzipped using a library that can handle 18 | the format. 7z works quite well in this case. Grabb another coffee, this 19 | will take a while. 20 | 21 | ```bash 22 | sudo apt install p7zip-full 23 | ``` 24 | 25 | ### Extract train feature 26 | 27 | 28 | ```bash 29 | 7z e -o/home/tam/Desktop/biomasters/train_features/ /datadisk/biomasters/raw/train_features.zip 30 | ``` 31 | 32 | Should look something like this 33 | 34 | ``` 35 | 7-Zip [64] 16.02 : Copyright (c) 1999-2016 Igor Pavlov : 2016-05-21 36 | p7zip Version 16.02 (locale=en_US.UTF-8,Utf16=on,HugeFiles=on,64 bits,16 CPUs Intel(R) Core(TM) i7-10875H CPU @ 2.30GHz (A0652),ASM,AES-NI) 37 | 38 | Scanning the drive for archives: 39 | 1 file, 10247884383 bytes (9774 MiB) 40 | 41 | Extracting archive: /datadisk/biomasters/raw/train_features.zip 42 | -- 43 | Path = /datadisk/biomasters/raw/train_features.zip 44 | Type = zip 45 | Physical Size = 10247884383 46 | Embedded Stub Size = 4 47 | 64-bit = + 48 | Total Physical Size = 149834321503 49 | Multivolume = + 50 | Volume Index = 13 51 | Volumes = 14 52 | 53 | Everything is Ok 54 | 55 | Folders: 1 56 | Files: 189078 57 | Size: 231859243932 58 | Compressed: 149834321503 59 | ``` 60 | 61 | ### Extract train AGBM 62 | 63 | ```bash 64 | 7z e -o/home/tam/Desktop/biomasters/train_agbm/ /datadisk/biomasters/raw/train_agbm.zip 65 | ``` 66 | 67 | Should look something like this 68 | 69 | ``` 70 | 7-Zip [64] 16.02 : Copyright (c) 1999-2016 Igor Pavlov : 2016-05-21 71 | p7zip Version 16.02 (locale=en_US.UTF-8,Utf16=on,HugeFiles=on,64 bits,16 CPUs Intel(R) Core(TM) i7-10875H CPU @ 2.30GHz (A0652),ASM,AES-NI) 72 | 73 | Scanning the drive for archives: 74 | 1 file, 575973495 bytes (550 MiB) 75 | 76 | Extracting archive: /datadisk/biomasters/raw/train_agbm.zip 77 | -- 78 | Path = /datadisk/biomasters/raw/train_agbm.zip 79 | Type = zip 80 | Physical Size = 575973495 81 | 82 | Everything is Ok 83 | 84 | Folders: 1 85 | Files: 8689 86 | Size: 2280706098 87 | Compressed: 575973495 88 | ``` 89 | 90 | ### Extract test features 91 | 92 | ```bash 93 | 7z e -o/home/tam/Desktop/biomasters/test_features/ /datadisk/biomasters/raw/test_features_splits.zip 94 | ``` 95 | 96 | Should look something like this 97 | 98 | ``` 99 | 7-Zip [64] 16.02 : Copyright (c) 1999-2016 Igor Pavlov : 2016-05-21 100 | p7zip Version 16.02 (locale=en_US.UTF-8,Utf16=on,HugeFiles=on,64 bits,16 CPUs Intel(R) Core(TM) i7-10875H CPU @ 2.30GHz (A0652),ASM,AES-NI) 101 | 102 | Scanning the drive for archives: 103 | 1 file, 6912625480 bytes (6593 MiB) 104 | 105 | Extracting archive: /datadisk/biomasters/raw/test_features_splits.zip 106 | -- 107 | Path = /datadisk/biomasters/raw/test_features_splits.zip 108 | Type = zip 109 | Physical Size = 6912625480 110 | Embedded Stub Size = 4 111 | 64-bit = + 112 | Total Physical Size = 49862298440 113 | Multivolume = + 114 | Volume Index = 4 115 | Volumes = 5 116 | 117 | Everything is Ok 118 | 119 | Folders: 1 120 | Files: 63348 121 | Size: 78334396224 122 | Compressed: 49862298440 123 | ``` 124 | 125 | ### Extract test AGBM 126 | 127 | ```bash 128 | 7z e -o/home/tam/Desktop/biomasters/test_agbm/ /datadisk/biomasters/raw/test_agbm.tar 129 | ``` 130 | 131 | Should look something like this 132 | 133 | ``` 134 | 7-Zip [64] 16.02 : Copyright (c) 1999-2016 Igor Pavlov : 2016-05-21 135 | p7zip Version 16.02 (locale=en_US.UTF-8,Utf16=on,HugeFiles=on,64 bits,16 CPUs Intel(R) Core(TM) i7-10875H CPU @ 2.30GHz (A0652),ASM,AES-NI) 136 | 137 | Scanning the drive for archives: 138 | 1 file, 729766400 bytes (696 MiB) 139 | 140 | Extracting archive: /datadisk/biomasters/raw/test_agbm.tar 141 | -- 142 | Path = /datadisk/biomasters/raw/test_agbm.tar 143 | Type = tar 144 | Physical Size = 729766400 145 | Headers Size = 1421312 146 | Code Page = UTF-8 147 | 148 | Everything is Ok 149 | 150 | Folders: 1 151 | Files: 2773 152 | Size: 727862586 153 | Compressed: 729766400 154 | ``` 155 | 156 | ## Prepare data 157 | 158 | This will take the average of all timesteps available for each tile. 159 | The time steps for Sentinel-2 are not complete, not all months are 160 | provided for all tiles. In addtion, the Clay model does not take time 161 | series as input. So aggregating the time element is simplifying but 162 | ok for the purpose of this example. 163 | 164 | **In addition, we skip the one orbit because it nodata most of the time** 165 | 166 | 167 | ### Prepare training features 168 | 169 | ```bash 170 | python finetune/regression/preprocess_data.py \ 171 | --features=/home/tam/Desktop/biomasters/train_features/ \ 172 | --cubes=/home/tam/Desktop/biomasters/train_cubes/ \ 173 | --processes=12 \ 174 | --sample=1 \ 175 | --overwrite 176 | ``` 177 | 178 | ### Prepare test features 179 | 180 | ```bash 181 | python finetune/regression/preprocess_data.py \ 182 | --features=/home/tam/Desktop/biomasters/test_features/ \ 183 | --cubes=/home/tam/Desktop/biomasters/test_cubes/ \ 184 | --processes=12 \ 185 | --sample=1 \ 186 | --overwrite 187 | ``` 188 | 189 | ### Train the model 190 | 191 | To ensure that the imports work properly, ensure that the root of 192 | the repository is in the python path before running the script. 193 | 194 | 195 | ```bash 196 | python -m finetune.regression.regression fit --config configs/regression_biomasters.yaml 197 | ``` 198 | -------------------------------------------------------------------------------- /docs/finetune/segment.md: -------------------------------------------------------------------------------- 1 | # Segmentation head fine tuning 2 | 3 | We have built an example for training a segmentation head on top of 4 | feature map extracted from the frozen Clay encoder. 5 | 6 | All the code for this example can be found in the 7 | [segment finetuning folder](https://github.com/Clay-foundation/model/blob/main/finetune/segment) 8 | of this repository. 9 | 10 | ## Segmentor 11 | 12 | The `Segmentor` class extracts the final feature map from the frozen Clay encoder. It then upsamples the feature map to the original image size using a series of convolution & pixel shuffle operations. 13 | 14 | ### Parameters 15 | 16 | - `num_classes (int)`: Number of classes to segment. 17 | - `ckpt_path (str)`: Path to the Clay model checkpoint. 18 | 19 | ### Example 20 | 21 | In this example, we will use the `Segmentor` class to segment Land Use Land Cover (LULC) classes for the Chesapeake Bay CVPR dataset. The implementation includes data preprocessing, data loading, and model training workflow using PyTorch Lightning. 22 | 23 | ## Dataset citation 24 | 25 | If you use this dataset, please cite the associated manuscript: 26 | 27 | Robinson C, Hou L, Malkin K, Soobitsky R, Czawlytko J, Dilkina B, Jojic N. 28 | Large Scale High-Resolution Land Cover Mapping with Multi-Resolution Data. 29 | Proceedings of the 2019 Conference on Computer Vision and Pattern Recognition (CVPR 2019). 30 | 31 | Dataset URL: [Chesapeake Bay Land Cover Dataset](https://lila.science/datasets/chesapeakelandcover) 32 | 33 | ## Setup 34 | 35 | Follow the instructions in the [README](../../README.md) to install the required dependencies. 36 | 37 | ```bash 38 | git clone 39 | cd model 40 | mamba env create --file environment.yml 41 | mamba activate claymodel 42 | ``` 43 | 44 | ## Usage 45 | 46 | ### Preparing the Dataset 47 | 48 | Download the Chesapeake Bay Land Cover dataset and organize your dataset directory as recommended. 49 | 50 | 1. Copy `*_lc.tif` and `*_naip-new.tif` files for segmentation downstream tasks using s5cmd: 51 | ```bash 52 | # train 53 | s5cmd --no-sign-request cp --include "*_lc.tif" --include "*_naip-new.tif" "s3://us-west-2.opendata.source.coop/agentmorris/lila-wildlife/lcmcvpr2019/cvpr_chesapeake_landcover/ny_1m_2013_extended-debuffered-train_tiles/*" data/cvpr/files/train/ 54 | 55 | # val 56 | s5cmd --no-sign-request cp --include "*_lc.tif" --include "*_naip-new.tif" "s3://us-west-2.opendata.source.coop/agentmorris/lila-wildlife/lcmcvpr2019/cvpr_chesapeake_landcover/ny_1m_2013_extended-debuffered-val_tiles/*" data/cvpr/files/val/ 57 | ``` 58 | 59 | 2. Create chips of size `256 x 256` to feed them to the model: 60 | ```bash 61 | python finetune/segment/preprocess_data.py data/cvpr/files data/cvpr/ny 256 62 | ``` 63 | 64 | Directory structure: 65 | ``` 66 | data/ 67 | └── cvpr/ 68 | ├── files/ 69 | │ ├── train/ 70 | │ └── val/ 71 | └── ny/ 72 | ├── train/ 73 | │ ├── chips/ 74 | │ └── labels/ 75 | └── val/ 76 | ├── chips/ 77 | └── labels/ 78 | ``` 79 | 80 | ### Training the Model 81 | 82 | The model can be run via LightningCLI using configurations in `finetune/segment/configs/segment_chesapeake.yaml`. 83 | 84 | 1. Download the Clay model checkpoint from [Huggingface model hub](https://huggingface.co/made-with-clay/Clay/blob/main/v1.5/clay-v1.5.ckpt) and save it in the `checkpoints/` directory. 85 | 86 | 2. Modify the batch size, learning rate, and other hyperparameters in the configuration file as needed: 87 | ```yaml 88 | data: 89 | batch_size: 16 90 | num_workers: 8 91 | model: 92 | num_classes: 7 93 | ckpt_path: checkpoints/clay-v1.5.ckpt 94 | lr: 1e-5 95 | wd: 0.05 96 | b1: 0.9 97 | b2: 0.95 98 | ``` 99 | 100 | 3. Update the [WandB logger](https://lightning.ai/docs/pytorch/stable/extensions/generated/lightning.pytorch.loggers.WandbLogger.html#lightning.pytorch.loggers.WandbLogger) configuration in the configuration file with your WandB details or use [CSV Logger](https://lightning.ai/docs/pytorch/stable/extensions/generated/lightning.pytorch.loggers.CSVLogger.html#lightning.pytorch.loggers.CSVLogger) if you don't want to log to WandB: 101 | ```yaml 102 | logger: 103 | - class_path: lightning.pytorch.loggers.WandbLogger 104 | init_args: 105 | entity: 106 | project: 107 | group: 108 | log_model: false 109 | ``` 110 | 111 | 4. Train the model: 112 | 113 | To ensure that the imports work properly, ensure that the root of 114 | the repository is in the python path before running the script. 115 | 116 | ```bash 117 | python -m finetune.segment.segment fit --config configs/segment_chesapeake.yaml 118 | ``` 119 | -------------------------------------------------------------------------------- /docs/getting-started/basic_use.md: -------------------------------------------------------------------------------- 1 | # Basic Use 2 | 3 | ## Running jupyter lab 4 | 5 | mamba activate claymodel 6 | python -m ipykernel install --user --name claymodel # to install virtual env properly 7 | jupyter kernelspec list --json # see if kernel is installed 8 | jupyter lab & 9 | 10 | 11 | ## Running the model 12 | The neural network model can be ran via 13 | [LightningCLI v2](https://pytorch-lightning.medium.com/introducing-lightningcli-v2supercharge-your-training-c070d43c7dd6). 14 | To check out the different options available, and look at the hyperparameter 15 | configurations, run: 16 | 17 | python trainer.py --help 18 | 19 | To quickly test the model on one batch in the validation set: 20 | 21 | python trainer.py fit --model ClayMAEModule --data ClayDataModule --config configs/config.yaml --trainer.fast_dev_run=True 22 | 23 | To train the model: 24 | 25 | python trainer.py fit --model ClayMAEModule --data ClayDataModule --config configs/config.yaml 26 | 27 | More options can be found using `python trainer.py fit --help`, or at the 28 | [LightningCLI docs](https://lightning.ai/docs/pytorch/2.1.0/cli/lightning_cli.html). 29 | -------------------------------------------------------------------------------- /docs/getting-started/installation.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | ## Cloud Environments 4 | 5 | Launch into a [JupyterLab](https://jupyterlab.readthedocs.io) environment on 6 | 7 | | [Binder](https://mybinder.readthedocs.io/en/latest) | [SageMaker Studio Lab](https://studiolab.sagemaker.aws) | 8 | |:--:|:--:| 9 | | [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/Clay-foundation/model/main) | [![Open in SageMaker Studio Lab](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/Clay-foundation/model/blob/main/docs/tutorials/clay-v1-wall-to-wall.ipynb) | 10 | 11 | ## Local Environments 12 | 13 | Start by cloning this [repo-url](https://github.com/Clay-foundation/model) 14 | 15 | git clone https://github.com/Clay-foundation/model 16 | cd model 17 | 18 | Then we recommend [using mamba](https://mamba.readthedocs.io/en/latest/installation/mamba-installation.html) 19 | to install the dependencies. A virtual environment will also be created with Python and 20 | [JupyterLab](https://github.com/jupyterlab/jupyterlab) installed. 21 | 22 | mamba env create --file environment.yml 23 | 24 | ```{note} 25 | The command above has been tested on Linux devices with CUDA GPUs. 26 | ``` 27 | 28 | Activate the virtual environment first. 29 | 30 | mamba activate claymodel 31 | 32 | Finally, double-check that the libraries have been installed. 33 | 34 | mamba list 35 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Clay Foundation Model 2 | 3 | ## An open source AI model for Earth 4 | 5 | Clay is a [foundation model](https://www.adalovelaceinstitute.org/resource/foundation-models-explainer/) of Earth. Foundation models trained on earth observation (EO) data can efficiently distill and synthesize vast amounts of environmental data, allowing them to generalize this knowledge to specific, downstream applications. This makes them versatile and powerful tools for nature and climate applications. 6 | 7 | Clay’s model takes satellite imagery, along with information about location and time, as an input, and outputs embeddings, which are mathematical representations of a given area at a certain time on Earth’s surface. It uses a Vision Transformer architecture adapted to understand geospatial and temporal relations on Earth Observation data. The model is trained via Self-supervised learning (SSL) using a Masked Autoencoder (MAE) method. 8 | 9 | | Zero to Hero technical Guide (For developers) | Non-Technical User Guide (upcoming webapps) | 10 | |:--:|:--:| 11 | | [![](https://github.com/Clay-foundation/model/assets/434029/0cbfb109-c391-4f73-99f0-abc8769a7a14)](https://www.youtube.com/live/Zd3rbBj56P0?si=_ts3aWBcso0KEH7K) | [![](https://github.com/Clay-foundation/model/assets/434029/5cd209ec-3571-4ae7-98a1-1cef9a898f25)](https://youtu.be/gFjbrGaAL6w?si=ampWKQi9bArVoJsy) | 12 | 13 | ## Usage 14 | 15 | The Clay model can be used in three main ways: 16 | - **Generate semantic embeddings for any location and time.** You can use embeddings for a variety of tasks, including to: 17 | - _Find features:_ Locate objects or features, such as surface mines, aquaculture, or concentrated animal feeding operations. 18 | 19 | - **Fine-tune the model for downstream tasks such as classification, regression, and generative tasks.** Fine-tuning the model takes advantage of its pre-training to more efficiently classify types, predict values, or detect change than from-scratch methods. Embeddings can also be used to do the following, which require fine-tuning: 20 | - _Classify types or predict values of interest:_ Identify the types or classes of a given feature, such as crop type or land cover, or predict values of variables of interest, such as above ground biomass or agricultural productivity. 21 | - _Detect changes over time:_ Find areas that have experienced changes such as deforestation, wildfires, destruction from human conflict, flooding, or urban development. 22 | - This can be done by training a downstream model to take embeddings as input and output predicted classes/values. This could also include fine-tuning model weights to update the embeddings themselves. 23 | 24 | - **Use the model as a backbone for other models.** 25 | 26 | ## Where is what 27 | 28 | - Our **website** is [madewithclay.org](https://madewithclay.org). 29 | - The Clay model **code** lives on [Github](https://github.com/Clay-foundation/model). 30 | License: [Apache-2.0](https://github.com/Clay-foundation/model/blob/main/LICENSE). 31 | The latest release is [v1.0](https://github.com/Clay-foundation/model/releases/tag/v1.0) 32 | - The Clay model **weights** on [Hugging Face](https://huggingface.co/made-with-clay/Clay/). 33 | License: [Apache-2.0](https://github.com/Clay-foundation/model/blob/main/LICENSE). 34 | - The Clay **documentation** [lives on this site](https://clay-foundation.github.io/model/index.html). 35 | License: [CC-BY](http://creativecommons.org/licenses/by/4.0/). 36 | - We release the **embeddings** of the used training data on [Source Cooperative](https://beta.source.coop/repositories/clay/clay-model-v0-embeddings). 37 | License: [ODC-BY](https://opendatacommons.org/licenses/by/). 38 | 39 | CLAY is a fiscal sponsored project of the 501c3 non-profit 40 | [Radiant Earth](https://www.radiant.earth). 41 | 42 | --- 43 | ### Table of Contents 44 | 45 | ```{tableofcontents} 46 | ``` 47 | -------------------------------------------------------------------------------- /docs/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clay-foundation/model/b817be2dcf31226959a132357d49e9bc5c89d05d/docs/logo.png -------------------------------------------------------------------------------- /docs/release-notes/changelog-v1.0.md: -------------------------------------------------------------------------------- 1 | (software_release_v1)= 2 | # Release v1.0 (2024/06/06) 3 | 4 | ## 💫 Highlights 5 | 6 | * 🎉 **Release of Clay Foundation Model v1.0** 🎉 7 | 8 | ## 🚀 Model 9 | 10 | * Datamodule & Model for Clay v1 by @srmsoumya in https://github.com/Clay-foundation/model/pull/253 11 | 12 | ## 🗃️ Data Pipeline 13 | 14 | * Moved data pipeline code into [stacchip](https://github.com/Clay-foundation/stacchip) a dedicated library for the data wrangling 15 | 16 | ## 📖 Documentation 17 | 18 | * Clay over aoi by @yellowcap in https://github.com/Clay-foundation/model/pull/116 19 | * Convert run-over-aoi ipynb to md by @yellowcap in https://github.com/Clay-foundation/model/pull/122 20 | * Static Documentation for v0.1 by @brunosan in https://github.com/Clay-foundation/model/pull/118 21 | * Document data sampling strategy by @yellowcap in https://github.com/Clay-foundation/model/pull/124 22 | * Updating location to download imagery script in docs by @MaxLenormand in https://github.com/Clay-foundation/model/pull/198 23 | * Change "Radiant Earth Foundation" to "Radiant Earth" in docs by @kbgg in https://github.com/Clay-foundation/model/pull/214 24 | * Update instructions to re-lock conda-lock.yml file by @weiji14 in https://github.com/Clay-foundation/model/pull/225 25 | * chore(docs): correct typos and improve stylistic consistency by @kelseyjosund in https://github.com/Clay-foundation/model/pull/232 26 | * Partial inputs - Pakistan flood tutorial by @lillythomas in https://github.com/Clay-foundation/model/pull/154 27 | * Obtain patch level metadata (e.g. geospatial bounds and cloud cover), save and demo DEP use case (sim search) by @lillythomas in https://github.com/Clay-foundation/model/pull/172 28 | * Obtain patch level cloud coverage percentages by @lillythomas in https://github.com/Clay-foundation/model/pull/184 29 | * Tutorial on burn scar analysis using embeddings from partial inputs by @yellowcap in https://github.com/Clay-foundation/model/pull/149 30 | * Doc updates by @brunosan in https://github.com/Clay-foundation/model/pull/230 31 | * Update specification.md by @brunosan in https://github.com/Clay-foundation/model/pull/250 32 | * Add notebook showing how to run v1 by @yellowcap in https://github.com/Clay-foundation/model/pull/254 33 | * Documented v1 sampling strategy by @yellowcap in https://github.com/Clay-foundation/model/pull/249 34 | * Updates the mamba install instructions link. by @dbonomo in https://github.com/Clay-foundation/model/pull/252 35 | * Remove clay-v0-tutorials & add visualize embeddings for v1 by @srmsoumya in https://github.com/Clay-foundation/model/pull/256 36 | * Minor spelling fix by @tylere in https://github.com/Clay-foundation/model/pull/260 37 | * Shorten comment line length by @yellowcap in https://github.com/Clay-foundation/model/pull/261 38 | * Refactor docs by moving v0 docs into separate section by @yellowcap in https://github.com/Clay-foundation/model/pull/262 39 | * Docs v1 continued by @yellowcap in https://github.com/Clay-foundation/model/pull/263 40 | * Documented metadata file for normalization and wavelenghts by @yellowcap in https://github.com/Clay-foundation/model/pull/266 41 | * [small change] add source.coop link by @brunosan in https://github.com/Clay-foundation/model/pull/137 42 | * Segmentation on Clay by @srmsoumya in https://github.com/Clay-foundation/model/pull/257 43 | 44 | ## 🧰 Maintenance 45 | 46 | * [pre-commit.ci] pre-commit autoupdate by @pre-commit-ci in https://github.com/Clay-foundation/model/pull/206 47 | * Bump conda-lock from 2.5.1 to 2.5.6 by @weiji14 in https://github.com/Clay-foundation/model/pull/183 48 | * Remove scripts folder by @yellowcap in https://github.com/Clay-foundation/model/pull/265 49 | * Unified v1 and v0 environments by @yellowcap in https://github.com/Clay-foundation/model/pull/264 50 | * Add osx-arm64 platform to conda-lock.yml file and GitHub Actions CI by @weiji14 in https://github.com/Clay-foundation/model/pull/164 51 | * Move pipeline scripts into dedicated folder by @yellowcap in https://github.com/Clay-foundation/model/pull/163 52 | * Memory buildup clean by @brunosan in https://github.com/Clay-foundation/model/pull/158 53 | * Add osx platform to conda lock file and CI by @chuckwondo in https://github.com/Clay-foundation/model/pull/162 54 | * Add scikit-learn by @weiji14 in https://github.com/Clay-foundation/model/pull/159 55 | * Remove float16 dype in favor of float32 by @yellowcap in https://github.com/Clay-foundation/model/pull/143 56 | * Remove stale bands argument from clay model classes by @yellowcap in https://github.com/Clay-foundation/model/pull/144 57 | 58 | ## :hammer_and_wrench: v0.2 updates 59 | 60 | * Add normalization parameters to documentation by @yellowcap in https://github.com/Clay-foundation/model/pull/152 61 | * Worldcover embeddings conus by @yellowcap in https://github.com/Clay-foundation/model/pull/153 62 | * Create data/minicubes directory if it doesn't already exist by @weiji14 in https://github.com/Clay-foundation/model/pull/160 63 | * Clay pipeline v04 by @yellowcap in https://github.com/Clay-foundation/model/pull/173 64 | * Shuffle patches only during training by @yellowcap in https://github.com/Clay-foundation/model/pull/135 65 | * Add option to output raw patch embeddings by @yellowcap in https://github.com/Clay-foundation/model/pull/133 66 | 67 | ## 🧑‍🤝‍🧑 New Contributors 68 | 69 | * @chuckwondo made their first contribution in https://github.com/Clay-foundation/model/pull/162 70 | * @MaxLenormand made their first contribution in https://github.com/Clay-foundation/model/pull/198 71 | * @kbgg made their first contribution in https://github.com/Clay-foundation/model/pull/214 72 | * @kelseyjosund made their first contribution in https://github.com/Clay-foundation/model/pull/232 73 | * @dbonomo made their first contribution in https://github.com/Clay-foundation/model/pull/252 74 | * @tylere made their first contribution in https://github.com/Clay-foundation/model/pull/260 75 | 76 | **Full Changelog**: https://github.com/Clay-foundation/model/compare/v0.0.1...v1.0 77 | -------------------------------------------------------------------------------- /docs/release-notes/specification.md: -------------------------------------------------------------------------------- 1 | # Pretrained Model release v1.5 2 | 3 | This changelog is a summary of the changes to the pretrained model weights for the Clay model. We follow the "[Stanford Foundation Model Transparency Index](https://github.com/stanford-crfm/fmti)" 4 | 5 | Model weights released on 2024/11/19. 6 | 7 | > For release notes for the source code, see [](software_release_v1.5) 8 | 9 | ## Summary 10 | 11 | Clay v1.5 is our MAE-based model designed to handle inputs from a variety of satellite sensors, including Sentinel-2, Landsat, Sentinel-1 SAR, LINZ, NAIP and MODIS. It supports inputs of any size and any number of bands. 12 | 13 | ### **Acknowledgments and Inspirations:** 14 | 15 | Clay v1.5 is based on the foundational work of several pioneering models and research papers. We owe a significant debt of gratitude to the following projects, which provided architectural inspiration and implementation guidance: 16 | 17 | - **DOFA**: [Code](https://github.com/zhu-xlab/DOFA), [Paper](https://arxiv.org/abs/2403.15356) 18 | - **GFM**: [Code](https://github.com/mmendiet/GFM), [Paper](https://arxiv.org/abs/2302.04476) 19 | - **Prithvi**: [Code](https://github.com/NASA-IMPACT/hls-foundation-os), [Paper](https://arxiv.org/abs/2310.18660) 20 | - **SatMAE**: [Project](https://sustainlab-group.github.io/SatMAE/) 21 | - **ScaleMAE**: [Project](https://ai-climate.berkeley.edu/scale-mae-website/) 22 | - **Spectral-GPT**: [Paper](https://arxiv.org/abs/2311.07113) 23 | 24 | ### **Components of Clay v1.5:** 25 | 26 | 1. **Dynamic Embedding Block**: This component generates patches for the chips from the number of bands and their wavelengths, which are then fed into the masked autoencoder (MAE). 27 | 2. **Position Encoding**: This component encodes spatial and temporal information by adding positional encoding to the model. This encoding is scaled according to the Ground Sampling Distance (GSD) and is combined with location information (latitude/longitude) and time step (week/hour). 28 | 3. **Masked Autoencoder (MAE)**: A VIT-based MAE is used to reconstruct the sensor data for all input bands. This contributes to 95% of the total loss, known as the reconstruction loss. 29 | 4. **Teacher**: DINOv2 is used as a teacher to compute the representation loss, which accounts for the remaining 5% of the total loss. 30 | 31 | ### **Pre-training and Usage:** 32 | 33 | The pre-trained model can process stacks of geospatial data from different sensors with various resolutions and bands, and output vector embeddings. During pre-training, the model processes stacks of chips from different sensors along with metadata such as wavelengths, GSD, latitude/longitude, and time step. The task involves capturing spatial, temporal, and spectral information about Earth and representing these relationships in high-dimensional space. Each resulting embedding represents a specific area of Earth at a particular time. 34 | 35 | Clay v1.5 was trained on 70 million globally distributed chips of size 156x256, collected according to the land use/land cover (LULC) statistics of the globe. The training was conducted on AWS using 20 g6.48xlarge instances for ~100 epochs in Sep 2024. 36 | 37 | You can access the model weights on HuggingFace [here](https://huggingface.co/made-with-clay/Clay/tree/main/v1.5). 38 | 39 | ## Model Architecture 40 | ![Architecture](https://github.com/Clay-foundation/model/assets/8049519/f6a1e92c-3993-4148-98a2-e3805dae4414) 41 | 42 | Clay v1.5's architecture includes a dynamic embedding block for generating patches from multi-band inputs, position encoding to integrate spatial and temporal data, a Vision Transformer-based masked autoencoder (MAE) for reconstructing sensor data, and a DINOv2 teacher model to enhance representation learning. This architecture allows the model to process inputs from various satellite sensors of any size and number of bands, capturing complex geospatial information effectively. 43 | 44 | For more details, you can view the source code [here](https://github.com/Clay-foundation/model/blob/main/src/model.py). 45 | 46 | https://github.com/Clay-foundation/model/blob/main/LICENSE 47 | 48 | - Core Framework: [Lightning](https://lightning.ai/) and its dependencies, such as PyTorch, etc. 49 | - Input modalities: 50 | - A fixed specification of 10 bands from Sentinel-2, 6 bands from Landsat, 4 bands from NAIP, 3 bands from LINZ, 2 bands from Sentinel-1 data and 7 bands from MODIS. 51 | - Output modalities: 52 | - As a masked auto-encoder, it has a fixed specification of 10 bands from Sentinel-2, 6 bands from Landsat, 4 bands from NAIP, 3 bands from LINZ, 2 bands from Sentinel-1 data and 7 bands from MODIS to closely mimic the input. 53 | - Model size (in millions): 54 | - Number of parameters: `632M` 55 | - Encoder size: `311M` 56 | - Decoder size: `15M` 57 | - Teacher size: `304M` 58 | - Model size on disk (just the encoder): `1.25 GB`. 59 | - Model license: 60 | - Source code and Model weights: [Apache-2.0](https://github.com/Clay-foundation/model/blob/main/LICENSE) 61 | - Feedback and redress mechanisms: 62 | - Please open an issue or discussion on the [GitHub repository](https://github.com/Clay-foundation/model) or send an email to `bruno@madewithclay.org`. 63 | 64 | ## Model Card 65 | 66 | For Clay v1.5, we utilized the [`clay_mae_large`](https://github.com/Clay-foundation/model/blob/80012459793cf71f4482b6d0de5254da83f642c6/src/model.py#L608-L624) and the model weights can be found on Huggingface [here](https://huggingface.co/made-with-clay/Clay/resolve/main/v1.5/clay-v1.5.ckpt). 67 | 68 | ``` 69 | MASKED PATCHES = 75% 70 | INPUT SIZE = 256 71 | 72 | NORM_PIX_LOSS = FALSE 73 | PATCH SIZE = 8 74 | 75 | OPTIMIZER 76 | AdamW 77 | Learning rate = 1e-5 78 | Weight decay = 0.05 79 | Beta 1 = 0.9 80 | Beta 2 = 0.95 81 | 82 | SCHEDULER 83 | CosineAnnealingWarmRestarts 84 | T_0 = 1000 85 | T_mult = 2 86 | eta_min = Learning rate * 100 87 | 88 | ENCODER 89 | dim = 1024 90 | depth = 24 91 | heads = 16 92 | dim_head = 64 93 | mlp_ratio = 4 94 | 95 | DECODER 96 | decoder_dim = 512 97 | decoder_depth = 4 98 | decoder_heads = 4 99 | decoder_dim_head = 64 100 | decoder_mlp_ratio = 4 101 | ``` 102 | 103 | ## Data Card 104 | 105 | The data used for this model is described in detail in the [](training-data) section. 106 | 107 | ## Normalization and Wavelength parameters 108 | 109 | The normalization parameters depend on the input system that is used. They are 110 | therefore not static values, but rather an input variable to the model. 111 | 112 | Similarly, the model takes the central wavelength of each input band as a variable. 113 | 114 | During training we used Sentinel-2, Sentinel-1, Landsat 8 and 9, NAIP, LINZ and MODIS data. For these we compiled normalization and wavelength values that were used 115 | during training. These can be used for inferencing when passing data from any of 116 | these systems. 117 | 118 | The normalization and wavelength parameters can be found in the following 119 | [metadata file](https://github.com/Clay-foundation/model/blob/main/configs/metadata.yaml). 120 | 121 | ## Training Card 122 | 123 | * Compute Resources: 124 | * 20 AWS EC2 g6.48xlarge with 8 NVIDIA L4 GPUs each 125 | * Training Time: 126 | * `100` epochs, each taking ~`8h` to train. 127 | * Carbon Emissions: 128 | * According to the "Customer Carbon Emission Tool", there were no Scope 1 or Scope 2 carbon emissions. Following the [documentation](https://docs.aws.amazon.com/awsaccountbilling/latest/aboutv2/ccft-estimation.html), we believe this is due to the usage of renewable energy sources. We are aware that Scope 3 emissions might be significant for data centers and that these are not included in the estimate. 129 | * Training stages: 130 | * While developing the model we run small tests locally and on the cloud. We estimate that all testing and development compute is less than the compute used for 1 epoch of training. 131 | * QA of the model is also done locally and on the cloud, and we estimate that it is less than the compute used for 1 epoch of training. 132 | * Release and distribution: 133 | * Model development happens in an open source repository on GitHub [here](https://github.com/Clay-foundation/model/). 134 | * We release the model weights on HuggingFace [here](https://huggingface.co/made-with-clay/Clay/). 135 | * We release the embeddings on Source Cooperative [here](https://beta.source.coop/clay/). 136 | * We do not have other distribution channels at this time. 137 | * Production use: 138 | * We support our partners to build applications with the model, and we expect them to use the model in production. 139 | 140 | 141 | ## Results 142 | 143 | As a foundation model, it is designed to be used as a building block for other models. We have documented examples of how the [embedding space](../tutorials/embeddings.ipynb) and the [reconstructions](../tutorials/reconstruction.ipynb) look like for the base model. 144 | 145 | 146 | ### Performance Metrics 147 | The model shows the following performance characteristics for its Masked Autoencoder objective: 148 | * Training loss: `0.165` 149 | * Validation loss: `0.165` 150 | 151 | ## Known Limitations and Biases 152 | 153 | - Training data for this model only covers land and coastal waters. 154 | - We only train on a very small sample of the source archives, both in terms of spatial coverage and time. 155 | - We do not train on the poles, and we do not train on open ocean, nor ocean nor atmospheric volumetric data. 156 | - We do not train on night time data. 157 | - We do not explicitly include extreme events in the training data. 158 | - We only train at most 6 different times per location. 159 | 160 | 161 | ## Ethical Considerations 162 | 163 | Our goal is to lower the barrier to use EO data for biodiversity and climate change mitigation and adaptation. We have designed our model to support this goal. 164 | 165 | We have also designed our model to be as open as possible, as modular as possible, as undifferentiated and general as possible, and as well documented as possible, so we can maximize the leverage of the resources needed for the creation of this model. 166 | 167 | As a fully open model, we cannot however control how it is used. We are aware that EO data can be used for harmful purposes, and we are committed to work with our partners to prevent this from happening. 168 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: claymodel 2 | channels: 3 | - conda-forge 4 | - nvidia 5 | - pytorch 6 | dependencies: 7 | - python=3.11 8 | - pip 9 | - pip: 10 | - --extra-index-url https://download.pytorch.org/whl/cu121 11 | - torch==2.4.0+cu121 12 | - torchvision==0.19.0+cu121 13 | - einops~=0.7.0 14 | - geopandas 15 | - jsonargparse[signatures]>=4.27.7 16 | - lightning 17 | - matplotlib 18 | - python-box 19 | - scikit-image 20 | - scikit-learn 21 | - timm 22 | - vit-pytorch 23 | - wandb 24 | -------------------------------------------------------------------------------- /finetune/classify/classify.py: -------------------------------------------------------------------------------- 1 | """ 2 | Command line interface to run the neural network model! 3 | 4 | From the project root directory, do: 5 | 6 | python classify.py fit --config configs/classify_eurosat.yaml 7 | 8 | References: 9 | - https://lightning.ai/docs/pytorch/2.1.0/cli/lightning_cli.html 10 | - https://pytorch-lightning.medium.com/introducing-lightningcli-v2-supercharge-your-training-c070d43c7dd6 11 | """ 12 | 13 | from lightning.pytorch.cli import LightningCLI 14 | 15 | from finetune.classify.eurosat_datamodule import EuroSATDataModule # noqa: F401 16 | from finetune.classify.eurosat_model import EuroSATClassifier # noqa: F401 17 | 18 | 19 | # %% 20 | def cli_main(): 21 | """ 22 | Command-line inteface to run Clasifier model with EuroSATDataModule. 23 | """ 24 | cli = LightningCLI( 25 | EuroSATClassifier, EuroSATDataModule, save_config_kwargs={"overwrite": True} 26 | ) 27 | return cli 28 | 29 | 30 | # %% 31 | if __name__ == "__main__": 32 | cli_main() 33 | 34 | print("Done!") 35 | -------------------------------------------------------------------------------- /finetune/classify/eurosat_datamodule.py: -------------------------------------------------------------------------------- 1 | import lightning as L 2 | import torch 3 | import yaml 4 | from box import Box 5 | from torch.utils.data import DataLoader 6 | from torchgeo.datasets import EuroSAT as TGEuroSAT 7 | from torchvision.transforms import v2 8 | 9 | S2_BANDS = [ 10 | "B02", 11 | "B03", 12 | "B04", 13 | "B05", 14 | "B06", 15 | "B07", 16 | "B08", 17 | "B08A", 18 | "B11", 19 | "B12", 20 | ] 21 | 22 | 23 | class EuroSAT(TGEuroSAT): 24 | """ 25 | Subclass of TGEuroSAT to customize the dataset loading and 26 | transformations. 27 | 28 | Args: 29 | root (str): Root directory of the dataset. 30 | split (str): Dataset split to use ('train' or 'val'). 31 | bands (list): List of spectral bands to use. 32 | transforms (callable): Transformations to apply to the samples. 33 | download (bool): If true, downloads the dataset. 34 | """ 35 | 36 | def __init__(self, root, split, bands, transforms, download): 37 | super().__init__(root, split, bands, transforms, download) 38 | 39 | def __getitem__(self, index): 40 | """ 41 | Override the __getitem__ method to apply custom transformations. 42 | 43 | Args: 44 | index (int): Index of the sample to retrieve. 45 | 46 | Returns: 47 | dict: A dictionary containing the image tensor, label, and 48 | additional metadata. 49 | """ 50 | image, label = self._load_image(index) 51 | 52 | image = torch.index_select(image, dim=0, index=self.band_indices).float() 53 | sample = { 54 | "pixels": image, 55 | "label": label, 56 | "time": torch.zeros(4), # Placeholder for time information 57 | "latlon": torch.zeros(4), # Placeholder for lat/lon information 58 | } 59 | 60 | if self.transforms is not None: 61 | sample = self.transforms(sample) 62 | 63 | return sample 64 | 65 | 66 | class EuroSATDataModule(L.LightningDataModule): 67 | """ 68 | Data module for loading and transforming the EuroSAT dataset. 69 | 70 | Args: 71 | batch_size (int): Batch size for the dataloaders. 72 | num_workers (int): Number of workers for data loading. 73 | metadata_path (str): Path to the metadata file for normalization 74 | statistics. 75 | """ 76 | 77 | def __init__(self, batch_size, num_workers, metadata_path): 78 | super().__init__() 79 | self.batch_size = batch_size 80 | self.num_workers = num_workers 81 | 82 | metadata = Box(yaml.safe_load(open(metadata_path)))["sentinel-2-l2a"] 83 | mean = list(metadata.bands.mean.values()) 84 | std = list(metadata.bands.std.values()) 85 | 86 | self.trn_tfm = v2.Compose( 87 | [ 88 | v2.RandomHorizontalFlip(), 89 | v2.RandomVerticalFlip(), 90 | v2.Normalize(mean, std), 91 | ] 92 | ) 93 | self.val_tfm = v2.Compose([v2.Normalize(mean, std)]) 94 | 95 | def setup(self, stage=None): 96 | """ 97 | Setup the datasets for training and validation. 98 | 99 | Args: 100 | stage (str): Stage of the training process ('fit', 'validate', 101 | etc.). 102 | """ 103 | if stage in {"fit", None}: 104 | self.trn_ds = EuroSAT( 105 | root="data", 106 | split="train", 107 | bands=S2_BANDS, 108 | transforms=self.trn_tfm, 109 | download=True, 110 | ) 111 | self.val_ds = EuroSAT( 112 | root="data", 113 | split="val", 114 | bands=S2_BANDS, 115 | transforms=self.val_tfm, 116 | download=True, 117 | ) 118 | 119 | def train_dataloader(self): 120 | """ 121 | Returns the DataLoader for the training dataset. 122 | 123 | Returns: 124 | DataLoader: DataLoader for the training dataset. 125 | """ 126 | return DataLoader( 127 | self.trn_ds, 128 | batch_size=self.batch_size, 129 | shuffle=True, 130 | num_workers=self.num_workers, 131 | pin_memory=True, 132 | prefetch_factor=2, 133 | ) 134 | 135 | def val_dataloader(self): 136 | """ 137 | Returns the DataLoader for the validation dataset. 138 | 139 | Returns: 140 | DataLoader: DataLoader for the validation dataset. 141 | """ 142 | return DataLoader( 143 | self.val_ds, 144 | batch_size=self.batch_size * 2, 145 | shuffle=False, 146 | num_workers=self.num_workers, 147 | pin_memory=True, 148 | prefetch_factor=2, 149 | ) 150 | -------------------------------------------------------------------------------- /finetune/classify/eurosat_model.py: -------------------------------------------------------------------------------- 1 | import lightning as L 2 | import torch 3 | from torch import nn, optim 4 | from torchmetrics import Accuracy 5 | 6 | from finetune.classify.factory import Classifier 7 | 8 | 9 | class EuroSATClassifier(L.LightningModule): 10 | """ 11 | LightningModule for training and evaluating a classifier on the EuroSAT 12 | dataset. 13 | 14 | Args: 15 | num_classes (int): Number of classes for classification. 16 | ckpt_path (str): Clay MAE pretrained checkpoint path. 17 | lr (float): Learning rate for the optimizer. 18 | wd (float): Weight decay for the optimizer. 19 | b1 (float): Beta1 parameter for the Adam optimizer. 20 | b2 (float): Beta2 parameter for the Adam optimizer. 21 | """ 22 | 23 | def __init__(self, num_classes, ckpt_path, lr, wd, b1, b2): # noqa: PLR0913 24 | super().__init__() 25 | self.save_hyperparameters() 26 | self.model = Classifier(num_classes=num_classes, ckpt_path=ckpt_path) 27 | self.loss_fn = nn.CrossEntropyLoss() 28 | self.accuracy = Accuracy(task="multiclass", num_classes=num_classes) 29 | 30 | def forward(self, datacube): 31 | """ 32 | Forward pass through the classifier. 33 | 34 | Args: 35 | datacube (dict): A dictionary containing the input datacube 36 | and meta information like time, latlon, gsd & wavelenths. 37 | 38 | Returns: 39 | torch.Tensor: The output logits from the classifier. 40 | """ 41 | # Wavelengths for Sentinel 2 bands of EuroSAT dataset 42 | waves = torch.tensor( 43 | [0.493, 0.56, 0.665, 0.704, 0.74, 0.783, 0.842, 0.865, 1.61, 2.19] 44 | ) 45 | gsd = torch.tensor(10.0) 46 | 47 | return self.model( 48 | { 49 | "pixels": datacube["pixels"], 50 | "time": datacube["time"], 51 | "latlon": datacube["latlon"], 52 | "gsd": gsd, 53 | "waves": waves, 54 | } 55 | ) 56 | 57 | def configure_optimizers(self): 58 | """ 59 | Configure the optimizer and learning rate scheduler. 60 | 61 | Returns: 62 | dict: A dictionary containing the optimizer and learning rate 63 | scheduler. 64 | """ 65 | optimizer = optim.AdamW( 66 | [ 67 | param 68 | for name, param in self.model.named_parameters() 69 | if param.requires_grad 70 | ], 71 | lr=self.hparams.lr, 72 | weight_decay=self.hparams.wd, 73 | betas=(self.hparams.b1, self.hparams.b2), 74 | ) 75 | scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts( 76 | optimizer, T_0=100, T_mult=1, eta_min=self.hparams.lr * 100, last_epoch=-1 77 | ) 78 | return { 79 | "optimizer": optimizer, 80 | "lr_scheduler": { 81 | "scheduler": scheduler, 82 | "interval": "step", 83 | }, 84 | } 85 | 86 | def shared_step(self, batch, batch_idx, phase): 87 | """ 88 | Perform a shared step for both training and validation. 89 | 90 | Args: 91 | batch (dict): A batch of data. 92 | batch_idx (int): The index of the batch. 93 | phase (str): The phase ('train' or 'val'). 94 | 95 | Returns: 96 | torch.Tensor: The computed loss for the batch. 97 | """ 98 | labels = batch["label"].long() 99 | logits = self(batch) 100 | loss = self.loss_fn(logits, labels) 101 | score = self.accuracy(logits, labels) 102 | 103 | self.log( 104 | f"{phase}/loss", 105 | loss, 106 | on_step=True, 107 | on_epoch=True, 108 | prog_bar=True, 109 | logger=True, 110 | sync_dist=True, 111 | ) 112 | self.log( 113 | f"{phase}/score", 114 | score, 115 | on_step=True, 116 | on_epoch=True, 117 | prog_bar=True, 118 | logger=True, 119 | sync_dist=True, 120 | ) 121 | return loss 122 | 123 | def training_step(self, batch, batch_idx): 124 | """ 125 | Perform a training step. 126 | 127 | Args: 128 | batch (dict): A batch of training data. 129 | batch_idx (int): The index of the batch. 130 | 131 | Returns: 132 | torch.Tensor: The computed loss for the batch. 133 | """ 134 | return self.shared_step(batch, batch_idx, "train") 135 | 136 | def validation_step(self, batch, batch_idx): 137 | """ 138 | Perform a validation step. 139 | 140 | Args: 141 | batch (dict): A batch of validation data. 142 | batch_idx (int): The index of the batch. 143 | 144 | Returns: 145 | torch.Tensor: The computed loss for the batch. 146 | """ 147 | return self.shared_step(batch, batch_idx, "val") 148 | -------------------------------------------------------------------------------- /finetune/classify/factory.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from src.model import Encoder 7 | 8 | 9 | class Classifier(nn.Module): 10 | """ 11 | Classifier class uses Clay Encoder for feature extraction and a head for 12 | classification. 13 | 14 | Attributes: 15 | clay_encoder (Encoder): The encoder for feature extraction. 16 | head (nn.Sequential): The head for classification. 17 | device (torch.device): The device to run the model on. 18 | """ 19 | 20 | def __init__(self, num_classes=10, ckpt_path=None): 21 | """ 22 | Initialize the Classifier. 23 | 24 | Args: 25 | num_classes (int, optional): The number of classes for 26 | classification. Defaults to 10. 27 | ckpt_path (str, optional): Clay MAE pretrained model checkpoint 28 | path. Defaults to None. 29 | """ 30 | super().__init__() 31 | 32 | # Initialize Clay Encoder with parameters from base model. Set 33 | # mask_ratio to 0.0 & shuffle to False for downstream tasks. 34 | # self.clay_encoder = Encoder( 35 | # mask_ratio=0.0, 36 | # patch_size=8, 37 | # shuffle=False, 38 | # dim=768, 39 | # depth=12, 40 | # heads=12, 41 | # dim_head=64, 42 | # mlp_ratio=4.0, 43 | # ) 44 | self.clay_encoder = Encoder( 45 | mask_ratio=0.0, 46 | patch_size=8, 47 | shuffle=False, 48 | dim=1024, 49 | depth=24, 50 | heads=16, 51 | dim_head=64, 52 | mlp_ratio=4.0, 53 | # feature_maps=feature_maps, 54 | # ckpt_path=ckpt_path, 55 | ) 56 | 57 | # Simple 2 layer MLP head for classification 58 | self.head = nn.Sequential( 59 | nn.Linear(1024, 512), 60 | nn.ReLU(), 61 | nn.Dropout(0.25), 62 | nn.Linear(512, num_classes), 63 | ) 64 | 65 | # Determine the device to run the model on 66 | self.device = ( 67 | torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 68 | ) 69 | 70 | # Load Clay MAE pretrained weights for the Encoder 71 | self.load_clay_weights(ckpt_path) 72 | 73 | def load_clay_weights(self, ckpt_path): 74 | """ 75 | Load the weights for Clay MAE Encoder from a checkpoint file. 76 | 77 | Args: 78 | ckpt_path (str): Clay MAE pretrained model checkpoint path. 79 | """ 80 | # Load the checkpoint file 81 | ckpt = torch.load(ckpt_path, map_location=self.device) 82 | state_dict = ckpt.get("state_dict") 83 | 84 | # Remove model.encoder prefix for the clay encoder 85 | state_dict = { 86 | re.sub(r"^model\.encoder\.", "", name): param 87 | for name, param in state_dict.items() 88 | if name.startswith("model.encoder") 89 | } 90 | 91 | # Copy the weights from the state dict to the encoder 92 | for name, param in self.clay_encoder.named_parameters(): 93 | if name in state_dict and param.size() == state_dict[name].size(): 94 | param.data.copy_(state_dict[name]) # Copy the weights 95 | else: 96 | print(f"No matching parameter for {name} with size {param.size()}") 97 | 98 | # Freeze clay encoder 99 | for param in self.clay_encoder.parameters(): 100 | param.requires_grad = False 101 | 102 | # Set the encoder to evaluation mode 103 | self.clay_encoder.eval() 104 | 105 | def forward(self, datacube): 106 | """ 107 | Forward pass of the Classifier. 108 | 109 | Args: 110 | datacube (torch.Tensor): A dictionary containing the input datacube 111 | and meta information like time, latlon, gsd & wavelenths. 112 | 113 | Returns: 114 | torch.Tensor: The output logits. 115 | """ 116 | # Get the embeddings from the encoder 117 | embeddings, *_ = self.clay_encoder( 118 | datacube 119 | ) # embeddings: batch x (1 + row x col) x 768 120 | 121 | # Use only the first embedding i.e cls token 122 | embeddings = embeddings[:, 0, :] 123 | 124 | # Pass the embeddings through the head to get the logits 125 | logits = self.head(embeddings) 126 | 127 | return logits 128 | -------------------------------------------------------------------------------- /finetune/embedder/factory.py: -------------------------------------------------------------------------------- 1 | """Export the Clay model to ONNX and pytorch ExportedProgram format. 2 | 3 | This script exports the Clay model to ONNX and pytorch ExportedProgram format 4 | for deployment. The model is exported with dynamic shapes for inference. 5 | 6 | How to use: 7 | 8 | ```bash 9 | python -m finetune.embedder.factory \ 10 | --img_size 256 \ 11 | --ckpt_path checkpoints/clay_v1.5.ckpt \ 12 | --device cuda \ 13 | --name clay-v1.5-encoder.onnx \ 14 | --onnx 15 | # exports Clay encoder to ONNX format that can handle chips of size 256x256 16 | # for different sensors like Sentinel-2, Landsat-8, NAIP, LINZ & Sentinel 1. 17 | ``` 18 | 19 | ```bash 20 | python -m finetune.embedder.factory \ 21 | --img_size 224 \ 22 | --ckpt_path checkpoints/clay_v1.5.ckpt \ 23 | --device cuda \ 24 | --name clay-v1.5-encoder.pt2 \ 25 | --ep 26 | # exports Clay encoder to pytorch ExportedProgram format that can handle chips 27 | # of size 224x224 for different sensors like Sentinel-2, Landsat-8, NAIP, LINZ 28 | # & Sentinel 1. 29 | ``` 30 | 31 | """ 32 | 33 | import argparse 34 | import re 35 | import warnings 36 | from pathlib import Path 37 | 38 | import torch 39 | from einops import repeat 40 | from torch import nn 41 | from torch.export import Dim 42 | 43 | from src.model import Encoder 44 | from src.utils import posemb_sincos_2d_with_gsd 45 | 46 | warnings.filterwarnings("ignore", category=UserWarning) 47 | 48 | 49 | class EmbeddingEncoder(Encoder): 50 | """Clay Encoder without mask and shuffle.""" 51 | 52 | def __init__( # noqa: PLR0913 53 | self, 54 | img_size, 55 | patch_size, 56 | dim, 57 | depth, 58 | heads, 59 | dim_head, 60 | mlp_ratio, 61 | ): 62 | super().__init__( 63 | mask_ratio=0.0, 64 | shuffle=False, 65 | patch_size=patch_size, 66 | dim=dim, 67 | depth=depth, 68 | heads=heads, 69 | dim_head=dim_head, 70 | mlp_ratio=mlp_ratio, 71 | ) 72 | self.img_size = img_size 73 | 74 | # Using fixed grid size for inference 75 | self.grid_size = img_size // patch_size 76 | self.num_patches = self.grid_size**2 77 | 78 | def add_encodings(self, patches, time, latlon, gsd): 79 | """Add position encoding to the patches""" 80 | B, L, D = patches.shape 81 | 82 | grid_size = self.grid_size 83 | 84 | pos_encoding = ( 85 | posemb_sincos_2d_with_gsd( 86 | h=grid_size, 87 | w=grid_size, 88 | dim=(self.dim - 8), 89 | gsd=gsd, 90 | ) 91 | .to(patches.device) 92 | .detach() 93 | ) # [L (D - 8)] 94 | 95 | time_latlon = torch.hstack((time, latlon)).to(patches.device).detach() # [B 8] 96 | 97 | pos_encoding = repeat(pos_encoding, "L D -> B L D", B=B) # [B L (D - 8)] 98 | time_latlon = repeat(time_latlon, "B D -> B L D", L=L) # [B L 8] 99 | pos_metadata_encoding = torch.cat( 100 | (pos_encoding, time_latlon), dim=-1 101 | ) # [B L D] 102 | 103 | patches = patches + pos_metadata_encoding # [B L D] + [B L D] -> [B L D] 104 | return patches # [B L D] 105 | 106 | # def forward(self, cube, time, latlon, waves, gsd): 107 | def forward(self, datacube): 108 | cube, time, latlon, gsd, waves = ( 109 | datacube["pixels"], # [B C H W] 110 | datacube["time"], # [B 2] 111 | datacube["latlon"], # [B 2] 112 | datacube["gsd"], # 1 113 | datacube["waves"], # [N] 114 | ) # [B C H W] 115 | B, C, H, W = cube.shape 116 | 117 | patches, _ = self.to_patch_embed( 118 | cube, waves 119 | ) # [B L D] - patchify & create embeddings per patch 120 | 121 | # Add time & latlon as encoding to patches 122 | patches = self.add_encodings( 123 | patches, 124 | time, 125 | latlon, 126 | gsd, 127 | ) # [B L D] - add position encoding to the embeddings 128 | 129 | # Add class tokens 130 | cls_tokens = repeat(self.cls_token, "1 1 D -> B 1 D", B=B) # [B 1 D] 131 | patches = torch.cat((cls_tokens, patches), dim=1) # [B (1 + L) D] 132 | 133 | # pass the patches through the transformer 134 | patches = self.transformer(patches) # [B (1 + L) D] 135 | 136 | # get the cls token 137 | embeddings = patches[:, 0, :] # [B D] 138 | 139 | return embeddings 140 | 141 | 142 | class Embedder(nn.Module): 143 | def __init__(self, img_size=256, ckpt_path=None, device="cpu"): 144 | super().__init__() 145 | self.clay_encoder = ( 146 | EmbeddingEncoder( # Default parameters for the Clay base model 147 | img_size=img_size, 148 | patch_size=8, 149 | dim=1024, 150 | depth=24, 151 | heads=16, 152 | dim_head=64, 153 | mlp_ratio=4.0, 154 | ).to(device) 155 | ) 156 | self.img_size = img_size 157 | self.device = torch.device(device) 158 | self.load_clay_weights(ckpt_path) 159 | 160 | def load_clay_weights(self, ckpt_path): 161 | "Load the weights from the Clay model encoder." 162 | ckpt = torch.load(ckpt_path, map_location=self.device) 163 | state_dict = ckpt.get("state_dict") 164 | state_dict = { 165 | re.sub(r"^model\.encoder\.", "", name): param 166 | for name, param in state_dict.items() 167 | if name.startswith("model.encoder") 168 | } 169 | 170 | with torch.no_grad(): 171 | for name, param in self.clay_encoder.named_parameters(): 172 | if name in state_dict and param.size() == state_dict[name].size(): 173 | param.data.copy_(state_dict[name]) # Copy the weights 174 | else: 175 | print(f"No matching parameter for {name} with size {param.size()}") 176 | 177 | for param in self.clay_encoder.parameters(): 178 | param.requires_grad = False 179 | 180 | self.clay_encoder.eval() 181 | 182 | def forward(self, datacube): 183 | embeddings = self.clay_encoder(datacube) 184 | 185 | return embeddings 186 | 187 | def fake_datacube(self): 188 | "Generate a fake datacube for model export." 189 | dummy_datacube = { 190 | "pixels": torch.randn(2, 3, self.img_size, self.img_size), 191 | "time": torch.randn(2, 4), 192 | "latlon": torch.randn(2, 4), 193 | "waves": torch.randn(3), 194 | "gsd": torch.randn(1), 195 | } 196 | dummy_datacube = {k: v.to(self.device) for k, v in dummy_datacube.items()} 197 | return dummy_datacube 198 | 199 | def export_to_onnx(self, name): 200 | "Save the model to ONNX format." 201 | 202 | datacube = self.fake_datacube() 203 | export_options = torch.onnx.ExportOptions(dynamic_shapes=True) 204 | 205 | # Export the model to ONNX format 206 | onnx_program = torch.onnx.dynamo_export( 207 | self.eval(), datacube, export_options=export_options 208 | ) 209 | 210 | # Save the exported model 211 | onnx_program.save(f"checkpoints/compiled/{name}") 212 | print(f"Model exported to ONNX format: checkpoints/compiled/{name}") 213 | 214 | return onnx_program 215 | 216 | def export_to_torchep(self, name): 217 | "Save the model to pytorch ExportedProgram format." 218 | 219 | datacube = self.fake_datacube() 220 | 221 | # dynamic shapes for model export 222 | batch_size = Dim("batch_size", min=2, max=1000) 223 | channel_bands = Dim("channel_bands", min=1, max=10) 224 | dynamic_shapes = { 225 | "datacube": { 226 | "pixels": {0: batch_size, 1: channel_bands}, 227 | "time": {0: batch_size}, 228 | "latlon": {0: batch_size}, 229 | "waves": {0: channel_bands}, 230 | "gsd": {0: None}, 231 | } 232 | } 233 | 234 | # Export the model to pytorch ExportedProgram format 235 | ep = torch.export.export( 236 | self.eval(), 237 | (datacube,), 238 | dynamic_shapes=dynamic_shapes, 239 | strict=True, 240 | ) 241 | 242 | # Save the exported model 243 | torch.export.save(ep, f"checkpoints/compiled/{name}") 244 | print( 245 | f"Model exported to pytorch ExportedProgram format: checkpoints/compiled/{name}" # noqa: E501 246 | ) 247 | 248 | return ep 249 | 250 | 251 | if __name__ == "__main__": 252 | parser = argparse.ArgumentParser(description="Export the Clay model.") 253 | parser.add_argument( 254 | "--img_size", 255 | type=int, 256 | default=256, 257 | help="Image size for the model", 258 | ) 259 | parser.add_argument( 260 | "--ckpt_path", 261 | type=str, 262 | default="checkpoints/clay-v1-base.ckpt", 263 | help="Path to the Clay model checkpoint", 264 | ) 265 | parser.add_argument( 266 | "--device", 267 | type=str, 268 | default="cuda", 269 | help="Device to use for the model", 270 | ) 271 | parser.add_argument( 272 | "--name", 273 | type=str, 274 | default="clay-base.pt", 275 | help="Name of the exported model", 276 | ) 277 | parser.add_argument( 278 | "--onnx", 279 | action="store_true", 280 | help="Export the model to ONNX format", 281 | ) 282 | parser.add_argument( 283 | "--ep", 284 | action="store_true", 285 | help="Export the model to pytorch ExportedProgram format", 286 | ) 287 | 288 | args = parser.parse_args() 289 | 290 | Path("checkpoints/compiled").mkdir(parents=True, exist_ok=True) 291 | embedder = Embedder( 292 | img_size=args.img_size, 293 | ckpt_path=args.ckpt_path, 294 | device=args.device, 295 | ) 296 | 297 | if args.onnx: 298 | embedder.export_to_onnx(args.name) 299 | elif args.ep: 300 | embedder.export_to_torchep(args.name) 301 | else: 302 | print("Please specify the format to export the model.") 303 | parser.print_help() 304 | -------------------------------------------------------------------------------- /finetune/regression/biomasters_datamodule.py: -------------------------------------------------------------------------------- 1 | """ 2 | DataModule for the BioMasters dataset for a regression task. 3 | 4 | BioMassters: A Benchmark Dataset for Forest Biomass Estimation using 5 | Multi-modal Satellite Time-series https://nascetti-a.github.io/BioMasster/ 6 | 7 | This implementation provides a structured way to handle the data loading and 8 | preprocessing required for training and validating a regression model. 9 | 10 | Citation: 11 | 12 | Andrea Nascetti, Ritu Yadav, Kirill Brodt, Qixun Qu, Hongwei Fan, Yuri 13 | Shendryk, Isha Shah, and Christine Chung, BioMassters: A Benchmark Dataset 14 | for Forest Biomass Estimation using Multi-modal Satellite Time-series, 15 | Thirty-seventh Conference on Neural Information Processing Systems Datasets 16 | and Benchmarks Track, 2023, https://openreview.net/forum?id=hrWsIC4Cmz 17 | """ 18 | 19 | from pathlib import Path 20 | 21 | import lightning as L 22 | import numpy as np 23 | import torch 24 | import yaml 25 | from box import Box 26 | from tifffile import imread 27 | from torch.utils.data import DataLoader, Dataset 28 | from torchvision.transforms import v2 29 | 30 | 31 | class BioMastersDataset(Dataset): 32 | """ 33 | Dataset class for the BioMasters regression dataset. 34 | 35 | Assumes band order 36 | vv, vh, vv, vh, B2, B3, B4, B5, B6, B7, B8, B8A, B11, B12 37 | 38 | Args: 39 | chip_dir (str): Directory containing the image chips. 40 | label_dir (str): Directory containing the labels. 41 | """ 42 | 43 | def __init__(self, chip_dir, label_dir, metadata): 44 | self.chip_dir = Path(chip_dir) 45 | self.label_dir = Path(label_dir) 46 | self.metadata = metadata 47 | 48 | # Load statistics from Clay metadata 49 | s2_mean = list(metadata["sentinel-2-l2a"].bands.mean.values()) 50 | s2_std = list(metadata["sentinel-2-l2a"].bands.std.values()) 51 | 52 | # Duplicate the S1 statistics so that the asc/desc orbit data 53 | # is handled correctly 54 | self.transform = self.create_transforms( 55 | mean=s2_mean, 56 | std=s2_std, 57 | ) 58 | # Load chip and label file names 59 | self.chips = [chip_path.name for chip_path in self.chip_dir.glob("*.npz")] 60 | print(f"Found {len(self.chips)} chips to process for {chip_dir}") 61 | 62 | def create_transforms(self, mean, std): 63 | """ 64 | Create normalization transforms. 65 | 66 | Args: 67 | mean (list): Mean values for normalization. 68 | std (list): Standard deviation values for normalization. 69 | 70 | Returns: 71 | torchvision.transforms.Compose: A composition of transforms. 72 | """ 73 | return v2.Compose( 74 | [ 75 | v2.Normalize(mean=mean, std=std), 76 | ], 77 | ) 78 | 79 | def __len__(self): 80 | return len(self.chips) 81 | 82 | def __getitem__(self, idx): 83 | """ 84 | Get a sample from the dataset. 85 | 86 | Args: 87 | idx (int): Index of the sample. 88 | 89 | Returns: 90 | dict: A dictionary containing the image, label, and additional information. 91 | """ 92 | chip_name = self.chip_dir / self.chips[idx] 93 | label_name = self.label_dir / (chip_name.stem.split("_")[-1] + "_agbm.tif") 94 | 95 | chip = np.load(chip_name)["cube"][2:, ...].astype("float32") 96 | label = imread(label_name).astype("float32") 97 | label = np.expand_dims(label, 0) 98 | 99 | sample = { 100 | "pixels": self.transform(torch.from_numpy(chip)), 101 | "label": torch.from_numpy(label), 102 | "time": torch.zeros(4), # Placeholder for time information 103 | "latlon": torch.zeros(4), # Placeholder for latlon information 104 | } 105 | return sample 106 | 107 | 108 | class BioMastersDataModule(L.LightningDataModule): 109 | """ 110 | DataModule class for the Chesapeake Bay dataset. 111 | 112 | Args: 113 | train_chip_dir (str): Directory containing training image chips. 114 | train_label_dir (str): Directory containing training labels. 115 | val_chip_dir (str): Directory containing validation image chips. 116 | val_label_dir (str): Directory containing validation labels. 117 | metadata_path (str): Path to the metadata file. 118 | batch_size (int): Batch size for data loading. 119 | num_workers (int): Number of workers for data loading. 120 | platform (str): Platform identifier used in metadata. 121 | """ 122 | 123 | def __init__( # noqa: PLR0913 124 | self, 125 | train_chip_dir, 126 | train_label_dir, 127 | val_chip_dir, 128 | val_label_dir, 129 | metadata_path, 130 | batch_size, 131 | num_workers, 132 | ): 133 | super().__init__() 134 | self.train_chip_dir = train_chip_dir 135 | self.train_label_dir = train_label_dir 136 | self.val_chip_dir = val_chip_dir 137 | self.val_label_dir = val_label_dir 138 | self.metadata = Box(yaml.safe_load(open(metadata_path))) 139 | self.batch_size = batch_size 140 | self.num_workers = num_workers 141 | 142 | def setup(self, stage=None): 143 | """ 144 | Setup datasets for training and validation. 145 | 146 | Args: 147 | stage (str): Stage identifier ('fit' or 'test'). 148 | """ 149 | if stage in {"fit", None}: 150 | self.trn_ds = BioMastersDataset( 151 | self.train_chip_dir, 152 | self.train_label_dir, 153 | self.metadata, 154 | ) 155 | self.val_ds = BioMastersDataset( 156 | self.val_chip_dir, 157 | self.val_label_dir, 158 | self.metadata, 159 | ) 160 | 161 | def train_dataloader(self): 162 | """ 163 | Create DataLoader for training data. 164 | 165 | Returns: 166 | DataLoader: DataLoader for training dataset. 167 | """ 168 | return DataLoader( 169 | self.trn_ds, 170 | batch_size=self.batch_size, 171 | shuffle=True, 172 | num_workers=self.num_workers, 173 | ) 174 | 175 | def val_dataloader(self): 176 | """ 177 | Create DataLoader for validation data. 178 | 179 | Returns: 180 | DataLoader: DataLoader for validation dataset. 181 | """ 182 | return DataLoader( 183 | self.val_ds, 184 | batch_size=self.batch_size, 185 | num_workers=self.num_workers, 186 | ) 187 | -------------------------------------------------------------------------------- /finetune/regression/biomasters_model.py: -------------------------------------------------------------------------------- 1 | import lightning as L 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn, optim 5 | from torchmetrics import MeanSquaredError 6 | 7 | from finetune.regression.factory import Regressor 8 | 9 | 10 | class NoNaNRMSE(nn.Module): 11 | def __init__(self, threshold=400): 12 | super().__init__() 13 | 14 | self.threshold = threshold 15 | 16 | def forward(self, logits, target): 17 | not_nan = target < self.threshold 18 | 19 | # logits = logits.squeeze(1) 20 | diff = logits - target 21 | diff[~not_nan] = 0 22 | diff2 = torch.square(diff) 23 | diff2m = (diff2 / not_nan.sum((-1, -2, -3), keepdim=True)).sum((-1, -2, -3)) 24 | diff2msqrt = torch.sqrt(diff2m) 25 | 26 | rmse = diff2msqrt.mean() 27 | 28 | return rmse 29 | 30 | 31 | class BioMastersClassifier(L.LightningModule): 32 | """ 33 | LightningModule for training and evaluating a regression on the BioMasters 34 | dataset. 35 | 36 | Args: 37 | num_classes (int): Number of classes for classification. 38 | ckpt_path (str): Clay MAE pretrained checkpoint path. 39 | lr (float): Learning rate for the optimizer. 40 | wd (float): Weight decay for the optimizer. 41 | b1 (float): Beta1 parameter for the Adam optimizer. 42 | b2 (float): Beta2 parameter for the Adam optimizer. 43 | """ 44 | 45 | def __init__(self, ckpt_path, lr, wd, b1, b2): # noqa: PLR0913 46 | super().__init__() 47 | self.save_hyperparameters() 48 | # self.model = Classifier(num_classes=1, ckpt_path=ckpt_path) 49 | self.model = Regressor(num_classes=1, ckpt_path=ckpt_path) 50 | self.loss_fn = NoNaNRMSE() 51 | self.score_fn = MeanSquaredError() 52 | 53 | def forward(self, datacube): 54 | """ 55 | Forward pass through the classifier. 56 | 57 | Args: 58 | datacube (dict): A dictionary containing the input datacube 59 | and meta information like time, latlon, gsd & wavelenths. 60 | 61 | Returns: 62 | torch.Tensor: The output logits from the classifier. 63 | """ 64 | # Wavelengths for S1 and S2 bands of BioMasters dataset 65 | waves = torch.tensor( 66 | [ 67 | # 3.5, # S1 68 | # 4.0, 69 | 0.493, # S2 70 | 0.56, 71 | 0.665, 72 | 0.704, 73 | 0.74, 74 | 0.783, 75 | 0.842, 76 | 0.865, 77 | 1.61, 78 | 2.19, 79 | ] 80 | ) 81 | gsd = torch.tensor(10.0) 82 | 83 | return self.model( 84 | { 85 | "pixels": datacube["pixels"], 86 | "time": datacube["time"], 87 | "latlon": datacube["latlon"], 88 | "gsd": gsd, 89 | "waves": waves, 90 | } 91 | ) 92 | 93 | def configure_optimizers(self): 94 | """ 95 | Configure the optimizer and learning rate scheduler. 96 | 97 | Returns: 98 | dict: A dictionary containing the optimizer and learning rate 99 | scheduler. 100 | """ 101 | optimizer = optim.AdamW( 102 | [ 103 | param 104 | for name, param in self.model.named_parameters() 105 | if param.requires_grad 106 | ], 107 | lr=self.hparams.lr, 108 | weight_decay=self.hparams.wd, 109 | betas=(self.hparams.b1, self.hparams.b2), 110 | ) 111 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5) 112 | return { 113 | "optimizer": optimizer, 114 | "lr_scheduler": { 115 | "scheduler": scheduler, 116 | "interval": "epoch", 117 | }, 118 | } 119 | 120 | def shared_step(self, batch, batch_idx, phase): 121 | """ 122 | Perform a shared step for both training and validation. 123 | 124 | Args: 125 | batch (dict): A batch of data. 126 | batch_idx (int): The index of the batch. 127 | phase (str): The phase ('train' or 'val'). 128 | 129 | Returns: 130 | torch.Tensor: The computed loss for the batch. 131 | """ 132 | labels = batch["label"] 133 | logits = self(batch) 134 | logits = F.interpolate( 135 | logits, 136 | size=(256, 256), 137 | mode="bilinear", 138 | align_corners=False, 139 | ) # Resize to match labels size 140 | # print("Logits shape", logits.shape) 141 | # print("Labels shape", labels.shape) 142 | loss = self.loss_fn(logits, labels) 143 | score = self.score_fn(logits, labels) 144 | # Convert to RMSE 145 | score = torch.sqrt(score) 146 | 147 | self.log( 148 | f"{phase}/loss", 149 | loss, 150 | on_step=True, 151 | on_epoch=True, 152 | prog_bar=True, 153 | logger=True, 154 | sync_dist=True, 155 | ) 156 | self.log( 157 | f"{phase}/score", 158 | score, 159 | on_step=True, 160 | on_epoch=True, 161 | prog_bar=True, 162 | logger=True, 163 | sync_dist=True, 164 | ) 165 | return loss 166 | 167 | def training_step(self, batch, batch_idx): 168 | """ 169 | Perform a training step. 170 | 171 | Args: 172 | batch (dict): A batch of training data. 173 | batch_idx (int): The index of the batch. 174 | 175 | Returns: 176 | torch.Tensor: The computed loss for the batch. 177 | """ 178 | return self.shared_step(batch, batch_idx, "train") 179 | 180 | def validation_step(self, batch, batch_idx): 181 | """ 182 | Perform a validation step. 183 | 184 | Args: 185 | batch (dict): A batch of validation data. 186 | batch_idx (int): The index of the batch. 187 | 188 | Returns: 189 | torch.Tensor: The computed loss for the batch. 190 | """ 191 | return self.shared_step(batch, batch_idx, "val") 192 | -------------------------------------------------------------------------------- /finetune/regression/factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | Clay Regressor for semantic regression tasks using PixelShuffle. 3 | 4 | Attribution: 5 | Decoder inspired by PixelShuffle-based upsampling. 6 | """ 7 | 8 | import re 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | from einops import rearrange, repeat 13 | from torch import nn 14 | 15 | from src.model import Encoder 16 | 17 | 18 | class RegressionEncoder(Encoder): 19 | """ 20 | Encoder class for regression tasks. 21 | 22 | Attributes: 23 | ckpt_path (str): Path to the clay checkpoint file. 24 | """ 25 | 26 | def __init__( # noqa: PLR0913 27 | self, 28 | mask_ratio, 29 | patch_size, 30 | shuffle, 31 | dim, 32 | depth, 33 | heads, 34 | dim_head, 35 | mlp_ratio, 36 | ckpt_path=None, 37 | ): 38 | super().__init__( 39 | mask_ratio, 40 | patch_size, 41 | shuffle, 42 | dim, 43 | depth, 44 | heads, 45 | dim_head, 46 | mlp_ratio, 47 | ) 48 | # Set device 49 | self.device = ( 50 | torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 51 | ) 52 | # Load model from checkpoint if provided 53 | self.load_from_ckpt(ckpt_path) 54 | 55 | def load_from_ckpt(self, ckpt_path): 56 | """ 57 | Load the model's state from a checkpoint file. 58 | 59 | Args: 60 | ckpt_path (str): The path to the checkpoint file. 61 | """ 62 | if ckpt_path: 63 | # Load checkpoint 64 | ckpt = torch.load(ckpt_path, map_location=self.device) 65 | state_dict = ckpt.get("state_dict") 66 | 67 | # Prepare new state dict with the desired subset and naming 68 | new_state_dict = { 69 | re.sub(r"^model\.encoder\.", "", name): param 70 | for name, param in state_dict.items() 71 | if name.startswith("model.encoder") 72 | } 73 | 74 | # Load the modified state dict into the model 75 | model_state_dict = self.state_dict() 76 | for name, param in new_state_dict.items(): 77 | if ( 78 | name in model_state_dict 79 | and param.size() == model_state_dict[name].size() 80 | ): 81 | model_state_dict[name].copy_(param) 82 | else: 83 | print(f"No matching parameter for {name} with size {param.size()}") 84 | 85 | # Freeze the loaded parameters 86 | for name, param in self.named_parameters(): 87 | if name in new_state_dict: 88 | param.requires_grad = False 89 | 90 | def forward(self, datacube): 91 | """ 92 | Forward pass of the RegressionEncoder. 93 | 94 | Args: 95 | datacube (dict): A dictionary containing the input datacube and 96 | meta information like time, latlon, gsd & wavelenths. 97 | 98 | Returns: 99 | torch.Tensor: The embeddings from the final layer. 100 | """ 101 | cube, time, latlon, gsd, waves = ( 102 | datacube["pixels"], # [B C H W] 103 | datacube["time"], # [B 2] 104 | datacube["latlon"], # [B 2] 105 | datacube["gsd"], # 1 106 | datacube["waves"], # [N] 107 | ) 108 | 109 | B, C, H, W = cube.shape 110 | 111 | # Patchify and create embeddings per patch 112 | patches, waves_encoded = self.to_patch_embed(cube, waves) # [B L D] 113 | patches = self.add_encodings(patches, time, latlon, gsd) # [B L D] 114 | 115 | # Add class tokens 116 | cls_tokens = repeat(self.cls_token, "1 1 D -> B 1 D", B=B) # [B 1 D] 117 | patches = torch.cat((cls_tokens, patches), dim=1) # [B (1 + L) D] 118 | 119 | # Transformer encoder 120 | patches = self.transformer(patches) 121 | 122 | # Remove class token 123 | patches = patches[:, 1:, :] # [B, L, D] 124 | 125 | return patches 126 | 127 | 128 | class Regressor(nn.Module): 129 | """ 130 | Clay Regressor class that combines the Encoder with PixelShuffle for regression. 131 | 132 | Attributes: 133 | num_classes (int): Number of output classes for regression. 134 | ckpt_path (str): Path to the checkpoint file. 135 | """ 136 | 137 | def __init__(self, num_classes, ckpt_path): 138 | super().__init__() 139 | # Initialize the encoder 140 | self.encoder = RegressionEncoder( 141 | mask_ratio=0.0, 142 | patch_size=8, 143 | shuffle=False, 144 | dim=1024, 145 | depth=24, 146 | heads=16, 147 | dim_head=64, 148 | mlp_ratio=4.0, 149 | ckpt_path=ckpt_path, 150 | ) 151 | 152 | # Freeze the encoder parameters 153 | for param in self.encoder.parameters(): 154 | param.requires_grad = False 155 | 156 | # Define layers after the encoder 157 | D = self.encoder.dim # embedding dimension 158 | hidden_dim = 512 159 | C_out = 64 160 | r = self.encoder.patch_size # upscale factor (patch_size) 161 | 162 | self.conv1 = nn.Conv2d(D, hidden_dim, kernel_size=3, padding=1) 163 | self.bn1 = nn.BatchNorm2d(hidden_dim) 164 | self.conv2 = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1) 165 | self.bn2 = nn.BatchNorm2d(hidden_dim) 166 | self.conv_ps = nn.Conv2d(hidden_dim, C_out * r * r, kernel_size=3, padding=1) 167 | self.pixel_shuffle = nn.PixelShuffle(upscale_factor=r) 168 | self.conv_out = nn.Conv2d(C_out, num_classes, kernel_size=3, padding=1) 169 | 170 | def forward(self, datacube): 171 | """ 172 | Forward pass of the Regressor. 173 | 174 | Args: 175 | datacube (dict): A dictionary containing the input datacube and 176 | meta information like time, latlon, gsd & wavelenths. 177 | 178 | Returns: 179 | torch.Tensor: The regression output. 180 | """ 181 | cube = datacube["pixels"] # [B C H_in W_in] 182 | B, C, H_in, W_in = cube.shape 183 | 184 | # Get embeddings from the encoder 185 | patches = self.encoder(datacube) # [B, L, D] 186 | 187 | # Reshape embeddings to [B, D, H', W'] 188 | H_patches = H_in // self.encoder.patch_size 189 | W_patches = W_in // self.encoder.patch_size 190 | x = rearrange(patches, "B (H W) D -> B D H W", H=H_patches, W=W_patches) 191 | 192 | # Pass through convolutional layers 193 | x = F.relu(self.bn1(self.conv1(x))) 194 | x = F.relu(self.bn2(self.conv2(x))) 195 | x = self.conv_ps(x) # [B, C_out * r^2, H', W'] 196 | 197 | # Upsample using PixelShuffle 198 | x = self.pixel_shuffle(x) # [B, C_out, H_in, W_in] 199 | 200 | # Final convolution to get desired output channels 201 | x = self.conv_out(x) # [B, num_outputs, H_in, W_in] 202 | 203 | return x 204 | -------------------------------------------------------------------------------- /finetune/regression/preprocess_data.py: -------------------------------------------------------------------------------- 1 | import random 2 | from multiprocessing import Pool 3 | from pathlib import Path 4 | from typing import List 5 | 6 | import click 7 | import numpy as np 8 | from tifffile import imread 9 | 10 | EXPECTED_NR_OF_FILES_PER_TILE = 24 11 | MONTHS = [ 12 | "00", 13 | "01", 14 | "02", 15 | "03", 16 | "04", 17 | "05", 18 | "06", 19 | "07", 20 | "08", 21 | "09", 22 | "10", 23 | "11", 24 | ] 25 | 26 | 27 | def list_uniqe_ids(src: Path) -> List[str]: 28 | ids = list(set(dat.name.split("_")[0] for dat in src.glob("*.tif"))) 29 | print(f"Found {len(ids)} unique tile ids") 30 | return ids 31 | 32 | 33 | def process_data_for_id( 34 | id: str, feature_path: Path, cubes_path: Path, overwrite: bool 35 | ) -> None: 36 | if not overwrite and (cubes_path / f"biomasters_cube_{id}.npz").exists(): 37 | print(f"Found existing file for {id}, skipping.") 38 | return 39 | data = [] 40 | for month in MONTHS: 41 | data_month = [] 42 | for platform in ["S1", "S2"]: 43 | feature_name = f"{id}_{platform}_{month}.tif" 44 | if not Path(feature_path / feature_name).exists(): 45 | continue 46 | file_data = ( 47 | imread(feature_path / feature_name).swapaxes(1, 2).swapaxes(0, 1) 48 | ) 49 | ND1 = 0 50 | ND2 = -9999 51 | if platform == "S1": 52 | # Limit to first orbit (the other is mostly nodata) 53 | file_data = file_data[:2] 54 | file_data = np.ma.array( 55 | file_data, mask=np.logical_or(file_data == ND1, file_data == ND2) 56 | ) 57 | else: 58 | file_data = file_data[:10] 59 | file_data = np.ma.array(file_data, mask=file_data == ND1) 60 | data_month.append(file_data) 61 | 62 | data_month = np.ma.vstack(data_month) 63 | NR_OF_BANDS_EXPECTED = 12 64 | if data_month.shape[0] != NR_OF_BANDS_EXPECTED: 65 | continue 66 | data.append(data_month) 67 | 68 | cube = np.ma.array(data) 69 | mean_cube = np.ma.mean(cube, axis=0) 70 | 71 | if np.sum(mean_cube.mask): 72 | print("Nodata", np.sum(mean_cube.mask)) 73 | NODATA_THRESHOLD = 1e5 74 | if np.sum(mean_cube.mask) > NODATA_THRESHOLD: 75 | print("Skipping due to lots of nodata") 76 | return 77 | 78 | np.savez_compressed(cubes_path / f"biomasters_cube_{id}.npz", cube=mean_cube) 79 | 80 | 81 | @click.command() 82 | @click.option( 83 | "--features", 84 | help="Folder with features (training or test)", 85 | type=click.Path(path_type=Path), 86 | ) 87 | @click.option( 88 | "--cubes", help="Folder to write the datacubes", type=click.Path(path_type=Path) 89 | ) 90 | @click.option( 91 | "--processes", 92 | default=1, 93 | help="How many processes to use for parallel processing", 94 | type=click.INT, 95 | ) 96 | @click.option( 97 | "--sample", 98 | default=0.05, 99 | help="Fraction of original data to sample", 100 | type=click.FloatRange(0, 1), 101 | ) 102 | @click.option( 103 | "--overwrite", 104 | is_flag=True, 105 | help="Overwrite existing cubes", 106 | ) 107 | def process(features, cubes, processes, sample, overwrite): 108 | """ 109 | Combine tiff files into npz datacubes. 110 | 111 | The datacubes will contain the S1 vv/vh bands for asc and desc orbits, 112 | stacked with the first 10 S2 bands. 113 | """ 114 | ids = list_uniqe_ids(features) 115 | 116 | if sample < 1: 117 | sample_length = int(len(ids) * sample) 118 | random.seed(42) 119 | ids = random.sample(ids, sample_length) 120 | 121 | print(f"Subsampled {len(ids)} tiles") 122 | 123 | if processes > 1: 124 | features = [features] * len(ids) 125 | cubes = [cubes] * len(ids) 126 | overwrite = [overwrite] * len(ids) 127 | with Pool(processes) as pl: 128 | pl.starmap(process_data_for_id, zip(ids, features, cubes, overwrite)) 129 | else: 130 | for id in ids: 131 | process_data_for_id(id, features, cubes, overwrite) 132 | 133 | 134 | if __name__ == "__main__": 135 | process() 136 | -------------------------------------------------------------------------------- /finetune/regression/regression.py: -------------------------------------------------------------------------------- 1 | """ 2 | Command line interface to run the neural network model! 3 | 4 | From the project root directory, do: 5 | 6 | python regression.py fit --config configs/regression_biomasters.yaml 7 | 8 | References: 9 | - https://lightning.ai/docs/pytorch/2.1.0/cli/lightning_cli.html 10 | - https://pytorch-lightning.medium.com/introducing-lightningcli-v2-supercharge-your-training-c070d43c7dd6 11 | """ 12 | 13 | from lightning.pytorch.cli import LightningCLI 14 | 15 | from finetune.regression.biomasters_datamodule import BioMastersDataModule # noqa: F401 16 | from finetune.regression.biomasters_model import BioMastersClassifier # noqa: F401 17 | 18 | 19 | # %% 20 | def cli_main(): 21 | """ 22 | Command-line inteface to run Regression with BioMastersDataModule. 23 | """ 24 | cli = LightningCLI( 25 | BioMastersClassifier, 26 | BioMastersDataModule, 27 | save_config_kwargs={"overwrite": True}, 28 | ) 29 | return cli 30 | 31 | 32 | # %% 33 | if __name__ == "__main__": 34 | cli_main() 35 | 36 | print("Done!") 37 | -------------------------------------------------------------------------------- /finetune/segment/chesapeake_datamodule.py: -------------------------------------------------------------------------------- 1 | """ 2 | DataModule for the Chesapeake Bay dataset for segmentation tasks. 3 | 4 | This implementation provides a structured way to handle the data loading and 5 | preprocessing required for training and validating a segmentation model. 6 | 7 | Dataset citation: 8 | Robinson C, Hou L, Malkin K, Soobitsky R, Czawlytko J, Dilkina B, Jojic N. 9 | Large Scale High-Resolution Land Cover Mapping with Multi-Resolution Data. 10 | Proceedings of the 2019 Conference on Computer Vision and Pattern Recognition 11 | (CVPR 2019). 12 | 13 | Dataset URL: https://lila.science/datasets/chesapeakelandcover 14 | """ 15 | 16 | import re 17 | from pathlib import Path 18 | 19 | import lightning as L 20 | import numpy as np 21 | import torch 22 | import yaml 23 | from box import Box 24 | from torch.utils.data import DataLoader, Dataset 25 | from torchvision.transforms import v2 26 | 27 | 28 | class ChesapeakeDataset(Dataset): 29 | """ 30 | Dataset class for the Chesapeake Bay segmentation dataset. 31 | 32 | Args: 33 | chip_dir (str): Directory containing the image chips. 34 | label_dir (str): Directory containing the labels. 35 | metadata (Box): Metadata for normalization and other dataset-specific details. 36 | platform (str): Platform identifier used in metadata. 37 | """ 38 | 39 | def __init__(self, chip_dir, label_dir, metadata, platform): 40 | self.chip_dir = Path(chip_dir) 41 | self.label_dir = Path(label_dir) 42 | self.metadata = metadata 43 | self.transform = self.create_transforms( 44 | mean=list(metadata[platform].bands.mean.values()), 45 | std=list(metadata[platform].bands.std.values()), 46 | ) 47 | 48 | # Load chip and label file names 49 | self.chips = [chip_path.name for chip_path in self.chip_dir.glob("*.npy")][ 50 | :1000 51 | ] 52 | self.labels = [re.sub("_naip-new_", "_lc_", chip) for chip in self.chips] 53 | 54 | def create_transforms(self, mean, std): 55 | """ 56 | Create normalization transforms. 57 | 58 | Args: 59 | mean (list): Mean values for normalization. 60 | std (list): Standard deviation values for normalization. 61 | 62 | Returns: 63 | torchvision.transforms.Compose: A composition of transforms. 64 | """ 65 | return v2.Compose( 66 | [ 67 | v2.Normalize(mean=mean, std=std), 68 | ], 69 | ) 70 | 71 | def __len__(self): 72 | return len(self.chips) 73 | 74 | def __getitem__(self, idx): 75 | """ 76 | Get a sample from the dataset. 77 | 78 | Args: 79 | idx (int): Index of the sample. 80 | 81 | Returns: 82 | dict: A dictionary containing the image, label, and additional information. 83 | """ 84 | chip_name = self.chip_dir / self.chips[idx] 85 | label_name = self.label_dir / self.labels[idx] 86 | 87 | chip = np.load(chip_name).astype(np.float32) 88 | label = np.load(label_name) 89 | 90 | # Remap labels to match desired classes 91 | label_mapping = {1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 15: 6} 92 | remapped_label = np.vectorize(label_mapping.get)(label) 93 | 94 | sample = { 95 | "pixels": self.transform(torch.from_numpy(chip)), 96 | "label": torch.from_numpy(remapped_label[0]), 97 | "time": torch.zeros(4), # Placeholder for time information 98 | "latlon": torch.zeros(4), # Placeholder for latlon information 99 | } 100 | return sample 101 | 102 | 103 | class ChesapeakeDataModule(L.LightningDataModule): 104 | """ 105 | DataModule class for the Chesapeake Bay dataset. 106 | 107 | Args: 108 | train_chip_dir (str): Directory containing training image chips. 109 | train_label_dir (str): Directory containing training labels. 110 | val_chip_dir (str): Directory containing validation image chips. 111 | val_label_dir (str): Directory containing validation labels. 112 | metadata_path (str): Path to the metadata file. 113 | batch_size (int): Batch size for data loading. 114 | num_workers (int): Number of workers for data loading. 115 | platform (str): Platform identifier used in metadata. 116 | """ 117 | 118 | def __init__( # noqa: PLR0913 119 | self, 120 | train_chip_dir, 121 | train_label_dir, 122 | val_chip_dir, 123 | val_label_dir, 124 | metadata_path, 125 | batch_size, 126 | num_workers, 127 | platform, 128 | ): 129 | super().__init__() 130 | self.train_chip_dir = train_chip_dir 131 | self.train_label_dir = train_label_dir 132 | self.val_chip_dir = val_chip_dir 133 | self.val_label_dir = val_label_dir 134 | self.metadata = Box(yaml.safe_load(open(metadata_path))) 135 | self.batch_size = batch_size 136 | self.num_workers = num_workers 137 | self.platform = platform 138 | 139 | def setup(self, stage=None): 140 | """ 141 | Setup datasets for training and validation. 142 | 143 | Args: 144 | stage (str): Stage identifier ('fit' or 'test'). 145 | """ 146 | if stage in {"fit", None}: 147 | self.trn_ds = ChesapeakeDataset( 148 | self.train_chip_dir, 149 | self.train_label_dir, 150 | self.metadata, 151 | self.platform, 152 | ) 153 | self.val_ds = ChesapeakeDataset( 154 | self.val_chip_dir, 155 | self.val_label_dir, 156 | self.metadata, 157 | self.platform, 158 | ) 159 | 160 | def train_dataloader(self): 161 | """ 162 | Create DataLoader for training data. 163 | 164 | Returns: 165 | DataLoader: DataLoader for training dataset. 166 | """ 167 | return DataLoader( 168 | self.trn_ds, 169 | batch_size=self.batch_size, 170 | shuffle=True, 171 | num_workers=self.num_workers, 172 | ) 173 | 174 | def val_dataloader(self): 175 | """ 176 | Create DataLoader for validation data. 177 | 178 | Returns: 179 | DataLoader: DataLoader for validation dataset. 180 | """ 181 | return DataLoader( 182 | self.val_ds, 183 | batch_size=self.batch_size, 184 | num_workers=self.num_workers, 185 | ) 186 | -------------------------------------------------------------------------------- /finetune/segment/chesapeake_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | LightningModule for training and validating a segmentation model using the 3 | Segmentor class. 4 | """ 5 | 6 | import lightning as L 7 | import segmentation_models_pytorch as smp 8 | import torch 9 | import torch.nn.functional as F 10 | from torch import optim 11 | from torchmetrics.classification import F1Score, MulticlassJaccardIndex 12 | 13 | from finetune.segment.factory import Segmentor 14 | 15 | 16 | class ChesapeakeSegmentor(L.LightningModule): 17 | """ 18 | LightningModule for segmentation tasks, utilizing Clay Segmentor. 19 | 20 | Attributes: 21 | model (nn.Module): Clay Segmentor model. 22 | loss_fn (nn.Module): The loss function. 23 | iou (Metric): Intersection over Union metric. 24 | f1 (Metric): F1 Score metric. 25 | lr (float): Learning rate. 26 | """ 27 | 28 | def __init__( # # noqa: PLR0913 29 | self, 30 | num_classes, 31 | ckpt_path, 32 | lr, 33 | wd, 34 | b1, 35 | b2, 36 | ): 37 | super().__init__() 38 | self.save_hyperparameters() # Save hyperparameters for checkpointing 39 | self.model = Segmentor( 40 | num_classes=num_classes, 41 | ckpt_path=ckpt_path, 42 | ) 43 | 44 | self.loss_fn = smp.losses.FocalLoss(mode="multiclass") 45 | self.iou = MulticlassJaccardIndex( 46 | num_classes=num_classes, 47 | average="weighted", 48 | ) 49 | self.f1 = F1Score( 50 | task="multiclass", 51 | num_classes=num_classes, 52 | average="weighted", 53 | ) 54 | 55 | def forward(self, datacube): 56 | """ 57 | Forward pass through the segmentation model. 58 | 59 | Args: 60 | datacube (dict): A dictionary containing the input datacube and 61 | meta information like time, latlon, gsd & wavelenths. 62 | 63 | Returns: 64 | torch.Tensor: The segmentation logits. 65 | """ 66 | waves = torch.tensor([0.65, 0.56, 0.48, 0.842]) # NAIP wavelengths 67 | gsd = torch.tensor(1.0) # NAIP GSD 68 | 69 | # Forward pass through the network 70 | return self.model( 71 | { 72 | "pixels": datacube["pixels"], 73 | "time": datacube["time"], 74 | "latlon": datacube["latlon"], 75 | "gsd": gsd, 76 | "waves": waves, 77 | }, 78 | ) 79 | 80 | def configure_optimizers(self): 81 | """ 82 | Configure the optimizer and learning rate scheduler. 83 | 84 | Returns: 85 | dict: A dictionary containing the optimizer and scheduler 86 | configuration. 87 | """ 88 | optimizer = optim.AdamW( 89 | [ 90 | param 91 | for name, param in self.model.named_parameters() 92 | if param.requires_grad 93 | ], 94 | lr=self.hparams.lr, 95 | weight_decay=self.hparams.wd, 96 | betas=(self.hparams.b1, self.hparams.b2), 97 | ) 98 | scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts( 99 | optimizer, 100 | T_0=100, 101 | T_mult=1, 102 | eta_min=self.hparams.lr * 100, 103 | last_epoch=-1, 104 | ) 105 | return { 106 | "optimizer": optimizer, 107 | "lr_scheduler": { 108 | "scheduler": scheduler, 109 | "interval": "step", 110 | }, 111 | } 112 | 113 | def shared_step(self, batch, batch_idx, phase): 114 | """ 115 | Shared step for training and validation. 116 | 117 | Args: 118 | batch (dict): A dictionary containing the batch data. 119 | batch_idx (int): The index of the batch. 120 | phase (str): The phase (train or val). 121 | 122 | Returns: 123 | torch.Tensor: The loss value. 124 | """ 125 | labels = batch["label"].long() 126 | outputs = self(batch) 127 | outputs = F.interpolate( 128 | outputs, 129 | size=(224, 224), 130 | mode="bilinear", 131 | align_corners=False, 132 | ) # Resize to match labels size 133 | 134 | loss = self.loss_fn(outputs, labels) 135 | iou = self.iou(outputs, labels) 136 | f1 = self.f1(outputs, labels) 137 | 138 | # Log metrics 139 | self.log( 140 | f"{phase}/loss", 141 | loss, 142 | on_step=True, 143 | on_epoch=True, 144 | prog_bar=True, 145 | logger=True, 146 | sync_dist=True, 147 | ) 148 | self.log( 149 | f"{phase}/iou", 150 | iou, 151 | on_step=True, 152 | on_epoch=True, 153 | prog_bar=True, 154 | logger=True, 155 | sync_dist=True, 156 | ) 157 | self.log( 158 | f"{phase}/f1", 159 | f1, 160 | on_step=True, 161 | on_epoch=True, 162 | prog_bar=True, 163 | logger=True, 164 | sync_dist=True, 165 | ) 166 | return loss 167 | 168 | def training_step(self, batch, batch_idx): 169 | """ 170 | Training step for the model. 171 | 172 | Args: 173 | batch (dict): A dictionary containing the batch data. 174 | batch_idx (int): The index of the batch. 175 | 176 | Returns: 177 | torch.Tensor: The loss value. 178 | """ 179 | return self.shared_step(batch, batch_idx, "train") 180 | 181 | def validation_step(self, batch, batch_idx): 182 | """ 183 | Validation step for the model. 184 | 185 | Args: 186 | batch (dict): A dictionary containing the batch data. 187 | batch_idx (int): The index of the batch. 188 | 189 | Returns: 190 | torch.Tensor: The loss value. 191 | """ 192 | return self.shared_step(batch, batch_idx, "val") 193 | -------------------------------------------------------------------------------- /finetune/segment/factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | Clay Segmentor for semantic segmentation tasks. 3 | 4 | Attribution: 5 | Decoder from Segformer: Simple and Efficient Design for Semantic Segmentation 6 | with Transformers 7 | Paper URL: https://arxiv.org/abs/2105.15203 8 | """ 9 | 10 | import re 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from einops import rearrange, repeat 15 | from torch import nn 16 | 17 | from src.model import Encoder 18 | 19 | 20 | class SegmentEncoder(Encoder): 21 | """ 22 | Encoder class for segmentation tasks, incorporating a feature pyramid 23 | network (FPN). 24 | 25 | Attributes: 26 | feature_maps (list): Indices of layers to be used for generating 27 | feature maps. 28 | ckpt_path (str): Path to the clay checkpoint file. 29 | """ 30 | 31 | def __init__( # noqa: PLR0913 32 | self, 33 | mask_ratio, 34 | patch_size, 35 | shuffle, 36 | dim, 37 | depth, 38 | heads, 39 | dim_head, 40 | mlp_ratio, 41 | ckpt_path=None, 42 | ): 43 | super().__init__( 44 | mask_ratio, 45 | patch_size, 46 | shuffle, 47 | dim, 48 | depth, 49 | heads, 50 | dim_head, 51 | mlp_ratio, 52 | ) 53 | 54 | # Set device 55 | self.device = ( 56 | torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 57 | ) 58 | # Load model from checkpoint if provided 59 | self.load_from_ckpt(ckpt_path) 60 | 61 | def load_from_ckpt(self, ckpt_path): 62 | """ 63 | Load the model's state from a checkpoint file. 64 | 65 | Args: 66 | ckpt_path (str): The path to the checkpoint file. 67 | """ 68 | if ckpt_path: 69 | # Load checkpoint 70 | ckpt = torch.load(ckpt_path, map_location=self.device) 71 | state_dict = ckpt.get("state_dict") 72 | 73 | # Prepare new state dict with the desired subset and naming 74 | new_state_dict = { 75 | re.sub(r"^model\.encoder\.", "", name): param 76 | for name, param in state_dict.items() 77 | if name.startswith("model.encoder") 78 | } 79 | 80 | # Load the modified state dict into the model 81 | model_state_dict = self.state_dict() 82 | for name, param in new_state_dict.items(): 83 | if ( 84 | name in model_state_dict 85 | and param.size() == model_state_dict[name].size() 86 | ): 87 | model_state_dict[name].copy_(param) 88 | else: 89 | print(f"No matching parameter for {name} with size {param.size()}") 90 | 91 | # Freeze the loaded parameters 92 | for name, param in self.named_parameters(): 93 | if name in new_state_dict: 94 | param.requires_grad = False 95 | 96 | def forward(self, datacube): 97 | """ 98 | Forward pass of the SegmentEncoder. 99 | 100 | Args: 101 | datacube (dict): A dictionary containing the input datacube and 102 | meta information like time, latlon, gsd & wavelenths. 103 | 104 | Returns: 105 | list: A list of feature maps extracted from the datacube. 106 | """ 107 | cube, time, latlon, gsd, waves = ( 108 | datacube["pixels"], # [B C H W] 109 | datacube["time"], # [B 2] 110 | datacube["latlon"], # [B 2] 111 | datacube["gsd"], # 1 112 | datacube["waves"], # [N] 113 | ) 114 | 115 | B, C, H, W = cube.shape 116 | 117 | # Patchify and create embeddings per patch 118 | patches, waves_encoded = self.to_patch_embed(cube, waves) # [B L D] 119 | patches = self.add_encodings(patches, time, latlon, gsd) # [B L D] 120 | 121 | # Add class tokens 122 | cls_tokens = repeat(self.cls_token, "1 1 D -> B 1 D", B=B) # [B 1 D] 123 | patches = torch.cat((cls_tokens, patches), dim=1) # [B (1 + L) D] 124 | 125 | patches = self.transformer(patches) 126 | patches = patches[:, 1:, :] # [B L D] 127 | 128 | return patches 129 | 130 | 131 | class Segmentor(nn.Module): 132 | """ 133 | Clay Segmentor class that combines the Encoder with FPN layers for semantic 134 | segmentation. 135 | 136 | Attributes: 137 | num_classes (int): Number of output classes for segmentation. 138 | feature_maps (list): Indices of layers to be used for generating feature maps. 139 | ckpt_path (str): Path to the checkpoint file. 140 | """ 141 | 142 | def __init__(self, num_classes, ckpt_path): 143 | super().__init__() 144 | # Default values are for the clay mae base model. 145 | self.encoder = SegmentEncoder( 146 | mask_ratio=0.0, 147 | patch_size=8, 148 | shuffle=False, 149 | dim=1024, 150 | depth=24, 151 | heads=16, 152 | dim_head=64, 153 | mlp_ratio=4.0, 154 | ckpt_path=ckpt_path, 155 | ) 156 | 157 | # Freeze the encoder parameters 158 | for param in self.encoder.parameters(): 159 | param.requires_grad = False 160 | 161 | # Define layers after the encoder 162 | D = self.encoder.dim # embedding dimension 163 | hidden_dim = 512 164 | C_out = 64 165 | r = self.encoder.patch_size # upscale factor (patch_size) 166 | 167 | self.conv1 = nn.Conv2d(D, hidden_dim, kernel_size=3, padding=1) 168 | self.bn1 = nn.BatchNorm2d(hidden_dim) 169 | self.conv2 = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1) 170 | self.bn2 = nn.BatchNorm2d(hidden_dim) 171 | self.conv_ps = nn.Conv2d(hidden_dim, C_out * r * r, kernel_size=3, padding=1) 172 | self.pixel_shuffle = nn.PixelShuffle(upscale_factor=r) 173 | self.conv_out = nn.Conv2d(C_out, num_classes, kernel_size=3, padding=1) 174 | 175 | def forward(self, datacube): 176 | """ 177 | Forward pass of the Segmentor. 178 | 179 | Args: 180 | datacube (dict): A dictionary containing the input datacube and 181 | meta information like time, latlon, gsd & wavelenths. 182 | 183 | Returns: 184 | torch.Tensor: The segmentation logits. 185 | """ 186 | cube = datacube["pixels"] # [B C H_in W_in] 187 | B, C, H_in, W_in = cube.shape 188 | 189 | # Get embeddings from the encoder 190 | patches = self.encoder(datacube) # [B, L, D] 191 | 192 | # Reshape embeddings to [B, D, H', W'] 193 | H_patches = H_in // self.encoder.patch_size 194 | W_patches = W_in // self.encoder.patch_size 195 | x = rearrange(patches, "B (H W) D -> B D H W", H=H_patches, W=W_patches) 196 | 197 | # Pass through convolutional layers 198 | x = F.relu(self.bn1(self.conv1(x))) 199 | x = F.relu(self.bn2(self.conv2(x))) 200 | x = self.conv_ps(x) # [B, C_out * r^2, H', W'] 201 | 202 | # Upsample using PixelShuffle 203 | x = self.pixel_shuffle(x) # [B, C_out, H_in, W_in] 204 | 205 | # Final convolution to get desired output channels 206 | x = self.conv_out(x) # [B, num_outputs, H_in, W_in] 207 | 208 | return x 209 | -------------------------------------------------------------------------------- /finetune/segment/preprocess_data.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Chesapeake CVPR Data Processing Script 3 | ====================================== 4 | 5 | This script processes GeoTIFF files from the Chesapeake CVPR dataset to create 6 | image chips for segmentation tasks. 7 | 8 | Dataset Source: 9 | --------------- 10 | Chesapeake CVPR data from LILA: 11 | https://lila.science/datasets/chesapeakelandcover 12 | 13 | For this experiment, we will use images from NY. 14 | 15 | Notes: 16 | ------ 17 | 1. Only copy *_lc.tif & *_naip-new.tif files that we will use for our 18 | segmentation downstream task. 19 | Using s5cmd for this: https://github.com/peak/s5cmd 20 | - Train: 21 | s5cmd cp \ 22 | --no-sign-request \ 23 | --include "*_lc.tif" \ 24 | --include "*_naip-new.tif" \ 25 | "s3://us-west-2.opendata.source.coop/agentmorris/lila-wildlife/lcmcvpr2019/cvpr_chesapeake_landcover/ny_1m_2013_extended-debuffered-train_tiles/*" \ 26 | data/cvpr/files/train/ 27 | - Val: 28 | s5cmd cp \ 29 | --no-sign-request \ 30 | --include "*_lc.tif" \ 31 | --include "*_naip-new.tif" \ 32 | "s3://us-west-2.opendata.source.coop/agentmorris/lila-wildlife/lcmcvpr2019/cvpr_chesapeake_landcover/ny_1m_2013_extended-debuffered-val_tiles/*" \ 33 | data/cvpr/files/val/ 34 | 35 | 2. We will create chips of size `224 x 224` to feed them to the model, feel 36 | free to experiment with other chip sizes as well. 37 | Run the script as follows: 38 | python preprocess_data.py 39 | 40 | Example: 41 | python preprocess_data.py data/cvpr/files data/cvpr/ny 224 42 | """ # noqa E501 43 | 44 | import os 45 | import sys 46 | from pathlib import Path 47 | 48 | import numpy as np 49 | import rasterio as rio 50 | 51 | 52 | def read_and_chip(file_path, chip_size, output_dir): 53 | """ 54 | Reads a GeoTIFF file, creates chips of specified size, and saves them as 55 | numpy arrays. 56 | 57 | Args: 58 | file_path (str or Path): Path to the GeoTIFF file. 59 | chip_size (int): Size of the square chips. 60 | output_dir (str or Path): Directory to save the chips. 61 | """ 62 | os.makedirs(output_dir, exist_ok=True) 63 | 64 | with rio.open(file_path) as src: 65 | data = src.read() 66 | 67 | n_chips_x = src.width // chip_size 68 | n_chips_y = src.height // chip_size 69 | 70 | chip_number = 0 71 | for i in range(n_chips_x): 72 | for j in range(n_chips_y): 73 | x1, y1 = i * chip_size, j * chip_size 74 | x2, y2 = x1 + chip_size, y1 + chip_size 75 | 76 | chip = data[:, y1:y2, x1:x2] 77 | chip_path = os.path.join( 78 | output_dir, 79 | f"{Path(file_path).stem}_chip_{chip_number}.npy", 80 | ) 81 | np.save(chip_path, chip) 82 | chip_number += 1 83 | 84 | 85 | def process_files(file_paths, output_dir, chip_size): 86 | """ 87 | Processes a list of files, creating chips and saving them. 88 | 89 | Args: 90 | file_paths (list of Path): List of paths to the GeoTIFF files. 91 | output_dir (str or Path): Directory to save the chips. 92 | chip_size (int): Size of the square chips. 93 | """ 94 | for file_path in file_paths: 95 | print(f"Processing: {file_path}") 96 | read_and_chip(file_path, chip_size, output_dir) 97 | 98 | 99 | def main(): 100 | """ 101 | Main function to process files and create chips. 102 | Expects three command line arguments: 103 | - data_dir: Directory containing the input GeoTIFF files. 104 | - output_dir: Directory to save the output chips. 105 | - chip_size: Size of the square chips. 106 | """ 107 | if len(sys.argv) != 4: # noqa: PLR2004 108 | print("Usage: python script.py ") 109 | sys.exit(1) 110 | 111 | data_dir = Path(sys.argv[1]) 112 | output_dir = Path(sys.argv[2]) 113 | chip_size = int(sys.argv[3]) 114 | 115 | train_image_paths = list((data_dir / "train").glob("*_naip-new.tif")) 116 | val_image_paths = list((data_dir / "val").glob("*_naip-new.tif")) 117 | train_label_paths = list((data_dir / "train").glob("*_lc.tif")) 118 | val_label_paths = list((data_dir / "val").glob("*_lc.tif")) 119 | 120 | process_files(train_image_paths, output_dir / "train/chips", chip_size) 121 | process_files(val_image_paths, output_dir / "val/chips", chip_size) 122 | process_files(train_label_paths, output_dir / "train/labels", chip_size) 123 | process_files(val_label_paths, output_dir / "val/labels", chip_size) 124 | 125 | 126 | if __name__ == "__main__": 127 | main() 128 | -------------------------------------------------------------------------------- /finetune/segment/segment.py: -------------------------------------------------------------------------------- 1 | """ 2 | Command line interface to run the neural network model! 3 | 4 | From the project root directory, do: 5 | 6 | python segment.py fit --config configs/segment_chesapeake.yaml 7 | 8 | References: 9 | - https://lightning.ai/docs/pytorch/2.1.0/cli/lightning_cli.html 10 | - https://pytorch-lightning.medium.com/introducing-lightningcli-v2-supercharge-your-training-c070d43c7dd6 11 | """ 12 | 13 | from lightning.pytorch.cli import LightningCLI 14 | 15 | from finetune.segment.chesapeake_datamodule import ChesapeakeDataModule # noqa: F401 16 | from finetune.segment.chesapeake_model import ChesapeakeSegmentor # noqa: F401 17 | 18 | 19 | # %% 20 | def cli_main(): 21 | """ 22 | Command-line inteface to run Segmentation Model with ChesapeakeDataModule. 23 | """ 24 | cli = LightningCLI( 25 | ChesapeakeSegmentor, 26 | ChesapeakeDataModule, 27 | save_config_kwargs={"overwrite": True}, 28 | ) 29 | return cli 30 | 31 | 32 | # %% 33 | if __name__ == "__main__": 34 | cli_main() 35 | 36 | print("Done!") 37 | -------------------------------------------------------------------------------- /ruff.toml: -------------------------------------------------------------------------------- 1 | [format] 2 | # https://docs.astral.sh/ruff/settings/#format 3 | line-ending = "lf" # Use UNIX `\n` line endings for all files 4 | 5 | [lint] 6 | # https://docs.astral.sh/ruff/rules/ 7 | select = [ 8 | "E", # pycodestyle errors 9 | "F", # pyflakes 10 | "I", # isort 11 | "NPY", # numpy 12 | "PL", # pylint 13 | "UP", # pyupgrade 14 | "W", # pycodestyle warnings 15 | ] 16 | 17 | [lint.per-file-ignores] 18 | "docs/clay_over_aoi.ipynb" = ["E501"] 19 | "finetune/embedder/how-to-embed.ipynb" = ["E501", "PLR2004"] 20 | "scripts/worldcover/worldcover_vrt.py" = ["E501"] 21 | 22 | [lint.pylint] 23 | max-args = 6 24 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clay-foundation/model/b817be2dcf31226959a132357d49e9bc5c89d05d/src/__init__.py -------------------------------------------------------------------------------- /src/backbone.py: -------------------------------------------------------------------------------- 1 | """Code for Transformer from Phil Wangs vit-pytorch library. 2 | Repository: https://github.com/lucidrains/vit-pytorch 3 | """ 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from einops import rearrange 8 | from torch import nn 9 | 10 | 11 | class FeedForward(nn.Module): 12 | def __init__(self, dim, hidden_dim): 13 | super().__init__() 14 | self.net = nn.Sequential( 15 | nn.LayerNorm(dim), 16 | nn.Linear(dim, hidden_dim), 17 | nn.GELU(), 18 | nn.Linear(hidden_dim, dim), 19 | ) 20 | 21 | def forward(self, x): 22 | return self.net(x) 23 | 24 | 25 | class Attention(nn.Module): 26 | def __init__(self, dim, heads=8, dim_head=64, fused_attn=True): 27 | super().__init__() 28 | inner_dim = dim_head * heads 29 | self.heads = heads 30 | self.scale = dim_head**-0.5 31 | self.norm = nn.LayerNorm(dim) 32 | self.fused_attn = fused_attn 33 | 34 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) 35 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 36 | 37 | def forward(self, x): 38 | x = self.norm(x) 39 | 40 | qkv = self.to_qkv(x).chunk(3, dim=-1) 41 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) 42 | 43 | if self.fused_attn: 44 | x = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0) 45 | else: 46 | attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale 47 | attn = attn.softmax(dim=-1) 48 | x = torch.matmul(attn, v) 49 | 50 | x = rearrange(x, "b h n d -> b n (h d)") 51 | return self.to_out(x) 52 | 53 | 54 | class Transformer(nn.Module): 55 | def __init__( # noqa: PLR0913 56 | self, 57 | dim, 58 | depth, 59 | heads, 60 | dim_head, 61 | mlp_dim, 62 | fused_attn, 63 | ): 64 | super().__init__() 65 | self.norm = nn.LayerNorm(dim) 66 | self.layers = nn.ModuleList([]) 67 | for _ in range(depth): 68 | self.layers.append( 69 | nn.ModuleList( 70 | [ 71 | Attention( 72 | dim, heads=heads, dim_head=dim_head, fused_attn=fused_attn 73 | ), 74 | FeedForward(dim, mlp_dim), 75 | ] 76 | ) 77 | ) 78 | 79 | def forward(self, x): 80 | for attn, ff in self.layers: 81 | x = attn(x) + x 82 | x = ff(x) + x 83 | return self.norm(x) 84 | -------------------------------------------------------------------------------- /src/callbacks.py: -------------------------------------------------------------------------------- 1 | from lightning.pytorch.callbacks import Callback 2 | from lightning.pytorch.callbacks.finetuning import BaseFinetuning 3 | 4 | 5 | class ProgressiveResizing(Callback): 6 | def __init__(self): 7 | self.resize_schedule = { 8 | 0: {"batch_size": 4, "num_workers": 4, "size": 64}, 9 | 10: {"batch_size": 2, "num_workers": 2, "size": 128}, 10 | 20: {"batch_size": 1, "num_workers": 1, "size": 256}, 11 | } 12 | 13 | def on_train_epoch_start(self, trainer, pl_module): 14 | if trainer.current_epoch in self.resize_schedule: 15 | params = self.resize_schedule[trainer.current_epoch] 16 | 17 | trainer.datamodule.size = params["size"] 18 | trainer.datamodule.batch_size = params["batch_size"] 19 | trainer.datamodule.num_workers = params["num_workers"] 20 | 21 | trainer.datamodule.setup(stage="fit") 22 | 23 | def on_validation_epoch_start(self, trainer, pl_module): 24 | if trainer.current_epoch in self.resize_schedule: 25 | params = self.resize_schedule[trainer.current_epoch] 26 | 27 | trainer.datamodule.size = params["size"] 28 | trainer.datamodule.batch_size = params["batch_size"] 29 | trainer.datamodule.num_workers = params["num_workers"] 30 | 31 | trainer.datamodule.setup(stage="validate") 32 | 33 | 34 | class LayerwiseFinetuning(BaseFinetuning): 35 | def __init__(self, phase, train_bn=True): 36 | """Initializes with phase & batch-norm information. 37 | 38 | Args: 39 | phase (List): Phases of fine-tuning the backbone network. 40 | train_bn (bool, optional): Trains just the batch-norm layers even 41 | when all the other layers of the network are freezed. Defaults to True. 42 | """ 43 | super().__init__() 44 | self.phase = phase 45 | self.train_bn = train_bn 46 | 47 | def freeze_before_training(self, pl_module): 48 | """Freezes the encoder before starting the training.""" 49 | self.freeze( 50 | modules=[ 51 | pl_module.model.encoder.patch_embedding, 52 | pl_module.model.encoder.transformer, 53 | ], 54 | train_bn=self.train_bn, 55 | ) 56 | 57 | def finetune_function(self, pl_module, epoch, optimizer): 58 | if epoch == self.phase: 59 | """Unfreezes the encoder for training.""" 60 | print(f"In Phase {self.phase}: Full throttle") 61 | self.unfreeze_and_add_param_group( 62 | modules=[ 63 | pl_module.model.encoder.patch_embedding, 64 | pl_module.model.encoder.transformer, 65 | ], 66 | optimizer=optimizer, 67 | train_bn=self.train_bn, 68 | ) 69 | params = list(pl_module.parameters()) 70 | active = list(filter(lambda p: p.requires_grad, params)) 71 | print(f"active: {len(active)}, all: {len(params)}") 72 | -------------------------------------------------------------------------------- /src/factory.py: -------------------------------------------------------------------------------- 1 | """Dynamic Embedding from DOFA paper. 2 | Reference: 3 | - https://arxiv.org/abs/2403.15356 4 | - https://github.com/zhu-xlab/DOFA 5 | """ 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from einops import rearrange 10 | from torch import nn 11 | 12 | from src.utils import posemb_sincos_1d 13 | 14 | 15 | class FCBlock(nn.Module): 16 | def __init__(self, size): 17 | super().__init__() 18 | self.l1 = nn.Linear(size, size) 19 | self.l2 = nn.Linear(size, size) 20 | 21 | def forward(self, x): 22 | y = F.gelu(self.l1(x)) 23 | y = F.gelu(self.l2(y)) 24 | return x + y 25 | 26 | 27 | class WavesTransformer(nn.Module): 28 | def __init__( # noqa: PLR0913 29 | self, 30 | wave_dim, 31 | output_dim, 32 | num_latent_tokens, 33 | embed_dim, 34 | is_decoder, 35 | num_heads=4, 36 | num_layers=1, 37 | ): 38 | super().__init__() 39 | self.num_latent_tokens = num_latent_tokens 40 | self.is_decoder = is_decoder 41 | layer = nn.TransformerEncoderLayer( 42 | d_model=wave_dim, 43 | nhead=num_heads, 44 | activation="gelu", 45 | dropout=0, 46 | norm_first=False, 47 | batch_first=True, 48 | ) 49 | self.encoder = nn.TransformerEncoder(layer, num_layers) 50 | 51 | self.fc_weight = nn.Linear(wave_dim, output_dim) 52 | self.fc_bias = None if self.is_decoder else nn.Linear(wave_dim, embed_dim) 53 | 54 | self.weight_tokens = nn.Parameter( 55 | torch.randn(self.num_latent_tokens, wave_dim) * 0.02 56 | ) 57 | self.bias_token = nn.Parameter(torch.randn(1, wave_dim) * 0.02) 58 | 59 | def forward(self, x): 60 | x = torch.cat([self.weight_tokens, x, self.bias_token], dim=0) 61 | out = self.encoder(x) 62 | weights = self.fc_weight( 63 | out[self.num_latent_tokens : -1] + x[self.num_latent_tokens : -1] 64 | ) 65 | bias = None if self.is_decoder else self.fc_bias(out[-1]) 66 | return weights, bias 67 | 68 | 69 | class DynamicEmbedding(nn.Module): 70 | def __init__( 71 | self, 72 | wave_dim, 73 | num_latent_tokens, 74 | patch_size, 75 | embed_dim, 76 | is_decoder=False, 77 | ): 78 | super().__init__() 79 | self.wave_dim = wave_dim 80 | self.num_latent_tokens = num_latent_tokens 81 | self.patch_size = patch_size 82 | self.embed_dim = embed_dim 83 | self.is_decoder = is_decoder 84 | self.output_dim = (patch_size**2) * embed_dim 85 | 86 | self.weight_generator = WavesTransformer( 87 | wave_dim, 88 | self.output_dim, 89 | self.num_latent_tokens, 90 | self.embed_dim, 91 | is_decoder, 92 | ) 93 | self.fclayer = FCBlock(self.wave_dim) 94 | 95 | self.initialize_weights() 96 | 97 | def forward(self, batch, waves): 98 | waves = posemb_sincos_1d(waves, self.wave_dim) 99 | waves = waves.to(batch.device) 100 | waves = self.fclayer(waves) 101 | weight, bias = self.weight_generator(waves) 102 | 103 | if self.is_decoder: 104 | dynamic_weight = rearrange( 105 | weight, 106 | "cin (k1 k2 cout) -> (cin k1 k2) cout", 107 | k1=self.patch_size, 108 | k2=self.patch_size, 109 | cout=self.embed_dim, 110 | ) 111 | if bias is not None: 112 | bias = rearrange(bias, "b -> (b)") 113 | dynamic_out = F.linear(batch, dynamic_weight * 0.02, bias=bias) 114 | x = dynamic_out 115 | else: 116 | dynamic_weight = rearrange( 117 | weight, 118 | "cin (cout k1 k2) -> cout cin k1 k2", 119 | k1=self.patch_size, 120 | k2=self.patch_size, 121 | ) 122 | if bias is not None: 123 | bias = rearrange(bias, "b -> (b)") 124 | dynamic_out = F.conv2d( 125 | batch, dynamic_weight * 0.02, bias=bias, stride=self.patch_size 126 | ) 127 | x = rearrange(dynamic_out, "b c h w -> b (h w) c") 128 | 129 | return x, waves 130 | 131 | def initialize_weights(self): 132 | # Initialize weights using Xavier initialization 133 | for m in self.modules(): 134 | if isinstance(m, (nn.Linear, nn.Conv2d)): 135 | nn.init.xavier_uniform_(m.weight) 136 | if m.bias is not None: 137 | nn.init.constant_(m.bias, 0) 138 | -------------------------------------------------------------------------------- /src/module.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | import lightning as L 4 | import torch 5 | import yaml 6 | from box import Box 7 | 8 | from src.model import clay_mae_base, clay_mae_large, clay_mae_small, clay_mae_tiny 9 | 10 | 11 | class ClayMAEModule(L.LightningModule): 12 | def __init__( # noqa: PLR0913 13 | self, 14 | model_size="base", 15 | mask_ratio=0.75, 16 | norm_pix_loss=False, 17 | patch_size=8, 18 | shuffle=False, 19 | metadata_path="configs/metadata.yaml", 20 | teacher="samvit_base_patch16.sa1b", 21 | dolls=[16, 32, 64, 128, 256, 768], 22 | doll_weights=[1, 1, 1, 1, 1, 1], 23 | lr=1e-5, 24 | wd=0.05, 25 | b1=0.9, 26 | b2=0.95, 27 | embeddings_level: Literal["mean", "patch", "group"] = "mean", 28 | ): 29 | super().__init__() 30 | self.save_hyperparameters(logger=True) 31 | self.metadata = Box(yaml.safe_load(open(metadata_path))) 32 | model_map = { 33 | "tiny": clay_mae_tiny, 34 | "small": clay_mae_small, 35 | "base": clay_mae_base, 36 | "large": clay_mae_large, 37 | } 38 | if model_size in model_map: 39 | model_args = { 40 | "mask_ratio": mask_ratio, 41 | "patch_size": patch_size, 42 | "norm_pix_loss": norm_pix_loss, 43 | "shuffle": shuffle, 44 | "metadata": self.metadata, 45 | "teacher": teacher, 46 | "dolls": dolls, 47 | "doll_weights": doll_weights, 48 | } 49 | self.model = model_map[model_size](**model_args) 50 | # checkpoint_path = 'mae_v1.5.0_epoch-76_val-loss-0.1612.ckpt' 51 | # checkpoint = torch.load(checkpoint_path, map_location="cpu") 52 | # # Extract the state dictionary 53 | # state_dict = checkpoint['state_dict'] 54 | 55 | # # Modify the state dictionary 56 | # new_state_dict = OrderedDict() 57 | # for k, v in state_dict.items(): 58 | # # Remove 'model.' prefix if it exists 59 | # if k.startswith('model.'): 60 | # k = k[len('model.'):] 61 | # # Exclude keys related to the 'teacher' 62 | # if not (k.startswith('teacher') or k.startswith('mrl')): 63 | # new_state_dict[k] = v 64 | # with torch.no_grad(): 65 | # # Load the modified state dictionary into your model 66 | # missing_keys, unexpected_keys = ( 67 | # self.model.load_state_dict(new_state_dict, strict=False) 68 | # ) 69 | # # Optionally, print missing and unexpected keys 70 | # print(f"Missing keys: {missing_keys}") 71 | # print(f"Unexpected keys: {unexpected_keys}") 72 | else: 73 | raise ValueError( 74 | f"Invalid model size {model_size}. Expected one of {model_map.keys()}" 75 | ) 76 | 77 | def on_train_epoch_start(self): 78 | self.model.teacher.eval() 79 | 80 | def forward(self, datacube: dict[str, torch.Tensor]): 81 | return self.model(datacube) 82 | 83 | def configure_optimizers(self): 84 | optimizer = torch.optim.AdamW( 85 | self.parameters(), 86 | lr=self.hparams.lr, 87 | weight_decay=self.hparams.wd, 88 | betas=(self.hparams.b1, self.hparams.b2), 89 | fused=True, 90 | ) 91 | scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( 92 | optimizer, T_0=5000, T_mult=1, eta_min=self.hparams.lr * 100, last_epoch=-1 93 | ) 94 | 95 | return { 96 | "optimizer": optimizer, 97 | "lr_scheduler": { 98 | "scheduler": scheduler, 99 | "interval": "step", 100 | }, 101 | } 102 | 103 | def shared_step(self, batch: dict[str, torch.Tensor], batch_idx: int, phase: str): 104 | platform = batch["platform"][0] 105 | loss, reconstruction_loss, representation_loss = self(batch) 106 | 107 | losses = { 108 | "loss": loss, 109 | "rec_loss": reconstruction_loss, 110 | "rep_loss": representation_loss, 111 | } 112 | 113 | for loss_name, loss_value in losses.items(): 114 | self.log( 115 | name=f"{phase}/{loss_name}", 116 | value=loss_value, 117 | on_step=True, 118 | on_epoch=True, 119 | prog_bar=True, 120 | logger=True, 121 | sync_dist=True, 122 | ) 123 | self.log( 124 | name=f"{phase}_{platform}/{loss_name}", 125 | value=loss_value, 126 | on_step=True, 127 | on_epoch=True, 128 | prog_bar=True, 129 | logger=True, 130 | sync_dist=True, 131 | ) 132 | 133 | return loss 134 | 135 | def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int): 136 | return self.shared_step(batch, batch_idx, phase="train") 137 | 138 | def validation_step(self, batch: dict[str, torch.Tensor], batch_idx: int): 139 | return self.shared_step(batch, batch_idx, phase="val") 140 | -------------------------------------------------------------------------------- /src/mrl.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class MRL(nn.Module): 5 | """ 6 | Matryoshka Representation Learning from the paper: https://arxiv.org/abs/2205.13147 7 | """ 8 | 9 | def __init__(self, features, dolls: list = [16, 32, 64, 128, 256, 768]) -> None: 10 | super().__init__() 11 | self.dolls = dolls 12 | self.layers = nn.ModuleDict() 13 | for doll in dolls: 14 | self.layers[f"mrl_{doll}"] = nn.Linear(doll, features) 15 | 16 | def forward(self, x): 17 | "x: (batch, features)" 18 | logits = [self.layers[f"mrl_{doll}"](x[:, :doll]) for doll in self.dolls] 19 | return logits 20 | 21 | 22 | class MRLLoss(nn.Module): 23 | def __init__(self, weights) -> None: 24 | super().__init__() 25 | self.weights = weights 26 | self.criterion = nn.CosineSimilarity(dim=1, eps=1e-6) 27 | 28 | def forward(self, representations, targets): 29 | """ 30 | representations: [(batch, features), ...] 31 | targets: (batch, features) 32 | """ 33 | losses = [ 34 | self.weights[i] * (1 - self.criterion(rep, targets)).mean() 35 | for i, rep in enumerate(representations) 36 | ] 37 | return sum(losses) / len(losses) 38 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code from https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/simple_vit.py 3 | 4 | """ 5 | 6 | import torch 7 | 8 | 9 | def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype=torch.float32): 10 | y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") 11 | assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" 12 | omega = torch.arange(dim // 4) / (dim // 4 - 1) 13 | omega = 1.0 / (temperature**omega) 14 | 15 | y = y.flatten()[:, None] * omega[None, :] 16 | x = x.flatten()[:, None] * omega[None, :] 17 | pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) 18 | return pe.type(dtype) 19 | 20 | 21 | def posemb_sincos_2d_with_gsd( 22 | h, w, dim, gsd=1.0, temperature: int = 10000, dtype=torch.float32 23 | ): 24 | y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") 25 | assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" 26 | 27 | gsd = gsd.to(x.device) 28 | omega = torch.arange(dim // 4) / (dim // 4 - 1) 29 | omega = 1.0 / (temperature ** (2 * omega / dim)) * (gsd / 1.0) # Adjusted for g 30 | 31 | y = y.flatten()[:, None] * omega[None, :] 32 | x = x.flatten()[:, None] * omega[None, :] 33 | pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) 34 | return pe.type(dtype) 35 | 36 | 37 | def posemb_sincos_1d(waves, dim, temperature: int = 10000, dtype=torch.float32): 38 | assert ( 39 | dim % 2 == 0 40 | ), "Feature dimension must be a multiple of 2 for sincos embedding" 41 | waves = torch.arange(waves) if isinstance(waves, int) else waves 42 | 43 | omega = torch.arange(dim // 2, device=waves.device) / (dim // 2 - 1) 44 | omega = 1.0 / (temperature**omega) 45 | 46 | scaled_waves = waves[:, None] * omega[None, :] 47 | pe = torch.cat((scaled_waves.sin(), scaled_waves.cos()), dim=1) 48 | 49 | return pe.type(dtype) 50 | -------------------------------------------------------------------------------- /train_clay.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=clay-laucher 4 | #SBATCH --nodes=24 5 | #SBATCH --ntasks-per-node=8 # EDIT if it's not 8-gpus per node 6 | #SBATCH --cpus-per-task=12 # EDIT this to how many cpu cores the node has divided by num of gpus 7 | #SBATCH --gres=gpu:8 # EDIT this if it's not 8-gpus per node 8 | #SBATCH --time=0-00:00:00 # EDIT the desired runtime 9 | #SBATCH --exclusive 10 | #SBATCH --partition=gpu # EDIT to the desired partition name 11 | #SBATCH --nodelist=gpu-dy-g6-[1-12],gpu-dy-g5-[1-12] 12 | #SBATCH --output=%x-%j-%N.out 13 | 14 | echo "START TIME: $(date)" 15 | 16 | # auto-fail on any errors in this script 17 | set -eo pipefail 18 | 19 | # logging script's variables/commands for future debug needs 20 | set -x 21 | 22 | # EDIT the conda evn and any startup scripts 23 | # source /path/to/start-xxx-user # if you have something to preload before the job 24 | # Load any required modules (environments, libraries etc.) 25 | eval "$(conda 'shell.bash' 'hook' 2> /dev/null)" 26 | 27 | # initialize conda 28 | conda activate /home/ubuntu/claymodel # if you have conda env to activate 29 | 30 | LOG_PATH="main_log.txt" 31 | 32 | # PTL doesn't need a special launcher 33 | LAUNCHER="python -u" 34 | 35 | # Capture the number of nodes allocated by Slurm 36 | NUM_NODES=$SLURM_JOB_NUM_NODES 37 | 38 | # EDIT the path+name of the python script and whatever args it needs 39 | PROGRAM="trainer.py fit --config configs/config.yaml --trainer.num_nodes=$NUM_NODES" 40 | 41 | export CMD="$LAUNCHER $PROGRAM" 42 | 43 | echo $CMD 44 | 45 | # EDIT if you want to redirect /tmp to /scratch (some local SSD path) since /tmp is tiny on compute nodes 46 | # export TMPDIR=/scratch 47 | 48 | # EDIT: useful for debug if needed 49 | # 50 | # to debug NCCL issues 51 | # export NCCL_DEBUG=INFO 52 | # 53 | # to unravel async errors w/o the correct traceback - potentially makes everything very slower 54 | # export CUDA_LAUNCH_BLOCKING=1 55 | # 56 | # to force crashing on nccl issues like hanging broadcast 57 | # export NCCL_ASYNC_ERROR_HANDLING=1 58 | 59 | # srun error handling: 60 | # --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks 61 | # --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code 62 | SRUN_ARGS=" \ 63 | --wait=60 \ 64 | --kill-on-bad-exit=1 \ 65 | --jobid $SLURM_JOB_ID \ 66 | " 67 | 68 | # bash -c is needed for the delayed interpolation of env vars to work 69 | srun $SRUN_ARGS bash -c "$CMD" 2>&1 | tee -a $LOG_PATH 70 | 71 | echo "END TIME: $(date)" 72 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Command line interface to run the neural network model! 3 | 4 | From the project root directory, do: 5 | 6 | python trainer.py fit 7 | 8 | References: 9 | - https://lightning.ai/docs/pytorch/2.1.0/cli/lightning_cli.html 10 | - https://pytorch-lightning.medium.com/introducing-lightningcli-v2-supercharge-your-training-c070d43c7dd6 11 | """ 12 | 13 | from lightning.pytorch.cli import LightningCLI 14 | 15 | from src.datamodule import ClayDataModule # noqa: F401 16 | from src.module import ClayMAEModule # noqa: F401 17 | 18 | 19 | # %% 20 | def cli_main(): 21 | """ 22 | Command-line inteface to run ClayMAE with ClayDataModule. 23 | """ 24 | cli = LightningCLI( 25 | ClayMAEModule, ClayDataModule, save_config_kwargs={"overwrite": True} 26 | ) 27 | return cli 28 | 29 | 30 | # %% 31 | if __name__ == "__main__": 32 | cli_main() 33 | 34 | print("Done!") 35 | -------------------------------------------------------------------------------- /utils/check_data_sanity.py: -------------------------------------------------------------------------------- 1 | import os 2 | from concurrent.futures import ThreadPoolExecutor, as_completed 3 | 4 | import numpy as np 5 | 6 | 7 | def check_and_delete_npz(file_path): 8 | try: 9 | # Attempt to load the .npz file using numpy 10 | data = np.load(file_path) 11 | 12 | # Check if the 'pixel' key exists and has shape 128 in the 0th dimension 13 | if "pixels" in data: 14 | if data["pixels"].shape[0] != 128: # noqa: PLR2004 15 | os.remove(file_path) 16 | return ( 17 | None, 18 | f"Invalid shape (not 128 in 0th dim): {file_path} - Deleted", 19 | ) 20 | else: 21 | return f"Valid: {file_path}", None 22 | else: 23 | os.remove(file_path) 24 | return None, f"'pixels' key missing: {file_path} - Deleted" 25 | 26 | except Exception as e: 27 | os.remove(file_path) 28 | return None, f"Invalid (Exception): {file_path} - {str(e)} - Deleted" 29 | 30 | 31 | def process_directory_in_parallel(directory, max_workers=4): 32 | invalid_files = [] 33 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 34 | futures = [] 35 | for root, dirs, files in os.walk(directory): 36 | for file in files: 37 | if file.endswith(".npz"): 38 | file_path = os.path.join(root, file) 39 | futures.append(executor.submit(check_and_delete_npz, file_path)) 40 | 41 | for future in as_completed(futures): 42 | valid_msg, invalid_msg = future.result() 43 | if valid_msg: 44 | print(valid_msg) 45 | if invalid_msg: 46 | print(invalid_msg) 47 | invalid_files.append(invalid_msg) 48 | 49 | return invalid_files 50 | 51 | 52 | # Replace 'your_directory_path' with the path to the directory you want to check 53 | invalid_files = process_directory_in_parallel("/fsx", max_workers=24) 54 | 55 | if invalid_files: 56 | print("\nInvalid or corrupted .npz files found and deleted:") 57 | for file in invalid_files: 58 | print(file) 59 | else: 60 | print("\nAll .npz files are valid and meet the shape criteria for 'pixel' key.") 61 | -------------------------------------------------------------------------------- /utils/split_npz.py: -------------------------------------------------------------------------------- 1 | import os 2 | from concurrent.futures import ProcessPoolExecutor 3 | 4 | import numpy as np 5 | 6 | 7 | def split_npz_file(file_path): 8 | # Load the .npz file 9 | with np.load(file_path) as data: 10 | # Check if the file has the required batch size of 128 11 | if "pixels" in data and data["pixels"].shape[0] == 128: # noqa: PLR2004 12 | # Extract all arrays 13 | keys = data.files 14 | arrays = {key: data[key] for key in keys} 15 | 16 | # Determine the batch size and the number of splits 17 | batch_size = 32 18 | num_splits = 4 # Since we want to split into 4 files, each with 32 samples 19 | 20 | # Split and save the smaller .npz files 21 | for i in range(num_splits): 22 | split_data = { 23 | key: value[i * batch_size : (i + 1) * batch_size] 24 | for key, value in arrays.items() 25 | } 26 | split_file_path = file_path.replace(".npz", f"_{i}.npz") 27 | np.savez(split_file_path, **split_data) 28 | print(f"Saved {split_file_path}") 29 | 30 | # Delete the original file 31 | os.remove(file_path) 32 | print(f"Deleted original file: {file_path}") 33 | else: 34 | print(f"Skipped {file_path}: Does not have a batch size of 128") 35 | 36 | 37 | def process_directory(root_dir): 38 | # Collect all .npz files 39 | npz_files = [] 40 | for dirpath, _, filenames in os.walk(root_dir): 41 | for filename in filenames: 42 | if filename.endswith(".npz"): 43 | file_path = os.path.join(dirpath, filename) 44 | npz_files.append(file_path) 45 | 46 | # Process files in parallel 47 | with ProcessPoolExecutor() as executor: 48 | executor.map(split_npz_file, npz_files) 49 | 50 | 51 | # Example usage 52 | root_dir = "/fsx" 53 | process_directory(root_dir) 54 | --------------------------------------------------------------------------------