├── .gitignore ├── README.md ├── assets ├── birds_teaser.png ├── src_01_5_crop.png ├── src_10_2_crop.png ├── trg_01_5_crop_blend.gif └── trg_10_2_crop_blend.gif ├── configs ├── classifier │ └── nearest_centroid.yaml ├── dataset │ ├── apk.yaml │ ├── cub.yaml │ ├── pascalparts.yaml │ ├── pfpascal.yaml │ └── spair.yaml ├── demo_data_loading.yaml ├── demo_keypoint_transfer.yaml ├── demo_segmenation_nearest_centroid.yaml ├── eval_all.yaml ├── eval_seg.yaml ├── eval_time_mem.yaml ├── feat_refine │ └── geosc.yaml ├── featurizer │ ├── dift_adm.yaml │ ├── dift_sd.yaml │ ├── dino.yaml │ ├── dinov2.yaml │ ├── dinov2lora.yaml │ ├── open_clip.yaml │ └── sd15ema_dinov2.yaml ├── store_feats.yaml ├── store_masks.yaml ├── train_pairs.yaml └── train_setting_default.yaml ├── demo_data_loading.ipynb ├── demo_keypoint_transfer.ipynb ├── demo_segmenation_nearest_centroid.ipynb ├── download_data ├── prepare_apk.sh ├── prepare_cub.sh ├── prepare_pascalparts.sh ├── prepare_pfpascal.sh ├── prepare_spair.sh ├── spair_keypoint_l_r_permutation.csv ├── spair_keypoint_names.csv └── spair_keypoint_permutation.csv ├── pretrained_weights └── geco │ └── last_lora_weights.pth ├── scripts ├── run_evaluation_all.py ├── run_evaluation_seg.py ├── run_evaluation_time_mem.py ├── store_feats.py ├── store_masks.py └── train_pairs.py ├── setup_env.sh └── src ├── __init__.py ├── dataset ├── apk_pairs.py ├── augmentations.py ├── cub_200.py ├── cub_200_pairs.py ├── pairwise_utils.py ├── pascalparts.py ├── pascalparts_part2ind.py ├── pfpascal_pairs.py ├── random_utils.py ├── spair.py ├── spair_single.py └── utils.py ├── evaluation ├── __init__.py ├── evaluation.py ├── pck.py ├── runtime_mem.py ├── segmentation.py └── segmentation_metrics.py ├── logging ├── log_results.py ├── visualization.py ├── visualization_pck.py └── visualization_seg.py ├── losses └── __init__.py ├── matcher ├── __init__.py ├── argmaxmatcher.py └── ot_matcher.py └── models ├── __init__.py ├── classifier ├── supervised │ └── nearest_centroid.py └── utils.py ├── featurizer ├── clip.py ├── dift_adm.py ├── dift_sd.py ├── dino.py ├── dinov2.py ├── dinov2_lora.py ├── sd15ema_dinov2.py └── utils.py ├── featurizer_refine ├── __init__.py ├── feat_refine_geosc.py └── resnet.py ├── pca.py └── segmentation └── sam.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.png 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | */.DS_Store 163 | .DS_Store 164 | 165 | guided-diffusion/ 166 | davis_results_sd/ 167 | davis_results_adm/ 168 | superpoint-1k/ 169 | hpatches_results/ 170 | superpoint-1k.zip 171 | SPair-71k.tar.gz 172 | SPair-71k/ 173 | ./guided-diffusion/models/256x256_diffusion_uncond.pt 174 | wandb/ 175 | outputs/ 176 | feats/ 177 | spair/ 178 | pretrained_weights/ 179 | runs/ 180 | pretrained_weights_backup/ 181 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🦎 GECO: Geometrically Consistent Embedding with Lightspeed Inference (ICCV 2025) 2 | 3 | 4 | [ 🌐**Project Page**](https://reginehartwig.github.io/publications/geco/) • [📄 **Paper**](https://arxiv.org/pdf/2508.00746) 5 | 6 |
7 |  8 | 9 | 10 |
11 | 12 | This is the official repository for the ICCV 2025 paper: 13 | 14 | > **GECO: Geometrically Consistent Embedding with Lightspeed Inference**. 15 | > 16 | >[Regine Hartwig](https://reginehartwig.github.io/)1,2, 17 | >[Dominik Muhle](https://dominikmuhle.github.io/)1,2, 18 | >[Riccardo Marin](https://ricma.netlify.app/)1,2, 19 | >[Daniel Cremers](https://cvg.cit.tum.de/members/cremers)1,2, 20 | > 21 | > 1Technical University of Munich, 2MCML 22 | > 23 | > [**ICCV 2025** (arXiv)](https://arxiv.org/pdf/2508.00746) 24 | 25 | 26 | If you find our work useful, please consider citing our paper: 27 | ``` 28 | @inproceedings{hartwig2025geco, 29 | title={GECO: Geometrically Consistent Embedding with Lightspeed Inference}, 30 | author={Hartwig, Regine and Muhle, Dominik and Marin, Riccardo and Cremers, Daniel}, 31 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision}, 32 | year={2025} 33 | } 34 | ``` 35 | 36 | ## Intro 37 | 38 | We address the task of geometry-aware feature encoding. A common way to test geometric awareness is through keypoint matching: Given a source image with an annotated keypoint, the goal is to predict the keypoint in the target image by selecting the location with the highest feature similarity. Below are two examples: 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 |
47 | 48 | We introduce a training paradigm and a lightweight architecture for learning from image pairs with sparsely annotated keypoints. Additionally, we enhance the testing of features by introducing subdivisions of the commonly used PCK metric and a centroid clustering approach for more densely evaluating the feature space. 49 | ## 🔧 Environment Setup 50 | If you're using a Linux machine, set up the Python environment with: 51 | ```bash 52 | conda create --name geco python=3.10 53 | conda activate geco 54 | bash setup_env.sh 55 | ``` 56 | To use [Segment Anything (SAM)](https://github.com/facebookresearch/segment-anything) for mask extraction: 57 | 58 | ```bash 59 | pip install git+https://github.com/facebookresearch/segment-anything.git 60 | wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth 61 | 62 | ``` 63 | 64 | Install odise in case you want to run the Geo baseline 65 | 66 | ```bash 67 | git clone git@github.com:NVlabs/ODISE.git 68 | cd ODISE 69 | pip install -e . 70 | ``` 71 | 72 | ## 🚀 Get Started 73 | 74 | ### 📁 Prepare the Datasets 75 | 76 | Run the following scripts to prepare each dataset: 77 | * APK: 78 | ```bash 79 | bash download_data/prepare_apk.sh 80 | wget https://github.com/Junyi42/GeoAware-SC/blob/master/prepare_ap10k.ipynb 81 | ``` 82 | Then run the notebook from GeoAware-SC to preprocess the data. 83 | * CUB: 84 | ```bash 85 | bash download_data/prepare_cub.sh 86 | ``` 87 | * PascalParts 88 | ```bash 89 | bash download_data/prepare_pascalparts.sh 90 | ``` 91 | * PFPascal 92 | ```bash 93 | bash download_data/prepare_pfpascal.sh 94 | ``` 95 | * SPair-71k: 96 | ```bash 97 | bash download_data/prepare_spair.sh 98 | ``` 99 | 100 | ### Extract the mask 101 | 1. Define `` in the dataset config files. 102 | 2. Define `` in `store_masks.yaml` pointing to the path, where `sam_vit_h_4b8939.pth`is stored. 103 | 3. Select the datasets to process in `store_masks.yaml`. 104 | 4. Run: 105 | ```bash 106 | python scripts/store_masks.py --config-name=store_masks.yaml 107 | ``` 108 | 109 | ### Precompute the features (not recommended) 110 | Choose a path and define `` in the dataset config files. 111 | Define which dataset you want to extract the features for in `store_feats.yaml` and run 112 | ```bash 113 | python scripts/store_feats.py --config-name=store_feats.yaml 114 | ``` 115 | 116 | 117 | ## 🎯 Pretrained Weights 118 | Pretrained weights are available in ```pretrained_weights/geco```. 119 | 120 | ## 🧪 Interactive Demos: Give it a Try! 121 | 122 | We provide interactive jupyter notebooks for testing. 123 | 124 | * [📚 Data Loading Demo](demo_data_loading.ipynb) 125 | 126 | Validate dataset preparation and path setup. 127 | 128 | * [🎨 Segmentation Demo](demo_segmenation_nearest_centroid.ipynb) 129 | 130 | Visualize part segmentation using a simple linear classifier. 131 | 132 | * [📍 Keypoint Transfer Demo](demo_keypoint_transfer.ipynb) 133 | 134 | Explore keypoint transfer and interactive attention maps. 135 | 136 | 137 | ## 📊 Run Evaluation 138 | 139 | Run full evaluation: 140 | ```bash 141 | python scripts/run_evaluation_all.py --config-name=eval_all.yaml 142 | ``` 143 | Evaluate inference time and memory usage: 144 | ```bash 145 | python scripts/run_evaluation_time_mem.py --config-name=eval_time_mem.yaml 146 | ``` 147 | Evaluate segmentation metrics: 148 | ```bash 149 | python scripts/run_evaluation_seg.py --config-name=eval_seg.yaml 150 | ``` 151 | 152 | 153 | ## 🏋️ Train the Model 154 | Before training, comment out the following block in `configs/featurizer/dinov2lora.yaml`: 155 | 156 | ```yaml 157 | init: 158 | id: geco 159 | eval_last: True 160 | ``` 161 | Then run: 162 | ```bash 163 | python scripts/train_pairs.py --config-name=train_pairs.yaml 164 | ``` -------------------------------------------------------------------------------- /assets/birds_teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reginehartwig/geco/3d3f37530bd61f36602f8ac840da13200c3c2e72/assets/birds_teaser.png -------------------------------------------------------------------------------- /assets/src_01_5_crop.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reginehartwig/geco/3d3f37530bd61f36602f8ac840da13200c3c2e72/assets/src_01_5_crop.png -------------------------------------------------------------------------------- /assets/src_10_2_crop.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reginehartwig/geco/3d3f37530bd61f36602f8ac840da13200c3c2e72/assets/src_10_2_crop.png -------------------------------------------------------------------------------- /assets/trg_01_5_crop_blend.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reginehartwig/geco/3d3f37530bd61f36602f8ac840da13200c3c2e72/assets/trg_01_5_crop_blend.gif -------------------------------------------------------------------------------- /assets/trg_10_2_crop_blend.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reginehartwig/geco/3d3f37530bd61f36602f8ac840da13200c3c2e72/assets/trg_10_2_crop_blend.gif -------------------------------------------------------------------------------- /configs/classifier/nearest_centroid.yaml: -------------------------------------------------------------------------------- 1 | model: 'nearest_centroid_fg' 2 | num_pcaparts: 32 3 | num_samples: 100 -------------------------------------------------------------------------------- /configs/dataset/apk.yaml: -------------------------------------------------------------------------------- 1 | name: 'ap10k' 2 | dataset_path: '/ap-10k' #path to spair dataset 3 | save_path: '' 4 | save_path_masks: '' 5 | hydra: 6 | output_subdir: null 7 | sup: sup_augmented # ['sup_augmented', 'sup_original'] 8 | flip_aug: True 9 | cat: all 10 | subsample: null -------------------------------------------------------------------------------- /configs/dataset/cub.yaml: -------------------------------------------------------------------------------- 1 | name: 'cub' 2 | dataset_path: '' 3 | dataset_path_annotations: '' 4 | n_pairs: 10000 5 | save_path: '' 6 | save_path_masks: '' 7 | hydra: 8 | output_subdir: null 9 | sup: sup_augmented # ['sup_augmented', 'sup_original'] 10 | borders_cut: True 11 | flip_aug: True 12 | cat: bird -------------------------------------------------------------------------------- /configs/dataset/pascalparts.yaml: -------------------------------------------------------------------------------- 1 | name: 'pascalparts' 2 | dataset_path: '/PASCAL-VOC/' #path to spair dataset 3 | save_path: '' 4 | hydra: 5 | output_subdir: null 6 | sup: sup_original # ['sup_augmented', 'sup_original'] -------------------------------------------------------------------------------- /configs/dataset/pfpascal.yaml: -------------------------------------------------------------------------------- 1 | name: 'pfpascal' 2 | dataset_path: '/PF-dataset-PASCAL' #path to spair dataset 3 | save_path: '' 4 | save_path_masks: '' 5 | hydra: 6 | output_subdir: null 7 | sup: sup_original # ['sup_augmented', 'sup_original'] 8 | cat: all 9 | -------------------------------------------------------------------------------- /configs/dataset/spair.yaml: -------------------------------------------------------------------------------- 1 | name: 'spair' 2 | dataset_path: '/SPair-71k' #path to spair dataset 3 | save_path: '' 4 | save_path_masks: '' 5 | hydra: 6 | output_subdir: null 7 | sup: sup_augmented # ['sup_augmented', 'sup_original'] 8 | flip_aug: True 9 | cat: all 10 | -------------------------------------------------------------------------------- /configs/demo_data_loading.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset/spair # ['apk', 'cub', 'spair', 'pfpascal'] 3 | - featurizer/dinov2 # which model to use ['dift_sd', 'dift_adm', 'open_clip', 'dino', 'dinov2', 'sd15ema_dinov2', 'dinov2lora'] 4 | 5 | model_seg_name: sam -------------------------------------------------------------------------------- /configs/demo_keypoint_transfer.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset/apk # ['apk', 'cub', 'spair', 'pfpascal'] 3 | - featurizer/dinov2lora # which model to use ['dift_sd', 'dift_adm', 'open_clip', 'dino', 'dinov2', 'sd15ema_dinov2', 'dinov2lora'] 4 | dataset: 5 | sup: sup_original # ['sup_augmented', 'sup_original'] -------------------------------------------------------------------------------- /configs/demo_segmenation_nearest_centroid.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset/pascalparts 3 | - featurizer/dinov2lora # which model to use ['dift_sd', 'dift_adm', 'open_clip', 'dino', 'dinov2', 'sd15ema_dinov2', 'dinov2lora'] 4 | - classifier/nearest_centroid@sup_classifier -------------------------------------------------------------------------------- /configs/eval_all.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset/spair # ['apk', 'cub', 'spair', 'pfpascal'] 3 | - dataset/apk@dataset1 4 | - dataset/cub@dataset2 5 | - dataset/pfpascal@dataset3 6 | - featurizer/dinov2lora # which model to use can be one of ['dift_sd', 'dift_adm', 'open_clip', 'dino', 'dinov2', 'sd15ema_dinov2, 'dinov2lora'] 7 | 8 | dataset: 9 | sup: sup_original # ['sup_augmented', 'sup_original'] 10 | dataset1: 11 | sup: sup_original # ['sup_augmented', 'sup_original'] 12 | dataset2: 13 | sup: sup_original # ['sup_augmented', 'sup_original'] 14 | dataset3: 15 | sup: sup_original # ['sup_augmented', 'sup_original'] 16 | 17 | # PCK 18 | upsample: True 19 | alpha_bbox: False 20 | n_pairs_eval_pck: 10000 21 | -------------------------------------------------------------------------------- /configs/eval_seg.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset/pascalparts 3 | - featurizer/dinov2lora 4 | - classifier/nearest_centroid@sup_classifier -------------------------------------------------------------------------------- /configs/eval_time_mem.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset/apk # ['apk', 'cub', 'spair', 'pfpascal'] 3 | - featurizer/dinov2lora # which model to use ['dift_sd', 'dift_adm', 'open_clip', 'dino', 'dinov2', 'sd15ema_dinov2', 'dinov2lora'] 4 | 5 | num_imgs_time_mem: 1000 6 | dataset: 7 | sup: sup_original # ['sup_augmented', 'sup_original'] -------------------------------------------------------------------------------- /configs/feat_refine/geosc.yaml: -------------------------------------------------------------------------------- 1 | model: GeoSc 2 | model_out_path: '' 3 | feature_dims: [640,1280,1280,768] 4 | projection_dim: 768 5 | # in_dim: 3968 6 | feat_map_dropout: 0.2 7 | 8 | init: 9 | id: geosc 10 | epoch: 856 11 | url: "https://github.com/Junyi42/GeoAware-SC/blob/master/results_spair/best_856.PTH" -------------------------------------------------------------------------------- /configs/featurizer/dift_adm.yaml: -------------------------------------------------------------------------------- 1 | model: 'dift_adm' 2 | img_size: [512, 512] #in the order of [width, height], resize input image to [w, h] before fed into diffusion model, if set to 0, will stick to the original input size. 3 | t: 101 #t for diffusion 4 | up_ft_index: 4 #which upsampling block to extract the ft map 5 | ensemble_size: 1 #ensemble size for getting an image ft map -------------------------------------------------------------------------------- /configs/featurizer/dift_sd.yaml: -------------------------------------------------------------------------------- 1 | model: 'dift_sd' 2 | img_size: [768, 768] #in the order of [width, height], resize input image to [w, h] before fed into diffusion model, if set to 0, will stick to the original input size. by default is 768x768. 3 | t: 261 #t for diffusion 4 | up_ft_index: 2 #which upsampling block to extract the ft map 5 | ensemble_size: 1 #ensemble size for getting an image ft map 6 | all_cats: none -------------------------------------------------------------------------------- /configs/featurizer/dino.yaml: -------------------------------------------------------------------------------- 1 | model: 'dino' 2 | img_size: [224, 224] #in the order of [width, height], resize input image to [w, h] before fed into the model, if set to 0, will stick to the original input size. 3 | up_ft_index: 4 #which upsampling block to extract the ft map 4 | log_bin: False -------------------------------------------------------------------------------- /configs/featurizer/dinov2.yaml: -------------------------------------------------------------------------------- 1 | model: 'dinov2' 2 | model_size: 'dinov2_vitb14' # ['dinov2_vits14', 'dinov2_vitb14'] 3 | img_size: [518, 518] #in the order of [width, height], resize input image to [w, h] before fed into the model, if set to 0, will stick to the original input size. 4 | up_ft_index: 1 #which upsampling block to extract the ft map 5 | log_bin: False -------------------------------------------------------------------------------- /configs/featurizer/dinov2lora.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dinov2 3 | model: 'dinov2lora' 4 | model_out_path: '' 5 | lora_rank: 10 6 | 7 | init: 8 | id: geco 9 | eval_last: True -------------------------------------------------------------------------------- /configs/featurizer/open_clip.yaml: -------------------------------------------------------------------------------- 1 | model: 'open_clip' 2 | img_size: [512, 512] #in the order of [width, height], resize input image to [w, h] before fed into the model, if set to 0, will stick to the original input size. 3 | up_ft_index: 4 #which upsampling block to extract the ft map 4 | ensemble_size: 1 #ensemble size for getting an image ft map 5 | t: 101 #t for diffusion -------------------------------------------------------------------------------- /configs/featurizer/sd15ema_dinov2.yaml: -------------------------------------------------------------------------------- 1 | model: 'sd15ema_dinov2' 2 | crop_feats: True -------------------------------------------------------------------------------- /configs/store_feats.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - featurizer/sd15ema_dinov2 # which model to use ['dift_sd', 'dift_adm', 'open_clip', 'dino', 'dinov2', 'sd15ema_dinov2', 'dinov2lora'] 3 | - dataset/cub # ['apk', 'cub', 'spair', 'pfpascal'] 4 | 5 | dataset: 6 | split: 'train' -------------------------------------------------------------------------------- /configs/store_masks.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: cub # ['apk', 'cub', 'spair', 'pfpascal'] 3 | 4 | dataset: 5 | split: 'train' 6 | path_model_seg: 7 | model_seg_name: sam -------------------------------------------------------------------------------- /configs/train_pairs.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - featurizer: dinov2lora # which model to use ['dift_sd', 'dift_adm', 'open_clip', 'dino', 'dinov2', 'sd15ema_dinov2', 'dinov2lora'] 3 | - dataset: apk # ['apk', 'cub', 'spair', 'pfpascal'] 4 | - dataset: spair@dataset2 # ['apk', 'cub', 'spair', 'pfpascal'] 5 | - dataset: pfpascal@dataset3 # ['apk', 'cub', 'spair', 'pfpascal'] 6 | - dataset: pfpascal@datasetgeneralization # ['apk', 'cub', 'spair', 'pfpascal'] 7 | - dataset: spair@datasetgeneralization2 # ['apk', 'cub', 'spair', 'pfpascal'] 8 | - dataset: cub@datasetgeneralization3 # ['apk', 'cub', 'spair', 'pfpascal'] 9 | - train_setting_default 10 | 11 | # fix number of elements for each category in the dataset 12 | dataset: 13 | num_el: 800 14 | dataset2: 15 | num_el: 800 16 | dataset3: 17 | num_el: 800 18 | 19 | ## for evaluating PCK 20 | upsample: True 21 | alpha_bbox: False 22 | n_pairs_eval_pck: 10000 23 | 24 | -------------------------------------------------------------------------------- /configs/train_setting_default.yaml: -------------------------------------------------------------------------------- 1 | learning_rate: 0.0001 2 | epoch: 8 3 | batch_size: 6 4 | weight_decay: 0 5 | scheduler: !!null 6 | 7 | # losses 8 | losses: 9 | pos: 1 10 | bin: 1 11 | neg: 10 12 | neg_fg_bkg: 10 13 | model_seg_name: sam -------------------------------------------------------------------------------- /download_data/prepare_apk.sh: -------------------------------------------------------------------------------- 1 | CURRENT_DIR=$(pwd) 2 | mkdir -p 3 | cd 4 | gdown https://drive.google.com/uc?id=1-FNNGcdtAQRehYYkGY1y4wzFNg4iWNad 5 | unzip ap-10k.zip -d ap-10k 6 | rm ap-10k.zip 7 | cd ${CURRENT_DIR} -------------------------------------------------------------------------------- /download_data/prepare_cub.sh: -------------------------------------------------------------------------------- 1 | CURRENT_DIR=$(pwd) 2 | mkdir -p 3 | cd 4 | wget https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz 5 | if [ -f CUB_200_2011.tgz ]; then 6 | echo "File downloaded successfully" 7 | else 8 | echo "File download failed" 9 | fi 10 | tar -xvzf CUB_200_2011.tgz 11 | if [ -d CUB_200_2011 ]; then 12 | echo "File extracted successfully" 13 | rm CUB_200_2011.tgz 14 | else 15 | echo "File extraction failed" 16 | fi 17 | 18 | mkdir -p 19 | cd 20 | apt-get install gdown 21 | gdown https://drive.google.com/drive/folders/1DnmpG8Owhv9Rmz_KFozqIKVTJCNSbWc9?usp=sharing 22 | unzip data.zip -C data-v2 23 | cd ${CURRENT_DIR} -------------------------------------------------------------------------------- /download_data/prepare_pascalparts.sh: -------------------------------------------------------------------------------- 1 | CURRENT_DIR=$(pwd) 2 | mkdir -p 3 | cd 4 | 5 | mkdir PASCAL-VOC 6 | cd PASCAL-VOC 7 | wget http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar 8 | tar -xf VOCtrainval_03-May-2010.tar 9 | if [ -d VOCdevkit ]; then 10 | echo "File extracted successfully" 11 | rm VOCtrainval_03-May-2010.tar 12 | else 13 | echo "File extraction failed" 14 | fi 15 | 16 | mkdir Parts 17 | cd Parts 18 | wget https://roozbehm.info/pascal-parts/trainval.tar.gz 19 | tar -xf trainval.tar.gz 20 | 21 | if [ -d Annotations_Part ]; then 22 | echo "File extracted successfully" 23 | rm trainval.tar.gz 24 | else 25 | echo "File extraction failed" 26 | fi 27 | cd ${CURRENT_DIR} -------------------------------------------------------------------------------- /download_data/prepare_pfpascal.sh: -------------------------------------------------------------------------------- 1 | CURRENT_DIR=$(pwd) 2 | mkdir -p 3 | cd 4 | wget http://www.di.ens.fr/willow/research/proposalflow/dataset/PF-dataset-PASCAL.zip 5 | wget http://www.di.ens.fr/willow/research/cnngeometric/other_resources/test_pairs_pf_pascal.csv 6 | wget http://www.di.ens.fr/willow/research/cnngeometric/other_resources/val_pairs_pf_pascal.csv 7 | gdown https://drive.google.com/uc?id=111tpXshLiJ4qudBHoGK3HbMSNr9vVRq9 # download the trin_pairs_pf_pascal.csv 8 | unzip PF-dataset-PASCAL.zip -d . 9 | rm PF-dataset-PASCAL.zip 10 | rm __MACOSX -r 11 | rm PF-dataset-PASCAL/Annotations/.DS_Store 12 | mv test_pairs_pf_pascal.csv PF-dataset-PASCAL 13 | mv val_pairs_pf_pascal.csv PF-dataset-PASCAL 14 | cd ${CURRENT_DIR} -------------------------------------------------------------------------------- /download_data/prepare_spair.sh: -------------------------------------------------------------------------------- 1 | CURRENT_DIR=$(pwd) 2 | mkdir -p 3 | cd 4 | wget http://cvlab.postech.ac.kr/research/SPair-71k/data/SPair-71k.tar.gz 5 | tar -xf SPair-71k.tar.gz 6 | if [ -d SPair-71k ]; then 7 | echo "File extracted successfully" 8 | rm SPair-71k.tar.gz 9 | else 10 | echo "File extraction failed" 11 | fi 12 | cp ${CURRENT_DIR}/spair*.csv /SPair-71k 13 | cd ${CURRENT_DIR} -------------------------------------------------------------------------------- /download_data/spair_keypoint_l_r_permutation.csv: -------------------------------------------------------------------------------- 1 | "kp_id ",aeroplane,bicycle,bird,boat,bottle,bus,car,cat,chair,cow,dog,horse,motorbike,person,pottedplant,sheep,train,tvmonitor 2 | "0",0,1,0,0,1,1,1,1,1,1,1,1,1,1,2,1,1,2 3 | "1",1,0,2,2,0,0,0,0,0,0,0,0,0,0,3,0,0,1 4 | "2",2,3,1,1,3,3,3,3,3,3,3,3,3,3,0,3,3,0 5 | "3",3,2,3,4,2,2,2,2,2,2,2,2,2,2,1,2,2,7 6 | "4",5,4,5,3,5,7,8,5,5,5,5,5,5,4,5,5,5,6 7 | "5",4,5,4,6,4,6,9,4,4,4,4,4,4,5,4,4,4,5 8 | "6",7,7,6,5,7,5,7,7,7,7,6,7,7,6,8,7,7,4 9 | "7",6,6,8,8,6,4,6,6,6,6,7,6,6,7,7,6,6,3 10 | "8",9,8,7,7,9,8,4,8,9,8,8,8,8,9,6,8,9,10 11 | "9",8,10,9,10,8,9,5,10,8,10,10,9,9,8,,10,8,9 12 | "10",11,9,11,9,,20,20,9,11,9,9,11,10,11,,9,11,8 13 | "11",10,11,10,12,,21,21,12,10,12,12,10,12,10,,12,10,15 14 | "12",13,12,13,11,,22,22,11,13,11,11,13,11,13,,11,13,14 15 | "13",12,13,12,13,,23,23,13,12,13,13,12,,12,,13,12,13 16 | "14",15,,15,,,24,24,14,,14,14,14,,15,,14,15,12 17 | "15",14,,14,,,25,25,,,16,15,15,,14,,16,14,11 18 | "16",17,,16,,,17,17,,,15,,17,,17,,15,17, 19 | "17",16,,,,,16,16,,,18,,16,,16,,18,16, 20 | "18",19,,,,,19,19,,,17,,19,,19,,17,, 21 | "19",18,,,,,18,18,,,20,,18,,18,,20,, 22 | "20",21,,,,,10,10,,,19,,,,,,19,, 23 | "21",20,,,,,11,11,,,,,,,,,,, 24 | "22",22,,,,,12,12,,,,,,,,,,, 25 | "23",23,,,,,13,13,,,,,,,,,,, 26 | "24",24,,,,,14,14,,,,,,,,,,, 27 | "25",,,,,,15,15,,,,,,,,,,, 28 | "26",,,,,,27,27,,,,,,,,,,, 29 | "27",,,,,,26,26,,,,,,,,,,, 30 | "28",,,,,,29,29,,,,,,,,,,, 31 | "29",,,,,,28,28,,,,,,,,,,, 32 | -------------------------------------------------------------------------------- /download_data/spair_keypoint_names.csv: -------------------------------------------------------------------------------- 1 | kp_id,aeroplane,bicycle,bird,boat,bottle,bus,car,cat,chair,cow,dog,horse,motorbike,person,pottedplant,sheep,train,tvmonitor 2 | 0,nose,front wheel,crown,front deck,right cap,right mirror,right mirror,right ear,right front-seat,right ear,right ear,right ear,right mirror,right eye,right rim,right ear,top-left front,top-right frame-corner 3 | 1,windshield,back wheel,right wing,front-right deck,left cap,left mirror,left mirror,left ear,left front-seat,left ear,left ear,left ear,left mirror,left eye,back rim,left ear,top-right front,frame-top 4 | 2,cockpit,right handlebar,left wing,front-left deck,right neck,right headlight,right headlight,right ear-tip,bottom-right back,right ear-tip,right ear-tip,right ear-tip,right handlebar,right ear,left rim,right ear-tip,bottom-left front,top-left frame-corner 5 | 3,front wheel,left handlebar ,beak,mid-right deck,left neck ,left headlight,left headlight,left ear-tip,bottom-left back,left ear-tip,left ear-tip,left ear-tip,left handlebar ,left ear,front rim,left ear-tip,bottom-right front,left frame-side 6 | 4,right wheel,middle handlebar,right cheek,mid-left deck ,top-right body,front license-plate,front license-plate,right eye,front-right leg,right eye,right eye,right eye,front light,nose,right body,right eye,top-left back,bottom-left frame-corner 7 | 5,left wheel,front seat,left cheek ,back-right deck,top-left body,right tail-light,front emblem,left eye,front-left leg,left eye,left eye,left eye,back light,chin,left body,left eye,top-right back,frame-base 8 | 6,right engine-fan,back-right seat,forehead ,back-left deck,right-center body,left tail-light,right tail-light,right nostril,back-right leg,right nostril,nose ,right nostril,front wheel,nape,right base,right nostril,bottom-left back,bottom-right frame-corner 9 | 7,left engine-fan,back-left seat,right eye,front-right hull,left-center body,back license-plate,left tail-light,left nostril,back-left leg,left nostril,forehead,left nostril,back wheel,mouth,middle base,left nostril,bottom-right back,right frame-side 10 | 8,right wing-tip,seat-post,left eye,front-left hull,right-bottom body,N/A,back license-plate,mouth,top-right back,mouth,mouth,mouth,license-plate,right shoulder,left base,mouth,top-left windshield ,top-right screen-corner 11 | 9,left wing-tip,right pedal,nape,mid-right hull,left-bottom body,N/A,back emblem,front-right foot,top-left back,front-right foot,front-right foot,forehead,tail ,left shoulder,,front-right foot,top-right windshield,screen-top 12 | 10,right engine,left pedal ,right foot,mid-left hull ,,back-right wheel-back-liner,back-right wheel-back-liner,front-left foot,right armrest,front-left foot,front-left foot,front-right foot,exhaustpipe,right elbow,,front-left foot,bottom-left windshield ,top-left screen-corner 13 | 11,left engine ,N/A,left foot,back-right hull,,back-right wheel,back-right wheel,back-right foot,left armrest,back-right foot,back-right foot,front-left foot,front wheelhub,left elbow,,back-right foot,bottom-right windshield ,left screen-side 14 | 12,right wing-front,N/A,right knee,back-left hull,,back-right wheel-front-liner,back-right wheel-front-liner,back-left foot,right elbowrest,back-left foot,back-left foot,back-right foot,back wheelhub ,right hand,,back-left foot,top-left windshield-center,bottom-left screen-corner 15 | 13,left wing-front,chainring ,left knee,front hull,,front-right wheel-back-liner,front-right wheel-back-liner,tail-tip,left elbowrest,tail-tip,tail-tip,back-left foot,,left hand,,tail-tip,top-right windshield-center,screen-base 16 | 14,right wing-back,,right hip,,,front-right wheel,front-right wheel,tail-base,,tail-base,tail-base,tail-tip,,right knee,,tail-base,bottom-left windshield-center,bottom-right screen-corner 17 | 15,left wing-back,,left hip,,,front-right wheel-front-liner,front-right wheel-front-liner,,,front-right knee,nape,tail-base,,left knee,,front-right knee,bottom-right windshield-center,right screen-side 18 | 16,right stabilizer-tip,,tail,,,bottom-right bus-front,bottom-right front-windshield,,,front-left knee,,front-right knee,,right ankle,,front-left knee,left headlight, 19 | 17,left stabilizer-tip,,,,,bottom-left bus-front,bottom-left front-windshield,,,back-right knee,,front-left knee,,left ankle,,back-right knee,right headlight, 20 | 18,right stabilizer-front,,,,,top-right bus-front,top-right front-windshield,,,back-left knee,,back-right knee,,right foot,,back-left knee,, 21 | 19,left stabilizer-front,,,,,top-left bus-front,top-left front-windshield,,,right horn,,back-left knee,,left foot,,right horn,, 22 | 20,right stabilizer-back,,,,,back-left wheel-back-liner,back-left wheel-back-liner,,,left horn,,,,,,left horn,, 23 | 21,left stabilizer-back,,,,,back-left wheel,back-left wheel,,,,,,,,,,, 24 | 22,vertical stabilizer-tip,,,,,back-left wheel-front-liner,back-left wheel-front-liner,,,,,,,,,,, 25 | 23,vertical stabilizer-front,,,,,front-left wheel-back-liner,front-left wheel-back-liner,,,,,,,,,,, 26 | 24,vertical stabilizer-back,,,,,front-left wheel,front-left wheel,,,,,,,,,,, 27 | 25,,,,,,front-left wheel-front-liner,front-left wheel-front-liner,,,,,,,,,,, 28 | 26,,,,,,bus-back bottom-right,bottom-right back-windshield,,,,,,,,,,, 29 | 27,,,,,,bus-back bottom-left,bottom-left back-windshield,,,,,,,,,,, 30 | 28,,,,,,bus-back top-right,top-right back-windshield,,,,,,,,,,, 31 | 29,,,,,,bus-back top-left,top-left back-windshield,,,,,,,,,,, 32 | -------------------------------------------------------------------------------- /download_data/spair_keypoint_permutation.csv: -------------------------------------------------------------------------------- 1 | "kp_id ",aeroplane,bicycle,bird,boat,bottle,bus,car,cat,chair,cow,dog,horse,motorbike,person,pottedplant,sheep,train,tvmonitor 2 | "0",0,1,0,0,1,1,1,1,1,1,1,1,1,1,2,1,1,2 3 | "1",1,0,2,2,0,0,0,0,0,0,0,0,0,0,3,0,0,1 4 | "2",2,3,1,1,3,3,3,3,3,3,3,3,3,3,0,3,3,0 5 | "3",3,2,3,4,2,2,2,2,2,2,2,2,2,2,1,2,2,7 6 | "4",5,4,5,3,5,7,8,5,5,5,5,5,5,4,5,5,5,6 7 | "5",4,5,4,6,4,6,9,4,4,4,4,4,4,5,4,4,4,5 8 | "6",7,7,6,5,7,5,7,7,7,7,6,7,7,6,8,7,7,4 9 | "7",6,6,8,8,6,4,6,6,6,6,7,6,6,7,7,6,6,3 10 | "8",9,8,7,7,9,8,4,8,9,8,8,8,8,9,6,8,9,10 11 | "9",8,10,9,10,8,9,5,10,8,10,10,9,9,8,,10,8,9 12 | "10",11,9,11,9,,20,20,9,11,9,9,11,10,11,,9,11,8 13 | "11",10,11,10,12,,21,21,12,10,12,12,10,12,10,,12,10,15 14 | "12",13,12,13,11,,22,22,11,13,11,11,13,11,13,,11,13,14 15 | "13",12,13,12,13,,23,23,13,12,13,13,12,,12,,13,12,13 16 | "14",15,,15,,,24,24,14,,14,14,14,,15,,14,15,12 17 | "15",14,,14,,,25,25,,,16,15,15,,14,,16,14,11 18 | "16",17,,16,,,17,17,,,15,,17,,17,,15,17, 19 | "17",16,,,,,16,16,,,18,,16,,16,,18,16, 20 | "18",19,,,,,19,19,,,17,,19,,19,,17,, 21 | "19",18,,,,,18,18,,,20,,18,,18,,20,, 22 | "20",21,,,,,10,10,,,19,,,,,,19,, 23 | "21",20,,,,,11,11,,,,,,,,,,, 24 | "22",22,,,,,12,12,,,,,,,,,,, 25 | "23",23,,,,,13,13,,,,,,,,,,, 26 | "24",24,,,,,14,14,,,,,,,,,,, 27 | "25",,,,,,15,15,,,,,,,,,,, 28 | "26",,,,,,27,27,,,,,,,,,,, 29 | "27",,,,,,26,26,,,,,,,,,,, 30 | "28",,,,,,29,29,,,,,,,,,,, 31 | "29",,,,,,28,28,,,,,,,,,,, 32 | -------------------------------------------------------------------------------- /pretrained_weights/geco/last_lora_weights.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reginehartwig/geco/3d3f37530bd61f36602f8ac840da13200c3c2e72/pretrained_weights/geco/last_lora_weights.pth -------------------------------------------------------------------------------- /scripts/run_evaluation_all.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 4 | 5 | import torch 6 | from src.logging.log_results import finish_wandb, init_wandb 7 | from src.evaluation.evaluation import Evaluation 8 | from omegaconf import DictConfig, OmegaConf 9 | import hydra 10 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 11 | from src.models.featurizer_refine import load_checkpoint as load_checkpoint_refiner 12 | from src.models.featurizer.utils import get_featurizer 13 | import copy 14 | 15 | @hydra.main(config_path="../configs") 16 | def main(args: DictConfig): 17 | 18 | # load the model if feat_refine is part of the config 19 | if 'feat_refine' in args: 20 | model_refine = load_checkpoint_refiner(args.feat_refine) 21 | else: 22 | model_refine = None 23 | 24 | featurizer = get_featurizer(args.featurizer) 25 | 26 | cfg = OmegaConf.to_container(args) 27 | # convert hydra to dict 28 | if 'feat_refine' in args: 29 | init_wandb(cfg, f'eval {args.feat_refine.init.id}') 30 | elif 'init' in args.featurizer: 31 | init_wandb(cfg, f'eval {args.featurizer.init.id}') 32 | else: 33 | init_wandb(cfg, 'eval ') 34 | 35 | evaluation_0 = Evaluation(args, featurizer) 36 | args_0 = copy.deepcopy(args) 37 | args_0.dataset = args.dataset1 38 | evaluation_1 = Evaluation(args_0, featurizer) 39 | args_2 = copy.deepcopy(args) 40 | args_2.dataset = args.dataset2 41 | evaluation_2 = Evaluation(args_2, featurizer) 42 | args_3 = copy.deepcopy(args) 43 | args_3.dataset = args.dataset3 44 | evaluation_3 = Evaluation(args_3, featurizer) 45 | 46 | evaluation_0.evaluate_pck(model_refine=model_refine, suffix=evaluation_0.dataset_test_pck.name) 47 | evaluation_1.evaluate_pck(model_refine=model_refine, suffix=evaluation_1.dataset_test_pck.name) 48 | evaluation_2.evaluate_pck(model_refine=model_refine, suffix=evaluation_2.dataset_test_pck.name) 49 | evaluation_3.evaluate_pck(model_refine=model_refine, suffix=evaluation_3.dataset_test_pck.name) 50 | 51 | evaluation_val_0 = Evaluation(args, featurizer, split='val') 52 | args_val_0 = copy.deepcopy(args) 53 | args_val_0.dataset = args.dataset1 54 | evaluation_val_1 = Evaluation(args_val_0, featurizer, split='val') 55 | args_val_2 = copy.deepcopy(args) 56 | args_val_2.dataset = args.dataset2 57 | evaluation_val_2 = Evaluation(args_val_2, featurizer, split='val') 58 | args_val_3 = copy.deepcopy(args) 59 | args_val_3.dataset = args.dataset3 60 | evaluation_val_3 = Evaluation(args_val_3, featurizer, split='val') 61 | 62 | n_eval_imgs = 10 63 | evaluation_val_0.evaluate_pck(model_refine=model_refine, n_pairs_eval_pck=n_eval_imgs, suffix=evaluation_val_0.dataset_test_pck.name) 64 | evaluation_val_1.evaluate_pck(model_refine=model_refine, n_pairs_eval_pck=n_eval_imgs, suffix=evaluation_val_1.dataset_test_pck.name) 65 | evaluation_val_2.evaluate_pck(model_refine=model_refine, n_pairs_eval_pck=n_eval_imgs, suffix=evaluation_val_2.dataset_test_pck.name) 66 | evaluation_val_3.evaluate_pck(model_refine=model_refine, n_pairs_eval_pck=n_eval_imgs, suffix=evaluation_val_3.dataset_test_pck.name) 67 | 68 | finish_wandb() 69 | 70 | if __name__ == '__main__': 71 | main() -------------------------------------------------------------------------------- /scripts/run_evaluation_seg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from src.logging.log_results import finish_wandb, init_wandb 3 | from src.evaluation.evaluation import Evaluation 4 | from omegaconf import DictConfig, OmegaConf 5 | import hydra 6 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 7 | from src.models.featurizer_refine import load_checkpoint 8 | 9 | @hydra.main(config_path="../configs") 10 | def main(args: DictConfig): 11 | 12 | # load the model if feat_refine is part of the config 13 | if 'feat_refine' in args: 14 | model_refine = load_checkpoint(args.feat_refine) 15 | else: 16 | model_refine = None 17 | 18 | cfg = OmegaConf.to_container(args) 19 | 20 | if 'feat_refine' in args: 21 | init_wandb(cfg, f'eval_seg {args.feat_refine.init.id}') 22 | elif 'init' in args.featurizer: 23 | init_wandb(cfg, f'eval_seg {args.featurizer.init.id}') 24 | else: 25 | init_wandb(cfg, 'eval_seg ') 26 | 27 | evaluation = Evaluation(args) 28 | evaluation.evaluate_seg(model_refine=model_refine) 29 | finish_wandb() 30 | 31 | if __name__ == '__main__': 32 | main() -------------------------------------------------------------------------------- /scripts/run_evaluation_time_mem.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 4 | 5 | import torch 6 | from src.logging.log_results import finish_wandb, init_wandb 7 | from src.evaluation.runtime_mem import evaluate 8 | from omegaconf import DictConfig, OmegaConf 9 | import hydra 10 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 11 | from src.models.featurizer_refine import load_checkpoint 12 | from src.dataset.spair_single import SpairDatasetSingle 13 | from src.dataset.cub_200 import CUBDataset 14 | from src.dataset.apk_pairs import AP10KPairs 15 | from src.dataset.pfpascal_pairs import PFPascalPairs 16 | from src.models.featurizer.utils import get_featurizer 17 | import wandb 18 | 19 | @hydra.main(config_path="../configs") 20 | def main(args: DictConfig): 21 | 22 | # load the model if feat_refine is part of the config 23 | if 'feat_refine' in args: 24 | model_refine = load_checkpoint(args.feat_refine) 25 | else: 26 | model_refine = None 27 | 28 | cfg = OmegaConf.to_container(args) 29 | # convert hydra to dict 30 | if 'feat_refine' in args: 31 | init_wandb(cfg, f'eval_time_mem {args.feat_refine.init.id}') 32 | elif 'init' in args.featurizer: 33 | init_wandb(cfg, f'eval_time_mem {args.featurizer.init.id}') 34 | else: 35 | init_wandb(cfg, 'eval_time_mem ') 36 | if args.dataset.name == "spair": 37 | dataset_test_time_mem = SpairDatasetSingle(args.dataset, split="test") 38 | elif args.dataset.name == "cub": 39 | dataset_test_time_mem = CUBDataset(args.dataset, split="test") 40 | elif args.dataset.name == 'ap10k': 41 | dataset_test_time_mem = AP10KPairs(args.dataset, split="test") 42 | elif args.dataset.name == 'pfpascal': 43 | dataset_test_time_mem = PFPascalPairs(args.dataset, split="test") 44 | 45 | dataset_test_time_mem.featurizer = get_featurizer(args.featurizer) 46 | dataset_test_time_mem.featurizer_kwargs = args.featurizer 47 | results = evaluate(dataset_test_time_mem, args.num_imgs_time_mem, model_refine=model_refine) 48 | wandb.log(results) 49 | finish_wandb() 50 | 51 | if __name__ == '__main__': 52 | main() -------------------------------------------------------------------------------- /scripts/store_feats.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from omegaconf import DictConfig, OmegaConf 3 | import hydra 4 | from src.dataset.cub_200 import CUBDatasetBordersCut, CUBDataset 5 | from src.dataset.spair_single import SpairDatasetSingle 6 | from src.dataset.pascalparts import PascalParts 7 | from src.dataset.apk_pairs import AP10KPairs 8 | from src.dataset.pfpascal_pairs import PFPascalPairs 9 | from src.models.featurizer.utils import get_featurizer 10 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 11 | 12 | @hydra.main(config_path="../configs") 13 | def main(args: DictConfig): 14 | # convert hydra to dict 15 | cfg = OmegaConf.to_container(args) 16 | # init the dataset 17 | args.dataset.sup = "sup_original" 18 | torch.cuda.set_device(0) 19 | if args.dataset.name == 'spair': 20 | dataset = SpairDatasetSingle(args.dataset, split=args.dataset.split) 21 | elif args.dataset.name == 'cub': 22 | # dataset = CUBDatasetBordersCut(args.dataset, split=args.dataset.split) 23 | dataset = CUBDataset(args.dataset, split=args.dataset.split) 24 | elif args.dataset.name == 'pascalparts': 25 | dataset = PascalParts(args.dataset, split=args.dataset.split) 26 | elif args.dataset.name == 'ap10k': 27 | dataset = AP10KPairs(args.dataset, split=args.dataset.split) 28 | elif args.dataset.name == 'pfpascal': 29 | dataset = PFPascalPairs(args.dataset, split=args.dataset.split) 30 | 31 | if args.featurizer.model == 'dift_sd': 32 | args.featurizer.all_cats = dataset.all_cats 33 | featurizer = get_featurizer(args.featurizer) 34 | dataset.featurizer_name = featurizer.name 35 | dataset.featurizer = featurizer 36 | dataset.featurizer_kwargs = args.featurizer 37 | print(dataset.featurizer_name) 38 | 39 | overwrite=False 40 | dataset.store_feats(featurizer, overwrite, args.featurizer) 41 | 42 | if __name__ == '__main__': 43 | main() -------------------------------------------------------------------------------- /scripts/store_masks.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 4 | 5 | import torch 6 | from omegaconf import DictConfig, OmegaConf 7 | import hydra 8 | from src.dataset.cub_200 import CUBDatasetBordersCut, CUBDataset 9 | from src.dataset.spair_single import SpairDatasetSingle 10 | from src.dataset.pascalparts import PascalParts 11 | from src.dataset.apk_pairs import AP10KPairs 12 | from src.dataset.pfpascal_pairs import PFPascalPairs 13 | from src.models.segmentation.sam import SAM 14 | 15 | @hydra.main(config_path="../configs") 16 | def main(args: DictConfig): 17 | # init the dataset 18 | args.dataset.sup = "sup_original" 19 | torch.cuda.set_device(0) 20 | if args.dataset.name == 'spair': 21 | dataset = SpairDatasetSingle(args.dataset, split=args.dataset.split) 22 | elif args.dataset.name == 'cub': 23 | # dataset = CUBDatasetBordersCut(args.dataset, split=args.dataset.split) 24 | dataset = CUBDataset(args.dataset, split=args.dataset.split) 25 | elif args.dataset.name == 'pascalparts': 26 | dataset = PascalParts(args.dataset, split=args.dataset.split) 27 | elif args.dataset.name == 'ap10k': 28 | dataset = AP10KPairs(args.dataset, split=args.dataset.split) 29 | elif args.dataset.name == 'pfpascal': 30 | dataset = PFPascalPairs(args.dataset, split=args.dataset.split) 31 | 32 | model_seg = SAM(args) 33 | dataset.model_seg = model_seg 34 | dataset.return_masks = True 35 | dataset.model_seg_name = model_seg.name 36 | 37 | overwrite=False 38 | dataset.store_masks(overwrite) 39 | 40 | if __name__ == '__main__': 41 | main() -------------------------------------------------------------------------------- /scripts/train_pairs.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 4 | 5 | import torch 6 | from omegaconf import DictConfig, OmegaConf 7 | import hydra 8 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 9 | import copy 10 | from src.models.featurizer.utils import get_featurizer 11 | from src.models.featurizer.utils import save_checkpoint as save_featurizer 12 | from src.logging.log_results import log_wandb_epoch, log_wandb_cfg, log_wandb_ram_usage, finish_wandb, log_wandb, init_wandb 13 | from src.dataset.cub_200_pairs import CUBPairDataset 14 | from src.dataset.spair import SpairDataset2 15 | from src.dataset.apk_pairs import AP10KPairs 16 | from src.dataset.pfpascal_pairs import PFPascalPairs 17 | from src.evaluation.evaluation import Evaluation 18 | from src.models.featurizer_refine import get_model_refine 19 | from src.models.featurizer_refine import save_checkpoint as save_model_refine 20 | from src.losses import PairwiseLoss 21 | import time 22 | from src.dataset.utils import get_multi_cat_dataset 23 | 24 | def set_seed(dataset_train, args, epoch): 25 | if args.dataset.cat == "all": 26 | # set seed of all the subdatasets of the concatenated dataset 27 | for dataset in dataset_train.datasets: 28 | dataset.seed_pairs = epoch 29 | else: 30 | dataset_train.seed_pairs = epoch # set the seed for the dataset to get the same pairs for each run, but different pairs for each epoch 31 | 32 | def forward_pass(model_refine, ft0, ft1): 33 | # get the new features 34 | if model_refine is not None: 35 | ft_new = [model_refine(f) for f in [ft0, ft1]] 36 | else: 37 | ft_new = [ft0, ft1] 38 | ft_new_ = [f.permute(0,2,3,1).flatten(1,-2) for f in ft_new] # each of shape (B, H*W, C) 39 | return ft_new_[0], ft_new_[1] 40 | 41 | def train_batch(data, model_refine, loss_fun): 42 | torch.cuda.empty_cache() 43 | ft0, ft1 = forward_pass(model_refine, data['src_ft'], data['trg_ft']) 44 | losses = loss_fun.get_loss(ft0, ft1, data) 45 | return losses 46 | 47 | def train_epoch(train_loader, model_refine, optimizer, weights, loss_fun, args, epoch, evaluation_test, evaluation_train, evaluation_val, evaluation_val_gen, evaluation_val_gen2, evaluation_val_gen3, best_pck, featurizer=None): 48 | # iterate over the dataset 49 | I = len(train_loader) 50 | for i, data in enumerate(train_loader): 51 | start_loss = time.time() 52 | if model_refine is not None: 53 | model_refine.train() 54 | else: 55 | featurizer.train() 56 | losses = train_batch(data, model_refine, loss_fun) 57 | # compute the mean loss 58 | losses = {k: v.mean()*weights[k] for k,v in losses.items()} 59 | # backprop 60 | optimizer.zero_grad() 61 | loss = sum(losses.values()) 62 | loss.backward() 63 | optimizer.step() 64 | # measure runtime of loss computation 65 | timediff_loss = time.time()-start_loss 66 | log_wandb({'loss_time':timediff_loss}) 67 | # log the losses, ram usage and evaluation 68 | if ((epoch-1)*I+i+1) % round(1*(6/args.batch_size)) == 0: 69 | log_wandb_epoch(epoch-1+i/I) 70 | if ((epoch-1)*I+i+1) % round(50*(6/args.batch_size)) == 0: 71 | losses['epoch'] = epoch-1+i/I 72 | log_wandb(losses) 73 | print ('Epoch [{}/{}], Step [{}/{}] Losses saved'.format(epoch, args.epoch,i,I)) 74 | if ((epoch-1)*I+i+1) % round(1000*(6/args.batch_size)) == 0: 75 | if model_refine is not None: 76 | model_refine.eval() 77 | else: 78 | featurizer.eval() 79 | start_eval = time.time() 80 | log_wandb_ram_usage() 81 | print ('Epoch [{}/{}], Step [{}/{}] Eval start'.format(epoch, args.epoch,i,I)) 82 | print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) 83 | for eva in [evaluation_test, evaluation_train, evaluation_val, evaluation_val_gen, evaluation_val_gen2, evaluation_val_gen3]: 84 | eva.epoch = epoch-1+i/I 85 | print("Epoch [{}/{}], Step [{}/{}] Eval on val set".format(epoch, args.epoch,i,I)) 86 | n_eval_imgs = 10 87 | evaluation_val.evaluate_pck(model_refine, n_pairs_eval_pck=n_eval_imgs, suffix=evaluation_val.dataset_test_pck.name) 88 | evaluation_val_gen.evaluate_pck(model_refine, n_pairs_eval_pck=n_eval_imgs, suffix=evaluation_val_gen.dataset_test_pck.name) 89 | evaluation_val_gen2.evaluate_pck(model_refine, n_pairs_eval_pck=n_eval_imgs, suffix=evaluation_val_gen2.dataset_test_pck.name) 90 | evaluation_val_gen3.evaluate_pck(model_refine, n_pairs_eval_pck=n_eval_imgs, suffix=evaluation_val_gen3.dataset_test_pck.name) 91 | log_wandb_ram_usage() 92 | print ('Epoch [{}/{}], Step [{}/{}] Eval end'.format(epoch, args.epoch,i,I)) 93 | print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) 94 | # process checkpoint 95 | suffix = evaluation_val.dataset_test_pck.name 96 | val_pck =evaluation_val.results_pck[f'per point PCK@0.1 _val_{n_eval_imgs}pairs_{suffix}'] 97 | suffix = evaluation_val_gen.dataset_test_pck.name 98 | val_pck+=evaluation_val_gen.results_pck[f'per point PCK@0.1 _val_{n_eval_imgs}pairs_{suffix}'] 99 | suffix = evaluation_val_gen2.dataset_test_pck.name 100 | val_pck+=evaluation_val_gen2.results_pck[f'per point PCK@0.1 _val_{n_eval_imgs}pairs_{suffix}'] 101 | suffix = evaluation_val_gen3.dataset_test_pck.name 102 | val_pck+=evaluation_val_gen3.results_pck[f'per point PCK@0.1 _val_{n_eval_imgs}pairs_{suffix}'] 103 | timediff_eval = time.time()-start_eval 104 | print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) 105 | print(f"Epoch [{epoch}/{args.epoch}], Step [{i}/{I}] Eval time: {timediff_eval:.2f} s, PCK@0.1: {val_pck:.2f}") 106 | log_wandb({'eval_time':timediff_eval}) 107 | if val_pck> best_pck: 108 | best_pck = val_pck 109 | if model_refine is not None: 110 | save_model_refine(model_refine, args.feat_refine, "best") 111 | else: 112 | save_featurizer(featurizer, args.featurizer, "best_lora_weights") 113 | 114 | if model_refine is not None: 115 | model_refine.train() 116 | else: 117 | featurizer.train() 118 | if ((epoch-1)*I+i+1) % round(1000*(6/args.batch_size)) == 0: 119 | if model_refine is not None: 120 | model_refine.epoch = epoch-1+i/I 121 | save_model_refine(model_refine, args.feat_refine, "last") 122 | else: 123 | featurizer.epoch = epoch-1+i/I 124 | save_featurizer(featurizer, args.featurizer, "last_lora_weights") 125 | if model_refine is not None: 126 | save_model_refine(model_refine, args.feat_refine, "last") 127 | else: 128 | save_featurizer(featurizer, args.featurizer, "last_lora_weights") 129 | return best_pck 130 | 131 | @hydra.main(config_path="../configs") 132 | def main(args: DictConfig): 133 | cfg = OmegaConf.to_container(args) 134 | init_wandb(cfg, 'train_pairs') 135 | torch.cuda.set_device(0) 136 | # init the dataset 137 | dataset_list = [] 138 | for dataset_name in ['dataset', 'dataset2', 'dataset3', 'dataset4']: 139 | if dataset_name in args: 140 | dataset_args = args[dataset_name] 141 | if dataset_args.name == 'spair': 142 | dataset_list.append(SpairDataset2(dataset_args, split='train')) 143 | elif dataset_args.name == 'cub': 144 | dataset_list.append(CUBPairDataset(dataset_args, split='train')) 145 | elif dataset_args.name == 'ap10k': 146 | dataset_list.append(AP10KPairs(dataset_args, split='train')) 147 | elif dataset_args.name == 'pfpascal': 148 | dataset_list.append(PFPascalPairs(dataset_args, split='train')) 149 | # init the featurizer 150 | log_wandb_ram_usage() 151 | featurizer = get_featurizer(args.featurizer) 152 | log_wandb_ram_usage() 153 | # init the evaluation 154 | evaluation_test = Evaluation(args, featurizer) 155 | evaluation_train = Evaluation(args, featurizer, split='train') 156 | evaluation_val = Evaluation(args, featurizer, split='val') 157 | log_wandb_ram_usage() 158 | if args.dataset.cat == "all" or 'dataset2' in args: 159 | # avoid deep copy of the featurizer 160 | dataset_train_multi = get_multi_cat_dataset(dataset_list, featurizer, args.featurizer, model_seg_name=args.model_seg_name) 161 | train_loader = torch.utils.data.DataLoader(dataset_train_multi, batch_size=args.batch_size, shuffle=True) 162 | for dataset in dataset_train_multi.datasets: 163 | print(f"Number of pairs in dataset {dataset.name} {dataset.cat}: {len(dataset)}") 164 | else: 165 | # We train for each category separately 166 | dataset_list[-1].featurizer = featurizer 167 | dataset_list[-1].featurizer_kwargs = args.featurizer 168 | dataset_list[-1].init_kps_cat(args.dataset.cat) 169 | dataset_list[-1].model_seg_name = args.model_seg_name 170 | dataset_list[-1].return_masks = True 171 | train_loader = torch.utils.data.DataLoader(dataset_list[-1], batch_size=args.batch_size, shuffle=True) 172 | log_wandb_ram_usage() 173 | weights = {'pos':args.losses.pos, 'bin':args.losses.bin, 'neg':args.losses.neg, 'neg_fg_bkg':args.losses.neg_fg_bkg} 174 | log_wandb_cfg({'weights':weights}) 175 | # init the model 176 | ep=0 177 | if 'feat_refine' in args: 178 | model_refine = get_model_refine(args.feat_refine) 179 | log_wandb_ram_usage() 180 | model_refine = model_refine.to(device) 181 | model_refine = model_refine.train() 182 | log_wandb_ram_usage() 183 | # define the optimizer 184 | # optimizer = torch.optim.Adam(list(model_refine.parameters())+list(matcher.parameters()), lr=args.learning_rate) 185 | optimizer = torch.optim.AdamW(list(model_refine.parameters()), lr=args.learning_rate, weight_decay=args.weight_decay) 186 | else: 187 | featurizer = featurizer.to(device) 188 | featurizer = featurizer.train() 189 | model_refine = None 190 | lora_params = [p for p in featurizer.parameters() if p.requires_grad] 191 | optimizer = torch.optim.AdamW(lora_params, lr=args.learning_rate, weight_decay=args.weight_decay) 192 | # define the scheduler 193 | if args.scheduler is not None: 194 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.learning_rate, steps_per_epoch=len(dataset_list[-1])//args.batch_size, epochs=args.epoch, pct_start=0.3) 195 | else: 196 | scheduler = None 197 | # add evaluation for different datasets 198 | argsgen = copy.deepcopy(args) 199 | argsgen.dataset = argsgen.datasetgeneralization 200 | evaluation_val_gen = Evaluation(argsgen, featurizer, split='val') 201 | argsgen2 = copy.deepcopy(args) 202 | argsgen2.dataset = argsgen.datasetgeneralization2 203 | evaluation_val_gen2 = Evaluation(argsgen2, featurizer, split='val') 204 | argsgen3 = copy.deepcopy(args) 205 | argsgen3.dataset = argsgen.datasetgeneralization3 206 | evaluation_val_gen3 = Evaluation(argsgen3, featurizer, split='val') 207 | # define the loss function 208 | loss_fun = PairwiseLoss(args) 209 | # start training 210 | best_pck = 0 211 | for epoch in range(ep+1, args.epoch+1): 212 | # set seed 213 | set_seed(train_loader.dataset, args, epoch) 214 | 215 | best_pck = train_epoch(train_loader, model_refine, optimizer, weights, loss_fun, args, epoch, evaluation_test, evaluation_train, evaluation_val, evaluation_val_gen, evaluation_val_gen2, evaluation_val_gen3, best_pck, featurizer=featurizer) 216 | if scheduler is not None: 217 | scheduler.step() 218 | # evaluate the best model 219 | if model_refine is not None: 220 | OmegaConf.update(args, "feat_refine.init.load_pretrained", True, force_add=True) 221 | OmegaConf.update(args, "feat_refine.init.id", model_refine.id, force_add=True) 222 | model_refine = get_model_refine(args.feat_refine) 223 | evaluation_test.reset_cat() 224 | evaluation_test.evaluate_pck(model_refine) 225 | finish_wandb() 226 | 227 | if __name__ == '__main__': 228 | main() 229 | -------------------------------------------------------------------------------- /setup_env.sh: -------------------------------------------------------------------------------- 1 | pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118 2 | pip3 install -U xformers --index-url https://download.pytorch.org/whl/cu118 # this deletes the torch installation, so we need to reinstall it and cleanup the nvidia installs 3 | pip install jupyterlab 4 | pip install -U matplotlib 5 | pip install transformers 6 | pip install ipympl 7 | pip install triton 8 | pip install open_clip_torch # this deletes the torch installation, so we need to reinstall it and cleanup the nvidia installs 9 | pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118 10 | pip install hydra-core 11 | pip install -U scikit-learn 12 | pip install pandas 13 | pip install wandb 14 | pip install POT 15 | pip install --upgrade diffusers[torch] 16 | python -m pip install 'git+https://github.com/facebookresearch/detectron2.git' 17 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reginehartwig/geco/3d3f37530bd61f36602f8ac840da13200c3c2e72/src/__init__.py -------------------------------------------------------------------------------- /src/dataset/augmentations.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from torchvision import transforms 4 | import torchvision.transforms.functional as F 5 | import numpy as np 6 | import torch 7 | 8 | 9 | class GaussianBlur(transforms.RandomApply): 10 | """ 11 | Apply Gaussian Blur to the PIL image. 12 | """ 13 | 14 | def __init__(self, *, p: float = 0.5, radius_min: float = 0.1, radius_max: float = 2.0): 15 | # NOTE: torchvision is applying 1 - probability to return the original image 16 | keep_p = 1 - p 17 | transform = transforms.GaussianBlur(kernel_size=9, sigma=(radius_min, radius_max)) 18 | super().__init__(transforms=[transform], p=keep_p) 19 | 20 | 21 | class DataAugmentation(object): 22 | def __init__( 23 | self, 24 | crops_scale=None, 25 | crops_size=None, 26 | KP_LEFT_RIGHT_PERMUTATION=None, 27 | KP_WITH_ORIENTATION=None, 28 | color_aug=True, 29 | flip_aug=True, 30 | ): 31 | self.KP_LEFT_RIGHT_PERMUTATION = KP_LEFT_RIGHT_PERMUTATION 32 | self.KP_WITH_ORIENTATION = KP_WITH_ORIENTATION 33 | if self.KP_WITH_ORIENTATION is not None: 34 | self.KP_WITH_ORIENTATION = torch.tensor(self.KP_WITH_ORIENTATION).flatten() 35 | self.flip_aug = flip_aug 36 | if crops_scale!=None: 37 | # random resized crop 38 | self.crops_size = crops_size 39 | self.geometric_augmentation_1 = transforms.RandomResizedCrop( 40 | crops_size, scale=crops_scale, ratio=(0.9, 1.1), 41 | ) 42 | else: 43 | self.crops_size = None 44 | 45 | # color distorsions / blurring 46 | if color_aug: 47 | color_jittering = transforms.Compose( 48 | [ 49 | transforms.RandomApply( 50 | [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)], 51 | p=0.8, 52 | ) 53 | ] 54 | ) 55 | 56 | colors = transforms.Compose( 57 | [ 58 | GaussianBlur(p=0.1), 59 | transforms.RandomSolarize(threshold=128, p=0.2), 60 | ] 61 | ) 62 | 63 | self.color_augmentation = transforms.Compose([color_jittering, colors]) 64 | else: 65 | self.color_augmentation = None 66 | 67 | def augment_mask(self, mask): 68 | 69 | # Horizontal flip 70 | if self.flip_aug: 71 | hflip = np.random.binomial(1, p=0.5) 72 | else: 73 | hflip = False 74 | # hflip = False 75 | if hflip: 76 | if mask is not None: 77 | mask = F.hflip(mask) 78 | 79 | if self.crops_size!=None: 80 | i, j, h, w = self.geometric_augmentation_1.get_params(mask, scale=list(self.geometric_augmentation_1.scale), ratio=list(self.geometric_augmentation_1.ratio)) 81 | # crop the image 82 | if mask is not None: 83 | mask = F.resized_crop(mask, i, j, h, w, [self.crops_size, self.crops_size]) 84 | return mask 85 | 86 | def __call__(self, image, data): 87 | #{'ft': ft, 'imsize': imsize, 'kps': keypoints, 'kps_symm_only': kps_symm_only, 'bndbox':bndbox, 'idx':idx} 88 | 89 | # Horizontal flip 90 | if self.flip_aug: 91 | hflip = np.random.binomial(1, p=0.5) 92 | else: 93 | hflip = False 94 | 95 | # hflip = False 96 | if hflip: 97 | image = F.hflip(image) 98 | # data['bndbox'] is a list of 4 elements [xmin, ymin, xmax, ymax] 99 | xmin = data['bndbox'][0] 100 | xmax = data['bndbox'][2] 101 | data['bndbox'][0] = data['imsize'][1].item()-xmax 102 | data['bndbox'][2] = data['imsize'][1].item()-xmin 103 | data['hflip'] = True 104 | else: 105 | data['hflip'] = False 106 | 107 | if hflip: 108 | if self.KP_WITH_ORIENTATION is not None: 109 | # Horizontal flip 110 | data['kps'][:, 1] = data['imsize'][1].item() - data['kps'][:, 1] - 1 111 | data['kps_symm_only'] = data['kps'].clone() 112 | 113 | data['kps_symm_only'][~self.KP_WITH_ORIENTATION, 2] = 0 114 | data['kps'][self.KP_WITH_ORIENTATION, 2] = 0.5 115 | elif self.KP_LEFT_RIGHT_PERMUTATION is not None: 116 | for key in ['kps','kps_symm_only']: 117 | data[key][:, 1] = data['imsize'][1].item() - data[key][:, 1] - 1 118 | data[key] = data[key][self.KP_LEFT_RIGHT_PERMUTATION] 119 | 120 | # Crop 121 | if self.crops_size!=None: 122 | i, j, h, w = self.geometric_augmentation_1.get_params(image, scale=list(self.geometric_augmentation_1.scale), ratio=list(self.geometric_augmentation_1.ratio)) 123 | # i,j are the top left corner of the crop (firs element for top, second for left) 124 | # h,w are the height and width of the crop 125 | # crop the image 126 | image = F.resized_crop(image, i, j, h, w, [self.crops_size, self.crops_size]) 127 | data['imsize'] = torch.tensor((self.crops_size, self.crops_size)) 128 | # data['bndbox'] is a list of 4 elements [xmin, ymin, xmax, ymax] 129 | data['bndbox'] = data['bndbox'] - np.array([j, i, j, i]) 130 | # Scale the bounding box to the new size 131 | data['bndbox'] = data['bndbox'] * np.array([self.crops_size / w, self.crops_size / h, self.crops_size / w, self.crops_size / h]) 132 | 133 | 134 | for key in ['kps','kps_symm_only']: 135 | keypoints = data[key] 136 | visible = keypoints[:, 2]>0.5 137 | if self.crops_size!=None: 138 | # Crop 139 | keypoints[visible, 0] -= i # left 140 | keypoints[visible, 1] -= j # top 141 | keypoints[visible, 0] *= self.crops_size/h 142 | keypoints[visible, 1] *= self.crops_size/w 143 | keypoints[visible, 2] *= torch.bitwise_and((keypoints[visible, 0] < data['imsize'][0]) , (keypoints[visible, 1] < data['imsize'][1]) ).float() 144 | keypoints[visible, 2] *= torch.bitwise_and((keypoints[visible, 0] >= 0 ) , (keypoints[visible, 1] >= 0 )).float() 145 | data[key] = keypoints 146 | 147 | 148 | # Color augmentation 149 | if self.color_augmentation is not None: 150 | image = self.color_augmentation(image) 151 | 152 | return image, data 153 | -------------------------------------------------------------------------------- /src/dataset/cub_200_pairs.py: -------------------------------------------------------------------------------- 1 | from src.dataset.cub_200 import CUBDataset, CUBDatasetBordersCut, CUBDatasetAugmented 2 | from src.dataset.pairwise_utils import random_pairs_2 3 | from src.dataset.random_utils import use_seed 4 | import torch 5 | import numpy as np 6 | 7 | def CUBPairDataset(args, **kwargs): 8 | if args.sup == 'sup_original': 9 | if args.borders_cut: 10 | return CUBPairDatasetOrigBC(args, **kwargs) 11 | else: 12 | return CUBPairDatasetOrig(args, **kwargs) 13 | elif args.sup == 'sup_augmented': 14 | return CUBPairDatasetAugmentedPadded(args, **kwargs) 15 | else: 16 | raise ValueError(f"Unknown supervision type {args.sup}") 17 | 18 | class CUBPairDatasetOrig(CUBDataset): 19 | pck_symm = True 20 | 21 | def __init__(self, args, **kwargs): 22 | # init the parent 23 | super().__init__(args, **kwargs) 24 | idx0, idx1 = random_pairs_2(args.n_pairs, len(self.data)) 25 | self.pairs = list(zip(idx0, idx1)) 26 | self.return_imgs = False 27 | self.return_masks = False 28 | assert(args.borders_cut == False) 29 | 30 | def __len__(self): 31 | return len(self.pairs) 32 | 33 | def get_imgs(self, idx): 34 | idx0, idx1 = self.pairs[idx] 35 | img0 = super().get_img(idx0) 36 | img1 = super().get_img(idx1) 37 | return img0, img1 38 | 39 | def get_masks(self, idx): 40 | idx0, idx1 = self.pairs[idx] 41 | mask0 = super().get_mask(idx0) 42 | mask1 = super().get_mask(idx1) 43 | return mask0, mask1 44 | 45 | def __getitem__(self, idx): 46 | idx0, idx1 = self.pairs[idx] 47 | data0 = super().__getitem__(idx0) 48 | data1 = super().__getitem__(idx1) 49 | data = {'src_imsize': data0['imsize'], 50 | 'trg_imsize': data1['imsize'], 51 | 'src_kps': data0['kps'], 52 | 'trg_kps': data1['kps'], 53 | 'trg_kps_symm_only': data1['kps_symm_only'], 54 | 'src_kps_symm_only': data0['kps_symm_only'], 55 | 'src_bndbox': data0['bndbox'], 56 | 'trg_bndbox': data1['bndbox'], 57 | 'cat': self.cat, 58 | 'idx': idx} 59 | 60 | data['numkp'] = data0['kps'].shape[0] 61 | if self.return_feats: 62 | data['src_ft'] = data0['ft'] 63 | data['trg_ft'] = data1['ft'] 64 | if self.return_imgs: 65 | img0,img1 = self.get_imgs(idx) 66 | data['src_img'] = np.array(img0) 67 | data['trg_img'] = np.array(img1) 68 | if self.return_masks: 69 | mask0,mask1 = self.get_masks(idx) 70 | data['src_mask'] = mask0 71 | data['trg_mask'] = mask1 72 | return data 73 | 74 | class CUBPairDatasetOrigBC(CUBDatasetBordersCut): 75 | pck_symm = True 76 | 77 | def __init__(self, args, **kwargs): 78 | # init the parent 79 | super().__init__(args, **kwargs) 80 | idx0, idx1 = random_pairs_2(args.n_pairs, len(self.data)) 81 | self.pairs = list(zip(idx0, idx1)) 82 | self.return_imgs = False 83 | self.return_masks = False 84 | 85 | def __len__(self): 86 | return len(self.pairs) 87 | 88 | def get_imgs(self, idx): 89 | idx0, idx1 = self.pairs[idx] 90 | img0 = super().get_img(idx0) 91 | img1 = super().get_img(idx1) 92 | return img0, img1 93 | 94 | def __getitem__(self, idx): 95 | idx0, idx1 = self.pairs[idx] 96 | data0 = super().__getitem__(idx0) 97 | data1 = super().__getitem__(idx1) 98 | data = {'src_imsize': data0['imsize'], 99 | 'trg_imsize': data1['imsize'], 100 | 'src_kps': data0['kps'], 101 | 'trg_kps': data1['kps'], 102 | 'trg_kps_symm_only': data1['kps_symm_only'], 103 | 'src_kps_symm_only': data0['kps_symm_only'], 104 | 'src_bndbox': data0['bndbox'], 105 | 'trg_bndbox': data1['bndbox'], 106 | 'cat': self.cat, 107 | 'idx': idx} 108 | if self.return_feats: 109 | data['src_ft'] = data0['ft'] 110 | data['trg_ft'] = data1['ft'] 111 | if self.return_imgs: 112 | img0,img1 = self.get_imgs(idx) 113 | data['src_img'] = np.array(img0) 114 | data['trg_img'] = np.array(img1) 115 | data['numkp'] = data0['kps'].shape[0] 116 | if self.return_masks: 117 | mask0,mask1 = self.get_masks(idx) 118 | data['src_mask'] = mask0 119 | data['trg_mask'] = mask1 120 | return data 121 | 122 | class CUBPairDatasetAugmented(CUBDatasetAugmented): 123 | pck_symm = True 124 | 125 | def __init__(self, args, **kwargs): 126 | # init the parent 127 | super().__init__(args, **kwargs) 128 | idx0, idx1 = random_pairs_2(args.n_pairs, len(self.data)) 129 | self.pairs = list(zip(idx0, idx1)) 130 | self.seed_pairs = 1 # this can be changed for different epochs 131 | self.return_imgs = False 132 | self.return_masks = False 133 | 134 | def __len__(self): 135 | return len(self.pairs) 136 | 137 | def get_imgs(self, idx): 138 | idx0, idx1 = self.pairs[idx] 139 | with use_seed(idx+self.seed_pairs): 140 | seed_offset0 = torch.randint(1000000, (1,)).item() 141 | seed_offset1 = torch.randint(1000000, (1,)).item() 142 | self.seed = self.seed+seed_offset0 143 | img0 = super().get_img(idx0) 144 | self.seed = self.seed-seed_offset0 145 | self.seed = self.seed+seed_offset1 146 | img1 = super().get_img(idx1) 147 | self.seed = self.seed-seed_offset1 148 | return img0, img1 149 | 150 | def get_masks(self, idx): 151 | idx0, idx1 = self.pairs[idx] 152 | with use_seed(idx+self.seed_pairs): 153 | seed_offset0 = torch.randint(1000000, (1,)).item() 154 | seed_offset1 = torch.randint(1000000, (1,)).item() 155 | self.seed = self.seed+seed_offset0 156 | mask0 = super().get_mask(idx0) 157 | self.seed = self.seed-seed_offset0 158 | self.seed = self.seed+seed_offset1 159 | mask1 = super().get_mask(idx1) 160 | self.seed = self.seed-seed_offset1 161 | return mask0, mask1 162 | 163 | def __getitem__(self, idx): 164 | idx0, idx1 = self.pairs[idx] 165 | with use_seed(idx+self.seed_pairs): 166 | seed_offset0 = torch.randint(1000000, (1,)).item() 167 | seed_offset1 = torch.randint(1000000, (1,)).item() 168 | self.seed = self.seed+seed_offset0 169 | data0 = super().__getitem__(idx0) 170 | self.seed = self.seed-seed_offset0 171 | self.seed = self.seed+seed_offset1 172 | data1 = super().__getitem__(idx1) 173 | self.seed = self.seed-seed_offset1 174 | data = {'src_imsize': data0['imsize'], 175 | 'trg_imsize': data1['imsize'], 176 | 'src_kps': data0['kps'], 177 | 'trg_kps': data1['kps'], 178 | 'trg_kps_symm_only': data1['kps_symm_only'], 179 | 'src_kps_symm_only': data0['kps_symm_only'], 180 | 'src_bndbox': data0['bndbox'], 181 | 'trg_bndbox': data1['bndbox'], 182 | 'src_hflip': data0['hflip'], 183 | 'trg_hflip': data1['hflip'], 184 | 'cat': self.cat, 185 | 'idx': idx} 186 | 187 | data['numkp'] = len(self.KP_NAMES) 188 | 189 | if self.return_feats: 190 | data['src_ft'] = data0['ft'] 191 | data['trg_ft'] = data1['ft'] 192 | if self.return_imgs: 193 | img0,img1 = self.get_imgs(idx) 194 | data['src_img'] = np.array(img0) 195 | data['trg_img'] = np.array(img1) 196 | if self.return_masks: 197 | mask0,mask1 = self.get_masks(idx) 198 | data['src_mask'] = mask0 199 | data['trg_mask'] = mask1 200 | return data 201 | 202 | class CUBPairDatasetAugmentedPadded(CUBPairDatasetAugmented): 203 | def __init__(self, args, **kwargs): 204 | super().__init__(args, **kwargs) 205 | self.n_kps = 100 206 | self.num_el = args.num_el if 'num_el' in args else None 207 | 208 | def total_len(self): 209 | return len(self.pairs) 210 | 211 | def __len__(self): 212 | if self.num_el is None: 213 | return self.total_len() 214 | return min(self.num_el, self.total_len()) 215 | 216 | def pad_kps(self, kps): 217 | kps = kps.clone() 218 | pad = self.n_kps - kps.shape[0] 219 | if pad > 0: 220 | kps = torch.cat([kps, torch.zeros(pad, 3)], dim=0) 221 | return kps 222 | 223 | def init_kps_cat(self, cat): 224 | super().init_kps_cat(cat) 225 | # remove the shuffeled_idx_list 226 | if hasattr(self, 'shuffeled_idx_list'): 227 | delattr(self, 'shuffeled_idx_list') 228 | 229 | def __getitem__(self, idx): 230 | totallen = self.total_len() 231 | subsetlen = self.__len__() 232 | # create new idx list excluding the idx that have been used 233 | if not hasattr(self, 'shuffeled_idx_list'): 234 | with use_seed(self.seed_pairs): 235 | self.shuffeled_idx_list = np.random.permutation(totallen) 236 | elif len(self.shuffeled_idx_list) < self.seed_pairs*subsetlen: 237 | with use_seed(self.seed_pairs+1): 238 | self.shuffeled_idx_list = np.append(self.shuffeled_idx_list, np.random.permutation(totallen)) 239 | idx_ = self.shuffeled_idx_list[idx+(self.seed_pairs-1)*subsetlen] 240 | 241 | data_out = super().__getitem__(int(idx_)) 242 | # pad the keypoints 243 | for k in data_out.keys(): 244 | if 'kps' in k: 245 | data_out[k] = self.pad_kps(data_out[k]) 246 | return data_out 247 | -------------------------------------------------------------------------------- /src/dataset/pairwise_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from src.dataset.random_utils import use_seed 3 | import numpy as np 4 | from src.dataset.utils import to_flattened_idx_torch 5 | import copy 6 | def random_pairs(n_pairs,n_imgs,permute_first=True): 7 | if permute_first: 8 | idx0 = torch.randint(n_imgs, (n_pairs,)) 9 | else: 10 | idx0 = torch.arange(n_imgs).repeat(n_pairs//n_imgs) 11 | idx0 = idx0[:n_pairs] 12 | idx1 = torch.randint(n_imgs, (n_pairs,)) 13 | for n in range(n_pairs): 14 | while idx0[n]==idx1[n]: 15 | idx1[n] = torch.randint(n_imgs, (1,)) 16 | return idx0, idx1 17 | 18 | def random_pairs_2(n_pairs, n_imgs): 19 | idx0, idx1 =[], [] 20 | while len(idx0) < n_pairs: 21 | with use_seed(len(idx0) + 123): 22 | indices = np.random.permutation(np.arange(0, n_imgs)) 23 | middle = n_imgs // 2 24 | idx0, idx1 = idx0 + indices[:middle].tolist(), idx1 + indices[-middle:].tolist() 25 | idx0, idx1 = idx0[:n_pairs], idx1[:n_pairs] 26 | return idx0, idx1 27 | 28 | def get_pos_pairs(data_in, b=0): 29 | src_kps_symm_only = data_in['src_kps_symm_only'][b] 30 | trg_kps_symm_only = data_in['trg_kps_symm_only'][b] 31 | src_kps = data_in['src_kps'][b] 32 | trg_kps = data_in['trg_kps'][b] 33 | vis_both = torch.bitwise_and(src_kps[:,2] > 0.5, trg_kps[:,2] > 0.5) 34 | pos_src_kps = src_kps[vis_both] 35 | pos_trg_kps = trg_kps[vis_both] 36 | flag_11_src = torch.bitwise_and(src_kps_symm_only[:,2] > 0.5, src_kps[:,2] > 0.5) 37 | flag_11_trg = torch.bitwise_and(trg_kps_symm_only[:,2] > 0.5, trg_kps[:,2] > 0.5) 38 | flags_11 = {'pos_src_11': flag_11_src[vis_both], 'pos_trg_11': flag_11_trg[vis_both]} 39 | return pos_src_kps, pos_trg_kps, flags_11 40 | 41 | def get_neg_pairs(data_in, b=0): 42 | src_kps_symm_only = data_in['src_kps_symm_only'][b] 43 | trg_kps_symm_only = data_in['trg_kps_symm_only'][b] 44 | src_kps = data_in['src_kps'][b] 45 | trg_kps = data_in['trg_kps'][b] 46 | # match to other points in src 47 | vis_both = torch.bitwise_and(src_kps_symm_only[:,2] > 0.5, trg_kps[:,2] > 0.5) 48 | neg_src_kps = src_kps_symm_only.clone() 49 | neg_trg_kps = trg_kps.clone() 50 | neg_src_kps[:,2] = vis_both 51 | neg_trg_kps[:,2] = vis_both 52 | flag_11_src = torch.bitwise_and(src_kps_symm_only[:,2] > 0.5, src_kps[:,2] > 0.5) 53 | # match to other points in trg 54 | vis_both = torch.bitwise_and(src_kps[:,2] > 0.5, trg_kps_symm_only[:,2] > 0.5) 55 | neg_src_kps[vis_both] = src_kps[vis_both] 56 | neg_trg_kps[vis_both] = trg_kps_symm_only[vis_both] 57 | neg_trg_kps[:,2] = torch.bitwise_or(neg_src_kps[:,2] > 0.5, vis_both) 58 | flag_11_trg = torch.bitwise_and(trg_kps_symm_only[:,2] > 0.5, trg_kps[:,2] > 0.5) 59 | # only keep the vis_both keypoints 60 | flags_11 = {'neg_src_11': flag_11_src[neg_src_kps[:,2] > 0.5], 'neg_trg_11': flag_11_trg[neg_trg_kps[:,2] > 0.5]} 61 | neg_src_kps = neg_src_kps[neg_src_kps[:,2] > 0.5] 62 | neg_trg_kps = neg_trg_kps[neg_trg_kps[:,2] > 0.5] 63 | return neg_src_kps, neg_trg_kps, flags_11 64 | 65 | def get_bin_pairs(data_in, b=0): 66 | src_kps_symm_only = data_in['src_kps_symm_only'][b] 67 | trg_kps_symm_only = data_in['trg_kps_symm_only'][b] 68 | src_kps = data_in['src_kps'][b] 69 | trg_kps = data_in['trg_kps'][b] 70 | occluded = torch.bitwise_xor(src_kps[:,2] == 0, trg_kps[:,2] == 0) 71 | vis_both = torch.bitwise_or(src_kps[:,2] > 0.5, trg_kps[:,2] > 0.5) 72 | bin_flag = torch.bitwise_and(occluded, vis_both) 73 | bin_src_kps = src_kps[bin_flag] 74 | bin_trg_kps = trg_kps[bin_flag] 75 | flag = bin_flag[bin_flag] 76 | flags_11 = {'bin_src_11': flag, 'bin_trg_11': flag} 77 | return bin_src_kps, bin_trg_kps, flags_11 78 | 79 | def get_matches(data_in, b=0): 80 | pos_src_kps, pos_trg_kps, pos_flag_11 = get_pos_pairs(data_in, b) 81 | neg_src_kps, neg_trg_kps, neg_flag_11 = get_neg_pairs(data_in, b) 82 | bin_src_kps, bin_trg_kps, bin_flag_11 = get_bin_pairs(data_in, b) 83 | # put them in a dictionary 84 | data_out = {'pos_src_kps': pos_src_kps, 85 | 'pos_trg_kps': pos_trg_kps, 86 | 'neg_src_kps': neg_src_kps, 87 | 'neg_trg_kps': neg_trg_kps, 88 | 'bin_src_kps': bin_src_kps, 89 | 'bin_trg_kps': bin_trg_kps} 90 | data_out.update(pos_flag_11) 91 | data_out.update(neg_flag_11) 92 | data_out.update(bin_flag_11) 93 | return data_out 94 | 95 | def scale_to_feature_dims(data_out, data_in, b=0): 96 | scale_src = torch.tensor(data_in['src_ft'].shape[-2:])/data_in['src_imsize'][b]# get the scale for the feature image coordinates to the original image coordinates 97 | scale_trg = torch.tensor(data_in['trg_ft'].shape[-2:])/data_in['trg_imsize'][b] 98 | 99 | data_out_ft = copy.deepcopy(data_out) 100 | for prefix in ['pos', 'bin', 'neg']: 101 | src_kps = data_out[prefix+'_src_kps'].clone() 102 | trg_kps = data_out[prefix+'_trg_kps'].clone() 103 | # convert to feature image coordinates 104 | src_kps[:,:2] = (src_kps[:,:2]*scale_src).floor().long() 105 | trg_kps[:,:2] = (trg_kps[:,:2]*scale_trg).floor().long() 106 | data_out_ft[prefix+'_src_kps'] = src_kps 107 | data_out_ft[prefix+'_trg_kps'] = trg_kps 108 | return data_out_ft 109 | 110 | def get_y_mat_gt_assignment(kp0, kp1, ft_size0, ft_size1): 111 | size = (ft_size0.prod()+1, ft_size1.prod()+1) 112 | if len(kp0) > 0: 113 | # enter dummy values for the keypoints that are not vis_both to avoid error in assert of flattened index function 114 | kp0[kp0[:,2]==0,0] = 0 115 | kp0[kp0[:,2]==0,1] = 0 116 | kp1[kp1[:,2]==0,0] = 0 117 | kp1[kp1[:,2]==0,1] = 0 118 | kp0_idx = to_flattened_idx_torch(kp0[:,0], kp0[:,1], ft_size0[0].item(), ft_size0[1].item()) 119 | kp1_idx = to_flattened_idx_torch(kp1[:,0], kp1[:,1], ft_size1[0].item(), ft_size1[1].item()) 120 | 121 | kp0_idx[kp0[:,2]==0] = ft_size0.prod() # bin idx 122 | kp1_idx[kp1[:,2]==0] = ft_size1.prod() # bin idx 123 | 124 | indices = torch.stack([kp0_idx, kp1_idx]).unique(dim=1, sorted=False) 125 | values = torch.ones_like(indices[0]).float() 126 | matrix = torch.sparse_coo_tensor(indices, values, size).coalesce() 127 | 128 | else: 129 | # if no keypoints are annotated, return a matrix with only zeros 130 | matrix = torch.sparse_coo_tensor(torch.empty((2, 0), dtype=torch.long), [], size).coalesce() 131 | return matrix -------------------------------------------------------------------------------- /src/dataset/pascalparts.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | import numpy as np 4 | import torch 5 | from torch.utils.data.dataset import Dataset as TorchDataset 6 | from scipy.io import loadmat 7 | import os 8 | from tqdm import tqdm 9 | from src.dataset.pascalparts_part2ind import part2ind 10 | from typing import Optional, Any 11 | OBJECT_CLASSES = ['aeroplane','bicycle','bird','boat','bottle','bus','car','cat', 12 | 'chair','cow','table','dog','horse','motorbike','person','pottedplant','sheep','sofa','train','tvmonitor'] 13 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 14 | 15 | class PascalParts(TorchDataset): 16 | name = 'pascalparts' 17 | name_this = 'pascalparts' 18 | 19 | def __init__(self, args, split='test'): 20 | self.cat = None 21 | self.root = args.dataset_path 22 | # join the path 23 | 24 | self.root_annotations = os.path.join(self.root,'Parts/Annotations_Part/') 25 | self.root_split = os.path.join(self.root,'VOCdevkit/VOC2010/ImageSets/Main/') 26 | self.root_imgs = os.path.join(self.root,'VOCdevkit/VOC2010/JPEGImages/') 27 | self.save_path = args.save_path 28 | self.split = split 29 | 30 | self.featurizer_name = None 31 | # Initialize attributes that may be set externally 32 | self.featurizer: Optional[Any] = None 33 | self.featurizer_kwargs: Optional[Any] = None 34 | self.model_seg: Optional[Any] = None 35 | self.model_seg_name: Optional[str] = None 36 | self.return_masks: bool = False 37 | 38 | self.all_cats = OBJECT_CLASSES 39 | self.all_cats_multipart = ['bird', 'cat', 'cow', 'dog', 'horse', 'sheep', 'person'] # categories with multiple parts that are annotated in many images 40 | 41 | try: 42 | self.data = torch.load(os.path.join(self.save_path, 'pascalparts.pth')) 43 | except: 44 | self.structure_mat() 45 | self.data = torch.load(os.path.join(self.save_path, 'pascalparts.pth')) 46 | 47 | def init_kps_cat(self, cat): 48 | self.cat = cat 49 | cat_idx = OBJECT_CLASSES.index(self.cat)+1 # the index of the category in the pascal parts dataset 50 | part_dict, KP_LEFT_RIGHT_PERMUTATION, KP_WITH_ORIENTATION = part2ind(cat_idx) # get the part dictionary for the category 51 | 52 | if KP_LEFT_RIGHT_PERMUTATION is not None: 53 | self.KP_WITH_ORIENTATION = KP_WITH_ORIENTATION 54 | self.KP_LEFT_RIGHT_PERMUTATION = KP_LEFT_RIGHT_PERMUTATION 55 | else: 56 | self.KP_WITH_ORIENTATION = None 57 | self.KP_LEFT_RIGHT_PERMUTATION = None 58 | 59 | if len(part_dict.values()) == 0: 60 | part_dict = {'fg': 1} 61 | 62 | self.KP_NAMES = [k for k, v in sorted(part_dict.items(), key=lambda item: item[1])] 63 | 64 | def structure_mat(self): 65 | data = {} 66 | mat_filenames = os.listdir(self.root_annotations) 67 | for cat in OBJECT_CLASSES: 68 | data[cat] = [] 69 | 70 | for idx, annotation_filename in enumerate(mat_filenames): 71 | mat = loadmat(os.path.join(self.root_annotations, annotation_filename), struct_as_record=False, squeeze_me=True)['anno'] 72 | obj = mat.__dict__['objects'] 73 | # check if obj is array, we only consider images with one object 74 | if isinstance(obj, np.ndarray): 75 | continue 76 | else: 77 | mat_cat = obj.__dict__['class'] 78 | data[mat_cat].append(mat.__dict__['imname']) 79 | torch.save(data, os.path.join(self.save_path, 'pascalparts.pth')) 80 | 81 | def __len__(self): 82 | return len(self.data[self.cat]) 83 | 84 | def _get_feat(self, idx, featurizer, featurizer_kwargs): 85 | cat = self.cat 86 | img = self.get_img(idx) 87 | feat = featurizer.forward(img, 88 | category=cat, 89 | **featurizer_kwargs) 90 | return feat 91 | 92 | def store_feats(self, featurizer, overwrite, featurizer_kwargs): 93 | assert(self.name == self.name_this) 94 | print("saving all %s images' features..."%self.split) 95 | self.imsizes = {} 96 | path = os.path.join(self.save_path, self.name_this, featurizer.name) 97 | for cat in tqdm(self.all_cats): 98 | self.init_kps_cat(cat) 99 | for idx in range(len(self)): 100 | ft_name = self.data[self.cat][idx]+'.pth' 101 | if os.path.exists(os.path.join(path, ft_name)) and not overwrite: 102 | continue 103 | feat = self._get_feat(idx, featurizer, featurizer_kwargs) 104 | # make directory of parent directory 105 | os.makedirs(os.path.join(path, os.path.dirname(ft_name)), exist_ok=True) 106 | torch.save(feat.detach().cpu(), os.path.join(path, ft_name)) 107 | del feat 108 | torch.cuda.empty_cache() 109 | 110 | def get_img(self, idx): 111 | imname = self.data[self.cat][idx]+'.jpg' 112 | img_path = os.path.join(self.root_imgs, imname) 113 | img = Image.open(img_path).convert('RGB') 114 | return img 115 | 116 | def get_feat(self, idx): 117 | assert(self.featurizer_name is not None) 118 | assert(self.cat is not None) 119 | ft_name = self.data[self.cat][idx]+'.pth' 120 | ft = torch.load(os.path.join(self.save_path, self.name_this, self.featurizer_name, ft_name)).to(device) 121 | return ft 122 | 123 | 124 | def get_parts(self, idx): 125 | ''' 126 | Output: 127 | part_dict: dictionary with part names as keys and part indices as values 128 | parts_mask: tensor of shape (num_parts, H, W) with one hot encoding of the parts, 129 | where num_parts is the number of parts in the category without the background, 130 | i.e. part_mask.sum(dim=0) should be 0 for the background 131 | 132 | ''' 133 | name = self.data[self.cat][idx] 134 | mat = loadmat(os.path.join(self.root_annotations, name+'.mat'), struct_as_record=False, squeeze_me=True)['anno'] 135 | # check if there is only one object in the image 136 | if isinstance(mat.__dict__['objects'], np.ndarray): 137 | raise ValueError('Multiple objects in image') 138 | obj_mask = mat.__dict__['objects'].__dict__['mask'] 139 | assert(self.cat is not None) 140 | cat_idx = OBJECT_CLASSES.index(self.cat)+1 # the index of the category in the pascal parts dataset 141 | part_dict, _, _ = part2ind(cat_idx) # get the part dictionary for the category 142 | if len(part_dict.values()) == 0: 143 | num_parts = 1 144 | parts_mask = torch.zeros(num_parts, obj_mask.shape[0], obj_mask.shape[1]) 145 | parts_mask[0, :, :] = torch.tensor(obj_mask) 146 | part_dict = {'fg': 1} 147 | else: 148 | num_parts = max(part_dict.values()) # not really the number of parts, as part indices are not continuous 149 | parts_mask = torch.zeros(num_parts, obj_mask.shape[0], obj_mask.shape[1]) 150 | # check if there is only one part in the object 151 | parts = mat.__dict__['objects'].__dict__['parts'] 152 | if not isinstance(parts, np.ndarray): 153 | parts = [parts] 154 | for part in parts: 155 | part_idx = part_dict[part.__dict__['part_name']]-1 156 | parts_mask[part_idx, :, :] = torch.tensor(part.__dict__['mask']) 157 | 158 | def one_hot_parts(part_mask): 159 | # just take the occurence with highest value in dimension 1, i.e. the part with the highest index, pascalparts has multiple labels per pixel and we want to take the one with the highest index 160 | bkg = (part_mask.sum(dim=0, keepdim=True)==0).repeat(part_mask.shape[0],1,1) 161 | prt_reverse = torch.flip(part_mask, [0]) 162 | label = part_mask.shape[0] - prt_reverse.argmax(dim=0, keepdim=True) - 1 # undo the reverse 163 | prt = torch.zeros_like(part_mask) 164 | prt = prt.scatter_(0, label, 1) 165 | prt[bkg]=0 166 | return prt 167 | 168 | parts_mask = one_hot_parts(parts_mask) 169 | 170 | num_parts = max(part_dict.values()) # not really the number of parts, as part indices are not continuous 171 | self.parts = torch.unique( torch.tensor(list(part_dict.values())) -1) 172 | num_parts_ = len(self.parts) 173 | 174 | if num_parts_ != num_parts: 175 | parts_mask = parts_mask[self.parts] 176 | 177 | return part_dict, parts_mask 178 | 179 | def __getitem__(self, idx): 180 | part_dict, parts_mask = self.get_parts(idx) 181 | try: 182 | ft = self.get_feat(idx)[0] 183 | except: 184 | assert(self.featurizer is not None) 185 | assert(self.featurizer_kwargs is not None) 186 | ft = self._get_feat(idx, self.featurizer, self.featurizer_kwargs)[0] 187 | imsize = torch.tensor((parts_mask.shape[-2], parts_mask.shape[-1])) 188 | return {'ft': ft, 'imsize': imsize, 'parts_mask': parts_mask, 'part_dict': part_dict, 'idx': idx} 189 | -------------------------------------------------------------------------------- /src/dataset/random_utils.py: -------------------------------------------------------------------------------- 1 | 2 | from functools import wraps 3 | 4 | 5 | from numpy.random import seed as np_seed 6 | from numpy.random import get_state as np_get_state 7 | from numpy.random import set_state as np_set_state 8 | from random import seed as rand_seed 9 | from random import getstate as rand_get_state 10 | from random import setstate as rand_set_state 11 | import torch 12 | from torch import manual_seed as torch_seed 13 | from torch import get_rng_state as torch_get_state 14 | from torch import set_rng_state as torch_set_state 15 | 16 | class use_seed: 17 | def __init__(self, seed=None): 18 | if seed is not None: 19 | assert isinstance(seed, int) and seed >= 0 20 | self.seed = seed 21 | 22 | def __enter__(self): 23 | if self.seed is not None: 24 | self.rand_state = rand_get_state() 25 | self.np_state = np_get_state() 26 | self.torch_state = torch_get_state() 27 | self.torch_cudnn_deterministic = torch.backends.cudnn.deterministic 28 | rand_seed(self.seed) 29 | np_seed(self.seed) 30 | torch_seed(self.seed) 31 | torch.backends.cudnn.deterministic = True 32 | return self 33 | 34 | def __exit__(self, typ, val, _traceback): 35 | if self.seed is not None: 36 | rand_set_state(self.rand_state) 37 | np_set_state(self.np_state) 38 | torch_set_state(self.torch_state) 39 | torch.backends.cudnn.deterministic = self.torch_cudnn_deterministic 40 | 41 | def __call__(self, f): 42 | @wraps(f) 43 | def wrapper(*args, **kw): 44 | seed = self.seed if self.seed is not None else kw.pop('seed', None) 45 | with use_seed(seed): 46 | return f(*args, **kw) 47 | 48 | return wrapper 49 | -------------------------------------------------------------------------------- /src/dataset/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | import torch.nn.functional as F 5 | import copy 6 | from torch.utils.data import ConcatDataset 7 | 8 | @torch.no_grad() 9 | def get_feats(dataset, idx, model_refine=None, only_fg=False, feat_interpol=False, prt_interpol=True): 10 | try: 11 | ft = dataset[idx]['ft'][None] 12 | except: 13 | ft = dataset._get_feat(idx, dataset.featurizer, dataset.featurizer_kwargs) 14 | if model_refine is not None: 15 | ft = model_refine(ft) 16 | feats = ft.permute(0,2,3,1).flatten(0,-2)# B, C, H, W -> H*W, C 17 | 18 | if 'parts_mask' in dataset[idx].keys(): 19 | prt = dataset[idx]['parts_mask'][None] 20 | ftsize = torch.tensor(ft.shape[-2:]) 21 | prtsize = torch.tensor(prt.shape[-2:]) 22 | if feat_interpol: 23 | ft_interp = F.interpolate(ft, size=prtsize.tolist(), mode='bilinear', align_corners=False) 24 | feats = ft_interp.permute(0,2,3,1).flatten(0,-2)# B, C, H, W -> H*W, C 25 | if prt_interpol: 26 | prt_interp = F.interpolate(prt, size=ftsize.tolist(), mode='bilinear', align_corners=False) 27 | parts = prt_interp.permute(0,2,3,1).flatten(0,-2)# B, C, H, W -> H*W, num_parts 28 | else: 29 | parts = prt.permute(0,2,3,1).flatten(0,-2) 30 | 31 | if only_fg: 32 | mask = parts.sum(-1)>0 33 | feats = feats[mask] 34 | parts = parts[mask] 35 | 36 | elif 'kps' in dataset[idx].keys() and only_fg: 37 | 38 | # Extract keypoint information 39 | kp = dataset[idx]['kps'] 40 | imgsize = dataset[idx]['imsize'] 41 | 42 | 43 | # Filter keypoints based on visibility 44 | visible_mask = kp[:, 2] > 0.5 45 | kp_ = kp[visible_mask, :2] 46 | 47 | # Normalize keypoint coordinates to range [-1, 1] for grid sampling 48 | h, w = ft.shape[2:] 49 | kp_normalized = torch.stack([ 50 | 2 * kp_[:, 1] / (imgsize[1] - 1) - 1, 51 | 2 * kp_[:, 0] / (imgsize[0] - 1) - 1 52 | ], dim=-1).unsqueeze(0).unsqueeze(0).to(ft.device) # Shape: (1, num_kps, 2) 53 | 54 | # Use grid_sample to directly sample features at keypoint locations 55 | # Ensure kp_normalized and ft are on the same device 56 | feats = F.grid_sample(ft, kp_normalized, mode='bilinear', align_corners=False) 57 | feats = feats.squeeze(0).squeeze(1).t() # Shape: (num_vis_kps, C) 58 | 59 | parts = torch.eye(kp.shape[0])[visible_mask] # num_vis_kps, num_parts 60 | else: 61 | parts = None 62 | 63 | return feats, parts 64 | 65 | @torch.no_grad() 66 | def get_featpairs(dataset, idx, model_refine=None, only_fg=False, feat_interpol=False, prt_interpol=True): 67 | try: 68 | ft0 = dataset[idx]['src_ft'][None] 69 | ft1 = dataset[idx]['trg_ft'][None] 70 | except: 71 | raise ValueError('Need to have src_ft and tgt_ft in dataset') 72 | if model_refine is not None: 73 | ft0 = model_refine(ft0) 74 | ft1 = model_refine(ft1) 75 | feats = [ft0.permute(0,2,3,1).flatten(0,-2), ft1.permute(0,2,3,1).flatten(0,-2)]# B, C, H, W -> H*W, C 76 | parts = [] 77 | for i,(ft, prefix) in enumerate(zip([ft0, ft1], ['src', 'trg'])): 78 | if prefix+'_kps' in dataset[idx].keys(): 79 | 80 | kp = dataset[idx][prefix+'_kps'] 81 | imgsize = dataset[idx][prefix+'_imsize'] 82 | 83 | # Filter keypoints based on visibility 84 | visible_mask = kp[:, 2] > 0.5 85 | kp_ = kp[visible_mask, :2] 86 | 87 | # Normalize keypoint coordinates to range [-1, 1] for grid sampling 88 | h, w = ft.shape[2:] 89 | kp_normalized = torch.stack([ 90 | 2 * kp_[:, 1] / (imgsize[1] - 1) - 1, 91 | 2 * kp_[:, 0] / (imgsize[0] - 1) - 1 92 | ], dim=-1).unsqueeze(0).unsqueeze(0).to(ft.device) # Shape: (1, num_kps, 2) 93 | 94 | # Use grid_sample to directly sample features at keypoint locations 95 | # Ensure kp_normalized and ft are on the same device 96 | feats[i] = F.grid_sample(ft, kp_normalized, mode='bilinear', align_corners=False) 97 | feats[i] = feats[i].squeeze(0).squeeze(1).t() # Shape: (num_vis_kps, C) 98 | 99 | # free up memory 100 | torch.cuda.empty_cache() 101 | parts.append(torch.eye(kp.shape[0])[visible_mask]) 102 | else: 103 | parts.append(None) 104 | 105 | return feats[0], parts[0], feats[1], parts[1] 106 | 107 | @torch.no_grad() 108 | def get_init_feats_and_labels(dataset, N, model_refine=None, feat_interpol=False, prt_interpol=True, only_fg=False): 109 | # use seeds to make sure the same samples are used for all models 110 | np.random.seed(0) 111 | indices = np.random.permutation(np.arange(0, len(dataset)))[:N] 112 | feats = [] 113 | parts = [] 114 | for idx in indices: 115 | feats_idx, parts_idx = get_feats(dataset, idx, model_refine=model_refine, only_fg=only_fg, feat_interpol=feat_interpol, prt_interpol=prt_interpol) 116 | feats.append(feats_idx) 117 | parts.append(parts_idx) 118 | feats = torch.cat(feats) 119 | parts = torch.cat(parts) 120 | return feats, parts 121 | 122 | def to_flattened_idx_torch(x, y, x_width, y_width): 123 | idx = (y.round() + x.round()*y_width).int() 124 | x_, y_ = torch.unravel_index(idx, (x_width, y_width)) 125 | # assert that all the indices are the original ones (up to 0.5) 126 | assert torch.all(torch.abs(x_-x)<0.5) 127 | assert torch.all(torch.abs(y_-y)<0.5) 128 | return idx 129 | 130 | 131 | def get_multi_cat_dataset(dataset_list, featurizer, featurizer_kwargs, cat_list_=None, model_seg_name=None): 132 | datasets = [] 133 | for dataset in dataset_list: 134 | cat_list = dataset.all_cats if cat_list_ is None else cat_list_ 135 | for cat in cat_list: 136 | dataset.padding_kps = True 137 | dataset.init_kps_cat(cat) 138 | datasets.append(copy.deepcopy(dataset)) 139 | datasets[-1].featurizer = featurizer 140 | datasets[-1].featurizer_kwargs = featurizer_kwargs 141 | datasets[-1].model_seg_name = model_seg_name 142 | datasets[-1].return_masks = True 143 | return ConcatDataset(datasets) 144 | 145 | # def get_multi_cat_dataset(dataset_train, cat_list, featurizer, featurizer_kwargs): 146 | # datasets = [] 147 | # for cat in cat_list: 148 | # dataset_train.padding_kps = True 149 | # dataset_train.init_kps_cat(cat) 150 | # datasets.append(copy.deepcopy(dataset_train)) 151 | # datasets[-1].featurizer = featurizer 152 | # datasets[-1].featurizer_kwargs = featurizer_kwargs 153 | # return ConcatDataset(datasets) 154 | 155 | -------------------------------------------------------------------------------- /src/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | STOREBASEPATH = '/storage/group/cvpr/regine/dataset_matching_symm/figures_new' -------------------------------------------------------------------------------- /src/evaluation/evaluation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from src.evaluation.pck import evaluate as evaluate_pck_cat 4 | from src.evaluation.pck import get_per_point_pck 5 | from src.evaluation.segmentation import evaluate as evaluate_seg_cat 6 | from src.logging.log_results import log_wandb 7 | 8 | from src.dataset.cub_200_pairs import CUBPairDataset 9 | from src.dataset.apk_pairs import AP10KPairs 10 | from src.dataset.spair import SpairDataset2 11 | from src.dataset.pfpascal_pairs import PFPascalPairsOrig 12 | from src.dataset.pascalparts import PascalParts 13 | import copy 14 | 15 | from src.models.featurizer.utils import get_featurizer_name, get_featurizer 16 | class Evaluation(): 17 | def __init__(self, args, featurizer=None, split='test') -> None: 18 | self.epoch = None 19 | self.args = copy.deepcopy(args) 20 | self.split = split 21 | torch.cuda.set_device(0) 22 | self.args.dataset.sup = "sup_original" 23 | if self.args.dataset.name == 'spair': 24 | self.dataset_test_pck = SpairDataset2(self.args.dataset, split=split) 25 | precomputed = False 26 | elif self.args.dataset.name == 'cub': 27 | self.args.dataset.borders_cut = False 28 | self.dataset_test_pck = CUBPairDataset(self.args.dataset, split=split) 29 | self.args.dataset.borders_cut = True 30 | precomputed = False 31 | elif self.args.dataset.name == 'ap10k': 32 | if hasattr(self.args, 'upsample'): 33 | self.args.upsample = False 34 | self.dataset_test_pck = AP10KPairs(self.args.dataset, split=split) 35 | precomputed = False 36 | elif self.args.dataset.name == 'pfpascal': 37 | self.dataset_test_pck = PFPascalPairsOrig(self.args.dataset, split=split) 38 | precomputed = False 39 | else: 40 | self.dataset_test_pck = None 41 | self.dataset_test_ot_pairs = None 42 | 43 | if self.args.dataset.name == 'pascalparts': 44 | self.dataset_test_segmentation = PascalParts(self.args.dataset, split="test") 45 | self.dataset_train_segmentation = PascalParts(self.args.dataset, split="train") 46 | precomputed = False 47 | else: 48 | self.dataset_test_segmentation = None 49 | self.dataset_train_segmentation = None 50 | 51 | self.init_featurizer(precomputed, featurizer) 52 | self.reset_cat() 53 | 54 | def init_featurizer(self, precomputed, featurizer): 55 | if featurizer is not None: 56 | if self.args.featurizer.model == 'dift_sd': 57 | self.args.featurizer.all_cats = self.dataset_test_pck.all_cats 58 | featurizer_name = featurizer.name 59 | elif not precomputed: 60 | featurizer = get_featurizer(self.args.featurizer) 61 | featurizer_name = featurizer.name 62 | else: 63 | featurizer_name = get_featurizer_name(self.args.featurizer) 64 | # test datasets 65 | for dataset in [self.dataset_test_pck, self.dataset_test_segmentation, self.dataset_train_segmentation]: 66 | if dataset is not None: 67 | if not precomputed: 68 | dataset.featurizer_name = featurizer_name 69 | dataset.featurizer = featurizer 70 | dataset.featurizer_kwargs = self.args.featurizer 71 | else: 72 | dataset.featurizer_name = featurizer_name 73 | 74 | def reset_cat(self): 75 | for dataset in [ self.dataset_test_pck, self.dataset_test_segmentation, self.dataset_train_segmentation]: 76 | if dataset is not None: 77 | dataset.all_cats_eval = dataset.all_cats 78 | 79 | def add_suffix(self, result, suffix): 80 | return {k+suffix:v for k,v in result.items()} 81 | 82 | @torch.no_grad() 83 | def evaluate_pck(self, model_refine=None, n_pairs_eval_pck=None, suffix=''): 84 | results = {} 85 | if self.dataset_test_pck is None: 86 | return results 87 | print("evaluate PCK...") 88 | 89 | if n_pairs_eval_pck is None: 90 | n_pairs = self.args.n_pairs_eval_pck 91 | else: 92 | n_pairs = n_pairs_eval_pck 93 | 94 | # per point PCK (over all categories) 95 | n_total = {'10': 0, '01': 0, '11': 0, '00': 0, '1x': 0, '1x_hat': 0, '10_hat': 0, '11_hat': 0, '11_overline': 0,'11_underline': 0, '01_overline': 0, '11_tilde': 0} 96 | n_total_10 = n_total.copy() 97 | n_total_05 = n_total.copy() 98 | n_total_15 = n_total.copy() 99 | 100 | for cat in self.dataset_test_pck.all_cats_eval: 101 | print(f'... for cat: {cat}') 102 | # evaluate the PCK for each category 103 | self.dataset_test_pck.init_kps_cat(cat) 104 | if model_refine is not None: 105 | model_refine = model_refine.eval() 106 | cat_results, n_total_05_cat, n_total_10_cat, n_total_15_cat = evaluate_pck_cat(cat, self.dataset_test_pck, self.args.upsample, self.args.alpha_bbox, n_pairs=n_pairs, model_refine=model_refine) 107 | results.update(cat_results) 108 | 109 | # per point PCK (over all categories) 110 | for key in n_total.keys(): 111 | n_total_05[key] += n_total_05_cat[key] 112 | n_total_10[key] += n_total_10_cat[key] 113 | n_total_15[key] += n_total_15_cat[key] 114 | 115 | # per point PCK (over all categories) 116 | for n_total, alph in zip([n_total_10, n_total_05, n_total_15], ['0.1', '0.05', '0.15']): 117 | results_alph = get_per_point_pck(n_total, '', alph) 118 | results.update(results_alph) 119 | 120 | if self.split in ['train', 'val']: 121 | results = self.add_suffix(results, f'_{self.split}') 122 | if n_pairs_eval_pck is not None: 123 | results = self.add_suffix(results, f'_{n_pairs}pairs') 124 | if suffix != '': 125 | results = self.add_suffix(results, f'_{suffix}') 126 | if self.epoch is not None: 127 | results['epoch'] = self.epoch 128 | log_wandb(results) 129 | self.results_pck = results 130 | 131 | @torch.no_grad() 132 | def evaluate_seg(self, model_refine=None, suffix=''): 133 | results, results_mean = {}, {} 134 | if self.dataset_test_segmentation is None: 135 | return results 136 | print("evaluate Segmentation...") 137 | modeldict = {'model_refine': model_refine, 'dataset_test': self.dataset_test_segmentation, 'dataset_train': self.dataset_train_segmentation} 138 | for cat in self.dataset_test_segmentation.all_cats_eval: 139 | print(f'... for cat: {cat}') 140 | self.dataset_test_segmentation.init_kps_cat(cat) 141 | self.dataset_train_segmentation.init_kps_cat(cat) 142 | if model_refine is not None: 143 | model_refine = model_refine.eval() 144 | cat_results = evaluate_seg_cat(self.args, modeldict) 145 | for k,v in cat_results.items(): 146 | if not k in results_mean.keys(): 147 | results_mean[k] = [] 148 | results_mean[k].append(v) 149 | results.update({'segmentation'+cat+'_'+k: v for k, v in cat_results.items()}) 150 | 151 | for k,v in results_mean.items(): 152 | results_mean[k] = torch.tensor(v).mean().numpy() 153 | results.update({'segmentation_mean_'+k: v for k, v in results_mean.items()}) 154 | if self.split in ['train', 'val']: 155 | results = self.add_suffix(results, f'_{self.split}') 156 | if suffix != '': 157 | results = self.add_suffix(results, f'_{suffix}') 158 | if self.epoch is not None: 159 | results['epoch'] = self.epoch 160 | log_wandb(results) 161 | -------------------------------------------------------------------------------- /src/evaluation/pck.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from tqdm import tqdm 4 | from src.matcher.argmaxmatcher import ArgmaxMatcher 5 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 6 | import copy 7 | from src.evaluation import STOREBASEPATH 8 | NUM_VIZ = 15 9 | from src.logging.visualization_pck import plot_src, plot_trg 10 | 11 | def viz_pck_alpha(dataset, idx, data, src_to_trg_point, heatmap_, store_path, count_viz, key): 12 | if key == '00': 13 | return count_viz 14 | if count_viz[key] 0 else 0 101 | results['per point PCK@'+alph+' \hat{n}_11/n_11' + cat] = n_total['11_hat'] / n_total['11'] * 100 if n_total['11'] > 0 else 0 102 | results['per point PCK@'+alph+' \hat{n}_1x/n_1x' + cat] = n_total['1x_hat'] / n_total['1x'] * 100 if n_total['1x'] > 0 else 0 103 | results['per point PCK@'+alph+' \overline{n}_11/n_11 ' + cat] = n_total['11_overline'] / n_total['11'] * 100 if n_total['11'] > 0 else 0 104 | results['per point PCK@'+alph+' \tilde{n}_11/n_11 ' + cat] = n_total['11_tilde'] / n_total['11'] * 100 if n_total['11'] > 0 else 0 105 | # results['per point PCK@'+alph+' \overline{n}_01/n_01 ' + cat] = n_total['01_overline'] / n_total['01'] * 100 if n_total['01'] > 0 else 0 106 | return results 107 | 108 | @torch.no_grad() 109 | def evaluate(cat, dataset, upsample, bbox, n_pairs, model_refine=None, path=STOREBASEPATH+'/05_experiments/', visualize=False): 110 | if model_refine is not None: 111 | store_path = path+'pck/'+dataset.name+'/'+cat+'/'+model_refine.id+'/' 112 | else: 113 | store_path = path+'pck/'+dataset.name+'/'+cat+'/'+dataset.featurizer_name+'/' 114 | matcher = ArgmaxMatcher() 115 | results = {} 116 | count_viz = {'10': 0, '01': 0, '11': 0, '00': 0, '1x': 0} 117 | # iterate over all categories 118 | dataset.init_kps_cat(cat) 119 | # init the counters 120 | n_total_10 = {'10': 0, '01': 0, '11': 0, '00': 0, '1x': 0, '1x_hat': 0, '10_hat': 0, '11_hat': 0, '11_overline': 0,'11_underline': 0, '01_overline': 0, '11_tilde': 0} 121 | n_total_05 = copy.deepcopy(n_total_10) 122 | n_total_15 = copy.deepcopy(n_total_10) 123 | alpha = 0.1 124 | 125 | # iterate over all test image pairs in the category 126 | perimg_PCK = [] 127 | for i, data in enumerate(tqdm(dataset)): 128 | if i == n_pairs: 129 | break 130 | 131 | # get the data for the pair 132 | src_ft = data['src_ft'][None].to(device) 133 | trg_ft = data['trg_ft'][None].to(device) 134 | if model_refine!=None: 135 | ft_orig = [src_ft, trg_ft] 136 | ft_new = [model_refine(f) for f in ft_orig] 137 | src_ft = ft_new[0] 138 | trg_ft = ft_new[1] 139 | 140 | # get the spatial resolution of the feature maps to match the original image size, where keypoints are annotated 141 | src_img_size = data['src_imsize'] 142 | trg_img_size = data['trg_imsize'] 143 | if bbox: 144 | trg_bndbox = data['trg_bndbox'] 145 | threshold = max(trg_bndbox[3] - trg_bndbox[1], trg_bndbox[2] - trg_bndbox[0]) 146 | else: 147 | threshold = max(trg_img_size[0], trg_img_size[1]) 148 | 149 | # init the per image counters 150 | n_img_10 = {'10': 0, '01': 0, '11': 0, '00': 0, '1x': 0, '1x_hat': 0, '10_hat': 0, '11_hat': 0, '11_overline': 0, '01_overline': 0, '11_tilde': 0} 151 | n_img_05 = copy.deepcopy(n_img_10) 152 | n_img_15 = copy.deepcopy(n_img_10) 153 | 154 | # iterate over all points in the pair and find the second point in the target image by argmax matching between query feature and all target features 155 | for idx in range(len(data['src_kps'])): 156 | 157 | # skip the points that are not annotated 158 | if data['src_kps'][idx][2] == 0: 159 | continue 160 | 161 | # match the keypoints 162 | src_point = data['src_kps'][idx] 163 | ft0, ft1, trg_ft_size = matcher.prepare_one_to_all(src_ft, trg_ft, src_point, src_img_size, trg_img_size, upsample) 164 | heatmap, prob = matcher(ft0, ft1) 165 | src_to_trg_point, heatmap_ = matcher.get_one_trg_point(heatmap[0], prob[0], trg_ft_size, trg_img_size) 166 | 167 | match_vis = data['trg_kps'][idx][2] > 0.5 168 | symm_vis = data['trg_kps_symm_only'][idx][2] > 0.5 if dataset.pck_symm else False 169 | 170 | # update counters 171 | n_img_10, key_10 = update_counters_pck_alpha(dataset, idx, data, src_to_trg_point, threshold, 0.10, match_vis, symm_vis, n_img_10) 172 | n_img_05, key_05 = update_counters_pck_alpha(dataset, idx, data, src_to_trg_point, threshold, 0.05, match_vis, symm_vis, n_img_05) 173 | n_img_15, key_15 = update_counters_pck_alpha(dataset, idx, data, src_to_trg_point, threshold, 0.15, match_vis, symm_vis, n_img_15) 174 | # visualize the keypoints 175 | if visualize: 176 | count_viz = viz_pck_alpha(dataset, idx, data, src_to_trg_point, heatmap_, store_path, count_viz, key_10) 177 | 178 | # update the counters 179 | for key, value in n_img_10.items(): 180 | n_total_10[key] += value 181 | for key, value in n_img_05.items(): 182 | n_total_05[key] += value 183 | for key, value in n_img_15.items(): 184 | n_total_15[key] += value 185 | 186 | n_total_pck_denom = n_total_10['10'] + n_total_10['11'] + n_total_10['1x'] 187 | 188 | results_10 = get_per_point_pck(n_total_10, cat, '0.1') 189 | results.update(results_10) 190 | results_05 = get_per_point_pck(n_total_05, cat, '0.05') 191 | results.update(results_05) 192 | results_15 = get_per_point_pck(n_total_15, cat, '0.15') 193 | results.update(results_15) 194 | 195 | results['n_10/n ' + cat] = n_total_10['10'] / n_total_pck_denom if n_total_pck_denom > 0 else 0 196 | results['n_11/n ' + cat] = n_total_10['11'] / n_total_pck_denom if n_total_pck_denom > 0 else 0 197 | results['n_1x/n ' + cat] = n_total_10['1x'] / n_total_pck_denom if n_total_pck_denom > 0 else 0 198 | results['n ' + cat] = n_total_pck_denom 199 | 200 | return results, n_total_05, n_total_10, n_total_15 -------------------------------------------------------------------------------- /src/evaluation/runtime_mem.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from tqdm import tqdm 4 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 5 | from src.evaluation import STOREBASEPATH 6 | import time 7 | 8 | def get_model_size(model): 9 | param_size = 0 10 | n_params = 0 11 | for param in model.parameters(): 12 | param_size += param.nelement() * param.element_size() 13 | n_params += param.nelement() 14 | buffer_size = 0 15 | n_buffers = 0 16 | for buffer in model.buffers(): 17 | buffer_size += buffer.nelement() * buffer.element_size() 18 | n_buffers += buffer.nelement() 19 | 20 | print('n_params:', n_params) 21 | print('n_buffers:', n_buffers) 22 | size_all_mb = (param_size + buffer_size) / 1024**2 23 | print('model size: {:.3f}MB'.format(size_all_mb)) 24 | return size_all_mb 25 | 26 | @torch.no_grad() 27 | def evaluate(dataset, n_imgs, model_refine=None, path=STOREBASEPATH+'/05_experiments/'): 28 | results = {} 29 | 30 | assert(dataset.featurizer is not None) 31 | # measure the runtime and memory usage 32 | # parameter sizes of dataset.featurizer model: 33 | models = dataset.featurizer.get_models() 34 | sizes = [get_model_size(model) for model in models] 35 | results["size featurizer (MB)"] = sum(sizes) 36 | 37 | if model_refine!=None: 38 | results["size refiner (MB)"] = get_model_size(model_refine) 39 | # measure the runtime 40 | i=0 41 | time_sum = 0 42 | for cat in tqdm(dataset.all_cats): 43 | dataset.init_kps_cat(cat) 44 | for idx_ in range(len(dataset)): 45 | if i == n_imgs: 46 | break 47 | 48 | img = dataset._get_img(idx_) 49 | 50 | starttime = time.time() 51 | ft_orig = dataset.featurizer.forward(img, 52 | category=cat, 53 | **dataset.featurizer_kwargs) 54 | if model_refine!=None: 55 | ft_new = model_refine(ft_orig) 56 | endtime = time.time() 57 | time_sum += endtime - starttime 58 | i+=1 59 | 60 | endtime = time.time() 61 | results["runtime"] = time_sum/i *1000 # in ms 62 | return results 63 | -------------------------------------------------------------------------------- /src/evaluation/segmentation.py: -------------------------------------------------------------------------------- 1 | from src.evaluation.segmentation_metrics import confusion_matrix, accuracy, mean_precision, mean_recall, mean_iou 2 | from src.models.classifier.utils import train_classifier, forward_classifier 3 | from typing import Dict 4 | import torch.functional as F 5 | import torch 6 | from src.logging.visualization_seg import plot_assignment 7 | from src.evaluation import STOREBASEPATH 8 | 9 | def evaluate_img(dataset_test, model_seg, model_refine=None, full_resolution=False, idx=0): 10 | # get segmentation model 11 | prt_pred = forward_classifier(dataset_test, idx, model_seg, model_refine=model_refine) 12 | 13 | data = dataset_test[idx] 14 | prt_gt = data['parts_mask'][None] 15 | # rescale to fit to gt size 16 | if full_resolution: 17 | prt_pred = F.interpolate(prt_pred, size=prt_gt.shape[-2:], mode='bilinear', align_corners=False) 18 | y_max = len(dataset_test.KP_NAMES) 19 | 20 | indices = prt_gt.sum(1)>1e-20 # only consider foreground pixels 21 | # generate labels 22 | prt_gt = prt_gt.argmax(1)+1 23 | prt_pred = prt_pred.argmax(1)+1 24 | 25 | prt_pred[~indices] = 0 26 | prt_gt[~indices] = 0 27 | 28 | conf_matrix = confusion_matrix(prt_pred[indices].detach().cpu()-1, prt_gt[indices].detach().cpu()-1, y_max).detach().cpu() 29 | 30 | cat_metrics: Dict[str, float] = {} 31 | cat_metrics["acc"] = accuracy(conf_matrix) 32 | cat_metrics["m_prcn"] = mean_precision(conf_matrix) 33 | cat_metrics["m_rcll"] = mean_recall(conf_matrix) 34 | cat_metrics["m_iou"] = mean_iou(conf_matrix) 35 | 36 | return cat_metrics, prt_pred, prt_gt 37 | 38 | def evaluate(args, modeldict, full_resolution=False, path=STOREBASEPATH+'/05_experiments/'): 39 | 40 | y_max = len(modeldict['dataset_train'].KP_NAMES) 41 | if y_max<2: 42 | return {} 43 | modeldict['model_seg'] = train_classifier(args.sup_classifier, modeldict['dataset_train'], model_refine=modeldict['model_refine']) 44 | pre_list = [] 45 | gt_list = [] 46 | 47 | for idx in range(len(modeldict['dataset_test'])): 48 | _, prt_pred, prt_gt = evaluate_img(modeldict['dataset_test'], modeldict['model_seg'], model_refine=modeldict['model_refine'], full_resolution=full_resolution, idx=idx) 49 | pre_list.append(prt_pred.flatten()) 50 | gt_list.append(prt_gt.flatten()) 51 | 52 | pred = torch.cat(pre_list, 0) 53 | gt = torch.cat(gt_list, 0) 54 | conf_matrix_all = confusion_matrix(pred.detach().cpu(), gt.detach().cpu(), y_max+1).detach().cpu() 55 | assert conf_matrix_all.shape[0] == y_max+1 56 | 57 | def get_results_dict_cat(y_max, conf_matrix, gt, text="", oriented=None): 58 | results_dict_cat = {} 59 | # evaluate the confusion matrix for all parts with at least one pixel in the ground truth 60 | num_gt_pixels = torch.zeros(y_max+1).int() 61 | for i in range(0, y_max+1): 62 | num_gt_pixels[i] = torch.sum(gt==i).int() 63 | if oriented is not None: 64 | if not oriented[i]: 65 | num_gt_pixels[i] = 0 66 | 67 | index = num_gt_pixels==0 68 | 69 | conf_matrix = conf_matrix[~index][:,~index] 70 | results_dict_cat["acc"+text] = accuracy(conf_matrix) 71 | results_dict_cat["m_prcn"+text] = mean_precision(conf_matrix) 72 | results_dict_cat["m_rcll"+text] = mean_recall(conf_matrix) 73 | results_dict_cat["m_iou"+text] = mean_iou(conf_matrix) 74 | 75 | num_gt_pixels = num_gt_pixels[~index] 76 | conf_matrix_normalized = conf_matrix/num_gt_pixels[None] 77 | results_dict_cat["acc_precnorm"+text] = accuracy(conf_matrix_normalized) 78 | results_dict_cat["m_prcn_precnorm"+text] = mean_precision(conf_matrix_normalized) 79 | results_dict_cat["m_rcll_precnorm"+text] = mean_recall(conf_matrix_normalized) 80 | results_dict_cat["m_iou_precnorm"+text] = mean_iou(conf_matrix_normalized) 81 | return results_dict_cat, conf_matrix, conf_matrix_normalized, index 82 | 83 | results_dict_cat, conf_matrix, conf_matrix_normalized, index = get_results_dict_cat(y_max, conf_matrix_all, gt) 84 | visualize = True 85 | def visualization(conf_matrix, conf_matrix_normalized, index, modeldict, path, text=""): 86 | if modeldict['model_refine'] is not None: 87 | store_path = path+'seg/'+modeldict['dataset_test'].name+'/'+modeldict['dataset_test'].cat+'/'+modeldict['model_refine'].id+'/' 88 | else: 89 | store_path = path+'seg/'+modeldict['dataset_test'].name+'/'+modeldict['dataset_test'].cat+'/'+modeldict['dataset_test'].featurizer_name+'/' 90 | # index = assign_dict['conf_matrix_normalized']==torch.nan 91 | assign_dict = {'conf_matrix': conf_matrix, 92 | 'conf_matrix_normalized': conf_matrix_normalized} 93 | kpnames = modeldict['dataset_test'].KP_NAMES.copy() 94 | kpnames = ['bkg']+kpnames 95 | kpnames = [kpnames[i] for i in range(len(kpnames)) if i not in torch.where(index)[0]] 96 | # remove the background 97 | plot_dict = {k:v[1:,1:] for k,v in assign_dict.items()} 98 | plot_kpnames = kpnames[1:] 99 | plot_assignment(plot_dict, plot_kpnames, plot_cmap=True, path=store_path+f'/conf/'+modeldict['dataset_test'].split+"/"+text+"/", limits=[0, 1.0/len(kpnames)]) 100 | 101 | if visualize: 102 | visualization(conf_matrix, conf_matrix_normalized, index, modeldict, path, "all") 103 | 104 | oriented = modeldict['dataset_test'].KP_WITH_ORIENTATION 105 | # add bkg to oriented 106 | oriented = torch.cat([torch.tensor([True]), torch.tensor(oriented)]) 107 | results_dict_cat_geo, conf_matrix_geo, conf_matrix_normalized_geo, index_geo = get_results_dict_cat(y_max, conf_matrix_all, gt, "_geo", oriented) 108 | 109 | if visualize: 110 | if conf_matrix_geo.shape[0] > 2: 111 | visualization(conf_matrix_geo, conf_matrix_normalized_geo, index_geo, modeldict, path, "geo") 112 | 113 | results_dict_cat.update(results_dict_cat_geo) 114 | 115 | return results_dict_cat 116 | -------------------------------------------------------------------------------- /src/evaluation/segmentation_metrics.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | import torch 4 | from sklearn.metrics import confusion_matrix as sk_conf_matrix 5 | 6 | 7 | def pixelwise_accuracy(prediction: torch.Tensor, annotations: torch.Tensor) -> Tuple[int, int]: 8 | # prediction: B,H,W 9 | # annotations: B,H,W 10 | pred_correct = int(((prediction - annotations) == 0).sum().item()) 11 | total = int(prediction.view(-1).shape[0]) 12 | 13 | return pred_correct, total 14 | 15 | 16 | def confusion_matrix(prediction: torch.Tensor, annotations: torch.Tensor, num_classes: int) -> torch.Tensor: 17 | conf_matrix = torch.zeros(num_classes, num_classes).int() 18 | 19 | cm = torch.from_numpy(sk_conf_matrix(prediction.reshape(-1).cpu(), annotations.reshape(-1).cpu(), labels=range(num_classes))) 20 | conf_matrix[: cm.shape[0], : cm.shape[1]] = cm 21 | 22 | return conf_matrix 23 | 24 | 25 | def precision(true_positive: int, false_positive: int) -> float: 26 | return (true_positive) / (true_positive + false_positive) 27 | 28 | 29 | def recall(true_positive: int, false_negative: int) -> float: 30 | return (true_positive) / (true_positive + false_negative) 31 | 32 | 33 | def accuracy(conf_matrix: torch.Tensor) -> float: 34 | if torch.sum(conf_matrix) > 1e-20: 35 | return (torch.trace(conf_matrix) / torch.sum(conf_matrix)).item() 36 | else: 37 | return 0.0 38 | 39 | 40 | def mean_precision(conf_matrix: torch.Tensor, mask: Optional[List] = None): 41 | precisions = [] 42 | for idx, row in enumerate(conf_matrix): 43 | if mask is not None: 44 | if mask[idx] == 0: 45 | continue 46 | if row.sum() > 1e-20: 47 | precisions.append((row[idx] / row.sum()).item()) 48 | 49 | return sum(precisions) / len(precisions) if len(precisions) > 0 else 0 50 | 51 | 52 | def mean_recall(conf_matrix: torch.Tensor, mask: Optional[List] = None): 53 | recalls = [] 54 | for idx, col in enumerate(conf_matrix.transpose(0, 1)): 55 | if mask is not None: 56 | if mask[idx] == 0: 57 | continue 58 | if col.sum() > 1e-20: 59 | recalls.append((col[idx] / col.sum()).item()) 60 | 61 | return sum(recalls) / len(recalls) if len(recalls) > 0 else 0 62 | 63 | 64 | def mean_iou(conf_matrix: torch.Tensor, mask: Optional[List] = None): 65 | ious = [] 66 | for idx in range(conf_matrix.shape[0]): 67 | if mask is not None: 68 | if mask[idx] == 0: 69 | continue 70 | denom = (conf_matrix[idx, :].sum() + conf_matrix[:, idx].sum() - conf_matrix[idx, idx]) 71 | if denom > 1e-20: 72 | ious.append( 73 | ( 74 | conf_matrix[idx, idx] / denom 75 | ).sum().item() 76 | ) 77 | 78 | return sum(ious) / len(ious) if len(ious) > 0 else 0 79 | -------------------------------------------------------------------------------- /src/logging/log_results.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | 3 | def init_wandb(cfg, prefix=''): 4 | # start a new wandb run to track this script 5 | # check if feat model is in the cfg 6 | name = '' 7 | if 'featurizer' in cfg: 8 | name = name+cfg['featurizer']['model'] 9 | if 'feat_refine' in cfg: 10 | name = name+'+'+cfg['feat_refine']['model'] 11 | name = name+'+'+cfg['dataset']['name'] 12 | if 'train' in prefix: 13 | if 'sup' in cfg['dataset']: 14 | name = name+cfg['dataset']['sup'] 15 | if 'dataset2' in cfg: 16 | name = name+'+'+cfg['dataset2']['name'] 17 | if 'sup' in cfg['dataset2']: 18 | name = name+cfg['dataset2']['sup'] 19 | if 'dataset3' in cfg: 20 | name = name+'+'+cfg['dataset3']['name'] 21 | if 'sup' in cfg['dataset3']: 22 | name = name+cfg['dataset3']['sup'] 23 | if 'dataset4' in cfg: 24 | name = name+'+'+cfg['dataset4']['name'] 25 | if 'sup' in cfg['dataset4']: 26 | name = name+cfg['dataset4']['sup'] 27 | 28 | wandb.init(entity="", 29 | project="", 30 | name=prefix+' '+name, 31 | config=cfg, 32 | settings=wandb.Settings(code_dir=".")) 33 | 34 | def log_wandb(results): 35 | if wandb.run: 36 | wandb.log(results) 37 | 38 | def log_wandb_epoch(epoch): 39 | 40 | if wandb.run: 41 | results = {'epoch': epoch} 42 | wandb.log(results) 43 | 44 | def log_wandb_cfg(cfg): 45 | # init wandb if not already 46 | if wandb.run: 47 | wandb.config.update(cfg, allow_val_change=True) 48 | 49 | def log_wandb_ram_usage(): 50 | import psutil 51 | import os 52 | process = psutil.Process(os.getpid()) 53 | other = psutil.virtual_memory() 54 | 55 | if wandb.run: 56 | wandb.log({ 'ram_usage': process.memory_info().rss/1024/1024 57 | , 'ram_total': other.total/1024/1024 58 | , 'ram_free': other.available/1024/1024 59 | , 'ram_percent': other.percent}) 60 | print('ram_usage: ', process.memory_info().rss/1024/1024, 'MB') 61 | print('ram_total: ', other.total/1024/1024, 'MB') 62 | 63 | def finish_wandb(): 64 | if wandb.run: 65 | wandb.finish() -------------------------------------------------------------------------------- /src/logging/visualization.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import matplotlib.pyplot as plt 3 | import torch 4 | import numpy as np 5 | from src.matcher.argmaxmatcher import ArgmaxMatcher 6 | from torch import nn 7 | from src.dataset.utils import to_flattened_idx_torch 8 | from PIL import Image 9 | 10 | class Demo: 11 | 12 | def __init__(self, imgs, ft, img_size): 13 | self.ft = ft # N+1, C, H, W 14 | # check if image is pil image 15 | if imgs[0].__class__.__name__ != 'PngImageFile': 16 | imgs = [Image.fromarray(img) for img in imgs] 17 | self.imgs = imgs 18 | self.num_imgs = len(imgs) 19 | self.img_size = img_size 20 | 21 | def plot_imgs_joint(self, fig_size=3): 22 | # concatenate the images and plot them 23 | # pad imgs to the same height 24 | max_h = max([img.size[1] for img in self.imgs]) 25 | imgs = [] 26 | for i in range(len(self.imgs)): 27 | if self.imgs[i].size[1] < max_h: 28 | # add zero padding using pillow 29 | img = Image.new('RGB', (self.imgs[i].size[0], max_h), (0, 0, 0)) 30 | img.paste(self.imgs[i], (0, 0)) 31 | else: 32 | img = self.imgs[i] 33 | imgs.append(img) 34 | 35 | img = np.concatenate([np.array(img) for img in imgs], axis=1) 36 | fig, ax = plt.subplots(1, 1, figsize=(fig_size*len(self.imgs), fig_size)) 37 | plt.tight_layout() 38 | ax.imshow(img) 39 | # no axis 40 | ax.axis('off') 41 | return fig, ax 42 | 43 | def plot_images(self, fig_size=3): 44 | # plot the source image and the target images with the heatmap corresponding to the source point 45 | fig, axes = plt.subplots(1, self.num_imgs, figsize=(fig_size*self.num_imgs, fig_size)) 46 | plt.tight_layout() 47 | for i in range(self.num_imgs): 48 | axes[i].imshow(self.imgs[i]) 49 | axes[i].axis('off') 50 | if i == 0: 51 | axes[i].set_title('Source image $S$') 52 | else: 53 | axes[i].set_title('Target image $T$') 54 | return fig, axes 55 | 56 | def plot_src(self, axes, src_point, scatter_size): 57 | # plot src image 58 | axes[0].clear() 59 | axes[0].imshow(self.imgs[0]) 60 | axes[0].axis('off') 61 | axes[0].scatter(src_point[1], src_point[0], c='r', s=scatter_size) 62 | # scale the point to the feature map size 63 | scale0 = torch.tensor(self.ft[0].shape[-2:]).float()/torch.tensor(self.img_size[0]).float() 64 | src_point_ = (src_point[:2]*scale0).clamp(torch.zeros(2), torch.tensor(self.ft[0].shape[-2:])-1) 65 | 66 | idx_src = to_flattened_idx_torch(src_point_[0], src_point_[1], self.ft[0].shape[-2], self.ft[0].shape[-1]) 67 | axes[0].set_title('source image\nidx_src: %d' % idx_src) 68 | 69 | def plot_heatmap(self, axes, i, heatmap, src_to_trg_point, alpha, scatter_size, limits=None): 70 | # plot the heatmap 71 | axes[i].clear() 72 | # heatmap = (heatmap - np.min(heatmap)) / (np.max(heatmap) - np.min(heatmap)) # Normalize to [0, 1] 73 | # convert to grayscale 74 | img_i = self.imgs[i].copy() 75 | img_i = img_i.convert('L') 76 | axes[i].imshow(img_i, alpha=1-alpha, cmap=plt.get_cmap('gray')) 77 | if limits is None: 78 | vmax = np.max(heatmap)*255 79 | vmin = np.min(heatmap)*255 80 | else: 81 | vmin = limits[0]*255 82 | vmax = limits[1]*255 83 | axes[i].imshow(255 * heatmap, alpha=alpha, cmap='viridis', vmin=vmin, vmax=vmax) 84 | axes[i].axis('off') 85 | if src_to_trg_point is not None: 86 | axes[i].scatter(src_to_trg_point[1], src_to_trg_point[0], c='r', s=scatter_size) 87 | axes[i].set_title('target image\nmax value: %.2f' % (heatmap).max()) 88 | 89 | def plot_matched_heatmap(self, axes, src_points, alpha, scatter_size, upsample): 90 | matcher = ArgmaxMatcher() 91 | for i in range(1, self.num_imgs): 92 | # prepare the feature maps for matching 93 | ft0, ft1, trg_ft_size = matcher.prepare_one_to_all(self.ft[0], self.ft[i], src_points[0], self.img_size[0], self.img_size[i], upsample=upsample) 94 | # match the feature maps 95 | heatmap, prob = matcher(ft0, ft1) 96 | for j in range(1): # only plot the first source point, in case we have multiple source points 97 | src_to_trg_point, heatmap_ = matcher.get_one_trg_point(heatmap[j], prob[j], trg_ft_size, self.img_size[i]) 98 | self.plot_heatmap(axes, i, heatmap_, src_to_trg_point, alpha, scatter_size, limits=(0.5, 1)) 99 | del heatmap 100 | 101 | def plot_img_pairs_click(self, fig_size=3, alpha=0.45, scatter_size=70, upsample=True): 102 | fig, axes = self.plot_images(fig_size) 103 | def onclick(event): 104 | if event.inaxes == axes[0]: 105 | with torch.no_grad(): 106 | x, y = int(np.round(event.xdata)), int(np.round(event.ydata)) 107 | src_point = torch.tensor([y,x,1]) 108 | self.plot_src(axes, src_point, scatter_size) 109 | self.plot_matched_heatmap(axes, src_point[None], alpha, scatter_size, upsample) 110 | gc.collect() 111 | 112 | fig.canvas.mpl_connect('button_press_event', onclick) 113 | plt.show() 114 | 115 | def plot_img_pairs(self, src_point, fig_size=3, alpha=0.45, scatter_size=70, upsample=True): 116 | fig, axes = self.plot_images(fig_size) 117 | axes[0].clear() 118 | axes[0].imshow(self.imgs[0]) 119 | axes[0].axis('off') 120 | axes[0].scatter(int(src_point[1]), int(src_point[0]), c='r', s=scatter_size) # scatter needs flipped x and y 121 | axes[0].set_title('source image') 122 | # plot trg heatmap 123 | self.plot_matched_heatmap(axes, src_point[None], alpha, scatter_size, upsample) 124 | plt.show() 125 | 126 | def plot_matches(self, src_points, trg_points, fig_size=3, alpha=0.45, scatter_size=40, title='', path=None): 127 | # fig, axes = self.plot_images(fig_size) 128 | fig, axes = self.plot_imgs_joint(fig_size) 129 | #colours = ['r', 'g', 'b', 'c', 'm', 'y', 'k', 'w'] 130 | colours = [ 131 | (0, 255, 0), (0, 0, 255), (0, 255, 255), (255, 0, 0), (255, 0, 255), 132 | (255, 255, 0), (0, 0, 255), (0, 128, 255), (128, 0, 255), (0, 128, 0), 133 | (128, 0, 0), (0, 0, 128), (128, 128, 0), (0, 128, 128), (128, 0, 128), 134 | ] 135 | 136 | # Normalize RGB values (0-255 -> 0-1) 137 | normalized_colours = [(r/255, g/255, b/255) for r, g, b in colours] 138 | num_colours = len(colours) 139 | # matplot colours are in the range [0, 1] 140 | colours = [(float(r)/255, float(g)/255, float(b)/255) for r, g, b in colours] 141 | for j in range(len(src_points)): 142 | src_point = src_points[j] 143 | if src_point[2]!=0: 144 | axes.scatter(int(src_point[1]), int(src_point[0]), color=normalized_colours[j%num_colours], s=scatter_size) # scatter needs flipped x and y 145 | for j in range(len(trg_points)): 146 | trg_point = trg_points[j] 147 | if trg_point[2]!=0: 148 | offset = self.img_size[0][1] 149 | axes.scatter(int(trg_point[1])+offset, int(trg_point[0]), color=normalized_colours[j%num_colours], s=scatter_size) 150 | c = 'r' if 'neg' in title else 'g' if 'pos' in title else 'b' 151 | if 'neg' in title or 'pos' in title: 152 | # plot red connections for negative matches 153 | for j in range(len(src_points)): 154 | src_point = src_points[j] 155 | trg_point = trg_points[j] 156 | if src_point[2]!=0 and trg_point[2]!=0: 157 | axes.plot([src_point[1], trg_point[1]+offset], [src_point[0], trg_point[0]], color=c, linewidth=2) 158 | if 'bin' in title: 159 | # plot lines from the points to the boundaries 160 | for j in range(len(src_points)): 161 | src_point = src_points[j] 162 | if src_point[2]!=0: 163 | axes.plot([src_point[1], 0], [src_point[0], src_point[0]], color=c, linewidth=2) 164 | for j in range(len(trg_points)): 165 | trg_point = trg_points[j] 166 | offset1 = self.img_size[1][1] 167 | if trg_point[2]!=0: 168 | axes.plot([trg_point[1]+offset, offset+offset1-1], [trg_point[0], trg_point[0]], color=c, linewidth=2) 169 | # fig.suptitle(title) 170 | plt.show() 171 | if path is not None: 172 | fig.savefig(path,bbox_inches='tight') 173 | plt.close('all') 174 | 175 | 176 | def helper_fun(self, src_point, heatmap, axes, alpha, scatter_size, limits): 177 | # scale the point to the feature map size for indexing 178 | scale0 = torch.tensor(self.ft[0].shape[-2:]).float()/torch.tensor(self.img_size[0]).float() 179 | src_point_ = (src_point[:2]*scale0).clamp(torch.zeros(2), torch.tensor(self.ft[0].shape[-2:])-1) # ft indexing 180 | 181 | idx_src = to_flattened_idx_torch(src_point_[0], src_point_[1], self.ft[0].shape[-2], self.ft[0].shape[-1]) 182 | heatmap_ = heatmap[idx_src.long().item()] 183 | heatmap_ = nn.Upsample(size=self.img_size[1].tolist(), mode='bilinear')(torch.tensor(heatmap_).view(1,1,self.ft[1].shape[-2],self.ft[1].shape[-1])).squeeze(0).squeeze(0).cpu().numpy() 184 | self.plot_heatmap(axes, 1, heatmap_, None, alpha, scatter_size, limits=limits) -------------------------------------------------------------------------------- /src/logging/visualization_pck.py: -------------------------------------------------------------------------------- 1 | 2 | import matplotlib.pyplot as plt 3 | import matplotlib 4 | matplotlib.use('Agg') 5 | import os 6 | import numpy as np 7 | import copy 8 | 9 | def crop_centered(img, kp_dict): 10 | if img.size[0] > img.size[1]: 11 | for k, v in kp_dict.items(): 12 | if v is not None: 13 | kp_dict[k][1] = kp_dict[k][1] - (img.size[0] - img.size[1]) // 2 # H,W -> W,H from PIL to numpy 14 | if kp_dict[k][1] < 0: 15 | kp_dict[k][1] = 0 16 | elif kp_dict[k][1] >= img.size[1]: 17 | kp_dict[k][1] = img.size[1]-1 18 | img = img.crop(((img.size[0] - img.size[1]) // 2, 0, (img.size[1] + img.size[0]) // 2, img.size[1])) 19 | else: 20 | for k, v in kp_dict.items(): 21 | if v is not None: 22 | kp_dict[k][0] = kp_dict[k][0] - (img.size[1] - img.size[0]) // 2 # H,W -> W,H from PIL to numpy 23 | if kp_dict[k][0] < 0: 24 | kp_dict[k][0] = 0 25 | elif kp_dict[k][0] >= img.size[0]: 26 | kp_dict[k][0] = img.size[0]-1 27 | img = img.crop((0, (img.size[1] - img.size[0]) // 2, img.size[0], (img.size[1] + img.size[0]) // 2)) 28 | return img, kp_dict 29 | 30 | def crop_centered_heatmap(heatmap): 31 | if heatmap.shape[0] > heatmap.shape[1]: 32 | heatmap = heatmap[(heatmap.shape[0] - heatmap.shape[1]) // 2:(heatmap.shape[1] + heatmap.shape[0]) // 2, :] 33 | else: 34 | heatmap = heatmap[:, (heatmap.shape[1] - heatmap.shape[0]) // 2:(heatmap.shape[1] + heatmap.shape[0]) // 2] 35 | return heatmap 36 | 37 | def plot_src(img_, src_point, alpha=0.2, path=None, scatter_size=70): 38 | img = img_ 39 | def plotting(): 40 | ax.clear() 41 | ax.imshow(img) 42 | ax.axis('off') 43 | ax.scatter(src_point[1], src_point[0], edgecolor='black', linewidth=1, facecolor='w', s=scatter_size, label="$Q$ query") 44 | if path is not None: 45 | ax.legend(fontsize="20", loc ="upper left") 46 | os.makedirs(os.path.dirname(path), exist_ok=True) 47 | plt.savefig(path, bbox_inches='tight', pad_inches=0) 48 | plt.close('all') 49 | ############################################## 50 | fig, ax = plt.subplots(1, 1, figsize=(6, 6)) 51 | plotting() 52 | # ############################################## 53 | # crop to quadratic 54 | fig, ax = plt.subplots(1, 1, figsize=(6, 6)) 55 | kp_dict = {'src': src_point} 56 | img,kp_dict = crop_centered(img, kp_dict) 57 | src_point = kp_dict['src'] 58 | if path is not None: 59 | # add padding to path 60 | path = path.split('.') 61 | path = path[0] + '_crop.' + path[1] 62 | plotting() 63 | 64 | def plot_trg(img_, heatmap, path=None, scatter_size=70, alpha_gt=0.2, alpha=0.45, limits=None, src_to_trg_point=None, trg_point=None, c='g', trg_point2=None, c2='r'): 65 | img = img_.convert('L').copy() 66 | if limits is None: 67 | vmax = np.max(heatmap)*255 68 | vmin = np.min(heatmap)*255 69 | else: 70 | vmin = limits[0]*255 71 | vmax = limits[1]*255 72 | 73 | def plotting_gt(): 74 | ax.clear() 75 | ax.imshow(img_, alpha=1-alpha_gt, cmap=plt.get_cmap('gray')) 76 | ax.axis('off') 77 | if path_plain is not None: 78 | os.makedirs(os.path.dirname(path_plain), exist_ok=True) 79 | plt.savefig(path_plain, bbox_inches='tight', pad_inches=0) 80 | # plt.close('all') 81 | if trg_point is not None and trg_point[2] > 0: 82 | ax.scatter(trg_point[1], trg_point[0], edgecolor='black', linewidth=1, facecolor=c, s=scatter_size, label="$Q_{GT}$ visible \u2714 (1)") 83 | # # plot text at the bottom of the image with a label for the target point 84 | # x, y = img.size[0] // 2, img.size[1] - 20 85 | # s = f'$Q$ \n $Q_s$' 86 | # plt.text(x, y, s, bbox=dict(fill=True, facecolor='w', linewidth=2)) 87 | else: 88 | # add to legend without scatter 89 | ax.scatter([], [], edgecolor='black', linewidth=1, facecolor=c, s=scatter_size, label="$Q_{GT}$ visible \u2718 (0)") 90 | if trg_point2 is not None and trg_point2[2] > 0: 91 | ax.scatter(trg_point2[1], trg_point2[0], edgecolor='black', linewidth=1, facecolor=c2, s=scatter_size, label="$Q_{Symm}$ visible \u2714 (1)") 92 | else: 93 | # add to legend without scatter 94 | ax.scatter([], [], edgecolor='black', linewidth=1, facecolor=c2, s=scatter_size, label="$Q_{Symm}$ visible \u2718 (0)") 95 | if path_gt is not None: 96 | ax.legend(fontsize="20", loc ="upper left") 97 | os.makedirs(os.path.dirname(path_gt), exist_ok=True) 98 | plt.savefig(path_gt, bbox_inches='tight', pad_inches=0) 99 | plt.close('all') 100 | 101 | def plotting_pred(): 102 | ax.clear() 103 | ax.imshow(img, alpha=1-alpha, cmap=plt.get_cmap('gray')) 104 | ax.imshow(255 * heatmap, alpha=alpha, cmap='viridis', vmin=vmin, vmax=vmax) 105 | ax.axis('off') 106 | if src_to_trg_point is not None: 107 | if heatmap[src_to_trg_point[0], src_to_trg_point[1]] > 0.3: 108 | ax.scatter(src_to_trg_point[1], src_to_trg_point[0], edgecolor='black', linewidth=1, facecolor='y', s=scatter_size) 109 | if path is not None: 110 | os.makedirs(os.path.dirname(path), exist_ok=True) 111 | plt.savefig(path, bbox_inches='tight', pad_inches=0) 112 | plt.close('all') 113 | 114 | ############################################## 115 | fig, ax = plt.subplots(1, 1, figsize=(6, 6)) 116 | if path is not None: 117 | path_list = path.split('.') 118 | path_plain = path_list[0] + '_plain.' + path_list[1] 119 | path_gt = path_list[0] + '_gt.' + path_list[1] 120 | plotting_gt() 121 | fig, ax = plt.subplots(1, 1, figsize=(6, 6)) 122 | plotting_pred() 123 | 124 | ############################################## 125 | # crop to quadratic 126 | fig, ax = plt.subplots(1, 1, figsize=(6, 6)) 127 | kp_dict = {'src': src_to_trg_point, 'trg': trg_point, 'trg2': trg_point2} 128 | img, kp_dict = crop_centered(img, kp_dict) 129 | img_, _ = crop_centered(img_, copy.deepcopy(kp_dict)) 130 | src_to_trg_point = kp_dict['src'] 131 | trg_point = kp_dict['trg'] 132 | trg_point2 = kp_dict['trg2'] 133 | heatmap = crop_centered_heatmap(heatmap) 134 | if path is not None: 135 | path_list = path.split('.') 136 | path_plain = path_list[0] + '_plain_crop.' + path_list[1] 137 | path_gt = path_list[0] + '_gt_crop.' + path_list[1] 138 | plotting_gt() 139 | fig, ax = plt.subplots(1, 1, figsize=(6, 6)) 140 | if path is not None: 141 | path_list = path.split('.') 142 | path = path_list[0] + '_crop.' + path_list[1] 143 | plotting_pred() -------------------------------------------------------------------------------- /src/logging/visualization_seg.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import matplotlib 3 | matplotlib.use('Agg') 4 | import wandb 5 | 6 | COLOURS = [ 7 | (0, 255, 0), (0, 0, 255), (0, 255, 255), (255, 0, 0), (255, 0, 255), 8 | (255, 255, 0), (0, 0, 255), (0, 128, 255), (128, 0, 255), (0, 128, 0), 9 | (128, 0, 0), (0, 0, 128), (128, 128, 0), (0, 128, 128), (128, 0, 128), 10 | ] 11 | 12 | def plot_assignment(assignment_dict, list_of_kp_names=None, path=None, wandb_suffix='', plot_cmap=False, limits=None): 13 | import os 14 | if path is not None: 15 | if not os.path.exists(path): 16 | os.makedirs(path) 17 | def plot_matrix(M_, title, limits): 18 | if path is not None and 'single' in path and "cos_sim" in title: 19 | M = M_.clone() 20 | limits = [0, 1] 21 | elif path is not None and 'single' in path and "diag" in path: 22 | M = M_.clone() 23 | else: 24 | M = M_.clone()/M_.sum() 25 | ratio = M.shape[1]/M.shape[0] 26 | plt.figure(figsize=(5*ratio,5)) 27 | if limits is None: 28 | plt.imshow(M.detach().cpu().numpy(), cmap='jet') 29 | else: 30 | plt.imshow(M.detach().cpu().numpy(), cmap='jet', vmin=limits[0], vmax=limits[1]) 31 | plt.yticks(()) 32 | plt.xticks(()) 33 | # title_suffix = 'max:%.4f' % M.max().item() 34 | # plt.title(title + ' - ' + title_suffix) 35 | plt.tick_params(axis='x', bottom=False, top=True, labelbottom=False, labeltop=True) 36 | if list_of_kp_names is not None: 37 | # x ticks on top 38 | plt.xticks(range(len(list_of_kp_names)), list_of_kp_names, rotation=90) 39 | plt.yticks(range(len(list_of_kp_names)), list_of_kp_names) 40 | # adjust the plot such that the heading and the colorbar are not cut off 41 | plt.subplots_adjust(top=0.78) 42 | if path is not None: 43 | plt.savefig(path+'/matrix_'+title + ".png") 44 | if wandb.run is not None: 45 | # log to wandb 46 | wandb.log({wandb_suffix+'_'+title : wandb.Image(path+'/matrix_'+title + ".png")}) 47 | if plot_cmap: 48 | # plot the jet colormap 49 | plt.gca().set_visible(False) 50 | cax = plt.axes((0.0, 0.0, 1.0, 0.05)) 51 | plt.colorbar(orientation="horizontal", cax=cax) 52 | plt.savefig(path+'/colourbar_'+title + ".png", bbox_inches='tight') 53 | 54 | plt.close('all') 55 | 56 | for key, value in assignment_dict.items(): 57 | # if value.sum() > 0: 58 | # check for sparse matrix 59 | if hasattr(value, 'to_dense'): 60 | plot_matrix(value.to_dense(), key, limits) 61 | else: 62 | plot_matrix(value, key, limits) 63 | -------------------------------------------------------------------------------- /src/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from src.dataset.pairwise_utils import get_matches, scale_to_feature_dims, get_y_mat_gt_assignment 2 | from src.matcher.ot_matcher import SoftmaxMatcher 3 | import torch 4 | 5 | def pairwise_loss(losses, data, prob, b, prt, masses): 6 | # evaluate the matches by comparing the ground truth matches with the matches from the model 7 | data_matches = get_matches(data, b) 8 | data_matches = scale_to_feature_dims(data_matches, data, b) 9 | 10 | assignment_dict = {} 11 | ft_orig_b = [data['src_ft'][b], data['trg_ft'][b]] 12 | ft_size_b = [torch.tensor(f.shape[-2:]) for f in ft_orig_b] 13 | 14 | for prefix in ['pos', 'bin', 'neg']: 15 | assignment_dict['y_mat_' + prefix] = get_y_mat_gt_assignment(data_matches[prefix+'_src_kps'].clone().detach(), data_matches[prefix+'_trg_kps'].clone().detach(), ft_size_b[0], ft_size_b[1]) 16 | idx = assignment_dict['y_mat_' + prefix].coalesce().indices() 17 | prob_prefix = prob[idx[0].long(),idx[1].long()] 18 | # compute the loss 19 | p = 1/prob.shape[0] 20 | if prefix in ['pos', 'bin']: 21 | losses[prefix].append(-torch.log(prob_prefix)*p) 22 | else: 23 | losses[prefix].append(-torch.log(1-prob_prefix)*(1-0)) 24 | 25 | # compute the loss 26 | if True: 27 | prefix = 'neg_fg_bkg' 28 | idx0 = prt[0]>0 29 | idx1 = prt[1]==0 30 | idx0 = torch.cat([idx0, torch.zeros(1).bool().to(idx0.device)]) # add the bin to bool mask 31 | idx1 = torch.cat([idx1, torch.zeros(1).bool().to(idx1.device)]) # add the bin to bool mask 32 | prob_prefix = prob[idx0,:][:,idx1] 33 | losses[prefix].append(-torch.log(1-prob_prefix.flatten())*(1-0)) 34 | idx0 = prt[0]==0 35 | idx1 = prt[1]>0 36 | idx0 = torch.cat([idx0, torch.zeros(1).bool().to(idx0.device)]) 37 | idx1 = torch.cat([idx1, torch.zeros(1).bool().to(idx1.device)]) 38 | prob_prefix = prob[idx0,:][:,idx1] 39 | losses[prefix].append(-torch.log(1-prob_prefix.flatten())*(1-0)) 40 | 41 | del data_matches, assignment_dict 42 | return losses 43 | 44 | class PairwiseLoss(): 45 | def __init__(self, args): 46 | 47 | ot_params = { 48 | 'reg':0.1, 49 | 'reg_kl':10, 50 | 'sinkhorn_iterations':10, 51 | 'mass':0.9, 52 | 'bin_score':0.3 53 | } 54 | self.args = args 55 | self.matcher = SoftmaxMatcher(**ot_params) 56 | self.matcher = self.matcher.eval() # init dict of losses 57 | self.prefix = ['pos', 'bin', 'neg', 'neg_fg_bkg'] 58 | 59 | def get_loss(self, src_ft_, trg_ft_, data): 60 | 61 | losses = {p: [] for p in self.prefix} 62 | # compute the loss 63 | B = src_ft_.shape[0] 64 | for b in range(B): 65 | prt = [data['src_mask'][b], data['trg_mask'][b]] 66 | # resize the prt segmentation to feature dimensions 67 | prt = [torch.nn.functional.interpolate(p[None,None].float(), size=(f.shape[-2], f.shape[-1]), mode='nearest')[0,0].bool() for p,f in zip(prt, [data['src_ft'][b], data['trg_ft'][b]])] 68 | prt = [p.flatten() for p in prt] 69 | if prt[0].sum()*prt[1].sum()<1e-10: 70 | continue 71 | # get masses 72 | mass0 = (data['src_kps'][b,:,2].sum()+1)/(data['numkp'][b]+1)*self.matcher.mass 73 | mass1 = (data['trg_kps'][b,:,2].sum()+1)/(data['numkp'][b]+1)*self.matcher.mass 74 | prob = self.matcher(src_ft_[b], trg_ft_[b], segmentations=prt, masses=[mass0, mass1]) 75 | # use positive, negative and bin matches to compute the loss 76 | losses = pairwise_loss(losses, data, prob, b, prt, masses=[mass0, mass1]) 77 | losses = {k: torch.cat(v) for k,v in losses.items()} 78 | losses = {k: v for k,v in losses.items() if len(v)>0} 79 | return losses -------------------------------------------------------------------------------- /src/matcher/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reginehartwig/geco/3d3f37530bd61f36602f8ac840da13200c3c2e72/src/matcher/__init__.py -------------------------------------------------------------------------------- /src/matcher/argmaxmatcher.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | class ArgmaxMatcher(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | 9 | def forward(self, src_vec, trg_vec): 10 | dev = src_vec.device 11 | n,d = src_vec.shape 12 | m,d = trg_vec.shape 13 | # match the feature vectors at the source point 14 | src_vec = F.normalize(src_vec) # 1, C 15 | trg_vec = F.normalize(trg_vec).transpose(0, 1) #C, H0W0 16 | cos_map = torch.mm(src_vec, trg_vec) #1, H0W0 17 | # get the probability map, which is a one-hot map 18 | #trg_point = torch.unravel_index(cos_map.argmax(), cos_map.shape) 19 | prob = torch.zeros(n+1, m+1).to(dev) 20 | idx_0 = torch.arange(n).to(dev) 21 | idx_1 = cos_map.argmax(dim = 1) 22 | prob[idx_0, idx_1] = 1 23 | return cos_map, prob 24 | 25 | def prepare_one_to_all(self, src_ft, trg_ft, src_point, src_img_size, trg_img_size, upsample): 26 | # We use the feature maps to match src_point to the all target points 27 | 28 | # upsample the feature maps to match the original image size 29 | if upsample: 30 | # print(f"Memory allocated ups: {torch.cuda.memory_allocated()//1024**3} GB") 31 | src_ft = nn.Upsample(size=src_img_size.tolist(), mode='bilinear')(src_ft) 32 | # print(f"Memory allocated: {torch.cuda.memory_allocated()//1024**3} GB") 33 | trg_ft = nn.Upsample(size=trg_img_size.tolist(), mode='bilinear')(trg_ft) 34 | # print(f"Memory allocated: {torch.cuda.memory_allocated()//1024**3} GB") 35 | 36 | # scale the source point to the feature map size, in case we did not upsample the feature maps 37 | # get the scale factor of the feature maps wrt the original image size 38 | src_ft_size = src_ft.shape[-2:] 39 | src_scale = torch.tensor(src_ft_size).float()/src_img_size.float() 40 | src_point_ = src_point.clone() 41 | src_point_[:2] = src_point[:2] * src_scale 42 | src_point_ = src_point_.floor().long() 43 | 44 | # get the feature vector at the source point 45 | num_channel = src_ft.size(1) 46 | src_vec = src_ft[0, :, src_point_[0], src_point_[1]].view(1, num_channel) # 1, C 47 | # get the feature vectors at all target points 48 | trg_vec = trg_ft.reshape(num_channel, -1).transpose(0, 1) # H0W0, C 49 | # get the size of the target feature map 50 | trg_ft_size = trg_ft.shape[-2:] 51 | return src_vec, trg_vec, trg_ft_size 52 | 53 | def get_one_trg_point(self, cos_map, prob, trg_ft_size, trg_img_size): 54 | # Input: 55 | # trg_ft_size: H0, W0 indicating the size of cos_map and prob 56 | # trg_img_size: H, W indicating the size of the original image 57 | # in case upsample is True, trg_ft_size=trg_img_size 58 | # Output: 59 | # trg_point: the target point in the original image size 60 | # cos_map: the cosine map in the original image size 61 | H0,W0 = trg_ft_size[-2], trg_ft_size[-1] 62 | trg_scale = trg_img_size.float()/torch.tensor(trg_ft_size).float() 63 | cos_map = cos_map.view(H0,W0) # H0,W0 64 | bin_prob = prob[-1] 65 | prob = prob[:-1].view(H0,W0) 66 | trg_point = torch.tensor(torch.unravel_index(prob.argmax(), prob.shape)) 67 | if prob[trg_point[0], trg_point[1]] torch.Tensor: 12 | """ 13 | Perform Sinkhorn Normalization in Log-space for stability 14 | lambda: higher values result in lower entropy 15 | """ 16 | u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu) 17 | Z = lmba*Z 18 | for _ in range(iters): 19 | u = log_mu - torch.logsumexp(Z + v.unsqueeze(1)*lmba, dim=2) 20 | u = u/lmba 21 | v = log_nu - torch.logsumexp(Z + u.unsqueeze(2)*lmba, dim=1) 22 | v = v/lmba 23 | return Z + u.unsqueeze(2)*lmba + v.unsqueeze(1)*lmba 24 | 25 | def log_optimal_transport(mu:torch.Tensor, nu:torch.Tensor, couplings: torch.Tensor, reg: float, sinkhorn_iterations: int) -> torch.Tensor: 26 | """ Perform Differentiable Optimal Transport in Log-space for stability""" 27 | B, N_, M_ = couplings.shape 28 | log_mu, log_nu = mu.log(), nu.log() 29 | log_mu, log_nu = log_mu[None].expand(B, -1), log_nu[None].expand(B, -1) 30 | log_mu, log_nu = log_mu.to(couplings.device), log_nu.to(couplings.device) 31 | Z = log_sinkhorn_iterations(couplings, log_mu, log_nu, sinkhorn_iterations, 1/reg) 32 | # Z = Z - norm # multiply probabilities by M+N 33 | return Z.exp() 34 | 35 | # 2 36 | def log_sinkhorn_iterations_kl(Z: torch.Tensor, log_mu: torch.Tensor, log_nu: torch.Tensor, iters: int, lmba: float, reg_kl:float) -> torch.Tensor: 37 | """ 38 | Perform Sinkhorn Normalization in Log-space for stability 39 | lambda: higher values result in lower entropy 40 | reg_kl: higher values result in stronger KL regularization 41 | """ 42 | u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu) 43 | Z = lmba*Z 44 | phi= reg_kl/(reg_kl+1/lmba) 45 | for _ in range(iters): 46 | u = log_mu - torch.logsumexp(Z + v.unsqueeze(1)*lmba, dim=2) 47 | u = u/lmba * phi 48 | v = log_nu - torch.logsumexp(Z + u.unsqueeze(2)*lmba, dim=1) 49 | v = v/lmba * phi 50 | return Z + u.unsqueeze(2)*lmba + v.unsqueeze(1)*lmba 51 | 52 | def log_optimal_transport_kl(mu:torch.Tensor, nu:torch.Tensor, couplings: torch.Tensor, reg: float, reg_kl: float, sinkhorn_iterations: int) -> torch.Tensor: 53 | """ Perform Differentiable Optimal Transport in Log-space for stability""" 54 | B, N_, M_ = couplings.shape 55 | log_mu, log_nu = mu.log(), nu.log() 56 | log_mu, log_nu = log_mu[None].expand(B, -1), log_nu[None].expand(B, -1) 57 | log_mu, log_nu = log_mu.to(couplings.device), log_nu.to(couplings.device) 58 | Z = log_sinkhorn_iterations_kl(couplings, log_mu, log_nu, sinkhorn_iterations, 1/reg, reg_kl) 59 | # Z = Z - norm # multiply probabilities by M+N 60 | return Z.exp() 61 | 62 | # 3 63 | def distributed_sinkhorn(couplings, reg, sinkhorn_iterations): 64 | 65 | Q = torch.exp(couplings / reg).transpose(-1,-2) # Q is N-by-M for consistency with notations from our paper 66 | N_1 = Q.shape[-2] # how many prototypes 67 | M_1 = Q.shape[-1] # number of samples to assign 68 | 69 | # make the matrix sums to 1 70 | sum_Q = torch.sum(Q) 71 | Q /= sum_Q 72 | 73 | for it in range(sinkhorn_iterations): 74 | # normalize each row: total weight per prototype must be 1/N 75 | sum_of_rows = torch.sum(Q, dim=-1, keepdim=True) 76 | Q /= sum_of_rows 77 | Q /= N_1 78 | 79 | # normalize each column: total weight per sample must be 1/M 80 | Q /= torch.sum(Q, dim=-2, keepdim=True) 81 | Q /= M_1 82 | 83 | Q *= M_1 # the colomns must sum to 1 so that Q is an assignment 84 | return Q.transpose(-1,-2) 85 | 86 | 87 | # 4 88 | def ot_solver(a, b, couplings, type = "partial_wasserstein", reg = 0.005, reg_m_kl = 0.05, reg_m_l2 = 5): 89 | B, M, N = couplings.shape 90 | P_list = [] 91 | for i in range(B): 92 | dist = 2-couplings[i] 93 | if type == "entropic": 94 | P = ot.sinkhorn(a, b, dist, reg) 95 | if type == "entropic_kl_uot": 96 | P = ot.unbalanced.sinkhorn_unbalanced(a, b, dist, reg, reg_m_kl) 97 | 98 | if type == "kl_uot": 99 | P = ot.unbalanced.mm_unbalanced(a, b, dist, reg_m_kl, div='kl') 100 | 101 | if type == "l2_uot": 102 | P = ot.unbalanced.mm_unbalanced(a, b, dist, reg_m_l2, div='l2') 103 | P_list.append(P) 104 | P = torch.stack(P_list) 105 | # if type == "partial_ot": 106 | # P = ot.partial.partial_wasserstein(a, b, dist, m=alpha.item()) 107 | # P = 108 | return P 109 | 110 | ############################################ 111 | # Problem definition, partial OT 112 | ############################################ 113 | 114 | def get_partial_ot_problem(scores, bin_score): 115 | if not isinstance(bin_score, torch.Tensor): 116 | bin_score = scores.new_tensor(bin_score) 117 | B, M, N = scores.shape 118 | dev = scores.device 119 | bins0 = bin_score.expand(B, M, 1).to(dev) 120 | bins1 = bin_score.expand(B, 1, N).to(dev) 121 | bin_score = bin_score.expand(B, 1, 1).to(dev) 122 | 123 | couplings = torch.cat([torch.cat([scores, bins0], -1), 124 | torch.cat([bins1, bin_score], -1)], 1) 125 | return couplings 126 | 127 | ############################################ 128 | # Problem definition, partial distributions that assign mass to all samples and (1-mass) to the bin 129 | ############################################ 130 | 131 | def get_gt_distributions(y_mat): 132 | # ground truth distribution 133 | mu = y_mat.sum(dim=1)/y_mat.sum() # distribution of prototypes 134 | nu = y_mat.sum(dim=0)/y_mat.sum() # distribution of features 135 | return mu, nu 136 | 137 | def get_partial_distributions(N, M, mass): 138 | a, b = torch.ones((N,)) / N, torch.ones((M,)) / M # uniform distribution on samples 139 | a = torch.cat([a*mass, a.new_tensor(1-mass)[None]]) 140 | b = torch.cat([b*mass, b.new_tensor(1-mass)[None]]) 141 | return a, b 142 | 143 | def get_partial_distributions_input_marginals(prt, masses, mass_fg): 144 | mass_bkg = 1- mass_fg 145 | a = prt[0].flatten() * masses[0]/prt[0].sum() 146 | if a[a<1e-10].shape[0]>0: 147 | mass_bkg_a = torch.tensor(mass_bkg) 148 | a[a<1e-10] = mass_bkg_a/(a[a<1e-10].shape[0]) 149 | else: 150 | mass_bkg_a = torch.tensor(0) 151 | # print(a[prt[0].flatten()>1e-10].sum()) 152 | a_bin = (1 - (masses[0] + mass_bkg_a)).clone().detach()[None].to(a.device) 153 | a = torch.cat([a, a_bin]) 154 | 155 | b = prt[1].flatten() * masses[1]/prt[1].sum() 156 | if b[b<1e-10].shape[0]>0: 157 | mass_bkg_b = torch.tensor(mass_bkg) 158 | b[b<1e-10] = mass_bkg_b/(b[b<1e-10].shape[0]) 159 | else: 160 | mass_bkg_b = torch.tensor(0) 161 | # print(a[prt[0].flatten()<1e-10].sum()) 162 | b_bin = (1 - (masses[1] + mass_bkg_b)).clone().detach()[None].to(b.device) 163 | b = torch.cat([b, b_bin]) 164 | return a, b 165 | 166 | def get_partial_log_distributions(N, M, mass): 167 | # specify the mass to be assigned 168 | one = torch.tensor(1.) 169 | ms, ns = (N*one), (M*one) 170 | # norm = - (ms + ns).log() 171 | # log_mu = torch.cat([norm.expand(m), ns.log()[None] + norm]) 172 | # log_nu = torch.cat([norm.expand(n), ms.log()[None] + norm]) 173 | mass = torch.tensor(mass) 174 | norm_m = - (ms).log() + mass.log() 175 | norm_n = - (ns).log() + mass.log() 176 | log_mu = torch.cat([norm_m.expand(N), (1-mass).log()[None]]) 177 | log_nu = torch.cat([norm_n.expand(M), (1-mass).log()[None]]) 178 | return log_mu, log_nu 179 | 180 | ############################################ 181 | # The Matcher Module with trainable bin score 182 | ############################################ 183 | 184 | def prob_add_bkg(prob_fg, mask0, mask1): 185 | # get the cosine similarity 186 | dev = prob_fg.device 187 | # n,m = prob_fg.shape 188 | n = mask0.shape[0] 189 | m = mask1.shape[0] 190 | # prepare the output, assign the probabilities to the foreground + bin 191 | mask_fg = torch.ones(n+1, m+1).to(dev) 192 | mask_fg[:-1,:-1][~mask0,:] = 0.0 193 | mask_fg[:-1,:-1][:,~mask1] = 0.0 194 | mask_fg[:-1,-1][~mask0] = 0.0 195 | mask_fg[-1,:-1][~mask1] = 0.0 196 | prob = torch.zeros(n+1, m+1).to(dev) #* -torch.inf 197 | prob[mask_fg.bool()] = prob_fg.flatten() 198 | prob[-1,-1] = 0 199 | return prob 200 | 201 | def ft_remove_bkg(ft0, ft1): 202 | th = 1e-8 203 | # get mask the background features 204 | mask0 = ft0.norm(dim=-1) > th 205 | mask1 = ft1.norm(dim=-1) > th 206 | 207 | ft0_fg = ft0[mask0] 208 | ft1_fg = ft1[mask1] 209 | return ft0_fg, ft1_fg, mask0, mask1 210 | 211 | class SoftmaxMatcher(nn.Module): 212 | def __init__(self, sinkhorn_iterations=100, bin_score=0.4, mass=0.9, reg=0.1, reg_kl=0.01): 213 | super().__init__() 214 | # super(SoftmaxMatcher, self).__init__() 215 | self.sinkhorn_iterations = sinkhorn_iterations 216 | bin_score = torch.nn.Parameter(torch.tensor(bin_score))# DINOv2 default value 217 | self.register_parameter('bin_score', bin_score) 218 | self.mass = mass 219 | self.reg = reg 220 | self.reg_kl = reg_kl 221 | 222 | def forward(self, ft0, ft1, y_mat=None, segmentations=None, masses=None): 223 | 224 | # remove bkg 225 | ft0_fg, ft1_fg, mask0, mask1 = ft_remove_bkg(ft0, ft1) 226 | 227 | # normalize the features (worse results!) 228 | ft0_fg = ft0_fg/ft0_fg.norm(dim=-1)[:,None] 229 | ft1_fg = ft1_fg/ft1_fg.norm(dim=-1)[:,None] 230 | 231 | cos_sim = torch.mm(ft0_fg, ft1_fg.t()) # N, M 232 | # Run the optimal transport on foreground features 233 | N, M = cos_sim.shape 234 | device = cos_sim.device 235 | 236 | # no gt y_mat available, we use the cosine similarity matrix 237 | if y_mat is not None: 238 | mu, nu = get_gt_distributions(y_mat) 239 | elif segmentations is not None and masses is not None: 240 | mu, nu = get_partial_distributions_input_marginals(segmentations, masses, self.mass) 241 | else: 242 | mu, nu = get_partial_distributions(N, M, mass=self.mass) 243 | couplings = get_partial_ot_problem(cos_sim[None], bin_score=self.bin_score) 244 | mu, nu = mu.to(device), nu.to(device) 245 | prob_out = log_optimal_transport_kl(mu, nu, couplings, reg=self.reg, reg_kl=self.reg_kl, sinkhorn_iterations=self.sinkhorn_iterations)[0] # diverging if reg_kl<1 246 | # prob_out = log_optimal_transport(mu, nu, couplings, reg=self.reg, sinkhorn_iterations=self.sinkhorn_iterations)[0] 247 | 248 | # add bkg 249 | prob = prob_add_bkg(prob_out, mask0, mask1) 250 | 251 | del prob_out, cos_sim 252 | return prob 253 | 254 | def get_cossim(self, ft0, ft1): 255 | ft0_fg, ft1_fg, mask0, mask1 = ft_remove_bkg(ft0, ft1) 256 | # normalize the features (worse results!) 257 | ft0_fg = ft0_fg/ft0_fg.norm(dim=-1)[:,None] 258 | ft1_fg = ft1_fg/ft1_fg.norm(dim=-1)[:,None] 259 | 260 | cos_sim = torch.mm(ft0_fg, ft1_fg.t()) 261 | cos_sim = F.pad(cos_sim, (0,1,0,1), value=0) 262 | return cos_sim -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reginehartwig/geco/3d3f37530bd61f36602f8ac840da13200c3c2e72/src/models/__init__.py -------------------------------------------------------------------------------- /src/models/classifier/supervised/nearest_centroid.py: -------------------------------------------------------------------------------- 1 | from src.models.pca import compute_pca 2 | from sklearn.neighbors import NearestCentroid 3 | import torch 4 | import torch.nn as nn 5 | 6 | class Nearest_Centroid_fg_Classifier(nn.Module): 7 | name = 'nearest_centroid_fg' 8 | def __init__(self, n_components_pca, x, y_mat): 9 | ''' 10 | Input: 11 | n_components_pca: int, number of components for PCA 12 | x: torch.tensor, shape (B, feature_dim), feature vectors 13 | y_mat: torch.tensor, shape (B, num_parts), part labels, one-hot encoding, 14 | no background part, as we assume to only receive foreground features 15 | ''' 16 | super().__init__() 17 | feature_dim = x.shape[-1] 18 | x = x.reshape(-1,feature_dim) 19 | x = x/x.norm(dim=-1, keepdim=True) 20 | self.part_components = compute_pca(x, n_components_pca) 21 | x_proj = torch.mm(x, self.part_components.t()) 22 | self.classifier = NearestCentroid(metric="manhattan") 23 | self.num_parts = y_mat.shape[-1] 24 | y = y_mat.argmax(-1) 25 | self.parts = torch.unique(y) 26 | # assert(len(self.parts) == self.num_parts) # check if all parts are present 27 | x_proj = x_proj/x_proj.norm(dim=-1, keepdim=True) 28 | self.classifier.fit(x_proj.cpu().numpy(), y.cpu().numpy()) 29 | self.prototypes_proj = torch.tensor(self.classifier.centroids_, device=self.part_components.device, dtype=self.part_components.dtype) # num_parts, n_components_pca 30 | self.prototypes = torch.mm(self.prototypes_proj, self.part_components) 31 | 32 | def get_prototypes(self): 33 | if len(self.parts) == self.num_parts: 34 | return self.prototypes 35 | else: 36 | prototypes = torch.zeros(self.num_parts, self.prototypes.shape[-1], device=self.prototypes.device, dtype=self.prototypes.dtype) 37 | prototypes[self.parts] = self.prototypes 38 | return prototypes 39 | 40 | def forward(self, x): 41 | ''' 42 | Input: 43 | x: torch.tensor, shape (B, feature_dim), feature vectors 44 | Output: 45 | y_mat: torch.tensor, shape (B, num_parts), part labels, one-hot encoding 46 | probably neg values for background, but not trained for that 47 | ''' 48 | feature_dim = x.shape[-1] 49 | # get the scalar product of the input with the principal components 50 | x = x.reshape(-1,feature_dim) 51 | x_proj = torch.mm(x, self.part_components.t()) 52 | x_proj = x_proj/x_proj.norm(dim=-1, keepdim=True) 53 | prototypes_proj = self.prototypes_proj/self.prototypes_proj.norm(dim=-1, keepdim=True) 54 | y_mat = torch.zeros((x.shape[0],self.num_parts)).to(x.device) # B, num_parts 55 | y_mat_ = torch.mm(x_proj, prototypes_proj.t()) 56 | y_mat[:, self.parts] = y_mat_ 57 | return y_mat 58 | -------------------------------------------------------------------------------- /src/models/classifier/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from src.dataset.utils import get_init_feats_and_labels 4 | import torch.nn.functional as F 5 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 6 | from src.models.classifier.supervised.nearest_centroid import Nearest_Centroid_fg_Classifier 7 | 8 | def train_classifier(args_seg, dataset_train, model_refine=None): 9 | 10 | if args_seg.model == 'nearest_centroid_fg': 11 | feats, y_mat = get_init_feats_and_labels(dataset_train, args_seg.num_samples, only_fg=True, model_refine=model_refine) 12 | model_seg = Nearest_Centroid_fg_Classifier(args_seg.num_pcaparts, feats, y_mat) 13 | 14 | return model_seg 15 | 16 | def forward_classifier(dataset, idx, model_seg, model_refine=None): 17 | data = dataset[idx] 18 | 19 | # get the part segmentation 20 | imsize = data['imsize'] 21 | if model_seg.name in ['nearest_centroid_fg']: 22 | ft = data['ft'][None].to(device) # B, C, H, W 23 | if model_refine is not None: 24 | ft = model_refine(ft) 25 | ft_interp = F.interpolate(ft, size=imsize.tolist(), mode='bilinear', align_corners=False) 26 | prt = model_seg(ft_interp.permute(0,2,3,1).flatten(0,-2)) # B, C, H, W -> H*W, C 27 | prt = prt.reshape(ft.shape[0], imsize[0], imsize[1], -1).permute(0,3,1,2) # H*W, C -> B, C, H, W 28 | 29 | return prt -------------------------------------------------------------------------------- /src/models/featurizer/clip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import open_clip 4 | import math 5 | 6 | 7 | def interpolate_pos_encoding(clip_model, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: 8 | """ 9 | This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher 10 | resolution images. 11 | Source: 12 | https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 13 | """ 14 | 15 | num_patches = embeddings.shape[1] - 1 16 | pos_embedding = clip_model.positional_embedding.unsqueeze(0) 17 | num_positions = pos_embedding.shape[1] - 1 18 | if num_patches == num_positions and height == width: 19 | return clip_model.positional_embedding 20 | class_pos_embed = pos_embedding[:, 0] 21 | patch_pos_embed = pos_embedding[:, 1:] 22 | dim = embeddings.shape[-1] 23 | h0 = height // clip_model.patch_size[0] 24 | w0 = width // clip_model.patch_size[1] 25 | # we add a small number to avoid floating point error in the interpolation 26 | # see discussion at https://github.com/facebookresearch/dino/issues/8 27 | h0, w0 = h0 + 0.1, w0 + 0.1 28 | patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) 29 | patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) 30 | patch_pos_embed = nn.functional.interpolate( 31 | patch_pos_embed, 32 | scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), 33 | mode="bicubic", 34 | align_corners=False, 35 | ) 36 | assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1] 37 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 38 | output = torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) 39 | 40 | return output 41 | 42 | def get_name(): 43 | return 'open_clip' 44 | 45 | class CLIPFeaturizer: 46 | name = 'open_clip' 47 | def __init__(self): 48 | clip_model, _, _ = open_clip.create_model_and_transforms('ViT-H-14', pretrained='laion2b_s32b_b79k') 49 | visual_model = clip_model.visual 50 | visual_model.output_tokens = True 51 | self.clip_model = visual_model.eval().cuda() 52 | 53 | 54 | def get_models(self): 55 | return [self.clip_model] 56 | 57 | @torch.no_grad() 58 | def forward(self, 59 | x, # single image, [1,c,h,w] 60 | block_index, **kwargs): 61 | batch_size = 1 62 | clip_model = self.clip_model 63 | if clip_model.input_patchnorm: 64 | # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)') 65 | x = x.reshape(x.shape[0], x.shape[1], clip_model.grid_size[0], clip_model.patch_size[0], clip_model.grid_size[1], clip_model.patch_size[1]) 66 | x = x.permute(0, 2, 4, 1, 3, 5) 67 | x = x.reshape(x.shape[0], clip_model.grid_size[0] * clip_model.grid_size[1], -1) 68 | x = clip_model.patchnorm_pre_ln(x) 69 | x = clip_model.conv1(x) 70 | else: 71 | x = clip_model.conv1(x) # shape = [*, width, grid, grid] 72 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 73 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 74 | # class embeddings and positional embeddings 75 | x = torch.cat( 76 | [clip_model.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), 77 | x], dim=1) # shape = [*, grid ** 2 + 1, width] 78 | if(x.shape[1] > clip_model.positional_embedding.shape[0]): 79 | dim = int(math.sqrt(x.shape[1]) * clip_model.patch_size[0]) 80 | x = x + interpolate_pos_encoding(clip_model, x, dim, dim).to(x.dtype) 81 | else: 82 | x = x + clip_model.positional_embedding.to(x.dtype) 83 | 84 | # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in 85 | x = clip_model.patch_dropout(x) 86 | x = clip_model.ln_pre(x) 87 | x = x.permute(1, 0, 2) # NLD -> LND 88 | 89 | num_channel = x.size(2) 90 | ft_size = int((x.shape[0]-1) ** 0.5) 91 | 92 | for i, r in enumerate(clip_model.transformer.resblocks): 93 | x = r(x) 94 | 95 | if i == block_index: 96 | tokens = x.permute(1, 0, 2) # LND -> NLD 97 | tokens = tokens[:, 1:] 98 | tokens = tokens.transpose(1, 2).contiguous().view(batch_size, num_channel, ft_size, ft_size) # NCHW 99 | 100 | return tokens -------------------------------------------------------------------------------- /src/models/featurizer/dift_adm.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import torch 4 | from torchvision import transforms 5 | main_path = Path(__file__).resolve().parent.parent.parent.parent 6 | print(f'main path: {main_path}') 7 | 8 | import sys 9 | sys.path.append(os.path.join(main_path, 'guided-diffusion')) 10 | from guided_diffusion.script_util import create_model_and_diffusion 11 | from guided_diffusion.nn import timestep_embedding 12 | 13 | def get_name(): 14 | return 'dift_adm' 15 | 16 | class ADMFeaturizer: 17 | name = 'dift_adm' 18 | def __init__(self): 19 | model, diffusion = create_model_and_diffusion( 20 | image_size=256, 21 | class_cond=False, 22 | learn_sigma=True, 23 | num_channels=256, 24 | num_res_blocks=2, 25 | channel_mult="", 26 | num_heads=4, 27 | num_head_channels=64, 28 | num_heads_upsample=-1, 29 | attention_resolutions="32,16,8", 30 | dropout=0.0, 31 | diffusion_steps=1000, 32 | noise_schedule='linear', 33 | timestep_respacing='', 34 | use_kl=False, 35 | predict_xstart=False, 36 | rescale_timesteps=False, 37 | rescale_learned_sigmas=False, 38 | use_checkpoint=False, 39 | use_scale_shift_norm=True, 40 | resblock_updown=True, 41 | use_fp16=False, 42 | use_new_attention_order=False, 43 | ) 44 | model_path = os.path.join(main_path, 'guided-diffusion/models/256x256_diffusion_uncond.pt') 45 | model.load_state_dict(torch.load(model_path, map_location="cpu")) 46 | self.model = model.eval().cuda() 47 | self.diffusion = diffusion 48 | 49 | self.adm_transforms = transforms.Compose([ 50 | transforms.ToTensor(), 51 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) 52 | ]) 53 | 54 | @torch.no_grad() 55 | def forward(self, img_tensor, 56 | t=101, 57 | up_ft_index=4, 58 | ensemble_size=8): 59 | model = self.model 60 | diffusion = self.diffusion 61 | 62 | img_tensor = img_tensor.repeat(ensemble_size, 1, 1, 1).cuda() # ensem, c, h, w 63 | t = torch.ones((img_tensor.shape[0],), device='cuda', dtype=torch.int64) * t 64 | x_t = diffusion.q_sample(img_tensor, t, noise=None) 65 | 66 | # get layer-wise features 67 | hs = [] 68 | emb = model.time_embed(timestep_embedding(t, model.model_channels)) 69 | h = x_t.type(model.dtype) 70 | for module in model.input_blocks: 71 | h = module(h, emb) 72 | hs.append(h) 73 | h = model.middle_block(h, emb) 74 | for i, module in enumerate(model.output_blocks): 75 | h = torch.cat([h, hs.pop()], dim=1) 76 | h = module(h, emb) 77 | 78 | if i == up_ft_index: 79 | ft = h.mean(0, keepdim=True).detach() 80 | return ft 81 | 82 | 83 | class ADMFeaturizer4Eval(ADMFeaturizer): 84 | 85 | @torch.no_grad() 86 | def forward(self, img, 87 | img_size=[512, 512], 88 | t=101, 89 | up_ft_index=4, 90 | ensemble_size=8, 91 | **kwargs): 92 | 93 | img_tensor = self.adm_transforms(img.resize(img_size)) 94 | ft = super().forward(img_tensor, 95 | t=t, 96 | up_ft_index=up_ft_index, 97 | ensemble_size=ensemble_size) 98 | del img_tensor 99 | torch.cuda.empty_cache() 100 | return ft 101 | 102 | def get_models(self): 103 | return [self.model] -------------------------------------------------------------------------------- /src/models/featurizer/dino.py: -------------------------------------------------------------------------------- 1 | import torch 2 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 3 | 4 | from torchvision import transforms as T 5 | def prepare_ViT_images(vit_patch_size, img, img_size): 6 | 7 | transform = T.Compose([ 8 | T.Resize(img_size), 9 | T.ToTensor(), 10 | T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 11 | T.ConvertImageDtype(torch.float) 12 | ]) 13 | img = transform(img) 14 | w, h = img.shape[-2] - img.shape[-2] % vit_patch_size, img.shape[-1] - img.shape[-1] % vit_patch_size 15 | img = img[..., :w, :h].float().to(device) 16 | # check if shape of image has 3 dimensions 17 | if len(img.shape) == 3: 18 | img = img[None] 19 | return img # B, C, W, H 20 | 21 | def get_name(dino_id='dino_vitb8', log_bin=True, img_size='224', up_ft_index='4', **kwargs): 22 | name = 'dino' 23 | if log_bin: 24 | name = name + '_logbin' 25 | return name + f'_%d_upft%d' % (img_size[0], up_ft_index) 26 | 27 | class DINOFeaturizer: 28 | name = 'dino' 29 | def __init__(self, dino_id='dino_vitb8', log_bin=True, img_size='224', up_ft_index='4', **kwargs): 30 | model = torch.hub.load('facebookresearch/dino:main', dino_id) 31 | self.model = model.eval().to(device) 32 | self._use_log_bin = log_bin 33 | self.name = get_name(dino_id, log_bin, img_size, up_ft_index, **kwargs) 34 | self.vit_patch_size = 8 35 | 36 | def get_models(self): 37 | return [self.model] 38 | 39 | def prepare_tokens(self, img_tensor, up_ft_index): 40 | B, C, W, H = img_tensor.shape 41 | # we return the output tokens from the `n` last blocks as a list 42 | out = self.model.get_intermediate_layers(img_tensor, n=up_ft_index) 43 | out = out[0] # take the output of the the n-th last block 44 | out = out[:, 1:, :] # B, w0*h0, D 45 | D = out.shape[-1] 46 | out = out.transpose(-2, -1).view(B, D, self.w0, self.h0) 47 | return out 48 | 49 | @torch.no_grad() 50 | def forward(self, img, up_ft_index=3, **kwargs): 51 | # convert pil to tensor 52 | img_size = kwargs.get('img_size') 53 | img_tensor = prepare_ViT_images(self.vit_patch_size, img, img_size) # [B, C, W, H] 54 | self.w0 = img_tensor.shape[-2] // self.vit_patch_size 55 | self.h0 = img_tensor.shape[-1] // self.vit_patch_size 56 | out = self.prepare_tokens(img_tensor, up_ft_index) # B, D, w0, h0 57 | if self._use_log_bin: 58 | out = self._log_bin(out.permute(0,2,3,1)) 59 | out = out.view(-1,self.w0,self.h0,out.shape[-1]) # B, w0, h0, D' 60 | out = out.permute(0,3,1,2) # B, D', w0, h0 61 | return out 62 | 63 | def get_w0_h0(self): 64 | return self.w0, self.h0 65 | 66 | def _log_bin(self, x: torch.Tensor, hierarchy: int = 2) -> torch.Tensor: 67 | """ 68 | create a log-binned descriptor. 69 | :param x: tensor of features. Has shape Bxwxhxd. 70 | :param hierarchy: how many bin hierarchies to use. 71 | """ 72 | B, w0, h0, d = x.shape 73 | num_bins = 1 + self.vit_patch_size * hierarchy 74 | bin_x = x.permute(0, 3, 1, 2) # B, d, w0, h0 75 | sub_desc_dim = bin_x.shape[1] 76 | 77 | avg_pools = [] 78 | # compute bins of all sizes for all spatial locations. 79 | for k in range(0, hierarchy): 80 | # avg pooling with kernel 3**kx3**k 81 | win_size = 3 ** k 82 | avg_pool = torch.nn.AvgPool2d(win_size, stride=1, padding=win_size // 2, count_include_pad=False) 83 | avg_pools.append(avg_pool(bin_x)) 84 | 85 | bin_x = torch.zeros((B, sub_desc_dim * num_bins, w0, h0)).to(x.device) 86 | for y in range(w0): 87 | for x in range(h0): 88 | part_idx: int = 0 89 | # fill all bins for a spatial location (y, x) 90 | for k in range(0, hierarchy): 91 | kernel_size = 3 ** k 92 | for i in range(y - kernel_size, y + kernel_size + 1, kernel_size): 93 | for j in range(x - kernel_size, x + kernel_size + 1, kernel_size): 94 | if i == y and j == x and k != 0: 95 | continue 96 | if 0 <= i < w0 and 0 <= j < h0: 97 | bin_x[:, part_idx * sub_desc_dim: (part_idx + 1) * sub_desc_dim, y, x] = avg_pools[k][ 98 | :, :, i, j] 99 | else: # handle padding in a more delicate way than zero padding 100 | temp_i = max(0, min(i, w0 - 1)) 101 | temp_j = max(0, min(j, h0 - 1)) 102 | bin_x[:, part_idx * sub_desc_dim: (part_idx + 1) * sub_desc_dim, y, x] = avg_pools[k][ 103 | :, :, temp_i, 104 | temp_j] 105 | part_idx += 1 106 | bin_x = bin_x.flatten(start_dim=-2, end_dim=-1).permute(0, 2, 1).unsqueeze(dim=1) 107 | # Bx1x(t-1)x(dxh) 108 | return bin_x -------------------------------------------------------------------------------- /src/models/featurizer/dinov2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 3 | 4 | from torchvision import transforms as T 5 | 6 | import torch 7 | 8 | def prepare_ViT_images(vit_patch_size, img, img_size): 9 | transform = T.Compose([ 10 | T.Resize(img_size), 11 | T.ToTensor(), 12 | T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 13 | T.ConvertImageDtype(torch.float) 14 | ]) 15 | img = transform(img) 16 | w, h = img.shape[-2] - img.shape[-2] % vit_patch_size, img.shape[-1] - img.shape[-1] % vit_patch_size 17 | img = img[..., :w, :h].float().to(device) 18 | # check if shape of image has 3 dimensions 19 | if len(img.shape) == 3: 20 | img = img[None] 21 | return img # B, C, W, H 22 | 23 | def get_name(log_bin=True, img_size='518', up_ft_index='1', **kwargs): 24 | name = 'dinov2' 25 | if log_bin: 26 | name = 'dinov2_logbin' 27 | 28 | model_size = kwargs.get('model_size', "dinov2_vits14") 29 | if model_size != "dinov2_vits14": 30 | # remove dinov2_vit from model_size 31 | model_size = model_size.replace("dinov2_vit", "") 32 | name += f'_{model_size}' 33 | 34 | return name + f'_%d_upft%d' % (img_size[0], up_ft_index) 35 | 36 | class DINOv2Featurizer: 37 | name = 'dinov2' 38 | def __init__(self, log_bin=True, img_size='518', up_ft_index='1', **kwargs): 39 | self.model_size = kwargs.get('model_size', "dinov2_vits14") 40 | model = torch.hub.load("facebookresearch/dinov2", self.model_size) 41 | self.model = model.eval().to(device) 42 | self._use_log_bin = log_bin 43 | self.name = get_name(log_bin, img_size, up_ft_index, **kwargs) 44 | self.vit_patch_size = 14 45 | 46 | def get_models(self): 47 | return [self.model] 48 | 49 | def prepare_tokens(self, img_tensor, up_ft_index): 50 | B, C, W, H = img_tensor.shape 51 | # out_dict = self.model(pixel_values=img_tensor, return_dict=True, output_hidden_states=True) 52 | #https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/models/dinov2/modeling_dinov2.py#L661 53 | # out = out_dict.hidden_states[-up_ft_index] # B, w0*h0, D 54 | out = self.model.get_intermediate_layers(img_tensor, n=up_ft_index) 55 | out = out[0] # take the output of the the n-th last block 56 | out = out[:, :, :] # B, w0*h0, D 57 | D = out.shape[-1] 58 | out = out.transpose(-2, -1).view(B, D, self.w0, self.h0) 59 | return out 60 | 61 | @torch.no_grad() 62 | def forward(self, img, up_ft_index=3, **kwargs): 63 | # convert pil to tensor 64 | img_size = kwargs.get('img_size') 65 | img_tensor = prepare_ViT_images(self.vit_patch_size, img, img_size) # [B, C, W, H] 66 | self.w0 = img_tensor.shape[-2] // self.vit_patch_size 67 | self.h0 = img_tensor.shape[-1] // self.vit_patch_size 68 | out = self.prepare_tokens(img_tensor, up_ft_index) # B, D, w0, h0 69 | if self._use_log_bin: 70 | out = self._log_bin(out.permute(0,2,3,1)) 71 | out = out.view(-1,self.w0,self.h0,out.shape[-1]) # B, w0, h0, D' 72 | out = out.permute(0,3,1,2) # B, D', w0, h0 73 | return out 74 | 75 | def get_w0_h0(self): 76 | return self.w0, self.h0 77 | 78 | def _log_bin(self, x: torch.Tensor, hierarchy: int = 2) -> torch.Tensor: 79 | """ 80 | create a log-binned descriptor. 81 | :param x: tensor of features. Has shape Bxwxhxd. 82 | :param hierarchy: how many bin hierarchies to use. 83 | """ 84 | B, w0, h0, d = x.shape 85 | num_bins = 1 + self.vit_patch_size * hierarchy 86 | bin_x = x.permute(0, 3, 1, 2) # B, d, w0, h0 87 | sub_desc_dim = bin_x.shape[1] 88 | 89 | avg_pools = [] 90 | # compute bins of all sizes for all spatial locations. 91 | for k in range(0, hierarchy): 92 | # avg pooling with kernel 3**kx3**k 93 | win_size = 3 ** k 94 | avg_pool = torch.nn.AvgPool2d(win_size, stride=1, padding=win_size // 2, count_include_pad=False) 95 | avg_pools.append(avg_pool(bin_x)) 96 | 97 | bin_x = torch.zeros((B, sub_desc_dim * num_bins, w0, h0)).to(x.device) 98 | for y in range(w0): 99 | for x in range(h0): 100 | part_idx: int = 0 101 | # fill all bins for a spatial location (y, x) 102 | for k in range(0, hierarchy): 103 | kernel_size = 3 ** k 104 | for i in range(y - kernel_size, y + kernel_size + 1, kernel_size): 105 | for j in range(x - kernel_size, x + kernel_size + 1, kernel_size): 106 | if i == y and j == x and k != 0: 107 | continue 108 | if 0 <= i < w0 and 0 <= j < h0: 109 | bin_x[:, part_idx * sub_desc_dim: (part_idx + 1) * sub_desc_dim, y, x] = avg_pools[k][ 110 | :, :, i, j] 111 | else: # handle padding in a more delicate way than zero padding 112 | temp_i = max(0, min(i, w0 - 1)) 113 | temp_j = max(0, min(j, h0 - 1)) 114 | bin_x[:, part_idx * sub_desc_dim: (part_idx + 1) * sub_desc_dim, y, x] = avg_pools[k][ 115 | :, :, temp_i, 116 | temp_j] 117 | part_idx += 1 118 | bin_x = bin_x.flatten(start_dim=-2, end_dim=-1).permute(0, 2, 1).unsqueeze(dim=1) 119 | # Bx1x(t-1)x(dxh) 120 | return bin_x -------------------------------------------------------------------------------- /src/models/featurizer/dinov2_lora.py: -------------------------------------------------------------------------------- 1 | import torch 2 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 3 | 4 | from torchvision import transforms as T 5 | 6 | import torch 7 | import torch.nn as nn 8 | import math 9 | 10 | # LoRA for Dinov2 11 | # This code is based on the original DINOv2 implementation and the LoRA implementation from https://github.com/RobvanGastel/dinov2-finetune/blob/main/dino_finetune/model/dino_v2.py 12 | 13 | class LoRA(nn.Module): 14 | """Low-Rank Adaptation for the for Query (Q), Key (Q), Value (V) matrices""" 15 | 16 | def __init__( 17 | self, 18 | qkv: nn.Module, 19 | linear_a_q: nn.Module, 20 | linear_b_q: nn.Module, 21 | linear_a_v: nn.Module, 22 | linear_b_v: nn.Module, 23 | ): 24 | super().__init__() 25 | self.qkv = qkv 26 | self.linear_a_q = linear_a_q 27 | self.linear_b_q = linear_b_q 28 | self.linear_a_v = linear_a_v 29 | self.linear_b_v = linear_b_v 30 | self.dim = getattr(qkv, 'in_features', 768) # Default fallback 31 | self.w_identity = torch.eye(self.dim) 32 | 33 | def forward(self, x) -> torch.Tensor: 34 | # Compute the original qkv 35 | qkv = self.qkv(x) # Shape: (B, N, 3 * org_C) 36 | 37 | # Compute the new q and v components 38 | new_q = self.linear_b_q(self.linear_a_q(x)) 39 | new_v = self.linear_b_v(self.linear_a_v(x)) 40 | 41 | # Add new q and v components to the original qkv tensor 42 | qkv[:, :, : self.dim] += new_q 43 | qkv[:, :, -self.dim :] += new_v 44 | 45 | return qkv 46 | 47 | 48 | def prepare_ViT_images(vit_patch_size, img, img_size): 49 | transform = T.Compose([ 50 | T.Resize(img_size), 51 | T.ToTensor(), 52 | T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 53 | T.ConvertImageDtype(torch.float) 54 | ]) 55 | img = transform(img) 56 | w, h = img.shape[-2] - img.shape[-2] % vit_patch_size, img.shape[-1] - img.shape[-1] % vit_patch_size 57 | img = img[..., :w, :h].float().to(device) 58 | # check if shape of image has 3 dimensions 59 | if len(img.shape) == 3: 60 | img = img[None] 61 | return img # B, C, W, H 62 | 63 | def get_name(log_bin=True, img_size='518', up_ft_index='1', **kwargs): 64 | name = 'dinov2lora' 65 | if log_bin: 66 | name = 'dinov2lora_logbin' 67 | # Handle both string and list inputs for img_size 68 | if isinstance(img_size, list): 69 | img_size_val = img_size[0] 70 | else: 71 | img_size_val = img_size 72 | # Handle up_ft_index conversion 73 | if hasattr(up_ft_index, '__iter__') and not isinstance(up_ft_index, str): 74 | up_ft_index_val = up_ft_index[0] if len(up_ft_index) > 0 else 1 75 | else: 76 | up_ft_index_val = up_ft_index 77 | # Convert to int, handling ListConfig objects 78 | try: 79 | img_size_int = int(img_size_val) 80 | up_ft_index_int = int(up_ft_index_val) 81 | except (ValueError, TypeError): 82 | # Fallback values if conversion fails 83 | img_size_int = 518 84 | up_ft_index_int = 1 85 | return name + f'_%d_upft%d' % (img_size_int, up_ft_index_int) 86 | 87 | class DINOv2LoRAFeaturizer(nn.Module): 88 | name = 'dinov2lora' 89 | def __init__(self, log_bin=True, img_size='518', up_ft_index='1', **kwargs): 90 | super().__init__() 91 | self.model_size = kwargs.get('model_size', "dinov2_vits14") 92 | self.model = torch.hub.load("facebookresearch/dinov2", self.model_size) 93 | self.model = self.model.to(device) # type: ignore 94 | for param in self.model.parameters(): 95 | param.requires_grad = False 96 | self._use_log_bin = log_bin 97 | self.name = get_name(log_bin, img_size, up_ft_index, **kwargs) 98 | self.vit_patch_size = 14 99 | 100 | self.lora_layers = list(range(len(self.model.blocks))) 101 | self.w_a = [] 102 | self.w_b = [] 103 | self.r = kwargs.get('lora_rank', 10) 104 | 105 | for i, block in enumerate(self.model.blocks): 106 | if i not in self.lora_layers: 107 | continue 108 | w_qkv_linear = block.attn.qkv 109 | dim = w_qkv_linear.in_features 110 | 111 | w_a_linear_q, w_b_linear_q = self._create_lora_layer(dim, self.r) 112 | w_a_linear_v, w_b_linear_v = self._create_lora_layer(dim, self.r) 113 | 114 | self.w_a.extend([w_a_linear_q, w_a_linear_v]) 115 | self.w_b.extend([w_b_linear_q, w_b_linear_v]) 116 | 117 | block.attn.qkv = LoRA( 118 | w_qkv_linear, 119 | w_a_linear_q, 120 | w_b_linear_q, 121 | w_a_linear_v, 122 | w_b_linear_v, 123 | ) 124 | self._reset_lora_parameters() 125 | 126 | def _create_lora_layer(self, dim: int, r: int): 127 | w_a = nn.Linear(dim, r, bias=False).to(device) 128 | w_b = nn.Linear(r, dim, bias=False).to(device) 129 | return w_a, w_b 130 | 131 | def _reset_lora_parameters(self) -> None: 132 | for w_a in self.w_a: 133 | nn.init.kaiming_uniform_(w_a.weight, a=math.sqrt(5)) 134 | for w_b in self.w_b: 135 | nn.init.zeros_(w_b.weight) 136 | 137 | def get_models(self): 138 | return [self] 139 | 140 | def prepare_tokens(self, img_tensor, up_ft_index): 141 | B, C, W, H = img_tensor.shape 142 | # out_dict = self.model(pixel_values=img_tensor, return_dict=True, output_hidden_states=True) 143 | #https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/models/dinov2/modeling_dinov2.py#L661 144 | # out = out_dict.hidden_states[-up_ft_index] # B, w0*h0, D 145 | out = self.model.get_intermediate_layers(img_tensor, n=up_ft_index) # type: ignore 146 | out = out[0] # take the output of the the n-th last block 147 | out = out[:, :, :] # B, w0*h0, D 148 | D = out.shape[-1] 149 | out = out.transpose(-2, -1).view(B, D, self.w0, self.h0) 150 | return out 151 | 152 | def forward(self, img, up_ft_index=3, **kwargs): 153 | # convert pil to tensor 154 | img_size = kwargs.get('img_size') 155 | img_tensor = prepare_ViT_images(self.vit_patch_size, img, img_size) # [B, C, W, H] 156 | self.w0 = img_tensor.shape[-2] // self.vit_patch_size 157 | self.h0 = img_tensor.shape[-1] // self.vit_patch_size 158 | out = self.prepare_tokens(img_tensor, up_ft_index) # B, D, w0, h0 159 | if self._use_log_bin: 160 | out = self._log_bin(out.permute(0,2,3,1)) 161 | out = out.view(-1,self.w0,self.h0,out.shape[-1]) # B, w0, h0, D' 162 | out = out.permute(0,3,1,2) # B, D', w0, h0 163 | return out 164 | 165 | def get_w0_h0(self): 166 | return self.w0, self.h0 167 | 168 | def _log_bin(self, x: torch.Tensor, hierarchy: int = 2) -> torch.Tensor: 169 | """ 170 | create a log-binned descriptor. 171 | :param x: tensor of features. Has shape Bxwxhxd. 172 | :param hierarchy: how many bin hierarchies to use. 173 | """ 174 | B, w0, h0, d = x.shape 175 | num_bins = 1 + self.vit_patch_size * hierarchy 176 | bin_x = x.permute(0, 3, 1, 2) # B, d, w0, h0 177 | sub_desc_dim = bin_x.shape[1] 178 | 179 | avg_pools = [] 180 | # compute bins of all sizes for all spatial locations. 181 | for k in range(0, hierarchy): 182 | # avg pooling with kernel 3**kx3**k 183 | win_size = 3 ** k 184 | avg_pool = torch.nn.AvgPool2d(win_size, stride=1, padding=win_size // 2, count_include_pad=False) 185 | avg_pools.append(avg_pool(bin_x)) 186 | 187 | bin_x = torch.zeros((B, sub_desc_dim * num_bins, w0, h0)).to(x.device) 188 | for y in range(w0): 189 | for x in range(h0): 190 | part_idx = 0 191 | # fill all bins for a spatial location (y, x) 192 | for k in range(0, hierarchy): 193 | kernel_size = 3 ** k 194 | for i in range(y - kernel_size, y + kernel_size + 1, kernel_size): 195 | for j in range(x - kernel_size, x + kernel_size + 1, kernel_size): 196 | if i == y and j == x and k != 0: 197 | continue 198 | if 0 <= i < w0 and 0 <= j < h0: 199 | bin_x[:, part_idx * sub_desc_dim: (part_idx + 1) * sub_desc_dim, y, x] = avg_pools[k][ 200 | :, :, i, j] 201 | else: # handle padding in a more delicate way than zero padding 202 | temp_i = max(0, min(i, w0 - 1)) 203 | temp_j = max(0, min(j, h0 - 1)) 204 | bin_x[:, part_idx * sub_desc_dim: (part_idx + 1) * sub_desc_dim, y, x] = avg_pools[k][ 205 | :, :, temp_i, 206 | temp_j] 207 | part_idx += 1 208 | bin_x = bin_x.flatten(start_dim=-2, end_dim=-1).permute(0, 2, 1).unsqueeze(dim=1) 209 | # Bx1x(t-1)x(dxh) 210 | return bin_x 211 | 212 | def save_parameters(self, filename: str) -> None: 213 | """Save the LoRA weights and decoder weights to a .pt file 214 | 215 | Args: 216 | filename (str): Filename of the weights 217 | """ 218 | w_a, w_b = {}, {} 219 | w_a = {f"w_a_{i:03d}": self.w_a[i].weight.to('cpu') for i in range(len(self.w_a))} 220 | w_b = {f"w_b_{i:03d}": self.w_b[i].weight.to('cpu') for i in range(len(self.w_a))} 221 | 222 | torch.save({**w_a, **w_b}, filename) 223 | 224 | def load_parameters(self, filename: str) -> None: 225 | """Load the LoRA and decoder weights from a file 226 | 227 | Args: 228 | filename (str): File name of the weights 229 | """ 230 | state_dict = torch.load(filename) 231 | 232 | # Load the LoRA parameters 233 | for i, w_A_linear in enumerate(self.w_a): 234 | saved_key = f"w_a_{i:03d}" 235 | saved_tensor = state_dict[saved_key].to(device) 236 | w_A_linear.weight = nn.Parameter(saved_tensor) 237 | 238 | for i, w_B_linear in enumerate(self.w_b): 239 | saved_key = f"w_b_{i:03d}" 240 | saved_tensor = state_dict[saved_key].to(device) 241 | w_B_linear.weight = nn.Parameter(saved_tensor) -------------------------------------------------------------------------------- /src/models/featurizer/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from src.models.featurizer.dinov2_lora import DINOv2LoRAFeaturizer 4 | from pathlib import Path 5 | import wandb 6 | import os 7 | 8 | def get_featurizer(featurizer_args): 9 | load_pretrained = 'init' in featurizer_args 10 | if not load_pretrained: 11 | if featurizer_args.model == 'dift_sd': 12 | from src.models.featurizer.dift_sd import SDFeaturizer4Eval 13 | featurizer = SDFeaturizer4Eval(cat_list=featurizer_args.all_cats) 14 | elif featurizer_args.model == 'dift_adm': 15 | from src.models.featurizer.dift_adm import ADMFeaturizer4Eval 16 | featurizer = ADMFeaturizer4Eval() 17 | elif featurizer_args.model == 'open_clip': 18 | from src.models.featurizer.clip import CLIPFeaturizer 19 | featurizer = CLIPFeaturizer() 20 | elif featurizer_args.model == 'dino': 21 | from src.models.featurizer.dino import DINOFeaturizer 22 | featurizer = DINOFeaturizer(**featurizer_args) 23 | elif featurizer_args.model == 'dinov2': 24 | from src.models.featurizer.dinov2 import DINOv2Featurizer 25 | featurizer = DINOv2Featurizer(**featurizer_args) 26 | elif featurizer_args.model == 'dinov2lora': 27 | from src.models.featurizer.dinov2_lora import DINOv2LoRAFeaturizer 28 | featurizer = DINOv2LoRAFeaturizer(**featurizer_args) 29 | elif featurizer_args.model == 'sd15ema_dinov2': 30 | from src.models.featurizer.sd15ema_dinov2 import sd15ema_dinov2_Featurizer 31 | featurizer = sd15ema_dinov2_Featurizer(**featurizer_args) 32 | else: 33 | raise ValueError('featurizer model not supported') 34 | else: 35 | featurizer = load_checkpoint(featurizer_args) 36 | return featurizer 37 | 38 | def load_checkpoint(args): 39 | if args.model == 'dinov2lora': 40 | model = DINOv2LoRAFeaturizer(**args) 41 | model_out_path = Path(args.model_out_path).joinpath(args.init.id) 42 | name = "last_lora_weights.pth" if 'eval_last' in args.init else "best_lora_weights.pth" 43 | model_out_path_new = model_out_path.joinpath(name) 44 | model.load_parameters(str(model_out_path_new)) 45 | else: 46 | raise ValueError 47 | 48 | model.id = args.init.id 49 | model.name = args.init.id 50 | return model 51 | 52 | def load_checkpoint_old(args): 53 | model_out_path = Path(args.model_out_path).joinpath(args.init.id) 54 | name = "last.pth" if 'eval_last' in args.init else "best.pth" 55 | model_out_path_new = model_out_path.joinpath(name) 56 | if args.model == 'dinov2lora': 57 | featurizer = DINOv2LoRAFeaturizer(**args) 58 | try: 59 | featurizer.load_state_dict(torch.load(model_out_path_new)) 60 | except: 61 | featurizer = torch.load(model_out_path_new) 62 | featurizer.id = args.init.id 63 | featurizer.name = args.init.id 64 | return featurizer 65 | 66 | def save_checkpoint(model, args, name = "lora_weights"): 67 | id = wandb.run.id if wandb.run else model.id 68 | model_out_path = Path(args.model_out_path).joinpath(id) 69 | os.makedirs(model_out_path, exist_ok=True) 70 | model_out_path_new = model_out_path.joinpath(name+".pth") 71 | # save state dict 72 | model.save_parameters(model_out_path_new) 73 | 74 | def get_featurizer_name(featurizer_args): 75 | if featurizer_args.model == 'dift_sd': 76 | from src.models.featurizer.dift_sd import get_name 77 | name = get_name(cat_list=featurizer_args.all_cats) 78 | elif featurizer_args.model == 'dift_adm': 79 | from src.models.featurizer.dift_adm import get_name 80 | name = get_name() 81 | elif featurizer_args.model == 'open_clip': 82 | from src.models.featurizer.clip import get_name 83 | name = get_name() 84 | elif featurizer_args.model == 'dino': 85 | from src.models.featurizer.dino import get_name 86 | name = get_name(**featurizer_args) 87 | elif featurizer_args.model == 'dinov2': 88 | from src.models.featurizer.dinov2 import get_name 89 | name = get_name(**featurizer_args) 90 | elif featurizer_args.model == 'dinov2lora': 91 | from src.models.featurizer.dinov2_lora import get_name 92 | name = get_name(**featurizer_args) 93 | elif featurizer_args.model == 'sd15ema_dinov2': 94 | from src.models.featurizer.sd15ema_dinov2 import get_name 95 | name = get_name(**featurizer_args) 96 | else: 97 | raise ValueError('featurizer model not supported') 98 | return name -------------------------------------------------------------------------------- /src/models/featurizer_refine/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .feat_refine_geosc import AggregationNetwork 3 | from src.dataset.cub_200 import CUBDatasetBordersCut 4 | import torch 5 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 6 | from pathlib import Path 7 | import os 8 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 9 | import wandb 10 | import os 11 | def get_init_dataset(args): 12 | dataset_pca = CUBDatasetBordersCut(args.init.dataset, split='train') 13 | from src.models.featurizer.utils import get_featurizer 14 | if args.init.featurizer.model == 'dift_sd': 15 | args.init.featurizer.all_cats = dataset_pca.all_cats 16 | featurizer = get_featurizer(args.init.featurizer) 17 | dataset_pca.init_kps_cat(args.init.dataset.cat) 18 | dataset_pca.featurizer = featurizer 19 | dataset_pca.featurizer_kwargs = args.init.featurizer 20 | return dataset_pca 21 | 22 | def get_model_refine(args): 23 | load_pretrained = 'init' in args 24 | if not load_pretrained: 25 | if args.model == 'GeoSc': 26 | model_refine = AggregationNetwork(feature_dims=args.feature_dims, projection_dim=args.projection_dim, device='cuda', feat_map_dropout=args.feat_map_dropout) 27 | model_refine.id = wandb.run.id if wandb.run!=None else "" 28 | else: 29 | model_refine = load_checkpoint(args) 30 | return model_refine 31 | 32 | def load_checkpoint(args): 33 | model_out_path = Path(args.model_out_path).joinpath(args.init.id) 34 | name = "last.pth" if 'eval_last' in args.init else "best.pth" 35 | name = "{}.pth".format(args.init.epoch) if args.model == 'GeoSc' else name 36 | model_out_path_new = model_out_path.joinpath(name) 37 | if args.model == 'GeoSc': 38 | # check if the model exists 39 | if not model_out_path_new.exists(): 40 | # wget the model 41 | os.mkdir(model_out_path) 42 | os.system(f"wget -O {model_out_path_new} {args.init.url}") 43 | model_refine = AggregationNetwork(feature_dims=args.feature_dims, projection_dim=args.projection_dim, device='cuda') 44 | try: 45 | model_refine.load_state_dict(torch.load(model_out_path_new)) 46 | except: 47 | model_refine = torch.load(model_out_path_new) 48 | model_refine.id = args.init.id 49 | return model_refine 50 | 51 | def save_checkpoint(model_refine, args, name): 52 | id = wandb.run.id if wandb.run else "test" 53 | model_out_path = Path(args.model_out_path).joinpath(id) 54 | os.makedirs(model_out_path, exist_ok=True) 55 | model_out_path_new = model_out_path.joinpath(name+".pth") 56 | # save state dict 57 | torch.save(model_refine.state_dict(), model_out_path_new) 58 | -------------------------------------------------------------------------------- /src/models/featurizer_refine/feat_refine_geosc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from .resnet import ResNet, BottleneckBlock 5 | import torch.nn.functional as F 6 | 7 | class DummyAggregationNetwork(nn.Module): # for testing, return the input 8 | def __init__(self): 9 | super(DummyAggregationNetwork, self).__init__() 10 | # dummy paprameter 11 | self.dummy = nn.Parameter(torch.ones([])) 12 | def forward(self, batch, pose=None): 13 | return batch * self.dummy 14 | 15 | class AggregationNetwork(nn.Module): 16 | """ 17 | Module for aggregating feature maps across time and space. 18 | Design inspired by the Feature Extractor from ODISE (Xu et. al., CVPR 2023). 19 | https://github.com/NVlabs/ODISE/blob/5836c0adfcd8d7fd1f8016ff5604d4a31dd3b145/odise/modeling/backbone/feature_extractor.py 20 | """ 21 | def __init__( 22 | self, 23 | device, 24 | feature_dims=[640,1280,1280,768], 25 | projection_dim=384, 26 | num_norm_groups=32, 27 | save_timestep=[1], 28 | kernel_size = [1,3,1], 29 | contrastive_temp = 10, 30 | feat_map_dropout=0.0, 31 | ): 32 | super().__init__() 33 | self.skip_connection = True 34 | self.feat_map_dropout = feat_map_dropout 35 | self.azimuth_embedding = None 36 | self.pos_embedding = None 37 | self.bottleneck_layers = nn.ModuleList() 38 | self.feature_dims = feature_dims 39 | # For CLIP symmetric cross entropy loss during training 40 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 41 | self.self_logit_scale = nn.Parameter(torch.ones([]) * np.log(contrastive_temp)) 42 | self.device = device 43 | self.save_timestep = save_timestep 44 | 45 | self.mixing_weights_names = [] 46 | for l, feature_dim in enumerate(self.feature_dims): 47 | bottleneck_layer = nn.Sequential( 48 | *ResNet.make_stage( 49 | BottleneckBlock, 50 | num_blocks=1, 51 | in_channels=feature_dim, 52 | bottleneck_channels=projection_dim // 4, 53 | out_channels=projection_dim, 54 | norm="GN", 55 | num_norm_groups=num_norm_groups, 56 | kernel_size=kernel_size 57 | ) 58 | ) 59 | self.bottleneck_layers.append(bottleneck_layer) 60 | for t in save_timestep: 61 | # 1-index the layer name following prior work 62 | self.mixing_weights_names.append(f"timestep-{save_timestep}_layer-{l+1}") 63 | self.last_layer = None 64 | self.bottleneck_layers = self.bottleneck_layers.to(device) 65 | mixing_weights = torch.ones(len(self.bottleneck_layers) * len(save_timestep)) 66 | self.mixing_weights = nn.Parameter(mixing_weights.to(device)) 67 | # count number of parameters 68 | num_params = 0 69 | for param in self.parameters(): 70 | num_params += param.numel() 71 | print(f"AggregationNetwork has {num_params} parameters.") 72 | 73 | def load_pretrained_weights(self, pretrained_dict): 74 | custom_dict = self.state_dict() 75 | 76 | # Handle size mismatch 77 | if 'mixing_weights' in custom_dict and 'mixing_weights' in pretrained_dict and custom_dict['mixing_weights'].shape != pretrained_dict['mixing_weights'].shape: 78 | # Keep the first four weights from the pretrained model, and randomly initialize the fifth weight 79 | custom_dict['mixing_weights'][:4] = pretrained_dict['mixing_weights'][:4] 80 | custom_dict['mixing_weights'][4] = torch.zeros_like(custom_dict['mixing_weights'][4]) 81 | else: 82 | custom_dict['mixing_weights'][:4] = pretrained_dict['mixing_weights'][:4] 83 | 84 | # Load the weights that do match 85 | matching_keys = {k: v for k, v in pretrained_dict.items() if k in custom_dict and k != 'mixing_weights'} 86 | custom_dict.update(matching_keys) 87 | 88 | # Now load the updated state_dict 89 | self.load_state_dict(custom_dict, strict=False) 90 | 91 | def forward(self, batch, pose=None): 92 | """ 93 | Assumes batch is shape (B, C, H, W) where C is the concatentation of all layer features. 94 | """ 95 | if self.feat_map_dropout > 0 and self.training: 96 | batch = F.dropout(batch, p=self.feat_map_dropout) 97 | 98 | output_feature = None 99 | start = 0 100 | mixing_weights = torch.nn.functional.softmax(self.mixing_weights, dim=0) 101 | if self.pos_embedding is not None: #position embedding 102 | batch = torch.cat((batch, self.pos_embedding), dim=1) 103 | for i in range(len(mixing_weights)): 104 | # Share bottleneck layers across timesteps 105 | bottleneck_layer = self.bottleneck_layers[i % len(self.feature_dims)] 106 | # Chunk the batch according the layer 107 | # Account for looping if there are multiple timesteps 108 | end = start + self.feature_dims[i % len(self.feature_dims)] 109 | feats = batch[:, start:end, :, :] 110 | start = end 111 | # Downsample the number of channels and weight the layer 112 | bottlenecked_feature = bottleneck_layer(feats) 113 | bottlenecked_feature = mixing_weights[i] * bottlenecked_feature 114 | if output_feature is None: 115 | output_feature = bottlenecked_feature 116 | else: 117 | output_feature += bottlenecked_feature 118 | 119 | if self.last_layer is not None: 120 | 121 | output_feature_after = self.last_layer(output_feature) 122 | if self.skip_connection: 123 | # skip connection 124 | output_feature = output_feature + output_feature_after 125 | 126 | 127 | norms_ft_new = torch.linalg.norm(output_feature, dim=1, keepdim=True) 128 | output_feature = output_feature / (norms_ft_new + 1e-8) 129 | return output_feature 130 | 131 | 132 | def conv1x1(in_planes, out_planes, stride=1): 133 | """1x1 convolution without padding""" 134 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) 135 | 136 | 137 | def conv3x3(in_planes, out_planes, stride=1): 138 | """3x3 convolution with padding""" 139 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 140 | 141 | 142 | class BasicBlock(nn.Module): 143 | def __init__(self, in_planes, planes, stride=1): 144 | super().__init__() 145 | self.conv1 = conv3x3(in_planes, planes, stride) 146 | self.conv2 = conv3x3(planes, planes) 147 | self.bn1 = nn.BatchNorm2d(planes) 148 | self.bn2 = nn.BatchNorm2d(planes) 149 | self.relu = nn.ReLU(inplace=True) 150 | 151 | if stride == 1: 152 | self.downsample = None 153 | else: 154 | self.downsample = nn.Sequential( 155 | conv1x1(in_planes, planes, stride=stride), 156 | nn.BatchNorm2d(planes) 157 | ) 158 | 159 | def forward(self, x): 160 | y = x 161 | y = self.relu(self.bn1(self.conv1(y))) 162 | y = self.bn2(self.conv2(y)) 163 | 164 | if self.downsample is not None: 165 | x = self.downsample(x) 166 | 167 | return self.relu(x+y) 168 | -------------------------------------------------------------------------------- /src/models/pca.py: -------------------------------------------------------------------------------- 1 | from sklearn.decomposition import PCA 2 | import torch 3 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 4 | 5 | def compute_pca(x, n_components): 6 | feature_dim = x.shape[-1] 7 | x = x.reshape(-1,feature_dim) 8 | pca = PCA(n_components=n_components).fit(x.cpu().numpy()) 9 | components = pca.components_[None, ...] 10 | components = components.reshape(-1,feature_dim) 11 | return torch.tensor(components).to(device) 12 | 13 | 14 | -------------------------------------------------------------------------------- /src/models/segmentation/sam.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import torch.nn as nn 4 | # Ensure the segment_anything package is accessible. 5 | sys.path.append("..") 6 | """pip install git+https://github.com/facebookresearch/segment-anything.git""" 7 | from segment_anything import sam_model_registry, SamPredictor 8 | device = "cuda" 9 | import numpy as np 10 | 11 | def get_name(): 12 | return 'sam' 13 | class SAM(nn.Module): 14 | name = "sam" 15 | def __init__(self, args): 16 | super().__init__() 17 | """wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth""" 18 | sam_checkpoint = args.path_model_seg 19 | model_type = "vit_h" 20 | sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=device) 21 | self.sam = SamPredictor(sam) 22 | return 23 | 24 | def forward(self, im, bbox=None, kps=None): 25 | kps = kps[:, [1,0,2]] 26 | self.sam.set_image(np.array(im)) 27 | if bbox is not None: 28 | input_box = np.array(bbox)[None, :] 29 | else: 30 | input_box = None 31 | prt, _, _ = self.sam.predict(box=input_box, multimask_output=False, point_coords=kps[:,:2].cpu().numpy(), point_labels=kps[:,2].cpu().numpy()) 32 | prt = torch.tensor(prt[0]) 33 | return prt --------------------------------------------------------------------------------