├── .all-contributorsrc ├── .github └── workflows │ ├── CI.yml │ └── pre-commit.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── assets ├── decals.png └── im_embedding.png ├── astroclip ├── __init__.py ├── astrodino │ ├── __init__.py │ ├── config.yaml │ ├── data │ │ ├── __init__.py │ │ ├── augmentations.py │ │ ├── dataset.py │ │ └── loaders.py │ ├── distributed.py │ ├── embed_legacysurvey │ │ ├── embed_legacysurvey.py │ │ └── launch_embedding.sh │ ├── trainer.py │ ├── training.sh │ └── utils.py ├── callbacks.py ├── data │ ├── __init__.py │ ├── datamodule.py │ └── dataset.py ├── env.py ├── models │ ├── __init__.py │ ├── astroclip.py │ ├── loader.py │ ├── moco_v2.py │ └── specformer.py ├── modules.py ├── scheduler.py └── trainer.py ├── configs ├── astroclip.yaml └── specformer.yaml ├── downstream_tasks ├── morphology_classification │ ├── README.md │ ├── embed_galaxy_zoo.py │ ├── morphology_classification.ipynb │ └── morphology_utils │ │ ├── cross_match.py │ │ ├── models.py │ │ └── plotting.py ├── property_estimation │ ├── README.md │ ├── baselines │ │ ├── README.md │ │ ├── data.py │ │ ├── modules.py │ │ └── trainer.py │ ├── embed_provabgs.py │ ├── posterior_estimation.py │ ├── property_estimation.ipynb │ ├── property_utils │ │ ├── cross_match.py │ │ ├── models.py │ │ └── plotting.py │ └── redshift.ipynb └── similarity_search │ ├── README.md │ ├── embed_astroclip.py │ ├── plotting.py │ └── similarity_search.ipynb ├── pyproject.toml ├── requirements.txt ├── scripts ├── README.md ├── cross_match_data.py └── export_data.py ├── submit.sbatch └── tutorial.ipynb /.all-contributorsrc: -------------------------------------------------------------------------------- 1 | { 2 | "projectName": "AstroCLIP", 3 | "projectOwner": "PolymathicAI", 4 | "repoType": "github", 5 | "repoHost": "https://github.com", 6 | "files": [ 7 | "README.md" 8 | ], 9 | "imageSize": 100, 10 | "commit": false, 11 | "commitConvention": "angular", 12 | "contributors": [ 13 | { 14 | "login": "lhparker1", 15 | "name": "Liam Parker", 16 | "avatar_url": "https://avatars.githubusercontent.com/u/86175266?v=4", 17 | "profile": "https://github.com/lhparker1", 18 | "contributions": [ 19 | "code" 20 | ] 21 | }, 22 | { 23 | "login": "EiffL", 24 | "name": "Francois Lanusse", 25 | "avatar_url": "https://avatars.githubusercontent.com/u/861591?v=4", 26 | "profile": "http://flanusse.net/", 27 | "contributions": [ 28 | "code", 29 | "data" 30 | ] 31 | }, 32 | { 33 | "login": "golkar", 34 | "name": "Siavash Golkar", 35 | "avatar_url": "https://avatars.githubusercontent.com/u/35383824?v=4", 36 | "profile": "https://github.com/golkar", 37 | "contributions": [ 38 | "code" 39 | ] 40 | }, 41 | { 42 | "login": "lsarra", 43 | "name": "Leopoldo", 44 | "avatar_url": "https://avatars.githubusercontent.com/u/66411731?v=4", 45 | "profile": "https://users.flatironinstitute.org/~lsarra/", 46 | "contributions": [ 47 | "code", 48 | "tool" 49 | ] 50 | }, 51 | { 52 | "login": "shirleysurelyho", 53 | "name": "Shirley Ho", 54 | "avatar_url": "https://avatars.githubusercontent.com/u/3279839?v=4", 55 | "profile": "https://github.com/shirleysurelyho", 56 | "contributions": [ 57 | "ideas", 58 | "fundingFinding" 59 | ] 60 | }, 61 | { 62 | "login": "MilesCranmer", 63 | "name": "Miles Cranmer", 64 | "avatar_url": "https://avatars.githubusercontent.com/u/7593028?v=4", 65 | "profile": "https://github.com/MilesCranmer", 66 | "contributions": [ 67 | "ideas", 68 | "design" 69 | ] 70 | } 71 | ], 72 | "contributorsPerLine": 7, 73 | "linkToUsage": false 74 | } 75 | -------------------------------------------------------------------------------- /.github/workflows/CI.yml: -------------------------------------------------------------------------------- 1 | name: Linux 2 | 3 | on: 4 | push: 5 | branches: 6 | - '**' 7 | paths: 8 | - '.github/workflows/CI.yml' 9 | - 'astroclip/*' 10 | - 'setup.py' 11 | - 'requirements.txt' 12 | pull_request: 13 | branches: 14 | - '*' 15 | paths: 16 | - '*' 17 | permissions: 18 | contents: write 19 | checks: write 20 | pull-requests: write 21 | jobs: 22 | test: 23 | runs-on: ${{ matrix.os }} 24 | timeout-minutes: 60 25 | defaults: 26 | run: 27 | shell: bash 28 | strategy: 29 | matrix: 30 | python-version: ['3.10'] 31 | os: [ubuntu-latest] 32 | 33 | steps: 34 | - uses: actions/checkout@v4 35 | - name: "Set up Python" 36 | uses: actions/setup-python@v5 37 | with: 38 | python-version: ${{ matrix.python-version }} 39 | cache: pip 40 | - name: "Clean up useless files" 41 | run: | 42 | echo "==============================================================================" 43 | echo "Freeing up disk space on CI system" 44 | echo "==============================================================================" 45 | 46 | echo "Listing 100 largest packages" 47 | dpkg-query -Wf '${Installed-Size}\t${Package}\n' | sort -n | tail -n 100 48 | df -h 49 | echo "Removing large packages" 50 | sudo apt-get remove -y '^dotnet-.*' 51 | sudo apt-get remove -y '^llvm-.*' 52 | sudo apt-get remove -y 'php.*' 53 | sudo apt-get remove -y azure-cli google-cloud-sdk google-chrome-stable firefox powershell mono-devel 54 | sudo apt-get autoremove -y 55 | sudo apt-get clean 56 | df -h 57 | echo "Removing large directories" 58 | # deleting 15GB 59 | rm -rf /usr/share/dotnet/ 60 | rm -rf /opt/hostedtoolcache 61 | df -h 62 | - name: "Install dependencies" 63 | run: | 64 | pip install --upgrade pip 65 | python -m pip install torch lightning[extra] pycairo # Extra dependency since we don't want to force user to use torch version 66 | pip install --extra-index-url https://pypi.nvidia.com cuml-cu11 67 | pip install --extra-index-url https://download.pytorch.org/whl/cu117 torch==2.0.0+cu117 68 | - name: "Install package" 69 | run: | 70 | pip install . 71 | - name: "Check dependencies aren't broken" 72 | run: python -m pip check 73 | - name: "Check package can be imported" 74 | run: python -c "import astroclip" 75 | - name: "Run tests" 76 | run: | 77 | pip install pytest 78 | python -m pytest -k 'not _local' 79 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yml: -------------------------------------------------------------------------------- 1 | name: Pre-commit 2 | on: 3 | push: 4 | branches: 5 | - '*' 6 | paths: 7 | - '*' 8 | pull_request: 9 | branches: 10 | - '*' 11 | paths: 12 | - '*' 13 | jobs: 14 | pre-commit: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - name: Checkout 18 | uses: actions/checkout@v4 19 | - name: Set up Python 20 | uses: actions/setup-python@v5 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip isort 24 | python -m pip install pre-commit 25 | - name: Run pre-commit 26 | run: python -m pre_commit run --all-files 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | astroclip/_version.py 163 | 164 | 165 | outputs 166 | wandb 167 | logs 168 | notebooks/dev 169 | notebooks/dev.ipynb 170 | *ckpt 171 | .vscode 172 | .local 173 | ceph_data 174 | astroclip/_version.py 175 | lightning_logs 176 | supervised.sh 177 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | # General linting 3 | - repo: https://github.com/pre-commit/pre-commit-hooks 4 | rev: v4.5.0 5 | hooks: 6 | - id: trailing-whitespace 7 | - id: end-of-file-fixer 8 | - id: check-yaml 9 | - id: check-added-large-files 10 | # General formatting 11 | - repo: https://github.com/psf/black 12 | rev: 23.12.1 13 | hooks: 14 | - id: black 15 | - id: black-jupyter 16 | # Stripping notebooks 17 | - repo: https://github.com/kynan/nbstripout 18 | rev: 0.6.1 19 | hooks: 20 | - id: nbstripout 21 | exclude: pysr/test/test_nb.ipynb 22 | # Unused imports 23 | - repo: https://github.com/hadialqattan/pycln 24 | rev: "v2.4.0" 25 | hooks: 26 | - id: pycln 27 | # Sorted imports 28 | - repo: https://github.com/PyCQA/isort 29 | rev: "5.13.2" 30 | hooks: 31 | - id: isort 32 | additional_dependencies: [toml] 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Polymathic AI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /assets/decals.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolymathicAI/AstroCLIP/9d8506970773fbbf4f0d445d2b52def77bd60f56/assets/decals.png -------------------------------------------------------------------------------- /assets/im_embedding.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolymathicAI/AstroCLIP/9d8506970773fbbf4f0d445d2b52def77bd60f56/assets/im_embedding.png -------------------------------------------------------------------------------- /astroclip/__init__.py: -------------------------------------------------------------------------------- 1 | from . import data, models, modules 2 | from .callbacks import CustomSaveConfigCallback, CustomWandbLogger, PlotsCallback 3 | from .env import format_with_env 4 | from .scheduler import CosineAnnealingWithWarmupLR 5 | -------------------------------------------------------------------------------- /astroclip/astrodino/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolymathicAI/AstroCLIP/9d8506970773fbbf4f0d445d2b52def77bd60f56/astroclip/astrodino/__init__.py -------------------------------------------------------------------------------- /astroclip/astrodino/config.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | WEIGHTS: '' 3 | compute_precision: 4 | grad_scaler: true 5 | teacher: 6 | backbone: 7 | sharding_strategy: SHARD_GRAD_OP 8 | mixed_precision: 9 | param_dtype: fp16 10 | reduce_dtype: fp16 11 | buffer_dtype: fp32 12 | dino_head: 13 | sharding_strategy: SHARD_GRAD_OP 14 | mixed_precision: 15 | param_dtype: fp16 16 | reduce_dtype: fp16 17 | buffer_dtype: fp32 18 | ibot_head: 19 | sharding_strategy: SHARD_GRAD_OP 20 | mixed_precision: 21 | param_dtype: fp16 22 | reduce_dtype: fp16 23 | buffer_dtype: fp32 24 | student: 25 | backbone: 26 | sharding_strategy: SHARD_GRAD_OP 27 | mixed_precision: 28 | param_dtype: fp16 29 | reduce_dtype: fp16 30 | buffer_dtype: fp32 31 | dino_head: 32 | sharding_strategy: SHARD_GRAD_OP 33 | mixed_precision: 34 | param_dtype: fp16 35 | reduce_dtype: fp32 36 | buffer_dtype: fp32 37 | ibot_head: 38 | sharding_strategy: SHARD_GRAD_OP 39 | mixed_precision: 40 | param_dtype: fp16 41 | reduce_dtype: fp32 42 | buffer_dtype: fp32 43 | dino: 44 | loss_weight: 1.0 45 | head_n_prototypes: 65536 46 | head_bottleneck_dim: 256 47 | head_nlayers: 3 48 | head_hidden_dim: 2048 49 | koleo_loss_weight: 0.1 50 | ibot: 51 | loss_weight: 1.0 52 | mask_sample_probability: 0.5 53 | mask_ratio_min_max: 54 | - 0.1 55 | - 0.5 56 | separate_head: false 57 | head_n_prototypes: 65536 58 | head_bottleneck_dim: 256 59 | head_nlayers: 3 60 | head_hidden_dim: 2048 61 | train: 62 | batch_size_per_gpu: 72 63 | dataset_path: LegacySurvey:split=train:root={ASTROCLIP_ROOT}/datasets/decals:extra="" 64 | output_dir: . 65 | saveckp_freq: 20 66 | seed: 0 67 | num_workers: 10 68 | OFFICIAL_EPOCH_LENGTH: 1250 69 | cache_dataset: true 70 | centering: "centering" # or "sinkhorn_knopp" 71 | student: 72 | arch: vit_large 73 | patch_size: 12 74 | drop_path_rate: 0.3 75 | layerscale: 1.0e-05 76 | drop_path_uniform: true 77 | pretrained_weights: '' 78 | ffn_layer: "mlp" 79 | block_chunks: 4 # Is 0 is the normal config 80 | qkv_bias: true 81 | proj_bias: true 82 | ffn_bias: true 83 | teacher: 84 | momentum_teacher: 0.992 85 | final_momentum_teacher: 1 86 | warmup_teacher_temp: 0.04 87 | teacher_temp: 0.07 88 | warmup_teacher_temp_epochs: 30 89 | optim: 90 | epochs: 200 91 | weight_decay: 0.001 92 | weight_decay_end: 0.01 93 | base_lr: 2.0e-4 # learning rate for a batch size of 1024 94 | lr: 0. # will be set after applying scaling rule 95 | warmup_epochs: 32 96 | min_lr: 1.0e-06 97 | clip_grad: 3.0 98 | freeze_last_layer_epochs: 1 99 | scaling_rule: sqrt_wrt_1024 100 | patch_embed_lr_mult: 0.2 101 | layerwise_decay: 0.9 102 | adamw_beta1: 0.9 103 | adamw_beta2: 0.999 104 | crops: 105 | global_crops_scale: 106 | - 0.8 107 | - 1.0 108 | local_crops_number: 8 109 | local_crops_scale: 110 | - 0.4 111 | - 0.6 112 | global_crops_size: 144 # was 224 113 | local_crops_size: 60 # 96 114 | evaluation: 115 | eval_period_iterations: 12500 116 | -------------------------------------------------------------------------------- /astroclip/astrodino/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolymathicAI/AstroCLIP/9d8506970773fbbf4f0d445d2b52def77bd60f56/astroclip/astrodino/data/__init__.py -------------------------------------------------------------------------------- /astroclip/astrodino/data/augmentations.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import logging 7 | from typing import List 8 | 9 | import numpy as np 10 | import skimage.filters 11 | import skimage.transform 12 | import torch 13 | from torchvision import transforms 14 | 15 | logger = logging.getLogger("dinov2") 16 | 17 | 18 | class DataAugmentationAstroDINO(object): 19 | def __init__( 20 | self, 21 | global_crops_scale, 22 | local_crops_scale, 23 | local_crops_number, 24 | global_crops_size=144, 25 | local_crops_size=60, 26 | ): 27 | self.global_crops_scale = global_crops_scale 28 | self.local_crops_scale = local_crops_scale 29 | self.local_crops_number = local_crops_number 30 | self.global_crops_size = global_crops_size 31 | self.local_crops_size = local_crops_size 32 | 33 | logger.info("###################################") 34 | logger.info("Using data augmentation parameters:") 35 | logger.info(f"global_crops_scale: {global_crops_scale}") 36 | logger.info(f"local_crops_scale: {local_crops_scale}") 37 | logger.info(f"local_crops_number: {local_crops_number}") 38 | logger.info(f"global_crops_size: {global_crops_size}") 39 | logger.info(f"local_crops_size: {local_crops_size}") 40 | logger.info("###################################") 41 | 42 | # random resized crop and flip 43 | self.geometric_augmentation_global = transforms.Compose( 44 | [ 45 | transforms.RandomCrop(global_crops_size), 46 | transforms.RandomHorizontalFlip(p=0.5), 47 | transforms.RandomVerticalFlip(p=0.5), 48 | ] 49 | ) 50 | 51 | self.geometric_augmentation_local = transforms.Compose( 52 | [ 53 | transforms.RandomCrop(local_crops_size), 54 | transforms.RandomHorizontalFlip(p=0.5), 55 | transforms.RandomVerticalFlip(p=0.5), 56 | ] 57 | ) 58 | 59 | global_transfo1_extra = transforms.Compose( 60 | [ 61 | RandomGaussianBlur(p=1.0), 62 | RandomGaussianNoise(p=1.0, im_dim=global_crops_size), 63 | ] 64 | ) 65 | 66 | global_transfo2_extra = transforms.Compose( 67 | [ 68 | RandomGaussianBlur(p=0.1), 69 | RandomGaussianNoise(p=0.1, im_dim=global_crops_size), 70 | ] 71 | ) 72 | 73 | local_transfo_extra = transforms.Compose( 74 | [ 75 | RandomGaussianBlur(p=0.5), 76 | RandomGaussianNoise(p=0.5, im_dim=local_crops_size), 77 | ] 78 | ) 79 | 80 | to_rgb = ToRGB() 81 | 82 | self.global_transfo1 = transforms.Compose([global_transfo1_extra, to_rgb]) 83 | self.global_transfo2 = transforms.Compose([global_transfo2_extra, to_rgb]) 84 | self.local_transfo = transforms.Compose([local_transfo_extra, to_rgb]) 85 | 86 | def __call__(self, image): 87 | output = {} 88 | 89 | # global crops: 90 | im1_base = np.array(self.geometric_augmentation_global(image)) 91 | global_crop_1 = torch.tensor(self.global_transfo1(im1_base)).permute(2, 0, 1) 92 | 93 | im2_base = np.array(self.geometric_augmentation_global(image)) 94 | global_crop_2 = torch.tensor(self.global_transfo2(im2_base)).permute(2, 0, 1) 95 | 96 | output["global_crops"] = [global_crop_1, global_crop_2] 97 | 98 | # global crops for teacher: 99 | output["global_crops_teacher"] = [global_crop_1, global_crop_2] 100 | 101 | # local crops: 102 | local_crops = [ 103 | torch.tensor( 104 | self.local_transfo(np.array(self.geometric_augmentation_local(image))) 105 | ).permute(2, 0, 1) 106 | for _ in range(self.local_crops_number) 107 | ] 108 | output["local_crops"] = local_crops 109 | output["offsets"] = () 110 | 111 | return output 112 | 113 | 114 | class RandomGaussianBlur(transforms.RandomApply): 115 | """Randomly apply Gaussian blur to the image.""" 116 | 117 | def __init__(self, *, p: float = 0.5): 118 | keep_p = 1 - p 119 | transform = GaussianBlur() 120 | super().__init__([transform], p=keep_p) 121 | 122 | 123 | class RandomGaussianNoise(transforms.RandomApply): 124 | """Randomly apply Gaussian noise to the image.""" 125 | 126 | def __init__(self, *, im_dim=144, p: float = 0.5): 127 | keep_p = 1 - p 128 | transform = GaussianNoise(im_dim=im_dim) 129 | super().__init__([transform], p=keep_p) 130 | 131 | 132 | class ToRGB: 133 | """ 134 | Transformation from raw image data (nanomaggies) to the rgb values displayed 135 | at the legacy viewer https://www.legacysurvey.org/viewer 136 | 137 | Code copied from 138 | https://github.com/legacysurvey/imagine/blob/master/map/views.py 139 | """ 140 | 141 | def __init__(self, scales=None, m=0.03, Q=20, bands=["g", "r", "z"]): 142 | rgb_scales = { 143 | "u": (2, 1.5), 144 | "g": (2, 6.0), 145 | "r": (1, 3.4), 146 | "i": (0, 1.0), 147 | "z": (0, 2.2), 148 | } 149 | if scales is not None: 150 | rgb_scales.update(scales) 151 | 152 | self.rgb_scales = rgb_scales 153 | self.m = m 154 | self.Q = Q 155 | self.bands = bands 156 | self.axes, self.scales = zip(*[rgb_scales[bands[i]] for i in range(len(bands))]) 157 | 158 | # rearange scales to correspond to image channels after swapping 159 | self.scales = [self.scales[i] for i in self.axes] 160 | 161 | def __call__(self, imgs): 162 | # Check image shape and set to C x H x W 163 | if imgs.shape[0] != len(self.bands): 164 | imgs = np.transpose(imgs, (2, 0, 1)) 165 | 166 | I = 0 167 | for img, band in zip(imgs, self.bands): 168 | plane, scale = self.rgb_scales[band] 169 | img = np.maximum(0, img * scale + self.m) 170 | I = I + img 171 | I /= len(self.bands) 172 | 173 | Q = 20 174 | fI = np.arcsinh(Q * I) / np.sqrt(Q) 175 | I += (I == 0.0) * 1e-6 176 | H, W = I.shape 177 | rgb = np.zeros((H, W, 3), np.float32) 178 | for img, band in zip(imgs, self.bands): 179 | plane, scale = self.rgb_scales[band] 180 | rgb[:, :, plane] = (img * scale + self.m) * fI / I 181 | 182 | rgb = np.clip(rgb, 0, 1) 183 | return rgb 184 | 185 | 186 | class GaussianNoise: 187 | """ 188 | Augmentations tuned to the Legacy Survey Data (with minor modifications). 189 | 190 | Code copied from 191 | https://github.com/georgestein/ssl-legacysurvey/blob/main/ssl_legacysurvey/data_loaders/decals_augmentations.py#L296 192 | """ 193 | 194 | def __init__( 195 | self, 196 | scaling: List = [1.0], 197 | mean: float = 0, 198 | im_dim: int = 144, 199 | im_ch: int = 3, 200 | decals: bool = True, 201 | uniform: bool = False, 202 | ): 203 | self.mean = mean 204 | self.decals = decals 205 | self.im_ch = im_ch 206 | self.im_dim = im_dim 207 | self.uniform = uniform 208 | 209 | # Log normal fit paramaters 210 | self.shape_dist = np.array([0.2264926, 0.2431146, 0.1334844]) 211 | self.loc_dist = np.array([-0.0006735, -0.0023663, -0.0143416]) 212 | self.scale_dist = np.array([0.0037602, 0.0067417, 0.0260779]) 213 | 214 | self.sigma_dist = np.log(self.scale_dist) 215 | 216 | # noise in channels is uncorrelated, as images taken at dirrerent times/telescopes 217 | self.noise_ch_min = np.array([0.001094, 0.001094, 0.001094]) 218 | self.noise_ch_max = np.array([0.013, 0.018, 0.061]) 219 | 220 | def __call__(self, image: np.ndarray): 221 | # draw 'true' noise level of each channel from lognormal fits 222 | self.sigma_true = ( 223 | np.random.lognormal(self.sigma_dist, self.shape_dist) + self.loc_dist 224 | ) 225 | 226 | if self.uniform: 227 | # draw desired augmented noise level from uniform, to target tails more 228 | self.sigma_final = np.random.uniform(self.noise_ch_min, self.noise_ch_max) 229 | else: 230 | self.sigma_final = ( 231 | np.random.lognormal(self.sigma_dist, self.shape_dist) + self.loc_dist 232 | ) 233 | 234 | # Gaussian noise adds as c^2 = a^2 + b^2 235 | self.sigma_augment = self.sigma_final**2 - self.sigma_true**2 236 | self.sigma_augment[self.sigma_augment < 0.0] = 0.0 237 | self.sigma_augment = np.sqrt(self.sigma_augment) 238 | 239 | for i in range(self.im_ch): 240 | if self.sigma_augment[i] > 0.0: 241 | image[i, :, :] += np.random.normal( 242 | self.mean, self.sigma_augment[i], size=(self.im_dim, self.im_dim) 243 | ) 244 | 245 | return image 246 | 247 | 248 | class GaussianBlur: 249 | """ 250 | Augmentations tuned to the Legacy Survey Data (with minor modifications). 251 | 252 | Code copied from 253 | https://github.com/georgestein/ssl-legacysurvey/blob/main/ssl_legacysurvey/data_loaders/decals_augmentations.py#L296 254 | """ 255 | 256 | def __init__( 257 | self, 258 | scaling: List = [1.0], 259 | im_dim: int = 144, 260 | im_ch: int = 3, 261 | decals: bool = True, 262 | uniform: bool = False, 263 | ): 264 | self.decals = decals 265 | self.im_ch = im_ch 266 | self.im_dim = im_dim 267 | self.uniform = uniform 268 | 269 | # Log normal fit paramaters 270 | self.shape_dist = np.array([0.2109966, 0.3008485, 0.3471172]) 271 | self.loc_dist = np.array([1.0807153, 1.2394326, 1.1928363]) 272 | self.scale_dist = np.array([1.3153171, 0.9164757, 0.8233702]) 273 | 274 | self.sigma_dist = np.log(self.scale_dist) 275 | 276 | self.psf_ch_min = np.array([1.3233109, 1.2667341, 1.2126263]) 277 | self.psf_ch_max = np.array([5.0, 4.5, 4.25]) 278 | 279 | def __call__(self, image: np.ndarray): 280 | # noise in channels is uncorrelated, as images taken at different times/telescopes 281 | # draw 'true' noise level of each channel from lognormal fits 282 | self.sigma_true = ( 283 | np.random.lognormal(self.sigma_dist, self.shape_dist) + self.loc_dist 284 | ) 285 | 286 | if self.uniform: 287 | # draw desired augmented noise level from uniform, to target tails more 288 | self.sigma_final = np.random.uniform(self.psf_ch_min, self.psf_ch_max) 289 | else: 290 | self.sigma_final = ( 291 | np.random.lognormal(self.sigma_dist, self.shape_dist) + self.loc_dist 292 | ) 293 | 294 | # Gaussian noise adds as c^2 = a^2 + b^2 295 | self.sigma_augment = self.sigma_final**2 - self.sigma_true**2 296 | self.sigma_augment[self.sigma_augment < 0.0] = 0.0 297 | self.sigma_augment = np.sqrt(self.sigma_augment) 298 | 299 | for i in range(self.im_ch): 300 | if self.sigma_augment[i] > 0.0: 301 | image[i, :, :] = skimage.filters.gaussian( 302 | image[i, :, :], sigma=self.sigma_augment[i], mode="reflect" 303 | ) 304 | 305 | return image 306 | -------------------------------------------------------------------------------- /astroclip/astrodino/data/dataset.py: -------------------------------------------------------------------------------- 1 | # Dataset file for DESI Legacy Survey data 2 | import logging 3 | import os 4 | from enum import Enum 5 | from typing import Any, Callable, Optional, Tuple, Union 6 | 7 | import h5py 8 | import numpy as np 9 | import torch 10 | from PIL import Image as im 11 | from torchvision.datasets import VisionDataset 12 | 13 | logger = logging.getLogger("astrodino") 14 | _Target = float 15 | 16 | 17 | class _SplitFull(Enum): 18 | TRAIN = "train" 19 | VAL = "val" 20 | TEST = "test" # NOTE: torchvision does not support the test split 21 | 22 | @property 23 | def length(self) -> int: 24 | split_lengths = { 25 | _SplitFull.TRAIN: 74_500_000, 26 | _SplitFull.VAL: 100_000, 27 | _SplitFull.TEST: 400_000, 28 | } 29 | return split_lengths[self] 30 | 31 | 32 | class LegacySurvey(VisionDataset): 33 | Target = Union[_Target] 34 | Split = Union[_SplitFull] 35 | 36 | def __init__( 37 | self, 38 | *, 39 | split: "LegacySurvey.Split", 40 | root: str, 41 | extra: str = None, 42 | transforms: Optional[Callable] = None, 43 | transform: Optional[Callable] = None, 44 | target_transform: Optional[Callable] = None, 45 | ) -> None: 46 | super().__init__(root, transforms, transform, target_transform) 47 | self._extra_root = extra 48 | self._split = split 49 | 50 | # We start by opening the hdf5 files located at the root directory 51 | self._files = [ 52 | h5py.File( 53 | os.path.join( 54 | root, "north/images_npix152_0%02d000000_0%02d000000.h5" % (i, i + 1) 55 | ) 56 | ) 57 | for i in range(14) 58 | ] 59 | self._files += [ 60 | h5py.File( 61 | os.path.join( 62 | root, "south/images_npix152_0%02d000000_0%02d000000.h5" % (i, i + 1) 63 | ) 64 | ) 65 | for i in range(61) 66 | ] 67 | 68 | # Create randomized array of indices 69 | rng = np.random.default_rng(seed=42) 70 | self._indices = rng.permutation(int(7.5e7)) 71 | if split == LegacySurvey.Split.TRAIN.value: 72 | self._indices = self._indices[:74_500_000] 73 | elif split == LegacySurvey.Split.VAL.value: 74 | self._indices = self._indices[74_500_000:-400_000] 75 | else: 76 | self._indices = self._indices[-400_000:] 77 | 78 | @property 79 | def split(self) -> "LegacySurvey.Split": 80 | return self._split 81 | 82 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 83 | true_index = self._indices[index] 84 | image = self._files[true_index // int(1e6)]["images"][ 85 | true_index % int(1e6) 86 | ].astype("float32") 87 | image = torch.tensor(image) 88 | target = None 89 | 90 | if self.transforms is not None: 91 | image, target = self.transforms(image, target) 92 | 93 | return image, target 94 | 95 | def __len__(self) -> int: 96 | return len(self._indices) 97 | 98 | 99 | class _SplitNorth(Enum): 100 | TRAIN = "train" 101 | VAL = "val" 102 | TEST = "test" # NOTE: torchvision does not support the test split 103 | 104 | @property 105 | def length(self) -> int: 106 | split_lengths = { 107 | _SplitNorth.TRAIN: 13_500_000, 108 | _SplitNorth.VAL: 100_000, 109 | _SplitNorth.TEST: 400_000, 110 | } 111 | return split_lengths[self] 112 | 113 | 114 | class LegacySurveyNorth(VisionDataset): 115 | Target = Union[_Target] 116 | Split = Union[_SplitNorth] 117 | 118 | def __init__( 119 | self, 120 | *, 121 | split: "LegacySurvey.Split", 122 | root: str, 123 | extra: str = None, 124 | transforms: Optional[Callable] = None, 125 | transform: Optional[Callable] = None, 126 | target_transform: Optional[Callable] = None, 127 | ) -> None: 128 | super().__init__(root, transforms, transform, target_transform) 129 | self._extra_root = extra 130 | self._split = split 131 | 132 | # We start by opening the hdf5 files located at the root directory 133 | self._files = [ 134 | h5py.File( 135 | os.path.join( 136 | root, "north/images_npix152_0%02d000000_0%02d000000.h5" % (i, i + 1) 137 | ) 138 | ) 139 | for i in range(14) 140 | ] 141 | 142 | # Create randomized array of indices 143 | rng = np.random.default_rng(seed=42) 144 | self._indices = rng.permutation(int(1.4e7)) 145 | if split == LegacySurvey.Split.TRAIN.value: 146 | self._indices = self._indices[:13_500_000] 147 | elif split == LegacySurvey.Split.VAL.value: 148 | self._indices = self._indices[13_500_000:-400_000] 149 | else: 150 | self._indices = self._indices[-400_000:] 151 | 152 | @property 153 | def split(self) -> "LegacySurvey.Split": 154 | return self._split 155 | 156 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 157 | true_index = self._indices[index] 158 | image = self._files[true_index // int(1e6)]["images"][ 159 | true_index % int(1e6) 160 | ].astype("float32") 161 | image = torch.tensor(image) 162 | target = None 163 | 164 | if self.transforms is not None: 165 | image, target = self.transforms(image, target) 166 | 167 | return image, target 168 | 169 | def __len__(self) -> int: 170 | return len(self._indices) 171 | -------------------------------------------------------------------------------- /astroclip/astrodino/data/loaders.py: -------------------------------------------------------------------------------- 1 | # Overriding default Dinov2 data loader function 2 | 3 | import logging 4 | from enum import Enum 5 | from typing import Any, Callable, List, Optional, TypeVar 6 | 7 | import torch 8 | from dinov2.data.samplers import EpochSampler, InfiniteSampler, ShardedInfiniteSampler 9 | from torch.utils.data import Sampler 10 | 11 | from .dataset import LegacySurvey, LegacySurveyNorth 12 | 13 | logger = logging.getLogger("dinov2") 14 | 15 | 16 | class SamplerType(Enum): 17 | DISTRIBUTED = 0 18 | EPOCH = 1 19 | INFINITE = 2 20 | SHARDED_INFINITE = 3 21 | SHARDED_INFINITE_NEW = 4 22 | 23 | 24 | def _parse_dataset_str(dataset_str: str): 25 | tokens = dataset_str.split(":") 26 | 27 | name = tokens[0] 28 | kwargs = {} 29 | 30 | for token in tokens[1:]: 31 | key, value = token.split("=") 32 | assert key in ("root", "extra", "split") 33 | kwargs[key] = value 34 | 35 | if name == "LegacySurvey": 36 | class_ = LegacySurvey 37 | if "split" in kwargs: 38 | kwargs["split"] = kwargs["split"] 39 | elif name == "LegacySurveyNorth": 40 | class_ = LegacySurveyNorth 41 | if "split" in kwargs: 42 | kwargs["split"] = kwargs["split"] 43 | else: 44 | raise ValueError(f'Unsupported dataset "{name}"') 45 | 46 | return class_, kwargs 47 | 48 | 49 | def make_dataset( 50 | *, 51 | dataset_str: str, 52 | transform: Optional[Callable] = None, 53 | target_transform: Optional[Callable] = None, 54 | ): 55 | """ 56 | Creates a dataset with the specified parameters. 57 | 58 | Args: 59 | dataset_str: A dataset string description (e.g. ImageNet:split=TRAIN). 60 | transform: A transform to apply to images. 61 | target_transform: A transform to apply to targets. 62 | 63 | Returns: 64 | The created dataset. 65 | """ 66 | logger.info(f'using dataset: "{dataset_str}"') 67 | 68 | class_, kwargs = _parse_dataset_str(dataset_str) 69 | dataset = class_(transform=transform, target_transform=target_transform, **kwargs) 70 | 71 | logger.info(f"# of dataset samples: {len(dataset):,d}") 72 | 73 | # Aggregated datasets do not expose (yet) these attributes, so add them. 74 | if not hasattr(dataset, "transform"): 75 | setattr(dataset, "transform", transform) 76 | if not hasattr(dataset, "target_transform"): 77 | setattr(dataset, "target_transform", target_transform) 78 | 79 | return dataset 80 | 81 | 82 | def _make_sampler( 83 | *, 84 | dataset, 85 | type: Optional[SamplerType] = None, 86 | shuffle: bool = False, 87 | seed: int = 0, 88 | size: int = -1, 89 | advance: int = 0, 90 | ) -> Optional[Sampler]: 91 | sample_count = len(dataset) 92 | 93 | if type == SamplerType.INFINITE: 94 | logger.info("sampler: infinite") 95 | if size > 0: 96 | raise ValueError("sampler size > 0 is invalid") 97 | return InfiniteSampler( 98 | sample_count=sample_count, 99 | shuffle=shuffle, 100 | seed=seed, 101 | advance=advance, 102 | ) 103 | elif type in (SamplerType.SHARDED_INFINITE, SamplerType.SHARDED_INFINITE_NEW): 104 | logger.info("sampler: sharded infinite") 105 | if size > 0: 106 | raise ValueError("sampler size > 0 is invalid") 107 | # TODO: Remove support for old shuffling 108 | use_new_shuffle_tensor_slice = type == SamplerType.SHARDED_INFINITE_NEW 109 | return ShardedInfiniteSampler( 110 | sample_count=sample_count, 111 | shuffle=shuffle, 112 | seed=seed, 113 | advance=advance, 114 | use_new_shuffle_tensor_slice=use_new_shuffle_tensor_slice, 115 | ) 116 | elif type == SamplerType.EPOCH: 117 | logger.info("sampler: epoch") 118 | if advance > 0: 119 | raise NotImplementedError("sampler advance > 0 is not supported") 120 | size = size if size > 0 else sample_count 121 | logger.info(f"# of samples / epoch: {size:,d}") 122 | return EpochSampler( 123 | size=size, 124 | sample_count=sample_count, 125 | shuffle=shuffle, 126 | seed=seed, 127 | ) 128 | elif type == SamplerType.DISTRIBUTED: 129 | logger.info("sampler: distributed") 130 | if size > 0: 131 | raise ValueError("sampler size > 0 is invalid") 132 | if advance > 0: 133 | raise ValueError("sampler advance > 0 is invalid") 134 | return torch.utils.data.DistributedSampler( 135 | dataset=dataset, 136 | shuffle=shuffle, 137 | seed=seed, 138 | drop_last=False, 139 | ) 140 | 141 | logger.info("sampler: none") 142 | return None 143 | 144 | 145 | T = TypeVar("T") 146 | 147 | 148 | def make_data_loader( 149 | *, 150 | dataset, 151 | batch_size: int, 152 | num_workers: int, 153 | shuffle: bool = True, 154 | seed: int = 0, 155 | sampler_type: Optional[SamplerType] = SamplerType.INFINITE, 156 | sampler_size: int = -1, 157 | sampler_advance: int = 0, 158 | drop_last: bool = True, 159 | persistent_workers: bool = False, 160 | collate_fn: Optional[Callable[[List[T]], Any]] = None, 161 | ): 162 | """ 163 | Creates a data loader with the specified parameters. 164 | 165 | Args: 166 | dataset: A dataset (third party, LaViDa or WebDataset). 167 | batch_size: The size of batches to generate. 168 | num_workers: The number of workers to use. 169 | shuffle: Whether to shuffle samples. 170 | seed: The random seed to use. 171 | sampler_type: Which sampler to use: EPOCH, INFINITE, SHARDED_INFINITE, SHARDED_INFINITE_NEW, DISTRIBUTED or None. 172 | sampler_size: The number of images per epoch (when applicable) or -1 for the entire dataset. 173 | sampler_advance: How many samples to skip (when applicable). 174 | drop_last: Whether the last non-full batch of data should be dropped. 175 | persistent_workers: maintain the workers Dataset instances alive after a dataset has been consumed once. 176 | collate_fn: Function that performs batch collation 177 | """ 178 | 179 | sampler = _make_sampler( 180 | dataset=dataset, 181 | type=sampler_type, 182 | shuffle=shuffle, 183 | seed=seed, 184 | size=sampler_size, 185 | advance=sampler_advance, 186 | ) 187 | 188 | logger.info("using PyTorch data loader") 189 | data_loader = torch.utils.data.DataLoader( 190 | dataset, 191 | sampler=sampler, 192 | batch_size=batch_size, 193 | num_workers=num_workers, 194 | pin_memory=True, 195 | drop_last=drop_last, 196 | persistent_workers=persistent_workers, 197 | collate_fn=collate_fn, 198 | ) 199 | 200 | try: 201 | logger.info(f"# of batches: {len(data_loader):,d}") 202 | except TypeError: # data loader has no length 203 | logger.info("infinite data loader") 204 | return data_loader 205 | -------------------------------------------------------------------------------- /astroclip/astrodino/distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | import random 8 | import re 9 | import socket 10 | from typing import Dict, List 11 | 12 | import torch 13 | import torch.distributed as dist 14 | 15 | _LOCAL_RANK = -1 16 | _LOCAL_WORLD_SIZE = -1 17 | 18 | 19 | def is_enabled() -> bool: 20 | """ 21 | Returns: 22 | True if distributed training is enabled 23 | """ 24 | return dist.is_available() and dist.is_initialized() 25 | 26 | 27 | def get_global_size() -> int: 28 | """ 29 | Returns: 30 | The number of processes in the process group 31 | """ 32 | return dist.get_world_size() if is_enabled() else 1 33 | 34 | 35 | def get_global_rank() -> int: 36 | """ 37 | Returns: 38 | The rank of the current process within the global process group. 39 | """ 40 | return dist.get_rank() if is_enabled() else 0 41 | 42 | 43 | def get_local_rank() -> int: 44 | """ 45 | Returns: 46 | The rank of the current process within the local (per-machine) process group. 47 | """ 48 | if not is_enabled(): 49 | return 0 50 | assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE 51 | return _LOCAL_RANK 52 | 53 | 54 | def get_local_size() -> int: 55 | """ 56 | Returns: 57 | The size of the per-machine process group, 58 | i.e. the number of processes per machine. 59 | """ 60 | if not is_enabled(): 61 | return 1 62 | assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE 63 | return _LOCAL_WORLD_SIZE 64 | 65 | 66 | def is_main_process() -> bool: 67 | """ 68 | Returns: 69 | True if the current process is the main one. 70 | """ 71 | return get_global_rank() == 0 72 | 73 | 74 | def _restrict_print_to_main_process() -> None: 75 | """ 76 | This function disables printing when not in the main process 77 | """ 78 | import builtins as __builtin__ 79 | 80 | builtin_print = __builtin__.print 81 | 82 | def print(*args, **kwargs): 83 | force = kwargs.pop("force", False) 84 | if is_main_process() or force: 85 | builtin_print(*args, **kwargs) 86 | 87 | __builtin__.print = print 88 | 89 | 90 | def _get_master_port(seed: int = 0) -> int: 91 | MIN_MASTER_PORT, MAX_MASTER_PORT = (20_000, 60_000) 92 | 93 | master_port_str = os.environ.get("MASTER_PORT") 94 | if master_port_str is None: 95 | rng = random.Random(seed) 96 | return rng.randint(MIN_MASTER_PORT, MAX_MASTER_PORT) 97 | 98 | return int(master_port_str) 99 | 100 | 101 | def _get_available_port() -> int: 102 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 103 | # A "" host address means INADDR_ANY i.e. binding to all interfaces. 104 | # Note this is not compatible with IPv6. 105 | s.bind(("", 0)) 106 | port = s.getsockname()[1] 107 | return port 108 | 109 | 110 | _TORCH_DISTRIBUTED_ENV_VARS = ( 111 | "MASTER_ADDR", 112 | "MASTER_PORT", 113 | "RANK", 114 | "WORLD_SIZE", 115 | "LOCAL_RANK", 116 | "LOCAL_WORLD_SIZE", 117 | ) 118 | 119 | 120 | def _collect_env_vars() -> Dict[str, str]: 121 | return { 122 | env_var: os.environ[env_var] 123 | for env_var in _TORCH_DISTRIBUTED_ENV_VARS 124 | if env_var in os.environ 125 | } 126 | 127 | 128 | def _is_slurm_job_process() -> bool: 129 | return "SLURM_JOB_ID" in os.environ 130 | 131 | 132 | def _parse_slurm_node_list(s: str) -> List[str]: 133 | nodes = [] 134 | # Extract "hostname", "hostname[1-2,3,4-5]," substrings 135 | p = re.compile(r"(([^\[]+)(?:\[([^\]]+)\])?),?") 136 | for m in p.finditer(s): 137 | prefix, suffixes = s[m.start(2) : m.end(2)], s[m.start(3) : m.end(3)] 138 | for suffix in suffixes.split(","): 139 | span = suffix.split("-") 140 | if len(span) == 1: 141 | nodes.append(prefix + suffix) 142 | else: 143 | width = len(span[0]) 144 | start, end = int(span[0]), int(span[1]) + 1 145 | nodes.extend([prefix + f"{i:0{width}}" for i in range(start, end)]) 146 | return [i for n in nodes for i in n.split(",")] 147 | 148 | 149 | def _check_env_variable(key: str, new_value: str): 150 | # Only check for difference with preset environment variables 151 | if key in os.environ and os.environ[key] != new_value: 152 | raise RuntimeError( 153 | f"Cannot export environment variables as {key} is already set" 154 | ) 155 | 156 | 157 | class _TorchDistributedEnvironment: 158 | def __init__(self): 159 | self.master_addr = "127.0.0.1" 160 | self.master_port = 0 161 | self.rank = -1 162 | self.world_size = -1 163 | self.local_rank = -1 164 | self.local_world_size = -1 165 | 166 | if _is_slurm_job_process(): 167 | return self._set_from_slurm_env() 168 | 169 | env_vars = _collect_env_vars() 170 | if not env_vars: 171 | # Environment is not set 172 | pass 173 | elif len(env_vars) == len(_TORCH_DISTRIBUTED_ENV_VARS): 174 | # Environment is fully set 175 | return self._set_from_preset_env() 176 | else: 177 | # Environment is partially set 178 | collected_env_vars = ", ".join(env_vars.keys()) 179 | raise RuntimeError(f"Partially set environment: {collected_env_vars}") 180 | 181 | if torch.cuda.device_count() > 0: 182 | return self._set_from_local() 183 | 184 | raise RuntimeError("Can't initialize PyTorch distributed environment") 185 | 186 | # Slurm job created with sbatch, submitit, etc... 187 | def _set_from_slurm_env(self): 188 | # logger.info("Initialization from Slurm environment") 189 | job_id = int(os.environ["SLURM_JOB_ID"]) 190 | node_count = int(os.environ["SLURM_JOB_NUM_NODES"]) 191 | nodes = _parse_slurm_node_list(os.environ["SLURM_JOB_NODELIST"]) 192 | assert len(nodes) == node_count, f"Expected {node_count} nodes, got {nodes}" 193 | 194 | self.master_addr = nodes[0] 195 | self.master_port = _get_master_port(seed=job_id) 196 | self.rank = int(os.environ["SLURM_PROCID"]) 197 | self.world_size = int(os.environ["SLURM_NTASKS"]) 198 | assert self.rank < self.world_size 199 | self.local_rank = int(os.environ["SLURM_LOCALID"]) 200 | self.local_world_size = self.world_size // node_count 201 | assert self.local_rank < self.local_world_size 202 | 203 | # Single node job with preset environment (i.e. torchrun) 204 | def _set_from_preset_env(self): 205 | # logger.info("Initialization from preset environment") 206 | self.master_addr = os.environ["MASTER_ADDR"] 207 | self.master_port = os.environ["MASTER_PORT"] 208 | self.rank = int(os.environ["RANK"]) 209 | self.world_size = int(os.environ["WORLD_SIZE"]) 210 | assert self.rank < self.world_size 211 | self.local_rank = int(os.environ["LOCAL_RANK"]) 212 | self.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) 213 | assert self.local_rank < self.local_world_size 214 | 215 | # Single node and GPU job (i.e. local script run) 216 | def _set_from_local(self): 217 | # logger.info("Initialization from local") 218 | self.master_addr = "127.0.0.1" 219 | self.master_port = _get_available_port() 220 | self.rank = 0 221 | self.world_size = 1 222 | self.local_rank = 0 223 | self.local_world_size = 1 224 | 225 | def export(self, *, overwrite: bool) -> "_TorchDistributedEnvironment": 226 | # See the "Environment variable initialization" section from 227 | # https://pytorch.org/docs/stable/distributed.html for the complete list of 228 | # environment variables required for the env:// initialization method. 229 | env_vars = { 230 | "MASTER_ADDR": self.master_addr, 231 | "MASTER_PORT": str(self.master_port), 232 | "RANK": str(self.rank), 233 | "WORLD_SIZE": str(self.world_size), 234 | "LOCAL_RANK": str(self.local_rank), 235 | "LOCAL_WORLD_SIZE": str(self.local_world_size), 236 | } 237 | if not overwrite: 238 | for k, v in env_vars.items(): 239 | _check_env_variable(k, v) 240 | 241 | os.environ.update(env_vars) 242 | return self 243 | 244 | 245 | def enable( 246 | *, 247 | set_cuda_current_device: bool = True, 248 | overwrite: bool = False, 249 | allow_nccl_timeout: bool = False, 250 | ): 251 | """Enable distributed mode 252 | 253 | Args: 254 | set_cuda_current_device: If True, call torch.cuda.set_device() to set the 255 | current PyTorch CUDA device to the one matching the local rank. 256 | overwrite: If True, overwrites already set variables. Else fails. 257 | """ 258 | 259 | global _LOCAL_RANK, _LOCAL_WORLD_SIZE 260 | if _LOCAL_RANK >= 0 or _LOCAL_WORLD_SIZE >= 0: 261 | raise RuntimeError("Distributed mode has already been enabled") 262 | torch_env = _TorchDistributedEnvironment() 263 | torch_env.export(overwrite=overwrite) 264 | 265 | if set_cuda_current_device: 266 | torch.cuda.set_device(torch_env.local_rank) 267 | 268 | if allow_nccl_timeout: 269 | # This allows to use torch distributed timeout in a NCCL backend 270 | key, value = "NCCL_ASYNC_ERROR_HANDLING", "1" 271 | if not overwrite: 272 | _check_env_variable(key, value) 273 | os.environ[key] = value 274 | 275 | dist.init_process_group(backend="nccl") 276 | dist.barrier() 277 | 278 | # Finalize setup 279 | _LOCAL_RANK = torch_env.local_rank 280 | _LOCAL_WORLD_SIZE = torch_env.local_world_size 281 | _restrict_print_to_main_process() 282 | -------------------------------------------------------------------------------- /astroclip/astrodino/embed_legacysurvey/embed_legacysurvey.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from multiprocessing import Pool 4 | 5 | import h5py 6 | import numpy as np 7 | import torch 8 | from datasets import load_dataset 9 | from torch import package 10 | from torchvision.transforms import CenterCrop, Compose, ToTensor 11 | from tqdm import tqdm 12 | 13 | # Set up dataset 14 | crop = CenterCrop(144) 15 | RGB_SCALES = { 16 | "u": (2, 1.5), 17 | "g": (2, 6.0), 18 | "r": (1, 3.4), 19 | "i": (0, 1.0), 20 | "z": (0, 2.2), 21 | } 22 | 23 | 24 | def decals_to_rgb(image, bands=["g", "r", "z"], scales=None, m=0.03, Q=20.0): 25 | axes, scales = zip(*[RGB_SCALES[bands[i]] for i in range(len(bands))]) 26 | scales = [scales[i] for i in axes] 27 | image = image.movedim(1, -1).flip(-1) 28 | scales = torch.tensor(scales, dtype=torch.float32).to(image.device) 29 | I = torch.sum(torch.clamp(image * scales + m, min=0), dim=-1) / len(bands) 30 | fI = torch.arcsinh(Q * I) / np.sqrt(Q) 31 | I += (I == 0.0) * 1e-6 32 | image = (image * scales + m) * (fI / I).unsqueeze(-1) 33 | image = torch.clamp(image, 0, 1) 34 | return image.movedim(-1, 1) 35 | 36 | 37 | def import_package(path: str, device: str = "cpu") -> torch.nn.Module: 38 | """Import a torch package from a given path""" 39 | importer = package.PackageImporter(path) 40 | model = importer.load_pickle("network", "network.pkl", map_location=device) 41 | return model 42 | 43 | 44 | def process_file(args) -> None: 45 | """Process a single file in the dataset""" 46 | file, save_dir, batch_size, gpu_id = args 47 | file_path = os.path.join(dset_root, file, "001-of-001.hdf5") 48 | 49 | # Set the GPU device for this process 50 | torch.cuda.set_device(gpu_id) 51 | 52 | # Load the model 53 | astrodino = import_package( 54 | "/mnt/ceph/users/polymathic/astroclip/pretrained/astrodino.pt" 55 | ).to(torch.device(f"cuda:{gpu_id}")) 56 | 57 | embeddings = [] 58 | with h5py.File(file_path, "r") as f: 59 | img_batch = [] 60 | for img in tqdm(f["image_array"]): 61 | # Convert to RGB 62 | img = crop(torch.tensor(img[[0, 1, 3]])) # get g,r,z 63 | 64 | # Append to batch 65 | img_batch.append(img) 66 | 67 | if len(img_batch) == batch_size: 68 | with torch.no_grad(): 69 | images = torch.stack(img_batch).cuda() 70 | images = decals_to_rgb(images) 71 | emb = astrodino(images) 72 | embeddings.append(emb.cpu().numpy()) 73 | im_batch = [] 74 | 75 | # Get ra, dec, obj_id 76 | ra = f["RA"][:] 77 | dec = f["DEC"][:] 78 | obj_id = f["object_id"][:] 79 | 80 | # Concatenate embeddings 81 | embeddings = np.concatenate(embeddings, axis=0) 82 | 83 | # Save embeddings 84 | save_dir = os.path.join(save_dir, file) 85 | if not os.path.exists(save_dir): 86 | os.makedirs(save_dir) 87 | 88 | save_path = os.path.join(save_dir, "001-of-001.hdf5") 89 | with h5py.File(save_path, "w") as f: 90 | f.create_dataset("embeddings", data=embeddings) 91 | f.create_dataset("RA", data=ra) 92 | f.create_dataset("DEC", data=dec) 93 | f.create_dataset("object_id", data=obj_id) 94 | 95 | 96 | def embed_legacysurvey( 97 | dset_root: str, save_dir: str, astrodino_dir: str, batch_size=512, num_gpus=4 98 | ): 99 | # List all files in the dataset 100 | files = os.listdir(dset_root) 101 | 102 | # Create arguments for each process 103 | args = [(f, save_dir, batch_size, i % num_gpus) for i, f in enumerate(files)] 104 | 105 | # Use multiprocessing to process files in parallel 106 | with Pool(processes=num_gpus) as pool: 107 | pool.map(process_file, args) 108 | 109 | 110 | if __name__ == "__main__": 111 | parser = argparse.ArgumentParser() 112 | parser.add_argument("--dset_root", type=str, required=True) 113 | parser.add_argument("--save_dir", type=str, required=True) 114 | parser.add_argument( 115 | "--astrodino_dir", 116 | type=str, 117 | default="/mnt/ceph/users/polymathic/astroclip/pretrained", 118 | ) 119 | parser.add_argument("--batch_size", type=int, default=512) 120 | parser.add_argument("--num_gpus", type=int, default=4) 121 | args = parser.parse_args() 122 | 123 | # Run the embedding process 124 | embed_legacysurvey(dset_root, save_dir, astrodino_dir, batch_size, num_gpus) 125 | -------------------------------------------------------------------------------- /astroclip/astrodino/embed_legacysurvey/launch_embedding.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -l 2 | 3 | #SBATCH -p gpu 4 | #SBATCH -N 1 5 | #SBATCH -C a100-80gb 6 | #SBATCH --ntasks-per-node=4 7 | #SBATCH --gpus-per-node=4 8 | #SBATCH --cpus-per-gpu=1 9 | #SBATCH -t 168:00:00 10 | #SBATCH --output=logs/out-%j.log 11 | #SBATCH -J "embedding" 12 | 13 | module purge 14 | module load gcc 15 | 16 | $dset_root = "/mnt/ceph/users/polymathic/MultimodalUniverse/legacysurvey/dr10_south_21" 17 | $save_root = "/mnt/ceph/users/polymathic/MultimodalUniverse/astrodino_legacysurvey" 18 | 19 | export OMP_NUM_THREADS=${SLURM_CPUS_ON_NODE} 20 | 21 | # enable logging 22 | export CUDA_LAUNCH_BLOCKING=1. 23 | 24 | source /mnt/home/lparker/python_envs/toto/bin/activate 25 | 26 | python launch_embeddings.py --dset_root $dset_root --save_root $save_root --batch_size 512 --num_gpus $SLURM_GPUS_PER_NODE 27 | -------------------------------------------------------------------------------- /astroclip/astrodino/training.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -l 2 | 3 | #SBATCH -p gpu 4 | #SBATCH -t 48:00:00 5 | #SBATCH -C a100,ib 6 | #SBATCH -N 5 7 | #SBATCH --gpus=20 8 | #SBATCH --tasks-per-node=4 9 | #SBATCH --cpus-per-task=12 10 | #SBATCH --output=logs/astrodino-%j.log 11 | 12 | module purge 13 | module load python 14 | module load cuda 15 | module load gcc 16 | 17 | random_number=$(shuf -i 2000-65000 -n 1) 18 | run_name="astroclip_$random_number" 19 | config="astroclip/astrodino/config.yaml" 20 | 21 | export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK 22 | source /mnt/home/lparker/python_envs/toto/bin/activate 23 | 24 | srun python -m astroclip.astrodino.trainer \ 25 | --config-file=$config --run-name=$run_name 26 | -------------------------------------------------------------------------------- /astroclip/astrodino/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import datetime 7 | import json 8 | import logging 9 | import os 10 | import sys 11 | import time 12 | from collections import defaultdict, deque 13 | from pathlib import Path 14 | 15 | import dinov2.distributed as distributed 16 | import torch 17 | import wandb 18 | from dinov2.eval.setup import setup_and_build_model 19 | 20 | logger = logging.getLogger("dinov2") 21 | 22 | 23 | class MetricLogger(object): 24 | def __init__(self, delimiter="\t", wandb=None, output_file=None): 25 | self.meters = defaultdict(SmoothedValue) 26 | self.delimiter = delimiter 27 | self.output_file = output_file 28 | self.wandb = wandb 29 | 30 | def update(self, **kwargs): 31 | for k, v in kwargs.items(): 32 | if isinstance(v, torch.Tensor): 33 | v = v.item() 34 | assert isinstance(v, (float, int)) 35 | self.meters[k].update(v) 36 | 37 | def __getattr__(self, attr): 38 | if attr in self.meters: 39 | return self.meters[attr] 40 | if attr in self.__dict__: 41 | return self.__dict__[attr] 42 | raise AttributeError( 43 | "'{}' object has no attribute '{}'".format(type(self).__name__, attr) 44 | ) 45 | 46 | def __str__(self): 47 | loss_str = [] 48 | for name, meter in self.meters.items(): 49 | loss_str.append("{}: {}".format(name, str(meter))) 50 | return self.delimiter.join(loss_str) 51 | 52 | def synchronize_between_processes(self): 53 | for meter in self.meters.values(): 54 | meter.synchronize_between_processes() 55 | 56 | def add_meter(self, name, meter): 57 | self.meters[name] = meter 58 | 59 | def dump_in_output_file(self, iteration, iter_time, data_time): 60 | if self.output_file is None or not distributed.is_main_process(): 61 | return 62 | 63 | os.makedirs(Path(self.output_file).parent, exist_ok=True) 64 | dict_to_dump = dict( 65 | iteration=iteration, 66 | iter_time=iter_time, 67 | data_time=data_time, 68 | ) 69 | 70 | metrics = {k: v.median for k, v in self.meters.items()} 71 | global_rank = int(os.environ.get("RANK", 0)) 72 | if self.wandb is not None and global_rank == 0: 73 | self.wandb.log(metrics, step=iteration) 74 | 75 | dict_to_dump.update(metrics) 76 | with open(self.output_file, "a") as f: 77 | f.write(json.dumps(dict_to_dump) + "\n") 78 | pass 79 | 80 | def log_every( 81 | self, iterable, print_freq, header=None, n_iterations=None, start_iteration=0 82 | ): 83 | i = start_iteration 84 | if not header: 85 | header = "" 86 | start_time = time.time() 87 | end = time.time() 88 | iter_time = SmoothedValue(fmt="{avg:.6f}") 89 | data_time = SmoothedValue(fmt="{avg:.6f}") 90 | 91 | if n_iterations is None: 92 | n_iterations = len(iterable) 93 | 94 | space_fmt = ":" + str(len(str(n_iterations))) + "d" 95 | 96 | log_list = [ 97 | header, 98 | "[{0" + space_fmt + "}/{1}]", 99 | "eta: {eta}", 100 | "{meters}", 101 | "time: {time}", 102 | "data: {data}", 103 | ] 104 | if torch.cuda.is_available(): 105 | log_list += ["max mem: {memory:.0f}"] 106 | 107 | log_msg = self.delimiter.join(log_list) 108 | MB = 1024.0 * 1024.0 109 | for obj in iterable: 110 | data_time.update(time.time() - end) 111 | yield obj 112 | iter_time.update(time.time() - end) 113 | if i % print_freq == 0 or i == n_iterations - 1: 114 | self.dump_in_output_file( 115 | iteration=i, iter_time=iter_time.avg, data_time=data_time.avg 116 | ) 117 | eta_seconds = iter_time.global_avg * (n_iterations - i) 118 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 119 | if torch.cuda.is_available(): 120 | logger.info( 121 | log_msg.format( 122 | i, 123 | n_iterations, 124 | eta=eta_string, 125 | meters=str(self), 126 | time=str(iter_time), 127 | data=str(data_time), 128 | memory=torch.cuda.max_memory_allocated() / MB, 129 | ) 130 | ) 131 | else: 132 | logger.info( 133 | log_msg.format( 134 | i, 135 | n_iterations, 136 | eta=eta_string, 137 | meters=str(self), 138 | time=str(iter_time), 139 | data=str(data_time), 140 | ) 141 | ) 142 | i += 1 143 | end = time.time() 144 | if i >= n_iterations: 145 | break 146 | total_time = time.time() - start_time 147 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 148 | logger.info( 149 | "{} Total time: {} ({:.6f} s / it)".format( 150 | header, total_time_str, total_time / n_iterations 151 | ) 152 | ) 153 | 154 | 155 | class SmoothedValue: 156 | """Track a series of values and provide access to smoothed values over a 157 | window or the global series average. 158 | """ 159 | 160 | def __init__(self, window_size=20, fmt=None): 161 | if fmt is None: 162 | fmt = "{median:.4f} ({global_avg:.4f})" 163 | self.deque = deque(maxlen=window_size) 164 | self.total = 0.0 165 | self.count = 0 166 | self.fmt = fmt 167 | 168 | def update(self, value, num=1): 169 | self.deque.append(value) 170 | self.count += num 171 | self.total += value * num 172 | 173 | def synchronize_between_processes(self): 174 | """ 175 | Distributed synchronization of the metric 176 | Warning: does not synchronize the deque! 177 | """ 178 | if not distributed.is_enabled(): 179 | return 180 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") 181 | torch.distributed.barrier() 182 | torch.distributed.all_reduce(t) 183 | t = t.tolist() 184 | self.count = int(t[0]) 185 | self.total = t[1] 186 | 187 | @property 188 | def median(self): 189 | d = torch.tensor(list(self.deque)) 190 | return d.median().item() 191 | 192 | @property 193 | def avg(self): 194 | d = torch.tensor(list(self.deque), dtype=torch.float32) 195 | return d.mean().item() 196 | 197 | @property 198 | def global_avg(self): 199 | return self.total / self.count 200 | 201 | @property 202 | def max(self): 203 | return max(self.deque) 204 | 205 | @property 206 | def value(self): 207 | return self.deque[-1] 208 | 209 | def __str__(self): 210 | return self.fmt.format( 211 | median=self.median, 212 | avg=self.avg, 213 | global_avg=self.global_avg, 214 | max=self.max, 215 | value=self.value, 216 | ) 217 | 218 | 219 | def setup_astrodino( 220 | astrodino_output_dir: str, 221 | astrodino_pretrained_weights: str, 222 | astrodino_config_file: str = "./astroclip/astrodino/config.yaml", 223 | ) -> torch.nn.Module: 224 | """Set up AstroDINO model""" 225 | 226 | # Set up config to pass to AstroDINO 227 | class config: 228 | output_dir = astrodino_output_dir 229 | config_file = astrodino_config_file 230 | pretrained_weights = astrodino_pretrained_weights 231 | opts = [] 232 | 233 | sys.stdout = open(os.devnull, "w") # Redirect stdout to null 234 | astrodino, _ = setup_and_build_model(config()) 235 | sys.stderr = sys.__stderr__ # Reset stderr 236 | return astrodino 237 | -------------------------------------------------------------------------------- /astroclip/callbacks.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Union 2 | 3 | import matplotlib.pyplot as plt 4 | import wandb 5 | from lightning import Callback, LightningModule, Trainer 6 | from lightning.pytorch.cli import SaveConfigCallback 7 | from lightning.pytorch.loggers import WandbLogger 8 | from omegaconf import OmegaConf 9 | 10 | 11 | def _safe_eval(s: str, max_len: int = 1024) -> Union[int, float]: 12 | """Safely evaluate an arithmetic expression. 13 | 14 | :param s: expression to evaluate; should only contain numbers, spaces, or the 15 | symbols `+, -, *, /, _, (, )`; exponential notation is supported 16 | :param max_len: maximum length string that will be evaluated; longer strings raise 17 | a `ValueError` 18 | """ 19 | # XXX need to be smarter about this 20 | is_safe = all(ch in "e0123456789_+-*/(). " for ch in s) 21 | if not is_safe: 22 | raise ValueError( 23 | "Only simple arithmetic expressions involving digits, parentheses, " 24 | "the letter e, or the symbols '+-*/_.' are allowed" 25 | ) 26 | if len(s) > max_len: 27 | raise ValueError(f"String length is {len(s)}, maximum allowed is {max_len}") 28 | return eval(s) 29 | 30 | 31 | # allow for the ${eval:...} resolver in the config file to perform simple arithmetic 32 | # XXX problem: accessing a node that involves resolving with `eval` is ~60x slower 33 | # than a simple numeric node (similar slowdowns for `oc` resolvers) 34 | OmegaConf.register_new_resolver("eval", _safe_eval, use_cache=True) 35 | 36 | 37 | class CustomWandbLogger(WandbLogger): 38 | # Disable unintended hyperparameter logging (already saved on init) 39 | def log_hyperparams(self, *args, **kwargs): 40 | ... 41 | 42 | 43 | class CustomSaveConfigCallback(SaveConfigCallback): 44 | """Saves full training configuration 45 | Otherwise wandb won't log full configuration but only flattened module and data hyperparameters 46 | """ 47 | 48 | def save_config( 49 | self, trainer: Trainer, pl_module: LightningModule, stage: str 50 | ) -> None: 51 | for logger in trainer.loggers: 52 | if issubclass(type(logger), WandbLogger): 53 | logger.experiment.config.update(self.config.as_dict()) 54 | return super().save_config(trainer, pl_module, stage) 55 | 56 | 57 | class PlotsCallback(Callback): 58 | # TODO: Update with latest code 59 | def __init__(self) -> None: 60 | super().__init__() 61 | 62 | def plot_spectrum(self, batch, output): 63 | sample_id = 3 64 | 65 | bs = len(batch["spectrum"]) 66 | sp_rec = batch["target"][:, 1:, 2:99].reshape(bs, -1)[sample_id] 67 | in_rec = batch["input"][:, 1:, 2:99].reshape(bs, -1)[sample_id] 68 | out_rec = output[:, 1:, 2:99].reshape(bs, -1)[sample_id] 69 | 70 | # plot the moving average of the spectrum 71 | win = 20 72 | 73 | sp_rec = [sp_rec[i : i + win].mean().item() for i in range(0, len(sp_rec), win)] 74 | in_rec = [in_rec[i : i + win].mean().item() for i in range(0, len(in_rec), win)] 75 | out_rec = [ 76 | out_rec[i : i + win].mean().item() for i in range(0, len(out_rec), win) 77 | ] 78 | 79 | fig = plt.figure() 80 | plt.plot(sp_rec, label="original") 81 | plt.plot(in_rec, label="dropped") 82 | plt.plot(out_rec, label="reconstructed") 83 | plt.legend() 84 | return fig 85 | 86 | def on_validation_batch_start( 87 | self, 88 | trainer: Trainer, 89 | pl_module: LightningModule, 90 | batch: Any, 91 | batch_idx: int, 92 | dataloader_idx: int = 0, 93 | ) -> None: 94 | if batch_idx == 0: 95 | output = trainer.model(batch["input"]) 96 | fig = self.plot_spectrum(batch, output) 97 | for logger in trainer.loggers: 98 | # Check WandbLogger and Enabled 99 | if issubclass(type(logger), WandbLogger) and not issubclass( 100 | type(logger.experiment), wandb.sdk.lib.disabled.RunDisabled 101 | ): 102 | logger: WandbLogger = logger 103 | logger.experiment.log({f"plot/{pl_module.current_epoch:03d}": fig}) 104 | -------------------------------------------------------------------------------- /astroclip/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .datamodule import AstroClipCollator, AstroClipDataloader 2 | from .dataset import AstroClipDataset 3 | -------------------------------------------------------------------------------- /astroclip/data/datamodule.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, List 2 | 3 | import datasets 4 | import lightning as L 5 | import torch 6 | from torch import Tensor 7 | from torch.utils.data.dataloader import default_collate 8 | from torchvision.transforms import CenterCrop 9 | 10 | from ..astrodino.data.augmentations import ToRGB 11 | 12 | 13 | class AstroClipDataloader(L.LightningDataModule): 14 | def __init__( 15 | self, 16 | path: str, 17 | columns: List[str] = ["image", "spectrum"], 18 | batch_size: int = 512, 19 | num_workers: int = 10, 20 | collate_fn: Callable[[Dict[str, Tensor]], Dict[str, Tensor]] = None, 21 | ) -> None: 22 | super().__init__() 23 | self.save_hyperparameters() 24 | 25 | def setup(self, stage: str) -> None: 26 | self.dataset = datasets.load_from_disk(self.hparams.path) 27 | self.dataset.set_format(type="torch", columns=self.hparams.columns) 28 | 29 | def train_dataloader(self): 30 | return torch.utils.data.DataLoader( 31 | self.dataset["train"], 32 | batch_size=self.hparams.batch_size, 33 | shuffle=True, 34 | num_workers=self.hparams.num_workers, # NOTE: disable for debugging 35 | drop_last=True, 36 | collate_fn=self.hparams.collate_fn, 37 | ) 38 | 39 | def val_dataloader(self): 40 | return torch.utils.data.DataLoader( 41 | self.dataset["test"], 42 | batch_size=self.hparams.batch_size, 43 | num_workers=self.hparams.num_workers, # NOTE: disable for debugging 44 | drop_last=True, 45 | collate_fn=self.hparams.collate_fn, 46 | ) 47 | 48 | 49 | class AstroClipCollator: 50 | def __init__( 51 | self, 52 | center_crop: int = 144, 53 | bands: List[str] = ["g", "r", "z"], 54 | m: float = 0.03, 55 | Q: int = 20, 56 | ): 57 | self.center_crop = CenterCrop(center_crop) 58 | self.to_rgb = ToRGB(bands=bands, m=m, Q=Q) 59 | 60 | def _process_images(self, images): 61 | # convert to rgb 62 | img_outs = [] 63 | for img in images: 64 | rgb_img = torch.tensor(self.to_rgb(img)[None, :, :, :]) 65 | img_outs.append(rgb_img) 66 | images = torch.concatenate(img_outs) 67 | 68 | images = self.center_crop(images.permute(0, 3, 2, 1)) 69 | return images 70 | 71 | def __call__(self, samples): 72 | # collate and handle dimensions 73 | samples = default_collate(samples) 74 | # process images 75 | samples["image"] = self._process_images(samples["image"]) 76 | return samples 77 | -------------------------------------------------------------------------------- /astroclip/data/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Joint dataset of DESI Legacy Survey and DESI Early Data Release.""" 15 | 16 | from aiohttp import ClientTimeout 17 | import datasets 18 | import h5py 19 | import numpy as np 20 | 21 | _CITATION = """ 22 | """ 23 | 24 | _DESCRIPTION = """\ 25 | This dataset is designed for cross-modal learning between images and spectra of galaxies 26 | contained in the DESI Early Data Release and the Legacy Survey DR9. It contains roughly 150k 27 | examples of images and spectra of galaxies, with their redshifts and targetids. 28 | """ 29 | 30 | _HOMEPAGE = "" 31 | 32 | _LICENSE = "" 33 | 34 | _URLS = { 35 | "joint": "https://users.flatironinstitute.org/~flanusse/astroclip_desi.1.1.5.h5", 36 | } 37 | 38 | 39 | class AstroClipDataset(datasets.GeneratorBasedBuilder): 40 | """TODO: Short description of my dataset.""" 41 | 42 | VERSION = datasets.Version("1.1.5") 43 | 44 | BUILDER_CONFIGS = [ 45 | datasets.BuilderConfig( 46 | name="joint", 47 | version=VERSION, 48 | description="This part of the dataset covers examples from both specral and image domains", 49 | ), 50 | ] 51 | 52 | DEFAULT_CONFIG_NAME = "joint" 53 | 54 | def _info(self): 55 | if self.config.name == "joint": 56 | features = datasets.Features( 57 | { 58 | "image": datasets.Array3D(shape=(152, 152, 3), dtype="float32"), 59 | "spectrum": datasets.Array2D(shape=(7781, 1), dtype="float32"), 60 | "redshift": datasets.Value("float32"), 61 | "targetid": datasets.Value("int64"), 62 | } 63 | ) 64 | else: 65 | raise NotImplementedError( 66 | "Only the joint configuration is implemented for now" 67 | ) 68 | 69 | return datasets.DatasetInfo( 70 | description=_DESCRIPTION, 71 | features=features, 72 | homepage=_HOMEPAGE, 73 | license=_LICENSE, 74 | citation=_CITATION, 75 | ) 76 | 77 | def _split_generators(self, dl_manager): 78 | urls = _URLS[self.config.name] 79 | dl_manager.download_config.storage_options["timeout"] = ClientTimeout(total=5000, connect=1000) 80 | data_dir = dl_manager.download_and_extract(urls) 81 | return [ 82 | datasets.SplitGenerator( 83 | name=datasets.Split.TRAIN, 84 | gen_kwargs={ 85 | "filepath": data_dir, 86 | "split": "train", 87 | }, 88 | ), 89 | datasets.SplitGenerator( 90 | name=datasets.Split.TEST, 91 | gen_kwargs={"filepath": data_dir, "split": "test"}, 92 | ), 93 | ] 94 | 95 | def _generate_examples(self, filepath, split): 96 | """Yields examples.""" 97 | with h5py.File(filepath) as d: 98 | for i in range(10): 99 | # Access the data 100 | images = d[str(i)]["images"] 101 | spectra = d[str(i)]["spectra"] 102 | redshifts = d[str(i)]["redshifts"] 103 | targetids = d[str(i)]["targetids"] 104 | 105 | dset_size = len(targetids) 106 | 107 | if split == "train": 108 | dset_range = (0, int(0.8 * dset_size)) 109 | else: 110 | dset_range = (int(0.8 * dset_size), dset_size) 111 | 112 | for j in range(dset_range[0], dset_range[1]): 113 | yield str(targetids[j]), { 114 | "image": np.array(images[j]).astype("float32"), 115 | "spectrum": np.reshape(spectra[j], [-1, 1]).astype("float32"), 116 | "redshift": redshifts[j], 117 | "targetid": targetids[j], 118 | } 119 | -------------------------------------------------------------------------------- /astroclip/env.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | from pathlib import Path 3 | from tempfile import NamedTemporaryFile 4 | from typing import TypeVar 5 | from warnings import warn 6 | 7 | from dotenv import dotenv_values 8 | 9 | WARN_ONCE = True 10 | 11 | 12 | # TODO: change here the defaults 13 | ASTROCLIP_ROOT = "/mnt/ceph/users/polymathic/astroclip" 14 | WANDB_ENTITY_NAME = "flatiron-scipt" 15 | 16 | 17 | def default_dotenv_values(): 18 | """Use a default .env but tell the user how to create their own.""" 19 | 20 | env_dir = Path(__file__).parent 21 | env_path = env_dir / ".env" 22 | 23 | if env_path.exists(): 24 | return dotenv_values(env_path) 25 | 26 | with NamedTemporaryFile(mode="w+") as f: 27 | global WARN_ONCE 28 | 29 | # TODO: these should be replaced with a folder in the project's root 30 | f.write("ASTROCLIP_ROOT={ASTROCLIP_ROOT}\n") 31 | f.write('WANDB_ENTITY_NAME="{WANDB_ENTITY_NAME}"\n') 32 | f.flush() 33 | 34 | if WARN_ONCE: 35 | f.seek(0) 36 | warn( 37 | f"No .env file found in {env_dir}. " 38 | "Using default environment variables for rusty. " 39 | f"To suppress this warning, create {env_dir}/.env with, e.g., the following content:\n" 40 | f"{f.read()}" 41 | ) 42 | WARN_ONCE = False 43 | 44 | return dotenv_values(f.name) 45 | 46 | 47 | T = TypeVar("T") 48 | 49 | 50 | def format_with_env(s: T) -> T: 51 | if isinstance(s, str): 52 | for k, v in default_dotenv_values().items(): 53 | s = s.replace("{" + k + "}", v) 54 | return s 55 | elif isinstance(s, dict): 56 | return {k: format_with_env(v) for k, v in s.items()} 57 | elif isinstance(s, list): 58 | return [format_with_env(v) for v in s] 59 | elif isinstance(s, Namespace): 60 | return type(s)(**{k: format_with_env(v) for k, v in s.__dict__.items()}) 61 | else: 62 | return s 63 | -------------------------------------------------------------------------------- /astroclip/models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import astroclip 2 | from .astroclip import AstroClipModel 3 | from .loader import load_model 4 | from .moco_v2 import Moco_v2 5 | from .specformer import SpecFormer 6 | -------------------------------------------------------------------------------- /astroclip/models/astroclip.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from typing import Tuple 4 | 5 | import lightning as L 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from dinov2.eval.setup import setup_and_build_model 11 | 12 | from ..modules import MLP, CrossAttentionHead 13 | from .specformer import SpecFormer 14 | 15 | 16 | class AstroClipModel(L.LightningModule): 17 | def __init__( 18 | self, 19 | image_encoder: nn.Module, 20 | spectrum_encoder: nn.Module, 21 | temperature: float = 15.5, 22 | lr: float = 1e-4, 23 | weight_decay: float = 0.05, 24 | epochs: int = 100, 25 | eta_min: float = 5e-7, 26 | logit_scale: float = 15.5, 27 | learnable_logit_scale: bool = False, 28 | ): 29 | """ 30 | The AstroCLIP model that takes an image and a spectrum and embeds them into a common space using CLIP loss. 31 | Note that you must provide the image and spectrum encoders to be used for the embedding. 32 | 33 | Args: 34 | image_encoder (nn.Module): The image encoder to be used for embedding. 35 | spectrum_encoder (nn.Module): The spectrum encoder to be used for embedding. 36 | temperature (float): The temperature parameter for the CLIP loss. 37 | lr (float): The learning rate for the optimizer. 38 | weight_decay (float): The weight decay for the optimizer. 39 | epochs (int): The number of epochs for training. 40 | eta_min (float): The minimum learning rate for the scheduler. 41 | logit_scale (float): The logit scale for the CLIP loss. 42 | learnable_logit_scale (bool): Whether the logit scale should be learnable. 43 | """ 44 | super().__init__() 45 | self.save_hyperparameters() 46 | 47 | # Define the image and spectrum encoder 48 | self.image_encoder = image_encoder 49 | self.spectrum_encoder = spectrum_encoder 50 | 51 | # Logit scale is fixed to 15.5 and is not a learnable parameter 52 | if not learnable_logit_scale: 53 | self.logit_scale = np.log(logit_scale) 54 | else: 55 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(logit_scale)) 56 | 57 | # Use CLIP loss 58 | self.criterion = CLIPLoss() 59 | 60 | def forward( 61 | self, 62 | input: torch.Tensor, 63 | input_type: str, 64 | ): 65 | if input_type == "image": 66 | return self.image_encoder(input) 67 | 68 | elif input_type == "spectrum": 69 | return self.spectrum_encoder(input) 70 | 71 | else: 72 | raise ValueError("Input type must be either 'image' or 'spectrum'") 73 | 74 | def training_step(self, batch, batch_idx): 75 | im, sp = batch["image"], batch["spectrum"] 76 | 77 | # Get the image and spectrum features 78 | image_features = self.image_encoder(im) 79 | spectrum_features = self.spectrum_encoder(sp) 80 | 81 | # Calculate the CLIP loss 82 | loss_withlogit = self.criterion( 83 | image_features, spectrum_features, self.hparams.temperature 84 | ) 85 | loss_nologit = self.criterion( 86 | image_features, spectrum_features, self.hparams.logit_scale 87 | ) 88 | 89 | # Log the losses 90 | self.log("train_loss_withlogit", loss_withlogit) 91 | self.log("train_loss_nologit", loss_nologit) 92 | self.log("scale", self.logit_scale) 93 | 94 | # Return the loss 95 | return loss_withlogit 96 | 97 | def validation_step(self, batch, batch_idx): 98 | im, sp = batch["image"], batch["spectrum"] 99 | 100 | # Get the image and spectrum features 101 | image_features = self.image_encoder(im) 102 | spectrum_features = self.spectrum_encoder(sp) 103 | 104 | # Calculate the CLIP loss 105 | val_loss_nologit = self.criterion( 106 | image_features, spectrum_features, self.hparams.logit_scale 107 | ) 108 | val_loss_withlogit = self.criterion( 109 | image_features, spectrum_features, self.hparams.temperature 110 | ) 111 | 112 | # Log the losses 113 | self.log("val_loss_nologit", val_loss_nologit) 114 | self.log("val_loss_withlogit", val_loss_withlogit) 115 | 116 | 117 | class CLIPLoss(nn.Module): 118 | def get_logits( 119 | self, 120 | image_features: torch.FloatTensor, 121 | spectrum_features: torch.FloatTensor, 122 | logit_scale: float, 123 | ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: 124 | # Normalize image features 125 | image_features = F.normalize(image_features, dim=-1, eps=1e-3) 126 | 127 | # Normalize spectrum features 128 | spectrum_features = F.normalize(spectrum_features, dim=-1, eps=1e-3) 129 | 130 | # Calculate the logits for the image and spectrum features 131 | logits_per_image = logit_scale * image_features @ spectrum_features.T 132 | return logits_per_image, logits_per_image.T 133 | 134 | def forward( 135 | self, 136 | image_features: torch.FloatTensor, 137 | spectrum_features: torch.FloatTensor, 138 | logit_scale: float, 139 | output_dict: bool = False, 140 | ) -> torch.FloatTensor: 141 | # Get the logits for the image and spectrum features 142 | logits_per_image, logits_per_spectrum = self.get_logits( 143 | image_features, spectrum_features, logit_scale 144 | ) 145 | 146 | # Calculate the contrastive loss 147 | labels = torch.arange( 148 | logits_per_image.shape[0], device=image_features.device, dtype=torch.long 149 | ) 150 | total_loss = ( 151 | F.cross_entropy(logits_per_image, labels) 152 | + F.cross_entropy(logits_per_spectrum, labels) 153 | ) / 2 154 | return {"contrastive_loss": total_loss} if output_dict else total_loss 155 | 156 | 157 | class ImageHead(nn.Module): 158 | def __init__( 159 | self, 160 | config: str, 161 | model_weights: str, 162 | save_directory: str, 163 | embed_dim: int = 1024, 164 | n_head: int = 4, 165 | model_embed_dim: int = 1024, 166 | dropout: float = 0.1, 167 | freeze_backbone: bool = True, 168 | ): 169 | """ 170 | Cross-attention image module that takes token outputs from the AstroDINO model and passes them through a 171 | cross-attention mechanism and MLP to get the final embedding. 172 | 173 | Args: 174 | save_directory (str): Path to the directory containing the AstroDINO model. 175 | config (str): Path to the configuration file of the AstroDINO model. 176 | model_weights (str): Path to the weights of the AstroDINO model. 177 | embed_dim (int): Dimension of the AstroCLIP embedding. 178 | n_head (int): Number of heads in the multihead attention. 179 | model_embed_dim (int): Dimension of the AstroDINO embedding. 180 | dropout (float): Dropout rate for MLP layers. 181 | freeze_backbone (bool): Whether to freeze the backbone of the AstroDINO model. 182 | """ 183 | super().__init__() 184 | 185 | # Define DINO config 186 | class config: 187 | output_dir = save_directory 188 | config_file = config 189 | pretrained_weights = model_weights 190 | opts = [] 191 | 192 | # Define DINO model 193 | sys.stdout = open(os.devnull, "w") # Redirect stdout to null 194 | self.backbone, _ = setup_and_build_model(config()) 195 | sys.stdout = sys.__stdout__ # Reset stdout 196 | 197 | # Freeze backbone if necessary 198 | self.freeze_backbone = freeze_backbone 199 | if self.freeze_backbone: 200 | for param in self.backbone.parameters(): 201 | param.requires_grad = False 202 | 203 | # Set up cross-attention 204 | self.cross_attention = CrossAttentionHead( 205 | embed_dim=embed_dim, 206 | n_head=n_head, 207 | model_embed_dim=model_embed_dim, 208 | dropout=dropout, 209 | ) 210 | 211 | # Set up MLP 212 | self.mlp = MLP( 213 | in_features=embed_dim, 214 | hidden_features=4 * embed_dim, 215 | dropout=dropout, 216 | ) 217 | 218 | def forward(self, x: torch.tensor, return_weights: bool = False): 219 | # Pass through the backbone 220 | with torch.set_grad_enabled(not self.freeze_backbone): 221 | x = self.backbone.patch_embed(x) 222 | for blk in self.backbone.blocks: 223 | x = blk(x) 224 | embedding = self.backbone.norm(x) 225 | 226 | # Pass through cross-attention 227 | x, attentions = self.cross_attention(embedding) 228 | 229 | # Pass through MLP and residual connection 230 | x = self.mlp(x) 231 | 232 | if return_weights: 233 | return x.squeeze(), attentions[1] 234 | 235 | return x.squeeze() 236 | 237 | 238 | class SpectrumHead(nn.Module): 239 | def __init__( 240 | self, 241 | model_path: str, 242 | embed_dim: int = 1024, 243 | n_head: int = 4, 244 | model_embed_dim: int = 768, 245 | dropout: float = 0.1, 246 | freeze_backbone: bool = True, 247 | load_pretrained_weights=True, 248 | ): 249 | """ 250 | Cross-attention spectrum module that takes a spectrum and passes it through a pretrained SpecFormer model and 251 | then through a cross-attention mechanism and MLP to get the final embedding. 252 | 253 | Args: 254 | save_path (str): Path to the checkpoint of the SpecFormer model. 255 | embed_dim (int): Dimension of the AstroCLIP embedding. 256 | n_head (int): Number of heads in the multihead attention. 257 | model_embed_dim (int): Dimension of the SpecFormer embedding. 258 | dropout (float): Dropout rate for MLP layers. 259 | freeze_backbone (bool): Whether to freeze the backbone of the SpecFormer model. 260 | """ 261 | super().__init__() 262 | # Load the model from the checkpoint 263 | checkpoint = torch.load(model_path) 264 | self.backbone = SpecFormer(**checkpoint["hyper_parameters"]) 265 | if load_pretrained_weights: 266 | self.backbone.load_state_dict(checkpoint["state_dict"]) 267 | 268 | # Freeze backbone if necessary 269 | self.freeze_backbone = freeze_backbone 270 | if self.freeze_backbone: 271 | for param in self.backbone.parameters(): 272 | param.requires_grad = False 273 | 274 | # Set up cross-attention 275 | self.cross_attention = CrossAttentionHead( 276 | embed_dim=embed_dim, 277 | n_head=n_head, 278 | model_embed_dim=model_embed_dim, 279 | dropout=dropout, 280 | ) 281 | 282 | # Set up MLP 283 | self.mlp = MLP( 284 | in_features=embed_dim, 285 | hidden_features=4 * embed_dim, 286 | dropout=dropout, 287 | ) 288 | 289 | def forward( 290 | self, x: torch.tensor, y: torch.tensor = None, return_weights: bool = False 291 | ): 292 | # Embed the spectrum using the pretrained model 293 | with torch.set_grad_enabled(not self.freeze_backbone): 294 | embedding = self.backbone(x)["embedding"] 295 | 296 | # Pass through cross-attention 297 | x, attentions = self.cross_attention(embedding) 298 | 299 | # Pass through MLP and residual connection 300 | x = x + self.mlp(x) 301 | 302 | if return_weights: 303 | return x.squeeze(), attentions[1] 304 | 305 | return x.squeeze() 306 | -------------------------------------------------------------------------------- /astroclip/models/loader.py: -------------------------------------------------------------------------------- 1 | import joblib 2 | from huggingface_hub import hf_hub_download 3 | 4 | 5 | def load_model(repo_id, filename): 6 | model = joblib.load(hf_hub_download(repo_id=repo_id, filename=filename)) 7 | return model 8 | -------------------------------------------------------------------------------- /astroclip/models/moco_v2.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from typing import Union 3 | 4 | import lightning as L 5 | import torch 6 | import torch.nn.functional as F 7 | import torchvision 8 | from lightning import Trainer 9 | from torch import nn 10 | 11 | 12 | class Moco_v2(L.LightningModule): 13 | """PyTorch Lightning implementation of `Moco `_ 14 | 15 | Paper authors: Xinlei Chen, Haoqi Fan, Ross Girshick, Kaiming He. 16 | 17 | Code adapted from `facebookresearch/moco `_ to Lightning by: 18 | - `William Falcon `_ 19 | """ 20 | 21 | def __init__( 22 | self, 23 | base_encoder: Union[str, torch.nn.Module] = "resnet18", 24 | emb_dim: int = 128, 25 | num_negatives: int = 65536, 26 | encoder_momentum: float = 0.999, 27 | softmax_temperature: float = 0.07, 28 | learning_rate: float = 0.03, 29 | momentum: float = 0.9, 30 | weight_decay: float = 1e-4, 31 | data_dir: str = "./", 32 | batch_size: int = 256, 33 | use_mlp: bool = False, 34 | num_workers: int = 8, 35 | *args, 36 | **kwargs 37 | ): 38 | super().__init__() 39 | self.save_hyperparameters() 40 | 41 | # create the encoders 42 | # num_classes is the output fc dimension 43 | self.encoder_q, self.encoder_k = self.init_encoders(base_encoder) 44 | 45 | if use_mlp: # hack: brute-force replacement 46 | dim_mlp = self.encoder_q.fc.weight.shape[1] 47 | self.encoder_q.fc = nn.Sequential( 48 | nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc 49 | ) 50 | self.encoder_k.fc = nn.Sequential( 51 | nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc 52 | ) 53 | 54 | for param_q, param_k in zip( 55 | self.encoder_q.parameters(), self.encoder_k.parameters() 56 | ): 57 | param_k.data.copy_(param_q.data) # initialize 58 | param_k.requires_grad = False # not update by gradient 59 | 60 | # create the queue 61 | self.register_buffer("queue", torch.randn(emb_dim, num_negatives)) 62 | self.queue = nn.functional.normalize(self.queue, dim=0) 63 | 64 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 65 | 66 | # create the validation queue 67 | self.register_buffer("val_queue", torch.randn(emb_dim, num_negatives)) 68 | self.val_queue = nn.functional.normalize(self.val_queue, dim=0) 69 | 70 | self.register_buffer("val_queue_ptr", torch.zeros(1, dtype=torch.long)) 71 | 72 | def init_encoders(self, base_encoder): 73 | """Override to add your own encoders.""" 74 | 75 | template_model = getattr(torchvision.models, base_encoder) 76 | encoder_q = template_model(num_classes=self.hparams.emb_dim) 77 | encoder_k = template_model(num_classes=self.hparams.emb_dim) 78 | 79 | return encoder_q, encoder_k 80 | 81 | @torch.no_grad() 82 | def _momentum_update_key_encoder(self): 83 | """Momentum update of the key encoder.""" 84 | for param_q, param_k in zip( 85 | self.encoder_q.parameters(), self.encoder_k.parameters() 86 | ): 87 | em = self.hparams.encoder_momentum 88 | param_k.data = param_k.data * em + param_q.data * (1.0 - em) 89 | 90 | @torch.no_grad() 91 | def _dequeue_and_enqueue(self, keys, queue_ptr, queue): 92 | # gather keys before updating queue 93 | if self._use_ddp_or_ddp2(self.trainer): 94 | keys = concat_all_gather(keys) 95 | 96 | batch_size = keys.shape[0] 97 | 98 | ptr = int(queue_ptr) 99 | assert self.hparams.num_negatives % batch_size == 0 # for simplicity 100 | 101 | # replace the keys at ptr (dequeue and enqueue) 102 | queue[:, ptr : ptr + batch_size] = keys.T 103 | ptr = (ptr + batch_size) % self.hparams.num_negatives # move pointer 104 | 105 | queue_ptr[0] = ptr 106 | 107 | @torch.no_grad() 108 | def _batch_shuffle_ddp(self, x): # pragma: no cover 109 | """Batch shuffle, for making use of BatchNorm. 110 | 111 | *** Only support DistributedDataParallel (DDP) model. *** 112 | """ 113 | # gather from all gpus 114 | batch_size_this = x.shape[0] 115 | x_gather = concat_all_gather(x) 116 | batch_size_all = x_gather.shape[0] 117 | 118 | num_gpus = batch_size_all // batch_size_this 119 | 120 | # random shuffle index 121 | idx_shuffle = torch.randperm(batch_size_all).cuda() 122 | 123 | # broadcast to all gpus 124 | torch.distributed.broadcast(idx_shuffle, src=0) 125 | 126 | # index for restoring 127 | idx_unshuffle = torch.argsort(idx_shuffle) 128 | 129 | # shuffled index for this gpu 130 | gpu_idx = torch.distributed.get_rank() 131 | idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] 132 | 133 | return x_gather[idx_this], idx_unshuffle 134 | 135 | @torch.no_grad() 136 | def _batch_unshuffle_ddp(self, x, idx_unshuffle): # pragma: no cover 137 | """Undo batch shuffle. 138 | 139 | *** Only support DistributedDataParallel (DDP) model. *** 140 | """ 141 | # gather from all gpus 142 | batch_size_this = x.shape[0] 143 | x_gather = concat_all_gather(x) 144 | batch_size_all = x_gather.shape[0] 145 | 146 | num_gpus = batch_size_all // batch_size_this 147 | 148 | # restored index for this gpu 149 | gpu_idx = torch.distributed.get_rank() 150 | idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] 151 | 152 | return x_gather[idx_this] 153 | 154 | def forward(self, img_q, img_k, queue): 155 | """ 156 | Input: 157 | im_q: a batch of query images 158 | im_k: a batch of key images 159 | queue: a queue from which to pick negative samples 160 | Output: 161 | logits, targets 162 | """ 163 | 164 | # compute query features 165 | q = self.encoder_q(img_q) # queries: NxC 166 | q = nn.functional.normalize(q, dim=1) 167 | 168 | # compute key features 169 | with torch.no_grad(): # no gradient to keys 170 | # shuffle for making use of BN 171 | if self._use_ddp_or_ddp2(self.trainer): 172 | img_k, idx_unshuffle = self._batch_shuffle_ddp(img_k) 173 | 174 | k = self.encoder_k(img_k) # keys: NxC 175 | k = nn.functional.normalize(k, dim=1) 176 | 177 | # undo shuffle 178 | if self._use_ddp_or_ddp2(self.trainer): 179 | k = self._batch_unshuffle_ddp(k, idx_unshuffle) 180 | 181 | # compute logits 182 | # Einstein sum is more intuitive 183 | # positive logits: Nx1 184 | l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1) 185 | # negative logits: NxK 186 | l_neg = torch.einsum("nc,ck->nk", [q, queue.clone().detach()]) 187 | 188 | # logits: Nx(1+K) 189 | logits = torch.cat([l_pos, l_neg], dim=1) 190 | 191 | # apply temperature 192 | logits /= self.hparams.softmax_temperature 193 | 194 | # labels: positive key indicators 195 | labels = torch.zeros(logits.shape[0], dtype=torch.long) 196 | labels = labels.type_as(logits) 197 | 198 | return logits, labels, k 199 | 200 | def training_step(self, batch, batch_idx): 201 | # in STL10 we pass in both lab+unl for online ft 202 | if self.trainer.datamodule.name == "stl10": 203 | # labeled_batch = batch[1] 204 | unlabeled_batch = batch[0] 205 | batch = unlabeled_batch 206 | 207 | (img_1, img_2), _ = batch 208 | 209 | self._momentum_update_key_encoder() # update the key encoder 210 | output, target, keys = self(img_q=img_1, img_k=img_2, queue=self.queue) 211 | self._dequeue_and_enqueue( 212 | keys, queue=self.queue, queue_ptr=self.queue_ptr 213 | ) # dequeue and enqueue 214 | 215 | loss = F.cross_entropy(output.float(), target.long()) 216 | log = {"train_loss": loss} 217 | self.log_dict(log) 218 | return loss 219 | 220 | def validation_step(self, batch, batch_idx): 221 | # in STL10 we pass in both lab+unl for online ft 222 | if self.trainer.datamodule.name == "stl10": 223 | # labeled_batch = batch[1] 224 | unlabeled_batch = batch[0] 225 | batch = unlabeled_batch 226 | 227 | (img_1, img_2), labels = batch 228 | 229 | output, target, keys = self(img_q=img_1, img_k=img_2, queue=self.val_queue) 230 | self._dequeue_and_enqueue( 231 | keys, queue=self.val_queue, queue_ptr=self.val_queue_ptr 232 | ) # dequeue and enqueue 233 | 234 | loss = F.cross_entropy(output, target.long()) 235 | results = {"val_loss": loss} 236 | return results 237 | 238 | def validation_epoch_end(self, outputs): 239 | log = {"val_loss": val_loss, "val_acc1": val_acc1, "val_acc5": val_acc5} 240 | self.log_dict(log) 241 | 242 | def configure_optimizers(self): 243 | optimizer = torch.optim.SGD( 244 | self.parameters(), 245 | self.hparams.learning_rate, 246 | momentum=self.hparams.momentum, 247 | weight_decay=self.hparams.weight_decay, 248 | ) 249 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 250 | optimizer, 251 | self.trainer.max_epochs, 252 | ) 253 | return [optimizer], [scheduler] 254 | 255 | @staticmethod 256 | def add_model_specific_args(parent_parser): 257 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 258 | parser.add_argument("--base_encoder", type=str, default="resnet18") 259 | parser.add_argument("--emb_dim", type=int, default=128) 260 | parser.add_argument("--num_workers", type=int, default=8) 261 | parser.add_argument("--num_negatives", type=int, default=65536) 262 | parser.add_argument("--encoder_momentum", type=float, default=0.999) 263 | parser.add_argument("--softmax_temperature", type=float, default=0.07) 264 | parser.add_argument("--learning_rate", type=float, default=0.03) 265 | parser.add_argument("--momentum", type=float, default=0.9) 266 | parser.add_argument("--weight_decay", type=float, default=1e-4) 267 | parser.add_argument("--data_dir", type=str, default="./") 268 | parser.add_argument( 269 | "--dataset", 270 | type=str, 271 | default="cifar10", 272 | choices=["cifar10", "imagenet2012", "stl10"], 273 | ) 274 | parser.add_argument("--batch_size", type=int, default=256) 275 | parser.add_argument("--use_mlp", action="store_true") 276 | parser.add_argument( 277 | "--meta_dir", default=".", type=str, help="path to meta.bin for imagenet" 278 | ) 279 | return parser 280 | 281 | @staticmethod 282 | def _use_ddp_or_ddp2(trainer: Trainer) -> bool: 283 | return isinstance(trainer.training_type_plugin, (DDPPlugin, DDP2Plugin)) 284 | -------------------------------------------------------------------------------- /astroclip/models/specformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import lightning as L 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch import Tensor 9 | 10 | from ..modules import LayerNorm, TransformerBlock, _init_by_depth 11 | 12 | 13 | class SpecFormer(L.LightningModule): 14 | def __init__( 15 | self, 16 | input_dim: int, 17 | embed_dim: int, 18 | num_layers: int, 19 | num_heads: int, 20 | max_len: int, 21 | mask_num_chunks: int = 6, 22 | mask_chunk_width: int = 50, 23 | slice_section_length: int = 20, 24 | slice_overlap: int = 10, 25 | dropout: float = 0.1, 26 | norm_first: bool = False, 27 | ): 28 | super().__init__() 29 | self.save_hyperparameters() 30 | 31 | self.data_embed = nn.Linear(input_dim, embed_dim) 32 | self.position_embed = nn.Embedding(max_len, embed_dim) 33 | self.dropout = nn.Dropout(dropout) 34 | self.blocks = nn.ModuleList( 35 | [ 36 | TransformerBlock( 37 | embedding_dim=embed_dim, 38 | num_heads=num_heads, 39 | causal=False, 40 | dropout=dropout, 41 | bias=True, 42 | ) 43 | for _ in range(num_layers) 44 | ] 45 | ) 46 | self.final_layernorm = LayerNorm(embed_dim, bias=True) 47 | self.head = nn.Linear(embed_dim, input_dim, bias=True) 48 | 49 | self._reset_parameters_datapt() 50 | 51 | def forward(self, x: Tensor) -> torch.Tensor: 52 | """Forward pass through the model.""" 53 | x = self.preprocess(x) 54 | return self.forward_without_preprocessing(x) 55 | 56 | def forward_without_preprocessing(self, x: Tensor): 57 | """Forward pass through the model. 58 | The training step performs masking before preprocessing, 59 | thus samples should not be preprocessed again as in forward()""" 60 | 61 | t = x.shape[1] 62 | if t > self.hparams.max_len: 63 | raise ValueError( 64 | f"Cannot forward sequence of length {t}, " 65 | f"block size is only {self.hparams.max_len}" 66 | ) 67 | pos = torch.arange(0, t, dtype=torch.long, device=x.device) # shape (t) 68 | 69 | # forward the GPT model itself 70 | data_emb = self.data_embed(x) # to shape (b, t, embedding_dim) 71 | pos_emb = self.position_embed(pos) # to shape (t, embedding_dim) 72 | 73 | x = self.dropout(data_emb + pos_emb) 74 | for block in self.blocks: 75 | x = block(x) 76 | x = self.final_layernorm(x) 77 | 78 | reconstructions = self.head(x) 79 | 80 | return {"reconstructions": reconstructions, "embedding": x} 81 | 82 | def training_step(self, batch): 83 | # slice the input and copy 84 | input = self.preprocess(batch["spectrum"]) 85 | target = torch.clone(input) 86 | 87 | # mask parts of the input 88 | input = self.mask_sequence(input) 89 | # forward pass 90 | output = self.forward_without_preprocessing(input)["reconstructions"] 91 | 92 | # find the mask locations 93 | locs = (input != target).type_as(output) 94 | loss = F.mse_loss(output * locs, target * locs, reduction="mean") / locs.mean() 95 | self.log("training_loss", loss, prog_bar=True) 96 | return loss 97 | 98 | def validation_step(self, batch): 99 | # slice the input and copy 100 | input = self.preprocess(batch["spectrum"]) 101 | target = torch.clone(input) 102 | 103 | # mask parts of the input 104 | input = self.mask_sequence(input) 105 | 106 | # forward pass 107 | output = self.forward_without_preprocessing(input)["reconstructions"] 108 | 109 | # find the mask locations 110 | locs = (input != target).type_as(output) 111 | loss = F.mse_loss(output * locs, target * locs, reduction="mean") / locs.mean() 112 | self.log("val_training_loss", loss, prog_bar=True) 113 | return loss 114 | 115 | def mask_sequence(self, x: Tensor): 116 | """Mask batched sequence""" 117 | return torch.stack([self._mask_seq(el) for el in x]) 118 | 119 | def preprocess(self, x): 120 | std, mean = x.std(1, keepdim=True).clip_(0.2), x.mean(1, keepdim=True) 121 | x = (x - mean) / std 122 | x = self._slice(x) 123 | x = F.pad(x, pad=(2, 0, 1, 0), mode="constant", value=0) 124 | x[:, 0, 0] = (mean.squeeze() - 2) / 2 125 | x[:, 0, 1] = (std.squeeze() - 2) / 8 126 | return x 127 | 128 | def _reset_parameters_datapt(self): 129 | # not scaling the initial embeddngs. 130 | for emb in [self.data_embed, self.position_embed]: 131 | std = 1 / math.sqrt(self.hparams.embed_dim) 132 | nn.init.trunc_normal_(emb.weight, std=std, a=-3 * std, b=3 * std) 133 | 134 | # transformer block weights 135 | self.blocks.apply(lambda m: _init_by_depth(m, self.hparams.num_layers)) 136 | self.head.apply(lambda m: _init_by_depth(m, 1 / 2)) 137 | 138 | def _slice(self, x): 139 | start_indices = np.arange( 140 | 0, 141 | x.shape[1] - self.hparams.slice_overlap, 142 | self.hparams.slice_section_length - self.hparams.slice_overlap, 143 | ) 144 | sections = [ 145 | x[:, start : start + self.hparams.slice_section_length].transpose(1, 2) 146 | for start in start_indices 147 | ] 148 | 149 | # If the last section is not of length 'section_length', you can decide whether to keep or discard it 150 | if sections[-1].shape[1] < self.hparams.slice_section_length: 151 | sections.pop(-1) # Discard the last section 152 | 153 | return torch.cat(sections, 1) 154 | 155 | def _mask_seq(self, seq: torch.Tensor) -> torch.Tensor: 156 | """Randomly masks contiguous sections of an unbatched sequence, 157 | ensuring separation between chunks is at least chunk_width.""" 158 | len_ = seq.shape[0] 159 | num_chunks = self.hparams.mask_num_chunks 160 | chunk_width = self.hparams.mask_chunk_width 161 | 162 | # Ensure there's enough space for the chunks and separations 163 | total_width_needed = num_chunks * chunk_width + (num_chunks - 1) * chunk_width 164 | if total_width_needed > len_: 165 | raise ValueError("Sequence is too short to mask") 166 | 167 | masked_seq = seq.clone() 168 | 169 | for i in range(num_chunks): 170 | start = (i * len_) // num_chunks 171 | loc = torch.randint(0, len_ // num_chunks - chunk_width, (1,)).item() 172 | masked_seq[loc + start : loc + start + chunk_width] = 0 173 | 174 | return masked_seq 175 | -------------------------------------------------------------------------------- /astroclip/scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.optim.lr_scheduler import LRScheduler 5 | 6 | 7 | class CosineAnnealingWithWarmupLR(LRScheduler): 8 | """A cosine-annealing learning rate scheduler with initial warmup. 9 | 10 | Currently this cuts off after one cycle. The interface is otherwise compatible with 11 | :class:`~torch.optim.lr_scheduler.CosineAnnealingLR`. 12 | 13 | :param optimizer: wrapped optimizer 14 | :param T_max: maximum number of iterations; unlike 15 | :class:`~torch.optim.lr_scheduler.CosineAnnealingLR`, this scheduler fixes the 16 | learning rate to :attr:`eta_min` after :attr:`T_max` iterations 17 | :param T_warmup: number of steps during which to use linear warmup 18 | :param eta_min: minimum learning rate 19 | :param last_epoch: index of last epoch 20 | :param verbose: whether to print a message to `stdout` for each update 21 | """ 22 | 23 | def __init__( 24 | self, 25 | optimizer: torch.optim.Optimizer, 26 | T_max: int, 27 | T_warmup: int = 0, 28 | eta_min: float = 0, 29 | last_epoch: int = -1, 30 | verbose: bool = False, 31 | ): 32 | self.T_max = T_max 33 | self.T_warmup = T_warmup 34 | self.eta_min = eta_min 35 | 36 | super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose) 37 | 38 | def get_lr(self) -> float: 39 | if not self._get_lr_called_within_step: 40 | print( 41 | "To get the last learning rate computed by the scheduler, " 42 | "please use `get_last_lr()`." 43 | ) 44 | 45 | if self.last_epoch < self.T_warmup: 46 | # linear warmup 47 | # T_warmup > last_epoch >= 0 so no division by zero 48 | return [ 49 | base_lr * self.last_epoch / self.T_warmup for base_lr in self.base_lrs 50 | ] 51 | elif self.last_epoch >= self.T_max: 52 | return [self.eta_min for _ in self.base_lrs] 53 | else: 54 | i = self.last_epoch - self.T_warmup 55 | n = self.T_max - self.T_warmup 56 | decay_ratio = i / n 57 | 58 | assert 0 <= decay_ratio <= 1 59 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) 60 | 61 | # coeff is between 0 and 1 so lr is between eta_min and base_lr 62 | return [ 63 | self.eta_min + coeff * (base_lr - self.eta_min) 64 | for base_lr in self.base_lrs 65 | ] 66 | -------------------------------------------------------------------------------- /astroclip/trainer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from typing import Any, Optional 3 | 4 | import matplotlib.pyplot as plt 5 | import wandb 6 | from lightning import Callback, LightningModule, Trainer 7 | from lightning.pytorch.cli import ( 8 | ArgsType, 9 | LightningArgumentParser, 10 | LightningCLI, 11 | LRSchedulerTypeUnion, 12 | ) 13 | from lightning.pytorch.loggers import WandbLogger 14 | from torch.optim import Optimizer 15 | 16 | from astroclip import format_with_env 17 | from astroclip.callbacks import CustomSaveConfigCallback 18 | 19 | 20 | class WrappedLightningCLI(LightningCLI): 21 | def before_instantiate_classes(self) -> None: 22 | self.config = format_with_env(self.config) 23 | 24 | # Changing the lr_scheduler interval to step instead of epoch 25 | @staticmethod 26 | def configure_optimizers( 27 | lightning_module: LightningModule, 28 | optimizer: Optimizer, 29 | lr_scheduler: Optional[LRSchedulerTypeUnion] = None, 30 | ) -> Any: 31 | optimizer_list, lr_scheduler_list = LightningCLI.configure_optimizers( 32 | lightning_module, optimizer=optimizer, lr_scheduler=lr_scheduler 33 | ) 34 | 35 | for idx in range(len(lr_scheduler_list)): 36 | if not isinstance(lr_scheduler_list[idx], dict): 37 | lr_scheduler_list[idx] = { 38 | "scheduler": lr_scheduler_list[idx], 39 | "interval": "step", 40 | } 41 | return optimizer_list, lr_scheduler_list 42 | 43 | 44 | def main_cli(args: ArgsType = None, run: bool = True): 45 | cli = WrappedLightningCLI( 46 | save_config_kwargs={"overwrite": True}, 47 | save_config_callback=CustomSaveConfigCallback, 48 | parser_kwargs={"parser_mode": "omegaconf"}, 49 | args=args, 50 | run=run, 51 | ) 52 | return cli 53 | 54 | 55 | if __name__ == "__main__": 56 | main_cli(run=True) 57 | -------------------------------------------------------------------------------- /configs/astroclip.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 42 2 | trainer: 3 | default_root_dir: "{ASTROCLIP_ROOT}/outputs" 4 | enable_checkpointing: true 5 | gradient_clip_val: 1. 6 | max_epochs: 100 7 | precision: null 8 | callbacks: 9 | - class_path: LearningRateMonitor 10 | init_args: 11 | logging_interval: "step" 12 | - class_path: ModelCheckpoint 13 | init_args: 14 | save_last: link 15 | save_top_k: 2 16 | every_n_epochs: 1 17 | monitor: "val_loss_nologit" 18 | logger: 19 | class_path: CustomWandbLogger 20 | init_args: 21 | project: "astroclip-alignment" 22 | entity: "{WANDB_ENTITY_NAME}" 23 | save_dir: ${trainer.default_root_dir} 24 | detect_anomaly: True 25 | model: 26 | class_path: AstroClipModel 27 | init_args: 28 | image_encoder: 29 | class_path: astroclip.models.astroclip.ImageHead 30 | init_args: 31 | config: "astroclip/astrodino/config.yaml" 32 | model_weights: "{ASTROCLIP_ROOT}/pretrained/astrodino.ckpt" 33 | save_directory: "{ASTROCLIP_ROOT}/outputs/astrodino" 34 | spectrum_encoder: 35 | class_path: astroclip.models.astroclip.SpectrumHead 36 | init_args: 37 | model_path: "{ASTROCLIP_ROOT}/pretrained/specformer.ckpt" 38 | data: 39 | class_path: AstroClipDataloader 40 | init_args: 41 | path: "{ASTROCLIP_ROOT}/datasets/astroclip_file/" 42 | columns: 43 | - image 44 | - spectrum 45 | batch_size: 256 46 | num_workers: 8 47 | collate_fn: 48 | class_path: astroclip.data.AstroClipCollator 49 | init_args: 50 | center_crop: 144 51 | optimizer: 52 | class_path: torch.optim.adamw.AdamW 53 | init_args: 54 | lr: 1e-4 55 | weight_decay: 0.05 56 | lr_scheduler: 57 | class_path: astroclip.CosineAnnealingWithWarmupLR 58 | init_args: 59 | T_max: 10_000 60 | T_warmup: 1_000 61 | eta_min: ${eval:'${optimizer.init_args.lr}//500'} 62 | -------------------------------------------------------------------------------- /configs/specformer.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 42 2 | trainer: 3 | default_root_dir: "{ASTROCLIP_ROOT}/outputs" 4 | enable_checkpointing: true 5 | gradient_clip_val: 1. 6 | max_steps: 500_000 7 | precision: null 8 | limit_val_batches: 100 9 | callbacks: 10 | - class_path: LearningRateMonitor 11 | init_args: 12 | logging_interval: "step" 13 | - class_path: ModelCheckpoint 14 | init_args: 15 | save_last: link 16 | # TODO: This needs to be updated with latest code 17 | #- class_path: PlotsCallback 18 | logger: 19 | class_path: CustomWandbLogger 20 | init_args: 21 | project: "astroclip-spectrum" 22 | entity: "{WANDB_ENTITY_NAME}" 23 | save_dir: ${trainer.default_root_dir} 24 | model: 25 | class_path: SpecFormer 26 | init_args: 27 | input_dim: 22 28 | embed_dim: 768 29 | num_layers: 6 30 | num_heads: 6 31 | max_len: 800 32 | dropout: 0. 33 | data: 34 | class_path: AstroClipDataloader 35 | init_args: 36 | path: "{ASTROCLIP_ROOT}/datasets/astroclip_file/" 37 | columns: 38 | - spectrum 39 | batch_size: 64 40 | num_workers: 0 41 | optimizer: 42 | class_path: torch.optim.adamw.AdamW 43 | init_args: 44 | lr: 1e-5 45 | weight_decay: 1e-1 46 | betas: 47 | - 0.9 48 | - 0.95 49 | lr_scheduler: 50 | class_path: CosineAnnealingWithWarmupLR 51 | init_args: 52 | T_max: ${trainer.max_steps} 53 | T_warmup: 2000 54 | eta_min: ${eval:'${optimizer.init_args.lr}//100'} 55 | -------------------------------------------------------------------------------- /downstream_tasks/morphology_classification/README.md: -------------------------------------------------------------------------------- 1 | ## Morphology Classification 2 | We demonstrate morphology classification using the GalaxyZoo DECaLS dataset. In particular, we use the classifications from GZD-5 (Walmsley, et al. (2022)), which includes over 7.5 million volunteer response classifications for roughly 314,000 galaxies on a variety of questions, including morphological T-types, strong bars, arm curvature, etc. 3 | 4 | ### Cross-Matching 5 | Cross-matching between GalaxyZoo DECaLS and the full DESI-LS survey is performed by running 6 | ```python 7 | python morphology_utils/cross_match.py 8 | ``` 9 | This creates a cross-matched table with containing the preprocessed DESI-LS survey images and their corresponding GalaxyZoo DECaLS volunteer classifications. Note that this assumes that the Legacy Survey images have been downloaded and correctly formatted in 10 | ```bash 11 | {ASTROCLIP_ROOT}/datasets/decals 12 | ``` 13 | If they are stored elsewhere, this can be specified using the `--root_dir` flag. 14 | 15 | ### Embedding 16 | The images are then embedded using 17 | ```python 18 | python embed_galaxy_zoo.py 19 | ``` 20 | This creates a table containing the embedded DESI-LS survey images and their corresponding GalaxyZoo DECaLS volunteer classifications. Note that this assumes that the cross-matching in the above step has already been performed. Additionally, it also assumes that all models have been downloaded and stored in the following directory: 21 | ```bash 22 | {ASTROCLIP_ROOT}/pretrained 23 | ``` 24 | 25 | ### Classification 26 | Once the embedded table has been generated, classification is performed in the `morphology_classification.ipynb` notebook; see that file for more details. 27 | -------------------------------------------------------------------------------- /downstream_tasks/morphology_classification/embed_galaxy_zoo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append("../..") 5 | 6 | from argparse import ArgumentParser 7 | from typing import Dict 8 | 9 | import numpy as np 10 | import torch 11 | from astropy.table import Table 12 | from dinov2.eval.setup import setup_and_build_model 13 | from tqdm import tqdm 14 | 15 | from astroclip.astrodino.utils import setup_astrodino 16 | from astroclip.env import format_with_env 17 | from astroclip.models import AstroClipModel, Moco_v2, SpecFormer 18 | 19 | 20 | def get_embeddings( 21 | image_models: Dict[str, torch.nn.Module], 22 | images: torch.Tensor, 23 | batch_size: int = 512, 24 | ) -> dict: 25 | """Get embeddings for images using models""" 26 | model_embeddings = {key: [] for key in image_models.keys()} 27 | im_batch = [] 28 | 29 | for image in tqdm(images): 30 | # Load images, already preprocessed 31 | im_batch.append(torch.tensor(image, dtype=torch.float32)[None, :, :, :]) 32 | 33 | # Get embeddings for batch 34 | if len(im_batch) == batch_size: 35 | with torch.no_grad(): 36 | images = torch.cat(im_batch).cuda() 37 | for key in image_models.keys(): 38 | model_embeddings[key].append(image_models[key](images)) 39 | 40 | im_batch = [] 41 | 42 | # Get embeddings for last batch 43 | if len(im_batch) > 0: 44 | with torch.no_grad(): 45 | images = torch.cat(im_batch).cuda() 46 | for key in image_models.keys(): 47 | model_embeddings[key].append(image_models[key](images)) 48 | 49 | model_embeddings = { 50 | key: np.concatenate(model_embeddings[key]) for key in model_embeddings.keys() 51 | } 52 | return model_embeddings 53 | 54 | 55 | def embed_galaxy_zoo( 56 | galaxy_zoo_file: str, 57 | pretrained_dir: str, 58 | batch_size: int = 128, 59 | ): 60 | # Get directories 61 | astrodino_output_dir = os.path.join(pretrained_dir, "astrodino_output_dir") 62 | 63 | pretrained_weights = {} 64 | for model in ["astroclip", "stein", "astrodino", "specformer"]: 65 | pretrained_weights[model] = os.path.join(pretrained_dir, f"{model}.ckpt") 66 | 67 | # Set up AstroCLIP 68 | astroclip = AstroClipModel.load_from_checkpoint( 69 | checkpoint_path=pretrained_weights["astroclip"], 70 | ) 71 | 72 | # Set up Stein, et al. model 73 | stein = Moco_v2.load_from_checkpoint( 74 | checkpoint_path=pretrained_weights["stein"], 75 | ).encoder_q 76 | 77 | # Set up AstroDINO model 78 | astrodino = setup_astrodino(astrodino_output_dir, pretrained_weights["astrodino"]) 79 | 80 | # Set up model dict 81 | image_models = { 82 | "astrodino": lambda x: astrodino(x).cpu().numpy(), 83 | "stein": lambda x: stein(x).cpu().numpy(), 84 | "astroclip": lambda x: astroclip(x, input_type="image").cpu().numpy(), 85 | } 86 | print("Models are correctly set up!") 87 | 88 | # Get embeddings 89 | galaxy_zoo = Table.read(galaxy_zoo_file) 90 | images = galaxy_zoo["image"] 91 | embeddings = get_embeddings(image_models, images, batch_size) 92 | 93 | # Remove images and replace with embeddings 94 | galaxy_zoo.remove_column("image") 95 | for key in embeddings.keys(): 96 | assert len(embeddings[key]) == len(galaxy_zoo), "Embeddings incorrect length" 97 | galaxy_zoo[f"{key}_embeddings"] = embeddings[key] 98 | 99 | # Save embeddings 100 | galaxy_zoo.write(galaxy_zoo_file.replace(".h5", "_embeddings.h5"), overwrite=True) 101 | 102 | 103 | if __name__ == "__main__": 104 | ASTROCLIP_ROOT = format_with_env("{ASTROCLIP_ROOT}") 105 | parser = ArgumentParser() 106 | parser.add_argument( 107 | "--galaxy_zoo_file", 108 | type=str, 109 | default=f"{ASTROCLIP_ROOT}/datasets/galaxy_zoo/gz5_decals_crossmatched.h5", 110 | ) 111 | parser.add_argument( 112 | "--pretrained_dir", 113 | type=str, 114 | default=f"{ASTROCLIP_ROOT}/pretrained", 115 | ) 116 | parser.add_argument("--batch_size", type=int, default=1024) 117 | args = parser.parse_args() 118 | 119 | embed_galaxy_zoo( 120 | args.galaxy_zoo_file, 121 | args.pretrained_dir, 122 | args.batch_size, 123 | ) 124 | -------------------------------------------------------------------------------- /downstream_tasks/morphology_classification/morphology_classification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os, sys\n", 10 | "\n", 11 | "sys.path.append(\"../..\")\n", 12 | "\n", 13 | "import torch\n", 14 | "import numpy as np\n", 15 | "from astropy.table import Table\n", 16 | "\n", 17 | "from astroclip.env import format_with_env\n", 18 | "from morphology_utils.models import train_eval_on_question\n", 19 | "from morphology_utils.plotting import plot_radar\n", 20 | "\n", 21 | "ASTROCLIP_ROOT = format_with_env(\"{ASTROCLIP_ROOT}\")\n", 22 | "\n", 23 | "\n", 24 | "# Load the data\n", 25 | "galaxy_zoo = Table.read(\n", 26 | " f\"{ASTROCLIP_ROOT}/datasets/galaxy_zoo/gz5_decals_crossmatched_embeddings.h5\"\n", 27 | ")\n", 28 | "\n", 29 | "# Remove the galaxies with fewer than 3 votes\n", 30 | "galaxy_zoo = galaxy_zoo[galaxy_zoo[\"smooth-or-featured_total-votes\"] >= 3]\n", 31 | "\n", 32 | "# Get the embeddings\n", 33 | "X = {\n", 34 | " \"AstroCLIP\": torch.tensor(galaxy_zoo[\"astroclip_embeddings\"]),\n", 35 | " \"AstroDINO\": torch.tensor(galaxy_zoo[\"astrodino_embeddings\"]),\n", 36 | " \"Stein\": torch.tensor(galaxy_zoo[\"stein_embeddings\"]),\n", 37 | "}\n", 38 | "\n", 39 | "# Get the names of the columns\n", 40 | "names = names = [\n", 41 | " \"smooth\",\n", 42 | " \"disk-edge-on\",\n", 43 | " \"spiral-arms\",\n", 44 | " \"bar\",\n", 45 | " \"bulge-size\",\n", 46 | " \"how-rounded\",\n", 47 | " \"edge-on-bulge\",\n", 48 | " \"spiral-winding\",\n", 49 | " \"spiral-arm-count\",\n", 50 | " \"merging\",\n", 51 | "]\n", 52 | "\n", 53 | "# Get the labels\n", 54 | "galaxy_zoo.remove_columns(\n", 55 | " [\"astroclip_embeddings\", \"astrodino_embeddings\", \"stein_embeddings\"]\n", 56 | ")\n", 57 | "classifications = galaxy_zoo\n", 58 | "\n", 59 | "# Get the key list\n", 60 | "keys = {\n", 61 | " name: {\n", 62 | " \"target\": [\n", 63 | " key\n", 64 | " for key in classifications.colnames\n", 65 | " if name in key and \"debiased\" in key and \"mask\" not in key\n", 66 | " ],\n", 67 | " \"counts\": [\n", 68 | " key\n", 69 | " for key in classifications.colnames\n", 70 | " if name in key and \"total-votes\" in key\n", 71 | " ][0],\n", 72 | " }\n", 73 | " for name in names\n", 74 | "}" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "# Select first 80% for train and last 20% for test\n", 84 | "train_indices = int(0.8 * len(classifications))\n", 85 | "\n", 86 | "X_train, X_test = {}, {}\n", 87 | "for key in X.keys():\n", 88 | " X_train[key] = X[key][:train_indices]\n", 89 | " X_test[key] = X[key][train_indices:]\n", 90 | "\n", 91 | "classifications_train, classifications_test = (\n", 92 | " classifications[:train_indices],\n", 93 | " classifications[train_indices:],\n", 94 | ")" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "# This is the total number of possible votes\n", 104 | "total_counts_train = classifications_train[keys[\"smooth\"][\"counts\"]].data\n", 105 | "\n", 106 | "# Get accuracy and F1 score on each question\n", 107 | "outputs = {key: {} for key in X.keys()}\n", 108 | "for name in names:\n", 109 | " question, num_classes = name, len(keys[name][\"target\"])\n", 110 | "\n", 111 | " # Get the train samples above 50% answered\n", 112 | " counts_train = classifications_train[keys[name][\"counts\"]].data\n", 113 | " # train_mask = np.where(counts_train / total_counts_train > 0.5)[0]\n", 114 | " train_mask = [True] * len(counts_train)\n", 115 | "\n", 116 | " # Get the test samples above 34 answers\n", 117 | " counts_test = classifications_test[keys[name][\"counts\"]].data\n", 118 | " test_mask = np.where(counts_test > 34)[0]\n", 119 | "\n", 120 | " # Get train and test\n", 121 | " y_train = torch.tensor(\n", 122 | " classifications_train[keys[name][\"target\"]].to_pandas().values\n", 123 | " )[train_mask]\n", 124 | " y_test = torch.tensor(\n", 125 | " classifications_test[keys[name][\"target\"]].to_pandas().values\n", 126 | " )[test_mask]\n", 127 | "\n", 128 | " train_nan_mask = torch.isnan(y_train).any(axis=1)\n", 129 | " test_nan_mask = torch.isnan(y_test).any(axis=1)\n", 130 | "\n", 131 | " # Train and evaluate on each model\n", 132 | " print(f\"Training on question: {question}...\")\n", 133 | " for model in X.keys():\n", 134 | " X_train_local = X_train[model][train_mask][~train_nan_mask]\n", 135 | " X_test_local = X_test[model][test_mask][~test_nan_mask]\n", 136 | " outputs[model][name] = train_eval_on_question(\n", 137 | " X_train_local,\n", 138 | " X_test_local,\n", 139 | " y_train,\n", 140 | " y_test,\n", 141 | " X_train_local.shape[1],\n", 142 | " num_classes=num_classes,\n", 143 | " MLP_dim=256,\n", 144 | " epochs=25,\n", 145 | " dropout=0.2,\n", 146 | " )\n", 147 | " print(\n", 148 | " f\"Model: {model}, Accuracy: {outputs[model][name]['Accuracy']:.4f}, F1: {outputs[model][name]['F1 Score']:.4f}\"\n", 149 | " )\n", 150 | " print(\"Done!\")" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "# Clean up labels\n", 160 | "outputs[\"Unaligned Transformer\"] = outputs.pop(\"AstroDINO\")\n", 161 | "outputs[\"Stein, et al.\"] = outputs.pop(\"Stein\")\n", 162 | "\n", 163 | "# Plot radar plots\n", 164 | "plot_radar(outputs, metric=\"Accuracy\", file_path=f\"./outputs/radar_accuracy.png\")\n", 165 | "plot_radar(outputs, metric=\"F1 Score\", file_path=f\"./outputs/radar_f1_score.png\")" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [] 174 | } 175 | ], 176 | "metadata": { 177 | "kernelspec": { 178 | "display_name": "toto", 179 | "language": "python", 180 | "name": "toto" 181 | } 182 | }, 183 | "nbformat": 4, 184 | "nbformat_minor": 2 185 | } 186 | -------------------------------------------------------------------------------- /downstream_tasks/morphology_classification/morphology_utils/cross_match.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append("../..") 5 | 6 | import argparse 7 | from typing import List 8 | 9 | import h5py 10 | import numpy as np 11 | import requests 12 | import torch 13 | from astropy import units as u 14 | from astropy.coordinates import SkyCoord 15 | from astropy.table import Table 16 | from torchvision.transforms import CenterCrop, Compose 17 | from tqdm import tqdm 18 | 19 | from astroclip.astrodino.data.augmentations import ToRGB 20 | from astroclip.env import format_with_env 21 | 22 | gz_5_link = ( 23 | "https://zenodo.org/records/4573248/files/gz_decals_volunteers_5.csv?download=1" 24 | ) 25 | 26 | 27 | def _generate_catalog(files: List[str]) -> Table: 28 | """Generate a catalog from a list of files.""" 29 | ra_list, dec_list = [], [] 30 | index_list, file_list = [], [] 31 | print("Generating catalogs", flush=True) 32 | for i, file in enumerate(tqdm(files)): 33 | with h5py.File(file, "r") as f: 34 | ra = f["ra"][:] 35 | dec = f["dec"][:] 36 | 37 | # Append data to lists 38 | ra_list.extend(ra) 39 | dec_list.extend(dec) 40 | file_list.extend([file] * len(ra)) 41 | index_list.extend(range(0, len(ra))) 42 | 43 | # Create astropy table 44 | return Table( 45 | [ra_list, dec_list, index_list, file_list], names=("ra", "dec", "index", "file") 46 | ) 47 | 48 | 49 | def _cross_match_tables( 50 | table1: Table, table2: Table, max_sep: float = 0.5 51 | ) -> tuple[Table, Table]: 52 | """Cross-match two tables.""" 53 | 54 | # Create SkyCoord objects 55 | coords1 = SkyCoord(ra=table1["ra"] * u.degree, dec=table1["dec"] * u.degree) 56 | coords2 = SkyCoord(ra=table2["ra"] * u.degree, dec=table2["dec"] * u.degree) 57 | 58 | print("Matching coordinates", flush=True) 59 | 60 | # Match coordinates 61 | idx, d2d, _ = coords1.match_to_catalog_sky(coords2) 62 | 63 | # Define separation constraint and apply it 64 | max_sep = max_sep * u.arcsec 65 | sep_constraint = d2d < max_sep 66 | 67 | print(f"Total number of matches: {np.sum(sep_constraint)} \n", flush=True) 68 | return table1[sep_constraint], table2[idx[sep_constraint]] 69 | 70 | 71 | def _get_images(files: list[str], classifications: Table) -> Table: 72 | """Get images from files.""" 73 | 74 | # Set up transforms 75 | transform = Compose([CenterCrop(144), ToRGB()]) 76 | 77 | # Add images to catalog 78 | print("Adding images to catalog", flush=True) 79 | images = np.zeros((len(classifications), 3, 144, 144)) 80 | for idx, file in enumerate(files): 81 | print(f"Processing file: {idx}", flush=True) 82 | with h5py.File(file, "r") as f: 83 | for k, entry in tqdm(enumerate(classifications)): 84 | if entry["file"] != file: 85 | continue 86 | index = entry["index"] 87 | image = transform(torch.tensor(f["images"][index])).T 88 | images[k] = np.array(image) 89 | classifications["image"] = images 90 | return classifications 91 | 92 | 93 | def _download_gz5_decals(survey_path: str) -> None: 94 | """Download Galaxy Zoo 5 classifications.""" 95 | response = requests.get(gz_5_link) 96 | with open(survey_path, "wb") as f: 97 | f.write(response.content) 98 | 99 | 100 | def _get_file_location(root_dir: List[str]) -> List[str]: 101 | """Get the locations of the Legacy Survey image files.""" 102 | north_path = os.path.join(root_dir, "north") 103 | south_path = os.path.join(root_dir, "south") 104 | 105 | files_north = [ 106 | os.path.join( 107 | north_path, 108 | "images_npix152_0%02d000000_0%02d000000.h5" % (i, i + 1), 109 | ) 110 | for i in range(14) 111 | ] 112 | files_south = [ 113 | os.path.join( 114 | south_path, 115 | "images_npix152_0%02d000000_0%02d000000.h5" % (i, i + 1), 116 | ) 117 | for i in range(62) 118 | ] 119 | 120 | files = files_north + files_south 121 | return files 122 | 123 | 124 | def cross_match_galaxy_zoo( 125 | root_dir: str, save_path: str, survey_path: str = None 126 | ) -> None: 127 | """ 128 | Pairs Galaxy Zoo classifications with DECaLS images in an Astropy table. 129 | 130 | Args: 131 | root_dir (str): Root directory of DECaLS images. 132 | survey_path (str): Path to Galaxy Zoo survey. 133 | 134 | Returns: 135 | Table: Table of paired classifications. 136 | """ 137 | 138 | # Get file locations 139 | files = _get_file_location(root_dir) 140 | 141 | # Load morphology classifications 142 | if not os.path.exists(survey_path): 143 | _download_gz5_decals(survey_path) 144 | morphologies = Table.read(survey_path, format="ascii") 145 | 146 | # Generate catalog of ra, dec, index, file from files 147 | positions = _generate_catalog(files) 148 | 149 | # Cross-match positions with morphology classifications 150 | classifications, positions_matched = _cross_match_tables(morphologies, positions) 151 | 152 | # Update classifications with index and file 153 | classifications["index"] = np.array(positions_matched["index"]) 154 | classifications["file"] = np.array(positions_matched["file"]) 155 | 156 | # Get images and add them to classifications 157 | classifications = _get_images(files, classifications) 158 | 159 | # Save classifications 160 | print(f"Saving paired classifications to {save_path}", flush=True) 161 | classifications.write(save_path, overwrite=True, format="hdf5") 162 | 163 | 164 | if __name__ == "__main__": 165 | ASTROCLIP_ROOT = format_with_env("{ASTROCLIP_ROOT}") 166 | parser = argparse.ArgumentParser() 167 | parser.add_argument( 168 | "--root_dir", 169 | type=str, 170 | default=f"{ASTROCLIP_ROOT}/datasets/decals", 171 | help="Root directory of DECaLS images.", 172 | ) 173 | parser.add_argument( 174 | "--save_path", 175 | type=str, 176 | default=f"{ASTROCLIP_ROOT}/datasets/galaxy_zoo/gz5_decals_crossmatched.hdf5", 177 | help="Path to Galaxy Zoo survey.", 178 | ) 179 | parser.add_argument( 180 | "--survey_path", 181 | type=str, 182 | default=f"{ASTROCLIP_ROOT}/datasets/galaxy_zoo/gz_decals_volunteers_5.csv", 183 | help="Path to Galaxy Zoo survey.", 184 | ) 185 | 186 | args = parser.parse_args() 187 | cross_match_galaxy_zoo(args.root_dir, args.save_path, args.survey_path) 188 | -------------------------------------------------------------------------------- /downstream_tasks/morphology_classification/morphology_utils/models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | from sklearn.metrics import accuracy_score, precision_recall_fscore_support 6 | from sklearn.model_selection import train_test_split 7 | from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler 8 | from tqdm import tqdm 9 | 10 | 11 | class MLP(nn.Module): 12 | """A simple feedforward neural network with 3 hidden layers.""" 13 | 14 | def __init__( 15 | self, input_dim: int, num_classes: int, hidden_dim: int, dropout_rate: int 16 | ): 17 | super().__init__() 18 | 19 | self.layers = nn.Sequential( 20 | nn.Linear(input_dim, hidden_dim), 21 | nn.Dropout(dropout_rate), 22 | nn.ReLU(), 23 | nn.Linear(hidden_dim, hidden_dim), 24 | nn.Dropout(dropout_rate), 25 | nn.ReLU(), 26 | nn.Linear(hidden_dim, hidden_dim), 27 | nn.Dropout(dropout_rate), 28 | nn.ReLU(), 29 | nn.Linear(hidden_dim, num_classes), 30 | ) 31 | 32 | self.softmax = nn.Softmax(dim=1) 33 | 34 | def forward(self, x: torch.tensor) -> torch.tensor: 35 | x = self.layers(x) 36 | x = self.softmax(x) 37 | return x.squeeze() 38 | 39 | 40 | def train_eval_on_question( 41 | X_train: torch.tensor, 42 | X_test: torch.tensor, 43 | y_train: torch.tensor, 44 | y_test: torch.tensor, 45 | embed_dim: int, 46 | num_classes: int, 47 | MLP_dim: int = 128, 48 | batch_size: int = 256, 49 | lr: float = 1e-3, 50 | epochs: int = 25, 51 | dropout: float = 0.2, 52 | ) -> dict: 53 | """Function to train and evaluate a simple feedforward neural network on a dataset.""" 54 | X_train, X_val, y_train, y_val = train_test_split( 55 | X_train, y_train, test_size=0.1, random_state=42 56 | ) 57 | train_dataset = TensorDataset(X_train, y_train) 58 | val_dataset = TensorDataset(X_val, y_val) 59 | 60 | # Compute class weights 61 | samples_weight = y_train.max(dim=1).values # Taking max fraction as the weight 62 | 63 | # Create a DataLoader 64 | train_loader = DataLoader( 65 | train_dataset, 66 | batch_size=batch_size, 67 | sampler=WeightedRandomSampler(samples_weight, len(samples_weight)), 68 | ) 69 | val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) 70 | 71 | # Set up model 72 | mlp = MLP(embed_dim, num_classes, MLP_dim, dropout).cuda() 73 | criterion = nn.CrossEntropyLoss() # Suitable for multi-label classification 74 | optimizer = optim.Adam(mlp.parameters(), lr=lr) 75 | 76 | # Training loop 77 | best_val_loss = float("inf") 78 | best_metrics = None 79 | for epoch in range(epochs): 80 | mlp.train() 81 | train_loss = 0 82 | for data, target in train_loader: 83 | optimizer.zero_grad() 84 | output = mlp(data.cuda()) 85 | loss = criterion(output.squeeze(), target.squeeze().cuda()) 86 | loss.backward() 87 | optimizer.step() 88 | train_loss += loss.item() 89 | 90 | train_loss /= len(train_loader) 91 | 92 | # Validation loop 93 | mlp.eval() 94 | val_loss = 0 95 | with torch.no_grad(): 96 | for data, target in val_loader: 97 | output = mlp(data.cuda()) 98 | loss = criterion(output.squeeze(), target.squeeze().cuda()) 99 | val_loss += loss.item() 100 | 101 | val_loss /= len(val_loader) 102 | 103 | # Save best model based on validation loss 104 | if val_loss < best_val_loss: 105 | best_val_loss = val_loss 106 | best_model = mlp.state_dict() 107 | 108 | # Get the best model 109 | mlp.load_state_dict(best_model) 110 | y_pred = mlp(X_test.cuda()).detach().cpu() 111 | 112 | # Discretize the predictions 113 | y_pred = (y_pred == torch.max(y_pred, dim=1, keepdim=True).values).int() 114 | y_true = (y_test == torch.max(y_test, dim=1, keepdim=True).values).int() 115 | 116 | # Compute and return the metrics 117 | accuracy = accuracy_score(y_true.numpy(), y_pred.numpy()) 118 | f1_score = precision_recall_fscore_support( 119 | y_true.numpy(), y_pred.numpy(), average="weighted", zero_division=0 120 | )[2] 121 | return {"Accuracy": accuracy, "F1 Score": f1_score} 122 | -------------------------------------------------------------------------------- /downstream_tasks/morphology_classification/morphology_utils/plotting.py: -------------------------------------------------------------------------------- 1 | import os 2 | from math import pi 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | 7 | 8 | def plot_radar(outputs: dict, metric: str, file_path: str, fontsize: int = 18): 9 | """Functionality for plotting radar chart""" 10 | questions = {} 11 | for key in outputs.keys(): 12 | questions[key] = [ 13 | outputs[key][question][metric] for question in outputs[key].keys() 14 | ] 15 | labels = outputs[key].keys() 16 | 17 | # Add Zoobot scores 18 | questions["ZooBot Reported"] = [ 19 | zoobot_scores[question][metric] for question in zoobot_scores.keys() 20 | ] 21 | 22 | # Create radar chart 23 | angles = np.linspace(0, 2 * pi, len(questions[key]), endpoint=False).tolist() 24 | angles += angles[:1] # complete the loop 25 | 26 | fig, ax = plt.subplots(figsize=(12, 12), subplot_kw=dict(polar=True)) 27 | 28 | colors = ["red", "red", "black", "blue"] 29 | styles = ["solid", "dashed", "solid", "solid"] 30 | 31 | # Plot each array on the radar chart 32 | for key in questions.keys(): 33 | stats = [questions[key][i] for i in range(len(questions[key]))] 34 | stats += stats[:1] 35 | ax.plot( 36 | angles, 37 | stats, 38 | label=key, 39 | linewidth=2, 40 | linestyle=styles.pop(0), 41 | color=colors.pop(0), 42 | ) 43 | 44 | # capitalize labels 45 | labels = [label.capitalize() for label in labels] 46 | 47 | # Add labels with specific fontsize 48 | ax.set_theta_offset(pi / 2) 49 | ax.set_theta_direction(-1) 50 | 51 | # Change r label to fontsize 52 | ax.tick_params(axis="y", labelsize=fontsize) 53 | ax.set_xticks(angles[:-1], labels, fontsize=fontsize, color="black") 54 | 55 | # make theta labels not overlap with plot 56 | ax.set_ylim(0, 1.0) 57 | 58 | # Add legend 59 | legend = plt.legend(loc="upper right", bbox_to_anchor=(1.1, 1.1)) 60 | plt.setp( 61 | legend.get_texts(), fontsize=fontsize 62 | ) # Explicitly set fontsize for legend 63 | 64 | # Save fig 65 | if not os.path.exists(os.path.dirname(file_path)): 66 | os.makedirs(os.path.dirname(file_path)) 67 | plt.savefig(file_path) 68 | plt.close() 69 | 70 | 71 | # ZooBot scores taken from Walmsley et al. (2021), https://arxiv.org/pdf/2102.08414 72 | zoobot_scores = { 73 | "smooth": {"Accuracy": 0.94, "F1 Score": 0.94}, 74 | "disk-edge-on": {"Accuracy": 0.99, "F1 Score": 0.99}, 75 | "spiral-arms": {"Accuracy": 0.93, "F1 Score": 0.94}, 76 | "bar": {"Accuracy": 0.82, "F1 Score": 0.81}, 77 | "bulge-size": {"Accuracy": 0.84, "F1 Score": 0.84}, 78 | "how-rounded": {"Accuracy": 0.93, "F1 Score": 0.93}, 79 | "edge-on-bulge": {"Accuracy": 0.91, "F1 Score": 0.90}, 80 | "spiral-winding": {"Accuracy": 0.78, "F1 Score": 0.79}, 81 | "spiral-arm-count": {"Accuracy": 0.77, "F1 Score": 0.76}, 82 | "merging": {"Accuracy": 0.88, "F1 Score": 0.85}, 83 | } 84 | -------------------------------------------------------------------------------- /downstream_tasks/property_estimation/README.md: -------------------------------------------------------------------------------- 1 | ## Property Estimation 2 | We demonstrate physical property estimation using the PROVABGS dataset from Hahn, et al. (2022). This includes best-fit parameters from galaxy spectroscopy of stellar mass, age, metallicity, and star formation rate generated with a state-of-the-art bayesian SED modeling framework. 3 | 4 | ### Cross-Matching 5 | Cross-matching between PROVABGS and the DESI x DESI-LS dataset is performed by running 6 | ```python 7 | python property_utils/cross_match.py 8 | ``` 9 | This creates a cross-matched table with containing the preprocessed DESI x DESI-LS survey images and spectra and their corresponding PROVABGS physical properties. 10 | 11 | ### Embedding 12 | Embedding of the images and spectra is then performed with 13 | ```python 14 | python embed_provabgs.py 15 | ``` 16 | This creates a table containing the embedded DESI x DESI-LS images and spectra and their corresponding PROVABGS physical properties. Note that this assumes that the cross-matching in the above step has already been performed. Additionally, it also assumes that all models have been downloaded and stored in the following directory: 17 | ```bash 18 | {ASTROCLIP_ROOT}/pretrained 19 | ``` 20 | 21 | ### Property Estimation 22 | Once the embedded table has been generated: 23 | - Redshift estimation is performed in `redshift.ipynb` 24 | - Property estimation is performed in `property_estimation.ipynb` 25 | - Posterior estimation is performed with `posterior_estimation.py` 26 | 27 | ### Baselines 28 | The baseline models used in the paper are trained in the `baselines` directory. Refer to the README therein for more details. 29 | -------------------------------------------------------------------------------- /downstream_tasks/property_estimation/baselines/README.md: -------------------------------------------------------------------------------- 1 | ### Property Estimation Baselines 2 | We include a variety of supervised benchmarks to perform physical property estimation from either images, spectra, or photometry. Baseline training can be run with 3 | ```python 4 | python trainer.py [modality] [model name] [properties] 5 | ``` 6 | This automatically trains and evaluates the model on the held-out test set, reporting R-squared metrics on the properties of interest. 7 | -------------------------------------------------------------------------------- /downstream_tasks/property_estimation/baselines/data.py: -------------------------------------------------------------------------------- 1 | import lightning as L 2 | import numpy as np 3 | import torch 4 | from astropy.table import Table 5 | from torch.utils.data import DataLoader, TensorDataset, random_split 6 | 7 | 8 | class SupervisedDataModule(L.LightningDataModule): 9 | def __init__( 10 | self, 11 | train_data: Table, 12 | test_data: Table, 13 | modality: str, 14 | properties: list, 15 | batch_size: int = 128, 16 | train_size: float = 0.8, 17 | ): 18 | super().__init__() 19 | # Load the data 20 | self.train_data = train_data 21 | self.test_data = test_data 22 | 23 | # Set the modality and properties 24 | self.modality = modality 25 | self.properties = properties 26 | self.batch_size = batch_size 27 | self.train_size = train_size 28 | 29 | # Check the modality 30 | if modality not in ["image", "spectrum", "photometry"]: 31 | raise ValueError("Invalid modality") 32 | 33 | def setup(self, stage=None): 34 | if self.modality == "image": 35 | # Load the data 36 | X_train, X_test = torch.tensor( 37 | self.train_data[self.modality], dtype=torch.float32 38 | ), torch.tensor(self.test_data[self.modality], dtype=torch.float32) 39 | elif self.modality == "spectrum": 40 | X_train, X_test = torch.tensor( 41 | self.train_data[self.modality], dtype=torch.float32 42 | ).squeeze(-1), torch.tensor( 43 | self.test_data[self.modality], dtype=torch.float32 44 | ).squeeze( 45 | -1 46 | ) 47 | elif self.modality == "photometry": 48 | # Load the photometry data 49 | X_train = torch.tensor( 50 | np.stack( 51 | [ 52 | self.train_data["MAG_G"], 53 | self.train_data["MAG_R"], 54 | self.train_data["MAG_Z"], 55 | ] 56 | ), 57 | dtype=torch.float32, 58 | ).permute(1, 0) 59 | X_test = torch.tensor( 60 | np.stack( 61 | [ 62 | self.test_data["MAG_G"], 63 | self.test_data["MAG_R"], 64 | self.test_data["MAG_Z"], 65 | ] 66 | ), 67 | dtype=torch.float32, 68 | ).permute(1, 0) 69 | 70 | # Normalize photometry 71 | mean, std = X_train.mean(), X_train.std() 72 | X_train = (X_train - mean) / std 73 | X_test = (X_test - mean) / std 74 | 75 | # Set up the property data 76 | property_data, scale = {}, {} 77 | for p in self.properties: 78 | data = torch.tensor(self.train_data[p].data, dtype=torch.float32) 79 | mean, std = data.mean(), data.std() 80 | property_data[p] = ((data - mean) / std).squeeze() 81 | scale[p] = {"mean": mean.numpy(), "std": std.numpy()} 82 | y_train = torch.stack([property_data[p] for p in self.properties], dim=1) 83 | 84 | # Split the data into training, validation, and test sets 85 | total_size = len(X_train) 86 | train_size = int(self.train_size * total_size) 87 | self.train_dataset, self.val_dataset = random_split( 88 | TensorDataset(X_train, y_train), [train_size, total_size - train_size] 89 | ) 90 | 91 | self.test_dataset = TensorDataset(X_test) 92 | self.scale = scale 93 | 94 | def train_dataloader(self): 95 | return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True) 96 | 97 | def val_dataloader(self): 98 | return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False) 99 | 100 | def test_dataloader(self): 101 | return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False) 102 | -------------------------------------------------------------------------------- /downstream_tasks/property_estimation/baselines/modules.py: -------------------------------------------------------------------------------- 1 | import lightning as L 2 | import torch 3 | from torch import nn, optim 4 | from torchvision import models 5 | from torchvision.transforms import ( 6 | Compose, 7 | GaussianBlur, 8 | RandomHorizontalFlip, 9 | RandomVerticalFlip, 10 | ) 11 | 12 | from astroclip.env import format_with_env 13 | 14 | ASTROCLIP_ROOT = format_with_env("{ASTROCLIP_ROOT}") 15 | 16 | 17 | class SupervisedModel(L.LightningModule): 18 | def __init__( 19 | self, 20 | model_name, 21 | modality, 22 | properties, 23 | scale, 24 | num_epochs, 25 | lr=1e-3, 26 | save_dir=None, 27 | ): 28 | super().__init__() 29 | self.model_name = model_name 30 | self.modality = modality 31 | self.properties = properties 32 | self.scale = scale 33 | self.lr = lr 34 | self.num_epochs = num_epochs 35 | self.criterion = nn.MSELoss() 36 | self.save_dir = save_dir 37 | self._initialize_model(model_name) 38 | self.image_transforms = Compose( 39 | [ 40 | RandomHorizontalFlip(), 41 | RandomVerticalFlip(), 42 | GaussianBlur(kernel_size=3), 43 | ] 44 | ) 45 | 46 | def _initialize_model(self, model_name): 47 | if model_name == "resnet18": 48 | self.model = ResNet18(n_out=len(self.properties)) 49 | elif model_name == "conv+att": 50 | self.model = SpectrumEncoder(n_latent=len(self.properties)) 51 | elif model_name == "mlp": 52 | self.model = MLP( 53 | n_in=3, 54 | n_out=len(self.properties), 55 | n_hidden=(64, 64), 56 | act=[nn.ReLU()] * 3, 57 | ) 58 | else: 59 | raise ValueError("Invalid model name") 60 | 61 | def forward(self, x): 62 | return self.model(x).squeeze() 63 | 64 | def training_step(self, batch, batch_idx): 65 | X_batch, y_batch = batch 66 | if self.modality == "image": 67 | X_batch = self.image_transforms(X_batch) 68 | y_pred = self(X_batch) 69 | loss = self.criterion(y_pred, y_batch.squeeze()) 70 | self.log("train_loss", loss, prog_bar=True, on_epoch=True) 71 | return loss 72 | 73 | def validation_step(self, batch, batch_idx): 74 | X_batch, y_batch = batch 75 | y_pred = self(X_batch) 76 | loss = self.criterion(y_pred, y_batch.squeeze()) 77 | self.log("val_loss", loss, prog_bar=True, on_epoch=True) 78 | return loss 79 | 80 | def configure_optimizers(self): 81 | optimizer = optim.Adam(self.parameters(), lr=self.lr) 82 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, self.num_epochs) 83 | return {"optimizer": optimizer, "scheduler": scheduler} 84 | 85 | 86 | class ResNet18(nn.Module): 87 | """Modfied ResNet18.""" 88 | 89 | def __init__(self, n_out=1): 90 | super(ResNet18, self).__init__() 91 | self.resnet = models.resnet18(weights=None) 92 | self.resnet.conv1 = nn.Conv2d( 93 | 3, 64, kernel_size=7, stride=2, padding=3, bias=False 94 | ) 95 | self.resnet.fc = nn.Linear(512, n_out) 96 | 97 | def forward(self, x): 98 | return self.resnet(x) 99 | 100 | 101 | class MLP(nn.Sequential): 102 | """MLP model""" 103 | 104 | def __init__(self, n_in, n_out, n_hidden=(16, 16, 16), act=None, dropout=0): 105 | if act is None: 106 | act = [ 107 | nn.LeakyReLU(), 108 | ] * (len(n_hidden) + 1) 109 | assert len(act) == len(n_hidden) + 1 110 | 111 | layer = [] 112 | n_ = [n_in, *n_hidden, n_out] 113 | for i in range(len(n_) - 2): 114 | layer.append(nn.Linear(n_[i], n_[i + 1])) 115 | layer.append(act[i]) 116 | layer.append(nn.Dropout(p=dropout)) 117 | layer.append(nn.Linear(n_[-2], n_[-1])) 118 | super(MLP, self).__init__(*layer) 119 | 120 | 121 | class SpectrumEncoder(nn.Module): 122 | """Spectrum encoder 123 | 124 | Modified version of the encoder by Serrà et al. (2018), which combines a 3 layer CNN 125 | with a dot-product attention module. This encoder adds a MLP to further compress the 126 | attended values into a low-dimensional latent space. 127 | 128 | Paper: Serrà et al., https://arxiv.org/abs/1805.03908 129 | """ 130 | 131 | def __init__(self, n_latent, n_hidden=(32, 32), act=None, dropout=0): 132 | super(SpectrumEncoder, self).__init__() 133 | self.n_latent = n_latent 134 | 135 | filters = [8, 16, 16, 32] 136 | sizes = [5, 10, 20, 40] 137 | self.conv1, self.conv2, self.conv3, self.conv4 = self._conv_blocks( 138 | filters, sizes, dropout=dropout 139 | ) 140 | self.n_feature = filters[-1] // 2 141 | 142 | # pools and softmax work for spectra and weights 143 | self.pool1, self.pool2, self.pool3 = tuple( 144 | nn.MaxPool1d(s, padding=s // 2) for s in sizes[:3] 145 | ) 146 | self.softmax = nn.Softmax(dim=-1) 147 | 148 | # small MLP to go from CNN features to latents 149 | if act is None: 150 | act = [nn.PReLU(n) for n in n_hidden] 151 | # last activation identity to have latents centered around 0 152 | act.append(nn.Identity()) 153 | self.mlp = MLP( 154 | self.n_feature, self.n_latent, n_hidden=n_hidden, act=act, dropout=dropout 155 | ) 156 | 157 | def _conv_blocks(self, filters, sizes, dropout=0): 158 | convs = [] 159 | for i in range(len(filters)): 160 | f_in = 1 if i == 0 else filters[i - 1] 161 | f = filters[i] 162 | s = sizes[i] 163 | p = s // 2 164 | conv = nn.Conv1d( 165 | in_channels=f_in, 166 | out_channels=f, 167 | kernel_size=s, 168 | padding=p, 169 | ) 170 | norm = nn.InstanceNorm1d(f) 171 | act = nn.PReLU(f) 172 | drop = nn.Dropout(p=dropout) 173 | convs.append(nn.Sequential(conv, norm, act, drop)) 174 | return tuple(convs) 175 | 176 | def _downsample(self, x): 177 | # compression 178 | x = x.unsqueeze(1) 179 | x = self.pool1(self.conv1(x)) 180 | x = self.pool2(self.conv2(x)) 181 | x = self.pool3(self.conv3(x)) 182 | x = self.conv4(x) 183 | C = x.shape[1] // 2 184 | # split half channels into attention value and key 185 | h, a = torch.split(x, [C, C], dim=1) 186 | 187 | return h, a 188 | 189 | def forward(self, y): 190 | # run through CNNs 191 | h, a = self._downsample(y) 192 | # softmax attention 193 | a = self.softmax(a) 194 | 195 | # attach hook to extract backward gradient of a scalar prediction 196 | # for Grad-FAM (Feature Activation Map) 197 | if ~self.training and a.requires_grad == True: 198 | a.register_hook(self._attention_hook) 199 | 200 | # apply attention 201 | x = torch.sum(h * a, dim=2) 202 | 203 | # run attended features into MLP for final latents 204 | x = self.mlp(x) 205 | return x 206 | 207 | @property 208 | def n_parameters(self): 209 | return sum(p.numel() for p in self.parameters() if p.requires_grad) 210 | 211 | def _attention_hook(self, grad): 212 | self._attention_grad = grad 213 | 214 | @property 215 | def attention_grad(self): 216 | if hasattr(self, "_attention_grad"): 217 | return self._attention_grad 218 | else: 219 | return None 220 | -------------------------------------------------------------------------------- /downstream_tasks/property_estimation/baselines/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | 4 | import lightning as L 5 | import numpy as np 6 | import torch 7 | from astropy.table import Table 8 | from sklearn.metrics import r2_score 9 | from torch import nn 10 | from torch.utils.data import DataLoader 11 | 12 | from astroclip.env import format_with_env 13 | from data import SupervisedDataModule 14 | from modules import SupervisedModel 15 | 16 | ASTROCLIP_ROOT = format_with_env("{ASTROCLIP_ROOT}") 17 | 18 | 19 | def _get_predictions(model, test_loader, test_provabgs, scale, device="cuda"): 20 | """Use model to get predictions""" 21 | test_pred = [] 22 | with torch.no_grad(): 23 | for X_batch in test_loader: 24 | y_pred = model(X_batch[0].to(device)).squeeze().detach().cpu() 25 | test_pred.append(y_pred) 26 | test_pred = torch.cat(test_pred).numpy() 27 | 28 | pred_dict = {} 29 | for i, p in enumerate(scale.keys()): 30 | if len(test_pred.shape) > 1: 31 | pred_dict[p] = (test_pred[:, i] * scale[p]["std"]) + scale[p]["mean"] 32 | else: 33 | pred_dict[p] = (test_pred * scale[p]["std"]) + scale[p]["mean"] 34 | print(f"{p} R^2: {r2_score(test_provabgs[p], pred_dict[p])}") 35 | 36 | return pred_dict 37 | 38 | 39 | def train_baseline( 40 | train_dataset: str, 41 | test_dataset: str, 42 | save_path: str, 43 | modality: str, 44 | model_name: str, 45 | num_epochs: int = 100, 46 | learning_rate: float = 5e-4, 47 | properties: str = None, 48 | accelerator: str = "gpu", 49 | ): 50 | # Load the data 51 | train_provabgs = Table.read(train_dataset) 52 | test_provabgs = Table.read(test_dataset) 53 | 54 | # Define output directory avoiding collisions 55 | save_dir_base = os.path.join(save_path, modality, model_name, properties) 56 | save_dir = save_dir_base 57 | v_int = 0 # Suffix to add in case of collisions 58 | while os.path.exists(save_dir): 59 | print(f"Directory {save_dir} already exists, adding suffix") 60 | v_int += 1 61 | save_dir = f"{save_dir_base}-v{v_int}" 62 | 63 | # Define the properties to predict 64 | if properties == "redshift": 65 | property_list = ["Z_HP"] 66 | elif properties == "global_properties": 67 | property_list = ["LOG_MSTAR", "Z_MW", "TAGE_MW", "sSFR"] 68 | elif properties == "all_properties": 69 | property_list = ["Z_HP", "LOG_MSTAR", "Z_MW", "TAGE_MW", "sSFR"] 70 | else: 71 | raise ValueError( 72 | "Invalid properties, choose from redshift or global_properties." 73 | ) 74 | 75 | # Get the data loaders & normalization 76 | data_module = SupervisedDataModule( 77 | train_provabgs, 78 | test_provabgs, 79 | modality, 80 | properties=property_list, 81 | ) 82 | data_module.setup(stage="fit") 83 | 84 | # Get the model 85 | model = SupervisedModel( 86 | model_name=model_name, 87 | modality=modality, 88 | properties=property_list, 89 | scale=data_module.scale, 90 | lr=learning_rate, 91 | num_epochs=num_epochs, 92 | save_dir=save_dir, 93 | ) 94 | 95 | # Set up val loss checkpoint 96 | checkpoint_callback = L.pytorch.callbacks.ModelCheckpoint( 97 | monitor="val_loss", 98 | dirpath=save_dir, 99 | filename="best-checkpoint", 100 | save_top_k=1, 101 | mode="min", 102 | ) 103 | 104 | # Train and fit the model 105 | trainer = L.Trainer( 106 | accelerator=accelerator, max_epochs=num_epochs, callbacks=[checkpoint_callback] 107 | ) 108 | trainer.fit(model, datamodule=data_module) 109 | 110 | # Test the model 111 | best_model_path = checkpoint_callback.best_model_path 112 | model = SupervisedModel.load_from_checkpoint( 113 | best_model_path, 114 | model_name=model_name, 115 | modality=modality, 116 | properties=property_list, 117 | scale=data_module.scale, 118 | lr=learning_rate, 119 | num_epochs=num_epochs, 120 | save_dir=save_dir, 121 | ) 122 | 123 | # Get the predictions 124 | pred_dict = _get_predictions( 125 | model.model, 126 | data_module.test_dataloader(), 127 | test_provabgs, 128 | data_module.scale, 129 | device="cuda" if accelerator == "gpu" else "cpu", 130 | ) 131 | 132 | # Save the model and the predictions 133 | print(f"Saving in {save_dir}") 134 | os.makedirs(save_dir, exist_ok=True) 135 | torch.save(model.state_dict(), os.path.join(save_dir, "model.pt")) 136 | torch.save(pred_dict, os.path.join(save_dir, "test_pred.pt")) 137 | 138 | 139 | if __name__ == "__main__": 140 | parser = ArgumentParser() 141 | parser.add_argument( 142 | "--train_dataset", 143 | type=str, 144 | help="Path to the training dataset", 145 | default=f"{ASTROCLIP_ROOT}/datasets/provabgs/provabgs_paired_train.hdf5", 146 | ) 147 | parser.add_argument( 148 | "--test_dataset", 149 | type=str, 150 | help="Path to the test dataset", 151 | default=f"{ASTROCLIP_ROOT}/datasets/provabgs/provabgs_paired_test.hdf5", 152 | ) 153 | parser.add_argument( 154 | "--save_dir", 155 | type=str, 156 | help="Directory to save the model and predictions", 157 | default=f"{ASTROCLIP_ROOT}/supervised/", 158 | ) 159 | parser.add_argument( 160 | "--modality", 161 | type=str, 162 | help="Modality of the data ('image', 'spectrum', 'photometry')", 163 | default="image", 164 | ) 165 | parser.add_argument( 166 | "--model_name", 167 | type=str, 168 | help="Model to use (e.g. 'resnet18', 'conv+att', or 'mlp')", 169 | default="none", 170 | ) 171 | parser.add_argument( 172 | "--num_epochs", 173 | type=int, 174 | help="Number of epochs to train the model", 175 | default=50, 176 | ) 177 | parser.add_argument( 178 | "--learning_rate", 179 | type=float, 180 | help="Learning rate for the optimizer", 181 | default=5e-4, 182 | ) 183 | parser.add_argument( 184 | "--properties", 185 | type=str, 186 | help="Properties to predict ('redshift', 'global_properties', or 'all_properties)", 187 | default="global_properties", 188 | ) 189 | args = parser.parse_args() 190 | 191 | # Infer model_name if missing 192 | if args.model_name == "none": 193 | if args.modality == "image": 194 | model_name = "resnet18" 195 | elif args.modality == "spectrum": 196 | model_name = "conv+att" 197 | elif args.modality == "photometry": 198 | model_name = "mlp" 199 | else: 200 | model_name = args.model_name 201 | 202 | print( 203 | f"Training {model_name} on {args.modality} data for {args.properties} prediction" 204 | ) 205 | 206 | train_baseline( 207 | train_dataset=args.train_dataset, 208 | test_dataset=args.test_dataset, 209 | save_path=args.save_dir, 210 | modality=args.modality, 211 | model_name=model_name, 212 | num_epochs=args.num_epochs, 213 | learning_rate=args.learning_rate, 214 | properties=args.properties, 215 | ) 216 | -------------------------------------------------------------------------------- /downstream_tasks/property_estimation/embed_provabgs.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | from typing import Dict 4 | 5 | import numpy as np 6 | import torch 7 | from astropy.table import Table 8 | from dinov2.eval.setup import setup_and_build_model 9 | from torchvision.transforms import CenterCrop, Compose 10 | from tqdm import tqdm 11 | 12 | from astroclip.astrodino.utils import setup_astrodino 13 | from astroclip.env import format_with_env 14 | from astroclip.models import AstroClipModel, Moco_v2, SpecFormer 15 | 16 | 17 | def get_embeddings( 18 | image_models: Dict[str, torch.nn.Module], 19 | spectrum_models: Dict[str, torch.nn.Module], 20 | images: torch.Tensor, 21 | spectra: torch.Tensor, 22 | batch_size: int = 512, 23 | ) -> dict: 24 | """Get embeddings for images using models""" 25 | full_keys = set(image_models.keys()).union(spectrum_models.keys()) 26 | model_embeddings = {key: [] for key in full_keys} 27 | im_batch, sp_batch = [], [] 28 | 29 | assert len(images) == len(spectra) 30 | for image, spectrum in tqdm(zip(images, spectra)): 31 | # Load images, already preprocessed 32 | im_batch.append(torch.tensor(image, dtype=torch.float32)[None, :, :, :]) 33 | sp_batch.append(torch.tensor(spectrum, dtype=torch.float32)[None, :, :]) 34 | 35 | # Get embeddings for batch 36 | if len(im_batch) == batch_size: 37 | with torch.no_grad(): 38 | spectra, images = torch.cat(sp_batch).cuda(), torch.cat(im_batch).cuda() 39 | 40 | for key in image_models.keys(): 41 | model_embeddings[key].append(image_models[key](images)) 42 | 43 | for key in spectrum_models.keys(): 44 | model_embeddings[key].append(spectrum_models[key](spectra)) 45 | 46 | im_batch, sp_batch = [], [] 47 | 48 | # Get embeddings for last batch 49 | if len(im_batch) > 0: 50 | with torch.no_grad(): 51 | spectra, images = torch.cat(sp_batch).cuda(), torch.cat(im_batch).cuda() 52 | 53 | # Get embeddings 54 | for key in image_models.keys(): 55 | model_embeddings[key].append(image_models[key](images)) 56 | 57 | for key in spectrum_models.keys(): 58 | model_embeddings[key].append(spectrum_models[key](spectra)) 59 | 60 | model_embeddings = { 61 | key: np.concatenate(model_embeddings[key]) for key in model_embeddings.keys() 62 | } 63 | return model_embeddings 64 | 65 | 66 | def embed_provabgs( 67 | provabgs_file_train: str, 68 | provabgs_file_test: str, 69 | pretrained_dir: str, 70 | batch_size: int = 512, 71 | ): 72 | # Get directories 73 | astrodino_output_dir = os.path.join(pretrained_dir, "astrodino_output_dir") 74 | 75 | pretrained_weights = {} 76 | for model in ["astroclip", "stein", "astrodino", "specformer"]: 77 | pretrained_weights[model] = os.path.join(pretrained_dir, f"{model}.ckpt") 78 | 79 | # Set up AstroCLIP 80 | astroclip = AstroClipModel.load_from_checkpoint( 81 | checkpoint_path=pretrained_weights["astroclip"], 82 | ) 83 | 84 | # Set up Stein, et al. model 85 | stein = Moco_v2.load_from_checkpoint( 86 | checkpoint_path=pretrained_weights["stein"], 87 | ).encoder_q 88 | 89 | # Set up SpecFormer model 90 | checkpoint = torch.load(pretrained_weights["specformer"]) 91 | specformer = SpecFormer(**checkpoint["hyper_parameters"]) 92 | specformer.load_state_dict(checkpoint["state_dict"]) 93 | specformer.cuda() 94 | 95 | # Set up AstroDINO model 96 | astrodino = setup_astrodino(astrodino_output_dir, pretrained_weights["astrodino"]) 97 | 98 | # Set up model dict 99 | image_models = { 100 | "astrodino": lambda x: astrodino(x).cpu().numpy(), 101 | "stein": lambda x: stein(x).cpu().numpy(), 102 | "astroclip_image": lambda x: astroclip(x, input_type="image").cpu().numpy(), 103 | } 104 | 105 | spectrum_models = { 106 | "astroclip_spectrum": lambda x: astroclip(x, input_type="spectrum") 107 | .cpu() 108 | .numpy(), 109 | "specformer": lambda x: np.mean( 110 | specformer(x)["embedding"].cpu().numpy(), axis=1 111 | ), 112 | } 113 | print("Models are correctly set up!") 114 | 115 | # Load data 116 | files = [provabgs_file_test, provabgs_file_train] 117 | for f in files: 118 | provabgs = Table.read(f) 119 | images, spectra = provabgs["image"], provabgs["spectrum"] 120 | 121 | # Get embeddings 122 | embeddings = get_embeddings( 123 | image_models, spectrum_models, images, spectra, batch_size 124 | ) 125 | 126 | # Remove images and replace with embeddings 127 | provabgs.remove_column("image") 128 | provabgs.remove_column("spectrum") 129 | for key in embeddings.keys(): 130 | assert len(embeddings[key]) == len(provabgs), "Embeddings incorrect length" 131 | provabgs[f"{key}_embeddings"] = embeddings[key] 132 | 133 | # Save embeddings 134 | provabgs.write(f.replace(".hdf5", "_embeddings.hdf5"), overwrite=True) 135 | 136 | 137 | if __name__ == "__main__": 138 | ASTROCLIP_ROOT = format_with_env("{ASTROCLIP_ROOT}") 139 | parser = ArgumentParser() 140 | parser.add_argument( 141 | "--provabgs_file_train", 142 | type=str, 143 | default=f"{ASTROCLIP_ROOT}/datasets/provabgs/provabgs_paired_train.hdf5", 144 | ) 145 | parser.add_argument( 146 | "--provabgs_file_test", 147 | type=str, 148 | default=f"{ASTROCLIP_ROOT}/datasets/provabgs/provabgs_paired_test.hdf5", 149 | ) 150 | parser.add_argument( 151 | "--pretrained_dir", 152 | type=str, 153 | default=f"{ASTROCLIP_ROOT}/pretrained/", 154 | ) 155 | parser.add_argument("--batch_size", type=int, default=512) 156 | args = parser.parse_args() 157 | 158 | embed_provabgs( 159 | args.provabgs_file_train, 160 | args.provabgs_file_test, 161 | args.pretrained_dir, 162 | args.batch_size, 163 | ) 164 | -------------------------------------------------------------------------------- /downstream_tasks/property_estimation/property_estimation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os, sys\n", 10 | "import numpy as np\n", 11 | "import torch\n", 12 | "from astropy.table import Table\n", 13 | "from sklearn.preprocessing import StandardScaler\n", 14 | "from sklearn.metrics import r2_score\n", 15 | "\n", 16 | "sys.path.append(\"../..\")\n", 17 | "\n", 18 | "from astroclip.env import format_with_env\n", 19 | "from property_utils.models import few_shot, zero_shot\n", 20 | "from property_utils.plotting import plot_scatter\n", 21 | "\n", 22 | "ASTROCLIP_ROOT = format_with_env(\"{ASTROCLIP_ROOT}\")\n", 23 | "\n", 24 | "PROVABGS_ROOT = f\"{ASTROCLIP_ROOT}/datasets/provabgs/\"\n", 25 | "SUPERVISED_ROOT = f\"{ASTROCLIP_ROOT}/supervised/\"\n", 26 | "\n", 27 | "# Define models in embeddings\n", 28 | "image_models = [\"astroclip_image\", \"astrodino\", \"stein\"]\n", 29 | "spectrum_models = [\"astroclip_spectrum\", \"specformer\"]\n", 30 | "\n", 31 | "# Set up the paths\n", 32 | "train_path = os.path.join(PROVABGS_ROOT, \"provabgs_paired_train_embeddings.hdf5\")\n", 33 | "test_path = os.path.join(PROVABGS_ROOT, \"provabgs_paired_test_embeddings.hdf5\")\n", 34 | "\n", 35 | "# Get embeddings and PROVABGS table\n", 36 | "train_provabgs = Table.read(train_path)\n", 37 | "test_provabgs = Table.read(test_path)\n", 38 | "\n", 39 | "# Get properties and scale\n", 40 | "properties = [\"Z_MW\", \"LOG_MSTAR\", \"TAGE_MW\", \"sSFR\"]\n", 41 | "y_train = np.stack([train_provabgs[prop].data.squeeze() for prop in properties]).T\n", 42 | "y_test = np.stack([test_provabgs[prop].data.squeeze() for prop in properties]).T\n", 43 | "scaler = {\"mean\": y_train.mean(axis=0), \"std\": y_train.std(axis=0)}\n", 44 | "y_train = (y_train - scaler[\"mean\"]) / scaler[\"std\"]\n", 45 | "\n", 46 | "print(\n", 47 | " \"Size of training set:\",\n", 48 | " len(train_provabgs),\n", 49 | " \"\\nSize of test set:\",\n", 50 | " len(test_provabgs),\n", 51 | ")" 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "metadata": {}, 57 | "source": [ 58 | "# Galaxy Property Prediction from Image Embeddings" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "# Get data\n", 68 | "data = {}\n", 69 | "for model in image_models:\n", 70 | " data[model] = {}\n", 71 | " X_train, X_test = (\n", 72 | " train_provabgs[model + \"_embeddings\"],\n", 73 | " test_provabgs[model + \"_embeddings\"],\n", 74 | " )\n", 75 | " embedding_scaler = StandardScaler().fit(X_train)\n", 76 | " data[model][\"train\"] = embedding_scaler.transform(X_train)\n", 77 | " data[model][\"test\"] = embedding_scaler.transform(X_test)" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "# Perfrom knn and mlp\n", 87 | "preds_knn, preds_mlp = {}, {}\n", 88 | "for key in data.keys():\n", 89 | " print(f\"Evaluating {key} model...\")\n", 90 | " raw_preds_knn = zero_shot(data[key][\"train\"], y_train, data[key][\"test\"])\n", 91 | " raw_preds_mlp = few_shot(\n", 92 | " model, data[key][\"train\"], y_train, data[key][\"test\"]\n", 93 | " ).squeeze()\n", 94 | " preds_knn[key] = raw_preds_knn * scaler[\"std\"] + scaler[\"mean\"]\n", 95 | " preds_mlp[key] = raw_preds_mlp * scaler[\"std\"] + scaler[\"mean\"]" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "# Make a table of r^2 scores\n", 105 | "knn_r2 = {key: [] for key in preds_knn.keys()}\n", 106 | "mlp_r2 = {key: [] for key in preds_mlp.keys()}\n", 107 | "\n", 108 | "for key in preds_knn.keys():\n", 109 | " for i, prop in enumerate(properties):\n", 110 | " knn_r2[key].append(r2_score(y_test[:, i], preds_knn[key][:, i]))\n", 111 | " mlp_r2[key].append(r2_score(y_test[:, i], preds_mlp[key][:, i]))\n", 112 | "\n", 113 | "knn_r2[\"properties\"] = properties\n", 114 | "mlp_r2[\"properties\"] = properties" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "Table(knn_r2)" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "Table(mlp_r2)" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "# Get predictions from supervised models\n", 142 | "resnet_preds = torch.load(\n", 143 | " os.path.join(SUPERVISED_ROOT, \"image/ResNet18/global_properties/test_pred.pt\")\n", 144 | ")\n", 145 | "photometry_preds = torch.load(\n", 146 | " os.path.join(SUPERVISED_ROOT, \"photometry/MLP/global_properties/test_pred.pt\")\n", 147 | ")\n", 148 | "\n", 149 | "# Add predictions to dictionary\n", 150 | "preds_supervised = {\n", 151 | " \"resnet18\": np.stack([resnet_preds[prop].squeeze() for prop in properties]).T,\n", 152 | " \"photometry\": np.stack([photometry_preds[prop].squeeze() for prop in properties]).T,\n", 153 | "}\n", 154 | "\n", 155 | "supervised_r2 = {key: [] for key in preds_supervised.keys()}\n", 156 | "for key in preds_supervised.keys():\n", 157 | " for i, prop in enumerate(properties):\n", 158 | " supervised_r2[key].append(r2_score(y_test[:, i], preds_supervised[key][:, i]))\n", 159 | "\n", 160 | "supervised_r2[\"properties\"] = properties\n", 161 | "Table(supervised_r2)" 162 | ] 163 | }, 164 | { 165 | "cell_type": "markdown", 166 | "metadata": {}, 167 | "source": [ 168 | "# Galaxy Property Prediction from Spectrum Embeddings" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": null, 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [ 177 | "# Get data\n", 178 | "data = {}\n", 179 | "for model in spectrum_models:\n", 180 | " data[model] = {}\n", 181 | " X_train, X_test = (\n", 182 | " train_provabgs[model + \"_embeddings\"],\n", 183 | " test_provabgs[model + \"_embeddings\"],\n", 184 | " )\n", 185 | " embedding_scaler = StandardScaler().fit(X_train)\n", 186 | " data[model][\"train\"] = embedding_scaler.transform(X_train)\n", 187 | " data[model][\"test\"] = embedding_scaler.transform(X_test)" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": null, 193 | "metadata": {}, 194 | "outputs": [], 195 | "source": [ 196 | "# Perfrom knn and mlp\n", 197 | "preds_knn, preds_mlp = {}, {}\n", 198 | "for key in data.keys():\n", 199 | " print(f\"Evaluating {key} model...\")\n", 200 | " raw_preds_knn = zero_shot(data[key][\"train\"], y_train, data[key][\"test\"])\n", 201 | " raw_preds_mlp = few_shot(\n", 202 | " model, data[key][\"train\"], y_train, data[key][\"test\"]\n", 203 | " ).squeeze()\n", 204 | " preds_knn[key] = raw_preds_knn * scaler[\"std\"] + scaler[\"mean\"]\n", 205 | " preds_mlp[key] = raw_preds_mlp * scaler[\"std\"] + scaler[\"mean\"]" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": null, 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [ 214 | "# Make a table of r^2 scores\n", 215 | "knn_r2 = {key: [] for key in preds_knn.keys()}\n", 216 | "mlp_r2 = {key: [] for key in preds_mlp.keys()}\n", 217 | "\n", 218 | "for key in preds_knn.keys():\n", 219 | " for i, prop in enumerate(properties):\n", 220 | " knn_r2[key].append(r2_score(y_test[:, i], preds_knn[key][:, i]))\n", 221 | " mlp_r2[key].append(r2_score(y_test[:, i], preds_mlp[key][:, i]))\n", 222 | "\n", 223 | "knn_r2[\"properties\"] = properties\n", 224 | "mlp_r2[\"properties\"] = properties" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": null, 230 | "metadata": {}, 231 | "outputs": [], 232 | "source": [ 233 | "Table(knn_r2)" 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": null, 239 | "metadata": {}, 240 | "outputs": [], 241 | "source": [ 242 | "Table(mlp_r2)" 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": null, 248 | "metadata": {}, 249 | "outputs": [], 250 | "source": [ 251 | "# Get predictions from supervised models\n", 252 | "spectrum_preds = torch.load(\n", 253 | " os.path.join(SUPERVISED_ROOT, \"spectrum/Conv+Att/global_properties/test_pred.pt\")\n", 254 | ")\n", 255 | "\n", 256 | "# Add predictions to dictionary\n", 257 | "preds_supervised = {\n", 258 | " \"conv+att\": np.stack([spectrum_preds[prop].squeeze() for prop in properties]).T,\n", 259 | "}\n", 260 | "\n", 261 | "supervised_r2 = {key: [] for key in preds_supervised.keys()}\n", 262 | "for key in preds_supervised.keys():\n", 263 | " for i, prop in enumerate(properties):\n", 264 | " supervised_r2[key].append(r2_score(y_test[:, i], preds_supervised[key][:, i]))\n", 265 | "\n", 266 | "supervised_r2[\"properties\"] = properties\n", 267 | "Table(supervised_r2)" 268 | ] 269 | } 270 | ], 271 | "metadata": { 272 | "kernelspec": { 273 | "display_name": "toto", 274 | "language": "python", 275 | "name": "toto" 276 | } 277 | }, 278 | "nbformat": 4, 279 | "nbformat_minor": 2 280 | } 281 | -------------------------------------------------------------------------------- /downstream_tasks/property_estimation/property_utils/cross_match.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append("../..") 5 | 6 | from argparse import ArgumentParser 7 | 8 | import numpy as np 9 | from astropy.table import Table, join 10 | from datasets import load_from_disk 11 | from provabgs import models as Models 12 | from torchvision.transforms import CenterCrop, Compose 13 | from tqdm import tqdm 14 | 15 | from astroclip.data.datamodule import AstroClipCollator, AstroClipDataloader 16 | from astroclip.env import format_with_env 17 | 18 | provabgs_file = "https://data.desi.lbl.gov/public/edr/vac/edr/provabgs/v1.0/BGS_ANY_full.provabgs.sv3.v0.hdf5" 19 | 20 | 21 | def _download_data(save_path: str): 22 | """Download the PROVABGS data from the web and save it to the specified directory.""" 23 | # Check if the save path exists 24 | if not os.path.exists(save_path): 25 | os.makedirs(save_path) 26 | 27 | # Download the PROVABGS file 28 | local_path = os.path.join(save_path, "BGS_ANY_full.provabgs.sv3.v0.hdf5") 29 | if not os.path.exists(local_path): 30 | print("Downloading PROVABGS data...") 31 | os.system(f"wget {provabgs_file} -O {local_path}") 32 | print("Downloaded PROVABGS data successfully!") 33 | else: 34 | print("PROVABGS data already exists!") 35 | 36 | 37 | def _get_best_fit(provabgs: Table): 38 | """Get the best fit model for each galaxy.""" 39 | m_nmf = Models.NMF(burst=True, emulator=True) 40 | 41 | # Filter out galaxies with no best fit model 42 | provabgs = provabgs[ 43 | (provabgs["PROVABGS_LOGMSTAR_BF"] > 0) 44 | * (provabgs["MAG_G"] > 0) 45 | * (provabgs["MAG_R"] > 0) 46 | * (provabgs["MAG_Z"] > 0) 47 | ] 48 | 49 | # Get the thetas and redshifts for each galaxy 50 | thetas = provabgs["PROVABGS_THETA_BF"][:, :12] 51 | zreds = provabgs["Z_HP"] 52 | 53 | Z_mw = [] # Stellar Metallicitiy 54 | tage_mw = [] # Age 55 | avg_sfr = [] # Star-Forming Region 56 | 57 | print("Calculating best-fit properties using the PROVABGS model...") 58 | for i in tqdm(range(len(thetas))): 59 | theta = thetas[i] 60 | zred = zreds[i] 61 | 62 | # Calculate properties using the PROVABGS model 63 | Z_mw.append(m_nmf.Z_MW(theta, zred=zred)) 64 | tage_mw.append(m_nmf.tage_MW(theta, zred=zred)) 65 | avg_sfr.append(m_nmf.avgSFR(theta, zred=zred)) 66 | 67 | # Add the properties to the table 68 | provabgs["Z_MW"] = np.array(Z_mw) 69 | provabgs["TAGE_MW"] = np.array(tage_mw) 70 | provabgs["AVG_SFR"] = np.array(avg_sfr) 71 | return provabgs 72 | 73 | 74 | def cross_match_provabgs( 75 | astroclip_path: str, 76 | provabgs_path: str, 77 | save_path: str = None, 78 | batch_size: int = 128, 79 | num_workers: int = 20, 80 | ): 81 | """Cross-match the AstroCLIP and PROVABGS datasets.""" 82 | 83 | # Download the PROVABGS data if it doesn't exist 84 | if not os.path.exists(provabgs_path): 85 | _download_data(provabgs_path) 86 | 87 | # Load the AstroCLIP dataset 88 | dataloader = AstroClipDataloader( 89 | astroclip_path, 90 | batch_size=batch_size, 91 | num_workers=num_workers, 92 | collate_fn=AstroClipCollator(), 93 | columns=["image", "targetid", "spectrum"], 94 | ) 95 | dataloader.setup("fit") 96 | 97 | # Process the images 98 | train_images, train_spectra, train_targetids = [], [], [] 99 | for batch in tqdm(dataloader.train_dataloader(), desc="Processing train images"): 100 | train_images.append(batch["image"]) 101 | train_spectra.append(batch["spectrum"]) 102 | train_targetids.append(batch["targetid"]) 103 | 104 | test_images, test_spectra, test_targetids = [], [], [] 105 | for batch in tqdm(dataloader.val_dataloader(), desc="Processing test images"): 106 | test_images.append(batch["image"]) 107 | test_spectra.append(batch["spectrum"]) 108 | test_targetids.append(batch["targetid"]) 109 | 110 | print(f"Shape of images is {np.concatenate(train_images).shape[1:]}", flush=True) 111 | 112 | # Create tables for the train and test datasets 113 | train_table = Table( 114 | { 115 | "targetid": np.concatenate(train_targetids), 116 | "image": np.concatenate(train_images), 117 | "spectrum": np.concatenate(train_spectra), 118 | } 119 | ) 120 | test_table = Table( 121 | { 122 | "targetid": np.concatenate(test_targetids), 123 | "image": np.concatenate(test_images), 124 | "spectrum": np.concatenate(test_spectra), 125 | } 126 | ) 127 | 128 | # Load the PROVABGS dataset 129 | provabgs = Table.read(provabgs_path) 130 | 131 | # Filter out galaxies with no best fit model 132 | provabgs = provabgs[ 133 | (provabgs["PROVABGS_LOGMSTAR_BF"] > 0) 134 | * (provabgs["MAG_G"] > 0) 135 | * (provabgs["MAG_R"] > 0) 136 | * (provabgs["MAG_Z"] > 0) 137 | ] 138 | 139 | # Get the best fit model for each galaxy 140 | print("Getting property best fit with PROVABGS SED model") 141 | provabgs = _get_best_fit(provabgs) 142 | 143 | # Scale the properties 144 | provabgs["LOG_MSTAR"] = provabgs["PROVABGS_LOGMSTAR_BF"].data 145 | provabgs["sSFR"] = np.log(provabgs["AVG_SFR"].data) - np.log(provabgs["Z_MW"].data) 146 | provabgs["Z_MW"] = np.log(provabgs["Z_MW"].data) 147 | 148 | # Join the PROVABGS and AstroCLIP datasets 149 | train_provabgs = join( 150 | provabgs, train_table, keys_left="TARGETID", keys_right="targetid" 151 | ) 152 | test_provabgs = join( 153 | provabgs, test_table, keys_left="TARGETID", keys_right="targetid" 154 | ) 155 | print("Number of galaxies in train:", len(train_provabgs)) 156 | print("Number of galaxies in test:", len(test_provabgs)) 157 | 158 | # Save the paired datasets 159 | if save_path is None: 160 | train_provabgs.write( 161 | provabgs_path.replace("provabgs.hdf5", "provabgs_paired_train.hdf5"), 162 | overwrite=True, 163 | ) 164 | test_provabgs.write( 165 | provabgs_path.replace("provabgs.hdf5", "provabgs_paired_test.hdf5"), 166 | overwrite=True, 167 | ) 168 | 169 | 170 | if __name__ == "__main__": 171 | ASTROCLIP_ROOT = format_with_env("{ASTROCLIP_ROOT}") 172 | parser = ArgumentParser() 173 | parser.add_argument( 174 | "--astroclip_path", 175 | type=str, 176 | default=f"{ASTROCLIP_ROOT}/datasets/astroclip_file/", 177 | help="Path to the AstroCLIP dataset.", 178 | ) 179 | parser.add_argument( 180 | "--provabgs_path", 181 | type=str, 182 | default=f"{ASTROCLIP_ROOT}/datasets/provabgs/provabgs.hdf5", 183 | help="Path to the PROVABGS dataset.", 184 | ) 185 | parser.add_argument( 186 | "--save_path", 187 | type=str, 188 | default=None, 189 | help="Path to save the paired datasets.", 190 | ) 191 | 192 | args = parser.parse_args() 193 | cross_match_provabgs( 194 | astroclip_path=args.astroclip_path, 195 | provabgs_path=args.provabgs_path, 196 | save_path=args.save_path, 197 | ) 198 | -------------------------------------------------------------------------------- /downstream_tasks/property_estimation/property_utils/models.py: -------------------------------------------------------------------------------- 1 | import lightning as L 2 | import pyro.distributions as dist 3 | import pyro.distributions.transforms as T 4 | import torch 5 | import torch.nn.functional as F 6 | import torchvision 7 | from lightning import Trainer 8 | from numpy import ndarray 9 | from sklearn.neighbors import KNeighborsRegressor 10 | from torch import nn 11 | from torch.utils.data import DataLoader, TensorDataset 12 | from tqdm import tqdm 13 | 14 | 15 | def few_shot( 16 | model: nn.Module, 17 | X_train: ndarray, 18 | y_train: ndarray, 19 | X_test: ndarray, 20 | max_epochs: int = 10, 21 | hidden_dims: list[int] = [64, 64], 22 | lr: float = 1e-3, 23 | ) -> ndarray: 24 | """Train a few-shot model using a simple neural network""" 25 | train_dataset = TensorDataset( 26 | torch.tensor(X_train, dtype=torch.float32), 27 | torch.tensor(y_train, dtype=torch.float32), 28 | ) 29 | train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) 30 | 31 | num_features = y_train.shape[1] if len(y_train.shape) > 1 else 1 32 | model = MLP( 33 | n_in=X_train.shape[1], 34 | n_out=num_features, 35 | n_hidden=hidden_dims, 36 | act=[nn.ReLU()] * (len(hidden_dims) + 1), 37 | dropout=0.1, 38 | ) 39 | 40 | # Set up the model 41 | criterion = nn.MSELoss() 42 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 43 | 44 | # Train the model 45 | model.cuda() 46 | model.train() 47 | for epoch in range(max_epochs): 48 | running_loss = 0.0 49 | for i, data in enumerate(train_loader, 0): 50 | inputs, labels = data 51 | optimizer.zero_grad() 52 | outputs = model(inputs.cuda()).squeeze() 53 | loss = criterion(outputs, labels.cuda()) 54 | loss.backward() 55 | optimizer.step() 56 | running_loss += loss.item() 57 | 58 | # Make predictions 59 | model.eval() 60 | with torch.no_grad(): 61 | preds = model(torch.tensor(X_test, dtype=torch.float32).cuda()).cpu().numpy() 62 | return preds 63 | 64 | 65 | def zero_shot( 66 | X_train: ndarray, y_train: ndarray, X_test: ndarray, n_neighbors: int = 64 67 | ) -> ndarray: 68 | """Train a zero-shot model using KNN""" 69 | neigh = KNeighborsRegressor(weights="distance", n_neighbors=64) 70 | neigh.fit(X_train, y_train) 71 | preds = neigh.predict(X_test) 72 | return preds 73 | 74 | 75 | class MLP(nn.Sequential): 76 | """MLP model""" 77 | 78 | def __init__(self, n_in, n_out, n_hidden=(16, 16, 16), act=None, dropout=0): 79 | if act is None: 80 | act = [ 81 | nn.LeakyReLU(), 82 | ] * (len(n_hidden) + 1) 83 | assert len(act) == len(n_hidden) + 1 84 | 85 | layer = [] 86 | n_ = [n_in, *n_hidden, n_out] 87 | for i in range(len(n_) - 2): 88 | layer.append(nn.Linear(n_[i], n_[i + 1])) 89 | layer.append(act[i]) 90 | layer.append(nn.Dropout(p=dropout)) 91 | layer.append(nn.Linear(n_[-2], n_[-1])) 92 | super(MLP, self).__init__(*layer) 93 | 94 | 95 | class ConditionalFlowStack(dist.conditional.ConditionalComposeTransformModule): 96 | """Normalizing flow stack for conditional distribution""" 97 | 98 | def __init__( 99 | self, 100 | input_dim: int, 101 | context_dim: int, 102 | hidden_dims: int, 103 | num_flows: int, 104 | device: str = "cuda", 105 | ): 106 | coupling_transforms = [ 107 | T.conditional_spline( 108 | input_dim, 109 | context_dim, 110 | count_bins=8, 111 | hidden_dims=hidden_dims, 112 | order="quadratic", 113 | ).to(device) 114 | for _ in range(num_flows) 115 | ] 116 | 117 | super().__init__(coupling_transforms, cache_size=1) 118 | -------------------------------------------------------------------------------- /downstream_tasks/property_estimation/property_utils/plotting.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import seaborn as sns 4 | from sklearn.metrics import r2_score 5 | 6 | 7 | def plot_scatter( 8 | preds: dict, 9 | z_test: np.ndarray, 10 | data_lower_lim: float = 0.0, 11 | data_upper_lim: float = 0.6, 12 | save_loc: str = "scatter.png", 13 | ) -> None: 14 | """Functionality to plot redshift scatter plots for different models.""" 15 | fig, ax = plt.subplots(2, len(preds.keys()), figsize=(16, 10)) 16 | 17 | for i, name in enumerate(preds.keys()): 18 | sns.scatterplot(ax=ax[0, i], x=z_test, y=preds[name], s=5, color=".15") 19 | sns.histplot( 20 | ax=ax[0, i], x=z_test, y=preds[name], bins=50, pthresh=0.1, cmap="mako" 21 | ) 22 | sns.kdeplot( 23 | ax=ax[0, i], x=z_test, y=preds[name], levels=5, color="k", linewidths=1 24 | ) 25 | 26 | ax[0, i].plot( 27 | data_lower_lim, 28 | data_upper_lim * 1.1, 29 | "--", 30 | linewidth=1.5, 31 | alpha=0.5, 32 | color="grey", 33 | ) 34 | ax[0, i].set_xlim(data_lower_lim, data_upper_lim) 35 | ax[0, i].set_ylim(data_lower_lim, data_upper_lim) 36 | ax[0, i].text( 37 | 0.9, 38 | 0.1, 39 | "$R^2$ score: %0.2f" % r2_score(z_test, preds[name]), 40 | horizontalalignment="right", 41 | verticalalignment="top", 42 | fontsize=22, 43 | transform=ax[0, i].transAxes, 44 | ) 45 | ax[0, i].set_title(name, fontsize=25) 46 | 47 | ax[0, 0].set_ylabel("$Z_{pred}$", fontsize=25) 48 | 49 | for i, name in enumerate(preds.keys()): 50 | x = z_test 51 | y = (z_test - preds[name]) / (1 + z_test) 52 | 53 | bins = np.linspace(data_lower_lim, data_upper_lim * 1.05, 20) 54 | x_binned = np.digitize(x, bins) 55 | y_avg = [y[x_binned == i].mean() for i in range(1, len(bins))] 56 | y_std = [y[x_binned == i].std() for i in range(1, len(bins))] 57 | 58 | sns.scatterplot(ax=ax[1, i], x=x, y=y, s=2, alpha=0.3, color="black") 59 | sns.lineplot(ax=ax[1, i], x=bins[:-1], y=y_std, color="r", label="std") 60 | 61 | # horizontal line on y = 0 62 | ax[1, i].axhline(0, color="grey", linewidth=1.5, alpha=0.5, linestyle="--") 63 | 64 | # sns.scatterplot(ax=ax[1,i], x=bins[:-1], y=y_avg, s=15, color='.15') 65 | ax[1, i].set_xlim(data_lower_lim, data_upper_lim) 66 | ax[1, i].set_ylim(-data_upper_lim / 2, data_upper_lim / 2) 67 | ax[1, i].set_xlabel("$Z_{true}$", fontsize=25) 68 | ax[1, i].legend(fontsize=15, loc="upper right") 69 | 70 | ax[1, 0].set_ylabel("$(Z_{true}-Z_{pred})/(1+Z_{true})$", fontsize=25) 71 | 72 | plt.tight_layout(rect=[0, 0, 1, 0.97]) 73 | plt.savefig(save_loc, dpi=300) 74 | -------------------------------------------------------------------------------- /downstream_tasks/similarity_search/README.md: -------------------------------------------------------------------------------- 1 | ## In-Modal and Cross-Modal Retrieval 2 | AstroCLIP enables researchers to easily find similar galaxies to a query galaxy by simply exploiting the cosine similarity between galaxy embeddings in embedding space. Because AstroCLIP's embedding space is shared between both galaxy images and optical spectra, retrieval can be performed for both in-modal and cross-modal similarity searches. 3 | 4 | ### Embedding the dataset 5 | To perform retrieval on the held-out validation set, it is important to first generate AstroCLIP embeddings of the galaxy images and spectra. To embed the held-out validation dataset, do the following: 6 | ```python 7 | python embed_astroclip.py [save_path] 8 | ``` 9 | 10 | ### Similarity Search 11 | Once embedded, the ```similarity_search.ipynb``` jupyter notebook contains a brief tutorial that demonstrates the retrieval abilities of the model. 12 | -------------------------------------------------------------------------------- /downstream_tasks/similarity_search/embed_astroclip.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | import h5py 4 | import numpy as np 5 | import torch 6 | from tqdm import tqdm 7 | 8 | from astroclip.data.datamodule import AstroClipCollator, AstroClipDataloader 9 | from astroclip.env import format_with_env 10 | from astroclip.models.astroclip import AstroClipModel 11 | 12 | 13 | def embed_astroclip( 14 | model_path: str, 15 | dataset_path: str, 16 | save_path: str, 17 | max_size: int = None, 18 | batch_size: int = 256, 19 | loader_type: str = "val", 20 | ): 21 | """Extract embeddings from the AstroClip model and save them to a file""" 22 | # Load the model 23 | astroclip = AstroClipModel.load_from_checkpoint(model_path) 24 | 25 | # Get the dataloader 26 | loader = AstroClipDataloader( 27 | path=dataset_path, 28 | batch_size=batch_size, 29 | num_workers=0, 30 | collate_fn=AstroClipCollator(), 31 | columns=["image", "spectrum", "targetid"], 32 | ) 33 | loader.setup("fit") 34 | 35 | # Set up loader 36 | if loader_type == "train": 37 | loader = loader.train_dataloader() 38 | elif loader_type == "val": 39 | loader = loader.val_dataloader() 40 | else: 41 | raise ValueError("loader must be either 'train' or 'val'") 42 | 43 | # Get the embeddings over the dataset 44 | im_embeddings, sp_embeddings, images, spectra, obj_ids = [], [], [], [], [] 45 | with torch.no_grad(): 46 | for idx, batch_test in tqdm(enumerate(loader), desc="Extracting embeddings"): 47 | # Break if max_size is reached 48 | if max_size is not None and idx * batch_size >= max_size: 49 | break 50 | 51 | # Append the image and spectrum to the list 52 | obj_ids.append(batch_test["targetid"]) 53 | 54 | # Extract the embeddings 55 | im_embeddings.append( 56 | astroclip(batch_test["image"].cuda(), input_type="image") 57 | .detach() 58 | .cpu() 59 | .numpy() 60 | ) 61 | sp_embeddings.append( 62 | astroclip(batch_test["spectrum"].cuda(), input_type="spectrum") 63 | .detach() 64 | .cpu() 65 | .numpy() 66 | ) 67 | images.append(batch_test["image"]) 68 | spectra.append(batch_test["spectrum"]) 69 | 70 | # Save as an HDF5 file 71 | with h5py.File(save_path, "w") as f: 72 | f.create_dataset("image_embeddings", data=np.concatenate(im_embeddings)) 73 | f.create_dataset("spectrum_embeddings", data=np.concatenate(sp_embeddings)) 74 | f.create_dataset("object_id", data=np.concatenate(obj_ids)) 75 | f.create_dataset("image", data=np.concatenate(images)) 76 | f.create_dataset("spectrum", data=np.concatenate(spectra)) 77 | print(f"Embeddings saved to {save_path}") 78 | 79 | 80 | if __name__ == "__main__": 81 | ASTROCLIP_ROOT = format_with_env("{ASTROCLIP_ROOT}") 82 | parser = ArgumentParser() 83 | parser.add_argument( 84 | "save_path", 85 | type=str, 86 | help="Path to save the embeddings", 87 | ) 88 | parser.add_argument( 89 | "--model_path", 90 | type=str, 91 | help="Path to the model", 92 | default=f"{ASTROCLIP_ROOT}/pretrained/astroclip.ckpt", 93 | ) 94 | parser.add_argument( 95 | "--dataset_path", 96 | type=str, 97 | help="Path to the dataset", 98 | default=f"{ASTROCLIP_ROOT}/datasets/astroclip_file/", 99 | ) 100 | parser.add_argument( 101 | "--batch_size", 102 | type=int, 103 | help="Batch size", 104 | default=256, 105 | ) 106 | parser.add_argument( 107 | "--max_size", 108 | type=int, 109 | help="Maximum number of samples to use", 110 | default=None, 111 | ) 112 | parser.add_argument( 113 | "--loader_type", 114 | type=str, 115 | help="Which loader to use (train or val)", 116 | default="val", 117 | ) 118 | args = parser.parse_args() 119 | embed_astroclip( 120 | args.model_path, 121 | args.dataset_path, 122 | args.save_path, 123 | args.max_size, 124 | args.batch_size, 125 | args.loader_type, 126 | ) 127 | -------------------------------------------------------------------------------- /downstream_tasks/similarity_search/plotting.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | from matplotlib import pyplot as plt 5 | from scipy.ndimage import gaussian_filter1d 6 | 7 | 8 | def plot_similar_images( 9 | query_images: list, 10 | sims: dict, 11 | similarity_type: str = "im_sim", 12 | num_retrievals: int = 8, 13 | save_dir: str = None, 14 | ): 15 | """Functionality for plotting retrieved galaxy images""" 16 | plt.figure(figsize=[19.4, 6.1]) 17 | for n, img in enumerate(query_images): 18 | plt.subplot(len(query_images), 13, n * 13 + 1) 19 | plt.imshow(img.T) 20 | plt.axis("off") 21 | for j in range(num_retrievals): 22 | plt.subplot(len(query_images), 13, n * 13 + j + 1 + 1) 23 | plt.imshow(sims[n][similarity_type][j].T) 24 | plt.axis("off") 25 | plt.subplots_adjust(wspace=0.01, hspace=0.0) 26 | plt.subplots_adjust(wspace=0.00, hspace=0.01) 27 | 28 | if save_dir is not None: 29 | if not os.path.exists(save_dir): 30 | os.makedirs(save_dir) 31 | plt.savefig(os.path.join(save_dir, f"retrieval_{similarity_type}.png")) 32 | 33 | 34 | def plot_similar_spectra( 35 | query_spectra: list, 36 | query_images: list, 37 | sims: dict, 38 | similarity_type: str = "im_sim", 39 | num_retrievals: int = 5, 40 | save_dir: str = None, 41 | ): 42 | """Functionality for plotting retrieved galaxy spectra""" 43 | l = np.linspace(3586.7408577, 10372.89543574, query_spectra[0].shape[0]) 44 | figure = plt.figure(figsize=[15, 5]) 45 | colors = ["r", "b", "g", "y", "m"] 46 | for n, sp in enumerate(query_spectra): 47 | plt.subplot(1, len(query_spectra), n + 1) 48 | plt.plot( 49 | l, 50 | gaussian_filter1d(sp[:, 0], 5), 51 | color=colors[n], 52 | lw=1, 53 | label="spectrum of query image", 54 | ) 55 | 56 | for j in range(num_retrievals): 57 | if j == 0: 58 | plt.plot( 59 | l, 60 | gaussian_filter1d(sims[n][similarity_type][j + 1][:, 0], 5), 61 | alpha=0.5, 62 | lw=1, 63 | color="gray", 64 | label="retrieved spectra", 65 | ) 66 | else: 67 | plt.plot( 68 | l, 69 | gaussian_filter1d(sims[n][similarity_type][j + 1][:, 0], 5), 70 | alpha=0.5, 71 | lw=1, 72 | color="gray", 73 | ) 74 | # set y lim 75 | plt.ylim(1.1 * min(sp[:, 0]), 1.1 * max(sp[:, 0])) 76 | 77 | plt.xlabel(r"$\lambda$", fontsize=18) 78 | plt.ylabel("flux", fontsize=18) 79 | plt.legend(fontsize=18, loc="lower right") 80 | 81 | # Add inset image to the first subplot 82 | axins = plt.gca().inset_axes([0, 0.55, 0.4, 0.4]) 83 | image_data = query_images[n] 84 | axins.imshow(image_data.T) 85 | axins.axis("off") 86 | 87 | if save_dir is not None: 88 | if not os.path.exists(save_dir): 89 | os.makedirs(save_dir) 90 | plt.savefig(os.path.join(save_dir, f"retrieval_{similarity_type}.png")) 91 | -------------------------------------------------------------------------------- /downstream_tasks/similarity_search/similarity_search.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%pylab inline\n", 10 | "import os, sys\n", 11 | "\n", 12 | "sys.path.append(\"../..\")\n", 13 | "import numpy as np\n", 14 | "import h5py\n", 15 | "from tqdm import tqdm\n", 16 | "\n", 17 | "from astroclip.env import format_with_env\n", 18 | "from plotting import plot_similar_images, plot_similar_spectra\n", 19 | "\n", 20 | "ASTROCLIP_ROOT = format_with_env(\"{ASTROCLIP_ROOT}\")\n", 21 | "\n", 22 | "# Load the embeddings\n", 23 | "embedding_loc = f\"{ASTROCLIP_ROOT}/datasets/embeded_astroclip.hdf5\"\n", 24 | "with h5py.File(embedding_loc, \"r\") as f:\n", 25 | " images = f[\"image\"][:]\n", 26 | " spectra = f[\"spectrum\"][:]\n", 27 | " im_embeddings = f[\"image_embeddings\"][:]\n", 28 | " sp_embeddings = f[\"spectrum_embeddings\"][:]\n", 29 | " obj_ids = f[\"object_id\"][:]\n", 30 | "\n", 31 | "# Normalize the embeddings\n", 32 | "image_features_normed = im_embeddings / np.linalg.norm(\n", 33 | " im_embeddings, axis=-1, keepdims=True\n", 34 | ")\n", 35 | "spectrum_features_normed = sp_embeddings / np.linalg.norm(\n", 36 | " sp_embeddings, axis=-1, keepdims=True\n", 37 | ")" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "# Look at some randomly selected galaxies\n", 47 | "figure(figsize=[15, 15])\n", 48 | "for i in range(15):\n", 49 | " for j in range(15):\n", 50 | " subplot(15, 15, i * 15 + j + 1)\n", 51 | " imshow(images[i * 15 + j + 1000].T)\n", 52 | " title(i * 15 + j + 1000)\n", 53 | " axis(\"off\")\n", 54 | "plt.subplots_adjust(wspace=0.1, hspace=0.11)" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": {}, 60 | "source": [ 61 | "# Plot retrieved galaxy images" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "# Choose some galaxies to search for similar galaxies\n", 71 | "ind_query = [7, 354, 526, 300]\n", 72 | "\n", 73 | "# Find the indices of the galaxies in the dataset\n", 74 | "im_sims = []\n", 75 | "\n", 76 | "for ind in ind_query:\n", 77 | " # Compute the similarity between the query galaxy and all other galaxies\n", 78 | " sp_sim = spectrum_features_normed[ind] @ spectrum_features_normed.T\n", 79 | " im_sim = image_features_normed[ind] @ image_features_normed.T\n", 80 | " x_im_sim = image_features_normed[ind] @ spectrum_features_normed.T\n", 81 | " x_sp_sim = spectrum_features_normed[ind] @ image_features_normed.T\n", 82 | "\n", 83 | " # Find the 8 most similar galaxies (images)\n", 84 | " im_sims.append(\n", 85 | " {\n", 86 | " \"sp_sim\": [images[i] for i in argsort(sp_sim)[::-1][:8]],\n", 87 | " \"im_sim\": [images[i] for i in argsort(im_sim)[::-1][:8]],\n", 88 | " \"x_im_sim\": [images[i] for i in argsort(x_im_sim)[::-1][:8]],\n", 89 | " \"x_sp_sim\": [images[i] for i in argsort(x_sp_sim)[::-1][:8]],\n", 90 | " }\n", 91 | " )" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "# Image-image similarity\n", 101 | "plot_similar_images(\n", 102 | " [images[i] for i in ind_query],\n", 103 | " im_sims,\n", 104 | " similarity_type=\"im_sim\",\n", 105 | " num_retrievals=8,\n", 106 | " save_dir=\"../outputs/image_retrieval/\",\n", 107 | ")" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "# Spectrum-spectrum similarity\n", 117 | "plot_similar_images(\n", 118 | " [images[i] for i in ind_query],\n", 119 | " im_sims,\n", 120 | " similarity_type=\"sp_sim\",\n", 121 | " num_retrievals=8,\n", 122 | " save_dir=\"../outputs/image_retrieval/\",\n", 123 | ")" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "# Image-spectrum similarity\n", 133 | "plot_similar_images(\n", 134 | " [images[i] for i in ind_query],\n", 135 | " im_sims,\n", 136 | " similarity_type=\"x_im_sim\",\n", 137 | " num_retrievals=8,\n", 138 | " save_dir=\"../outputs/image_retrieval/\",\n", 139 | ")" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": null, 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "# Spectrum-image similarity\n", 149 | "plot_similar_images(\n", 150 | " [images[i] for i in ind_query],\n", 151 | " im_sims,\n", 152 | " similarity_type=\"x_sp_sim\",\n", 153 | " num_retrievals=8,\n", 154 | " save_dir=\"../outputs/image_retrieval/\",\n", 155 | ")" 156 | ] 157 | }, 158 | { 159 | "cell_type": "markdown", 160 | "metadata": {}, 161 | "source": [ 162 | "# Plot retrieved galaxy spectra" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": null, 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [ 171 | "# Choose some galaxies to search for similar galaxies\n", 172 | "ind_query = [7, 77]\n", 173 | "\n", 174 | "# Find the indices of the galaxies in the dataset\n", 175 | "sp_sims = []\n", 176 | "\n", 177 | "for ind in ind_query:\n", 178 | " # Compute the similarity between the query galaxy and all other galaxies\n", 179 | " sp_sim = spectrum_features_normed[ind] @ spectrum_features_normed.T\n", 180 | " im_sim = image_features_normed[ind] @ image_features_normed.T\n", 181 | " x_im_sim = image_features_normed[ind] @ spectrum_features_normed.T\n", 182 | " x_sp_sim = spectrum_features_normed[ind] @ image_features_normed.T\n", 183 | "\n", 184 | " # Find the 8 most similar galaxies (images)\n", 185 | " sp_sims.append(\n", 186 | " {\n", 187 | " \"sp_sim\": [spectra[i] for i in argsort(sp_sim)[::-1][:8]],\n", 188 | " \"im_sim\": [spectra[i] for i in argsort(im_sim)[::-1][:8]],\n", 189 | " \"x_im_sim\": [spectra[i] for i in argsort(x_im_sim)[::-1][:8]],\n", 190 | " \"x_sp_sim\": [spectra[i] for i in argsort(x_sp_sim)[::-1][:8]],\n", 191 | " }\n", 192 | " )" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": null, 198 | "metadata": {}, 199 | "outputs": [], 200 | "source": [ 201 | "# Image-image similarity\n", 202 | "plot_similar_spectra(\n", 203 | " [spectra[i] for i in ind_query],\n", 204 | " [images[i] for i in ind_query],\n", 205 | " sp_sims,\n", 206 | " similarity_type=\"im_sim\",\n", 207 | " save_dir=\"./outputs/spectrum_retrieval/\",\n", 208 | ")" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": null, 214 | "metadata": {}, 215 | "outputs": [], 216 | "source": [ 217 | "# Spectrum-spectrum similarity\n", 218 | "plot_similar_spectra(\n", 219 | " [spectra[i] for i in ind_query],\n", 220 | " [images[i] for i in ind_query],\n", 221 | " sp_sims,\n", 222 | " similarity_type=\"sp_sim\",\n", 223 | " save_dir=\"./outputs/spectrum_retrieval/\",\n", 224 | ")" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": null, 230 | "metadata": {}, 231 | "outputs": [], 232 | "source": [ 233 | "# Image-spectrum similarity\n", 234 | "plot_similar_spectra(\n", 235 | " [spectra[i] for i in ind_query],\n", 236 | " [images[i] for i in ind_query],\n", 237 | " sp_sims,\n", 238 | " similarity_type=\"x_im_sim\",\n", 239 | " save_dir=\"./outputs/spectrum_retrieval/\",\n", 240 | ")" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": null, 246 | "metadata": {}, 247 | "outputs": [], 248 | "source": [ 249 | "# Spectrum-image similarity\n", 250 | "plot_similar_spectra(\n", 251 | " [spectra[i] for i in ind_query],\n", 252 | " [images[i] for i in ind_query],\n", 253 | " sp_sims,\n", 254 | " similarity_type=\"x_sp_sim\",\n", 255 | " save_dir=\"./outputs/spectrum_retrieval/\",\n", 256 | ")" 257 | ] 258 | } 259 | ], 260 | "metadata": { 261 | "kernelspec": { 262 | "display_name": "toto", 263 | "language": "python", 264 | "name": "python3" 265 | }, 266 | "language_info": { 267 | "codemirror_mode": { 268 | "name": "ipython", 269 | "version": 3 270 | }, 271 | "file_extension": ".py", 272 | "mimetype": "text/x-python", 273 | "name": "python", 274 | "nbconvert_exporter": "python", 275 | "pygments_lexer": "ipython3", 276 | "version": "3.10.10" 277 | } 278 | }, 279 | "nbformat": 4, 280 | "nbformat_minor": 2 281 | } 282 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "setuptools-scm"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "AstroClip" 7 | authors = [ 8 | {name = "Liam Parker", email = "lparker@flatironinstitute.org"}, 9 | {name = "Leopoldo Sarra", email = "lsarra@flatironinstitute.prg"}, 10 | {name = "Francois Lanusse", email = "flanusse@flatironinstitute.org"}, 11 | {name = "Siavash Golkar", email = "sgolkar@flatironinstitute.org"}, 12 | {name = "Miles Cranmer", email = "mc2473@cam.ac.uk"}, 13 | ] 14 | description = "AstroCLIP: Cross-Modal Pre-Training for Astronomical Foundation Models" 15 | readme = "README.md" 16 | license = {text = "MIT"} 17 | classifiers = [ 18 | "Programming Language :: Python :: 3", 19 | "License :: OSI Approved :: MIT License", 20 | ] 21 | dynamic = ["dependencies", "version"] 22 | 23 | [tool.setuptools_scm] 24 | version_file = "astroclip/_version.py" 25 | 26 | [project.scripts] 27 | spectrum_trainer = "astroclip.trainer:main_cli" 28 | astroclip_trainer = "astroclip.trainer:main_cli" 29 | image_trainer = "astroclip.astrodino.trainer:main_cli" 30 | 31 | [tool.setuptools] 32 | packages = ["astroclip"] 33 | 34 | [tool.setuptools.dynamic] 35 | dependencies = { file = "requirements.txt" } 36 | 37 | [tool.isort] 38 | profile = "black" 39 | src_paths = ["astroclip"] 40 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | astropy 2 | datasets 3 | dinov2 @ git+https://github.com/facebookresearch/dinov2.git@2302b6bf46953431b969155307b9bed152754069 4 | huggingface_hub 5 | jaxtyping 6 | lightning[extra] 7 | plotly 8 | provabgs @ git+https://github.com/changhoonhahn/provabgs.git 9 | pycairo 10 | pyro-ppl 11 | python-dotenv 12 | scikit-image 13 | scikit-learn 14 | torchvision==0.15.0 15 | wandb 16 | -------------------------------------------------------------------------------- /scripts/README.md: -------------------------------------------------------------------------------- 1 | 2 | ## Dataset generation 3 | 4 | The following scripts are used to generate the datasets used in the paper: 5 | 6 | - `cross_match_data.py`: Finds spectra for objects in the Legacy Survey 7 | data prepared by George Stein (https://github.com/georgestein/ssl-legacysurvey/tree/main) 8 | 9 | - `export_data.py`: Exports the combination of images and spectra into 10 | a single HDF5 file. 11 | 12 | In principle you should not need to run these scripts, as the datasets are 13 | already provided by the resulting HuggingFace datasets. However, these 14 | scripts are provided for reproducibility purposes. 15 | -------------------------------------------------------------------------------- /scripts/cross_match_data.py: -------------------------------------------------------------------------------- 1 | # This script will cross match the cutouts from the legacy survey 2 | # cutouts with spectra from the DESI EDR 3 | import glob 4 | 5 | import h5py 6 | import numpy as np 7 | import pandas as pd 8 | from astropy.table import Table, join, vstack 9 | from dl import authClient as ac 10 | from dl import queryClient as qc 11 | from sparcl.client import SparclClient 12 | from tqdm import tqdm 13 | 14 | DATA_DIR = "/mnt/home/flanusse/ceph" 15 | 16 | client = SparclClient() 17 | inc = [ 18 | "specid", 19 | "redshift", 20 | "flux", 21 | "ra", 22 | "dec", 23 | "wavelength", 24 | "spectype", 25 | "specprimary", 26 | "survey", 27 | "program", 28 | "targetid", 29 | "coadd_fiberstatus", 30 | ] 31 | 32 | 33 | print("Retrieving all objects in the DESI data release...") 34 | query = """ 35 | SELECT phot.targetid, phot.brickid, phot.brick_objid, phot.release, zpix.healpix 36 | FROM desi_edr.photometry AS phot 37 | INNER JOIN desi_edr.zpix ON phot.targetid = zpix.targetid 38 | WHERE (zpix.coadd_fiberstatus = 0 AND zpix.sv_primary) 39 | """ 40 | cat = qc.query(sql=query, fmt="table") 41 | print("done") 42 | # Building search key based on brick ids 43 | cat["key"] = [ 44 | "%d_%d_%d" % (cat["release"][i], cat["brickid"][i], cat["brick_objid"][i]) 45 | for i in range(len(cat)) 46 | ] 47 | 48 | merged_cat = None 49 | 50 | # Looping over the downloaded image files 51 | for file in tqdm(glob.glob(DATA_DIR + "/*.h5")): 52 | try: 53 | with h5py.File(file) as d: 54 | # search key 55 | d_key = np.array( 56 | [ 57 | "%d_%d_%d" % (d["release"][i], d["brickid"][i], d["objid"][i]) 58 | for i in range(len(d["brickid"])) 59 | ] 60 | ) 61 | t = Table(data=[d["inds"][:], d_key], names=["inds", "key"]) 62 | except: 63 | continue 64 | file_cat = join(cat, t, keys=["key"]) 65 | file_cat["image_file"] = file 66 | file_cat.sort("healpix") 67 | 68 | # Retrieving spectra associated with this file 69 | target_ids = [int(i) for i in file_cat["targetid"]] 70 | records = None 71 | for i in tqdm(range(len(target_ids) // 500 + 1)): 72 | start = i * 500 73 | end = min((i + 1) * 500, len(target_ids) - 1) 74 | 75 | res = client.retrieve_by_specid( 76 | specid_list=target_ids[start:end], include=inc, dataset_list=["DESI-EDR"] 77 | ) 78 | if records is None: 79 | records = Table.from_pandas(pd.DataFrame.from_records(res.records)) 80 | else: 81 | r = Table.from_pandas(pd.DataFrame.from_records(res.records)) 82 | records = vstack([records, r]) 83 | 84 | # Merging catalogs 85 | file_cat = join(file_cat, records, keys=["targetid"]) 86 | 87 | if merged_cat is None: 88 | merged_cat = file_cat 89 | else: 90 | merged_cat = vstack([merged_cat, file_cat]) 91 | 92 | # Saving the results 93 | merged_cat.to_pandas().to_parquet("matched_catalog.pq") 94 | -------------------------------------------------------------------------------- /scripts/export_data.py: -------------------------------------------------------------------------------- 1 | # This script exports the data needed for the dataset into a single file. 2 | # 3 | import h5py 4 | import numpy as np 5 | import pandas as pd 6 | from astropy.table import Table, join 7 | from tqdm import tqdm 8 | 9 | DATA_DIR = "/mnt/home/flanusse/ceph" 10 | 11 | # Open matched catalog 12 | joint_cat = pd.read_parquet(DATA_DIR + "/matched_catalog.pq").drop_duplicates( 13 | subset=["key"] 14 | ) 15 | 16 | # Create randomized indices to shuffle the dataset 17 | rng = np.random.default_rng(seed=42) 18 | indices = rng.permutation(len(joint_cat)) 19 | joint_cat = joint_cat.iloc[indices] 20 | 21 | with h5py.File(DATA_DIR + "/exported_data.h5", "w") as f: 22 | for i in range(10): 23 | print("Processing file %d" % i) 24 | # Considering only the objects that are in the current file 25 | sub_cat = joint_cat[joint_cat["inds"] // 1000000 == i] 26 | images = [] 27 | spectra = [] 28 | redshifts = [] 29 | targetids = [] 30 | with h5py.File( 31 | DATA_DIR + "/images_npix152_0%02d000000_0%02d000000.h5" % (i, i + 1) 32 | ) as d: 33 | for j in tqdm(range(len(sub_cat))): 34 | images.append( 35 | np.array(d["images"][sub_cat["inds"].iloc[j] % 1000000]).T.astype( 36 | "float32" 37 | ) 38 | ) 39 | spectra.append( 40 | np.reshape(sub_cat["flux"].iloc[j], [-1, 1]).astype("float32") 41 | ) 42 | redshifts.append(sub_cat["redshift"].iloc[j]) 43 | targetids.append(sub_cat["targetid"].iloc[j]) 44 | f.create_group(str(i)) 45 | f[str(i)].create_dataset("images", data=images) 46 | f[str(i)].create_dataset("spectra", data=spectra) 47 | f[str(i)].create_dataset("redshifts", data=redshifts) 48 | f[str(i)].create_dataset("targetids", data=targetids) 49 | -------------------------------------------------------------------------------- /submit.sbatch: -------------------------------------------------------------------------------- 1 | #!/bin/bash -l 2 | 3 | #SBATCH -p gpu 4 | ##SBATCH -C "h100" 5 | #SBATCH -J "astroclip-job" 6 | #SBATCH --nodes=1 7 | #SBATCH --ntasks-per-node=4 8 | #SBATCH --gpus-per-node=4 9 | #SBATCH --cpus-per-gpu=2 10 | #SBATCH --mem=200G 11 | #SBATCH --output=logs/astroclip-%j.log 12 | #SBATCH --time=48:00:00 13 | 14 | module load gcc 15 | module load nccl 16 | 17 | export num_workers=$(expr $SLURM_JOB_CPUS_PER_NODE - 1) 18 | export OMP_NUM_THREADS=${SLURM_CPUS_ON_NODE} 19 | 20 | # some debugging logs 21 | export WANDB_START_METHOD=thread 22 | export NCCL_DEBUG=INFO 23 | export CUDA_LAUNCH_BLOCKING=1. 24 | 25 | # Load running environment 26 | source /mnt/home/lparker/python_envs/toto/bin/activate 27 | 28 | srun $@ \ 29 | --data.num_workers=8 \ 30 | --trainer.num_nodes=${SLURM_NNODES} \ 31 | --trainer.devices=${SLURM_GPUS_PER_NODE} \ 32 | --trainer.strategy='ddp_find_unused_parameters_true' 33 | --------------------------------------------------------------------------------