├── .gitignore
├── README.md
├── assets
├── birds_teaser.png
├── src_01_5_crop.png
├── src_10_2_crop.png
├── trg_01_5_crop_blend.gif
└── trg_10_2_crop_blend.gif
├── configs
├── classifier
│ └── nearest_centroid.yaml
├── dataset
│ ├── apk.yaml
│ ├── cub.yaml
│ ├── pascalparts.yaml
│ ├── pfpascal.yaml
│ └── spair.yaml
├── demo_data_loading.yaml
├── demo_keypoint_transfer.yaml
├── demo_segmenation_nearest_centroid.yaml
├── eval_all.yaml
├── eval_seg.yaml
├── eval_time_mem.yaml
├── feat_refine
│ └── geosc.yaml
├── featurizer
│ ├── dift_adm.yaml
│ ├── dift_sd.yaml
│ ├── dino.yaml
│ ├── dinov2.yaml
│ ├── dinov2lora.yaml
│ ├── open_clip.yaml
│ └── sd15ema_dinov2.yaml
├── store_feats.yaml
├── store_masks.yaml
├── train_pairs.yaml
└── train_setting_default.yaml
├── demo_data_loading.ipynb
├── demo_keypoint_transfer.ipynb
├── demo_segmenation_nearest_centroid.ipynb
├── download_data
├── prepare_apk.sh
├── prepare_cub.sh
├── prepare_pascalparts.sh
├── prepare_pfpascal.sh
├── prepare_spair.sh
├── spair_keypoint_l_r_permutation.csv
├── spair_keypoint_names.csv
└── spair_keypoint_permutation.csv
├── pretrained_weights
└── geco
│ └── last_lora_weights.pth
├── scripts
├── run_evaluation_all.py
├── run_evaluation_seg.py
├── run_evaluation_time_mem.py
├── store_feats.py
├── store_masks.py
└── train_pairs.py
├── setup_env.sh
└── src
├── __init__.py
├── dataset
├── apk_pairs.py
├── augmentations.py
├── cub_200.py
├── cub_200_pairs.py
├── pairwise_utils.py
├── pascalparts.py
├── pascalparts_part2ind.py
├── pfpascal_pairs.py
├── random_utils.py
├── spair.py
├── spair_single.py
└── utils.py
├── evaluation
├── __init__.py
├── evaluation.py
├── pck.py
├── runtime_mem.py
├── segmentation.py
└── segmentation_metrics.py
├── logging
├── log_results.py
├── visualization.py
├── visualization_pck.py
└── visualization_seg.py
├── losses
└── __init__.py
├── matcher
├── __init__.py
├── argmaxmatcher.py
└── ot_matcher.py
└── models
├── __init__.py
├── classifier
├── supervised
│ └── nearest_centroid.py
└── utils.py
├── featurizer
├── clip.py
├── dift_adm.py
├── dift_sd.py
├── dino.py
├── dinov2.py
├── dinov2_lora.py
├── sd15ema_dinov2.py
└── utils.py
├── featurizer_refine
├── __init__.py
├── feat_refine_geosc.py
└── resnet.py
├── pca.py
└── segmentation
└── sam.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | *.png
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
162 | */.DS_Store
163 | .DS_Store
164 |
165 | guided-diffusion/
166 | davis_results_sd/
167 | davis_results_adm/
168 | superpoint-1k/
169 | hpatches_results/
170 | superpoint-1k.zip
171 | SPair-71k.tar.gz
172 | SPair-71k/
173 | ./guided-diffusion/models/256x256_diffusion_uncond.pt
174 | wandb/
175 | outputs/
176 | feats/
177 | spair/
178 | pretrained_weights/
179 | runs/
180 | pretrained_weights_backup/
181 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 🦎 GECO: Geometrically Consistent Embedding with Lightspeed Inference (ICCV 2025)
2 |
3 |
4 | [ 🌐**Project Page**](https://reginehartwig.github.io/publications/geco/) • [📄 **Paper**](https://arxiv.org/pdf/2508.00746)
5 |
6 |
7 |

8 |
9 |
10 |
11 |
12 | This is the official repository for the ICCV 2025 paper:
13 |
14 | > **GECO: Geometrically Consistent Embedding with Lightspeed Inference**.
15 | >
16 | >[Regine Hartwig](https://reginehartwig.github.io/)1,2,
17 | >[Dominik Muhle](https://dominikmuhle.github.io/)1,2,
18 | >[Riccardo Marin](https://ricma.netlify.app/)1,2,
19 | >[Daniel Cremers](https://cvg.cit.tum.de/members/cremers)1,2,
20 | >
21 | > 1Technical University of Munich, 2MCML
22 | >
23 | > [**ICCV 2025** (arXiv)](https://arxiv.org/pdf/2508.00746)
24 |
25 |
26 | If you find our work useful, please consider citing our paper:
27 | ```
28 | @inproceedings{hartwig2025geco,
29 | title={GECO: Geometrically Consistent Embedding with Lightspeed Inference},
30 | author={Hartwig, Regine and Muhle, Dominik and Marin, Riccardo and Cremers, Daniel},
31 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision},
32 | year={2025}
33 | }
34 | ```
35 |
36 | ## Intro
37 |
38 | We address the task of geometry-aware feature encoding. A common way to test geometric awareness is through keypoint matching: Given a source image with an annotated keypoint, the goal is to predict the keypoint in the target image by selecting the location with the highest feature similarity. Below are two examples:
39 |
40 |
41 |  |
42 |  |
43 |  |
44 |  |
45 |
46 |
47 |
48 | We introduce a training paradigm and a lightweight architecture for learning from image pairs with sparsely annotated keypoints. Additionally, we enhance the testing of features by introducing subdivisions of the commonly used PCK metric and a centroid clustering approach for more densely evaluating the feature space.
49 | ## 🔧 Environment Setup
50 | If you're using a Linux machine, set up the Python environment with:
51 | ```bash
52 | conda create --name geco python=3.10
53 | conda activate geco
54 | bash setup_env.sh
55 | ```
56 | To use [Segment Anything (SAM)](https://github.com/facebookresearch/segment-anything) for mask extraction:
57 |
58 | ```bash
59 | pip install git+https://github.com/facebookresearch/segment-anything.git
60 | wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
61 |
62 | ```
63 |
64 | Install odise in case you want to run the Geo baseline
65 |
66 | ```bash
67 | git clone git@github.com:NVlabs/ODISE.git
68 | cd ODISE
69 | pip install -e .
70 | ```
71 |
72 | ## 🚀 Get Started
73 |
74 | ### 📁 Prepare the Datasets
75 |
76 | Run the following scripts to prepare each dataset:
77 | * APK:
78 | ```bash
79 | bash download_data/prepare_apk.sh
80 | wget https://github.com/Junyi42/GeoAware-SC/blob/master/prepare_ap10k.ipynb
81 | ```
82 | Then run the notebook from GeoAware-SC to preprocess the data.
83 | * CUB:
84 | ```bash
85 | bash download_data/prepare_cub.sh
86 | ```
87 | * PascalParts
88 | ```bash
89 | bash download_data/prepare_pascalparts.sh
90 | ```
91 | * PFPascal
92 | ```bash
93 | bash download_data/prepare_pfpascal.sh
94 | ```
95 | * SPair-71k:
96 | ```bash
97 | bash download_data/prepare_spair.sh
98 | ```
99 |
100 | ### Extract the mask
101 | 1. Define `` in the dataset config files.
102 | 2. Define `` in `store_masks.yaml` pointing to the path, where `sam_vit_h_4b8939.pth`is stored.
103 | 3. Select the datasets to process in `store_masks.yaml`.
104 | 4. Run:
105 | ```bash
106 | python scripts/store_masks.py --config-name=store_masks.yaml
107 | ```
108 |
109 | ### Precompute the features (not recommended)
110 | Choose a path and define `` in the dataset config files.
111 | Define which dataset you want to extract the features for in `store_feats.yaml` and run
112 | ```bash
113 | python scripts/store_feats.py --config-name=store_feats.yaml
114 | ```
115 |
116 |
117 | ## 🎯 Pretrained Weights
118 | Pretrained weights are available in ```pretrained_weights/geco```.
119 |
120 | ## 🧪 Interactive Demos: Give it a Try!
121 |
122 | We provide interactive jupyter notebooks for testing.
123 |
124 | * [📚 Data Loading Demo](demo_data_loading.ipynb)
125 |
126 | Validate dataset preparation and path setup.
127 |
128 | * [🎨 Segmentation Demo](demo_segmenation_nearest_centroid.ipynb)
129 |
130 | Visualize part segmentation using a simple linear classifier.
131 |
132 | * [📍 Keypoint Transfer Demo](demo_keypoint_transfer.ipynb)
133 |
134 | Explore keypoint transfer and interactive attention maps.
135 |
136 |
137 | ## 📊 Run Evaluation
138 |
139 | Run full evaluation:
140 | ```bash
141 | python scripts/run_evaluation_all.py --config-name=eval_all.yaml
142 | ```
143 | Evaluate inference time and memory usage:
144 | ```bash
145 | python scripts/run_evaluation_time_mem.py --config-name=eval_time_mem.yaml
146 | ```
147 | Evaluate segmentation metrics:
148 | ```bash
149 | python scripts/run_evaluation_seg.py --config-name=eval_seg.yaml
150 | ```
151 |
152 |
153 | ## 🏋️ Train the Model
154 | Before training, comment out the following block in `configs/featurizer/dinov2lora.yaml`:
155 |
156 | ```yaml
157 | init:
158 | id: geco
159 | eval_last: True
160 | ```
161 | Then run:
162 | ```bash
163 | python scripts/train_pairs.py --config-name=train_pairs.yaml
164 | ```
--------------------------------------------------------------------------------
/assets/birds_teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/reginehartwig/geco/3d3f37530bd61f36602f8ac840da13200c3c2e72/assets/birds_teaser.png
--------------------------------------------------------------------------------
/assets/src_01_5_crop.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/reginehartwig/geco/3d3f37530bd61f36602f8ac840da13200c3c2e72/assets/src_01_5_crop.png
--------------------------------------------------------------------------------
/assets/src_10_2_crop.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/reginehartwig/geco/3d3f37530bd61f36602f8ac840da13200c3c2e72/assets/src_10_2_crop.png
--------------------------------------------------------------------------------
/assets/trg_01_5_crop_blend.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/reginehartwig/geco/3d3f37530bd61f36602f8ac840da13200c3c2e72/assets/trg_01_5_crop_blend.gif
--------------------------------------------------------------------------------
/assets/trg_10_2_crop_blend.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/reginehartwig/geco/3d3f37530bd61f36602f8ac840da13200c3c2e72/assets/trg_10_2_crop_blend.gif
--------------------------------------------------------------------------------
/configs/classifier/nearest_centroid.yaml:
--------------------------------------------------------------------------------
1 | model: 'nearest_centroid_fg'
2 | num_pcaparts: 32
3 | num_samples: 100
--------------------------------------------------------------------------------
/configs/dataset/apk.yaml:
--------------------------------------------------------------------------------
1 | name: 'ap10k'
2 | dataset_path: '/ap-10k' #path to spair dataset
3 | save_path: ''
4 | save_path_masks: ''
5 | hydra:
6 | output_subdir: null
7 | sup: sup_augmented # ['sup_augmented', 'sup_original']
8 | flip_aug: True
9 | cat: all
10 | subsample: null
--------------------------------------------------------------------------------
/configs/dataset/cub.yaml:
--------------------------------------------------------------------------------
1 | name: 'cub'
2 | dataset_path: ''
3 | dataset_path_annotations: ''
4 | n_pairs: 10000
5 | save_path: ''
6 | save_path_masks: ''
7 | hydra:
8 | output_subdir: null
9 | sup: sup_augmented # ['sup_augmented', 'sup_original']
10 | borders_cut: True
11 | flip_aug: True
12 | cat: bird
--------------------------------------------------------------------------------
/configs/dataset/pascalparts.yaml:
--------------------------------------------------------------------------------
1 | name: 'pascalparts'
2 | dataset_path: '/PASCAL-VOC/' #path to spair dataset
3 | save_path: ''
4 | hydra:
5 | output_subdir: null
6 | sup: sup_original # ['sup_augmented', 'sup_original']
--------------------------------------------------------------------------------
/configs/dataset/pfpascal.yaml:
--------------------------------------------------------------------------------
1 | name: 'pfpascal'
2 | dataset_path: '/PF-dataset-PASCAL' #path to spair dataset
3 | save_path: ''
4 | save_path_masks: ''
5 | hydra:
6 | output_subdir: null
7 | sup: sup_original # ['sup_augmented', 'sup_original']
8 | cat: all
9 |
--------------------------------------------------------------------------------
/configs/dataset/spair.yaml:
--------------------------------------------------------------------------------
1 | name: 'spair'
2 | dataset_path: '/SPair-71k' #path to spair dataset
3 | save_path: ''
4 | save_path_masks: ''
5 | hydra:
6 | output_subdir: null
7 | sup: sup_augmented # ['sup_augmented', 'sup_original']
8 | flip_aug: True
9 | cat: all
10 |
--------------------------------------------------------------------------------
/configs/demo_data_loading.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - dataset/spair # ['apk', 'cub', 'spair', 'pfpascal']
3 | - featurizer/dinov2 # which model to use ['dift_sd', 'dift_adm', 'open_clip', 'dino', 'dinov2', 'sd15ema_dinov2', 'dinov2lora']
4 |
5 | model_seg_name: sam
--------------------------------------------------------------------------------
/configs/demo_keypoint_transfer.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - dataset/apk # ['apk', 'cub', 'spair', 'pfpascal']
3 | - featurizer/dinov2lora # which model to use ['dift_sd', 'dift_adm', 'open_clip', 'dino', 'dinov2', 'sd15ema_dinov2', 'dinov2lora']
4 | dataset:
5 | sup: sup_original # ['sup_augmented', 'sup_original']
--------------------------------------------------------------------------------
/configs/demo_segmenation_nearest_centroid.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - dataset/pascalparts
3 | - featurizer/dinov2lora # which model to use ['dift_sd', 'dift_adm', 'open_clip', 'dino', 'dinov2', 'sd15ema_dinov2', 'dinov2lora']
4 | - classifier/nearest_centroid@sup_classifier
--------------------------------------------------------------------------------
/configs/eval_all.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - dataset/spair # ['apk', 'cub', 'spair', 'pfpascal']
3 | - dataset/apk@dataset1
4 | - dataset/cub@dataset2
5 | - dataset/pfpascal@dataset3
6 | - featurizer/dinov2lora # which model to use can be one of ['dift_sd', 'dift_adm', 'open_clip', 'dino', 'dinov2', 'sd15ema_dinov2, 'dinov2lora']
7 |
8 | dataset:
9 | sup: sup_original # ['sup_augmented', 'sup_original']
10 | dataset1:
11 | sup: sup_original # ['sup_augmented', 'sup_original']
12 | dataset2:
13 | sup: sup_original # ['sup_augmented', 'sup_original']
14 | dataset3:
15 | sup: sup_original # ['sup_augmented', 'sup_original']
16 |
17 | # PCK
18 | upsample: True
19 | alpha_bbox: False
20 | n_pairs_eval_pck: 10000
21 |
--------------------------------------------------------------------------------
/configs/eval_seg.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - dataset/pascalparts
3 | - featurizer/dinov2lora
4 | - classifier/nearest_centroid@sup_classifier
--------------------------------------------------------------------------------
/configs/eval_time_mem.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - dataset/apk # ['apk', 'cub', 'spair', 'pfpascal']
3 | - featurizer/dinov2lora # which model to use ['dift_sd', 'dift_adm', 'open_clip', 'dino', 'dinov2', 'sd15ema_dinov2', 'dinov2lora']
4 |
5 | num_imgs_time_mem: 1000
6 | dataset:
7 | sup: sup_original # ['sup_augmented', 'sup_original']
--------------------------------------------------------------------------------
/configs/feat_refine/geosc.yaml:
--------------------------------------------------------------------------------
1 | model: GeoSc
2 | model_out_path: ''
3 | feature_dims: [640,1280,1280,768]
4 | projection_dim: 768
5 | # in_dim: 3968
6 | feat_map_dropout: 0.2
7 |
8 | init:
9 | id: geosc
10 | epoch: 856
11 | url: "https://github.com/Junyi42/GeoAware-SC/blob/master/results_spair/best_856.PTH"
--------------------------------------------------------------------------------
/configs/featurizer/dift_adm.yaml:
--------------------------------------------------------------------------------
1 | model: 'dift_adm'
2 | img_size: [512, 512] #in the order of [width, height], resize input image to [w, h] before fed into diffusion model, if set to 0, will stick to the original input size.
3 | t: 101 #t for diffusion
4 | up_ft_index: 4 #which upsampling block to extract the ft map
5 | ensemble_size: 1 #ensemble size for getting an image ft map
--------------------------------------------------------------------------------
/configs/featurizer/dift_sd.yaml:
--------------------------------------------------------------------------------
1 | model: 'dift_sd'
2 | img_size: [768, 768] #in the order of [width, height], resize input image to [w, h] before fed into diffusion model, if set to 0, will stick to the original input size. by default is 768x768.
3 | t: 261 #t for diffusion
4 | up_ft_index: 2 #which upsampling block to extract the ft map
5 | ensemble_size: 1 #ensemble size for getting an image ft map
6 | all_cats: none
--------------------------------------------------------------------------------
/configs/featurizer/dino.yaml:
--------------------------------------------------------------------------------
1 | model: 'dino'
2 | img_size: [224, 224] #in the order of [width, height], resize input image to [w, h] before fed into the model, if set to 0, will stick to the original input size.
3 | up_ft_index: 4 #which upsampling block to extract the ft map
4 | log_bin: False
--------------------------------------------------------------------------------
/configs/featurizer/dinov2.yaml:
--------------------------------------------------------------------------------
1 | model: 'dinov2'
2 | model_size: 'dinov2_vitb14' # ['dinov2_vits14', 'dinov2_vitb14']
3 | img_size: [518, 518] #in the order of [width, height], resize input image to [w, h] before fed into the model, if set to 0, will stick to the original input size.
4 | up_ft_index: 1 #which upsampling block to extract the ft map
5 | log_bin: False
--------------------------------------------------------------------------------
/configs/featurizer/dinov2lora.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - dinov2
3 | model: 'dinov2lora'
4 | model_out_path: ''
5 | lora_rank: 10
6 |
7 | init:
8 | id: geco
9 | eval_last: True
--------------------------------------------------------------------------------
/configs/featurizer/open_clip.yaml:
--------------------------------------------------------------------------------
1 | model: 'open_clip'
2 | img_size: [512, 512] #in the order of [width, height], resize input image to [w, h] before fed into the model, if set to 0, will stick to the original input size.
3 | up_ft_index: 4 #which upsampling block to extract the ft map
4 | ensemble_size: 1 #ensemble size for getting an image ft map
5 | t: 101 #t for diffusion
--------------------------------------------------------------------------------
/configs/featurizer/sd15ema_dinov2.yaml:
--------------------------------------------------------------------------------
1 | model: 'sd15ema_dinov2'
2 | crop_feats: True
--------------------------------------------------------------------------------
/configs/store_feats.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - featurizer/sd15ema_dinov2 # which model to use ['dift_sd', 'dift_adm', 'open_clip', 'dino', 'dinov2', 'sd15ema_dinov2', 'dinov2lora']
3 | - dataset/cub # ['apk', 'cub', 'spair', 'pfpascal']
4 |
5 | dataset:
6 | split: 'train'
--------------------------------------------------------------------------------
/configs/store_masks.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - dataset: cub # ['apk', 'cub', 'spair', 'pfpascal']
3 |
4 | dataset:
5 | split: 'train'
6 | path_model_seg:
7 | model_seg_name: sam
--------------------------------------------------------------------------------
/configs/train_pairs.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - featurizer: dinov2lora # which model to use ['dift_sd', 'dift_adm', 'open_clip', 'dino', 'dinov2', 'sd15ema_dinov2', 'dinov2lora']
3 | - dataset: apk # ['apk', 'cub', 'spair', 'pfpascal']
4 | - dataset: spair@dataset2 # ['apk', 'cub', 'spair', 'pfpascal']
5 | - dataset: pfpascal@dataset3 # ['apk', 'cub', 'spair', 'pfpascal']
6 | - dataset: pfpascal@datasetgeneralization # ['apk', 'cub', 'spair', 'pfpascal']
7 | - dataset: spair@datasetgeneralization2 # ['apk', 'cub', 'spair', 'pfpascal']
8 | - dataset: cub@datasetgeneralization3 # ['apk', 'cub', 'spair', 'pfpascal']
9 | - train_setting_default
10 |
11 | # fix number of elements for each category in the dataset
12 | dataset:
13 | num_el: 800
14 | dataset2:
15 | num_el: 800
16 | dataset3:
17 | num_el: 800
18 |
19 | ## for evaluating PCK
20 | upsample: True
21 | alpha_bbox: False
22 | n_pairs_eval_pck: 10000
23 |
24 |
--------------------------------------------------------------------------------
/configs/train_setting_default.yaml:
--------------------------------------------------------------------------------
1 | learning_rate: 0.0001
2 | epoch: 8
3 | batch_size: 6
4 | weight_decay: 0
5 | scheduler: !!null
6 |
7 | # losses
8 | losses:
9 | pos: 1
10 | bin: 1
11 | neg: 10
12 | neg_fg_bkg: 10
13 | model_seg_name: sam
--------------------------------------------------------------------------------
/download_data/prepare_apk.sh:
--------------------------------------------------------------------------------
1 | CURRENT_DIR=$(pwd)
2 | mkdir -p
3 | cd
4 | gdown https://drive.google.com/uc?id=1-FNNGcdtAQRehYYkGY1y4wzFNg4iWNad
5 | unzip ap-10k.zip -d ap-10k
6 | rm ap-10k.zip
7 | cd ${CURRENT_DIR}
--------------------------------------------------------------------------------
/download_data/prepare_cub.sh:
--------------------------------------------------------------------------------
1 | CURRENT_DIR=$(pwd)
2 | mkdir -p
3 | cd
4 | wget https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz
5 | if [ -f CUB_200_2011.tgz ]; then
6 | echo "File downloaded successfully"
7 | else
8 | echo "File download failed"
9 | fi
10 | tar -xvzf CUB_200_2011.tgz
11 | if [ -d CUB_200_2011 ]; then
12 | echo "File extracted successfully"
13 | rm CUB_200_2011.tgz
14 | else
15 | echo "File extraction failed"
16 | fi
17 |
18 | mkdir -p
19 | cd
20 | apt-get install gdown
21 | gdown https://drive.google.com/drive/folders/1DnmpG8Owhv9Rmz_KFozqIKVTJCNSbWc9?usp=sharing
22 | unzip data.zip -C data-v2
23 | cd ${CURRENT_DIR}
--------------------------------------------------------------------------------
/download_data/prepare_pascalparts.sh:
--------------------------------------------------------------------------------
1 | CURRENT_DIR=$(pwd)
2 | mkdir -p
3 | cd
4 |
5 | mkdir PASCAL-VOC
6 | cd PASCAL-VOC
7 | wget http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar
8 | tar -xf VOCtrainval_03-May-2010.tar
9 | if [ -d VOCdevkit ]; then
10 | echo "File extracted successfully"
11 | rm VOCtrainval_03-May-2010.tar
12 | else
13 | echo "File extraction failed"
14 | fi
15 |
16 | mkdir Parts
17 | cd Parts
18 | wget https://roozbehm.info/pascal-parts/trainval.tar.gz
19 | tar -xf trainval.tar.gz
20 |
21 | if [ -d Annotations_Part ]; then
22 | echo "File extracted successfully"
23 | rm trainval.tar.gz
24 | else
25 | echo "File extraction failed"
26 | fi
27 | cd ${CURRENT_DIR}
--------------------------------------------------------------------------------
/download_data/prepare_pfpascal.sh:
--------------------------------------------------------------------------------
1 | CURRENT_DIR=$(pwd)
2 | mkdir -p
3 | cd
4 | wget http://www.di.ens.fr/willow/research/proposalflow/dataset/PF-dataset-PASCAL.zip
5 | wget http://www.di.ens.fr/willow/research/cnngeometric/other_resources/test_pairs_pf_pascal.csv
6 | wget http://www.di.ens.fr/willow/research/cnngeometric/other_resources/val_pairs_pf_pascal.csv
7 | gdown https://drive.google.com/uc?id=111tpXshLiJ4qudBHoGK3HbMSNr9vVRq9 # download the trin_pairs_pf_pascal.csv
8 | unzip PF-dataset-PASCAL.zip -d .
9 | rm PF-dataset-PASCAL.zip
10 | rm __MACOSX -r
11 | rm PF-dataset-PASCAL/Annotations/.DS_Store
12 | mv test_pairs_pf_pascal.csv PF-dataset-PASCAL
13 | mv val_pairs_pf_pascal.csv PF-dataset-PASCAL
14 | cd ${CURRENT_DIR}
--------------------------------------------------------------------------------
/download_data/prepare_spair.sh:
--------------------------------------------------------------------------------
1 | CURRENT_DIR=$(pwd)
2 | mkdir -p
3 | cd
4 | wget http://cvlab.postech.ac.kr/research/SPair-71k/data/SPair-71k.tar.gz
5 | tar -xf SPair-71k.tar.gz
6 | if [ -d SPair-71k ]; then
7 | echo "File extracted successfully"
8 | rm SPair-71k.tar.gz
9 | else
10 | echo "File extraction failed"
11 | fi
12 | cp ${CURRENT_DIR}/spair*.csv /SPair-71k
13 | cd ${CURRENT_DIR}
--------------------------------------------------------------------------------
/download_data/spair_keypoint_l_r_permutation.csv:
--------------------------------------------------------------------------------
1 | "kp_id ",aeroplane,bicycle,bird,boat,bottle,bus,car,cat,chair,cow,dog,horse,motorbike,person,pottedplant,sheep,train,tvmonitor
2 | "0",0,1,0,0,1,1,1,1,1,1,1,1,1,1,2,1,1,2
3 | "1",1,0,2,2,0,0,0,0,0,0,0,0,0,0,3,0,0,1
4 | "2",2,3,1,1,3,3,3,3,3,3,3,3,3,3,0,3,3,0
5 | "3",3,2,3,4,2,2,2,2,2,2,2,2,2,2,1,2,2,7
6 | "4",5,4,5,3,5,7,8,5,5,5,5,5,5,4,5,5,5,6
7 | "5",4,5,4,6,4,6,9,4,4,4,4,4,4,5,4,4,4,5
8 | "6",7,7,6,5,7,5,7,7,7,7,6,7,7,6,8,7,7,4
9 | "7",6,6,8,8,6,4,6,6,6,6,7,6,6,7,7,6,6,3
10 | "8",9,8,7,7,9,8,4,8,9,8,8,8,8,9,6,8,9,10
11 | "9",8,10,9,10,8,9,5,10,8,10,10,9,9,8,,10,8,9
12 | "10",11,9,11,9,,20,20,9,11,9,9,11,10,11,,9,11,8
13 | "11",10,11,10,12,,21,21,12,10,12,12,10,12,10,,12,10,15
14 | "12",13,12,13,11,,22,22,11,13,11,11,13,11,13,,11,13,14
15 | "13",12,13,12,13,,23,23,13,12,13,13,12,,12,,13,12,13
16 | "14",15,,15,,,24,24,14,,14,14,14,,15,,14,15,12
17 | "15",14,,14,,,25,25,,,16,15,15,,14,,16,14,11
18 | "16",17,,16,,,17,17,,,15,,17,,17,,15,17,
19 | "17",16,,,,,16,16,,,18,,16,,16,,18,16,
20 | "18",19,,,,,19,19,,,17,,19,,19,,17,,
21 | "19",18,,,,,18,18,,,20,,18,,18,,20,,
22 | "20",21,,,,,10,10,,,19,,,,,,19,,
23 | "21",20,,,,,11,11,,,,,,,,,,,
24 | "22",22,,,,,12,12,,,,,,,,,,,
25 | "23",23,,,,,13,13,,,,,,,,,,,
26 | "24",24,,,,,14,14,,,,,,,,,,,
27 | "25",,,,,,15,15,,,,,,,,,,,
28 | "26",,,,,,27,27,,,,,,,,,,,
29 | "27",,,,,,26,26,,,,,,,,,,,
30 | "28",,,,,,29,29,,,,,,,,,,,
31 | "29",,,,,,28,28,,,,,,,,,,,
32 |
--------------------------------------------------------------------------------
/download_data/spair_keypoint_names.csv:
--------------------------------------------------------------------------------
1 | kp_id,aeroplane,bicycle,bird,boat,bottle,bus,car,cat,chair,cow,dog,horse,motorbike,person,pottedplant,sheep,train,tvmonitor
2 | 0,nose,front wheel,crown,front deck,right cap,right mirror,right mirror,right ear,right front-seat,right ear,right ear,right ear,right mirror,right eye,right rim,right ear,top-left front,top-right frame-corner
3 | 1,windshield,back wheel,right wing,front-right deck,left cap,left mirror,left mirror,left ear,left front-seat,left ear,left ear,left ear,left mirror,left eye,back rim,left ear,top-right front,frame-top
4 | 2,cockpit,right handlebar,left wing,front-left deck,right neck,right headlight,right headlight,right ear-tip,bottom-right back,right ear-tip,right ear-tip,right ear-tip,right handlebar,right ear,left rim,right ear-tip,bottom-left front,top-left frame-corner
5 | 3,front wheel,left handlebar ,beak,mid-right deck,left neck ,left headlight,left headlight,left ear-tip,bottom-left back,left ear-tip,left ear-tip,left ear-tip,left handlebar ,left ear,front rim,left ear-tip,bottom-right front,left frame-side
6 | 4,right wheel,middle handlebar,right cheek,mid-left deck ,top-right body,front license-plate,front license-plate,right eye,front-right leg,right eye,right eye,right eye,front light,nose,right body,right eye,top-left back,bottom-left frame-corner
7 | 5,left wheel,front seat,left cheek ,back-right deck,top-left body,right tail-light,front emblem,left eye,front-left leg,left eye,left eye,left eye,back light,chin,left body,left eye,top-right back,frame-base
8 | 6,right engine-fan,back-right seat,forehead ,back-left deck,right-center body,left tail-light,right tail-light,right nostril,back-right leg,right nostril,nose ,right nostril,front wheel,nape,right base,right nostril,bottom-left back,bottom-right frame-corner
9 | 7,left engine-fan,back-left seat,right eye,front-right hull,left-center body,back license-plate,left tail-light,left nostril,back-left leg,left nostril,forehead,left nostril,back wheel,mouth,middle base,left nostril,bottom-right back,right frame-side
10 | 8,right wing-tip,seat-post,left eye,front-left hull,right-bottom body,N/A,back license-plate,mouth,top-right back,mouth,mouth,mouth,license-plate,right shoulder,left base,mouth,top-left windshield ,top-right screen-corner
11 | 9,left wing-tip,right pedal,nape,mid-right hull,left-bottom body,N/A,back emblem,front-right foot,top-left back,front-right foot,front-right foot,forehead,tail ,left shoulder,,front-right foot,top-right windshield,screen-top
12 | 10,right engine,left pedal ,right foot,mid-left hull ,,back-right wheel-back-liner,back-right wheel-back-liner,front-left foot,right armrest,front-left foot,front-left foot,front-right foot,exhaustpipe,right elbow,,front-left foot,bottom-left windshield ,top-left screen-corner
13 | 11,left engine ,N/A,left foot,back-right hull,,back-right wheel,back-right wheel,back-right foot,left armrest,back-right foot,back-right foot,front-left foot,front wheelhub,left elbow,,back-right foot,bottom-right windshield ,left screen-side
14 | 12,right wing-front,N/A,right knee,back-left hull,,back-right wheel-front-liner,back-right wheel-front-liner,back-left foot,right elbowrest,back-left foot,back-left foot,back-right foot,back wheelhub ,right hand,,back-left foot,top-left windshield-center,bottom-left screen-corner
15 | 13,left wing-front,chainring ,left knee,front hull,,front-right wheel-back-liner,front-right wheel-back-liner,tail-tip,left elbowrest,tail-tip,tail-tip,back-left foot,,left hand,,tail-tip,top-right windshield-center,screen-base
16 | 14,right wing-back,,right hip,,,front-right wheel,front-right wheel,tail-base,,tail-base,tail-base,tail-tip,,right knee,,tail-base,bottom-left windshield-center,bottom-right screen-corner
17 | 15,left wing-back,,left hip,,,front-right wheel-front-liner,front-right wheel-front-liner,,,front-right knee,nape,tail-base,,left knee,,front-right knee,bottom-right windshield-center,right screen-side
18 | 16,right stabilizer-tip,,tail,,,bottom-right bus-front,bottom-right front-windshield,,,front-left knee,,front-right knee,,right ankle,,front-left knee,left headlight,
19 | 17,left stabilizer-tip,,,,,bottom-left bus-front,bottom-left front-windshield,,,back-right knee,,front-left knee,,left ankle,,back-right knee,right headlight,
20 | 18,right stabilizer-front,,,,,top-right bus-front,top-right front-windshield,,,back-left knee,,back-right knee,,right foot,,back-left knee,,
21 | 19,left stabilizer-front,,,,,top-left bus-front,top-left front-windshield,,,right horn,,back-left knee,,left foot,,right horn,,
22 | 20,right stabilizer-back,,,,,back-left wheel-back-liner,back-left wheel-back-liner,,,left horn,,,,,,left horn,,
23 | 21,left stabilizer-back,,,,,back-left wheel,back-left wheel,,,,,,,,,,,
24 | 22,vertical stabilizer-tip,,,,,back-left wheel-front-liner,back-left wheel-front-liner,,,,,,,,,,,
25 | 23,vertical stabilizer-front,,,,,front-left wheel-back-liner,front-left wheel-back-liner,,,,,,,,,,,
26 | 24,vertical stabilizer-back,,,,,front-left wheel,front-left wheel,,,,,,,,,,,
27 | 25,,,,,,front-left wheel-front-liner,front-left wheel-front-liner,,,,,,,,,,,
28 | 26,,,,,,bus-back bottom-right,bottom-right back-windshield,,,,,,,,,,,
29 | 27,,,,,,bus-back bottom-left,bottom-left back-windshield,,,,,,,,,,,
30 | 28,,,,,,bus-back top-right,top-right back-windshield,,,,,,,,,,,
31 | 29,,,,,,bus-back top-left,top-left back-windshield,,,,,,,,,,,
32 |
--------------------------------------------------------------------------------
/download_data/spair_keypoint_permutation.csv:
--------------------------------------------------------------------------------
1 | "kp_id ",aeroplane,bicycle,bird,boat,bottle,bus,car,cat,chair,cow,dog,horse,motorbike,person,pottedplant,sheep,train,tvmonitor
2 | "0",0,1,0,0,1,1,1,1,1,1,1,1,1,1,2,1,1,2
3 | "1",1,0,2,2,0,0,0,0,0,0,0,0,0,0,3,0,0,1
4 | "2",2,3,1,1,3,3,3,3,3,3,3,3,3,3,0,3,3,0
5 | "3",3,2,3,4,2,2,2,2,2,2,2,2,2,2,1,2,2,7
6 | "4",5,4,5,3,5,7,8,5,5,5,5,5,5,4,5,5,5,6
7 | "5",4,5,4,6,4,6,9,4,4,4,4,4,4,5,4,4,4,5
8 | "6",7,7,6,5,7,5,7,7,7,7,6,7,7,6,8,7,7,4
9 | "7",6,6,8,8,6,4,6,6,6,6,7,6,6,7,7,6,6,3
10 | "8",9,8,7,7,9,8,4,8,9,8,8,8,8,9,6,8,9,10
11 | "9",8,10,9,10,8,9,5,10,8,10,10,9,9,8,,10,8,9
12 | "10",11,9,11,9,,20,20,9,11,9,9,11,10,11,,9,11,8
13 | "11",10,11,10,12,,21,21,12,10,12,12,10,12,10,,12,10,15
14 | "12",13,12,13,11,,22,22,11,13,11,11,13,11,13,,11,13,14
15 | "13",12,13,12,13,,23,23,13,12,13,13,12,,12,,13,12,13
16 | "14",15,,15,,,24,24,14,,14,14,14,,15,,14,15,12
17 | "15",14,,14,,,25,25,,,16,15,15,,14,,16,14,11
18 | "16",17,,16,,,17,17,,,15,,17,,17,,15,17,
19 | "17",16,,,,,16,16,,,18,,16,,16,,18,16,
20 | "18",19,,,,,19,19,,,17,,19,,19,,17,,
21 | "19",18,,,,,18,18,,,20,,18,,18,,20,,
22 | "20",21,,,,,10,10,,,19,,,,,,19,,
23 | "21",20,,,,,11,11,,,,,,,,,,,
24 | "22",22,,,,,12,12,,,,,,,,,,,
25 | "23",23,,,,,13,13,,,,,,,,,,,
26 | "24",24,,,,,14,14,,,,,,,,,,,
27 | "25",,,,,,15,15,,,,,,,,,,,
28 | "26",,,,,,27,27,,,,,,,,,,,
29 | "27",,,,,,26,26,,,,,,,,,,,
30 | "28",,,,,,29,29,,,,,,,,,,,
31 | "29",,,,,,28,28,,,,,,,,,,,
32 |
--------------------------------------------------------------------------------
/pretrained_weights/geco/last_lora_weights.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/reginehartwig/geco/3d3f37530bd61f36602f8ac840da13200c3c2e72/pretrained_weights/geco/last_lora_weights.pth
--------------------------------------------------------------------------------
/scripts/run_evaluation_all.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
4 |
5 | import torch
6 | from src.logging.log_results import finish_wandb, init_wandb
7 | from src.evaluation.evaluation import Evaluation
8 | from omegaconf import DictConfig, OmegaConf
9 | import hydra
10 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11 | from src.models.featurizer_refine import load_checkpoint as load_checkpoint_refiner
12 | from src.models.featurizer.utils import get_featurizer
13 | import copy
14 |
15 | @hydra.main(config_path="../configs")
16 | def main(args: DictConfig):
17 |
18 | # load the model if feat_refine is part of the config
19 | if 'feat_refine' in args:
20 | model_refine = load_checkpoint_refiner(args.feat_refine)
21 | else:
22 | model_refine = None
23 |
24 | featurizer = get_featurizer(args.featurizer)
25 |
26 | cfg = OmegaConf.to_container(args)
27 | # convert hydra to dict
28 | if 'feat_refine' in args:
29 | init_wandb(cfg, f'eval {args.feat_refine.init.id}')
30 | elif 'init' in args.featurizer:
31 | init_wandb(cfg, f'eval {args.featurizer.init.id}')
32 | else:
33 | init_wandb(cfg, 'eval ')
34 |
35 | evaluation_0 = Evaluation(args, featurizer)
36 | args_0 = copy.deepcopy(args)
37 | args_0.dataset = args.dataset1
38 | evaluation_1 = Evaluation(args_0, featurizer)
39 | args_2 = copy.deepcopy(args)
40 | args_2.dataset = args.dataset2
41 | evaluation_2 = Evaluation(args_2, featurizer)
42 | args_3 = copy.deepcopy(args)
43 | args_3.dataset = args.dataset3
44 | evaluation_3 = Evaluation(args_3, featurizer)
45 |
46 | evaluation_0.evaluate_pck(model_refine=model_refine, suffix=evaluation_0.dataset_test_pck.name)
47 | evaluation_1.evaluate_pck(model_refine=model_refine, suffix=evaluation_1.dataset_test_pck.name)
48 | evaluation_2.evaluate_pck(model_refine=model_refine, suffix=evaluation_2.dataset_test_pck.name)
49 | evaluation_3.evaluate_pck(model_refine=model_refine, suffix=evaluation_3.dataset_test_pck.name)
50 |
51 | evaluation_val_0 = Evaluation(args, featurizer, split='val')
52 | args_val_0 = copy.deepcopy(args)
53 | args_val_0.dataset = args.dataset1
54 | evaluation_val_1 = Evaluation(args_val_0, featurizer, split='val')
55 | args_val_2 = copy.deepcopy(args)
56 | args_val_2.dataset = args.dataset2
57 | evaluation_val_2 = Evaluation(args_val_2, featurizer, split='val')
58 | args_val_3 = copy.deepcopy(args)
59 | args_val_3.dataset = args.dataset3
60 | evaluation_val_3 = Evaluation(args_val_3, featurizer, split='val')
61 |
62 | n_eval_imgs = 10
63 | evaluation_val_0.evaluate_pck(model_refine=model_refine, n_pairs_eval_pck=n_eval_imgs, suffix=evaluation_val_0.dataset_test_pck.name)
64 | evaluation_val_1.evaluate_pck(model_refine=model_refine, n_pairs_eval_pck=n_eval_imgs, suffix=evaluation_val_1.dataset_test_pck.name)
65 | evaluation_val_2.evaluate_pck(model_refine=model_refine, n_pairs_eval_pck=n_eval_imgs, suffix=evaluation_val_2.dataset_test_pck.name)
66 | evaluation_val_3.evaluate_pck(model_refine=model_refine, n_pairs_eval_pck=n_eval_imgs, suffix=evaluation_val_3.dataset_test_pck.name)
67 |
68 | finish_wandb()
69 |
70 | if __name__ == '__main__':
71 | main()
--------------------------------------------------------------------------------
/scripts/run_evaluation_seg.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from src.logging.log_results import finish_wandb, init_wandb
3 | from src.evaluation.evaluation import Evaluation
4 | from omegaconf import DictConfig, OmegaConf
5 | import hydra
6 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
7 | from src.models.featurizer_refine import load_checkpoint
8 |
9 | @hydra.main(config_path="../configs")
10 | def main(args: DictConfig):
11 |
12 | # load the model if feat_refine is part of the config
13 | if 'feat_refine' in args:
14 | model_refine = load_checkpoint(args.feat_refine)
15 | else:
16 | model_refine = None
17 |
18 | cfg = OmegaConf.to_container(args)
19 |
20 | if 'feat_refine' in args:
21 | init_wandb(cfg, f'eval_seg {args.feat_refine.init.id}')
22 | elif 'init' in args.featurizer:
23 | init_wandb(cfg, f'eval_seg {args.featurizer.init.id}')
24 | else:
25 | init_wandb(cfg, 'eval_seg ')
26 |
27 | evaluation = Evaluation(args)
28 | evaluation.evaluate_seg(model_refine=model_refine)
29 | finish_wandb()
30 |
31 | if __name__ == '__main__':
32 | main()
--------------------------------------------------------------------------------
/scripts/run_evaluation_time_mem.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
4 |
5 | import torch
6 | from src.logging.log_results import finish_wandb, init_wandb
7 | from src.evaluation.runtime_mem import evaluate
8 | from omegaconf import DictConfig, OmegaConf
9 | import hydra
10 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11 | from src.models.featurizer_refine import load_checkpoint
12 | from src.dataset.spair_single import SpairDatasetSingle
13 | from src.dataset.cub_200 import CUBDataset
14 | from src.dataset.apk_pairs import AP10KPairs
15 | from src.dataset.pfpascal_pairs import PFPascalPairs
16 | from src.models.featurizer.utils import get_featurizer
17 | import wandb
18 |
19 | @hydra.main(config_path="../configs")
20 | def main(args: DictConfig):
21 |
22 | # load the model if feat_refine is part of the config
23 | if 'feat_refine' in args:
24 | model_refine = load_checkpoint(args.feat_refine)
25 | else:
26 | model_refine = None
27 |
28 | cfg = OmegaConf.to_container(args)
29 | # convert hydra to dict
30 | if 'feat_refine' in args:
31 | init_wandb(cfg, f'eval_time_mem {args.feat_refine.init.id}')
32 | elif 'init' in args.featurizer:
33 | init_wandb(cfg, f'eval_time_mem {args.featurizer.init.id}')
34 | else:
35 | init_wandb(cfg, 'eval_time_mem ')
36 | if args.dataset.name == "spair":
37 | dataset_test_time_mem = SpairDatasetSingle(args.dataset, split="test")
38 | elif args.dataset.name == "cub":
39 | dataset_test_time_mem = CUBDataset(args.dataset, split="test")
40 | elif args.dataset.name == 'ap10k':
41 | dataset_test_time_mem = AP10KPairs(args.dataset, split="test")
42 | elif args.dataset.name == 'pfpascal':
43 | dataset_test_time_mem = PFPascalPairs(args.dataset, split="test")
44 |
45 | dataset_test_time_mem.featurizer = get_featurizer(args.featurizer)
46 | dataset_test_time_mem.featurizer_kwargs = args.featurizer
47 | results = evaluate(dataset_test_time_mem, args.num_imgs_time_mem, model_refine=model_refine)
48 | wandb.log(results)
49 | finish_wandb()
50 |
51 | if __name__ == '__main__':
52 | main()
--------------------------------------------------------------------------------
/scripts/store_feats.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from omegaconf import DictConfig, OmegaConf
3 | import hydra
4 | from src.dataset.cub_200 import CUBDatasetBordersCut, CUBDataset
5 | from src.dataset.spair_single import SpairDatasetSingle
6 | from src.dataset.pascalparts import PascalParts
7 | from src.dataset.apk_pairs import AP10KPairs
8 | from src.dataset.pfpascal_pairs import PFPascalPairs
9 | from src.models.featurizer.utils import get_featurizer
10 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11 |
12 | @hydra.main(config_path="../configs")
13 | def main(args: DictConfig):
14 | # convert hydra to dict
15 | cfg = OmegaConf.to_container(args)
16 | # init the dataset
17 | args.dataset.sup = "sup_original"
18 | torch.cuda.set_device(0)
19 | if args.dataset.name == 'spair':
20 | dataset = SpairDatasetSingle(args.dataset, split=args.dataset.split)
21 | elif args.dataset.name == 'cub':
22 | # dataset = CUBDatasetBordersCut(args.dataset, split=args.dataset.split)
23 | dataset = CUBDataset(args.dataset, split=args.dataset.split)
24 | elif args.dataset.name == 'pascalparts':
25 | dataset = PascalParts(args.dataset, split=args.dataset.split)
26 | elif args.dataset.name == 'ap10k':
27 | dataset = AP10KPairs(args.dataset, split=args.dataset.split)
28 | elif args.dataset.name == 'pfpascal':
29 | dataset = PFPascalPairs(args.dataset, split=args.dataset.split)
30 |
31 | if args.featurizer.model == 'dift_sd':
32 | args.featurizer.all_cats = dataset.all_cats
33 | featurizer = get_featurizer(args.featurizer)
34 | dataset.featurizer_name = featurizer.name
35 | dataset.featurizer = featurizer
36 | dataset.featurizer_kwargs = args.featurizer
37 | print(dataset.featurizer_name)
38 |
39 | overwrite=False
40 | dataset.store_feats(featurizer, overwrite, args.featurizer)
41 |
42 | if __name__ == '__main__':
43 | main()
--------------------------------------------------------------------------------
/scripts/store_masks.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
4 |
5 | import torch
6 | from omegaconf import DictConfig, OmegaConf
7 | import hydra
8 | from src.dataset.cub_200 import CUBDatasetBordersCut, CUBDataset
9 | from src.dataset.spair_single import SpairDatasetSingle
10 | from src.dataset.pascalparts import PascalParts
11 | from src.dataset.apk_pairs import AP10KPairs
12 | from src.dataset.pfpascal_pairs import PFPascalPairs
13 | from src.models.segmentation.sam import SAM
14 |
15 | @hydra.main(config_path="../configs")
16 | def main(args: DictConfig):
17 | # init the dataset
18 | args.dataset.sup = "sup_original"
19 | torch.cuda.set_device(0)
20 | if args.dataset.name == 'spair':
21 | dataset = SpairDatasetSingle(args.dataset, split=args.dataset.split)
22 | elif args.dataset.name == 'cub':
23 | # dataset = CUBDatasetBordersCut(args.dataset, split=args.dataset.split)
24 | dataset = CUBDataset(args.dataset, split=args.dataset.split)
25 | elif args.dataset.name == 'pascalparts':
26 | dataset = PascalParts(args.dataset, split=args.dataset.split)
27 | elif args.dataset.name == 'ap10k':
28 | dataset = AP10KPairs(args.dataset, split=args.dataset.split)
29 | elif args.dataset.name == 'pfpascal':
30 | dataset = PFPascalPairs(args.dataset, split=args.dataset.split)
31 |
32 | model_seg = SAM(args)
33 | dataset.model_seg = model_seg
34 | dataset.return_masks = True
35 | dataset.model_seg_name = model_seg.name
36 |
37 | overwrite=False
38 | dataset.store_masks(overwrite)
39 |
40 | if __name__ == '__main__':
41 | main()
--------------------------------------------------------------------------------
/scripts/train_pairs.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
4 |
5 | import torch
6 | from omegaconf import DictConfig, OmegaConf
7 | import hydra
8 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9 | import copy
10 | from src.models.featurizer.utils import get_featurizer
11 | from src.models.featurizer.utils import save_checkpoint as save_featurizer
12 | from src.logging.log_results import log_wandb_epoch, log_wandb_cfg, log_wandb_ram_usage, finish_wandb, log_wandb, init_wandb
13 | from src.dataset.cub_200_pairs import CUBPairDataset
14 | from src.dataset.spair import SpairDataset2
15 | from src.dataset.apk_pairs import AP10KPairs
16 | from src.dataset.pfpascal_pairs import PFPascalPairs
17 | from src.evaluation.evaluation import Evaluation
18 | from src.models.featurizer_refine import get_model_refine
19 | from src.models.featurizer_refine import save_checkpoint as save_model_refine
20 | from src.losses import PairwiseLoss
21 | import time
22 | from src.dataset.utils import get_multi_cat_dataset
23 |
24 | def set_seed(dataset_train, args, epoch):
25 | if args.dataset.cat == "all":
26 | # set seed of all the subdatasets of the concatenated dataset
27 | for dataset in dataset_train.datasets:
28 | dataset.seed_pairs = epoch
29 | else:
30 | dataset_train.seed_pairs = epoch # set the seed for the dataset to get the same pairs for each run, but different pairs for each epoch
31 |
32 | def forward_pass(model_refine, ft0, ft1):
33 | # get the new features
34 | if model_refine is not None:
35 | ft_new = [model_refine(f) for f in [ft0, ft1]]
36 | else:
37 | ft_new = [ft0, ft1]
38 | ft_new_ = [f.permute(0,2,3,1).flatten(1,-2) for f in ft_new] # each of shape (B, H*W, C)
39 | return ft_new_[0], ft_new_[1]
40 |
41 | def train_batch(data, model_refine, loss_fun):
42 | torch.cuda.empty_cache()
43 | ft0, ft1 = forward_pass(model_refine, data['src_ft'], data['trg_ft'])
44 | losses = loss_fun.get_loss(ft0, ft1, data)
45 | return losses
46 |
47 | def train_epoch(train_loader, model_refine, optimizer, weights, loss_fun, args, epoch, evaluation_test, evaluation_train, evaluation_val, evaluation_val_gen, evaluation_val_gen2, evaluation_val_gen3, best_pck, featurizer=None):
48 | # iterate over the dataset
49 | I = len(train_loader)
50 | for i, data in enumerate(train_loader):
51 | start_loss = time.time()
52 | if model_refine is not None:
53 | model_refine.train()
54 | else:
55 | featurizer.train()
56 | losses = train_batch(data, model_refine, loss_fun)
57 | # compute the mean loss
58 | losses = {k: v.mean()*weights[k] for k,v in losses.items()}
59 | # backprop
60 | optimizer.zero_grad()
61 | loss = sum(losses.values())
62 | loss.backward()
63 | optimizer.step()
64 | # measure runtime of loss computation
65 | timediff_loss = time.time()-start_loss
66 | log_wandb({'loss_time':timediff_loss})
67 | # log the losses, ram usage and evaluation
68 | if ((epoch-1)*I+i+1) % round(1*(6/args.batch_size)) == 0:
69 | log_wandb_epoch(epoch-1+i/I)
70 | if ((epoch-1)*I+i+1) % round(50*(6/args.batch_size)) == 0:
71 | losses['epoch'] = epoch-1+i/I
72 | log_wandb(losses)
73 | print ('Epoch [{}/{}], Step [{}/{}] Losses saved'.format(epoch, args.epoch,i,I))
74 | if ((epoch-1)*I+i+1) % round(1000*(6/args.batch_size)) == 0:
75 | if model_refine is not None:
76 | model_refine.eval()
77 | else:
78 | featurizer.eval()
79 | start_eval = time.time()
80 | log_wandb_ram_usage()
81 | print ('Epoch [{}/{}], Step [{}/{}] Eval start'.format(epoch, args.epoch,i,I))
82 | print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
83 | for eva in [evaluation_test, evaluation_train, evaluation_val, evaluation_val_gen, evaluation_val_gen2, evaluation_val_gen3]:
84 | eva.epoch = epoch-1+i/I
85 | print("Epoch [{}/{}], Step [{}/{}] Eval on val set".format(epoch, args.epoch,i,I))
86 | n_eval_imgs = 10
87 | evaluation_val.evaluate_pck(model_refine, n_pairs_eval_pck=n_eval_imgs, suffix=evaluation_val.dataset_test_pck.name)
88 | evaluation_val_gen.evaluate_pck(model_refine, n_pairs_eval_pck=n_eval_imgs, suffix=evaluation_val_gen.dataset_test_pck.name)
89 | evaluation_val_gen2.evaluate_pck(model_refine, n_pairs_eval_pck=n_eval_imgs, suffix=evaluation_val_gen2.dataset_test_pck.name)
90 | evaluation_val_gen3.evaluate_pck(model_refine, n_pairs_eval_pck=n_eval_imgs, suffix=evaluation_val_gen3.dataset_test_pck.name)
91 | log_wandb_ram_usage()
92 | print ('Epoch [{}/{}], Step [{}/{}] Eval end'.format(epoch, args.epoch,i,I))
93 | print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
94 | # process checkpoint
95 | suffix = evaluation_val.dataset_test_pck.name
96 | val_pck =evaluation_val.results_pck[f'per point PCK@0.1 _val_{n_eval_imgs}pairs_{suffix}']
97 | suffix = evaluation_val_gen.dataset_test_pck.name
98 | val_pck+=evaluation_val_gen.results_pck[f'per point PCK@0.1 _val_{n_eval_imgs}pairs_{suffix}']
99 | suffix = evaluation_val_gen2.dataset_test_pck.name
100 | val_pck+=evaluation_val_gen2.results_pck[f'per point PCK@0.1 _val_{n_eval_imgs}pairs_{suffix}']
101 | suffix = evaluation_val_gen3.dataset_test_pck.name
102 | val_pck+=evaluation_val_gen3.results_pck[f'per point PCK@0.1 _val_{n_eval_imgs}pairs_{suffix}']
103 | timediff_eval = time.time()-start_eval
104 | print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
105 | print(f"Epoch [{epoch}/{args.epoch}], Step [{i}/{I}] Eval time: {timediff_eval:.2f} s, PCK@0.1: {val_pck:.2f}")
106 | log_wandb({'eval_time':timediff_eval})
107 | if val_pck> best_pck:
108 | best_pck = val_pck
109 | if model_refine is not None:
110 | save_model_refine(model_refine, args.feat_refine, "best")
111 | else:
112 | save_featurizer(featurizer, args.featurizer, "best_lora_weights")
113 |
114 | if model_refine is not None:
115 | model_refine.train()
116 | else:
117 | featurizer.train()
118 | if ((epoch-1)*I+i+1) % round(1000*(6/args.batch_size)) == 0:
119 | if model_refine is not None:
120 | model_refine.epoch = epoch-1+i/I
121 | save_model_refine(model_refine, args.feat_refine, "last")
122 | else:
123 | featurizer.epoch = epoch-1+i/I
124 | save_featurizer(featurizer, args.featurizer, "last_lora_weights")
125 | if model_refine is not None:
126 | save_model_refine(model_refine, args.feat_refine, "last")
127 | else:
128 | save_featurizer(featurizer, args.featurizer, "last_lora_weights")
129 | return best_pck
130 |
131 | @hydra.main(config_path="../configs")
132 | def main(args: DictConfig):
133 | cfg = OmegaConf.to_container(args)
134 | init_wandb(cfg, 'train_pairs')
135 | torch.cuda.set_device(0)
136 | # init the dataset
137 | dataset_list = []
138 | for dataset_name in ['dataset', 'dataset2', 'dataset3', 'dataset4']:
139 | if dataset_name in args:
140 | dataset_args = args[dataset_name]
141 | if dataset_args.name == 'spair':
142 | dataset_list.append(SpairDataset2(dataset_args, split='train'))
143 | elif dataset_args.name == 'cub':
144 | dataset_list.append(CUBPairDataset(dataset_args, split='train'))
145 | elif dataset_args.name == 'ap10k':
146 | dataset_list.append(AP10KPairs(dataset_args, split='train'))
147 | elif dataset_args.name == 'pfpascal':
148 | dataset_list.append(PFPascalPairs(dataset_args, split='train'))
149 | # init the featurizer
150 | log_wandb_ram_usage()
151 | featurizer = get_featurizer(args.featurizer)
152 | log_wandb_ram_usage()
153 | # init the evaluation
154 | evaluation_test = Evaluation(args, featurizer)
155 | evaluation_train = Evaluation(args, featurizer, split='train')
156 | evaluation_val = Evaluation(args, featurizer, split='val')
157 | log_wandb_ram_usage()
158 | if args.dataset.cat == "all" or 'dataset2' in args:
159 | # avoid deep copy of the featurizer
160 | dataset_train_multi = get_multi_cat_dataset(dataset_list, featurizer, args.featurizer, model_seg_name=args.model_seg_name)
161 | train_loader = torch.utils.data.DataLoader(dataset_train_multi, batch_size=args.batch_size, shuffle=True)
162 | for dataset in dataset_train_multi.datasets:
163 | print(f"Number of pairs in dataset {dataset.name} {dataset.cat}: {len(dataset)}")
164 | else:
165 | # We train for each category separately
166 | dataset_list[-1].featurizer = featurizer
167 | dataset_list[-1].featurizer_kwargs = args.featurizer
168 | dataset_list[-1].init_kps_cat(args.dataset.cat)
169 | dataset_list[-1].model_seg_name = args.model_seg_name
170 | dataset_list[-1].return_masks = True
171 | train_loader = torch.utils.data.DataLoader(dataset_list[-1], batch_size=args.batch_size, shuffle=True)
172 | log_wandb_ram_usage()
173 | weights = {'pos':args.losses.pos, 'bin':args.losses.bin, 'neg':args.losses.neg, 'neg_fg_bkg':args.losses.neg_fg_bkg}
174 | log_wandb_cfg({'weights':weights})
175 | # init the model
176 | ep=0
177 | if 'feat_refine' in args:
178 | model_refine = get_model_refine(args.feat_refine)
179 | log_wandb_ram_usage()
180 | model_refine = model_refine.to(device)
181 | model_refine = model_refine.train()
182 | log_wandb_ram_usage()
183 | # define the optimizer
184 | # optimizer = torch.optim.Adam(list(model_refine.parameters())+list(matcher.parameters()), lr=args.learning_rate)
185 | optimizer = torch.optim.AdamW(list(model_refine.parameters()), lr=args.learning_rate, weight_decay=args.weight_decay)
186 | else:
187 | featurizer = featurizer.to(device)
188 | featurizer = featurizer.train()
189 | model_refine = None
190 | lora_params = [p for p in featurizer.parameters() if p.requires_grad]
191 | optimizer = torch.optim.AdamW(lora_params, lr=args.learning_rate, weight_decay=args.weight_decay)
192 | # define the scheduler
193 | if args.scheduler is not None:
194 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.learning_rate, steps_per_epoch=len(dataset_list[-1])//args.batch_size, epochs=args.epoch, pct_start=0.3)
195 | else:
196 | scheduler = None
197 | # add evaluation for different datasets
198 | argsgen = copy.deepcopy(args)
199 | argsgen.dataset = argsgen.datasetgeneralization
200 | evaluation_val_gen = Evaluation(argsgen, featurizer, split='val')
201 | argsgen2 = copy.deepcopy(args)
202 | argsgen2.dataset = argsgen.datasetgeneralization2
203 | evaluation_val_gen2 = Evaluation(argsgen2, featurizer, split='val')
204 | argsgen3 = copy.deepcopy(args)
205 | argsgen3.dataset = argsgen.datasetgeneralization3
206 | evaluation_val_gen3 = Evaluation(argsgen3, featurizer, split='val')
207 | # define the loss function
208 | loss_fun = PairwiseLoss(args)
209 | # start training
210 | best_pck = 0
211 | for epoch in range(ep+1, args.epoch+1):
212 | # set seed
213 | set_seed(train_loader.dataset, args, epoch)
214 |
215 | best_pck = train_epoch(train_loader, model_refine, optimizer, weights, loss_fun, args, epoch, evaluation_test, evaluation_train, evaluation_val, evaluation_val_gen, evaluation_val_gen2, evaluation_val_gen3, best_pck, featurizer=featurizer)
216 | if scheduler is not None:
217 | scheduler.step()
218 | # evaluate the best model
219 | if model_refine is not None:
220 | OmegaConf.update(args, "feat_refine.init.load_pretrained", True, force_add=True)
221 | OmegaConf.update(args, "feat_refine.init.id", model_refine.id, force_add=True)
222 | model_refine = get_model_refine(args.feat_refine)
223 | evaluation_test.reset_cat()
224 | evaluation_test.evaluate_pck(model_refine)
225 | finish_wandb()
226 |
227 | if __name__ == '__main__':
228 | main()
229 |
--------------------------------------------------------------------------------
/setup_env.sh:
--------------------------------------------------------------------------------
1 | pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118
2 | pip3 install -U xformers --index-url https://download.pytorch.org/whl/cu118 # this deletes the torch installation, so we need to reinstall it and cleanup the nvidia installs
3 | pip install jupyterlab
4 | pip install -U matplotlib
5 | pip install transformers
6 | pip install ipympl
7 | pip install triton
8 | pip install open_clip_torch # this deletes the torch installation, so we need to reinstall it and cleanup the nvidia installs
9 | pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118
10 | pip install hydra-core
11 | pip install -U scikit-learn
12 | pip install pandas
13 | pip install wandb
14 | pip install POT
15 | pip install --upgrade diffusers[torch]
16 | python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'
17 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/reginehartwig/geco/3d3f37530bd61f36602f8ac840da13200c3c2e72/src/__init__.py
--------------------------------------------------------------------------------
/src/dataset/augmentations.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | from torchvision import transforms
4 | import torchvision.transforms.functional as F
5 | import numpy as np
6 | import torch
7 |
8 |
9 | class GaussianBlur(transforms.RandomApply):
10 | """
11 | Apply Gaussian Blur to the PIL image.
12 | """
13 |
14 | def __init__(self, *, p: float = 0.5, radius_min: float = 0.1, radius_max: float = 2.0):
15 | # NOTE: torchvision is applying 1 - probability to return the original image
16 | keep_p = 1 - p
17 | transform = transforms.GaussianBlur(kernel_size=9, sigma=(radius_min, radius_max))
18 | super().__init__(transforms=[transform], p=keep_p)
19 |
20 |
21 | class DataAugmentation(object):
22 | def __init__(
23 | self,
24 | crops_scale=None,
25 | crops_size=None,
26 | KP_LEFT_RIGHT_PERMUTATION=None,
27 | KP_WITH_ORIENTATION=None,
28 | color_aug=True,
29 | flip_aug=True,
30 | ):
31 | self.KP_LEFT_RIGHT_PERMUTATION = KP_LEFT_RIGHT_PERMUTATION
32 | self.KP_WITH_ORIENTATION = KP_WITH_ORIENTATION
33 | if self.KP_WITH_ORIENTATION is not None:
34 | self.KP_WITH_ORIENTATION = torch.tensor(self.KP_WITH_ORIENTATION).flatten()
35 | self.flip_aug = flip_aug
36 | if crops_scale!=None:
37 | # random resized crop
38 | self.crops_size = crops_size
39 | self.geometric_augmentation_1 = transforms.RandomResizedCrop(
40 | crops_size, scale=crops_scale, ratio=(0.9, 1.1),
41 | )
42 | else:
43 | self.crops_size = None
44 |
45 | # color distorsions / blurring
46 | if color_aug:
47 | color_jittering = transforms.Compose(
48 | [
49 | transforms.RandomApply(
50 | [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
51 | p=0.8,
52 | )
53 | ]
54 | )
55 |
56 | colors = transforms.Compose(
57 | [
58 | GaussianBlur(p=0.1),
59 | transforms.RandomSolarize(threshold=128, p=0.2),
60 | ]
61 | )
62 |
63 | self.color_augmentation = transforms.Compose([color_jittering, colors])
64 | else:
65 | self.color_augmentation = None
66 |
67 | def augment_mask(self, mask):
68 |
69 | # Horizontal flip
70 | if self.flip_aug:
71 | hflip = np.random.binomial(1, p=0.5)
72 | else:
73 | hflip = False
74 | # hflip = False
75 | if hflip:
76 | if mask is not None:
77 | mask = F.hflip(mask)
78 |
79 | if self.crops_size!=None:
80 | i, j, h, w = self.geometric_augmentation_1.get_params(mask, scale=list(self.geometric_augmentation_1.scale), ratio=list(self.geometric_augmentation_1.ratio))
81 | # crop the image
82 | if mask is not None:
83 | mask = F.resized_crop(mask, i, j, h, w, [self.crops_size, self.crops_size])
84 | return mask
85 |
86 | def __call__(self, image, data):
87 | #{'ft': ft, 'imsize': imsize, 'kps': keypoints, 'kps_symm_only': kps_symm_only, 'bndbox':bndbox, 'idx':idx}
88 |
89 | # Horizontal flip
90 | if self.flip_aug:
91 | hflip = np.random.binomial(1, p=0.5)
92 | else:
93 | hflip = False
94 |
95 | # hflip = False
96 | if hflip:
97 | image = F.hflip(image)
98 | # data['bndbox'] is a list of 4 elements [xmin, ymin, xmax, ymax]
99 | xmin = data['bndbox'][0]
100 | xmax = data['bndbox'][2]
101 | data['bndbox'][0] = data['imsize'][1].item()-xmax
102 | data['bndbox'][2] = data['imsize'][1].item()-xmin
103 | data['hflip'] = True
104 | else:
105 | data['hflip'] = False
106 |
107 | if hflip:
108 | if self.KP_WITH_ORIENTATION is not None:
109 | # Horizontal flip
110 | data['kps'][:, 1] = data['imsize'][1].item() - data['kps'][:, 1] - 1
111 | data['kps_symm_only'] = data['kps'].clone()
112 |
113 | data['kps_symm_only'][~self.KP_WITH_ORIENTATION, 2] = 0
114 | data['kps'][self.KP_WITH_ORIENTATION, 2] = 0.5
115 | elif self.KP_LEFT_RIGHT_PERMUTATION is not None:
116 | for key in ['kps','kps_symm_only']:
117 | data[key][:, 1] = data['imsize'][1].item() - data[key][:, 1] - 1
118 | data[key] = data[key][self.KP_LEFT_RIGHT_PERMUTATION]
119 |
120 | # Crop
121 | if self.crops_size!=None:
122 | i, j, h, w = self.geometric_augmentation_1.get_params(image, scale=list(self.geometric_augmentation_1.scale), ratio=list(self.geometric_augmentation_1.ratio))
123 | # i,j are the top left corner of the crop (firs element for top, second for left)
124 | # h,w are the height and width of the crop
125 | # crop the image
126 | image = F.resized_crop(image, i, j, h, w, [self.crops_size, self.crops_size])
127 | data['imsize'] = torch.tensor((self.crops_size, self.crops_size))
128 | # data['bndbox'] is a list of 4 elements [xmin, ymin, xmax, ymax]
129 | data['bndbox'] = data['bndbox'] - np.array([j, i, j, i])
130 | # Scale the bounding box to the new size
131 | data['bndbox'] = data['bndbox'] * np.array([self.crops_size / w, self.crops_size / h, self.crops_size / w, self.crops_size / h])
132 |
133 |
134 | for key in ['kps','kps_symm_only']:
135 | keypoints = data[key]
136 | visible = keypoints[:, 2]>0.5
137 | if self.crops_size!=None:
138 | # Crop
139 | keypoints[visible, 0] -= i # left
140 | keypoints[visible, 1] -= j # top
141 | keypoints[visible, 0] *= self.crops_size/h
142 | keypoints[visible, 1] *= self.crops_size/w
143 | keypoints[visible, 2] *= torch.bitwise_and((keypoints[visible, 0] < data['imsize'][0]) , (keypoints[visible, 1] < data['imsize'][1]) ).float()
144 | keypoints[visible, 2] *= torch.bitwise_and((keypoints[visible, 0] >= 0 ) , (keypoints[visible, 1] >= 0 )).float()
145 | data[key] = keypoints
146 |
147 |
148 | # Color augmentation
149 | if self.color_augmentation is not None:
150 | image = self.color_augmentation(image)
151 |
152 | return image, data
153 |
--------------------------------------------------------------------------------
/src/dataset/cub_200_pairs.py:
--------------------------------------------------------------------------------
1 | from src.dataset.cub_200 import CUBDataset, CUBDatasetBordersCut, CUBDatasetAugmented
2 | from src.dataset.pairwise_utils import random_pairs_2
3 | from src.dataset.random_utils import use_seed
4 | import torch
5 | import numpy as np
6 |
7 | def CUBPairDataset(args, **kwargs):
8 | if args.sup == 'sup_original':
9 | if args.borders_cut:
10 | return CUBPairDatasetOrigBC(args, **kwargs)
11 | else:
12 | return CUBPairDatasetOrig(args, **kwargs)
13 | elif args.sup == 'sup_augmented':
14 | return CUBPairDatasetAugmentedPadded(args, **kwargs)
15 | else:
16 | raise ValueError(f"Unknown supervision type {args.sup}")
17 |
18 | class CUBPairDatasetOrig(CUBDataset):
19 | pck_symm = True
20 |
21 | def __init__(self, args, **kwargs):
22 | # init the parent
23 | super().__init__(args, **kwargs)
24 | idx0, idx1 = random_pairs_2(args.n_pairs, len(self.data))
25 | self.pairs = list(zip(idx0, idx1))
26 | self.return_imgs = False
27 | self.return_masks = False
28 | assert(args.borders_cut == False)
29 |
30 | def __len__(self):
31 | return len(self.pairs)
32 |
33 | def get_imgs(self, idx):
34 | idx0, idx1 = self.pairs[idx]
35 | img0 = super().get_img(idx0)
36 | img1 = super().get_img(idx1)
37 | return img0, img1
38 |
39 | def get_masks(self, idx):
40 | idx0, idx1 = self.pairs[idx]
41 | mask0 = super().get_mask(idx0)
42 | mask1 = super().get_mask(idx1)
43 | return mask0, mask1
44 |
45 | def __getitem__(self, idx):
46 | idx0, idx1 = self.pairs[idx]
47 | data0 = super().__getitem__(idx0)
48 | data1 = super().__getitem__(idx1)
49 | data = {'src_imsize': data0['imsize'],
50 | 'trg_imsize': data1['imsize'],
51 | 'src_kps': data0['kps'],
52 | 'trg_kps': data1['kps'],
53 | 'trg_kps_symm_only': data1['kps_symm_only'],
54 | 'src_kps_symm_only': data0['kps_symm_only'],
55 | 'src_bndbox': data0['bndbox'],
56 | 'trg_bndbox': data1['bndbox'],
57 | 'cat': self.cat,
58 | 'idx': idx}
59 |
60 | data['numkp'] = data0['kps'].shape[0]
61 | if self.return_feats:
62 | data['src_ft'] = data0['ft']
63 | data['trg_ft'] = data1['ft']
64 | if self.return_imgs:
65 | img0,img1 = self.get_imgs(idx)
66 | data['src_img'] = np.array(img0)
67 | data['trg_img'] = np.array(img1)
68 | if self.return_masks:
69 | mask0,mask1 = self.get_masks(idx)
70 | data['src_mask'] = mask0
71 | data['trg_mask'] = mask1
72 | return data
73 |
74 | class CUBPairDatasetOrigBC(CUBDatasetBordersCut):
75 | pck_symm = True
76 |
77 | def __init__(self, args, **kwargs):
78 | # init the parent
79 | super().__init__(args, **kwargs)
80 | idx0, idx1 = random_pairs_2(args.n_pairs, len(self.data))
81 | self.pairs = list(zip(idx0, idx1))
82 | self.return_imgs = False
83 | self.return_masks = False
84 |
85 | def __len__(self):
86 | return len(self.pairs)
87 |
88 | def get_imgs(self, idx):
89 | idx0, idx1 = self.pairs[idx]
90 | img0 = super().get_img(idx0)
91 | img1 = super().get_img(idx1)
92 | return img0, img1
93 |
94 | def __getitem__(self, idx):
95 | idx0, idx1 = self.pairs[idx]
96 | data0 = super().__getitem__(idx0)
97 | data1 = super().__getitem__(idx1)
98 | data = {'src_imsize': data0['imsize'],
99 | 'trg_imsize': data1['imsize'],
100 | 'src_kps': data0['kps'],
101 | 'trg_kps': data1['kps'],
102 | 'trg_kps_symm_only': data1['kps_symm_only'],
103 | 'src_kps_symm_only': data0['kps_symm_only'],
104 | 'src_bndbox': data0['bndbox'],
105 | 'trg_bndbox': data1['bndbox'],
106 | 'cat': self.cat,
107 | 'idx': idx}
108 | if self.return_feats:
109 | data['src_ft'] = data0['ft']
110 | data['trg_ft'] = data1['ft']
111 | if self.return_imgs:
112 | img0,img1 = self.get_imgs(idx)
113 | data['src_img'] = np.array(img0)
114 | data['trg_img'] = np.array(img1)
115 | data['numkp'] = data0['kps'].shape[0]
116 | if self.return_masks:
117 | mask0,mask1 = self.get_masks(idx)
118 | data['src_mask'] = mask0
119 | data['trg_mask'] = mask1
120 | return data
121 |
122 | class CUBPairDatasetAugmented(CUBDatasetAugmented):
123 | pck_symm = True
124 |
125 | def __init__(self, args, **kwargs):
126 | # init the parent
127 | super().__init__(args, **kwargs)
128 | idx0, idx1 = random_pairs_2(args.n_pairs, len(self.data))
129 | self.pairs = list(zip(idx0, idx1))
130 | self.seed_pairs = 1 # this can be changed for different epochs
131 | self.return_imgs = False
132 | self.return_masks = False
133 |
134 | def __len__(self):
135 | return len(self.pairs)
136 |
137 | def get_imgs(self, idx):
138 | idx0, idx1 = self.pairs[idx]
139 | with use_seed(idx+self.seed_pairs):
140 | seed_offset0 = torch.randint(1000000, (1,)).item()
141 | seed_offset1 = torch.randint(1000000, (1,)).item()
142 | self.seed = self.seed+seed_offset0
143 | img0 = super().get_img(idx0)
144 | self.seed = self.seed-seed_offset0
145 | self.seed = self.seed+seed_offset1
146 | img1 = super().get_img(idx1)
147 | self.seed = self.seed-seed_offset1
148 | return img0, img1
149 |
150 | def get_masks(self, idx):
151 | idx0, idx1 = self.pairs[idx]
152 | with use_seed(idx+self.seed_pairs):
153 | seed_offset0 = torch.randint(1000000, (1,)).item()
154 | seed_offset1 = torch.randint(1000000, (1,)).item()
155 | self.seed = self.seed+seed_offset0
156 | mask0 = super().get_mask(idx0)
157 | self.seed = self.seed-seed_offset0
158 | self.seed = self.seed+seed_offset1
159 | mask1 = super().get_mask(idx1)
160 | self.seed = self.seed-seed_offset1
161 | return mask0, mask1
162 |
163 | def __getitem__(self, idx):
164 | idx0, idx1 = self.pairs[idx]
165 | with use_seed(idx+self.seed_pairs):
166 | seed_offset0 = torch.randint(1000000, (1,)).item()
167 | seed_offset1 = torch.randint(1000000, (1,)).item()
168 | self.seed = self.seed+seed_offset0
169 | data0 = super().__getitem__(idx0)
170 | self.seed = self.seed-seed_offset0
171 | self.seed = self.seed+seed_offset1
172 | data1 = super().__getitem__(idx1)
173 | self.seed = self.seed-seed_offset1
174 | data = {'src_imsize': data0['imsize'],
175 | 'trg_imsize': data1['imsize'],
176 | 'src_kps': data0['kps'],
177 | 'trg_kps': data1['kps'],
178 | 'trg_kps_symm_only': data1['kps_symm_only'],
179 | 'src_kps_symm_only': data0['kps_symm_only'],
180 | 'src_bndbox': data0['bndbox'],
181 | 'trg_bndbox': data1['bndbox'],
182 | 'src_hflip': data0['hflip'],
183 | 'trg_hflip': data1['hflip'],
184 | 'cat': self.cat,
185 | 'idx': idx}
186 |
187 | data['numkp'] = len(self.KP_NAMES)
188 |
189 | if self.return_feats:
190 | data['src_ft'] = data0['ft']
191 | data['trg_ft'] = data1['ft']
192 | if self.return_imgs:
193 | img0,img1 = self.get_imgs(idx)
194 | data['src_img'] = np.array(img0)
195 | data['trg_img'] = np.array(img1)
196 | if self.return_masks:
197 | mask0,mask1 = self.get_masks(idx)
198 | data['src_mask'] = mask0
199 | data['trg_mask'] = mask1
200 | return data
201 |
202 | class CUBPairDatasetAugmentedPadded(CUBPairDatasetAugmented):
203 | def __init__(self, args, **kwargs):
204 | super().__init__(args, **kwargs)
205 | self.n_kps = 100
206 | self.num_el = args.num_el if 'num_el' in args else None
207 |
208 | def total_len(self):
209 | return len(self.pairs)
210 |
211 | def __len__(self):
212 | if self.num_el is None:
213 | return self.total_len()
214 | return min(self.num_el, self.total_len())
215 |
216 | def pad_kps(self, kps):
217 | kps = kps.clone()
218 | pad = self.n_kps - kps.shape[0]
219 | if pad > 0:
220 | kps = torch.cat([kps, torch.zeros(pad, 3)], dim=0)
221 | return kps
222 |
223 | def init_kps_cat(self, cat):
224 | super().init_kps_cat(cat)
225 | # remove the shuffeled_idx_list
226 | if hasattr(self, 'shuffeled_idx_list'):
227 | delattr(self, 'shuffeled_idx_list')
228 |
229 | def __getitem__(self, idx):
230 | totallen = self.total_len()
231 | subsetlen = self.__len__()
232 | # create new idx list excluding the idx that have been used
233 | if not hasattr(self, 'shuffeled_idx_list'):
234 | with use_seed(self.seed_pairs):
235 | self.shuffeled_idx_list = np.random.permutation(totallen)
236 | elif len(self.shuffeled_idx_list) < self.seed_pairs*subsetlen:
237 | with use_seed(self.seed_pairs+1):
238 | self.shuffeled_idx_list = np.append(self.shuffeled_idx_list, np.random.permutation(totallen))
239 | idx_ = self.shuffeled_idx_list[idx+(self.seed_pairs-1)*subsetlen]
240 |
241 | data_out = super().__getitem__(int(idx_))
242 | # pad the keypoints
243 | for k in data_out.keys():
244 | if 'kps' in k:
245 | data_out[k] = self.pad_kps(data_out[k])
246 | return data_out
247 |
--------------------------------------------------------------------------------
/src/dataset/pairwise_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from src.dataset.random_utils import use_seed
3 | import numpy as np
4 | from src.dataset.utils import to_flattened_idx_torch
5 | import copy
6 | def random_pairs(n_pairs,n_imgs,permute_first=True):
7 | if permute_first:
8 | idx0 = torch.randint(n_imgs, (n_pairs,))
9 | else:
10 | idx0 = torch.arange(n_imgs).repeat(n_pairs//n_imgs)
11 | idx0 = idx0[:n_pairs]
12 | idx1 = torch.randint(n_imgs, (n_pairs,))
13 | for n in range(n_pairs):
14 | while idx0[n]==idx1[n]:
15 | idx1[n] = torch.randint(n_imgs, (1,))
16 | return idx0, idx1
17 |
18 | def random_pairs_2(n_pairs, n_imgs):
19 | idx0, idx1 =[], []
20 | while len(idx0) < n_pairs:
21 | with use_seed(len(idx0) + 123):
22 | indices = np.random.permutation(np.arange(0, n_imgs))
23 | middle = n_imgs // 2
24 | idx0, idx1 = idx0 + indices[:middle].tolist(), idx1 + indices[-middle:].tolist()
25 | idx0, idx1 = idx0[:n_pairs], idx1[:n_pairs]
26 | return idx0, idx1
27 |
28 | def get_pos_pairs(data_in, b=0):
29 | src_kps_symm_only = data_in['src_kps_symm_only'][b]
30 | trg_kps_symm_only = data_in['trg_kps_symm_only'][b]
31 | src_kps = data_in['src_kps'][b]
32 | trg_kps = data_in['trg_kps'][b]
33 | vis_both = torch.bitwise_and(src_kps[:,2] > 0.5, trg_kps[:,2] > 0.5)
34 | pos_src_kps = src_kps[vis_both]
35 | pos_trg_kps = trg_kps[vis_both]
36 | flag_11_src = torch.bitwise_and(src_kps_symm_only[:,2] > 0.5, src_kps[:,2] > 0.5)
37 | flag_11_trg = torch.bitwise_and(trg_kps_symm_only[:,2] > 0.5, trg_kps[:,2] > 0.5)
38 | flags_11 = {'pos_src_11': flag_11_src[vis_both], 'pos_trg_11': flag_11_trg[vis_both]}
39 | return pos_src_kps, pos_trg_kps, flags_11
40 |
41 | def get_neg_pairs(data_in, b=0):
42 | src_kps_symm_only = data_in['src_kps_symm_only'][b]
43 | trg_kps_symm_only = data_in['trg_kps_symm_only'][b]
44 | src_kps = data_in['src_kps'][b]
45 | trg_kps = data_in['trg_kps'][b]
46 | # match to other points in src
47 | vis_both = torch.bitwise_and(src_kps_symm_only[:,2] > 0.5, trg_kps[:,2] > 0.5)
48 | neg_src_kps = src_kps_symm_only.clone()
49 | neg_trg_kps = trg_kps.clone()
50 | neg_src_kps[:,2] = vis_both
51 | neg_trg_kps[:,2] = vis_both
52 | flag_11_src = torch.bitwise_and(src_kps_symm_only[:,2] > 0.5, src_kps[:,2] > 0.5)
53 | # match to other points in trg
54 | vis_both = torch.bitwise_and(src_kps[:,2] > 0.5, trg_kps_symm_only[:,2] > 0.5)
55 | neg_src_kps[vis_both] = src_kps[vis_both]
56 | neg_trg_kps[vis_both] = trg_kps_symm_only[vis_both]
57 | neg_trg_kps[:,2] = torch.bitwise_or(neg_src_kps[:,2] > 0.5, vis_both)
58 | flag_11_trg = torch.bitwise_and(trg_kps_symm_only[:,2] > 0.5, trg_kps[:,2] > 0.5)
59 | # only keep the vis_both keypoints
60 | flags_11 = {'neg_src_11': flag_11_src[neg_src_kps[:,2] > 0.5], 'neg_trg_11': flag_11_trg[neg_trg_kps[:,2] > 0.5]}
61 | neg_src_kps = neg_src_kps[neg_src_kps[:,2] > 0.5]
62 | neg_trg_kps = neg_trg_kps[neg_trg_kps[:,2] > 0.5]
63 | return neg_src_kps, neg_trg_kps, flags_11
64 |
65 | def get_bin_pairs(data_in, b=0):
66 | src_kps_symm_only = data_in['src_kps_symm_only'][b]
67 | trg_kps_symm_only = data_in['trg_kps_symm_only'][b]
68 | src_kps = data_in['src_kps'][b]
69 | trg_kps = data_in['trg_kps'][b]
70 | occluded = torch.bitwise_xor(src_kps[:,2] == 0, trg_kps[:,2] == 0)
71 | vis_both = torch.bitwise_or(src_kps[:,2] > 0.5, trg_kps[:,2] > 0.5)
72 | bin_flag = torch.bitwise_and(occluded, vis_both)
73 | bin_src_kps = src_kps[bin_flag]
74 | bin_trg_kps = trg_kps[bin_flag]
75 | flag = bin_flag[bin_flag]
76 | flags_11 = {'bin_src_11': flag, 'bin_trg_11': flag}
77 | return bin_src_kps, bin_trg_kps, flags_11
78 |
79 | def get_matches(data_in, b=0):
80 | pos_src_kps, pos_trg_kps, pos_flag_11 = get_pos_pairs(data_in, b)
81 | neg_src_kps, neg_trg_kps, neg_flag_11 = get_neg_pairs(data_in, b)
82 | bin_src_kps, bin_trg_kps, bin_flag_11 = get_bin_pairs(data_in, b)
83 | # put them in a dictionary
84 | data_out = {'pos_src_kps': pos_src_kps,
85 | 'pos_trg_kps': pos_trg_kps,
86 | 'neg_src_kps': neg_src_kps,
87 | 'neg_trg_kps': neg_trg_kps,
88 | 'bin_src_kps': bin_src_kps,
89 | 'bin_trg_kps': bin_trg_kps}
90 | data_out.update(pos_flag_11)
91 | data_out.update(neg_flag_11)
92 | data_out.update(bin_flag_11)
93 | return data_out
94 |
95 | def scale_to_feature_dims(data_out, data_in, b=0):
96 | scale_src = torch.tensor(data_in['src_ft'].shape[-2:])/data_in['src_imsize'][b]# get the scale for the feature image coordinates to the original image coordinates
97 | scale_trg = torch.tensor(data_in['trg_ft'].shape[-2:])/data_in['trg_imsize'][b]
98 |
99 | data_out_ft = copy.deepcopy(data_out)
100 | for prefix in ['pos', 'bin', 'neg']:
101 | src_kps = data_out[prefix+'_src_kps'].clone()
102 | trg_kps = data_out[prefix+'_trg_kps'].clone()
103 | # convert to feature image coordinates
104 | src_kps[:,:2] = (src_kps[:,:2]*scale_src).floor().long()
105 | trg_kps[:,:2] = (trg_kps[:,:2]*scale_trg).floor().long()
106 | data_out_ft[prefix+'_src_kps'] = src_kps
107 | data_out_ft[prefix+'_trg_kps'] = trg_kps
108 | return data_out_ft
109 |
110 | def get_y_mat_gt_assignment(kp0, kp1, ft_size0, ft_size1):
111 | size = (ft_size0.prod()+1, ft_size1.prod()+1)
112 | if len(kp0) > 0:
113 | # enter dummy values for the keypoints that are not vis_both to avoid error in assert of flattened index function
114 | kp0[kp0[:,2]==0,0] = 0
115 | kp0[kp0[:,2]==0,1] = 0
116 | kp1[kp1[:,2]==0,0] = 0
117 | kp1[kp1[:,2]==0,1] = 0
118 | kp0_idx = to_flattened_idx_torch(kp0[:,0], kp0[:,1], ft_size0[0].item(), ft_size0[1].item())
119 | kp1_idx = to_flattened_idx_torch(kp1[:,0], kp1[:,1], ft_size1[0].item(), ft_size1[1].item())
120 |
121 | kp0_idx[kp0[:,2]==0] = ft_size0.prod() # bin idx
122 | kp1_idx[kp1[:,2]==0] = ft_size1.prod() # bin idx
123 |
124 | indices = torch.stack([kp0_idx, kp1_idx]).unique(dim=1, sorted=False)
125 | values = torch.ones_like(indices[0]).float()
126 | matrix = torch.sparse_coo_tensor(indices, values, size).coalesce()
127 |
128 | else:
129 | # if no keypoints are annotated, return a matrix with only zeros
130 | matrix = torch.sparse_coo_tensor(torch.empty((2, 0), dtype=torch.long), [], size).coalesce()
131 | return matrix
--------------------------------------------------------------------------------
/src/dataset/pascalparts.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 |
3 | import numpy as np
4 | import torch
5 | from torch.utils.data.dataset import Dataset as TorchDataset
6 | from scipy.io import loadmat
7 | import os
8 | from tqdm import tqdm
9 | from src.dataset.pascalparts_part2ind import part2ind
10 | from typing import Optional, Any
11 | OBJECT_CLASSES = ['aeroplane','bicycle','bird','boat','bottle','bus','car','cat',
12 | 'chair','cow','table','dog','horse','motorbike','person','pottedplant','sheep','sofa','train','tvmonitor']
13 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14 |
15 | class PascalParts(TorchDataset):
16 | name = 'pascalparts'
17 | name_this = 'pascalparts'
18 |
19 | def __init__(self, args, split='test'):
20 | self.cat = None
21 | self.root = args.dataset_path
22 | # join the path
23 |
24 | self.root_annotations = os.path.join(self.root,'Parts/Annotations_Part/')
25 | self.root_split = os.path.join(self.root,'VOCdevkit/VOC2010/ImageSets/Main/')
26 | self.root_imgs = os.path.join(self.root,'VOCdevkit/VOC2010/JPEGImages/')
27 | self.save_path = args.save_path
28 | self.split = split
29 |
30 | self.featurizer_name = None
31 | # Initialize attributes that may be set externally
32 | self.featurizer: Optional[Any] = None
33 | self.featurizer_kwargs: Optional[Any] = None
34 | self.model_seg: Optional[Any] = None
35 | self.model_seg_name: Optional[str] = None
36 | self.return_masks: bool = False
37 |
38 | self.all_cats = OBJECT_CLASSES
39 | self.all_cats_multipart = ['bird', 'cat', 'cow', 'dog', 'horse', 'sheep', 'person'] # categories with multiple parts that are annotated in many images
40 |
41 | try:
42 | self.data = torch.load(os.path.join(self.save_path, 'pascalparts.pth'))
43 | except:
44 | self.structure_mat()
45 | self.data = torch.load(os.path.join(self.save_path, 'pascalparts.pth'))
46 |
47 | def init_kps_cat(self, cat):
48 | self.cat = cat
49 | cat_idx = OBJECT_CLASSES.index(self.cat)+1 # the index of the category in the pascal parts dataset
50 | part_dict, KP_LEFT_RIGHT_PERMUTATION, KP_WITH_ORIENTATION = part2ind(cat_idx) # get the part dictionary for the category
51 |
52 | if KP_LEFT_RIGHT_PERMUTATION is not None:
53 | self.KP_WITH_ORIENTATION = KP_WITH_ORIENTATION
54 | self.KP_LEFT_RIGHT_PERMUTATION = KP_LEFT_RIGHT_PERMUTATION
55 | else:
56 | self.KP_WITH_ORIENTATION = None
57 | self.KP_LEFT_RIGHT_PERMUTATION = None
58 |
59 | if len(part_dict.values()) == 0:
60 | part_dict = {'fg': 1}
61 |
62 | self.KP_NAMES = [k for k, v in sorted(part_dict.items(), key=lambda item: item[1])]
63 |
64 | def structure_mat(self):
65 | data = {}
66 | mat_filenames = os.listdir(self.root_annotations)
67 | for cat in OBJECT_CLASSES:
68 | data[cat] = []
69 |
70 | for idx, annotation_filename in enumerate(mat_filenames):
71 | mat = loadmat(os.path.join(self.root_annotations, annotation_filename), struct_as_record=False, squeeze_me=True)['anno']
72 | obj = mat.__dict__['objects']
73 | # check if obj is array, we only consider images with one object
74 | if isinstance(obj, np.ndarray):
75 | continue
76 | else:
77 | mat_cat = obj.__dict__['class']
78 | data[mat_cat].append(mat.__dict__['imname'])
79 | torch.save(data, os.path.join(self.save_path, 'pascalparts.pth'))
80 |
81 | def __len__(self):
82 | return len(self.data[self.cat])
83 |
84 | def _get_feat(self, idx, featurizer, featurizer_kwargs):
85 | cat = self.cat
86 | img = self.get_img(idx)
87 | feat = featurizer.forward(img,
88 | category=cat,
89 | **featurizer_kwargs)
90 | return feat
91 |
92 | def store_feats(self, featurizer, overwrite, featurizer_kwargs):
93 | assert(self.name == self.name_this)
94 | print("saving all %s images' features..."%self.split)
95 | self.imsizes = {}
96 | path = os.path.join(self.save_path, self.name_this, featurizer.name)
97 | for cat in tqdm(self.all_cats):
98 | self.init_kps_cat(cat)
99 | for idx in range(len(self)):
100 | ft_name = self.data[self.cat][idx]+'.pth'
101 | if os.path.exists(os.path.join(path, ft_name)) and not overwrite:
102 | continue
103 | feat = self._get_feat(idx, featurizer, featurizer_kwargs)
104 | # make directory of parent directory
105 | os.makedirs(os.path.join(path, os.path.dirname(ft_name)), exist_ok=True)
106 | torch.save(feat.detach().cpu(), os.path.join(path, ft_name))
107 | del feat
108 | torch.cuda.empty_cache()
109 |
110 | def get_img(self, idx):
111 | imname = self.data[self.cat][idx]+'.jpg'
112 | img_path = os.path.join(self.root_imgs, imname)
113 | img = Image.open(img_path).convert('RGB')
114 | return img
115 |
116 | def get_feat(self, idx):
117 | assert(self.featurizer_name is not None)
118 | assert(self.cat is not None)
119 | ft_name = self.data[self.cat][idx]+'.pth'
120 | ft = torch.load(os.path.join(self.save_path, self.name_this, self.featurizer_name, ft_name)).to(device)
121 | return ft
122 |
123 |
124 | def get_parts(self, idx):
125 | '''
126 | Output:
127 | part_dict: dictionary with part names as keys and part indices as values
128 | parts_mask: tensor of shape (num_parts, H, W) with one hot encoding of the parts,
129 | where num_parts is the number of parts in the category without the background,
130 | i.e. part_mask.sum(dim=0) should be 0 for the background
131 |
132 | '''
133 | name = self.data[self.cat][idx]
134 | mat = loadmat(os.path.join(self.root_annotations, name+'.mat'), struct_as_record=False, squeeze_me=True)['anno']
135 | # check if there is only one object in the image
136 | if isinstance(mat.__dict__['objects'], np.ndarray):
137 | raise ValueError('Multiple objects in image')
138 | obj_mask = mat.__dict__['objects'].__dict__['mask']
139 | assert(self.cat is not None)
140 | cat_idx = OBJECT_CLASSES.index(self.cat)+1 # the index of the category in the pascal parts dataset
141 | part_dict, _, _ = part2ind(cat_idx) # get the part dictionary for the category
142 | if len(part_dict.values()) == 0:
143 | num_parts = 1
144 | parts_mask = torch.zeros(num_parts, obj_mask.shape[0], obj_mask.shape[1])
145 | parts_mask[0, :, :] = torch.tensor(obj_mask)
146 | part_dict = {'fg': 1}
147 | else:
148 | num_parts = max(part_dict.values()) # not really the number of parts, as part indices are not continuous
149 | parts_mask = torch.zeros(num_parts, obj_mask.shape[0], obj_mask.shape[1])
150 | # check if there is only one part in the object
151 | parts = mat.__dict__['objects'].__dict__['parts']
152 | if not isinstance(parts, np.ndarray):
153 | parts = [parts]
154 | for part in parts:
155 | part_idx = part_dict[part.__dict__['part_name']]-1
156 | parts_mask[part_idx, :, :] = torch.tensor(part.__dict__['mask'])
157 |
158 | def one_hot_parts(part_mask):
159 | # just take the occurence with highest value in dimension 1, i.e. the part with the highest index, pascalparts has multiple labels per pixel and we want to take the one with the highest index
160 | bkg = (part_mask.sum(dim=0, keepdim=True)==0).repeat(part_mask.shape[0],1,1)
161 | prt_reverse = torch.flip(part_mask, [0])
162 | label = part_mask.shape[0] - prt_reverse.argmax(dim=0, keepdim=True) - 1 # undo the reverse
163 | prt = torch.zeros_like(part_mask)
164 | prt = prt.scatter_(0, label, 1)
165 | prt[bkg]=0
166 | return prt
167 |
168 | parts_mask = one_hot_parts(parts_mask)
169 |
170 | num_parts = max(part_dict.values()) # not really the number of parts, as part indices are not continuous
171 | self.parts = torch.unique( torch.tensor(list(part_dict.values())) -1)
172 | num_parts_ = len(self.parts)
173 |
174 | if num_parts_ != num_parts:
175 | parts_mask = parts_mask[self.parts]
176 |
177 | return part_dict, parts_mask
178 |
179 | def __getitem__(self, idx):
180 | part_dict, parts_mask = self.get_parts(idx)
181 | try:
182 | ft = self.get_feat(idx)[0]
183 | except:
184 | assert(self.featurizer is not None)
185 | assert(self.featurizer_kwargs is not None)
186 | ft = self._get_feat(idx, self.featurizer, self.featurizer_kwargs)[0]
187 | imsize = torch.tensor((parts_mask.shape[-2], parts_mask.shape[-1]))
188 | return {'ft': ft, 'imsize': imsize, 'parts_mask': parts_mask, 'part_dict': part_dict, 'idx': idx}
189 |
--------------------------------------------------------------------------------
/src/dataset/random_utils.py:
--------------------------------------------------------------------------------
1 |
2 | from functools import wraps
3 |
4 |
5 | from numpy.random import seed as np_seed
6 | from numpy.random import get_state as np_get_state
7 | from numpy.random import set_state as np_set_state
8 | from random import seed as rand_seed
9 | from random import getstate as rand_get_state
10 | from random import setstate as rand_set_state
11 | import torch
12 | from torch import manual_seed as torch_seed
13 | from torch import get_rng_state as torch_get_state
14 | from torch import set_rng_state as torch_set_state
15 |
16 | class use_seed:
17 | def __init__(self, seed=None):
18 | if seed is not None:
19 | assert isinstance(seed, int) and seed >= 0
20 | self.seed = seed
21 |
22 | def __enter__(self):
23 | if self.seed is not None:
24 | self.rand_state = rand_get_state()
25 | self.np_state = np_get_state()
26 | self.torch_state = torch_get_state()
27 | self.torch_cudnn_deterministic = torch.backends.cudnn.deterministic
28 | rand_seed(self.seed)
29 | np_seed(self.seed)
30 | torch_seed(self.seed)
31 | torch.backends.cudnn.deterministic = True
32 | return self
33 |
34 | def __exit__(self, typ, val, _traceback):
35 | if self.seed is not None:
36 | rand_set_state(self.rand_state)
37 | np_set_state(self.np_state)
38 | torch_set_state(self.torch_state)
39 | torch.backends.cudnn.deterministic = self.torch_cudnn_deterministic
40 |
41 | def __call__(self, f):
42 | @wraps(f)
43 | def wrapper(*args, **kw):
44 | seed = self.seed if self.seed is not None else kw.pop('seed', None)
45 | with use_seed(seed):
46 | return f(*args, **kw)
47 |
48 | return wrapper
49 |
--------------------------------------------------------------------------------
/src/dataset/utils.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import numpy as np
4 | import torch.nn.functional as F
5 | import copy
6 | from torch.utils.data import ConcatDataset
7 |
8 | @torch.no_grad()
9 | def get_feats(dataset, idx, model_refine=None, only_fg=False, feat_interpol=False, prt_interpol=True):
10 | try:
11 | ft = dataset[idx]['ft'][None]
12 | except:
13 | ft = dataset._get_feat(idx, dataset.featurizer, dataset.featurizer_kwargs)
14 | if model_refine is not None:
15 | ft = model_refine(ft)
16 | feats = ft.permute(0,2,3,1).flatten(0,-2)# B, C, H, W -> H*W, C
17 |
18 | if 'parts_mask' in dataset[idx].keys():
19 | prt = dataset[idx]['parts_mask'][None]
20 | ftsize = torch.tensor(ft.shape[-2:])
21 | prtsize = torch.tensor(prt.shape[-2:])
22 | if feat_interpol:
23 | ft_interp = F.interpolate(ft, size=prtsize.tolist(), mode='bilinear', align_corners=False)
24 | feats = ft_interp.permute(0,2,3,1).flatten(0,-2)# B, C, H, W -> H*W, C
25 | if prt_interpol:
26 | prt_interp = F.interpolate(prt, size=ftsize.tolist(), mode='bilinear', align_corners=False)
27 | parts = prt_interp.permute(0,2,3,1).flatten(0,-2)# B, C, H, W -> H*W, num_parts
28 | else:
29 | parts = prt.permute(0,2,3,1).flatten(0,-2)
30 |
31 | if only_fg:
32 | mask = parts.sum(-1)>0
33 | feats = feats[mask]
34 | parts = parts[mask]
35 |
36 | elif 'kps' in dataset[idx].keys() and only_fg:
37 |
38 | # Extract keypoint information
39 | kp = dataset[idx]['kps']
40 | imgsize = dataset[idx]['imsize']
41 |
42 |
43 | # Filter keypoints based on visibility
44 | visible_mask = kp[:, 2] > 0.5
45 | kp_ = kp[visible_mask, :2]
46 |
47 | # Normalize keypoint coordinates to range [-1, 1] for grid sampling
48 | h, w = ft.shape[2:]
49 | kp_normalized = torch.stack([
50 | 2 * kp_[:, 1] / (imgsize[1] - 1) - 1,
51 | 2 * kp_[:, 0] / (imgsize[0] - 1) - 1
52 | ], dim=-1).unsqueeze(0).unsqueeze(0).to(ft.device) # Shape: (1, num_kps, 2)
53 |
54 | # Use grid_sample to directly sample features at keypoint locations
55 | # Ensure kp_normalized and ft are on the same device
56 | feats = F.grid_sample(ft, kp_normalized, mode='bilinear', align_corners=False)
57 | feats = feats.squeeze(0).squeeze(1).t() # Shape: (num_vis_kps, C)
58 |
59 | parts = torch.eye(kp.shape[0])[visible_mask] # num_vis_kps, num_parts
60 | else:
61 | parts = None
62 |
63 | return feats, parts
64 |
65 | @torch.no_grad()
66 | def get_featpairs(dataset, idx, model_refine=None, only_fg=False, feat_interpol=False, prt_interpol=True):
67 | try:
68 | ft0 = dataset[idx]['src_ft'][None]
69 | ft1 = dataset[idx]['trg_ft'][None]
70 | except:
71 | raise ValueError('Need to have src_ft and tgt_ft in dataset')
72 | if model_refine is not None:
73 | ft0 = model_refine(ft0)
74 | ft1 = model_refine(ft1)
75 | feats = [ft0.permute(0,2,3,1).flatten(0,-2), ft1.permute(0,2,3,1).flatten(0,-2)]# B, C, H, W -> H*W, C
76 | parts = []
77 | for i,(ft, prefix) in enumerate(zip([ft0, ft1], ['src', 'trg'])):
78 | if prefix+'_kps' in dataset[idx].keys():
79 |
80 | kp = dataset[idx][prefix+'_kps']
81 | imgsize = dataset[idx][prefix+'_imsize']
82 |
83 | # Filter keypoints based on visibility
84 | visible_mask = kp[:, 2] > 0.5
85 | kp_ = kp[visible_mask, :2]
86 |
87 | # Normalize keypoint coordinates to range [-1, 1] for grid sampling
88 | h, w = ft.shape[2:]
89 | kp_normalized = torch.stack([
90 | 2 * kp_[:, 1] / (imgsize[1] - 1) - 1,
91 | 2 * kp_[:, 0] / (imgsize[0] - 1) - 1
92 | ], dim=-1).unsqueeze(0).unsqueeze(0).to(ft.device) # Shape: (1, num_kps, 2)
93 |
94 | # Use grid_sample to directly sample features at keypoint locations
95 | # Ensure kp_normalized and ft are on the same device
96 | feats[i] = F.grid_sample(ft, kp_normalized, mode='bilinear', align_corners=False)
97 | feats[i] = feats[i].squeeze(0).squeeze(1).t() # Shape: (num_vis_kps, C)
98 |
99 | # free up memory
100 | torch.cuda.empty_cache()
101 | parts.append(torch.eye(kp.shape[0])[visible_mask])
102 | else:
103 | parts.append(None)
104 |
105 | return feats[0], parts[0], feats[1], parts[1]
106 |
107 | @torch.no_grad()
108 | def get_init_feats_and_labels(dataset, N, model_refine=None, feat_interpol=False, prt_interpol=True, only_fg=False):
109 | # use seeds to make sure the same samples are used for all models
110 | np.random.seed(0)
111 | indices = np.random.permutation(np.arange(0, len(dataset)))[:N]
112 | feats = []
113 | parts = []
114 | for idx in indices:
115 | feats_idx, parts_idx = get_feats(dataset, idx, model_refine=model_refine, only_fg=only_fg, feat_interpol=feat_interpol, prt_interpol=prt_interpol)
116 | feats.append(feats_idx)
117 | parts.append(parts_idx)
118 | feats = torch.cat(feats)
119 | parts = torch.cat(parts)
120 | return feats, parts
121 |
122 | def to_flattened_idx_torch(x, y, x_width, y_width):
123 | idx = (y.round() + x.round()*y_width).int()
124 | x_, y_ = torch.unravel_index(idx, (x_width, y_width))
125 | # assert that all the indices are the original ones (up to 0.5)
126 | assert torch.all(torch.abs(x_-x)<0.5)
127 | assert torch.all(torch.abs(y_-y)<0.5)
128 | return idx
129 |
130 |
131 | def get_multi_cat_dataset(dataset_list, featurizer, featurizer_kwargs, cat_list_=None, model_seg_name=None):
132 | datasets = []
133 | for dataset in dataset_list:
134 | cat_list = dataset.all_cats if cat_list_ is None else cat_list_
135 | for cat in cat_list:
136 | dataset.padding_kps = True
137 | dataset.init_kps_cat(cat)
138 | datasets.append(copy.deepcopy(dataset))
139 | datasets[-1].featurizer = featurizer
140 | datasets[-1].featurizer_kwargs = featurizer_kwargs
141 | datasets[-1].model_seg_name = model_seg_name
142 | datasets[-1].return_masks = True
143 | return ConcatDataset(datasets)
144 |
145 | # def get_multi_cat_dataset(dataset_train, cat_list, featurizer, featurizer_kwargs):
146 | # datasets = []
147 | # for cat in cat_list:
148 | # dataset_train.padding_kps = True
149 | # dataset_train.init_kps_cat(cat)
150 | # datasets.append(copy.deepcopy(dataset_train))
151 | # datasets[-1].featurizer = featurizer
152 | # datasets[-1].featurizer_kwargs = featurizer_kwargs
153 | # return ConcatDataset(datasets)
154 |
155 |
--------------------------------------------------------------------------------
/src/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | STOREBASEPATH = '/storage/group/cvpr/regine/dataset_matching_symm/figures_new'
--------------------------------------------------------------------------------
/src/evaluation/evaluation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from src.evaluation.pck import evaluate as evaluate_pck_cat
4 | from src.evaluation.pck import get_per_point_pck
5 | from src.evaluation.segmentation import evaluate as evaluate_seg_cat
6 | from src.logging.log_results import log_wandb
7 |
8 | from src.dataset.cub_200_pairs import CUBPairDataset
9 | from src.dataset.apk_pairs import AP10KPairs
10 | from src.dataset.spair import SpairDataset2
11 | from src.dataset.pfpascal_pairs import PFPascalPairsOrig
12 | from src.dataset.pascalparts import PascalParts
13 | import copy
14 |
15 | from src.models.featurizer.utils import get_featurizer_name, get_featurizer
16 | class Evaluation():
17 | def __init__(self, args, featurizer=None, split='test') -> None:
18 | self.epoch = None
19 | self.args = copy.deepcopy(args)
20 | self.split = split
21 | torch.cuda.set_device(0)
22 | self.args.dataset.sup = "sup_original"
23 | if self.args.dataset.name == 'spair':
24 | self.dataset_test_pck = SpairDataset2(self.args.dataset, split=split)
25 | precomputed = False
26 | elif self.args.dataset.name == 'cub':
27 | self.args.dataset.borders_cut = False
28 | self.dataset_test_pck = CUBPairDataset(self.args.dataset, split=split)
29 | self.args.dataset.borders_cut = True
30 | precomputed = False
31 | elif self.args.dataset.name == 'ap10k':
32 | if hasattr(self.args, 'upsample'):
33 | self.args.upsample = False
34 | self.dataset_test_pck = AP10KPairs(self.args.dataset, split=split)
35 | precomputed = False
36 | elif self.args.dataset.name == 'pfpascal':
37 | self.dataset_test_pck = PFPascalPairsOrig(self.args.dataset, split=split)
38 | precomputed = False
39 | else:
40 | self.dataset_test_pck = None
41 | self.dataset_test_ot_pairs = None
42 |
43 | if self.args.dataset.name == 'pascalparts':
44 | self.dataset_test_segmentation = PascalParts(self.args.dataset, split="test")
45 | self.dataset_train_segmentation = PascalParts(self.args.dataset, split="train")
46 | precomputed = False
47 | else:
48 | self.dataset_test_segmentation = None
49 | self.dataset_train_segmentation = None
50 |
51 | self.init_featurizer(precomputed, featurizer)
52 | self.reset_cat()
53 |
54 | def init_featurizer(self, precomputed, featurizer):
55 | if featurizer is not None:
56 | if self.args.featurizer.model == 'dift_sd':
57 | self.args.featurizer.all_cats = self.dataset_test_pck.all_cats
58 | featurizer_name = featurizer.name
59 | elif not precomputed:
60 | featurizer = get_featurizer(self.args.featurizer)
61 | featurizer_name = featurizer.name
62 | else:
63 | featurizer_name = get_featurizer_name(self.args.featurizer)
64 | # test datasets
65 | for dataset in [self.dataset_test_pck, self.dataset_test_segmentation, self.dataset_train_segmentation]:
66 | if dataset is not None:
67 | if not precomputed:
68 | dataset.featurizer_name = featurizer_name
69 | dataset.featurizer = featurizer
70 | dataset.featurizer_kwargs = self.args.featurizer
71 | else:
72 | dataset.featurizer_name = featurizer_name
73 |
74 | def reset_cat(self):
75 | for dataset in [ self.dataset_test_pck, self.dataset_test_segmentation, self.dataset_train_segmentation]:
76 | if dataset is not None:
77 | dataset.all_cats_eval = dataset.all_cats
78 |
79 | def add_suffix(self, result, suffix):
80 | return {k+suffix:v for k,v in result.items()}
81 |
82 | @torch.no_grad()
83 | def evaluate_pck(self, model_refine=None, n_pairs_eval_pck=None, suffix=''):
84 | results = {}
85 | if self.dataset_test_pck is None:
86 | return results
87 | print("evaluate PCK...")
88 |
89 | if n_pairs_eval_pck is None:
90 | n_pairs = self.args.n_pairs_eval_pck
91 | else:
92 | n_pairs = n_pairs_eval_pck
93 |
94 | # per point PCK (over all categories)
95 | n_total = {'10': 0, '01': 0, '11': 0, '00': 0, '1x': 0, '1x_hat': 0, '10_hat': 0, '11_hat': 0, '11_overline': 0,'11_underline': 0, '01_overline': 0, '11_tilde': 0}
96 | n_total_10 = n_total.copy()
97 | n_total_05 = n_total.copy()
98 | n_total_15 = n_total.copy()
99 |
100 | for cat in self.dataset_test_pck.all_cats_eval:
101 | print(f'... for cat: {cat}')
102 | # evaluate the PCK for each category
103 | self.dataset_test_pck.init_kps_cat(cat)
104 | if model_refine is not None:
105 | model_refine = model_refine.eval()
106 | cat_results, n_total_05_cat, n_total_10_cat, n_total_15_cat = evaluate_pck_cat(cat, self.dataset_test_pck, self.args.upsample, self.args.alpha_bbox, n_pairs=n_pairs, model_refine=model_refine)
107 | results.update(cat_results)
108 |
109 | # per point PCK (over all categories)
110 | for key in n_total.keys():
111 | n_total_05[key] += n_total_05_cat[key]
112 | n_total_10[key] += n_total_10_cat[key]
113 | n_total_15[key] += n_total_15_cat[key]
114 |
115 | # per point PCK (over all categories)
116 | for n_total, alph in zip([n_total_10, n_total_05, n_total_15], ['0.1', '0.05', '0.15']):
117 | results_alph = get_per_point_pck(n_total, '', alph)
118 | results.update(results_alph)
119 |
120 | if self.split in ['train', 'val']:
121 | results = self.add_suffix(results, f'_{self.split}')
122 | if n_pairs_eval_pck is not None:
123 | results = self.add_suffix(results, f'_{n_pairs}pairs')
124 | if suffix != '':
125 | results = self.add_suffix(results, f'_{suffix}')
126 | if self.epoch is not None:
127 | results['epoch'] = self.epoch
128 | log_wandb(results)
129 | self.results_pck = results
130 |
131 | @torch.no_grad()
132 | def evaluate_seg(self, model_refine=None, suffix=''):
133 | results, results_mean = {}, {}
134 | if self.dataset_test_segmentation is None:
135 | return results
136 | print("evaluate Segmentation...")
137 | modeldict = {'model_refine': model_refine, 'dataset_test': self.dataset_test_segmentation, 'dataset_train': self.dataset_train_segmentation}
138 | for cat in self.dataset_test_segmentation.all_cats_eval:
139 | print(f'... for cat: {cat}')
140 | self.dataset_test_segmentation.init_kps_cat(cat)
141 | self.dataset_train_segmentation.init_kps_cat(cat)
142 | if model_refine is not None:
143 | model_refine = model_refine.eval()
144 | cat_results = evaluate_seg_cat(self.args, modeldict)
145 | for k,v in cat_results.items():
146 | if not k in results_mean.keys():
147 | results_mean[k] = []
148 | results_mean[k].append(v)
149 | results.update({'segmentation'+cat+'_'+k: v for k, v in cat_results.items()})
150 |
151 | for k,v in results_mean.items():
152 | results_mean[k] = torch.tensor(v).mean().numpy()
153 | results.update({'segmentation_mean_'+k: v for k, v in results_mean.items()})
154 | if self.split in ['train', 'val']:
155 | results = self.add_suffix(results, f'_{self.split}')
156 | if suffix != '':
157 | results = self.add_suffix(results, f'_{suffix}')
158 | if self.epoch is not None:
159 | results['epoch'] = self.epoch
160 | log_wandb(results)
161 |
--------------------------------------------------------------------------------
/src/evaluation/pck.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from tqdm import tqdm
4 | from src.matcher.argmaxmatcher import ArgmaxMatcher
5 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
6 | import copy
7 | from src.evaluation import STOREBASEPATH
8 | NUM_VIZ = 15
9 | from src.logging.visualization_pck import plot_src, plot_trg
10 |
11 | def viz_pck_alpha(dataset, idx, data, src_to_trg_point, heatmap_, store_path, count_viz, key):
12 | if key == '00':
13 | return count_viz
14 | if count_viz[key] 0 else 0
101 | results['per point PCK@'+alph+' \hat{n}_11/n_11' + cat] = n_total['11_hat'] / n_total['11'] * 100 if n_total['11'] > 0 else 0
102 | results['per point PCK@'+alph+' \hat{n}_1x/n_1x' + cat] = n_total['1x_hat'] / n_total['1x'] * 100 if n_total['1x'] > 0 else 0
103 | results['per point PCK@'+alph+' \overline{n}_11/n_11 ' + cat] = n_total['11_overline'] / n_total['11'] * 100 if n_total['11'] > 0 else 0
104 | results['per point PCK@'+alph+' \tilde{n}_11/n_11 ' + cat] = n_total['11_tilde'] / n_total['11'] * 100 if n_total['11'] > 0 else 0
105 | # results['per point PCK@'+alph+' \overline{n}_01/n_01 ' + cat] = n_total['01_overline'] / n_total['01'] * 100 if n_total['01'] > 0 else 0
106 | return results
107 |
108 | @torch.no_grad()
109 | def evaluate(cat, dataset, upsample, bbox, n_pairs, model_refine=None, path=STOREBASEPATH+'/05_experiments/', visualize=False):
110 | if model_refine is not None:
111 | store_path = path+'pck/'+dataset.name+'/'+cat+'/'+model_refine.id+'/'
112 | else:
113 | store_path = path+'pck/'+dataset.name+'/'+cat+'/'+dataset.featurizer_name+'/'
114 | matcher = ArgmaxMatcher()
115 | results = {}
116 | count_viz = {'10': 0, '01': 0, '11': 0, '00': 0, '1x': 0}
117 | # iterate over all categories
118 | dataset.init_kps_cat(cat)
119 | # init the counters
120 | n_total_10 = {'10': 0, '01': 0, '11': 0, '00': 0, '1x': 0, '1x_hat': 0, '10_hat': 0, '11_hat': 0, '11_overline': 0,'11_underline': 0, '01_overline': 0, '11_tilde': 0}
121 | n_total_05 = copy.deepcopy(n_total_10)
122 | n_total_15 = copy.deepcopy(n_total_10)
123 | alpha = 0.1
124 |
125 | # iterate over all test image pairs in the category
126 | perimg_PCK = []
127 | for i, data in enumerate(tqdm(dataset)):
128 | if i == n_pairs:
129 | break
130 |
131 | # get the data for the pair
132 | src_ft = data['src_ft'][None].to(device)
133 | trg_ft = data['trg_ft'][None].to(device)
134 | if model_refine!=None:
135 | ft_orig = [src_ft, trg_ft]
136 | ft_new = [model_refine(f) for f in ft_orig]
137 | src_ft = ft_new[0]
138 | trg_ft = ft_new[1]
139 |
140 | # get the spatial resolution of the feature maps to match the original image size, where keypoints are annotated
141 | src_img_size = data['src_imsize']
142 | trg_img_size = data['trg_imsize']
143 | if bbox:
144 | trg_bndbox = data['trg_bndbox']
145 | threshold = max(trg_bndbox[3] - trg_bndbox[1], trg_bndbox[2] - trg_bndbox[0])
146 | else:
147 | threshold = max(trg_img_size[0], trg_img_size[1])
148 |
149 | # init the per image counters
150 | n_img_10 = {'10': 0, '01': 0, '11': 0, '00': 0, '1x': 0, '1x_hat': 0, '10_hat': 0, '11_hat': 0, '11_overline': 0, '01_overline': 0, '11_tilde': 0}
151 | n_img_05 = copy.deepcopy(n_img_10)
152 | n_img_15 = copy.deepcopy(n_img_10)
153 |
154 | # iterate over all points in the pair and find the second point in the target image by argmax matching between query feature and all target features
155 | for idx in range(len(data['src_kps'])):
156 |
157 | # skip the points that are not annotated
158 | if data['src_kps'][idx][2] == 0:
159 | continue
160 |
161 | # match the keypoints
162 | src_point = data['src_kps'][idx]
163 | ft0, ft1, trg_ft_size = matcher.prepare_one_to_all(src_ft, trg_ft, src_point, src_img_size, trg_img_size, upsample)
164 | heatmap, prob = matcher(ft0, ft1)
165 | src_to_trg_point, heatmap_ = matcher.get_one_trg_point(heatmap[0], prob[0], trg_ft_size, trg_img_size)
166 |
167 | match_vis = data['trg_kps'][idx][2] > 0.5
168 | symm_vis = data['trg_kps_symm_only'][idx][2] > 0.5 if dataset.pck_symm else False
169 |
170 | # update counters
171 | n_img_10, key_10 = update_counters_pck_alpha(dataset, idx, data, src_to_trg_point, threshold, 0.10, match_vis, symm_vis, n_img_10)
172 | n_img_05, key_05 = update_counters_pck_alpha(dataset, idx, data, src_to_trg_point, threshold, 0.05, match_vis, symm_vis, n_img_05)
173 | n_img_15, key_15 = update_counters_pck_alpha(dataset, idx, data, src_to_trg_point, threshold, 0.15, match_vis, symm_vis, n_img_15)
174 | # visualize the keypoints
175 | if visualize:
176 | count_viz = viz_pck_alpha(dataset, idx, data, src_to_trg_point, heatmap_, store_path, count_viz, key_10)
177 |
178 | # update the counters
179 | for key, value in n_img_10.items():
180 | n_total_10[key] += value
181 | for key, value in n_img_05.items():
182 | n_total_05[key] += value
183 | for key, value in n_img_15.items():
184 | n_total_15[key] += value
185 |
186 | n_total_pck_denom = n_total_10['10'] + n_total_10['11'] + n_total_10['1x']
187 |
188 | results_10 = get_per_point_pck(n_total_10, cat, '0.1')
189 | results.update(results_10)
190 | results_05 = get_per_point_pck(n_total_05, cat, '0.05')
191 | results.update(results_05)
192 | results_15 = get_per_point_pck(n_total_15, cat, '0.15')
193 | results.update(results_15)
194 |
195 | results['n_10/n ' + cat] = n_total_10['10'] / n_total_pck_denom if n_total_pck_denom > 0 else 0
196 | results['n_11/n ' + cat] = n_total_10['11'] / n_total_pck_denom if n_total_pck_denom > 0 else 0
197 | results['n_1x/n ' + cat] = n_total_10['1x'] / n_total_pck_denom if n_total_pck_denom > 0 else 0
198 | results['n ' + cat] = n_total_pck_denom
199 |
200 | return results, n_total_05, n_total_10, n_total_15
--------------------------------------------------------------------------------
/src/evaluation/runtime_mem.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from tqdm import tqdm
4 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
5 | from src.evaluation import STOREBASEPATH
6 | import time
7 |
8 | def get_model_size(model):
9 | param_size = 0
10 | n_params = 0
11 | for param in model.parameters():
12 | param_size += param.nelement() * param.element_size()
13 | n_params += param.nelement()
14 | buffer_size = 0
15 | n_buffers = 0
16 | for buffer in model.buffers():
17 | buffer_size += buffer.nelement() * buffer.element_size()
18 | n_buffers += buffer.nelement()
19 |
20 | print('n_params:', n_params)
21 | print('n_buffers:', n_buffers)
22 | size_all_mb = (param_size + buffer_size) / 1024**2
23 | print('model size: {:.3f}MB'.format(size_all_mb))
24 | return size_all_mb
25 |
26 | @torch.no_grad()
27 | def evaluate(dataset, n_imgs, model_refine=None, path=STOREBASEPATH+'/05_experiments/'):
28 | results = {}
29 |
30 | assert(dataset.featurizer is not None)
31 | # measure the runtime and memory usage
32 | # parameter sizes of dataset.featurizer model:
33 | models = dataset.featurizer.get_models()
34 | sizes = [get_model_size(model) for model in models]
35 | results["size featurizer (MB)"] = sum(sizes)
36 |
37 | if model_refine!=None:
38 | results["size refiner (MB)"] = get_model_size(model_refine)
39 | # measure the runtime
40 | i=0
41 | time_sum = 0
42 | for cat in tqdm(dataset.all_cats):
43 | dataset.init_kps_cat(cat)
44 | for idx_ in range(len(dataset)):
45 | if i == n_imgs:
46 | break
47 |
48 | img = dataset._get_img(idx_)
49 |
50 | starttime = time.time()
51 | ft_orig = dataset.featurizer.forward(img,
52 | category=cat,
53 | **dataset.featurizer_kwargs)
54 | if model_refine!=None:
55 | ft_new = model_refine(ft_orig)
56 | endtime = time.time()
57 | time_sum += endtime - starttime
58 | i+=1
59 |
60 | endtime = time.time()
61 | results["runtime"] = time_sum/i *1000 # in ms
62 | return results
63 |
--------------------------------------------------------------------------------
/src/evaluation/segmentation.py:
--------------------------------------------------------------------------------
1 | from src.evaluation.segmentation_metrics import confusion_matrix, accuracy, mean_precision, mean_recall, mean_iou
2 | from src.models.classifier.utils import train_classifier, forward_classifier
3 | from typing import Dict
4 | import torch.functional as F
5 | import torch
6 | from src.logging.visualization_seg import plot_assignment
7 | from src.evaluation import STOREBASEPATH
8 |
9 | def evaluate_img(dataset_test, model_seg, model_refine=None, full_resolution=False, idx=0):
10 | # get segmentation model
11 | prt_pred = forward_classifier(dataset_test, idx, model_seg, model_refine=model_refine)
12 |
13 | data = dataset_test[idx]
14 | prt_gt = data['parts_mask'][None]
15 | # rescale to fit to gt size
16 | if full_resolution:
17 | prt_pred = F.interpolate(prt_pred, size=prt_gt.shape[-2:], mode='bilinear', align_corners=False)
18 | y_max = len(dataset_test.KP_NAMES)
19 |
20 | indices = prt_gt.sum(1)>1e-20 # only consider foreground pixels
21 | # generate labels
22 | prt_gt = prt_gt.argmax(1)+1
23 | prt_pred = prt_pred.argmax(1)+1
24 |
25 | prt_pred[~indices] = 0
26 | prt_gt[~indices] = 0
27 |
28 | conf_matrix = confusion_matrix(prt_pred[indices].detach().cpu()-1, prt_gt[indices].detach().cpu()-1, y_max).detach().cpu()
29 |
30 | cat_metrics: Dict[str, float] = {}
31 | cat_metrics["acc"] = accuracy(conf_matrix)
32 | cat_metrics["m_prcn"] = mean_precision(conf_matrix)
33 | cat_metrics["m_rcll"] = mean_recall(conf_matrix)
34 | cat_metrics["m_iou"] = mean_iou(conf_matrix)
35 |
36 | return cat_metrics, prt_pred, prt_gt
37 |
38 | def evaluate(args, modeldict, full_resolution=False, path=STOREBASEPATH+'/05_experiments/'):
39 |
40 | y_max = len(modeldict['dataset_train'].KP_NAMES)
41 | if y_max<2:
42 | return {}
43 | modeldict['model_seg'] = train_classifier(args.sup_classifier, modeldict['dataset_train'], model_refine=modeldict['model_refine'])
44 | pre_list = []
45 | gt_list = []
46 |
47 | for idx in range(len(modeldict['dataset_test'])):
48 | _, prt_pred, prt_gt = evaluate_img(modeldict['dataset_test'], modeldict['model_seg'], model_refine=modeldict['model_refine'], full_resolution=full_resolution, idx=idx)
49 | pre_list.append(prt_pred.flatten())
50 | gt_list.append(prt_gt.flatten())
51 |
52 | pred = torch.cat(pre_list, 0)
53 | gt = torch.cat(gt_list, 0)
54 | conf_matrix_all = confusion_matrix(pred.detach().cpu(), gt.detach().cpu(), y_max+1).detach().cpu()
55 | assert conf_matrix_all.shape[0] == y_max+1
56 |
57 | def get_results_dict_cat(y_max, conf_matrix, gt, text="", oriented=None):
58 | results_dict_cat = {}
59 | # evaluate the confusion matrix for all parts with at least one pixel in the ground truth
60 | num_gt_pixels = torch.zeros(y_max+1).int()
61 | for i in range(0, y_max+1):
62 | num_gt_pixels[i] = torch.sum(gt==i).int()
63 | if oriented is not None:
64 | if not oriented[i]:
65 | num_gt_pixels[i] = 0
66 |
67 | index = num_gt_pixels==0
68 |
69 | conf_matrix = conf_matrix[~index][:,~index]
70 | results_dict_cat["acc"+text] = accuracy(conf_matrix)
71 | results_dict_cat["m_prcn"+text] = mean_precision(conf_matrix)
72 | results_dict_cat["m_rcll"+text] = mean_recall(conf_matrix)
73 | results_dict_cat["m_iou"+text] = mean_iou(conf_matrix)
74 |
75 | num_gt_pixels = num_gt_pixels[~index]
76 | conf_matrix_normalized = conf_matrix/num_gt_pixels[None]
77 | results_dict_cat["acc_precnorm"+text] = accuracy(conf_matrix_normalized)
78 | results_dict_cat["m_prcn_precnorm"+text] = mean_precision(conf_matrix_normalized)
79 | results_dict_cat["m_rcll_precnorm"+text] = mean_recall(conf_matrix_normalized)
80 | results_dict_cat["m_iou_precnorm"+text] = mean_iou(conf_matrix_normalized)
81 | return results_dict_cat, conf_matrix, conf_matrix_normalized, index
82 |
83 | results_dict_cat, conf_matrix, conf_matrix_normalized, index = get_results_dict_cat(y_max, conf_matrix_all, gt)
84 | visualize = True
85 | def visualization(conf_matrix, conf_matrix_normalized, index, modeldict, path, text=""):
86 | if modeldict['model_refine'] is not None:
87 | store_path = path+'seg/'+modeldict['dataset_test'].name+'/'+modeldict['dataset_test'].cat+'/'+modeldict['model_refine'].id+'/'
88 | else:
89 | store_path = path+'seg/'+modeldict['dataset_test'].name+'/'+modeldict['dataset_test'].cat+'/'+modeldict['dataset_test'].featurizer_name+'/'
90 | # index = assign_dict['conf_matrix_normalized']==torch.nan
91 | assign_dict = {'conf_matrix': conf_matrix,
92 | 'conf_matrix_normalized': conf_matrix_normalized}
93 | kpnames = modeldict['dataset_test'].KP_NAMES.copy()
94 | kpnames = ['bkg']+kpnames
95 | kpnames = [kpnames[i] for i in range(len(kpnames)) if i not in torch.where(index)[0]]
96 | # remove the background
97 | plot_dict = {k:v[1:,1:] for k,v in assign_dict.items()}
98 | plot_kpnames = kpnames[1:]
99 | plot_assignment(plot_dict, plot_kpnames, plot_cmap=True, path=store_path+f'/conf/'+modeldict['dataset_test'].split+"/"+text+"/", limits=[0, 1.0/len(kpnames)])
100 |
101 | if visualize:
102 | visualization(conf_matrix, conf_matrix_normalized, index, modeldict, path, "all")
103 |
104 | oriented = modeldict['dataset_test'].KP_WITH_ORIENTATION
105 | # add bkg to oriented
106 | oriented = torch.cat([torch.tensor([True]), torch.tensor(oriented)])
107 | results_dict_cat_geo, conf_matrix_geo, conf_matrix_normalized_geo, index_geo = get_results_dict_cat(y_max, conf_matrix_all, gt, "_geo", oriented)
108 |
109 | if visualize:
110 | if conf_matrix_geo.shape[0] > 2:
111 | visualization(conf_matrix_geo, conf_matrix_normalized_geo, index_geo, modeldict, path, "geo")
112 |
113 | results_dict_cat.update(results_dict_cat_geo)
114 |
115 | return results_dict_cat
116 |
--------------------------------------------------------------------------------
/src/evaluation/segmentation_metrics.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Tuple
2 |
3 | import torch
4 | from sklearn.metrics import confusion_matrix as sk_conf_matrix
5 |
6 |
7 | def pixelwise_accuracy(prediction: torch.Tensor, annotations: torch.Tensor) -> Tuple[int, int]:
8 | # prediction: B,H,W
9 | # annotations: B,H,W
10 | pred_correct = int(((prediction - annotations) == 0).sum().item())
11 | total = int(prediction.view(-1).shape[0])
12 |
13 | return pred_correct, total
14 |
15 |
16 | def confusion_matrix(prediction: torch.Tensor, annotations: torch.Tensor, num_classes: int) -> torch.Tensor:
17 | conf_matrix = torch.zeros(num_classes, num_classes).int()
18 |
19 | cm = torch.from_numpy(sk_conf_matrix(prediction.reshape(-1).cpu(), annotations.reshape(-1).cpu(), labels=range(num_classes)))
20 | conf_matrix[: cm.shape[0], : cm.shape[1]] = cm
21 |
22 | return conf_matrix
23 |
24 |
25 | def precision(true_positive: int, false_positive: int) -> float:
26 | return (true_positive) / (true_positive + false_positive)
27 |
28 |
29 | def recall(true_positive: int, false_negative: int) -> float:
30 | return (true_positive) / (true_positive + false_negative)
31 |
32 |
33 | def accuracy(conf_matrix: torch.Tensor) -> float:
34 | if torch.sum(conf_matrix) > 1e-20:
35 | return (torch.trace(conf_matrix) / torch.sum(conf_matrix)).item()
36 | else:
37 | return 0.0
38 |
39 |
40 | def mean_precision(conf_matrix: torch.Tensor, mask: Optional[List] = None):
41 | precisions = []
42 | for idx, row in enumerate(conf_matrix):
43 | if mask is not None:
44 | if mask[idx] == 0:
45 | continue
46 | if row.sum() > 1e-20:
47 | precisions.append((row[idx] / row.sum()).item())
48 |
49 | return sum(precisions) / len(precisions) if len(precisions) > 0 else 0
50 |
51 |
52 | def mean_recall(conf_matrix: torch.Tensor, mask: Optional[List] = None):
53 | recalls = []
54 | for idx, col in enumerate(conf_matrix.transpose(0, 1)):
55 | if mask is not None:
56 | if mask[idx] == 0:
57 | continue
58 | if col.sum() > 1e-20:
59 | recalls.append((col[idx] / col.sum()).item())
60 |
61 | return sum(recalls) / len(recalls) if len(recalls) > 0 else 0
62 |
63 |
64 | def mean_iou(conf_matrix: torch.Tensor, mask: Optional[List] = None):
65 | ious = []
66 | for idx in range(conf_matrix.shape[0]):
67 | if mask is not None:
68 | if mask[idx] == 0:
69 | continue
70 | denom = (conf_matrix[idx, :].sum() + conf_matrix[:, idx].sum() - conf_matrix[idx, idx])
71 | if denom > 1e-20:
72 | ious.append(
73 | (
74 | conf_matrix[idx, idx] / denom
75 | ).sum().item()
76 | )
77 |
78 | return sum(ious) / len(ious) if len(ious) > 0 else 0
79 |
--------------------------------------------------------------------------------
/src/logging/log_results.py:
--------------------------------------------------------------------------------
1 | import wandb
2 |
3 | def init_wandb(cfg, prefix=''):
4 | # start a new wandb run to track this script
5 | # check if feat model is in the cfg
6 | name = ''
7 | if 'featurizer' in cfg:
8 | name = name+cfg['featurizer']['model']
9 | if 'feat_refine' in cfg:
10 | name = name+'+'+cfg['feat_refine']['model']
11 | name = name+'+'+cfg['dataset']['name']
12 | if 'train' in prefix:
13 | if 'sup' in cfg['dataset']:
14 | name = name+cfg['dataset']['sup']
15 | if 'dataset2' in cfg:
16 | name = name+'+'+cfg['dataset2']['name']
17 | if 'sup' in cfg['dataset2']:
18 | name = name+cfg['dataset2']['sup']
19 | if 'dataset3' in cfg:
20 | name = name+'+'+cfg['dataset3']['name']
21 | if 'sup' in cfg['dataset3']:
22 | name = name+cfg['dataset3']['sup']
23 | if 'dataset4' in cfg:
24 | name = name+'+'+cfg['dataset4']['name']
25 | if 'sup' in cfg['dataset4']:
26 | name = name+cfg['dataset4']['sup']
27 |
28 | wandb.init(entity="",
29 | project="",
30 | name=prefix+' '+name,
31 | config=cfg,
32 | settings=wandb.Settings(code_dir="."))
33 |
34 | def log_wandb(results):
35 | if wandb.run:
36 | wandb.log(results)
37 |
38 | def log_wandb_epoch(epoch):
39 |
40 | if wandb.run:
41 | results = {'epoch': epoch}
42 | wandb.log(results)
43 |
44 | def log_wandb_cfg(cfg):
45 | # init wandb if not already
46 | if wandb.run:
47 | wandb.config.update(cfg, allow_val_change=True)
48 |
49 | def log_wandb_ram_usage():
50 | import psutil
51 | import os
52 | process = psutil.Process(os.getpid())
53 | other = psutil.virtual_memory()
54 |
55 | if wandb.run:
56 | wandb.log({ 'ram_usage': process.memory_info().rss/1024/1024
57 | , 'ram_total': other.total/1024/1024
58 | , 'ram_free': other.available/1024/1024
59 | , 'ram_percent': other.percent})
60 | print('ram_usage: ', process.memory_info().rss/1024/1024, 'MB')
61 | print('ram_total: ', other.total/1024/1024, 'MB')
62 |
63 | def finish_wandb():
64 | if wandb.run:
65 | wandb.finish()
--------------------------------------------------------------------------------
/src/logging/visualization.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import matplotlib.pyplot as plt
3 | import torch
4 | import numpy as np
5 | from src.matcher.argmaxmatcher import ArgmaxMatcher
6 | from torch import nn
7 | from src.dataset.utils import to_flattened_idx_torch
8 | from PIL import Image
9 |
10 | class Demo:
11 |
12 | def __init__(self, imgs, ft, img_size):
13 | self.ft = ft # N+1, C, H, W
14 | # check if image is pil image
15 | if imgs[0].__class__.__name__ != 'PngImageFile':
16 | imgs = [Image.fromarray(img) for img in imgs]
17 | self.imgs = imgs
18 | self.num_imgs = len(imgs)
19 | self.img_size = img_size
20 |
21 | def plot_imgs_joint(self, fig_size=3):
22 | # concatenate the images and plot them
23 | # pad imgs to the same height
24 | max_h = max([img.size[1] for img in self.imgs])
25 | imgs = []
26 | for i in range(len(self.imgs)):
27 | if self.imgs[i].size[1] < max_h:
28 | # add zero padding using pillow
29 | img = Image.new('RGB', (self.imgs[i].size[0], max_h), (0, 0, 0))
30 | img.paste(self.imgs[i], (0, 0))
31 | else:
32 | img = self.imgs[i]
33 | imgs.append(img)
34 |
35 | img = np.concatenate([np.array(img) for img in imgs], axis=1)
36 | fig, ax = plt.subplots(1, 1, figsize=(fig_size*len(self.imgs), fig_size))
37 | plt.tight_layout()
38 | ax.imshow(img)
39 | # no axis
40 | ax.axis('off')
41 | return fig, ax
42 |
43 | def plot_images(self, fig_size=3):
44 | # plot the source image and the target images with the heatmap corresponding to the source point
45 | fig, axes = plt.subplots(1, self.num_imgs, figsize=(fig_size*self.num_imgs, fig_size))
46 | plt.tight_layout()
47 | for i in range(self.num_imgs):
48 | axes[i].imshow(self.imgs[i])
49 | axes[i].axis('off')
50 | if i == 0:
51 | axes[i].set_title('Source image $S$')
52 | else:
53 | axes[i].set_title('Target image $T$')
54 | return fig, axes
55 |
56 | def plot_src(self, axes, src_point, scatter_size):
57 | # plot src image
58 | axes[0].clear()
59 | axes[0].imshow(self.imgs[0])
60 | axes[0].axis('off')
61 | axes[0].scatter(src_point[1], src_point[0], c='r', s=scatter_size)
62 | # scale the point to the feature map size
63 | scale0 = torch.tensor(self.ft[0].shape[-2:]).float()/torch.tensor(self.img_size[0]).float()
64 | src_point_ = (src_point[:2]*scale0).clamp(torch.zeros(2), torch.tensor(self.ft[0].shape[-2:])-1)
65 |
66 | idx_src = to_flattened_idx_torch(src_point_[0], src_point_[1], self.ft[0].shape[-2], self.ft[0].shape[-1])
67 | axes[0].set_title('source image\nidx_src: %d' % idx_src)
68 |
69 | def plot_heatmap(self, axes, i, heatmap, src_to_trg_point, alpha, scatter_size, limits=None):
70 | # plot the heatmap
71 | axes[i].clear()
72 | # heatmap = (heatmap - np.min(heatmap)) / (np.max(heatmap) - np.min(heatmap)) # Normalize to [0, 1]
73 | # convert to grayscale
74 | img_i = self.imgs[i].copy()
75 | img_i = img_i.convert('L')
76 | axes[i].imshow(img_i, alpha=1-alpha, cmap=plt.get_cmap('gray'))
77 | if limits is None:
78 | vmax = np.max(heatmap)*255
79 | vmin = np.min(heatmap)*255
80 | else:
81 | vmin = limits[0]*255
82 | vmax = limits[1]*255
83 | axes[i].imshow(255 * heatmap, alpha=alpha, cmap='viridis', vmin=vmin, vmax=vmax)
84 | axes[i].axis('off')
85 | if src_to_trg_point is not None:
86 | axes[i].scatter(src_to_trg_point[1], src_to_trg_point[0], c='r', s=scatter_size)
87 | axes[i].set_title('target image\nmax value: %.2f' % (heatmap).max())
88 |
89 | def plot_matched_heatmap(self, axes, src_points, alpha, scatter_size, upsample):
90 | matcher = ArgmaxMatcher()
91 | for i in range(1, self.num_imgs):
92 | # prepare the feature maps for matching
93 | ft0, ft1, trg_ft_size = matcher.prepare_one_to_all(self.ft[0], self.ft[i], src_points[0], self.img_size[0], self.img_size[i], upsample=upsample)
94 | # match the feature maps
95 | heatmap, prob = matcher(ft0, ft1)
96 | for j in range(1): # only plot the first source point, in case we have multiple source points
97 | src_to_trg_point, heatmap_ = matcher.get_one_trg_point(heatmap[j], prob[j], trg_ft_size, self.img_size[i])
98 | self.plot_heatmap(axes, i, heatmap_, src_to_trg_point, alpha, scatter_size, limits=(0.5, 1))
99 | del heatmap
100 |
101 | def plot_img_pairs_click(self, fig_size=3, alpha=0.45, scatter_size=70, upsample=True):
102 | fig, axes = self.plot_images(fig_size)
103 | def onclick(event):
104 | if event.inaxes == axes[0]:
105 | with torch.no_grad():
106 | x, y = int(np.round(event.xdata)), int(np.round(event.ydata))
107 | src_point = torch.tensor([y,x,1])
108 | self.plot_src(axes, src_point, scatter_size)
109 | self.plot_matched_heatmap(axes, src_point[None], alpha, scatter_size, upsample)
110 | gc.collect()
111 |
112 | fig.canvas.mpl_connect('button_press_event', onclick)
113 | plt.show()
114 |
115 | def plot_img_pairs(self, src_point, fig_size=3, alpha=0.45, scatter_size=70, upsample=True):
116 | fig, axes = self.plot_images(fig_size)
117 | axes[0].clear()
118 | axes[0].imshow(self.imgs[0])
119 | axes[0].axis('off')
120 | axes[0].scatter(int(src_point[1]), int(src_point[0]), c='r', s=scatter_size) # scatter needs flipped x and y
121 | axes[0].set_title('source image')
122 | # plot trg heatmap
123 | self.plot_matched_heatmap(axes, src_point[None], alpha, scatter_size, upsample)
124 | plt.show()
125 |
126 | def plot_matches(self, src_points, trg_points, fig_size=3, alpha=0.45, scatter_size=40, title='', path=None):
127 | # fig, axes = self.plot_images(fig_size)
128 | fig, axes = self.plot_imgs_joint(fig_size)
129 | #colours = ['r', 'g', 'b', 'c', 'm', 'y', 'k', 'w']
130 | colours = [
131 | (0, 255, 0), (0, 0, 255), (0, 255, 255), (255, 0, 0), (255, 0, 255),
132 | (255, 255, 0), (0, 0, 255), (0, 128, 255), (128, 0, 255), (0, 128, 0),
133 | (128, 0, 0), (0, 0, 128), (128, 128, 0), (0, 128, 128), (128, 0, 128),
134 | ]
135 |
136 | # Normalize RGB values (0-255 -> 0-1)
137 | normalized_colours = [(r/255, g/255, b/255) for r, g, b in colours]
138 | num_colours = len(colours)
139 | # matplot colours are in the range [0, 1]
140 | colours = [(float(r)/255, float(g)/255, float(b)/255) for r, g, b in colours]
141 | for j in range(len(src_points)):
142 | src_point = src_points[j]
143 | if src_point[2]!=0:
144 | axes.scatter(int(src_point[1]), int(src_point[0]), color=normalized_colours[j%num_colours], s=scatter_size) # scatter needs flipped x and y
145 | for j in range(len(trg_points)):
146 | trg_point = trg_points[j]
147 | if trg_point[2]!=0:
148 | offset = self.img_size[0][1]
149 | axes.scatter(int(trg_point[1])+offset, int(trg_point[0]), color=normalized_colours[j%num_colours], s=scatter_size)
150 | c = 'r' if 'neg' in title else 'g' if 'pos' in title else 'b'
151 | if 'neg' in title or 'pos' in title:
152 | # plot red connections for negative matches
153 | for j in range(len(src_points)):
154 | src_point = src_points[j]
155 | trg_point = trg_points[j]
156 | if src_point[2]!=0 and trg_point[2]!=0:
157 | axes.plot([src_point[1], trg_point[1]+offset], [src_point[0], trg_point[0]], color=c, linewidth=2)
158 | if 'bin' in title:
159 | # plot lines from the points to the boundaries
160 | for j in range(len(src_points)):
161 | src_point = src_points[j]
162 | if src_point[2]!=0:
163 | axes.plot([src_point[1], 0], [src_point[0], src_point[0]], color=c, linewidth=2)
164 | for j in range(len(trg_points)):
165 | trg_point = trg_points[j]
166 | offset1 = self.img_size[1][1]
167 | if trg_point[2]!=0:
168 | axes.plot([trg_point[1]+offset, offset+offset1-1], [trg_point[0], trg_point[0]], color=c, linewidth=2)
169 | # fig.suptitle(title)
170 | plt.show()
171 | if path is not None:
172 | fig.savefig(path,bbox_inches='tight')
173 | plt.close('all')
174 |
175 |
176 | def helper_fun(self, src_point, heatmap, axes, alpha, scatter_size, limits):
177 | # scale the point to the feature map size for indexing
178 | scale0 = torch.tensor(self.ft[0].shape[-2:]).float()/torch.tensor(self.img_size[0]).float()
179 | src_point_ = (src_point[:2]*scale0).clamp(torch.zeros(2), torch.tensor(self.ft[0].shape[-2:])-1) # ft indexing
180 |
181 | idx_src = to_flattened_idx_torch(src_point_[0], src_point_[1], self.ft[0].shape[-2], self.ft[0].shape[-1])
182 | heatmap_ = heatmap[idx_src.long().item()]
183 | heatmap_ = nn.Upsample(size=self.img_size[1].tolist(), mode='bilinear')(torch.tensor(heatmap_).view(1,1,self.ft[1].shape[-2],self.ft[1].shape[-1])).squeeze(0).squeeze(0).cpu().numpy()
184 | self.plot_heatmap(axes, 1, heatmap_, None, alpha, scatter_size, limits=limits)
--------------------------------------------------------------------------------
/src/logging/visualization_pck.py:
--------------------------------------------------------------------------------
1 |
2 | import matplotlib.pyplot as plt
3 | import matplotlib
4 | matplotlib.use('Agg')
5 | import os
6 | import numpy as np
7 | import copy
8 |
9 | def crop_centered(img, kp_dict):
10 | if img.size[0] > img.size[1]:
11 | for k, v in kp_dict.items():
12 | if v is not None:
13 | kp_dict[k][1] = kp_dict[k][1] - (img.size[0] - img.size[1]) // 2 # H,W -> W,H from PIL to numpy
14 | if kp_dict[k][1] < 0:
15 | kp_dict[k][1] = 0
16 | elif kp_dict[k][1] >= img.size[1]:
17 | kp_dict[k][1] = img.size[1]-1
18 | img = img.crop(((img.size[0] - img.size[1]) // 2, 0, (img.size[1] + img.size[0]) // 2, img.size[1]))
19 | else:
20 | for k, v in kp_dict.items():
21 | if v is not None:
22 | kp_dict[k][0] = kp_dict[k][0] - (img.size[1] - img.size[0]) // 2 # H,W -> W,H from PIL to numpy
23 | if kp_dict[k][0] < 0:
24 | kp_dict[k][0] = 0
25 | elif kp_dict[k][0] >= img.size[0]:
26 | kp_dict[k][0] = img.size[0]-1
27 | img = img.crop((0, (img.size[1] - img.size[0]) // 2, img.size[0], (img.size[1] + img.size[0]) // 2))
28 | return img, kp_dict
29 |
30 | def crop_centered_heatmap(heatmap):
31 | if heatmap.shape[0] > heatmap.shape[1]:
32 | heatmap = heatmap[(heatmap.shape[0] - heatmap.shape[1]) // 2:(heatmap.shape[1] + heatmap.shape[0]) // 2, :]
33 | else:
34 | heatmap = heatmap[:, (heatmap.shape[1] - heatmap.shape[0]) // 2:(heatmap.shape[1] + heatmap.shape[0]) // 2]
35 | return heatmap
36 |
37 | def plot_src(img_, src_point, alpha=0.2, path=None, scatter_size=70):
38 | img = img_
39 | def plotting():
40 | ax.clear()
41 | ax.imshow(img)
42 | ax.axis('off')
43 | ax.scatter(src_point[1], src_point[0], edgecolor='black', linewidth=1, facecolor='w', s=scatter_size, label="$Q$ query")
44 | if path is not None:
45 | ax.legend(fontsize="20", loc ="upper left")
46 | os.makedirs(os.path.dirname(path), exist_ok=True)
47 | plt.savefig(path, bbox_inches='tight', pad_inches=0)
48 | plt.close('all')
49 | ##############################################
50 | fig, ax = plt.subplots(1, 1, figsize=(6, 6))
51 | plotting()
52 | # ##############################################
53 | # crop to quadratic
54 | fig, ax = plt.subplots(1, 1, figsize=(6, 6))
55 | kp_dict = {'src': src_point}
56 | img,kp_dict = crop_centered(img, kp_dict)
57 | src_point = kp_dict['src']
58 | if path is not None:
59 | # add padding to path
60 | path = path.split('.')
61 | path = path[0] + '_crop.' + path[1]
62 | plotting()
63 |
64 | def plot_trg(img_, heatmap, path=None, scatter_size=70, alpha_gt=0.2, alpha=0.45, limits=None, src_to_trg_point=None, trg_point=None, c='g', trg_point2=None, c2='r'):
65 | img = img_.convert('L').copy()
66 | if limits is None:
67 | vmax = np.max(heatmap)*255
68 | vmin = np.min(heatmap)*255
69 | else:
70 | vmin = limits[0]*255
71 | vmax = limits[1]*255
72 |
73 | def plotting_gt():
74 | ax.clear()
75 | ax.imshow(img_, alpha=1-alpha_gt, cmap=plt.get_cmap('gray'))
76 | ax.axis('off')
77 | if path_plain is not None:
78 | os.makedirs(os.path.dirname(path_plain), exist_ok=True)
79 | plt.savefig(path_plain, bbox_inches='tight', pad_inches=0)
80 | # plt.close('all')
81 | if trg_point is not None and trg_point[2] > 0:
82 | ax.scatter(trg_point[1], trg_point[0], edgecolor='black', linewidth=1, facecolor=c, s=scatter_size, label="$Q_{GT}$ visible \u2714 (1)")
83 | # # plot text at the bottom of the image with a label for the target point
84 | # x, y = img.size[0] // 2, img.size[1] - 20
85 | # s = f'$Q$ \n $Q_s$'
86 | # plt.text(x, y, s, bbox=dict(fill=True, facecolor='w', linewidth=2))
87 | else:
88 | # add to legend without scatter
89 | ax.scatter([], [], edgecolor='black', linewidth=1, facecolor=c, s=scatter_size, label="$Q_{GT}$ visible \u2718 (0)")
90 | if trg_point2 is not None and trg_point2[2] > 0:
91 | ax.scatter(trg_point2[1], trg_point2[0], edgecolor='black', linewidth=1, facecolor=c2, s=scatter_size, label="$Q_{Symm}$ visible \u2714 (1)")
92 | else:
93 | # add to legend without scatter
94 | ax.scatter([], [], edgecolor='black', linewidth=1, facecolor=c2, s=scatter_size, label="$Q_{Symm}$ visible \u2718 (0)")
95 | if path_gt is not None:
96 | ax.legend(fontsize="20", loc ="upper left")
97 | os.makedirs(os.path.dirname(path_gt), exist_ok=True)
98 | plt.savefig(path_gt, bbox_inches='tight', pad_inches=0)
99 | plt.close('all')
100 |
101 | def plotting_pred():
102 | ax.clear()
103 | ax.imshow(img, alpha=1-alpha, cmap=plt.get_cmap('gray'))
104 | ax.imshow(255 * heatmap, alpha=alpha, cmap='viridis', vmin=vmin, vmax=vmax)
105 | ax.axis('off')
106 | if src_to_trg_point is not None:
107 | if heatmap[src_to_trg_point[0], src_to_trg_point[1]] > 0.3:
108 | ax.scatter(src_to_trg_point[1], src_to_trg_point[0], edgecolor='black', linewidth=1, facecolor='y', s=scatter_size)
109 | if path is not None:
110 | os.makedirs(os.path.dirname(path), exist_ok=True)
111 | plt.savefig(path, bbox_inches='tight', pad_inches=0)
112 | plt.close('all')
113 |
114 | ##############################################
115 | fig, ax = plt.subplots(1, 1, figsize=(6, 6))
116 | if path is not None:
117 | path_list = path.split('.')
118 | path_plain = path_list[0] + '_plain.' + path_list[1]
119 | path_gt = path_list[0] + '_gt.' + path_list[1]
120 | plotting_gt()
121 | fig, ax = plt.subplots(1, 1, figsize=(6, 6))
122 | plotting_pred()
123 |
124 | ##############################################
125 | # crop to quadratic
126 | fig, ax = plt.subplots(1, 1, figsize=(6, 6))
127 | kp_dict = {'src': src_to_trg_point, 'trg': trg_point, 'trg2': trg_point2}
128 | img, kp_dict = crop_centered(img, kp_dict)
129 | img_, _ = crop_centered(img_, copy.deepcopy(kp_dict))
130 | src_to_trg_point = kp_dict['src']
131 | trg_point = kp_dict['trg']
132 | trg_point2 = kp_dict['trg2']
133 | heatmap = crop_centered_heatmap(heatmap)
134 | if path is not None:
135 | path_list = path.split('.')
136 | path_plain = path_list[0] + '_plain_crop.' + path_list[1]
137 | path_gt = path_list[0] + '_gt_crop.' + path_list[1]
138 | plotting_gt()
139 | fig, ax = plt.subplots(1, 1, figsize=(6, 6))
140 | if path is not None:
141 | path_list = path.split('.')
142 | path = path_list[0] + '_crop.' + path_list[1]
143 | plotting_pred()
--------------------------------------------------------------------------------
/src/logging/visualization_seg.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import matplotlib
3 | matplotlib.use('Agg')
4 | import wandb
5 |
6 | COLOURS = [
7 | (0, 255, 0), (0, 0, 255), (0, 255, 255), (255, 0, 0), (255, 0, 255),
8 | (255, 255, 0), (0, 0, 255), (0, 128, 255), (128, 0, 255), (0, 128, 0),
9 | (128, 0, 0), (0, 0, 128), (128, 128, 0), (0, 128, 128), (128, 0, 128),
10 | ]
11 |
12 | def plot_assignment(assignment_dict, list_of_kp_names=None, path=None, wandb_suffix='', plot_cmap=False, limits=None):
13 | import os
14 | if path is not None:
15 | if not os.path.exists(path):
16 | os.makedirs(path)
17 | def plot_matrix(M_, title, limits):
18 | if path is not None and 'single' in path and "cos_sim" in title:
19 | M = M_.clone()
20 | limits = [0, 1]
21 | elif path is not None and 'single' in path and "diag" in path:
22 | M = M_.clone()
23 | else:
24 | M = M_.clone()/M_.sum()
25 | ratio = M.shape[1]/M.shape[0]
26 | plt.figure(figsize=(5*ratio,5))
27 | if limits is None:
28 | plt.imshow(M.detach().cpu().numpy(), cmap='jet')
29 | else:
30 | plt.imshow(M.detach().cpu().numpy(), cmap='jet', vmin=limits[0], vmax=limits[1])
31 | plt.yticks(())
32 | plt.xticks(())
33 | # title_suffix = 'max:%.4f' % M.max().item()
34 | # plt.title(title + ' - ' + title_suffix)
35 | plt.tick_params(axis='x', bottom=False, top=True, labelbottom=False, labeltop=True)
36 | if list_of_kp_names is not None:
37 | # x ticks on top
38 | plt.xticks(range(len(list_of_kp_names)), list_of_kp_names, rotation=90)
39 | plt.yticks(range(len(list_of_kp_names)), list_of_kp_names)
40 | # adjust the plot such that the heading and the colorbar are not cut off
41 | plt.subplots_adjust(top=0.78)
42 | if path is not None:
43 | plt.savefig(path+'/matrix_'+title + ".png")
44 | if wandb.run is not None:
45 | # log to wandb
46 | wandb.log({wandb_suffix+'_'+title : wandb.Image(path+'/matrix_'+title + ".png")})
47 | if plot_cmap:
48 | # plot the jet colormap
49 | plt.gca().set_visible(False)
50 | cax = plt.axes((0.0, 0.0, 1.0, 0.05))
51 | plt.colorbar(orientation="horizontal", cax=cax)
52 | plt.savefig(path+'/colourbar_'+title + ".png", bbox_inches='tight')
53 |
54 | plt.close('all')
55 |
56 | for key, value in assignment_dict.items():
57 | # if value.sum() > 0:
58 | # check for sparse matrix
59 | if hasattr(value, 'to_dense'):
60 | plot_matrix(value.to_dense(), key, limits)
61 | else:
62 | plot_matrix(value, key, limits)
63 |
--------------------------------------------------------------------------------
/src/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from src.dataset.pairwise_utils import get_matches, scale_to_feature_dims, get_y_mat_gt_assignment
2 | from src.matcher.ot_matcher import SoftmaxMatcher
3 | import torch
4 |
5 | def pairwise_loss(losses, data, prob, b, prt, masses):
6 | # evaluate the matches by comparing the ground truth matches with the matches from the model
7 | data_matches = get_matches(data, b)
8 | data_matches = scale_to_feature_dims(data_matches, data, b)
9 |
10 | assignment_dict = {}
11 | ft_orig_b = [data['src_ft'][b], data['trg_ft'][b]]
12 | ft_size_b = [torch.tensor(f.shape[-2:]) for f in ft_orig_b]
13 |
14 | for prefix in ['pos', 'bin', 'neg']:
15 | assignment_dict['y_mat_' + prefix] = get_y_mat_gt_assignment(data_matches[prefix+'_src_kps'].clone().detach(), data_matches[prefix+'_trg_kps'].clone().detach(), ft_size_b[0], ft_size_b[1])
16 | idx = assignment_dict['y_mat_' + prefix].coalesce().indices()
17 | prob_prefix = prob[idx[0].long(),idx[1].long()]
18 | # compute the loss
19 | p = 1/prob.shape[0]
20 | if prefix in ['pos', 'bin']:
21 | losses[prefix].append(-torch.log(prob_prefix)*p)
22 | else:
23 | losses[prefix].append(-torch.log(1-prob_prefix)*(1-0))
24 |
25 | # compute the loss
26 | if True:
27 | prefix = 'neg_fg_bkg'
28 | idx0 = prt[0]>0
29 | idx1 = prt[1]==0
30 | idx0 = torch.cat([idx0, torch.zeros(1).bool().to(idx0.device)]) # add the bin to bool mask
31 | idx1 = torch.cat([idx1, torch.zeros(1).bool().to(idx1.device)]) # add the bin to bool mask
32 | prob_prefix = prob[idx0,:][:,idx1]
33 | losses[prefix].append(-torch.log(1-prob_prefix.flatten())*(1-0))
34 | idx0 = prt[0]==0
35 | idx1 = prt[1]>0
36 | idx0 = torch.cat([idx0, torch.zeros(1).bool().to(idx0.device)])
37 | idx1 = torch.cat([idx1, torch.zeros(1).bool().to(idx1.device)])
38 | prob_prefix = prob[idx0,:][:,idx1]
39 | losses[prefix].append(-torch.log(1-prob_prefix.flatten())*(1-0))
40 |
41 | del data_matches, assignment_dict
42 | return losses
43 |
44 | class PairwiseLoss():
45 | def __init__(self, args):
46 |
47 | ot_params = {
48 | 'reg':0.1,
49 | 'reg_kl':10,
50 | 'sinkhorn_iterations':10,
51 | 'mass':0.9,
52 | 'bin_score':0.3
53 | }
54 | self.args = args
55 | self.matcher = SoftmaxMatcher(**ot_params)
56 | self.matcher = self.matcher.eval() # init dict of losses
57 | self.prefix = ['pos', 'bin', 'neg', 'neg_fg_bkg']
58 |
59 | def get_loss(self, src_ft_, trg_ft_, data):
60 |
61 | losses = {p: [] for p in self.prefix}
62 | # compute the loss
63 | B = src_ft_.shape[0]
64 | for b in range(B):
65 | prt = [data['src_mask'][b], data['trg_mask'][b]]
66 | # resize the prt segmentation to feature dimensions
67 | prt = [torch.nn.functional.interpolate(p[None,None].float(), size=(f.shape[-2], f.shape[-1]), mode='nearest')[0,0].bool() for p,f in zip(prt, [data['src_ft'][b], data['trg_ft'][b]])]
68 | prt = [p.flatten() for p in prt]
69 | if prt[0].sum()*prt[1].sum()<1e-10:
70 | continue
71 | # get masses
72 | mass0 = (data['src_kps'][b,:,2].sum()+1)/(data['numkp'][b]+1)*self.matcher.mass
73 | mass1 = (data['trg_kps'][b,:,2].sum()+1)/(data['numkp'][b]+1)*self.matcher.mass
74 | prob = self.matcher(src_ft_[b], trg_ft_[b], segmentations=prt, masses=[mass0, mass1])
75 | # use positive, negative and bin matches to compute the loss
76 | losses = pairwise_loss(losses, data, prob, b, prt, masses=[mass0, mass1])
77 | losses = {k: torch.cat(v) for k,v in losses.items()}
78 | losses = {k: v for k,v in losses.items() if len(v)>0}
79 | return losses
--------------------------------------------------------------------------------
/src/matcher/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/reginehartwig/geco/3d3f37530bd61f36602f8ac840da13200c3c2e72/src/matcher/__init__.py
--------------------------------------------------------------------------------
/src/matcher/argmaxmatcher.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 | class ArgmaxMatcher(nn.Module):
6 | def __init__(self):
7 | super().__init__()
8 |
9 | def forward(self, src_vec, trg_vec):
10 | dev = src_vec.device
11 | n,d = src_vec.shape
12 | m,d = trg_vec.shape
13 | # match the feature vectors at the source point
14 | src_vec = F.normalize(src_vec) # 1, C
15 | trg_vec = F.normalize(trg_vec).transpose(0, 1) #C, H0W0
16 | cos_map = torch.mm(src_vec, trg_vec) #1, H0W0
17 | # get the probability map, which is a one-hot map
18 | #trg_point = torch.unravel_index(cos_map.argmax(), cos_map.shape)
19 | prob = torch.zeros(n+1, m+1).to(dev)
20 | idx_0 = torch.arange(n).to(dev)
21 | idx_1 = cos_map.argmax(dim = 1)
22 | prob[idx_0, idx_1] = 1
23 | return cos_map, prob
24 |
25 | def prepare_one_to_all(self, src_ft, trg_ft, src_point, src_img_size, trg_img_size, upsample):
26 | # We use the feature maps to match src_point to the all target points
27 |
28 | # upsample the feature maps to match the original image size
29 | if upsample:
30 | # print(f"Memory allocated ups: {torch.cuda.memory_allocated()//1024**3} GB")
31 | src_ft = nn.Upsample(size=src_img_size.tolist(), mode='bilinear')(src_ft)
32 | # print(f"Memory allocated: {torch.cuda.memory_allocated()//1024**3} GB")
33 | trg_ft = nn.Upsample(size=trg_img_size.tolist(), mode='bilinear')(trg_ft)
34 | # print(f"Memory allocated: {torch.cuda.memory_allocated()//1024**3} GB")
35 |
36 | # scale the source point to the feature map size, in case we did not upsample the feature maps
37 | # get the scale factor of the feature maps wrt the original image size
38 | src_ft_size = src_ft.shape[-2:]
39 | src_scale = torch.tensor(src_ft_size).float()/src_img_size.float()
40 | src_point_ = src_point.clone()
41 | src_point_[:2] = src_point[:2] * src_scale
42 | src_point_ = src_point_.floor().long()
43 |
44 | # get the feature vector at the source point
45 | num_channel = src_ft.size(1)
46 | src_vec = src_ft[0, :, src_point_[0], src_point_[1]].view(1, num_channel) # 1, C
47 | # get the feature vectors at all target points
48 | trg_vec = trg_ft.reshape(num_channel, -1).transpose(0, 1) # H0W0, C
49 | # get the size of the target feature map
50 | trg_ft_size = trg_ft.shape[-2:]
51 | return src_vec, trg_vec, trg_ft_size
52 |
53 | def get_one_trg_point(self, cos_map, prob, trg_ft_size, trg_img_size):
54 | # Input:
55 | # trg_ft_size: H0, W0 indicating the size of cos_map and prob
56 | # trg_img_size: H, W indicating the size of the original image
57 | # in case upsample is True, trg_ft_size=trg_img_size
58 | # Output:
59 | # trg_point: the target point in the original image size
60 | # cos_map: the cosine map in the original image size
61 | H0,W0 = trg_ft_size[-2], trg_ft_size[-1]
62 | trg_scale = trg_img_size.float()/torch.tensor(trg_ft_size).float()
63 | cos_map = cos_map.view(H0,W0) # H0,W0
64 | bin_prob = prob[-1]
65 | prob = prob[:-1].view(H0,W0)
66 | trg_point = torch.tensor(torch.unravel_index(prob.argmax(), prob.shape))
67 | if prob[trg_point[0], trg_point[1]] torch.Tensor:
12 | """
13 | Perform Sinkhorn Normalization in Log-space for stability
14 | lambda: higher values result in lower entropy
15 | """
16 | u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu)
17 | Z = lmba*Z
18 | for _ in range(iters):
19 | u = log_mu - torch.logsumexp(Z + v.unsqueeze(1)*lmba, dim=2)
20 | u = u/lmba
21 | v = log_nu - torch.logsumexp(Z + u.unsqueeze(2)*lmba, dim=1)
22 | v = v/lmba
23 | return Z + u.unsqueeze(2)*lmba + v.unsqueeze(1)*lmba
24 |
25 | def log_optimal_transport(mu:torch.Tensor, nu:torch.Tensor, couplings: torch.Tensor, reg: float, sinkhorn_iterations: int) -> torch.Tensor:
26 | """ Perform Differentiable Optimal Transport in Log-space for stability"""
27 | B, N_, M_ = couplings.shape
28 | log_mu, log_nu = mu.log(), nu.log()
29 | log_mu, log_nu = log_mu[None].expand(B, -1), log_nu[None].expand(B, -1)
30 | log_mu, log_nu = log_mu.to(couplings.device), log_nu.to(couplings.device)
31 | Z = log_sinkhorn_iterations(couplings, log_mu, log_nu, sinkhorn_iterations, 1/reg)
32 | # Z = Z - norm # multiply probabilities by M+N
33 | return Z.exp()
34 |
35 | # 2
36 | def log_sinkhorn_iterations_kl(Z: torch.Tensor, log_mu: torch.Tensor, log_nu: torch.Tensor, iters: int, lmba: float, reg_kl:float) -> torch.Tensor:
37 | """
38 | Perform Sinkhorn Normalization in Log-space for stability
39 | lambda: higher values result in lower entropy
40 | reg_kl: higher values result in stronger KL regularization
41 | """
42 | u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu)
43 | Z = lmba*Z
44 | phi= reg_kl/(reg_kl+1/lmba)
45 | for _ in range(iters):
46 | u = log_mu - torch.logsumexp(Z + v.unsqueeze(1)*lmba, dim=2)
47 | u = u/lmba * phi
48 | v = log_nu - torch.logsumexp(Z + u.unsqueeze(2)*lmba, dim=1)
49 | v = v/lmba * phi
50 | return Z + u.unsqueeze(2)*lmba + v.unsqueeze(1)*lmba
51 |
52 | def log_optimal_transport_kl(mu:torch.Tensor, nu:torch.Tensor, couplings: torch.Tensor, reg: float, reg_kl: float, sinkhorn_iterations: int) -> torch.Tensor:
53 | """ Perform Differentiable Optimal Transport in Log-space for stability"""
54 | B, N_, M_ = couplings.shape
55 | log_mu, log_nu = mu.log(), nu.log()
56 | log_mu, log_nu = log_mu[None].expand(B, -1), log_nu[None].expand(B, -1)
57 | log_mu, log_nu = log_mu.to(couplings.device), log_nu.to(couplings.device)
58 | Z = log_sinkhorn_iterations_kl(couplings, log_mu, log_nu, sinkhorn_iterations, 1/reg, reg_kl)
59 | # Z = Z - norm # multiply probabilities by M+N
60 | return Z.exp()
61 |
62 | # 3
63 | def distributed_sinkhorn(couplings, reg, sinkhorn_iterations):
64 |
65 | Q = torch.exp(couplings / reg).transpose(-1,-2) # Q is N-by-M for consistency with notations from our paper
66 | N_1 = Q.shape[-2] # how many prototypes
67 | M_1 = Q.shape[-1] # number of samples to assign
68 |
69 | # make the matrix sums to 1
70 | sum_Q = torch.sum(Q)
71 | Q /= sum_Q
72 |
73 | for it in range(sinkhorn_iterations):
74 | # normalize each row: total weight per prototype must be 1/N
75 | sum_of_rows = torch.sum(Q, dim=-1, keepdim=True)
76 | Q /= sum_of_rows
77 | Q /= N_1
78 |
79 | # normalize each column: total weight per sample must be 1/M
80 | Q /= torch.sum(Q, dim=-2, keepdim=True)
81 | Q /= M_1
82 |
83 | Q *= M_1 # the colomns must sum to 1 so that Q is an assignment
84 | return Q.transpose(-1,-2)
85 |
86 |
87 | # 4
88 | def ot_solver(a, b, couplings, type = "partial_wasserstein", reg = 0.005, reg_m_kl = 0.05, reg_m_l2 = 5):
89 | B, M, N = couplings.shape
90 | P_list = []
91 | for i in range(B):
92 | dist = 2-couplings[i]
93 | if type == "entropic":
94 | P = ot.sinkhorn(a, b, dist, reg)
95 | if type == "entropic_kl_uot":
96 | P = ot.unbalanced.sinkhorn_unbalanced(a, b, dist, reg, reg_m_kl)
97 |
98 | if type == "kl_uot":
99 | P = ot.unbalanced.mm_unbalanced(a, b, dist, reg_m_kl, div='kl')
100 |
101 | if type == "l2_uot":
102 | P = ot.unbalanced.mm_unbalanced(a, b, dist, reg_m_l2, div='l2')
103 | P_list.append(P)
104 | P = torch.stack(P_list)
105 | # if type == "partial_ot":
106 | # P = ot.partial.partial_wasserstein(a, b, dist, m=alpha.item())
107 | # P =
108 | return P
109 |
110 | ############################################
111 | # Problem definition, partial OT
112 | ############################################
113 |
114 | def get_partial_ot_problem(scores, bin_score):
115 | if not isinstance(bin_score, torch.Tensor):
116 | bin_score = scores.new_tensor(bin_score)
117 | B, M, N = scores.shape
118 | dev = scores.device
119 | bins0 = bin_score.expand(B, M, 1).to(dev)
120 | bins1 = bin_score.expand(B, 1, N).to(dev)
121 | bin_score = bin_score.expand(B, 1, 1).to(dev)
122 |
123 | couplings = torch.cat([torch.cat([scores, bins0], -1),
124 | torch.cat([bins1, bin_score], -1)], 1)
125 | return couplings
126 |
127 | ############################################
128 | # Problem definition, partial distributions that assign mass to all samples and (1-mass) to the bin
129 | ############################################
130 |
131 | def get_gt_distributions(y_mat):
132 | # ground truth distribution
133 | mu = y_mat.sum(dim=1)/y_mat.sum() # distribution of prototypes
134 | nu = y_mat.sum(dim=0)/y_mat.sum() # distribution of features
135 | return mu, nu
136 |
137 | def get_partial_distributions(N, M, mass):
138 | a, b = torch.ones((N,)) / N, torch.ones((M,)) / M # uniform distribution on samples
139 | a = torch.cat([a*mass, a.new_tensor(1-mass)[None]])
140 | b = torch.cat([b*mass, b.new_tensor(1-mass)[None]])
141 | return a, b
142 |
143 | def get_partial_distributions_input_marginals(prt, masses, mass_fg):
144 | mass_bkg = 1- mass_fg
145 | a = prt[0].flatten() * masses[0]/prt[0].sum()
146 | if a[a<1e-10].shape[0]>0:
147 | mass_bkg_a = torch.tensor(mass_bkg)
148 | a[a<1e-10] = mass_bkg_a/(a[a<1e-10].shape[0])
149 | else:
150 | mass_bkg_a = torch.tensor(0)
151 | # print(a[prt[0].flatten()>1e-10].sum())
152 | a_bin = (1 - (masses[0] + mass_bkg_a)).clone().detach()[None].to(a.device)
153 | a = torch.cat([a, a_bin])
154 |
155 | b = prt[1].flatten() * masses[1]/prt[1].sum()
156 | if b[b<1e-10].shape[0]>0:
157 | mass_bkg_b = torch.tensor(mass_bkg)
158 | b[b<1e-10] = mass_bkg_b/(b[b<1e-10].shape[0])
159 | else:
160 | mass_bkg_b = torch.tensor(0)
161 | # print(a[prt[0].flatten()<1e-10].sum())
162 | b_bin = (1 - (masses[1] + mass_bkg_b)).clone().detach()[None].to(b.device)
163 | b = torch.cat([b, b_bin])
164 | return a, b
165 |
166 | def get_partial_log_distributions(N, M, mass):
167 | # specify the mass to be assigned
168 | one = torch.tensor(1.)
169 | ms, ns = (N*one), (M*one)
170 | # norm = - (ms + ns).log()
171 | # log_mu = torch.cat([norm.expand(m), ns.log()[None] + norm])
172 | # log_nu = torch.cat([norm.expand(n), ms.log()[None] + norm])
173 | mass = torch.tensor(mass)
174 | norm_m = - (ms).log() + mass.log()
175 | norm_n = - (ns).log() + mass.log()
176 | log_mu = torch.cat([norm_m.expand(N), (1-mass).log()[None]])
177 | log_nu = torch.cat([norm_n.expand(M), (1-mass).log()[None]])
178 | return log_mu, log_nu
179 |
180 | ############################################
181 | # The Matcher Module with trainable bin score
182 | ############################################
183 |
184 | def prob_add_bkg(prob_fg, mask0, mask1):
185 | # get the cosine similarity
186 | dev = prob_fg.device
187 | # n,m = prob_fg.shape
188 | n = mask0.shape[0]
189 | m = mask1.shape[0]
190 | # prepare the output, assign the probabilities to the foreground + bin
191 | mask_fg = torch.ones(n+1, m+1).to(dev)
192 | mask_fg[:-1,:-1][~mask0,:] = 0.0
193 | mask_fg[:-1,:-1][:,~mask1] = 0.0
194 | mask_fg[:-1,-1][~mask0] = 0.0
195 | mask_fg[-1,:-1][~mask1] = 0.0
196 | prob = torch.zeros(n+1, m+1).to(dev) #* -torch.inf
197 | prob[mask_fg.bool()] = prob_fg.flatten()
198 | prob[-1,-1] = 0
199 | return prob
200 |
201 | def ft_remove_bkg(ft0, ft1):
202 | th = 1e-8
203 | # get mask the background features
204 | mask0 = ft0.norm(dim=-1) > th
205 | mask1 = ft1.norm(dim=-1) > th
206 |
207 | ft0_fg = ft0[mask0]
208 | ft1_fg = ft1[mask1]
209 | return ft0_fg, ft1_fg, mask0, mask1
210 |
211 | class SoftmaxMatcher(nn.Module):
212 | def __init__(self, sinkhorn_iterations=100, bin_score=0.4, mass=0.9, reg=0.1, reg_kl=0.01):
213 | super().__init__()
214 | # super(SoftmaxMatcher, self).__init__()
215 | self.sinkhorn_iterations = sinkhorn_iterations
216 | bin_score = torch.nn.Parameter(torch.tensor(bin_score))# DINOv2 default value
217 | self.register_parameter('bin_score', bin_score)
218 | self.mass = mass
219 | self.reg = reg
220 | self.reg_kl = reg_kl
221 |
222 | def forward(self, ft0, ft1, y_mat=None, segmentations=None, masses=None):
223 |
224 | # remove bkg
225 | ft0_fg, ft1_fg, mask0, mask1 = ft_remove_bkg(ft0, ft1)
226 |
227 | # normalize the features (worse results!)
228 | ft0_fg = ft0_fg/ft0_fg.norm(dim=-1)[:,None]
229 | ft1_fg = ft1_fg/ft1_fg.norm(dim=-1)[:,None]
230 |
231 | cos_sim = torch.mm(ft0_fg, ft1_fg.t()) # N, M
232 | # Run the optimal transport on foreground features
233 | N, M = cos_sim.shape
234 | device = cos_sim.device
235 |
236 | # no gt y_mat available, we use the cosine similarity matrix
237 | if y_mat is not None:
238 | mu, nu = get_gt_distributions(y_mat)
239 | elif segmentations is not None and masses is not None:
240 | mu, nu = get_partial_distributions_input_marginals(segmentations, masses, self.mass)
241 | else:
242 | mu, nu = get_partial_distributions(N, M, mass=self.mass)
243 | couplings = get_partial_ot_problem(cos_sim[None], bin_score=self.bin_score)
244 | mu, nu = mu.to(device), nu.to(device)
245 | prob_out = log_optimal_transport_kl(mu, nu, couplings, reg=self.reg, reg_kl=self.reg_kl, sinkhorn_iterations=self.sinkhorn_iterations)[0] # diverging if reg_kl<1
246 | # prob_out = log_optimal_transport(mu, nu, couplings, reg=self.reg, sinkhorn_iterations=self.sinkhorn_iterations)[0]
247 |
248 | # add bkg
249 | prob = prob_add_bkg(prob_out, mask0, mask1)
250 |
251 | del prob_out, cos_sim
252 | return prob
253 |
254 | def get_cossim(self, ft0, ft1):
255 | ft0_fg, ft1_fg, mask0, mask1 = ft_remove_bkg(ft0, ft1)
256 | # normalize the features (worse results!)
257 | ft0_fg = ft0_fg/ft0_fg.norm(dim=-1)[:,None]
258 | ft1_fg = ft1_fg/ft1_fg.norm(dim=-1)[:,None]
259 |
260 | cos_sim = torch.mm(ft0_fg, ft1_fg.t())
261 | cos_sim = F.pad(cos_sim, (0,1,0,1), value=0)
262 | return cos_sim
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/reginehartwig/geco/3d3f37530bd61f36602f8ac840da13200c3c2e72/src/models/__init__.py
--------------------------------------------------------------------------------
/src/models/classifier/supervised/nearest_centroid.py:
--------------------------------------------------------------------------------
1 | from src.models.pca import compute_pca
2 | from sklearn.neighbors import NearestCentroid
3 | import torch
4 | import torch.nn as nn
5 |
6 | class Nearest_Centroid_fg_Classifier(nn.Module):
7 | name = 'nearest_centroid_fg'
8 | def __init__(self, n_components_pca, x, y_mat):
9 | '''
10 | Input:
11 | n_components_pca: int, number of components for PCA
12 | x: torch.tensor, shape (B, feature_dim), feature vectors
13 | y_mat: torch.tensor, shape (B, num_parts), part labels, one-hot encoding,
14 | no background part, as we assume to only receive foreground features
15 | '''
16 | super().__init__()
17 | feature_dim = x.shape[-1]
18 | x = x.reshape(-1,feature_dim)
19 | x = x/x.norm(dim=-1, keepdim=True)
20 | self.part_components = compute_pca(x, n_components_pca)
21 | x_proj = torch.mm(x, self.part_components.t())
22 | self.classifier = NearestCentroid(metric="manhattan")
23 | self.num_parts = y_mat.shape[-1]
24 | y = y_mat.argmax(-1)
25 | self.parts = torch.unique(y)
26 | # assert(len(self.parts) == self.num_parts) # check if all parts are present
27 | x_proj = x_proj/x_proj.norm(dim=-1, keepdim=True)
28 | self.classifier.fit(x_proj.cpu().numpy(), y.cpu().numpy())
29 | self.prototypes_proj = torch.tensor(self.classifier.centroids_, device=self.part_components.device, dtype=self.part_components.dtype) # num_parts, n_components_pca
30 | self.prototypes = torch.mm(self.prototypes_proj, self.part_components)
31 |
32 | def get_prototypes(self):
33 | if len(self.parts) == self.num_parts:
34 | return self.prototypes
35 | else:
36 | prototypes = torch.zeros(self.num_parts, self.prototypes.shape[-1], device=self.prototypes.device, dtype=self.prototypes.dtype)
37 | prototypes[self.parts] = self.prototypes
38 | return prototypes
39 |
40 | def forward(self, x):
41 | '''
42 | Input:
43 | x: torch.tensor, shape (B, feature_dim), feature vectors
44 | Output:
45 | y_mat: torch.tensor, shape (B, num_parts), part labels, one-hot encoding
46 | probably neg values for background, but not trained for that
47 | '''
48 | feature_dim = x.shape[-1]
49 | # get the scalar product of the input with the principal components
50 | x = x.reshape(-1,feature_dim)
51 | x_proj = torch.mm(x, self.part_components.t())
52 | x_proj = x_proj/x_proj.norm(dim=-1, keepdim=True)
53 | prototypes_proj = self.prototypes_proj/self.prototypes_proj.norm(dim=-1, keepdim=True)
54 | y_mat = torch.zeros((x.shape[0],self.num_parts)).to(x.device) # B, num_parts
55 | y_mat_ = torch.mm(x_proj, prototypes_proj.t())
56 | y_mat[:, self.parts] = y_mat_
57 | return y_mat
58 |
--------------------------------------------------------------------------------
/src/models/classifier/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from src.dataset.utils import get_init_feats_and_labels
4 | import torch.nn.functional as F
5 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
6 | from src.models.classifier.supervised.nearest_centroid import Nearest_Centroid_fg_Classifier
7 |
8 | def train_classifier(args_seg, dataset_train, model_refine=None):
9 |
10 | if args_seg.model == 'nearest_centroid_fg':
11 | feats, y_mat = get_init_feats_and_labels(dataset_train, args_seg.num_samples, only_fg=True, model_refine=model_refine)
12 | model_seg = Nearest_Centroid_fg_Classifier(args_seg.num_pcaparts, feats, y_mat)
13 |
14 | return model_seg
15 |
16 | def forward_classifier(dataset, idx, model_seg, model_refine=None):
17 | data = dataset[idx]
18 |
19 | # get the part segmentation
20 | imsize = data['imsize']
21 | if model_seg.name in ['nearest_centroid_fg']:
22 | ft = data['ft'][None].to(device) # B, C, H, W
23 | if model_refine is not None:
24 | ft = model_refine(ft)
25 | ft_interp = F.interpolate(ft, size=imsize.tolist(), mode='bilinear', align_corners=False)
26 | prt = model_seg(ft_interp.permute(0,2,3,1).flatten(0,-2)) # B, C, H, W -> H*W, C
27 | prt = prt.reshape(ft.shape[0], imsize[0], imsize[1], -1).permute(0,3,1,2) # H*W, C -> B, C, H, W
28 |
29 | return prt
--------------------------------------------------------------------------------
/src/models/featurizer/clip.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import open_clip
4 | import math
5 |
6 |
7 | def interpolate_pos_encoding(clip_model, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
8 | """
9 | This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
10 | resolution images.
11 | Source:
12 | https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
13 | """
14 |
15 | num_patches = embeddings.shape[1] - 1
16 | pos_embedding = clip_model.positional_embedding.unsqueeze(0)
17 | num_positions = pos_embedding.shape[1] - 1
18 | if num_patches == num_positions and height == width:
19 | return clip_model.positional_embedding
20 | class_pos_embed = pos_embedding[:, 0]
21 | patch_pos_embed = pos_embedding[:, 1:]
22 | dim = embeddings.shape[-1]
23 | h0 = height // clip_model.patch_size[0]
24 | w0 = width // clip_model.patch_size[1]
25 | # we add a small number to avoid floating point error in the interpolation
26 | # see discussion at https://github.com/facebookresearch/dino/issues/8
27 | h0, w0 = h0 + 0.1, w0 + 0.1
28 | patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
29 | patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
30 | patch_pos_embed = nn.functional.interpolate(
31 | patch_pos_embed,
32 | scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
33 | mode="bicubic",
34 | align_corners=False,
35 | )
36 | assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
37 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
38 | output = torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
39 |
40 | return output
41 |
42 | def get_name():
43 | return 'open_clip'
44 |
45 | class CLIPFeaturizer:
46 | name = 'open_clip'
47 | def __init__(self):
48 | clip_model, _, _ = open_clip.create_model_and_transforms('ViT-H-14', pretrained='laion2b_s32b_b79k')
49 | visual_model = clip_model.visual
50 | visual_model.output_tokens = True
51 | self.clip_model = visual_model.eval().cuda()
52 |
53 |
54 | def get_models(self):
55 | return [self.clip_model]
56 |
57 | @torch.no_grad()
58 | def forward(self,
59 | x, # single image, [1,c,h,w]
60 | block_index, **kwargs):
61 | batch_size = 1
62 | clip_model = self.clip_model
63 | if clip_model.input_patchnorm:
64 | # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')
65 | x = x.reshape(x.shape[0], x.shape[1], clip_model.grid_size[0], clip_model.patch_size[0], clip_model.grid_size[1], clip_model.patch_size[1])
66 | x = x.permute(0, 2, 4, 1, 3, 5)
67 | x = x.reshape(x.shape[0], clip_model.grid_size[0] * clip_model.grid_size[1], -1)
68 | x = clip_model.patchnorm_pre_ln(x)
69 | x = clip_model.conv1(x)
70 | else:
71 | x = clip_model.conv1(x) # shape = [*, width, grid, grid]
72 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
73 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
74 | # class embeddings and positional embeddings
75 | x = torch.cat(
76 | [clip_model.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
77 | x], dim=1) # shape = [*, grid ** 2 + 1, width]
78 | if(x.shape[1] > clip_model.positional_embedding.shape[0]):
79 | dim = int(math.sqrt(x.shape[1]) * clip_model.patch_size[0])
80 | x = x + interpolate_pos_encoding(clip_model, x, dim, dim).to(x.dtype)
81 | else:
82 | x = x + clip_model.positional_embedding.to(x.dtype)
83 |
84 | # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
85 | x = clip_model.patch_dropout(x)
86 | x = clip_model.ln_pre(x)
87 | x = x.permute(1, 0, 2) # NLD -> LND
88 |
89 | num_channel = x.size(2)
90 | ft_size = int((x.shape[0]-1) ** 0.5)
91 |
92 | for i, r in enumerate(clip_model.transformer.resblocks):
93 | x = r(x)
94 |
95 | if i == block_index:
96 | tokens = x.permute(1, 0, 2) # LND -> NLD
97 | tokens = tokens[:, 1:]
98 | tokens = tokens.transpose(1, 2).contiguous().view(batch_size, num_channel, ft_size, ft_size) # NCHW
99 |
100 | return tokens
--------------------------------------------------------------------------------
/src/models/featurizer/dift_adm.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | import torch
4 | from torchvision import transforms
5 | main_path = Path(__file__).resolve().parent.parent.parent.parent
6 | print(f'main path: {main_path}')
7 |
8 | import sys
9 | sys.path.append(os.path.join(main_path, 'guided-diffusion'))
10 | from guided_diffusion.script_util import create_model_and_diffusion
11 | from guided_diffusion.nn import timestep_embedding
12 |
13 | def get_name():
14 | return 'dift_adm'
15 |
16 | class ADMFeaturizer:
17 | name = 'dift_adm'
18 | def __init__(self):
19 | model, diffusion = create_model_and_diffusion(
20 | image_size=256,
21 | class_cond=False,
22 | learn_sigma=True,
23 | num_channels=256,
24 | num_res_blocks=2,
25 | channel_mult="",
26 | num_heads=4,
27 | num_head_channels=64,
28 | num_heads_upsample=-1,
29 | attention_resolutions="32,16,8",
30 | dropout=0.0,
31 | diffusion_steps=1000,
32 | noise_schedule='linear',
33 | timestep_respacing='',
34 | use_kl=False,
35 | predict_xstart=False,
36 | rescale_timesteps=False,
37 | rescale_learned_sigmas=False,
38 | use_checkpoint=False,
39 | use_scale_shift_norm=True,
40 | resblock_updown=True,
41 | use_fp16=False,
42 | use_new_attention_order=False,
43 | )
44 | model_path = os.path.join(main_path, 'guided-diffusion/models/256x256_diffusion_uncond.pt')
45 | model.load_state_dict(torch.load(model_path, map_location="cpu"))
46 | self.model = model.eval().cuda()
47 | self.diffusion = diffusion
48 |
49 | self.adm_transforms = transforms.Compose([
50 | transforms.ToTensor(),
51 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
52 | ])
53 |
54 | @torch.no_grad()
55 | def forward(self, img_tensor,
56 | t=101,
57 | up_ft_index=4,
58 | ensemble_size=8):
59 | model = self.model
60 | diffusion = self.diffusion
61 |
62 | img_tensor = img_tensor.repeat(ensemble_size, 1, 1, 1).cuda() # ensem, c, h, w
63 | t = torch.ones((img_tensor.shape[0],), device='cuda', dtype=torch.int64) * t
64 | x_t = diffusion.q_sample(img_tensor, t, noise=None)
65 |
66 | # get layer-wise features
67 | hs = []
68 | emb = model.time_embed(timestep_embedding(t, model.model_channels))
69 | h = x_t.type(model.dtype)
70 | for module in model.input_blocks:
71 | h = module(h, emb)
72 | hs.append(h)
73 | h = model.middle_block(h, emb)
74 | for i, module in enumerate(model.output_blocks):
75 | h = torch.cat([h, hs.pop()], dim=1)
76 | h = module(h, emb)
77 |
78 | if i == up_ft_index:
79 | ft = h.mean(0, keepdim=True).detach()
80 | return ft
81 |
82 |
83 | class ADMFeaturizer4Eval(ADMFeaturizer):
84 |
85 | @torch.no_grad()
86 | def forward(self, img,
87 | img_size=[512, 512],
88 | t=101,
89 | up_ft_index=4,
90 | ensemble_size=8,
91 | **kwargs):
92 |
93 | img_tensor = self.adm_transforms(img.resize(img_size))
94 | ft = super().forward(img_tensor,
95 | t=t,
96 | up_ft_index=up_ft_index,
97 | ensemble_size=ensemble_size)
98 | del img_tensor
99 | torch.cuda.empty_cache()
100 | return ft
101 |
102 | def get_models(self):
103 | return [self.model]
--------------------------------------------------------------------------------
/src/models/featurizer/dino.py:
--------------------------------------------------------------------------------
1 | import torch
2 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
3 |
4 | from torchvision import transforms as T
5 | def prepare_ViT_images(vit_patch_size, img, img_size):
6 |
7 | transform = T.Compose([
8 | T.Resize(img_size),
9 | T.ToTensor(),
10 | T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
11 | T.ConvertImageDtype(torch.float)
12 | ])
13 | img = transform(img)
14 | w, h = img.shape[-2] - img.shape[-2] % vit_patch_size, img.shape[-1] - img.shape[-1] % vit_patch_size
15 | img = img[..., :w, :h].float().to(device)
16 | # check if shape of image has 3 dimensions
17 | if len(img.shape) == 3:
18 | img = img[None]
19 | return img # B, C, W, H
20 |
21 | def get_name(dino_id='dino_vitb8', log_bin=True, img_size='224', up_ft_index='4', **kwargs):
22 | name = 'dino'
23 | if log_bin:
24 | name = name + '_logbin'
25 | return name + f'_%d_upft%d' % (img_size[0], up_ft_index)
26 |
27 | class DINOFeaturizer:
28 | name = 'dino'
29 | def __init__(self, dino_id='dino_vitb8', log_bin=True, img_size='224', up_ft_index='4', **kwargs):
30 | model = torch.hub.load('facebookresearch/dino:main', dino_id)
31 | self.model = model.eval().to(device)
32 | self._use_log_bin = log_bin
33 | self.name = get_name(dino_id, log_bin, img_size, up_ft_index, **kwargs)
34 | self.vit_patch_size = 8
35 |
36 | def get_models(self):
37 | return [self.model]
38 |
39 | def prepare_tokens(self, img_tensor, up_ft_index):
40 | B, C, W, H = img_tensor.shape
41 | # we return the output tokens from the `n` last blocks as a list
42 | out = self.model.get_intermediate_layers(img_tensor, n=up_ft_index)
43 | out = out[0] # take the output of the the n-th last block
44 | out = out[:, 1:, :] # B, w0*h0, D
45 | D = out.shape[-1]
46 | out = out.transpose(-2, -1).view(B, D, self.w0, self.h0)
47 | return out
48 |
49 | @torch.no_grad()
50 | def forward(self, img, up_ft_index=3, **kwargs):
51 | # convert pil to tensor
52 | img_size = kwargs.get('img_size')
53 | img_tensor = prepare_ViT_images(self.vit_patch_size, img, img_size) # [B, C, W, H]
54 | self.w0 = img_tensor.shape[-2] // self.vit_patch_size
55 | self.h0 = img_tensor.shape[-1] // self.vit_patch_size
56 | out = self.prepare_tokens(img_tensor, up_ft_index) # B, D, w0, h0
57 | if self._use_log_bin:
58 | out = self._log_bin(out.permute(0,2,3,1))
59 | out = out.view(-1,self.w0,self.h0,out.shape[-1]) # B, w0, h0, D'
60 | out = out.permute(0,3,1,2) # B, D', w0, h0
61 | return out
62 |
63 | def get_w0_h0(self):
64 | return self.w0, self.h0
65 |
66 | def _log_bin(self, x: torch.Tensor, hierarchy: int = 2) -> torch.Tensor:
67 | """
68 | create a log-binned descriptor.
69 | :param x: tensor of features. Has shape Bxwxhxd.
70 | :param hierarchy: how many bin hierarchies to use.
71 | """
72 | B, w0, h0, d = x.shape
73 | num_bins = 1 + self.vit_patch_size * hierarchy
74 | bin_x = x.permute(0, 3, 1, 2) # B, d, w0, h0
75 | sub_desc_dim = bin_x.shape[1]
76 |
77 | avg_pools = []
78 | # compute bins of all sizes for all spatial locations.
79 | for k in range(0, hierarchy):
80 | # avg pooling with kernel 3**kx3**k
81 | win_size = 3 ** k
82 | avg_pool = torch.nn.AvgPool2d(win_size, stride=1, padding=win_size // 2, count_include_pad=False)
83 | avg_pools.append(avg_pool(bin_x))
84 |
85 | bin_x = torch.zeros((B, sub_desc_dim * num_bins, w0, h0)).to(x.device)
86 | for y in range(w0):
87 | for x in range(h0):
88 | part_idx: int = 0
89 | # fill all bins for a spatial location (y, x)
90 | for k in range(0, hierarchy):
91 | kernel_size = 3 ** k
92 | for i in range(y - kernel_size, y + kernel_size + 1, kernel_size):
93 | for j in range(x - kernel_size, x + kernel_size + 1, kernel_size):
94 | if i == y and j == x and k != 0:
95 | continue
96 | if 0 <= i < w0 and 0 <= j < h0:
97 | bin_x[:, part_idx * sub_desc_dim: (part_idx + 1) * sub_desc_dim, y, x] = avg_pools[k][
98 | :, :, i, j]
99 | else: # handle padding in a more delicate way than zero padding
100 | temp_i = max(0, min(i, w0 - 1))
101 | temp_j = max(0, min(j, h0 - 1))
102 | bin_x[:, part_idx * sub_desc_dim: (part_idx + 1) * sub_desc_dim, y, x] = avg_pools[k][
103 | :, :, temp_i,
104 | temp_j]
105 | part_idx += 1
106 | bin_x = bin_x.flatten(start_dim=-2, end_dim=-1).permute(0, 2, 1).unsqueeze(dim=1)
107 | # Bx1x(t-1)x(dxh)
108 | return bin_x
--------------------------------------------------------------------------------
/src/models/featurizer/dinov2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
3 |
4 | from torchvision import transforms as T
5 |
6 | import torch
7 |
8 | def prepare_ViT_images(vit_patch_size, img, img_size):
9 | transform = T.Compose([
10 | T.Resize(img_size),
11 | T.ToTensor(),
12 | T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
13 | T.ConvertImageDtype(torch.float)
14 | ])
15 | img = transform(img)
16 | w, h = img.shape[-2] - img.shape[-2] % vit_patch_size, img.shape[-1] - img.shape[-1] % vit_patch_size
17 | img = img[..., :w, :h].float().to(device)
18 | # check if shape of image has 3 dimensions
19 | if len(img.shape) == 3:
20 | img = img[None]
21 | return img # B, C, W, H
22 |
23 | def get_name(log_bin=True, img_size='518', up_ft_index='1', **kwargs):
24 | name = 'dinov2'
25 | if log_bin:
26 | name = 'dinov2_logbin'
27 |
28 | model_size = kwargs.get('model_size', "dinov2_vits14")
29 | if model_size != "dinov2_vits14":
30 | # remove dinov2_vit from model_size
31 | model_size = model_size.replace("dinov2_vit", "")
32 | name += f'_{model_size}'
33 |
34 | return name + f'_%d_upft%d' % (img_size[0], up_ft_index)
35 |
36 | class DINOv2Featurizer:
37 | name = 'dinov2'
38 | def __init__(self, log_bin=True, img_size='518', up_ft_index='1', **kwargs):
39 | self.model_size = kwargs.get('model_size', "dinov2_vits14")
40 | model = torch.hub.load("facebookresearch/dinov2", self.model_size)
41 | self.model = model.eval().to(device)
42 | self._use_log_bin = log_bin
43 | self.name = get_name(log_bin, img_size, up_ft_index, **kwargs)
44 | self.vit_patch_size = 14
45 |
46 | def get_models(self):
47 | return [self.model]
48 |
49 | def prepare_tokens(self, img_tensor, up_ft_index):
50 | B, C, W, H = img_tensor.shape
51 | # out_dict = self.model(pixel_values=img_tensor, return_dict=True, output_hidden_states=True)
52 | #https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/models/dinov2/modeling_dinov2.py#L661
53 | # out = out_dict.hidden_states[-up_ft_index] # B, w0*h0, D
54 | out = self.model.get_intermediate_layers(img_tensor, n=up_ft_index)
55 | out = out[0] # take the output of the the n-th last block
56 | out = out[:, :, :] # B, w0*h0, D
57 | D = out.shape[-1]
58 | out = out.transpose(-2, -1).view(B, D, self.w0, self.h0)
59 | return out
60 |
61 | @torch.no_grad()
62 | def forward(self, img, up_ft_index=3, **kwargs):
63 | # convert pil to tensor
64 | img_size = kwargs.get('img_size')
65 | img_tensor = prepare_ViT_images(self.vit_patch_size, img, img_size) # [B, C, W, H]
66 | self.w0 = img_tensor.shape[-2] // self.vit_patch_size
67 | self.h0 = img_tensor.shape[-1] // self.vit_patch_size
68 | out = self.prepare_tokens(img_tensor, up_ft_index) # B, D, w0, h0
69 | if self._use_log_bin:
70 | out = self._log_bin(out.permute(0,2,3,1))
71 | out = out.view(-1,self.w0,self.h0,out.shape[-1]) # B, w0, h0, D'
72 | out = out.permute(0,3,1,2) # B, D', w0, h0
73 | return out
74 |
75 | def get_w0_h0(self):
76 | return self.w0, self.h0
77 |
78 | def _log_bin(self, x: torch.Tensor, hierarchy: int = 2) -> torch.Tensor:
79 | """
80 | create a log-binned descriptor.
81 | :param x: tensor of features. Has shape Bxwxhxd.
82 | :param hierarchy: how many bin hierarchies to use.
83 | """
84 | B, w0, h0, d = x.shape
85 | num_bins = 1 + self.vit_patch_size * hierarchy
86 | bin_x = x.permute(0, 3, 1, 2) # B, d, w0, h0
87 | sub_desc_dim = bin_x.shape[1]
88 |
89 | avg_pools = []
90 | # compute bins of all sizes for all spatial locations.
91 | for k in range(0, hierarchy):
92 | # avg pooling with kernel 3**kx3**k
93 | win_size = 3 ** k
94 | avg_pool = torch.nn.AvgPool2d(win_size, stride=1, padding=win_size // 2, count_include_pad=False)
95 | avg_pools.append(avg_pool(bin_x))
96 |
97 | bin_x = torch.zeros((B, sub_desc_dim * num_bins, w0, h0)).to(x.device)
98 | for y in range(w0):
99 | for x in range(h0):
100 | part_idx: int = 0
101 | # fill all bins for a spatial location (y, x)
102 | for k in range(0, hierarchy):
103 | kernel_size = 3 ** k
104 | for i in range(y - kernel_size, y + kernel_size + 1, kernel_size):
105 | for j in range(x - kernel_size, x + kernel_size + 1, kernel_size):
106 | if i == y and j == x and k != 0:
107 | continue
108 | if 0 <= i < w0 and 0 <= j < h0:
109 | bin_x[:, part_idx * sub_desc_dim: (part_idx + 1) * sub_desc_dim, y, x] = avg_pools[k][
110 | :, :, i, j]
111 | else: # handle padding in a more delicate way than zero padding
112 | temp_i = max(0, min(i, w0 - 1))
113 | temp_j = max(0, min(j, h0 - 1))
114 | bin_x[:, part_idx * sub_desc_dim: (part_idx + 1) * sub_desc_dim, y, x] = avg_pools[k][
115 | :, :, temp_i,
116 | temp_j]
117 | part_idx += 1
118 | bin_x = bin_x.flatten(start_dim=-2, end_dim=-1).permute(0, 2, 1).unsqueeze(dim=1)
119 | # Bx1x(t-1)x(dxh)
120 | return bin_x
--------------------------------------------------------------------------------
/src/models/featurizer/dinov2_lora.py:
--------------------------------------------------------------------------------
1 | import torch
2 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
3 |
4 | from torchvision import transforms as T
5 |
6 | import torch
7 | import torch.nn as nn
8 | import math
9 |
10 | # LoRA for Dinov2
11 | # This code is based on the original DINOv2 implementation and the LoRA implementation from https://github.com/RobvanGastel/dinov2-finetune/blob/main/dino_finetune/model/dino_v2.py
12 |
13 | class LoRA(nn.Module):
14 | """Low-Rank Adaptation for the for Query (Q), Key (Q), Value (V) matrices"""
15 |
16 | def __init__(
17 | self,
18 | qkv: nn.Module,
19 | linear_a_q: nn.Module,
20 | linear_b_q: nn.Module,
21 | linear_a_v: nn.Module,
22 | linear_b_v: nn.Module,
23 | ):
24 | super().__init__()
25 | self.qkv = qkv
26 | self.linear_a_q = linear_a_q
27 | self.linear_b_q = linear_b_q
28 | self.linear_a_v = linear_a_v
29 | self.linear_b_v = linear_b_v
30 | self.dim = getattr(qkv, 'in_features', 768) # Default fallback
31 | self.w_identity = torch.eye(self.dim)
32 |
33 | def forward(self, x) -> torch.Tensor:
34 | # Compute the original qkv
35 | qkv = self.qkv(x) # Shape: (B, N, 3 * org_C)
36 |
37 | # Compute the new q and v components
38 | new_q = self.linear_b_q(self.linear_a_q(x))
39 | new_v = self.linear_b_v(self.linear_a_v(x))
40 |
41 | # Add new q and v components to the original qkv tensor
42 | qkv[:, :, : self.dim] += new_q
43 | qkv[:, :, -self.dim :] += new_v
44 |
45 | return qkv
46 |
47 |
48 | def prepare_ViT_images(vit_patch_size, img, img_size):
49 | transform = T.Compose([
50 | T.Resize(img_size),
51 | T.ToTensor(),
52 | T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
53 | T.ConvertImageDtype(torch.float)
54 | ])
55 | img = transform(img)
56 | w, h = img.shape[-2] - img.shape[-2] % vit_patch_size, img.shape[-1] - img.shape[-1] % vit_patch_size
57 | img = img[..., :w, :h].float().to(device)
58 | # check if shape of image has 3 dimensions
59 | if len(img.shape) == 3:
60 | img = img[None]
61 | return img # B, C, W, H
62 |
63 | def get_name(log_bin=True, img_size='518', up_ft_index='1', **kwargs):
64 | name = 'dinov2lora'
65 | if log_bin:
66 | name = 'dinov2lora_logbin'
67 | # Handle both string and list inputs for img_size
68 | if isinstance(img_size, list):
69 | img_size_val = img_size[0]
70 | else:
71 | img_size_val = img_size
72 | # Handle up_ft_index conversion
73 | if hasattr(up_ft_index, '__iter__') and not isinstance(up_ft_index, str):
74 | up_ft_index_val = up_ft_index[0] if len(up_ft_index) > 0 else 1
75 | else:
76 | up_ft_index_val = up_ft_index
77 | # Convert to int, handling ListConfig objects
78 | try:
79 | img_size_int = int(img_size_val)
80 | up_ft_index_int = int(up_ft_index_val)
81 | except (ValueError, TypeError):
82 | # Fallback values if conversion fails
83 | img_size_int = 518
84 | up_ft_index_int = 1
85 | return name + f'_%d_upft%d' % (img_size_int, up_ft_index_int)
86 |
87 | class DINOv2LoRAFeaturizer(nn.Module):
88 | name = 'dinov2lora'
89 | def __init__(self, log_bin=True, img_size='518', up_ft_index='1', **kwargs):
90 | super().__init__()
91 | self.model_size = kwargs.get('model_size', "dinov2_vits14")
92 | self.model = torch.hub.load("facebookresearch/dinov2", self.model_size)
93 | self.model = self.model.to(device) # type: ignore
94 | for param in self.model.parameters():
95 | param.requires_grad = False
96 | self._use_log_bin = log_bin
97 | self.name = get_name(log_bin, img_size, up_ft_index, **kwargs)
98 | self.vit_patch_size = 14
99 |
100 | self.lora_layers = list(range(len(self.model.blocks)))
101 | self.w_a = []
102 | self.w_b = []
103 | self.r = kwargs.get('lora_rank', 10)
104 |
105 | for i, block in enumerate(self.model.blocks):
106 | if i not in self.lora_layers:
107 | continue
108 | w_qkv_linear = block.attn.qkv
109 | dim = w_qkv_linear.in_features
110 |
111 | w_a_linear_q, w_b_linear_q = self._create_lora_layer(dim, self.r)
112 | w_a_linear_v, w_b_linear_v = self._create_lora_layer(dim, self.r)
113 |
114 | self.w_a.extend([w_a_linear_q, w_a_linear_v])
115 | self.w_b.extend([w_b_linear_q, w_b_linear_v])
116 |
117 | block.attn.qkv = LoRA(
118 | w_qkv_linear,
119 | w_a_linear_q,
120 | w_b_linear_q,
121 | w_a_linear_v,
122 | w_b_linear_v,
123 | )
124 | self._reset_lora_parameters()
125 |
126 | def _create_lora_layer(self, dim: int, r: int):
127 | w_a = nn.Linear(dim, r, bias=False).to(device)
128 | w_b = nn.Linear(r, dim, bias=False).to(device)
129 | return w_a, w_b
130 |
131 | def _reset_lora_parameters(self) -> None:
132 | for w_a in self.w_a:
133 | nn.init.kaiming_uniform_(w_a.weight, a=math.sqrt(5))
134 | for w_b in self.w_b:
135 | nn.init.zeros_(w_b.weight)
136 |
137 | def get_models(self):
138 | return [self]
139 |
140 | def prepare_tokens(self, img_tensor, up_ft_index):
141 | B, C, W, H = img_tensor.shape
142 | # out_dict = self.model(pixel_values=img_tensor, return_dict=True, output_hidden_states=True)
143 | #https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/models/dinov2/modeling_dinov2.py#L661
144 | # out = out_dict.hidden_states[-up_ft_index] # B, w0*h0, D
145 | out = self.model.get_intermediate_layers(img_tensor, n=up_ft_index) # type: ignore
146 | out = out[0] # take the output of the the n-th last block
147 | out = out[:, :, :] # B, w0*h0, D
148 | D = out.shape[-1]
149 | out = out.transpose(-2, -1).view(B, D, self.w0, self.h0)
150 | return out
151 |
152 | def forward(self, img, up_ft_index=3, **kwargs):
153 | # convert pil to tensor
154 | img_size = kwargs.get('img_size')
155 | img_tensor = prepare_ViT_images(self.vit_patch_size, img, img_size) # [B, C, W, H]
156 | self.w0 = img_tensor.shape[-2] // self.vit_patch_size
157 | self.h0 = img_tensor.shape[-1] // self.vit_patch_size
158 | out = self.prepare_tokens(img_tensor, up_ft_index) # B, D, w0, h0
159 | if self._use_log_bin:
160 | out = self._log_bin(out.permute(0,2,3,1))
161 | out = out.view(-1,self.w0,self.h0,out.shape[-1]) # B, w0, h0, D'
162 | out = out.permute(0,3,1,2) # B, D', w0, h0
163 | return out
164 |
165 | def get_w0_h0(self):
166 | return self.w0, self.h0
167 |
168 | def _log_bin(self, x: torch.Tensor, hierarchy: int = 2) -> torch.Tensor:
169 | """
170 | create a log-binned descriptor.
171 | :param x: tensor of features. Has shape Bxwxhxd.
172 | :param hierarchy: how many bin hierarchies to use.
173 | """
174 | B, w0, h0, d = x.shape
175 | num_bins = 1 + self.vit_patch_size * hierarchy
176 | bin_x = x.permute(0, 3, 1, 2) # B, d, w0, h0
177 | sub_desc_dim = bin_x.shape[1]
178 |
179 | avg_pools = []
180 | # compute bins of all sizes for all spatial locations.
181 | for k in range(0, hierarchy):
182 | # avg pooling with kernel 3**kx3**k
183 | win_size = 3 ** k
184 | avg_pool = torch.nn.AvgPool2d(win_size, stride=1, padding=win_size // 2, count_include_pad=False)
185 | avg_pools.append(avg_pool(bin_x))
186 |
187 | bin_x = torch.zeros((B, sub_desc_dim * num_bins, w0, h0)).to(x.device)
188 | for y in range(w0):
189 | for x in range(h0):
190 | part_idx = 0
191 | # fill all bins for a spatial location (y, x)
192 | for k in range(0, hierarchy):
193 | kernel_size = 3 ** k
194 | for i in range(y - kernel_size, y + kernel_size + 1, kernel_size):
195 | for j in range(x - kernel_size, x + kernel_size + 1, kernel_size):
196 | if i == y and j == x and k != 0:
197 | continue
198 | if 0 <= i < w0 and 0 <= j < h0:
199 | bin_x[:, part_idx * sub_desc_dim: (part_idx + 1) * sub_desc_dim, y, x] = avg_pools[k][
200 | :, :, i, j]
201 | else: # handle padding in a more delicate way than zero padding
202 | temp_i = max(0, min(i, w0 - 1))
203 | temp_j = max(0, min(j, h0 - 1))
204 | bin_x[:, part_idx * sub_desc_dim: (part_idx + 1) * sub_desc_dim, y, x] = avg_pools[k][
205 | :, :, temp_i,
206 | temp_j]
207 | part_idx += 1
208 | bin_x = bin_x.flatten(start_dim=-2, end_dim=-1).permute(0, 2, 1).unsqueeze(dim=1)
209 | # Bx1x(t-1)x(dxh)
210 | return bin_x
211 |
212 | def save_parameters(self, filename: str) -> None:
213 | """Save the LoRA weights and decoder weights to a .pt file
214 |
215 | Args:
216 | filename (str): Filename of the weights
217 | """
218 | w_a, w_b = {}, {}
219 | w_a = {f"w_a_{i:03d}": self.w_a[i].weight.to('cpu') for i in range(len(self.w_a))}
220 | w_b = {f"w_b_{i:03d}": self.w_b[i].weight.to('cpu') for i in range(len(self.w_a))}
221 |
222 | torch.save({**w_a, **w_b}, filename)
223 |
224 | def load_parameters(self, filename: str) -> None:
225 | """Load the LoRA and decoder weights from a file
226 |
227 | Args:
228 | filename (str): File name of the weights
229 | """
230 | state_dict = torch.load(filename)
231 |
232 | # Load the LoRA parameters
233 | for i, w_A_linear in enumerate(self.w_a):
234 | saved_key = f"w_a_{i:03d}"
235 | saved_tensor = state_dict[saved_key].to(device)
236 | w_A_linear.weight = nn.Parameter(saved_tensor)
237 |
238 | for i, w_B_linear in enumerate(self.w_b):
239 | saved_key = f"w_b_{i:03d}"
240 | saved_tensor = state_dict[saved_key].to(device)
241 | w_B_linear.weight = nn.Parameter(saved_tensor)
--------------------------------------------------------------------------------
/src/models/featurizer/utils.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from src.models.featurizer.dinov2_lora import DINOv2LoRAFeaturizer
4 | from pathlib import Path
5 | import wandb
6 | import os
7 |
8 | def get_featurizer(featurizer_args):
9 | load_pretrained = 'init' in featurizer_args
10 | if not load_pretrained:
11 | if featurizer_args.model == 'dift_sd':
12 | from src.models.featurizer.dift_sd import SDFeaturizer4Eval
13 | featurizer = SDFeaturizer4Eval(cat_list=featurizer_args.all_cats)
14 | elif featurizer_args.model == 'dift_adm':
15 | from src.models.featurizer.dift_adm import ADMFeaturizer4Eval
16 | featurizer = ADMFeaturizer4Eval()
17 | elif featurizer_args.model == 'open_clip':
18 | from src.models.featurizer.clip import CLIPFeaturizer
19 | featurizer = CLIPFeaturizer()
20 | elif featurizer_args.model == 'dino':
21 | from src.models.featurizer.dino import DINOFeaturizer
22 | featurizer = DINOFeaturizer(**featurizer_args)
23 | elif featurizer_args.model == 'dinov2':
24 | from src.models.featurizer.dinov2 import DINOv2Featurizer
25 | featurizer = DINOv2Featurizer(**featurizer_args)
26 | elif featurizer_args.model == 'dinov2lora':
27 | from src.models.featurizer.dinov2_lora import DINOv2LoRAFeaturizer
28 | featurizer = DINOv2LoRAFeaturizer(**featurizer_args)
29 | elif featurizer_args.model == 'sd15ema_dinov2':
30 | from src.models.featurizer.sd15ema_dinov2 import sd15ema_dinov2_Featurizer
31 | featurizer = sd15ema_dinov2_Featurizer(**featurizer_args)
32 | else:
33 | raise ValueError('featurizer model not supported')
34 | else:
35 | featurizer = load_checkpoint(featurizer_args)
36 | return featurizer
37 |
38 | def load_checkpoint(args):
39 | if args.model == 'dinov2lora':
40 | model = DINOv2LoRAFeaturizer(**args)
41 | model_out_path = Path(args.model_out_path).joinpath(args.init.id)
42 | name = "last_lora_weights.pth" if 'eval_last' in args.init else "best_lora_weights.pth"
43 | model_out_path_new = model_out_path.joinpath(name)
44 | model.load_parameters(str(model_out_path_new))
45 | else:
46 | raise ValueError
47 |
48 | model.id = args.init.id
49 | model.name = args.init.id
50 | return model
51 |
52 | def load_checkpoint_old(args):
53 | model_out_path = Path(args.model_out_path).joinpath(args.init.id)
54 | name = "last.pth" if 'eval_last' in args.init else "best.pth"
55 | model_out_path_new = model_out_path.joinpath(name)
56 | if args.model == 'dinov2lora':
57 | featurizer = DINOv2LoRAFeaturizer(**args)
58 | try:
59 | featurizer.load_state_dict(torch.load(model_out_path_new))
60 | except:
61 | featurizer = torch.load(model_out_path_new)
62 | featurizer.id = args.init.id
63 | featurizer.name = args.init.id
64 | return featurizer
65 |
66 | def save_checkpoint(model, args, name = "lora_weights"):
67 | id = wandb.run.id if wandb.run else model.id
68 | model_out_path = Path(args.model_out_path).joinpath(id)
69 | os.makedirs(model_out_path, exist_ok=True)
70 | model_out_path_new = model_out_path.joinpath(name+".pth")
71 | # save state dict
72 | model.save_parameters(model_out_path_new)
73 |
74 | def get_featurizer_name(featurizer_args):
75 | if featurizer_args.model == 'dift_sd':
76 | from src.models.featurizer.dift_sd import get_name
77 | name = get_name(cat_list=featurizer_args.all_cats)
78 | elif featurizer_args.model == 'dift_adm':
79 | from src.models.featurizer.dift_adm import get_name
80 | name = get_name()
81 | elif featurizer_args.model == 'open_clip':
82 | from src.models.featurizer.clip import get_name
83 | name = get_name()
84 | elif featurizer_args.model == 'dino':
85 | from src.models.featurizer.dino import get_name
86 | name = get_name(**featurizer_args)
87 | elif featurizer_args.model == 'dinov2':
88 | from src.models.featurizer.dinov2 import get_name
89 | name = get_name(**featurizer_args)
90 | elif featurizer_args.model == 'dinov2lora':
91 | from src.models.featurizer.dinov2_lora import get_name
92 | name = get_name(**featurizer_args)
93 | elif featurizer_args.model == 'sd15ema_dinov2':
94 | from src.models.featurizer.sd15ema_dinov2 import get_name
95 | name = get_name(**featurizer_args)
96 | else:
97 | raise ValueError('featurizer model not supported')
98 | return name
--------------------------------------------------------------------------------
/src/models/featurizer_refine/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from .feat_refine_geosc import AggregationNetwork
3 | from src.dataset.cub_200 import CUBDatasetBordersCut
4 | import torch
5 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
6 | from pathlib import Path
7 | import os
8 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
9 | import wandb
10 | import os
11 | def get_init_dataset(args):
12 | dataset_pca = CUBDatasetBordersCut(args.init.dataset, split='train')
13 | from src.models.featurizer.utils import get_featurizer
14 | if args.init.featurizer.model == 'dift_sd':
15 | args.init.featurizer.all_cats = dataset_pca.all_cats
16 | featurizer = get_featurizer(args.init.featurizer)
17 | dataset_pca.init_kps_cat(args.init.dataset.cat)
18 | dataset_pca.featurizer = featurizer
19 | dataset_pca.featurizer_kwargs = args.init.featurizer
20 | return dataset_pca
21 |
22 | def get_model_refine(args):
23 | load_pretrained = 'init' in args
24 | if not load_pretrained:
25 | if args.model == 'GeoSc':
26 | model_refine = AggregationNetwork(feature_dims=args.feature_dims, projection_dim=args.projection_dim, device='cuda', feat_map_dropout=args.feat_map_dropout)
27 | model_refine.id = wandb.run.id if wandb.run!=None else ""
28 | else:
29 | model_refine = load_checkpoint(args)
30 | return model_refine
31 |
32 | def load_checkpoint(args):
33 | model_out_path = Path(args.model_out_path).joinpath(args.init.id)
34 | name = "last.pth" if 'eval_last' in args.init else "best.pth"
35 | name = "{}.pth".format(args.init.epoch) if args.model == 'GeoSc' else name
36 | model_out_path_new = model_out_path.joinpath(name)
37 | if args.model == 'GeoSc':
38 | # check if the model exists
39 | if not model_out_path_new.exists():
40 | # wget the model
41 | os.mkdir(model_out_path)
42 | os.system(f"wget -O {model_out_path_new} {args.init.url}")
43 | model_refine = AggregationNetwork(feature_dims=args.feature_dims, projection_dim=args.projection_dim, device='cuda')
44 | try:
45 | model_refine.load_state_dict(torch.load(model_out_path_new))
46 | except:
47 | model_refine = torch.load(model_out_path_new)
48 | model_refine.id = args.init.id
49 | return model_refine
50 |
51 | def save_checkpoint(model_refine, args, name):
52 | id = wandb.run.id if wandb.run else "test"
53 | model_out_path = Path(args.model_out_path).joinpath(id)
54 | os.makedirs(model_out_path, exist_ok=True)
55 | model_out_path_new = model_out_path.joinpath(name+".pth")
56 | # save state dict
57 | torch.save(model_refine.state_dict(), model_out_path_new)
58 |
--------------------------------------------------------------------------------
/src/models/featurizer_refine/feat_refine_geosc.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 | from .resnet import ResNet, BottleneckBlock
5 | import torch.nn.functional as F
6 |
7 | class DummyAggregationNetwork(nn.Module): # for testing, return the input
8 | def __init__(self):
9 | super(DummyAggregationNetwork, self).__init__()
10 | # dummy paprameter
11 | self.dummy = nn.Parameter(torch.ones([]))
12 | def forward(self, batch, pose=None):
13 | return batch * self.dummy
14 |
15 | class AggregationNetwork(nn.Module):
16 | """
17 | Module for aggregating feature maps across time and space.
18 | Design inspired by the Feature Extractor from ODISE (Xu et. al., CVPR 2023).
19 | https://github.com/NVlabs/ODISE/blob/5836c0adfcd8d7fd1f8016ff5604d4a31dd3b145/odise/modeling/backbone/feature_extractor.py
20 | """
21 | def __init__(
22 | self,
23 | device,
24 | feature_dims=[640,1280,1280,768],
25 | projection_dim=384,
26 | num_norm_groups=32,
27 | save_timestep=[1],
28 | kernel_size = [1,3,1],
29 | contrastive_temp = 10,
30 | feat_map_dropout=0.0,
31 | ):
32 | super().__init__()
33 | self.skip_connection = True
34 | self.feat_map_dropout = feat_map_dropout
35 | self.azimuth_embedding = None
36 | self.pos_embedding = None
37 | self.bottleneck_layers = nn.ModuleList()
38 | self.feature_dims = feature_dims
39 | # For CLIP symmetric cross entropy loss during training
40 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
41 | self.self_logit_scale = nn.Parameter(torch.ones([]) * np.log(contrastive_temp))
42 | self.device = device
43 | self.save_timestep = save_timestep
44 |
45 | self.mixing_weights_names = []
46 | for l, feature_dim in enumerate(self.feature_dims):
47 | bottleneck_layer = nn.Sequential(
48 | *ResNet.make_stage(
49 | BottleneckBlock,
50 | num_blocks=1,
51 | in_channels=feature_dim,
52 | bottleneck_channels=projection_dim // 4,
53 | out_channels=projection_dim,
54 | norm="GN",
55 | num_norm_groups=num_norm_groups,
56 | kernel_size=kernel_size
57 | )
58 | )
59 | self.bottleneck_layers.append(bottleneck_layer)
60 | for t in save_timestep:
61 | # 1-index the layer name following prior work
62 | self.mixing_weights_names.append(f"timestep-{save_timestep}_layer-{l+1}")
63 | self.last_layer = None
64 | self.bottleneck_layers = self.bottleneck_layers.to(device)
65 | mixing_weights = torch.ones(len(self.bottleneck_layers) * len(save_timestep))
66 | self.mixing_weights = nn.Parameter(mixing_weights.to(device))
67 | # count number of parameters
68 | num_params = 0
69 | for param in self.parameters():
70 | num_params += param.numel()
71 | print(f"AggregationNetwork has {num_params} parameters.")
72 |
73 | def load_pretrained_weights(self, pretrained_dict):
74 | custom_dict = self.state_dict()
75 |
76 | # Handle size mismatch
77 | if 'mixing_weights' in custom_dict and 'mixing_weights' in pretrained_dict and custom_dict['mixing_weights'].shape != pretrained_dict['mixing_weights'].shape:
78 | # Keep the first four weights from the pretrained model, and randomly initialize the fifth weight
79 | custom_dict['mixing_weights'][:4] = pretrained_dict['mixing_weights'][:4]
80 | custom_dict['mixing_weights'][4] = torch.zeros_like(custom_dict['mixing_weights'][4])
81 | else:
82 | custom_dict['mixing_weights'][:4] = pretrained_dict['mixing_weights'][:4]
83 |
84 | # Load the weights that do match
85 | matching_keys = {k: v for k, v in pretrained_dict.items() if k in custom_dict and k != 'mixing_weights'}
86 | custom_dict.update(matching_keys)
87 |
88 | # Now load the updated state_dict
89 | self.load_state_dict(custom_dict, strict=False)
90 |
91 | def forward(self, batch, pose=None):
92 | """
93 | Assumes batch is shape (B, C, H, W) where C is the concatentation of all layer features.
94 | """
95 | if self.feat_map_dropout > 0 and self.training:
96 | batch = F.dropout(batch, p=self.feat_map_dropout)
97 |
98 | output_feature = None
99 | start = 0
100 | mixing_weights = torch.nn.functional.softmax(self.mixing_weights, dim=0)
101 | if self.pos_embedding is not None: #position embedding
102 | batch = torch.cat((batch, self.pos_embedding), dim=1)
103 | for i in range(len(mixing_weights)):
104 | # Share bottleneck layers across timesteps
105 | bottleneck_layer = self.bottleneck_layers[i % len(self.feature_dims)]
106 | # Chunk the batch according the layer
107 | # Account for looping if there are multiple timesteps
108 | end = start + self.feature_dims[i % len(self.feature_dims)]
109 | feats = batch[:, start:end, :, :]
110 | start = end
111 | # Downsample the number of channels and weight the layer
112 | bottlenecked_feature = bottleneck_layer(feats)
113 | bottlenecked_feature = mixing_weights[i] * bottlenecked_feature
114 | if output_feature is None:
115 | output_feature = bottlenecked_feature
116 | else:
117 | output_feature += bottlenecked_feature
118 |
119 | if self.last_layer is not None:
120 |
121 | output_feature_after = self.last_layer(output_feature)
122 | if self.skip_connection:
123 | # skip connection
124 | output_feature = output_feature + output_feature_after
125 |
126 |
127 | norms_ft_new = torch.linalg.norm(output_feature, dim=1, keepdim=True)
128 | output_feature = output_feature / (norms_ft_new + 1e-8)
129 | return output_feature
130 |
131 |
132 | def conv1x1(in_planes, out_planes, stride=1):
133 | """1x1 convolution without padding"""
134 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False)
135 |
136 |
137 | def conv3x3(in_planes, out_planes, stride=1):
138 | """3x3 convolution with padding"""
139 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
140 |
141 |
142 | class BasicBlock(nn.Module):
143 | def __init__(self, in_planes, planes, stride=1):
144 | super().__init__()
145 | self.conv1 = conv3x3(in_planes, planes, stride)
146 | self.conv2 = conv3x3(planes, planes)
147 | self.bn1 = nn.BatchNorm2d(planes)
148 | self.bn2 = nn.BatchNorm2d(planes)
149 | self.relu = nn.ReLU(inplace=True)
150 |
151 | if stride == 1:
152 | self.downsample = None
153 | else:
154 | self.downsample = nn.Sequential(
155 | conv1x1(in_planes, planes, stride=stride),
156 | nn.BatchNorm2d(planes)
157 | )
158 |
159 | def forward(self, x):
160 | y = x
161 | y = self.relu(self.bn1(self.conv1(y)))
162 | y = self.bn2(self.conv2(y))
163 |
164 | if self.downsample is not None:
165 | x = self.downsample(x)
166 |
167 | return self.relu(x+y)
168 |
--------------------------------------------------------------------------------
/src/models/pca.py:
--------------------------------------------------------------------------------
1 | from sklearn.decomposition import PCA
2 | import torch
3 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
4 |
5 | def compute_pca(x, n_components):
6 | feature_dim = x.shape[-1]
7 | x = x.reshape(-1,feature_dim)
8 | pca = PCA(n_components=n_components).fit(x.cpu().numpy())
9 | components = pca.components_[None, ...]
10 | components = components.reshape(-1,feature_dim)
11 | return torch.tensor(components).to(device)
12 |
13 |
14 |
--------------------------------------------------------------------------------
/src/models/segmentation/sam.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import torch
3 | import torch.nn as nn
4 | # Ensure the segment_anything package is accessible.
5 | sys.path.append("..")
6 | """pip install git+https://github.com/facebookresearch/segment-anything.git"""
7 | from segment_anything import sam_model_registry, SamPredictor
8 | device = "cuda"
9 | import numpy as np
10 |
11 | def get_name():
12 | return 'sam'
13 | class SAM(nn.Module):
14 | name = "sam"
15 | def __init__(self, args):
16 | super().__init__()
17 | """wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"""
18 | sam_checkpoint = args.path_model_seg
19 | model_type = "vit_h"
20 | sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=device)
21 | self.sam = SamPredictor(sam)
22 | return
23 |
24 | def forward(self, im, bbox=None, kps=None):
25 | kps = kps[:, [1,0,2]]
26 | self.sam.set_image(np.array(im))
27 | if bbox is not None:
28 | input_box = np.array(bbox)[None, :]
29 | else:
30 | input_box = None
31 | prt, _, _ = self.sam.predict(box=input_box, multimask_output=False, point_coords=kps[:,:2].cpu().numpy(), point_labels=kps[:,2].cpu().numpy())
32 | prt = torch.tensor(prt[0])
33 | return prt
--------------------------------------------------------------------------------