├── .gitignore ├── LICENSE ├── README.md ├── annotator ├── canny │ └── __init__.py ├── midas │ ├── __init__.py │ ├── api.py │ ├── midas │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── blocks.py │ │ ├── dpt_depth.py │ │ ├── midas_net.py │ │ ├── midas_net_custom.py │ │ ├── transforms.py │ │ └── vit.py │ └── utils.py └── util.py ├── app.py ├── ckpt └── cldm_v15.yaml ├── cldm ├── cldm.py ├── hack.py ├── logger.py └── model.py ├── ldm ├── data │ ├── __init__.py │ └── util.py ├── models │ ├── autoencoder.py │ └── diffusion │ │ ├── __init__.py │ │ ├── ddim.py │ │ ├── ddpm.py │ │ ├── dpm_solver │ │ ├── __init__.py │ │ ├── dpm_solver.py │ │ └── sampler.py │ │ ├── plms.py │ │ └── sampling_util.py ├── modules │ ├── attention.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── model.py │ │ ├── openaimodel.py │ │ ├── upscaling.py │ │ └── util.py │ ├── distributions │ │ ├── __init__.py │ │ └── distributions.py │ ├── ema.py │ ├── encoders │ │ ├── __init__.py │ │ └── modules.py │ ├── image_degradation │ │ ├── __init__.py │ │ ├── bsrgan.py │ │ ├── bsrgan_light.py │ │ ├── utils │ │ │ └── test.png │ │ └── utils_image.py │ └── midas │ │ ├── __init__.py │ │ ├── api.py │ │ ├── midas │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── blocks.py │ │ ├── dpt_depth.py │ │ ├── midas_net.py │ │ ├── midas_net_custom.py │ │ ├── transforms.py │ │ └── vit.py │ │ └── utils.py └── util.py ├── log └── LOG.md ├── requirements.txt └── stablevideo ├── aggnet.py ├── atlas_data.py ├── atlas_utils.py └── implicit_neural_networks.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | 3 | training/ 4 | lightning_logs/ 5 | image_log/ 6 | 7 | *.pth 8 | *.pt 9 | *.ckpt 10 | *.safetensors 11 | 12 | gradio_pose2image_private.py 13 | gradio_canny2image_private.py 14 | 15 | # Byte-compiled / optimized / DLL files 16 | __pycache__/ 17 | *.py[cod] 18 | *$py.class 19 | 20 | # C extensions 21 | *.so 22 | 23 | # Distribution / packaging 24 | .Python 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | wheels/ 37 | pip-wheel-metadata/ 38 | share/python-wheels/ 39 | *.egg-info/ 40 | .installed.cfg 41 | *.egg 42 | MANIFEST 43 | 44 | # PyInstaller 45 | # Usually these files are written by a python script from a template 46 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 47 | *.manifest 48 | *.spec 49 | 50 | # Installer logs 51 | pip-log.txt 52 | pip-delete-this-directory.txt 53 | 54 | # Unit test / coverage reports 55 | htmlcov/ 56 | .tox/ 57 | .nox/ 58 | .coverage 59 | .coverage.* 60 | .cache 61 | nosetests.xml 62 | coverage.xml 63 | *.cover 64 | *.py,cover 65 | .hypothesis/ 66 | .pytest_cache/ 67 | 68 | # Translations 69 | *.mo 70 | *.pot 71 | 72 | # Django stuff: 73 | *.log 74 | local_settings.py 75 | db.sqlite3 76 | db.sqlite3-journal 77 | 78 | # Flask stuff: 79 | instance/ 80 | .webassets-cache 81 | 82 | # Scrapy stuff: 83 | .scrapy 84 | 85 | # Sphinx documentation 86 | docs/_build/ 87 | 88 | # PyBuilder 89 | target/ 90 | 91 | # Jupyter Notebook 92 | .ipynb_checkpoints 93 | 94 | # IPython 95 | profile_default/ 96 | ipython_config.py 97 | 98 | # pyenv 99 | .python-version 100 | 101 | # pipenv 102 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 103 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 104 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 105 | # install all needed dependencies. 106 | #Pipfile.lock 107 | 108 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 109 | __pypackages__/ 110 | 111 | # Celery stuff 112 | celerybeat-schedule 113 | celerybeat.pid 114 | 115 | # SageMath parsed files 116 | *.sage.py 117 | 118 | # Environments 119 | .env 120 | .venv 121 | env/ 122 | venv/ 123 | ENV/ 124 | env.bak/ 125 | venv.bak/ 126 | 127 | # Spyder project settings 128 | .spyderproject 129 | .spyproject 130 | 131 | # Rope project settings 132 | .ropeproject 133 | 134 | # mkdocs documentation 135 | /site 136 | 137 | # mypy 138 | .mypy_cache/ 139 | .dmypy.json 140 | dmypy.json 141 | 142 | # Pyre type checker 143 | .pyre/ 144 | 145 | data/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # StableVideo 2 | 3 | [![](http://img.shields.io/badge/cs.CV-arXiv%3A2308.09592-B31B1B.svg)](https://arxiv.org/abs/2308.09592) 4 | [![](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Demo-orange)](https://huggingface.co/spaces/Reself/StableVideo) 5 | 6 | > **StableVideo: Text-driven Consistency-aware Diffusion Video Editing** 7 | > Wenhao Chai, Xun Guo✉️, Gaoang Wang, Yan Lu 8 | > _ICCV 2023_ 9 | 10 | https://github.com/rese1f/StableVideo/assets/58205475/558555f1-711c-46f0-85bc-9c229ff1f511 11 | 12 | https://github.com/rese1f/StableVideo/assets/58205475/c152d0fa-16d3-4528-b9c2-ad2ec53944b9 13 | 14 | https://github.com/rese1f/StableVideo/assets/58205475/0edbefdd-9b5f-4868-842c-9bf3156a54d3 15 | 16 | 17 | ## VRAM requirement 18 | | |VRAM (MiB)| 19 | |---|---| 20 | |float32|29145| 21 | |amp|23005| 22 | |amp + cpu|17639| 23 | |amp + cpu + xformers|14185| 24 | 25 | - cpu: use cpu cache, args: `save_memory` 26 | 27 | under default setting (*e.g.* resolution, *etc.*) in `app.py` 28 | 29 | ## Installation 30 | ``` 31 | git clone https://github.com/rese1f/StableVideo.git 32 | conda create -n stablevideo python=3.11 33 | pip install -r requirements.txt 34 | (optional) pip install xformers 35 | ``` 36 | 37 | (optional) We also provide CPU only version [huggingface demo](https://huggingface.co/spaces/Reself/StableVideo). 38 | ``` 39 | git lfs install 40 | git clone https://huggingface.co/spaces/Reself/StableVideo 41 | pip install -r requirements.txt 42 | ``` 43 | 44 | ## Download Pretrained Model 45 | 46 | All models and detectors can be downloaded from ControlNet Hugging Face page at [Download Link](https://huggingface.co/lllyasviel/ControlNet). 47 | 48 | 49 | ## Download example videos 50 | Download the example atlas for car-turn, boat, libby, blackswa, bear, bicycle_tali, giraffe, kite-surf, lucia and motorbike at [Download Link](https://www.dropbox.com/s/oiyhbiqdws2p6r1/nla_share.zip?dl=0) shared by [Text2LIVE](https://github.com/omerbt/Text2LIVE) authors. 51 | 52 | You can also train on your own video following [NLA](https://github.com/ykasten/layered-neural-atlases). 53 | 54 | And it will create a folder data: 55 | ``` 56 | StableVideo 57 | ├── ... 58 | ├── ckpt 59 | │ ├── cldm_v15.yaml 60 | | ├── dpt_hybrid-midas-501f0c75.pt 61 | │ ├── control_sd15_canny.pth 62 | │ └── control_sd15_depth.pth 63 | ├── data 64 | │ └── car-turn 65 | │ ├── checkpoint # NLA models are stored here 66 | │ ├── car-turn # contains video frames 67 | │ ├── ... 68 | │ ├── blackswan 69 | │ ├── ... 70 | └── ... 71 | ``` 72 | 73 | ## Run and Play! 74 | Run the following command to start. 75 | ``` 76 | python app.py 77 | ``` 78 | the result `.mp4` video and keyframe will be stored in the directory `./log` after clicking `render` button. 79 | 80 | You can also edit the mask region for the foreground atlas as follows. Currently there might be a bug in Gradio. Please carefully check if the `editable output foreground atlas block` looks the same as the one above. If not, try to restart the entire program. 81 | 82 | 83 | 84 | 85 | ## Citation 86 | If our work is useful for your research, please consider citing as below. Many thanks :) 87 | ``` 88 | @article{chai2023stablevideo, 89 | title={StableVideo: Text-driven Consistency-aware Diffusion Video Editing}, 90 | author={Chai, Wenhao and Guo, Xun and Wang, Gaoang and Lu, Yan}, 91 | journal={arXiv preprint arXiv:2308.09592}, 92 | year={2023} 93 | } 94 | ``` 95 | 96 | ## Acknowledgement 97 | 98 | This implementation is built partly on [Text2LIVE](https://github.com/omerbt/Text2LIVE) and [ControlNet](https://github.com/lllyasviel/ControlNet). 99 | 100 | 101 | -------------------------------------------------------------------------------- /annotator/canny/__init__.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | 4 | class CannyDetector: 5 | def __call__(self, img, low_threshold, high_threshold): 6 | return cv2.Canny(img, low_threshold, high_threshold) 7 | -------------------------------------------------------------------------------- /annotator/midas/__init__.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | 5 | from einops import rearrange 6 | from .api import MiDaSInference 7 | 8 | 9 | class MidasDetector: 10 | def __init__(self): 11 | self.model = MiDaSInference(model_type="dpt_hybrid").cuda() 12 | 13 | def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1): 14 | assert input_image.ndim == 3 15 | image_depth = input_image 16 | with torch.no_grad(): 17 | image_depth = torch.from_numpy(image_depth).float().cuda() 18 | image_depth = image_depth / 127.5 - 1.0 19 | image_depth = rearrange(image_depth, 'h w c -> 1 c h w') 20 | depth = self.model(image_depth)[0] 21 | 22 | depth_pt = depth.clone() 23 | depth_pt -= torch.min(depth_pt) 24 | depth_pt /= torch.max(depth_pt) 25 | depth_pt = depth_pt.cpu().numpy() 26 | depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8) 27 | 28 | depth_np = depth.cpu().numpy() 29 | x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3) 30 | y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3) 31 | z = np.ones_like(x) * a 32 | x[depth_pt < bg_th] = 0 33 | y[depth_pt < bg_th] = 0 34 | normal = np.stack([x, y, z], axis=2) 35 | normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5 36 | normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8) 37 | 38 | return depth_image, normal_image 39 | -------------------------------------------------------------------------------- /annotator/midas/api.py: -------------------------------------------------------------------------------- 1 | # based on https://github.com/isl-org/MiDaS 2 | 3 | import cv2 4 | import os 5 | import torch 6 | import torch.nn as nn 7 | from torchvision.transforms import Compose 8 | 9 | from .midas.dpt_depth import DPTDepthModel 10 | from .midas.midas_net import MidasNet 11 | from .midas.midas_net_custom import MidasNet_small 12 | from .midas.transforms import Resize, NormalizeImage, PrepareForNet 13 | from annotator.util import annotator_ckpts_path 14 | 15 | 16 | ISL_PATHS = { 17 | "dpt_large": os.path.join(annotator_ckpts_path, "dpt_large-midas-2f21e586.pt"), 18 | "dpt_hybrid": os.path.join(annotator_ckpts_path, "dpt_hybrid-midas-501f0c75.pt"), 19 | "midas_v21": "", 20 | "midas_v21_small": "", 21 | } 22 | 23 | remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt" 24 | 25 | 26 | def disabled_train(self, mode=True): 27 | """Overwrite model.train with this function to make sure train/eval mode 28 | does not change anymore.""" 29 | return self 30 | 31 | 32 | def load_midas_transform(model_type): 33 | # https://github.com/isl-org/MiDaS/blob/master/run.py 34 | # load transform only 35 | if model_type == "dpt_large": # DPT-Large 36 | net_w, net_h = 384, 384 37 | resize_mode = "minimal" 38 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 39 | 40 | elif model_type == "dpt_hybrid": # DPT-Hybrid 41 | net_w, net_h = 384, 384 42 | resize_mode = "minimal" 43 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 44 | 45 | elif model_type == "midas_v21": 46 | net_w, net_h = 384, 384 47 | resize_mode = "upper_bound" 48 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 49 | 50 | elif model_type == "midas_v21_small": 51 | net_w, net_h = 256, 256 52 | resize_mode = "upper_bound" 53 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 54 | 55 | else: 56 | assert False, f"model_type '{model_type}' not implemented, use: --model_type large" 57 | 58 | transform = Compose( 59 | [ 60 | Resize( 61 | net_w, 62 | net_h, 63 | resize_target=None, 64 | keep_aspect_ratio=True, 65 | ensure_multiple_of=32, 66 | resize_method=resize_mode, 67 | image_interpolation_method=cv2.INTER_CUBIC, 68 | ), 69 | normalization, 70 | PrepareForNet(), 71 | ] 72 | ) 73 | 74 | return transform 75 | 76 | 77 | def load_model(model_type): 78 | # https://github.com/isl-org/MiDaS/blob/master/run.py 79 | # load network 80 | model_path = ISL_PATHS[model_type] 81 | if model_type == "dpt_large": # DPT-Large 82 | model = DPTDepthModel( 83 | path=model_path, 84 | backbone="vitl16_384", 85 | non_negative=True, 86 | ) 87 | net_w, net_h = 384, 384 88 | resize_mode = "minimal" 89 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 90 | 91 | elif model_type == "dpt_hybrid": # DPT-Hybrid 92 | if not os.path.exists(model_path): 93 | from basicsr.utils.download_util import load_file_from_url 94 | load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path) 95 | 96 | model = DPTDepthModel( 97 | path=model_path, 98 | backbone="vitb_rn50_384", 99 | non_negative=True, 100 | ) 101 | net_w, net_h = 384, 384 102 | resize_mode = "minimal" 103 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 104 | 105 | elif model_type == "midas_v21": 106 | model = MidasNet(model_path, non_negative=True) 107 | net_w, net_h = 384, 384 108 | resize_mode = "upper_bound" 109 | normalization = NormalizeImage( 110 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 111 | ) 112 | 113 | elif model_type == "midas_v21_small": 114 | model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, 115 | non_negative=True, blocks={'expand': True}) 116 | net_w, net_h = 256, 256 117 | resize_mode = "upper_bound" 118 | normalization = NormalizeImage( 119 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 120 | ) 121 | 122 | else: 123 | print(f"model_type '{model_type}' not implemented, use: --model_type large") 124 | assert False 125 | 126 | transform = Compose( 127 | [ 128 | Resize( 129 | net_w, 130 | net_h, 131 | resize_target=None, 132 | keep_aspect_ratio=True, 133 | ensure_multiple_of=32, 134 | resize_method=resize_mode, 135 | image_interpolation_method=cv2.INTER_CUBIC, 136 | ), 137 | normalization, 138 | PrepareForNet(), 139 | ] 140 | ) 141 | 142 | return model.eval(), transform 143 | 144 | 145 | class MiDaSInference(nn.Module): 146 | MODEL_TYPES_TORCH_HUB = [ 147 | "DPT_Large", 148 | "DPT_Hybrid", 149 | "MiDaS_small" 150 | ] 151 | MODEL_TYPES_ISL = [ 152 | "dpt_large", 153 | "dpt_hybrid", 154 | "midas_v21", 155 | "midas_v21_small", 156 | ] 157 | 158 | def __init__(self, model_type): 159 | super().__init__() 160 | assert (model_type in self.MODEL_TYPES_ISL) 161 | model, _ = load_model(model_type) 162 | self.model = model 163 | self.model.train = disabled_train 164 | 165 | def forward(self, x): 166 | with torch.no_grad(): 167 | prediction = self.model(x) 168 | return prediction 169 | 170 | -------------------------------------------------------------------------------- /annotator/midas/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rese1f/StableVideo/298b52f8df5a1edb93ee46c59001833a2b291aa5/annotator/midas/midas/__init__.py -------------------------------------------------------------------------------- /annotator/midas/midas/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseModel(torch.nn.Module): 5 | def load(self, path): 6 | """Load model from file. 7 | 8 | Args: 9 | path (str): file path 10 | """ 11 | parameters = torch.load(path, map_location=torch.device('cpu')) 12 | 13 | if "optimizer" in parameters: 14 | parameters = parameters["model"] 15 | 16 | self.load_state_dict(parameters) 17 | -------------------------------------------------------------------------------- /annotator/midas/midas/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .vit import ( 5 | _make_pretrained_vitb_rn50_384, 6 | _make_pretrained_vitl16_384, 7 | _make_pretrained_vitb16_384, 8 | forward_vit, 9 | ) 10 | 11 | def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): 12 | if backbone == "vitl16_384": 13 | pretrained = _make_pretrained_vitl16_384( 14 | use_pretrained, hooks=hooks, use_readout=use_readout 15 | ) 16 | scratch = _make_scratch( 17 | [256, 512, 1024, 1024], features, groups=groups, expand=expand 18 | ) # ViT-L/16 - 85.0% Top1 (backbone) 19 | elif backbone == "vitb_rn50_384": 20 | pretrained = _make_pretrained_vitb_rn50_384( 21 | use_pretrained, 22 | hooks=hooks, 23 | use_vit_only=use_vit_only, 24 | use_readout=use_readout, 25 | ) 26 | scratch = _make_scratch( 27 | [256, 512, 768, 768], features, groups=groups, expand=expand 28 | ) # ViT-H/16 - 85.0% Top1 (backbone) 29 | elif backbone == "vitb16_384": 30 | pretrained = _make_pretrained_vitb16_384( 31 | use_pretrained, hooks=hooks, use_readout=use_readout 32 | ) 33 | scratch = _make_scratch( 34 | [96, 192, 384, 768], features, groups=groups, expand=expand 35 | ) # ViT-B/16 - 84.6% Top1 (backbone) 36 | elif backbone == "resnext101_wsl": 37 | pretrained = _make_pretrained_resnext101_wsl(use_pretrained) 38 | scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 39 | elif backbone == "efficientnet_lite3": 40 | pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) 41 | scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 42 | else: 43 | print(f"Backbone '{backbone}' not implemented") 44 | assert False 45 | 46 | return pretrained, scratch 47 | 48 | 49 | def _make_scratch(in_shape, out_shape, groups=1, expand=False): 50 | scratch = nn.Module() 51 | 52 | out_shape1 = out_shape 53 | out_shape2 = out_shape 54 | out_shape3 = out_shape 55 | out_shape4 = out_shape 56 | if expand==True: 57 | out_shape1 = out_shape 58 | out_shape2 = out_shape*2 59 | out_shape3 = out_shape*4 60 | out_shape4 = out_shape*8 61 | 62 | scratch.layer1_rn = nn.Conv2d( 63 | in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 64 | ) 65 | scratch.layer2_rn = nn.Conv2d( 66 | in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 67 | ) 68 | scratch.layer3_rn = nn.Conv2d( 69 | in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 70 | ) 71 | scratch.layer4_rn = nn.Conv2d( 72 | in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 73 | ) 74 | 75 | return scratch 76 | 77 | 78 | def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): 79 | efficientnet = torch.hub.load( 80 | "rwightman/gen-efficientnet-pytorch", 81 | "tf_efficientnet_lite3", 82 | pretrained=use_pretrained, 83 | exportable=exportable 84 | ) 85 | return _make_efficientnet_backbone(efficientnet) 86 | 87 | 88 | def _make_efficientnet_backbone(effnet): 89 | pretrained = nn.Module() 90 | 91 | pretrained.layer1 = nn.Sequential( 92 | effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] 93 | ) 94 | pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) 95 | pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) 96 | pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) 97 | 98 | return pretrained 99 | 100 | 101 | def _make_resnet_backbone(resnet): 102 | pretrained = nn.Module() 103 | pretrained.layer1 = nn.Sequential( 104 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 105 | ) 106 | 107 | pretrained.layer2 = resnet.layer2 108 | pretrained.layer3 = resnet.layer3 109 | pretrained.layer4 = resnet.layer4 110 | 111 | return pretrained 112 | 113 | 114 | def _make_pretrained_resnext101_wsl(use_pretrained): 115 | resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") 116 | return _make_resnet_backbone(resnet) 117 | 118 | 119 | 120 | class Interpolate(nn.Module): 121 | """Interpolation module. 122 | """ 123 | 124 | def __init__(self, scale_factor, mode, align_corners=False): 125 | """Init. 126 | 127 | Args: 128 | scale_factor (float): scaling 129 | mode (str): interpolation mode 130 | """ 131 | super(Interpolate, self).__init__() 132 | 133 | self.interp = nn.functional.interpolate 134 | self.scale_factor = scale_factor 135 | self.mode = mode 136 | self.align_corners = align_corners 137 | 138 | def forward(self, x): 139 | """Forward pass. 140 | 141 | Args: 142 | x (tensor): input 143 | 144 | Returns: 145 | tensor: interpolated data 146 | """ 147 | 148 | x = self.interp( 149 | x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners 150 | ) 151 | 152 | return x 153 | 154 | 155 | class ResidualConvUnit(nn.Module): 156 | """Residual convolution module. 157 | """ 158 | 159 | def __init__(self, features): 160 | """Init. 161 | 162 | Args: 163 | features (int): number of features 164 | """ 165 | super().__init__() 166 | 167 | self.conv1 = nn.Conv2d( 168 | features, features, kernel_size=3, stride=1, padding=1, bias=True 169 | ) 170 | 171 | self.conv2 = nn.Conv2d( 172 | features, features, kernel_size=3, stride=1, padding=1, bias=True 173 | ) 174 | 175 | self.relu = nn.ReLU(inplace=True) 176 | 177 | def forward(self, x): 178 | """Forward pass. 179 | 180 | Args: 181 | x (tensor): input 182 | 183 | Returns: 184 | tensor: output 185 | """ 186 | out = self.relu(x) 187 | out = self.conv1(out) 188 | out = self.relu(out) 189 | out = self.conv2(out) 190 | 191 | return out + x 192 | 193 | 194 | class FeatureFusionBlock(nn.Module): 195 | """Feature fusion block. 196 | """ 197 | 198 | def __init__(self, features): 199 | """Init. 200 | 201 | Args: 202 | features (int): number of features 203 | """ 204 | super(FeatureFusionBlock, self).__init__() 205 | 206 | self.resConfUnit1 = ResidualConvUnit(features) 207 | self.resConfUnit2 = ResidualConvUnit(features) 208 | 209 | def forward(self, *xs): 210 | """Forward pass. 211 | 212 | Returns: 213 | tensor: output 214 | """ 215 | output = xs[0] 216 | 217 | if len(xs) == 2: 218 | output += self.resConfUnit1(xs[1]) 219 | 220 | output = self.resConfUnit2(output) 221 | 222 | output = nn.functional.interpolate( 223 | output, scale_factor=2, mode="bilinear", align_corners=True 224 | ) 225 | 226 | return output 227 | 228 | 229 | 230 | 231 | class ResidualConvUnit_custom(nn.Module): 232 | """Residual convolution module. 233 | """ 234 | 235 | def __init__(self, features, activation, bn): 236 | """Init. 237 | 238 | Args: 239 | features (int): number of features 240 | """ 241 | super().__init__() 242 | 243 | self.bn = bn 244 | 245 | self.groups=1 246 | 247 | self.conv1 = nn.Conv2d( 248 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups 249 | ) 250 | 251 | self.conv2 = nn.Conv2d( 252 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups 253 | ) 254 | 255 | if self.bn==True: 256 | self.bn1 = nn.BatchNorm2d(features) 257 | self.bn2 = nn.BatchNorm2d(features) 258 | 259 | self.activation = activation 260 | 261 | self.skip_add = nn.quantized.FloatFunctional() 262 | 263 | def forward(self, x): 264 | """Forward pass. 265 | 266 | Args: 267 | x (tensor): input 268 | 269 | Returns: 270 | tensor: output 271 | """ 272 | 273 | out = self.activation(x) 274 | out = self.conv1(out) 275 | if self.bn==True: 276 | out = self.bn1(out) 277 | 278 | out = self.activation(out) 279 | out = self.conv2(out) 280 | if self.bn==True: 281 | out = self.bn2(out) 282 | 283 | if self.groups > 1: 284 | out = self.conv_merge(out) 285 | 286 | return self.skip_add.add(out, x) 287 | 288 | # return out + x 289 | 290 | 291 | class FeatureFusionBlock_custom(nn.Module): 292 | """Feature fusion block. 293 | """ 294 | 295 | def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): 296 | """Init. 297 | 298 | Args: 299 | features (int): number of features 300 | """ 301 | super(FeatureFusionBlock_custom, self).__init__() 302 | 303 | self.deconv = deconv 304 | self.align_corners = align_corners 305 | 306 | self.groups=1 307 | 308 | self.expand = expand 309 | out_features = features 310 | if self.expand==True: 311 | out_features = features//2 312 | 313 | self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) 314 | 315 | self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) 316 | self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) 317 | 318 | self.skip_add = nn.quantized.FloatFunctional() 319 | 320 | def forward(self, *xs): 321 | """Forward pass. 322 | 323 | Returns: 324 | tensor: output 325 | """ 326 | output = xs[0] 327 | 328 | if len(xs) == 2: 329 | res = self.resConfUnit1(xs[1]) 330 | output = self.skip_add.add(output, res) 331 | # output += res 332 | 333 | output = self.resConfUnit2(output) 334 | 335 | output = nn.functional.interpolate( 336 | output, scale_factor=2, mode="bilinear", align_corners=self.align_corners 337 | ) 338 | 339 | output = self.out_conv(output) 340 | 341 | return output 342 | 343 | -------------------------------------------------------------------------------- /annotator/midas/midas/dpt_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base_model import BaseModel 6 | from .blocks import ( 7 | FeatureFusionBlock, 8 | FeatureFusionBlock_custom, 9 | Interpolate, 10 | _make_encoder, 11 | forward_vit, 12 | ) 13 | 14 | 15 | def _make_fusion_block(features, use_bn): 16 | return FeatureFusionBlock_custom( 17 | features, 18 | nn.ReLU(False), 19 | deconv=False, 20 | bn=use_bn, 21 | expand=False, 22 | align_corners=True, 23 | ) 24 | 25 | 26 | class DPT(BaseModel): 27 | def __init__( 28 | self, 29 | head, 30 | features=256, 31 | backbone="vitb_rn50_384", 32 | readout="project", 33 | channels_last=False, 34 | use_bn=False, 35 | ): 36 | 37 | super(DPT, self).__init__() 38 | 39 | self.channels_last = channels_last 40 | 41 | hooks = { 42 | "vitb_rn50_384": [0, 1, 8, 11], 43 | "vitb16_384": [2, 5, 8, 11], 44 | "vitl16_384": [5, 11, 17, 23], 45 | } 46 | 47 | # Instantiate backbone and reassemble blocks 48 | self.pretrained, self.scratch = _make_encoder( 49 | backbone, 50 | features, 51 | False, # Set to true of you want to train from scratch, uses ImageNet weights 52 | groups=1, 53 | expand=False, 54 | exportable=False, 55 | hooks=hooks[backbone], 56 | use_readout=readout, 57 | ) 58 | 59 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 60 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 61 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn) 62 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 63 | 64 | self.scratch.output_conv = head 65 | 66 | 67 | def forward(self, x): 68 | if self.channels_last == True: 69 | x.contiguous(memory_format=torch.channels_last) 70 | 71 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) 72 | 73 | layer_1_rn = self.scratch.layer1_rn(layer_1) 74 | layer_2_rn = self.scratch.layer2_rn(layer_2) 75 | layer_3_rn = self.scratch.layer3_rn(layer_3) 76 | layer_4_rn = self.scratch.layer4_rn(layer_4) 77 | 78 | path_4 = self.scratch.refinenet4(layer_4_rn) 79 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 80 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 81 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 82 | 83 | out = self.scratch.output_conv(path_1) 84 | 85 | return out 86 | 87 | 88 | class DPTDepthModel(DPT): 89 | def __init__(self, path=None, non_negative=True, **kwargs): 90 | features = kwargs["features"] if "features" in kwargs else 256 91 | 92 | head = nn.Sequential( 93 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), 94 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 95 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 96 | nn.ReLU(True), 97 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 98 | nn.ReLU(True) if non_negative else nn.Identity(), 99 | nn.Identity(), 100 | ) 101 | 102 | super().__init__(head, **kwargs) 103 | 104 | if path is not None: 105 | self.load(path) 106 | 107 | def forward(self, x): 108 | return super().forward(x).squeeze(dim=1) 109 | 110 | -------------------------------------------------------------------------------- /annotator/midas/midas/midas_net.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=256, non_negative=True): 17 | """Init. 18 | 19 | Args: 20 | path (str, optional): Path to saved model. Defaults to None. 21 | features (int, optional): Number of features. Defaults to 256. 22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 23 | """ 24 | print("Loading weights: ", path) 25 | 26 | super(MidasNet, self).__init__() 27 | 28 | use_pretrained = False if path is None else True 29 | 30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) 31 | 32 | self.scratch.refinenet4 = FeatureFusionBlock(features) 33 | self.scratch.refinenet3 = FeatureFusionBlock(features) 34 | self.scratch.refinenet2 = FeatureFusionBlock(features) 35 | self.scratch.refinenet1 = FeatureFusionBlock(features) 36 | 37 | self.scratch.output_conv = nn.Sequential( 38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 39 | Interpolate(scale_factor=2, mode="bilinear"), 40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 41 | nn.ReLU(True), 42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 43 | nn.ReLU(True) if non_negative else nn.Identity(), 44 | ) 45 | 46 | if path: 47 | self.load(path) 48 | 49 | def forward(self, x): 50 | """Forward pass. 51 | 52 | Args: 53 | x (tensor): input data (image) 54 | 55 | Returns: 56 | tensor: depth 57 | """ 58 | 59 | layer_1 = self.pretrained.layer1(x) 60 | layer_2 = self.pretrained.layer2(layer_1) 61 | layer_3 = self.pretrained.layer3(layer_2) 62 | layer_4 = self.pretrained.layer4(layer_3) 63 | 64 | layer_1_rn = self.scratch.layer1_rn(layer_1) 65 | layer_2_rn = self.scratch.layer2_rn(layer_2) 66 | layer_3_rn = self.scratch.layer3_rn(layer_3) 67 | layer_4_rn = self.scratch.layer4_rn(layer_4) 68 | 69 | path_4 = self.scratch.refinenet4(layer_4_rn) 70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 73 | 74 | out = self.scratch.output_conv(path_1) 75 | 76 | return torch.squeeze(out, dim=1) 77 | -------------------------------------------------------------------------------- /annotator/midas/midas/midas_net_custom.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet_small(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, 17 | blocks={'expand': True}): 18 | """Init. 19 | 20 | Args: 21 | path (str, optional): Path to saved model. Defaults to None. 22 | features (int, optional): Number of features. Defaults to 256. 23 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 24 | """ 25 | print("Loading weights: ", path) 26 | 27 | super(MidasNet_small, self).__init__() 28 | 29 | use_pretrained = False if path else True 30 | 31 | self.channels_last = channels_last 32 | self.blocks = blocks 33 | self.backbone = backbone 34 | 35 | self.groups = 1 36 | 37 | features1=features 38 | features2=features 39 | features3=features 40 | features4=features 41 | self.expand = False 42 | if "expand" in self.blocks and self.blocks['expand'] == True: 43 | self.expand = True 44 | features1=features 45 | features2=features*2 46 | features3=features*4 47 | features4=features*8 48 | 49 | self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) 50 | 51 | self.scratch.activation = nn.ReLU(False) 52 | 53 | self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 54 | self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 55 | self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 56 | self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) 57 | 58 | 59 | self.scratch.output_conv = nn.Sequential( 60 | nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), 61 | Interpolate(scale_factor=2, mode="bilinear"), 62 | nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), 63 | self.scratch.activation, 64 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 65 | nn.ReLU(True) if non_negative else nn.Identity(), 66 | nn.Identity(), 67 | ) 68 | 69 | if path: 70 | self.load(path) 71 | 72 | 73 | def forward(self, x): 74 | """Forward pass. 75 | 76 | Args: 77 | x (tensor): input data (image) 78 | 79 | Returns: 80 | tensor: depth 81 | """ 82 | if self.channels_last==True: 83 | print("self.channels_last = ", self.channels_last) 84 | x.contiguous(memory_format=torch.channels_last) 85 | 86 | 87 | layer_1 = self.pretrained.layer1(x) 88 | layer_2 = self.pretrained.layer2(layer_1) 89 | layer_3 = self.pretrained.layer3(layer_2) 90 | layer_4 = self.pretrained.layer4(layer_3) 91 | 92 | layer_1_rn = self.scratch.layer1_rn(layer_1) 93 | layer_2_rn = self.scratch.layer2_rn(layer_2) 94 | layer_3_rn = self.scratch.layer3_rn(layer_3) 95 | layer_4_rn = self.scratch.layer4_rn(layer_4) 96 | 97 | 98 | path_4 = self.scratch.refinenet4(layer_4_rn) 99 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 100 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 101 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 102 | 103 | out = self.scratch.output_conv(path_1) 104 | 105 | return torch.squeeze(out, dim=1) 106 | 107 | 108 | 109 | def fuse_model(m): 110 | prev_previous_type = nn.Identity() 111 | prev_previous_name = '' 112 | previous_type = nn.Identity() 113 | previous_name = '' 114 | for name, module in m.named_modules(): 115 | if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: 116 | # print("FUSED ", prev_previous_name, previous_name, name) 117 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) 118 | elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: 119 | # print("FUSED ", prev_previous_name, previous_name) 120 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) 121 | # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: 122 | # print("FUSED ", previous_name, name) 123 | # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) 124 | 125 | prev_previous_type = previous_type 126 | prev_previous_name = previous_name 127 | previous_type = type(module) 128 | previous_name = name -------------------------------------------------------------------------------- /annotator/midas/midas/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import math 4 | 5 | 6 | def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): 7 | """Rezise the sample to ensure the given size. Keeps aspect ratio. 8 | 9 | Args: 10 | sample (dict): sample 11 | size (tuple): image size 12 | 13 | Returns: 14 | tuple: new size 15 | """ 16 | shape = list(sample["disparity"].shape) 17 | 18 | if shape[0] >= size[0] and shape[1] >= size[1]: 19 | return sample 20 | 21 | scale = [0, 0] 22 | scale[0] = size[0] / shape[0] 23 | scale[1] = size[1] / shape[1] 24 | 25 | scale = max(scale) 26 | 27 | shape[0] = math.ceil(scale * shape[0]) 28 | shape[1] = math.ceil(scale * shape[1]) 29 | 30 | # resize 31 | sample["image"] = cv2.resize( 32 | sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method 33 | ) 34 | 35 | sample["disparity"] = cv2.resize( 36 | sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST 37 | ) 38 | sample["mask"] = cv2.resize( 39 | sample["mask"].astype(np.float32), 40 | tuple(shape[::-1]), 41 | interpolation=cv2.INTER_NEAREST, 42 | ) 43 | sample["mask"] = sample["mask"].astype(bool) 44 | 45 | return tuple(shape) 46 | 47 | 48 | class Resize(object): 49 | """Resize sample to given size (width, height). 50 | """ 51 | 52 | def __init__( 53 | self, 54 | width, 55 | height, 56 | resize_target=True, 57 | keep_aspect_ratio=False, 58 | ensure_multiple_of=1, 59 | resize_method="lower_bound", 60 | image_interpolation_method=cv2.INTER_AREA, 61 | ): 62 | """Init. 63 | 64 | Args: 65 | width (int): desired output width 66 | height (int): desired output height 67 | resize_target (bool, optional): 68 | True: Resize the full sample (image, mask, target). 69 | False: Resize image only. 70 | Defaults to True. 71 | keep_aspect_ratio (bool, optional): 72 | True: Keep the aspect ratio of the input sample. 73 | Output sample might not have the given width and height, and 74 | resize behaviour depends on the parameter 'resize_method'. 75 | Defaults to False. 76 | ensure_multiple_of (int, optional): 77 | Output width and height is constrained to be multiple of this parameter. 78 | Defaults to 1. 79 | resize_method (str, optional): 80 | "lower_bound": Output will be at least as large as the given size. 81 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) 82 | "minimal": Scale as least as possible. (Output size might be smaller than given size.) 83 | Defaults to "lower_bound". 84 | """ 85 | self.__width = width 86 | self.__height = height 87 | 88 | self.__resize_target = resize_target 89 | self.__keep_aspect_ratio = keep_aspect_ratio 90 | self.__multiple_of = ensure_multiple_of 91 | self.__resize_method = resize_method 92 | self.__image_interpolation_method = image_interpolation_method 93 | 94 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None): 95 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) 96 | 97 | if max_val is not None and y > max_val: 98 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) 99 | 100 | if y < min_val: 101 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) 102 | 103 | return y 104 | 105 | def get_size(self, width, height): 106 | # determine new height and width 107 | scale_height = self.__height / height 108 | scale_width = self.__width / width 109 | 110 | if self.__keep_aspect_ratio: 111 | if self.__resize_method == "lower_bound": 112 | # scale such that output size is lower bound 113 | if scale_width > scale_height: 114 | # fit width 115 | scale_height = scale_width 116 | else: 117 | # fit height 118 | scale_width = scale_height 119 | elif self.__resize_method == "upper_bound": 120 | # scale such that output size is upper bound 121 | if scale_width < scale_height: 122 | # fit width 123 | scale_height = scale_width 124 | else: 125 | # fit height 126 | scale_width = scale_height 127 | elif self.__resize_method == "minimal": 128 | # scale as least as possbile 129 | if abs(1 - scale_width) < abs(1 - scale_height): 130 | # fit width 131 | scale_height = scale_width 132 | else: 133 | # fit height 134 | scale_width = scale_height 135 | else: 136 | raise ValueError( 137 | f"resize_method {self.__resize_method} not implemented" 138 | ) 139 | 140 | if self.__resize_method == "lower_bound": 141 | new_height = self.constrain_to_multiple_of( 142 | scale_height * height, min_val=self.__height 143 | ) 144 | new_width = self.constrain_to_multiple_of( 145 | scale_width * width, min_val=self.__width 146 | ) 147 | elif self.__resize_method == "upper_bound": 148 | new_height = self.constrain_to_multiple_of( 149 | scale_height * height, max_val=self.__height 150 | ) 151 | new_width = self.constrain_to_multiple_of( 152 | scale_width * width, max_val=self.__width 153 | ) 154 | elif self.__resize_method == "minimal": 155 | new_height = self.constrain_to_multiple_of(scale_height * height) 156 | new_width = self.constrain_to_multiple_of(scale_width * width) 157 | else: 158 | raise ValueError(f"resize_method {self.__resize_method} not implemented") 159 | 160 | return (new_width, new_height) 161 | 162 | def __call__(self, sample): 163 | width, height = self.get_size( 164 | sample["image"].shape[1], sample["image"].shape[0] 165 | ) 166 | 167 | # resize sample 168 | sample["image"] = cv2.resize( 169 | sample["image"], 170 | (width, height), 171 | interpolation=self.__image_interpolation_method, 172 | ) 173 | 174 | if self.__resize_target: 175 | if "disparity" in sample: 176 | sample["disparity"] = cv2.resize( 177 | sample["disparity"], 178 | (width, height), 179 | interpolation=cv2.INTER_NEAREST, 180 | ) 181 | 182 | if "depth" in sample: 183 | sample["depth"] = cv2.resize( 184 | sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST 185 | ) 186 | 187 | sample["mask"] = cv2.resize( 188 | sample["mask"].astype(np.float32), 189 | (width, height), 190 | interpolation=cv2.INTER_NEAREST, 191 | ) 192 | sample["mask"] = sample["mask"].astype(bool) 193 | 194 | return sample 195 | 196 | 197 | class NormalizeImage(object): 198 | """Normlize image by given mean and std. 199 | """ 200 | 201 | def __init__(self, mean, std): 202 | self.__mean = mean 203 | self.__std = std 204 | 205 | def __call__(self, sample): 206 | sample["image"] = (sample["image"] - self.__mean) / self.__std 207 | 208 | return sample 209 | 210 | 211 | class PrepareForNet(object): 212 | """Prepare sample for usage as network input. 213 | """ 214 | 215 | def __init__(self): 216 | pass 217 | 218 | def __call__(self, sample): 219 | image = np.transpose(sample["image"], (2, 0, 1)) 220 | sample["image"] = np.ascontiguousarray(image).astype(np.float32) 221 | 222 | if "mask" in sample: 223 | sample["mask"] = sample["mask"].astype(np.float32) 224 | sample["mask"] = np.ascontiguousarray(sample["mask"]) 225 | 226 | if "disparity" in sample: 227 | disparity = sample["disparity"].astype(np.float32) 228 | sample["disparity"] = np.ascontiguousarray(disparity) 229 | 230 | if "depth" in sample: 231 | depth = sample["depth"].astype(np.float32) 232 | sample["depth"] = np.ascontiguousarray(depth) 233 | 234 | return sample 235 | -------------------------------------------------------------------------------- /annotator/midas/utils.py: -------------------------------------------------------------------------------- 1 | """Utils for monoDepth.""" 2 | import sys 3 | import re 4 | import numpy as np 5 | import cv2 6 | import torch 7 | 8 | 9 | def read_pfm(path): 10 | """Read pfm file. 11 | 12 | Args: 13 | path (str): path to file 14 | 15 | Returns: 16 | tuple: (data, scale) 17 | """ 18 | with open(path, "rb") as file: 19 | 20 | color = None 21 | width = None 22 | height = None 23 | scale = None 24 | endian = None 25 | 26 | header = file.readline().rstrip() 27 | if header.decode("ascii") == "PF": 28 | color = True 29 | elif header.decode("ascii") == "Pf": 30 | color = False 31 | else: 32 | raise Exception("Not a PFM file: " + path) 33 | 34 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) 35 | if dim_match: 36 | width, height = list(map(int, dim_match.groups())) 37 | else: 38 | raise Exception("Malformed PFM header.") 39 | 40 | scale = float(file.readline().decode("ascii").rstrip()) 41 | if scale < 0: 42 | # little-endian 43 | endian = "<" 44 | scale = -scale 45 | else: 46 | # big-endian 47 | endian = ">" 48 | 49 | data = np.fromfile(file, endian + "f") 50 | shape = (height, width, 3) if color else (height, width) 51 | 52 | data = np.reshape(data, shape) 53 | data = np.flipud(data) 54 | 55 | return data, scale 56 | 57 | 58 | def write_pfm(path, image, scale=1): 59 | """Write pfm file. 60 | 61 | Args: 62 | path (str): pathto file 63 | image (array): data 64 | scale (int, optional): Scale. Defaults to 1. 65 | """ 66 | 67 | with open(path, "wb") as file: 68 | color = None 69 | 70 | if image.dtype.name != "float32": 71 | raise Exception("Image dtype must be float32.") 72 | 73 | image = np.flipud(image) 74 | 75 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 76 | color = True 77 | elif ( 78 | len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 79 | ): # greyscale 80 | color = False 81 | else: 82 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") 83 | 84 | file.write("PF\n" if color else "Pf\n".encode()) 85 | file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) 86 | 87 | endian = image.dtype.byteorder 88 | 89 | if endian == "<" or endian == "=" and sys.byteorder == "little": 90 | scale = -scale 91 | 92 | file.write("%f\n".encode() % scale) 93 | 94 | image.tofile(file) 95 | 96 | 97 | def read_image(path): 98 | """Read image and output RGB image (0-1). 99 | 100 | Args: 101 | path (str): path to file 102 | 103 | Returns: 104 | array: RGB image (0-1) 105 | """ 106 | img = cv2.imread(path) 107 | 108 | if img.ndim == 2: 109 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 110 | 111 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 112 | 113 | return img 114 | 115 | 116 | def resize_image(img): 117 | """Resize image and make it fit for network. 118 | 119 | Args: 120 | img (array): image 121 | 122 | Returns: 123 | tensor: data ready for network 124 | """ 125 | height_orig = img.shape[0] 126 | width_orig = img.shape[1] 127 | 128 | if width_orig > height_orig: 129 | scale = width_orig / 384 130 | else: 131 | scale = height_orig / 384 132 | 133 | height = (np.ceil(height_orig / scale / 32) * 32).astype(int) 134 | width = (np.ceil(width_orig / scale / 32) * 32).astype(int) 135 | 136 | img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) 137 | 138 | img_resized = ( 139 | torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() 140 | ) 141 | img_resized = img_resized.unsqueeze(0) 142 | 143 | return img_resized 144 | 145 | 146 | def resize_depth(depth, width, height): 147 | """Resize depth map and bring to CPU (numpy). 148 | 149 | Args: 150 | depth (tensor): depth 151 | width (int): image width 152 | height (int): image height 153 | 154 | Returns: 155 | array: processed depth 156 | """ 157 | depth = torch.squeeze(depth[0, :, :, :]).to("cpu") 158 | 159 | depth_resized = cv2.resize( 160 | depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC 161 | ) 162 | 163 | return depth_resized 164 | 165 | def write_depth(path, depth, bits=1): 166 | """Write depth map to pfm and png file. 167 | 168 | Args: 169 | path (str): filepath without extension 170 | depth (array): depth 171 | """ 172 | write_pfm(path + ".pfm", depth.astype(np.float32)) 173 | 174 | depth_min = depth.min() 175 | depth_max = depth.max() 176 | 177 | max_val = (2**(8*bits))-1 178 | 179 | if depth_max - depth_min > np.finfo("float").eps: 180 | out = max_val * (depth - depth_min) / (depth_max - depth_min) 181 | else: 182 | out = np.zeros(depth.shape, dtype=depth.type) 183 | 184 | if bits == 1: 185 | cv2.imwrite(path + ".png", out.astype("uint8")) 186 | elif bits == 2: 187 | cv2.imwrite(path + ".png", out.astype("uint16")) 188 | 189 | return 190 | -------------------------------------------------------------------------------- /annotator/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | 5 | 6 | annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts') 7 | 8 | 9 | def HWC3(x): 10 | assert x.dtype == np.uint8 11 | if x.ndim == 2: 12 | x = x[:, :, None] 13 | assert x.ndim == 3 14 | H, W, C = x.shape 15 | assert C == 1 or C == 3 or C == 4 16 | if C == 3: 17 | return x 18 | if C == 1: 19 | return np.concatenate([x, x, x], axis=2) 20 | if C == 4: 21 | color = x[:, :, 0:3].astype(np.float32) 22 | alpha = x[:, :, 3:4].astype(np.float32) / 255.0 23 | y = color * alpha + 255.0 * (1.0 - alpha) 24 | y = y.clip(0, 255).astype(np.uint8) 25 | return y 26 | 27 | 28 | def resize_image(input_image, resolution): 29 | H, W, C = input_image.shape 30 | H = float(H) 31 | W = float(W) 32 | k = float(resolution) / min(H, W) 33 | H *= k 34 | W *= k 35 | H = int(np.round(H / 64.0)) * 64 36 | W = int(np.round(W / 64.0)) * 64 37 | img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) 38 | return img 39 | -------------------------------------------------------------------------------- /ckpt/cldm_v15.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: cldm.cldm.ControlLDM 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.0120 6 | num_timesteps_cond: 1 7 | log_every_t: 200 8 | timesteps: 1000 9 | first_stage_key: "jpg" 10 | cond_stage_key: "txt" 11 | control_key: "hint" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | only_mid_control: False 20 | 21 | control_stage_config: 22 | target: cldm.cldm.ControlNet 23 | params: 24 | image_size: 32 # unused 25 | in_channels: 4 26 | hint_channels: 3 27 | model_channels: 320 28 | attention_resolutions: [ 4, 2, 1 ] 29 | num_res_blocks: 2 30 | channel_mult: [ 1, 2, 4, 4 ] 31 | num_heads: 8 32 | use_spatial_transformer: True 33 | transformer_depth: 1 34 | context_dim: 768 35 | use_checkpoint: True 36 | legacy: False 37 | 38 | unet_config: 39 | target: cldm.cldm.ControlledUnetModel 40 | params: 41 | image_size: 32 # unused 42 | in_channels: 4 43 | out_channels: 4 44 | model_channels: 320 45 | attention_resolutions: [ 4, 2, 1 ] 46 | num_res_blocks: 2 47 | channel_mult: [ 1, 2, 4, 4 ] 48 | num_heads: 8 49 | use_spatial_transformer: True 50 | transformer_depth: 1 51 | context_dim: 768 52 | use_checkpoint: True 53 | legacy: False 54 | 55 | first_stage_config: 56 | target: ldm.models.autoencoder.AutoencoderKL 57 | params: 58 | embed_dim: 4 59 | monitor: val/rec_loss 60 | ddconfig: 61 | double_z: true 62 | z_channels: 4 63 | resolution: 256 64 | in_channels: 3 65 | out_ch: 3 66 | ch: 128 67 | ch_mult: 68 | - 1 69 | - 2 70 | - 4 71 | - 4 72 | num_res_blocks: 2 73 | attn_resolutions: [] 74 | dropout: 0.0 75 | lossconfig: 76 | target: torch.nn.Identity 77 | 78 | cond_stage_config: 79 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 80 | -------------------------------------------------------------------------------- /cldm/hack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import einops 3 | 4 | import ldm.modules.encoders.modules 5 | import ldm.modules.attention 6 | 7 | from transformers import logging 8 | from ldm.modules.attention import default 9 | 10 | 11 | def disable_verbosity(): 12 | logging.set_verbosity_error() 13 | print('logging improved.') 14 | return 15 | 16 | 17 | def enable_sliced_attention(): 18 | ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward 19 | print('Enabled sliced_attention.') 20 | return 21 | 22 | 23 | def hack_everything(clip_skip=0): 24 | disable_verbosity() 25 | ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward 26 | ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip 27 | print('Enabled clip hacks.') 28 | return 29 | 30 | 31 | # Written by Lvmin 32 | def _hacked_clip_forward(self, text): 33 | PAD = self.tokenizer.pad_token_id 34 | EOS = self.tokenizer.eos_token_id 35 | BOS = self.tokenizer.bos_token_id 36 | 37 | def tokenize(t): 38 | return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"] 39 | 40 | def transformer_encode(t): 41 | if self.clip_skip > 1: 42 | rt = self.transformer(input_ids=t, output_hidden_states=True) 43 | return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip]) 44 | else: 45 | return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state 46 | 47 | def split(x): 48 | return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3] 49 | 50 | def pad(x, p, i): 51 | return x[:i] if len(x) >= i else x + [p] * (i - len(x)) 52 | 53 | raw_tokens_list = tokenize(text) 54 | tokens_list = [] 55 | 56 | for raw_tokens in raw_tokens_list: 57 | raw_tokens_123 = split(raw_tokens) 58 | raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123] 59 | raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123] 60 | tokens_list.append(raw_tokens_123) 61 | 62 | tokens_list = torch.IntTensor(tokens_list).to(self.device) 63 | 64 | feed = einops.rearrange(tokens_list, 'b f i -> (b f) i') 65 | y = transformer_encode(feed) 66 | z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3) 67 | 68 | return z 69 | 70 | 71 | # Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py 72 | def _hacked_sliced_attentin_forward(self, x, context=None, mask=None): 73 | h = self.heads 74 | 75 | q = self.to_q(x) 76 | context = default(context, x) 77 | k = self.to_k(context) 78 | v = self.to_v(context) 79 | del context, x 80 | 81 | q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 82 | 83 | limit = k.shape[0] 84 | att_step = 1 85 | q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0)) 86 | k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0)) 87 | v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0)) 88 | 89 | q_chunks.reverse() 90 | k_chunks.reverse() 91 | v_chunks.reverse() 92 | sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) 93 | del k, q, v 94 | for i in range(0, limit, att_step): 95 | q_buffer = q_chunks.pop() 96 | k_buffer = k_chunks.pop() 97 | v_buffer = v_chunks.pop() 98 | sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale 99 | 100 | del k_buffer, q_buffer 101 | # attention, what we cannot get enough of, by chunks 102 | 103 | sim_buffer = sim_buffer.softmax(dim=-1) 104 | 105 | sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer) 106 | del v_buffer 107 | sim[i:i + att_step, :, :] = sim_buffer 108 | 109 | del sim_buffer 110 | sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h) 111 | return self.to_out(sim) 112 | -------------------------------------------------------------------------------- /cldm/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import torchvision 6 | from PIL import Image 7 | from pytorch_lightning.callbacks import Callback 8 | from pytorch_lightning.utilities.distributed import rank_zero_only 9 | 10 | 11 | class ImageLogger(Callback): 12 | def __init__(self, batch_frequency=2000, max_images=4, clamp=True, increase_log_steps=True, 13 | rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False, 14 | log_images_kwargs=None): 15 | super().__init__() 16 | self.rescale = rescale 17 | self.batch_freq = batch_frequency 18 | self.max_images = max_images 19 | if not increase_log_steps: 20 | self.log_steps = [self.batch_freq] 21 | self.clamp = clamp 22 | self.disabled = disabled 23 | self.log_on_batch_idx = log_on_batch_idx 24 | self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} 25 | self.log_first_step = log_first_step 26 | 27 | @rank_zero_only 28 | def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx): 29 | root = os.path.join(save_dir, "image_log", split) 30 | for k in images: 31 | grid = torchvision.utils.make_grid(images[k], nrow=4) 32 | if self.rescale: 33 | grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w 34 | grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) 35 | grid = grid.numpy() 36 | grid = (grid * 255).astype(np.uint8) 37 | filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx) 38 | path = os.path.join(root, filename) 39 | os.makedirs(os.path.split(path)[0], exist_ok=True) 40 | Image.fromarray(grid).save(path) 41 | 42 | def log_img(self, pl_module, batch, batch_idx, split="train"): 43 | check_idx = batch_idx # if self.log_on_batch_idx else pl_module.global_step 44 | if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0 45 | hasattr(pl_module, "log_images") and 46 | callable(pl_module.log_images) and 47 | self.max_images > 0): 48 | logger = type(pl_module.logger) 49 | 50 | is_train = pl_module.training 51 | if is_train: 52 | pl_module.eval() 53 | 54 | with torch.no_grad(): 55 | images = pl_module.log_images(batch, split=split, **self.log_images_kwargs) 56 | 57 | for k in images: 58 | N = min(images[k].shape[0], self.max_images) 59 | images[k] = images[k][:N] 60 | if isinstance(images[k], torch.Tensor): 61 | images[k] = images[k].detach().cpu() 62 | if self.clamp: 63 | images[k] = torch.clamp(images[k], -1., 1.) 64 | 65 | self.log_local(pl_module.logger.save_dir, split, images, 66 | pl_module.global_step, pl_module.current_epoch, batch_idx) 67 | 68 | if is_train: 69 | pl_module.train() 70 | 71 | def check_frequency(self, check_idx): 72 | return check_idx % self.batch_freq == 0 73 | 74 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 75 | if not self.disabled: 76 | self.log_img(pl_module, batch, batch_idx, split="train") 77 | -------------------------------------------------------------------------------- /cldm/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from omegaconf import OmegaConf 5 | from ldm.util import instantiate_from_config 6 | 7 | 8 | def get_state_dict(d): 9 | return d.get('state_dict', d) 10 | 11 | 12 | def load_state_dict(ckpt_path, location='cpu'): 13 | _, extension = os.path.splitext(ckpt_path) 14 | if extension.lower() == ".safetensors": 15 | import safetensors.torch 16 | state_dict = safetensors.torch.load_file(ckpt_path, device=location) 17 | else: 18 | state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location))) 19 | state_dict = get_state_dict(state_dict) 20 | print(f'Loaded state_dict from [{ckpt_path}]') 21 | return state_dict 22 | 23 | 24 | def create_model(config_path): 25 | config = OmegaConf.load(config_path) 26 | model = instantiate_from_config(config.model).cpu() 27 | print(f'Loaded model config from [{config_path}]') 28 | return model 29 | -------------------------------------------------------------------------------- /ldm/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rese1f/StableVideo/298b52f8df5a1edb93ee46c59001833a2b291aa5/ldm/data/__init__.py -------------------------------------------------------------------------------- /ldm/data/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ldm.modules.midas.api import load_midas_transform 4 | 5 | 6 | class AddMiDaS(object): 7 | def __init__(self, model_type): 8 | super().__init__() 9 | self.transform = load_midas_transform(model_type) 10 | 11 | def pt2np(self, x): 12 | x = ((x + 1.0) * .5).detach().cpu().numpy() 13 | return x 14 | 15 | def np2pt(self, x): 16 | x = torch.from_numpy(x) * 2 - 1. 17 | return x 18 | 19 | def __call__(self, sample): 20 | # sample['jpg'] is tensor hwc in [-1, 1] at this point 21 | x = self.pt2np(sample['jpg']) 22 | x = self.transform({"image": x})["image"] 23 | sample['midas_in'] = x 24 | return sample -------------------------------------------------------------------------------- /ldm/models/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | import torch.nn.functional as F 4 | from contextlib import contextmanager 5 | 6 | from ldm.modules.diffusionmodules.model import Encoder, Decoder 7 | from ldm.modules.distributions.distributions import DiagonalGaussianDistribution 8 | 9 | from ldm.util import instantiate_from_config 10 | from ldm.modules.ema import LitEma 11 | 12 | 13 | class AutoencoderKL(pl.LightningModule): 14 | def __init__(self, 15 | ddconfig, 16 | lossconfig, 17 | embed_dim, 18 | ckpt_path=None, 19 | ignore_keys=[], 20 | image_key="image", 21 | colorize_nlabels=None, 22 | monitor=None, 23 | ema_decay=None, 24 | learn_logvar=False 25 | ): 26 | super().__init__() 27 | self.learn_logvar = learn_logvar 28 | self.image_key = image_key 29 | self.encoder = Encoder(**ddconfig) 30 | self.decoder = Decoder(**ddconfig) 31 | self.loss = instantiate_from_config(lossconfig) 32 | assert ddconfig["double_z"] 33 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) 34 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 35 | self.embed_dim = embed_dim 36 | if colorize_nlabels is not None: 37 | assert type(colorize_nlabels)==int 38 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 39 | if monitor is not None: 40 | self.monitor = monitor 41 | 42 | self.use_ema = ema_decay is not None 43 | if self.use_ema: 44 | self.ema_decay = ema_decay 45 | assert 0. < ema_decay < 1. 46 | self.model_ema = LitEma(self, decay=ema_decay) 47 | print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") 48 | 49 | if ckpt_path is not None: 50 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 51 | 52 | def init_from_ckpt(self, path, ignore_keys=list()): 53 | sd = torch.load(path, map_location="cpu")["state_dict"] 54 | keys = list(sd.keys()) 55 | for k in keys: 56 | for ik in ignore_keys: 57 | if k.startswith(ik): 58 | print("Deleting key {} from state_dict.".format(k)) 59 | del sd[k] 60 | self.load_state_dict(sd, strict=False) 61 | print(f"Restored from {path}") 62 | 63 | @contextmanager 64 | def ema_scope(self, context=None): 65 | if self.use_ema: 66 | self.model_ema.store(self.parameters()) 67 | self.model_ema.copy_to(self) 68 | if context is not None: 69 | print(f"{context}: Switched to EMA weights") 70 | try: 71 | yield None 72 | finally: 73 | if self.use_ema: 74 | self.model_ema.restore(self.parameters()) 75 | if context is not None: 76 | print(f"{context}: Restored training weights") 77 | 78 | def on_train_batch_end(self, *args, **kwargs): 79 | if self.use_ema: 80 | self.model_ema(self) 81 | 82 | def encode(self, x): 83 | h = self.encoder(x) 84 | moments = self.quant_conv(h) 85 | posterior = DiagonalGaussianDistribution(moments) 86 | return posterior 87 | 88 | def decode(self, z): 89 | z = self.post_quant_conv(z) 90 | dec = self.decoder(z) 91 | return dec 92 | 93 | def forward(self, input, sample_posterior=True): 94 | posterior = self.encode(input) 95 | if sample_posterior: 96 | z = posterior.sample() 97 | else: 98 | z = posterior.mode() 99 | dec = self.decode(z) 100 | return dec, posterior 101 | 102 | def get_input(self, batch, k): 103 | x = batch[k] 104 | if len(x.shape) == 3: 105 | x = x[..., None] 106 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() 107 | return x 108 | 109 | def training_step(self, batch, batch_idx, optimizer_idx): 110 | inputs = self.get_input(batch, self.image_key) 111 | reconstructions, posterior = self(inputs) 112 | 113 | if optimizer_idx == 0: 114 | # train encoder+decoder+logvar 115 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 116 | last_layer=self.get_last_layer(), split="train") 117 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 118 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) 119 | return aeloss 120 | 121 | if optimizer_idx == 1: 122 | # train the discriminator 123 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 124 | last_layer=self.get_last_layer(), split="train") 125 | 126 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 127 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) 128 | return discloss 129 | 130 | def validation_step(self, batch, batch_idx): 131 | log_dict = self._validation_step(batch, batch_idx) 132 | with self.ema_scope(): 133 | log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") 134 | return log_dict 135 | 136 | def _validation_step(self, batch, batch_idx, postfix=""): 137 | inputs = self.get_input(batch, self.image_key) 138 | reconstructions, posterior = self(inputs) 139 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, 140 | last_layer=self.get_last_layer(), split="val"+postfix) 141 | 142 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, 143 | last_layer=self.get_last_layer(), split="val"+postfix) 144 | 145 | self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) 146 | self.log_dict(log_dict_ae) 147 | self.log_dict(log_dict_disc) 148 | return self.log_dict 149 | 150 | def configure_optimizers(self): 151 | lr = self.learning_rate 152 | ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list( 153 | self.quant_conv.parameters()) + list(self.post_quant_conv.parameters()) 154 | if self.learn_logvar: 155 | print(f"{self.__class__.__name__}: Learning logvar") 156 | ae_params_list.append(self.loss.logvar) 157 | opt_ae = torch.optim.Adam(ae_params_list, 158 | lr=lr, betas=(0.5, 0.9)) 159 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), 160 | lr=lr, betas=(0.5, 0.9)) 161 | return [opt_ae, opt_disc], [] 162 | 163 | def get_last_layer(self): 164 | return self.decoder.conv_out.weight 165 | 166 | @torch.no_grad() 167 | def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): 168 | log = dict() 169 | x = self.get_input(batch, self.image_key) 170 | x = x.to(self.device) 171 | if not only_inputs: 172 | xrec, posterior = self(x) 173 | if x.shape[1] > 3: 174 | # colorize with random projection 175 | assert xrec.shape[1] > 3 176 | x = self.to_rgb(x) 177 | xrec = self.to_rgb(xrec) 178 | log["samples"] = self.decode(torch.randn_like(posterior.sample())) 179 | log["reconstructions"] = xrec 180 | if log_ema or self.use_ema: 181 | with self.ema_scope(): 182 | xrec_ema, posterior_ema = self(x) 183 | if x.shape[1] > 3: 184 | # colorize with random projection 185 | assert xrec_ema.shape[1] > 3 186 | xrec_ema = self.to_rgb(xrec_ema) 187 | log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample())) 188 | log["reconstructions_ema"] = xrec_ema 189 | log["inputs"] = x 190 | return log 191 | 192 | def to_rgb(self, x): 193 | assert self.image_key == "segmentation" 194 | if not hasattr(self, "colorize"): 195 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 196 | x = F.conv2d(x, weight=self.colorize) 197 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1. 198 | return x 199 | 200 | 201 | class IdentityFirstStage(torch.nn.Module): 202 | def __init__(self, *args, vq_interface=False, **kwargs): 203 | self.vq_interface = vq_interface 204 | super().__init__() 205 | 206 | def encode(self, x, *args, **kwargs): 207 | return x 208 | 209 | def decode(self, x, *args, **kwargs): 210 | return x 211 | 212 | def quantize(self, x, *args, **kwargs): 213 | if self.vq_interface: 214 | return x, None, [None, None, None] 215 | return x 216 | 217 | def forward(self, x, *args, **kwargs): 218 | return x 219 | 220 | -------------------------------------------------------------------------------- /ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rese1f/StableVideo/298b52f8df5a1edb93ee46c59001833a2b291aa5/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler import DPMSolverSampler -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/sampler.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | import torch 3 | 4 | from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver 5 | 6 | 7 | MODEL_TYPES = { 8 | "eps": "noise", 9 | "v": "v" 10 | } 11 | 12 | 13 | class DPMSolverSampler(object): 14 | def __init__(self, model, **kwargs): 15 | super().__init__() 16 | self.model = model 17 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) 18 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) 19 | 20 | def register_buffer(self, name, attr): 21 | if type(attr) == torch.Tensor: 22 | if attr.device != torch.device("cuda"): 23 | attr = attr.to(torch.device("cuda")) 24 | setattr(self, name, attr) 25 | 26 | @torch.no_grad() 27 | def sample(self, 28 | S, 29 | batch_size, 30 | shape, 31 | conditioning=None, 32 | callback=None, 33 | normals_sequence=None, 34 | img_callback=None, 35 | quantize_x0=False, 36 | eta=0., 37 | mask=None, 38 | x0=None, 39 | temperature=1., 40 | noise_dropout=0., 41 | score_corrector=None, 42 | corrector_kwargs=None, 43 | verbose=True, 44 | x_T=None, 45 | log_every_t=100, 46 | unconditional_guidance_scale=1., 47 | unconditional_conditioning=None, 48 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 49 | **kwargs 50 | ): 51 | if conditioning is not None: 52 | if isinstance(conditioning, dict): 53 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 54 | if cbs != batch_size: 55 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 56 | else: 57 | if conditioning.shape[0] != batch_size: 58 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 59 | 60 | # sampling 61 | C, H, W = shape 62 | size = (batch_size, C, H, W) 63 | 64 | print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') 65 | 66 | device = self.model.betas.device 67 | if x_T is None: 68 | img = torch.randn(size, device=device) 69 | else: 70 | img = x_T 71 | 72 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) 73 | 74 | model_fn = model_wrapper( 75 | lambda x, t, c: self.model.apply_model(x, t, c), 76 | ns, 77 | model_type=MODEL_TYPES[self.model.parameterization], 78 | guidance_type="classifier-free", 79 | condition=conditioning, 80 | unconditional_condition=unconditional_conditioning, 81 | guidance_scale=unconditional_guidance_scale, 82 | ) 83 | 84 | dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) 85 | x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) 86 | 87 | return x.to(device), None -------------------------------------------------------------------------------- /ldm/models/diffusion/plms.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | from functools import partial 7 | 8 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like 9 | from ldm.models.diffusion.sampling_util import norm_thresholding 10 | 11 | 12 | class PLMSSampler(object): 13 | def __init__(self, model, schedule="linear", **kwargs): 14 | super().__init__() 15 | self.model = model 16 | self.ddpm_num_timesteps = model.num_timesteps 17 | self.schedule = schedule 18 | 19 | def register_buffer(self, name, attr): 20 | if type(attr) == torch.Tensor: 21 | if attr.device != torch.device("cuda"): 22 | attr = attr.to(torch.device("cuda")) 23 | setattr(self, name, attr) 24 | 25 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 26 | if ddim_eta != 0: 27 | raise ValueError('ddim_eta must be 0 for PLMS') 28 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 29 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 30 | alphas_cumprod = self.model.alphas_cumprod 31 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 32 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 33 | 34 | self.register_buffer('betas', to_torch(self.model.betas)) 35 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 36 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 37 | 38 | # calculations for diffusion q(x_t | x_{t-1}) and others 39 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 40 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 41 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 42 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 43 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 44 | 45 | # ddim sampling parameters 46 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 47 | ddim_timesteps=self.ddim_timesteps, 48 | eta=ddim_eta,verbose=verbose) 49 | self.register_buffer('ddim_sigmas', ddim_sigmas) 50 | self.register_buffer('ddim_alphas', ddim_alphas) 51 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 52 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 53 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 54 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 55 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 56 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 57 | 58 | @torch.no_grad() 59 | def sample(self, 60 | S, 61 | batch_size, 62 | shape, 63 | conditioning=None, 64 | callback=None, 65 | normals_sequence=None, 66 | img_callback=None, 67 | quantize_x0=False, 68 | eta=0., 69 | mask=None, 70 | x0=None, 71 | temperature=1., 72 | noise_dropout=0., 73 | score_corrector=None, 74 | corrector_kwargs=None, 75 | verbose=True, 76 | x_T=None, 77 | log_every_t=100, 78 | unconditional_guidance_scale=1., 79 | unconditional_conditioning=None, 80 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 81 | dynamic_threshold=None, 82 | **kwargs 83 | ): 84 | if conditioning is not None: 85 | if isinstance(conditioning, dict): 86 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 87 | if cbs != batch_size: 88 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 89 | else: 90 | if conditioning.shape[0] != batch_size: 91 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 92 | 93 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) 94 | # sampling 95 | C, H, W = shape 96 | size = (batch_size, C, H, W) 97 | print(f'Data shape for PLMS sampling is {size}') 98 | 99 | samples, intermediates = self.plms_sampling(conditioning, size, 100 | callback=callback, 101 | img_callback=img_callback, 102 | quantize_denoised=quantize_x0, 103 | mask=mask, x0=x0, 104 | ddim_use_original_steps=False, 105 | noise_dropout=noise_dropout, 106 | temperature=temperature, 107 | score_corrector=score_corrector, 108 | corrector_kwargs=corrector_kwargs, 109 | x_T=x_T, 110 | log_every_t=log_every_t, 111 | unconditional_guidance_scale=unconditional_guidance_scale, 112 | unconditional_conditioning=unconditional_conditioning, 113 | dynamic_threshold=dynamic_threshold, 114 | ) 115 | return samples, intermediates 116 | 117 | @torch.no_grad() 118 | def plms_sampling(self, cond, shape, 119 | x_T=None, ddim_use_original_steps=False, 120 | callback=None, timesteps=None, quantize_denoised=False, 121 | mask=None, x0=None, img_callback=None, log_every_t=100, 122 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 123 | unconditional_guidance_scale=1., unconditional_conditioning=None, 124 | dynamic_threshold=None): 125 | device = self.model.betas.device 126 | b = shape[0] 127 | if x_T is None: 128 | img = torch.randn(shape, device=device) 129 | else: 130 | img = x_T 131 | 132 | if timesteps is None: 133 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 134 | elif timesteps is not None and not ddim_use_original_steps: 135 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 136 | timesteps = self.ddim_timesteps[:subset_end] 137 | 138 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 139 | time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) 140 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 141 | print(f"Running PLMS Sampling with {total_steps} timesteps") 142 | 143 | iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) 144 | old_eps = [] 145 | 146 | for i, step in enumerate(iterator): 147 | index = total_steps - i - 1 148 | ts = torch.full((b,), step, device=device, dtype=torch.long) 149 | ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) 150 | 151 | if mask is not None: 152 | assert x0 is not None 153 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 154 | img = img_orig * mask + (1. - mask) * img 155 | 156 | outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 157 | quantize_denoised=quantize_denoised, temperature=temperature, 158 | noise_dropout=noise_dropout, score_corrector=score_corrector, 159 | corrector_kwargs=corrector_kwargs, 160 | unconditional_guidance_scale=unconditional_guidance_scale, 161 | unconditional_conditioning=unconditional_conditioning, 162 | old_eps=old_eps, t_next=ts_next, 163 | dynamic_threshold=dynamic_threshold) 164 | img, pred_x0, e_t = outs 165 | old_eps.append(e_t) 166 | if len(old_eps) >= 4: 167 | old_eps.pop(0) 168 | if callback: callback(i) 169 | if img_callback: img_callback(pred_x0, i) 170 | 171 | if index % log_every_t == 0 or index == total_steps - 1: 172 | intermediates['x_inter'].append(img) 173 | intermediates['pred_x0'].append(pred_x0) 174 | 175 | return img, intermediates 176 | 177 | @torch.no_grad() 178 | def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 179 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 180 | unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, 181 | dynamic_threshold=None): 182 | b, *_, device = *x.shape, x.device 183 | 184 | def get_model_output(x, t): 185 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 186 | e_t = self.model.apply_model(x, t, c) 187 | else: 188 | x_in = torch.cat([x] * 2) 189 | t_in = torch.cat([t] * 2) 190 | c_in = torch.cat([unconditional_conditioning, c]) 191 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) 192 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) 193 | 194 | if score_corrector is not None: 195 | assert self.model.parameterization == "eps" 196 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 197 | 198 | return e_t 199 | 200 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 201 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 202 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 203 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 204 | 205 | def get_x_prev_and_pred_x0(e_t, index): 206 | # select parameters corresponding to the currently considered timestep 207 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) 208 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) 209 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) 210 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) 211 | 212 | # current prediction for x_0 213 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 214 | if quantize_denoised: 215 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 216 | if dynamic_threshold is not None: 217 | pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) 218 | # direction pointing to x_t 219 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 220 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 221 | if noise_dropout > 0.: 222 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 223 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 224 | return x_prev, pred_x0 225 | 226 | e_t = get_model_output(x, t) 227 | if len(old_eps) == 0: 228 | # Pseudo Improved Euler (2nd order) 229 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) 230 | e_t_next = get_model_output(x_prev, t_next) 231 | e_t_prime = (e_t + e_t_next) / 2 232 | elif len(old_eps) == 1: 233 | # 2nd order Pseudo Linear Multistep (Adams-Bashforth) 234 | e_t_prime = (3 * e_t - old_eps[-1]) / 2 235 | elif len(old_eps) == 2: 236 | # 3nd order Pseudo Linear Multistep (Adams-Bashforth) 237 | e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 238 | elif len(old_eps) >= 3: 239 | # 4nd order Pseudo Linear Multistep (Adams-Bashforth) 240 | e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 241 | 242 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) 243 | 244 | return x_prev, pred_x0, e_t 245 | -------------------------------------------------------------------------------- /ldm/models/diffusion/sampling_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def append_dims(x, target_dims): 6 | """Appends dimensions to the end of a tensor until it has target_dims dimensions. 7 | From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" 8 | dims_to_append = target_dims - x.ndim 9 | if dims_to_append < 0: 10 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') 11 | return x[(...,) + (None,) * dims_to_append] 12 | 13 | 14 | def norm_thresholding(x0, value): 15 | s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) 16 | return x0 * (value / s) 17 | 18 | 19 | def spatial_norm_thresholding(x0, value): 20 | # b c h w 21 | s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) 22 | return x0 * (value / s) -------------------------------------------------------------------------------- /ldm/modules/attention.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, einsum 6 | from einops import rearrange, repeat 7 | from typing import Optional, Any 8 | 9 | from ldm.modules.diffusionmodules.util import checkpoint 10 | 11 | 12 | try: 13 | import xformers 14 | import xformers.ops 15 | XFORMERS_IS_AVAILBLE = True 16 | except: 17 | XFORMERS_IS_AVAILBLE = False 18 | 19 | # CrossAttn precision handling 20 | import os 21 | _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") 22 | 23 | def exists(val): 24 | return val is not None 25 | 26 | 27 | def uniq(arr): 28 | return{el: True for el in arr}.keys() 29 | 30 | 31 | def default(val, d): 32 | if exists(val): 33 | return val 34 | return d() if isfunction(d) else d 35 | 36 | 37 | def max_neg_value(t): 38 | return -torch.finfo(t.dtype).max 39 | 40 | 41 | def init_(tensor): 42 | dim = tensor.shape[-1] 43 | std = 1 / math.sqrt(dim) 44 | tensor.uniform_(-std, std) 45 | return tensor 46 | 47 | 48 | # feedforward 49 | class GEGLU(nn.Module): 50 | def __init__(self, dim_in, dim_out): 51 | super().__init__() 52 | self.proj = nn.Linear(dim_in, dim_out * 2) 53 | 54 | def forward(self, x): 55 | x, gate = self.proj(x).chunk(2, dim=-1) 56 | return x * F.gelu(gate) 57 | 58 | 59 | class FeedForward(nn.Module): 60 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 61 | super().__init__() 62 | inner_dim = int(dim * mult) 63 | dim_out = default(dim_out, dim) 64 | project_in = nn.Sequential( 65 | nn.Linear(dim, inner_dim), 66 | nn.GELU() 67 | ) if not glu else GEGLU(dim, inner_dim) 68 | 69 | self.net = nn.Sequential( 70 | project_in, 71 | nn.Dropout(dropout), 72 | nn.Linear(inner_dim, dim_out) 73 | ) 74 | 75 | def forward(self, x): 76 | return self.net(x) 77 | 78 | 79 | def zero_module(module): 80 | """ 81 | Zero out the parameters of a module and return it. 82 | """ 83 | for p in module.parameters(): 84 | p.detach().zero_() 85 | return module 86 | 87 | 88 | def Normalize(in_channels): 89 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 90 | 91 | 92 | class SpatialSelfAttention(nn.Module): 93 | def __init__(self, in_channels): 94 | super().__init__() 95 | self.in_channels = in_channels 96 | 97 | self.norm = Normalize(in_channels) 98 | self.q = torch.nn.Conv2d(in_channels, 99 | in_channels, 100 | kernel_size=1, 101 | stride=1, 102 | padding=0) 103 | self.k = torch.nn.Conv2d(in_channels, 104 | in_channels, 105 | kernel_size=1, 106 | stride=1, 107 | padding=0) 108 | self.v = torch.nn.Conv2d(in_channels, 109 | in_channels, 110 | kernel_size=1, 111 | stride=1, 112 | padding=0) 113 | self.proj_out = torch.nn.Conv2d(in_channels, 114 | in_channels, 115 | kernel_size=1, 116 | stride=1, 117 | padding=0) 118 | 119 | def forward(self, x): 120 | h_ = x 121 | h_ = self.norm(h_) 122 | q = self.q(h_) 123 | k = self.k(h_) 124 | v = self.v(h_) 125 | 126 | # compute attention 127 | b,c,h,w = q.shape 128 | q = rearrange(q, 'b c h w -> b (h w) c') 129 | k = rearrange(k, 'b c h w -> b c (h w)') 130 | w_ = torch.einsum('bij,bjk->bik', q, k) 131 | 132 | w_ = w_ * (int(c)**(-0.5)) 133 | w_ = torch.nn.functional.softmax(w_, dim=2) 134 | 135 | # attend to values 136 | v = rearrange(v, 'b c h w -> b c (h w)') 137 | w_ = rearrange(w_, 'b i j -> b j i') 138 | h_ = torch.einsum('bij,bjk->bik', v, w_) 139 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) 140 | h_ = self.proj_out(h_) 141 | 142 | return x+h_ 143 | 144 | 145 | class CrossAttention(nn.Module): 146 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): 147 | super().__init__() 148 | inner_dim = dim_head * heads 149 | context_dim = default(context_dim, query_dim) 150 | 151 | self.scale = dim_head ** -0.5 152 | self.heads = heads 153 | 154 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 155 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 156 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 157 | 158 | self.to_out = nn.Sequential( 159 | nn.Linear(inner_dim, query_dim), 160 | nn.Dropout(dropout) 161 | ) 162 | 163 | def forward(self, x, context=None, mask=None): 164 | h = self.heads 165 | 166 | q = self.to_q(x) 167 | context = default(context, x) 168 | k = self.to_k(context) 169 | v = self.to_v(context) 170 | 171 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 172 | 173 | # force cast to fp32 to avoid overflowing 174 | if _ATTN_PRECISION =="fp32": 175 | with torch.autocast(enabled=False, device_type = 'cuda'): 176 | q, k = q.float(), k.float() 177 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 178 | else: 179 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 180 | 181 | del q, k 182 | 183 | if exists(mask): 184 | mask = rearrange(mask, 'b ... -> b (...)') 185 | max_neg_value = -torch.finfo(sim.dtype).max 186 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 187 | sim.masked_fill_(~mask, max_neg_value) 188 | 189 | # attention, what we cannot get enough of 190 | sim = sim.softmax(dim=-1) 191 | 192 | out = einsum('b i j, b j d -> b i d', sim, v) 193 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 194 | return self.to_out(out) 195 | 196 | 197 | class MemoryEfficientCrossAttention(nn.Module): 198 | # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 199 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): 200 | super().__init__() 201 | print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " 202 | f"{heads} heads.") 203 | inner_dim = dim_head * heads 204 | context_dim = default(context_dim, query_dim) 205 | 206 | self.heads = heads 207 | self.dim_head = dim_head 208 | 209 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 210 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 211 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 212 | 213 | self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) 214 | self.attention_op: Optional[Any] = None 215 | 216 | def forward(self, x, context=None, mask=None): 217 | q = self.to_q(x) 218 | context = default(context, x) 219 | k = self.to_k(context) 220 | v = self.to_v(context) 221 | 222 | b, _, _ = q.shape 223 | q, k, v = map( 224 | lambda t: t.unsqueeze(3) 225 | .reshape(b, t.shape[1], self.heads, self.dim_head) 226 | .permute(0, 2, 1, 3) 227 | .reshape(b * self.heads, t.shape[1], self.dim_head) 228 | .contiguous(), 229 | (q, k, v), 230 | ) 231 | 232 | # actually compute the attention, what we cannot get enough of 233 | out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) 234 | 235 | if exists(mask): 236 | raise NotImplementedError 237 | out = ( 238 | out.unsqueeze(0) 239 | .reshape(b, self.heads, out.shape[1], self.dim_head) 240 | .permute(0, 2, 1, 3) 241 | .reshape(b, out.shape[1], self.heads * self.dim_head) 242 | ) 243 | return self.to_out(out) 244 | 245 | 246 | class BasicTransformerBlock(nn.Module): 247 | ATTENTION_MODES = { 248 | "softmax": CrossAttention, # vanilla attention 249 | "softmax-xformers": MemoryEfficientCrossAttention 250 | } 251 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, 252 | disable_self_attn=False): 253 | super().__init__() 254 | attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" 255 | assert attn_mode in self.ATTENTION_MODES 256 | attn_cls = self.ATTENTION_MODES[attn_mode] 257 | self.disable_self_attn = disable_self_attn 258 | self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, 259 | context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn 260 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 261 | self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, 262 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none 263 | self.norm1 = nn.LayerNorm(dim) 264 | self.norm2 = nn.LayerNorm(dim) 265 | self.norm3 = nn.LayerNorm(dim) 266 | self.checkpoint = checkpoint 267 | 268 | def forward(self, x, context=None): 269 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) 270 | 271 | def _forward(self, x, context=None): 272 | x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x 273 | x = self.attn2(self.norm2(x), context=context) + x 274 | x = self.ff(self.norm3(x)) + x 275 | return x 276 | 277 | 278 | class SpatialTransformer(nn.Module): 279 | """ 280 | Transformer block for image-like data. 281 | First, project the input (aka embedding) 282 | and reshape to b, t, d. 283 | Then apply standard transformer action. 284 | Finally, reshape to image 285 | NEW: use_linear for more efficiency instead of the 1x1 convs 286 | """ 287 | def __init__(self, in_channels, n_heads, d_head, 288 | depth=1, dropout=0., context_dim=None, 289 | disable_self_attn=False, use_linear=False, 290 | use_checkpoint=True): 291 | super().__init__() 292 | if exists(context_dim) and not isinstance(context_dim, list): 293 | context_dim = [context_dim] 294 | self.in_channels = in_channels 295 | inner_dim = n_heads * d_head 296 | self.norm = Normalize(in_channels) 297 | if not use_linear: 298 | self.proj_in = nn.Conv2d(in_channels, 299 | inner_dim, 300 | kernel_size=1, 301 | stride=1, 302 | padding=0) 303 | else: 304 | self.proj_in = nn.Linear(in_channels, inner_dim) 305 | 306 | self.transformer_blocks = nn.ModuleList( 307 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], 308 | disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) 309 | for d in range(depth)] 310 | ) 311 | if not use_linear: 312 | self.proj_out = zero_module(nn.Conv2d(inner_dim, 313 | in_channels, 314 | kernel_size=1, 315 | stride=1, 316 | padding=0)) 317 | else: 318 | self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) 319 | self.use_linear = use_linear 320 | 321 | def forward(self, x, context=None): 322 | # note: if no context is given, cross-attention defaults to self-attention 323 | if not isinstance(context, list): 324 | context = [context] 325 | b, c, h, w = x.shape 326 | x_in = x 327 | x = self.norm(x) 328 | if not self.use_linear: 329 | x = self.proj_in(x) 330 | x = rearrange(x, 'b c h w -> b (h w) c').contiguous() 331 | if self.use_linear: 332 | x = self.proj_in(x) 333 | for i, block in enumerate(self.transformer_blocks): 334 | x = block(x, context=context[i]) 335 | if self.use_linear: 336 | x = self.proj_out(x) 337 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() 338 | if not self.use_linear: 339 | x = self.proj_out(x) 340 | return x + x_in 341 | 342 | -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rese1f/StableVideo/298b52f8df5a1edb93ee46c59001833a2b291aa5/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/upscaling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from functools import partial 5 | 6 | from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule 7 | from ldm.util import default 8 | 9 | 10 | class AbstractLowScaleModel(nn.Module): 11 | # for concatenating a downsampled image to the latent representation 12 | def __init__(self, noise_schedule_config=None): 13 | super(AbstractLowScaleModel, self).__init__() 14 | if noise_schedule_config is not None: 15 | self.register_schedule(**noise_schedule_config) 16 | 17 | def register_schedule(self, beta_schedule="linear", timesteps=1000, 18 | linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 19 | betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, 20 | cosine_s=cosine_s) 21 | alphas = 1. - betas 22 | alphas_cumprod = np.cumprod(alphas, axis=0) 23 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) 24 | 25 | timesteps, = betas.shape 26 | self.num_timesteps = int(timesteps) 27 | self.linear_start = linear_start 28 | self.linear_end = linear_end 29 | assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' 30 | 31 | to_torch = partial(torch.tensor, dtype=torch.float32) 32 | 33 | self.register_buffer('betas', to_torch(betas)) 34 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 35 | self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) 36 | 37 | # calculations for diffusion q(x_t | x_{t-1}) and others 38 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) 39 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) 40 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) 41 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) 42 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) 43 | 44 | def q_sample(self, x_start, t, noise=None): 45 | noise = default(noise, lambda: torch.randn_like(x_start)) 46 | return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 47 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) 48 | 49 | def forward(self, x): 50 | return x, None 51 | 52 | def decode(self, x): 53 | return x 54 | 55 | 56 | class SimpleImageConcat(AbstractLowScaleModel): 57 | # no noise level conditioning 58 | def __init__(self): 59 | super(SimpleImageConcat, self).__init__(noise_schedule_config=None) 60 | self.max_noise_level = 0 61 | 62 | def forward(self, x): 63 | # fix to constant noise level 64 | return x, torch.zeros(x.shape[0], device=x.device).long() 65 | 66 | 67 | class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel): 68 | def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False): 69 | super().__init__(noise_schedule_config=noise_schedule_config) 70 | self.max_noise_level = max_noise_level 71 | 72 | def forward(self, x, noise_level=None): 73 | if noise_level is None: 74 | noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() 75 | else: 76 | assert isinstance(noise_level, torch.Tensor) 77 | z = self.q_sample(x, noise_level) 78 | return z, noise_level 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/util.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import os 12 | import math 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | from einops import repeat 17 | 18 | from ldm.util import instantiate_from_config 19 | 20 | 21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 22 | if schedule == "linear": 23 | betas = ( 24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 25 | ) 26 | 27 | elif schedule == "cosine": 28 | timesteps = ( 29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 30 | ) 31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 32 | alphas = torch.cos(alphas).pow(2) 33 | alphas = alphas / alphas[0] 34 | betas = 1 - alphas[1:] / alphas[:-1] 35 | betas = np.clip(betas, a_min=0, a_max=0.999) 36 | 37 | elif schedule == "sqrt_linear": 38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 39 | elif schedule == "sqrt": 40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 41 | else: 42 | raise ValueError(f"schedule '{schedule}' unknown.") 43 | return betas.numpy() 44 | 45 | 46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 47 | if ddim_discr_method == 'uniform': 48 | c = num_ddpm_timesteps // num_ddim_timesteps 49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 50 | elif ddim_discr_method == 'quad': 51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 52 | else: 53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 54 | 55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 56 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 57 | steps_out = ddim_timesteps + 1 58 | if verbose: 59 | print(f'Selected timesteps for ddim sampler: {steps_out}') 60 | return steps_out 61 | 62 | 63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 64 | # select alphas for computing the variance schedule 65 | alphas = alphacums[ddim_timesteps] 66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 67 | 68 | # according the the formula provided in https://arxiv.org/abs/2010.02502 69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 70 | if verbose: 71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 72 | print(f'For the chosen value of eta, which is {eta}, ' 73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 74 | return sigmas, alphas, alphas_prev 75 | 76 | 77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 78 | """ 79 | Create a beta schedule that discretizes the given alpha_t_bar function, 80 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 81 | :param num_diffusion_timesteps: the number of betas to produce. 82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 83 | produces the cumulative product of (1-beta) up to that 84 | part of the diffusion process. 85 | :param max_beta: the maximum beta to use; use values lower than 1 to 86 | prevent singularities. 87 | """ 88 | betas = [] 89 | for i in range(num_diffusion_timesteps): 90 | t1 = i / num_diffusion_timesteps 91 | t2 = (i + 1) / num_diffusion_timesteps 92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 93 | return np.array(betas) 94 | 95 | 96 | def extract_into_tensor(a, t, x_shape): 97 | b, *_ = t.shape 98 | out = a.gather(-1, t) 99 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 100 | 101 | 102 | def checkpoint(func, inputs, params, flag): 103 | """ 104 | Evaluate a function without caching intermediate activations, allowing for 105 | reduced memory at the expense of extra compute in the backward pass. 106 | :param func: the function to evaluate. 107 | :param inputs: the argument sequence to pass to `func`. 108 | :param params: a sequence of parameters `func` depends on but does not 109 | explicitly take as arguments. 110 | :param flag: if False, disable gradient checkpointing. 111 | """ 112 | if flag: 113 | args = tuple(inputs) + tuple(params) 114 | return CheckpointFunction.apply(func, len(inputs), *args) 115 | else: 116 | return func(*inputs) 117 | 118 | 119 | class CheckpointFunction(torch.autograd.Function): 120 | @staticmethod 121 | def forward(ctx, run_function, length, *args): 122 | ctx.run_function = run_function 123 | ctx.input_tensors = list(args[:length]) 124 | ctx.input_params = list(args[length:]) 125 | ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), 126 | "dtype": torch.get_autocast_gpu_dtype(), 127 | "cache_enabled": torch.is_autocast_cache_enabled()} 128 | with torch.no_grad(): 129 | output_tensors = ctx.run_function(*ctx.input_tensors) 130 | return output_tensors 131 | 132 | @staticmethod 133 | def backward(ctx, *output_grads): 134 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 135 | with torch.enable_grad(), \ 136 | torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): 137 | # Fixes a bug where the first op in run_function modifies the 138 | # Tensor storage in place, which is not allowed for detach()'d 139 | # Tensors. 140 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 141 | output_tensors = ctx.run_function(*shallow_copies) 142 | input_grads = torch.autograd.grad( 143 | output_tensors, 144 | ctx.input_tensors + ctx.input_params, 145 | output_grads, 146 | allow_unused=True, 147 | ) 148 | del ctx.input_tensors 149 | del ctx.input_params 150 | del output_tensors 151 | return (None, None) + input_grads 152 | 153 | 154 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 155 | """ 156 | Create sinusoidal timestep embeddings. 157 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 158 | These may be fractional. 159 | :param dim: the dimension of the output. 160 | :param max_period: controls the minimum frequency of the embeddings. 161 | :return: an [N x dim] Tensor of positional embeddings. 162 | """ 163 | if not repeat_only: 164 | half = dim // 2 165 | freqs = torch.exp( 166 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 167 | ).to(device=timesteps.device) 168 | args = timesteps[:, None].float() * freqs[None] 169 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 170 | if dim % 2: 171 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 172 | else: 173 | embedding = repeat(timesteps, 'b -> b d', d=dim) 174 | return embedding 175 | 176 | 177 | def zero_module(module): 178 | """ 179 | Zero out the parameters of a module and return it. 180 | """ 181 | for p in module.parameters(): 182 | p.detach().zero_() 183 | return module 184 | 185 | 186 | def scale_module(module, scale): 187 | """ 188 | Scale the parameters of a module and return it. 189 | """ 190 | for p in module.parameters(): 191 | p.detach().mul_(scale) 192 | return module 193 | 194 | 195 | def mean_flat(tensor): 196 | """ 197 | Take the mean over all non-batch dimensions. 198 | """ 199 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 200 | 201 | 202 | def normalization(channels): 203 | """ 204 | Make a standard normalization layer. 205 | :param channels: number of input channels. 206 | :return: an nn.Module for normalization. 207 | """ 208 | return GroupNorm32(32, channels) 209 | 210 | 211 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 212 | class SiLU(nn.Module): 213 | def forward(self, x): 214 | return x * torch.sigmoid(x) 215 | 216 | 217 | class GroupNorm32(nn.GroupNorm): 218 | def forward(self, x): 219 | return super().forward(x.float()).type(x.dtype) 220 | 221 | def conv_nd(dims, *args, **kwargs): 222 | """ 223 | Create a 1D, 2D, or 3D convolution module. 224 | """ 225 | if dims == 1: 226 | return nn.Conv1d(*args, **kwargs) 227 | elif dims == 2: 228 | return nn.Conv2d(*args, **kwargs) 229 | elif dims == 3: 230 | return nn.Conv3d(*args, **kwargs) 231 | raise ValueError(f"unsupported dimensions: {dims}") 232 | 233 | 234 | def linear(*args, **kwargs): 235 | """ 236 | Create a linear module. 237 | """ 238 | return nn.Linear(*args, **kwargs) 239 | 240 | 241 | def avg_pool_nd(dims, *args, **kwargs): 242 | """ 243 | Create a 1D, 2D, or 3D average pooling module. 244 | """ 245 | if dims == 1: 246 | return nn.AvgPool1d(*args, **kwargs) 247 | elif dims == 2: 248 | return nn.AvgPool2d(*args, **kwargs) 249 | elif dims == 3: 250 | return nn.AvgPool3d(*args, **kwargs) 251 | raise ValueError(f"unsupported dimensions: {dims}") 252 | 253 | 254 | class HybridConditioner(nn.Module): 255 | 256 | def __init__(self, c_concat_config, c_crossattn_config): 257 | super().__init__() 258 | self.concat_conditioner = instantiate_from_config(c_concat_config) 259 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 260 | 261 | def forward(self, c_concat, c_crossattn): 262 | c_concat = self.concat_conditioner(c_concat) 263 | c_crossattn = self.crossattn_conditioner(c_crossattn) 264 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} 265 | 266 | 267 | def noise_like(shape, device, repeat=False): 268 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 269 | noise = lambda: torch.randn(shape, device=device) 270 | return repeat_noise() if repeat else noise() -------------------------------------------------------------------------------- /ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rese1f/StableVideo/298b52f8df5a1edb93ee46c59001833a2b291aa5/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1, dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | # remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.', '') 20 | self.m_name2s_name.update({name: s_name}) 21 | self.register_buffer(s_name, p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def reset_num_updates(self): 26 | del self.num_updates 27 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int)) 28 | 29 | def forward(self, model): 30 | decay = self.decay 31 | 32 | if self.num_updates >= 0: 33 | self.num_updates += 1 34 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 35 | 36 | one_minus_decay = 1.0 - decay 37 | 38 | with torch.no_grad(): 39 | m_param = dict(model.named_parameters()) 40 | shadow_params = dict(self.named_buffers()) 41 | 42 | for key in m_param: 43 | if m_param[key].requires_grad: 44 | sname = self.m_name2s_name[key] 45 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 46 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 47 | else: 48 | assert not key in self.m_name2s_name 49 | 50 | def copy_to(self, model): 51 | m_param = dict(model.named_parameters()) 52 | shadow_params = dict(self.named_buffers()) 53 | for key in m_param: 54 | if m_param[key].requires_grad: 55 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 56 | else: 57 | assert not key in self.m_name2s_name 58 | 59 | def store(self, parameters): 60 | """ 61 | Save the current parameters for restoring later. 62 | Args: 63 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 64 | temporarily stored. 65 | """ 66 | self.collected_params = [param.clone() for param in parameters] 67 | 68 | def restore(self, parameters): 69 | """ 70 | Restore the parameters stored with the `store` method. 71 | Useful to validate the model with EMA parameters without affecting the 72 | original optimization process. Store the parameters before the 73 | `copy_to` method. After validation (or model saving), use this to 74 | restore the former parameters. 75 | Args: 76 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 77 | updated with the stored parameters. 78 | """ 79 | for c_param, param in zip(self.collected_params, parameters): 80 | param.data.copy_(c_param.data) 81 | -------------------------------------------------------------------------------- /ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rese1f/StableVideo/298b52f8df5a1edb93ee46c59001833a2b291aa5/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /ldm/modules/encoders/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.checkpoint import checkpoint 4 | 5 | from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel 6 | 7 | import open_clip 8 | from ldm.util import default, count_params 9 | 10 | 11 | class AbstractEncoder(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | def encode(self, *args, **kwargs): 16 | raise NotImplementedError 17 | 18 | 19 | class IdentityEncoder(AbstractEncoder): 20 | 21 | def encode(self, x): 22 | return x 23 | 24 | 25 | class ClassEmbedder(nn.Module): 26 | def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1): 27 | super().__init__() 28 | self.key = key 29 | self.embedding = nn.Embedding(n_classes, embed_dim) 30 | self.n_classes = n_classes 31 | self.ucg_rate = ucg_rate 32 | 33 | def forward(self, batch, key=None, disable_dropout=False): 34 | if key is None: 35 | key = self.key 36 | # this is for use in crossattn 37 | c = batch[key][:, None] 38 | if self.ucg_rate > 0. and not disable_dropout: 39 | mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) 40 | c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1) 41 | c = c.long() 42 | c = self.embedding(c) 43 | return c 44 | 45 | def get_unconditional_conditioning(self, bs, device="cuda"): 46 | uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) 47 | uc = torch.ones((bs,), device=device) * uc_class 48 | uc = {self.key: uc} 49 | return uc 50 | 51 | 52 | def disabled_train(self, mode=True): 53 | """Overwrite model.train with this function to make sure train/eval mode 54 | does not change anymore.""" 55 | return self 56 | 57 | 58 | class FrozenT5Embedder(AbstractEncoder): 59 | """Uses the T5 transformer encoder for text""" 60 | def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl 61 | super().__init__() 62 | self.tokenizer = T5Tokenizer.from_pretrained(version) 63 | self.transformer = T5EncoderModel.from_pretrained(version) 64 | self.device = device 65 | self.max_length = max_length # TODO: typical value? 66 | if freeze: 67 | self.freeze() 68 | 69 | def freeze(self): 70 | self.transformer = self.transformer.eval() 71 | #self.train = disabled_train 72 | for param in self.parameters(): 73 | param.requires_grad = False 74 | 75 | def forward(self, text): 76 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 77 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 78 | tokens = batch_encoding["input_ids"].to(self.device) 79 | outputs = self.transformer(input_ids=tokens) 80 | 81 | z = outputs.last_hidden_state 82 | return z 83 | 84 | def encode(self, text): 85 | return self(text) 86 | 87 | 88 | class FrozenCLIPEmbedder(AbstractEncoder): 89 | """Uses the CLIP transformer encoder for text (from huggingface)""" 90 | LAYERS = [ 91 | "last", 92 | "pooled", 93 | "hidden" 94 | ] 95 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, 96 | freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 97 | super().__init__() 98 | assert layer in self.LAYERS 99 | self.tokenizer = CLIPTokenizer.from_pretrained(version) 100 | self.transformer = CLIPTextModel.from_pretrained(version) 101 | self.device = device 102 | self.max_length = max_length 103 | if freeze: 104 | self.freeze() 105 | self.layer = layer 106 | self.layer_idx = layer_idx 107 | if layer == "hidden": 108 | assert layer_idx is not None 109 | assert 0 <= abs(layer_idx) <= 12 110 | 111 | def freeze(self): 112 | self.transformer = self.transformer.eval() 113 | #self.train = disabled_train 114 | for param in self.parameters(): 115 | param.requires_grad = False 116 | 117 | def forward(self, text): 118 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 119 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 120 | tokens = batch_encoding["input_ids"].to(self.device) 121 | outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") 122 | if self.layer == "last": 123 | z = outputs.last_hidden_state 124 | elif self.layer == "pooled": 125 | z = outputs.pooler_output[:, None, :] 126 | else: 127 | z = outputs.hidden_states[self.layer_idx] 128 | return z 129 | 130 | def encode(self, text): 131 | return self(text) 132 | 133 | 134 | class FrozenOpenCLIPEmbedder(AbstractEncoder): 135 | """ 136 | Uses the OpenCLIP transformer encoder for text 137 | """ 138 | LAYERS = [ 139 | #"pooled", 140 | "last", 141 | "penultimate" 142 | ] 143 | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, 144 | freeze=True, layer="last"): 145 | super().__init__() 146 | assert layer in self.LAYERS 147 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) 148 | del model.visual 149 | self.model = model 150 | 151 | self.device = device 152 | self.max_length = max_length 153 | if freeze: 154 | self.freeze() 155 | self.layer = layer 156 | if self.layer == "last": 157 | self.layer_idx = 0 158 | elif self.layer == "penultimate": 159 | self.layer_idx = 1 160 | else: 161 | raise NotImplementedError() 162 | 163 | def freeze(self): 164 | self.model = self.model.eval() 165 | for param in self.parameters(): 166 | param.requires_grad = False 167 | 168 | def forward(self, text): 169 | tokens = open_clip.tokenize(text) 170 | z = self.encode_with_transformer(tokens.to(self.device)) 171 | return z 172 | 173 | def encode_with_transformer(self, text): 174 | x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] 175 | x = x + self.model.positional_embedding 176 | x = x.permute(1, 0, 2) # NLD -> LND 177 | x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) 178 | x = x.permute(1, 0, 2) # LND -> NLD 179 | x = self.model.ln_final(x) 180 | return x 181 | 182 | def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): 183 | for i, r in enumerate(self.model.transformer.resblocks): 184 | if i == len(self.model.transformer.resblocks) - self.layer_idx: 185 | break 186 | if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): 187 | x = checkpoint(r, x, attn_mask) 188 | else: 189 | x = r(x, attn_mask=attn_mask) 190 | return x 191 | 192 | def encode(self, text): 193 | return self(text) 194 | 195 | 196 | class FrozenCLIPT5Encoder(AbstractEncoder): 197 | def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", 198 | clip_max_length=77, t5_max_length=77): 199 | super().__init__() 200 | self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) 201 | self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) 202 | print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, " 203 | f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.") 204 | 205 | def encode(self, text): 206 | return self(text) 207 | 208 | def forward(self, text): 209 | clip_z = self.clip_encoder.encode(text) 210 | t5_z = self.t5_encoder.encode(text) 211 | return [clip_z, t5_z] 212 | 213 | 214 | -------------------------------------------------------------------------------- /ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rese1f/StableVideo/298b52f8df5a1edb93ee46c59001833a2b291aa5/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /ldm/modules/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rese1f/StableVideo/298b52f8df5a1edb93ee46c59001833a2b291aa5/ldm/modules/midas/__init__.py -------------------------------------------------------------------------------- /ldm/modules/midas/api.py: -------------------------------------------------------------------------------- 1 | # based on https://github.com/isl-org/MiDaS 2 | 3 | import cv2 4 | import torch 5 | import torch.nn as nn 6 | from torchvision.transforms import Compose 7 | 8 | from ldm.modules.midas.midas.dpt_depth import DPTDepthModel 9 | from ldm.modules.midas.midas.midas_net import MidasNet 10 | from ldm.modules.midas.midas.midas_net_custom import MidasNet_small 11 | from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet 12 | 13 | 14 | ISL_PATHS = { 15 | "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt", 16 | "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt", 17 | "midas_v21": "", 18 | "midas_v21_small": "", 19 | } 20 | 21 | 22 | def disabled_train(self, mode=True): 23 | """Overwrite model.train with this function to make sure train/eval mode 24 | does not change anymore.""" 25 | return self 26 | 27 | 28 | def load_midas_transform(model_type): 29 | # https://github.com/isl-org/MiDaS/blob/master/run.py 30 | # load transform only 31 | if model_type == "dpt_large": # DPT-Large 32 | net_w, net_h = 384, 384 33 | resize_mode = "minimal" 34 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 35 | 36 | elif model_type == "dpt_hybrid": # DPT-Hybrid 37 | net_w, net_h = 384, 384 38 | resize_mode = "minimal" 39 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 40 | 41 | elif model_type == "midas_v21": 42 | net_w, net_h = 384, 384 43 | resize_mode = "upper_bound" 44 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 45 | 46 | elif model_type == "midas_v21_small": 47 | net_w, net_h = 256, 256 48 | resize_mode = "upper_bound" 49 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 50 | 51 | else: 52 | assert False, f"model_type '{model_type}' not implemented, use: --model_type large" 53 | 54 | transform = Compose( 55 | [ 56 | Resize( 57 | net_w, 58 | net_h, 59 | resize_target=None, 60 | keep_aspect_ratio=True, 61 | ensure_multiple_of=32, 62 | resize_method=resize_mode, 63 | image_interpolation_method=cv2.INTER_CUBIC, 64 | ), 65 | normalization, 66 | PrepareForNet(), 67 | ] 68 | ) 69 | 70 | return transform 71 | 72 | 73 | def load_model(model_type): 74 | # https://github.com/isl-org/MiDaS/blob/master/run.py 75 | # load network 76 | model_path = ISL_PATHS[model_type] 77 | if model_type == "dpt_large": # DPT-Large 78 | model = DPTDepthModel( 79 | path=model_path, 80 | backbone="vitl16_384", 81 | non_negative=True, 82 | ) 83 | net_w, net_h = 384, 384 84 | resize_mode = "minimal" 85 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 86 | 87 | elif model_type == "dpt_hybrid": # DPT-Hybrid 88 | model = DPTDepthModel( 89 | path=model_path, 90 | backbone="vitb_rn50_384", 91 | non_negative=True, 92 | ) 93 | net_w, net_h = 384, 384 94 | resize_mode = "minimal" 95 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 96 | 97 | elif model_type == "midas_v21": 98 | model = MidasNet(model_path, non_negative=True) 99 | net_w, net_h = 384, 384 100 | resize_mode = "upper_bound" 101 | normalization = NormalizeImage( 102 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 103 | ) 104 | 105 | elif model_type == "midas_v21_small": 106 | model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, 107 | non_negative=True, blocks={'expand': True}) 108 | net_w, net_h = 256, 256 109 | resize_mode = "upper_bound" 110 | normalization = NormalizeImage( 111 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 112 | ) 113 | 114 | else: 115 | print(f"model_type '{model_type}' not implemented, use: --model_type large") 116 | assert False 117 | 118 | transform = Compose( 119 | [ 120 | Resize( 121 | net_w, 122 | net_h, 123 | resize_target=None, 124 | keep_aspect_ratio=True, 125 | ensure_multiple_of=32, 126 | resize_method=resize_mode, 127 | image_interpolation_method=cv2.INTER_CUBIC, 128 | ), 129 | normalization, 130 | PrepareForNet(), 131 | ] 132 | ) 133 | 134 | return model.eval(), transform 135 | 136 | 137 | class MiDaSInference(nn.Module): 138 | MODEL_TYPES_TORCH_HUB = [ 139 | "DPT_Large", 140 | "DPT_Hybrid", 141 | "MiDaS_small" 142 | ] 143 | MODEL_TYPES_ISL = [ 144 | "dpt_large", 145 | "dpt_hybrid", 146 | "midas_v21", 147 | "midas_v21_small", 148 | ] 149 | 150 | def __init__(self, model_type): 151 | super().__init__() 152 | assert (model_type in self.MODEL_TYPES_ISL) 153 | model, _ = load_model(model_type) 154 | self.model = model 155 | self.model.train = disabled_train 156 | 157 | def forward(self, x): 158 | # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array 159 | # NOTE: we expect that the correct transform has been called during dataloading. 160 | with torch.no_grad(): 161 | prediction = self.model(x) 162 | prediction = torch.nn.functional.interpolate( 163 | prediction.unsqueeze(1), 164 | size=x.shape[2:], 165 | mode="bicubic", 166 | align_corners=False, 167 | ) 168 | assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3]) 169 | return prediction 170 | 171 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rese1f/StableVideo/298b52f8df5a1edb93ee46c59001833a2b291aa5/ldm/modules/midas/midas/__init__.py -------------------------------------------------------------------------------- /ldm/modules/midas/midas/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseModel(torch.nn.Module): 5 | def load(self, path): 6 | """Load model from file. 7 | 8 | Args: 9 | path (str): file path 10 | """ 11 | parameters = torch.load(path, map_location=torch.device('cpu')) 12 | 13 | if "optimizer" in parameters: 14 | parameters = parameters["model"] 15 | 16 | self.load_state_dict(parameters) 17 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .vit import ( 5 | _make_pretrained_vitb_rn50_384, 6 | _make_pretrained_vitl16_384, 7 | _make_pretrained_vitb16_384, 8 | forward_vit, 9 | ) 10 | 11 | def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): 12 | if backbone == "vitl16_384": 13 | pretrained = _make_pretrained_vitl16_384( 14 | use_pretrained, hooks=hooks, use_readout=use_readout 15 | ) 16 | scratch = _make_scratch( 17 | [256, 512, 1024, 1024], features, groups=groups, expand=expand 18 | ) # ViT-L/16 - 85.0% Top1 (backbone) 19 | elif backbone == "vitb_rn50_384": 20 | pretrained = _make_pretrained_vitb_rn50_384( 21 | use_pretrained, 22 | hooks=hooks, 23 | use_vit_only=use_vit_only, 24 | use_readout=use_readout, 25 | ) 26 | scratch = _make_scratch( 27 | [256, 512, 768, 768], features, groups=groups, expand=expand 28 | ) # ViT-H/16 - 85.0% Top1 (backbone) 29 | elif backbone == "vitb16_384": 30 | pretrained = _make_pretrained_vitb16_384( 31 | use_pretrained, hooks=hooks, use_readout=use_readout 32 | ) 33 | scratch = _make_scratch( 34 | [96, 192, 384, 768], features, groups=groups, expand=expand 35 | ) # ViT-B/16 - 84.6% Top1 (backbone) 36 | elif backbone == "resnext101_wsl": 37 | pretrained = _make_pretrained_resnext101_wsl(use_pretrained) 38 | scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 39 | elif backbone == "efficientnet_lite3": 40 | pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) 41 | scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 42 | else: 43 | print(f"Backbone '{backbone}' not implemented") 44 | assert False 45 | 46 | return pretrained, scratch 47 | 48 | 49 | def _make_scratch(in_shape, out_shape, groups=1, expand=False): 50 | scratch = nn.Module() 51 | 52 | out_shape1 = out_shape 53 | out_shape2 = out_shape 54 | out_shape3 = out_shape 55 | out_shape4 = out_shape 56 | if expand==True: 57 | out_shape1 = out_shape 58 | out_shape2 = out_shape*2 59 | out_shape3 = out_shape*4 60 | out_shape4 = out_shape*8 61 | 62 | scratch.layer1_rn = nn.Conv2d( 63 | in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 64 | ) 65 | scratch.layer2_rn = nn.Conv2d( 66 | in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 67 | ) 68 | scratch.layer3_rn = nn.Conv2d( 69 | in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 70 | ) 71 | scratch.layer4_rn = nn.Conv2d( 72 | in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 73 | ) 74 | 75 | return scratch 76 | 77 | 78 | def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): 79 | efficientnet = torch.hub.load( 80 | "rwightman/gen-efficientnet-pytorch", 81 | "tf_efficientnet_lite3", 82 | pretrained=use_pretrained, 83 | exportable=exportable 84 | ) 85 | return _make_efficientnet_backbone(efficientnet) 86 | 87 | 88 | def _make_efficientnet_backbone(effnet): 89 | pretrained = nn.Module() 90 | 91 | pretrained.layer1 = nn.Sequential( 92 | effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] 93 | ) 94 | pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) 95 | pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) 96 | pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) 97 | 98 | return pretrained 99 | 100 | 101 | def _make_resnet_backbone(resnet): 102 | pretrained = nn.Module() 103 | pretrained.layer1 = nn.Sequential( 104 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 105 | ) 106 | 107 | pretrained.layer2 = resnet.layer2 108 | pretrained.layer3 = resnet.layer3 109 | pretrained.layer4 = resnet.layer4 110 | 111 | return pretrained 112 | 113 | 114 | def _make_pretrained_resnext101_wsl(use_pretrained): 115 | resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") 116 | return _make_resnet_backbone(resnet) 117 | 118 | 119 | 120 | class Interpolate(nn.Module): 121 | """Interpolation module. 122 | """ 123 | 124 | def __init__(self, scale_factor, mode, align_corners=False): 125 | """Init. 126 | 127 | Args: 128 | scale_factor (float): scaling 129 | mode (str): interpolation mode 130 | """ 131 | super(Interpolate, self).__init__() 132 | 133 | self.interp = nn.functional.interpolate 134 | self.scale_factor = scale_factor 135 | self.mode = mode 136 | self.align_corners = align_corners 137 | 138 | def forward(self, x): 139 | """Forward pass. 140 | 141 | Args: 142 | x (tensor): input 143 | 144 | Returns: 145 | tensor: interpolated data 146 | """ 147 | 148 | x = self.interp( 149 | x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners 150 | ) 151 | 152 | return x 153 | 154 | 155 | class ResidualConvUnit(nn.Module): 156 | """Residual convolution module. 157 | """ 158 | 159 | def __init__(self, features): 160 | """Init. 161 | 162 | Args: 163 | features (int): number of features 164 | """ 165 | super().__init__() 166 | 167 | self.conv1 = nn.Conv2d( 168 | features, features, kernel_size=3, stride=1, padding=1, bias=True 169 | ) 170 | 171 | self.conv2 = nn.Conv2d( 172 | features, features, kernel_size=3, stride=1, padding=1, bias=True 173 | ) 174 | 175 | self.relu = nn.ReLU(inplace=True) 176 | 177 | def forward(self, x): 178 | """Forward pass. 179 | 180 | Args: 181 | x (tensor): input 182 | 183 | Returns: 184 | tensor: output 185 | """ 186 | out = self.relu(x) 187 | out = self.conv1(out) 188 | out = self.relu(out) 189 | out = self.conv2(out) 190 | 191 | return out + x 192 | 193 | 194 | class FeatureFusionBlock(nn.Module): 195 | """Feature fusion block. 196 | """ 197 | 198 | def __init__(self, features): 199 | """Init. 200 | 201 | Args: 202 | features (int): number of features 203 | """ 204 | super(FeatureFusionBlock, self).__init__() 205 | 206 | self.resConfUnit1 = ResidualConvUnit(features) 207 | self.resConfUnit2 = ResidualConvUnit(features) 208 | 209 | def forward(self, *xs): 210 | """Forward pass. 211 | 212 | Returns: 213 | tensor: output 214 | """ 215 | output = xs[0] 216 | 217 | if len(xs) == 2: 218 | output += self.resConfUnit1(xs[1]) 219 | 220 | output = self.resConfUnit2(output) 221 | 222 | output = nn.functional.interpolate( 223 | output, scale_factor=2, mode="bilinear", align_corners=True 224 | ) 225 | 226 | return output 227 | 228 | 229 | 230 | 231 | class ResidualConvUnit_custom(nn.Module): 232 | """Residual convolution module. 233 | """ 234 | 235 | def __init__(self, features, activation, bn): 236 | """Init. 237 | 238 | Args: 239 | features (int): number of features 240 | """ 241 | super().__init__() 242 | 243 | self.bn = bn 244 | 245 | self.groups=1 246 | 247 | self.conv1 = nn.Conv2d( 248 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups 249 | ) 250 | 251 | self.conv2 = nn.Conv2d( 252 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups 253 | ) 254 | 255 | if self.bn==True: 256 | self.bn1 = nn.BatchNorm2d(features) 257 | self.bn2 = nn.BatchNorm2d(features) 258 | 259 | self.activation = activation 260 | 261 | self.skip_add = nn.quantized.FloatFunctional() 262 | 263 | def forward(self, x): 264 | """Forward pass. 265 | 266 | Args: 267 | x (tensor): input 268 | 269 | Returns: 270 | tensor: output 271 | """ 272 | 273 | out = self.activation(x) 274 | out = self.conv1(out) 275 | if self.bn==True: 276 | out = self.bn1(out) 277 | 278 | out = self.activation(out) 279 | out = self.conv2(out) 280 | if self.bn==True: 281 | out = self.bn2(out) 282 | 283 | if self.groups > 1: 284 | out = self.conv_merge(out) 285 | 286 | return self.skip_add.add(out, x) 287 | 288 | # return out + x 289 | 290 | 291 | class FeatureFusionBlock_custom(nn.Module): 292 | """Feature fusion block. 293 | """ 294 | 295 | def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): 296 | """Init. 297 | 298 | Args: 299 | features (int): number of features 300 | """ 301 | super(FeatureFusionBlock_custom, self).__init__() 302 | 303 | self.deconv = deconv 304 | self.align_corners = align_corners 305 | 306 | self.groups=1 307 | 308 | self.expand = expand 309 | out_features = features 310 | if self.expand==True: 311 | out_features = features//2 312 | 313 | self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) 314 | 315 | self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) 316 | self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) 317 | 318 | self.skip_add = nn.quantized.FloatFunctional() 319 | 320 | def forward(self, *xs): 321 | """Forward pass. 322 | 323 | Returns: 324 | tensor: output 325 | """ 326 | output = xs[0] 327 | 328 | if len(xs) == 2: 329 | res = self.resConfUnit1(xs[1]) 330 | output = self.skip_add.add(output, res) 331 | # output += res 332 | 333 | output = self.resConfUnit2(output) 334 | 335 | output = nn.functional.interpolate( 336 | output, scale_factor=2, mode="bilinear", align_corners=self.align_corners 337 | ) 338 | 339 | output = self.out_conv(output) 340 | 341 | return output 342 | 343 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/dpt_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base_model import BaseModel 6 | from .blocks import ( 7 | FeatureFusionBlock, 8 | FeatureFusionBlock_custom, 9 | Interpolate, 10 | _make_encoder, 11 | forward_vit, 12 | ) 13 | 14 | 15 | def _make_fusion_block(features, use_bn): 16 | return FeatureFusionBlock_custom( 17 | features, 18 | nn.ReLU(False), 19 | deconv=False, 20 | bn=use_bn, 21 | expand=False, 22 | align_corners=True, 23 | ) 24 | 25 | 26 | class DPT(BaseModel): 27 | def __init__( 28 | self, 29 | head, 30 | features=256, 31 | backbone="vitb_rn50_384", 32 | readout="project", 33 | channels_last=False, 34 | use_bn=False, 35 | ): 36 | 37 | super(DPT, self).__init__() 38 | 39 | self.channels_last = channels_last 40 | 41 | hooks = { 42 | "vitb_rn50_384": [0, 1, 8, 11], 43 | "vitb16_384": [2, 5, 8, 11], 44 | "vitl16_384": [5, 11, 17, 23], 45 | } 46 | 47 | # Instantiate backbone and reassemble blocks 48 | self.pretrained, self.scratch = _make_encoder( 49 | backbone, 50 | features, 51 | False, # Set to true of you want to train from scratch, uses ImageNet weights 52 | groups=1, 53 | expand=False, 54 | exportable=False, 55 | hooks=hooks[backbone], 56 | use_readout=readout, 57 | ) 58 | 59 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 60 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 61 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn) 62 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 63 | 64 | self.scratch.output_conv = head 65 | 66 | 67 | def forward(self, x): 68 | if self.channels_last == True: 69 | x.contiguous(memory_format=torch.channels_last) 70 | 71 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) 72 | 73 | layer_1_rn = self.scratch.layer1_rn(layer_1) 74 | layer_2_rn = self.scratch.layer2_rn(layer_2) 75 | layer_3_rn = self.scratch.layer3_rn(layer_3) 76 | layer_4_rn = self.scratch.layer4_rn(layer_4) 77 | 78 | path_4 = self.scratch.refinenet4(layer_4_rn) 79 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 80 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 81 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 82 | 83 | out = self.scratch.output_conv(path_1) 84 | 85 | return out 86 | 87 | 88 | class DPTDepthModel(DPT): 89 | def __init__(self, path=None, non_negative=True, **kwargs): 90 | features = kwargs["features"] if "features" in kwargs else 256 91 | 92 | head = nn.Sequential( 93 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), 94 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 95 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 96 | nn.ReLU(True), 97 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 98 | nn.ReLU(True) if non_negative else nn.Identity(), 99 | nn.Identity(), 100 | ) 101 | 102 | super().__init__(head, **kwargs) 103 | 104 | if path is not None: 105 | self.load(path) 106 | 107 | def forward(self, x): 108 | return super().forward(x).squeeze(dim=1) 109 | 110 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/midas_net.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=256, non_negative=True): 17 | """Init. 18 | 19 | Args: 20 | path (str, optional): Path to saved model. Defaults to None. 21 | features (int, optional): Number of features. Defaults to 256. 22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 23 | """ 24 | print("Loading weights: ", path) 25 | 26 | super(MidasNet, self).__init__() 27 | 28 | use_pretrained = False if path is None else True 29 | 30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) 31 | 32 | self.scratch.refinenet4 = FeatureFusionBlock(features) 33 | self.scratch.refinenet3 = FeatureFusionBlock(features) 34 | self.scratch.refinenet2 = FeatureFusionBlock(features) 35 | self.scratch.refinenet1 = FeatureFusionBlock(features) 36 | 37 | self.scratch.output_conv = nn.Sequential( 38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 39 | Interpolate(scale_factor=2, mode="bilinear"), 40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 41 | nn.ReLU(True), 42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 43 | nn.ReLU(True) if non_negative else nn.Identity(), 44 | ) 45 | 46 | if path: 47 | self.load(path) 48 | 49 | def forward(self, x): 50 | """Forward pass. 51 | 52 | Args: 53 | x (tensor): input data (image) 54 | 55 | Returns: 56 | tensor: depth 57 | """ 58 | 59 | layer_1 = self.pretrained.layer1(x) 60 | layer_2 = self.pretrained.layer2(layer_1) 61 | layer_3 = self.pretrained.layer3(layer_2) 62 | layer_4 = self.pretrained.layer4(layer_3) 63 | 64 | layer_1_rn = self.scratch.layer1_rn(layer_1) 65 | layer_2_rn = self.scratch.layer2_rn(layer_2) 66 | layer_3_rn = self.scratch.layer3_rn(layer_3) 67 | layer_4_rn = self.scratch.layer4_rn(layer_4) 68 | 69 | path_4 = self.scratch.refinenet4(layer_4_rn) 70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 73 | 74 | out = self.scratch.output_conv(path_1) 75 | 76 | return torch.squeeze(out, dim=1) 77 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/midas_net_custom.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet_small(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, 17 | blocks={'expand': True}): 18 | """Init. 19 | 20 | Args: 21 | path (str, optional): Path to saved model. Defaults to None. 22 | features (int, optional): Number of features. Defaults to 256. 23 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 24 | """ 25 | print("Loading weights: ", path) 26 | 27 | super(MidasNet_small, self).__init__() 28 | 29 | use_pretrained = False if path else True 30 | 31 | self.channels_last = channels_last 32 | self.blocks = blocks 33 | self.backbone = backbone 34 | 35 | self.groups = 1 36 | 37 | features1=features 38 | features2=features 39 | features3=features 40 | features4=features 41 | self.expand = False 42 | if "expand" in self.blocks and self.blocks['expand'] == True: 43 | self.expand = True 44 | features1=features 45 | features2=features*2 46 | features3=features*4 47 | features4=features*8 48 | 49 | self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) 50 | 51 | self.scratch.activation = nn.ReLU(False) 52 | 53 | self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 54 | self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 55 | self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 56 | self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) 57 | 58 | 59 | self.scratch.output_conv = nn.Sequential( 60 | nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), 61 | Interpolate(scale_factor=2, mode="bilinear"), 62 | nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), 63 | self.scratch.activation, 64 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 65 | nn.ReLU(True) if non_negative else nn.Identity(), 66 | nn.Identity(), 67 | ) 68 | 69 | if path: 70 | self.load(path) 71 | 72 | 73 | def forward(self, x): 74 | """Forward pass. 75 | 76 | Args: 77 | x (tensor): input data (image) 78 | 79 | Returns: 80 | tensor: depth 81 | """ 82 | if self.channels_last==True: 83 | print("self.channels_last = ", self.channels_last) 84 | x.contiguous(memory_format=torch.channels_last) 85 | 86 | 87 | layer_1 = self.pretrained.layer1(x) 88 | layer_2 = self.pretrained.layer2(layer_1) 89 | layer_3 = self.pretrained.layer3(layer_2) 90 | layer_4 = self.pretrained.layer4(layer_3) 91 | 92 | layer_1_rn = self.scratch.layer1_rn(layer_1) 93 | layer_2_rn = self.scratch.layer2_rn(layer_2) 94 | layer_3_rn = self.scratch.layer3_rn(layer_3) 95 | layer_4_rn = self.scratch.layer4_rn(layer_4) 96 | 97 | 98 | path_4 = self.scratch.refinenet4(layer_4_rn) 99 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 100 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 101 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 102 | 103 | out = self.scratch.output_conv(path_1) 104 | 105 | return torch.squeeze(out, dim=1) 106 | 107 | 108 | 109 | def fuse_model(m): 110 | prev_previous_type = nn.Identity() 111 | prev_previous_name = '' 112 | previous_type = nn.Identity() 113 | previous_name = '' 114 | for name, module in m.named_modules(): 115 | if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: 116 | # print("FUSED ", prev_previous_name, previous_name, name) 117 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) 118 | elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: 119 | # print("FUSED ", prev_previous_name, previous_name) 120 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) 121 | # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: 122 | # print("FUSED ", previous_name, name) 123 | # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) 124 | 125 | prev_previous_type = previous_type 126 | prev_previous_name = previous_name 127 | previous_type = type(module) 128 | previous_name = name -------------------------------------------------------------------------------- /ldm/modules/midas/midas/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import math 4 | 5 | 6 | def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): 7 | """Rezise the sample to ensure the given size. Keeps aspect ratio. 8 | 9 | Args: 10 | sample (dict): sample 11 | size (tuple): image size 12 | 13 | Returns: 14 | tuple: new size 15 | """ 16 | shape = list(sample["disparity"].shape) 17 | 18 | if shape[0] >= size[0] and shape[1] >= size[1]: 19 | return sample 20 | 21 | scale = [0, 0] 22 | scale[0] = size[0] / shape[0] 23 | scale[1] = size[1] / shape[1] 24 | 25 | scale = max(scale) 26 | 27 | shape[0] = math.ceil(scale * shape[0]) 28 | shape[1] = math.ceil(scale * shape[1]) 29 | 30 | # resize 31 | sample["image"] = cv2.resize( 32 | sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method 33 | ) 34 | 35 | sample["disparity"] = cv2.resize( 36 | sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST 37 | ) 38 | sample["mask"] = cv2.resize( 39 | sample["mask"].astype(np.float32), 40 | tuple(shape[::-1]), 41 | interpolation=cv2.INTER_NEAREST, 42 | ) 43 | sample["mask"] = sample["mask"].astype(bool) 44 | 45 | return tuple(shape) 46 | 47 | 48 | class Resize(object): 49 | """Resize sample to given size (width, height). 50 | """ 51 | 52 | def __init__( 53 | self, 54 | width, 55 | height, 56 | resize_target=True, 57 | keep_aspect_ratio=False, 58 | ensure_multiple_of=1, 59 | resize_method="lower_bound", 60 | image_interpolation_method=cv2.INTER_AREA, 61 | ): 62 | """Init. 63 | 64 | Args: 65 | width (int): desired output width 66 | height (int): desired output height 67 | resize_target (bool, optional): 68 | True: Resize the full sample (image, mask, target). 69 | False: Resize image only. 70 | Defaults to True. 71 | keep_aspect_ratio (bool, optional): 72 | True: Keep the aspect ratio of the input sample. 73 | Output sample might not have the given width and height, and 74 | resize behaviour depends on the parameter 'resize_method'. 75 | Defaults to False. 76 | ensure_multiple_of (int, optional): 77 | Output width and height is constrained to be multiple of this parameter. 78 | Defaults to 1. 79 | resize_method (str, optional): 80 | "lower_bound": Output will be at least as large as the given size. 81 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) 82 | "minimal": Scale as least as possible. (Output size might be smaller than given size.) 83 | Defaults to "lower_bound". 84 | """ 85 | self.__width = width 86 | self.__height = height 87 | 88 | self.__resize_target = resize_target 89 | self.__keep_aspect_ratio = keep_aspect_ratio 90 | self.__multiple_of = ensure_multiple_of 91 | self.__resize_method = resize_method 92 | self.__image_interpolation_method = image_interpolation_method 93 | 94 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None): 95 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) 96 | 97 | if max_val is not None and y > max_val: 98 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) 99 | 100 | if y < min_val: 101 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) 102 | 103 | return y 104 | 105 | def get_size(self, width, height): 106 | # determine new height and width 107 | scale_height = self.__height / height 108 | scale_width = self.__width / width 109 | 110 | if self.__keep_aspect_ratio: 111 | if self.__resize_method == "lower_bound": 112 | # scale such that output size is lower bound 113 | if scale_width > scale_height: 114 | # fit width 115 | scale_height = scale_width 116 | else: 117 | # fit height 118 | scale_width = scale_height 119 | elif self.__resize_method == "upper_bound": 120 | # scale such that output size is upper bound 121 | if scale_width < scale_height: 122 | # fit width 123 | scale_height = scale_width 124 | else: 125 | # fit height 126 | scale_width = scale_height 127 | elif self.__resize_method == "minimal": 128 | # scale as least as possbile 129 | if abs(1 - scale_width) < abs(1 - scale_height): 130 | # fit width 131 | scale_height = scale_width 132 | else: 133 | # fit height 134 | scale_width = scale_height 135 | else: 136 | raise ValueError( 137 | f"resize_method {self.__resize_method} not implemented" 138 | ) 139 | 140 | if self.__resize_method == "lower_bound": 141 | new_height = self.constrain_to_multiple_of( 142 | scale_height * height, min_val=self.__height 143 | ) 144 | new_width = self.constrain_to_multiple_of( 145 | scale_width * width, min_val=self.__width 146 | ) 147 | elif self.__resize_method == "upper_bound": 148 | new_height = self.constrain_to_multiple_of( 149 | scale_height * height, max_val=self.__height 150 | ) 151 | new_width = self.constrain_to_multiple_of( 152 | scale_width * width, max_val=self.__width 153 | ) 154 | elif self.__resize_method == "minimal": 155 | new_height = self.constrain_to_multiple_of(scale_height * height) 156 | new_width = self.constrain_to_multiple_of(scale_width * width) 157 | else: 158 | raise ValueError(f"resize_method {self.__resize_method} not implemented") 159 | 160 | return (new_width, new_height) 161 | 162 | def __call__(self, sample): 163 | width, height = self.get_size( 164 | sample["image"].shape[1], sample["image"].shape[0] 165 | ) 166 | 167 | # resize sample 168 | sample["image"] = cv2.resize( 169 | sample["image"], 170 | (width, height), 171 | interpolation=self.__image_interpolation_method, 172 | ) 173 | 174 | if self.__resize_target: 175 | if "disparity" in sample: 176 | sample["disparity"] = cv2.resize( 177 | sample["disparity"], 178 | (width, height), 179 | interpolation=cv2.INTER_NEAREST, 180 | ) 181 | 182 | if "depth" in sample: 183 | sample["depth"] = cv2.resize( 184 | sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST 185 | ) 186 | 187 | sample["mask"] = cv2.resize( 188 | sample["mask"].astype(np.float32), 189 | (width, height), 190 | interpolation=cv2.INTER_NEAREST, 191 | ) 192 | sample["mask"] = sample["mask"].astype(bool) 193 | 194 | return sample 195 | 196 | 197 | class NormalizeImage(object): 198 | """Normlize image by given mean and std. 199 | """ 200 | 201 | def __init__(self, mean, std): 202 | self.__mean = mean 203 | self.__std = std 204 | 205 | def __call__(self, sample): 206 | sample["image"] = (sample["image"] - self.__mean) / self.__std 207 | 208 | return sample 209 | 210 | 211 | class PrepareForNet(object): 212 | """Prepare sample for usage as network input. 213 | """ 214 | 215 | def __init__(self): 216 | pass 217 | 218 | def __call__(self, sample): 219 | image = np.transpose(sample["image"], (2, 0, 1)) 220 | sample["image"] = np.ascontiguousarray(image).astype(np.float32) 221 | 222 | if "mask" in sample: 223 | sample["mask"] = sample["mask"].astype(np.float32) 224 | sample["mask"] = np.ascontiguousarray(sample["mask"]) 225 | 226 | if "disparity" in sample: 227 | disparity = sample["disparity"].astype(np.float32) 228 | sample["disparity"] = np.ascontiguousarray(disparity) 229 | 230 | if "depth" in sample: 231 | depth = sample["depth"].astype(np.float32) 232 | sample["depth"] = np.ascontiguousarray(depth) 233 | 234 | return sample 235 | -------------------------------------------------------------------------------- /ldm/modules/midas/utils.py: -------------------------------------------------------------------------------- 1 | """Utils for monoDepth.""" 2 | import sys 3 | import re 4 | import numpy as np 5 | import cv2 6 | import torch 7 | 8 | 9 | def read_pfm(path): 10 | """Read pfm file. 11 | 12 | Args: 13 | path (str): path to file 14 | 15 | Returns: 16 | tuple: (data, scale) 17 | """ 18 | with open(path, "rb") as file: 19 | 20 | color = None 21 | width = None 22 | height = None 23 | scale = None 24 | endian = None 25 | 26 | header = file.readline().rstrip() 27 | if header.decode("ascii") == "PF": 28 | color = True 29 | elif header.decode("ascii") == "Pf": 30 | color = False 31 | else: 32 | raise Exception("Not a PFM file: " + path) 33 | 34 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) 35 | if dim_match: 36 | width, height = list(map(int, dim_match.groups())) 37 | else: 38 | raise Exception("Malformed PFM header.") 39 | 40 | scale = float(file.readline().decode("ascii").rstrip()) 41 | if scale < 0: 42 | # little-endian 43 | endian = "<" 44 | scale = -scale 45 | else: 46 | # big-endian 47 | endian = ">" 48 | 49 | data = np.fromfile(file, endian + "f") 50 | shape = (height, width, 3) if color else (height, width) 51 | 52 | data = np.reshape(data, shape) 53 | data = np.flipud(data) 54 | 55 | return data, scale 56 | 57 | 58 | def write_pfm(path, image, scale=1): 59 | """Write pfm file. 60 | 61 | Args: 62 | path (str): pathto file 63 | image (array): data 64 | scale (int, optional): Scale. Defaults to 1. 65 | """ 66 | 67 | with open(path, "wb") as file: 68 | color = None 69 | 70 | if image.dtype.name != "float32": 71 | raise Exception("Image dtype must be float32.") 72 | 73 | image = np.flipud(image) 74 | 75 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 76 | color = True 77 | elif ( 78 | len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 79 | ): # greyscale 80 | color = False 81 | else: 82 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") 83 | 84 | file.write("PF\n" if color else "Pf\n".encode()) 85 | file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) 86 | 87 | endian = image.dtype.byteorder 88 | 89 | if endian == "<" or endian == "=" and sys.byteorder == "little": 90 | scale = -scale 91 | 92 | file.write("%f\n".encode() % scale) 93 | 94 | image.tofile(file) 95 | 96 | 97 | def read_image(path): 98 | """Read image and output RGB image (0-1). 99 | 100 | Args: 101 | path (str): path to file 102 | 103 | Returns: 104 | array: RGB image (0-1) 105 | """ 106 | img = cv2.imread(path) 107 | 108 | if img.ndim == 2: 109 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 110 | 111 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 112 | 113 | return img 114 | 115 | 116 | def resize_image(img): 117 | """Resize image and make it fit for network. 118 | 119 | Args: 120 | img (array): image 121 | 122 | Returns: 123 | tensor: data ready for network 124 | """ 125 | height_orig = img.shape[0] 126 | width_orig = img.shape[1] 127 | 128 | if width_orig > height_orig: 129 | scale = width_orig / 384 130 | else: 131 | scale = height_orig / 384 132 | 133 | height = (np.ceil(height_orig / scale / 32) * 32).astype(int) 134 | width = (np.ceil(width_orig / scale / 32) * 32).astype(int) 135 | 136 | img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) 137 | 138 | img_resized = ( 139 | torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() 140 | ) 141 | img_resized = img_resized.unsqueeze(0) 142 | 143 | return img_resized 144 | 145 | 146 | def resize_depth(depth, width, height): 147 | """Resize depth map and bring to CPU (numpy). 148 | 149 | Args: 150 | depth (tensor): depth 151 | width (int): image width 152 | height (int): image height 153 | 154 | Returns: 155 | array: processed depth 156 | """ 157 | depth = torch.squeeze(depth[0, :, :, :]).to("cpu") 158 | 159 | depth_resized = cv2.resize( 160 | depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC 161 | ) 162 | 163 | return depth_resized 164 | 165 | def write_depth(path, depth, bits=1): 166 | """Write depth map to pfm and png file. 167 | 168 | Args: 169 | path (str): filepath without extension 170 | depth (array): depth 171 | """ 172 | write_pfm(path + ".pfm", depth.astype(np.float32)) 173 | 174 | depth_min = depth.min() 175 | depth_max = depth.max() 176 | 177 | max_val = (2**(8*bits))-1 178 | 179 | if depth_max - depth_min > np.finfo("float").eps: 180 | out = max_val * (depth - depth_min) / (depth_max - depth_min) 181 | else: 182 | out = np.zeros(depth.shape, dtype=depth.type) 183 | 184 | if bits == 1: 185 | cv2.imwrite(path + ".png", out.astype("uint8")) 186 | elif bits == 2: 187 | cv2.imwrite(path + ".png", out.astype("uint16")) 188 | 189 | return 190 | -------------------------------------------------------------------------------- /ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | from torch import optim 5 | import numpy as np 6 | 7 | from inspect import isfunction 8 | from PIL import Image, ImageDraw, ImageFont 9 | 10 | 11 | def log_txt_as_img(wh, xc, size=10): 12 | # wh a tuple of (width, height) 13 | # xc a list of captions to plot 14 | b = len(xc) 15 | txts = list() 16 | for bi in range(b): 17 | txt = Image.new("RGB", wh, color="white") 18 | draw = ImageDraw.Draw(txt) 19 | font = ImageFont.truetype('font/DejaVuSans.ttf', size=size) 20 | nc = int(40 * (wh[0] / 256)) 21 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 22 | 23 | try: 24 | draw.text((0, 0), lines, fill="black", font=font) 25 | except UnicodeEncodeError: 26 | print("Cant encode string for logging. Skipping.") 27 | 28 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 29 | txts.append(txt) 30 | txts = np.stack(txts) 31 | txts = torch.tensor(txts) 32 | return txts 33 | 34 | 35 | def ismap(x): 36 | if not isinstance(x, torch.Tensor): 37 | return False 38 | return (len(x.shape) == 4) and (x.shape[1] > 3) 39 | 40 | 41 | def isimage(x): 42 | if not isinstance(x,torch.Tensor): 43 | return False 44 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 45 | 46 | 47 | def exists(x): 48 | return x is not None 49 | 50 | 51 | def default(val, d): 52 | if exists(val): 53 | return val 54 | return d() if isfunction(d) else d 55 | 56 | 57 | def mean_flat(tensor): 58 | """ 59 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 60 | Take the mean over all non-batch dimensions. 61 | """ 62 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 63 | 64 | 65 | def count_params(model, verbose=False): 66 | total_params = sum(p.numel() for p in model.parameters()) 67 | if verbose: 68 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") 69 | return total_params 70 | 71 | 72 | def instantiate_from_config(config): 73 | if not "target" in config: 74 | if config == '__is_first_stage__': 75 | return None 76 | elif config == "__is_unconditional__": 77 | return None 78 | raise KeyError("Expected key `target` to instantiate.") 79 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 80 | 81 | 82 | def get_obj_from_str(string, reload=False): 83 | module, cls = string.rsplit(".", 1) 84 | if reload: 85 | module_imp = importlib.import_module(module) 86 | importlib.reload(module_imp) 87 | return getattr(importlib.import_module(module, package=None), cls) 88 | 89 | 90 | class AdamWwithEMAandWings(optim.Optimizer): 91 | # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 92 | def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using 93 | weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code 94 | ema_power=1., param_names=()): 95 | """AdamW that saves EMA versions of the parameters.""" 96 | if not 0.0 <= lr: 97 | raise ValueError("Invalid learning rate: {}".format(lr)) 98 | if not 0.0 <= eps: 99 | raise ValueError("Invalid epsilon value: {}".format(eps)) 100 | if not 0.0 <= betas[0] < 1.0: 101 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 102 | if not 0.0 <= betas[1] < 1.0: 103 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 104 | if not 0.0 <= weight_decay: 105 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 106 | if not 0.0 <= ema_decay <= 1.0: 107 | raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) 108 | defaults = dict(lr=lr, betas=betas, eps=eps, 109 | weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, 110 | ema_power=ema_power, param_names=param_names) 111 | super().__init__(params, defaults) 112 | 113 | def __setstate__(self, state): 114 | super().__setstate__(state) 115 | for group in self.param_groups: 116 | group.setdefault('amsgrad', False) 117 | 118 | @torch.no_grad() 119 | def step(self, closure=None): 120 | """Performs a single optimization step. 121 | Args: 122 | closure (callable, optional): A closure that reevaluates the model 123 | and returns the loss. 124 | """ 125 | loss = None 126 | if closure is not None: 127 | with torch.enable_grad(): 128 | loss = closure() 129 | 130 | for group in self.param_groups: 131 | params_with_grad = [] 132 | grads = [] 133 | exp_avgs = [] 134 | exp_avg_sqs = [] 135 | ema_params_with_grad = [] 136 | state_sums = [] 137 | max_exp_avg_sqs = [] 138 | state_steps = [] 139 | amsgrad = group['amsgrad'] 140 | beta1, beta2 = group['betas'] 141 | ema_decay = group['ema_decay'] 142 | ema_power = group['ema_power'] 143 | 144 | for p in group['params']: 145 | if p.grad is None: 146 | continue 147 | params_with_grad.append(p) 148 | if p.grad.is_sparse: 149 | raise RuntimeError('AdamW does not support sparse gradients') 150 | grads.append(p.grad) 151 | 152 | state = self.state[p] 153 | 154 | # State initialization 155 | if len(state) == 0: 156 | state['step'] = 0 157 | # Exponential moving average of gradient values 158 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 159 | # Exponential moving average of squared gradient values 160 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 161 | if amsgrad: 162 | # Maintains max of all exp. moving avg. of sq. grad. values 163 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 164 | # Exponential moving average of parameter values 165 | state['param_exp_avg'] = p.detach().float().clone() 166 | 167 | exp_avgs.append(state['exp_avg']) 168 | exp_avg_sqs.append(state['exp_avg_sq']) 169 | ema_params_with_grad.append(state['param_exp_avg']) 170 | 171 | if amsgrad: 172 | max_exp_avg_sqs.append(state['max_exp_avg_sq']) 173 | 174 | # update the steps for each param group update 175 | state['step'] += 1 176 | # record the step after step update 177 | state_steps.append(state['step']) 178 | 179 | optim._functional.adamw(params_with_grad, 180 | grads, 181 | exp_avgs, 182 | exp_avg_sqs, 183 | max_exp_avg_sqs, 184 | state_steps, 185 | amsgrad=amsgrad, 186 | beta1=beta1, 187 | beta2=beta2, 188 | lr=group['lr'], 189 | weight_decay=group['weight_decay'], 190 | eps=group['eps'], 191 | maximize=False) 192 | 193 | cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) 194 | for param, ema_param in zip(params_with_grad, ema_params_with_grad): 195 | ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) 196 | 197 | return loss -------------------------------------------------------------------------------- /log/LOG.md: -------------------------------------------------------------------------------- 1 | # for saving result -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.4.0 2 | aiofiles==23.2.1 3 | aiohttp==3.8.5 4 | aiosignal==1.3.1 5 | altair==5.0.1 6 | annotated-types==0.5.0 7 | antlr4-python3-runtime==4.9.3 8 | anyio==3.7.1 9 | async-timeout==4.0.3 10 | attrs==23.1.0 11 | cachetools==5.3.1 12 | certifi==2023.7.22 13 | charset-normalizer==3.2.0 14 | click==8.1.6 15 | cmake==3.27.2 16 | contourpy==1.1.0 17 | cycler==0.11.0 18 | einops==0.6.1 19 | fastapi==0.101.1 20 | ffmpy==0.3.1 21 | filelock==3.12.2 22 | fonttools==4.42.0 23 | frozenlist==1.4.0 24 | fsspec==2023.6.0 25 | ftfy==6.1.1 26 | future==0.18.3 27 | google-auth==2.22.0 28 | google-auth-oauthlib==1.0.0 29 | gradio==3.40.1 30 | gradio_client==0.4.0 31 | grpcio==1.57.0 32 | h11==0.14.0 33 | httpcore==0.17.3 34 | httpx==0.24.1 35 | huggingface-hub==0.16.4 36 | idna==3.4 37 | imageio==2.31.1 38 | imageio-ffmpeg==0.4.8 39 | importlib-resources==6.0.1 40 | Jinja2==3.1.2 41 | jsonschema==4.19.0 42 | jsonschema-specifications==2023.7.1 43 | kiwisolver==1.4.4 44 | lightning-utilities==0.9.0 45 | linkify-it-py==2.0.2 46 | lit==16.0.6 47 | Markdown==3.4.4 48 | markdown-it-py==2.2.0 49 | MarkupSafe==2.1.3 50 | matplotlib==3.7.2 51 | mdit-py-plugins==0.3.3 52 | mdurl==0.1.2 53 | mpmath==1.3.0 54 | multidict==6.0.4 55 | networkx==3.1 56 | numpy==1.25.2 57 | nvidia-cublas-cu11==11.10.3.66 58 | nvidia-cuda-cupti-cu11==11.7.101 59 | nvidia-cuda-nvrtc-cu11==11.7.99 60 | nvidia-cuda-runtime-cu11==11.7.99 61 | nvidia-cudnn-cu11==8.5.0.96 62 | nvidia-cufft-cu11==10.9.0.58 63 | nvidia-curand-cu11==10.2.10.91 64 | nvidia-cusolver-cu11==11.4.0.1 65 | nvidia-cusparse-cu11==11.7.4.91 66 | nvidia-nccl-cu11==2.14.3 67 | nvidia-nvtx-cu11==11.7.91 68 | oauthlib==3.2.2 69 | omegaconf==2.3.0 70 | open-clip-torch==2.20.0 71 | opencv-python==4.8.0.76 72 | orjson==3.9.5 73 | packaging==23.1 74 | pandas==2.0.3 75 | Pillow==10.0.0 76 | protobuf==3.20.3 77 | pyasn1==0.5.0 78 | pyasn1-modules==0.3.0 79 | pydantic==2.1.1 80 | pydantic_core==2.4.0 81 | pyDeprecate==0.3.1 82 | pydub==0.25.1 83 | pyparsing==3.0.9 84 | python-dateutil==2.8.2 85 | python-multipart==0.0.6 86 | pytorch-lightning==1.5.0 87 | pytz==2023.3 88 | PyYAML==6.0.1 89 | referencing==0.30.2 90 | regex==2023.8.8 91 | requests==2.31.0 92 | requests-oauthlib==1.3.1 93 | rpds-py==0.9.2 94 | rsa==4.9 95 | safetensors==0.3.2 96 | scipy==1.11.1 97 | semantic-version==2.10.0 98 | sentencepiece==0.1.99 99 | six==1.16.0 100 | sniffio==1.3.0 101 | starlette==0.27.0 102 | sympy==1.12 103 | tensorboard==2.14.0 104 | tensorboard-data-server==0.7.1 105 | timm==0.9.5 106 | tokenizers==0.13.3 107 | toolz==0.12.0 108 | torch==2.0.1 109 | torchmetrics==1.0.3 110 | torchvision==0.15.2 111 | tqdm==4.66.1 112 | transformers==4.31.0 113 | triton==2.0.0 114 | typing_extensions==4.7.1 115 | tzdata==2023.3 116 | uc-micro-py==1.0.2 117 | urllib3==1.26.16 118 | uvicorn==0.23.2 119 | wcwidth==0.2.6 120 | websockets==11.0.3 121 | Werkzeug==2.3.7 122 | yarl==1.9.2 123 | -------------------------------------------------------------------------------- /stablevideo/aggnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class AGGNet(nn.Module): 4 | def __init__(self) -> None: 5 | super().__init__() 6 | self.stage1=nn.Sequential( 7 | nn.Conv2d(in_channels=3,out_channels=64,kernel_size=3,padding=1,bias=False), 8 | nn.ReLU() 9 | ) 10 | self.stage2=nn.Sequential( 11 | nn.ConvTranspose2d(in_channels=64,out_channels=3,kernel_size=3,padding=1,bias=False), 12 | ) 13 | 14 | def forward(self, x): 15 | x1 = self.stage1(x) 16 | x2 = self.stage2(x1) 17 | return x + x2 -------------------------------------------------------------------------------- /stablevideo/atlas_data.py: -------------------------------------------------------------------------------- 1 | import random 2 | import json 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.utils.data import Dataset 7 | from torchvision import transforms 8 | from torchvision.transforms.functional import crop 9 | 10 | from stablevideo.atlas_utils import ( 11 | load_neural_atlases_models, 12 | get_frames_data, 13 | get_high_res_atlas, 14 | get_atlas_crops, 15 | reconstruct_video_layer, 16 | create_uv_mask, 17 | get_masks_boundaries, 18 | get_random_crop_params, 19 | get_atlas_bounding_box, 20 | load_video 21 | ) 22 | 23 | class AtlasData(): 24 | def __init__(self, video_name): 25 | with open(f"data/{video_name}/config.json", "r") as f: 26 | json_dict = json.load(f) 27 | try: 28 | maximum_number_of_frames = json_dict["number_of_frames"] 29 | except: 30 | maximum_number_of_frames = json_dict["maximum_number_of_frames"] 31 | 32 | config = { 33 | "device": "cuda", 34 | "checkpoint_path": f"data/{video_name}/checkpoint", 35 | "resx": json_dict["resx"], 36 | "resy": json_dict["resy"], 37 | "maximum_number_of_frames": maximum_number_of_frames, 38 | "return_atlas_alpha": False, 39 | "grid_atlas_resolution": 2000, 40 | "num_scales": 7, 41 | "masks_border_expansion": 30, 42 | "mask_alpha_threshold": 0.99, # 0.95 43 | "align_corners": False 44 | } 45 | self.config = config 46 | self.device = config["device"] 47 | 48 | self.min_size = min(self.config["resx"], self.config["resy"]) 49 | self.max_size = max(self.config["resx"], self.config["resy"]) 50 | data_folder = f"data/{video_name}/{video_name}" 51 | self.original_video = load_video( 52 | data_folder, 53 | resize=(self.config["resy"], self.config["resx"]), 54 | num_frames=self.config["maximum_number_of_frames"], 55 | ) 56 | self.original_video = self.original_video.to(self.device) # tensor 57 | 58 | ( 59 | foreground_mapping, 60 | background_mapping, 61 | foreground_atlas_model, 62 | background_atlas_model, 63 | alpha_model, 64 | ) = load_neural_atlases_models(config) 65 | ( 66 | original_background_all_uvs, 67 | original_foreground_all_uvs, 68 | self.all_alpha, 69 | foreground_atlas_alpha, 70 | ) = get_frames_data( 71 | config, 72 | foreground_mapping, 73 | background_mapping, 74 | alpha_model, 75 | ) 76 | 77 | self.background_reconstruction = reconstruct_video_layer(original_background_all_uvs, background_atlas_model) 78 | # using original video for the foreground layer 79 | self.foreground_reconstruction = self.original_video * self.all_alpha 80 | 81 | ( 82 | self.background_all_uvs, 83 | self.scaled_background_uvs, 84 | self.background_min_u, 85 | self.background_min_v, 86 | self.background_max_u, 87 | self.background_max_v, 88 | ) = self.preprocess_uv_values( 89 | original_background_all_uvs, config["grid_atlas_resolution"], device=self.device, layer="background" 90 | ) 91 | ( 92 | self.foreground_all_uvs, 93 | self.scaled_foreground_uvs, 94 | self.foreground_min_u, 95 | self.foreground_min_v, 96 | self.foreground_max_u, 97 | self.foreground_max_v, 98 | ) = self.preprocess_uv_values( 99 | original_foreground_all_uvs, config["grid_atlas_resolution"], device=self.device, layer="foreground" 100 | ) 101 | 102 | self.background_uv_mask = create_uv_mask( 103 | config, 104 | background_mapping, 105 | self.background_min_u, 106 | self.background_min_v, 107 | self.background_max_u, 108 | self.background_max_v, 109 | uv_shift=-0.5, 110 | resolution_shift=1, 111 | ) 112 | self.foreground_uv_mask = create_uv_mask( 113 | config, 114 | foreground_mapping, 115 | self.foreground_min_u, 116 | self.foreground_min_v, 117 | self.foreground_max_u, 118 | self.foreground_max_v, 119 | uv_shift=0.5, 120 | resolution_shift=0, 121 | ) 122 | self.background_grid_atlas = get_high_res_atlas( 123 | background_atlas_model, 124 | self.background_min_v, 125 | self.background_min_u, 126 | self.background_max_v, 127 | self.background_max_u, 128 | config["grid_atlas_resolution"], 129 | device=config["device"], 130 | layer="background", 131 | ) 132 | self.foreground_grid_atlas = get_high_res_atlas( 133 | foreground_atlas_model, 134 | self.foreground_min_v, 135 | self.foreground_min_u, 136 | self.foreground_max_v, 137 | self.foreground_max_u, 138 | config["grid_atlas_resolution"], 139 | device=config["device"], 140 | layer="foreground", 141 | ) 142 | if config["return_atlas_alpha"]: 143 | self.foreground_atlas_alpha = foreground_atlas_alpha # used for visualizations 144 | self.cnn_min_crop_size = 2 ** self.config["num_scales"] + 1 145 | 146 | self.mask_boundaries = get_masks_boundaries( 147 | alpha_video=self.all_alpha.cpu(), 148 | border=self.config["masks_border_expansion"], 149 | threshold=self.config["mask_alpha_threshold"], 150 | min_crop_size=self.cnn_min_crop_size, 151 | ) 152 | self.cropped_foreground_atlas, self.foreground_atlas_bbox = get_atlas_bounding_box( 153 | self.mask_boundaries, self.foreground_grid_atlas, self.foreground_all_uvs 154 | ) 155 | 156 | self.step = -1 157 | self.edited_atlas_dict, self.edit_dict, self.uv_mask = {}, {}, {} 158 | 159 | @staticmethod 160 | def preprocess_uv_values(layer_uv_values, resolution, device="cuda", layer="background"): 161 | if layer == "background": 162 | shift = 1 163 | else: 164 | shift = 0 165 | uv_values = (layer_uv_values + shift) * resolution 166 | min_u, min_v = uv_values.reshape(-1, 2).min(dim=0).values.long() 167 | uv_values -= torch.tensor([min_u, min_v], device=device) 168 | max_u, max_v = uv_values.reshape(-1, 2).max(dim=0).values.ceil().long() 169 | 170 | edge_size = torch.tensor([max_u, max_v], device=device) 171 | scaled_uv_values = ((uv_values.reshape(-1, 2) / edge_size) * 2 - 1).unsqueeze(1).unsqueeze(0) 172 | 173 | return uv_values, scaled_uv_values, min_u, min_v, max_u, max_v 174 | 175 | def get_random_crop_data(self, crop_size): 176 | t = random.randint(0, self.config["maximum_number_of_frames"] - 1) 177 | y_start, x_start, h_crop, w_crop = get_random_crop_params((self.config["resx"], self.config["resy"]), crop_size) 178 | return y_start, x_start, h_crop, w_crop, t 179 | 180 | def get_global_crops_multi(self, keyframes, res): 181 | foreground_atlas_crops = [] 182 | background_atlas_crops = [] 183 | foreground_uvs = [] 184 | background_uvs = [] 185 | background_alpha_crops = [] 186 | foreground_alpha_crops = [] 187 | original_background_crops = [] 188 | original_foreground_crops = [] 189 | output_dict = {} 190 | 191 | self.config["crops_min_cover"] = 0.95 192 | self.config["grid_atlas_resolution"] = res 193 | 194 | for cur_frame in keyframes: 195 | y_start, x_start, frame_h, frame_w = self.mask_boundaries[cur_frame].tolist() 196 | crop_size = ( 197 | max( 198 | random.randint(round(self.config["crops_min_cover"] * frame_h), frame_h), 199 | self.cnn_min_crop_size, 200 | ), 201 | max( 202 | random.randint(round(self.config["crops_min_cover"] * frame_w), frame_w), 203 | self.cnn_min_crop_size, 204 | ), 205 | ) 206 | y_crop, x_crop, h_crop, w_crop = get_random_crop_params((frame_w, frame_h), crop_size) 207 | 208 | foreground_uv = self.foreground_all_uvs[ 209 | cur_frame, 210 | y_start + y_crop : y_start + y_crop + h_crop, 211 | x_start + x_crop : x_start + x_crop + w_crop, 212 | ] 213 | alpha = self.all_alpha[ 214 | [cur_frame], 215 | :, 216 | y_start + y_crop : y_start + y_crop + h_crop, 217 | x_start + x_crop : x_start + x_crop + w_crop, 218 | ] 219 | 220 | original_foreground_crop = self.foreground_reconstruction[ 221 | [cur_frame], 222 | :, 223 | y_start + y_crop : y_start + y_crop + h_crop, 224 | x_start + x_crop : x_start + x_crop + w_crop, 225 | ] 226 | 227 | foreground_alpha_crops.append(alpha) 228 | foreground_uvs.append(foreground_uv) # not scaled 229 | original_foreground_crops.append(original_foreground_crop) 230 | 231 | foreground_max_vals = torch.tensor( 232 | [self.config["grid_atlas_resolution"]] * 2, device=self.device, dtype=torch.long 233 | ) 234 | foreground_min_vals = torch.tensor([0] * 2, device=self.device, dtype=torch.long) 235 | for uv_values in foreground_uvs: 236 | min_uv = uv_values.amin(dim=[0, 1]).long() 237 | max_uv = uv_values.amax(dim=[0, 1]).ceil().long() 238 | foreground_min_vals = torch.minimum(foreground_min_vals, min_uv) 239 | foreground_max_vals = torch.maximum(foreground_max_vals, max_uv) 240 | 241 | h_v = foreground_max_vals[1] - foreground_min_vals[1] 242 | w_u = foreground_max_vals[0] - foreground_min_vals[0] 243 | foreground_atlas_crop = crop( 244 | self.foreground_grid_atlas, 245 | foreground_min_vals[1], 246 | foreground_min_vals[0], 247 | h_v, 248 | w_u, 249 | ) 250 | foreground_atlas_crops.append(foreground_atlas_crop) 251 | 252 | for i, uv_values in enumerate(foreground_uvs): 253 | foreground_uvs[i] = ( 254 | 2 * (uv_values - foreground_min_vals) / (foreground_max_vals - foreground_min_vals) - 1 255 | ).unsqueeze(0) 256 | 257 | 258 | crop_size = ( 259 | random.randint(round(self.config["crops_min_cover"] * self.min_size), self.min_size), 260 | random.randint(round(self.config["crops_min_cover"] * self.max_size), self.max_size), 261 | ) 262 | crop_data = self.get_random_crop_data(crop_size) 263 | y, x, h, w, _ = crop_data 264 | background_uv = self.background_all_uvs[keyframes, y : y + h, x : x + w] 265 | original_background_crop = self.background_reconstruction[ 266 | keyframes, :, y : y + h, x : x + w 267 | ] 268 | alpha = self.all_alpha[keyframes, :, y : y + h, x : x + w] 269 | 270 | 271 | original_background_crops = [el.unsqueeze(0) for el in original_background_crop] 272 | background_alpha_crops = [el.unsqueeze(0) for el in alpha] 273 | 274 | background_atlas_crop, background_min_vals, background_max_vals = get_atlas_crops( 275 | background_uv, 276 | self.background_grid_atlas, 277 | ) 278 | background_uv = 2 * (background_uv - background_min_vals) / (background_max_vals - background_min_vals) - 1 279 | background_atlas_crops = [el.unsqueeze(0) for el in background_atlas_crop] 280 | background_uvs = [el.unsqueeze(0) for el in background_uv] 281 | 282 | output_dict["foreground_alpha"] = foreground_alpha_crops 283 | output_dict["foreground_uvs"] = foreground_uvs 284 | output_dict["original_foreground_crops"] = original_foreground_crops 285 | output_dict["foreground_atlas_crops"] = foreground_atlas_crops 286 | 287 | output_dict["background_alpha"] = background_alpha_crops 288 | output_dict["background_uvs"] = background_uvs 289 | output_dict["original_background_crops"] = original_background_crops 290 | output_dict["background_atlas_crops"] = background_atlas_crops 291 | return output_dict -------------------------------------------------------------------------------- /stablevideo/implicit_neural_networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | # code taken from https://github.com/ykasten/layered-neural-atlases 7 | 8 | 9 | def count_parameters(model): 10 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 11 | 12 | 13 | def positionalEncoding_vec(in_tensor, b): 14 | proj = torch.einsum("ij, k -> ijk", in_tensor, b) # shape (batch, in_tensor.size(1), freqNum) 15 | mapped_coords = torch.cat((torch.sin(proj), torch.cos(proj)), dim=1) # shape (batch, 2*in_tensor.size(1), freqNum) 16 | output = mapped_coords.transpose(2, 1).contiguous().view(mapped_coords.size(0), -1) 17 | return output 18 | 19 | 20 | class IMLP(nn.Module): 21 | def __init__( 22 | self, 23 | input_dim, 24 | output_dim, 25 | hidden_dim=256, 26 | use_positional=True, 27 | positional_dim=10, 28 | skip_layers=[4, 6], 29 | num_layers=8, # includes the output layer 30 | verbose=True, 31 | use_tanh=True, 32 | apply_softmax=False, 33 | ): 34 | super(IMLP, self).__init__() 35 | self.verbose = verbose 36 | self.use_tanh = use_tanh 37 | self.apply_softmax = apply_softmax 38 | if apply_softmax: 39 | self.softmax = nn.Softmax() 40 | if use_positional: 41 | encoding_dimensions = 2 * input_dim * positional_dim 42 | self.b = torch.tensor([(2 ** j) * np.pi for j in range(positional_dim)], requires_grad=False) 43 | else: 44 | encoding_dimensions = input_dim 45 | 46 | self.hidden = nn.ModuleList() 47 | for i in range(num_layers): 48 | if i == 0: 49 | input_dims = encoding_dimensions 50 | elif i in skip_layers: 51 | input_dims = hidden_dim + encoding_dimensions 52 | else: 53 | input_dims = hidden_dim 54 | 55 | if i == num_layers - 1: 56 | # last layer 57 | self.hidden.append(nn.Linear(input_dims, output_dim, bias=True)) 58 | else: 59 | self.hidden.append(nn.Linear(input_dims, hidden_dim, bias=True)) 60 | 61 | self.skip_layers = skip_layers 62 | self.num_layers = num_layers 63 | 64 | self.positional_dim = positional_dim 65 | self.use_positional = use_positional 66 | 67 | if self.verbose: 68 | print(f"Model has {count_parameters(self)} params") 69 | 70 | def forward(self, x): 71 | if self.use_positional: 72 | if self.b.device != x.device: 73 | self.b = self.b.to(x.device) 74 | pos = positionalEncoding_vec(x, self.b) 75 | x = pos 76 | 77 | input = x.detach().clone() 78 | for i, layer in enumerate(self.hidden): 79 | if i > 0: 80 | x = F.relu(x) 81 | if i in self.skip_layers: 82 | x = torch.cat((x, input), 1) 83 | x = layer(x) 84 | if self.use_tanh: 85 | x = torch.tanh(x) 86 | 87 | if self.apply_softmax: 88 | x = self.softmax(x) 89 | return x 90 | --------------------------------------------------------------------------------