├── .gitignore
├── LICENSE
├── README.md
├── assets
├── cartoon.png
├── cat.png
├── demo_cat_480p.gif
├── demo_guitar_480p.gif
├── edit_cat.gif
├── guitar.png
├── guitar_cat.jpg
├── painting_cat.jpg
├── target_cat.png
├── target_guitar.png
└── teaser.gif
├── demo.ipynb
├── edit_propagation.ipynb
├── environment.yml
├── eval_davis.py
├── eval_homography.py
├── eval_hpatches.py
├── eval_spair.py
├── extract_dift.py
├── extract_dift.sh
├── sd_featurizer_spair.py
├── setup_env.sh
└── src
├── models
├── clip.py
├── dift_adm.py
├── dift_sd.py
└── dino.py
└── utils
└── visualization.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
162 | */.DS_Store
163 | .DS_Store
164 |
165 | guided-diffusion/
166 | davis_results_sd/
167 | davis_results_adm/
168 | superpoint-1k/
169 | hpatches_results/
170 | superpoint-1k.zip
171 | SPair-71k.tar.gz
172 | SPair-71k/
173 | ./guided-diffusion/models/256x256_diffusion_uncond.pt
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Luming Tang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Diffusion Features (DIFT)
2 | This repository contains code for our NeurIPS 2023 paper "Emergent Correspondence from Image Diffusion".
3 |
4 | ### [Project Page](https://diffusionfeatures.github.io/) | [Paper](https://arxiv.org/abs/2306.03881) | [Colab Demo](https://colab.research.google.com/drive/1km6MGafhAvbPOouD3oo64aUXgLlWM6L1?usp=sharing)
5 |
6 | 
7 |
8 | ## Prerequisites
9 | If you have a Linux machine, you could either set up the python environment using the following command:
10 | ```
11 | conda env create -f environment.yml
12 | conda activate dift
13 | ```
14 | or create a new conda environment and install the packages manually using the
15 | shell commands in [setup_env.sh](setup_env.sh).
16 |
17 | ## Interactive Demo: Give it a Try!
18 | We provide an interactive jupyter notebook [demo.ipynb](demo.ipynb) to demonstrate the semantic correspondence established by DIFT, and you could try on your own images! After loading two images, you could left-click on an interesting point of the source image on the left, then after 1 or 2 seconds, the corresponding point on the target image will be displayed as a red point on the right, together with a heatmap showing the per-pixel cosine distance calculated using DIFT. Here're two examples on cat and guitar:
19 |
20 |
21 |
22 |  |
23 |  |
24 |
25 |
26 |
27 | If you don't have a local GPU, you can also use the provided [Colab Demo](https://colab.research.google.com/drive/1km6MGafhAvbPOouD3oo64aUXgLlWM6L1?usp=sharing).
28 |
29 | ## Extract DIFT for a given image
30 | You could use the following [command](extract_dift.sh) to extract DIFT from a given image, and save it as a torch tensor. These arguments are set to the same as in the semantic correspondence tasks by default.
31 | ```
32 | python extract_dift.py \
33 | --input_path ./assets/cat.png \
34 | --output_path dift_cat.pt \
35 | --img_size 768 768 \
36 | --t 261 \
37 | --up_ft_index 1 \
38 | --prompt 'a photo of a cat' \
39 | --ensemble_size 8
40 | ```
41 | Here're the explanation for each argument:
42 | - `input_path`: path to the input image file.
43 | - `output_path`: path to save the output features as torch tensor.
44 | - `img_size`: the width and height of the resized image before fed into diffusion model. If set to 0, then no resize operation would be performed thus it will stick to the original image size. It is set to [768, 768] by default. You can decrease this if encountering memory issue.
45 | - `t`: time step for diffusion, choose from range [0, 1000], must be an integer. `t=261` by default for semantic correspondence.
46 | - `up_ft_index`: the index of the U-Net upsampling block to extract the feature map, choose from [0, 1, 2, 3]. `up_ft_index=1` by default for semantic correspondence.
47 | - `prompt`: the prompt used in the diffusion model.
48 | - `ensemble_size`: the number of repeated images in each batch used to get features. `ensemble_size=8` by default. You can reduce this value if encountering memory issue.
49 |
50 | The output DIFT tensor spatial size is determined by both `img_size` and `up_ft_index`. If `up_ft_index=0`, the output size would be 1/32 of `img_size`; if `up_ft_index=1`, it would be 1/16; if `up_ft_index=2 or 3`, it would be 1/8.
51 |
52 | ## Application: Edit Propagation
53 | Using DIFT, we can propagate edits in one image to others that share semantic correspondences, even cross categories and domains:
54 |
55 | More implementation details are in this notebook [edit_propagation.ipynb](edit_propagation.ipynb).
56 |
57 | ## Get Benchmark Evaluation Results
58 | First, run the following scripts to enable the usage of DIFT_adm:
59 | ```
60 | git clone git@github.com:openai/guided-diffusion.git
61 | cd guided-diffusion && mkdir models && cd models
62 | wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt
63 | ```
64 |
65 | ### SPair-71k
66 |
67 | First, download SPair-71k data:
68 | ```
69 | wget https://cvlab.postech.ac.kr/research/SPair-71k/data/SPair-71k.tar.gz
70 | tar -xzvf SPair-71k.tar.gz
71 | ```
72 | Run the following script to get PCK (both per point and per img) of DIFT_sd on SPair-71k:
73 | ```
74 | python eval_spair.py \
75 | --dataset_path ./SPair-71k \
76 | --save_path ./spair_ft \ # a path to save features
77 | --dift_model sd \
78 | --img_size 768 768 \
79 | --t 261 \
80 | --up_ft_index 1 \
81 | --ensemble_size 8
82 | ```
83 | Run the following script to get PCK (both per point and per img) of DIFT_adm on SPair-71k:
84 | ```
85 | python eval_spair.py \
86 | --dataset_path ./SPair-71k \
87 | --save_path ./spair_ft \ # a path to save features
88 | --dift_model adm \
89 | --img_size 512 512 \
90 | --t 101 \
91 | --up_ft_index 4 \
92 | --ensemble_size 8
93 | ```
94 |
95 | ### HPatches
96 |
97 | First, prepare HPatches data:
98 | ```
99 | cd $HOME
100 | git clone git@github.com:mihaidusmanu/d2-net.git && cd d2-net/hpatches_sequences/
101 | chmod u+x download.sh
102 | ./download.sh
103 | ```
104 |
105 | Then, download the 1k superpoint keypoints:
106 | ```
107 | wget "https://www.dropbox.com/scl/fi/1mxy3oycnz7m2acd92u2x/superpoint-1k.zip?rlkey=fic30gr2tlth3cmsyyywcg385&dl=1" -O superpoint-1k.zip
108 | unzip superpoint-1k.zip
109 | rm superpoint-1k.zip
110 | ```
111 |
112 | Run the following script to get hompography estimation accuracy of DIFT_sd on HPatches:
113 | ```
114 | python eval_hpatches.py \
115 | --hpatches_path ../d2-net/hpatches_sequences/hpatches-sequences-release \
116 | --kpts_path ./superpoint-1k \
117 | --save_path ./hpatches_results \
118 | --dift_model sd \
119 | --img_size 768 768 \
120 | --t 0 \
121 | --up_ft_index 2 \
122 | --ensemble_size 8
123 |
124 | python eval_homography.py \
125 | --hpatches_path ../d2-net/hpatches_sequences/hpatches-sequences-release \
126 | --save_path ./hpatches_results \
127 | --hpatches_path
128 | --feat dift_sd \
129 | --metric cosine \
130 | --mode lmeds
131 | ```
132 |
133 | Run the following script to get hompography estimation accuracy of DIFT_adm on HPatches:
134 | ```
135 | python eval_hpatches.py \
136 | --hpatches_path ../d2-net/hpatches_sequences/hpatches-sequences-release \
137 | --kpts_path ./superpoint-1k \
138 | --save_path ./hpatches_results \
139 | --dift_model adm \
140 | --img_size 768 768 \
141 | --t 41 \
142 | --up_ft_index 11 \
143 | --ensemble_size 4
144 |
145 | python eval_homography.py \
146 | --hpatches_path ../d2-net/hpatches_sequences/hpatches-sequences-release \
147 | --save_path ./hpatches_results \
148 | --hpatches_path
149 | --feat dift_adm \
150 | --metric l2 \
151 | --mode ransac
152 | ```
153 |
154 | ### DAVIS
155 |
156 | We follow the evaluation protocal as in DINO's [implementation](https://github.com/facebookresearch/dino#evaluation-davis-2017-video-object-segmentation).
157 |
158 | First, prepare DAVIS 2017 data and evaluation tools:
159 | ```
160 | cd $HOME
161 | git clone https://github.com/davisvideochallenge/davis-2017 && cd davis-2017
162 | ./data/get_davis.sh
163 | cd $HOME
164 | git clone https://github.com/davisvideochallenge/davis2017-evaluation
165 | ```
166 |
167 | Then, get segmentation results using DIFT_sd:
168 | ```
169 | python eval_davis.py \
170 | --dift_model sd \
171 | --t 51 \
172 | --up_ft_index 2 \
173 | --temperature 0.2 \
174 | --topk 15 \
175 | --n_last_frames 28 \
176 | --ensemble_size 8 \
177 | --size_mask_neighborhood 15 \
178 | --data_path $HOME/davis-2017/DAVIS/ \
179 | --output_dir ./davis_results_sd/
180 | ```
181 |
182 | and results using DIFT_adm:
183 | ```
184 | python eval_davis.py \
185 | --dift_model adm \
186 | --t 51 \
187 | --up_ft_index 7 \
188 | --temperature 0.1 \
189 | --topk 10 \
190 | --n_last_frames 28 \
191 | --ensemble_size 4 \
192 | --size_mask_neighborhood 15 \
193 | --data_path $HOME/davis-2017/DAVIS/ \
194 | --output_dir ./davis_results_adm/
195 | ```
196 |
197 | Finally, evaluate the results:
198 | ```
199 | python $HOME/davis2017-evaluation/evaluation_method.py \
200 | --task semi-supervised \
201 | --results_path ./davis_results_sd/ \
202 | --davis_path $HOME/davis-2017/DAVIS/
203 |
204 | python $HOME/davis2017-evaluation/evaluation_method.py \
205 | --task semi-supervised \
206 | --results_path ./davis_results_adm/ \
207 | --davis_path $HOME/davis-2017/DAVIS/
208 | ```
209 |
210 | # Misc.
211 | If you find our code or paper useful to your research work, please consider citing our work using the following bibtex:
212 | ```
213 | @inproceedings{
214 | tang2023emergent,
215 | title={Emergent Correspondence from Image Diffusion},
216 | author={Luming Tang and Menglin Jia and Qianqian Wang and Cheng Perng Phoo and Bharath Hariharan},
217 | booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
218 | year={2023},
219 | url={https://openreview.net/forum?id=ypOiXjdfnU}
220 | }
221 | ```
222 |
--------------------------------------------------------------------------------
/assets/cartoon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tsingularity/dift/9421eb2034396c5b66f1aff37f03e540c264e52f/assets/cartoon.png
--------------------------------------------------------------------------------
/assets/cat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tsingularity/dift/9421eb2034396c5b66f1aff37f03e540c264e52f/assets/cat.png
--------------------------------------------------------------------------------
/assets/demo_cat_480p.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tsingularity/dift/9421eb2034396c5b66f1aff37f03e540c264e52f/assets/demo_cat_480p.gif
--------------------------------------------------------------------------------
/assets/demo_guitar_480p.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tsingularity/dift/9421eb2034396c5b66f1aff37f03e540c264e52f/assets/demo_guitar_480p.gif
--------------------------------------------------------------------------------
/assets/edit_cat.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tsingularity/dift/9421eb2034396c5b66f1aff37f03e540c264e52f/assets/edit_cat.gif
--------------------------------------------------------------------------------
/assets/guitar.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tsingularity/dift/9421eb2034396c5b66f1aff37f03e540c264e52f/assets/guitar.png
--------------------------------------------------------------------------------
/assets/guitar_cat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tsingularity/dift/9421eb2034396c5b66f1aff37f03e540c264e52f/assets/guitar_cat.jpg
--------------------------------------------------------------------------------
/assets/painting_cat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tsingularity/dift/9421eb2034396c5b66f1aff37f03e540c264e52f/assets/painting_cat.jpg
--------------------------------------------------------------------------------
/assets/target_cat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tsingularity/dift/9421eb2034396c5b66f1aff37f03e540c264e52f/assets/target_cat.png
--------------------------------------------------------------------------------
/assets/target_guitar.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tsingularity/dift/9421eb2034396c5b66f1aff37f03e540c264e52f/assets/target_guitar.png
--------------------------------------------------------------------------------
/assets/teaser.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tsingularity/dift/9421eb2034396c5b66f1aff37f03e540c264e52f/assets/teaser.gif
--------------------------------------------------------------------------------
/demo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "f6e50bfe-edb9-4932-bf43-39f047bf36d1",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "%matplotlib widget\n",
11 | "import argparse\n",
12 | "import gc\n",
13 | "import random\n",
14 | "import torch\n",
15 | "from PIL import Image\n",
16 | "from torchvision.transforms import PILToTensor\n",
17 | "from src.models.dift_sd import SDFeaturizer\n",
18 | "from src.utils.visualization import Demo"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": null,
24 | "id": "85e967e6-ca78-424d-ac69-2bc2c0bd744f",
25 | "metadata": {},
26 | "outputs": [],
27 | "source": [
28 | "torch.cuda.set_device(0)"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": null,
34 | "id": "1bf04cbe-b63d-4484-9918-7fac9f3506e9",
35 | "metadata": {},
36 | "outputs": [],
37 | "source": [
38 | "dift = SDFeaturizer()"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": null,
44 | "id": "829c1cf2-82a1-45ea-b284-19a055311855",
45 | "metadata": {},
46 | "outputs": [],
47 | "source": [
48 | "# you can choose visualize cat or guitar\n",
49 | "category = random.choice(['cat', 'guitar'])\n",
50 | "\n",
51 | "print(f\"let's visualize semantic correspondence on {category}\")"
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": null,
57 | "id": "076f7afc-232c-4fff-83dc-c63b876ee12f",
58 | "metadata": {},
59 | "outputs": [],
60 | "source": [
61 | "if category == 'cat':\n",
62 | " filelist = ['./assets/cat.png', './assets/target_cat.png', './assets/target_cat.png']\n",
63 | "elif category == 'guitar':\n",
64 | " filelist = ['./assets/guitar.png', './assets/target_guitar.png', './assets/target_guitar.png']\n",
65 | "\n",
66 | "prompt = f'a photo of a {category}'"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": null,
72 | "id": "245f2c3f-8445-42c5-b31f-be438c7239d9",
73 | "metadata": {},
74 | "outputs": [],
75 | "source": [
76 | "ft = []\n",
77 | "imglist = []\n",
78 | "\n",
79 | "# decrease these two if you don't have enough RAM or GPU memory\n",
80 | "img_size = 768\n",
81 | "ensemble_size = 8\n",
82 | "\n",
83 | "for filename in filelist:\n",
84 | " img = Image.open(filename).convert('RGB')\n",
85 | " img = img.resize((img_size, img_size))\n",
86 | " imglist.append(img)\n",
87 | " img_tensor = (PILToTensor()(img) / 255.0 - 0.5) * 2\n",
88 | " ft.append(dift.forward(img_tensor,\n",
89 | " prompt=prompt,\n",
90 | " ensemble_size=ensemble_size))\n",
91 | "ft = torch.cat(ft, dim=0)\n",
92 | "\n",
93 | "gc.collect()\n",
94 | "torch.cuda.empty_cache()"
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": null,
100 | "id": "f81ae975-cce2-491d-9b6b-58412559b8b6",
101 | "metadata": {},
102 | "outputs": [],
103 | "source": [
104 | "demo = Demo(imglist, ft, img_size)"
105 | ]
106 | },
107 | {
108 | "cell_type": "code",
109 | "execution_count": null,
110 | "id": "6f9b5ef9-db57-46dc-9ad2-9758b5d573c4",
111 | "metadata": {},
112 | "outputs": [],
113 | "source": [
114 | "'''\n",
115 | "left is source image, right is target image.\n",
116 | "you can click on the source image, and DIFT will find the corresponding\n",
117 | "point on the right image, mark it with red point and also plot the per-pixel \n",
118 | "cosine distance as heatmap.\n",
119 | "'''\n",
120 | "demo.plot_img_pairs(fig_size=5)"
121 | ]
122 | }
123 | ],
124 | "metadata": {
125 | "kernelspec": {
126 | "display_name": "Python 3 (ipykernel)",
127 | "language": "python",
128 | "name": "python3"
129 | },
130 | "language_info": {
131 | "codemirror_mode": {
132 | "name": "ipython",
133 | "version": 3
134 | },
135 | "file_extension": ".py",
136 | "mimetype": "text/x-python",
137 | "name": "python",
138 | "nbconvert_exporter": "python",
139 | "pygments_lexer": "ipython3",
140 | "version": "3.10.9"
141 | }
142 | },
143 | "nbformat": 4,
144 | "nbformat_minor": 5
145 | }
146 |
--------------------------------------------------------------------------------
/edit_propagation.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "3306ccce-4b17-41a9-831d-add6cccddc0e",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import torch\n",
11 | "import torch.nn as nn\n",
12 | "import matplotlib.pyplot as plt\n",
13 | "import numpy as np\n",
14 | "import gc\n",
15 | "import imageio\n",
16 | "from PIL import Image\n",
17 | "from torchvision.transforms import PILToTensor\n",
18 | "import os\n",
19 | "import json\n",
20 | "from PIL import Image, ImageDraw\n",
21 | "import torch.nn.functional as F\n",
22 | "import cv2\n",
23 | "import glob\n",
24 | "from torchvision.transforms import PILToTensor\n",
25 | "from src.models.dift_sd import SDFeaturizer4Eval"
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": null,
31 | "id": "081cd585-9d9d-4ffe-8c9b-6c6360d2e4ad",
32 | "metadata": {},
33 | "outputs": [],
34 | "source": [
35 | "def gen_grid(h, w, device, normalize=False, homogeneous=False):\n",
36 | " if normalize:\n",
37 | " lin_y = torch.linspace(-1., 1., steps=h, device=device)\n",
38 | " lin_x = torch.linspace(-1., 1., steps=w, device=device)\n",
39 | " else:\n",
40 | " lin_y = torch.arange(0, h, device=device)\n",
41 | " lin_x = torch.arange(0, w, device=device)\n",
42 | " grid_y, grid_x = torch.meshgrid((lin_y, lin_x))\n",
43 | " grid = torch.stack((grid_x, grid_y), -1)\n",
44 | " if homogeneous:\n",
45 | " grid = torch.cat([grid, torch.ones_like(grid[..., :1])], dim=-1)\n",
46 | " return grid # [h, w, 2 or 3]\n",
47 | "\n",
48 | "\n",
49 | "def normalize_coords(coords, h, w, no_shift=False):\n",
50 | " assert coords.shape[-1] == 2\n",
51 | " if no_shift:\n",
52 | " return coords / torch.tensor([w-1., h-1.], device=coords.device) * 2\n",
53 | " else:\n",
54 | " return coords / torch.tensor([w-1., h-1.], device=coords.device) * 2 - 1."
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": null,
60 | "id": "2a13b459-4698-4a9c-803f-d7ba8adb6962",
61 | "metadata": {},
62 | "outputs": [],
63 | "source": [
64 | "cat = 'cat'\n",
65 | "dift = SDFeaturizer4Eval(cat_list=['cat'])"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": null,
71 | "id": "0606e9dd-9e51-49ec-bf37-1f2bc9f78a84",
72 | "metadata": {},
73 | "outputs": [],
74 | "source": [
75 | "src_img = Image.open('./assets/guitar_cat.jpg').convert('RGB')\n",
76 | "trg_img = Image.open('./assets/painting_cat.jpg').convert('RGB')\n",
77 | "sticker = imageio.imread('./assets/cartoon.png')\n",
78 | "sticker_color, sticker_mask = sticker[..., :3], sticker[..., 3]\n",
79 | "\n",
80 | "assert np.array(src_img).shape[:2] == sticker.shape[:2]\n",
81 | "h_src, w_src = sticker.shape[:2]\n",
82 | "h_trg, w_trg = np.array(trg_img).shape[:2]\n",
83 | "\n",
84 | "sd_feat_src = dift.forward(src_img, cat)\n",
85 | "sd_feat_trg = dift.forward(trg_img, cat)\n",
86 | "\n",
87 | "sd_feat_src = F.normalize(sd_feat_src.squeeze(), p=2, dim=0)\n",
88 | "sd_feat_trg = F.normalize(sd_feat_trg.squeeze(), p=2, dim=0)\n",
89 | "feat_dim = sd_feat_src.shape[0]\n",
90 | "\n",
91 | "grid_src = gen_grid(h_src, w_src, device='cuda')\n",
92 | "grid_trg = gen_grid(h_trg, w_trg, device='cuda')\n",
93 | "\n",
94 | "coord_src = grid_src[sticker_mask > 0]\n",
95 | "coord_src = coord_src[torch.randperm(len(coord_src))][:1000]\n",
96 | "coord_src_normed = normalize_coords(coord_src, h_src, w_src)\n",
97 | "grid_trg_normed = normalize_coords(grid_trg, h_trg, w_trg)\n",
98 | "\n",
99 | "feat_src = F.grid_sample(sd_feat_src[None], coord_src_normed[None, None], align_corners=True).squeeze().T\n",
100 | "feat_trg = F.grid_sample(sd_feat_trg[None], grid_trg_normed[None], align_corners=True).squeeze()\n",
101 | "feat_trg_flattened = feat_trg.permute(1, 2, 0).reshape(-1, feat_dim)\n",
102 | "\n",
103 | "distances = torch.cdist(feat_src, feat_trg_flattened)\n",
104 | "_, indices = torch.min(distances, dim=1)\n",
105 | "\n",
106 | "src_pts = coord_src.reshape(-1, 2).cpu().numpy()\n",
107 | "trg_pts = grid_trg.reshape(-1, 2)[indices].cpu().numpy()\n",
108 | "\n",
109 | "M, mask = cv2.findHomography(src_pts, trg_pts, cv2.RANSAC, 5.0)\n",
110 | "sticker_out = cv2.warpPerspective(sticker, M, (w_trg, h_trg))\n",
111 | "\n",
112 | "sticker_out_alpha = sticker_out[..., 3:] / 255\n",
113 | "sticker_alpha = sticker[..., 3:] / 255\n",
114 | "\n",
115 | "trg_img_with_sticker = sticker_out_alpha * sticker_out[..., :3] + (1 - sticker_out_alpha) * trg_img\n",
116 | "src_img_with_sticker = sticker_alpha * sticker[..., :3] + (1 - sticker_alpha) * src_img"
117 | ]
118 | },
119 | {
120 | "cell_type": "code",
121 | "execution_count": null,
122 | "id": "88723600-c18f-4eb1-aec7-feb4112e2610",
123 | "metadata": {},
124 | "outputs": [],
125 | "source": [
126 | "fig, axs = plt.subplots(2, 2, figsize=(10, 10))\n",
127 | "\n",
128 | "axs[0, 0].imshow(src_img)\n",
129 | "axs[0, 0].set_title(\"Source Image\")\n",
130 | "axs[0, 0].axis('off')\n",
131 | "\n",
132 | "axs[0, 1].imshow(src_img_with_sticker.astype(np.uint8))\n",
133 | "axs[0, 1].set_title(\"Source Image with Edits\")\n",
134 | "axs[0, 1].axis('off')\n",
135 | "\n",
136 | "axs[1, 0].imshow(trg_img)\n",
137 | "axs[1, 0].set_title(\"Target Image\")\n",
138 | "axs[1, 0].axis('off')\n",
139 | "\n",
140 | "axs[1, 1].imshow(trg_img_with_sticker.astype(np.uint8))\n",
141 | "axs[1, 1].set_title(\"Target Image with Propagated Edits\")\n",
142 | "axs[1, 1].axis('off')\n",
143 | "\n",
144 | "plt.tight_layout()\n",
145 | "plt.show()"
146 | ]
147 | }
148 | ],
149 | "metadata": {
150 | "kernelspec": {
151 | "display_name": "Python 3 (ipykernel)",
152 | "language": "python",
153 | "name": "python3"
154 | },
155 | "language_info": {
156 | "codemirror_mode": {
157 | "name": "ipython",
158 | "version": 3
159 | },
160 | "file_extension": ".py",
161 | "mimetype": "text/x-python",
162 | "name": "python",
163 | "nbconvert_exporter": "python",
164 | "pygments_lexer": "ipython3",
165 | "version": "3.10.9"
166 | }
167 | },
168 | "nbformat": 4,
169 | "nbformat_minor": 5
170 | }
171 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: dift
2 | channels:
3 | - xformers
4 | - pytorch
5 | - nvidia
6 | - defaults
7 | dependencies:
8 | - _libgcc_mutex=0.1=main
9 | - _openmp_mutex=5.1=1_gnu
10 | - blas=1.0=mkl
11 | - brotlipy=0.7.0=py310h7f8727e_1002
12 | - bzip2=1.0.8=h7b6447c_0
13 | - ca-certificates=2023.01.10=h06a4308_0
14 | - certifi=2023.5.7=py310h06a4308_0
15 | - cffi=1.15.1=py310h5eee18b_3
16 | - charset-normalizer=2.0.4=pyhd3eb1b0_0
17 | - cryptography=39.0.1=py310h9ce1e76_0
18 | - cuda-cudart=11.7.99=0
19 | - cuda-cupti=11.7.101=0
20 | - cuda-libraries=11.7.1=0
21 | - cuda-nvrtc=11.7.99=0
22 | - cuda-nvtx=11.7.91=0
23 | - cuda-runtime=11.7.1=0
24 | - ffmpeg=4.3=hf484d3e_0
25 | - freetype=2.12.1=h4a9f257_0
26 | - giflib=5.2.1=h5eee18b_3
27 | - gmp=6.2.1=h295c915_3
28 | - gnutls=3.6.15=he1e5248_0
29 | - idna=3.4=py310h06a4308_0
30 | - intel-openmp=2023.1.0=hdb19cb5_46305
31 | - jpeg=9e=h5eee18b_1
32 | - lame=3.100=h7b6447c_0
33 | - lcms2=2.12=h3be6417_0
34 | - ld_impl_linux-64=2.38=h1181459_1
35 | - lerc=3.0=h295c915_0
36 | - libcublas=11.10.3.66=0
37 | - libcufft=10.7.2.124=h4fbf590_0
38 | - libcufile=1.6.1.9=0
39 | - libcurand=10.3.2.106=0
40 | - libcusolver=11.4.0.1=0
41 | - libcusparse=11.7.4.91=0
42 | - libdeflate=1.17=h5eee18b_0
43 | - libffi=3.4.4=h6a678d5_0
44 | - libgcc-ng=11.2.0=h1234567_1
45 | - libgomp=11.2.0=h1234567_1
46 | - libiconv=1.16=h7f8727e_2
47 | - libidn2=2.3.4=h5eee18b_0
48 | - libnpp=11.7.4.75=0
49 | - libnvjpeg=11.8.0.2=0
50 | - libpng=1.6.39=h5eee18b_0
51 | - libstdcxx-ng=11.2.0=h1234567_1
52 | - libtasn1=4.19.0=h5eee18b_0
53 | - libtiff=4.5.0=h6a678d5_2
54 | - libunistring=0.9.10=h27cfd23_0
55 | - libuuid=1.41.5=h5eee18b_0
56 | - libwebp=1.2.4=h11a3e52_1
57 | - libwebp-base=1.2.4=h5eee18b_1
58 | - lz4-c=1.9.4=h6a678d5_0
59 | - mkl=2023.1.0=h6d00ec8_46342
60 | - mkl-service=2.4.0=py310h5eee18b_1
61 | - mkl_fft=1.3.6=py310h1128e8f_1
62 | - mkl_random=1.2.2=py310h1128e8f_1
63 | - ncurses=6.4=h6a678d5_0
64 | - nettle=3.7.3=hbbd107a_1
65 | - numpy=1.24.3=py310h5f9d8c6_1
66 | - numpy-base=1.24.3=py310hb5e798b_1
67 | - openh264=2.1.1=h4ff587b_0
68 | - openssl=1.1.1t=h7f8727e_0
69 | - pillow=9.4.0=py310h6a678d5_0
70 | - pip=23.0.1=py310h06a4308_0
71 | - pycparser=2.21=pyhd3eb1b0_0
72 | - pyopenssl=23.0.0=py310h06a4308_0
73 | - pysocks=1.7.1=py310h06a4308_0
74 | - python=3.10.9=h7a1cb2a_2
75 | - pytorch=1.13.1=py3.10_cuda11.7_cudnn8.5.0_0
76 | - pytorch-cuda=11.7=h778d358_5
77 | - pytorch-mutex=1.0=cuda
78 | - readline=8.2=h5eee18b_0
79 | - requests=2.29.0=py310h06a4308_0
80 | - setuptools=67.8.0=py310h06a4308_0
81 | - sqlite=3.41.2=h5eee18b_0
82 | - tbb=2021.8.0=hdb19cb5_0
83 | - tk=8.6.12=h1ccaba5_0
84 | - torchaudio=0.13.1=py310_cu117
85 | - torchvision=0.14.1=py310_cu117
86 | - typing_extensions=4.5.0=py310h06a4308_0
87 | - urllib3=1.26.15=py310h06a4308_0
88 | - wheel=0.38.4=py310h06a4308_0
89 | - xformers=0.0.20=py310_cu11.7.1_pyt1.13.1
90 | - xz=5.4.2=h5eee18b_0
91 | - zlib=1.2.13=h5eee18b_0
92 | - zstd=1.5.5=hc292b87_0
93 | - pip:
94 | - accelerate==0.19.0
95 | - aiohttp==3.8.4
96 | - aiosignal==1.3.1
97 | - anyio==3.7.0
98 | - argon2-cffi==21.3.0
99 | - argon2-cffi-bindings==21.2.0
100 | - arrow==1.2.3
101 | - asttokens==2.2.1
102 | - async-lru==2.0.2
103 | - async-timeout==4.0.2
104 | - attrs==23.1.0
105 | - babel==2.12.1
106 | - backcall==0.2.0
107 | - beautifulsoup4==4.12.2
108 | - bleach==6.0.0
109 | - brotli==1.0.9
110 | - cmake==3.26.3
111 | - comm==0.1.3
112 | - contourpy==1.0.7
113 | - cycler==0.11.0
114 | - debugpy==1.6.7
115 | - decorator==5.1.1
116 | - defusedxml==0.7.1
117 | - diffusers==0.15.0
118 | - exceptiongroup==1.1.1
119 | - executing==1.2.0
120 | - fastjsonschema==2.17.1
121 | - filelock==3.12.0
122 | - fonttools==4.39.4
123 | - fqdn==1.5.1
124 | - frozenlist==1.3.3
125 | - fsspec==2023.5.0
126 | - gevent==22.10.2
127 | - geventhttpclient==2.0.2
128 | - greenlet==2.0.2
129 | - grpcio==1.54.2
130 | - huggingface-hub==0.14.1
131 | - imageio==2.33.0
132 | - importlib-metadata==6.6.0
133 | - ipykernel==6.23.1
134 | - ipympl==0.9.3
135 | - ipython==8.13.2
136 | - ipython-genutils==0.2.0
137 | - ipywidgets==8.0.6
138 | - isoduration==20.11.0
139 | - jedi==0.18.2
140 | - jinja2==3.1.2
141 | - json5==0.9.14
142 | - jsonpointer==2.3
143 | - jsonschema==4.17.3
144 | - jupyter-client==8.2.0
145 | - jupyter-core==5.3.0
146 | - jupyter-events==0.6.3
147 | - jupyter-lsp==2.2.0
148 | - jupyter-server==2.6.0
149 | - jupyter-server-terminals==0.4.4
150 | - jupyterlab==4.0.0
151 | - jupyterlab-pygments==0.2.2
152 | - jupyterlab-server==2.22.1
153 | - jupyterlab-widgets==3.0.7
154 | - kiwisolver==1.4.4
155 | - lazy-loader==0.3
156 | - lit==16.0.5
157 | - markupsafe==2.1.2
158 | - matplotlib==3.7.1
159 | - matplotlib-inline==0.1.6
160 | - mistune==2.0.5
161 | - multidict==6.0.4
162 | - mypy-extensions==1.0.0
163 | - nbclient==0.8.0
164 | - nbconvert==7.4.0
165 | - nbformat==5.8.0
166 | - nest-asyncio==1.5.6
167 | - networkx==3.2.1
168 | - notebook-shim==0.2.3
169 | - opencv-python==4.8.1.78
170 | - overrides==7.3.1
171 | - packaging==23.1
172 | - pandas==2.1.4
173 | - pandocfilters==1.5.0
174 | - parso==0.8.3
175 | - pexpect==4.8.0
176 | - pickleshare==0.7.5
177 | - platformdirs==3.5.1
178 | - prometheus-client==0.17.0
179 | - prompt-toolkit==3.0.38
180 | - protobuf==3.20.3
181 | - psutil==5.9.5
182 | - ptyprocess==0.7.0
183 | - pure-eval==0.2.2
184 | - pygments==2.15.1
185 | - pyparsing==3.0.9
186 | - pyrsistent==0.19.3
187 | - python-dateutil==2.8.2
188 | - python-json-logger==2.0.7
189 | - python-rapidjson==1.10
190 | - pytz==2023.3.post1
191 | - pyyaml==6.0
192 | - pyzmq==25.1.0
193 | - regex==2023.5.5
194 | - rfc3339-validator==0.1.4
195 | - rfc3986-validator==0.1.1
196 | - scikit-image==0.22.0
197 | - scipy==1.11.4
198 | - send2trash==1.8.2
199 | - sh==1.14.3
200 | - six==1.16.0
201 | - sniffio==1.3.0
202 | - soupsieve==2.4.1
203 | - stack-data==0.6.2
204 | - terminado==0.17.1
205 | - tifffile==2023.9.26
206 | - tinycss2==1.2.1
207 | - tokenizers==0.13.3
208 | - tomli==2.0.1
209 | - tornado==6.3.2
210 | - tqdm==4.65.0
211 | - traitlets==5.9.0
212 | - transformers==4.29.2
213 | - triton==2.0.0.post1
214 | - tritonclient==2.33.0
215 | - typing-inspect==0.6.0
216 | - tzdata==2023.3
217 | - uri-template==1.2.0
218 | - wcwidth==0.2.6
219 | - webcolors==1.13
220 | - webencodings==0.5.1
221 | - websocket-client==1.5.2
222 | - widgetsnbextension==4.0.7
223 | - wrapt==1.15.0
224 | - yarl==1.9.2
225 | - zipp==3.15.0
226 | - zope-event==4.6
227 | - zope-interface==6.0
--------------------------------------------------------------------------------
/eval_davis.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Some parts are taken from https://github.com/Liusifei/UVC
16 | """
17 | import os
18 | import copy
19 | import glob
20 | import queue
21 | from urllib.request import urlopen
22 | import argparse
23 | import numpy as np
24 | from tqdm import tqdm
25 |
26 | import gc
27 | import cv2
28 | import torch
29 | from torch.nn import functional as F
30 | from PIL import Image
31 | from src.models.dift_sd import SDFeaturizer
32 | from src.models.dift_adm import ADMFeaturizer
33 |
34 |
35 | @torch.no_grad()
36 | def eval_video_tracking_davis(args, model, scale_factor, frame_list, video_dir, first_seg, seg_ori, color_palette):
37 | """
38 | Evaluate tracking on a video given first frame & segmentation
39 | """
40 | video_folder = os.path.join(args.output_dir, video_dir.split('/')[-1])
41 | os.makedirs(video_folder, exist_ok=True)
42 |
43 | # The queue stores the n preceeding frames
44 | que = queue.Queue(args.n_last_frames)
45 |
46 | # first frame
47 | frame1, ori_h, ori_w = read_frame(frame_list[0])
48 | # extract first frame feature
49 | frame1_feat = extract_feature(args, model, frame1).T # dim x h*w
50 |
51 | # saving first segmentation
52 | out_path = os.path.join(video_folder, "00000.png")
53 | imwrite_indexed(out_path, seg_ori, color_palette)
54 | mask_neighborhood = None
55 | for cnt in tqdm(range(1, len(frame_list))):
56 | frame_tar = read_frame(frame_list[cnt])[0]
57 |
58 | # we use the first segmentation and the n previous ones
59 | used_frame_feats = [frame1_feat] + [pair[0] for pair in list(que.queue)]
60 | used_segs = [first_seg] + [pair[1] for pair in list(que.queue)]
61 |
62 | frame_tar_avg, feat_tar, mask_neighborhood = label_propagation(args, model, frame_tar, used_frame_feats, used_segs, mask_neighborhood)
63 |
64 | # pop out oldest frame if neccessary
65 | if que.qsize() == args.n_last_frames:
66 | que.get()
67 | # push current results into queue
68 | seg = copy.deepcopy(frame_tar_avg)
69 | que.put([feat_tar, seg])
70 |
71 | # upsampling & argmax
72 | frame_tar_avg = F.interpolate(frame_tar_avg, scale_factor=scale_factor, mode='bilinear', align_corners=False, recompute_scale_factor=False)[0]
73 | frame_tar_avg = norm_mask(frame_tar_avg)
74 | _, frame_tar_seg = torch.max(frame_tar_avg, dim=0)
75 |
76 | # saving to disk
77 | frame_tar_seg = np.array(frame_tar_seg.squeeze().cpu(), dtype=np.uint8)
78 | frame_tar_seg = np.array(Image.fromarray(frame_tar_seg).resize((ori_w, ori_h), 0))
79 | frame_nm = frame_list[cnt].split('/')[-1].replace(".jpg", ".png")
80 | imwrite_indexed(os.path.join(video_folder, frame_nm), frame_tar_seg, color_palette)
81 |
82 |
83 | def restrict_neighborhood(h, w):
84 | # We restrict the set of source nodes considered to a spatial neighborhood of the query node (i.e. ``local attention'')
85 | mask = torch.zeros(h, w, h, w)
86 | for i in range(h):
87 | for j in range(w):
88 | for p in range(2 * args.size_mask_neighborhood + 1):
89 | for q in range(2 * args.size_mask_neighborhood + 1):
90 | if i - args.size_mask_neighborhood + p < 0 or i - args.size_mask_neighborhood + p >= h:
91 | continue
92 | if j - args.size_mask_neighborhood + q < 0 or j - args.size_mask_neighborhood + q >= w:
93 | continue
94 | mask[i, j, i - args.size_mask_neighborhood + p, j - args.size_mask_neighborhood + q] = 1
95 |
96 | mask = mask.reshape(h * w, h * w)
97 | return mask.cuda(non_blocking=True)
98 |
99 |
100 | def norm_mask(mask):
101 | c, h, w = mask.size()
102 | for cnt in range(c):
103 | mask_cnt = mask[cnt,:,:]
104 | if(mask_cnt.max() > 0):
105 | mask_cnt = (mask_cnt - mask_cnt.min())
106 | mask_cnt = mask_cnt/mask_cnt.max()
107 | mask[cnt,:,:] = mask_cnt
108 | return mask
109 |
110 |
111 | def label_propagation(args, model, frame_tar, list_frame_feats, list_segs, mask_neighborhood=None):
112 | """
113 | propagate segs of frames in list_frames to frame_tar
114 | """
115 | gc.collect()
116 | torch.cuda.empty_cache()
117 |
118 | ## we only need to extract feature of the target frame
119 | feat_tar, h, w = extract_feature(args, model, frame_tar, return_h_w=True)
120 |
121 | gc.collect()
122 | torch.cuda.empty_cache()
123 |
124 | return_feat_tar = feat_tar.T # dim x h*w
125 |
126 | ncontext = len(list_frame_feats)
127 | feat_sources = torch.stack(list_frame_feats) # nmb_context x dim x h*w
128 |
129 | feat_tar = F.normalize(feat_tar, dim=1, p=2)
130 | feat_sources = F.normalize(feat_sources, dim=1, p=2)
131 |
132 | feat_tar = feat_tar.unsqueeze(0).repeat(ncontext, 1, 1)
133 | aff = torch.exp(torch.bmm(feat_tar, feat_sources) / args.temperature) # nmb_context x h*w (tar: query) x h*w (source: keys)
134 |
135 | if args.size_mask_neighborhood > 0:
136 | if mask_neighborhood is None:
137 | mask_neighborhood = restrict_neighborhood(h, w)
138 | mask_neighborhood = mask_neighborhood.unsqueeze(0).repeat(ncontext, 1, 1)
139 | aff *= mask_neighborhood
140 |
141 | aff = aff.transpose(2, 1).reshape(-1, h * w) # nmb_context*h*w (source: keys) x h*w (tar: queries)
142 | tk_val, _ = torch.topk(aff, dim=0, k=args.topk)
143 | tk_val_min, _ = torch.min(tk_val, dim=0)
144 | aff[aff < tk_val_min] = 0
145 |
146 | aff = aff / torch.sum(aff, keepdim=True, axis=0)
147 |
148 | gc.collect()
149 | torch.cuda.empty_cache()
150 |
151 | list_segs = [s.cuda() for s in list_segs]
152 | segs = torch.cat(list_segs)
153 | nmb_context, C, h, w = segs.shape
154 | segs = segs.reshape(nmb_context, C, -1).transpose(2, 1).reshape(-1, C).T # C x nmb_context*h*w
155 | seg_tar = torch.mm(segs, aff)
156 | seg_tar = seg_tar.reshape(1, C, h, w)
157 |
158 | return seg_tar, return_feat_tar, mask_neighborhood
159 |
160 |
161 | def extract_feature(args, model, frame, return_h_w=False):
162 | """Extract one frame feature everytime."""
163 | with torch.no_grad():
164 | unet_ft = model.forward(frame,
165 | t=args.t,
166 | up_ft_index=args.up_ft_index,
167 | ensemble_size=args.ensemble_size).squeeze() # c, h, w
168 | dim, h, w = unet_ft.shape
169 | unet_ft = torch.permute(unet_ft, (1, 2, 0)) # h,w,c
170 | unet_ft = unet_ft.view(h * w, dim) # hw,c
171 | if return_h_w:
172 | return unet_ft, h, w
173 | return unet_ft
174 |
175 |
176 | def imwrite_indexed(filename, array, color_palette):
177 | """ Save indexed png for DAVIS."""
178 | if np.atleast_3d(array).shape[2] != 1:
179 | raise Exception("Saving indexed PNGs requires 2D array.")
180 |
181 | im = Image.fromarray(array)
182 | im.putpalette(color_palette.ravel())
183 | im.save(filename, format='PNG')
184 |
185 |
186 | def to_one_hot(y_tensor, n_dims=None):
187 | """
188 | Take integer y (tensor or variable) with n dims &
189 | convert it to 1-hot representation with n+1 dims.
190 | """
191 | if(n_dims is None):
192 | n_dims = int(y_tensor.max()+ 1)
193 | _,h,w = y_tensor.size()
194 | y_tensor = y_tensor.type(torch.LongTensor).view(-1, 1)
195 | n_dims = n_dims if n_dims is not None else int(torch.max(y_tensor)) + 1
196 | y_one_hot = torch.zeros(y_tensor.size()[0], n_dims).scatter_(1, y_tensor, 1)
197 | y_one_hot = y_one_hot.view(h,w,n_dims)
198 | return y_one_hot.permute(2, 0, 1).unsqueeze(0)
199 |
200 |
201 | def read_frame_list(video_dir):
202 | frame_list = [img for img in glob.glob(os.path.join(video_dir,"*.jpg"))]
203 | frame_list = sorted(frame_list)
204 | return frame_list
205 |
206 |
207 | def read_frame(frame_dir, scale_size=[480]):
208 | """
209 | read a single frame & preprocess
210 | """
211 | img = cv2.imread(frame_dir)
212 | ori_h, ori_w, _ = img.shape
213 | if len(scale_size) == 1:
214 | if(ori_h > ori_w):
215 | tw = scale_size[0]
216 | th = (tw * ori_h) / ori_w
217 | th = int((th // 32) * 32)
218 | else:
219 | th = scale_size[0]
220 | tw = (th * ori_w) / ori_h
221 | tw = int((tw // 32) * 32)
222 | else:
223 | th, tw = scale_size
224 | img = cv2.resize(img, (tw, th))
225 | img = img.astype(np.float32)
226 | img = img / 255.0
227 | img = img[:, :, ::-1]
228 | img = np.transpose(img.copy(), (2, 0, 1))
229 | img = torch.from_numpy(img).float()
230 | img = color_normalize(img)
231 | return img, ori_h, ori_w
232 |
233 |
234 | def read_seg(seg_dir, scale_factor, scale_size=[480]):
235 | seg = Image.open(seg_dir)
236 | _w, _h = seg.size # note PIL.Image.Image's size is (w, h)
237 | if len(scale_size) == 1:
238 | if(_w > _h):
239 | _th = scale_size[0]
240 | _tw = (_th * _w) / _h
241 | _tw = int((_tw // 32) * 32)
242 | else:
243 | _tw = scale_size[0]
244 | _th = (_tw * _h) / _w
245 | _th = int((_th // 32) * 32)
246 | else:
247 | _th = scale_size[1]
248 | _tw = scale_size[0]
249 | small_seg = np.array(seg.resize((_tw // scale_factor, _th // scale_factor), 0))
250 | small_seg = torch.from_numpy(small_seg.copy()).contiguous().float().unsqueeze(0)
251 |
252 | return to_one_hot(small_seg), np.asarray(seg)
253 |
254 |
255 | def color_normalize(x, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
256 | for t, m, s in zip(x, mean, std):
257 | t.sub_(m)
258 | t.div_(s)
259 | return x
260 |
261 |
262 | if __name__ == '__main__':
263 | parser = argparse.ArgumentParser('Evaluation with video object segmentation on DAVIS 2017')
264 | parser.add_argument('--dift_model', choices=['sd', 'adm'], default='sd', help="which dift version to use")
265 | parser.add_argument('--t', default=201, type=int, help='t for diffusion')
266 | parser.add_argument('--up_ft_index', default=1, type=int, help='which upsampling block to extract the ft map')
267 | parser.add_argument('--ensemble_size', default=4, type=int, help='ensemble size for getting an image ft map')
268 | parser.add_argument('--temperature', default=0.1, type=float, help='temperature for softmax')
269 |
270 | parser.add_argument('--output_dir', type=str, help='Path where to save segmentations')
271 | parser.add_argument('--data_path', type=str, help="path to davis dataset")
272 | parser.add_argument("--n_last_frames", type=int, default=7, help="number of preceeding frames")
273 | parser.add_argument("--size_mask_neighborhood", default=12, type=int,
274 | help="We restrict the set of source nodes considered to a spatial neighborhood of the query node")
275 | parser.add_argument("--topk", type=int, default=5, help="accumulate label from top k neighbors")
276 | args = parser.parse_args()
277 |
278 | print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
279 |
280 | color_palette = []
281 | for line in urlopen("https://raw.githubusercontent.com/Liusifei/UVC/master/libs/data/palette.txt"):
282 | color_palette.append([int(i) for i in line.decode("utf-8").split('\n')[0].split(" ")])
283 | color_palette = np.asarray(color_palette, dtype=np.uint8).reshape(-1,3)
284 |
285 | video_list = open(os.path.join(args.data_path, "ImageSets/2017/val.txt")).readlines()
286 | n_last_frames = args.n_last_frames
287 |
288 | if args.dift_model == 'adm':
289 | index2factor = {0:32, 1:32, 2:16, 3:16, 4:16, 5:8, 6:8, 7:8, 8:4,
290 | 9:4, 10:4, 11:2, 12:2, 13:2, 14:1, 15:1, 16:1, 17:1}
291 | model = ADMFeaturizer()
292 | elif args.dift_model == 'sd':
293 | index2factor = {0:32, 1:16, 2:8, 3:8}
294 | model = SDFeaturizer()
295 |
296 | scale_factor = index2factor[args.up_ft_index]
297 | for i, video_name in enumerate(video_list):
298 | video_name = video_name.strip()
299 |
300 | if video_name == 'shooting':
301 | if args.n_last_frames > 10:
302 | args.n_last_frames = 10 # this can resolve the OOM issue
303 | else:
304 | args.n_last_frames = n_last_frames
305 |
306 | print(f'[{i}/{len(video_list)}] Begin to segmentate video {video_name}.')
307 | video_dir = os.path.join(args.data_path, "JPEGImages/480p/", video_name)
308 | frame_list = read_frame_list(video_dir)
309 | seg_path = frame_list[0].replace("JPEGImages", "Annotations").replace("jpg", "png")
310 | first_seg, seg_ori = read_seg(seg_path, scale_factor)
311 | eval_video_tracking_davis(args, model, scale_factor, frame_list, video_dir, first_seg, seg_ori, color_palette)
312 |
--------------------------------------------------------------------------------
/eval_homography.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import argparse
3 | import os
4 | import torch
5 | from tqdm import tqdm
6 | import cv2
7 | import torch.nn.functional as F
8 |
9 | def mnn_matcher(descriptors_a, descriptors_b, metric='cosine'):
10 | device = descriptors_a.device
11 | if metric == 'cosine':
12 | descriptors_a = F.normalize(descriptors_a)
13 | descriptors_b = F.normalize(descriptors_b)
14 | sim = descriptors_a @ descriptors_b.t()
15 | elif metric == 'l2':
16 | dist = torch.sum(descriptors_a**2, dim=1, keepdim=True) + torch.sum(descriptors_b**2, dim=1, keepdim=True).t() - \
17 | 2 * descriptors_a.mm(descriptors_b.t())
18 | sim = -dist
19 | nn12 = torch.max(sim, dim=1)[1]
20 | nn21 = torch.max(sim, dim=0)[1]
21 | ids1 = torch.arange(0, sim.shape[0], device=device)
22 | mask = (ids1 == nn21[nn12])
23 | matches = torch.stack([ids1[mask], nn12[mask]])
24 | return matches.t().data.cpu().numpy()
25 |
26 | def generate_read_function(save_path, method, extension='ppm', top_k=None):
27 | def read_function(seq_name, im_idx):
28 | aux = np.load(os.path.join(save_path, seq_name, '%d.%s.%s' % (im_idx, extension, method)))
29 | if top_k is None:
30 | return aux['keypoints'], aux['descriptors']
31 | else:
32 | if len(aux['scores']) != 0:
33 | ids = np.argsort(aux['scores'])[-top_k :]
34 | if len(aux['scores'].shape) == 2:
35 | scores = aux['scores'][0]
36 | elif len(aux['scores'].shape) == 1:
37 | scores = aux['scores']
38 | ids = np.argsort(scores)[-top_k :]
39 | return aux['keypoints'][ids, :], aux['descriptors'][ids, :]
40 | else:
41 | return aux['keypoints'][:, :2], aux['descriptors']
42 | return read_function
43 |
44 |
45 | if __name__ == "__main__":
46 | parser = argparse.ArgumentParser(description='SPair-71k Evaluation Script')
47 | parser.add_argument('--hpatches_path', type=str, default='/scratch/dift_release/d2-net/hpatches_sequences/hpatches-sequences-release', help='path to hpatches dataset')
48 | parser.add_argument('--save_path', type=str, default='./hpatches_results', help='path to save features')
49 | parser.add_argument('--feat', choices=['dift_sd', 'dift_adm'], default='dift_sd', help="which feature to use")
50 | parser.add_argument('--metric', choices=['cosine', 'l2'], default='cosine', help="which distance metric to use")
51 | parser.add_argument('--mode', choices=['ransac', 'lmeds'], default='lmeds', help="which method to use when calculating homography")
52 | args = parser.parse_args()
53 |
54 | seq_names = sorted(os.listdir(args.hpatches_path))
55 | read_function = generate_read_function(args.save_path, args.feat)
56 | th = np.linspace(1, 5, 3)
57 |
58 | i_accuracy = []
59 | v_accuracy = []
60 |
61 | for seq_idx, seq_name in tqdm(enumerate(seq_names)):
62 | keypoints_a, descriptors_a = read_function(seq_name, 1)
63 | keypoints_a, unique_idx = np.unique(keypoints_a, return_index=True, axis=0)
64 | descriptors_a = descriptors_a[unique_idx]
65 |
66 | h, w = cv2.imread(os.path.join(args.hpatches_path, seq_name, '1.ppm')).shape[:2]
67 |
68 | for im_idx in range(2, 7):
69 | h2, w2 = cv2.imread(os.path.join(args.hpatches_path, seq_name, '{}.ppm'.format(im_idx))).shape[:2]
70 | keypoints_b, descriptors_b = read_function(seq_name, im_idx)
71 | keypoints_b, unique_idx = np.unique(keypoints_b, return_index=True, axis=0)
72 | descriptors_b = descriptors_b[unique_idx]
73 |
74 | matches = mnn_matcher(
75 | torch.from_numpy(descriptors_a).cuda(),
76 | torch.from_numpy(descriptors_b).cuda(),
77 | metric=args.metric
78 | )
79 |
80 | H_gt = np.loadtxt(os.path.join(args.hpatches_path, seq_name, "H_1_" + str(im_idx)))
81 | pts_a = keypoints_a[matches[:, 0]].reshape(-1, 1, 2).astype(np.float32)
82 | pts_b = keypoints_b[matches[:, 1]].reshape(-1, 1, 2).astype(np.float32)
83 |
84 | if args.mode == 'ransac':
85 | H, mask = cv2.findHomography(pts_a, pts_b, cv2.RANSAC, ransacReprojThreshold=3)
86 | elif args.mode == 'lmeds':
87 | H, mask = cv2.findHomography(pts_a, pts_b, cv2.LMEDS, ransacReprojThreshold=3)
88 |
89 | corners = np.array([[0, 0, 1],
90 | [0, h-1, 1],
91 | [w - 1, 0, 1],
92 | [w - 1, h - 1, 1]])
93 |
94 | real_warped_corners = np.dot(corners, np.transpose(H_gt))
95 | real_warped_corners = real_warped_corners[:, :2] / real_warped_corners[:, 2:]
96 | warped_corners = np.dot(corners, np.transpose(H))
97 | warped_corners = warped_corners[:, :2] / warped_corners[:, 2:]
98 |
99 | mean_dist = np.mean(np.linalg.norm(real_warped_corners - warped_corners, axis=1))
100 | correctness = mean_dist <= th
101 |
102 | if seq_name[0] == 'i':
103 | i_accuracy.append(correctness)
104 | elif seq_name[0] == 'v':
105 | v_accuracy.append(correctness)
106 |
107 | i_accuracy = np.array(i_accuracy)
108 | v_accuracy = np.array(v_accuracy)
109 | i_mean_accuracy = np.mean(i_accuracy, axis=0)
110 | v_mean_accuracy = np.mean(v_accuracy, axis=0)
111 | overall_mean_accuracy = np.mean(np.concatenate((i_accuracy, v_accuracy), axis=0), axis=0)
112 | print('overall_acc: {}, i_acc: {}, v_acc: {}'.format(
113 | overall_mean_accuracy * 100, i_mean_accuracy * 100, v_mean_accuracy * 100))
--------------------------------------------------------------------------------
/eval_hpatches.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import sys
3 | import os
4 | import argparse
5 | import numpy as np
6 | from PIL import Image
7 | from torch.utils.data import Dataset
8 | from tqdm import tqdm
9 | import skimage.io as io
10 | import torch.nn.functional as F
11 | from src.models.dift_sd import SDFeaturizer4Eval
12 | from src.models.dift_adm import ADMFeaturizer4Eval
13 |
14 | class HPatchDataset(Dataset):
15 | def __init__(self, imdir, spdir):
16 | self.imfs = []
17 | for f in os.listdir(imdir):
18 | scene_dir = os.path.join(imdir, f)
19 | self.imfs.extend([os.path.join(scene_dir, '{}.ppm').format(ind) for ind in range(1, 7)])
20 | self.spdir = spdir
21 |
22 | def __getitem__(self, item):
23 | imf = self.imfs[item]
24 | im = io.imread(imf)
25 | name, idx = imf.split('/')[-2:]
26 | coord = np.loadtxt(os.path.join(self.spdir, f'{name}-{idx[0]}.kp')).astype(np.float32)
27 | out = {'coord': coord, 'imf': imf}
28 | return out
29 |
30 | def __len__(self):
31 | return len(self.imfs)
32 |
33 |
34 | def main(args):
35 | for arg in vars(args):
36 | value = getattr(args,arg)
37 | if value is not None:
38 | print('%s: %s' % (str(arg),str(value)))
39 |
40 | dataset = HPatchDataset(imdir=args.hpatches_path, spdir=args.kpts_path)
41 | data_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)
42 | if args.dift_model == 'sd':
43 | dift = SDFeaturizer4Eval()
44 | elif args.dift_model == 'adm':
45 | dift = ADMFeaturizer4Eval()
46 |
47 | with torch.no_grad():
48 | for data in tqdm(data_loader):
49 | img_path = data['imf'][0]
50 | img = Image.open(img_path)
51 | w, h = img.size
52 | coord = data['coord'].to('cuda')
53 | c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).to(coord.device).float()
54 | coord_norm = (coord - c) / c
55 |
56 | feat = dift.forward(img,
57 | img_size=args.img_size,
58 | t=args.t,
59 | up_ft_index=args.up_ft_index,
60 | ensemble_size=args.ensemble_size)
61 |
62 | feat = F.grid_sample(feat, coord_norm.unsqueeze(2)).squeeze(-1)
63 | feat = feat.transpose(1, 2)
64 |
65 | desc = feat.squeeze(0).detach().cpu().numpy()
66 | kpt = coord.cpu().numpy().squeeze(0)
67 |
68 | out_dir = os.path.join(args.save_path, os.path.basename(os.path.dirname(img_path)))
69 | os.makedirs(out_dir, exist_ok=True)
70 | with open(os.path.join(out_dir, f'{os.path.basename(img_path)}.dift_{args.dift_model}'), 'wb') as output_file:
71 | np.savez(
72 | output_file,
73 | keypoints=kpt,
74 | scores=[],
75 | descriptors=desc
76 | )
77 |
78 |
79 | if __name__ == "__main__":
80 | parser = argparse.ArgumentParser(description='SPair-71k Evaluation Script')
81 | parser.add_argument('--hpatches_path', type=str, default='/scratch/dift_release/d2-net/hpatches_sequences/hpatches-sequences-release', help='path to hpatches dataset')
82 | parser.add_argument('--kpts_path', type=str, default='./superpoint-1k', help='path to 1k superpoint keypoints')
83 | parser.add_argument('--save_path', type=str, default='./hpatches_results', help='path to save features')
84 | parser.add_argument('--dift_model', choices=['sd', 'adm'], default='sd', help="which dift version to use")
85 | parser.add_argument('--img_size', nargs='+', type=int, default=[768, 768],
86 | help='''in the order of [width, height], resize input image
87 | to [w, h] before fed into diffusion model, if set to 0, will
88 | stick to the original input size. by default is 768x768.''')
89 | parser.add_argument('--t', default=261, type=int, help='t for diffusion')
90 | parser.add_argument('--up_ft_index', default=1, type=int, help='which upsampling block to extract the ft map')
91 | parser.add_argument('--ensemble_size', default=8, type=int, help='ensemble size for getting an image ft map')
92 | args = parser.parse_args()
93 | main(args)
--------------------------------------------------------------------------------
/eval_spair.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | from torch.nn import functional as F
4 | from tqdm import tqdm
5 | import numpy as np
6 | from src.models.dift_sd import SDFeaturizer4Eval
7 | from src.models.dift_adm import ADMFeaturizer4Eval
8 | import os
9 | import json
10 | from PIL import Image
11 | import torch.nn as nn
12 |
13 |
14 | def main(args):
15 | for arg in vars(args):
16 | value = getattr(args,arg)
17 | if value is not None:
18 | print('%s: %s' % (str(arg),str(value)))
19 |
20 | torch.cuda.set_device(0)
21 |
22 | dataset_path = args.dataset_path
23 | test_path = 'PairAnnotation/test'
24 | json_list = os.listdir(os.path.join(dataset_path, test_path))
25 | all_cats = os.listdir(os.path.join(dataset_path, 'JPEGImages'))
26 | cat2json = {}
27 |
28 | for cat in all_cats:
29 | cat_list = []
30 | for i in json_list:
31 | if cat in i:
32 | cat_list.append(i)
33 | cat2json[cat] = cat_list
34 |
35 | # get test image path for all cats
36 | cat2img = {}
37 | for cat in all_cats:
38 | cat2img[cat] = []
39 | cat_list = cat2json[cat]
40 | for json_path in cat_list:
41 | with open(os.path.join(dataset_path, test_path, json_path)) as temp_f:
42 | data = json.load(temp_f)
43 | temp_f.close()
44 | src_imname = data['src_imname']
45 | trg_imname = data['trg_imname']
46 | if src_imname not in cat2img[cat]:
47 | cat2img[cat].append(src_imname)
48 | if trg_imname not in cat2img[cat]:
49 | cat2img[cat].append(trg_imname)
50 |
51 | if args.dift_model == 'sd':
52 | dift = SDFeaturizer4Eval(cat_list=all_cats)
53 | elif args.dift_model == 'adm':
54 | dift = ADMFeaturizer4Eval()
55 |
56 | print("saving all test images' features...")
57 | os.makedirs(args.save_path, exist_ok=True)
58 | for cat in tqdm(all_cats):
59 | output_dict = {}
60 | image_list = cat2img[cat]
61 | for image_path in image_list:
62 | img = Image.open(os.path.join(dataset_path, 'JPEGImages', cat, image_path))
63 | output_dict[image_path] = dift.forward(img,
64 | category=cat,
65 | img_size=args.img_size,
66 | t=args.t,
67 | up_ft_index=args.up_ft_index,
68 | ensemble_size=args.ensemble_size)
69 | torch.save(output_dict, os.path.join(args.save_path, f'{cat}.pth'))
70 |
71 | total_pck = []
72 | all_correct = 0
73 | all_total = 0
74 |
75 | for cat in all_cats:
76 | cat_list = cat2json[cat]
77 | output_dict = torch.load(os.path.join(args.save_path, f'{cat}.pth'))
78 |
79 | cat_pck = []
80 | cat_correct = 0
81 | cat_total = 0
82 |
83 | for json_path in tqdm(cat_list):
84 |
85 | with open(os.path.join(dataset_path, test_path, json_path)) as temp_f:
86 | data = json.load(temp_f)
87 |
88 | src_img_size = data['src_imsize'][:2][::-1]
89 | trg_img_size = data['trg_imsize'][:2][::-1]
90 |
91 | src_ft = output_dict[data['src_imname']]
92 | trg_ft = output_dict[data['trg_imname']]
93 |
94 | src_ft = nn.Upsample(size=src_img_size, mode='bilinear')(src_ft)
95 | trg_ft = nn.Upsample(size=trg_img_size, mode='bilinear')(trg_ft)
96 | h = trg_ft.shape[-2]
97 | w = trg_ft.shape[-1]
98 |
99 | trg_bndbox = data['trg_bndbox']
100 | threshold = max(trg_bndbox[3] - trg_bndbox[1], trg_bndbox[2] - trg_bndbox[0])
101 |
102 | total = 0
103 | correct = 0
104 |
105 | for idx in range(len(data['src_kps'])):
106 | total += 1
107 | cat_total += 1
108 | all_total += 1
109 | src_point = data['src_kps'][idx]
110 | trg_point = data['trg_kps'][idx]
111 |
112 | num_channel = src_ft.size(1)
113 | src_vec = src_ft[0, :, src_point[1], src_point[0]].view(1, num_channel) # 1, C
114 | trg_vec = trg_ft.view(num_channel, -1).transpose(0, 1) # HW, C
115 | src_vec = F.normalize(src_vec).transpose(0, 1) # c, 1
116 | trg_vec = F.normalize(trg_vec) # HW, c
117 | cos_map = torch.mm(trg_vec, src_vec).view(h, w).cpu().numpy() # H, W
118 |
119 | max_yx = np.unravel_index(cos_map.argmax(), cos_map.shape)
120 |
121 | dist = ((max_yx[1] - trg_point[0]) ** 2 + (max_yx[0] - trg_point[1]) ** 2) ** 0.5
122 | if (dist / threshold) <= 0.1:
123 | correct += 1
124 | cat_correct += 1
125 | all_correct += 1
126 |
127 | cat_pck.append(correct / total)
128 | total_pck.extend(cat_pck)
129 |
130 | print(f'{cat} per image PCK@0.1: {np.mean(cat_pck) * 100:.2f}')
131 | print(f'{cat} per point PCK@0.1: {cat_correct / cat_total * 100:.2f}')
132 | print(f'All per image PCK@0.1: {np.mean(total_pck) * 100:.2f}')
133 | print(f'All per point PCK@0.1: {all_correct / all_total * 100:.2f}')
134 |
135 |
136 | if __name__ == "__main__":
137 | parser = argparse.ArgumentParser(description='SPair-71k Evaluation Script')
138 | parser.add_argument('--dataset_path', type=str, default='./SPair-71k/', help='path to spair dataset')
139 | parser.add_argument('--save_path', type=str, default='/scratch/lt453/spair_ft/', help='path to save features')
140 | parser.add_argument('--dift_model', choices=['sd', 'adm'], default='sd', help="which dift version to use")
141 | parser.add_argument('--img_size', nargs='+', type=int, default=[768, 768],
142 | help='''in the order of [width, height], resize input image
143 | to [w, h] before fed into diffusion model, if set to 0, will
144 | stick to the original input size. by default is 768x768.''')
145 | parser.add_argument('--t', default=261, type=int, help='t for diffusion')
146 | parser.add_argument('--up_ft_index', default=1, type=int, help='which upsampling block to extract the ft map')
147 | parser.add_argument('--ensemble_size', default=8, type=int, help='ensemble size for getting an image ft map')
148 | args = parser.parse_args()
149 | main(args)
--------------------------------------------------------------------------------
/extract_dift.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | from PIL import Image
4 | from torchvision.transforms import PILToTensor
5 | from src.models.dift_sd import SDFeaturizer
6 |
7 | def main(args):
8 | dift = SDFeaturizer(args.model_id)
9 | img = Image.open(args.input_path).convert('RGB')
10 | if args.img_size[0] > 0:
11 | img = img.resize(args.img_size)
12 | img_tensor = (PILToTensor()(img) / 255.0 - 0.5) * 2
13 | ft = dift.forward(img_tensor,
14 | prompt=args.prompt,
15 | t=args.t,
16 | up_ft_index=args.up_ft_index,
17 | ensemble_size=args.ensemble_size)
18 | ft = torch.save(ft.squeeze(0).cpu(), args.output_path) # save feature in the shape of [c, h, w]
19 |
20 |
21 | if __name__ == '__main__':
22 |
23 | parser = argparse.ArgumentParser(
24 | description='''extract dift from input image, and save it as torch tenosr,
25 | in the shape of [c, h, w].''')
26 |
27 | parser.add_argument('--img_size', nargs='+', type=int, default=[768, 768],
28 | help='''in the order of [width, height], resize input image
29 | to [w, h] before fed into diffusion model, if set to 0, will
30 | stick to the original input size. by default is 768x768.''')
31 | parser.add_argument('--model_id', default='stabilityai/stable-diffusion-2-1', type=str,
32 | help='model_id of the diffusion model in huggingface')
33 | parser.add_argument('--t', default=261, type=int,
34 | help='time step for diffusion, choose from range [0, 1000]')
35 | parser.add_argument('--up_ft_index', default=1, type=int, choices=[0, 1, 2 ,3],
36 | help='which upsampling block of U-Net to extract the feature map')
37 | parser.add_argument('--prompt', default='', type=str,
38 | help='prompt used in the stable diffusion')
39 | parser.add_argument('--ensemble_size', default=8, type=int,
40 | help='number of repeated images in each batch used to get features')
41 | parser.add_argument('--input_path', type=str,
42 | help='path to the input image file')
43 | parser.add_argument('--output_path', type=str, default='dift.pt',
44 | help='path to save the output features as torch tensor')
45 | args = parser.parse_args()
46 | main(args)
--------------------------------------------------------------------------------
/extract_dift.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 \
2 | python extract_dift.py \
3 | --input_path ./assets/cat.png \
4 | --output_path dift_cat.pt \
5 | --img_size 0 \
6 | --t 261 \
7 | --up_ft_index 1 \
8 | --prompt 'a photo of a cat' \
9 | --ensemble_size 8
--------------------------------------------------------------------------------
/sd_featurizer_spair.py:
--------------------------------------------------------------------------------
1 | from diffusers import StableDiffusionPipeline
2 | import torch
3 | import torch.nn as nn
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | from typing import Any, Callable, Dict, List, Optional, Union, Tuple
7 | from diffusers.models.unet_2d_condition import UNet2DConditionModel, UNet2DConditionOutput
8 | from diffusers import DDIMScheduler
9 | from diffusers.models.modeling_utils import ModelMixin
10 | import gc
11 | from PIL import Image
12 | from torchvision.transforms import PILToTensor
13 | import os
14 | from lavis.models import load_model_and_preprocess
15 | import json
16 | from PIL import Image, ImageDraw
17 |
18 |
19 | class MyUNet2DConditionModel(UNet2DConditionModel):
20 | def forward(
21 | self,
22 | sample: torch.FloatTensor,
23 | timestep: Union[torch.Tensor, float, int],
24 | up_ft_indices,
25 | encoder_hidden_states: torch.Tensor,
26 | class_labels: Optional[torch.Tensor] = None,
27 | timestep_cond: Optional[torch.Tensor] = None,
28 | attention_mask: Optional[torch.Tensor] = None,
29 | cross_attention_kwargs: Optional[Dict[str, Any]] = None):
30 | r"""
31 | Args:
32 | sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
33 | timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
34 | encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
35 | cross_attention_kwargs (`dict`, *optional*):
36 | A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
37 | `self.processor` in
38 | [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
39 | """
40 | # By default samples have to be AT least a multiple of the overall upsampling factor.
41 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
42 | # However, the upsampling interpolation output size can be forced to fit any upsampling size
43 | # on the fly if necessary.
44 | default_overall_up_factor = 2**self.num_upsamplers
45 |
46 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
47 | forward_upsample_size = False
48 | upsample_size = None
49 |
50 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
51 | # logger.info("Forward upsample size to force interpolation output size.")
52 | forward_upsample_size = True
53 |
54 | # prepare attention_mask
55 | if attention_mask is not None:
56 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
57 | attention_mask = attention_mask.unsqueeze(1)
58 |
59 | # 0. center input if necessary
60 | if self.config.center_input_sample:
61 | sample = 2 * sample - 1.0
62 |
63 | # 1. time
64 | timesteps = timestep
65 | if not torch.is_tensor(timesteps):
66 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
67 | # This would be a good case for the `match` statement (Python 3.10+)
68 | is_mps = sample.device.type == "mps"
69 | if isinstance(timestep, float):
70 | dtype = torch.float32 if is_mps else torch.float64
71 | else:
72 | dtype = torch.int32 if is_mps else torch.int64
73 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
74 | elif len(timesteps.shape) == 0:
75 | timesteps = timesteps[None].to(sample.device)
76 |
77 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
78 | timesteps = timesteps.expand(sample.shape[0])
79 |
80 | t_emb = self.time_proj(timesteps)
81 |
82 | # timesteps does not contain any weights and will always return f32 tensors
83 | # but time_embedding might actually be running in fp16. so we need to cast here.
84 | # there might be better ways to encapsulate this.
85 | t_emb = t_emb.to(dtype=self.dtype)
86 |
87 | emb = self.time_embedding(t_emb, timestep_cond)
88 |
89 | if self.class_embedding is not None:
90 | if class_labels is None:
91 | raise ValueError("class_labels should be provided when num_class_embeds > 0")
92 |
93 | if self.config.class_embed_type == "timestep":
94 | class_labels = self.time_proj(class_labels)
95 |
96 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
97 | emb = emb + class_emb
98 |
99 | # 2. pre-process
100 | sample = self.conv_in(sample)
101 |
102 | # 3. down
103 | down_block_res_samples = (sample,)
104 | for downsample_block in self.down_blocks:
105 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
106 | sample, res_samples = downsample_block(
107 | hidden_states=sample,
108 | temb=emb,
109 | encoder_hidden_states=encoder_hidden_states,
110 | attention_mask=attention_mask,
111 | cross_attention_kwargs=cross_attention_kwargs,
112 | )
113 | else:
114 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
115 |
116 | down_block_res_samples += res_samples
117 |
118 | # 4. mid
119 | if self.mid_block is not None:
120 | sample = self.mid_block(
121 | sample,
122 | emb,
123 | encoder_hidden_states=encoder_hidden_states,
124 | attention_mask=attention_mask,
125 | cross_attention_kwargs=cross_attention_kwargs,
126 | )
127 |
128 | # 5. up
129 | up_ft = {}
130 | for i, upsample_block in enumerate(self.up_blocks):
131 |
132 | if i > np.max(up_ft_indices):
133 | break
134 |
135 | is_final_block = i == len(self.up_blocks) - 1
136 |
137 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
138 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
139 |
140 | # if we have not reached the final block and need to forward the
141 | # upsample size, we do it here
142 | if not is_final_block and forward_upsample_size:
143 | upsample_size = down_block_res_samples[-1].shape[2:]
144 |
145 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
146 | sample = upsample_block(
147 | hidden_states=sample,
148 | temb=emb,
149 | res_hidden_states_tuple=res_samples,
150 | encoder_hidden_states=encoder_hidden_states,
151 | cross_attention_kwargs=cross_attention_kwargs,
152 | upsample_size=upsample_size,
153 | attention_mask=attention_mask,
154 | )
155 | else:
156 | sample = upsample_block(
157 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
158 | )
159 |
160 | if i in up_ft_indices:
161 | up_ft[i] = sample.detach()
162 |
163 | output = {}
164 | output['up_ft'] = up_ft
165 | return output
166 |
167 | class OneStepSDPipeline(StableDiffusionPipeline):
168 | @torch.no_grad()
169 | def __call__(
170 | self,
171 | img_tensor,
172 | t,
173 | up_ft_indices,
174 | negative_prompt: Optional[Union[str, List[str]]] = None,
175 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
176 | prompt_embeds: Optional[torch.FloatTensor] = None,
177 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
178 | callback_steps: int = 1,
179 | cross_attention_kwargs: Optional[Dict[str, Any]] = None
180 | ):
181 |
182 | device = self._execution_device
183 | latents = self.vae.encode(img_tensor).latent_dist.sample() * self.vae.config.scaling_factor
184 | t = torch.tensor(t, dtype=torch.long, device=device)
185 | noise = torch.randn_like(latents).to(device)
186 | latents_noisy = self.scheduler.add_noise(latents, noise, t)
187 | unet_output = self.unet(latents_noisy,
188 | t,
189 | up_ft_indices,
190 | encoder_hidden_states=prompt_embeds,
191 | cross_attention_kwargs=cross_attention_kwargs)
192 | return unet_output
193 |
194 |
195 | class SDFeaturizer:
196 | def __init__(self, sd_id='stabilityai/stable-diffusion-2-1', null_prompt=''):
197 | unet = MyUNet2DConditionModel.from_pretrained(sd_id, subfolder="unet")
198 | onestep_pipe = OneStepSDPipeline.from_pretrained(sd_id, unet=unet, safety_checker=None)
199 | onestep_pipe.vae.decoder = None
200 | onestep_pipe.scheduler = DDIMScheduler.from_pretrained(sd_id, subfolder="scheduler")
201 | with torch.no_grad():
202 | cat2prompt = {}
203 | all_cats = os.listdir('/home/lt453/SPair-71k/JPEGImages')
204 | for cat in all_cats:
205 | prompt = f"a photo of a {cat}"
206 | prompt_embeds = onestep_pipe._encode_prompt(
207 | prompt=prompt,
208 | device='cpu',
209 | num_images_per_prompt=1,
210 | do_classifier_free_guidance=False) # [1, 77, dim]
211 | cat2prompt[cat] = prompt_embeds
212 | null_prompt_embeds = onestep_pipe._encode_prompt(
213 | prompt=null_prompt,
214 | device='cpu',
215 | num_images_per_prompt=1,
216 | do_classifier_free_guidance=False) # [1, 77, dim]
217 | onestep_pipe.tokenizer = None
218 | onestep_pipe.text_encoder = None
219 | gc.collect()
220 | onestep_pipe = onestep_pipe.to("cuda")
221 | self.cat2prompt = cat2prompt
222 | self.null_prompt_embeds = null_prompt_embeds
223 | onestep_pipe.enable_attention_slicing()
224 | onestep_pipe.enable_xformers_memory_efficient_attention()
225 | self.pipe = onestep_pipe
226 |
227 | @torch.no_grad()
228 | def forward(self,
229 | img,
230 | category,
231 | img_size=[768, 768],
232 | t=261,
233 | up_ft_index=1,
234 | ensemble_size=8):
235 | if img_size is not None:
236 | img = img.resize(img_size)
237 | img_tensor = (PILToTensor()(img) / 255.0 - 0.5) * 2
238 | img_tensor = img_tensor.unsqueeze(0).repeat(ensemble_size, 1, 1, 1).cuda() # ensem, c, h, w
239 | if category in self.cat2prompt:
240 | prompt_embeds = self.cat2prompt[category]
241 | else:
242 | prompt_embeds = self.null_prompt_embeds
243 | prompt_embeds = prompt_embeds.repeat(ensemble_size, 1, 1).cuda()
244 | unet_ft_all = self.pipe(
245 | img_tensor=img_tensor,
246 | t=t,
247 | up_ft_indices=[up_ft_index],
248 | prompt_embeds=prompt_embeds)
249 | unet_ft = unet_ft_all['up_ft'][up_ft_index] # ensem, c, h, w
250 | unet_ft = unet_ft.mean(0, keepdim=True) # n, c,h,w
251 | return unet_ft
--------------------------------------------------------------------------------
/setup_env.sh:
--------------------------------------------------------------------------------
1 | conda create -n dift python=3.10
2 | conda activate dift
3 |
4 | conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.7 -c pytorch -c nvidia
5 | conda install xformers -c xformers
6 | pip install jupyterlab
7 | pip install diffusers[torch]==0.15.0
8 | pip install -U matplotlib
9 | pip install transformers
10 | pip install ipympl
11 | pip install triton
--------------------------------------------------------------------------------
/src/models/clip.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 | import gc
6 | from PIL import Image, ImageDraw
7 | from torchvision.transforms import PILToTensor
8 | import os
9 | import open_clip
10 | from torchvision import transforms
11 | import copy
12 | import json
13 | from tqdm.notebook import tqdm
14 | import time
15 | import datetime
16 | import math
17 |
18 |
19 | def interpolate_pos_encoding(clip_model, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
20 | """
21 | This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
22 | resolution images.
23 | Source:
24 | https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
25 | """
26 |
27 | num_patches = embeddings.shape[1] - 1
28 | pos_embedding = clip_model.positional_embedding.unsqueeze(0)
29 | num_positions = pos_embedding.shape[1] - 1
30 | if num_patches == num_positions and height == width:
31 | return clip_model.positional_embedding
32 | class_pos_embed = pos_embedding[:, 0]
33 | patch_pos_embed = pos_embedding[:, 1:]
34 | dim = embeddings.shape[-1]
35 | h0 = height // clip_model.patch_size[0]
36 | w0 = width // clip_model.patch_size[1]
37 | # we add a small number to avoid floating point error in the interpolation
38 | # see discussion at https://github.com/facebookresearch/dino/issues/8
39 | h0, w0 = h0 + 0.1, w0 + 0.1
40 | patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
41 | patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
42 | patch_pos_embed = nn.functional.interpolate(
43 | patch_pos_embed,
44 | scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
45 | mode="bicubic",
46 | align_corners=False,
47 | )
48 | assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
49 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
50 | output = torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
51 |
52 | return output
53 |
54 |
55 | class CLIPFeaturizer:
56 | def __init__(self):
57 | clip_model, _, _ = open_clip.create_model_and_transforms('ViT-H-14', pretrained='laion2b_s32b_b79k')
58 | visual_model = clip_model.visual
59 | visual_model.output_tokens = True
60 | self.clip_model = visual_model.eval().cuda()
61 |
62 |
63 | @torch.no_grad()
64 | def forward(self,
65 | x, # single image, [1,c,h,w]
66 | block_index):
67 | batch_size = 1
68 | clip_model = self.clip_model
69 | if clip_model.input_patchnorm:
70 | # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')
71 | x = x.reshape(x.shape[0], x.shape[1], clip_model.grid_size[0], clip_model.patch_size[0], clip_model.grid_size[1], clip_model.patch_size[1])
72 | x = x.permute(0, 2, 4, 1, 3, 5)
73 | x = x.reshape(x.shape[0], clip_model.grid_size[0] * clip_model.grid_size[1], -1)
74 | x = clip_model.patchnorm_pre_ln(x)
75 | x = clip_model.conv1(x)
76 | else:
77 | x = clip_model.conv1(x) # shape = [*, width, grid, grid]
78 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
79 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
80 | # class embeddings and positional embeddings
81 | x = torch.cat(
82 | [clip_model.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
83 | x], dim=1) # shape = [*, grid ** 2 + 1, width]
84 | if(x.shape[1] > clip_model.positional_embedding.shape[0]):
85 | dim = int(math.sqrt(x.shape[1]) * clip_model.patch_size[0])
86 | x = x + interpolate_pos_encoding(clip_model, x, dim, dim).to(x.dtype)
87 | else:
88 | x = x + clip_model.positional_embedding.to(x.dtype)
89 |
90 | # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
91 | x = clip_model.patch_dropout(x)
92 | x = clip_model.ln_pre(x)
93 | x = x.permute(1, 0, 2) # NLD -> LND
94 |
95 | num_channel = x.size(2)
96 | ft_size = int((x.shape[0]-1) ** 0.5)
97 |
98 | for i, r in enumerate(clip_model.transformer.resblocks):
99 | x = r(x)
100 |
101 | if i == block_index:
102 | tokens = x.permute(1, 0, 2) # LND -> NLD
103 | tokens = tokens[:, 1:]
104 | tokens = tokens.transpose(1, 2).contiguous().view(batch_size, num_channel, ft_size, ft_size) # NCHW
105 |
106 | return tokens
--------------------------------------------------------------------------------
/src/models/dift_adm.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | import torch
4 | from torchvision import transforms
5 | main_path = Path(__file__).resolve().parent.parent.parent
6 | print(f'main path: {main_path}')
7 |
8 | import sys
9 | sys.path.append(os.path.join(main_path, 'guided-diffusion'))
10 | from guided_diffusion.script_util import create_model_and_diffusion
11 | from guided_diffusion.nn import timestep_embedding
12 |
13 |
14 | class ADMFeaturizer:
15 | def __init__(self):
16 | model, diffusion = create_model_and_diffusion(
17 | image_size=256,
18 | class_cond=False,
19 | learn_sigma=True,
20 | num_channels=256,
21 | num_res_blocks=2,
22 | channel_mult="",
23 | num_heads=4,
24 | num_head_channels=64,
25 | num_heads_upsample=-1,
26 | attention_resolutions="32,16,8",
27 | dropout=0.0,
28 | diffusion_steps=1000,
29 | noise_schedule='linear',
30 | timestep_respacing='',
31 | use_kl=False,
32 | predict_xstart=False,
33 | rescale_timesteps=False,
34 | rescale_learned_sigmas=False,
35 | use_checkpoint=False,
36 | use_scale_shift_norm=True,
37 | resblock_updown=True,
38 | use_fp16=False,
39 | use_new_attention_order=False,
40 | )
41 | model_path = os.path.join(main_path, 'guided-diffusion/models/256x256_diffusion_uncond.pt')
42 | model.load_state_dict(torch.load(model_path, map_location="cpu"))
43 | self.model = model.eval().cuda()
44 | self.diffusion = diffusion
45 |
46 | self.adm_transforms = transforms.Compose([
47 | transforms.ToTensor(),
48 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
49 | ])
50 |
51 | @torch.no_grad()
52 | def forward(self, img_tensor,
53 | t=101,
54 | up_ft_index=4,
55 | ensemble_size=8):
56 | model = self.model
57 | diffusion = self.diffusion
58 |
59 | img_tensor = img_tensor.repeat(ensemble_size, 1, 1, 1).cuda() # ensem, c, h, w
60 | t = torch.ones((img_tensor.shape[0],), device='cuda', dtype=torch.int64) * t
61 | x_t = diffusion.q_sample(img_tensor, t, noise=None)
62 |
63 | # get layer-wise features
64 | hs = []
65 | emb = model.time_embed(timestep_embedding(t, model.model_channels))
66 | h = x_t.type(model.dtype)
67 | for module in model.input_blocks:
68 | h = module(h, emb)
69 | hs.append(h)
70 | h = model.middle_block(h, emb)
71 | for i, module in enumerate(model.output_blocks):
72 | h = torch.cat([h, hs.pop()], dim=1)
73 | h = module(h, emb)
74 |
75 | if i == up_ft_index:
76 | ft = h.mean(0, keepdim=True).detach()
77 | return ft
78 |
79 |
80 | class ADMFeaturizer4Eval(ADMFeaturizer):
81 |
82 | @torch.no_grad()
83 | def forward(self, img,
84 | img_size=[512, 512],
85 | t=101,
86 | up_ft_index=4,
87 | ensemble_size=8,
88 | **kwargs):
89 |
90 | img_tensor = self.adm_transforms(img.resize(img_size))
91 | ft = super().forward(img_tensor,
92 | t=t,
93 | up_ft_index=up_ft_index,
94 | ensemble_size=ensemble_size)
95 | return ft
--------------------------------------------------------------------------------
/src/models/dift_sd.py:
--------------------------------------------------------------------------------
1 | from diffusers import StableDiffusionPipeline
2 | import torch
3 | import torch.nn as nn
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | from typing import Any, Callable, Dict, List, Optional, Union
7 | from diffusers.models.unet_2d_condition import UNet2DConditionModel
8 | from diffusers import DDIMScheduler
9 | import gc
10 | import os
11 | from PIL import Image
12 | from torchvision.transforms import PILToTensor
13 |
14 | class MyUNet2DConditionModel(UNet2DConditionModel):
15 | def forward(
16 | self,
17 | sample: torch.FloatTensor,
18 | timestep: Union[torch.Tensor, float, int],
19 | up_ft_indices,
20 | encoder_hidden_states: torch.Tensor,
21 | class_labels: Optional[torch.Tensor] = None,
22 | timestep_cond: Optional[torch.Tensor] = None,
23 | attention_mask: Optional[torch.Tensor] = None,
24 | cross_attention_kwargs: Optional[Dict[str, Any]] = None):
25 | r"""
26 | Args:
27 | sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
28 | timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
29 | encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
30 | cross_attention_kwargs (`dict`, *optional*):
31 | A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
32 | `self.processor` in
33 | [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
34 | """
35 | # By default samples have to be AT least a multiple of the overall upsampling factor.
36 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
37 | # However, the upsampling interpolation output size can be forced to fit any upsampling size
38 | # on the fly if necessary.
39 | default_overall_up_factor = 2**self.num_upsamplers
40 |
41 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
42 | forward_upsample_size = False
43 | upsample_size = None
44 |
45 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
46 | # logger.info("Forward upsample size to force interpolation output size.")
47 | forward_upsample_size = True
48 |
49 | # prepare attention_mask
50 | if attention_mask is not None:
51 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
52 | attention_mask = attention_mask.unsqueeze(1)
53 |
54 | # 0. center input if necessary
55 | if self.config.center_input_sample:
56 | sample = 2 * sample - 1.0
57 |
58 | # 1. time
59 | timesteps = timestep
60 | if not torch.is_tensor(timesteps):
61 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
62 | # This would be a good case for the `match` statement (Python 3.10+)
63 | is_mps = sample.device.type == "mps"
64 | if isinstance(timestep, float):
65 | dtype = torch.float32 if is_mps else torch.float64
66 | else:
67 | dtype = torch.int32 if is_mps else torch.int64
68 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
69 | elif len(timesteps.shape) == 0:
70 | timesteps = timesteps[None].to(sample.device)
71 |
72 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
73 | timesteps = timesteps.expand(sample.shape[0])
74 |
75 | t_emb = self.time_proj(timesteps)
76 |
77 | # timesteps does not contain any weights and will always return f32 tensors
78 | # but time_embedding might actually be running in fp16. so we need to cast here.
79 | # there might be better ways to encapsulate this.
80 | t_emb = t_emb.to(dtype=self.dtype)
81 |
82 | emb = self.time_embedding(t_emb, timestep_cond)
83 |
84 | if self.class_embedding is not None:
85 | if class_labels is None:
86 | raise ValueError("class_labels should be provided when num_class_embeds > 0")
87 |
88 | if self.config.class_embed_type == "timestep":
89 | class_labels = self.time_proj(class_labels)
90 |
91 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
92 | emb = emb + class_emb
93 |
94 | # 2. pre-process
95 | sample = self.conv_in(sample)
96 |
97 | # 3. down
98 | down_block_res_samples = (sample,)
99 | for downsample_block in self.down_blocks:
100 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
101 | sample, res_samples = downsample_block(
102 | hidden_states=sample,
103 | temb=emb,
104 | encoder_hidden_states=encoder_hidden_states,
105 | attention_mask=attention_mask,
106 | cross_attention_kwargs=cross_attention_kwargs,
107 | )
108 | else:
109 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
110 |
111 | down_block_res_samples += res_samples
112 |
113 | # 4. mid
114 | if self.mid_block is not None:
115 | sample = self.mid_block(
116 | sample,
117 | emb,
118 | encoder_hidden_states=encoder_hidden_states,
119 | attention_mask=attention_mask,
120 | cross_attention_kwargs=cross_attention_kwargs,
121 | )
122 |
123 | # 5. up
124 | up_ft = {}
125 | for i, upsample_block in enumerate(self.up_blocks):
126 |
127 | if i > np.max(up_ft_indices):
128 | break
129 |
130 | is_final_block = i == len(self.up_blocks) - 1
131 |
132 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
133 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
134 |
135 | # if we have not reached the final block and need to forward the
136 | # upsample size, we do it here
137 | if not is_final_block and forward_upsample_size:
138 | upsample_size = down_block_res_samples[-1].shape[2:]
139 |
140 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
141 | sample = upsample_block(
142 | hidden_states=sample,
143 | temb=emb,
144 | res_hidden_states_tuple=res_samples,
145 | encoder_hidden_states=encoder_hidden_states,
146 | cross_attention_kwargs=cross_attention_kwargs,
147 | upsample_size=upsample_size,
148 | attention_mask=attention_mask,
149 | )
150 | else:
151 | sample = upsample_block(
152 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
153 | )
154 |
155 | if i in up_ft_indices:
156 | up_ft[i] = sample.detach()
157 |
158 | output = {}
159 | output['up_ft'] = up_ft
160 | return output
161 |
162 | class OneStepSDPipeline(StableDiffusionPipeline):
163 | @torch.no_grad()
164 | def __call__(
165 | self,
166 | img_tensor,
167 | t,
168 | up_ft_indices,
169 | negative_prompt: Optional[Union[str, List[str]]] = None,
170 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
171 | prompt_embeds: Optional[torch.FloatTensor] = None,
172 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
173 | callback_steps: int = 1,
174 | cross_attention_kwargs: Optional[Dict[str, Any]] = None
175 | ):
176 |
177 | device = self._execution_device
178 | latents = self.vae.encode(img_tensor).latent_dist.sample() * self.vae.config.scaling_factor
179 | t = torch.tensor(t, dtype=torch.long, device=device)
180 | noise = torch.randn_like(latents).to(device)
181 | latents_noisy = self.scheduler.add_noise(latents, noise, t)
182 | unet_output = self.unet(latents_noisy,
183 | t,
184 | up_ft_indices,
185 | encoder_hidden_states=prompt_embeds,
186 | cross_attention_kwargs=cross_attention_kwargs)
187 | return unet_output
188 |
189 |
190 | class SDFeaturizer:
191 | def __init__(self, sd_id='stabilityai/stable-diffusion-2-1', null_prompt=''):
192 | unet = MyUNet2DConditionModel.from_pretrained(sd_id, subfolder="unet")
193 | onestep_pipe = OneStepSDPipeline.from_pretrained(sd_id, unet=unet, safety_checker=None)
194 | onestep_pipe.vae.decoder = None
195 | onestep_pipe.scheduler = DDIMScheduler.from_pretrained(sd_id, subfolder="scheduler")
196 | gc.collect()
197 | onestep_pipe = onestep_pipe.to("cuda")
198 | onestep_pipe.enable_attention_slicing()
199 | onestep_pipe.enable_xformers_memory_efficient_attention()
200 | null_prompt_embeds = onestep_pipe._encode_prompt(
201 | prompt=null_prompt,
202 | device='cuda',
203 | num_images_per_prompt=1,
204 | do_classifier_free_guidance=False) # [1, 77, dim]
205 |
206 | self.null_prompt_embeds = null_prompt_embeds
207 | self.null_prompt = null_prompt
208 | self.pipe = onestep_pipe
209 |
210 | @torch.no_grad()
211 | def forward(self,
212 | img_tensor,
213 | prompt='',
214 | t=261,
215 | up_ft_index=1,
216 | ensemble_size=8):
217 | '''
218 | Args:
219 | img_tensor: should be a single torch tensor in the shape of [1, C, H, W] or [C, H, W]
220 | prompt: the prompt to use, a string
221 | t: the time step to use, should be an int in the range of [0, 1000]
222 | up_ft_index: which upsampling block of the U-Net to extract feature, you can choose [0, 1, 2, 3]
223 | ensemble_size: the number of repeated images used in the batch to extract features
224 | Return:
225 | unet_ft: a torch tensor in the shape of [1, c, h, w]
226 | '''
227 | img_tensor = img_tensor.repeat(ensemble_size, 1, 1, 1).cuda() # ensem, c, h, w
228 | if prompt == self.null_prompt:
229 | prompt_embeds = self.null_prompt_embeds
230 | else:
231 | prompt_embeds = self.pipe._encode_prompt(
232 | prompt=prompt,
233 | device='cuda',
234 | num_images_per_prompt=1,
235 | do_classifier_free_guidance=False) # [1, 77, dim]
236 | prompt_embeds = prompt_embeds.repeat(ensemble_size, 1, 1)
237 | unet_ft_all = self.pipe(
238 | img_tensor=img_tensor,
239 | t=t,
240 | up_ft_indices=[up_ft_index],
241 | prompt_embeds=prompt_embeds)
242 | unet_ft = unet_ft_all['up_ft'][up_ft_index] # ensem, c, h, w
243 | unet_ft = unet_ft.mean(0, keepdim=True) # 1,c,h,w
244 | return unet_ft
245 |
246 |
247 | class SDFeaturizer4Eval(SDFeaturizer):
248 | def __init__(self, sd_id='stabilityai/stable-diffusion-2-1', null_prompt='', cat_list=[]):
249 | super().__init__(sd_id, null_prompt)
250 | with torch.no_grad():
251 | cat2prompt_embeds = {}
252 | for cat in cat_list:
253 | prompt = f"a photo of a {cat}"
254 | prompt_embeds = self.pipe._encode_prompt(
255 | prompt=prompt,
256 | device='cuda',
257 | num_images_per_prompt=1,
258 | do_classifier_free_guidance=False) # [1, 77, dim]
259 | cat2prompt_embeds[cat] = prompt_embeds
260 | self.cat2prompt_embeds = cat2prompt_embeds
261 |
262 | self.pipe.tokenizer = None
263 | self.pipe.text_encoder = None
264 | gc.collect()
265 | torch.cuda.empty_cache()
266 |
267 |
268 | @torch.no_grad()
269 | def forward(self,
270 | img,
271 | category=None,
272 | img_size=[768, 768],
273 | t=261,
274 | up_ft_index=1,
275 | ensemble_size=8):
276 | if img_size is not None:
277 | img = img.resize(img_size)
278 | img_tensor = (PILToTensor()(img) / 255.0 - 0.5) * 2
279 | img_tensor = img_tensor.unsqueeze(0).repeat(ensemble_size, 1, 1, 1).cuda() # ensem, c, h, w
280 | if category in self.cat2prompt_embeds:
281 | prompt_embeds = self.cat2prompt_embeds[category]
282 | else:
283 | prompt_embeds = self.null_prompt_embeds
284 | prompt_embeds = prompt_embeds.repeat(ensemble_size, 1, 1).cuda()
285 | unet_ft_all = self.pipe(
286 | img_tensor=img_tensor,
287 | t=t,
288 | up_ft_indices=[up_ft_index],
289 | prompt_embeds=prompt_embeds)
290 | unet_ft = unet_ft_all['up_ft'][up_ft_index] # ensem, c, h, w
291 | unet_ft = unet_ft.mean(0, keepdim=True) # 1,c,h,w
292 | return unet_ft
--------------------------------------------------------------------------------
/src/models/dino.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class DINOFeaturizer:
4 | def __init__(self, dino_id='dino_vitb8'):
5 | self.model = torch.hub.load('facebookresearch/dino:main', dino_id).eval().cuda()
6 |
7 | @torch.no_grad()
8 | def forward(self, img_tensor, block_index):
9 | h = img_tensor.shape[2] // 8
10 | w = img_tensor.shape[3] // 8
11 | n = 12 - block_index
12 | out = self.model.get_intermediate_layers(img_tensor, n=n)[0][0, 1:, :] # hw, c
13 | dim = out.shape[1]
14 | out = out.transpose(0, 1).view(1, dim, h, w)
15 | return out
--------------------------------------------------------------------------------
/src/utils/visualization.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import matplotlib.pyplot as plt
3 | import torch
4 | import torch.nn as nn
5 | import numpy as np
6 | import torch.nn.functional as F
7 |
8 | class Demo:
9 |
10 | def __init__(self, imgs, ft, img_size):
11 | self.ft = ft # N+1, C, H, W
12 | self.imgs = imgs
13 | self.num_imgs = len(imgs)
14 | self.img_size = img_size
15 |
16 | def plot_img_pairs(self, fig_size=3, alpha=0.45, scatter_size=70):
17 |
18 | fig, axes = plt.subplots(1, self.num_imgs, figsize=(fig_size*self.num_imgs, fig_size))
19 |
20 | plt.tight_layout()
21 |
22 | for i in range(self.num_imgs):
23 | axes[i].imshow(self.imgs[i])
24 | axes[i].axis('off')
25 | if i == 0:
26 | axes[i].set_title('source image')
27 | else:
28 | axes[i].set_title('target image')
29 |
30 | num_channel = self.ft.size(1)
31 |
32 | def onclick(event):
33 | if event.inaxes == axes[0]:
34 | with torch.no_grad():
35 |
36 | x, y = int(np.round(event.xdata)), int(np.round(event.ydata))
37 |
38 | src_ft = self.ft[0].unsqueeze(0)
39 | src_ft = nn.Upsample(size=(self.img_size, self.img_size), mode='bilinear')(src_ft)
40 | src_vec = src_ft[0, :, y, x].view(1, num_channel) # 1, C
41 |
42 | del src_ft
43 | gc.collect()
44 | torch.cuda.empty_cache()
45 |
46 | trg_ft = nn.Upsample(size=(self.img_size, self.img_size), mode='bilinear')(self.ft[1:]) # N, C, H, W
47 | trg_vec = trg_ft.view(self.num_imgs - 1, num_channel, -1) # N, C, HW
48 |
49 | del trg_ft
50 | gc.collect()
51 | torch.cuda.empty_cache()
52 |
53 | src_vec = F.normalize(src_vec) # 1, C
54 | trg_vec = F.normalize(trg_vec) # N, C, HW
55 | cos_map = torch.matmul(src_vec, trg_vec).view(self.num_imgs - 1, self.img_size, self.img_size).cpu().numpy() # N, H, W
56 |
57 | axes[0].clear()
58 | axes[0].imshow(self.imgs[0])
59 | axes[0].axis('off')
60 | axes[0].scatter(x, y, c='r', s=scatter_size)
61 | axes[0].set_title('source image')
62 |
63 | for i in range(1, self.num_imgs):
64 | max_yx = np.unravel_index(cos_map[i-1].argmax(), cos_map[i-1].shape)
65 | axes[i].clear()
66 |
67 | heatmap = cos_map[i-1]
68 | heatmap = (heatmap - np.min(heatmap)) / (np.max(heatmap) - np.min(heatmap)) # Normalize to [0, 1]
69 | axes[i].imshow(self.imgs[i])
70 | axes[i].imshow(255 * heatmap, alpha=alpha, cmap='viridis')
71 | axes[i].axis('off')
72 | axes[i].scatter(max_yx[1].item(), max_yx[0].item(), c='r', s=scatter_size)
73 | axes[i].set_title('target image')
74 |
75 | del cos_map
76 | del heatmap
77 | gc.collect()
78 |
79 | fig.canvas.mpl_connect('button_press_event', onclick)
80 | plt.show()
--------------------------------------------------------------------------------