├── .gitignore
├── LICENSE
├── README.md
├── app.py
├── assets
├── 3d_generation_video.mp4
├── car.png
├── dragon.png
├── gradio.png
├── kunkun.png
└── rocket.png
├── cldm
├── ddim_hacked.py
├── hack.py
├── logger.py
├── model.py
└── toss.py
├── config.py
├── datasets
├── __init__.py
├── base.py
├── colmap.py
├── colmap_utils.py
├── depth_utils.py
├── nerfpp.py
├── nsvf.py
├── objaverse.py
├── objaverse800k.py
├── objaverse_car.py
├── ray_utils.py
└── rtmv.py
├── ldm
├── data
│ ├── __init__.py
│ ├── base.py
│ ├── personalized.py
│ └── util.py
├── lr_scheduler.py
├── models
│ ├── autoencoder.py
│ └── diffusion
│ │ ├── __init__.py
│ │ ├── ddim.py
│ │ ├── ddpm.py
│ │ ├── dpm_solver
│ │ ├── __init__.py
│ │ ├── dpm_solver.py
│ │ └── sampler.py
│ │ ├── plms.py
│ │ └── sampling_util.py
├── modules
│ ├── attention.py
│ ├── diffusionmodules
│ │ ├── __init__.py
│ │ ├── model.py
│ │ ├── openaimodel.py
│ │ ├── upscaling.py
│ │ └── util.py
│ ├── distributions
│ │ ├── __init__.py
│ │ └── distributions.py
│ ├── ema.py
│ ├── embedding_manager.py
│ ├── encoders
│ │ ├── __init__.py
│ │ └── modules.py
│ ├── image_degradation
│ │ ├── __init__.py
│ │ ├── bsrgan.py
│ │ ├── bsrgan_light.py
│ │ ├── utils
│ │ │ └── test.png
│ │ └── utils_image.py
│ └── midas
│ │ ├── __init__.py
│ │ ├── api.py
│ │ ├── midas
│ │ ├── __init__.py
│ │ ├── base_model.py
│ │ ├── blocks.py
│ │ ├── dpt_depth.py
│ │ ├── midas_net.py
│ │ ├── midas_net_custom.py
│ │ ├── transforms.py
│ │ └── vit.py
│ │ └── utils.py
└── util.py
├── models
└── toss_vae.yaml
├── opt.py
├── outputs
├── a dragon toy with fire on the back.png
├── a dragon toy with ice on the back.png
├── anya
│ └── 0_95.png
├── backview of a dragon toy with fire on the back.png
├── backview of a dragon toy with ice on the back.png
├── dragon
│ ├── a purple dragon with fire on the back.png
│ ├── a dragon toy with fire on the back.png
│ ├── a dragon with fire on its back.png
│ ├── a dragon with ice on its back.png
│ ├── a dragon with ice on the back.png
│ └── a purple dragon with fire on the back.png
└── minion
│ ├── a dragon toy with fire on its back.png
│ └── a minion with a rocket on the back.png
├── requirements.txt
├── share.py
├── streamlit_app.py
├── train.py
├── tutorial_dataset.py
└── viz.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
162 | gradio_res/
163 | ckpt/toss.ckpt
164 | exp/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TOSS: High-quality Text-guided Novel View Synthesis from a Single Image (ICLR2024)
2 |
3 | #####
Yukai Shi, Jianan Wang, He Cao, Boshi Tang, Xianbiao Qi, Tianyu Yang, Yukun Huang, Shilong Liu, Lei Zhang, Heung-Yeung Shum
4 |
5 |
6 |
7 |
8 | Official implementation for *TOSS: High-quality Text-guided Novel View Synthesis from a Single Image*.
9 |
10 | **TOSS introduces text as high-level sementic information to constraint the NVS solution space for more controllable and more plausible results.**
11 |
12 | ## [Project Page](https://toss3d.github.io/) | [ArXiv](https://arxiv.org/abs/2310.10644) | [Weights](https://drive.google.com/drive/folders/15URQHblOVi_7YXZtgdFpjZAlKsoHylsq?usp=sharing)
13 |
14 |
15 | https://github.com/IDEA-Research/TOSS/assets/54578597/cd64c6c5-fef8-43c2-a223-7930ad6a71b7
16 |
17 |
18 | ## Install
19 |
20 | ### Create environment
21 | ```bash
22 | conda create -n toss python=3.9
23 | conda activate toss
24 | ```
25 |
26 | ### Install packages
27 | ```bash
28 | pip install torch==2.0.0 torchvision==0.15.1 torchaudio==2.0.1 --index-url https://download.pytorch.org/whl/cu118
29 | pip install -r requirements.txt
30 | git clone https://github.com/openai/CLIP.git
31 | pip install -e CLIP/
32 | ```
33 | ### Weights
34 | Download pretrain weights from [this link](https://drive.google.com/drive/folders/15URQHblOVi_7YXZtgdFpjZAlKsoHylsq?usp=sharing) to sub-directory ./ckpt
35 |
36 | ## Inference
37 |
38 | We suggest gradio for a visualized inference and test this demo on a single RTX3090.
39 |
40 | ```
41 | python app.py
42 | ```
43 |
44 | 
45 |
46 |
47 | ## Todo List
48 | - [x] Release inference code.
49 | - [x] Release pretrained models.
50 | - [ ] Upload 3D generation code.
51 | - [ ] Upload training code.
52 |
53 | ## Acknowledgement
54 | - [ControlNet](https://github.com/lllyasviel/ControlNet/)
55 | - [Zero123](https://github.com/cvlab-columbia/zero123/)
56 | - [threestudio](https://github.com/threestudio-project/threestudio)
57 |
58 | ## Citation
59 |
60 | ```
61 | @article{shi2023toss,
62 | title={Toss: High-quality text-guided novel view synthesis from a single image},
63 | author={Shi, Yukai and Wang, Jianan and Cao, He and Tang, Boshi and Qi, Xianbiao and Yang, Tianyu and Huang, Yukun and Liu, Shilong and Zhang, Lei and Shum, Heung-Yeung},
64 | journal={arXiv preprint arXiv:2310.10644},
65 | year={2023}
66 | }
67 | ```
68 |
--------------------------------------------------------------------------------
/assets/3d_generation_video.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/assets/3d_generation_video.mp4
--------------------------------------------------------------------------------
/assets/car.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/assets/car.png
--------------------------------------------------------------------------------
/assets/dragon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/assets/dragon.png
--------------------------------------------------------------------------------
/assets/gradio.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/assets/gradio.png
--------------------------------------------------------------------------------
/assets/kunkun.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/assets/kunkun.png
--------------------------------------------------------------------------------
/assets/rocket.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/assets/rocket.png
--------------------------------------------------------------------------------
/cldm/hack.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import einops
3 |
4 | import ldm.modules.encoders.modules
5 | import ldm.modules.attention
6 |
7 | from transformers import logging
8 | from ldm.modules.attention import default
9 |
10 |
11 | def disable_verbosity():
12 | logging.set_verbosity_error()
13 | print('logging improved.')
14 | return
15 |
16 |
17 | def enable_sliced_attention():
18 | ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward
19 | print('Enabled sliced_attention.')
20 | return
21 |
22 |
23 | def hack_everything(clip_skip=0):
24 | disable_verbosity()
25 | ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward
26 | ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip
27 | print('Enabled clip hacks.')
28 | return
29 |
30 |
31 | # Written by Lvmin
32 | def _hacked_clip_forward(self, text):
33 | PAD = self.tokenizer.pad_token_id
34 | EOS = self.tokenizer.eos_token_id
35 | BOS = self.tokenizer.bos_token_id
36 |
37 | def tokenize(t):
38 | return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"]
39 |
40 | def transformer_encode(t):
41 | if self.clip_skip > 1:
42 | rt = self.transformer(input_ids=t, output_hidden_states=True)
43 | return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip])
44 | else:
45 | return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state
46 |
47 | def split(x):
48 | return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3]
49 |
50 | def pad(x, p, i):
51 | return x[:i] if len(x) >= i else x + [p] * (i - len(x))
52 |
53 | raw_tokens_list = tokenize(text)
54 | tokens_list = []
55 |
56 | for raw_tokens in raw_tokens_list:
57 | raw_tokens_123 = split(raw_tokens)
58 | raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123]
59 | raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123]
60 | tokens_list.append(raw_tokens_123)
61 |
62 | tokens_list = torch.IntTensor(tokens_list).to(self.device)
63 |
64 | feed = einops.rearrange(tokens_list, 'b f i -> (b f) i')
65 | y = transformer_encode(feed)
66 | z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3)
67 |
68 | return z
69 |
70 |
71 | # Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py
72 | def _hacked_sliced_attentin_forward(self, x, context=None, mask=None):
73 | h = self.heads
74 |
75 | q = self.to_q(x)
76 | context = default(context, x)
77 | k = self.to_k(context)
78 | v = self.to_v(context)
79 | del context, x
80 |
81 | q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
82 |
83 | limit = k.shape[0]
84 | att_step = 1
85 | q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0))
86 | k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0))
87 | v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0))
88 |
89 | q_chunks.reverse()
90 | k_chunks.reverse()
91 | v_chunks.reverse()
92 | sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
93 | del k, q, v
94 | for i in range(0, limit, att_step):
95 | q_buffer = q_chunks.pop()
96 | k_buffer = k_chunks.pop()
97 | v_buffer = v_chunks.pop()
98 | sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
99 |
100 | del k_buffer, q_buffer
101 | # attention, what we cannot get enough of, by chunks
102 |
103 | sim_buffer = sim_buffer.softmax(dim=-1)
104 |
105 | sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
106 | del v_buffer
107 | sim[i:i + att_step, :, :] = sim_buffer
108 |
109 | del sim_buffer
110 | sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h)
111 | return self.to_out(sim)
112 |
--------------------------------------------------------------------------------
/cldm/logger.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import torch
5 | import torchvision
6 | from PIL import Image
7 | from pytorch_lightning.callbacks import Callback
8 | from pytorch_lightning.utilities.distributed import rank_zero_only
9 | import time
10 |
11 |
12 | class ImageLogger(Callback):
13 | def __init__(self, batch_frequency=2000, max_images=4, clamp=True, increase_log_steps=True,
14 | rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
15 | log_images_kwargs=None, epoch_frequency=1):
16 | super().__init__()
17 | self.rescale = rescale
18 | self.batch_freq = batch_frequency
19 | self.max_images = max_images
20 | if not increase_log_steps:
21 | self.log_steps = [self.batch_freq]
22 | self.clamp = clamp
23 | self.disabled = disabled
24 | self.log_on_batch_idx = log_on_batch_idx
25 | self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
26 | self.log_first_step = log_first_step
27 | self.epoch = 0
28 | self.epoch_frequency = epoch_frequency
29 | self.start_time = time.time()
30 |
31 | @rank_zero_only
32 | def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
33 | root = os.path.join(save_dir, "image_log", split)
34 | for k in images:
35 | grid = torchvision.utils.make_grid(images[k], nrow=4)
36 | if self.rescale:
37 | grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
38 | grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
39 | grid = grid.numpy()
40 | grid = (grid * 255).astype(np.uint8)
41 | filename = "epoch{:06}/{}_gs-{:06}_e-{:06}_b-{:06}.png".format(current_epoch,k, global_step, current_epoch, batch_idx)
42 | path = os.path.join(root, filename)
43 | os.makedirs(os.path.split(path)[0], exist_ok=True)
44 | Image.fromarray(grid).save(path)
45 | if current_epoch > self.epoch:
46 | self.epoch += 1
47 |
48 | def log_img(self, pl_module, batch, batch_idx, split="train"):
49 | check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
50 | if (self.check_frequency(check_idx) and
51 | hasattr(pl_module, "log_images") and
52 | callable(pl_module.log_images) and
53 | self.max_images > 0):
54 | logger = type(pl_module.logger)
55 |
56 | is_train = pl_module.training
57 | if is_train:
58 | pl_module.eval()
59 |
60 | with torch.no_grad():
61 | images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
62 |
63 | for k in images:
64 | N = min(images[k].shape[0], self.max_images)
65 | images[k] = images[k][:N]
66 | if isinstance(images[k], torch.Tensor):
67 | images[k] = images[k].detach().cpu()
68 | if self.clamp:
69 | images[k] = torch.clamp(images[k], -1., 1.)
70 |
71 | self.log_local(pl_module.logger.save_dir, split, images,
72 | pl_module.global_step, pl_module.current_epoch, batch_idx)
73 |
74 | if is_train:
75 | pl_module.train()
76 |
77 | def check_frequency(self, check_idx):
78 | return check_idx % self.batch_freq == 0
79 |
80 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
81 | if not self.disabled:
82 | self.log_img(pl_module, batch, batch_idx, split="train")
83 |
84 | def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
85 | if not self.disabled:
86 | self.log_img(pl_module, batch, batch_idx, split="val")
87 |
88 | def on_train_epoch_end(self, trainer, pl_module):
89 | end_time = time.time()
90 | self.time_counter(self.start_time, end_time, pl_module.global_step)
91 |
92 | def time_counter(self, begin_time, end_time, step):
93 | run_time = round(end_time - begin_time)
94 | # transfer time
95 | hour = run_time // 3600
96 | minute = (run_time - 3600 * hour) // 60
97 | second = run_time - 3600 * hour - 60 * minute
98 | print(f'Time cost until step {step}:{hour}h:{minute}miin:{second}s')
--------------------------------------------------------------------------------
/cldm/model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | from omegaconf import OmegaConf
5 | from ldm.util import instantiate_from_config
6 |
7 |
8 | def get_state_dict(d):
9 | return d.get('state_dict', d)
10 |
11 |
12 | def load_state_dict(ckpt_path, location='cpu'):
13 | _, extension = os.path.splitext(ckpt_path)
14 | if extension.lower() == ".safetensors":
15 | import safetensors.torch
16 | state_dict = safetensors.torch.load_file(ckpt_path, device=location)
17 | else:
18 | state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
19 | state_dict = get_state_dict(state_dict)
20 | print(f'Loaded state_dict from [{ckpt_path}]')
21 | return state_dict
22 |
23 |
24 | def create_model(config_path):
25 | config = OmegaConf.load(config_path)
26 | model = instantiate_from_config(config.model).cpu()
27 | print(f'Loaded model config from [{config_path}]')
28 | return model
29 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | save_memory = False
2 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .nsvf import NSVFDataset, NSVFDataset_v2, NSVFDataset_all
2 | from .colmap import ColmapDataset
3 | from .nerfpp import NeRFPPDataset
4 | from .objaverse import ObjaverseData
5 |
6 |
7 | dataset_dict = {'nsvf': NSVFDataset,
8 | 'nsvf_v2': NSVFDataset_v2,
9 | "nsvf_all": NSVFDataset_all,
10 | 'colmap': ColmapDataset,
11 | 'nerfpp': NeRFPPDataset,
12 | 'objaverse': ObjaverseData,}
--------------------------------------------------------------------------------
/datasets/base.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms as T
2 |
3 | from torchvision import transforms
4 | from torch.utils.data import Dataset
5 | import numpy as np
6 | import torch
7 | import random
8 | import math
9 | import pdb
10 | from viz import save_image_tensor2cv2
11 |
12 |
13 | class BaseDataset(Dataset):
14 | """
15 | Define length and sampling method
16 | """
17 | def __init__(self, root_dir, split='train', text="a green chair", img_size=512, downsample=1.0):
18 | self.img_w, self.img_h = img_size, img_size
19 | self.root_dir = root_dir
20 | self.split = split
21 | self.downsample = downsample
22 | self.define_transforms()
23 | self.text = text
24 |
25 | def read_intrinsics(self):
26 | raise NotImplementedError
27 |
28 | def define_transforms(self):
29 | self.transform = transforms.Compose([T.ToTensor(), transforms.Resize(size=(self.img_w, self.img_h))])
30 | # self.transform = T.ToTensor()
31 |
32 | def __len__(self):
33 | # if self.split.startswith('train'):
34 | # return 1000
35 | return len(self.poses)
36 |
37 | def cartesian_to_spherical(self, xyz):
38 | ptsnew = np.hstack((xyz, np.zeros(xyz.shape)))
39 | xy = xyz[:,0]**2 + xyz[:,1]**2
40 | z = np.sqrt(xy + xyz[:,2]**2)
41 | theta = np.arctan2(np.sqrt(xy), xyz[:,2]) # for elevation angle defined from Z-axis down
42 | #ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up
43 | azimuth = np.arctan2(xyz[:,1], xyz[:,0])
44 | return np.array([theta, azimuth, z])
45 |
46 | def get_T(self, target_RT, cond_RT):
47 | R, T = target_RT[:3, :3], target_RT[:, -1]
48 | T_target = -R.T @ T
49 |
50 | R, T = cond_RT[:3, :3], cond_RT[:, -1]
51 | T_cond = -R.T @ T
52 |
53 | theta_cond, azimuth_cond, z_cond = self.cartesian_to_spherical(T_cond[None, :])
54 | theta_target, azimuth_target, z_target = self.cartesian_to_spherical(T_target[None, :])
55 |
56 | d_theta = theta_target - theta_cond
57 | d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi)
58 | d_z = z_target - z_cond
59 |
60 | d_T = torch.tensor([d_theta.item(), d_azimuth.item(), d_z.item()])
61 | return d_T
62 |
63 | def get_T_w2c(self, target_RT, cond_RT):
64 | T_target = target_RT[:, -1]
65 | T_cond = cond_RT[:, -1]
66 |
67 | theta_cond, azimuth_cond, z_cond = self.cartesian_to_spherical(T_cond[None, :])
68 | theta_target, azimuth_target, z_target = self.cartesian_to_spherical(T_target[None, :])
69 |
70 | d_theta = theta_target - theta_cond
71 | d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi)
72 | d_z = z_target - z_cond
73 | # d_z = (z_target - z_cond) / np.max(z_cond) * 2
74 | # pdb.set_trace()
75 |
76 | d_T = torch.tensor([d_theta.item(), d_azimuth.item(), d_z.item()])
77 | return d_T
78 |
79 | def __getitem__(self, idx):
80 | # camera pose and img
81 | poses = self.poses[idx]
82 | img = self.imgs[idx]
83 | prompt = self.text
84 |
85 | # condition
86 | # idx_cond = idx % 5
87 | # idx_cond = 1
88 | idx_cond = random.randint(0, 4)
89 | # idx_cond = random.randint(0, len(self.poses)-1)
90 | poses_cond = self.poses[idx_cond]
91 | img_cond = self.imgs[idx_cond]
92 |
93 | # if len(self.imgs)>0: # if ground truth available
94 | # img_rgb = imgs[:, :,:3]
95 | # img_rgb_cond = imgs_cond[:, :,:3]
96 | # if imgs.shape[-1] == 4: # HDR-NeRF data
97 | # sample['exposure'] = rays[0, 3] # same exposure for all rays
98 |
99 | # Normalize target images to [-1, 1].
100 | target = (img.float() / 127.5) - 1.0
101 | # Normalize source images to [0, 1].
102 | condition = img_cond.float() / 255.0
103 | # save_image_tensor2cv2(condition.permute(2,0,1), "./viz_fig/input_cond.png")
104 |
105 | # get delta pose
106 | # delta_pose = self.get_T(target_RT=poses, cond_RT=poses_cond)
107 | delta_pose = self.get_T_w2c(target_RT=poses, cond_RT=poses_cond)
108 |
109 | return dict(jpg=target, txt=prompt, hint=condition, delta_pose=delta_pose)
--------------------------------------------------------------------------------
/datasets/colmap.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import os
4 | import glob
5 | from PIL import Image
6 | from einops import rearrange
7 | from tqdm import tqdm
8 |
9 | from .ray_utils import *
10 | from .colmap_utils import \
11 | read_cameras_binary, read_images_binary, read_points3d_binary
12 |
13 | from .base import BaseDataset
14 |
15 |
16 | class ColmapDataset(BaseDataset):
17 | def __init__(self, root_dir, split='train', downsample=1.0, **kwargs):
18 | super().__init__(root_dir, split, downsample)
19 |
20 | self.read_intrinsics()
21 |
22 | if kwargs.get('read_meta', True):
23 | self.read_meta(split, **kwargs)
24 |
25 | def read_intrinsics(self):
26 | # Step 1: read and scale intrinsics (same for all images)
27 | camdata = read_cameras_binary(os.path.join(self.root_dir, 'sparse/0/cameras.bin'))
28 | h = int(camdata[1].height*self.downsample)
29 | w = int(camdata[1].width*self.downsample)
30 | self.img_wh = (w, h)
31 |
32 | if camdata[1].model == 'SIMPLE_RADIAL':
33 | fx = fy = camdata[1].params[0]*self.downsample
34 | cx = camdata[1].params[1]*self.downsample
35 | cy = camdata[1].params[2]*self.downsample
36 | elif camdata[1].model in ['PINHOLE', 'OPENCV']:
37 | fx = camdata[1].params[0]*self.downsample
38 | fy = camdata[1].params[1]*self.downsample
39 | cx = camdata[1].params[2]*self.downsample
40 | cy = camdata[1].params[3]*self.downsample
41 | else:
42 | raise ValueError(f"Please parse the intrinsics for camera model {camdata[1].model}!")
43 | self.K = torch.FloatTensor([[fx, 0, cx],
44 | [0, fy, cy],
45 | [0, 0, 1]])
46 | self.directions = get_ray_directions(h, w, self.K)
47 |
48 | def read_meta(self, split, **kwargs):
49 | # Step 2: correct poses
50 | # read extrinsics (of successfully reconstructed images)
51 | imdata = read_images_binary(os.path.join(self.root_dir, 'sparse/0/images.bin'))
52 | img_names = [imdata[k].name for k in imdata]
53 | perm = np.argsort(img_names)
54 | if '360_v2' in self.root_dir and self.downsample<1: # mipnerf360 data
55 | folder = f'images_{int(1/self.downsample)}'
56 | else:
57 | folder = 'images'
58 | # read successfully reconstructed images and ignore others
59 | img_paths = [os.path.join(self.root_dir, folder, name)
60 | for name in sorted(img_names)]
61 | w2c_mats = []
62 | bottom = np.array([[0, 0, 0, 1.]])
63 | for k in imdata:
64 | im = imdata[k]
65 | R = im.qvec2rotmat(); t = im.tvec.reshape(3, 1)
66 | w2c_mats += [np.concatenate([np.concatenate([R, t], 1), bottom], 0)]
67 | w2c_mats = np.stack(w2c_mats, 0)
68 | poses = np.linalg.inv(w2c_mats)[perm, :3] # (N_images, 3, 4) cam2world matrices
69 |
70 | pts3d = read_points3d_binary(os.path.join(self.root_dir, 'sparse/0/points3D.bin'))
71 | pts3d = np.array([pts3d[k].xyz for k in pts3d]) # (N, 3)
72 |
73 | self.poses, self.pts3d = center_poses(poses, pts3d)
74 |
75 | scale = np.linalg.norm(self.poses[..., 3], axis=-1).min()
76 | self.poses[..., 3] /= scale
77 | self.pts3d /= scale
78 |
79 | self.rays = []
80 | if split == 'test_traj': # use precomputed test poses
81 | self.poses = create_spheric_poses(1.2, self.poses[:, 1, 3].mean())
82 | self.poses = torch.FloatTensor(self.poses)
83 | return
84 |
85 | if 'HDR-NeRF' in self.root_dir: # HDR-NeRF data
86 | if 'syndata' in self.root_dir: # synthetic
87 | # first 17 are test, last 18 are train
88 | self.unit_exposure_rgb = 0.73
89 | if split=='train':
90 | img_paths = sorted(glob.glob(os.path.join(self.root_dir,
91 | f'train/*[024].png')))
92 | self.poses = np.repeat(self.poses[-18:], 3, 0)
93 | elif split=='test':
94 | img_paths = sorted(glob.glob(os.path.join(self.root_dir,
95 | f'test/*[13].png')))
96 | self.poses = np.repeat(self.poses[:17], 2, 0)
97 | else:
98 | raise ValueError(f"split {split} is invalid for HDR-NeRF!")
99 | else: # real
100 | self.unit_exposure_rgb = 0.5
101 | # even numbers are train, odd numbers are test
102 | if split=='train':
103 | img_paths = sorted(glob.glob(os.path.join(self.root_dir,
104 | f'input_images/*0.jpg')))[::2]
105 | img_paths+= sorted(glob.glob(os.path.join(self.root_dir,
106 | f'input_images/*2.jpg')))[::2]
107 | img_paths+= sorted(glob.glob(os.path.join(self.root_dir,
108 | f'input_images/*4.jpg')))[::2]
109 | self.poses = np.tile(self.poses[::2], (3, 1, 1))
110 | elif split=='test':
111 | img_paths = sorted(glob.glob(os.path.join(self.root_dir,
112 | f'input_images/*1.jpg')))[1::2]
113 | img_paths+= sorted(glob.glob(os.path.join(self.root_dir,
114 | f'input_images/*3.jpg')))[1::2]
115 | self.poses = np.tile(self.poses[1::2], (2, 1, 1))
116 | else:
117 | raise ValueError(f"split {split} is invalid for HDR-NeRF!")
118 | else:
119 | # use every 8th image as test set
120 | if split=='train':
121 | img_paths = [x for i, x in enumerate(img_paths) if i%8!=0]
122 | self.poses = np.array([x for i, x in enumerate(self.poses) if i%8!=0])
123 | elif split=='test':
124 | img_paths = [x for i, x in enumerate(img_paths) if i%8==0]
125 | self.poses = np.array([x for i, x in enumerate(self.poses) if i%8==0])
126 |
127 | print(f'Loading {len(img_paths)} {split} images ...')
128 | for img_path in tqdm(img_paths):
129 | buf = [] # buffer for ray attributes: rgb, etc
130 |
131 | img = Image.open(img_path).convert('RGB').resize(self.img_wh, Image.LANCZOS)
132 | img = rearrange(self.transform(img), 'c h w -> (h w) c')
133 | buf += [img]
134 |
135 | if 'HDR-NeRF' in self.root_dir: # get exposure
136 | folder = self.root_dir.split('/')
137 | scene = folder[-1] if folder[-1] != '' else folder[-2]
138 | if scene in ['bathroom', 'bear', 'chair', 'desk']:
139 | e_dict = {e: 1/8*4**e for e in range(5)}
140 | elif scene in ['diningroom', 'dog']:
141 | e_dict = {e: 1/16*4**e for e in range(5)}
142 | elif scene in ['sofa']:
143 | e_dict = {0:0.25, 1:1, 2:2, 3:4, 4:16}
144 | elif scene in ['sponza']:
145 | e_dict = {0:0.5, 1:2, 2:4, 3:8, 4:32}
146 | elif scene in ['box']:
147 | e_dict = {0:2/3, 1:1/3, 2:1/6, 3:0.1, 4:0.05}
148 | elif scene in ['computer']:
149 | e_dict = {0:1/3, 1:1/8, 2:1/15, 3:1/30, 4:1/60}
150 | elif scene in ['flower']:
151 | e_dict = {0:1/3, 1:1/6, 2:0.1, 3:0.05, 4:1/45}
152 | elif scene in ['luckycat']:
153 | e_dict = {0:2, 1:1, 2:0.5, 3:0.25, 4:0.125}
154 | e = int(img_path.split('.')[0][-1])
155 | buf += [e_dict[e]*torch.ones_like(img[:, :1])]
156 |
157 | self.rays += [torch.cat(buf, 1)]
158 |
159 | self.rays = torch.stack(self.rays) # (N_images, hw, ?)
160 | self.poses = torch.FloatTensor(self.poses) # (N_images, 3, 4)
--------------------------------------------------------------------------------
/datasets/colmap_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill.
2 | # All rights reserved.
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # * Redistributions of source code must retain the above copyright
8 | # notice, this list of conditions and the following disclaimer.
9 | #
10 | # * Redistributions in binary form must reproduce the above copyright
11 | # notice, this list of conditions and the following disclaimer in the
12 | # documentation and/or other materials provided with the distribution.
13 | #
14 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
15 | # its contributors may be used to endorse or promote products derived
16 | # from this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
21 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
22 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
23 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
24 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
25 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
26 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
27 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
28 | # POSSIBILITY OF SUCH DAMAGE.
29 | #
30 | # Author: Johannes L. Schoenberger (jsch at inf.ethz.ch)
31 |
32 | import os
33 | import sys
34 | import collections
35 | import numpy as np
36 | import struct
37 |
38 |
39 | CameraModel = collections.namedtuple(
40 | "CameraModel", ["model_id", "model_name", "num_params"])
41 | Camera = collections.namedtuple(
42 | "Camera", ["id", "model", "width", "height", "params"])
43 | BaseImage = collections.namedtuple(
44 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
45 | Point3D = collections.namedtuple(
46 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])
47 |
48 | class Image(BaseImage):
49 | def qvec2rotmat(self):
50 | return qvec2rotmat(self.qvec)
51 |
52 |
53 | CAMERA_MODELS = {
54 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
55 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
56 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
57 | CameraModel(model_id=3, model_name="RADIAL", num_params=5),
58 | CameraModel(model_id=4, model_name="OPENCV", num_params=8),
59 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
60 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
61 | CameraModel(model_id=7, model_name="FOV", num_params=5),
62 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
63 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
64 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12)
65 | }
66 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) \
67 | for camera_model in CAMERA_MODELS])
68 |
69 |
70 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
71 | """Read and unpack the next bytes from a binary file.
72 | :param fid:
73 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
74 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
75 | :param endian_character: Any of {@, =, <, >, !}
76 | :return: Tuple of read and unpacked values.
77 | """
78 | data = fid.read(num_bytes)
79 | return struct.unpack(endian_character + format_char_sequence, data)
80 |
81 |
82 | def read_cameras_text(path):
83 | """
84 | see: src/base/reconstruction.cc
85 | void Reconstruction::WriteCamerasText(const std::string& path)
86 | void Reconstruction::ReadCamerasText(const std::string& path)
87 | """
88 | cameras = {}
89 | with open(path, "r") as fid:
90 | while True:
91 | line = fid.readline()
92 | if not line:
93 | break
94 | line = line.strip()
95 | if len(line) > 0 and line[0] != "#":
96 | elems = line.split()
97 | camera_id = int(elems[0])
98 | model = elems[1]
99 | width = int(elems[2])
100 | height = int(elems[3])
101 | params = np.array(tuple(map(float, elems[4:])))
102 | cameras[camera_id] = Camera(id=camera_id, model=model,
103 | width=width, height=height,
104 | params=params)
105 | return cameras
106 |
107 |
108 | def read_cameras_binary(path_to_model_file):
109 | """
110 | see: src/base/reconstruction.cc
111 | void Reconstruction::WriteCamerasBinary(const std::string& path)
112 | void Reconstruction::ReadCamerasBinary(const std::string& path)
113 | """
114 | cameras = {}
115 | with open(path_to_model_file, "rb") as fid:
116 | num_cameras = read_next_bytes(fid, 8, "Q")[0]
117 | for camera_line_index in range(num_cameras):
118 | camera_properties = read_next_bytes(
119 | fid, num_bytes=24, format_char_sequence="iiQQ")
120 | camera_id = camera_properties[0]
121 | model_id = camera_properties[1]
122 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
123 | width = camera_properties[2]
124 | height = camera_properties[3]
125 | num_params = CAMERA_MODEL_IDS[model_id].num_params
126 | params = read_next_bytes(fid, num_bytes=8*num_params,
127 | format_char_sequence="d"*num_params)
128 | cameras[camera_id] = Camera(id=camera_id,
129 | model=model_name,
130 | width=width,
131 | height=height,
132 | params=np.array(params))
133 | assert len(cameras) == num_cameras
134 | return cameras
135 |
136 |
137 | def read_images_text(path):
138 | """
139 | see: src/base/reconstruction.cc
140 | void Reconstruction::ReadImagesText(const std::string& path)
141 | void Reconstruction::WriteImagesText(const std::string& path)
142 | """
143 | images = {}
144 | with open(path, "r") as fid:
145 | while True:
146 | line = fid.readline()
147 | if not line:
148 | break
149 | line = line.strip()
150 | if len(line) > 0 and line[0] != "#":
151 | elems = line.split()
152 | image_id = int(elems[0])
153 | qvec = np.array(tuple(map(float, elems[1:5])))
154 | tvec = np.array(tuple(map(float, elems[5:8])))
155 | camera_id = int(elems[8])
156 | image_name = elems[9]
157 | elems = fid.readline().split()
158 | xys = np.column_stack([tuple(map(float, elems[0::3])),
159 | tuple(map(float, elems[1::3]))])
160 | point3D_ids = np.array(tuple(map(int, elems[2::3])))
161 | images[image_id] = Image(
162 | id=image_id, qvec=qvec, tvec=tvec,
163 | camera_id=camera_id, name=image_name,
164 | xys=xys, point3D_ids=point3D_ids)
165 | return images
166 |
167 |
168 | def read_images_binary(path_to_model_file):
169 | """
170 | see: src/base/reconstruction.cc
171 | void Reconstruction::ReadImagesBinary(const std::string& path)
172 | void Reconstruction::WriteImagesBinary(const std::string& path)
173 | """
174 | images = {}
175 | with open(path_to_model_file, "rb") as fid:
176 | num_reg_images = read_next_bytes(fid, 8, "Q")[0]
177 | for image_index in range(num_reg_images):
178 | binary_image_properties = read_next_bytes(
179 | fid, num_bytes=64, format_char_sequence="idddddddi")
180 | image_id = binary_image_properties[0]
181 | qvec = np.array(binary_image_properties[1:5])
182 | tvec = np.array(binary_image_properties[5:8])
183 | camera_id = binary_image_properties[8]
184 | image_name = ""
185 | current_char = read_next_bytes(fid, 1, "c")[0]
186 | while current_char != b"\x00": # look for the ASCII 0 entry
187 | image_name += current_char.decode("utf-8")
188 | current_char = read_next_bytes(fid, 1, "c")[0]
189 | num_points2D = read_next_bytes(fid, num_bytes=8,
190 | format_char_sequence="Q")[0]
191 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,
192 | format_char_sequence="ddq"*num_points2D)
193 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
194 | tuple(map(float, x_y_id_s[1::3]))])
195 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
196 | images[image_id] = Image(
197 | id=image_id, qvec=qvec, tvec=tvec,
198 | camera_id=camera_id, name=image_name,
199 | xys=xys, point3D_ids=point3D_ids)
200 | return images
201 |
202 |
203 | def read_points3D_text(path):
204 | """
205 | see: src/base/reconstruction.cc
206 | void Reconstruction::ReadPoints3DText(const std::string& path)
207 | void Reconstruction::WritePoints3DText(const std::string& path)
208 | """
209 | points3D = {}
210 | with open(path, "r") as fid:
211 | while True:
212 | line = fid.readline()
213 | if not line:
214 | break
215 | line = line.strip()
216 | if len(line) > 0 and line[0] != "#":
217 | elems = line.split()
218 | point3D_id = int(elems[0])
219 | xyz = np.array(tuple(map(float, elems[1:4])))
220 | rgb = np.array(tuple(map(int, elems[4:7])))
221 | error = float(elems[7])
222 | image_ids = np.array(tuple(map(int, elems[8::2])))
223 | point2D_idxs = np.array(tuple(map(int, elems[9::2])))
224 | points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb,
225 | error=error, image_ids=image_ids,
226 | point2D_idxs=point2D_idxs)
227 | return points3D
228 |
229 |
230 | def read_points3d_binary(path_to_model_file):
231 | """
232 | see: src/base/reconstruction.cc
233 | void Reconstruction::ReadPoints3DBinary(const std::string& path)
234 | void Reconstruction::WritePoints3DBinary(const std::string& path)
235 | """
236 | points3D = {}
237 | with open(path_to_model_file, "rb") as fid:
238 | num_points = read_next_bytes(fid, 8, "Q")[0]
239 | for point_line_index in range(num_points):
240 | binary_point_line_properties = read_next_bytes(
241 | fid, num_bytes=43, format_char_sequence="QdddBBBd")
242 | point3D_id = binary_point_line_properties[0]
243 | xyz = np.array(binary_point_line_properties[1:4])
244 | rgb = np.array(binary_point_line_properties[4:7])
245 | error = np.array(binary_point_line_properties[7])
246 | track_length = read_next_bytes(
247 | fid, num_bytes=8, format_char_sequence="Q")[0]
248 | track_elems = read_next_bytes(
249 | fid, num_bytes=8*track_length,
250 | format_char_sequence="ii"*track_length)
251 | image_ids = np.array(tuple(map(int, track_elems[0::2])))
252 | point2D_idxs = np.array(tuple(map(int, track_elems[1::2])))
253 | points3D[point3D_id] = Point3D(
254 | id=point3D_id, xyz=xyz, rgb=rgb,
255 | error=error, image_ids=image_ids,
256 | point2D_idxs=point2D_idxs)
257 | return points3D
258 |
259 |
260 | def read_model(path, ext):
261 | if ext == ".txt":
262 | cameras = read_cameras_text(os.path.join(path, "cameras" + ext))
263 | images = read_images_text(os.path.join(path, "images" + ext))
264 | points3D = read_points3D_text(os.path.join(path, "points3D") + ext)
265 | else:
266 | cameras = read_cameras_binary(os.path.join(path, "cameras" + ext))
267 | images = read_images_binary(os.path.join(path, "images" + ext))
268 | points3D = read_points3d_binary(os.path.join(path, "points3D") + ext)
269 | return cameras, images, points3D
270 |
271 |
272 | def qvec2rotmat(qvec):
273 | return np.array([
274 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
275 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
276 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
277 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
278 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
279 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
280 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
281 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
282 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])
283 |
284 |
285 | def rotmat2qvec(R):
286 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
287 | K = np.array([
288 | [Rxx - Ryy - Rzz, 0, 0, 0],
289 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
290 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
291 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
292 | eigvals, eigvecs = np.linalg.eigh(K)
293 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
294 | if qvec[0] < 0:
295 | qvec *= -1
296 | return qvec
--------------------------------------------------------------------------------
/datasets/depth_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import re
3 |
4 |
5 | def read_pfm(path):
6 | """Read pfm file.
7 |
8 | Args:
9 | path (str): path to file
10 |
11 | Returns:
12 | tuple: (data, scale)
13 | """
14 | with open(path, "rb") as file:
15 |
16 | color = None
17 | width = None
18 | height = None
19 | scale = None
20 | endian = None
21 |
22 | header = file.readline().rstrip()
23 | if header.decode("ascii") == "PF":
24 | color = True
25 | elif header.decode("ascii") == "Pf":
26 | color = False
27 | else:
28 | raise Exception("Not a PFM file: " + path)
29 |
30 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
31 | if dim_match:
32 | width, height = list(map(int, dim_match.groups()))
33 | else:
34 | raise Exception("Malformed PFM header.")
35 |
36 | scale = float(file.readline().decode("ascii").rstrip())
37 | if scale < 0:
38 | # little-endian
39 | endian = "<"
40 | scale = -scale
41 | else:
42 | # big-endian
43 | endian = ">"
44 |
45 | data = np.fromfile(file, endian + "f")
46 | shape = (height, width, 3) if color else (height, width)
47 |
48 | data = np.reshape(data, shape)
49 | data = np.flipud(data)
50 |
51 | return data, scale
--------------------------------------------------------------------------------
/datasets/nerfpp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import glob
3 | import numpy as np
4 | import os
5 | from PIL import Image
6 | from einops import rearrange
7 | from tqdm import tqdm
8 |
9 | from .ray_utils import get_ray_directions
10 |
11 | from .base import BaseDataset
12 |
13 |
14 | class NeRFPPDataset(BaseDataset):
15 | def __init__(self, root_dir, split='train', downsample=1.0, **kwargs):
16 | super().__init__(root_dir, split, downsample)
17 |
18 | self.read_intrinsics()
19 |
20 | if kwargs.get('read_meta', True):
21 | self.read_meta(split)
22 |
23 | def read_intrinsics(self):
24 | K = np.loadtxt(glob.glob(os.path.join(self.root_dir, 'train/intrinsics/*.txt'))[0],
25 | dtype=np.float32).reshape(4, 4)[:3, :3]
26 | K[:2] *= self.downsample
27 | w, h = Image.open(glob.glob(os.path.join(self.root_dir, 'train/rgb/*'))[0]).size
28 | w, h = int(w*self.downsample), int(h*self.downsample)
29 | self.K = torch.FloatTensor(K)
30 | self.directions = get_ray_directions(h, w, self.K)
31 | self.img_wh = (w, h)
32 |
33 | def read_meta(self, split):
34 | self.rays = []
35 | self.poses = []
36 |
37 | if split == 'test_traj':
38 | poses_path = \
39 | sorted(glob.glob(os.path.join(self.root_dir, 'camera_path/pose/*.txt')))
40 | self.poses = [np.loadtxt(p).reshape(4, 4)[:3] for p in poses_path]
41 | else:
42 | if split=='trainval':
43 | imgs = sorted(glob.glob(os.path.join(self.root_dir, 'train/rgb/*')))+\
44 | sorted(glob.glob(os.path.join(self.root_dir, 'val/rgb/*')))
45 | poses = sorted(glob.glob(os.path.join(self.root_dir, 'train/pose/*.txt')))+\
46 | sorted(glob.glob(os.path.join(self.root_dir, 'val/pose/*.txt')))
47 | else:
48 | imgs = sorted(glob.glob(os.path.join(self.root_dir, split, 'rgb/*')))
49 | poses = sorted(glob.glob(os.path.join(self.root_dir, split, 'pose/*.txt')))
50 |
51 | print(f'Loading {len(imgs)} {split} images ...')
52 | for img, pose in tqdm(zip(imgs, poses)):
53 | self.poses += [np.loadtxt(pose).reshape(4, 4)[:3]]
54 |
55 | img = Image.open(img).convert('RGB').resize(self.img_wh, Image.LANCZOS)
56 | img = rearrange(self.transform(img), 'c h w -> (h w) c')
57 |
58 | self.rays += [img]
59 |
60 | self.rays = torch.stack(self.rays) # (N_images, hw, ?)
61 | self.poses = torch.FloatTensor(self.poses) # (N_images, 3, 4)
62 |
--------------------------------------------------------------------------------
/datasets/ray_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from kornia import create_meshgrid
4 | from einops import rearrange
5 |
6 |
7 | @torch.cuda.amp.autocast(dtype=torch.float32)
8 | def get_ray_directions(H, W, K, device='cpu', random=False, return_uv=False, flatten=True):
9 | """
10 | Get ray directions for all pixels in camera coordinate [right down front].
11 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
12 | ray-tracing-generating-camera-rays/standard-coordinate-systems
13 |
14 | Inputs:
15 | H, W: image height and width
16 | K: (3, 3) camera intrinsics
17 | random: whether the ray passes randomly inside the pixel
18 | return_uv: whether to return uv image coordinates
19 |
20 | Outputs: (shape depends on @flatten)
21 | directions: (H, W, 3) or (H*W, 3), the direction of the rays in camera coordinate
22 | uv: (H, W, 2) or (H*W, 2) image coordinates
23 | """
24 | grid = create_meshgrid(H, W, False, device=device)[0] # (H, W, 2)
25 | u, v = grid.unbind(-1)
26 |
27 | fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2]
28 | if random:
29 | directions = \
30 | torch.stack([(u-cx+torch.rand_like(u))/fx,
31 | (v-cy+torch.rand_like(v))/fy,
32 | torch.ones_like(u)], -1)
33 | else: # pass by the center
34 | directions = \
35 | torch.stack([(u-cx+0.5)/fx, (v-cy+0.5)/fy, torch.ones_like(u)], -1)
36 | if flatten:
37 | directions = directions.reshape(-1, 3)
38 | grid = grid.reshape(-1, 2)
39 |
40 | if return_uv:
41 | return directions, grid
42 | return directions
43 |
44 |
45 | @torch.cuda.amp.autocast(dtype=torch.float32)
46 | def get_rays(directions, c2w):
47 | """
48 | Get ray origin and directions in world coordinate for all pixels in one image.
49 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
50 | ray-tracing-generating-camera-rays/standard-coordinate-systems
51 |
52 | Inputs:
53 | directions: (N, 3) ray directions in camera coordinate
54 | c2w: (3, 4) or (N, 3, 4) transformation matrix from camera coordinate to world coordinate
55 |
56 | Outputs:
57 | rays_o: (N, 3), the origin of the rays in world coordinate
58 | rays_d: (N, 3), the direction of the rays in world coordinate
59 | """
60 | if c2w.ndim==2:
61 | # Rotate ray directions from camera coordinate to the world coordinate
62 | rays_d = directions @ c2w[:, :3].T
63 | else:
64 | rays_d = rearrange(directions, 'n c -> n 1 c') @ c2w[..., :3].mT
65 | rays_d = rearrange(rays_d, 'n 1 c -> n c')
66 | # The origin of all rays is the camera origin in world coordinate
67 | rays_o = c2w[..., 3].expand_as(rays_d)
68 |
69 | return rays_o, rays_d
70 |
71 |
72 | @torch.cuda.amp.autocast(dtype=torch.float32)
73 | def axisangle_to_R(v):
74 | """
75 | Convert an axis-angle vector to rotation matrix
76 | from https://github.com/ActiveVisionLab/nerfmm/blob/main/utils/lie_group_helper.py#L47
77 |
78 | Inputs:
79 | v: (B, 3)
80 |
81 | Outputs:
82 | R: (B, 3, 3)
83 | """
84 | zero = torch.zeros_like(v[:, :1]) # (B, 1)
85 | skew_v0 = torch.cat([zero, -v[:, 2:3], v[:, 1:2]], 1) # (B, 3)
86 | skew_v1 = torch.cat([v[:, 2:3], zero, -v[:, 0:1]], 1)
87 | skew_v2 = torch.cat([-v[:, 1:2], v[:, 0:1], zero], 1)
88 | skew_v = torch.stack([skew_v0, skew_v1, skew_v2], dim=1) # (B, 3, 3)
89 |
90 | norm_v = rearrange(torch.norm(v, dim=1)+1e-7, 'b -> b 1 1')
91 | eye = torch.eye(3, device=v.device)
92 | R = eye + (torch.sin(norm_v)/norm_v)*skew_v + \
93 | ((1-torch.cos(norm_v))/norm_v**2)*(skew_v@skew_v)
94 | return R
95 |
96 |
97 | def normalize(v):
98 | """Normalize a vector."""
99 | return v/np.linalg.norm(v)
100 |
101 |
102 | def average_poses(poses, pts3d=None):
103 | """
104 | Calculate the average pose, which is then used to center all poses
105 | using @center_poses. Its computation is as follows:
106 | 1. Compute the center: the average of 3d point cloud (if None, center of cameras).
107 | 2. Compute the z axis: the normalized average z axis.
108 | 3. Compute axis y': the average y axis.
109 | 4. Compute x' = y' cross product z, then normalize it as the x axis.
110 | 5. Compute the y axis: z cross product x.
111 |
112 | Note that at step 3, we cannot directly use y' as y axis since it's
113 | not necessarily orthogonal to z axis. We need to pass from x to y.
114 | Inputs:
115 | poses: (N_images, 3, 4)
116 | pts3d: (N, 3)
117 |
118 | Outputs:
119 | pose_avg: (3, 4) the average pose
120 | """
121 | # 1. Compute the center
122 | if pts3d is not None:
123 | center = pts3d.mean(0)
124 | else:
125 | center = poses[..., 3].mean(0)
126 |
127 | # 2. Compute the z axis
128 | z = normalize(poses[..., 2].mean(0)) # (3)
129 |
130 | # 3. Compute axis y' (no need to normalize as it's not the final output)
131 | y_ = poses[..., 1].mean(0) # (3)
132 |
133 | # 4. Compute the x axis
134 | x = normalize(np.cross(y_, z)) # (3)
135 |
136 | # 5. Compute the y axis (as z and x are normalized, y is already of norm 1)
137 | y = np.cross(z, x) # (3)
138 |
139 | pose_avg = np.stack([x, y, z, center], 1) # (3, 4)
140 |
141 | return pose_avg
142 |
143 |
144 | def center_poses(poses, pts3d=None):
145 | """
146 | See https://github.com/bmild/nerf/issues/34
147 | Inputs:
148 | poses: (N_images, 3, 4)
149 | pts3d: (N, 3) reconstructed point cloud
150 |
151 | Outputs:
152 | poses_centered: (N_images, 3, 4) the centered poses
153 | pts3d_centered: (N, 3) centered point cloud
154 | """
155 |
156 | pose_avg = average_poses(poses, pts3d) # (3, 4)
157 | pose_avg_homo = np.eye(4)
158 | pose_avg_homo[:3] = pose_avg # convert to homogeneous coordinate for faster computation
159 | # by simply adding 0, 0, 0, 1 as the last row
160 | pose_avg_inv = np.linalg.inv(pose_avg_homo)
161 | last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1)) # (N_images, 1, 4)
162 | poses_homo = \
163 | np.concatenate([poses, last_row], 1) # (N_images, 4, 4) homogeneous coordinate
164 |
165 | poses_centered = pose_avg_inv @ poses_homo # (N_images, 4, 4)
166 | poses_centered = poses_centered[:, :3] # (N_images, 3, 4)
167 |
168 | if pts3d is not None:
169 | pts3d_centered = pts3d @ pose_avg_inv[:, :3].T + pose_avg_inv[:, 3:].T
170 | return poses_centered, pts3d_centered
171 |
172 | return poses_centered
173 |
174 | def create_spheric_poses(radius, mean_h, n_poses=120):
175 | """
176 | Create circular poses around z axis.
177 | Inputs:
178 | radius: the (negative) height and the radius of the circle.
179 | mean_h: mean camera height
180 | Outputs:
181 | spheric_poses: (n_poses, 3, 4) the poses in the circular path
182 | """
183 | def spheric_pose(theta, phi, radius):
184 | trans_t = lambda t : np.array([
185 | [1,0,0,0],
186 | [0,1,0,2*mean_h],
187 | [0,0,1,-t]
188 | ])
189 |
190 | rot_phi = lambda phi : np.array([
191 | [1,0,0],
192 | [0,np.cos(phi),-np.sin(phi)],
193 | [0,np.sin(phi), np.cos(phi)]
194 | ])
195 |
196 | rot_theta = lambda th : np.array([
197 | [np.cos(th),0,-np.sin(th)],
198 | [0,1,0],
199 | [np.sin(th),0, np.cos(th)]
200 | ])
201 |
202 | c2w = rot_theta(theta) @ rot_phi(phi) @ trans_t(radius)
203 | c2w = np.array([[-1,0,0],[0,0,1],[0,1,0]]) @ c2w
204 | return c2w
205 |
206 | spheric_poses = []
207 | for th in np.linspace(0, 2*np.pi, n_poses+1)[:-1]:
208 | spheric_poses += [spheric_pose(th, -np.pi/12, radius)]
209 | return np.stack(spheric_poses, 0)
--------------------------------------------------------------------------------
/datasets/rtmv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import glob
3 | import json
4 | #### Under construction. Don't use now
5 |
6 | import numpy as np
7 | import os
8 | import imageio
9 | import cv2
10 | from einops import rearrange
11 | from tqdm import tqdm
12 |
13 | from .ray_utils import get_ray_directions
14 |
15 | from .base import BaseDataset
16 |
17 |
18 | def srgb_to_linear(img):
19 | limit = 0.04045
20 | return np.where(img > limit, np.power((img + 0.055) / 1.055, 2.4), img / 12.92)
21 |
22 |
23 | def linear_to_srgb(img):
24 | limit = 0.0031308
25 | return np.where(img > limit, 1.055 * (img ** (1.0 / 2.4)) - 0.055, 12.92 * img)
26 |
27 |
28 | class RTMVDataset(BaseDataset):
29 | def __init__(self, root_dir, split='train', downsample=1.0, **kwargs):
30 | super().__init__(root_dir, split, downsample)
31 |
32 | with open(os.path.join(self.root_dir, '00000.json'), 'r') as f:
33 | meta = json.load(f)['camera_data']
34 | self.shift = np.array(meta['scene_center_3d_box'])
35 | self.scale = (np.array(meta['scene_max_3d_box'])-
36 | np.array(meta['scene_min_3d_box'])).max()/2 * 1.05 # enlarge a little
37 |
38 | fx = meta['intrinsics']['fx'] * downsample
39 | fy = meta['intrinsics']['fy'] * downsample
40 | cx = meta['intrinsics']['cx'] * downsample
41 | cy = meta['intrinsics']['cy'] * downsample
42 | w = int(meta['width']*downsample)
43 | h = int(meta['height']*downsample)
44 | K = np.float32([[fx, 0, cx],
45 | [0, fy, cy],
46 | [0, 0, 1]])
47 | self.K = torch.FloatTensor(K)
48 | self.directions = get_ray_directions(h, w, self.K)
49 | self.img_wh = (w, h)
50 |
51 | self.read_meta(split)
52 |
53 | def read_meta(self, split):
54 | self.rays = []
55 | self.poses = []
56 |
57 | if split == 'train': start_idx, end_idx = 0, 100
58 | elif split == 'trainval': start_idx, end_idx = 0, 105
59 | elif split == 'test': start_idx, end_idx = 105, 150
60 | else: raise ValueError(f'{split} split not recognized!')
61 | imgs = sorted(glob.glob(os.path.join(self.root_dir, '*[0-9].exr')))[start_idx:end_idx]
62 | poses = sorted(glob.glob(os.path.join(self.root_dir, '*.json')))[start_idx:end_idx]
63 |
64 | print(f'Loading {len(imgs)} {split} images ...')
65 | for img, pose in tqdm(zip(imgs, poses)):
66 | with open(pose, 'r') as f:
67 | m = json.load(f)['camera_data']
68 | c2w = np.zeros((3, 4), dtype=np.float32)
69 | c2w[:, :3] = -np.array(m['cam2world'])[:3, :3].T
70 | c2w[:, 3] = np.array(m['location_world'])-self.shift
71 | c2w[:, 3] /= 2*self.scale # to bound the scene inside [-0.5, 0.5]
72 | self.poses += [c2w]
73 |
74 | img = imageio.imread(img)[..., :3]
75 | img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_LANCZOS4)
76 | # img = np.clip(linear_to_srgb(img), 0, 1)
77 | img = rearrange(torch.FloatTensor(img), 'h w c -> (h w) c')
78 |
79 | self.rays += [img]
80 |
81 | self.rays = torch.stack(self.rays) # (N_images, hw, ?)
82 | self.poses = torch.FloatTensor(self.poses) # (N_images, 3, 4)
83 |
--------------------------------------------------------------------------------
/ldm/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/ldm/data/__init__.py
--------------------------------------------------------------------------------
/ldm/data/base.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
3 |
4 |
5 | class Txt2ImgIterableBaseDataset(IterableDataset):
6 | '''
7 | Define an interface to make the IterableDatasets for text2img data chainable
8 | '''
9 | def __init__(self, num_records=0, valid_ids=None, size=256):
10 | super().__init__()
11 | self.num_records = num_records
12 | self.valid_ids = valid_ids
13 | self.sample_ids = valid_ids
14 | self.size = size
15 |
16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
17 |
18 | def __len__(self):
19 | return self.num_records
20 |
21 | @abstractmethod
22 | def __iter__(self):
23 | pass
--------------------------------------------------------------------------------
/ldm/data/personalized.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import PIL
4 | from PIL import Image
5 | import matplotlib.pyplot as plt
6 | import torch
7 | from torch.utils.data import Dataset
8 | from torchvision import transforms
9 |
10 | import random
11 |
12 | imagenet_templates_smallest = [
13 | 'a photo of a {}',
14 | ]
15 |
16 | imagenet_templates_small = [
17 | 'a photo of a {}',
18 | 'a rendering of a {}',
19 | 'a cropped photo of the {}',
20 | 'the photo of a {}',
21 | 'a photo of a clean {}',
22 | 'a photo of a dirty {}',
23 | 'a dark photo of the {}',
24 | 'a photo of my {}',
25 | 'a photo of the cool {}',
26 | 'a close-up photo of a {}',
27 | 'a bright photo of the {}',
28 | 'a cropped photo of a {}',
29 | 'a photo of the {}',
30 | 'a good photo of the {}',
31 | 'a photo of one {}',
32 | 'a close-up photo of the {}',
33 | 'a rendition of the {}',
34 | 'a photo of the clean {}',
35 | 'a rendition of a {}',
36 | 'a photo of a nice {}',
37 | 'a good photo of a {}',
38 | 'a photo of the nice {}',
39 | 'a photo of the small {}',
40 | 'a photo of the weird {}',
41 | 'a photo of the large {}',
42 | 'a photo of a cool {}',
43 | 'a photo of a small {}',
44 | 'an illustration of a {}',
45 | 'a rendering of a {}',
46 | 'a cropped photo of the {}',
47 | 'the photo of a {}',
48 | 'an illustration of a clean {}',
49 | 'an illustration of a dirty {}',
50 | 'a dark photo of the {}',
51 | 'an illustration of my {}',
52 | 'an illustration of the cool {}',
53 | 'a close-up photo of a {}',
54 | 'a bright photo of the {}',
55 | 'a cropped photo of a {}',
56 | 'an illustration of the {}',
57 | 'a good photo of the {}',
58 | 'an illustration of one {}',
59 | 'a close-up photo of the {}',
60 | 'a rendition of the {}',
61 | 'an illustration of the clean {}',
62 | 'a rendition of a {}',
63 | 'an illustration of a nice {}',
64 | 'a good photo of a {}',
65 | 'an illustration of the nice {}',
66 | 'an illustration of the small {}',
67 | 'an illustration of the weird {}',
68 | 'an illustration of the large {}',
69 | 'an illustration of a cool {}',
70 | 'an illustration of a small {}',
71 | 'a depiction of a {}',
72 | 'a rendering of a {}',
73 | 'a cropped photo of the {}',
74 | 'the photo of a {}',
75 | 'a depiction of a clean {}',
76 | 'a depiction of a dirty {}',
77 | 'a dark photo of the {}',
78 | 'a depiction of my {}',
79 | 'a depiction of the cool {}',
80 | 'a close-up photo of a {}',
81 | 'a bright photo of the {}',
82 | 'a cropped photo of a {}',
83 | 'a depiction of the {}',
84 | 'a good photo of the {}',
85 | 'a depiction of one {}',
86 | 'a close-up photo of the {}',
87 | 'a rendition of the {}',
88 | 'a depiction of the clean {}',
89 | 'a rendition of a {}',
90 | 'a depiction of a nice {}',
91 | 'a good photo of a {}',
92 | 'a depiction of the nice {}',
93 | 'a depiction of the small {}',
94 | 'a depiction of the weird {}',
95 | 'a depiction of the large {}',
96 | 'a depiction of a cool {}',
97 | 'a depiction of a small {}',
98 | ]
99 |
100 | imagenet_dual_templates_small = [
101 | 'a photo of a {} with {}',
102 | 'a rendering of a {} with {}',
103 | 'a cropped photo of the {} with {}',
104 | 'the photo of a {} with {}',
105 | 'a photo of a clean {} with {}',
106 | 'a photo of a dirty {} with {}',
107 | 'a dark photo of the {} with {}',
108 | 'a photo of my {} with {}',
109 | 'a photo of the cool {} with {}',
110 | 'a close-up photo of a {} with {}',
111 | 'a bright photo of the {} with {}',
112 | 'a cropped photo of a {} with {}',
113 | 'a photo of the {} with {}',
114 | 'a good photo of the {} with {}',
115 | 'a photo of one {} with {}',
116 | 'a close-up photo of the {} with {}',
117 | 'a rendition of the {} with {}',
118 | 'a photo of the clean {} with {}',
119 | 'a rendition of a {} with {}',
120 | 'a photo of a nice {} with {}',
121 | 'a good photo of a {} with {}',
122 | 'a photo of the nice {} with {}',
123 | 'a photo of the small {} with {}',
124 | 'a photo of the weird {} with {}',
125 | 'a photo of the large {} with {}',
126 | 'a photo of a cool {} with {}',
127 | 'a photo of a small {} with {}',
128 | ]
129 |
130 | per_img_token_list = [
131 | 'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת',
132 | ]
133 |
134 | class PersonalizedBase(Dataset):
135 | def __init__(self,
136 | data_root,
137 | size=None,
138 | repeats=100,
139 | interpolation="bicubic",
140 | flip_p=0.5,
141 | set="train",
142 | placeholder_token="*",
143 | per_image_tokens=False,
144 | center_crop=False,
145 | mixing_prob=0.25,
146 | coarse_class_text=None,
147 | ):
148 |
149 | self.data_root = data_root
150 |
151 | self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
152 |
153 | # self._length = len(self.image_paths)
154 | self.num_images = len(self.image_paths)
155 | self._length = self.num_images
156 |
157 | self.placeholder_token = placeholder_token
158 |
159 | self.per_image_tokens = per_image_tokens
160 | self.center_crop = center_crop
161 | self.mixing_prob = mixing_prob
162 |
163 | self.coarse_class_text = coarse_class_text
164 |
165 | if per_image_tokens:
166 | assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
167 |
168 | if set == "train":
169 | self._length = self.num_images * repeats
170 |
171 | self.size = size
172 | self.interpolation = {
173 | "bilinear": Image.Resampling.BILINEAR,
174 | "bicubic": Image.Resampling.BICUBIC,
175 | "lanczos": Image.Resampling.LANCZOS,
176 | }[interpolation]
177 | self.flip = transforms.RandomHorizontalFlip(p=flip_p)
178 |
179 | def __len__(self):
180 | return self._length
181 |
182 |
183 | def __getitem__(self, i):
184 | example = {}
185 | img = plt.imread(self.image_paths[i % self.num_images])
186 |
187 | placeholder_string = self.placeholder_token
188 | if self.coarse_class_text:
189 | placeholder_string = f"{self.coarse_class_text} {placeholder_string}"
190 |
191 | if self.per_image_tokens and np.random.uniform() < self.mixing_prob:
192 | text = random.choice(imagenet_dual_templates_small).format(placeholder_string, per_img_token_list[i % self.num_images])
193 | else:
194 | text = random.choice(imagenet_templates_small).format(placeholder_string)
195 |
196 | # default to score-sde preprocessing
197 | # alpha channel affects
198 | if img.shape[-1] == 4:
199 | img = img[:,:,:3]*img[:,:, -1:]+(1-img[:,:, -1:])
200 |
201 | if self.center_crop:
202 | crop = min(img.shape[0], img.shape[1])
203 | h, w, = img.shape[0], img.shape[1]
204 | img = img[(h - crop) // 2:(h + crop) // 2,
205 | (w - crop) // 2:(w + crop) // 2]
206 | image = Image.fromarray(np.uint8(img[:, :, :3] * 255.)).convert("RGB")
207 | if self.size is not None:
208 | image = image.resize((self.size, self.size), resample=self.interpolation)
209 |
210 | image = self.flip(image)
211 | image = np.array(image).astype(np.uint8)
212 | example["jpg"] = (image / 127.5 - 1.0).astype(np.float32)
213 | example["txt"] = text
214 | example["hint"] = (image / 255.0).astype(np.float32)
215 | example["delta_pose"] = torch.tensor([0., 0., 0.])
216 | return example
--------------------------------------------------------------------------------
/ldm/data/util.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from ldm.modules.midas.api import load_midas_transform
4 |
5 |
6 | class AddMiDaS(object):
7 | def __init__(self, model_type):
8 | super().__init__()
9 | self.transform = load_midas_transform(model_type)
10 |
11 | def pt2np(self, x):
12 | x = ((x + 1.0) * .5).detach().cpu().numpy()
13 | return x
14 |
15 | def np2pt(self, x):
16 | x = torch.from_numpy(x) * 2 - 1.
17 | return x
18 |
19 | def __call__(self, sample):
20 | # sample['jpg'] is tensor hwc in [-1, 1] at this point
21 | x = self.pt2np(sample['jpg'])
22 | x = self.transform({"image": x})["image"]
23 | sample['midas_in'] = x
24 | return sample
--------------------------------------------------------------------------------
/ldm/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class LambdaWarmUpCosineScheduler:
5 | """
6 | note: use with a base_lr of 1.0
7 | """
8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9 | self.lr_warm_up_steps = warm_up_steps
10 | self.lr_start = lr_start
11 | self.lr_min = lr_min
12 | self.lr_max = lr_max
13 | self.lr_max_decay_steps = max_decay_steps
14 | self.last_lr = 0.
15 | self.verbosity_interval = verbosity_interval
16 |
17 | def schedule(self, n, **kwargs):
18 | if self.verbosity_interval > 0:
19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20 | if n < self.lr_warm_up_steps:
21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22 | self.last_lr = lr
23 | return lr
24 | else:
25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26 | t = min(t, 1.0)
27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28 | 1 + np.cos(t * np.pi))
29 | self.last_lr = lr
30 | return lr
31 |
32 | def __call__(self, n, **kwargs):
33 | return self.schedule(n,**kwargs)
34 |
35 |
36 | class LambdaWarmUpCosineScheduler2:
37 | """
38 | supports repeated iterations, configurable via lists
39 | note: use with a base_lr of 1.0.
40 | """
41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
43 | self.lr_warm_up_steps = warm_up_steps
44 | self.f_start = f_start
45 | self.f_min = f_min
46 | self.f_max = f_max
47 | self.cycle_lengths = cycle_lengths
48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
49 | self.last_f = 0.
50 | self.verbosity_interval = verbosity_interval
51 |
52 | def find_in_interval(self, n):
53 | interval = 0
54 | for cl in self.cum_cycles[1:]:
55 | if n <= cl:
56 | return interval
57 | interval += 1
58 |
59 | def schedule(self, n, **kwargs):
60 | cycle = self.find_in_interval(n)
61 | n = n - self.cum_cycles[cycle]
62 | if self.verbosity_interval > 0:
63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
64 | f"current cycle {cycle}")
65 | if n < self.lr_warm_up_steps[cycle]:
66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
67 | self.last_f = f
68 | return f
69 | else:
70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
71 | t = min(t, 1.0)
72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
73 | 1 + np.cos(t * np.pi))
74 | self.last_f = f
75 | return f
76 |
77 | def __call__(self, n, **kwargs):
78 | return self.schedule(n, **kwargs)
79 |
80 |
81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
82 |
83 | def schedule(self, n, **kwargs):
84 | cycle = self.find_in_interval(n)
85 | n = n - self.cum_cycles[cycle]
86 | if self.verbosity_interval > 0:
87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
88 | f"current cycle {cycle}")
89 |
90 | if n < self.lr_warm_up_steps[cycle]:
91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
92 | self.last_f = f
93 | return f
94 | else:
95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
96 | self.last_f = f
97 | return f
98 |
99 |
--------------------------------------------------------------------------------
/ldm/models/autoencoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pytorch_lightning as pl
3 | import torch.nn.functional as F
4 | from contextlib import contextmanager
5 |
6 | from ldm.modules.diffusionmodules.model import Encoder, Decoder
7 | from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
8 |
9 | from ldm.util import instantiate_from_config
10 | from ldm.modules.ema import LitEma
11 |
12 |
13 | class AutoencoderKL(pl.LightningModule):
14 | def __init__(self,
15 | ddconfig,
16 | lossconfig,
17 | embed_dim,
18 | ckpt_path=None,
19 | ignore_keys=[],
20 | image_key="image",
21 | colorize_nlabels=None,
22 | monitor=None,
23 | ema_decay=None,
24 | learn_logvar=False
25 | ):
26 | super().__init__()
27 | self.learn_logvar = learn_logvar
28 | self.image_key = image_key
29 | self.encoder = Encoder(**ddconfig)
30 | self.decoder = Decoder(**ddconfig)
31 | self.loss = instantiate_from_config(lossconfig)
32 | assert ddconfig["double_z"]
33 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
34 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
35 | self.embed_dim = embed_dim
36 | if colorize_nlabels is not None:
37 | assert type(colorize_nlabels)==int
38 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
39 | if monitor is not None:
40 | self.monitor = monitor
41 |
42 | self.use_ema = ema_decay is not None
43 | if self.use_ema:
44 | self.ema_decay = ema_decay
45 | assert 0. < ema_decay < 1.
46 | self.model_ema = LitEma(self, decay=ema_decay)
47 | print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
48 |
49 | if ckpt_path is not None:
50 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
51 |
52 | def init_from_ckpt(self, path, ignore_keys=list()):
53 | sd = torch.load(path, map_location="cpu")["state_dict"]
54 | keys = list(sd.keys())
55 | for k in keys:
56 | for ik in ignore_keys:
57 | if k.startswith(ik):
58 | print("Deleting key {} from state_dict.".format(k))
59 | del sd[k]
60 | self.load_state_dict(sd, strict=False)
61 | print(f"Restored from {path}")
62 |
63 | @contextmanager
64 | def ema_scope(self, context=None):
65 | if self.use_ema:
66 | self.model_ema.store(self.parameters())
67 | self.model_ema.copy_to(self)
68 | if context is not None:
69 | print(f"{context}: Switched to EMA weights")
70 | try:
71 | yield None
72 | finally:
73 | if self.use_ema:
74 | self.model_ema.restore(self.parameters())
75 | if context is not None:
76 | print(f"{context}: Restored training weights")
77 |
78 | def on_train_batch_end(self, *args, **kwargs):
79 | if self.use_ema:
80 | self.model_ema(self)
81 |
82 | def encode(self, x):
83 | h = self.encoder(x)
84 | moments = self.quant_conv(h)
85 | posterior = DiagonalGaussianDistribution(moments)
86 | return posterior
87 |
88 | def decode(self, z):
89 | z = self.post_quant_conv(z)
90 | dec = self.decoder(z)
91 | return dec
92 |
93 | def forward(self, input, sample_posterior=True):
94 | posterior = self.encode(input)
95 | if sample_posterior:
96 | z = posterior.sample()
97 | else:
98 | z = posterior.mode()
99 | dec = self.decode(z)
100 | return dec, posterior
101 |
102 | def get_input(self, batch, k):
103 | x = batch[k]
104 | if len(x.shape) == 3:
105 | x = x[..., None]
106 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
107 | return x
108 |
109 | def training_step(self, batch, batch_idx, optimizer_idx):
110 | inputs = self.get_input(batch, self.image_key)
111 | reconstructions, posterior = self(inputs)
112 |
113 | if optimizer_idx == 0:
114 | # train encoder+decoder+logvar
115 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
116 | last_layer=self.get_last_layer(), split="train")
117 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
118 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
119 | return aeloss
120 |
121 | if optimizer_idx == 1:
122 | # train the discriminator
123 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
124 | last_layer=self.get_last_layer(), split="train")
125 |
126 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
127 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
128 | return discloss
129 |
130 | def validation_step(self, batch, batch_idx):
131 | log_dict = self._validation_step(batch, batch_idx)
132 | with self.ema_scope():
133 | log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
134 | return log_dict
135 |
136 | def _validation_step(self, batch, batch_idx, postfix=""):
137 | inputs = self.get_input(batch, self.image_key)
138 | reconstructions, posterior = self(inputs)
139 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
140 | last_layer=self.get_last_layer(), split="val"+postfix)
141 |
142 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
143 | last_layer=self.get_last_layer(), split="val"+postfix)
144 |
145 | self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
146 | self.log_dict(log_dict_ae)
147 | self.log_dict(log_dict_disc)
148 | return self.log_dict
149 |
150 | def configure_optimizers(self):
151 | lr = self.learning_rate
152 | ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
153 | self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
154 | if self.learn_logvar:
155 | print(f"{self.__class__.__name__}: Learning logvar")
156 | ae_params_list.append(self.loss.logvar)
157 | opt_ae = torch.optim.Adam(ae_params_list,
158 | lr=lr, betas=(0.5, 0.9))
159 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
160 | lr=lr, betas=(0.5, 0.9))
161 | return [opt_ae, opt_disc], []
162 |
163 | def get_last_layer(self):
164 | return self.decoder.conv_out.weight
165 |
166 | @torch.no_grad()
167 | def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
168 | log = dict()
169 | x = self.get_input(batch, self.image_key)
170 | x = x.to(self.device)
171 | if not only_inputs:
172 | xrec, posterior = self(x)
173 | if x.shape[1] > 3:
174 | # colorize with random projection
175 | assert xrec.shape[1] > 3
176 | x = self.to_rgb(x)
177 | xrec = self.to_rgb(xrec)
178 | log["samples"] = self.decode(torch.randn_like(posterior.sample()))
179 | log["reconstructions"] = xrec
180 | if log_ema or self.use_ema:
181 | with self.ema_scope():
182 | xrec_ema, posterior_ema = self(x)
183 | if x.shape[1] > 3:
184 | # colorize with random projection
185 | assert xrec_ema.shape[1] > 3
186 | xrec_ema = self.to_rgb(xrec_ema)
187 | log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
188 | log["reconstructions_ema"] = xrec_ema
189 | log["inputs"] = x
190 | return log
191 |
192 | def to_rgb(self, x):
193 | assert self.image_key == "segmentation"
194 | if not hasattr(self, "colorize"):
195 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
196 | x = F.conv2d(x, weight=self.colorize)
197 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
198 | return x
199 |
200 |
201 | class IdentityFirstStage(torch.nn.Module):
202 | def __init__(self, *args, vq_interface=False, **kwargs):
203 | self.vq_interface = vq_interface
204 | super().__init__()
205 |
206 | def encode(self, x, *args, **kwargs):
207 | return x
208 |
209 | def decode(self, x, *args, **kwargs):
210 | return x
211 |
212 | def quantize(self, x, *args, **kwargs):
213 | if self.vq_interface:
214 | return x, None, [None, None, None]
215 | return x
216 |
217 | def forward(self, x, *args, **kwargs):
218 | return x
219 |
220 |
--------------------------------------------------------------------------------
/ldm/models/diffusion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/ldm/models/diffusion/__init__.py
--------------------------------------------------------------------------------
/ldm/models/diffusion/dpm_solver/__init__.py:
--------------------------------------------------------------------------------
1 | from .sampler import DPMSolverSampler
--------------------------------------------------------------------------------
/ldm/models/diffusion/dpm_solver/sampler.py:
--------------------------------------------------------------------------------
1 | """SAMPLING ONLY."""
2 | import torch
3 |
4 | from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
5 |
6 |
7 | MODEL_TYPES = {
8 | "eps": "noise",
9 | "v": "v"
10 | }
11 |
12 |
13 | class DPMSolverSampler(object):
14 | def __init__(self, model, **kwargs):
15 | super().__init__()
16 | self.model = model
17 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
18 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
19 |
20 | def register_buffer(self, name, attr):
21 | if type(attr) == torch.Tensor:
22 | if attr.device != torch.device("cuda"):
23 | attr = attr.to(torch.device("cuda"))
24 | setattr(self, name, attr)
25 |
26 | @torch.no_grad()
27 | def sample(self,
28 | S,
29 | batch_size,
30 | shape,
31 | conditioning=None,
32 | callback=None,
33 | normals_sequence=None,
34 | img_callback=None,
35 | quantize_x0=False,
36 | eta=0.,
37 | mask=None,
38 | x0=None,
39 | temperature=1.,
40 | noise_dropout=0.,
41 | score_corrector=None,
42 | corrector_kwargs=None,
43 | verbose=True,
44 | x_T=None,
45 | log_every_t=100,
46 | unconditional_guidance_scale=1.,
47 | unconditional_conditioning=None,
48 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
49 | **kwargs
50 | ):
51 | if conditioning is not None:
52 | if isinstance(conditioning, dict):
53 | cbs = conditioning[list(conditioning.keys())[0]].shape[0]
54 | if cbs != batch_size:
55 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
56 | else:
57 | if conditioning.shape[0] != batch_size:
58 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
59 |
60 | # sampling
61 | C, H, W = shape
62 | size = (batch_size, C, H, W)
63 |
64 | print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
65 |
66 | device = self.model.betas.device
67 | if x_T is None:
68 | img = torch.randn(size, device=device)
69 | else:
70 | img = x_T
71 |
72 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
73 |
74 | model_fn = model_wrapper(
75 | lambda x, t, c: self.model.apply_model(x, t, c),
76 | ns,
77 | model_type=MODEL_TYPES[self.model.parameterization],
78 | guidance_type="classifier-free",
79 | condition=conditioning,
80 | unconditional_condition=unconditional_conditioning,
81 | guidance_scale=unconditional_guidance_scale,
82 | )
83 |
84 | dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
85 | x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
86 |
87 | return x.to(device), None
--------------------------------------------------------------------------------
/ldm/models/diffusion/plms.py:
--------------------------------------------------------------------------------
1 | """SAMPLING ONLY."""
2 |
3 | import torch
4 | import numpy as np
5 | from tqdm import tqdm
6 | from functools import partial
7 |
8 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9 | from ldm.models.diffusion.sampling_util import norm_thresholding
10 |
11 |
12 | class PLMSSampler(object):
13 | def __init__(self, model, schedule="linear", **kwargs):
14 | super().__init__()
15 | self.model = model
16 | self.ddpm_num_timesteps = model.num_timesteps
17 | self.schedule = schedule
18 |
19 | def register_buffer(self, name, attr):
20 | if type(attr) == torch.Tensor:
21 | if attr.device != torch.device("cuda"):
22 | attr = attr.to(torch.device("cuda"))
23 | setattr(self, name, attr)
24 |
25 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
26 | if ddim_eta != 0:
27 | raise ValueError('ddim_eta must be 0 for PLMS')
28 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
29 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
30 | alphas_cumprod = self.model.alphas_cumprod
31 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
32 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
33 |
34 | self.register_buffer('betas', to_torch(self.model.betas))
35 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
36 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
37 |
38 | # calculations for diffusion q(x_t | x_{t-1}) and others
39 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
40 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
41 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
42 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
43 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
44 |
45 | # ddim sampling parameters
46 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
47 | ddim_timesteps=self.ddim_timesteps,
48 | eta=ddim_eta,verbose=verbose)
49 | self.register_buffer('ddim_sigmas', ddim_sigmas)
50 | self.register_buffer('ddim_alphas', ddim_alphas)
51 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
52 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
53 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
54 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
55 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
56 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
57 |
58 | @torch.no_grad()
59 | def sample(self,
60 | S,
61 | batch_size,
62 | shape,
63 | conditioning=None,
64 | callback=None,
65 | normals_sequence=None,
66 | img_callback=None,
67 | quantize_x0=False,
68 | eta=0.,
69 | mask=None,
70 | x0=None,
71 | temperature=1.,
72 | noise_dropout=0.,
73 | score_corrector=None,
74 | corrector_kwargs=None,
75 | verbose=True,
76 | x_T=None,
77 | log_every_t=100,
78 | unconditional_guidance_scale=1.,
79 | unconditional_conditioning=None,
80 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
81 | dynamic_threshold=None,
82 | **kwargs
83 | ):
84 | if conditioning is not None:
85 | if isinstance(conditioning, dict):
86 | cbs = conditioning[list(conditioning.keys())[0]].shape[0]
87 | if cbs != batch_size:
88 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
89 | else:
90 | if conditioning.shape[0] != batch_size:
91 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
92 |
93 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
94 | # sampling
95 | C, H, W = shape
96 | size = (batch_size, C, H, W)
97 | print(f'Data shape for PLMS sampling is {size}')
98 |
99 | samples, intermediates = self.plms_sampling(conditioning, size,
100 | callback=callback,
101 | img_callback=img_callback,
102 | quantize_denoised=quantize_x0,
103 | mask=mask, x0=x0,
104 | ddim_use_original_steps=False,
105 | noise_dropout=noise_dropout,
106 | temperature=temperature,
107 | score_corrector=score_corrector,
108 | corrector_kwargs=corrector_kwargs,
109 | x_T=x_T,
110 | log_every_t=log_every_t,
111 | unconditional_guidance_scale=unconditional_guidance_scale,
112 | unconditional_conditioning=unconditional_conditioning,
113 | dynamic_threshold=dynamic_threshold,
114 | )
115 | return samples, intermediates
116 |
117 | @torch.no_grad()
118 | def plms_sampling(self, cond, shape,
119 | x_T=None, ddim_use_original_steps=False,
120 | callback=None, timesteps=None, quantize_denoised=False,
121 | mask=None, x0=None, img_callback=None, log_every_t=100,
122 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
123 | unconditional_guidance_scale=1., unconditional_conditioning=None,
124 | dynamic_threshold=None):
125 | device = self.model.betas.device
126 | b = shape[0]
127 | if x_T is None:
128 | img = torch.randn(shape, device=device)
129 | else:
130 | img = x_T
131 |
132 | if timesteps is None:
133 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
134 | elif timesteps is not None and not ddim_use_original_steps:
135 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
136 | timesteps = self.ddim_timesteps[:subset_end]
137 |
138 | intermediates = {'x_inter': [img], 'pred_x0': [img]}
139 | time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
140 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
141 | print(f"Running PLMS Sampling with {total_steps} timesteps")
142 |
143 | iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
144 | old_eps = []
145 |
146 | for i, step in enumerate(iterator):
147 | index = total_steps - i - 1
148 | ts = torch.full((b,), step, device=device, dtype=torch.long)
149 | ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
150 |
151 | if mask is not None:
152 | assert x0 is not None
153 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
154 | img = img_orig * mask + (1. - mask) * img
155 |
156 | outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
157 | quantize_denoised=quantize_denoised, temperature=temperature,
158 | noise_dropout=noise_dropout, score_corrector=score_corrector,
159 | corrector_kwargs=corrector_kwargs,
160 | unconditional_guidance_scale=unconditional_guidance_scale,
161 | unconditional_conditioning=unconditional_conditioning,
162 | old_eps=old_eps, t_next=ts_next,
163 | dynamic_threshold=dynamic_threshold)
164 | img, pred_x0, e_t = outs
165 | old_eps.append(e_t)
166 | if len(old_eps) >= 4:
167 | old_eps.pop(0)
168 | if callback: callback(i)
169 | if img_callback: img_callback(pred_x0, i)
170 |
171 | if index % log_every_t == 0 or index == total_steps - 1:
172 | intermediates['x_inter'].append(img)
173 | intermediates['pred_x0'].append(pred_x0)
174 |
175 | return img, intermediates
176 |
177 | @torch.no_grad()
178 | def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
179 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
180 | unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
181 | dynamic_threshold=None):
182 | b, *_, device = *x.shape, x.device
183 |
184 | def get_model_output(x, t):
185 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
186 | e_t = self.model.apply_model(x, t, c)
187 | else:
188 | x_in = torch.cat([x] * 2)
189 | t_in = torch.cat([t] * 2)
190 | c_in = torch.cat([unconditional_conditioning, c])
191 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
192 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
193 |
194 | if score_corrector is not None:
195 | assert self.model.parameterization == "eps"
196 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
197 |
198 | return e_t
199 |
200 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
201 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
202 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
203 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
204 |
205 | def get_x_prev_and_pred_x0(e_t, index):
206 | # select parameters corresponding to the currently considered timestep
207 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
208 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
209 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
210 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
211 |
212 | # current prediction for x_0
213 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
214 | if quantize_denoised:
215 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
216 | if dynamic_threshold is not None:
217 | pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
218 | # direction pointing to x_t
219 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
220 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
221 | if noise_dropout > 0.:
222 | noise = torch.nn.functional.dropout(noise, p=noise_dropout)
223 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
224 | return x_prev, pred_x0
225 |
226 | e_t = get_model_output(x, t)
227 | if len(old_eps) == 0:
228 | # Pseudo Improved Euler (2nd order)
229 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
230 | e_t_next = get_model_output(x_prev, t_next)
231 | e_t_prime = (e_t + e_t_next) / 2
232 | elif len(old_eps) == 1:
233 | # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
234 | e_t_prime = (3 * e_t - old_eps[-1]) / 2
235 | elif len(old_eps) == 2:
236 | # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
237 | e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
238 | elif len(old_eps) >= 3:
239 | # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
240 | e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
241 |
242 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
243 |
244 | return x_prev, pred_x0, e_t
245 |
--------------------------------------------------------------------------------
/ldm/models/diffusion/sampling_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | def append_dims(x, target_dims):
6 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.
7 | From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
8 | dims_to_append = target_dims - x.ndim
9 | if dims_to_append < 0:
10 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
11 | return x[(...,) + (None,) * dims_to_append]
12 |
13 |
14 | def norm_thresholding(x0, value):
15 | s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
16 | return x0 * (value / s)
17 |
18 |
19 | def spatial_norm_thresholding(x0, value):
20 | # b c h w
21 | s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
22 | return x0 * (value / s)
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/ldm/modules/diffusionmodules/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/upscaling.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from functools import partial
5 |
6 | from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
7 | from ldm.util import default
8 |
9 |
10 | class AbstractLowScaleModel(nn.Module):
11 | # for concatenating a downsampled image to the latent representation
12 | def __init__(self, noise_schedule_config=None):
13 | super(AbstractLowScaleModel, self).__init__()
14 | if noise_schedule_config is not None:
15 | self.register_schedule(**noise_schedule_config)
16 |
17 | def register_schedule(self, beta_schedule="linear", timesteps=1000,
18 | linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
19 | betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
20 | cosine_s=cosine_s)
21 | alphas = 1. - betas
22 | alphas_cumprod = np.cumprod(alphas, axis=0)
23 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
24 |
25 | timesteps, = betas.shape
26 | self.num_timesteps = int(timesteps)
27 | self.linear_start = linear_start
28 | self.linear_end = linear_end
29 | assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
30 |
31 | to_torch = partial(torch.tensor, dtype=torch.float32)
32 |
33 | self.register_buffer('betas', to_torch(betas))
34 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
35 | self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
36 |
37 | # calculations for diffusion q(x_t | x_{t-1}) and others
38 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
39 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
40 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
41 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
42 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
43 |
44 | def q_sample(self, x_start, t, noise=None):
45 | noise = default(noise, lambda: torch.randn_like(x_start))
46 | return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
47 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
48 |
49 | def forward(self, x):
50 | return x, None
51 |
52 | def decode(self, x):
53 | return x
54 |
55 |
56 | class SimpleImageConcat(AbstractLowScaleModel):
57 | # no noise level conditioning
58 | def __init__(self):
59 | super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
60 | self.max_noise_level = 0
61 |
62 | def forward(self, x):
63 | # fix to constant noise level
64 | return x, torch.zeros(x.shape[0], device=x.device).long()
65 |
66 |
67 | class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
68 | def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
69 | super().__init__(noise_schedule_config=noise_schedule_config)
70 | self.max_noise_level = max_noise_level
71 |
72 | def forward(self, x, noise_level=None):
73 | if noise_level is None:
74 | noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
75 | else:
76 | assert isinstance(noise_level, torch.Tensor)
77 | z = self.q_sample(x, noise_level)
78 | return z, noise_level
79 |
80 |
81 |
82 |
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/util.py:
--------------------------------------------------------------------------------
1 | # adopted from
2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3 | # and
4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5 | # and
6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7 | #
8 | # thanks!
9 |
10 |
11 | import os
12 | import math
13 | import torch
14 | import torch.nn as nn
15 | import numpy as np
16 | from einops import repeat
17 |
18 | from ldm.util import instantiate_from_config
19 | import pdb
20 |
21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
22 | if schedule == "linear":
23 | betas = (torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2)
24 |
25 | elif schedule == "cosine":
26 | timesteps = (
27 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
28 | )
29 | alphas = timesteps / (1 + cosine_s) * np.pi / 2
30 | alphas = torch.cos(alphas).pow(2)
31 | alphas = alphas / alphas[0]
32 | betas = 1 - alphas[1:] / alphas[:-1]
33 | betas = np.clip(betas, a_min=0, a_max=0.999)
34 |
35 | elif schedule == "sqrt_linear":
36 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
37 | elif schedule == "sqrt":
38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
39 | else:
40 | raise ValueError(f"schedule '{schedule}' unknown.")
41 | return betas.numpy()
42 |
43 |
44 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
45 | if ddim_discr_method == 'uniform':
46 | c = num_ddpm_timesteps // num_ddim_timesteps
47 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
48 | elif ddim_discr_method == 'quad':
49 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
50 | else:
51 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
52 |
53 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps
54 | # add one to get the final alpha values right (the ones from first scale to data during sampling)
55 | steps_out = ddim_timesteps + 1
56 | if verbose:
57 | print(f'Selected timesteps for ddim sampler: {steps_out}')
58 | return steps_out
59 |
60 |
61 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
62 | # select alphas for computing the variance schedule
63 | alphas = alphacums[ddim_timesteps]
64 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
65 |
66 | # according the the formula provided in https://arxiv.org/abs/2010.02502
67 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
68 | if verbose:
69 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
70 | print(f'For the chosen value of eta, which is {eta}, '
71 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
72 | return sigmas, alphas, alphas_prev
73 |
74 |
75 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
76 | """
77 | Create a beta schedule that discretizes the given alpha_t_bar function,
78 | which defines the cumulative product of (1-beta) over time from t = [0,1].
79 | :param num_diffusion_timesteps: the number of betas to produce.
80 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
81 | produces the cumulative product of (1-beta) up to that
82 | part of the diffusion process.
83 | :param max_beta: the maximum beta to use; use values lower than 1 to
84 | prevent singularities.
85 | """
86 | betas = []
87 | for i in range(num_diffusion_timesteps):
88 | t1 = i / num_diffusion_timesteps
89 | t2 = (i + 1) / num_diffusion_timesteps
90 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
91 | return np.array(betas)
92 |
93 |
94 | def extract_into_tensor(a, t, x_shape):
95 | b, *_ = t.shape
96 | out = a.gather(-1, t)
97 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
98 |
99 |
100 | def checkpoint(func, inputs, params, flag):
101 | """
102 | Evaluate a function without caching intermediate activations, allowing for
103 | reduced memory at the expense of extra compute in the backward pass.
104 | :param func: the function to evaluate.
105 | :param inputs: the argument sequence to pass to `func`.
106 | :param params: a sequence of parameters `func` depends on but does not
107 | explicitly take as arguments.
108 | :param flag: if False, disable gradient checkpointing.
109 | """
110 | if flag:
111 | args = tuple(inputs) + tuple(params)
112 | return CheckpointFunction.apply(func, len(inputs), *args)
113 | else:
114 | return func(*inputs)
115 |
116 |
117 | class CheckpointFunction(torch.autograd.Function):
118 | @staticmethod
119 | def forward(ctx, run_function, length, *args):
120 | ctx.run_function = run_function
121 | ctx.input_tensors = list(args[:length])
122 | ctx.input_params = list(args[length:])
123 | ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
124 | "dtype": torch.get_autocast_gpu_dtype(),
125 | "cache_enabled": torch.is_autocast_cache_enabled()}
126 | with torch.no_grad():
127 | output_tensors = ctx.run_function(*ctx.input_tensors)
128 | return output_tensors
129 |
130 | @staticmethod
131 | def backward(ctx, *output_grads):
132 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
133 | with torch.enable_grad(), \
134 | torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
135 | # Fixes a bug where the first op in run_function modifies the
136 | # Tensor storage in place, which is not allowed for detach()'d
137 | # Tensors.
138 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
139 | output_tensors = ctx.run_function(*shallow_copies)
140 | input_grads = torch.autograd.grad(
141 | output_tensors,
142 | ctx.input_tensors + ctx.input_params,
143 | output_grads,
144 | allow_unused=True,
145 | )
146 | del ctx.input_tensors
147 | del ctx.input_params
148 | del output_tensors
149 | return (None, None) + input_grads
150 |
151 |
152 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
153 | """
154 | Create sinusoidal timestep embeddings.
155 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
156 | These may be fractional.
157 | :param dim: the dimension of the output.
158 | :param max_period: controls the minimum frequency of the embeddings.
159 | :return: an [N x dim] Tensor of positional embeddings.
160 | """
161 | if not repeat_only:
162 | half = dim // 2
163 | freqs = torch.exp(
164 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
165 | ).to(device=timesteps.device)
166 |
167 | # fp16
168 | # freqs = torch.exp(
169 | # -math.log(max_period) * torch.arange(start=0, end=half, device=timesteps.device, dtype=torch.float16) / half
170 | # )
171 |
172 | args = timesteps[:, None].float() * freqs[None]
173 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
174 | if dim % 2:
175 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
176 | else:
177 | embedding = repeat(timesteps, 'b -> b d', d=dim)
178 | return embedding
179 |
180 |
181 | def zero_module(module):
182 | """
183 | Zero out the parameters of a module and return it.
184 | """
185 | for p in module.parameters():
186 | p.detach().zero_()
187 | return module
188 |
189 |
190 | def scale_module(module, scale):
191 | """
192 | Scale the parameters of a module and return it.
193 | """
194 | for p in module.parameters():
195 | p.detach().mul_(scale)
196 | return module
197 |
198 |
199 | def mean_flat(tensor):
200 | """
201 | Take the mean over all non-batch dimensions.
202 | """
203 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
204 |
205 |
206 | def normalization(channels):
207 | """
208 | Make a standard normalization layer.
209 | :param channels: number of input channels.
210 | :return: an nn.Module for normalization.
211 | """
212 | return GroupNorm32(32, channels)
213 |
214 |
215 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
216 | class SiLU(nn.Module):
217 | def forward(self, x):
218 | return x * torch.sigmoid(x)
219 |
220 |
221 | class GroupNorm32(nn.GroupNorm):
222 | def forward(self, x):
223 | return super().forward(x.float()).type(x.dtype)
224 |
225 | def conv_nd(dims, *args, **kwargs):
226 | """
227 | Create a 1D, 2D, or 3D convolution module.
228 | """
229 | if dims == 1:
230 | return nn.Conv1d(*args, **kwargs)
231 | elif dims == 2:
232 | return nn.Conv2d(*args, **kwargs)
233 | elif dims == 3:
234 | return nn.Conv3d(*args, **kwargs)
235 | raise ValueError(f"unsupported dimensions: {dims}")
236 |
237 |
238 | def linear(*args, **kwargs):
239 | """
240 | Create a linear module.
241 | """
242 | return nn.Linear(*args, **kwargs)
243 |
244 |
245 | def avg_pool_nd(dims, *args, **kwargs):
246 | """
247 | Create a 1D, 2D, or 3D average pooling module.
248 | """
249 | if dims == 1:
250 | return nn.AvgPool1d(*args, **kwargs)
251 | elif dims == 2:
252 | return nn.AvgPool2d(*args, **kwargs)
253 | elif dims == 3:
254 | return nn.AvgPool3d(*args, **kwargs)
255 | raise ValueError(f"unsupported dimensions: {dims}")
256 |
257 |
258 | class HybridConditioner(nn.Module):
259 |
260 | def __init__(self, c_concat_config, c_crossattn_config):
261 | super().__init__()
262 | self.concat_conditioner = instantiate_from_config(c_concat_config)
263 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
264 |
265 | def forward(self, c_concat, c_crossattn):
266 | c_concat = self.concat_conditioner(c_concat)
267 | c_crossattn = self.crossattn_conditioner(c_crossattn)
268 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
269 |
270 |
271 | def noise_like(shape, device, repeat=False):
272 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
273 | noise = lambda: torch.randn(shape, device=device)
274 | return repeat_noise() if repeat else noise()
--------------------------------------------------------------------------------
/ldm/modules/distributions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/ldm/modules/distributions/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/distributions/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 |
35 | def sample(self):
36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37 | return x
38 |
39 | def kl(self, other=None):
40 | if self.deterministic:
41 | return torch.Tensor([0.])
42 | else:
43 | if other is None:
44 | return 0.5 * torch.sum(torch.pow(self.mean, 2)
45 | + self.var - 1.0 - self.logvar,
46 | dim=[1, 2, 3])
47 | else:
48 | return 0.5 * torch.sum(
49 | torch.pow(self.mean - other.mean, 2) / other.var
50 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
51 | dim=[1, 2, 3])
52 |
53 | def nll(self, sample, dims=[1,2,3]):
54 | if self.deterministic:
55 | return torch.Tensor([0.])
56 | logtwopi = np.log(2.0 * np.pi)
57 | return 0.5 * torch.sum(
58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59 | dim=dims)
60 |
61 | def mode(self):
62 | return self.mean
63 |
64 |
65 | def normal_kl(mean1, logvar1, mean2, logvar2):
66 | """
67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68 | Compute the KL divergence between two gaussians.
69 | Shapes are automatically broadcasted, so batches can be compared to
70 | scalars, among other use cases.
71 | """
72 | tensor = None
73 | for obj in (mean1, logvar1, mean2, logvar2):
74 | if isinstance(obj, torch.Tensor):
75 | tensor = obj
76 | break
77 | assert tensor is not None, "at least one argument must be a Tensor"
78 |
79 | # Force variances to be Tensors. Broadcasting helps convert scalars to
80 | # Tensors, but it does not work for torch.exp().
81 | logvar1, logvar2 = [
82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83 | for x in (logvar1, logvar2)
84 | ]
85 |
86 | return 0.5 * (
87 | -1.0
88 | + logvar2
89 | - logvar1
90 | + torch.exp(logvar1 - logvar2)
91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92 | )
93 |
--------------------------------------------------------------------------------
/ldm/modules/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError('Decay must be between 0 and 1')
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
14 | else torch.tensor(-1, dtype=torch.int))
15 |
16 | for name, p in model.named_parameters():
17 | if p.requires_grad:
18 | # remove as '.'-character is not allowed in buffers
19 | s_name = name.replace('.', '')
20 | self.m_name2s_name.update({name: s_name})
21 | self.register_buffer(s_name, p.clone().detach().data)
22 |
23 | self.collected_params = []
24 |
25 | def reset_num_updates(self):
26 | del self.num_updates
27 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
28 |
29 | def forward(self, model):
30 | decay = self.decay
31 |
32 | if self.num_updates >= 0:
33 | self.num_updates += 1
34 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
35 |
36 | one_minus_decay = 1.0 - decay
37 |
38 | with torch.no_grad():
39 | m_param = dict(model.named_parameters())
40 | shadow_params = dict(self.named_buffers())
41 |
42 | for key in m_param:
43 | if m_param[key].requires_grad:
44 | sname = self.m_name2s_name[key]
45 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
46 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
47 | else:
48 | assert not key in self.m_name2s_name
49 |
50 | def copy_to(self, model):
51 | m_param = dict(model.named_parameters())
52 | shadow_params = dict(self.named_buffers())
53 | for key in m_param:
54 | if m_param[key].requires_grad:
55 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
56 | else:
57 | assert not key in self.m_name2s_name
58 |
59 | def store(self, parameters):
60 | """
61 | Save the current parameters for restoring later.
62 | Args:
63 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
64 | temporarily stored.
65 | """
66 | self.collected_params = [param.clone() for param in parameters]
67 |
68 | def restore(self, parameters):
69 | """
70 | Restore the parameters stored with the `store` method.
71 | Useful to validate the model with EMA parameters without affecting the
72 | original optimization process. Store the parameters before the
73 | `copy_to` method. After validation (or model saving), use this to
74 | restore the former parameters.
75 | Args:
76 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
77 | updated with the stored parameters.
78 | """
79 | for c_param, param in zip(self.collected_params, parameters):
80 | param.data.copy_(c_param.data)
81 |
--------------------------------------------------------------------------------
/ldm/modules/embedding_manager.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from ldm.data.personalized import per_img_token_list
5 | from transformers import CLIPTokenizer
6 | from functools import partial
7 |
8 | DEFAULT_PLACEHOLDER_TOKEN = ["*"]
9 |
10 | PROGRESSIVE_SCALE = 2000
11 |
12 | def get_clip_token_for_string(tokenizer, string):
13 | batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True,
14 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
15 | tokens = batch_encoding["input_ids"]
16 | assert torch.count_nonzero(tokens - 49407) == 2, f"String '{string}' maps to more than a single token. Please use another string"
17 |
18 | return tokens[0, 1]
19 |
20 | def get_bert_token_for_string(tokenizer, string):
21 | token = tokenizer(string)
22 | assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string"
23 |
24 | token = token[0, 1]
25 |
26 | return token
27 |
28 | def get_embedding_for_clip_token(embedder, token):
29 | return embedder(token.unsqueeze(0))[0, 0]
30 |
31 |
32 | class EmbeddingManager(nn.Module):
33 | def __init__(
34 | self,
35 | embedder,
36 | placeholder_strings=None,
37 | initializer_words=None,
38 | per_image_tokens=False,
39 | num_vectors_per_token=1,
40 | progressive_words=False,
41 | **kwargs
42 | ):
43 | super().__init__()
44 |
45 | self.string_to_token_dict = {}
46 |
47 | self.string_to_param_dict = nn.ParameterDict()
48 |
49 | self.initial_embeddings = nn.ParameterDict() # These should not be optimized
50 |
51 | self.progressive_words = progressive_words
52 | self.progressive_counter = 0
53 |
54 | self.max_vectors_per_token = num_vectors_per_token
55 |
56 | if hasattr(embedder, 'tokenizer'): # using Stable Diffusion's CLIP encoder
57 | self.is_clip = True
58 | get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
59 | get_embedding_for_tkn = partial(get_embedding_for_clip_token, embedder.transformer.text_model.embeddings)
60 | token_dim = 768
61 | else: # using LDM's BERT encoder
62 | self.is_clip = False
63 | get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn)
64 | get_embedding_for_tkn = embedder.transformer.token_emb
65 | token_dim = 1280
66 |
67 | if per_image_tokens:
68 | placeholder_strings.extend(per_img_token_list)
69 |
70 | for idx, placeholder_string in enumerate(placeholder_strings):
71 |
72 | token = get_token_for_string(placeholder_string)
73 |
74 | if initializer_words and idx < len(initializer_words):
75 | init_word_token = get_token_for_string(initializer_words[idx])
76 |
77 | with torch.no_grad():
78 | init_word_embedding = get_embedding_for_tkn(init_word_token.cpu())
79 |
80 | token_params = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=True)
81 | self.initial_embeddings[placeholder_string] = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=False)
82 | else:
83 | token_params = torch.nn.Parameter(torch.rand(size=(num_vectors_per_token, token_dim), requires_grad=True))
84 |
85 | self.string_to_token_dict[placeholder_string] = token
86 | self.string_to_param_dict[placeholder_string] = token_params
87 |
88 | def forward(
89 | self,
90 | tokenized_text,
91 | embedded_text,
92 | ):
93 | b, n, device = *tokenized_text.shape, tokenized_text.device
94 |
95 | for placeholder_string, placeholder_token in self.string_to_token_dict.items():
96 |
97 | placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device)
98 |
99 | if self.max_vectors_per_token == 1: # If there's only one vector per token, we can do a simple replacement
100 | placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device))
101 | embedded_text[placeholder_idx] = placeholder_embedding
102 | else: # otherwise, need to insert and keep track of changing indices
103 | if self.progressive_words:
104 | self.progressive_counter += 1
105 | max_step_tokens = 1 + self.progressive_counter // PROGRESSIVE_SCALE
106 | else:
107 | max_step_tokens = self.max_vectors_per_token
108 |
109 | num_vectors_for_token = min(placeholder_embedding.shape[0], max_step_tokens)
110 |
111 | placeholder_rows, placeholder_cols = torch.where(tokenized_text == placeholder_token.to(device))
112 |
113 | if placeholder_rows.nelement() == 0:
114 | continue
115 |
116 | sorted_cols, sort_idx = torch.sort(placeholder_cols, descending=True)
117 | sorted_rows = placeholder_rows[sort_idx]
118 |
119 | for idx in range(len(sorted_rows)):
120 | row = sorted_rows[idx]
121 | col = sorted_cols[idx]
122 |
123 | new_token_row = torch.cat([tokenized_text[row][:col], placeholder_token.repeat(num_vectors_for_token).to(device), tokenized_text[row][col + 1:]], axis=0)[:n]
124 | new_embed_row = torch.cat([embedded_text[row][:col], placeholder_embedding[:num_vectors_for_token], embedded_text[row][col + 1:]], axis=0)[:n]
125 |
126 | embedded_text[row] = new_embed_row
127 | tokenized_text[row] = new_token_row
128 |
129 | return embedded_text
130 |
131 | def save(self, ckpt_path):
132 | torch.save({"string_to_token": self.string_to_token_dict,
133 | "string_to_param": self.string_to_param_dict}, ckpt_path)
134 |
135 | def load(self, ckpt_path):
136 | ckpt = torch.load(ckpt_path, map_location='cpu')
137 |
138 | self.string_to_token_dict = ckpt["string_to_token"]
139 | self.string_to_param_dict = ckpt["string_to_param"]
140 |
141 | def get_embedding_norms_squared(self):
142 | all_params = torch.cat(list(self.string_to_param_dict.values()), axis=0) # num_placeholders x embedding_dim
143 | param_norm_squared = (all_params * all_params).sum(axis=-1) # num_placeholders
144 |
145 | return param_norm_squared
146 |
147 | def embedding_parameters(self):
148 | return self.string_to_param_dict.parameters()
149 |
150 | def embedding_to_coarse_loss(self):
151 |
152 | loss = 0.
153 | num_embeddings = len(self.initial_embeddings)
154 |
155 | for key in self.initial_embeddings:
156 | optimized = self.string_to_param_dict[key]
157 | coarse = self.initial_embeddings[key].clone().to(optimized.device)
158 |
159 | loss = loss + (optimized - coarse) @ (optimized - coarse).T / num_embeddings
160 |
161 | return loss
--------------------------------------------------------------------------------
/ldm/modules/encoders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/ldm/modules/encoders/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/image_degradation/__init__.py:
--------------------------------------------------------------------------------
1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
3 |
--------------------------------------------------------------------------------
/ldm/modules/image_degradation/utils/test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/ldm/modules/image_degradation/utils/test.png
--------------------------------------------------------------------------------
/ldm/modules/midas/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/ldm/modules/midas/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/midas/api.py:
--------------------------------------------------------------------------------
1 | # based on https://github.com/isl-org/MiDaS
2 |
3 | import cv2
4 | import torch
5 | import torch.nn as nn
6 | from torchvision.transforms import Compose
7 |
8 | from ldm.modules.midas.midas.dpt_depth import DPTDepthModel
9 | from ldm.modules.midas.midas.midas_net import MidasNet
10 | from ldm.modules.midas.midas.midas_net_custom import MidasNet_small
11 | from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet
12 |
13 |
14 | ISL_PATHS = {
15 | "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt",
16 | "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt",
17 | "midas_v21": "",
18 | "midas_v21_small": "",
19 | }
20 |
21 |
22 | def disabled_train(self, mode=True):
23 | """Overwrite model.train with this function to make sure train/eval mode
24 | does not change anymore."""
25 | return self
26 |
27 |
28 | def load_midas_transform(model_type):
29 | # https://github.com/isl-org/MiDaS/blob/master/run.py
30 | # load transform only
31 | if model_type == "dpt_large": # DPT-Large
32 | net_w, net_h = 384, 384
33 | resize_mode = "minimal"
34 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
35 |
36 | elif model_type == "dpt_hybrid": # DPT-Hybrid
37 | net_w, net_h = 384, 384
38 | resize_mode = "minimal"
39 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
40 |
41 | elif model_type == "midas_v21":
42 | net_w, net_h = 384, 384
43 | resize_mode = "upper_bound"
44 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
45 |
46 | elif model_type == "midas_v21_small":
47 | net_w, net_h = 256, 256
48 | resize_mode = "upper_bound"
49 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
50 |
51 | else:
52 | assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
53 |
54 | transform = Compose(
55 | [
56 | Resize(
57 | net_w,
58 | net_h,
59 | resize_target=None,
60 | keep_aspect_ratio=True,
61 | ensure_multiple_of=32,
62 | resize_method=resize_mode,
63 | image_interpolation_method=cv2.INTER_CUBIC,
64 | ),
65 | normalization,
66 | PrepareForNet(),
67 | ]
68 | )
69 |
70 | return transform
71 |
72 |
73 | def load_model(model_type):
74 | # https://github.com/isl-org/MiDaS/blob/master/run.py
75 | # load network
76 | model_path = ISL_PATHS[model_type]
77 | if model_type == "dpt_large": # DPT-Large
78 | model = DPTDepthModel(
79 | path=model_path,
80 | backbone="vitl16_384",
81 | non_negative=True,
82 | )
83 | net_w, net_h = 384, 384
84 | resize_mode = "minimal"
85 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
86 |
87 | elif model_type == "dpt_hybrid": # DPT-Hybrid
88 | model = DPTDepthModel(
89 | path=model_path,
90 | backbone="vitb_rn50_384",
91 | non_negative=True,
92 | )
93 | net_w, net_h = 384, 384
94 | resize_mode = "minimal"
95 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
96 |
97 | elif model_type == "midas_v21":
98 | model = MidasNet(model_path, non_negative=True)
99 | net_w, net_h = 384, 384
100 | resize_mode = "upper_bound"
101 | normalization = NormalizeImage(
102 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
103 | )
104 |
105 | elif model_type == "midas_v21_small":
106 | model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
107 | non_negative=True, blocks={'expand': True})
108 | net_w, net_h = 256, 256
109 | resize_mode = "upper_bound"
110 | normalization = NormalizeImage(
111 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
112 | )
113 |
114 | else:
115 | print(f"model_type '{model_type}' not implemented, use: --model_type large")
116 | assert False
117 |
118 | transform = Compose(
119 | [
120 | Resize(
121 | net_w,
122 | net_h,
123 | resize_target=None,
124 | keep_aspect_ratio=True,
125 | ensure_multiple_of=32,
126 | resize_method=resize_mode,
127 | image_interpolation_method=cv2.INTER_CUBIC,
128 | ),
129 | normalization,
130 | PrepareForNet(),
131 | ]
132 | )
133 |
134 | return model.eval(), transform
135 |
136 |
137 | class MiDaSInference(nn.Module):
138 | MODEL_TYPES_TORCH_HUB = [
139 | "DPT_Large",
140 | "DPT_Hybrid",
141 | "MiDaS_small"
142 | ]
143 | MODEL_TYPES_ISL = [
144 | "dpt_large",
145 | "dpt_hybrid",
146 | "midas_v21",
147 | "midas_v21_small",
148 | ]
149 |
150 | def __init__(self, model_type):
151 | super().__init__()
152 | assert (model_type in self.MODEL_TYPES_ISL)
153 | model, _ = load_model(model_type)
154 | self.model = model
155 | self.model.train = disabled_train
156 |
157 | def forward(self, x):
158 | # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array
159 | # NOTE: we expect that the correct transform has been called during dataloading.
160 | with torch.no_grad():
161 | prediction = self.model(x)
162 | prediction = torch.nn.functional.interpolate(
163 | prediction.unsqueeze(1),
164 | size=x.shape[2:],
165 | mode="bicubic",
166 | align_corners=False,
167 | )
168 | assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3])
169 | return prediction
170 |
171 |
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/ldm/modules/midas/midas/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/base_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class BaseModel(torch.nn.Module):
5 | def load(self, path):
6 | """Load model from file.
7 |
8 | Args:
9 | path (str): file path
10 | """
11 | parameters = torch.load(path, map_location=torch.device('cpu'))
12 |
13 | if "optimizer" in parameters:
14 | parameters = parameters["model"]
15 |
16 | self.load_state_dict(parameters)
17 |
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/blocks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .vit import (
5 | _make_pretrained_vitb_rn50_384,
6 | _make_pretrained_vitl16_384,
7 | _make_pretrained_vitb16_384,
8 | forward_vit,
9 | )
10 |
11 | def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
12 | if backbone == "vitl16_384":
13 | pretrained = _make_pretrained_vitl16_384(
14 | use_pretrained, hooks=hooks, use_readout=use_readout
15 | )
16 | scratch = _make_scratch(
17 | [256, 512, 1024, 1024], features, groups=groups, expand=expand
18 | ) # ViT-L/16 - 85.0% Top1 (backbone)
19 | elif backbone == "vitb_rn50_384":
20 | pretrained = _make_pretrained_vitb_rn50_384(
21 | use_pretrained,
22 | hooks=hooks,
23 | use_vit_only=use_vit_only,
24 | use_readout=use_readout,
25 | )
26 | scratch = _make_scratch(
27 | [256, 512, 768, 768], features, groups=groups, expand=expand
28 | ) # ViT-H/16 - 85.0% Top1 (backbone)
29 | elif backbone == "vitb16_384":
30 | pretrained = _make_pretrained_vitb16_384(
31 | use_pretrained, hooks=hooks, use_readout=use_readout
32 | )
33 | scratch = _make_scratch(
34 | [96, 192, 384, 768], features, groups=groups, expand=expand
35 | ) # ViT-B/16 - 84.6% Top1 (backbone)
36 | elif backbone == "resnext101_wsl":
37 | pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
38 | scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
39 | elif backbone == "efficientnet_lite3":
40 | pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
41 | scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
42 | else:
43 | print(f"Backbone '{backbone}' not implemented")
44 | assert False
45 |
46 | return pretrained, scratch
47 |
48 |
49 | def _make_scratch(in_shape, out_shape, groups=1, expand=False):
50 | scratch = nn.Module()
51 |
52 | out_shape1 = out_shape
53 | out_shape2 = out_shape
54 | out_shape3 = out_shape
55 | out_shape4 = out_shape
56 | if expand==True:
57 | out_shape1 = out_shape
58 | out_shape2 = out_shape*2
59 | out_shape3 = out_shape*4
60 | out_shape4 = out_shape*8
61 |
62 | scratch.layer1_rn = nn.Conv2d(
63 | in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
64 | )
65 | scratch.layer2_rn = nn.Conv2d(
66 | in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
67 | )
68 | scratch.layer3_rn = nn.Conv2d(
69 | in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
70 | )
71 | scratch.layer4_rn = nn.Conv2d(
72 | in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
73 | )
74 |
75 | return scratch
76 |
77 |
78 | def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
79 | efficientnet = torch.hub.load(
80 | "rwightman/gen-efficientnet-pytorch",
81 | "tf_efficientnet_lite3",
82 | pretrained=use_pretrained,
83 | exportable=exportable
84 | )
85 | return _make_efficientnet_backbone(efficientnet)
86 |
87 |
88 | def _make_efficientnet_backbone(effnet):
89 | pretrained = nn.Module()
90 |
91 | pretrained.layer1 = nn.Sequential(
92 | effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
93 | )
94 | pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
95 | pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
96 | pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
97 |
98 | return pretrained
99 |
100 |
101 | def _make_resnet_backbone(resnet):
102 | pretrained = nn.Module()
103 | pretrained.layer1 = nn.Sequential(
104 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
105 | )
106 |
107 | pretrained.layer2 = resnet.layer2
108 | pretrained.layer3 = resnet.layer3
109 | pretrained.layer4 = resnet.layer4
110 |
111 | return pretrained
112 |
113 |
114 | def _make_pretrained_resnext101_wsl(use_pretrained):
115 | resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
116 | return _make_resnet_backbone(resnet)
117 |
118 |
119 |
120 | class Interpolate(nn.Module):
121 | """Interpolation module.
122 | """
123 |
124 | def __init__(self, scale_factor, mode, align_corners=False):
125 | """Init.
126 |
127 | Args:
128 | scale_factor (float): scaling
129 | mode (str): interpolation mode
130 | """
131 | super(Interpolate, self).__init__()
132 |
133 | self.interp = nn.functional.interpolate
134 | self.scale_factor = scale_factor
135 | self.mode = mode
136 | self.align_corners = align_corners
137 |
138 | def forward(self, x):
139 | """Forward pass.
140 |
141 | Args:
142 | x (tensor): input
143 |
144 | Returns:
145 | tensor: interpolated data
146 | """
147 |
148 | x = self.interp(
149 | x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
150 | )
151 |
152 | return x
153 |
154 |
155 | class ResidualConvUnit(nn.Module):
156 | """Residual convolution module.
157 | """
158 |
159 | def __init__(self, features):
160 | """Init.
161 |
162 | Args:
163 | features (int): number of features
164 | """
165 | super().__init__()
166 |
167 | self.conv1 = nn.Conv2d(
168 | features, features, kernel_size=3, stride=1, padding=1, bias=True
169 | )
170 |
171 | self.conv2 = nn.Conv2d(
172 | features, features, kernel_size=3, stride=1, padding=1, bias=True
173 | )
174 |
175 | self.relu = nn.ReLU(inplace=True)
176 |
177 | def forward(self, x):
178 | """Forward pass.
179 |
180 | Args:
181 | x (tensor): input
182 |
183 | Returns:
184 | tensor: output
185 | """
186 | out = self.relu(x)
187 | out = self.conv1(out)
188 | out = self.relu(out)
189 | out = self.conv2(out)
190 |
191 | return out + x
192 |
193 |
194 | class FeatureFusionBlock(nn.Module):
195 | """Feature fusion block.
196 | """
197 |
198 | def __init__(self, features):
199 | """Init.
200 |
201 | Args:
202 | features (int): number of features
203 | """
204 | super(FeatureFusionBlock, self).__init__()
205 |
206 | self.resConfUnit1 = ResidualConvUnit(features)
207 | self.resConfUnit2 = ResidualConvUnit(features)
208 |
209 | def forward(self, *xs):
210 | """Forward pass.
211 |
212 | Returns:
213 | tensor: output
214 | """
215 | output = xs[0]
216 |
217 | if len(xs) == 2:
218 | output += self.resConfUnit1(xs[1])
219 |
220 | output = self.resConfUnit2(output)
221 |
222 | output = nn.functional.interpolate(
223 | output, scale_factor=2, mode="bilinear", align_corners=True
224 | )
225 |
226 | return output
227 |
228 |
229 |
230 |
231 | class ResidualConvUnit_custom(nn.Module):
232 | """Residual convolution module.
233 | """
234 |
235 | def __init__(self, features, activation, bn):
236 | """Init.
237 |
238 | Args:
239 | features (int): number of features
240 | """
241 | super().__init__()
242 |
243 | self.bn = bn
244 |
245 | self.groups=1
246 |
247 | self.conv1 = nn.Conv2d(
248 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
249 | )
250 |
251 | self.conv2 = nn.Conv2d(
252 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
253 | )
254 |
255 | if self.bn==True:
256 | self.bn1 = nn.BatchNorm2d(features)
257 | self.bn2 = nn.BatchNorm2d(features)
258 |
259 | self.activation = activation
260 |
261 | self.skip_add = nn.quantized.FloatFunctional()
262 |
263 | def forward(self, x):
264 | """Forward pass.
265 |
266 | Args:
267 | x (tensor): input
268 |
269 | Returns:
270 | tensor: output
271 | """
272 |
273 | out = self.activation(x)
274 | out = self.conv1(out)
275 | if self.bn==True:
276 | out = self.bn1(out)
277 |
278 | out = self.activation(out)
279 | out = self.conv2(out)
280 | if self.bn==True:
281 | out = self.bn2(out)
282 |
283 | if self.groups > 1:
284 | out = self.conv_merge(out)
285 |
286 | return self.skip_add.add(out, x)
287 |
288 | # return out + x
289 |
290 |
291 | class FeatureFusionBlock_custom(nn.Module):
292 | """Feature fusion block.
293 | """
294 |
295 | def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
296 | """Init.
297 |
298 | Args:
299 | features (int): number of features
300 | """
301 | super(FeatureFusionBlock_custom, self).__init__()
302 |
303 | self.deconv = deconv
304 | self.align_corners = align_corners
305 |
306 | self.groups=1
307 |
308 | self.expand = expand
309 | out_features = features
310 | if self.expand==True:
311 | out_features = features//2
312 |
313 | self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
314 |
315 | self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
316 | self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
317 |
318 | self.skip_add = nn.quantized.FloatFunctional()
319 |
320 | def forward(self, *xs):
321 | """Forward pass.
322 |
323 | Returns:
324 | tensor: output
325 | """
326 | output = xs[0]
327 |
328 | if len(xs) == 2:
329 | res = self.resConfUnit1(xs[1])
330 | output = self.skip_add.add(output, res)
331 | # output += res
332 |
333 | output = self.resConfUnit2(output)
334 |
335 | output = nn.functional.interpolate(
336 | output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
337 | )
338 |
339 | output = self.out_conv(output)
340 |
341 | return output
342 |
343 |
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/dpt_depth.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .base_model import BaseModel
6 | from .blocks import (
7 | FeatureFusionBlock,
8 | FeatureFusionBlock_custom,
9 | Interpolate,
10 | _make_encoder,
11 | forward_vit,
12 | )
13 |
14 |
15 | def _make_fusion_block(features, use_bn):
16 | return FeatureFusionBlock_custom(
17 | features,
18 | nn.ReLU(False),
19 | deconv=False,
20 | bn=use_bn,
21 | expand=False,
22 | align_corners=True,
23 | )
24 |
25 |
26 | class DPT(BaseModel):
27 | def __init__(
28 | self,
29 | head,
30 | features=256,
31 | backbone="vitb_rn50_384",
32 | readout="project",
33 | channels_last=False,
34 | use_bn=False,
35 | ):
36 |
37 | super(DPT, self).__init__()
38 |
39 | self.channels_last = channels_last
40 |
41 | hooks = {
42 | "vitb_rn50_384": [0, 1, 8, 11],
43 | "vitb16_384": [2, 5, 8, 11],
44 | "vitl16_384": [5, 11, 17, 23],
45 | }
46 |
47 | # Instantiate backbone and reassemble blocks
48 | self.pretrained, self.scratch = _make_encoder(
49 | backbone,
50 | features,
51 | False, # Set to true of you want to train from scratch, uses ImageNet weights
52 | groups=1,
53 | expand=False,
54 | exportable=False,
55 | hooks=hooks[backbone],
56 | use_readout=readout,
57 | )
58 |
59 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
60 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
61 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
62 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
63 |
64 | self.scratch.output_conv = head
65 |
66 |
67 | def forward(self, x):
68 | if self.channels_last == True:
69 | x.contiguous(memory_format=torch.channels_last)
70 |
71 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
72 |
73 | layer_1_rn = self.scratch.layer1_rn(layer_1)
74 | layer_2_rn = self.scratch.layer2_rn(layer_2)
75 | layer_3_rn = self.scratch.layer3_rn(layer_3)
76 | layer_4_rn = self.scratch.layer4_rn(layer_4)
77 |
78 | path_4 = self.scratch.refinenet4(layer_4_rn)
79 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
80 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
81 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
82 |
83 | out = self.scratch.output_conv(path_1)
84 |
85 | return out
86 |
87 |
88 | class DPTDepthModel(DPT):
89 | def __init__(self, path=None, non_negative=True, **kwargs):
90 | features = kwargs["features"] if "features" in kwargs else 256
91 |
92 | head = nn.Sequential(
93 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
94 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
95 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
96 | nn.ReLU(True),
97 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
98 | nn.ReLU(True) if non_negative else nn.Identity(),
99 | nn.Identity(),
100 | )
101 |
102 | super().__init__(head, **kwargs)
103 |
104 | if path is not None:
105 | self.load(path)
106 |
107 | def forward(self, x):
108 | return super().forward(x).squeeze(dim=1)
109 |
110 |
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/midas_net.py:
--------------------------------------------------------------------------------
1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2 | This file contains code that is adapted from
3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4 | """
5 | import torch
6 | import torch.nn as nn
7 |
8 | from .base_model import BaseModel
9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
10 |
11 |
12 | class MidasNet(BaseModel):
13 | """Network for monocular depth estimation.
14 | """
15 |
16 | def __init__(self, path=None, features=256, non_negative=True):
17 | """Init.
18 |
19 | Args:
20 | path (str, optional): Path to saved model. Defaults to None.
21 | features (int, optional): Number of features. Defaults to 256.
22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50
23 | """
24 | print("Loading weights: ", path)
25 |
26 | super(MidasNet, self).__init__()
27 |
28 | use_pretrained = False if path is None else True
29 |
30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
31 |
32 | self.scratch.refinenet4 = FeatureFusionBlock(features)
33 | self.scratch.refinenet3 = FeatureFusionBlock(features)
34 | self.scratch.refinenet2 = FeatureFusionBlock(features)
35 | self.scratch.refinenet1 = FeatureFusionBlock(features)
36 |
37 | self.scratch.output_conv = nn.Sequential(
38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
39 | Interpolate(scale_factor=2, mode="bilinear"),
40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
41 | nn.ReLU(True),
42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
43 | nn.ReLU(True) if non_negative else nn.Identity(),
44 | )
45 |
46 | if path:
47 | self.load(path)
48 |
49 | def forward(self, x):
50 | """Forward pass.
51 |
52 | Args:
53 | x (tensor): input data (image)
54 |
55 | Returns:
56 | tensor: depth
57 | """
58 |
59 | layer_1 = self.pretrained.layer1(x)
60 | layer_2 = self.pretrained.layer2(layer_1)
61 | layer_3 = self.pretrained.layer3(layer_2)
62 | layer_4 = self.pretrained.layer4(layer_3)
63 |
64 | layer_1_rn = self.scratch.layer1_rn(layer_1)
65 | layer_2_rn = self.scratch.layer2_rn(layer_2)
66 | layer_3_rn = self.scratch.layer3_rn(layer_3)
67 | layer_4_rn = self.scratch.layer4_rn(layer_4)
68 |
69 | path_4 = self.scratch.refinenet4(layer_4_rn)
70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
73 |
74 | out = self.scratch.output_conv(path_1)
75 |
76 | return torch.squeeze(out, dim=1)
77 |
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/midas_net_custom.py:
--------------------------------------------------------------------------------
1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2 | This file contains code that is adapted from
3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4 | """
5 | import torch
6 | import torch.nn as nn
7 |
8 | from .base_model import BaseModel
9 | from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
10 |
11 |
12 | class MidasNet_small(BaseModel):
13 | """Network for monocular depth estimation.
14 | """
15 |
16 | def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
17 | blocks={'expand': True}):
18 | """Init.
19 |
20 | Args:
21 | path (str, optional): Path to saved model. Defaults to None.
22 | features (int, optional): Number of features. Defaults to 256.
23 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50
24 | """
25 | print("Loading weights: ", path)
26 |
27 | super(MidasNet_small, self).__init__()
28 |
29 | use_pretrained = False if path else True
30 |
31 | self.channels_last = channels_last
32 | self.blocks = blocks
33 | self.backbone = backbone
34 |
35 | self.groups = 1
36 |
37 | features1=features
38 | features2=features
39 | features3=features
40 | features4=features
41 | self.expand = False
42 | if "expand" in self.blocks and self.blocks['expand'] == True:
43 | self.expand = True
44 | features1=features
45 | features2=features*2
46 | features3=features*4
47 | features4=features*8
48 |
49 | self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
50 |
51 | self.scratch.activation = nn.ReLU(False)
52 |
53 | self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
54 | self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
55 | self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
56 | self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
57 |
58 |
59 | self.scratch.output_conv = nn.Sequential(
60 | nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
61 | Interpolate(scale_factor=2, mode="bilinear"),
62 | nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
63 | self.scratch.activation,
64 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
65 | nn.ReLU(True) if non_negative else nn.Identity(),
66 | nn.Identity(),
67 | )
68 |
69 | if path:
70 | self.load(path)
71 |
72 |
73 | def forward(self, x):
74 | """Forward pass.
75 |
76 | Args:
77 | x (tensor): input data (image)
78 |
79 | Returns:
80 | tensor: depth
81 | """
82 | if self.channels_last==True:
83 | print("self.channels_last = ", self.channels_last)
84 | x.contiguous(memory_format=torch.channels_last)
85 |
86 |
87 | layer_1 = self.pretrained.layer1(x)
88 | layer_2 = self.pretrained.layer2(layer_1)
89 | layer_3 = self.pretrained.layer3(layer_2)
90 | layer_4 = self.pretrained.layer4(layer_3)
91 |
92 | layer_1_rn = self.scratch.layer1_rn(layer_1)
93 | layer_2_rn = self.scratch.layer2_rn(layer_2)
94 | layer_3_rn = self.scratch.layer3_rn(layer_3)
95 | layer_4_rn = self.scratch.layer4_rn(layer_4)
96 |
97 |
98 | path_4 = self.scratch.refinenet4(layer_4_rn)
99 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
100 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
101 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
102 |
103 | out = self.scratch.output_conv(path_1)
104 |
105 | return torch.squeeze(out, dim=1)
106 |
107 |
108 |
109 | def fuse_model(m):
110 | prev_previous_type = nn.Identity()
111 | prev_previous_name = ''
112 | previous_type = nn.Identity()
113 | previous_name = ''
114 | for name, module in m.named_modules():
115 | if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
116 | # print("FUSED ", prev_previous_name, previous_name, name)
117 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
118 | elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
119 | # print("FUSED ", prev_previous_name, previous_name)
120 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
121 | # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
122 | # print("FUSED ", previous_name, name)
123 | # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
124 |
125 | prev_previous_type = previous_type
126 | prev_previous_name = previous_name
127 | previous_type = type(module)
128 | previous_name = name
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/transforms.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import math
4 |
5 |
6 | def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
7 | """Rezise the sample to ensure the given size. Keeps aspect ratio.
8 |
9 | Args:
10 | sample (dict): sample
11 | size (tuple): image size
12 |
13 | Returns:
14 | tuple: new size
15 | """
16 | shape = list(sample["disparity"].shape)
17 |
18 | if shape[0] >= size[0] and shape[1] >= size[1]:
19 | return sample
20 |
21 | scale = [0, 0]
22 | scale[0] = size[0] / shape[0]
23 | scale[1] = size[1] / shape[1]
24 |
25 | scale = max(scale)
26 |
27 | shape[0] = math.ceil(scale * shape[0])
28 | shape[1] = math.ceil(scale * shape[1])
29 |
30 | # resize
31 | sample["image"] = cv2.resize(
32 | sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
33 | )
34 |
35 | sample["disparity"] = cv2.resize(
36 | sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
37 | )
38 | sample["mask"] = cv2.resize(
39 | sample["mask"].astype(np.float32),
40 | tuple(shape[::-1]),
41 | interpolation=cv2.INTER_NEAREST,
42 | )
43 | sample["mask"] = sample["mask"].astype(bool)
44 |
45 | return tuple(shape)
46 |
47 |
48 | class Resize(object):
49 | """Resize sample to given size (width, height).
50 | """
51 |
52 | def __init__(
53 | self,
54 | width,
55 | height,
56 | resize_target=True,
57 | keep_aspect_ratio=False,
58 | ensure_multiple_of=1,
59 | resize_method="lower_bound",
60 | image_interpolation_method=cv2.INTER_AREA,
61 | ):
62 | """Init.
63 |
64 | Args:
65 | width (int): desired output width
66 | height (int): desired output height
67 | resize_target (bool, optional):
68 | True: Resize the full sample (image, mask, target).
69 | False: Resize image only.
70 | Defaults to True.
71 | keep_aspect_ratio (bool, optional):
72 | True: Keep the aspect ratio of the input sample.
73 | Output sample might not have the given width and height, and
74 | resize behaviour depends on the parameter 'resize_method'.
75 | Defaults to False.
76 | ensure_multiple_of (int, optional):
77 | Output width and height is constrained to be multiple of this parameter.
78 | Defaults to 1.
79 | resize_method (str, optional):
80 | "lower_bound": Output will be at least as large as the given size.
81 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
82 | "minimal": Scale as least as possible. (Output size might be smaller than given size.)
83 | Defaults to "lower_bound".
84 | """
85 | self.__width = width
86 | self.__height = height
87 |
88 | self.__resize_target = resize_target
89 | self.__keep_aspect_ratio = keep_aspect_ratio
90 | self.__multiple_of = ensure_multiple_of
91 | self.__resize_method = resize_method
92 | self.__image_interpolation_method = image_interpolation_method
93 |
94 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
95 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
96 |
97 | if max_val is not None and y > max_val:
98 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
99 |
100 | if y < min_val:
101 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
102 |
103 | return y
104 |
105 | def get_size(self, width, height):
106 | # determine new height and width
107 | scale_height = self.__height / height
108 | scale_width = self.__width / width
109 |
110 | if self.__keep_aspect_ratio:
111 | if self.__resize_method == "lower_bound":
112 | # scale such that output size is lower bound
113 | if scale_width > scale_height:
114 | # fit width
115 | scale_height = scale_width
116 | else:
117 | # fit height
118 | scale_width = scale_height
119 | elif self.__resize_method == "upper_bound":
120 | # scale such that output size is upper bound
121 | if scale_width < scale_height:
122 | # fit width
123 | scale_height = scale_width
124 | else:
125 | # fit height
126 | scale_width = scale_height
127 | elif self.__resize_method == "minimal":
128 | # scale as least as possbile
129 | if abs(1 - scale_width) < abs(1 - scale_height):
130 | # fit width
131 | scale_height = scale_width
132 | else:
133 | # fit height
134 | scale_width = scale_height
135 | else:
136 | raise ValueError(
137 | f"resize_method {self.__resize_method} not implemented"
138 | )
139 |
140 | if self.__resize_method == "lower_bound":
141 | new_height = self.constrain_to_multiple_of(
142 | scale_height * height, min_val=self.__height
143 | )
144 | new_width = self.constrain_to_multiple_of(
145 | scale_width * width, min_val=self.__width
146 | )
147 | elif self.__resize_method == "upper_bound":
148 | new_height = self.constrain_to_multiple_of(
149 | scale_height * height, max_val=self.__height
150 | )
151 | new_width = self.constrain_to_multiple_of(
152 | scale_width * width, max_val=self.__width
153 | )
154 | elif self.__resize_method == "minimal":
155 | new_height = self.constrain_to_multiple_of(scale_height * height)
156 | new_width = self.constrain_to_multiple_of(scale_width * width)
157 | else:
158 | raise ValueError(f"resize_method {self.__resize_method} not implemented")
159 |
160 | return (new_width, new_height)
161 |
162 | def __call__(self, sample):
163 | width, height = self.get_size(
164 | sample["image"].shape[1], sample["image"].shape[0]
165 | )
166 |
167 | # resize sample
168 | sample["image"] = cv2.resize(
169 | sample["image"],
170 | (width, height),
171 | interpolation=self.__image_interpolation_method,
172 | )
173 |
174 | if self.__resize_target:
175 | if "disparity" in sample:
176 | sample["disparity"] = cv2.resize(
177 | sample["disparity"],
178 | (width, height),
179 | interpolation=cv2.INTER_NEAREST,
180 | )
181 |
182 | if "depth" in sample:
183 | sample["depth"] = cv2.resize(
184 | sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
185 | )
186 |
187 | sample["mask"] = cv2.resize(
188 | sample["mask"].astype(np.float32),
189 | (width, height),
190 | interpolation=cv2.INTER_NEAREST,
191 | )
192 | sample["mask"] = sample["mask"].astype(bool)
193 |
194 | return sample
195 |
196 |
197 | class NormalizeImage(object):
198 | """Normlize image by given mean and std.
199 | """
200 |
201 | def __init__(self, mean, std):
202 | self.__mean = mean
203 | self.__std = std
204 |
205 | def __call__(self, sample):
206 | sample["image"] = (sample["image"] - self.__mean) / self.__std
207 |
208 | return sample
209 |
210 |
211 | class PrepareForNet(object):
212 | """Prepare sample for usage as network input.
213 | """
214 |
215 | def __init__(self):
216 | pass
217 |
218 | def __call__(self, sample):
219 | image = np.transpose(sample["image"], (2, 0, 1))
220 | sample["image"] = np.ascontiguousarray(image).astype(np.float32)
221 |
222 | if "mask" in sample:
223 | sample["mask"] = sample["mask"].astype(np.float32)
224 | sample["mask"] = np.ascontiguousarray(sample["mask"])
225 |
226 | if "disparity" in sample:
227 | disparity = sample["disparity"].astype(np.float32)
228 | sample["disparity"] = np.ascontiguousarray(disparity)
229 |
230 | if "depth" in sample:
231 | depth = sample["depth"].astype(np.float32)
232 | sample["depth"] = np.ascontiguousarray(depth)
233 |
234 | return sample
235 |
--------------------------------------------------------------------------------
/ldm/modules/midas/utils.py:
--------------------------------------------------------------------------------
1 | """Utils for monoDepth."""
2 | import sys
3 | import re
4 | import numpy as np
5 | import cv2
6 | import torch
7 |
8 |
9 | def read_pfm(path):
10 | """Read pfm file.
11 |
12 | Args:
13 | path (str): path to file
14 |
15 | Returns:
16 | tuple: (data, scale)
17 | """
18 | with open(path, "rb") as file:
19 |
20 | color = None
21 | width = None
22 | height = None
23 | scale = None
24 | endian = None
25 |
26 | header = file.readline().rstrip()
27 | if header.decode("ascii") == "PF":
28 | color = True
29 | elif header.decode("ascii") == "Pf":
30 | color = False
31 | else:
32 | raise Exception("Not a PFM file: " + path)
33 |
34 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
35 | if dim_match:
36 | width, height = list(map(int, dim_match.groups()))
37 | else:
38 | raise Exception("Malformed PFM header.")
39 |
40 | scale = float(file.readline().decode("ascii").rstrip())
41 | if scale < 0:
42 | # little-endian
43 | endian = "<"
44 | scale = -scale
45 | else:
46 | # big-endian
47 | endian = ">"
48 |
49 | data = np.fromfile(file, endian + "f")
50 | shape = (height, width, 3) if color else (height, width)
51 |
52 | data = np.reshape(data, shape)
53 | data = np.flipud(data)
54 |
55 | return data, scale
56 |
57 |
58 | def write_pfm(path, image, scale=1):
59 | """Write pfm file.
60 |
61 | Args:
62 | path (str): pathto file
63 | image (array): data
64 | scale (int, optional): Scale. Defaults to 1.
65 | """
66 |
67 | with open(path, "wb") as file:
68 | color = None
69 |
70 | if image.dtype.name != "float32":
71 | raise Exception("Image dtype must be float32.")
72 |
73 | image = np.flipud(image)
74 |
75 | if len(image.shape) == 3 and image.shape[2] == 3: # color image
76 | color = True
77 | elif (
78 | len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
79 | ): # greyscale
80 | color = False
81 | else:
82 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
83 |
84 | file.write("PF\n" if color else "Pf\n".encode())
85 | file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
86 |
87 | endian = image.dtype.byteorder
88 |
89 | if endian == "<" or endian == "=" and sys.byteorder == "little":
90 | scale = -scale
91 |
92 | file.write("%f\n".encode() % scale)
93 |
94 | image.tofile(file)
95 |
96 |
97 | def read_image(path):
98 | """Read image and output RGB image (0-1).
99 |
100 | Args:
101 | path (str): path to file
102 |
103 | Returns:
104 | array: RGB image (0-1)
105 | """
106 | img = cv2.imread(path)
107 |
108 | if img.ndim == 2:
109 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
110 |
111 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
112 |
113 | return img
114 |
115 |
116 | def resize_image(img):
117 | """Resize image and make it fit for network.
118 |
119 | Args:
120 | img (array): image
121 |
122 | Returns:
123 | tensor: data ready for network
124 | """
125 | height_orig = img.shape[0]
126 | width_orig = img.shape[1]
127 |
128 | if width_orig > height_orig:
129 | scale = width_orig / 384
130 | else:
131 | scale = height_orig / 384
132 |
133 | height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
134 | width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
135 |
136 | img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
137 |
138 | img_resized = (
139 | torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
140 | )
141 | img_resized = img_resized.unsqueeze(0)
142 |
143 | return img_resized
144 |
145 |
146 | def resize_depth(depth, width, height):
147 | """Resize depth map and bring to CPU (numpy).
148 |
149 | Args:
150 | depth (tensor): depth
151 | width (int): image width
152 | height (int): image height
153 |
154 | Returns:
155 | array: processed depth
156 | """
157 | depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
158 |
159 | depth_resized = cv2.resize(
160 | depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
161 | )
162 |
163 | return depth_resized
164 |
165 | def write_depth(path, depth, bits=1):
166 | """Write depth map to pfm and png file.
167 |
168 | Args:
169 | path (str): filepath without extension
170 | depth (array): depth
171 | """
172 | write_pfm(path + ".pfm", depth.astype(np.float32))
173 |
174 | depth_min = depth.min()
175 | depth_max = depth.max()
176 |
177 | max_val = (2**(8*bits))-1
178 |
179 | if depth_max - depth_min > np.finfo("float").eps:
180 | out = max_val * (depth - depth_min) / (depth_max - depth_min)
181 | else:
182 | out = np.zeros(depth.shape, dtype=depth.type)
183 |
184 | if bits == 1:
185 | cv2.imwrite(path + ".png", out.astype("uint8"))
186 | elif bits == 2:
187 | cv2.imwrite(path + ".png", out.astype("uint16"))
188 |
189 | return
190 |
--------------------------------------------------------------------------------
/ldm/util.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | import torch
4 | from torch import optim
5 | import numpy as np
6 |
7 | from inspect import isfunction
8 | from PIL import Image, ImageDraw, ImageFont
9 |
10 |
11 | def log_txt_as_img(wh, xc, size=10):
12 | # wh a tuple of (width, height)
13 | # xc a list of captions to plot
14 | b = len(xc)
15 | txts = list()
16 | for bi in range(b):
17 | txt = Image.new("RGB", wh, color="white")
18 | draw = ImageDraw.Draw(txt)
19 | font = ImageFont.truetype('font/DejaVuSans.ttf', size=size)
20 | nc = int(40 * (wh[0] / 256))
21 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
22 |
23 | try:
24 | draw.text((0, 0), lines, fill="black", font=font)
25 | except UnicodeEncodeError:
26 | print("Cant encode string for logging. Skipping.")
27 |
28 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
29 | txts.append(txt)
30 | txts = np.stack(txts)
31 | txts = torch.tensor(txts)
32 | return txts
33 |
34 |
35 | def ismap(x):
36 | if not isinstance(x, torch.Tensor):
37 | return False
38 | return (len(x.shape) == 4) and (x.shape[1] > 3)
39 |
40 |
41 | def isimage(x):
42 | if not isinstance(x,torch.Tensor):
43 | return False
44 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
45 |
46 |
47 | def exists(x):
48 | return x is not None
49 |
50 |
51 | def default(val, d):
52 | if exists(val):
53 | return val
54 | return d() if isfunction(d) else d
55 |
56 |
57 | def mean_flat(tensor):
58 | """
59 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
60 | Take the mean over all non-batch dimensions.
61 | """
62 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
63 |
64 |
65 | def count_params(model, verbose=False):
66 | total_params = sum(p.numel() for p in model.parameters())
67 | if verbose:
68 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
69 | return total_params
70 |
71 |
72 | def instantiate_from_config(config, **kwargs):
73 | if not "target" in config:
74 | if config == '__is_first_stage__':
75 | return None
76 | elif config == "__is_unconditional__":
77 | return None
78 | raise KeyError("Expected key `target` to instantiate.")
79 | return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs)
80 |
81 |
82 | def get_obj_from_str(string, reload=False):
83 | module, cls = string.rsplit(".", 1)
84 | if reload:
85 | module_imp = importlib.import_module(module)
86 | importlib.reload(module_imp)
87 | return getattr(importlib.import_module(module, package=None), cls)
88 |
89 |
90 | class AdamWwithEMAandWings(optim.Optimizer):
91 | # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
92 | def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
93 | weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
94 | ema_power=1., param_names=()):
95 | """AdamW that saves EMA versions of the parameters."""
96 | if not 0.0 <= lr:
97 | raise ValueError("Invalid learning rate: {}".format(lr))
98 | if not 0.0 <= eps:
99 | raise ValueError("Invalid epsilon value: {}".format(eps))
100 | if not 0.0 <= betas[0] < 1.0:
101 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
102 | if not 0.0 <= betas[1] < 1.0:
103 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
104 | if not 0.0 <= weight_decay:
105 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
106 | if not 0.0 <= ema_decay <= 1.0:
107 | raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
108 | defaults = dict(lr=lr, betas=betas, eps=eps,
109 | weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
110 | ema_power=ema_power, param_names=param_names)
111 | super().__init__(params, defaults)
112 |
113 | def __setstate__(self, state):
114 | super().__setstate__(state)
115 | for group in self.param_groups:
116 | group.setdefault('amsgrad', False)
117 |
118 | @torch.no_grad()
119 | def step(self, closure=None):
120 | """Performs a single optimization step.
121 | Args:
122 | closure (callable, optional): A closure that reevaluates the model
123 | and returns the loss.
124 | """
125 | loss = None
126 | if closure is not None:
127 | with torch.enable_grad():
128 | loss = closure()
129 |
130 | for group in self.param_groups:
131 | params_with_grad = []
132 | grads = []
133 | exp_avgs = []
134 | exp_avg_sqs = []
135 | ema_params_with_grad = []
136 | state_sums = []
137 | max_exp_avg_sqs = []
138 | state_steps = []
139 | amsgrad = group['amsgrad']
140 | beta1, beta2 = group['betas']
141 | ema_decay = group['ema_decay']
142 | ema_power = group['ema_power']
143 |
144 | for p in group['params']:
145 | if p.grad is None:
146 | continue
147 | params_with_grad.append(p)
148 | if p.grad.is_sparse:
149 | raise RuntimeError('AdamW does not support sparse gradients')
150 | grads.append(p.grad)
151 |
152 | state = self.state[p]
153 |
154 | # State initialization
155 | if len(state) == 0:
156 | state['step'] = 0
157 | # Exponential moving average of gradient values
158 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
159 | # Exponential moving average of squared gradient values
160 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
161 | if amsgrad:
162 | # Maintains max of all exp. moving avg. of sq. grad. values
163 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
164 | # Exponential moving average of parameter values
165 | state['param_exp_avg'] = p.detach().float().clone()
166 |
167 | exp_avgs.append(state['exp_avg'])
168 | exp_avg_sqs.append(state['exp_avg_sq'])
169 | ema_params_with_grad.append(state['param_exp_avg'])
170 |
171 | if amsgrad:
172 | max_exp_avg_sqs.append(state['max_exp_avg_sq'])
173 |
174 | # update the steps for each param group update
175 | state['step'] += 1
176 | # record the step after step update
177 | state_steps.append(state['step'])
178 |
179 | optim._functional.adamw(params_with_grad,
180 | grads,
181 | exp_avgs,
182 | exp_avg_sqs,
183 | max_exp_avg_sqs,
184 | state_steps,
185 | amsgrad=amsgrad,
186 | beta1=beta1,
187 | beta2=beta2,
188 | lr=group['lr'],
189 | weight_decay=group['weight_decay'],
190 | eps=group['eps'],
191 | maximize=False)
192 |
193 | cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
194 | for param, ema_param in zip(params_with_grad, ema_params_with_grad):
195 | ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
196 |
197 | return loss
--------------------------------------------------------------------------------
/models/toss_vae.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: cldm.toss.TOSS
3 | params:
4 | linear_start: 0.00085
5 | linear_end: 0.0120
6 | num_timesteps_cond: 1
7 | log_every_t: 200
8 | timesteps: 1000
9 | first_stage_key: "jpg"
10 | cond_stage_key: "txt"
11 | control_key: "hint"
12 | image_size: 32
13 | channels: 4
14 | cond_stage_trainable: false
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: True
19 | only_mid_control: False
20 | ucg_txt: 0.5
21 | max_timesteps: 1000
22 | min_timesteps: 0
23 | finetune: True
24 | ucg_img: 0.05
25 |
26 | scheduler_config: # 10000 warmup steps
27 | target: ldm.lr_scheduler.LambdaLinearScheduler
28 | params:
29 | warm_up_steps: [ 100 ]
30 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
31 | f_start: [ 1.e-6 ]
32 | f_max: [ 1. ]
33 | f_min: [ 1. ]
34 |
35 | unet_config:
36 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel_toss
37 | params:
38 | image_size: 32 # unused
39 | in_channels: 4
40 | out_channels: 4
41 | model_channels: 320
42 | attention_resolutions: [ 4, 2, 1 ]
43 | num_res_blocks: 2
44 | channel_mult: [ 1, 2, 4, 4 ]
45 | num_heads: 8
46 | use_spatial_transformer: True
47 | transformer_depth: 1
48 | context_dim: 768
49 | use_checkpoint: True
50 | legacy: False
51 | temp_attn: "CA_vae"
52 | pose_enc: "vae"
53 |
54 | first_stage_config:
55 | target: ldm.models.autoencoder.AutoencoderKL
56 | params:
57 | embed_dim: 4
58 | monitor: val/rec_loss
59 | ddconfig:
60 | double_z: true
61 | z_channels: 4
62 | resolution: 256
63 | in_channels: 3
64 | out_ch: 3
65 | ch: 128
66 | ch_mult:
67 | - 1
68 | - 2
69 | - 4
70 | - 4
71 | num_res_blocks: 2
72 | attn_resolutions: []
73 | dropout: 0.0
74 | lossconfig:
75 | target: torch.nn.Identity
76 |
77 | cond_stage_config:
78 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
79 |
80 |
81 | data800k:
82 | target: datasets.objaverse800k.ObjaverseDataModuleFromConfig
83 | params:
84 | root_dir: '/comp_robot/mm_generative/data/.objaverse/hf-objaverse-v1/views_release'
85 | batch_size: 128
86 | num_workers: 12
87 | total_view: 12
88 | caption: "rerank"
89 | pose_enc: "freq"
90 | train:
91 | validation: False
92 | image_transforms:
93 | size: 256
94 |
95 | validation:
96 | validation: True
97 | image_transforms:
98 | size: 256
99 |
100 |
101 | data_car:
102 | target: datasets.objaverse_car.ObjaverseDataModuleFromConfig
103 | params:
104 | root_dir: '/comp_robot/mm_generative/data/.objaverse/hf-objaverse-v1/views_release'
105 | batch_size: 128
106 | num_workers: 12
107 | total_view: 12
108 | caption: 'rerank'
109 | pose_enc: "freq"
110 | train:
111 | validation: False
112 | image_transforms:
113 | size: 256
114 |
115 | validation:
116 | validation: True
117 | image_transforms:
118 | size: 256
119 |
--------------------------------------------------------------------------------
/opt.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def get_opts():
4 | parser = argparse.ArgumentParser()
5 | # common args for all datasets
6 | parser.add_argument('--root_dir', type=str, default="/comp_robot/shiyukai/dataset/nerf/Synthetic_NeRF/Chair/",
7 | help='root directory of dataset')
8 | parser.add_argument('--eval_root_dir', type=str, default='/comp_robot/mm_generative/data/GSO/views',
9 | help='root directory of dataset')
10 | parser.add_argument('--dataset_name', type=str, default='nsvf', help='which dataset to train/test')
11 | parser.add_argument('--split', type=str, default='train',
12 | choices=['train', 'trainval'],
13 | help='use which split to train')
14 | parser.add_argument('--downsample', type=float, default=1.0,
15 | help='downsample factor (<=1.0) for the images')
16 |
17 | # model parameters
18 | parser.add_argument('--model_cfg', type=str, default='./models/cldm_pose_v15.yaml',
19 | help='cfg path of model')
20 | parser.add_argument('--model_low_cfg', type=str, default='./models/cldm_pose_v15.yaml',
21 | help='cfg path of low-level model')
22 | parser.add_argument('--scale', type=float, default=0.5,
23 | help='scene scale (whole scene must lie in [-scale, scale]^3')
24 | parser.add_argument('--use_exposure', action='store_true', default=False,
25 | help='whether to train in HDR-NeRF setting')
26 | parser.add_argument('--resume_path', type=str, default='./models/control_sd15_pose_ini.ckpt',
27 | help='resume path')
28 | parser.add_argument('--resume_path_low', type=str, default='./models/control_sd15_pose_ini.ckpt',
29 | help='resume path for low-level model')
30 | parser.add_argument('--resume', action='store_true', default=False,
31 | help='train from resume')
32 | parser.add_argument('--text', type=str, default="a yellow lego bulldozer sitting on top of a table",
33 | help='text prompt')
34 | parser.add_argument('--uncond_pose', type=bool, default=False,
35 | help='set delta pose zero')
36 | parser.add_argument('--img_size', type=int, default=512,
37 | help='size of img')
38 | parser.add_argument('--acc_grad', type=int, default=None,
39 | help='accumulate grad')
40 | # parser.add_argument('--eval_guidance_scale1', type=float, default=1,
41 | # help='prompt guidance scale for eval')
42 | # parser.add_argument('--eval_guidance_scale2', type=float, default=1,
43 | # help='img guidance scale for eval')
44 | # WERN: duplicate
45 | parser.add_argument('--eval_prompt_guidance_scale', type=float, default=1,
46 | help='prompt guidance scale for eval')
47 | parser.add_argument('--eval_img_guidance_scale', type=float, default=1,
48 | help='img guidance scale for eval')
49 | parser.add_argument('--eval_guidance_scale_low', type=float, default=1,
50 | help='guidance scale for eval')
51 | parser.add_argument('--eval_use_ema_scope', action="store_true",
52 | help='ema_scop for eval')
53 | parser.add_argument('--eval_caption', type=str, default="origin",
54 | help='caption mode for eval')
55 | parser.add_argument('--inf_img_path', type=str, default="./exp/inference/img/008.png",
56 | help='caption mode for eval')
57 | parser.add_argument('--test_sub', action='store_true', default=False,
58 | help='test on part of eval')
59 | parser.add_argument('--divide_steps', type=int, default=800,
60 | help='divide steps for stage model')
61 | parser.add_argument('--attn_t', type=int, default=800,
62 | help='timesteps in viz attn')
63 | parser.add_argument('--layer_name', type=str, default="",
64 | help='timesteps in viz attn')
65 | parser.add_argument('--output_mode_attn', type=str, default="masked",
66 | help='timesteps in viz attn')
67 | parser.add_argument('--img_ucg', type=float, default=0.0,
68 | help='ucg for img')
69 | parser.add_argument('--register_scheduler', action='store_true', default=False,
70 | help='whether to register noise scheduler')
71 | parser.add_argument('--pose_enc', type=str, default="freq",
72 | help='encoding for camera pose')
73 |
74 |
75 | # training options
76 | parser.add_argument('--batch_size', type=int, default=1,
77 | help='number of rays in a batch')
78 | parser.add_argument('--log_interval_epoch', type=int, default=1,
79 | help='interval of logging info')
80 | parser.add_argument('--ckpt_interval', type=int, default=10,
81 | help='interval of ckpt')
82 | parser.add_argument('--logger_freq', type=int, default=20,
83 | help='logger_freq')
84 | parser.add_argument('--ray_sampling_strategy', type=str, default='all_images',
85 | choices=['all_images', 'same_image'],
86 | help='''
87 | all_images: uniformly from all pixels of ALL images
88 | same_image: uniformly from all pixels of a SAME image
89 | ''')
90 | parser.add_argument('--max_steps', type=int, default=100000,
91 | help='max number of steps to train')
92 | parser.add_argument('--num_epochs', type=int, default=30,
93 | help='number of training epochs')
94 | parser.add_argument('--num_gpus', type=int, default=1,
95 | help='number of gpus')
96 | parser.add_argument('--lr', type=float, default=1e-4,
97 | help='learning rate')
98 | # experimental training options
99 | parser.add_argument('--optimize_ext', action='store_true', default=False,
100 | help='whether to optimize extrinsics (experimental)')
101 | parser.add_argument('--random_bg', action='store_true', default=False,
102 | help='''whether to train with random bg color (real dataset only)
103 | to avoid objects with black color to be predicted as transparent
104 | ''')
105 |
106 | # validation options
107 | parser.add_argument('--eval_lpips', action='store_true', default=False,
108 | help='evaluate lpips metric (consumes more VRAM)')
109 | parser.add_argument('--val_only', action='store_true', default=False,
110 | help='run only validation (need to provide ckpt_path)')
111 | parser.add_argument('--no_save_test', action='store_true', default=False,
112 | help='whether to save test image and video')
113 | parser.add_argument("--eval_image", type=str, default="", help="path to eval image")
114 | parser.add_argument("--eval_prompt", type=str, default="", help="prompt for eval image")
115 |
116 | # misc
117 | parser.add_argument('--exp_name', type=str, default='exp',
118 | help='experiment name')
119 | parser.add_argument('--ckpt_path', type=str, default=None,
120 | help='pretrained checkpoint to load (including optimizers, etc)')
121 | parser.add_argument('--weight_path', type=str, default=None,
122 | help='pretrained checkpoint to load (excluding optimizers, etc)')
123 |
124 | # modified model paras
125 | parser.add_argument("--fuse_fn", type=str, default="trilinear_interp",
126 | help='fuse function for codebook')
127 | parser.add_argument("--deformable_hash", type=str, default="no_deformable",
128 | help='use deformable hash or not, deformable_codebook / deformable_sample')
129 | parser.add_argument("--deformable_hash_speedup", type=str, default="no_speedup",
130 | help='use deformable hash speedup or not, no_speedup / sampling / clustering')
131 | parser.add_argument("--n_levels", type=int, default=16, help='n_levels of codebook')
132 | parser.add_argument("--finest_res", type=int, default=1024, help='finest resolultion for hashed embedding')
133 | parser.add_argument("--base_res", type=int, default=16, help='base resolultion for hashed embedding')
134 | parser.add_argument("--n_features_per_level", type=int, default=2, help='n_features_per_level')
135 | parser.add_argument("--log2_hashmap_size", type=int, default=19, help='log2 of hashmap size')
136 | parser.add_argument("--max_samples", type=int, default=1024, help='max sample points in a ray')
137 | parser.add_argument('--offset_mode', action='store_true', default=False,
138 | help='use offset in codebook or not')
139 | parser.add_argument('--record_offset', action='store_true', default=False,
140 | help='record offset in codebook or not')
141 | parser.add_argument('--update_interval', type=int, default=16,
142 | help='update interval for density map')
143 | parser.add_argument('--deformable_lr', type=float,
144 | help='learning rate of offset')
145 | parser.add_argument('--multi_scale_lr', type=float,
146 | help='adaptive learning rate of multi scale offset')
147 | parser.add_argument('--mlp_lr', type=float,
148 | help='adaptive learning rate of mlp')
149 | parser.add_argument('--feature_lr', type=float, default=0.01,
150 | help='learning rate of feature')
151 | parser.add_argument('--position_loss', type=float,
152 | help='apply l2 loss on position of codebook points or not')
153 | parser.add_argument('--var_loss', type=float,
154 | help='apply loss on variance of codebook points position or not')
155 | parser.add_argument('--offset_loss', type=float,
156 | help='apply l2 loss on offsets of codebook points or not')
157 | parser.add_argument('--warmup', type=int, default=256,
158 | help='warmup steps')
159 | parser.add_argument('--check_codebook', type=int,
160 | help='check codebook steps')
161 | parser.add_argument('--reinit_start', type=int, default=1000,
162 | help='start reinit for codebook check')
163 | parser.add_argument('--reinit_end', type=int, default=2000,
164 | help='end reinit for codebook check')
165 | parser.add_argument('--threshold', type=float, default=0.5,
166 | help='threshold for codebook check')
167 | parser.add_argument('--noise', type=float, default=0.,
168 | help='random noise for codebook check')
169 | parser.add_argument('--degree', type=float, default=1.,
170 | help='degree for distance inverse interpolation')
171 | parser.add_argument('--limit_func', type=str, default="sigmoid",
172 | help='function on limit of offset')
173 | parser.add_argument('--table_size', type=int, default=5,
174 | help='table_size for grid')
175 | parser.add_argument('--offset_grid', type=int, default=2,
176 | help='area allowed to offset')
177 | parser.add_argument('--grid_hashmap_size', type=int, default=19,
178 | help='hashmap_size for grid')
179 | parser.add_argument('--warmup_epochs', type=float, default=0,
180 | help='warmup epochs for lr scheduler')
181 |
182 | # deformable sample
183 | parser.add_argument('--multi_offset', type=int, default=8,
184 | help='num of deformable offsets for each sample')
185 |
186 | # textual inversion
187 | parser.add_argument("--data_root", type=str,
188 | help='root directory of dataset for textual inversion')
189 | parser.add_argument("--placeholder_string",
190 | type=str,
191 | help="Placeholder string which will be used to denote the concept in future prompts. Overwrites the config options.")
192 | parser.add_argument("--init_word",
193 | type=str,
194 | help="Word to use as source for initial token embedding")
195 | parser.add_argument("--embedding_manager_ckpt",
196 | type=str,
197 | default="",
198 | help="Initialize embedding manager from a checkpoint")
199 | parser.add_argument("--embedding_path",
200 | type=str,
201 | help="Path to a pre-trained embedding manager checkpoint")
202 |
203 |
204 | return parser.parse_args()
205 |
--------------------------------------------------------------------------------
/outputs/a dragon toy with fire on the back.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/outputs/a dragon toy with fire on the back.png
--------------------------------------------------------------------------------
/outputs/a dragon toy with ice on the back.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/outputs/a dragon toy with ice on the back.png
--------------------------------------------------------------------------------
/outputs/anya/0_95.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/outputs/anya/0_95.png
--------------------------------------------------------------------------------
/outputs/backview of a dragon toy with fire on the back.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/outputs/backview of a dragon toy with fire on the back.png
--------------------------------------------------------------------------------
/outputs/backview of a dragon toy with ice on the back.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/outputs/backview of a dragon toy with ice on the back.png
--------------------------------------------------------------------------------
/outputs/dragon/ a purple dragon with fire on the back.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/outputs/dragon/ a purple dragon with fire on the back.png
--------------------------------------------------------------------------------
/outputs/dragon/a dragon toy with fire on the back.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/outputs/dragon/a dragon toy with fire on the back.png
--------------------------------------------------------------------------------
/outputs/dragon/a dragon with fire on its back.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/outputs/dragon/a dragon with fire on its back.png
--------------------------------------------------------------------------------
/outputs/dragon/a dragon with ice on its back.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/outputs/dragon/a dragon with ice on its back.png
--------------------------------------------------------------------------------
/outputs/dragon/a dragon with ice on the back.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/outputs/dragon/a dragon with ice on the back.png
--------------------------------------------------------------------------------
/outputs/dragon/a purple dragon with fire on the back.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/outputs/dragon/a purple dragon with fire on the back.png
--------------------------------------------------------------------------------
/outputs/minion/a dragon toy with fire on its back.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/outputs/minion/a dragon toy with fire on its back.png
--------------------------------------------------------------------------------
/outputs/minion/a minion with a rocket on the back.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-Research/TOSS/451a1d30f55d2f54f0421db3d0dfd6ed741b23f3/outputs/minion/a minion with a rocket on the back.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | gradio==3.40.1
2 | albumentations==1.3.0
3 | opencv-python==4.5.5.64
4 | imageio==2.9.0
5 | imageio-ffmpeg==0.4.2
6 | pytorch-lightning==1.5.0
7 | omegaconf==2.1.1
8 | test-tube>=0.7.5
9 | streamlit==1.12.1
10 | einops==0.3.0
11 | transformers==4.22.2
12 | webdataset==0.2.5
13 | kornia==0.6
14 | open_clip_torch==2.0.2
15 | invisible-watermark>=0.1.5
16 | streamlit-drawable-canvas==0.8.0
17 | torchmetrics==0.6.0
18 | timm==0.6.12
19 | addict==2.4.0
20 | yapf==0.32.0
21 | prettytable==3.6.0
22 | safetensors==0.2.7
23 | basicsr==1.4.2
24 | carvekit-colab==4.1.0
--------------------------------------------------------------------------------
/share.py:
--------------------------------------------------------------------------------
1 | import config
2 | from cldm.hack import disable_verbosity, enable_sliced_attention
3 |
4 |
5 | disable_verbosity()
6 |
7 | if config.save_memory:
8 | enable_sliced_attention()
9 |
--------------------------------------------------------------------------------
/streamlit_app.py:
--------------------------------------------------------------------------------
1 |
2 | from collections import namedtuple
3 | import altair as alt
4 | import math
5 | import pandas as pd
6 | import streamlit as st
7 |
8 | """
9 | # Welcome to Streamlit!
10 |
11 | Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:
12 |
13 | If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
14 | forums](https://discuss.streamlit.io).
15 |
16 | In the meantime, below is an example of what you can do with just a few lines of code:
17 | """
18 |
19 |
20 | with st.echo(code_location='below'):
21 | total_points = st.slider("Number of points in spiral", 1, 5000, 2000)
22 | num_turns = st.slider("Number of turns in spiral", 1, 100, 9)
23 |
24 | Point = namedtuple('Point', 'x y')
25 | data = []
26 |
27 | points_per_turn = total_points / num_turns
28 |
29 | for curr_point_num in range(total_points):
30 | curr_turn, i = divmod(curr_point_num, points_per_turn)
31 | angle = (curr_turn + 1) * 2 * math.pi * i / points_per_turn
32 | radius = curr_point_num / total_points
33 | x = radius * math.cos(angle)
34 | y = radius * math.sin(angle)
35 | data.append(Point(x, y))
36 |
37 | st.altair_chart(alt.Chart(pd.DataFrame(data), height=500, width=500)
38 | .mark_circle(color='#0068c9', opacity=0.5)
39 | .encode(x='x:Q', y='y:Q'))
40 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from share import *
2 |
3 | import pytorch_lightning as pl
4 | from torch.utils.data import DataLoader
5 | from tutorial_dataset import MyDataset
6 | from cldm.logger import ImageLogger
7 | from cldm.model import create_model, load_state_dict
8 | from pytorch_lightning.callbacks import ModelCheckpoint
9 | from pytorch_lightning.loggers import TensorBoardLogger
10 | from pytorch_lightning.plugins import DDPPlugin
11 | from ldm.util import instantiate_from_config
12 | import torch
13 | from omegaconf import OmegaConf
14 |
15 | # data
16 | from torch.utils.data import DataLoader
17 | from datasets import dataset_dict
18 | from datasets.ray_utils import axisangle_to_R, get_rays
19 |
20 | # configure
21 | from opt import get_opts
22 | import pdb
23 |
24 |
25 | if __name__ == '__main__':
26 | # Configs
27 | hparams = get_opts()
28 | logger_freq = hparams.logger_freq
29 | sd_locked = True
30 | only_mid_control = False
31 | cfgs = OmegaConf.load(hparams.model_cfg)
32 |
33 |
34 | # First use cpu to load models. Pytorch Lightning will automatically move it to GPUs.
35 | model = instantiate_from_config(cfgs.model)
36 |
37 | # get missing keys
38 | if hparams.resume_path != './models/control_sd15_pose_ini.ckpt':
39 | missing, unexpected = model.load_state_dict(load_state_dict(hparams.resume_path, location='cpu'), strict=False)
40 | print(f"Restored from {hparams.resume_path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
41 | if len(missing) > 0:
42 | print(f"Missing Keys:\n {missing}")
43 | if len(unexpected) > 0:
44 | print(f"\nUnexpected Keys:\n {unexpected}")
45 |
46 | # model.load_state_dict(load_state_dict(hparams.resume_path, location='cpu'), strict=False)
47 |
48 | # reweight noise scheduer
49 | if hparams.register_scheduler:
50 | model.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.016)
51 | model.learning_rate = hparams.lr
52 | model.sd_locked = sd_locked
53 | model.only_mid_control = only_mid_control
54 | # model.control_model.uncond_pose = hparams.uncond_pose
55 | model.eval()
56 |
57 | # data
58 | if 'objaverse' in hparams.dataset_name:
59 | if "car" in hparams.dataset_name:
60 | dataloader = instantiate_from_config(cfgs.data_car)
61 | print("Using objaverse car data!")
62 | elif "800k" in hparams.dataset_name:
63 | dataloader = instantiate_from_config(cfgs.data800k)
64 | else:
65 | dataloader = instantiate_from_config(cfgs.data)
66 | dataloader.prepare_data()
67 | dataloader.setup()
68 | elif "srn" in hparams.dataset_name:
69 | if "chair" in hparams.dataset_name:
70 | dataloader = instantiate_from_config(cfgs.srn_chairs)
71 | else:
72 | dataloader = instantiate_from_config(cfgs.srn_data)
73 | dataloader.prepare_data()
74 | dataloader.setup()
75 | else:
76 | kwargs = {'root_dir': hparams.root_dir}
77 | dataset = dataset_dict[hparams.dataset_name](split=hparams.split, text=hparams.text, img_size=hparams.img_size, **kwargs)
78 | dataloader = DataLoader(dataset, num_workers=0, batch_size=hparams.batch_size, shuffle=True)
79 |
80 |
81 | # Train!
82 | save_dir = f'./exp/{hparams.dataset_name}/{hparams.exp_name}/'
83 | logger = TensorBoardLogger(save_dir=save_dir,
84 | name=hparams.exp_name, default_hp_metric=False)
85 | img_logger = ImageLogger(batch_frequency=logger_freq, epoch_frequency=hparams.log_interval_epoch)
86 | ckpt_cb = ModelCheckpoint(dirpath=f'{save_dir}/ckpt/',
87 | filename='{epoch:d}',
88 | save_weights_only=False,
89 | every_n_epochs=hparams.ckpt_interval,
90 | save_on_train_epoch_end=True,
91 | save_top_k=-1)
92 | callbacks = [img_logger, ckpt_cb]
93 | trainer = pl.Trainer(gpus=hparams.num_gpus, callbacks=callbacks, \
94 | precision=16, \
95 | # amp_backend='apex', amp_level="O2", \
96 | check_val_every_n_epoch=20,
97 | logger=logger, max_epochs=hparams.num_epochs, \
98 | resume_from_checkpoint=hparams.resume_path if hparams.resume else None,
99 | # strategy="ddp",
100 | accumulate_grad_batches=8//hparams.num_gpus \
101 | if hparams.acc_grad==None else hparams.acc_grad,
102 | plugins=DDPPlugin(find_unused_parameters=False),
103 | accelerator="ddp"
104 | )
105 |
106 | # trainer.validate(model, dataloader)
107 | trainer.fit(model, dataloader)
108 |
109 |
--------------------------------------------------------------------------------
/tutorial_dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import cv2
3 | import numpy as np
4 |
5 | from torch.utils.data import Dataset
6 |
7 |
8 | class MyDataset(Dataset):
9 | def __init__(self):
10 | self.data = []
11 | with open('./training/fill50k/prompt.json', 'rt') as f:
12 | for line in f:
13 | self.data.append(json.loads(line))
14 |
15 | def __len__(self):
16 | return len(self.data)
17 |
18 | def __getitem__(self, idx):
19 | item = self.data[idx]
20 |
21 | source_filename = item['source']
22 | target_filename = item['target']
23 | prompt = item['prompt']
24 |
25 | source = cv2.imread('./training/fill50k/' + source_filename)
26 | target = cv2.imread('./training/fill50k/' + target_filename)
27 |
28 | # Do not forget that OpenCV read images in BGR order.
29 | source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB)
30 | target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)
31 |
32 | # Normalize source images to [0, 1].
33 | source = source.astype(np.float32) / 255.0
34 |
35 | # Normalize target images to [-1, 1].
36 | target = (target.astype(np.float32) / 127.5) - 1.0
37 |
38 | return dict(jpg=target, txt=prompt, hint=source)
39 |
40 |
--------------------------------------------------------------------------------
/viz.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import cv2
3 |
4 |
5 | def save_image_tensor2cv2(input_tensor: torch.Tensor, filename):
6 | """
7 | Save a tensor to a file as an image.
8 | :param input_tensor: tensor to save [C, H, W]
9 | :param filename: file to save to
10 | """
11 | assert (len(input_tensor.shape) == 3)
12 | input_tensor = input_tensor.clone().detach()
13 | input_tensor = input_tensor.to(torch.device('cpu'))
14 | input_tensor = input_tensor.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).type(torch.uint8).numpy()
15 | input_tensor = cv2.cvtColor(input_tensor, cv2.COLOR_RGB2BGR)
16 | cv2.imwrite(filename, input_tensor)
17 |
--------------------------------------------------------------------------------