├── .gitignore ├── LICENSE ├── README.md ├── assets ├── cartoon.png ├── cat.png ├── demo_cat_480p.gif ├── demo_guitar_480p.gif ├── edit_cat.gif ├── guitar.png ├── guitar_cat.jpg ├── painting_cat.jpg ├── target_cat.png ├── target_guitar.png └── teaser.gif ├── demo.ipynb ├── edit_propagation.ipynb ├── environment.yml ├── eval_davis.py ├── eval_homography.py ├── eval_hpatches.py ├── eval_spair.py ├── extract_dift.py ├── extract_dift.sh ├── sd_featurizer_spair.py ├── setup_env.sh └── src ├── models ├── clip.py ├── dift_adm.py ├── dift_sd.py └── dino.py └── utils └── visualization.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | */.DS_Store 163 | .DS_Store 164 | 165 | guided-diffusion/ 166 | davis_results_sd/ 167 | davis_results_adm/ 168 | superpoint-1k/ 169 | hpatches_results/ 170 | superpoint-1k.zip 171 | SPair-71k.tar.gz 172 | SPair-71k/ 173 | ./guided-diffusion/models/256x256_diffusion_uncond.pt -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Luming Tang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Diffusion Features (DIFT) 2 | This repository contains code for our NeurIPS 2023 paper "Emergent Correspondence from Image Diffusion". 3 | 4 | ### [Project Page](https://diffusionfeatures.github.io/) | [Paper](https://arxiv.org/abs/2306.03881) | [Colab Demo](https://colab.research.google.com/drive/1km6MGafhAvbPOouD3oo64aUXgLlWM6L1?usp=sharing) 5 | 6 | ![video](./assets/teaser.gif) 7 | 8 | ## Prerequisites 9 | If you have a Linux machine, you could either set up the python environment using the following command: 10 | ``` 11 | conda env create -f environment.yml 12 | conda activate dift 13 | ``` 14 | or create a new conda environment and install the packages manually using the 15 | shell commands in [setup_env.sh](setup_env.sh). 16 | 17 | ## Interactive Demo: Give it a Try! 18 | We provide an interactive jupyter notebook [demo.ipynb](demo.ipynb) to demonstrate the semantic correspondence established by DIFT, and you could try on your own images! After loading two images, you could left-click on an interesting point of the source image on the left, then after 1 or 2 seconds, the corresponding point on the target image will be displayed as a red point on the right, together with a heatmap showing the per-pixel cosine distance calculated using DIFT. Here're two examples on cat and guitar: 19 | 20 | 21 | 22 | 23 | 24 | 25 |
demo catdemo cat
26 | 27 | If you don't have a local GPU, you can also use the provided [Colab Demo](https://colab.research.google.com/drive/1km6MGafhAvbPOouD3oo64aUXgLlWM6L1?usp=sharing). 28 | 29 | ## Extract DIFT for a given image 30 | You could use the following [command](extract_dift.sh) to extract DIFT from a given image, and save it as a torch tensor. These arguments are set to the same as in the semantic correspondence tasks by default. 31 | ``` 32 | python extract_dift.py \ 33 | --input_path ./assets/cat.png \ 34 | --output_path dift_cat.pt \ 35 | --img_size 768 768 \ 36 | --t 261 \ 37 | --up_ft_index 1 \ 38 | --prompt 'a photo of a cat' \ 39 | --ensemble_size 8 40 | ``` 41 | Here're the explanation for each argument: 42 | - `input_path`: path to the input image file. 43 | - `output_path`: path to save the output features as torch tensor. 44 | - `img_size`: the width and height of the resized image before fed into diffusion model. If set to 0, then no resize operation would be performed thus it will stick to the original image size. It is set to [768, 768] by default. You can decrease this if encountering memory issue. 45 | - `t`: time step for diffusion, choose from range [0, 1000], must be an integer. `t=261` by default for semantic correspondence. 46 | - `up_ft_index`: the index of the U-Net upsampling block to extract the feature map, choose from [0, 1, 2, 3]. `up_ft_index=1` by default for semantic correspondence. 47 | - `prompt`: the prompt used in the diffusion model. 48 | - `ensemble_size`: the number of repeated images in each batch used to get features. `ensemble_size=8` by default. You can reduce this value if encountering memory issue. 49 | 50 | The output DIFT tensor spatial size is determined by both `img_size` and `up_ft_index`. If `up_ft_index=0`, the output size would be 1/32 of `img_size`; if `up_ft_index=1`, it would be 1/16; if `up_ft_index=2 or 3`, it would be 1/8. 51 | 52 | ## Application: Edit Propagation 53 | Using DIFT, we can propagate edits in one image to others that share semantic correspondences, even cross categories and domains: 54 | edit cat 55 | More implementation details are in this notebook [edit_propagation.ipynb](edit_propagation.ipynb). 56 | 57 | ## Get Benchmark Evaluation Results 58 | First, run the following scripts to enable the usage of DIFT_adm: 59 | ``` 60 | git clone git@github.com:openai/guided-diffusion.git 61 | cd guided-diffusion && mkdir models && cd models 62 | wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt 63 | ``` 64 | 65 | ### SPair-71k 66 | 67 | First, download SPair-71k data: 68 | ``` 69 | wget https://cvlab.postech.ac.kr/research/SPair-71k/data/SPair-71k.tar.gz 70 | tar -xzvf SPair-71k.tar.gz 71 | ``` 72 | Run the following script to get PCK (both per point and per img) of DIFT_sd on SPair-71k: 73 | ``` 74 | python eval_spair.py \ 75 | --dataset_path ./SPair-71k \ 76 | --save_path ./spair_ft \ # a path to save features 77 | --dift_model sd \ 78 | --img_size 768 768 \ 79 | --t 261 \ 80 | --up_ft_index 1 \ 81 | --ensemble_size 8 82 | ``` 83 | Run the following script to get PCK (both per point and per img) of DIFT_adm on SPair-71k: 84 | ``` 85 | python eval_spair.py \ 86 | --dataset_path ./SPair-71k \ 87 | --save_path ./spair_ft \ # a path to save features 88 | --dift_model adm \ 89 | --img_size 512 512 \ 90 | --t 101 \ 91 | --up_ft_index 4 \ 92 | --ensemble_size 8 93 | ``` 94 | 95 | ### HPatches 96 | 97 | First, prepare HPatches data: 98 | ``` 99 | cd $HOME 100 | git clone git@github.com:mihaidusmanu/d2-net.git && cd d2-net/hpatches_sequences/ 101 | chmod u+x download.sh 102 | ./download.sh 103 | ``` 104 | 105 | Then, download the 1k superpoint keypoints: 106 | ``` 107 | wget "https://www.dropbox.com/scl/fi/1mxy3oycnz7m2acd92u2x/superpoint-1k.zip?rlkey=fic30gr2tlth3cmsyyywcg385&dl=1" -O superpoint-1k.zip 108 | unzip superpoint-1k.zip 109 | rm superpoint-1k.zip 110 | ``` 111 | 112 | Run the following script to get hompography estimation accuracy of DIFT_sd on HPatches: 113 | ``` 114 | python eval_hpatches.py \ 115 | --hpatches_path ../d2-net/hpatches_sequences/hpatches-sequences-release \ 116 | --kpts_path ./superpoint-1k \ 117 | --save_path ./hpatches_results \ 118 | --dift_model sd \ 119 | --img_size 768 768 \ 120 | --t 0 \ 121 | --up_ft_index 2 \ 122 | --ensemble_size 8 123 | 124 | python eval_homography.py \ 125 | --hpatches_path ../d2-net/hpatches_sequences/hpatches-sequences-release \ 126 | --save_path ./hpatches_results \ 127 | --hpatches_path 128 | --feat dift_sd \ 129 | --metric cosine \ 130 | --mode lmeds 131 | ``` 132 | 133 | Run the following script to get hompography estimation accuracy of DIFT_adm on HPatches: 134 | ``` 135 | python eval_hpatches.py \ 136 | --hpatches_path ../d2-net/hpatches_sequences/hpatches-sequences-release \ 137 | --kpts_path ./superpoint-1k \ 138 | --save_path ./hpatches_results \ 139 | --dift_model adm \ 140 | --img_size 768 768 \ 141 | --t 41 \ 142 | --up_ft_index 11 \ 143 | --ensemble_size 4 144 | 145 | python eval_homography.py \ 146 | --hpatches_path ../d2-net/hpatches_sequences/hpatches-sequences-release \ 147 | --save_path ./hpatches_results \ 148 | --hpatches_path 149 | --feat dift_adm \ 150 | --metric l2 \ 151 | --mode ransac 152 | ``` 153 | 154 | ### DAVIS 155 | 156 | We follow the evaluation protocal as in DINO's [implementation](https://github.com/facebookresearch/dino#evaluation-davis-2017-video-object-segmentation). 157 | 158 | First, prepare DAVIS 2017 data and evaluation tools: 159 | ``` 160 | cd $HOME 161 | git clone https://github.com/davisvideochallenge/davis-2017 && cd davis-2017 162 | ./data/get_davis.sh 163 | cd $HOME 164 | git clone https://github.com/davisvideochallenge/davis2017-evaluation 165 | ``` 166 | 167 | Then, get segmentation results using DIFT_sd: 168 | ``` 169 | python eval_davis.py \ 170 | --dift_model sd \ 171 | --t 51 \ 172 | --up_ft_index 2 \ 173 | --temperature 0.2 \ 174 | --topk 15 \ 175 | --n_last_frames 28 \ 176 | --ensemble_size 8 \ 177 | --size_mask_neighborhood 15 \ 178 | --data_path $HOME/davis-2017/DAVIS/ \ 179 | --output_dir ./davis_results_sd/ 180 | ``` 181 | 182 | and results using DIFT_adm: 183 | ``` 184 | python eval_davis.py \ 185 | --dift_model adm \ 186 | --t 51 \ 187 | --up_ft_index 7 \ 188 | --temperature 0.1 \ 189 | --topk 10 \ 190 | --n_last_frames 28 \ 191 | --ensemble_size 4 \ 192 | --size_mask_neighborhood 15 \ 193 | --data_path $HOME/davis-2017/DAVIS/ \ 194 | --output_dir ./davis_results_adm/ 195 | ``` 196 | 197 | Finally, evaluate the results: 198 | ``` 199 | python $HOME/davis2017-evaluation/evaluation_method.py \ 200 | --task semi-supervised \ 201 | --results_path ./davis_results_sd/ \ 202 | --davis_path $HOME/davis-2017/DAVIS/ 203 | 204 | python $HOME/davis2017-evaluation/evaluation_method.py \ 205 | --task semi-supervised \ 206 | --results_path ./davis_results_adm/ \ 207 | --davis_path $HOME/davis-2017/DAVIS/ 208 | ``` 209 | 210 | # Misc. 211 | If you find our code or paper useful to your research work, please consider citing our work using the following bibtex: 212 | ``` 213 | @inproceedings{ 214 | tang2023emergent, 215 | title={Emergent Correspondence from Image Diffusion}, 216 | author={Luming Tang and Menglin Jia and Qianqian Wang and Cheng Perng Phoo and Bharath Hariharan}, 217 | booktitle={Thirty-seventh Conference on Neural Information Processing Systems}, 218 | year={2023}, 219 | url={https://openreview.net/forum?id=ypOiXjdfnU} 220 | } 221 | ``` 222 | -------------------------------------------------------------------------------- /assets/cartoon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tsingularity/dift/9421eb2034396c5b66f1aff37f03e540c264e52f/assets/cartoon.png -------------------------------------------------------------------------------- /assets/cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tsingularity/dift/9421eb2034396c5b66f1aff37f03e540c264e52f/assets/cat.png -------------------------------------------------------------------------------- /assets/demo_cat_480p.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tsingularity/dift/9421eb2034396c5b66f1aff37f03e540c264e52f/assets/demo_cat_480p.gif -------------------------------------------------------------------------------- /assets/demo_guitar_480p.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tsingularity/dift/9421eb2034396c5b66f1aff37f03e540c264e52f/assets/demo_guitar_480p.gif -------------------------------------------------------------------------------- /assets/edit_cat.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tsingularity/dift/9421eb2034396c5b66f1aff37f03e540c264e52f/assets/edit_cat.gif -------------------------------------------------------------------------------- /assets/guitar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tsingularity/dift/9421eb2034396c5b66f1aff37f03e540c264e52f/assets/guitar.png -------------------------------------------------------------------------------- /assets/guitar_cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tsingularity/dift/9421eb2034396c5b66f1aff37f03e540c264e52f/assets/guitar_cat.jpg -------------------------------------------------------------------------------- /assets/painting_cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tsingularity/dift/9421eb2034396c5b66f1aff37f03e540c264e52f/assets/painting_cat.jpg -------------------------------------------------------------------------------- /assets/target_cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tsingularity/dift/9421eb2034396c5b66f1aff37f03e540c264e52f/assets/target_cat.png -------------------------------------------------------------------------------- /assets/target_guitar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tsingularity/dift/9421eb2034396c5b66f1aff37f03e540c264e52f/assets/target_guitar.png -------------------------------------------------------------------------------- /assets/teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tsingularity/dift/9421eb2034396c5b66f1aff37f03e540c264e52f/assets/teaser.gif -------------------------------------------------------------------------------- /demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "f6e50bfe-edb9-4932-bf43-39f047bf36d1", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%matplotlib widget\n", 11 | "import argparse\n", 12 | "import gc\n", 13 | "import random\n", 14 | "import torch\n", 15 | "from PIL import Image\n", 16 | "from torchvision.transforms import PILToTensor\n", 17 | "from src.models.dift_sd import SDFeaturizer\n", 18 | "from src.utils.visualization import Demo" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "id": "85e967e6-ca78-424d-ac69-2bc2c0bd744f", 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "torch.cuda.set_device(0)" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "id": "1bf04cbe-b63d-4484-9918-7fac9f3506e9", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "dift = SDFeaturizer()" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "id": "829c1cf2-82a1-45ea-b284-19a055311855", 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "# you can choose visualize cat or guitar\n", 49 | "category = random.choice(['cat', 'guitar'])\n", 50 | "\n", 51 | "print(f\"let's visualize semantic correspondence on {category}\")" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "id": "076f7afc-232c-4fff-83dc-c63b876ee12f", 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "if category == 'cat':\n", 62 | " filelist = ['./assets/cat.png', './assets/target_cat.png', './assets/target_cat.png']\n", 63 | "elif category == 'guitar':\n", 64 | " filelist = ['./assets/guitar.png', './assets/target_guitar.png', './assets/target_guitar.png']\n", 65 | "\n", 66 | "prompt = f'a photo of a {category}'" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "id": "245f2c3f-8445-42c5-b31f-be438c7239d9", 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "ft = []\n", 77 | "imglist = []\n", 78 | "\n", 79 | "# decrease these two if you don't have enough RAM or GPU memory\n", 80 | "img_size = 768\n", 81 | "ensemble_size = 8\n", 82 | "\n", 83 | "for filename in filelist:\n", 84 | " img = Image.open(filename).convert('RGB')\n", 85 | " img = img.resize((img_size, img_size))\n", 86 | " imglist.append(img)\n", 87 | " img_tensor = (PILToTensor()(img) / 255.0 - 0.5) * 2\n", 88 | " ft.append(dift.forward(img_tensor,\n", 89 | " prompt=prompt,\n", 90 | " ensemble_size=ensemble_size))\n", 91 | "ft = torch.cat(ft, dim=0)\n", 92 | "\n", 93 | "gc.collect()\n", 94 | "torch.cuda.empty_cache()" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "id": "f81ae975-cce2-491d-9b6b-58412559b8b6", 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "demo = Demo(imglist, ft, img_size)" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "id": "6f9b5ef9-db57-46dc-9ad2-9758b5d573c4", 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "'''\n", 115 | "left is source image, right is target image.\n", 116 | "you can click on the source image, and DIFT will find the corresponding\n", 117 | "point on the right image, mark it with red point and also plot the per-pixel \n", 118 | "cosine distance as heatmap.\n", 119 | "'''\n", 120 | "demo.plot_img_pairs(fig_size=5)" 121 | ] 122 | } 123 | ], 124 | "metadata": { 125 | "kernelspec": { 126 | "display_name": "Python 3 (ipykernel)", 127 | "language": "python", 128 | "name": "python3" 129 | }, 130 | "language_info": { 131 | "codemirror_mode": { 132 | "name": "ipython", 133 | "version": 3 134 | }, 135 | "file_extension": ".py", 136 | "mimetype": "text/x-python", 137 | "name": "python", 138 | "nbconvert_exporter": "python", 139 | "pygments_lexer": "ipython3", 140 | "version": "3.10.9" 141 | } 142 | }, 143 | "nbformat": 4, 144 | "nbformat_minor": 5 145 | } 146 | -------------------------------------------------------------------------------- /edit_propagation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "3306ccce-4b17-41a9-831d-add6cccddc0e", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch\n", 11 | "import torch.nn as nn\n", 12 | "import matplotlib.pyplot as plt\n", 13 | "import numpy as np\n", 14 | "import gc\n", 15 | "import imageio\n", 16 | "from PIL import Image\n", 17 | "from torchvision.transforms import PILToTensor\n", 18 | "import os\n", 19 | "import json\n", 20 | "from PIL import Image, ImageDraw\n", 21 | "import torch.nn.functional as F\n", 22 | "import cv2\n", 23 | "import glob\n", 24 | "from torchvision.transforms import PILToTensor\n", 25 | "from src.models.dift_sd import SDFeaturizer4Eval" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "id": "081cd585-9d9d-4ffe-8c9b-6c6360d2e4ad", 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "def gen_grid(h, w, device, normalize=False, homogeneous=False):\n", 36 | " if normalize:\n", 37 | " lin_y = torch.linspace(-1., 1., steps=h, device=device)\n", 38 | " lin_x = torch.linspace(-1., 1., steps=w, device=device)\n", 39 | " else:\n", 40 | " lin_y = torch.arange(0, h, device=device)\n", 41 | " lin_x = torch.arange(0, w, device=device)\n", 42 | " grid_y, grid_x = torch.meshgrid((lin_y, lin_x))\n", 43 | " grid = torch.stack((grid_x, grid_y), -1)\n", 44 | " if homogeneous:\n", 45 | " grid = torch.cat([grid, torch.ones_like(grid[..., :1])], dim=-1)\n", 46 | " return grid # [h, w, 2 or 3]\n", 47 | "\n", 48 | "\n", 49 | "def normalize_coords(coords, h, w, no_shift=False):\n", 50 | " assert coords.shape[-1] == 2\n", 51 | " if no_shift:\n", 52 | " return coords / torch.tensor([w-1., h-1.], device=coords.device) * 2\n", 53 | " else:\n", 54 | " return coords / torch.tensor([w-1., h-1.], device=coords.device) * 2 - 1." 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "id": "2a13b459-4698-4a9c-803f-d7ba8adb6962", 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "cat = 'cat'\n", 65 | "dift = SDFeaturizer4Eval(cat_list=['cat'])" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "id": "0606e9dd-9e51-49ec-bf37-1f2bc9f78a84", 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "src_img = Image.open('./assets/guitar_cat.jpg').convert('RGB')\n", 76 | "trg_img = Image.open('./assets/painting_cat.jpg').convert('RGB')\n", 77 | "sticker = imageio.imread('./assets/cartoon.png')\n", 78 | "sticker_color, sticker_mask = sticker[..., :3], sticker[..., 3]\n", 79 | "\n", 80 | "assert np.array(src_img).shape[:2] == sticker.shape[:2]\n", 81 | "h_src, w_src = sticker.shape[:2]\n", 82 | "h_trg, w_trg = np.array(trg_img).shape[:2]\n", 83 | "\n", 84 | "sd_feat_src = dift.forward(src_img, cat)\n", 85 | "sd_feat_trg = dift.forward(trg_img, cat)\n", 86 | "\n", 87 | "sd_feat_src = F.normalize(sd_feat_src.squeeze(), p=2, dim=0)\n", 88 | "sd_feat_trg = F.normalize(sd_feat_trg.squeeze(), p=2, dim=0)\n", 89 | "feat_dim = sd_feat_src.shape[0]\n", 90 | "\n", 91 | "grid_src = gen_grid(h_src, w_src, device='cuda')\n", 92 | "grid_trg = gen_grid(h_trg, w_trg, device='cuda')\n", 93 | "\n", 94 | "coord_src = grid_src[sticker_mask > 0]\n", 95 | "coord_src = coord_src[torch.randperm(len(coord_src))][:1000]\n", 96 | "coord_src_normed = normalize_coords(coord_src, h_src, w_src)\n", 97 | "grid_trg_normed = normalize_coords(grid_trg, h_trg, w_trg)\n", 98 | "\n", 99 | "feat_src = F.grid_sample(sd_feat_src[None], coord_src_normed[None, None], align_corners=True).squeeze().T\n", 100 | "feat_trg = F.grid_sample(sd_feat_trg[None], grid_trg_normed[None], align_corners=True).squeeze()\n", 101 | "feat_trg_flattened = feat_trg.permute(1, 2, 0).reshape(-1, feat_dim)\n", 102 | "\n", 103 | "distances = torch.cdist(feat_src, feat_trg_flattened)\n", 104 | "_, indices = torch.min(distances, dim=1)\n", 105 | "\n", 106 | "src_pts = coord_src.reshape(-1, 2).cpu().numpy()\n", 107 | "trg_pts = grid_trg.reshape(-1, 2)[indices].cpu().numpy()\n", 108 | "\n", 109 | "M, mask = cv2.findHomography(src_pts, trg_pts, cv2.RANSAC, 5.0)\n", 110 | "sticker_out = cv2.warpPerspective(sticker, M, (w_trg, h_trg))\n", 111 | "\n", 112 | "sticker_out_alpha = sticker_out[..., 3:] / 255\n", 113 | "sticker_alpha = sticker[..., 3:] / 255\n", 114 | "\n", 115 | "trg_img_with_sticker = sticker_out_alpha * sticker_out[..., :3] + (1 - sticker_out_alpha) * trg_img\n", 116 | "src_img_with_sticker = sticker_alpha * sticker[..., :3] + (1 - sticker_alpha) * src_img" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "id": "88723600-c18f-4eb1-aec7-feb4112e2610", 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "fig, axs = plt.subplots(2, 2, figsize=(10, 10))\n", 127 | "\n", 128 | "axs[0, 0].imshow(src_img)\n", 129 | "axs[0, 0].set_title(\"Source Image\")\n", 130 | "axs[0, 0].axis('off')\n", 131 | "\n", 132 | "axs[0, 1].imshow(src_img_with_sticker.astype(np.uint8))\n", 133 | "axs[0, 1].set_title(\"Source Image with Edits\")\n", 134 | "axs[0, 1].axis('off')\n", 135 | "\n", 136 | "axs[1, 0].imshow(trg_img)\n", 137 | "axs[1, 0].set_title(\"Target Image\")\n", 138 | "axs[1, 0].axis('off')\n", 139 | "\n", 140 | "axs[1, 1].imshow(trg_img_with_sticker.astype(np.uint8))\n", 141 | "axs[1, 1].set_title(\"Target Image with Propagated Edits\")\n", 142 | "axs[1, 1].axis('off')\n", 143 | "\n", 144 | "plt.tight_layout()\n", 145 | "plt.show()" 146 | ] 147 | } 148 | ], 149 | "metadata": { 150 | "kernelspec": { 151 | "display_name": "Python 3 (ipykernel)", 152 | "language": "python", 153 | "name": "python3" 154 | }, 155 | "language_info": { 156 | "codemirror_mode": { 157 | "name": "ipython", 158 | "version": 3 159 | }, 160 | "file_extension": ".py", 161 | "mimetype": "text/x-python", 162 | "name": "python", 163 | "nbconvert_exporter": "python", 164 | "pygments_lexer": "ipython3", 165 | "version": "3.10.9" 166 | } 167 | }, 168 | "nbformat": 4, 169 | "nbformat_minor": 5 170 | } 171 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: dift 2 | channels: 3 | - xformers 4 | - pytorch 5 | - nvidia 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _openmp_mutex=5.1=1_gnu 10 | - blas=1.0=mkl 11 | - brotlipy=0.7.0=py310h7f8727e_1002 12 | - bzip2=1.0.8=h7b6447c_0 13 | - ca-certificates=2023.01.10=h06a4308_0 14 | - certifi=2023.5.7=py310h06a4308_0 15 | - cffi=1.15.1=py310h5eee18b_3 16 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 17 | - cryptography=39.0.1=py310h9ce1e76_0 18 | - cuda-cudart=11.7.99=0 19 | - cuda-cupti=11.7.101=0 20 | - cuda-libraries=11.7.1=0 21 | - cuda-nvrtc=11.7.99=0 22 | - cuda-nvtx=11.7.91=0 23 | - cuda-runtime=11.7.1=0 24 | - ffmpeg=4.3=hf484d3e_0 25 | - freetype=2.12.1=h4a9f257_0 26 | - giflib=5.2.1=h5eee18b_3 27 | - gmp=6.2.1=h295c915_3 28 | - gnutls=3.6.15=he1e5248_0 29 | - idna=3.4=py310h06a4308_0 30 | - intel-openmp=2023.1.0=hdb19cb5_46305 31 | - jpeg=9e=h5eee18b_1 32 | - lame=3.100=h7b6447c_0 33 | - lcms2=2.12=h3be6417_0 34 | - ld_impl_linux-64=2.38=h1181459_1 35 | - lerc=3.0=h295c915_0 36 | - libcublas=11.10.3.66=0 37 | - libcufft=10.7.2.124=h4fbf590_0 38 | - libcufile=1.6.1.9=0 39 | - libcurand=10.3.2.106=0 40 | - libcusolver=11.4.0.1=0 41 | - libcusparse=11.7.4.91=0 42 | - libdeflate=1.17=h5eee18b_0 43 | - libffi=3.4.4=h6a678d5_0 44 | - libgcc-ng=11.2.0=h1234567_1 45 | - libgomp=11.2.0=h1234567_1 46 | - libiconv=1.16=h7f8727e_2 47 | - libidn2=2.3.4=h5eee18b_0 48 | - libnpp=11.7.4.75=0 49 | - libnvjpeg=11.8.0.2=0 50 | - libpng=1.6.39=h5eee18b_0 51 | - libstdcxx-ng=11.2.0=h1234567_1 52 | - libtasn1=4.19.0=h5eee18b_0 53 | - libtiff=4.5.0=h6a678d5_2 54 | - libunistring=0.9.10=h27cfd23_0 55 | - libuuid=1.41.5=h5eee18b_0 56 | - libwebp=1.2.4=h11a3e52_1 57 | - libwebp-base=1.2.4=h5eee18b_1 58 | - lz4-c=1.9.4=h6a678d5_0 59 | - mkl=2023.1.0=h6d00ec8_46342 60 | - mkl-service=2.4.0=py310h5eee18b_1 61 | - mkl_fft=1.3.6=py310h1128e8f_1 62 | - mkl_random=1.2.2=py310h1128e8f_1 63 | - ncurses=6.4=h6a678d5_0 64 | - nettle=3.7.3=hbbd107a_1 65 | - numpy=1.24.3=py310h5f9d8c6_1 66 | - numpy-base=1.24.3=py310hb5e798b_1 67 | - openh264=2.1.1=h4ff587b_0 68 | - openssl=1.1.1t=h7f8727e_0 69 | - pillow=9.4.0=py310h6a678d5_0 70 | - pip=23.0.1=py310h06a4308_0 71 | - pycparser=2.21=pyhd3eb1b0_0 72 | - pyopenssl=23.0.0=py310h06a4308_0 73 | - pysocks=1.7.1=py310h06a4308_0 74 | - python=3.10.9=h7a1cb2a_2 75 | - pytorch=1.13.1=py3.10_cuda11.7_cudnn8.5.0_0 76 | - pytorch-cuda=11.7=h778d358_5 77 | - pytorch-mutex=1.0=cuda 78 | - readline=8.2=h5eee18b_0 79 | - requests=2.29.0=py310h06a4308_0 80 | - setuptools=67.8.0=py310h06a4308_0 81 | - sqlite=3.41.2=h5eee18b_0 82 | - tbb=2021.8.0=hdb19cb5_0 83 | - tk=8.6.12=h1ccaba5_0 84 | - torchaudio=0.13.1=py310_cu117 85 | - torchvision=0.14.1=py310_cu117 86 | - typing_extensions=4.5.0=py310h06a4308_0 87 | - urllib3=1.26.15=py310h06a4308_0 88 | - wheel=0.38.4=py310h06a4308_0 89 | - xformers=0.0.20=py310_cu11.7.1_pyt1.13.1 90 | - xz=5.4.2=h5eee18b_0 91 | - zlib=1.2.13=h5eee18b_0 92 | - zstd=1.5.5=hc292b87_0 93 | - pip: 94 | - accelerate==0.19.0 95 | - aiohttp==3.8.4 96 | - aiosignal==1.3.1 97 | - anyio==3.7.0 98 | - argon2-cffi==21.3.0 99 | - argon2-cffi-bindings==21.2.0 100 | - arrow==1.2.3 101 | - asttokens==2.2.1 102 | - async-lru==2.0.2 103 | - async-timeout==4.0.2 104 | - attrs==23.1.0 105 | - babel==2.12.1 106 | - backcall==0.2.0 107 | - beautifulsoup4==4.12.2 108 | - bleach==6.0.0 109 | - brotli==1.0.9 110 | - cmake==3.26.3 111 | - comm==0.1.3 112 | - contourpy==1.0.7 113 | - cycler==0.11.0 114 | - debugpy==1.6.7 115 | - decorator==5.1.1 116 | - defusedxml==0.7.1 117 | - diffusers==0.15.0 118 | - exceptiongroup==1.1.1 119 | - executing==1.2.0 120 | - fastjsonschema==2.17.1 121 | - filelock==3.12.0 122 | - fonttools==4.39.4 123 | - fqdn==1.5.1 124 | - frozenlist==1.3.3 125 | - fsspec==2023.5.0 126 | - gevent==22.10.2 127 | - geventhttpclient==2.0.2 128 | - greenlet==2.0.2 129 | - grpcio==1.54.2 130 | - huggingface-hub==0.14.1 131 | - imageio==2.33.0 132 | - importlib-metadata==6.6.0 133 | - ipykernel==6.23.1 134 | - ipympl==0.9.3 135 | - ipython==8.13.2 136 | - ipython-genutils==0.2.0 137 | - ipywidgets==8.0.6 138 | - isoduration==20.11.0 139 | - jedi==0.18.2 140 | - jinja2==3.1.2 141 | - json5==0.9.14 142 | - jsonpointer==2.3 143 | - jsonschema==4.17.3 144 | - jupyter-client==8.2.0 145 | - jupyter-core==5.3.0 146 | - jupyter-events==0.6.3 147 | - jupyter-lsp==2.2.0 148 | - jupyter-server==2.6.0 149 | - jupyter-server-terminals==0.4.4 150 | - jupyterlab==4.0.0 151 | - jupyterlab-pygments==0.2.2 152 | - jupyterlab-server==2.22.1 153 | - jupyterlab-widgets==3.0.7 154 | - kiwisolver==1.4.4 155 | - lazy-loader==0.3 156 | - lit==16.0.5 157 | - markupsafe==2.1.2 158 | - matplotlib==3.7.1 159 | - matplotlib-inline==0.1.6 160 | - mistune==2.0.5 161 | - multidict==6.0.4 162 | - mypy-extensions==1.0.0 163 | - nbclient==0.8.0 164 | - nbconvert==7.4.0 165 | - nbformat==5.8.0 166 | - nest-asyncio==1.5.6 167 | - networkx==3.2.1 168 | - notebook-shim==0.2.3 169 | - opencv-python==4.8.1.78 170 | - overrides==7.3.1 171 | - packaging==23.1 172 | - pandas==2.1.4 173 | - pandocfilters==1.5.0 174 | - parso==0.8.3 175 | - pexpect==4.8.0 176 | - pickleshare==0.7.5 177 | - platformdirs==3.5.1 178 | - prometheus-client==0.17.0 179 | - prompt-toolkit==3.0.38 180 | - protobuf==3.20.3 181 | - psutil==5.9.5 182 | - ptyprocess==0.7.0 183 | - pure-eval==0.2.2 184 | - pygments==2.15.1 185 | - pyparsing==3.0.9 186 | - pyrsistent==0.19.3 187 | - python-dateutil==2.8.2 188 | - python-json-logger==2.0.7 189 | - python-rapidjson==1.10 190 | - pytz==2023.3.post1 191 | - pyyaml==6.0 192 | - pyzmq==25.1.0 193 | - regex==2023.5.5 194 | - rfc3339-validator==0.1.4 195 | - rfc3986-validator==0.1.1 196 | - scikit-image==0.22.0 197 | - scipy==1.11.4 198 | - send2trash==1.8.2 199 | - sh==1.14.3 200 | - six==1.16.0 201 | - sniffio==1.3.0 202 | - soupsieve==2.4.1 203 | - stack-data==0.6.2 204 | - terminado==0.17.1 205 | - tifffile==2023.9.26 206 | - tinycss2==1.2.1 207 | - tokenizers==0.13.3 208 | - tomli==2.0.1 209 | - tornado==6.3.2 210 | - tqdm==4.65.0 211 | - traitlets==5.9.0 212 | - transformers==4.29.2 213 | - triton==2.0.0.post1 214 | - tritonclient==2.33.0 215 | - typing-inspect==0.6.0 216 | - tzdata==2023.3 217 | - uri-template==1.2.0 218 | - wcwidth==0.2.6 219 | - webcolors==1.13 220 | - webencodings==0.5.1 221 | - websocket-client==1.5.2 222 | - widgetsnbextension==4.0.7 223 | - wrapt==1.15.0 224 | - yarl==1.9.2 225 | - zipp==3.15.0 226 | - zope-event==4.6 227 | - zope-interface==6.0 -------------------------------------------------------------------------------- /eval_davis.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Some parts are taken from https://github.com/Liusifei/UVC 16 | """ 17 | import os 18 | import copy 19 | import glob 20 | import queue 21 | from urllib.request import urlopen 22 | import argparse 23 | import numpy as np 24 | from tqdm import tqdm 25 | 26 | import gc 27 | import cv2 28 | import torch 29 | from torch.nn import functional as F 30 | from PIL import Image 31 | from src.models.dift_sd import SDFeaturizer 32 | from src.models.dift_adm import ADMFeaturizer 33 | 34 | 35 | @torch.no_grad() 36 | def eval_video_tracking_davis(args, model, scale_factor, frame_list, video_dir, first_seg, seg_ori, color_palette): 37 | """ 38 | Evaluate tracking on a video given first frame & segmentation 39 | """ 40 | video_folder = os.path.join(args.output_dir, video_dir.split('/')[-1]) 41 | os.makedirs(video_folder, exist_ok=True) 42 | 43 | # The queue stores the n preceeding frames 44 | que = queue.Queue(args.n_last_frames) 45 | 46 | # first frame 47 | frame1, ori_h, ori_w = read_frame(frame_list[0]) 48 | # extract first frame feature 49 | frame1_feat = extract_feature(args, model, frame1).T # dim x h*w 50 | 51 | # saving first segmentation 52 | out_path = os.path.join(video_folder, "00000.png") 53 | imwrite_indexed(out_path, seg_ori, color_palette) 54 | mask_neighborhood = None 55 | for cnt in tqdm(range(1, len(frame_list))): 56 | frame_tar = read_frame(frame_list[cnt])[0] 57 | 58 | # we use the first segmentation and the n previous ones 59 | used_frame_feats = [frame1_feat] + [pair[0] for pair in list(que.queue)] 60 | used_segs = [first_seg] + [pair[1] for pair in list(que.queue)] 61 | 62 | frame_tar_avg, feat_tar, mask_neighborhood = label_propagation(args, model, frame_tar, used_frame_feats, used_segs, mask_neighborhood) 63 | 64 | # pop out oldest frame if neccessary 65 | if que.qsize() == args.n_last_frames: 66 | que.get() 67 | # push current results into queue 68 | seg = copy.deepcopy(frame_tar_avg) 69 | que.put([feat_tar, seg]) 70 | 71 | # upsampling & argmax 72 | frame_tar_avg = F.interpolate(frame_tar_avg, scale_factor=scale_factor, mode='bilinear', align_corners=False, recompute_scale_factor=False)[0] 73 | frame_tar_avg = norm_mask(frame_tar_avg) 74 | _, frame_tar_seg = torch.max(frame_tar_avg, dim=0) 75 | 76 | # saving to disk 77 | frame_tar_seg = np.array(frame_tar_seg.squeeze().cpu(), dtype=np.uint8) 78 | frame_tar_seg = np.array(Image.fromarray(frame_tar_seg).resize((ori_w, ori_h), 0)) 79 | frame_nm = frame_list[cnt].split('/')[-1].replace(".jpg", ".png") 80 | imwrite_indexed(os.path.join(video_folder, frame_nm), frame_tar_seg, color_palette) 81 | 82 | 83 | def restrict_neighborhood(h, w): 84 | # We restrict the set of source nodes considered to a spatial neighborhood of the query node (i.e. ``local attention'') 85 | mask = torch.zeros(h, w, h, w) 86 | for i in range(h): 87 | for j in range(w): 88 | for p in range(2 * args.size_mask_neighborhood + 1): 89 | for q in range(2 * args.size_mask_neighborhood + 1): 90 | if i - args.size_mask_neighborhood + p < 0 or i - args.size_mask_neighborhood + p >= h: 91 | continue 92 | if j - args.size_mask_neighborhood + q < 0 or j - args.size_mask_neighborhood + q >= w: 93 | continue 94 | mask[i, j, i - args.size_mask_neighborhood + p, j - args.size_mask_neighborhood + q] = 1 95 | 96 | mask = mask.reshape(h * w, h * w) 97 | return mask.cuda(non_blocking=True) 98 | 99 | 100 | def norm_mask(mask): 101 | c, h, w = mask.size() 102 | for cnt in range(c): 103 | mask_cnt = mask[cnt,:,:] 104 | if(mask_cnt.max() > 0): 105 | mask_cnt = (mask_cnt - mask_cnt.min()) 106 | mask_cnt = mask_cnt/mask_cnt.max() 107 | mask[cnt,:,:] = mask_cnt 108 | return mask 109 | 110 | 111 | def label_propagation(args, model, frame_tar, list_frame_feats, list_segs, mask_neighborhood=None): 112 | """ 113 | propagate segs of frames in list_frames to frame_tar 114 | """ 115 | gc.collect() 116 | torch.cuda.empty_cache() 117 | 118 | ## we only need to extract feature of the target frame 119 | feat_tar, h, w = extract_feature(args, model, frame_tar, return_h_w=True) 120 | 121 | gc.collect() 122 | torch.cuda.empty_cache() 123 | 124 | return_feat_tar = feat_tar.T # dim x h*w 125 | 126 | ncontext = len(list_frame_feats) 127 | feat_sources = torch.stack(list_frame_feats) # nmb_context x dim x h*w 128 | 129 | feat_tar = F.normalize(feat_tar, dim=1, p=2) 130 | feat_sources = F.normalize(feat_sources, dim=1, p=2) 131 | 132 | feat_tar = feat_tar.unsqueeze(0).repeat(ncontext, 1, 1) 133 | aff = torch.exp(torch.bmm(feat_tar, feat_sources) / args.temperature) # nmb_context x h*w (tar: query) x h*w (source: keys) 134 | 135 | if args.size_mask_neighborhood > 0: 136 | if mask_neighborhood is None: 137 | mask_neighborhood = restrict_neighborhood(h, w) 138 | mask_neighborhood = mask_neighborhood.unsqueeze(0).repeat(ncontext, 1, 1) 139 | aff *= mask_neighborhood 140 | 141 | aff = aff.transpose(2, 1).reshape(-1, h * w) # nmb_context*h*w (source: keys) x h*w (tar: queries) 142 | tk_val, _ = torch.topk(aff, dim=0, k=args.topk) 143 | tk_val_min, _ = torch.min(tk_val, dim=0) 144 | aff[aff < tk_val_min] = 0 145 | 146 | aff = aff / torch.sum(aff, keepdim=True, axis=0) 147 | 148 | gc.collect() 149 | torch.cuda.empty_cache() 150 | 151 | list_segs = [s.cuda() for s in list_segs] 152 | segs = torch.cat(list_segs) 153 | nmb_context, C, h, w = segs.shape 154 | segs = segs.reshape(nmb_context, C, -1).transpose(2, 1).reshape(-1, C).T # C x nmb_context*h*w 155 | seg_tar = torch.mm(segs, aff) 156 | seg_tar = seg_tar.reshape(1, C, h, w) 157 | 158 | return seg_tar, return_feat_tar, mask_neighborhood 159 | 160 | 161 | def extract_feature(args, model, frame, return_h_w=False): 162 | """Extract one frame feature everytime.""" 163 | with torch.no_grad(): 164 | unet_ft = model.forward(frame, 165 | t=args.t, 166 | up_ft_index=args.up_ft_index, 167 | ensemble_size=args.ensemble_size).squeeze() # c, h, w 168 | dim, h, w = unet_ft.shape 169 | unet_ft = torch.permute(unet_ft, (1, 2, 0)) # h,w,c 170 | unet_ft = unet_ft.view(h * w, dim) # hw,c 171 | if return_h_w: 172 | return unet_ft, h, w 173 | return unet_ft 174 | 175 | 176 | def imwrite_indexed(filename, array, color_palette): 177 | """ Save indexed png for DAVIS.""" 178 | if np.atleast_3d(array).shape[2] != 1: 179 | raise Exception("Saving indexed PNGs requires 2D array.") 180 | 181 | im = Image.fromarray(array) 182 | im.putpalette(color_palette.ravel()) 183 | im.save(filename, format='PNG') 184 | 185 | 186 | def to_one_hot(y_tensor, n_dims=None): 187 | """ 188 | Take integer y (tensor or variable) with n dims & 189 | convert it to 1-hot representation with n+1 dims. 190 | """ 191 | if(n_dims is None): 192 | n_dims = int(y_tensor.max()+ 1) 193 | _,h,w = y_tensor.size() 194 | y_tensor = y_tensor.type(torch.LongTensor).view(-1, 1) 195 | n_dims = n_dims if n_dims is not None else int(torch.max(y_tensor)) + 1 196 | y_one_hot = torch.zeros(y_tensor.size()[0], n_dims).scatter_(1, y_tensor, 1) 197 | y_one_hot = y_one_hot.view(h,w,n_dims) 198 | return y_one_hot.permute(2, 0, 1).unsqueeze(0) 199 | 200 | 201 | def read_frame_list(video_dir): 202 | frame_list = [img for img in glob.glob(os.path.join(video_dir,"*.jpg"))] 203 | frame_list = sorted(frame_list) 204 | return frame_list 205 | 206 | 207 | def read_frame(frame_dir, scale_size=[480]): 208 | """ 209 | read a single frame & preprocess 210 | """ 211 | img = cv2.imread(frame_dir) 212 | ori_h, ori_w, _ = img.shape 213 | if len(scale_size) == 1: 214 | if(ori_h > ori_w): 215 | tw = scale_size[0] 216 | th = (tw * ori_h) / ori_w 217 | th = int((th // 32) * 32) 218 | else: 219 | th = scale_size[0] 220 | tw = (th * ori_w) / ori_h 221 | tw = int((tw // 32) * 32) 222 | else: 223 | th, tw = scale_size 224 | img = cv2.resize(img, (tw, th)) 225 | img = img.astype(np.float32) 226 | img = img / 255.0 227 | img = img[:, :, ::-1] 228 | img = np.transpose(img.copy(), (2, 0, 1)) 229 | img = torch.from_numpy(img).float() 230 | img = color_normalize(img) 231 | return img, ori_h, ori_w 232 | 233 | 234 | def read_seg(seg_dir, scale_factor, scale_size=[480]): 235 | seg = Image.open(seg_dir) 236 | _w, _h = seg.size # note PIL.Image.Image's size is (w, h) 237 | if len(scale_size) == 1: 238 | if(_w > _h): 239 | _th = scale_size[0] 240 | _tw = (_th * _w) / _h 241 | _tw = int((_tw // 32) * 32) 242 | else: 243 | _tw = scale_size[0] 244 | _th = (_tw * _h) / _w 245 | _th = int((_th // 32) * 32) 246 | else: 247 | _th = scale_size[1] 248 | _tw = scale_size[0] 249 | small_seg = np.array(seg.resize((_tw // scale_factor, _th // scale_factor), 0)) 250 | small_seg = torch.from_numpy(small_seg.copy()).contiguous().float().unsqueeze(0) 251 | 252 | return to_one_hot(small_seg), np.asarray(seg) 253 | 254 | 255 | def color_normalize(x, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): 256 | for t, m, s in zip(x, mean, std): 257 | t.sub_(m) 258 | t.div_(s) 259 | return x 260 | 261 | 262 | if __name__ == '__main__': 263 | parser = argparse.ArgumentParser('Evaluation with video object segmentation on DAVIS 2017') 264 | parser.add_argument('--dift_model', choices=['sd', 'adm'], default='sd', help="which dift version to use") 265 | parser.add_argument('--t', default=201, type=int, help='t for diffusion') 266 | parser.add_argument('--up_ft_index', default=1, type=int, help='which upsampling block to extract the ft map') 267 | parser.add_argument('--ensemble_size', default=4, type=int, help='ensemble size for getting an image ft map') 268 | parser.add_argument('--temperature', default=0.1, type=float, help='temperature for softmax') 269 | 270 | parser.add_argument('--output_dir', type=str, help='Path where to save segmentations') 271 | parser.add_argument('--data_path', type=str, help="path to davis dataset") 272 | parser.add_argument("--n_last_frames", type=int, default=7, help="number of preceeding frames") 273 | parser.add_argument("--size_mask_neighborhood", default=12, type=int, 274 | help="We restrict the set of source nodes considered to a spatial neighborhood of the query node") 275 | parser.add_argument("--topk", type=int, default=5, help="accumulate label from top k neighbors") 276 | args = parser.parse_args() 277 | 278 | print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) 279 | 280 | color_palette = [] 281 | for line in urlopen("https://raw.githubusercontent.com/Liusifei/UVC/master/libs/data/palette.txt"): 282 | color_palette.append([int(i) for i in line.decode("utf-8").split('\n')[0].split(" ")]) 283 | color_palette = np.asarray(color_palette, dtype=np.uint8).reshape(-1,3) 284 | 285 | video_list = open(os.path.join(args.data_path, "ImageSets/2017/val.txt")).readlines() 286 | n_last_frames = args.n_last_frames 287 | 288 | if args.dift_model == 'adm': 289 | index2factor = {0:32, 1:32, 2:16, 3:16, 4:16, 5:8, 6:8, 7:8, 8:4, 290 | 9:4, 10:4, 11:2, 12:2, 13:2, 14:1, 15:1, 16:1, 17:1} 291 | model = ADMFeaturizer() 292 | elif args.dift_model == 'sd': 293 | index2factor = {0:32, 1:16, 2:8, 3:8} 294 | model = SDFeaturizer() 295 | 296 | scale_factor = index2factor[args.up_ft_index] 297 | for i, video_name in enumerate(video_list): 298 | video_name = video_name.strip() 299 | 300 | if video_name == 'shooting': 301 | if args.n_last_frames > 10: 302 | args.n_last_frames = 10 # this can resolve the OOM issue 303 | else: 304 | args.n_last_frames = n_last_frames 305 | 306 | print(f'[{i}/{len(video_list)}] Begin to segmentate video {video_name}.') 307 | video_dir = os.path.join(args.data_path, "JPEGImages/480p/", video_name) 308 | frame_list = read_frame_list(video_dir) 309 | seg_path = frame_list[0].replace("JPEGImages", "Annotations").replace("jpg", "png") 310 | first_seg, seg_ori = read_seg(seg_path, scale_factor) 311 | eval_video_tracking_davis(args, model, scale_factor, frame_list, video_dir, first_seg, seg_ori, color_palette) 312 | -------------------------------------------------------------------------------- /eval_homography.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import os 4 | import torch 5 | from tqdm import tqdm 6 | import cv2 7 | import torch.nn.functional as F 8 | 9 | def mnn_matcher(descriptors_a, descriptors_b, metric='cosine'): 10 | device = descriptors_a.device 11 | if metric == 'cosine': 12 | descriptors_a = F.normalize(descriptors_a) 13 | descriptors_b = F.normalize(descriptors_b) 14 | sim = descriptors_a @ descriptors_b.t() 15 | elif metric == 'l2': 16 | dist = torch.sum(descriptors_a**2, dim=1, keepdim=True) + torch.sum(descriptors_b**2, dim=1, keepdim=True).t() - \ 17 | 2 * descriptors_a.mm(descriptors_b.t()) 18 | sim = -dist 19 | nn12 = torch.max(sim, dim=1)[1] 20 | nn21 = torch.max(sim, dim=0)[1] 21 | ids1 = torch.arange(0, sim.shape[0], device=device) 22 | mask = (ids1 == nn21[nn12]) 23 | matches = torch.stack([ids1[mask], nn12[mask]]) 24 | return matches.t().data.cpu().numpy() 25 | 26 | def generate_read_function(save_path, method, extension='ppm', top_k=None): 27 | def read_function(seq_name, im_idx): 28 | aux = np.load(os.path.join(save_path, seq_name, '%d.%s.%s' % (im_idx, extension, method))) 29 | if top_k is None: 30 | return aux['keypoints'], aux['descriptors'] 31 | else: 32 | if len(aux['scores']) != 0: 33 | ids = np.argsort(aux['scores'])[-top_k :] 34 | if len(aux['scores'].shape) == 2: 35 | scores = aux['scores'][0] 36 | elif len(aux['scores'].shape) == 1: 37 | scores = aux['scores'] 38 | ids = np.argsort(scores)[-top_k :] 39 | return aux['keypoints'][ids, :], aux['descriptors'][ids, :] 40 | else: 41 | return aux['keypoints'][:, :2], aux['descriptors'] 42 | return read_function 43 | 44 | 45 | if __name__ == "__main__": 46 | parser = argparse.ArgumentParser(description='SPair-71k Evaluation Script') 47 | parser.add_argument('--hpatches_path', type=str, default='/scratch/dift_release/d2-net/hpatches_sequences/hpatches-sequences-release', help='path to hpatches dataset') 48 | parser.add_argument('--save_path', type=str, default='./hpatches_results', help='path to save features') 49 | parser.add_argument('--feat', choices=['dift_sd', 'dift_adm'], default='dift_sd', help="which feature to use") 50 | parser.add_argument('--metric', choices=['cosine', 'l2'], default='cosine', help="which distance metric to use") 51 | parser.add_argument('--mode', choices=['ransac', 'lmeds'], default='lmeds', help="which method to use when calculating homography") 52 | args = parser.parse_args() 53 | 54 | seq_names = sorted(os.listdir(args.hpatches_path)) 55 | read_function = generate_read_function(args.save_path, args.feat) 56 | th = np.linspace(1, 5, 3) 57 | 58 | i_accuracy = [] 59 | v_accuracy = [] 60 | 61 | for seq_idx, seq_name in tqdm(enumerate(seq_names)): 62 | keypoints_a, descriptors_a = read_function(seq_name, 1) 63 | keypoints_a, unique_idx = np.unique(keypoints_a, return_index=True, axis=0) 64 | descriptors_a = descriptors_a[unique_idx] 65 | 66 | h, w = cv2.imread(os.path.join(args.hpatches_path, seq_name, '1.ppm')).shape[:2] 67 | 68 | for im_idx in range(2, 7): 69 | h2, w2 = cv2.imread(os.path.join(args.hpatches_path, seq_name, '{}.ppm'.format(im_idx))).shape[:2] 70 | keypoints_b, descriptors_b = read_function(seq_name, im_idx) 71 | keypoints_b, unique_idx = np.unique(keypoints_b, return_index=True, axis=0) 72 | descriptors_b = descriptors_b[unique_idx] 73 | 74 | matches = mnn_matcher( 75 | torch.from_numpy(descriptors_a).cuda(), 76 | torch.from_numpy(descriptors_b).cuda(), 77 | metric=args.metric 78 | ) 79 | 80 | H_gt = np.loadtxt(os.path.join(args.hpatches_path, seq_name, "H_1_" + str(im_idx))) 81 | pts_a = keypoints_a[matches[:, 0]].reshape(-1, 1, 2).astype(np.float32) 82 | pts_b = keypoints_b[matches[:, 1]].reshape(-1, 1, 2).astype(np.float32) 83 | 84 | if args.mode == 'ransac': 85 | H, mask = cv2.findHomography(pts_a, pts_b, cv2.RANSAC, ransacReprojThreshold=3) 86 | elif args.mode == 'lmeds': 87 | H, mask = cv2.findHomography(pts_a, pts_b, cv2.LMEDS, ransacReprojThreshold=3) 88 | 89 | corners = np.array([[0, 0, 1], 90 | [0, h-1, 1], 91 | [w - 1, 0, 1], 92 | [w - 1, h - 1, 1]]) 93 | 94 | real_warped_corners = np.dot(corners, np.transpose(H_gt)) 95 | real_warped_corners = real_warped_corners[:, :2] / real_warped_corners[:, 2:] 96 | warped_corners = np.dot(corners, np.transpose(H)) 97 | warped_corners = warped_corners[:, :2] / warped_corners[:, 2:] 98 | 99 | mean_dist = np.mean(np.linalg.norm(real_warped_corners - warped_corners, axis=1)) 100 | correctness = mean_dist <= th 101 | 102 | if seq_name[0] == 'i': 103 | i_accuracy.append(correctness) 104 | elif seq_name[0] == 'v': 105 | v_accuracy.append(correctness) 106 | 107 | i_accuracy = np.array(i_accuracy) 108 | v_accuracy = np.array(v_accuracy) 109 | i_mean_accuracy = np.mean(i_accuracy, axis=0) 110 | v_mean_accuracy = np.mean(v_accuracy, axis=0) 111 | overall_mean_accuracy = np.mean(np.concatenate((i_accuracy, v_accuracy), axis=0), axis=0) 112 | print('overall_acc: {}, i_acc: {}, v_acc: {}'.format( 113 | overall_mean_accuracy * 100, i_mean_accuracy * 100, v_mean_accuracy * 100)) -------------------------------------------------------------------------------- /eval_hpatches.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | import os 4 | import argparse 5 | import numpy as np 6 | from PIL import Image 7 | from torch.utils.data import Dataset 8 | from tqdm import tqdm 9 | import skimage.io as io 10 | import torch.nn.functional as F 11 | from src.models.dift_sd import SDFeaturizer4Eval 12 | from src.models.dift_adm import ADMFeaturizer4Eval 13 | 14 | class HPatchDataset(Dataset): 15 | def __init__(self, imdir, spdir): 16 | self.imfs = [] 17 | for f in os.listdir(imdir): 18 | scene_dir = os.path.join(imdir, f) 19 | self.imfs.extend([os.path.join(scene_dir, '{}.ppm').format(ind) for ind in range(1, 7)]) 20 | self.spdir = spdir 21 | 22 | def __getitem__(self, item): 23 | imf = self.imfs[item] 24 | im = io.imread(imf) 25 | name, idx = imf.split('/')[-2:] 26 | coord = np.loadtxt(os.path.join(self.spdir, f'{name}-{idx[0]}.kp')).astype(np.float32) 27 | out = {'coord': coord, 'imf': imf} 28 | return out 29 | 30 | def __len__(self): 31 | return len(self.imfs) 32 | 33 | 34 | def main(args): 35 | for arg in vars(args): 36 | value = getattr(args,arg) 37 | if value is not None: 38 | print('%s: %s' % (str(arg),str(value))) 39 | 40 | dataset = HPatchDataset(imdir=args.hpatches_path, spdir=args.kpts_path) 41 | data_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False) 42 | if args.dift_model == 'sd': 43 | dift = SDFeaturizer4Eval() 44 | elif args.dift_model == 'adm': 45 | dift = ADMFeaturizer4Eval() 46 | 47 | with torch.no_grad(): 48 | for data in tqdm(data_loader): 49 | img_path = data['imf'][0] 50 | img = Image.open(img_path) 51 | w, h = img.size 52 | coord = data['coord'].to('cuda') 53 | c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).to(coord.device).float() 54 | coord_norm = (coord - c) / c 55 | 56 | feat = dift.forward(img, 57 | img_size=args.img_size, 58 | t=args.t, 59 | up_ft_index=args.up_ft_index, 60 | ensemble_size=args.ensemble_size) 61 | 62 | feat = F.grid_sample(feat, coord_norm.unsqueeze(2)).squeeze(-1) 63 | feat = feat.transpose(1, 2) 64 | 65 | desc = feat.squeeze(0).detach().cpu().numpy() 66 | kpt = coord.cpu().numpy().squeeze(0) 67 | 68 | out_dir = os.path.join(args.save_path, os.path.basename(os.path.dirname(img_path))) 69 | os.makedirs(out_dir, exist_ok=True) 70 | with open(os.path.join(out_dir, f'{os.path.basename(img_path)}.dift_{args.dift_model}'), 'wb') as output_file: 71 | np.savez( 72 | output_file, 73 | keypoints=kpt, 74 | scores=[], 75 | descriptors=desc 76 | ) 77 | 78 | 79 | if __name__ == "__main__": 80 | parser = argparse.ArgumentParser(description='SPair-71k Evaluation Script') 81 | parser.add_argument('--hpatches_path', type=str, default='/scratch/dift_release/d2-net/hpatches_sequences/hpatches-sequences-release', help='path to hpatches dataset') 82 | parser.add_argument('--kpts_path', type=str, default='./superpoint-1k', help='path to 1k superpoint keypoints') 83 | parser.add_argument('--save_path', type=str, default='./hpatches_results', help='path to save features') 84 | parser.add_argument('--dift_model', choices=['sd', 'adm'], default='sd', help="which dift version to use") 85 | parser.add_argument('--img_size', nargs='+', type=int, default=[768, 768], 86 | help='''in the order of [width, height], resize input image 87 | to [w, h] before fed into diffusion model, if set to 0, will 88 | stick to the original input size. by default is 768x768.''') 89 | parser.add_argument('--t', default=261, type=int, help='t for diffusion') 90 | parser.add_argument('--up_ft_index', default=1, type=int, help='which upsampling block to extract the ft map') 91 | parser.add_argument('--ensemble_size', default=8, type=int, help='ensemble size for getting an image ft map') 92 | args = parser.parse_args() 93 | main(args) -------------------------------------------------------------------------------- /eval_spair.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from torch.nn import functional as F 4 | from tqdm import tqdm 5 | import numpy as np 6 | from src.models.dift_sd import SDFeaturizer4Eval 7 | from src.models.dift_adm import ADMFeaturizer4Eval 8 | import os 9 | import json 10 | from PIL import Image 11 | import torch.nn as nn 12 | 13 | 14 | def main(args): 15 | for arg in vars(args): 16 | value = getattr(args,arg) 17 | if value is not None: 18 | print('%s: %s' % (str(arg),str(value))) 19 | 20 | torch.cuda.set_device(0) 21 | 22 | dataset_path = args.dataset_path 23 | test_path = 'PairAnnotation/test' 24 | json_list = os.listdir(os.path.join(dataset_path, test_path)) 25 | all_cats = os.listdir(os.path.join(dataset_path, 'JPEGImages')) 26 | cat2json = {} 27 | 28 | for cat in all_cats: 29 | cat_list = [] 30 | for i in json_list: 31 | if cat in i: 32 | cat_list.append(i) 33 | cat2json[cat] = cat_list 34 | 35 | # get test image path for all cats 36 | cat2img = {} 37 | for cat in all_cats: 38 | cat2img[cat] = [] 39 | cat_list = cat2json[cat] 40 | for json_path in cat_list: 41 | with open(os.path.join(dataset_path, test_path, json_path)) as temp_f: 42 | data = json.load(temp_f) 43 | temp_f.close() 44 | src_imname = data['src_imname'] 45 | trg_imname = data['trg_imname'] 46 | if src_imname not in cat2img[cat]: 47 | cat2img[cat].append(src_imname) 48 | if trg_imname not in cat2img[cat]: 49 | cat2img[cat].append(trg_imname) 50 | 51 | if args.dift_model == 'sd': 52 | dift = SDFeaturizer4Eval(cat_list=all_cats) 53 | elif args.dift_model == 'adm': 54 | dift = ADMFeaturizer4Eval() 55 | 56 | print("saving all test images' features...") 57 | os.makedirs(args.save_path, exist_ok=True) 58 | for cat in tqdm(all_cats): 59 | output_dict = {} 60 | image_list = cat2img[cat] 61 | for image_path in image_list: 62 | img = Image.open(os.path.join(dataset_path, 'JPEGImages', cat, image_path)) 63 | output_dict[image_path] = dift.forward(img, 64 | category=cat, 65 | img_size=args.img_size, 66 | t=args.t, 67 | up_ft_index=args.up_ft_index, 68 | ensemble_size=args.ensemble_size) 69 | torch.save(output_dict, os.path.join(args.save_path, f'{cat}.pth')) 70 | 71 | total_pck = [] 72 | all_correct = 0 73 | all_total = 0 74 | 75 | for cat in all_cats: 76 | cat_list = cat2json[cat] 77 | output_dict = torch.load(os.path.join(args.save_path, f'{cat}.pth')) 78 | 79 | cat_pck = [] 80 | cat_correct = 0 81 | cat_total = 0 82 | 83 | for json_path in tqdm(cat_list): 84 | 85 | with open(os.path.join(dataset_path, test_path, json_path)) as temp_f: 86 | data = json.load(temp_f) 87 | 88 | src_img_size = data['src_imsize'][:2][::-1] 89 | trg_img_size = data['trg_imsize'][:2][::-1] 90 | 91 | src_ft = output_dict[data['src_imname']] 92 | trg_ft = output_dict[data['trg_imname']] 93 | 94 | src_ft = nn.Upsample(size=src_img_size, mode='bilinear')(src_ft) 95 | trg_ft = nn.Upsample(size=trg_img_size, mode='bilinear')(trg_ft) 96 | h = trg_ft.shape[-2] 97 | w = trg_ft.shape[-1] 98 | 99 | trg_bndbox = data['trg_bndbox'] 100 | threshold = max(trg_bndbox[3] - trg_bndbox[1], trg_bndbox[2] - trg_bndbox[0]) 101 | 102 | total = 0 103 | correct = 0 104 | 105 | for idx in range(len(data['src_kps'])): 106 | total += 1 107 | cat_total += 1 108 | all_total += 1 109 | src_point = data['src_kps'][idx] 110 | trg_point = data['trg_kps'][idx] 111 | 112 | num_channel = src_ft.size(1) 113 | src_vec = src_ft[0, :, src_point[1], src_point[0]].view(1, num_channel) # 1, C 114 | trg_vec = trg_ft.view(num_channel, -1).transpose(0, 1) # HW, C 115 | src_vec = F.normalize(src_vec).transpose(0, 1) # c, 1 116 | trg_vec = F.normalize(trg_vec) # HW, c 117 | cos_map = torch.mm(trg_vec, src_vec).view(h, w).cpu().numpy() # H, W 118 | 119 | max_yx = np.unravel_index(cos_map.argmax(), cos_map.shape) 120 | 121 | dist = ((max_yx[1] - trg_point[0]) ** 2 + (max_yx[0] - trg_point[1]) ** 2) ** 0.5 122 | if (dist / threshold) <= 0.1: 123 | correct += 1 124 | cat_correct += 1 125 | all_correct += 1 126 | 127 | cat_pck.append(correct / total) 128 | total_pck.extend(cat_pck) 129 | 130 | print(f'{cat} per image PCK@0.1: {np.mean(cat_pck) * 100:.2f}') 131 | print(f'{cat} per point PCK@0.1: {cat_correct / cat_total * 100:.2f}') 132 | print(f'All per image PCK@0.1: {np.mean(total_pck) * 100:.2f}') 133 | print(f'All per point PCK@0.1: {all_correct / all_total * 100:.2f}') 134 | 135 | 136 | if __name__ == "__main__": 137 | parser = argparse.ArgumentParser(description='SPair-71k Evaluation Script') 138 | parser.add_argument('--dataset_path', type=str, default='./SPair-71k/', help='path to spair dataset') 139 | parser.add_argument('--save_path', type=str, default='/scratch/lt453/spair_ft/', help='path to save features') 140 | parser.add_argument('--dift_model', choices=['sd', 'adm'], default='sd', help="which dift version to use") 141 | parser.add_argument('--img_size', nargs='+', type=int, default=[768, 768], 142 | help='''in the order of [width, height], resize input image 143 | to [w, h] before fed into diffusion model, if set to 0, will 144 | stick to the original input size. by default is 768x768.''') 145 | parser.add_argument('--t', default=261, type=int, help='t for diffusion') 146 | parser.add_argument('--up_ft_index', default=1, type=int, help='which upsampling block to extract the ft map') 147 | parser.add_argument('--ensemble_size', default=8, type=int, help='ensemble size for getting an image ft map') 148 | args = parser.parse_args() 149 | main(args) -------------------------------------------------------------------------------- /extract_dift.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from PIL import Image 4 | from torchvision.transforms import PILToTensor 5 | from src.models.dift_sd import SDFeaturizer 6 | 7 | def main(args): 8 | dift = SDFeaturizer(args.model_id) 9 | img = Image.open(args.input_path).convert('RGB') 10 | if args.img_size[0] > 0: 11 | img = img.resize(args.img_size) 12 | img_tensor = (PILToTensor()(img) / 255.0 - 0.5) * 2 13 | ft = dift.forward(img_tensor, 14 | prompt=args.prompt, 15 | t=args.t, 16 | up_ft_index=args.up_ft_index, 17 | ensemble_size=args.ensemble_size) 18 | ft = torch.save(ft.squeeze(0).cpu(), args.output_path) # save feature in the shape of [c, h, w] 19 | 20 | 21 | if __name__ == '__main__': 22 | 23 | parser = argparse.ArgumentParser( 24 | description='''extract dift from input image, and save it as torch tenosr, 25 | in the shape of [c, h, w].''') 26 | 27 | parser.add_argument('--img_size', nargs='+', type=int, default=[768, 768], 28 | help='''in the order of [width, height], resize input image 29 | to [w, h] before fed into diffusion model, if set to 0, will 30 | stick to the original input size. by default is 768x768.''') 31 | parser.add_argument('--model_id', default='stabilityai/stable-diffusion-2-1', type=str, 32 | help='model_id of the diffusion model in huggingface') 33 | parser.add_argument('--t', default=261, type=int, 34 | help='time step for diffusion, choose from range [0, 1000]') 35 | parser.add_argument('--up_ft_index', default=1, type=int, choices=[0, 1, 2 ,3], 36 | help='which upsampling block of U-Net to extract the feature map') 37 | parser.add_argument('--prompt', default='', type=str, 38 | help='prompt used in the stable diffusion') 39 | parser.add_argument('--ensemble_size', default=8, type=int, 40 | help='number of repeated images in each batch used to get features') 41 | parser.add_argument('--input_path', type=str, 42 | help='path to the input image file') 43 | parser.add_argument('--output_path', type=str, default='dift.pt', 44 | help='path to save the output features as torch tensor') 45 | args = parser.parse_args() 46 | main(args) -------------------------------------------------------------------------------- /extract_dift.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 \ 2 | python extract_dift.py \ 3 | --input_path ./assets/cat.png \ 4 | --output_path dift_cat.pt \ 5 | --img_size 0 \ 6 | --t 261 \ 7 | --up_ft_index 1 \ 8 | --prompt 'a photo of a cat' \ 9 | --ensemble_size 8 -------------------------------------------------------------------------------- /sd_featurizer_spair.py: -------------------------------------------------------------------------------- 1 | from diffusers import StableDiffusionPipeline 2 | import torch 3 | import torch.nn as nn 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | from typing import Any, Callable, Dict, List, Optional, Union, Tuple 7 | from diffusers.models.unet_2d_condition import UNet2DConditionModel, UNet2DConditionOutput 8 | from diffusers import DDIMScheduler 9 | from diffusers.models.modeling_utils import ModelMixin 10 | import gc 11 | from PIL import Image 12 | from torchvision.transforms import PILToTensor 13 | import os 14 | from lavis.models import load_model_and_preprocess 15 | import json 16 | from PIL import Image, ImageDraw 17 | 18 | 19 | class MyUNet2DConditionModel(UNet2DConditionModel): 20 | def forward( 21 | self, 22 | sample: torch.FloatTensor, 23 | timestep: Union[torch.Tensor, float, int], 24 | up_ft_indices, 25 | encoder_hidden_states: torch.Tensor, 26 | class_labels: Optional[torch.Tensor] = None, 27 | timestep_cond: Optional[torch.Tensor] = None, 28 | attention_mask: Optional[torch.Tensor] = None, 29 | cross_attention_kwargs: Optional[Dict[str, Any]] = None): 30 | r""" 31 | Args: 32 | sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor 33 | timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps 34 | encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states 35 | cross_attention_kwargs (`dict`, *optional*): 36 | A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under 37 | `self.processor` in 38 | [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). 39 | """ 40 | # By default samples have to be AT least a multiple of the overall upsampling factor. 41 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). 42 | # However, the upsampling interpolation output size can be forced to fit any upsampling size 43 | # on the fly if necessary. 44 | default_overall_up_factor = 2**self.num_upsamplers 45 | 46 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 47 | forward_upsample_size = False 48 | upsample_size = None 49 | 50 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): 51 | # logger.info("Forward upsample size to force interpolation output size.") 52 | forward_upsample_size = True 53 | 54 | # prepare attention_mask 55 | if attention_mask is not None: 56 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 57 | attention_mask = attention_mask.unsqueeze(1) 58 | 59 | # 0. center input if necessary 60 | if self.config.center_input_sample: 61 | sample = 2 * sample - 1.0 62 | 63 | # 1. time 64 | timesteps = timestep 65 | if not torch.is_tensor(timesteps): 66 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 67 | # This would be a good case for the `match` statement (Python 3.10+) 68 | is_mps = sample.device.type == "mps" 69 | if isinstance(timestep, float): 70 | dtype = torch.float32 if is_mps else torch.float64 71 | else: 72 | dtype = torch.int32 if is_mps else torch.int64 73 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 74 | elif len(timesteps.shape) == 0: 75 | timesteps = timesteps[None].to(sample.device) 76 | 77 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 78 | timesteps = timesteps.expand(sample.shape[0]) 79 | 80 | t_emb = self.time_proj(timesteps) 81 | 82 | # timesteps does not contain any weights and will always return f32 tensors 83 | # but time_embedding might actually be running in fp16. so we need to cast here. 84 | # there might be better ways to encapsulate this. 85 | t_emb = t_emb.to(dtype=self.dtype) 86 | 87 | emb = self.time_embedding(t_emb, timestep_cond) 88 | 89 | if self.class_embedding is not None: 90 | if class_labels is None: 91 | raise ValueError("class_labels should be provided when num_class_embeds > 0") 92 | 93 | if self.config.class_embed_type == "timestep": 94 | class_labels = self.time_proj(class_labels) 95 | 96 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) 97 | emb = emb + class_emb 98 | 99 | # 2. pre-process 100 | sample = self.conv_in(sample) 101 | 102 | # 3. down 103 | down_block_res_samples = (sample,) 104 | for downsample_block in self.down_blocks: 105 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 106 | sample, res_samples = downsample_block( 107 | hidden_states=sample, 108 | temb=emb, 109 | encoder_hidden_states=encoder_hidden_states, 110 | attention_mask=attention_mask, 111 | cross_attention_kwargs=cross_attention_kwargs, 112 | ) 113 | else: 114 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 115 | 116 | down_block_res_samples += res_samples 117 | 118 | # 4. mid 119 | if self.mid_block is not None: 120 | sample = self.mid_block( 121 | sample, 122 | emb, 123 | encoder_hidden_states=encoder_hidden_states, 124 | attention_mask=attention_mask, 125 | cross_attention_kwargs=cross_attention_kwargs, 126 | ) 127 | 128 | # 5. up 129 | up_ft = {} 130 | for i, upsample_block in enumerate(self.up_blocks): 131 | 132 | if i > np.max(up_ft_indices): 133 | break 134 | 135 | is_final_block = i == len(self.up_blocks) - 1 136 | 137 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 138 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 139 | 140 | # if we have not reached the final block and need to forward the 141 | # upsample size, we do it here 142 | if not is_final_block and forward_upsample_size: 143 | upsample_size = down_block_res_samples[-1].shape[2:] 144 | 145 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: 146 | sample = upsample_block( 147 | hidden_states=sample, 148 | temb=emb, 149 | res_hidden_states_tuple=res_samples, 150 | encoder_hidden_states=encoder_hidden_states, 151 | cross_attention_kwargs=cross_attention_kwargs, 152 | upsample_size=upsample_size, 153 | attention_mask=attention_mask, 154 | ) 155 | else: 156 | sample = upsample_block( 157 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size 158 | ) 159 | 160 | if i in up_ft_indices: 161 | up_ft[i] = sample.detach() 162 | 163 | output = {} 164 | output['up_ft'] = up_ft 165 | return output 166 | 167 | class OneStepSDPipeline(StableDiffusionPipeline): 168 | @torch.no_grad() 169 | def __call__( 170 | self, 171 | img_tensor, 172 | t, 173 | up_ft_indices, 174 | negative_prompt: Optional[Union[str, List[str]]] = None, 175 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 176 | prompt_embeds: Optional[torch.FloatTensor] = None, 177 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 178 | callback_steps: int = 1, 179 | cross_attention_kwargs: Optional[Dict[str, Any]] = None 180 | ): 181 | 182 | device = self._execution_device 183 | latents = self.vae.encode(img_tensor).latent_dist.sample() * self.vae.config.scaling_factor 184 | t = torch.tensor(t, dtype=torch.long, device=device) 185 | noise = torch.randn_like(latents).to(device) 186 | latents_noisy = self.scheduler.add_noise(latents, noise, t) 187 | unet_output = self.unet(latents_noisy, 188 | t, 189 | up_ft_indices, 190 | encoder_hidden_states=prompt_embeds, 191 | cross_attention_kwargs=cross_attention_kwargs) 192 | return unet_output 193 | 194 | 195 | class SDFeaturizer: 196 | def __init__(self, sd_id='stabilityai/stable-diffusion-2-1', null_prompt=''): 197 | unet = MyUNet2DConditionModel.from_pretrained(sd_id, subfolder="unet") 198 | onestep_pipe = OneStepSDPipeline.from_pretrained(sd_id, unet=unet, safety_checker=None) 199 | onestep_pipe.vae.decoder = None 200 | onestep_pipe.scheduler = DDIMScheduler.from_pretrained(sd_id, subfolder="scheduler") 201 | with torch.no_grad(): 202 | cat2prompt = {} 203 | all_cats = os.listdir('/home/lt453/SPair-71k/JPEGImages') 204 | for cat in all_cats: 205 | prompt = f"a photo of a {cat}" 206 | prompt_embeds = onestep_pipe._encode_prompt( 207 | prompt=prompt, 208 | device='cpu', 209 | num_images_per_prompt=1, 210 | do_classifier_free_guidance=False) # [1, 77, dim] 211 | cat2prompt[cat] = prompt_embeds 212 | null_prompt_embeds = onestep_pipe._encode_prompt( 213 | prompt=null_prompt, 214 | device='cpu', 215 | num_images_per_prompt=1, 216 | do_classifier_free_guidance=False) # [1, 77, dim] 217 | onestep_pipe.tokenizer = None 218 | onestep_pipe.text_encoder = None 219 | gc.collect() 220 | onestep_pipe = onestep_pipe.to("cuda") 221 | self.cat2prompt = cat2prompt 222 | self.null_prompt_embeds = null_prompt_embeds 223 | onestep_pipe.enable_attention_slicing() 224 | onestep_pipe.enable_xformers_memory_efficient_attention() 225 | self.pipe = onestep_pipe 226 | 227 | @torch.no_grad() 228 | def forward(self, 229 | img, 230 | category, 231 | img_size=[768, 768], 232 | t=261, 233 | up_ft_index=1, 234 | ensemble_size=8): 235 | if img_size is not None: 236 | img = img.resize(img_size) 237 | img_tensor = (PILToTensor()(img) / 255.0 - 0.5) * 2 238 | img_tensor = img_tensor.unsqueeze(0).repeat(ensemble_size, 1, 1, 1).cuda() # ensem, c, h, w 239 | if category in self.cat2prompt: 240 | prompt_embeds = self.cat2prompt[category] 241 | else: 242 | prompt_embeds = self.null_prompt_embeds 243 | prompt_embeds = prompt_embeds.repeat(ensemble_size, 1, 1).cuda() 244 | unet_ft_all = self.pipe( 245 | img_tensor=img_tensor, 246 | t=t, 247 | up_ft_indices=[up_ft_index], 248 | prompt_embeds=prompt_embeds) 249 | unet_ft = unet_ft_all['up_ft'][up_ft_index] # ensem, c, h, w 250 | unet_ft = unet_ft.mean(0, keepdim=True) # n, c,h,w 251 | return unet_ft -------------------------------------------------------------------------------- /setup_env.sh: -------------------------------------------------------------------------------- 1 | conda create -n dift python=3.10 2 | conda activate dift 3 | 4 | conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.7 -c pytorch -c nvidia 5 | conda install xformers -c xformers 6 | pip install jupyterlab 7 | pip install diffusers[torch]==0.15.0 8 | pip install -U matplotlib 9 | pip install transformers 10 | pip install ipympl 11 | pip install triton -------------------------------------------------------------------------------- /src/models/clip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import gc 6 | from PIL import Image, ImageDraw 7 | from torchvision.transforms import PILToTensor 8 | import os 9 | import open_clip 10 | from torchvision import transforms 11 | import copy 12 | import json 13 | from tqdm.notebook import tqdm 14 | import time 15 | import datetime 16 | import math 17 | 18 | 19 | def interpolate_pos_encoding(clip_model, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: 20 | """ 21 | This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher 22 | resolution images. 23 | Source: 24 | https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 25 | """ 26 | 27 | num_patches = embeddings.shape[1] - 1 28 | pos_embedding = clip_model.positional_embedding.unsqueeze(0) 29 | num_positions = pos_embedding.shape[1] - 1 30 | if num_patches == num_positions and height == width: 31 | return clip_model.positional_embedding 32 | class_pos_embed = pos_embedding[:, 0] 33 | patch_pos_embed = pos_embedding[:, 1:] 34 | dim = embeddings.shape[-1] 35 | h0 = height // clip_model.patch_size[0] 36 | w0 = width // clip_model.patch_size[1] 37 | # we add a small number to avoid floating point error in the interpolation 38 | # see discussion at https://github.com/facebookresearch/dino/issues/8 39 | h0, w0 = h0 + 0.1, w0 + 0.1 40 | patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) 41 | patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) 42 | patch_pos_embed = nn.functional.interpolate( 43 | patch_pos_embed, 44 | scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), 45 | mode="bicubic", 46 | align_corners=False, 47 | ) 48 | assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1] 49 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 50 | output = torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) 51 | 52 | return output 53 | 54 | 55 | class CLIPFeaturizer: 56 | def __init__(self): 57 | clip_model, _, _ = open_clip.create_model_and_transforms('ViT-H-14', pretrained='laion2b_s32b_b79k') 58 | visual_model = clip_model.visual 59 | visual_model.output_tokens = True 60 | self.clip_model = visual_model.eval().cuda() 61 | 62 | 63 | @torch.no_grad() 64 | def forward(self, 65 | x, # single image, [1,c,h,w] 66 | block_index): 67 | batch_size = 1 68 | clip_model = self.clip_model 69 | if clip_model.input_patchnorm: 70 | # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)') 71 | x = x.reshape(x.shape[0], x.shape[1], clip_model.grid_size[0], clip_model.patch_size[0], clip_model.grid_size[1], clip_model.patch_size[1]) 72 | x = x.permute(0, 2, 4, 1, 3, 5) 73 | x = x.reshape(x.shape[0], clip_model.grid_size[0] * clip_model.grid_size[1], -1) 74 | x = clip_model.patchnorm_pre_ln(x) 75 | x = clip_model.conv1(x) 76 | else: 77 | x = clip_model.conv1(x) # shape = [*, width, grid, grid] 78 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 79 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 80 | # class embeddings and positional embeddings 81 | x = torch.cat( 82 | [clip_model.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), 83 | x], dim=1) # shape = [*, grid ** 2 + 1, width] 84 | if(x.shape[1] > clip_model.positional_embedding.shape[0]): 85 | dim = int(math.sqrt(x.shape[1]) * clip_model.patch_size[0]) 86 | x = x + interpolate_pos_encoding(clip_model, x, dim, dim).to(x.dtype) 87 | else: 88 | x = x + clip_model.positional_embedding.to(x.dtype) 89 | 90 | # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in 91 | x = clip_model.patch_dropout(x) 92 | x = clip_model.ln_pre(x) 93 | x = x.permute(1, 0, 2) # NLD -> LND 94 | 95 | num_channel = x.size(2) 96 | ft_size = int((x.shape[0]-1) ** 0.5) 97 | 98 | for i, r in enumerate(clip_model.transformer.resblocks): 99 | x = r(x) 100 | 101 | if i == block_index: 102 | tokens = x.permute(1, 0, 2) # LND -> NLD 103 | tokens = tokens[:, 1:] 104 | tokens = tokens.transpose(1, 2).contiguous().view(batch_size, num_channel, ft_size, ft_size) # NCHW 105 | 106 | return tokens -------------------------------------------------------------------------------- /src/models/dift_adm.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import torch 4 | from torchvision import transforms 5 | main_path = Path(__file__).resolve().parent.parent.parent 6 | print(f'main path: {main_path}') 7 | 8 | import sys 9 | sys.path.append(os.path.join(main_path, 'guided-diffusion')) 10 | from guided_diffusion.script_util import create_model_and_diffusion 11 | from guided_diffusion.nn import timestep_embedding 12 | 13 | 14 | class ADMFeaturizer: 15 | def __init__(self): 16 | model, diffusion = create_model_and_diffusion( 17 | image_size=256, 18 | class_cond=False, 19 | learn_sigma=True, 20 | num_channels=256, 21 | num_res_blocks=2, 22 | channel_mult="", 23 | num_heads=4, 24 | num_head_channels=64, 25 | num_heads_upsample=-1, 26 | attention_resolutions="32,16,8", 27 | dropout=0.0, 28 | diffusion_steps=1000, 29 | noise_schedule='linear', 30 | timestep_respacing='', 31 | use_kl=False, 32 | predict_xstart=False, 33 | rescale_timesteps=False, 34 | rescale_learned_sigmas=False, 35 | use_checkpoint=False, 36 | use_scale_shift_norm=True, 37 | resblock_updown=True, 38 | use_fp16=False, 39 | use_new_attention_order=False, 40 | ) 41 | model_path = os.path.join(main_path, 'guided-diffusion/models/256x256_diffusion_uncond.pt') 42 | model.load_state_dict(torch.load(model_path, map_location="cpu")) 43 | self.model = model.eval().cuda() 44 | self.diffusion = diffusion 45 | 46 | self.adm_transforms = transforms.Compose([ 47 | transforms.ToTensor(), 48 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) 49 | ]) 50 | 51 | @torch.no_grad() 52 | def forward(self, img_tensor, 53 | t=101, 54 | up_ft_index=4, 55 | ensemble_size=8): 56 | model = self.model 57 | diffusion = self.diffusion 58 | 59 | img_tensor = img_tensor.repeat(ensemble_size, 1, 1, 1).cuda() # ensem, c, h, w 60 | t = torch.ones((img_tensor.shape[0],), device='cuda', dtype=torch.int64) * t 61 | x_t = diffusion.q_sample(img_tensor, t, noise=None) 62 | 63 | # get layer-wise features 64 | hs = [] 65 | emb = model.time_embed(timestep_embedding(t, model.model_channels)) 66 | h = x_t.type(model.dtype) 67 | for module in model.input_blocks: 68 | h = module(h, emb) 69 | hs.append(h) 70 | h = model.middle_block(h, emb) 71 | for i, module in enumerate(model.output_blocks): 72 | h = torch.cat([h, hs.pop()], dim=1) 73 | h = module(h, emb) 74 | 75 | if i == up_ft_index: 76 | ft = h.mean(0, keepdim=True).detach() 77 | return ft 78 | 79 | 80 | class ADMFeaturizer4Eval(ADMFeaturizer): 81 | 82 | @torch.no_grad() 83 | def forward(self, img, 84 | img_size=[512, 512], 85 | t=101, 86 | up_ft_index=4, 87 | ensemble_size=8, 88 | **kwargs): 89 | 90 | img_tensor = self.adm_transforms(img.resize(img_size)) 91 | ft = super().forward(img_tensor, 92 | t=t, 93 | up_ft_index=up_ft_index, 94 | ensemble_size=ensemble_size) 95 | return ft -------------------------------------------------------------------------------- /src/models/dift_sd.py: -------------------------------------------------------------------------------- 1 | from diffusers import StableDiffusionPipeline 2 | import torch 3 | import torch.nn as nn 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | from typing import Any, Callable, Dict, List, Optional, Union 7 | from diffusers.models.unet_2d_condition import UNet2DConditionModel 8 | from diffusers import DDIMScheduler 9 | import gc 10 | import os 11 | from PIL import Image 12 | from torchvision.transforms import PILToTensor 13 | 14 | class MyUNet2DConditionModel(UNet2DConditionModel): 15 | def forward( 16 | self, 17 | sample: torch.FloatTensor, 18 | timestep: Union[torch.Tensor, float, int], 19 | up_ft_indices, 20 | encoder_hidden_states: torch.Tensor, 21 | class_labels: Optional[torch.Tensor] = None, 22 | timestep_cond: Optional[torch.Tensor] = None, 23 | attention_mask: Optional[torch.Tensor] = None, 24 | cross_attention_kwargs: Optional[Dict[str, Any]] = None): 25 | r""" 26 | Args: 27 | sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor 28 | timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps 29 | encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states 30 | cross_attention_kwargs (`dict`, *optional*): 31 | A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under 32 | `self.processor` in 33 | [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). 34 | """ 35 | # By default samples have to be AT least a multiple of the overall upsampling factor. 36 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). 37 | # However, the upsampling interpolation output size can be forced to fit any upsampling size 38 | # on the fly if necessary. 39 | default_overall_up_factor = 2**self.num_upsamplers 40 | 41 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 42 | forward_upsample_size = False 43 | upsample_size = None 44 | 45 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): 46 | # logger.info("Forward upsample size to force interpolation output size.") 47 | forward_upsample_size = True 48 | 49 | # prepare attention_mask 50 | if attention_mask is not None: 51 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 52 | attention_mask = attention_mask.unsqueeze(1) 53 | 54 | # 0. center input if necessary 55 | if self.config.center_input_sample: 56 | sample = 2 * sample - 1.0 57 | 58 | # 1. time 59 | timesteps = timestep 60 | if not torch.is_tensor(timesteps): 61 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 62 | # This would be a good case for the `match` statement (Python 3.10+) 63 | is_mps = sample.device.type == "mps" 64 | if isinstance(timestep, float): 65 | dtype = torch.float32 if is_mps else torch.float64 66 | else: 67 | dtype = torch.int32 if is_mps else torch.int64 68 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 69 | elif len(timesteps.shape) == 0: 70 | timesteps = timesteps[None].to(sample.device) 71 | 72 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 73 | timesteps = timesteps.expand(sample.shape[0]) 74 | 75 | t_emb = self.time_proj(timesteps) 76 | 77 | # timesteps does not contain any weights and will always return f32 tensors 78 | # but time_embedding might actually be running in fp16. so we need to cast here. 79 | # there might be better ways to encapsulate this. 80 | t_emb = t_emb.to(dtype=self.dtype) 81 | 82 | emb = self.time_embedding(t_emb, timestep_cond) 83 | 84 | if self.class_embedding is not None: 85 | if class_labels is None: 86 | raise ValueError("class_labels should be provided when num_class_embeds > 0") 87 | 88 | if self.config.class_embed_type == "timestep": 89 | class_labels = self.time_proj(class_labels) 90 | 91 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) 92 | emb = emb + class_emb 93 | 94 | # 2. pre-process 95 | sample = self.conv_in(sample) 96 | 97 | # 3. down 98 | down_block_res_samples = (sample,) 99 | for downsample_block in self.down_blocks: 100 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 101 | sample, res_samples = downsample_block( 102 | hidden_states=sample, 103 | temb=emb, 104 | encoder_hidden_states=encoder_hidden_states, 105 | attention_mask=attention_mask, 106 | cross_attention_kwargs=cross_attention_kwargs, 107 | ) 108 | else: 109 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 110 | 111 | down_block_res_samples += res_samples 112 | 113 | # 4. mid 114 | if self.mid_block is not None: 115 | sample = self.mid_block( 116 | sample, 117 | emb, 118 | encoder_hidden_states=encoder_hidden_states, 119 | attention_mask=attention_mask, 120 | cross_attention_kwargs=cross_attention_kwargs, 121 | ) 122 | 123 | # 5. up 124 | up_ft = {} 125 | for i, upsample_block in enumerate(self.up_blocks): 126 | 127 | if i > np.max(up_ft_indices): 128 | break 129 | 130 | is_final_block = i == len(self.up_blocks) - 1 131 | 132 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 133 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 134 | 135 | # if we have not reached the final block and need to forward the 136 | # upsample size, we do it here 137 | if not is_final_block and forward_upsample_size: 138 | upsample_size = down_block_res_samples[-1].shape[2:] 139 | 140 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: 141 | sample = upsample_block( 142 | hidden_states=sample, 143 | temb=emb, 144 | res_hidden_states_tuple=res_samples, 145 | encoder_hidden_states=encoder_hidden_states, 146 | cross_attention_kwargs=cross_attention_kwargs, 147 | upsample_size=upsample_size, 148 | attention_mask=attention_mask, 149 | ) 150 | else: 151 | sample = upsample_block( 152 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size 153 | ) 154 | 155 | if i in up_ft_indices: 156 | up_ft[i] = sample.detach() 157 | 158 | output = {} 159 | output['up_ft'] = up_ft 160 | return output 161 | 162 | class OneStepSDPipeline(StableDiffusionPipeline): 163 | @torch.no_grad() 164 | def __call__( 165 | self, 166 | img_tensor, 167 | t, 168 | up_ft_indices, 169 | negative_prompt: Optional[Union[str, List[str]]] = None, 170 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 171 | prompt_embeds: Optional[torch.FloatTensor] = None, 172 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 173 | callback_steps: int = 1, 174 | cross_attention_kwargs: Optional[Dict[str, Any]] = None 175 | ): 176 | 177 | device = self._execution_device 178 | latents = self.vae.encode(img_tensor).latent_dist.sample() * self.vae.config.scaling_factor 179 | t = torch.tensor(t, dtype=torch.long, device=device) 180 | noise = torch.randn_like(latents).to(device) 181 | latents_noisy = self.scheduler.add_noise(latents, noise, t) 182 | unet_output = self.unet(latents_noisy, 183 | t, 184 | up_ft_indices, 185 | encoder_hidden_states=prompt_embeds, 186 | cross_attention_kwargs=cross_attention_kwargs) 187 | return unet_output 188 | 189 | 190 | class SDFeaturizer: 191 | def __init__(self, sd_id='stabilityai/stable-diffusion-2-1', null_prompt=''): 192 | unet = MyUNet2DConditionModel.from_pretrained(sd_id, subfolder="unet") 193 | onestep_pipe = OneStepSDPipeline.from_pretrained(sd_id, unet=unet, safety_checker=None) 194 | onestep_pipe.vae.decoder = None 195 | onestep_pipe.scheduler = DDIMScheduler.from_pretrained(sd_id, subfolder="scheduler") 196 | gc.collect() 197 | onestep_pipe = onestep_pipe.to("cuda") 198 | onestep_pipe.enable_attention_slicing() 199 | onestep_pipe.enable_xformers_memory_efficient_attention() 200 | null_prompt_embeds = onestep_pipe._encode_prompt( 201 | prompt=null_prompt, 202 | device='cuda', 203 | num_images_per_prompt=1, 204 | do_classifier_free_guidance=False) # [1, 77, dim] 205 | 206 | self.null_prompt_embeds = null_prompt_embeds 207 | self.null_prompt = null_prompt 208 | self.pipe = onestep_pipe 209 | 210 | @torch.no_grad() 211 | def forward(self, 212 | img_tensor, 213 | prompt='', 214 | t=261, 215 | up_ft_index=1, 216 | ensemble_size=8): 217 | ''' 218 | Args: 219 | img_tensor: should be a single torch tensor in the shape of [1, C, H, W] or [C, H, W] 220 | prompt: the prompt to use, a string 221 | t: the time step to use, should be an int in the range of [0, 1000] 222 | up_ft_index: which upsampling block of the U-Net to extract feature, you can choose [0, 1, 2, 3] 223 | ensemble_size: the number of repeated images used in the batch to extract features 224 | Return: 225 | unet_ft: a torch tensor in the shape of [1, c, h, w] 226 | ''' 227 | img_tensor = img_tensor.repeat(ensemble_size, 1, 1, 1).cuda() # ensem, c, h, w 228 | if prompt == self.null_prompt: 229 | prompt_embeds = self.null_prompt_embeds 230 | else: 231 | prompt_embeds = self.pipe._encode_prompt( 232 | prompt=prompt, 233 | device='cuda', 234 | num_images_per_prompt=1, 235 | do_classifier_free_guidance=False) # [1, 77, dim] 236 | prompt_embeds = prompt_embeds.repeat(ensemble_size, 1, 1) 237 | unet_ft_all = self.pipe( 238 | img_tensor=img_tensor, 239 | t=t, 240 | up_ft_indices=[up_ft_index], 241 | prompt_embeds=prompt_embeds) 242 | unet_ft = unet_ft_all['up_ft'][up_ft_index] # ensem, c, h, w 243 | unet_ft = unet_ft.mean(0, keepdim=True) # 1,c,h,w 244 | return unet_ft 245 | 246 | 247 | class SDFeaturizer4Eval(SDFeaturizer): 248 | def __init__(self, sd_id='stabilityai/stable-diffusion-2-1', null_prompt='', cat_list=[]): 249 | super().__init__(sd_id, null_prompt) 250 | with torch.no_grad(): 251 | cat2prompt_embeds = {} 252 | for cat in cat_list: 253 | prompt = f"a photo of a {cat}" 254 | prompt_embeds = self.pipe._encode_prompt( 255 | prompt=prompt, 256 | device='cuda', 257 | num_images_per_prompt=1, 258 | do_classifier_free_guidance=False) # [1, 77, dim] 259 | cat2prompt_embeds[cat] = prompt_embeds 260 | self.cat2prompt_embeds = cat2prompt_embeds 261 | 262 | self.pipe.tokenizer = None 263 | self.pipe.text_encoder = None 264 | gc.collect() 265 | torch.cuda.empty_cache() 266 | 267 | 268 | @torch.no_grad() 269 | def forward(self, 270 | img, 271 | category=None, 272 | img_size=[768, 768], 273 | t=261, 274 | up_ft_index=1, 275 | ensemble_size=8): 276 | if img_size is not None: 277 | img = img.resize(img_size) 278 | img_tensor = (PILToTensor()(img) / 255.0 - 0.5) * 2 279 | img_tensor = img_tensor.unsqueeze(0).repeat(ensemble_size, 1, 1, 1).cuda() # ensem, c, h, w 280 | if category in self.cat2prompt_embeds: 281 | prompt_embeds = self.cat2prompt_embeds[category] 282 | else: 283 | prompt_embeds = self.null_prompt_embeds 284 | prompt_embeds = prompt_embeds.repeat(ensemble_size, 1, 1).cuda() 285 | unet_ft_all = self.pipe( 286 | img_tensor=img_tensor, 287 | t=t, 288 | up_ft_indices=[up_ft_index], 289 | prompt_embeds=prompt_embeds) 290 | unet_ft = unet_ft_all['up_ft'][up_ft_index] # ensem, c, h, w 291 | unet_ft = unet_ft.mean(0, keepdim=True) # 1,c,h,w 292 | return unet_ft -------------------------------------------------------------------------------- /src/models/dino.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class DINOFeaturizer: 4 | def __init__(self, dino_id='dino_vitb8'): 5 | self.model = torch.hub.load('facebookresearch/dino:main', dino_id).eval().cuda() 6 | 7 | @torch.no_grad() 8 | def forward(self, img_tensor, block_index): 9 | h = img_tensor.shape[2] // 8 10 | w = img_tensor.shape[3] // 8 11 | n = 12 - block_index 12 | out = self.model.get_intermediate_layers(img_tensor, n=n)[0][0, 1:, :] # hw, c 13 | dim = out.shape[1] 14 | out = out.transpose(0, 1).view(1, dim, h, w) 15 | return out -------------------------------------------------------------------------------- /src/utils/visualization.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import matplotlib.pyplot as plt 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | import torch.nn.functional as F 7 | 8 | class Demo: 9 | 10 | def __init__(self, imgs, ft, img_size): 11 | self.ft = ft # N+1, C, H, W 12 | self.imgs = imgs 13 | self.num_imgs = len(imgs) 14 | self.img_size = img_size 15 | 16 | def plot_img_pairs(self, fig_size=3, alpha=0.45, scatter_size=70): 17 | 18 | fig, axes = plt.subplots(1, self.num_imgs, figsize=(fig_size*self.num_imgs, fig_size)) 19 | 20 | plt.tight_layout() 21 | 22 | for i in range(self.num_imgs): 23 | axes[i].imshow(self.imgs[i]) 24 | axes[i].axis('off') 25 | if i == 0: 26 | axes[i].set_title('source image') 27 | else: 28 | axes[i].set_title('target image') 29 | 30 | num_channel = self.ft.size(1) 31 | 32 | def onclick(event): 33 | if event.inaxes == axes[0]: 34 | with torch.no_grad(): 35 | 36 | x, y = int(np.round(event.xdata)), int(np.round(event.ydata)) 37 | 38 | src_ft = self.ft[0].unsqueeze(0) 39 | src_ft = nn.Upsample(size=(self.img_size, self.img_size), mode='bilinear')(src_ft) 40 | src_vec = src_ft[0, :, y, x].view(1, num_channel) # 1, C 41 | 42 | del src_ft 43 | gc.collect() 44 | torch.cuda.empty_cache() 45 | 46 | trg_ft = nn.Upsample(size=(self.img_size, self.img_size), mode='bilinear')(self.ft[1:]) # N, C, H, W 47 | trg_vec = trg_ft.view(self.num_imgs - 1, num_channel, -1) # N, C, HW 48 | 49 | del trg_ft 50 | gc.collect() 51 | torch.cuda.empty_cache() 52 | 53 | src_vec = F.normalize(src_vec) # 1, C 54 | trg_vec = F.normalize(trg_vec) # N, C, HW 55 | cos_map = torch.matmul(src_vec, trg_vec).view(self.num_imgs - 1, self.img_size, self.img_size).cpu().numpy() # N, H, W 56 | 57 | axes[0].clear() 58 | axes[0].imshow(self.imgs[0]) 59 | axes[0].axis('off') 60 | axes[0].scatter(x, y, c='r', s=scatter_size) 61 | axes[0].set_title('source image') 62 | 63 | for i in range(1, self.num_imgs): 64 | max_yx = np.unravel_index(cos_map[i-1].argmax(), cos_map[i-1].shape) 65 | axes[i].clear() 66 | 67 | heatmap = cos_map[i-1] 68 | heatmap = (heatmap - np.min(heatmap)) / (np.max(heatmap) - np.min(heatmap)) # Normalize to [0, 1] 69 | axes[i].imshow(self.imgs[i]) 70 | axes[i].imshow(255 * heatmap, alpha=alpha, cmap='viridis') 71 | axes[i].axis('off') 72 | axes[i].scatter(max_yx[1].item(), max_yx[0].item(), c='r', s=scatter_size) 73 | axes[i].set_title('target image') 74 | 75 | del cos_map 76 | del heatmap 77 | gc.collect() 78 | 79 | fig.canvas.mpl_connect('button_press_event', onclick) 80 | plt.show() --------------------------------------------------------------------------------