├── .gitignore ├── LICENSE ├── README.md ├── app.py ├── assets ├── 3d_generation_video.mp4 ├── car.png ├── dragon.png ├── gradio.png ├── kunkun.png └── rocket.png ├── cldm ├── ddim_hacked.py ├── hack.py ├── logger.py ├── model.py └── toss.py ├── config.py ├── datasets ├── __init__.py ├── base.py ├── colmap.py ├── colmap_utils.py ├── depth_utils.py ├── nerfpp.py ├── nsvf.py ├── objaverse.py ├── objaverse800k.py ├── objaverse_car.py ├── ray_utils.py └── rtmv.py ├── ldm ├── data │ ├── __init__.py │ ├── base.py │ ├── personalized.py │ └── util.py ├── lr_scheduler.py ├── models │ ├── autoencoder.py │ └── diffusion │ │ ├── __init__.py │ │ ├── ddim.py │ │ ├── ddpm.py │ │ ├── dpm_solver │ │ ├── __init__.py │ │ ├── dpm_solver.py │ │ └── sampler.py │ │ ├── plms.py │ │ └── sampling_util.py ├── modules │ ├── attention.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── model.py │ │ ├── openaimodel.py │ │ ├── upscaling.py │ │ └── util.py │ ├── distributions │ │ ├── __init__.py │ │ └── distributions.py │ ├── ema.py │ ├── embedding_manager.py │ ├── encoders │ │ ├── __init__.py │ │ └── modules.py │ ├── image_degradation │ │ ├── __init__.py │ │ ├── bsrgan.py │ │ ├── bsrgan_light.py │ │ ├── utils │ │ │ └── test.png │ │ └── utils_image.py │ └── midas │ │ ├── __init__.py │ │ ├── api.py │ │ ├── midas │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── blocks.py │ │ ├── dpt_depth.py │ │ ├── midas_net.py │ │ ├── midas_net_custom.py │ │ ├── transforms.py │ │ └── vit.py │ │ └── utils.py └── util.py ├── models └── toss_vae.yaml ├── opt.py ├── outputs ├── a dragon toy with fire on the back.png ├── a dragon toy with ice on the back.png ├── anya │ └── 0_95.png ├── backview of a dragon toy with fire on the back.png ├── backview of a dragon toy with ice on the back.png ├── dragon │ ├── a purple dragon with fire on the back.png │ ├── a dragon toy with fire on the back.png │ ├── a dragon with fire on its back.png │ ├── a dragon with ice on its back.png │ ├── a dragon with ice on the back.png │ └── a purple dragon with fire on the back.png └── minion │ ├── a dragon toy with fire on its back.png │ └── a minion with a rocket on the back.png ├── requirements.txt ├── share.py ├── streamlit_app.py ├── train.py ├── tutorial_dataset.py └── viz.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | gradio_res/ 163 | ckpt/toss.ckpt 164 | exp/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TOSS: High-quality Text-guided Novel View Synthesis from a Single Image (ICLR2024) 2 | 3 | #####

Yukai Shi, Jianan Wang, He Cao, Boshi Tang, Xianbiao Qi, Tianyu Yang, Yukun Huang, Shilong Liu, Lei Zhang, Heung-Yeung Shum 4 |

5 | 6 |

7 | 8 | Official implementation for *TOSS: High-quality Text-guided Novel View Synthesis from a Single Image*. 9 | 10 | **TOSS introduces text as high-level sementic information to constraint the NVS solution space for more controllable and more plausible results.** 11 | 12 | ## [Project Page](https://toss3d.github.io/) | [ArXiv](https://arxiv.org/abs/2310.10644) | [Weights](https://drive.google.com/drive/folders/15URQHblOVi_7YXZtgdFpjZAlKsoHylsq?usp=sharing) 13 | 14 | 15 | https://github.com/IDEA-Research/TOSS/assets/54578597/cd64c6c5-fef8-43c2-a223-7930ad6a71b7 16 | 17 | 18 | ## Install 19 | 20 | ### Create environment 21 | ```bash 22 | conda create -n toss python=3.9 23 | conda activate toss 24 | ``` 25 | 26 | ### Install packages 27 | ```bash 28 | pip install torch==2.0.0 torchvision==0.15.1 torchaudio==2.0.1 --index-url https://download.pytorch.org/whl/cu118 29 | pip install -r requirements.txt 30 | git clone https://github.com/openai/CLIP.git 31 | pip install -e CLIP/ 32 | ``` 33 | ### Weights 34 | Download pretrain weights from [this link](https://drive.google.com/drive/folders/15URQHblOVi_7YXZtgdFpjZAlKsoHylsq?usp=sharing) to sub-directory ./ckpt 35 | 36 | ## Inference 37 | 38 | We suggest gradio for a visualized inference and test this demo on a single RTX3090. 39 | 40 | ``` 41 | python app.py 42 | ``` 43 | 44 | ![image](assets/gradio.png) 45 | 46 | 47 | ## Todo List 48 | - [x] Release inference code. 49 | - [x] Release pretrained models. 50 | - [ ] Upload 3D generation code. 51 | - [ ] Upload training code. 52 | 53 | ## Acknowledgement 54 | - [ControlNet](https://github.com/lllyasviel/ControlNet/) 55 | - [Zero123](https://github.com/cvlab-columbia/zero123/) 56 | - [threestudio](https://github.com/threestudio-project/threestudio) 57 | 58 | ## Citation 59 | 60 | ``` 61 | @article{shi2023toss, 62 | title={Toss: High-quality text-guided novel view synthesis from a single image}, 63 | author={Shi, Yukai and Wang, Jianan and Cao, He and Tang, Boshi and Qi, Xianbiao and Yang, Tianyu and Huang, Yukun and Liu, Shilong and Zhang, Lei and Shum, Heung-Yeung}, 64 | journal={arXiv preprint arXiv:2310.10644}, 65 | year={2023} 66 | } 67 | ``` 68 | -------------------------------------------------------------------------------- /assets/3d_generation_video.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/assets/3d_generation_video.mp4 -------------------------------------------------------------------------------- /assets/car.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/assets/car.png -------------------------------------------------------------------------------- /assets/dragon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/assets/dragon.png -------------------------------------------------------------------------------- /assets/gradio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/assets/gradio.png -------------------------------------------------------------------------------- /assets/kunkun.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/assets/kunkun.png -------------------------------------------------------------------------------- /assets/rocket.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/assets/rocket.png -------------------------------------------------------------------------------- /cldm/hack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import einops 3 | 4 | import ldm.modules.encoders.modules 5 | import ldm.modules.attention 6 | 7 | from transformers import logging 8 | from ldm.modules.attention import default 9 | 10 | 11 | def disable_verbosity(): 12 | logging.set_verbosity_error() 13 | print('logging improved.') 14 | return 15 | 16 | 17 | def enable_sliced_attention(): 18 | ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward 19 | print('Enabled sliced_attention.') 20 | return 21 | 22 | 23 | def hack_everything(clip_skip=0): 24 | disable_verbosity() 25 | ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward 26 | ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip 27 | print('Enabled clip hacks.') 28 | return 29 | 30 | 31 | # Written by Lvmin 32 | def _hacked_clip_forward(self, text): 33 | PAD = self.tokenizer.pad_token_id 34 | EOS = self.tokenizer.eos_token_id 35 | BOS = self.tokenizer.bos_token_id 36 | 37 | def tokenize(t): 38 | return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"] 39 | 40 | def transformer_encode(t): 41 | if self.clip_skip > 1: 42 | rt = self.transformer(input_ids=t, output_hidden_states=True) 43 | return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip]) 44 | else: 45 | return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state 46 | 47 | def split(x): 48 | return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3] 49 | 50 | def pad(x, p, i): 51 | return x[:i] if len(x) >= i else x + [p] * (i - len(x)) 52 | 53 | raw_tokens_list = tokenize(text) 54 | tokens_list = [] 55 | 56 | for raw_tokens in raw_tokens_list: 57 | raw_tokens_123 = split(raw_tokens) 58 | raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123] 59 | raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123] 60 | tokens_list.append(raw_tokens_123) 61 | 62 | tokens_list = torch.IntTensor(tokens_list).to(self.device) 63 | 64 | feed = einops.rearrange(tokens_list, 'b f i -> (b f) i') 65 | y = transformer_encode(feed) 66 | z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3) 67 | 68 | return z 69 | 70 | 71 | # Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py 72 | def _hacked_sliced_attentin_forward(self, x, context=None, mask=None): 73 | h = self.heads 74 | 75 | q = self.to_q(x) 76 | context = default(context, x) 77 | k = self.to_k(context) 78 | v = self.to_v(context) 79 | del context, x 80 | 81 | q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 82 | 83 | limit = k.shape[0] 84 | att_step = 1 85 | q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0)) 86 | k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0)) 87 | v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0)) 88 | 89 | q_chunks.reverse() 90 | k_chunks.reverse() 91 | v_chunks.reverse() 92 | sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) 93 | del k, q, v 94 | for i in range(0, limit, att_step): 95 | q_buffer = q_chunks.pop() 96 | k_buffer = k_chunks.pop() 97 | v_buffer = v_chunks.pop() 98 | sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale 99 | 100 | del k_buffer, q_buffer 101 | # attention, what we cannot get enough of, by chunks 102 | 103 | sim_buffer = sim_buffer.softmax(dim=-1) 104 | 105 | sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer) 106 | del v_buffer 107 | sim[i:i + att_step, :, :] = sim_buffer 108 | 109 | del sim_buffer 110 | sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h) 111 | return self.to_out(sim) 112 | -------------------------------------------------------------------------------- /cldm/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import torchvision 6 | from PIL import Image 7 | from pytorch_lightning.callbacks import Callback 8 | from pytorch_lightning.utilities.distributed import rank_zero_only 9 | import time 10 | 11 | 12 | class ImageLogger(Callback): 13 | def __init__(self, batch_frequency=2000, max_images=4, clamp=True, increase_log_steps=True, 14 | rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False, 15 | log_images_kwargs=None, epoch_frequency=1): 16 | super().__init__() 17 | self.rescale = rescale 18 | self.batch_freq = batch_frequency 19 | self.max_images = max_images 20 | if not increase_log_steps: 21 | self.log_steps = [self.batch_freq] 22 | self.clamp = clamp 23 | self.disabled = disabled 24 | self.log_on_batch_idx = log_on_batch_idx 25 | self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} 26 | self.log_first_step = log_first_step 27 | self.epoch = 0 28 | self.epoch_frequency = epoch_frequency 29 | self.start_time = time.time() 30 | 31 | @rank_zero_only 32 | def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx): 33 | root = os.path.join(save_dir, "image_log", split) 34 | for k in images: 35 | grid = torchvision.utils.make_grid(images[k], nrow=4) 36 | if self.rescale: 37 | grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w 38 | grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) 39 | grid = grid.numpy() 40 | grid = (grid * 255).astype(np.uint8) 41 | filename = "epoch{:06}/{}_gs-{:06}_e-{:06}_b-{:06}.png".format(current_epoch,k, global_step, current_epoch, batch_idx) 42 | path = os.path.join(root, filename) 43 | os.makedirs(os.path.split(path)[0], exist_ok=True) 44 | Image.fromarray(grid).save(path) 45 | if current_epoch > self.epoch: 46 | self.epoch += 1 47 | 48 | def log_img(self, pl_module, batch, batch_idx, split="train"): 49 | check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step 50 | if (self.check_frequency(check_idx) and 51 | hasattr(pl_module, "log_images") and 52 | callable(pl_module.log_images) and 53 | self.max_images > 0): 54 | logger = type(pl_module.logger) 55 | 56 | is_train = pl_module.training 57 | if is_train: 58 | pl_module.eval() 59 | 60 | with torch.no_grad(): 61 | images = pl_module.log_images(batch, split=split, **self.log_images_kwargs) 62 | 63 | for k in images: 64 | N = min(images[k].shape[0], self.max_images) 65 | images[k] = images[k][:N] 66 | if isinstance(images[k], torch.Tensor): 67 | images[k] = images[k].detach().cpu() 68 | if self.clamp: 69 | images[k] = torch.clamp(images[k], -1., 1.) 70 | 71 | self.log_local(pl_module.logger.save_dir, split, images, 72 | pl_module.global_step, pl_module.current_epoch, batch_idx) 73 | 74 | if is_train: 75 | pl_module.train() 76 | 77 | def check_frequency(self, check_idx): 78 | return check_idx % self.batch_freq == 0 79 | 80 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 81 | if not self.disabled: 82 | self.log_img(pl_module, batch, batch_idx, split="train") 83 | 84 | def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 85 | if not self.disabled: 86 | self.log_img(pl_module, batch, batch_idx, split="val") 87 | 88 | def on_train_epoch_end(self, trainer, pl_module): 89 | end_time = time.time() 90 | self.time_counter(self.start_time, end_time, pl_module.global_step) 91 | 92 | def time_counter(self, begin_time, end_time, step): 93 | run_time = round(end_time - begin_time) 94 | # transfer time 95 | hour = run_time // 3600 96 | minute = (run_time - 3600 * hour) // 60 97 | second = run_time - 3600 * hour - 60 * minute 98 | print(f'Time cost until step {step}:{hour}h:{minute}miin:{second}s') -------------------------------------------------------------------------------- /cldm/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from omegaconf import OmegaConf 5 | from ldm.util import instantiate_from_config 6 | 7 | 8 | def get_state_dict(d): 9 | return d.get('state_dict', d) 10 | 11 | 12 | def load_state_dict(ckpt_path, location='cpu'): 13 | _, extension = os.path.splitext(ckpt_path) 14 | if extension.lower() == ".safetensors": 15 | import safetensors.torch 16 | state_dict = safetensors.torch.load_file(ckpt_path, device=location) 17 | else: 18 | state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location))) 19 | state_dict = get_state_dict(state_dict) 20 | print(f'Loaded state_dict from [{ckpt_path}]') 21 | return state_dict 22 | 23 | 24 | def create_model(config_path): 25 | config = OmegaConf.load(config_path) 26 | model = instantiate_from_config(config.model).cpu() 27 | print(f'Loaded model config from [{config_path}]') 28 | return model 29 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | save_memory = False 2 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .nsvf import NSVFDataset, NSVFDataset_v2, NSVFDataset_all 2 | from .colmap import ColmapDataset 3 | from .nerfpp import NeRFPPDataset 4 | from .objaverse import ObjaverseData 5 | 6 | 7 | dataset_dict = {'nsvf': NSVFDataset, 8 | 'nsvf_v2': NSVFDataset_v2, 9 | "nsvf_all": NSVFDataset_all, 10 | 'colmap': ColmapDataset, 11 | 'nerfpp': NeRFPPDataset, 12 | 'objaverse': ObjaverseData,} -------------------------------------------------------------------------------- /datasets/base.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms as T 2 | 3 | from torchvision import transforms 4 | from torch.utils.data import Dataset 5 | import numpy as np 6 | import torch 7 | import random 8 | import math 9 | import pdb 10 | from viz import save_image_tensor2cv2 11 | 12 | 13 | class BaseDataset(Dataset): 14 | """ 15 | Define length and sampling method 16 | """ 17 | def __init__(self, root_dir, split='train', text="a green chair", img_size=512, downsample=1.0): 18 | self.img_w, self.img_h = img_size, img_size 19 | self.root_dir = root_dir 20 | self.split = split 21 | self.downsample = downsample 22 | self.define_transforms() 23 | self.text = text 24 | 25 | def read_intrinsics(self): 26 | raise NotImplementedError 27 | 28 | def define_transforms(self): 29 | self.transform = transforms.Compose([T.ToTensor(), transforms.Resize(size=(self.img_w, self.img_h))]) 30 | # self.transform = T.ToTensor() 31 | 32 | def __len__(self): 33 | # if self.split.startswith('train'): 34 | # return 1000 35 | return len(self.poses) 36 | 37 | def cartesian_to_spherical(self, xyz): 38 | ptsnew = np.hstack((xyz, np.zeros(xyz.shape))) 39 | xy = xyz[:,0]**2 + xyz[:,1]**2 40 | z = np.sqrt(xy + xyz[:,2]**2) 41 | theta = np.arctan2(np.sqrt(xy), xyz[:,2]) # for elevation angle defined from Z-axis down 42 | #ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up 43 | azimuth = np.arctan2(xyz[:,1], xyz[:,0]) 44 | return np.array([theta, azimuth, z]) 45 | 46 | def get_T(self, target_RT, cond_RT): 47 | R, T = target_RT[:3, :3], target_RT[:, -1] 48 | T_target = -R.T @ T 49 | 50 | R, T = cond_RT[:3, :3], cond_RT[:, -1] 51 | T_cond = -R.T @ T 52 | 53 | theta_cond, azimuth_cond, z_cond = self.cartesian_to_spherical(T_cond[None, :]) 54 | theta_target, azimuth_target, z_target = self.cartesian_to_spherical(T_target[None, :]) 55 | 56 | d_theta = theta_target - theta_cond 57 | d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi) 58 | d_z = z_target - z_cond 59 | 60 | d_T = torch.tensor([d_theta.item(), d_azimuth.item(), d_z.item()]) 61 | return d_T 62 | 63 | def get_T_w2c(self, target_RT, cond_RT): 64 | T_target = target_RT[:, -1] 65 | T_cond = cond_RT[:, -1] 66 | 67 | theta_cond, azimuth_cond, z_cond = self.cartesian_to_spherical(T_cond[None, :]) 68 | theta_target, azimuth_target, z_target = self.cartesian_to_spherical(T_target[None, :]) 69 | 70 | d_theta = theta_target - theta_cond 71 | d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi) 72 | d_z = z_target - z_cond 73 | # d_z = (z_target - z_cond) / np.max(z_cond) * 2 74 | # pdb.set_trace() 75 | 76 | d_T = torch.tensor([d_theta.item(), d_azimuth.item(), d_z.item()]) 77 | return d_T 78 | 79 | def __getitem__(self, idx): 80 | # camera pose and img 81 | poses = self.poses[idx] 82 | img = self.imgs[idx] 83 | prompt = self.text 84 | 85 | # condition 86 | # idx_cond = idx % 5 87 | # idx_cond = 1 88 | idx_cond = random.randint(0, 4) 89 | # idx_cond = random.randint(0, len(self.poses)-1) 90 | poses_cond = self.poses[idx_cond] 91 | img_cond = self.imgs[idx_cond] 92 | 93 | # if len(self.imgs)>0: # if ground truth available 94 | # img_rgb = imgs[:, :,:3] 95 | # img_rgb_cond = imgs_cond[:, :,:3] 96 | # if imgs.shape[-1] == 4: # HDR-NeRF data 97 | # sample['exposure'] = rays[0, 3] # same exposure for all rays 98 | 99 | # Normalize target images to [-1, 1]. 100 | target = (img.float() / 127.5) - 1.0 101 | # Normalize source images to [0, 1]. 102 | condition = img_cond.float() / 255.0 103 | # save_image_tensor2cv2(condition.permute(2,0,1), "./viz_fig/input_cond.png") 104 | 105 | # get delta pose 106 | # delta_pose = self.get_T(target_RT=poses, cond_RT=poses_cond) 107 | delta_pose = self.get_T_w2c(target_RT=poses, cond_RT=poses_cond) 108 | 109 | return dict(jpg=target, txt=prompt, hint=condition, delta_pose=delta_pose) -------------------------------------------------------------------------------- /datasets/colmap.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | import glob 5 | from PIL import Image 6 | from einops import rearrange 7 | from tqdm import tqdm 8 | 9 | from .ray_utils import * 10 | from .colmap_utils import \ 11 | read_cameras_binary, read_images_binary, read_points3d_binary 12 | 13 | from .base import BaseDataset 14 | 15 | 16 | class ColmapDataset(BaseDataset): 17 | def __init__(self, root_dir, split='train', downsample=1.0, **kwargs): 18 | super().__init__(root_dir, split, downsample) 19 | 20 | self.read_intrinsics() 21 | 22 | if kwargs.get('read_meta', True): 23 | self.read_meta(split, **kwargs) 24 | 25 | def read_intrinsics(self): 26 | # Step 1: read and scale intrinsics (same for all images) 27 | camdata = read_cameras_binary(os.path.join(self.root_dir, 'sparse/0/cameras.bin')) 28 | h = int(camdata[1].height*self.downsample) 29 | w = int(camdata[1].width*self.downsample) 30 | self.img_wh = (w, h) 31 | 32 | if camdata[1].model == 'SIMPLE_RADIAL': 33 | fx = fy = camdata[1].params[0]*self.downsample 34 | cx = camdata[1].params[1]*self.downsample 35 | cy = camdata[1].params[2]*self.downsample 36 | elif camdata[1].model in ['PINHOLE', 'OPENCV']: 37 | fx = camdata[1].params[0]*self.downsample 38 | fy = camdata[1].params[1]*self.downsample 39 | cx = camdata[1].params[2]*self.downsample 40 | cy = camdata[1].params[3]*self.downsample 41 | else: 42 | raise ValueError(f"Please parse the intrinsics for camera model {camdata[1].model}!") 43 | self.K = torch.FloatTensor([[fx, 0, cx], 44 | [0, fy, cy], 45 | [0, 0, 1]]) 46 | self.directions = get_ray_directions(h, w, self.K) 47 | 48 | def read_meta(self, split, **kwargs): 49 | # Step 2: correct poses 50 | # read extrinsics (of successfully reconstructed images) 51 | imdata = read_images_binary(os.path.join(self.root_dir, 'sparse/0/images.bin')) 52 | img_names = [imdata[k].name for k in imdata] 53 | perm = np.argsort(img_names) 54 | if '360_v2' in self.root_dir and self.downsample<1: # mipnerf360 data 55 | folder = f'images_{int(1/self.downsample)}' 56 | else: 57 | folder = 'images' 58 | # read successfully reconstructed images and ignore others 59 | img_paths = [os.path.join(self.root_dir, folder, name) 60 | for name in sorted(img_names)] 61 | w2c_mats = [] 62 | bottom = np.array([[0, 0, 0, 1.]]) 63 | for k in imdata: 64 | im = imdata[k] 65 | R = im.qvec2rotmat(); t = im.tvec.reshape(3, 1) 66 | w2c_mats += [np.concatenate([np.concatenate([R, t], 1), bottom], 0)] 67 | w2c_mats = np.stack(w2c_mats, 0) 68 | poses = np.linalg.inv(w2c_mats)[perm, :3] # (N_images, 3, 4) cam2world matrices 69 | 70 | pts3d = read_points3d_binary(os.path.join(self.root_dir, 'sparse/0/points3D.bin')) 71 | pts3d = np.array([pts3d[k].xyz for k in pts3d]) # (N, 3) 72 | 73 | self.poses, self.pts3d = center_poses(poses, pts3d) 74 | 75 | scale = np.linalg.norm(self.poses[..., 3], axis=-1).min() 76 | self.poses[..., 3] /= scale 77 | self.pts3d /= scale 78 | 79 | self.rays = [] 80 | if split == 'test_traj': # use precomputed test poses 81 | self.poses = create_spheric_poses(1.2, self.poses[:, 1, 3].mean()) 82 | self.poses = torch.FloatTensor(self.poses) 83 | return 84 | 85 | if 'HDR-NeRF' in self.root_dir: # HDR-NeRF data 86 | if 'syndata' in self.root_dir: # synthetic 87 | # first 17 are test, last 18 are train 88 | self.unit_exposure_rgb = 0.73 89 | if split=='train': 90 | img_paths = sorted(glob.glob(os.path.join(self.root_dir, 91 | f'train/*[024].png'))) 92 | self.poses = np.repeat(self.poses[-18:], 3, 0) 93 | elif split=='test': 94 | img_paths = sorted(glob.glob(os.path.join(self.root_dir, 95 | f'test/*[13].png'))) 96 | self.poses = np.repeat(self.poses[:17], 2, 0) 97 | else: 98 | raise ValueError(f"split {split} is invalid for HDR-NeRF!") 99 | else: # real 100 | self.unit_exposure_rgb = 0.5 101 | # even numbers are train, odd numbers are test 102 | if split=='train': 103 | img_paths = sorted(glob.glob(os.path.join(self.root_dir, 104 | f'input_images/*0.jpg')))[::2] 105 | img_paths+= sorted(glob.glob(os.path.join(self.root_dir, 106 | f'input_images/*2.jpg')))[::2] 107 | img_paths+= sorted(glob.glob(os.path.join(self.root_dir, 108 | f'input_images/*4.jpg')))[::2] 109 | self.poses = np.tile(self.poses[::2], (3, 1, 1)) 110 | elif split=='test': 111 | img_paths = sorted(glob.glob(os.path.join(self.root_dir, 112 | f'input_images/*1.jpg')))[1::2] 113 | img_paths+= sorted(glob.glob(os.path.join(self.root_dir, 114 | f'input_images/*3.jpg')))[1::2] 115 | self.poses = np.tile(self.poses[1::2], (2, 1, 1)) 116 | else: 117 | raise ValueError(f"split {split} is invalid for HDR-NeRF!") 118 | else: 119 | # use every 8th image as test set 120 | if split=='train': 121 | img_paths = [x for i, x in enumerate(img_paths) if i%8!=0] 122 | self.poses = np.array([x for i, x in enumerate(self.poses) if i%8!=0]) 123 | elif split=='test': 124 | img_paths = [x for i, x in enumerate(img_paths) if i%8==0] 125 | self.poses = np.array([x for i, x in enumerate(self.poses) if i%8==0]) 126 | 127 | print(f'Loading {len(img_paths)} {split} images ...') 128 | for img_path in tqdm(img_paths): 129 | buf = [] # buffer for ray attributes: rgb, etc 130 | 131 | img = Image.open(img_path).convert('RGB').resize(self.img_wh, Image.LANCZOS) 132 | img = rearrange(self.transform(img), 'c h w -> (h w) c') 133 | buf += [img] 134 | 135 | if 'HDR-NeRF' in self.root_dir: # get exposure 136 | folder = self.root_dir.split('/') 137 | scene = folder[-1] if folder[-1] != '' else folder[-2] 138 | if scene in ['bathroom', 'bear', 'chair', 'desk']: 139 | e_dict = {e: 1/8*4**e for e in range(5)} 140 | elif scene in ['diningroom', 'dog']: 141 | e_dict = {e: 1/16*4**e for e in range(5)} 142 | elif scene in ['sofa']: 143 | e_dict = {0:0.25, 1:1, 2:2, 3:4, 4:16} 144 | elif scene in ['sponza']: 145 | e_dict = {0:0.5, 1:2, 2:4, 3:8, 4:32} 146 | elif scene in ['box']: 147 | e_dict = {0:2/3, 1:1/3, 2:1/6, 3:0.1, 4:0.05} 148 | elif scene in ['computer']: 149 | e_dict = {0:1/3, 1:1/8, 2:1/15, 3:1/30, 4:1/60} 150 | elif scene in ['flower']: 151 | e_dict = {0:1/3, 1:1/6, 2:0.1, 3:0.05, 4:1/45} 152 | elif scene in ['luckycat']: 153 | e_dict = {0:2, 1:1, 2:0.5, 3:0.25, 4:0.125} 154 | e = int(img_path.split('.')[0][-1]) 155 | buf += [e_dict[e]*torch.ones_like(img[:, :1])] 156 | 157 | self.rays += [torch.cat(buf, 1)] 158 | 159 | self.rays = torch.stack(self.rays) # (N_images, hw, ?) 160 | self.poses = torch.FloatTensor(self.poses) # (N_images, 3, 4) -------------------------------------------------------------------------------- /datasets/colmap_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # * Redistributions of source code must retain the above copyright 8 | # notice, this list of conditions and the following disclaimer. 9 | # 10 | # * Redistributions in binary form must reproduce the above copyright 11 | # notice, this list of conditions and the following disclaimer in the 12 | # documentation and/or other materials provided with the distribution. 13 | # 14 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of 15 | # its contributors may be used to endorse or promote products derived 16 | # from this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 22 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | # POSSIBILITY OF SUCH DAMAGE. 29 | # 30 | # Author: Johannes L. Schoenberger (jsch at inf.ethz.ch) 31 | 32 | import os 33 | import sys 34 | import collections 35 | import numpy as np 36 | import struct 37 | 38 | 39 | CameraModel = collections.namedtuple( 40 | "CameraModel", ["model_id", "model_name", "num_params"]) 41 | Camera = collections.namedtuple( 42 | "Camera", ["id", "model", "width", "height", "params"]) 43 | BaseImage = collections.namedtuple( 44 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) 45 | Point3D = collections.namedtuple( 46 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) 47 | 48 | class Image(BaseImage): 49 | def qvec2rotmat(self): 50 | return qvec2rotmat(self.qvec) 51 | 52 | 53 | CAMERA_MODELS = { 54 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 55 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 56 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 57 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 58 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 59 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 60 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 61 | CameraModel(model_id=7, model_name="FOV", num_params=5), 62 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 63 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 64 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) 65 | } 66 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) \ 67 | for camera_model in CAMERA_MODELS]) 68 | 69 | 70 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 71 | """Read and unpack the next bytes from a binary file. 72 | :param fid: 73 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 74 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 75 | :param endian_character: Any of {@, =, <, >, !} 76 | :return: Tuple of read and unpacked values. 77 | """ 78 | data = fid.read(num_bytes) 79 | return struct.unpack(endian_character + format_char_sequence, data) 80 | 81 | 82 | def read_cameras_text(path): 83 | """ 84 | see: src/base/reconstruction.cc 85 | void Reconstruction::WriteCamerasText(const std::string& path) 86 | void Reconstruction::ReadCamerasText(const std::string& path) 87 | """ 88 | cameras = {} 89 | with open(path, "r") as fid: 90 | while True: 91 | line = fid.readline() 92 | if not line: 93 | break 94 | line = line.strip() 95 | if len(line) > 0 and line[0] != "#": 96 | elems = line.split() 97 | camera_id = int(elems[0]) 98 | model = elems[1] 99 | width = int(elems[2]) 100 | height = int(elems[3]) 101 | params = np.array(tuple(map(float, elems[4:]))) 102 | cameras[camera_id] = Camera(id=camera_id, model=model, 103 | width=width, height=height, 104 | params=params) 105 | return cameras 106 | 107 | 108 | def read_cameras_binary(path_to_model_file): 109 | """ 110 | see: src/base/reconstruction.cc 111 | void Reconstruction::WriteCamerasBinary(const std::string& path) 112 | void Reconstruction::ReadCamerasBinary(const std::string& path) 113 | """ 114 | cameras = {} 115 | with open(path_to_model_file, "rb") as fid: 116 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 117 | for camera_line_index in range(num_cameras): 118 | camera_properties = read_next_bytes( 119 | fid, num_bytes=24, format_char_sequence="iiQQ") 120 | camera_id = camera_properties[0] 121 | model_id = camera_properties[1] 122 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 123 | width = camera_properties[2] 124 | height = camera_properties[3] 125 | num_params = CAMERA_MODEL_IDS[model_id].num_params 126 | params = read_next_bytes(fid, num_bytes=8*num_params, 127 | format_char_sequence="d"*num_params) 128 | cameras[camera_id] = Camera(id=camera_id, 129 | model=model_name, 130 | width=width, 131 | height=height, 132 | params=np.array(params)) 133 | assert len(cameras) == num_cameras 134 | return cameras 135 | 136 | 137 | def read_images_text(path): 138 | """ 139 | see: src/base/reconstruction.cc 140 | void Reconstruction::ReadImagesText(const std::string& path) 141 | void Reconstruction::WriteImagesText(const std::string& path) 142 | """ 143 | images = {} 144 | with open(path, "r") as fid: 145 | while True: 146 | line = fid.readline() 147 | if not line: 148 | break 149 | line = line.strip() 150 | if len(line) > 0 and line[0] != "#": 151 | elems = line.split() 152 | image_id = int(elems[0]) 153 | qvec = np.array(tuple(map(float, elems[1:5]))) 154 | tvec = np.array(tuple(map(float, elems[5:8]))) 155 | camera_id = int(elems[8]) 156 | image_name = elems[9] 157 | elems = fid.readline().split() 158 | xys = np.column_stack([tuple(map(float, elems[0::3])), 159 | tuple(map(float, elems[1::3]))]) 160 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 161 | images[image_id] = Image( 162 | id=image_id, qvec=qvec, tvec=tvec, 163 | camera_id=camera_id, name=image_name, 164 | xys=xys, point3D_ids=point3D_ids) 165 | return images 166 | 167 | 168 | def read_images_binary(path_to_model_file): 169 | """ 170 | see: src/base/reconstruction.cc 171 | void Reconstruction::ReadImagesBinary(const std::string& path) 172 | void Reconstruction::WriteImagesBinary(const std::string& path) 173 | """ 174 | images = {} 175 | with open(path_to_model_file, "rb") as fid: 176 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 177 | for image_index in range(num_reg_images): 178 | binary_image_properties = read_next_bytes( 179 | fid, num_bytes=64, format_char_sequence="idddddddi") 180 | image_id = binary_image_properties[0] 181 | qvec = np.array(binary_image_properties[1:5]) 182 | tvec = np.array(binary_image_properties[5:8]) 183 | camera_id = binary_image_properties[8] 184 | image_name = "" 185 | current_char = read_next_bytes(fid, 1, "c")[0] 186 | while current_char != b"\x00": # look for the ASCII 0 entry 187 | image_name += current_char.decode("utf-8") 188 | current_char = read_next_bytes(fid, 1, "c")[0] 189 | num_points2D = read_next_bytes(fid, num_bytes=8, 190 | format_char_sequence="Q")[0] 191 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, 192 | format_char_sequence="ddq"*num_points2D) 193 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), 194 | tuple(map(float, x_y_id_s[1::3]))]) 195 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 196 | images[image_id] = Image( 197 | id=image_id, qvec=qvec, tvec=tvec, 198 | camera_id=camera_id, name=image_name, 199 | xys=xys, point3D_ids=point3D_ids) 200 | return images 201 | 202 | 203 | def read_points3D_text(path): 204 | """ 205 | see: src/base/reconstruction.cc 206 | void Reconstruction::ReadPoints3DText(const std::string& path) 207 | void Reconstruction::WritePoints3DText(const std::string& path) 208 | """ 209 | points3D = {} 210 | with open(path, "r") as fid: 211 | while True: 212 | line = fid.readline() 213 | if not line: 214 | break 215 | line = line.strip() 216 | if len(line) > 0 and line[0] != "#": 217 | elems = line.split() 218 | point3D_id = int(elems[0]) 219 | xyz = np.array(tuple(map(float, elems[1:4]))) 220 | rgb = np.array(tuple(map(int, elems[4:7]))) 221 | error = float(elems[7]) 222 | image_ids = np.array(tuple(map(int, elems[8::2]))) 223 | point2D_idxs = np.array(tuple(map(int, elems[9::2]))) 224 | points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb, 225 | error=error, image_ids=image_ids, 226 | point2D_idxs=point2D_idxs) 227 | return points3D 228 | 229 | 230 | def read_points3d_binary(path_to_model_file): 231 | """ 232 | see: src/base/reconstruction.cc 233 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 234 | void Reconstruction::WritePoints3DBinary(const std::string& path) 235 | """ 236 | points3D = {} 237 | with open(path_to_model_file, "rb") as fid: 238 | num_points = read_next_bytes(fid, 8, "Q")[0] 239 | for point_line_index in range(num_points): 240 | binary_point_line_properties = read_next_bytes( 241 | fid, num_bytes=43, format_char_sequence="QdddBBBd") 242 | point3D_id = binary_point_line_properties[0] 243 | xyz = np.array(binary_point_line_properties[1:4]) 244 | rgb = np.array(binary_point_line_properties[4:7]) 245 | error = np.array(binary_point_line_properties[7]) 246 | track_length = read_next_bytes( 247 | fid, num_bytes=8, format_char_sequence="Q")[0] 248 | track_elems = read_next_bytes( 249 | fid, num_bytes=8*track_length, 250 | format_char_sequence="ii"*track_length) 251 | image_ids = np.array(tuple(map(int, track_elems[0::2]))) 252 | point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) 253 | points3D[point3D_id] = Point3D( 254 | id=point3D_id, xyz=xyz, rgb=rgb, 255 | error=error, image_ids=image_ids, 256 | point2D_idxs=point2D_idxs) 257 | return points3D 258 | 259 | 260 | def read_model(path, ext): 261 | if ext == ".txt": 262 | cameras = read_cameras_text(os.path.join(path, "cameras" + ext)) 263 | images = read_images_text(os.path.join(path, "images" + ext)) 264 | points3D = read_points3D_text(os.path.join(path, "points3D") + ext) 265 | else: 266 | cameras = read_cameras_binary(os.path.join(path, "cameras" + ext)) 267 | images = read_images_binary(os.path.join(path, "images" + ext)) 268 | points3D = read_points3d_binary(os.path.join(path, "points3D") + ext) 269 | return cameras, images, points3D 270 | 271 | 272 | def qvec2rotmat(qvec): 273 | return np.array([ 274 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 275 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 276 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 277 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 278 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 279 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 280 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 281 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 282 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) 283 | 284 | 285 | def rotmat2qvec(R): 286 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 287 | K = np.array([ 288 | [Rxx - Ryy - Rzz, 0, 0, 0], 289 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 290 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 291 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 292 | eigvals, eigvecs = np.linalg.eigh(K) 293 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 294 | if qvec[0] < 0: 295 | qvec *= -1 296 | return qvec -------------------------------------------------------------------------------- /datasets/depth_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import re 3 | 4 | 5 | def read_pfm(path): 6 | """Read pfm file. 7 | 8 | Args: 9 | path (str): path to file 10 | 11 | Returns: 12 | tuple: (data, scale) 13 | """ 14 | with open(path, "rb") as file: 15 | 16 | color = None 17 | width = None 18 | height = None 19 | scale = None 20 | endian = None 21 | 22 | header = file.readline().rstrip() 23 | if header.decode("ascii") == "PF": 24 | color = True 25 | elif header.decode("ascii") == "Pf": 26 | color = False 27 | else: 28 | raise Exception("Not a PFM file: " + path) 29 | 30 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) 31 | if dim_match: 32 | width, height = list(map(int, dim_match.groups())) 33 | else: 34 | raise Exception("Malformed PFM header.") 35 | 36 | scale = float(file.readline().decode("ascii").rstrip()) 37 | if scale < 0: 38 | # little-endian 39 | endian = "<" 40 | scale = -scale 41 | else: 42 | # big-endian 43 | endian = ">" 44 | 45 | data = np.fromfile(file, endian + "f") 46 | shape = (height, width, 3) if color else (height, width) 47 | 48 | data = np.reshape(data, shape) 49 | data = np.flipud(data) 50 | 51 | return data, scale -------------------------------------------------------------------------------- /datasets/nerfpp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import glob 3 | import numpy as np 4 | import os 5 | from PIL import Image 6 | from einops import rearrange 7 | from tqdm import tqdm 8 | 9 | from .ray_utils import get_ray_directions 10 | 11 | from .base import BaseDataset 12 | 13 | 14 | class NeRFPPDataset(BaseDataset): 15 | def __init__(self, root_dir, split='train', downsample=1.0, **kwargs): 16 | super().__init__(root_dir, split, downsample) 17 | 18 | self.read_intrinsics() 19 | 20 | if kwargs.get('read_meta', True): 21 | self.read_meta(split) 22 | 23 | def read_intrinsics(self): 24 | K = np.loadtxt(glob.glob(os.path.join(self.root_dir, 'train/intrinsics/*.txt'))[0], 25 | dtype=np.float32).reshape(4, 4)[:3, :3] 26 | K[:2] *= self.downsample 27 | w, h = Image.open(glob.glob(os.path.join(self.root_dir, 'train/rgb/*'))[0]).size 28 | w, h = int(w*self.downsample), int(h*self.downsample) 29 | self.K = torch.FloatTensor(K) 30 | self.directions = get_ray_directions(h, w, self.K) 31 | self.img_wh = (w, h) 32 | 33 | def read_meta(self, split): 34 | self.rays = [] 35 | self.poses = [] 36 | 37 | if split == 'test_traj': 38 | poses_path = \ 39 | sorted(glob.glob(os.path.join(self.root_dir, 'camera_path/pose/*.txt'))) 40 | self.poses = [np.loadtxt(p).reshape(4, 4)[:3] for p in poses_path] 41 | else: 42 | if split=='trainval': 43 | imgs = sorted(glob.glob(os.path.join(self.root_dir, 'train/rgb/*')))+\ 44 | sorted(glob.glob(os.path.join(self.root_dir, 'val/rgb/*'))) 45 | poses = sorted(glob.glob(os.path.join(self.root_dir, 'train/pose/*.txt')))+\ 46 | sorted(glob.glob(os.path.join(self.root_dir, 'val/pose/*.txt'))) 47 | else: 48 | imgs = sorted(glob.glob(os.path.join(self.root_dir, split, 'rgb/*'))) 49 | poses = sorted(glob.glob(os.path.join(self.root_dir, split, 'pose/*.txt'))) 50 | 51 | print(f'Loading {len(imgs)} {split} images ...') 52 | for img, pose in tqdm(zip(imgs, poses)): 53 | self.poses += [np.loadtxt(pose).reshape(4, 4)[:3]] 54 | 55 | img = Image.open(img).convert('RGB').resize(self.img_wh, Image.LANCZOS) 56 | img = rearrange(self.transform(img), 'c h w -> (h w) c') 57 | 58 | self.rays += [img] 59 | 60 | self.rays = torch.stack(self.rays) # (N_images, hw, ?) 61 | self.poses = torch.FloatTensor(self.poses) # (N_images, 3, 4) 62 | -------------------------------------------------------------------------------- /datasets/ray_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from kornia import create_meshgrid 4 | from einops import rearrange 5 | 6 | 7 | @torch.cuda.amp.autocast(dtype=torch.float32) 8 | def get_ray_directions(H, W, K, device='cpu', random=False, return_uv=False, flatten=True): 9 | """ 10 | Get ray directions for all pixels in camera coordinate [right down front]. 11 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ 12 | ray-tracing-generating-camera-rays/standard-coordinate-systems 13 | 14 | Inputs: 15 | H, W: image height and width 16 | K: (3, 3) camera intrinsics 17 | random: whether the ray passes randomly inside the pixel 18 | return_uv: whether to return uv image coordinates 19 | 20 | Outputs: (shape depends on @flatten) 21 | directions: (H, W, 3) or (H*W, 3), the direction of the rays in camera coordinate 22 | uv: (H, W, 2) or (H*W, 2) image coordinates 23 | """ 24 | grid = create_meshgrid(H, W, False, device=device)[0] # (H, W, 2) 25 | u, v = grid.unbind(-1) 26 | 27 | fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2] 28 | if random: 29 | directions = \ 30 | torch.stack([(u-cx+torch.rand_like(u))/fx, 31 | (v-cy+torch.rand_like(v))/fy, 32 | torch.ones_like(u)], -1) 33 | else: # pass by the center 34 | directions = \ 35 | torch.stack([(u-cx+0.5)/fx, (v-cy+0.5)/fy, torch.ones_like(u)], -1) 36 | if flatten: 37 | directions = directions.reshape(-1, 3) 38 | grid = grid.reshape(-1, 2) 39 | 40 | if return_uv: 41 | return directions, grid 42 | return directions 43 | 44 | 45 | @torch.cuda.amp.autocast(dtype=torch.float32) 46 | def get_rays(directions, c2w): 47 | """ 48 | Get ray origin and directions in world coordinate for all pixels in one image. 49 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ 50 | ray-tracing-generating-camera-rays/standard-coordinate-systems 51 | 52 | Inputs: 53 | directions: (N, 3) ray directions in camera coordinate 54 | c2w: (3, 4) or (N, 3, 4) transformation matrix from camera coordinate to world coordinate 55 | 56 | Outputs: 57 | rays_o: (N, 3), the origin of the rays in world coordinate 58 | rays_d: (N, 3), the direction of the rays in world coordinate 59 | """ 60 | if c2w.ndim==2: 61 | # Rotate ray directions from camera coordinate to the world coordinate 62 | rays_d = directions @ c2w[:, :3].T 63 | else: 64 | rays_d = rearrange(directions, 'n c -> n 1 c') @ c2w[..., :3].mT 65 | rays_d = rearrange(rays_d, 'n 1 c -> n c') 66 | # The origin of all rays is the camera origin in world coordinate 67 | rays_o = c2w[..., 3].expand_as(rays_d) 68 | 69 | return rays_o, rays_d 70 | 71 | 72 | @torch.cuda.amp.autocast(dtype=torch.float32) 73 | def axisangle_to_R(v): 74 | """ 75 | Convert an axis-angle vector to rotation matrix 76 | from https://github.com/ActiveVisionLab/nerfmm/blob/main/utils/lie_group_helper.py#L47 77 | 78 | Inputs: 79 | v: (B, 3) 80 | 81 | Outputs: 82 | R: (B, 3, 3) 83 | """ 84 | zero = torch.zeros_like(v[:, :1]) # (B, 1) 85 | skew_v0 = torch.cat([zero, -v[:, 2:3], v[:, 1:2]], 1) # (B, 3) 86 | skew_v1 = torch.cat([v[:, 2:3], zero, -v[:, 0:1]], 1) 87 | skew_v2 = torch.cat([-v[:, 1:2], v[:, 0:1], zero], 1) 88 | skew_v = torch.stack([skew_v0, skew_v1, skew_v2], dim=1) # (B, 3, 3) 89 | 90 | norm_v = rearrange(torch.norm(v, dim=1)+1e-7, 'b -> b 1 1') 91 | eye = torch.eye(3, device=v.device) 92 | R = eye + (torch.sin(norm_v)/norm_v)*skew_v + \ 93 | ((1-torch.cos(norm_v))/norm_v**2)*(skew_v@skew_v) 94 | return R 95 | 96 | 97 | def normalize(v): 98 | """Normalize a vector.""" 99 | return v/np.linalg.norm(v) 100 | 101 | 102 | def average_poses(poses, pts3d=None): 103 | """ 104 | Calculate the average pose, which is then used to center all poses 105 | using @center_poses. Its computation is as follows: 106 | 1. Compute the center: the average of 3d point cloud (if None, center of cameras). 107 | 2. Compute the z axis: the normalized average z axis. 108 | 3. Compute axis y': the average y axis. 109 | 4. Compute x' = y' cross product z, then normalize it as the x axis. 110 | 5. Compute the y axis: z cross product x. 111 | 112 | Note that at step 3, we cannot directly use y' as y axis since it's 113 | not necessarily orthogonal to z axis. We need to pass from x to y. 114 | Inputs: 115 | poses: (N_images, 3, 4) 116 | pts3d: (N, 3) 117 | 118 | Outputs: 119 | pose_avg: (3, 4) the average pose 120 | """ 121 | # 1. Compute the center 122 | if pts3d is not None: 123 | center = pts3d.mean(0) 124 | else: 125 | center = poses[..., 3].mean(0) 126 | 127 | # 2. Compute the z axis 128 | z = normalize(poses[..., 2].mean(0)) # (3) 129 | 130 | # 3. Compute axis y' (no need to normalize as it's not the final output) 131 | y_ = poses[..., 1].mean(0) # (3) 132 | 133 | # 4. Compute the x axis 134 | x = normalize(np.cross(y_, z)) # (3) 135 | 136 | # 5. Compute the y axis (as z and x are normalized, y is already of norm 1) 137 | y = np.cross(z, x) # (3) 138 | 139 | pose_avg = np.stack([x, y, z, center], 1) # (3, 4) 140 | 141 | return pose_avg 142 | 143 | 144 | def center_poses(poses, pts3d=None): 145 | """ 146 | See https://github.com/bmild/nerf/issues/34 147 | Inputs: 148 | poses: (N_images, 3, 4) 149 | pts3d: (N, 3) reconstructed point cloud 150 | 151 | Outputs: 152 | poses_centered: (N_images, 3, 4) the centered poses 153 | pts3d_centered: (N, 3) centered point cloud 154 | """ 155 | 156 | pose_avg = average_poses(poses, pts3d) # (3, 4) 157 | pose_avg_homo = np.eye(4) 158 | pose_avg_homo[:3] = pose_avg # convert to homogeneous coordinate for faster computation 159 | # by simply adding 0, 0, 0, 1 as the last row 160 | pose_avg_inv = np.linalg.inv(pose_avg_homo) 161 | last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1)) # (N_images, 1, 4) 162 | poses_homo = \ 163 | np.concatenate([poses, last_row], 1) # (N_images, 4, 4) homogeneous coordinate 164 | 165 | poses_centered = pose_avg_inv @ poses_homo # (N_images, 4, 4) 166 | poses_centered = poses_centered[:, :3] # (N_images, 3, 4) 167 | 168 | if pts3d is not None: 169 | pts3d_centered = pts3d @ pose_avg_inv[:, :3].T + pose_avg_inv[:, 3:].T 170 | return poses_centered, pts3d_centered 171 | 172 | return poses_centered 173 | 174 | def create_spheric_poses(radius, mean_h, n_poses=120): 175 | """ 176 | Create circular poses around z axis. 177 | Inputs: 178 | radius: the (negative) height and the radius of the circle. 179 | mean_h: mean camera height 180 | Outputs: 181 | spheric_poses: (n_poses, 3, 4) the poses in the circular path 182 | """ 183 | def spheric_pose(theta, phi, radius): 184 | trans_t = lambda t : np.array([ 185 | [1,0,0,0], 186 | [0,1,0,2*mean_h], 187 | [0,0,1,-t] 188 | ]) 189 | 190 | rot_phi = lambda phi : np.array([ 191 | [1,0,0], 192 | [0,np.cos(phi),-np.sin(phi)], 193 | [0,np.sin(phi), np.cos(phi)] 194 | ]) 195 | 196 | rot_theta = lambda th : np.array([ 197 | [np.cos(th),0,-np.sin(th)], 198 | [0,1,0], 199 | [np.sin(th),0, np.cos(th)] 200 | ]) 201 | 202 | c2w = rot_theta(theta) @ rot_phi(phi) @ trans_t(radius) 203 | c2w = np.array([[-1,0,0],[0,0,1],[0,1,0]]) @ c2w 204 | return c2w 205 | 206 | spheric_poses = [] 207 | for th in np.linspace(0, 2*np.pi, n_poses+1)[:-1]: 208 | spheric_poses += [spheric_pose(th, -np.pi/12, radius)] 209 | return np.stack(spheric_poses, 0) -------------------------------------------------------------------------------- /datasets/rtmv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import glob 3 | import json 4 | #### Under construction. Don't use now 5 | 6 | import numpy as np 7 | import os 8 | import imageio 9 | import cv2 10 | from einops import rearrange 11 | from tqdm import tqdm 12 | 13 | from .ray_utils import get_ray_directions 14 | 15 | from .base import BaseDataset 16 | 17 | 18 | def srgb_to_linear(img): 19 | limit = 0.04045 20 | return np.where(img > limit, np.power((img + 0.055) / 1.055, 2.4), img / 12.92) 21 | 22 | 23 | def linear_to_srgb(img): 24 | limit = 0.0031308 25 | return np.where(img > limit, 1.055 * (img ** (1.0 / 2.4)) - 0.055, 12.92 * img) 26 | 27 | 28 | class RTMVDataset(BaseDataset): 29 | def __init__(self, root_dir, split='train', downsample=1.0, **kwargs): 30 | super().__init__(root_dir, split, downsample) 31 | 32 | with open(os.path.join(self.root_dir, '00000.json'), 'r') as f: 33 | meta = json.load(f)['camera_data'] 34 | self.shift = np.array(meta['scene_center_3d_box']) 35 | self.scale = (np.array(meta['scene_max_3d_box'])- 36 | np.array(meta['scene_min_3d_box'])).max()/2 * 1.05 # enlarge a little 37 | 38 | fx = meta['intrinsics']['fx'] * downsample 39 | fy = meta['intrinsics']['fy'] * downsample 40 | cx = meta['intrinsics']['cx'] * downsample 41 | cy = meta['intrinsics']['cy'] * downsample 42 | w = int(meta['width']*downsample) 43 | h = int(meta['height']*downsample) 44 | K = np.float32([[fx, 0, cx], 45 | [0, fy, cy], 46 | [0, 0, 1]]) 47 | self.K = torch.FloatTensor(K) 48 | self.directions = get_ray_directions(h, w, self.K) 49 | self.img_wh = (w, h) 50 | 51 | self.read_meta(split) 52 | 53 | def read_meta(self, split): 54 | self.rays = [] 55 | self.poses = [] 56 | 57 | if split == 'train': start_idx, end_idx = 0, 100 58 | elif split == 'trainval': start_idx, end_idx = 0, 105 59 | elif split == 'test': start_idx, end_idx = 105, 150 60 | else: raise ValueError(f'{split} split not recognized!') 61 | imgs = sorted(glob.glob(os.path.join(self.root_dir, '*[0-9].exr')))[start_idx:end_idx] 62 | poses = sorted(glob.glob(os.path.join(self.root_dir, '*.json')))[start_idx:end_idx] 63 | 64 | print(f'Loading {len(imgs)} {split} images ...') 65 | for img, pose in tqdm(zip(imgs, poses)): 66 | with open(pose, 'r') as f: 67 | m = json.load(f)['camera_data'] 68 | c2w = np.zeros((3, 4), dtype=np.float32) 69 | c2w[:, :3] = -np.array(m['cam2world'])[:3, :3].T 70 | c2w[:, 3] = np.array(m['location_world'])-self.shift 71 | c2w[:, 3] /= 2*self.scale # to bound the scene inside [-0.5, 0.5] 72 | self.poses += [c2w] 73 | 74 | img = imageio.imread(img)[..., :3] 75 | img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_LANCZOS4) 76 | # img = np.clip(linear_to_srgb(img), 0, 1) 77 | img = rearrange(torch.FloatTensor(img), 'h w c -> (h w) c') 78 | 79 | self.rays += [img] 80 | 81 | self.rays = torch.stack(self.rays) # (N_images, hw, ?) 82 | self.poses = torch.FloatTensor(self.poses) # (N_images, 3, 4) 83 | -------------------------------------------------------------------------------- /ldm/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/ldm/data/__init__.py -------------------------------------------------------------------------------- /ldm/data/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset 3 | 4 | 5 | class Txt2ImgIterableBaseDataset(IterableDataset): 6 | ''' 7 | Define an interface to make the IterableDatasets for text2img data chainable 8 | ''' 9 | def __init__(self, num_records=0, valid_ids=None, size=256): 10 | super().__init__() 11 | self.num_records = num_records 12 | self.valid_ids = valid_ids 13 | self.sample_ids = valid_ids 14 | self.size = size 15 | 16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') 17 | 18 | def __len__(self): 19 | return self.num_records 20 | 21 | @abstractmethod 22 | def __iter__(self): 23 | pass -------------------------------------------------------------------------------- /ldm/data/personalized.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL 4 | from PIL import Image 5 | import matplotlib.pyplot as plt 6 | import torch 7 | from torch.utils.data import Dataset 8 | from torchvision import transforms 9 | 10 | import random 11 | 12 | imagenet_templates_smallest = [ 13 | 'a photo of a {}', 14 | ] 15 | 16 | imagenet_templates_small = [ 17 | 'a photo of a {}', 18 | 'a rendering of a {}', 19 | 'a cropped photo of the {}', 20 | 'the photo of a {}', 21 | 'a photo of a clean {}', 22 | 'a photo of a dirty {}', 23 | 'a dark photo of the {}', 24 | 'a photo of my {}', 25 | 'a photo of the cool {}', 26 | 'a close-up photo of a {}', 27 | 'a bright photo of the {}', 28 | 'a cropped photo of a {}', 29 | 'a photo of the {}', 30 | 'a good photo of the {}', 31 | 'a photo of one {}', 32 | 'a close-up photo of the {}', 33 | 'a rendition of the {}', 34 | 'a photo of the clean {}', 35 | 'a rendition of a {}', 36 | 'a photo of a nice {}', 37 | 'a good photo of a {}', 38 | 'a photo of the nice {}', 39 | 'a photo of the small {}', 40 | 'a photo of the weird {}', 41 | 'a photo of the large {}', 42 | 'a photo of a cool {}', 43 | 'a photo of a small {}', 44 | 'an illustration of a {}', 45 | 'a rendering of a {}', 46 | 'a cropped photo of the {}', 47 | 'the photo of a {}', 48 | 'an illustration of a clean {}', 49 | 'an illustration of a dirty {}', 50 | 'a dark photo of the {}', 51 | 'an illustration of my {}', 52 | 'an illustration of the cool {}', 53 | 'a close-up photo of a {}', 54 | 'a bright photo of the {}', 55 | 'a cropped photo of a {}', 56 | 'an illustration of the {}', 57 | 'a good photo of the {}', 58 | 'an illustration of one {}', 59 | 'a close-up photo of the {}', 60 | 'a rendition of the {}', 61 | 'an illustration of the clean {}', 62 | 'a rendition of a {}', 63 | 'an illustration of a nice {}', 64 | 'a good photo of a {}', 65 | 'an illustration of the nice {}', 66 | 'an illustration of the small {}', 67 | 'an illustration of the weird {}', 68 | 'an illustration of the large {}', 69 | 'an illustration of a cool {}', 70 | 'an illustration of a small {}', 71 | 'a depiction of a {}', 72 | 'a rendering of a {}', 73 | 'a cropped photo of the {}', 74 | 'the photo of a {}', 75 | 'a depiction of a clean {}', 76 | 'a depiction of a dirty {}', 77 | 'a dark photo of the {}', 78 | 'a depiction of my {}', 79 | 'a depiction of the cool {}', 80 | 'a close-up photo of a {}', 81 | 'a bright photo of the {}', 82 | 'a cropped photo of a {}', 83 | 'a depiction of the {}', 84 | 'a good photo of the {}', 85 | 'a depiction of one {}', 86 | 'a close-up photo of the {}', 87 | 'a rendition of the {}', 88 | 'a depiction of the clean {}', 89 | 'a rendition of a {}', 90 | 'a depiction of a nice {}', 91 | 'a good photo of a {}', 92 | 'a depiction of the nice {}', 93 | 'a depiction of the small {}', 94 | 'a depiction of the weird {}', 95 | 'a depiction of the large {}', 96 | 'a depiction of a cool {}', 97 | 'a depiction of a small {}', 98 | ] 99 | 100 | imagenet_dual_templates_small = [ 101 | 'a photo of a {} with {}', 102 | 'a rendering of a {} with {}', 103 | 'a cropped photo of the {} with {}', 104 | 'the photo of a {} with {}', 105 | 'a photo of a clean {} with {}', 106 | 'a photo of a dirty {} with {}', 107 | 'a dark photo of the {} with {}', 108 | 'a photo of my {} with {}', 109 | 'a photo of the cool {} with {}', 110 | 'a close-up photo of a {} with {}', 111 | 'a bright photo of the {} with {}', 112 | 'a cropped photo of a {} with {}', 113 | 'a photo of the {} with {}', 114 | 'a good photo of the {} with {}', 115 | 'a photo of one {} with {}', 116 | 'a close-up photo of the {} with {}', 117 | 'a rendition of the {} with {}', 118 | 'a photo of the clean {} with {}', 119 | 'a rendition of a {} with {}', 120 | 'a photo of a nice {} with {}', 121 | 'a good photo of a {} with {}', 122 | 'a photo of the nice {} with {}', 123 | 'a photo of the small {} with {}', 124 | 'a photo of the weird {} with {}', 125 | 'a photo of the large {} with {}', 126 | 'a photo of a cool {} with {}', 127 | 'a photo of a small {} with {}', 128 | ] 129 | 130 | per_img_token_list = [ 131 | 'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת', 132 | ] 133 | 134 | class PersonalizedBase(Dataset): 135 | def __init__(self, 136 | data_root, 137 | size=None, 138 | repeats=100, 139 | interpolation="bicubic", 140 | flip_p=0.5, 141 | set="train", 142 | placeholder_token="*", 143 | per_image_tokens=False, 144 | center_crop=False, 145 | mixing_prob=0.25, 146 | coarse_class_text=None, 147 | ): 148 | 149 | self.data_root = data_root 150 | 151 | self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)] 152 | 153 | # self._length = len(self.image_paths) 154 | self.num_images = len(self.image_paths) 155 | self._length = self.num_images 156 | 157 | self.placeholder_token = placeholder_token 158 | 159 | self.per_image_tokens = per_image_tokens 160 | self.center_crop = center_crop 161 | self.mixing_prob = mixing_prob 162 | 163 | self.coarse_class_text = coarse_class_text 164 | 165 | if per_image_tokens: 166 | assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'." 167 | 168 | if set == "train": 169 | self._length = self.num_images * repeats 170 | 171 | self.size = size 172 | self.interpolation = { 173 | "bilinear": Image.Resampling.BILINEAR, 174 | "bicubic": Image.Resampling.BICUBIC, 175 | "lanczos": Image.Resampling.LANCZOS, 176 | }[interpolation] 177 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) 178 | 179 | def __len__(self): 180 | return self._length 181 | 182 | 183 | def __getitem__(self, i): 184 | example = {} 185 | img = plt.imread(self.image_paths[i % self.num_images]) 186 | 187 | placeholder_string = self.placeholder_token 188 | if self.coarse_class_text: 189 | placeholder_string = f"{self.coarse_class_text} {placeholder_string}" 190 | 191 | if self.per_image_tokens and np.random.uniform() < self.mixing_prob: 192 | text = random.choice(imagenet_dual_templates_small).format(placeholder_string, per_img_token_list[i % self.num_images]) 193 | else: 194 | text = random.choice(imagenet_templates_small).format(placeholder_string) 195 | 196 | # default to score-sde preprocessing 197 | # alpha channel affects 198 | if img.shape[-1] == 4: 199 | img = img[:,:,:3]*img[:,:, -1:]+(1-img[:,:, -1:]) 200 | 201 | if self.center_crop: 202 | crop = min(img.shape[0], img.shape[1]) 203 | h, w, = img.shape[0], img.shape[1] 204 | img = img[(h - crop) // 2:(h + crop) // 2, 205 | (w - crop) // 2:(w + crop) // 2] 206 | image = Image.fromarray(np.uint8(img[:, :, :3] * 255.)).convert("RGB") 207 | if self.size is not None: 208 | image = image.resize((self.size, self.size), resample=self.interpolation) 209 | 210 | image = self.flip(image) 211 | image = np.array(image).astype(np.uint8) 212 | example["jpg"] = (image / 127.5 - 1.0).astype(np.float32) 213 | example["txt"] = text 214 | example["hint"] = (image / 255.0).astype(np.float32) 215 | example["delta_pose"] = torch.tensor([0., 0., 0.]) 216 | return example -------------------------------------------------------------------------------- /ldm/data/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ldm.modules.midas.api import load_midas_transform 4 | 5 | 6 | class AddMiDaS(object): 7 | def __init__(self, model_type): 8 | super().__init__() 9 | self.transform = load_midas_transform(model_type) 10 | 11 | def pt2np(self, x): 12 | x = ((x + 1.0) * .5).detach().cpu().numpy() 13 | return x 14 | 15 | def np2pt(self, x): 16 | x = torch.from_numpy(x) * 2 - 1. 17 | return x 18 | 19 | def __call__(self, sample): 20 | # sample['jpg'] is tensor hwc in [-1, 1] at this point 21 | x = self.pt2np(sample['jpg']) 22 | x = self.transform({"image": x})["image"] 23 | sample['midas_in'] = x 24 | return sample -------------------------------------------------------------------------------- /ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /ldm/models/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | import torch.nn.functional as F 4 | from contextlib import contextmanager 5 | 6 | from ldm.modules.diffusionmodules.model import Encoder, Decoder 7 | from ldm.modules.distributions.distributions import DiagonalGaussianDistribution 8 | 9 | from ldm.util import instantiate_from_config 10 | from ldm.modules.ema import LitEma 11 | 12 | 13 | class AutoencoderKL(pl.LightningModule): 14 | def __init__(self, 15 | ddconfig, 16 | lossconfig, 17 | embed_dim, 18 | ckpt_path=None, 19 | ignore_keys=[], 20 | image_key="image", 21 | colorize_nlabels=None, 22 | monitor=None, 23 | ema_decay=None, 24 | learn_logvar=False 25 | ): 26 | super().__init__() 27 | self.learn_logvar = learn_logvar 28 | self.image_key = image_key 29 | self.encoder = Encoder(**ddconfig) 30 | self.decoder = Decoder(**ddconfig) 31 | self.loss = instantiate_from_config(lossconfig) 32 | assert ddconfig["double_z"] 33 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) 34 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 35 | self.embed_dim = embed_dim 36 | if colorize_nlabels is not None: 37 | assert type(colorize_nlabels)==int 38 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 39 | if monitor is not None: 40 | self.monitor = monitor 41 | 42 | self.use_ema = ema_decay is not None 43 | if self.use_ema: 44 | self.ema_decay = ema_decay 45 | assert 0. < ema_decay < 1. 46 | self.model_ema = LitEma(self, decay=ema_decay) 47 | print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") 48 | 49 | if ckpt_path is not None: 50 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 51 | 52 | def init_from_ckpt(self, path, ignore_keys=list()): 53 | sd = torch.load(path, map_location="cpu")["state_dict"] 54 | keys = list(sd.keys()) 55 | for k in keys: 56 | for ik in ignore_keys: 57 | if k.startswith(ik): 58 | print("Deleting key {} from state_dict.".format(k)) 59 | del sd[k] 60 | self.load_state_dict(sd, strict=False) 61 | print(f"Restored from {path}") 62 | 63 | @contextmanager 64 | def ema_scope(self, context=None): 65 | if self.use_ema: 66 | self.model_ema.store(self.parameters()) 67 | self.model_ema.copy_to(self) 68 | if context is not None: 69 | print(f"{context}: Switched to EMA weights") 70 | try: 71 | yield None 72 | finally: 73 | if self.use_ema: 74 | self.model_ema.restore(self.parameters()) 75 | if context is not None: 76 | print(f"{context}: Restored training weights") 77 | 78 | def on_train_batch_end(self, *args, **kwargs): 79 | if self.use_ema: 80 | self.model_ema(self) 81 | 82 | def encode(self, x): 83 | h = self.encoder(x) 84 | moments = self.quant_conv(h) 85 | posterior = DiagonalGaussianDistribution(moments) 86 | return posterior 87 | 88 | def decode(self, z): 89 | z = self.post_quant_conv(z) 90 | dec = self.decoder(z) 91 | return dec 92 | 93 | def forward(self, input, sample_posterior=True): 94 | posterior = self.encode(input) 95 | if sample_posterior: 96 | z = posterior.sample() 97 | else: 98 | z = posterior.mode() 99 | dec = self.decode(z) 100 | return dec, posterior 101 | 102 | def get_input(self, batch, k): 103 | x = batch[k] 104 | if len(x.shape) == 3: 105 | x = x[..., None] 106 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() 107 | return x 108 | 109 | def training_step(self, batch, batch_idx, optimizer_idx): 110 | inputs = self.get_input(batch, self.image_key) 111 | reconstructions, posterior = self(inputs) 112 | 113 | if optimizer_idx == 0: 114 | # train encoder+decoder+logvar 115 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 116 | last_layer=self.get_last_layer(), split="train") 117 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 118 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) 119 | return aeloss 120 | 121 | if optimizer_idx == 1: 122 | # train the discriminator 123 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 124 | last_layer=self.get_last_layer(), split="train") 125 | 126 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 127 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) 128 | return discloss 129 | 130 | def validation_step(self, batch, batch_idx): 131 | log_dict = self._validation_step(batch, batch_idx) 132 | with self.ema_scope(): 133 | log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") 134 | return log_dict 135 | 136 | def _validation_step(self, batch, batch_idx, postfix=""): 137 | inputs = self.get_input(batch, self.image_key) 138 | reconstructions, posterior = self(inputs) 139 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, 140 | last_layer=self.get_last_layer(), split="val"+postfix) 141 | 142 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, 143 | last_layer=self.get_last_layer(), split="val"+postfix) 144 | 145 | self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) 146 | self.log_dict(log_dict_ae) 147 | self.log_dict(log_dict_disc) 148 | return self.log_dict 149 | 150 | def configure_optimizers(self): 151 | lr = self.learning_rate 152 | ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list( 153 | self.quant_conv.parameters()) + list(self.post_quant_conv.parameters()) 154 | if self.learn_logvar: 155 | print(f"{self.__class__.__name__}: Learning logvar") 156 | ae_params_list.append(self.loss.logvar) 157 | opt_ae = torch.optim.Adam(ae_params_list, 158 | lr=lr, betas=(0.5, 0.9)) 159 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), 160 | lr=lr, betas=(0.5, 0.9)) 161 | return [opt_ae, opt_disc], [] 162 | 163 | def get_last_layer(self): 164 | return self.decoder.conv_out.weight 165 | 166 | @torch.no_grad() 167 | def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): 168 | log = dict() 169 | x = self.get_input(batch, self.image_key) 170 | x = x.to(self.device) 171 | if not only_inputs: 172 | xrec, posterior = self(x) 173 | if x.shape[1] > 3: 174 | # colorize with random projection 175 | assert xrec.shape[1] > 3 176 | x = self.to_rgb(x) 177 | xrec = self.to_rgb(xrec) 178 | log["samples"] = self.decode(torch.randn_like(posterior.sample())) 179 | log["reconstructions"] = xrec 180 | if log_ema or self.use_ema: 181 | with self.ema_scope(): 182 | xrec_ema, posterior_ema = self(x) 183 | if x.shape[1] > 3: 184 | # colorize with random projection 185 | assert xrec_ema.shape[1] > 3 186 | xrec_ema = self.to_rgb(xrec_ema) 187 | log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample())) 188 | log["reconstructions_ema"] = xrec_ema 189 | log["inputs"] = x 190 | return log 191 | 192 | def to_rgb(self, x): 193 | assert self.image_key == "segmentation" 194 | if not hasattr(self, "colorize"): 195 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 196 | x = F.conv2d(x, weight=self.colorize) 197 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1. 198 | return x 199 | 200 | 201 | class IdentityFirstStage(torch.nn.Module): 202 | def __init__(self, *args, vq_interface=False, **kwargs): 203 | self.vq_interface = vq_interface 204 | super().__init__() 205 | 206 | def encode(self, x, *args, **kwargs): 207 | return x 208 | 209 | def decode(self, x, *args, **kwargs): 210 | return x 211 | 212 | def quantize(self, x, *args, **kwargs): 213 | if self.vq_interface: 214 | return x, None, [None, None, None] 215 | return x 216 | 217 | def forward(self, x, *args, **kwargs): 218 | return x 219 | 220 | -------------------------------------------------------------------------------- /ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler import DPMSolverSampler -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/sampler.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | import torch 3 | 4 | from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver 5 | 6 | 7 | MODEL_TYPES = { 8 | "eps": "noise", 9 | "v": "v" 10 | } 11 | 12 | 13 | class DPMSolverSampler(object): 14 | def __init__(self, model, **kwargs): 15 | super().__init__() 16 | self.model = model 17 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) 18 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) 19 | 20 | def register_buffer(self, name, attr): 21 | if type(attr) == torch.Tensor: 22 | if attr.device != torch.device("cuda"): 23 | attr = attr.to(torch.device("cuda")) 24 | setattr(self, name, attr) 25 | 26 | @torch.no_grad() 27 | def sample(self, 28 | S, 29 | batch_size, 30 | shape, 31 | conditioning=None, 32 | callback=None, 33 | normals_sequence=None, 34 | img_callback=None, 35 | quantize_x0=False, 36 | eta=0., 37 | mask=None, 38 | x0=None, 39 | temperature=1., 40 | noise_dropout=0., 41 | score_corrector=None, 42 | corrector_kwargs=None, 43 | verbose=True, 44 | x_T=None, 45 | log_every_t=100, 46 | unconditional_guidance_scale=1., 47 | unconditional_conditioning=None, 48 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 49 | **kwargs 50 | ): 51 | if conditioning is not None: 52 | if isinstance(conditioning, dict): 53 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 54 | if cbs != batch_size: 55 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 56 | else: 57 | if conditioning.shape[0] != batch_size: 58 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 59 | 60 | # sampling 61 | C, H, W = shape 62 | size = (batch_size, C, H, W) 63 | 64 | print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') 65 | 66 | device = self.model.betas.device 67 | if x_T is None: 68 | img = torch.randn(size, device=device) 69 | else: 70 | img = x_T 71 | 72 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) 73 | 74 | model_fn = model_wrapper( 75 | lambda x, t, c: self.model.apply_model(x, t, c), 76 | ns, 77 | model_type=MODEL_TYPES[self.model.parameterization], 78 | guidance_type="classifier-free", 79 | condition=conditioning, 80 | unconditional_condition=unconditional_conditioning, 81 | guidance_scale=unconditional_guidance_scale, 82 | ) 83 | 84 | dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) 85 | x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) 86 | 87 | return x.to(device), None -------------------------------------------------------------------------------- /ldm/models/diffusion/plms.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | from functools import partial 7 | 8 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like 9 | from ldm.models.diffusion.sampling_util import norm_thresholding 10 | 11 | 12 | class PLMSSampler(object): 13 | def __init__(self, model, schedule="linear", **kwargs): 14 | super().__init__() 15 | self.model = model 16 | self.ddpm_num_timesteps = model.num_timesteps 17 | self.schedule = schedule 18 | 19 | def register_buffer(self, name, attr): 20 | if type(attr) == torch.Tensor: 21 | if attr.device != torch.device("cuda"): 22 | attr = attr.to(torch.device("cuda")) 23 | setattr(self, name, attr) 24 | 25 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 26 | if ddim_eta != 0: 27 | raise ValueError('ddim_eta must be 0 for PLMS') 28 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 29 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 30 | alphas_cumprod = self.model.alphas_cumprod 31 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 32 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 33 | 34 | self.register_buffer('betas', to_torch(self.model.betas)) 35 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 36 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 37 | 38 | # calculations for diffusion q(x_t | x_{t-1}) and others 39 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 40 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 41 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 42 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 43 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 44 | 45 | # ddim sampling parameters 46 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 47 | ddim_timesteps=self.ddim_timesteps, 48 | eta=ddim_eta,verbose=verbose) 49 | self.register_buffer('ddim_sigmas', ddim_sigmas) 50 | self.register_buffer('ddim_alphas', ddim_alphas) 51 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 52 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 53 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 54 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 55 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 56 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 57 | 58 | @torch.no_grad() 59 | def sample(self, 60 | S, 61 | batch_size, 62 | shape, 63 | conditioning=None, 64 | callback=None, 65 | normals_sequence=None, 66 | img_callback=None, 67 | quantize_x0=False, 68 | eta=0., 69 | mask=None, 70 | x0=None, 71 | temperature=1., 72 | noise_dropout=0., 73 | score_corrector=None, 74 | corrector_kwargs=None, 75 | verbose=True, 76 | x_T=None, 77 | log_every_t=100, 78 | unconditional_guidance_scale=1., 79 | unconditional_conditioning=None, 80 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 81 | dynamic_threshold=None, 82 | **kwargs 83 | ): 84 | if conditioning is not None: 85 | if isinstance(conditioning, dict): 86 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 87 | if cbs != batch_size: 88 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 89 | else: 90 | if conditioning.shape[0] != batch_size: 91 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 92 | 93 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) 94 | # sampling 95 | C, H, W = shape 96 | size = (batch_size, C, H, W) 97 | print(f'Data shape for PLMS sampling is {size}') 98 | 99 | samples, intermediates = self.plms_sampling(conditioning, size, 100 | callback=callback, 101 | img_callback=img_callback, 102 | quantize_denoised=quantize_x0, 103 | mask=mask, x0=x0, 104 | ddim_use_original_steps=False, 105 | noise_dropout=noise_dropout, 106 | temperature=temperature, 107 | score_corrector=score_corrector, 108 | corrector_kwargs=corrector_kwargs, 109 | x_T=x_T, 110 | log_every_t=log_every_t, 111 | unconditional_guidance_scale=unconditional_guidance_scale, 112 | unconditional_conditioning=unconditional_conditioning, 113 | dynamic_threshold=dynamic_threshold, 114 | ) 115 | return samples, intermediates 116 | 117 | @torch.no_grad() 118 | def plms_sampling(self, cond, shape, 119 | x_T=None, ddim_use_original_steps=False, 120 | callback=None, timesteps=None, quantize_denoised=False, 121 | mask=None, x0=None, img_callback=None, log_every_t=100, 122 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 123 | unconditional_guidance_scale=1., unconditional_conditioning=None, 124 | dynamic_threshold=None): 125 | device = self.model.betas.device 126 | b = shape[0] 127 | if x_T is None: 128 | img = torch.randn(shape, device=device) 129 | else: 130 | img = x_T 131 | 132 | if timesteps is None: 133 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 134 | elif timesteps is not None and not ddim_use_original_steps: 135 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 136 | timesteps = self.ddim_timesteps[:subset_end] 137 | 138 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 139 | time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) 140 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 141 | print(f"Running PLMS Sampling with {total_steps} timesteps") 142 | 143 | iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) 144 | old_eps = [] 145 | 146 | for i, step in enumerate(iterator): 147 | index = total_steps - i - 1 148 | ts = torch.full((b,), step, device=device, dtype=torch.long) 149 | ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) 150 | 151 | if mask is not None: 152 | assert x0 is not None 153 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 154 | img = img_orig * mask + (1. - mask) * img 155 | 156 | outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 157 | quantize_denoised=quantize_denoised, temperature=temperature, 158 | noise_dropout=noise_dropout, score_corrector=score_corrector, 159 | corrector_kwargs=corrector_kwargs, 160 | unconditional_guidance_scale=unconditional_guidance_scale, 161 | unconditional_conditioning=unconditional_conditioning, 162 | old_eps=old_eps, t_next=ts_next, 163 | dynamic_threshold=dynamic_threshold) 164 | img, pred_x0, e_t = outs 165 | old_eps.append(e_t) 166 | if len(old_eps) >= 4: 167 | old_eps.pop(0) 168 | if callback: callback(i) 169 | if img_callback: img_callback(pred_x0, i) 170 | 171 | if index % log_every_t == 0 or index == total_steps - 1: 172 | intermediates['x_inter'].append(img) 173 | intermediates['pred_x0'].append(pred_x0) 174 | 175 | return img, intermediates 176 | 177 | @torch.no_grad() 178 | def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 179 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 180 | unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, 181 | dynamic_threshold=None): 182 | b, *_, device = *x.shape, x.device 183 | 184 | def get_model_output(x, t): 185 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 186 | e_t = self.model.apply_model(x, t, c) 187 | else: 188 | x_in = torch.cat([x] * 2) 189 | t_in = torch.cat([t] * 2) 190 | c_in = torch.cat([unconditional_conditioning, c]) 191 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) 192 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) 193 | 194 | if score_corrector is not None: 195 | assert self.model.parameterization == "eps" 196 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 197 | 198 | return e_t 199 | 200 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 201 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 202 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 203 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 204 | 205 | def get_x_prev_and_pred_x0(e_t, index): 206 | # select parameters corresponding to the currently considered timestep 207 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) 208 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) 209 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) 210 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) 211 | 212 | # current prediction for x_0 213 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 214 | if quantize_denoised: 215 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 216 | if dynamic_threshold is not None: 217 | pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) 218 | # direction pointing to x_t 219 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 220 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 221 | if noise_dropout > 0.: 222 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 223 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 224 | return x_prev, pred_x0 225 | 226 | e_t = get_model_output(x, t) 227 | if len(old_eps) == 0: 228 | # Pseudo Improved Euler (2nd order) 229 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) 230 | e_t_next = get_model_output(x_prev, t_next) 231 | e_t_prime = (e_t + e_t_next) / 2 232 | elif len(old_eps) == 1: 233 | # 2nd order Pseudo Linear Multistep (Adams-Bashforth) 234 | e_t_prime = (3 * e_t - old_eps[-1]) / 2 235 | elif len(old_eps) == 2: 236 | # 3nd order Pseudo Linear Multistep (Adams-Bashforth) 237 | e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 238 | elif len(old_eps) >= 3: 239 | # 4nd order Pseudo Linear Multistep (Adams-Bashforth) 240 | e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 241 | 242 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) 243 | 244 | return x_prev, pred_x0, e_t 245 | -------------------------------------------------------------------------------- /ldm/models/diffusion/sampling_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def append_dims(x, target_dims): 6 | """Appends dimensions to the end of a tensor until it has target_dims dimensions. 7 | From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" 8 | dims_to_append = target_dims - x.ndim 9 | if dims_to_append < 0: 10 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') 11 | return x[(...,) + (None,) * dims_to_append] 12 | 13 | 14 | def norm_thresholding(x0, value): 15 | s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) 16 | return x0 * (value / s) 17 | 18 | 19 | def spatial_norm_thresholding(x0, value): 20 | # b c h w 21 | s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) 22 | return x0 * (value / s) -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/upscaling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from functools import partial 5 | 6 | from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule 7 | from ldm.util import default 8 | 9 | 10 | class AbstractLowScaleModel(nn.Module): 11 | # for concatenating a downsampled image to the latent representation 12 | def __init__(self, noise_schedule_config=None): 13 | super(AbstractLowScaleModel, self).__init__() 14 | if noise_schedule_config is not None: 15 | self.register_schedule(**noise_schedule_config) 16 | 17 | def register_schedule(self, beta_schedule="linear", timesteps=1000, 18 | linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 19 | betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, 20 | cosine_s=cosine_s) 21 | alphas = 1. - betas 22 | alphas_cumprod = np.cumprod(alphas, axis=0) 23 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) 24 | 25 | timesteps, = betas.shape 26 | self.num_timesteps = int(timesteps) 27 | self.linear_start = linear_start 28 | self.linear_end = linear_end 29 | assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' 30 | 31 | to_torch = partial(torch.tensor, dtype=torch.float32) 32 | 33 | self.register_buffer('betas', to_torch(betas)) 34 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 35 | self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) 36 | 37 | # calculations for diffusion q(x_t | x_{t-1}) and others 38 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) 39 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) 40 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) 41 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) 42 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) 43 | 44 | def q_sample(self, x_start, t, noise=None): 45 | noise = default(noise, lambda: torch.randn_like(x_start)) 46 | return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 47 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) 48 | 49 | def forward(self, x): 50 | return x, None 51 | 52 | def decode(self, x): 53 | return x 54 | 55 | 56 | class SimpleImageConcat(AbstractLowScaleModel): 57 | # no noise level conditioning 58 | def __init__(self): 59 | super(SimpleImageConcat, self).__init__(noise_schedule_config=None) 60 | self.max_noise_level = 0 61 | 62 | def forward(self, x): 63 | # fix to constant noise level 64 | return x, torch.zeros(x.shape[0], device=x.device).long() 65 | 66 | 67 | class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel): 68 | def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False): 69 | super().__init__(noise_schedule_config=noise_schedule_config) 70 | self.max_noise_level = max_noise_level 71 | 72 | def forward(self, x, noise_level=None): 73 | if noise_level is None: 74 | noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() 75 | else: 76 | assert isinstance(noise_level, torch.Tensor) 77 | z = self.q_sample(x, noise_level) 78 | return z, noise_level 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/util.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import os 12 | import math 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | from einops import repeat 17 | 18 | from ldm.util import instantiate_from_config 19 | import pdb 20 | 21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 22 | if schedule == "linear": 23 | betas = (torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2) 24 | 25 | elif schedule == "cosine": 26 | timesteps = ( 27 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 28 | ) 29 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 30 | alphas = torch.cos(alphas).pow(2) 31 | alphas = alphas / alphas[0] 32 | betas = 1 - alphas[1:] / alphas[:-1] 33 | betas = np.clip(betas, a_min=0, a_max=0.999) 34 | 35 | elif schedule == "sqrt_linear": 36 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 37 | elif schedule == "sqrt": 38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 39 | else: 40 | raise ValueError(f"schedule '{schedule}' unknown.") 41 | return betas.numpy() 42 | 43 | 44 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 45 | if ddim_discr_method == 'uniform': 46 | c = num_ddpm_timesteps // num_ddim_timesteps 47 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 48 | elif ddim_discr_method == 'quad': 49 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 50 | else: 51 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 52 | 53 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 54 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 55 | steps_out = ddim_timesteps + 1 56 | if verbose: 57 | print(f'Selected timesteps for ddim sampler: {steps_out}') 58 | return steps_out 59 | 60 | 61 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 62 | # select alphas for computing the variance schedule 63 | alphas = alphacums[ddim_timesteps] 64 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 65 | 66 | # according the the formula provided in https://arxiv.org/abs/2010.02502 67 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 68 | if verbose: 69 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 70 | print(f'For the chosen value of eta, which is {eta}, ' 71 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 72 | return sigmas, alphas, alphas_prev 73 | 74 | 75 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 76 | """ 77 | Create a beta schedule that discretizes the given alpha_t_bar function, 78 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 79 | :param num_diffusion_timesteps: the number of betas to produce. 80 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 81 | produces the cumulative product of (1-beta) up to that 82 | part of the diffusion process. 83 | :param max_beta: the maximum beta to use; use values lower than 1 to 84 | prevent singularities. 85 | """ 86 | betas = [] 87 | for i in range(num_diffusion_timesteps): 88 | t1 = i / num_diffusion_timesteps 89 | t2 = (i + 1) / num_diffusion_timesteps 90 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 91 | return np.array(betas) 92 | 93 | 94 | def extract_into_tensor(a, t, x_shape): 95 | b, *_ = t.shape 96 | out = a.gather(-1, t) 97 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 98 | 99 | 100 | def checkpoint(func, inputs, params, flag): 101 | """ 102 | Evaluate a function without caching intermediate activations, allowing for 103 | reduced memory at the expense of extra compute in the backward pass. 104 | :param func: the function to evaluate. 105 | :param inputs: the argument sequence to pass to `func`. 106 | :param params: a sequence of parameters `func` depends on but does not 107 | explicitly take as arguments. 108 | :param flag: if False, disable gradient checkpointing. 109 | """ 110 | if flag: 111 | args = tuple(inputs) + tuple(params) 112 | return CheckpointFunction.apply(func, len(inputs), *args) 113 | else: 114 | return func(*inputs) 115 | 116 | 117 | class CheckpointFunction(torch.autograd.Function): 118 | @staticmethod 119 | def forward(ctx, run_function, length, *args): 120 | ctx.run_function = run_function 121 | ctx.input_tensors = list(args[:length]) 122 | ctx.input_params = list(args[length:]) 123 | ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), 124 | "dtype": torch.get_autocast_gpu_dtype(), 125 | "cache_enabled": torch.is_autocast_cache_enabled()} 126 | with torch.no_grad(): 127 | output_tensors = ctx.run_function(*ctx.input_tensors) 128 | return output_tensors 129 | 130 | @staticmethod 131 | def backward(ctx, *output_grads): 132 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 133 | with torch.enable_grad(), \ 134 | torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): 135 | # Fixes a bug where the first op in run_function modifies the 136 | # Tensor storage in place, which is not allowed for detach()'d 137 | # Tensors. 138 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 139 | output_tensors = ctx.run_function(*shallow_copies) 140 | input_grads = torch.autograd.grad( 141 | output_tensors, 142 | ctx.input_tensors + ctx.input_params, 143 | output_grads, 144 | allow_unused=True, 145 | ) 146 | del ctx.input_tensors 147 | del ctx.input_params 148 | del output_tensors 149 | return (None, None) + input_grads 150 | 151 | 152 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 153 | """ 154 | Create sinusoidal timestep embeddings. 155 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 156 | These may be fractional. 157 | :param dim: the dimension of the output. 158 | :param max_period: controls the minimum frequency of the embeddings. 159 | :return: an [N x dim] Tensor of positional embeddings. 160 | """ 161 | if not repeat_only: 162 | half = dim // 2 163 | freqs = torch.exp( 164 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 165 | ).to(device=timesteps.device) 166 | 167 | # fp16 168 | # freqs = torch.exp( 169 | # -math.log(max_period) * torch.arange(start=0, end=half, device=timesteps.device, dtype=torch.float16) / half 170 | # ) 171 | 172 | args = timesteps[:, None].float() * freqs[None] 173 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 174 | if dim % 2: 175 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 176 | else: 177 | embedding = repeat(timesteps, 'b -> b d', d=dim) 178 | return embedding 179 | 180 | 181 | def zero_module(module): 182 | """ 183 | Zero out the parameters of a module and return it. 184 | """ 185 | for p in module.parameters(): 186 | p.detach().zero_() 187 | return module 188 | 189 | 190 | def scale_module(module, scale): 191 | """ 192 | Scale the parameters of a module and return it. 193 | """ 194 | for p in module.parameters(): 195 | p.detach().mul_(scale) 196 | return module 197 | 198 | 199 | def mean_flat(tensor): 200 | """ 201 | Take the mean over all non-batch dimensions. 202 | """ 203 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 204 | 205 | 206 | def normalization(channels): 207 | """ 208 | Make a standard normalization layer. 209 | :param channels: number of input channels. 210 | :return: an nn.Module for normalization. 211 | """ 212 | return GroupNorm32(32, channels) 213 | 214 | 215 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 216 | class SiLU(nn.Module): 217 | def forward(self, x): 218 | return x * torch.sigmoid(x) 219 | 220 | 221 | class GroupNorm32(nn.GroupNorm): 222 | def forward(self, x): 223 | return super().forward(x.float()).type(x.dtype) 224 | 225 | def conv_nd(dims, *args, **kwargs): 226 | """ 227 | Create a 1D, 2D, or 3D convolution module. 228 | """ 229 | if dims == 1: 230 | return nn.Conv1d(*args, **kwargs) 231 | elif dims == 2: 232 | return nn.Conv2d(*args, **kwargs) 233 | elif dims == 3: 234 | return nn.Conv3d(*args, **kwargs) 235 | raise ValueError(f"unsupported dimensions: {dims}") 236 | 237 | 238 | def linear(*args, **kwargs): 239 | """ 240 | Create a linear module. 241 | """ 242 | return nn.Linear(*args, **kwargs) 243 | 244 | 245 | def avg_pool_nd(dims, *args, **kwargs): 246 | """ 247 | Create a 1D, 2D, or 3D average pooling module. 248 | """ 249 | if dims == 1: 250 | return nn.AvgPool1d(*args, **kwargs) 251 | elif dims == 2: 252 | return nn.AvgPool2d(*args, **kwargs) 253 | elif dims == 3: 254 | return nn.AvgPool3d(*args, **kwargs) 255 | raise ValueError(f"unsupported dimensions: {dims}") 256 | 257 | 258 | class HybridConditioner(nn.Module): 259 | 260 | def __init__(self, c_concat_config, c_crossattn_config): 261 | super().__init__() 262 | self.concat_conditioner = instantiate_from_config(c_concat_config) 263 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 264 | 265 | def forward(self, c_concat, c_crossattn): 266 | c_concat = self.concat_conditioner(c_concat) 267 | c_crossattn = self.crossattn_conditioner(c_crossattn) 268 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} 269 | 270 | 271 | def noise_like(shape, device, repeat=False): 272 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 273 | noise = lambda: torch.randn(shape, device=device) 274 | return repeat_noise() if repeat else noise() -------------------------------------------------------------------------------- /ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1, dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | # remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.', '') 20 | self.m_name2s_name.update({name: s_name}) 21 | self.register_buffer(s_name, p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def reset_num_updates(self): 26 | del self.num_updates 27 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int)) 28 | 29 | def forward(self, model): 30 | decay = self.decay 31 | 32 | if self.num_updates >= 0: 33 | self.num_updates += 1 34 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 35 | 36 | one_minus_decay = 1.0 - decay 37 | 38 | with torch.no_grad(): 39 | m_param = dict(model.named_parameters()) 40 | shadow_params = dict(self.named_buffers()) 41 | 42 | for key in m_param: 43 | if m_param[key].requires_grad: 44 | sname = self.m_name2s_name[key] 45 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 46 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 47 | else: 48 | assert not key in self.m_name2s_name 49 | 50 | def copy_to(self, model): 51 | m_param = dict(model.named_parameters()) 52 | shadow_params = dict(self.named_buffers()) 53 | for key in m_param: 54 | if m_param[key].requires_grad: 55 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 56 | else: 57 | assert not key in self.m_name2s_name 58 | 59 | def store(self, parameters): 60 | """ 61 | Save the current parameters for restoring later. 62 | Args: 63 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 64 | temporarily stored. 65 | """ 66 | self.collected_params = [param.clone() for param in parameters] 67 | 68 | def restore(self, parameters): 69 | """ 70 | Restore the parameters stored with the `store` method. 71 | Useful to validate the model with EMA parameters without affecting the 72 | original optimization process. Store the parameters before the 73 | `copy_to` method. After validation (or model saving), use this to 74 | restore the former parameters. 75 | Args: 76 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 77 | updated with the stored parameters. 78 | """ 79 | for c_param, param in zip(self.collected_params, parameters): 80 | param.data.copy_(c_param.data) 81 | -------------------------------------------------------------------------------- /ldm/modules/embedding_manager.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from ldm.data.personalized import per_img_token_list 5 | from transformers import CLIPTokenizer 6 | from functools import partial 7 | 8 | DEFAULT_PLACEHOLDER_TOKEN = ["*"] 9 | 10 | PROGRESSIVE_SCALE = 2000 11 | 12 | def get_clip_token_for_string(tokenizer, string): 13 | batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True, 14 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 15 | tokens = batch_encoding["input_ids"] 16 | assert torch.count_nonzero(tokens - 49407) == 2, f"String '{string}' maps to more than a single token. Please use another string" 17 | 18 | return tokens[0, 1] 19 | 20 | def get_bert_token_for_string(tokenizer, string): 21 | token = tokenizer(string) 22 | assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string" 23 | 24 | token = token[0, 1] 25 | 26 | return token 27 | 28 | def get_embedding_for_clip_token(embedder, token): 29 | return embedder(token.unsqueeze(0))[0, 0] 30 | 31 | 32 | class EmbeddingManager(nn.Module): 33 | def __init__( 34 | self, 35 | embedder, 36 | placeholder_strings=None, 37 | initializer_words=None, 38 | per_image_tokens=False, 39 | num_vectors_per_token=1, 40 | progressive_words=False, 41 | **kwargs 42 | ): 43 | super().__init__() 44 | 45 | self.string_to_token_dict = {} 46 | 47 | self.string_to_param_dict = nn.ParameterDict() 48 | 49 | self.initial_embeddings = nn.ParameterDict() # These should not be optimized 50 | 51 | self.progressive_words = progressive_words 52 | self.progressive_counter = 0 53 | 54 | self.max_vectors_per_token = num_vectors_per_token 55 | 56 | if hasattr(embedder, 'tokenizer'): # using Stable Diffusion's CLIP encoder 57 | self.is_clip = True 58 | get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer) 59 | get_embedding_for_tkn = partial(get_embedding_for_clip_token, embedder.transformer.text_model.embeddings) 60 | token_dim = 768 61 | else: # using LDM's BERT encoder 62 | self.is_clip = False 63 | get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn) 64 | get_embedding_for_tkn = embedder.transformer.token_emb 65 | token_dim = 1280 66 | 67 | if per_image_tokens: 68 | placeholder_strings.extend(per_img_token_list) 69 | 70 | for idx, placeholder_string in enumerate(placeholder_strings): 71 | 72 | token = get_token_for_string(placeholder_string) 73 | 74 | if initializer_words and idx < len(initializer_words): 75 | init_word_token = get_token_for_string(initializer_words[idx]) 76 | 77 | with torch.no_grad(): 78 | init_word_embedding = get_embedding_for_tkn(init_word_token.cpu()) 79 | 80 | token_params = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=True) 81 | self.initial_embeddings[placeholder_string] = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=False) 82 | else: 83 | token_params = torch.nn.Parameter(torch.rand(size=(num_vectors_per_token, token_dim), requires_grad=True)) 84 | 85 | self.string_to_token_dict[placeholder_string] = token 86 | self.string_to_param_dict[placeholder_string] = token_params 87 | 88 | def forward( 89 | self, 90 | tokenized_text, 91 | embedded_text, 92 | ): 93 | b, n, device = *tokenized_text.shape, tokenized_text.device 94 | 95 | for placeholder_string, placeholder_token in self.string_to_token_dict.items(): 96 | 97 | placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device) 98 | 99 | if self.max_vectors_per_token == 1: # If there's only one vector per token, we can do a simple replacement 100 | placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device)) 101 | embedded_text[placeholder_idx] = placeholder_embedding 102 | else: # otherwise, need to insert and keep track of changing indices 103 | if self.progressive_words: 104 | self.progressive_counter += 1 105 | max_step_tokens = 1 + self.progressive_counter // PROGRESSIVE_SCALE 106 | else: 107 | max_step_tokens = self.max_vectors_per_token 108 | 109 | num_vectors_for_token = min(placeholder_embedding.shape[0], max_step_tokens) 110 | 111 | placeholder_rows, placeholder_cols = torch.where(tokenized_text == placeholder_token.to(device)) 112 | 113 | if placeholder_rows.nelement() == 0: 114 | continue 115 | 116 | sorted_cols, sort_idx = torch.sort(placeholder_cols, descending=True) 117 | sorted_rows = placeholder_rows[sort_idx] 118 | 119 | for idx in range(len(sorted_rows)): 120 | row = sorted_rows[idx] 121 | col = sorted_cols[idx] 122 | 123 | new_token_row = torch.cat([tokenized_text[row][:col], placeholder_token.repeat(num_vectors_for_token).to(device), tokenized_text[row][col + 1:]], axis=0)[:n] 124 | new_embed_row = torch.cat([embedded_text[row][:col], placeholder_embedding[:num_vectors_for_token], embedded_text[row][col + 1:]], axis=0)[:n] 125 | 126 | embedded_text[row] = new_embed_row 127 | tokenized_text[row] = new_token_row 128 | 129 | return embedded_text 130 | 131 | def save(self, ckpt_path): 132 | torch.save({"string_to_token": self.string_to_token_dict, 133 | "string_to_param": self.string_to_param_dict}, ckpt_path) 134 | 135 | def load(self, ckpt_path): 136 | ckpt = torch.load(ckpt_path, map_location='cpu') 137 | 138 | self.string_to_token_dict = ckpt["string_to_token"] 139 | self.string_to_param_dict = ckpt["string_to_param"] 140 | 141 | def get_embedding_norms_squared(self): 142 | all_params = torch.cat(list(self.string_to_param_dict.values()), axis=0) # num_placeholders x embedding_dim 143 | param_norm_squared = (all_params * all_params).sum(axis=-1) # num_placeholders 144 | 145 | return param_norm_squared 146 | 147 | def embedding_parameters(self): 148 | return self.string_to_param_dict.parameters() 149 | 150 | def embedding_to_coarse_loss(self): 151 | 152 | loss = 0. 153 | num_embeddings = len(self.initial_embeddings) 154 | 155 | for key in self.initial_embeddings: 156 | optimized = self.string_to_param_dict[key] 157 | coarse = self.initial_embeddings[key].clone().to(optimized.device) 158 | 159 | loss = loss + (optimized - coarse) @ (optimized - coarse).T / num_embeddings 160 | 161 | return loss -------------------------------------------------------------------------------- /ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /ldm/modules/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/ldm/modules/midas/__init__.py -------------------------------------------------------------------------------- /ldm/modules/midas/api.py: -------------------------------------------------------------------------------- 1 | # based on https://github.com/isl-org/MiDaS 2 | 3 | import cv2 4 | import torch 5 | import torch.nn as nn 6 | from torchvision.transforms import Compose 7 | 8 | from ldm.modules.midas.midas.dpt_depth import DPTDepthModel 9 | from ldm.modules.midas.midas.midas_net import MidasNet 10 | from ldm.modules.midas.midas.midas_net_custom import MidasNet_small 11 | from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet 12 | 13 | 14 | ISL_PATHS = { 15 | "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt", 16 | "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt", 17 | "midas_v21": "", 18 | "midas_v21_small": "", 19 | } 20 | 21 | 22 | def disabled_train(self, mode=True): 23 | """Overwrite model.train with this function to make sure train/eval mode 24 | does not change anymore.""" 25 | return self 26 | 27 | 28 | def load_midas_transform(model_type): 29 | # https://github.com/isl-org/MiDaS/blob/master/run.py 30 | # load transform only 31 | if model_type == "dpt_large": # DPT-Large 32 | net_w, net_h = 384, 384 33 | resize_mode = "minimal" 34 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 35 | 36 | elif model_type == "dpt_hybrid": # DPT-Hybrid 37 | net_w, net_h = 384, 384 38 | resize_mode = "minimal" 39 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 40 | 41 | elif model_type == "midas_v21": 42 | net_w, net_h = 384, 384 43 | resize_mode = "upper_bound" 44 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 45 | 46 | elif model_type == "midas_v21_small": 47 | net_w, net_h = 256, 256 48 | resize_mode = "upper_bound" 49 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 50 | 51 | else: 52 | assert False, f"model_type '{model_type}' not implemented, use: --model_type large" 53 | 54 | transform = Compose( 55 | [ 56 | Resize( 57 | net_w, 58 | net_h, 59 | resize_target=None, 60 | keep_aspect_ratio=True, 61 | ensure_multiple_of=32, 62 | resize_method=resize_mode, 63 | image_interpolation_method=cv2.INTER_CUBIC, 64 | ), 65 | normalization, 66 | PrepareForNet(), 67 | ] 68 | ) 69 | 70 | return transform 71 | 72 | 73 | def load_model(model_type): 74 | # https://github.com/isl-org/MiDaS/blob/master/run.py 75 | # load network 76 | model_path = ISL_PATHS[model_type] 77 | if model_type == "dpt_large": # DPT-Large 78 | model = DPTDepthModel( 79 | path=model_path, 80 | backbone="vitl16_384", 81 | non_negative=True, 82 | ) 83 | net_w, net_h = 384, 384 84 | resize_mode = "minimal" 85 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 86 | 87 | elif model_type == "dpt_hybrid": # DPT-Hybrid 88 | model = DPTDepthModel( 89 | path=model_path, 90 | backbone="vitb_rn50_384", 91 | non_negative=True, 92 | ) 93 | net_w, net_h = 384, 384 94 | resize_mode = "minimal" 95 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 96 | 97 | elif model_type == "midas_v21": 98 | model = MidasNet(model_path, non_negative=True) 99 | net_w, net_h = 384, 384 100 | resize_mode = "upper_bound" 101 | normalization = NormalizeImage( 102 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 103 | ) 104 | 105 | elif model_type == "midas_v21_small": 106 | model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, 107 | non_negative=True, blocks={'expand': True}) 108 | net_w, net_h = 256, 256 109 | resize_mode = "upper_bound" 110 | normalization = NormalizeImage( 111 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 112 | ) 113 | 114 | else: 115 | print(f"model_type '{model_type}' not implemented, use: --model_type large") 116 | assert False 117 | 118 | transform = Compose( 119 | [ 120 | Resize( 121 | net_w, 122 | net_h, 123 | resize_target=None, 124 | keep_aspect_ratio=True, 125 | ensure_multiple_of=32, 126 | resize_method=resize_mode, 127 | image_interpolation_method=cv2.INTER_CUBIC, 128 | ), 129 | normalization, 130 | PrepareForNet(), 131 | ] 132 | ) 133 | 134 | return model.eval(), transform 135 | 136 | 137 | class MiDaSInference(nn.Module): 138 | MODEL_TYPES_TORCH_HUB = [ 139 | "DPT_Large", 140 | "DPT_Hybrid", 141 | "MiDaS_small" 142 | ] 143 | MODEL_TYPES_ISL = [ 144 | "dpt_large", 145 | "dpt_hybrid", 146 | "midas_v21", 147 | "midas_v21_small", 148 | ] 149 | 150 | def __init__(self, model_type): 151 | super().__init__() 152 | assert (model_type in self.MODEL_TYPES_ISL) 153 | model, _ = load_model(model_type) 154 | self.model = model 155 | self.model.train = disabled_train 156 | 157 | def forward(self, x): 158 | # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array 159 | # NOTE: we expect that the correct transform has been called during dataloading. 160 | with torch.no_grad(): 161 | prediction = self.model(x) 162 | prediction = torch.nn.functional.interpolate( 163 | prediction.unsqueeze(1), 164 | size=x.shape[2:], 165 | mode="bicubic", 166 | align_corners=False, 167 | ) 168 | assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3]) 169 | return prediction 170 | 171 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/ldm/modules/midas/midas/__init__.py -------------------------------------------------------------------------------- /ldm/modules/midas/midas/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseModel(torch.nn.Module): 5 | def load(self, path): 6 | """Load model from file. 7 | 8 | Args: 9 | path (str): file path 10 | """ 11 | parameters = torch.load(path, map_location=torch.device('cpu')) 12 | 13 | if "optimizer" in parameters: 14 | parameters = parameters["model"] 15 | 16 | self.load_state_dict(parameters) 17 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .vit import ( 5 | _make_pretrained_vitb_rn50_384, 6 | _make_pretrained_vitl16_384, 7 | _make_pretrained_vitb16_384, 8 | forward_vit, 9 | ) 10 | 11 | def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): 12 | if backbone == "vitl16_384": 13 | pretrained = _make_pretrained_vitl16_384( 14 | use_pretrained, hooks=hooks, use_readout=use_readout 15 | ) 16 | scratch = _make_scratch( 17 | [256, 512, 1024, 1024], features, groups=groups, expand=expand 18 | ) # ViT-L/16 - 85.0% Top1 (backbone) 19 | elif backbone == "vitb_rn50_384": 20 | pretrained = _make_pretrained_vitb_rn50_384( 21 | use_pretrained, 22 | hooks=hooks, 23 | use_vit_only=use_vit_only, 24 | use_readout=use_readout, 25 | ) 26 | scratch = _make_scratch( 27 | [256, 512, 768, 768], features, groups=groups, expand=expand 28 | ) # ViT-H/16 - 85.0% Top1 (backbone) 29 | elif backbone == "vitb16_384": 30 | pretrained = _make_pretrained_vitb16_384( 31 | use_pretrained, hooks=hooks, use_readout=use_readout 32 | ) 33 | scratch = _make_scratch( 34 | [96, 192, 384, 768], features, groups=groups, expand=expand 35 | ) # ViT-B/16 - 84.6% Top1 (backbone) 36 | elif backbone == "resnext101_wsl": 37 | pretrained = _make_pretrained_resnext101_wsl(use_pretrained) 38 | scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 39 | elif backbone == "efficientnet_lite3": 40 | pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) 41 | scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 42 | else: 43 | print(f"Backbone '{backbone}' not implemented") 44 | assert False 45 | 46 | return pretrained, scratch 47 | 48 | 49 | def _make_scratch(in_shape, out_shape, groups=1, expand=False): 50 | scratch = nn.Module() 51 | 52 | out_shape1 = out_shape 53 | out_shape2 = out_shape 54 | out_shape3 = out_shape 55 | out_shape4 = out_shape 56 | if expand==True: 57 | out_shape1 = out_shape 58 | out_shape2 = out_shape*2 59 | out_shape3 = out_shape*4 60 | out_shape4 = out_shape*8 61 | 62 | scratch.layer1_rn = nn.Conv2d( 63 | in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 64 | ) 65 | scratch.layer2_rn = nn.Conv2d( 66 | in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 67 | ) 68 | scratch.layer3_rn = nn.Conv2d( 69 | in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 70 | ) 71 | scratch.layer4_rn = nn.Conv2d( 72 | in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 73 | ) 74 | 75 | return scratch 76 | 77 | 78 | def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): 79 | efficientnet = torch.hub.load( 80 | "rwightman/gen-efficientnet-pytorch", 81 | "tf_efficientnet_lite3", 82 | pretrained=use_pretrained, 83 | exportable=exportable 84 | ) 85 | return _make_efficientnet_backbone(efficientnet) 86 | 87 | 88 | def _make_efficientnet_backbone(effnet): 89 | pretrained = nn.Module() 90 | 91 | pretrained.layer1 = nn.Sequential( 92 | effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] 93 | ) 94 | pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) 95 | pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) 96 | pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) 97 | 98 | return pretrained 99 | 100 | 101 | def _make_resnet_backbone(resnet): 102 | pretrained = nn.Module() 103 | pretrained.layer1 = nn.Sequential( 104 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 105 | ) 106 | 107 | pretrained.layer2 = resnet.layer2 108 | pretrained.layer3 = resnet.layer3 109 | pretrained.layer4 = resnet.layer4 110 | 111 | return pretrained 112 | 113 | 114 | def _make_pretrained_resnext101_wsl(use_pretrained): 115 | resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") 116 | return _make_resnet_backbone(resnet) 117 | 118 | 119 | 120 | class Interpolate(nn.Module): 121 | """Interpolation module. 122 | """ 123 | 124 | def __init__(self, scale_factor, mode, align_corners=False): 125 | """Init. 126 | 127 | Args: 128 | scale_factor (float): scaling 129 | mode (str): interpolation mode 130 | """ 131 | super(Interpolate, self).__init__() 132 | 133 | self.interp = nn.functional.interpolate 134 | self.scale_factor = scale_factor 135 | self.mode = mode 136 | self.align_corners = align_corners 137 | 138 | def forward(self, x): 139 | """Forward pass. 140 | 141 | Args: 142 | x (tensor): input 143 | 144 | Returns: 145 | tensor: interpolated data 146 | """ 147 | 148 | x = self.interp( 149 | x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners 150 | ) 151 | 152 | return x 153 | 154 | 155 | class ResidualConvUnit(nn.Module): 156 | """Residual convolution module. 157 | """ 158 | 159 | def __init__(self, features): 160 | """Init. 161 | 162 | Args: 163 | features (int): number of features 164 | """ 165 | super().__init__() 166 | 167 | self.conv1 = nn.Conv2d( 168 | features, features, kernel_size=3, stride=1, padding=1, bias=True 169 | ) 170 | 171 | self.conv2 = nn.Conv2d( 172 | features, features, kernel_size=3, stride=1, padding=1, bias=True 173 | ) 174 | 175 | self.relu = nn.ReLU(inplace=True) 176 | 177 | def forward(self, x): 178 | """Forward pass. 179 | 180 | Args: 181 | x (tensor): input 182 | 183 | Returns: 184 | tensor: output 185 | """ 186 | out = self.relu(x) 187 | out = self.conv1(out) 188 | out = self.relu(out) 189 | out = self.conv2(out) 190 | 191 | return out + x 192 | 193 | 194 | class FeatureFusionBlock(nn.Module): 195 | """Feature fusion block. 196 | """ 197 | 198 | def __init__(self, features): 199 | """Init. 200 | 201 | Args: 202 | features (int): number of features 203 | """ 204 | super(FeatureFusionBlock, self).__init__() 205 | 206 | self.resConfUnit1 = ResidualConvUnit(features) 207 | self.resConfUnit2 = ResidualConvUnit(features) 208 | 209 | def forward(self, *xs): 210 | """Forward pass. 211 | 212 | Returns: 213 | tensor: output 214 | """ 215 | output = xs[0] 216 | 217 | if len(xs) == 2: 218 | output += self.resConfUnit1(xs[1]) 219 | 220 | output = self.resConfUnit2(output) 221 | 222 | output = nn.functional.interpolate( 223 | output, scale_factor=2, mode="bilinear", align_corners=True 224 | ) 225 | 226 | return output 227 | 228 | 229 | 230 | 231 | class ResidualConvUnit_custom(nn.Module): 232 | """Residual convolution module. 233 | """ 234 | 235 | def __init__(self, features, activation, bn): 236 | """Init. 237 | 238 | Args: 239 | features (int): number of features 240 | """ 241 | super().__init__() 242 | 243 | self.bn = bn 244 | 245 | self.groups=1 246 | 247 | self.conv1 = nn.Conv2d( 248 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups 249 | ) 250 | 251 | self.conv2 = nn.Conv2d( 252 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups 253 | ) 254 | 255 | if self.bn==True: 256 | self.bn1 = nn.BatchNorm2d(features) 257 | self.bn2 = nn.BatchNorm2d(features) 258 | 259 | self.activation = activation 260 | 261 | self.skip_add = nn.quantized.FloatFunctional() 262 | 263 | def forward(self, x): 264 | """Forward pass. 265 | 266 | Args: 267 | x (tensor): input 268 | 269 | Returns: 270 | tensor: output 271 | """ 272 | 273 | out = self.activation(x) 274 | out = self.conv1(out) 275 | if self.bn==True: 276 | out = self.bn1(out) 277 | 278 | out = self.activation(out) 279 | out = self.conv2(out) 280 | if self.bn==True: 281 | out = self.bn2(out) 282 | 283 | if self.groups > 1: 284 | out = self.conv_merge(out) 285 | 286 | return self.skip_add.add(out, x) 287 | 288 | # return out + x 289 | 290 | 291 | class FeatureFusionBlock_custom(nn.Module): 292 | """Feature fusion block. 293 | """ 294 | 295 | def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): 296 | """Init. 297 | 298 | Args: 299 | features (int): number of features 300 | """ 301 | super(FeatureFusionBlock_custom, self).__init__() 302 | 303 | self.deconv = deconv 304 | self.align_corners = align_corners 305 | 306 | self.groups=1 307 | 308 | self.expand = expand 309 | out_features = features 310 | if self.expand==True: 311 | out_features = features//2 312 | 313 | self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) 314 | 315 | self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) 316 | self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) 317 | 318 | self.skip_add = nn.quantized.FloatFunctional() 319 | 320 | def forward(self, *xs): 321 | """Forward pass. 322 | 323 | Returns: 324 | tensor: output 325 | """ 326 | output = xs[0] 327 | 328 | if len(xs) == 2: 329 | res = self.resConfUnit1(xs[1]) 330 | output = self.skip_add.add(output, res) 331 | # output += res 332 | 333 | output = self.resConfUnit2(output) 334 | 335 | output = nn.functional.interpolate( 336 | output, scale_factor=2, mode="bilinear", align_corners=self.align_corners 337 | ) 338 | 339 | output = self.out_conv(output) 340 | 341 | return output 342 | 343 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/dpt_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base_model import BaseModel 6 | from .blocks import ( 7 | FeatureFusionBlock, 8 | FeatureFusionBlock_custom, 9 | Interpolate, 10 | _make_encoder, 11 | forward_vit, 12 | ) 13 | 14 | 15 | def _make_fusion_block(features, use_bn): 16 | return FeatureFusionBlock_custom( 17 | features, 18 | nn.ReLU(False), 19 | deconv=False, 20 | bn=use_bn, 21 | expand=False, 22 | align_corners=True, 23 | ) 24 | 25 | 26 | class DPT(BaseModel): 27 | def __init__( 28 | self, 29 | head, 30 | features=256, 31 | backbone="vitb_rn50_384", 32 | readout="project", 33 | channels_last=False, 34 | use_bn=False, 35 | ): 36 | 37 | super(DPT, self).__init__() 38 | 39 | self.channels_last = channels_last 40 | 41 | hooks = { 42 | "vitb_rn50_384": [0, 1, 8, 11], 43 | "vitb16_384": [2, 5, 8, 11], 44 | "vitl16_384": [5, 11, 17, 23], 45 | } 46 | 47 | # Instantiate backbone and reassemble blocks 48 | self.pretrained, self.scratch = _make_encoder( 49 | backbone, 50 | features, 51 | False, # Set to true of you want to train from scratch, uses ImageNet weights 52 | groups=1, 53 | expand=False, 54 | exportable=False, 55 | hooks=hooks[backbone], 56 | use_readout=readout, 57 | ) 58 | 59 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 60 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 61 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn) 62 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 63 | 64 | self.scratch.output_conv = head 65 | 66 | 67 | def forward(self, x): 68 | if self.channels_last == True: 69 | x.contiguous(memory_format=torch.channels_last) 70 | 71 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) 72 | 73 | layer_1_rn = self.scratch.layer1_rn(layer_1) 74 | layer_2_rn = self.scratch.layer2_rn(layer_2) 75 | layer_3_rn = self.scratch.layer3_rn(layer_3) 76 | layer_4_rn = self.scratch.layer4_rn(layer_4) 77 | 78 | path_4 = self.scratch.refinenet4(layer_4_rn) 79 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 80 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 81 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 82 | 83 | out = self.scratch.output_conv(path_1) 84 | 85 | return out 86 | 87 | 88 | class DPTDepthModel(DPT): 89 | def __init__(self, path=None, non_negative=True, **kwargs): 90 | features = kwargs["features"] if "features" in kwargs else 256 91 | 92 | head = nn.Sequential( 93 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), 94 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 95 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 96 | nn.ReLU(True), 97 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 98 | nn.ReLU(True) if non_negative else nn.Identity(), 99 | nn.Identity(), 100 | ) 101 | 102 | super().__init__(head, **kwargs) 103 | 104 | if path is not None: 105 | self.load(path) 106 | 107 | def forward(self, x): 108 | return super().forward(x).squeeze(dim=1) 109 | 110 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/midas_net.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=256, non_negative=True): 17 | """Init. 18 | 19 | Args: 20 | path (str, optional): Path to saved model. Defaults to None. 21 | features (int, optional): Number of features. Defaults to 256. 22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 23 | """ 24 | print("Loading weights: ", path) 25 | 26 | super(MidasNet, self).__init__() 27 | 28 | use_pretrained = False if path is None else True 29 | 30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) 31 | 32 | self.scratch.refinenet4 = FeatureFusionBlock(features) 33 | self.scratch.refinenet3 = FeatureFusionBlock(features) 34 | self.scratch.refinenet2 = FeatureFusionBlock(features) 35 | self.scratch.refinenet1 = FeatureFusionBlock(features) 36 | 37 | self.scratch.output_conv = nn.Sequential( 38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 39 | Interpolate(scale_factor=2, mode="bilinear"), 40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 41 | nn.ReLU(True), 42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 43 | nn.ReLU(True) if non_negative else nn.Identity(), 44 | ) 45 | 46 | if path: 47 | self.load(path) 48 | 49 | def forward(self, x): 50 | """Forward pass. 51 | 52 | Args: 53 | x (tensor): input data (image) 54 | 55 | Returns: 56 | tensor: depth 57 | """ 58 | 59 | layer_1 = self.pretrained.layer1(x) 60 | layer_2 = self.pretrained.layer2(layer_1) 61 | layer_3 = self.pretrained.layer3(layer_2) 62 | layer_4 = self.pretrained.layer4(layer_3) 63 | 64 | layer_1_rn = self.scratch.layer1_rn(layer_1) 65 | layer_2_rn = self.scratch.layer2_rn(layer_2) 66 | layer_3_rn = self.scratch.layer3_rn(layer_3) 67 | layer_4_rn = self.scratch.layer4_rn(layer_4) 68 | 69 | path_4 = self.scratch.refinenet4(layer_4_rn) 70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 73 | 74 | out = self.scratch.output_conv(path_1) 75 | 76 | return torch.squeeze(out, dim=1) 77 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/midas_net_custom.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet_small(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, 17 | blocks={'expand': True}): 18 | """Init. 19 | 20 | Args: 21 | path (str, optional): Path to saved model. Defaults to None. 22 | features (int, optional): Number of features. Defaults to 256. 23 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 24 | """ 25 | print("Loading weights: ", path) 26 | 27 | super(MidasNet_small, self).__init__() 28 | 29 | use_pretrained = False if path else True 30 | 31 | self.channels_last = channels_last 32 | self.blocks = blocks 33 | self.backbone = backbone 34 | 35 | self.groups = 1 36 | 37 | features1=features 38 | features2=features 39 | features3=features 40 | features4=features 41 | self.expand = False 42 | if "expand" in self.blocks and self.blocks['expand'] == True: 43 | self.expand = True 44 | features1=features 45 | features2=features*2 46 | features3=features*4 47 | features4=features*8 48 | 49 | self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) 50 | 51 | self.scratch.activation = nn.ReLU(False) 52 | 53 | self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 54 | self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 55 | self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 56 | self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) 57 | 58 | 59 | self.scratch.output_conv = nn.Sequential( 60 | nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), 61 | Interpolate(scale_factor=2, mode="bilinear"), 62 | nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), 63 | self.scratch.activation, 64 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 65 | nn.ReLU(True) if non_negative else nn.Identity(), 66 | nn.Identity(), 67 | ) 68 | 69 | if path: 70 | self.load(path) 71 | 72 | 73 | def forward(self, x): 74 | """Forward pass. 75 | 76 | Args: 77 | x (tensor): input data (image) 78 | 79 | Returns: 80 | tensor: depth 81 | """ 82 | if self.channels_last==True: 83 | print("self.channels_last = ", self.channels_last) 84 | x.contiguous(memory_format=torch.channels_last) 85 | 86 | 87 | layer_1 = self.pretrained.layer1(x) 88 | layer_2 = self.pretrained.layer2(layer_1) 89 | layer_3 = self.pretrained.layer3(layer_2) 90 | layer_4 = self.pretrained.layer4(layer_3) 91 | 92 | layer_1_rn = self.scratch.layer1_rn(layer_1) 93 | layer_2_rn = self.scratch.layer2_rn(layer_2) 94 | layer_3_rn = self.scratch.layer3_rn(layer_3) 95 | layer_4_rn = self.scratch.layer4_rn(layer_4) 96 | 97 | 98 | path_4 = self.scratch.refinenet4(layer_4_rn) 99 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 100 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 101 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 102 | 103 | out = self.scratch.output_conv(path_1) 104 | 105 | return torch.squeeze(out, dim=1) 106 | 107 | 108 | 109 | def fuse_model(m): 110 | prev_previous_type = nn.Identity() 111 | prev_previous_name = '' 112 | previous_type = nn.Identity() 113 | previous_name = '' 114 | for name, module in m.named_modules(): 115 | if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: 116 | # print("FUSED ", prev_previous_name, previous_name, name) 117 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) 118 | elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: 119 | # print("FUSED ", prev_previous_name, previous_name) 120 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) 121 | # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: 122 | # print("FUSED ", previous_name, name) 123 | # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) 124 | 125 | prev_previous_type = previous_type 126 | prev_previous_name = previous_name 127 | previous_type = type(module) 128 | previous_name = name -------------------------------------------------------------------------------- /ldm/modules/midas/midas/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import math 4 | 5 | 6 | def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): 7 | """Rezise the sample to ensure the given size. Keeps aspect ratio. 8 | 9 | Args: 10 | sample (dict): sample 11 | size (tuple): image size 12 | 13 | Returns: 14 | tuple: new size 15 | """ 16 | shape = list(sample["disparity"].shape) 17 | 18 | if shape[0] >= size[0] and shape[1] >= size[1]: 19 | return sample 20 | 21 | scale = [0, 0] 22 | scale[0] = size[0] / shape[0] 23 | scale[1] = size[1] / shape[1] 24 | 25 | scale = max(scale) 26 | 27 | shape[0] = math.ceil(scale * shape[0]) 28 | shape[1] = math.ceil(scale * shape[1]) 29 | 30 | # resize 31 | sample["image"] = cv2.resize( 32 | sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method 33 | ) 34 | 35 | sample["disparity"] = cv2.resize( 36 | sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST 37 | ) 38 | sample["mask"] = cv2.resize( 39 | sample["mask"].astype(np.float32), 40 | tuple(shape[::-1]), 41 | interpolation=cv2.INTER_NEAREST, 42 | ) 43 | sample["mask"] = sample["mask"].astype(bool) 44 | 45 | return tuple(shape) 46 | 47 | 48 | class Resize(object): 49 | """Resize sample to given size (width, height). 50 | """ 51 | 52 | def __init__( 53 | self, 54 | width, 55 | height, 56 | resize_target=True, 57 | keep_aspect_ratio=False, 58 | ensure_multiple_of=1, 59 | resize_method="lower_bound", 60 | image_interpolation_method=cv2.INTER_AREA, 61 | ): 62 | """Init. 63 | 64 | Args: 65 | width (int): desired output width 66 | height (int): desired output height 67 | resize_target (bool, optional): 68 | True: Resize the full sample (image, mask, target). 69 | False: Resize image only. 70 | Defaults to True. 71 | keep_aspect_ratio (bool, optional): 72 | True: Keep the aspect ratio of the input sample. 73 | Output sample might not have the given width and height, and 74 | resize behaviour depends on the parameter 'resize_method'. 75 | Defaults to False. 76 | ensure_multiple_of (int, optional): 77 | Output width and height is constrained to be multiple of this parameter. 78 | Defaults to 1. 79 | resize_method (str, optional): 80 | "lower_bound": Output will be at least as large as the given size. 81 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) 82 | "minimal": Scale as least as possible. (Output size might be smaller than given size.) 83 | Defaults to "lower_bound". 84 | """ 85 | self.__width = width 86 | self.__height = height 87 | 88 | self.__resize_target = resize_target 89 | self.__keep_aspect_ratio = keep_aspect_ratio 90 | self.__multiple_of = ensure_multiple_of 91 | self.__resize_method = resize_method 92 | self.__image_interpolation_method = image_interpolation_method 93 | 94 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None): 95 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) 96 | 97 | if max_val is not None and y > max_val: 98 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) 99 | 100 | if y < min_val: 101 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) 102 | 103 | return y 104 | 105 | def get_size(self, width, height): 106 | # determine new height and width 107 | scale_height = self.__height / height 108 | scale_width = self.__width / width 109 | 110 | if self.__keep_aspect_ratio: 111 | if self.__resize_method == "lower_bound": 112 | # scale such that output size is lower bound 113 | if scale_width > scale_height: 114 | # fit width 115 | scale_height = scale_width 116 | else: 117 | # fit height 118 | scale_width = scale_height 119 | elif self.__resize_method == "upper_bound": 120 | # scale such that output size is upper bound 121 | if scale_width < scale_height: 122 | # fit width 123 | scale_height = scale_width 124 | else: 125 | # fit height 126 | scale_width = scale_height 127 | elif self.__resize_method == "minimal": 128 | # scale as least as possbile 129 | if abs(1 - scale_width) < abs(1 - scale_height): 130 | # fit width 131 | scale_height = scale_width 132 | else: 133 | # fit height 134 | scale_width = scale_height 135 | else: 136 | raise ValueError( 137 | f"resize_method {self.__resize_method} not implemented" 138 | ) 139 | 140 | if self.__resize_method == "lower_bound": 141 | new_height = self.constrain_to_multiple_of( 142 | scale_height * height, min_val=self.__height 143 | ) 144 | new_width = self.constrain_to_multiple_of( 145 | scale_width * width, min_val=self.__width 146 | ) 147 | elif self.__resize_method == "upper_bound": 148 | new_height = self.constrain_to_multiple_of( 149 | scale_height * height, max_val=self.__height 150 | ) 151 | new_width = self.constrain_to_multiple_of( 152 | scale_width * width, max_val=self.__width 153 | ) 154 | elif self.__resize_method == "minimal": 155 | new_height = self.constrain_to_multiple_of(scale_height * height) 156 | new_width = self.constrain_to_multiple_of(scale_width * width) 157 | else: 158 | raise ValueError(f"resize_method {self.__resize_method} not implemented") 159 | 160 | return (new_width, new_height) 161 | 162 | def __call__(self, sample): 163 | width, height = self.get_size( 164 | sample["image"].shape[1], sample["image"].shape[0] 165 | ) 166 | 167 | # resize sample 168 | sample["image"] = cv2.resize( 169 | sample["image"], 170 | (width, height), 171 | interpolation=self.__image_interpolation_method, 172 | ) 173 | 174 | if self.__resize_target: 175 | if "disparity" in sample: 176 | sample["disparity"] = cv2.resize( 177 | sample["disparity"], 178 | (width, height), 179 | interpolation=cv2.INTER_NEAREST, 180 | ) 181 | 182 | if "depth" in sample: 183 | sample["depth"] = cv2.resize( 184 | sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST 185 | ) 186 | 187 | sample["mask"] = cv2.resize( 188 | sample["mask"].astype(np.float32), 189 | (width, height), 190 | interpolation=cv2.INTER_NEAREST, 191 | ) 192 | sample["mask"] = sample["mask"].astype(bool) 193 | 194 | return sample 195 | 196 | 197 | class NormalizeImage(object): 198 | """Normlize image by given mean and std. 199 | """ 200 | 201 | def __init__(self, mean, std): 202 | self.__mean = mean 203 | self.__std = std 204 | 205 | def __call__(self, sample): 206 | sample["image"] = (sample["image"] - self.__mean) / self.__std 207 | 208 | return sample 209 | 210 | 211 | class PrepareForNet(object): 212 | """Prepare sample for usage as network input. 213 | """ 214 | 215 | def __init__(self): 216 | pass 217 | 218 | def __call__(self, sample): 219 | image = np.transpose(sample["image"], (2, 0, 1)) 220 | sample["image"] = np.ascontiguousarray(image).astype(np.float32) 221 | 222 | if "mask" in sample: 223 | sample["mask"] = sample["mask"].astype(np.float32) 224 | sample["mask"] = np.ascontiguousarray(sample["mask"]) 225 | 226 | if "disparity" in sample: 227 | disparity = sample["disparity"].astype(np.float32) 228 | sample["disparity"] = np.ascontiguousarray(disparity) 229 | 230 | if "depth" in sample: 231 | depth = sample["depth"].astype(np.float32) 232 | sample["depth"] = np.ascontiguousarray(depth) 233 | 234 | return sample 235 | -------------------------------------------------------------------------------- /ldm/modules/midas/utils.py: -------------------------------------------------------------------------------- 1 | """Utils for monoDepth.""" 2 | import sys 3 | import re 4 | import numpy as np 5 | import cv2 6 | import torch 7 | 8 | 9 | def read_pfm(path): 10 | """Read pfm file. 11 | 12 | Args: 13 | path (str): path to file 14 | 15 | Returns: 16 | tuple: (data, scale) 17 | """ 18 | with open(path, "rb") as file: 19 | 20 | color = None 21 | width = None 22 | height = None 23 | scale = None 24 | endian = None 25 | 26 | header = file.readline().rstrip() 27 | if header.decode("ascii") == "PF": 28 | color = True 29 | elif header.decode("ascii") == "Pf": 30 | color = False 31 | else: 32 | raise Exception("Not a PFM file: " + path) 33 | 34 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) 35 | if dim_match: 36 | width, height = list(map(int, dim_match.groups())) 37 | else: 38 | raise Exception("Malformed PFM header.") 39 | 40 | scale = float(file.readline().decode("ascii").rstrip()) 41 | if scale < 0: 42 | # little-endian 43 | endian = "<" 44 | scale = -scale 45 | else: 46 | # big-endian 47 | endian = ">" 48 | 49 | data = np.fromfile(file, endian + "f") 50 | shape = (height, width, 3) if color else (height, width) 51 | 52 | data = np.reshape(data, shape) 53 | data = np.flipud(data) 54 | 55 | return data, scale 56 | 57 | 58 | def write_pfm(path, image, scale=1): 59 | """Write pfm file. 60 | 61 | Args: 62 | path (str): pathto file 63 | image (array): data 64 | scale (int, optional): Scale. Defaults to 1. 65 | """ 66 | 67 | with open(path, "wb") as file: 68 | color = None 69 | 70 | if image.dtype.name != "float32": 71 | raise Exception("Image dtype must be float32.") 72 | 73 | image = np.flipud(image) 74 | 75 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 76 | color = True 77 | elif ( 78 | len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 79 | ): # greyscale 80 | color = False 81 | else: 82 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") 83 | 84 | file.write("PF\n" if color else "Pf\n".encode()) 85 | file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) 86 | 87 | endian = image.dtype.byteorder 88 | 89 | if endian == "<" or endian == "=" and sys.byteorder == "little": 90 | scale = -scale 91 | 92 | file.write("%f\n".encode() % scale) 93 | 94 | image.tofile(file) 95 | 96 | 97 | def read_image(path): 98 | """Read image and output RGB image (0-1). 99 | 100 | Args: 101 | path (str): path to file 102 | 103 | Returns: 104 | array: RGB image (0-1) 105 | """ 106 | img = cv2.imread(path) 107 | 108 | if img.ndim == 2: 109 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 110 | 111 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 112 | 113 | return img 114 | 115 | 116 | def resize_image(img): 117 | """Resize image and make it fit for network. 118 | 119 | Args: 120 | img (array): image 121 | 122 | Returns: 123 | tensor: data ready for network 124 | """ 125 | height_orig = img.shape[0] 126 | width_orig = img.shape[1] 127 | 128 | if width_orig > height_orig: 129 | scale = width_orig / 384 130 | else: 131 | scale = height_orig / 384 132 | 133 | height = (np.ceil(height_orig / scale / 32) * 32).astype(int) 134 | width = (np.ceil(width_orig / scale / 32) * 32).astype(int) 135 | 136 | img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) 137 | 138 | img_resized = ( 139 | torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() 140 | ) 141 | img_resized = img_resized.unsqueeze(0) 142 | 143 | return img_resized 144 | 145 | 146 | def resize_depth(depth, width, height): 147 | """Resize depth map and bring to CPU (numpy). 148 | 149 | Args: 150 | depth (tensor): depth 151 | width (int): image width 152 | height (int): image height 153 | 154 | Returns: 155 | array: processed depth 156 | """ 157 | depth = torch.squeeze(depth[0, :, :, :]).to("cpu") 158 | 159 | depth_resized = cv2.resize( 160 | depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC 161 | ) 162 | 163 | return depth_resized 164 | 165 | def write_depth(path, depth, bits=1): 166 | """Write depth map to pfm and png file. 167 | 168 | Args: 169 | path (str): filepath without extension 170 | depth (array): depth 171 | """ 172 | write_pfm(path + ".pfm", depth.astype(np.float32)) 173 | 174 | depth_min = depth.min() 175 | depth_max = depth.max() 176 | 177 | max_val = (2**(8*bits))-1 178 | 179 | if depth_max - depth_min > np.finfo("float").eps: 180 | out = max_val * (depth - depth_min) / (depth_max - depth_min) 181 | else: 182 | out = np.zeros(depth.shape, dtype=depth.type) 183 | 184 | if bits == 1: 185 | cv2.imwrite(path + ".png", out.astype("uint8")) 186 | elif bits == 2: 187 | cv2.imwrite(path + ".png", out.astype("uint16")) 188 | 189 | return 190 | -------------------------------------------------------------------------------- /ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | from torch import optim 5 | import numpy as np 6 | 7 | from inspect import isfunction 8 | from PIL import Image, ImageDraw, ImageFont 9 | 10 | 11 | def log_txt_as_img(wh, xc, size=10): 12 | # wh a tuple of (width, height) 13 | # xc a list of captions to plot 14 | b = len(xc) 15 | txts = list() 16 | for bi in range(b): 17 | txt = Image.new("RGB", wh, color="white") 18 | draw = ImageDraw.Draw(txt) 19 | font = ImageFont.truetype('font/DejaVuSans.ttf', size=size) 20 | nc = int(40 * (wh[0] / 256)) 21 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 22 | 23 | try: 24 | draw.text((0, 0), lines, fill="black", font=font) 25 | except UnicodeEncodeError: 26 | print("Cant encode string for logging. Skipping.") 27 | 28 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 29 | txts.append(txt) 30 | txts = np.stack(txts) 31 | txts = torch.tensor(txts) 32 | return txts 33 | 34 | 35 | def ismap(x): 36 | if not isinstance(x, torch.Tensor): 37 | return False 38 | return (len(x.shape) == 4) and (x.shape[1] > 3) 39 | 40 | 41 | def isimage(x): 42 | if not isinstance(x,torch.Tensor): 43 | return False 44 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 45 | 46 | 47 | def exists(x): 48 | return x is not None 49 | 50 | 51 | def default(val, d): 52 | if exists(val): 53 | return val 54 | return d() if isfunction(d) else d 55 | 56 | 57 | def mean_flat(tensor): 58 | """ 59 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 60 | Take the mean over all non-batch dimensions. 61 | """ 62 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 63 | 64 | 65 | def count_params(model, verbose=False): 66 | total_params = sum(p.numel() for p in model.parameters()) 67 | if verbose: 68 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") 69 | return total_params 70 | 71 | 72 | def instantiate_from_config(config, **kwargs): 73 | if not "target" in config: 74 | if config == '__is_first_stage__': 75 | return None 76 | elif config == "__is_unconditional__": 77 | return None 78 | raise KeyError("Expected key `target` to instantiate.") 79 | return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs) 80 | 81 | 82 | def get_obj_from_str(string, reload=False): 83 | module, cls = string.rsplit(".", 1) 84 | if reload: 85 | module_imp = importlib.import_module(module) 86 | importlib.reload(module_imp) 87 | return getattr(importlib.import_module(module, package=None), cls) 88 | 89 | 90 | class AdamWwithEMAandWings(optim.Optimizer): 91 | # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 92 | def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using 93 | weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code 94 | ema_power=1., param_names=()): 95 | """AdamW that saves EMA versions of the parameters.""" 96 | if not 0.0 <= lr: 97 | raise ValueError("Invalid learning rate: {}".format(lr)) 98 | if not 0.0 <= eps: 99 | raise ValueError("Invalid epsilon value: {}".format(eps)) 100 | if not 0.0 <= betas[0] < 1.0: 101 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 102 | if not 0.0 <= betas[1] < 1.0: 103 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 104 | if not 0.0 <= weight_decay: 105 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 106 | if not 0.0 <= ema_decay <= 1.0: 107 | raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) 108 | defaults = dict(lr=lr, betas=betas, eps=eps, 109 | weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, 110 | ema_power=ema_power, param_names=param_names) 111 | super().__init__(params, defaults) 112 | 113 | def __setstate__(self, state): 114 | super().__setstate__(state) 115 | for group in self.param_groups: 116 | group.setdefault('amsgrad', False) 117 | 118 | @torch.no_grad() 119 | def step(self, closure=None): 120 | """Performs a single optimization step. 121 | Args: 122 | closure (callable, optional): A closure that reevaluates the model 123 | and returns the loss. 124 | """ 125 | loss = None 126 | if closure is not None: 127 | with torch.enable_grad(): 128 | loss = closure() 129 | 130 | for group in self.param_groups: 131 | params_with_grad = [] 132 | grads = [] 133 | exp_avgs = [] 134 | exp_avg_sqs = [] 135 | ema_params_with_grad = [] 136 | state_sums = [] 137 | max_exp_avg_sqs = [] 138 | state_steps = [] 139 | amsgrad = group['amsgrad'] 140 | beta1, beta2 = group['betas'] 141 | ema_decay = group['ema_decay'] 142 | ema_power = group['ema_power'] 143 | 144 | for p in group['params']: 145 | if p.grad is None: 146 | continue 147 | params_with_grad.append(p) 148 | if p.grad.is_sparse: 149 | raise RuntimeError('AdamW does not support sparse gradients') 150 | grads.append(p.grad) 151 | 152 | state = self.state[p] 153 | 154 | # State initialization 155 | if len(state) == 0: 156 | state['step'] = 0 157 | # Exponential moving average of gradient values 158 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 159 | # Exponential moving average of squared gradient values 160 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 161 | if amsgrad: 162 | # Maintains max of all exp. moving avg. of sq. grad. values 163 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 164 | # Exponential moving average of parameter values 165 | state['param_exp_avg'] = p.detach().float().clone() 166 | 167 | exp_avgs.append(state['exp_avg']) 168 | exp_avg_sqs.append(state['exp_avg_sq']) 169 | ema_params_with_grad.append(state['param_exp_avg']) 170 | 171 | if amsgrad: 172 | max_exp_avg_sqs.append(state['max_exp_avg_sq']) 173 | 174 | # update the steps for each param group update 175 | state['step'] += 1 176 | # record the step after step update 177 | state_steps.append(state['step']) 178 | 179 | optim._functional.adamw(params_with_grad, 180 | grads, 181 | exp_avgs, 182 | exp_avg_sqs, 183 | max_exp_avg_sqs, 184 | state_steps, 185 | amsgrad=amsgrad, 186 | beta1=beta1, 187 | beta2=beta2, 188 | lr=group['lr'], 189 | weight_decay=group['weight_decay'], 190 | eps=group['eps'], 191 | maximize=False) 192 | 193 | cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) 194 | for param, ema_param in zip(params_with_grad, ema_params_with_grad): 195 | ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) 196 | 197 | return loss -------------------------------------------------------------------------------- /models/toss_vae.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: cldm.toss.TOSS 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.0120 6 | num_timesteps_cond: 1 7 | log_every_t: 200 8 | timesteps: 1000 9 | first_stage_key: "jpg" 10 | cond_stage_key: "txt" 11 | control_key: "hint" 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: True 19 | only_mid_control: False 20 | ucg_txt: 0.5 21 | max_timesteps: 1000 22 | min_timesteps: 0 23 | finetune: True 24 | ucg_img: 0.05 25 | 26 | scheduler_config: # 10000 warmup steps 27 | target: ldm.lr_scheduler.LambdaLinearScheduler 28 | params: 29 | warm_up_steps: [ 100 ] 30 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 31 | f_start: [ 1.e-6 ] 32 | f_max: [ 1. ] 33 | f_min: [ 1. ] 34 | 35 | unet_config: 36 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel_toss 37 | params: 38 | image_size: 32 # unused 39 | in_channels: 4 40 | out_channels: 4 41 | model_channels: 320 42 | attention_resolutions: [ 4, 2, 1 ] 43 | num_res_blocks: 2 44 | channel_mult: [ 1, 2, 4, 4 ] 45 | num_heads: 8 46 | use_spatial_transformer: True 47 | transformer_depth: 1 48 | context_dim: 768 49 | use_checkpoint: True 50 | legacy: False 51 | temp_attn: "CA_vae" 52 | pose_enc: "vae" 53 | 54 | first_stage_config: 55 | target: ldm.models.autoencoder.AutoencoderKL 56 | params: 57 | embed_dim: 4 58 | monitor: val/rec_loss 59 | ddconfig: 60 | double_z: true 61 | z_channels: 4 62 | resolution: 256 63 | in_channels: 3 64 | out_ch: 3 65 | ch: 128 66 | ch_mult: 67 | - 1 68 | - 2 69 | - 4 70 | - 4 71 | num_res_blocks: 2 72 | attn_resolutions: [] 73 | dropout: 0.0 74 | lossconfig: 75 | target: torch.nn.Identity 76 | 77 | cond_stage_config: 78 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 79 | 80 | 81 | data800k: 82 | target: datasets.objaverse800k.ObjaverseDataModuleFromConfig 83 | params: 84 | root_dir: '/comp_robot/mm_generative/data/.objaverse/hf-objaverse-v1/views_release' 85 | batch_size: 128 86 | num_workers: 12 87 | total_view: 12 88 | caption: "rerank" 89 | pose_enc: "freq" 90 | train: 91 | validation: False 92 | image_transforms: 93 | size: 256 94 | 95 | validation: 96 | validation: True 97 | image_transforms: 98 | size: 256 99 | 100 | 101 | data_car: 102 | target: datasets.objaverse_car.ObjaverseDataModuleFromConfig 103 | params: 104 | root_dir: '/comp_robot/mm_generative/data/.objaverse/hf-objaverse-v1/views_release' 105 | batch_size: 128 106 | num_workers: 12 107 | total_view: 12 108 | caption: 'rerank' 109 | pose_enc: "freq" 110 | train: 111 | validation: False 112 | image_transforms: 113 | size: 256 114 | 115 | validation: 116 | validation: True 117 | image_transforms: 118 | size: 256 119 | -------------------------------------------------------------------------------- /opt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_opts(): 4 | parser = argparse.ArgumentParser() 5 | # common args for all datasets 6 | parser.add_argument('--root_dir', type=str, default="/comp_robot/shiyukai/dataset/nerf/Synthetic_NeRF/Chair/", 7 | help='root directory of dataset') 8 | parser.add_argument('--eval_root_dir', type=str, default='/comp_robot/mm_generative/data/GSO/views', 9 | help='root directory of dataset') 10 | parser.add_argument('--dataset_name', type=str, default='nsvf', help='which dataset to train/test') 11 | parser.add_argument('--split', type=str, default='train', 12 | choices=['train', 'trainval'], 13 | help='use which split to train') 14 | parser.add_argument('--downsample', type=float, default=1.0, 15 | help='downsample factor (<=1.0) for the images') 16 | 17 | # model parameters 18 | parser.add_argument('--model_cfg', type=str, default='./models/cldm_pose_v15.yaml', 19 | help='cfg path of model') 20 | parser.add_argument('--model_low_cfg', type=str, default='./models/cldm_pose_v15.yaml', 21 | help='cfg path of low-level model') 22 | parser.add_argument('--scale', type=float, default=0.5, 23 | help='scene scale (whole scene must lie in [-scale, scale]^3') 24 | parser.add_argument('--use_exposure', action='store_true', default=False, 25 | help='whether to train in HDR-NeRF setting') 26 | parser.add_argument('--resume_path', type=str, default='./models/control_sd15_pose_ini.ckpt', 27 | help='resume path') 28 | parser.add_argument('--resume_path_low', type=str, default='./models/control_sd15_pose_ini.ckpt', 29 | help='resume path for low-level model') 30 | parser.add_argument('--resume', action='store_true', default=False, 31 | help='train from resume') 32 | parser.add_argument('--text', type=str, default="a yellow lego bulldozer sitting on top of a table", 33 | help='text prompt') 34 | parser.add_argument('--uncond_pose', type=bool, default=False, 35 | help='set delta pose zero') 36 | parser.add_argument('--img_size', type=int, default=512, 37 | help='size of img') 38 | parser.add_argument('--acc_grad', type=int, default=None, 39 | help='accumulate grad') 40 | # parser.add_argument('--eval_guidance_scale1', type=float, default=1, 41 | # help='prompt guidance scale for eval') 42 | # parser.add_argument('--eval_guidance_scale2', type=float, default=1, 43 | # help='img guidance scale for eval') 44 | # WERN: duplicate 45 | parser.add_argument('--eval_prompt_guidance_scale', type=float, default=1, 46 | help='prompt guidance scale for eval') 47 | parser.add_argument('--eval_img_guidance_scale', type=float, default=1, 48 | help='img guidance scale for eval') 49 | parser.add_argument('--eval_guidance_scale_low', type=float, default=1, 50 | help='guidance scale for eval') 51 | parser.add_argument('--eval_use_ema_scope', action="store_true", 52 | help='ema_scop for eval') 53 | parser.add_argument('--eval_caption', type=str, default="origin", 54 | help='caption mode for eval') 55 | parser.add_argument('--inf_img_path', type=str, default="./exp/inference/img/008.png", 56 | help='caption mode for eval') 57 | parser.add_argument('--test_sub', action='store_true', default=False, 58 | help='test on part of eval') 59 | parser.add_argument('--divide_steps', type=int, default=800, 60 | help='divide steps for stage model') 61 | parser.add_argument('--attn_t', type=int, default=800, 62 | help='timesteps in viz attn') 63 | parser.add_argument('--layer_name', type=str, default="", 64 | help='timesteps in viz attn') 65 | parser.add_argument('--output_mode_attn', type=str, default="masked", 66 | help='timesteps in viz attn') 67 | parser.add_argument('--img_ucg', type=float, default=0.0, 68 | help='ucg for img') 69 | parser.add_argument('--register_scheduler', action='store_true', default=False, 70 | help='whether to register noise scheduler') 71 | parser.add_argument('--pose_enc', type=str, default="freq", 72 | help='encoding for camera pose') 73 | 74 | 75 | # training options 76 | parser.add_argument('--batch_size', type=int, default=1, 77 | help='number of rays in a batch') 78 | parser.add_argument('--log_interval_epoch', type=int, default=1, 79 | help='interval of logging info') 80 | parser.add_argument('--ckpt_interval', type=int, default=10, 81 | help='interval of ckpt') 82 | parser.add_argument('--logger_freq', type=int, default=20, 83 | help='logger_freq') 84 | parser.add_argument('--ray_sampling_strategy', type=str, default='all_images', 85 | choices=['all_images', 'same_image'], 86 | help=''' 87 | all_images: uniformly from all pixels of ALL images 88 | same_image: uniformly from all pixels of a SAME image 89 | ''') 90 | parser.add_argument('--max_steps', type=int, default=100000, 91 | help='max number of steps to train') 92 | parser.add_argument('--num_epochs', type=int, default=30, 93 | help='number of training epochs') 94 | parser.add_argument('--num_gpus', type=int, default=1, 95 | help='number of gpus') 96 | parser.add_argument('--lr', type=float, default=1e-4, 97 | help='learning rate') 98 | # experimental training options 99 | parser.add_argument('--optimize_ext', action='store_true', default=False, 100 | help='whether to optimize extrinsics (experimental)') 101 | parser.add_argument('--random_bg', action='store_true', default=False, 102 | help='''whether to train with random bg color (real dataset only) 103 | to avoid objects with black color to be predicted as transparent 104 | ''') 105 | 106 | # validation options 107 | parser.add_argument('--eval_lpips', action='store_true', default=False, 108 | help='evaluate lpips metric (consumes more VRAM)') 109 | parser.add_argument('--val_only', action='store_true', default=False, 110 | help='run only validation (need to provide ckpt_path)') 111 | parser.add_argument('--no_save_test', action='store_true', default=False, 112 | help='whether to save test image and video') 113 | parser.add_argument("--eval_image", type=str, default="", help="path to eval image") 114 | parser.add_argument("--eval_prompt", type=str, default="", help="prompt for eval image") 115 | 116 | # misc 117 | parser.add_argument('--exp_name', type=str, default='exp', 118 | help='experiment name') 119 | parser.add_argument('--ckpt_path', type=str, default=None, 120 | help='pretrained checkpoint to load (including optimizers, etc)') 121 | parser.add_argument('--weight_path', type=str, default=None, 122 | help='pretrained checkpoint to load (excluding optimizers, etc)') 123 | 124 | # modified model paras 125 | parser.add_argument("--fuse_fn", type=str, default="trilinear_interp", 126 | help='fuse function for codebook') 127 | parser.add_argument("--deformable_hash", type=str, default="no_deformable", 128 | help='use deformable hash or not, deformable_codebook / deformable_sample') 129 | parser.add_argument("--deformable_hash_speedup", type=str, default="no_speedup", 130 | help='use deformable hash speedup or not, no_speedup / sampling / clustering') 131 | parser.add_argument("--n_levels", type=int, default=16, help='n_levels of codebook') 132 | parser.add_argument("--finest_res", type=int, default=1024, help='finest resolultion for hashed embedding') 133 | parser.add_argument("--base_res", type=int, default=16, help='base resolultion for hashed embedding') 134 | parser.add_argument("--n_features_per_level", type=int, default=2, help='n_features_per_level') 135 | parser.add_argument("--log2_hashmap_size", type=int, default=19, help='log2 of hashmap size') 136 | parser.add_argument("--max_samples", type=int, default=1024, help='max sample points in a ray') 137 | parser.add_argument('--offset_mode', action='store_true', default=False, 138 | help='use offset in codebook or not') 139 | parser.add_argument('--record_offset', action='store_true', default=False, 140 | help='record offset in codebook or not') 141 | parser.add_argument('--update_interval', type=int, default=16, 142 | help='update interval for density map') 143 | parser.add_argument('--deformable_lr', type=float, 144 | help='learning rate of offset') 145 | parser.add_argument('--multi_scale_lr', type=float, 146 | help='adaptive learning rate of multi scale offset') 147 | parser.add_argument('--mlp_lr', type=float, 148 | help='adaptive learning rate of mlp') 149 | parser.add_argument('--feature_lr', type=float, default=0.01, 150 | help='learning rate of feature') 151 | parser.add_argument('--position_loss', type=float, 152 | help='apply l2 loss on position of codebook points or not') 153 | parser.add_argument('--var_loss', type=float, 154 | help='apply loss on variance of codebook points position or not') 155 | parser.add_argument('--offset_loss', type=float, 156 | help='apply l2 loss on offsets of codebook points or not') 157 | parser.add_argument('--warmup', type=int, default=256, 158 | help='warmup steps') 159 | parser.add_argument('--check_codebook', type=int, 160 | help='check codebook steps') 161 | parser.add_argument('--reinit_start', type=int, default=1000, 162 | help='start reinit for codebook check') 163 | parser.add_argument('--reinit_end', type=int, default=2000, 164 | help='end reinit for codebook check') 165 | parser.add_argument('--threshold', type=float, default=0.5, 166 | help='threshold for codebook check') 167 | parser.add_argument('--noise', type=float, default=0., 168 | help='random noise for codebook check') 169 | parser.add_argument('--degree', type=float, default=1., 170 | help='degree for distance inverse interpolation') 171 | parser.add_argument('--limit_func', type=str, default="sigmoid", 172 | help='function on limit of offset') 173 | parser.add_argument('--table_size', type=int, default=5, 174 | help='table_size for grid') 175 | parser.add_argument('--offset_grid', type=int, default=2, 176 | help='area allowed to offset') 177 | parser.add_argument('--grid_hashmap_size', type=int, default=19, 178 | help='hashmap_size for grid') 179 | parser.add_argument('--warmup_epochs', type=float, default=0, 180 | help='warmup epochs for lr scheduler') 181 | 182 | # deformable sample 183 | parser.add_argument('--multi_offset', type=int, default=8, 184 | help='num of deformable offsets for each sample') 185 | 186 | # textual inversion 187 | parser.add_argument("--data_root", type=str, 188 | help='root directory of dataset for textual inversion') 189 | parser.add_argument("--placeholder_string", 190 | type=str, 191 | help="Placeholder string which will be used to denote the concept in future prompts. Overwrites the config options.") 192 | parser.add_argument("--init_word", 193 | type=str, 194 | help="Word to use as source for initial token embedding") 195 | parser.add_argument("--embedding_manager_ckpt", 196 | type=str, 197 | default="", 198 | help="Initialize embedding manager from a checkpoint") 199 | parser.add_argument("--embedding_path", 200 | type=str, 201 | help="Path to a pre-trained embedding manager checkpoint") 202 | 203 | 204 | return parser.parse_args() 205 | -------------------------------------------------------------------------------- /outputs/a dragon toy with fire on the back.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/outputs/a dragon toy with fire on the back.png -------------------------------------------------------------------------------- /outputs/a dragon toy with ice on the back.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/outputs/a dragon toy with ice on the back.png -------------------------------------------------------------------------------- /outputs/anya/0_95.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/outputs/anya/0_95.png -------------------------------------------------------------------------------- /outputs/backview of a dragon toy with fire on the back.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/outputs/backview of a dragon toy with fire on the back.png -------------------------------------------------------------------------------- /outputs/backview of a dragon toy with ice on the back.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/outputs/backview of a dragon toy with ice on the back.png -------------------------------------------------------------------------------- /outputs/dragon/ a purple dragon with fire on the back.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/outputs/dragon/ a purple dragon with fire on the back.png -------------------------------------------------------------------------------- /outputs/dragon/a dragon toy with fire on the back.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/outputs/dragon/a dragon toy with fire on the back.png -------------------------------------------------------------------------------- /outputs/dragon/a dragon with fire on its back.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/outputs/dragon/a dragon with fire on its back.png -------------------------------------------------------------------------------- /outputs/dragon/a dragon with ice on its back.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/outputs/dragon/a dragon with ice on its back.png -------------------------------------------------------------------------------- /outputs/dragon/a dragon with ice on the back.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/outputs/dragon/a dragon with ice on the back.png -------------------------------------------------------------------------------- /outputs/dragon/a purple dragon with fire on the back.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/outputs/dragon/a purple dragon with fire on the back.png -------------------------------------------------------------------------------- /outputs/minion/a dragon toy with fire on its back.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/outputs/minion/a dragon toy with fire on its back.png -------------------------------------------------------------------------------- /outputs/minion/a minion with a rocket on the back.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/outputs/minion/a minion with a rocket on the back.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gradio==3.40.1 2 | albumentations==1.3.0 3 | opencv-python==4.5.5.64 4 | imageio==2.9.0 5 | imageio-ffmpeg==0.4.2 6 | pytorch-lightning==1.5.0 7 | omegaconf==2.1.1 8 | test-tube>=0.7.5 9 | streamlit==1.12.1 10 | einops==0.3.0 11 | transformers==4.22.2 12 | webdataset==0.2.5 13 | kornia==0.6 14 | open_clip_torch==2.0.2 15 | invisible-watermark>=0.1.5 16 | streamlit-drawable-canvas==0.8.0 17 | torchmetrics==0.6.0 18 | timm==0.6.12 19 | addict==2.4.0 20 | yapf==0.32.0 21 | prettytable==3.6.0 22 | safetensors==0.2.7 23 | basicsr==1.4.2 24 | carvekit-colab==4.1.0 -------------------------------------------------------------------------------- /share.py: -------------------------------------------------------------------------------- 1 | import config 2 | from cldm.hack import disable_verbosity, enable_sliced_attention 3 | 4 | 5 | disable_verbosity() 6 | 7 | if config.save_memory: 8 | enable_sliced_attention() 9 | -------------------------------------------------------------------------------- /streamlit_app.py: -------------------------------------------------------------------------------- 1 | 2 | from collections import namedtuple 3 | import altair as alt 4 | import math 5 | import pandas as pd 6 | import streamlit as st 7 | 8 | """ 9 | # Welcome to Streamlit! 10 | 11 | Edit `/streamlit_app.py` to customize this app to your heart's desire :heart: 12 | 13 | If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community 14 | forums](https://discuss.streamlit.io). 15 | 16 | In the meantime, below is an example of what you can do with just a few lines of code: 17 | """ 18 | 19 | 20 | with st.echo(code_location='below'): 21 | total_points = st.slider("Number of points in spiral", 1, 5000, 2000) 22 | num_turns = st.slider("Number of turns in spiral", 1, 100, 9) 23 | 24 | Point = namedtuple('Point', 'x y') 25 | data = [] 26 | 27 | points_per_turn = total_points / num_turns 28 | 29 | for curr_point_num in range(total_points): 30 | curr_turn, i = divmod(curr_point_num, points_per_turn) 31 | angle = (curr_turn + 1) * 2 * math.pi * i / points_per_turn 32 | radius = curr_point_num / total_points 33 | x = radius * math.cos(angle) 34 | y = radius * math.sin(angle) 35 | data.append(Point(x, y)) 36 | 37 | st.altair_chart(alt.Chart(pd.DataFrame(data), height=500, width=500) 38 | .mark_circle(color='#0068c9', opacity=0.5) 39 | .encode(x='x:Q', y='y:Q')) 40 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from share import * 2 | 3 | import pytorch_lightning as pl 4 | from torch.utils.data import DataLoader 5 | from tutorial_dataset import MyDataset 6 | from cldm.logger import ImageLogger 7 | from cldm.model import create_model, load_state_dict 8 | from pytorch_lightning.callbacks import ModelCheckpoint 9 | from pytorch_lightning.loggers import TensorBoardLogger 10 | from pytorch_lightning.plugins import DDPPlugin 11 | from ldm.util import instantiate_from_config 12 | import torch 13 | from omegaconf import OmegaConf 14 | 15 | # data 16 | from torch.utils.data import DataLoader 17 | from datasets import dataset_dict 18 | from datasets.ray_utils import axisangle_to_R, get_rays 19 | 20 | # configure 21 | from opt import get_opts 22 | import pdb 23 | 24 | 25 | if __name__ == '__main__': 26 | # Configs 27 | hparams = get_opts() 28 | logger_freq = hparams.logger_freq 29 | sd_locked = True 30 | only_mid_control = False 31 | cfgs = OmegaConf.load(hparams.model_cfg) 32 | 33 | 34 | # First use cpu to load models. Pytorch Lightning will automatically move it to GPUs. 35 | model = instantiate_from_config(cfgs.model) 36 | 37 | # get missing keys 38 | if hparams.resume_path != './models/control_sd15_pose_ini.ckpt': 39 | missing, unexpected = model.load_state_dict(load_state_dict(hparams.resume_path, location='cpu'), strict=False) 40 | print(f"Restored from {hparams.resume_path} with {len(missing)} missing and {len(unexpected)} unexpected keys") 41 | if len(missing) > 0: 42 | print(f"Missing Keys:\n {missing}") 43 | if len(unexpected) > 0: 44 | print(f"\nUnexpected Keys:\n {unexpected}") 45 | 46 | # model.load_state_dict(load_state_dict(hparams.resume_path, location='cpu'), strict=False) 47 | 48 | # reweight noise scheduer 49 | if hparams.register_scheduler: 50 | model.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.016) 51 | model.learning_rate = hparams.lr 52 | model.sd_locked = sd_locked 53 | model.only_mid_control = only_mid_control 54 | # model.control_model.uncond_pose = hparams.uncond_pose 55 | model.eval() 56 | 57 | # data 58 | if 'objaverse' in hparams.dataset_name: 59 | if "car" in hparams.dataset_name: 60 | dataloader = instantiate_from_config(cfgs.data_car) 61 | print("Using objaverse car data!") 62 | elif "800k" in hparams.dataset_name: 63 | dataloader = instantiate_from_config(cfgs.data800k) 64 | else: 65 | dataloader = instantiate_from_config(cfgs.data) 66 | dataloader.prepare_data() 67 | dataloader.setup() 68 | elif "srn" in hparams.dataset_name: 69 | if "chair" in hparams.dataset_name: 70 | dataloader = instantiate_from_config(cfgs.srn_chairs) 71 | else: 72 | dataloader = instantiate_from_config(cfgs.srn_data) 73 | dataloader.prepare_data() 74 | dataloader.setup() 75 | else: 76 | kwargs = {'root_dir': hparams.root_dir} 77 | dataset = dataset_dict[hparams.dataset_name](split=hparams.split, text=hparams.text, img_size=hparams.img_size, **kwargs) 78 | dataloader = DataLoader(dataset, num_workers=0, batch_size=hparams.batch_size, shuffle=True) 79 | 80 | 81 | # Train! 82 | save_dir = f'./exp/{hparams.dataset_name}/{hparams.exp_name}/' 83 | logger = TensorBoardLogger(save_dir=save_dir, 84 | name=hparams.exp_name, default_hp_metric=False) 85 | img_logger = ImageLogger(batch_frequency=logger_freq, epoch_frequency=hparams.log_interval_epoch) 86 | ckpt_cb = ModelCheckpoint(dirpath=f'{save_dir}/ckpt/', 87 | filename='{epoch:d}', 88 | save_weights_only=False, 89 | every_n_epochs=hparams.ckpt_interval, 90 | save_on_train_epoch_end=True, 91 | save_top_k=-1) 92 | callbacks = [img_logger, ckpt_cb] 93 | trainer = pl.Trainer(gpus=hparams.num_gpus, callbacks=callbacks, \ 94 | precision=16, \ 95 | # amp_backend='apex', amp_level="O2", \ 96 | check_val_every_n_epoch=20, 97 | logger=logger, max_epochs=hparams.num_epochs, \ 98 | resume_from_checkpoint=hparams.resume_path if hparams.resume else None, 99 | # strategy="ddp", 100 | accumulate_grad_batches=8//hparams.num_gpus \ 101 | if hparams.acc_grad==None else hparams.acc_grad, 102 | plugins=DDPPlugin(find_unused_parameters=False), 103 | accelerator="ddp" 104 | ) 105 | 106 | # trainer.validate(model, dataloader) 107 | trainer.fit(model, dataloader) 108 | 109 | -------------------------------------------------------------------------------- /tutorial_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import cv2 3 | import numpy as np 4 | 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class MyDataset(Dataset): 9 | def __init__(self): 10 | self.data = [] 11 | with open('./training/fill50k/prompt.json', 'rt') as f: 12 | for line in f: 13 | self.data.append(json.loads(line)) 14 | 15 | def __len__(self): 16 | return len(self.data) 17 | 18 | def __getitem__(self, idx): 19 | item = self.data[idx] 20 | 21 | source_filename = item['source'] 22 | target_filename = item['target'] 23 | prompt = item['prompt'] 24 | 25 | source = cv2.imread('./training/fill50k/' + source_filename) 26 | target = cv2.imread('./training/fill50k/' + target_filename) 27 | 28 | # Do not forget that OpenCV read images in BGR order. 29 | source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB) 30 | target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB) 31 | 32 | # Normalize source images to [0, 1]. 33 | source = source.astype(np.float32) / 255.0 34 | 35 | # Normalize target images to [-1, 1]. 36 | target = (target.astype(np.float32) / 127.5) - 1.0 37 | 38 | return dict(jpg=target, txt=prompt, hint=source) 39 | 40 | -------------------------------------------------------------------------------- /viz.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | 4 | 5 | def save_image_tensor2cv2(input_tensor: torch.Tensor, filename): 6 | """ 7 | Save a tensor to a file as an image. 8 | :param input_tensor: tensor to save [C, H, W] 9 | :param filename: file to save to 10 | """ 11 | assert (len(input_tensor.shape) == 3) 12 | input_tensor = input_tensor.clone().detach() 13 | input_tensor = input_tensor.to(torch.device('cpu')) 14 | input_tensor = input_tensor.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).type(torch.uint8).numpy() 15 | input_tensor = cv2.cvtColor(input_tensor, cv2.COLOR_RGB2BGR) 16 | cv2.imwrite(filename, input_tensor) 17 | --------------------------------------------------------------------------------