├── LICENSE ├── README.md ├── configs ├── sample_i2v.yaml └── sample_transition.yaml ├── datasets └── video_transforms.py ├── diffusion ├── __init__.py ├── diffusion_utils.py ├── gaussian_diffusion.py ├── respace.py └── timestep_sampler.py ├── examples ├── Close-up essence is poured from bottleKodak Vision3 50,slow motion_0000_001.gif ├── The picture shows the beauty of the sea and at the sam,slow motion_0000_11301.gif ├── The picture shows the beauty of the sea and at the sam,slow motion_0000_6600.gif ├── Travel from Earth's spring blossoms to the alien cherry blossom forestssmooth transition, slow motion_0000_003.gif ├── flying through fantasy landscapes in the cloud, 4k, high resolution._0000_885.gif ├── orange-flower.gif └── spiderman-becomes-a-sand-sculpture.gif ├── input ├── i2v │ ├── Close-up_essence_is_poured_from_bottleKodak_Vision.png │ ├── The_picture_shows_the_beauty_of_the_sea.png │ └── The_picture_shows_the_beauty_of_the_sea_and_at_the_same.png └── transition │ ├── 1 │ ├── 1-Close-up shot of a blooming cherry tree, realism-1.png │ └── 2-Wide angle shot of an alien planet with cherry blossom forest-2.png │ ├── 2 │ ├── 1-Overhead view of a bustling city street at night, realism-1.png │ └── 2-Aerial view of a futuristic city bathed in neon lights-2.png │ └── 3 │ ├── 1-Close-up shot of a candle lit in the darkness, realism-1png.png │ └── 2-Wide shot of a mystical land illuminated by giant glowing flowers-2.png ├── models ├── __init__.py ├── attention.py ├── clip.py ├── resnet.py ├── unet.py ├── unet_blocks.py └── utils.py ├── requirement.txt ├── sample_scripts └── with_mask_sample.py ├── seine.gif └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SEINE 2 | [![arXiv](https://img.shields.io/badge/arXiv-2310.20700-b31b1b.svg)](https://arxiv.org/abs/2310.20700) 3 | [![Project Page](https://img.shields.io/badge/SEINE-Website-green)](https://vchitect.github.io/SEINE-project/) 4 | [![Replicate](https://replicate.com/lucataco/seine/badge)](https://replicate.com/lucataco/seine) 5 | [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-yellow)](https://huggingface.co/spaces/Vchitect/SEINE) 6 | [![Hits](https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.com%2FVchitect%2FSEINE&count_bg=%23F59352&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=visitors&edge_flat=false)](https://hits.seeyoufarm.com) 7 | 8 | This repository is the official implementation of [SEINE](https://arxiv.org/abs/2310.20700): 9 | 10 | **[SEINE: Short-to-Long Video Diffusion Model for Generative Transition and Prediction (ICLR2024)](https://arxiv.org/abs/2310.20700)** 11 | 12 | **SEINE** is a video diffusion model and is part of the video generation system [Vchitect](http://vchitect.intern-ai.org.cn/). 13 | You can also check our Text-to-Video (T2V) framework [LaVie](https://github.com/Vchitect/LaVie). 14 | 15 | 16 | 17 | 18 | 19 | 20 | ## Setup 21 | 22 | ### Prepare Environment 23 | ``` 24 | conda create -n seine python==3.9.16 25 | conda activate seine 26 | pip install -r requirement.txt 27 | ``` 28 | 29 | ### Download our model and T2I base model 30 | 31 | Our model is based on Stable diffusion v1.4, you may download [Stable Diffusion v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) to the director of ``` pretrained ``` 32 | . 33 | Download our model checkpoint (from [google drive](https://drive.google.com/drive/folders/1cWfeDzKJhpb0m6HA5DoMOH0_ItuUY95b?usp=sharing) or [hugging face](https://huggingface.co/xinyuanc91/SEINE/tree/main)) and save to the directory of ```pretrained``` 34 | 35 | 36 | Now under `./pretrained`, you should be able to see the following: 37 | ``` 38 | ├── pretrained 39 | │ ├── seine.pt 40 | │ ├── stable-diffusion-v1-4 41 | │ │ ├── ... 42 | └── └── ├── ... 43 | ├── ... 44 | ``` 45 | ## Usage 46 | ### Inference for I2V 47 | Run the following command to get the I2V results: 48 | ```python 49 | python sample_scripts/with_mask_sample.py --config configs/sample_i2v.yaml 50 | ``` 51 | The generated video will be saved in ```./results/i2v```. 52 | 53 | #### More Details 54 | You may modify ```./configs/sample_i2v.yaml``` to change the generation conditions. 55 | For example: 56 | 57 | ```ckpt``` is used to specify a model checkpoint. 58 | 59 | ```text_prompt``` is used to describe the content of the video. 60 | 61 | ```input_path``` is used to specify the path to the image. 62 | 63 | ### Inference for Transition 64 | ```python 65 | python sample_scripts/with_mask_sample.py --config configs/sample_transition.yaml 66 | ``` 67 | The generated video will be saved in ```./results/transition```. 68 | 69 | 70 | 71 | 72 | ## Results 73 | ### I2V Results 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 |
Input ImageOutput Video
95 | 96 | 97 | ### Transition Results 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 |
Input ImagesOutput Video
115 | 116 | ## BibTeX 117 | ```bibtex 118 | @inproceedings{chen2023seine, 119 | title={Seine: Short-to-long video diffusion model for generative transition and prediction}, 120 | author={Chen, Xinyuan and Wang, Yaohui and Zhang, Lingjun and Zhuang, Shaobin and Ma, Xin and Yu, Jiashuo and Wang, Yali and Lin, Dahua and Qiao, Yu and Liu, Ziwei}, 121 | booktitle={ICLR}, 122 | year={2023} 123 | } 124 | ``` 125 | 126 | ```bibtex 127 | @article{wang2023lavie, 128 | title={LAVIE: High-Quality Video Generation with Cascaded Latent Diffusion Models}, 129 | author={Wang, Yaohui and Chen, Xinyuan and Ma, Xin and Zhou, Shangchen and Huang, Ziqi and Wang, Yi and Yang, Ceyuan and He, Yinan and Yu, Jiashuo and Yang, Peiqing and others}, 130 | journal={IJCV}, 131 | year={2024} 132 | } 133 | ``` 134 | 135 | ## Disclaimer 136 | We disclaim responsibility for user-generated content. The model was not trained to realistically represent people or events, so using it to generate such content is beyond the model's capabilities. It is prohibited for pornographic, violent and bloody content generation, and to generate content that is demeaning or harmful to people or their environment, culture, religion, etc. Users are solely liable for their actions. The project contributors are not legally affiliated with, nor accountable for users' behaviors. Use the generative model responsibly, adhering to ethical and legal standards. 137 | 138 | ## Contact Us 139 | **Xinyuan Chen**: [chenxinyuan@pjlab.org.cn](mailto:chenxinyuan@pjlab.org.cn) 140 | **Yaohui Wang**: [wangyaohui@pjlab.org.cn](mailto:wangyaohui@pjlab.org.cn) 141 | 142 | ## Acknowledgements 143 | The code is built upon [LaVie](https://github.com/Vchitect/LaVie), [diffusers](https://github.com/huggingface/diffusers) and [Stable Diffusion](https://github.com/CompVis/stable-diffusion), we thank all the contributors for open-sourcing. 144 | 145 | 146 | ## License 147 | The code is licensed under Apache-2.0, model weights are fully open for academic research and also allow **free** commercial usage. To apply for a commercial license, please contact vchitect@pjlab.org.cn. 148 | -------------------------------------------------------------------------------- /configs/sample_i2v.yaml: -------------------------------------------------------------------------------- 1 | # path config: 2 | ckpt: "pretrained/seine.pt" 3 | pretrained_model_path: "pretrained/stable-diffusion-v1-4/" 4 | #input_path: 'input/i2v/The_picture_shows_the_beauty_of_the_sea_.jpg' 5 | input_path: 'input/i2v/Close-up_essence_is_poured_from_bottleKodak_Vision.png' 6 | save_path: "./results/i2v/" 7 | 8 | # model config: 9 | model: UNet 10 | num_frames: 16 11 | image_size: [240, 560] 12 | #image_size: [320, 512] 13 | # image_size: [512, 512] 14 | 15 | # model speedup 16 | use_fp16: True 17 | enable_xformers_memory_efficient_attention: True 18 | 19 | # sample config: 20 | seed: 21 | run_time: 13 22 | cfg_scale: 8.0 23 | sample_method: 'ddpm' 24 | num_sampling_steps: 250 25 | text_prompt: [] 26 | additional_prompt: ", slow motion." 27 | negative_prompt: "" 28 | do_classifier_free_guidance: True 29 | mask_type: "first1" 30 | use_mask: True 31 | -------------------------------------------------------------------------------- /configs/sample_transition.yaml: -------------------------------------------------------------------------------- 1 | #path config: 2 | ckpt: "pretrained/seine.pt" 3 | pretrained_model_path: "pretrained/stable-diffusion-v1-4/" 4 | input_path: "input/transition/1" 5 | save_path: "./results/transition/" 6 | 7 | # model config: 8 | model: UNet 9 | num_frames: 16 10 | #image_size: [320, 512] 11 | image_size: [512, 512] 12 | 13 | # model speedup 14 | use_fp16: True 15 | enable_xformers_memory_efficient_attention: True 16 | 17 | # sample config: 18 | seed: 0 19 | run_time: 13 20 | cfg_scale: 8.0 21 | sample_method: 'ddpm' 22 | num_sampling_steps: 250 23 | text_prompt: ['smooth transition'] 24 | additional_prompt: "smooth transition." 25 | negative_prompt: "" 26 | do_classifier_free_guidance: True 27 | mask_type: "onelast1" 28 | use_mask: True 29 | -------------------------------------------------------------------------------- /datasets/video_transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numbers 4 | from torchvision.transforms import RandomCrop, RandomResizedCrop 5 | from PIL import Image 6 | 7 | def _is_tensor_video_clip(clip): 8 | if not torch.is_tensor(clip): 9 | raise TypeError("clip should be Tensor. Got %s" % type(clip)) 10 | 11 | if not clip.ndimension() == 4: 12 | raise ValueError("clip should be 4D. Got %dD" % clip.dim()) 13 | 14 | return True 15 | 16 | 17 | def center_crop_arr(pil_image, image_size): 18 | """ 19 | Center cropping implementation from ADM. 20 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 21 | """ 22 | while min(*pil_image.size) >= 2 * image_size: 23 | pil_image = pil_image.resize( 24 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 25 | ) 26 | 27 | scale = image_size / min(*pil_image.size) 28 | pil_image = pil_image.resize( 29 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 30 | ) 31 | 32 | arr = np.array(pil_image) 33 | crop_y = (arr.shape[0] - image_size) // 2 34 | crop_x = (arr.shape[1] - image_size) // 2 35 | return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) 36 | 37 | 38 | def crop(clip, i, j, h, w): 39 | """ 40 | Args: 41 | clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) 42 | """ 43 | if len(clip.size()) != 4: 44 | raise ValueError("clip should be a 4D tensor") 45 | return clip[..., i : i + h, j : j + w] 46 | 47 | 48 | def resize(clip, target_size, interpolation_mode): 49 | if len(target_size) != 2: 50 | raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") 51 | return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False) 52 | 53 | def resize_scale(clip, target_size, interpolation_mode): 54 | if len(target_size) != 2: 55 | raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") 56 | H, W = clip.size(-2), clip.size(-1) 57 | scale_ = target_size[0] / min(H, W) 58 | return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False) 59 | 60 | def resize_with_scale_factor(clip, scale_factor, interpolation_mode): 61 | return torch.nn.functional.interpolate(clip, scale_factor=scale_factor, mode=interpolation_mode, align_corners=False) 62 | 63 | def resize_scale_with_height(clip, target_size, interpolation_mode): 64 | H, W = clip.size(-2), clip.size(-1) 65 | scale_ = target_size / H 66 | return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False) 67 | 68 | def resize_scale_with_weight(clip, target_size, interpolation_mode): 69 | H, W = clip.size(-2), clip.size(-1) 70 | scale_ = target_size / W 71 | return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False) 72 | 73 | 74 | def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"): 75 | """ 76 | Do spatial cropping and resizing to the video clip 77 | Args: 78 | clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) 79 | i (int): i in (i,j) i.e coordinates of the upper left corner. 80 | j (int): j in (i,j) i.e coordinates of the upper left corner. 81 | h (int): Height of the cropped region. 82 | w (int): Width of the cropped region. 83 | size (tuple(int, int)): height and width of resized clip 84 | Returns: 85 | clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W) 86 | """ 87 | if not _is_tensor_video_clip(clip): 88 | raise ValueError("clip should be a 4D torch.tensor") 89 | clip = crop(clip, i, j, h, w) 90 | clip = resize(clip, size, interpolation_mode) 91 | return clip 92 | 93 | 94 | def center_crop(clip, crop_size): 95 | if not _is_tensor_video_clip(clip): 96 | raise ValueError("clip should be a 4D torch.tensor") 97 | h, w = clip.size(-2), clip.size(-1) 98 | # print(clip.shape) 99 | th, tw = crop_size 100 | if h < th or w < tw: 101 | # print(h, w) 102 | raise ValueError("height {} and width {} must be no smaller than crop_size".format(h, w)) 103 | 104 | i = int(round((h - th) / 2.0)) 105 | j = int(round((w - tw) / 2.0)) 106 | return crop(clip, i, j, th, tw) 107 | 108 | 109 | def center_crop_using_short_edge(clip): 110 | if not _is_tensor_video_clip(clip): 111 | raise ValueError("clip should be a 4D torch.tensor") 112 | h, w = clip.size(-2), clip.size(-1) 113 | if h < w: 114 | th, tw = h, h 115 | i = 0 116 | j = int(round((w - tw) / 2.0)) 117 | else: 118 | th, tw = w, w 119 | i = int(round((h - th) / 2.0)) 120 | j = 0 121 | return crop(clip, i, j, th, tw) 122 | 123 | 124 | def random_shift_crop(clip): 125 | ''' 126 | Slide along the long edge, with the short edge as crop size 127 | ''' 128 | if not _is_tensor_video_clip(clip): 129 | raise ValueError("clip should be a 4D torch.tensor") 130 | h, w = clip.size(-2), clip.size(-1) 131 | 132 | if h <= w: 133 | long_edge = w 134 | short_edge = h 135 | else: 136 | long_edge = h 137 | short_edge =w 138 | 139 | th, tw = short_edge, short_edge 140 | 141 | i = torch.randint(0, h - th + 1, size=(1,)).item() 142 | j = torch.randint(0, w - tw + 1, size=(1,)).item() 143 | return crop(clip, i, j, th, tw) 144 | 145 | 146 | def to_tensor(clip): 147 | """ 148 | Convert tensor data type from uint8 to float, divide value by 255.0 and 149 | permute the dimensions of clip tensor 150 | Args: 151 | clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) 152 | Return: 153 | clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) 154 | """ 155 | _is_tensor_video_clip(clip) 156 | if not clip.dtype == torch.uint8: 157 | raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype)) 158 | # return clip.float().permute(3, 0, 1, 2) / 255.0 159 | return clip.float() / 255.0 160 | 161 | 162 | def normalize(clip, mean, std, inplace=False): 163 | """ 164 | Args: 165 | clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W) 166 | mean (tuple): pixel RGB mean. Size is (3) 167 | std (tuple): pixel standard deviation. Size is (3) 168 | Returns: 169 | normalized clip (torch.tensor): Size is (T, C, H, W) 170 | """ 171 | if not _is_tensor_video_clip(clip): 172 | raise ValueError("clip should be a 4D torch.tensor") 173 | if not inplace: 174 | clip = clip.clone() 175 | mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device) 176 | # print(mean) 177 | std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device) 178 | clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) 179 | return clip 180 | 181 | 182 | def hflip(clip): 183 | """ 184 | Args: 185 | clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W) 186 | Returns: 187 | flipped clip (torch.tensor): Size is (T, C, H, W) 188 | """ 189 | if not _is_tensor_video_clip(clip): 190 | raise ValueError("clip should be a 4D torch.tensor") 191 | return clip.flip(-1) 192 | 193 | 194 | class RandomCropVideo: 195 | def __init__(self, size): 196 | if isinstance(size, numbers.Number): 197 | self.size = (int(size), int(size)) 198 | else: 199 | self.size = size 200 | 201 | def __call__(self, clip): 202 | """ 203 | Args: 204 | clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) 205 | Returns: 206 | torch.tensor: randomly cropped video clip. 207 | size is (T, C, OH, OW) 208 | """ 209 | i, j, h, w = self.get_params(clip) 210 | return crop(clip, i, j, h, w) 211 | 212 | def get_params(self, clip): 213 | h, w = clip.shape[-2:] 214 | th, tw = self.size 215 | 216 | if h < th or w < tw: 217 | raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}") 218 | 219 | if w == tw and h == th: 220 | return 0, 0, h, w 221 | 222 | i = torch.randint(0, h - th + 1, size=(1,)).item() 223 | j = torch.randint(0, w - tw + 1, size=(1,)).item() 224 | 225 | return i, j, th, tw 226 | 227 | def __repr__(self) -> str: 228 | return f"{self.__class__.__name__}(size={self.size})" 229 | 230 | class CenterCropResizeVideo: 231 | ''' 232 | First use the short side for cropping length, 233 | center crop video, then resize to the specified size 234 | ''' 235 | def __init__( 236 | self, 237 | size, 238 | interpolation_mode="bilinear", 239 | ): 240 | if isinstance(size, tuple): 241 | if len(size) != 2: 242 | raise ValueError(f"size should be tuple (height, width), instead got {size}") 243 | self.size = size 244 | else: 245 | self.size = (size, size) 246 | 247 | self.interpolation_mode = interpolation_mode 248 | 249 | 250 | def __call__(self, clip): 251 | """ 252 | Args: 253 | clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) 254 | Returns: 255 | torch.tensor: scale resized / center cropped video clip. 256 | size is (T, C, crop_size, crop_size) 257 | """ 258 | # print(clip.shape) 259 | clip_center_crop = center_crop_using_short_edge(clip) 260 | # print(clip_center_crop.shape) 320 512 261 | clip_center_crop_resize = resize(clip_center_crop, target_size=self.size, interpolation_mode=self.interpolation_mode) 262 | return clip_center_crop_resize 263 | 264 | def __repr__(self) -> str: 265 | return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" 266 | 267 | 268 | class CenterCropVideo: 269 | def __init__( 270 | self, 271 | size, 272 | interpolation_mode="bilinear", 273 | ): 274 | if isinstance(size, tuple): 275 | if len(size) != 2: 276 | raise ValueError(f"size should be tuple (height, width), instead got {size}") 277 | self.size = size 278 | else: 279 | self.size = (size, size) 280 | 281 | self.interpolation_mode = interpolation_mode 282 | 283 | 284 | def __call__(self, clip): 285 | """ 286 | Args: 287 | clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) 288 | Returns: 289 | torch.tensor: center cropped video clip. 290 | size is (T, C, crop_size, crop_size) 291 | """ 292 | clip_center_crop = center_crop(clip, self.size) 293 | return clip_center_crop 294 | 295 | def __repr__(self) -> str: 296 | return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" 297 | 298 | 299 | class NormalizeVideo: 300 | """ 301 | Normalize the video clip by mean subtraction and division by standard deviation 302 | Args: 303 | mean (3-tuple): pixel RGB mean 304 | std (3-tuple): pixel RGB standard deviation 305 | inplace (boolean): whether do in-place normalization 306 | """ 307 | 308 | def __init__(self, mean, std, inplace=False): 309 | self.mean = mean 310 | self.std = std 311 | self.inplace = inplace 312 | 313 | def __call__(self, clip): 314 | """ 315 | Args: 316 | clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W) 317 | """ 318 | return normalize(clip, self.mean, self.std, self.inplace) 319 | 320 | def __repr__(self) -> str: 321 | return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})" 322 | 323 | 324 | class ToTensorVideo: 325 | """ 326 | Convert tensor data type from uint8 to float, divide value by 255.0 and 327 | permute the dimensions of clip tensor 328 | """ 329 | 330 | def __init__(self): 331 | pass 332 | 333 | def __call__(self, clip): 334 | """ 335 | Args: 336 | clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) 337 | Return: 338 | clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) 339 | """ 340 | return to_tensor(clip) 341 | 342 | def __repr__(self) -> str: 343 | return self.__class__.__name__ 344 | 345 | 346 | class ResizeVideo(): 347 | ''' 348 | First use the short side for cropping length, 349 | center crop video, then resize to the specified size 350 | ''' 351 | def __init__( 352 | self, 353 | size, 354 | interpolation_mode="bilinear", 355 | ): 356 | if isinstance(size, tuple): 357 | if len(size) != 2: 358 | raise ValueError(f"size should be tuple (height, width), instead got {size}") 359 | self.size = size 360 | else: 361 | self.size = (size, size) 362 | 363 | self.interpolation_mode = interpolation_mode 364 | 365 | 366 | def __call__(self, clip): 367 | """ 368 | Args: 369 | clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) 370 | Returns: 371 | torch.tensor: scale resized / center cropped video clip. 372 | size is (T, C, crop_size, crop_size) 373 | """ 374 | clip_resize = resize(clip, target_size=self.size, interpolation_mode=self.interpolation_mode) 375 | return clip_resize 376 | 377 | def __repr__(self) -> str: 378 | return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" 379 | 380 | # ------------------------------------------------------------ 381 | # --------------------- Sampling --------------------------- 382 | # ------------------------------------------------------------ 383 | -------------------------------------------------------------------------------- /diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from . import gaussian_diffusion as gd 7 | from .respace import SpacedDiffusion, space_timesteps 8 | 9 | 10 | def create_diffusion( 11 | timestep_respacing, 12 | noise_schedule="linear", 13 | use_kl=False, 14 | sigma_small=False, 15 | predict_xstart=False, 16 | # learn_sigma=True, 17 | learn_sigma=False, # for unet 18 | rescale_learned_sigmas=False, 19 | diffusion_steps=1000 20 | ): 21 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) 22 | if use_kl: 23 | loss_type = gd.LossType.RESCALED_KL 24 | elif rescale_learned_sigmas: 25 | loss_type = gd.LossType.RESCALED_MSE 26 | else: 27 | loss_type = gd.LossType.MSE 28 | if timestep_respacing is None or timestep_respacing == "": 29 | timestep_respacing = [diffusion_steps] 30 | return SpacedDiffusion( 31 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), 32 | betas=betas, 33 | model_mean_type=( 34 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 35 | ), 36 | model_var_type=( 37 | ( 38 | gd.ModelVarType.FIXED_LARGE 39 | if not sigma_small 40 | else gd.ModelVarType.FIXED_SMALL 41 | ) 42 | if not learn_sigma 43 | else gd.ModelVarType.LEARNED_RANGE 44 | ), 45 | loss_type=loss_type 46 | # rescale_timesteps=rescale_timesteps, 47 | ) 48 | -------------------------------------------------------------------------------- /diffusion/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import torch as th 7 | import numpy as np 8 | 9 | 10 | def normal_kl(mean1, logvar1, mean2, logvar2): 11 | """ 12 | Compute the KL divergence between two gaussians. 13 | Shapes are automatically broadcasted, so batches can be compared to 14 | scalars, among other use cases. 15 | """ 16 | tensor = None 17 | for obj in (mean1, logvar1, mean2, logvar2): 18 | if isinstance(obj, th.Tensor): 19 | tensor = obj 20 | break 21 | assert tensor is not None, "at least one argument must be a Tensor" 22 | 23 | # Force variances to be Tensors. Broadcasting helps convert scalars to 24 | # Tensors, but it does not work for th.exp(). 25 | logvar1, logvar2 = [ 26 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 27 | for x in (logvar1, logvar2) 28 | ] 29 | 30 | return 0.5 * ( 31 | -1.0 32 | + logvar2 33 | - logvar1 34 | + th.exp(logvar1 - logvar2) 35 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 36 | ) 37 | 38 | 39 | def approx_standard_normal_cdf(x): 40 | """ 41 | A fast approximation of the cumulative distribution function of the 42 | standard normal. 43 | """ 44 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 45 | 46 | 47 | def continuous_gaussian_log_likelihood(x, *, means, log_scales): 48 | """ 49 | Compute the log-likelihood of a continuous Gaussian distribution. 50 | :param x: the targets 51 | :param means: the Gaussian mean Tensor. 52 | :param log_scales: the Gaussian log stddev Tensor. 53 | :return: a tensor like x of log probabilities (in nats). 54 | """ 55 | centered_x = x - means 56 | inv_stdv = th.exp(-log_scales) 57 | normalized_x = centered_x * inv_stdv 58 | log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x) 59 | return log_probs 60 | 61 | 62 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 63 | """ 64 | Compute the log-likelihood of a Gaussian distribution discretizing to a 65 | given image. 66 | :param x: the target images. It is assumed that this was uint8 values, 67 | rescaled to the range [-1, 1]. 68 | :param means: the Gaussian mean Tensor. 69 | :param log_scales: the Gaussian log stddev Tensor. 70 | :return: a tensor like x of log probabilities (in nats). 71 | """ 72 | assert x.shape == means.shape == log_scales.shape 73 | centered_x = x - means 74 | inv_stdv = th.exp(-log_scales) 75 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 76 | cdf_plus = approx_standard_normal_cdf(plus_in) 77 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 78 | cdf_min = approx_standard_normal_cdf(min_in) 79 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 80 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 81 | cdf_delta = cdf_plus - cdf_min 82 | log_probs = th.where( 83 | x < -0.999, 84 | log_cdf_plus, 85 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 86 | ) 87 | assert log_probs.shape == x.shape 88 | return log_probs 89 | -------------------------------------------------------------------------------- /diffusion/gaussian_diffusion.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | 7 | import math 8 | 9 | import numpy as np 10 | import torch as th 11 | import enum 12 | 13 | from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl 14 | 15 | 16 | def mean_flat(tensor): 17 | """ 18 | Take the mean over all non-batch dimensions. 19 | """ 20 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 21 | 22 | 23 | class ModelMeanType(enum.Enum): 24 | """ 25 | Which type of output the model predicts. 26 | """ 27 | 28 | PREVIOUS_X = enum.auto() # the model predicts x_{t-1} 29 | START_X = enum.auto() # the model predicts x_0 30 | EPSILON = enum.auto() # the model predicts epsilon 31 | 32 | 33 | class ModelVarType(enum.Enum): 34 | """ 35 | What is used as the model's output variance. 36 | The LEARNED_RANGE option has been added to allow the model to predict 37 | values between FIXED_SMALL and FIXED_LARGE, making its job easier. 38 | """ 39 | 40 | LEARNED = enum.auto() 41 | FIXED_SMALL = enum.auto() 42 | FIXED_LARGE = enum.auto() 43 | LEARNED_RANGE = enum.auto() 44 | 45 | 46 | class LossType(enum.Enum): 47 | MSE = enum.auto() # use raw MSE loss (and KL when learning variances) 48 | RESCALED_MSE = ( 49 | enum.auto() 50 | ) # use raw MSE loss (with RESCALED_KL when learning variances) 51 | KL = enum.auto() # use the variational lower-bound 52 | RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB 53 | 54 | def is_vb(self): 55 | return self == LossType.KL or self == LossType.RESCALED_KL 56 | 57 | 58 | def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): 59 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) 60 | warmup_time = int(num_diffusion_timesteps * warmup_frac) 61 | betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) 62 | return betas 63 | 64 | 65 | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): 66 | """ 67 | This is the deprecated API for creating beta schedules. 68 | See get_named_beta_schedule() for the new library of schedules. 69 | """ 70 | if beta_schedule == "quad": 71 | betas = ( 72 | np.linspace( 73 | beta_start ** 0.5, 74 | beta_end ** 0.5, 75 | num_diffusion_timesteps, 76 | dtype=np.float64, 77 | ) 78 | ** 2 79 | ) 80 | elif beta_schedule == "linear": 81 | betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) 82 | elif beta_schedule == "warmup10": 83 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) 84 | elif beta_schedule == "warmup50": 85 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) 86 | elif beta_schedule == "const": 87 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) 88 | elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 89 | betas = 1.0 / np.linspace( 90 | num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 91 | ) 92 | else: 93 | raise NotImplementedError(beta_schedule) 94 | assert betas.shape == (num_diffusion_timesteps,) 95 | return betas 96 | 97 | 98 | def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): 99 | """ 100 | Get a pre-defined beta schedule for the given name. 101 | The beta schedule library consists of beta schedules which remain similar 102 | in the limit of num_diffusion_timesteps. 103 | Beta schedules may be added, but should not be removed or changed once 104 | they are committed to maintain backwards compatibility. 105 | """ 106 | if schedule_name == "linear": 107 | # Linear schedule from Ho et al, extended to work for any number of 108 | # diffusion steps. 109 | scale = 1000 / num_diffusion_timesteps 110 | return get_beta_schedule( 111 | "linear", 112 | beta_start=scale * 0.0001, 113 | beta_end=scale * 0.02, 114 | # diffuser stable diffusion 115 | # beta_start=scale * 0.00085, 116 | # beta_end=scale * 0.012, 117 | num_diffusion_timesteps=num_diffusion_timesteps, 118 | ) 119 | elif schedule_name == "squaredcos_cap_v2": 120 | return betas_for_alpha_bar( 121 | num_diffusion_timesteps, 122 | lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, 123 | ) 124 | else: 125 | raise NotImplementedError(f"unknown beta schedule: {schedule_name}") 126 | 127 | 128 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 129 | """ 130 | Create a beta schedule that discretizes the given alpha_t_bar function, 131 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 132 | :param num_diffusion_timesteps: the number of betas to produce. 133 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 134 | produces the cumulative product of (1-beta) up to that 135 | part of the diffusion process. 136 | :param max_beta: the maximum beta to use; use values lower than 1 to 137 | prevent singularities. 138 | """ 139 | betas = [] 140 | for i in range(num_diffusion_timesteps): 141 | t1 = i / num_diffusion_timesteps 142 | t2 = (i + 1) / num_diffusion_timesteps 143 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 144 | return np.array(betas) 145 | 146 | 147 | class GaussianDiffusion: 148 | """ 149 | Utilities for training and sampling diffusion models. 150 | Original ported from this codebase: 151 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 152 | :param betas: a 1-D numpy array of betas for each diffusion timestep, 153 | starting at T and going to 1. 154 | """ 155 | 156 | def __init__( 157 | self, 158 | *, 159 | betas, 160 | model_mean_type, 161 | model_var_type, 162 | loss_type 163 | ): 164 | 165 | self.model_mean_type = model_mean_type 166 | self.model_var_type = model_var_type 167 | self.loss_type = loss_type 168 | 169 | # Use float64 for accuracy. 170 | betas = np.array(betas, dtype=np.float64) 171 | self.betas = betas 172 | assert len(betas.shape) == 1, "betas must be 1-D" 173 | assert (betas > 0).all() and (betas <= 1).all() 174 | 175 | self.num_timesteps = int(betas.shape[0]) 176 | 177 | alphas = 1.0 - betas 178 | self.alphas_cumprod = np.cumprod(alphas, axis=0) 179 | self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) 180 | self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) 181 | assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) 182 | 183 | # calculations for diffusion q(x_t | x_{t-1}) and others 184 | self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) 185 | self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) 186 | self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) 187 | self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) 188 | self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) 189 | 190 | # calculations for posterior q(x_{t-1} | x_t, x_0) 191 | self.posterior_variance = ( 192 | betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 193 | ) 194 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 195 | self.posterior_log_variance_clipped = np.log( 196 | np.append(self.posterior_variance[1], self.posterior_variance[1:]) 197 | ) if len(self.posterior_variance) > 1 else np.array([]) 198 | 199 | self.posterior_mean_coef1 = ( 200 | betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 201 | ) 202 | self.posterior_mean_coef2 = ( 203 | (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) 204 | ) 205 | 206 | def q_mean_variance(self, x_start, t): 207 | """ 208 | Get the distribution q(x_t | x_0). 209 | :param x_start: the [N x C x ...] tensor of noiseless inputs. 210 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 211 | :return: A tuple (mean, variance, log_variance), all of x_start's shape. 212 | """ 213 | mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 214 | variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) 215 | log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) 216 | return mean, variance, log_variance 217 | 218 | def q_sample(self, x_start, t, noise=None): 219 | """ 220 | Diffuse the data for a given number of diffusion steps. 221 | In other words, sample from q(x_t | x_0). 222 | :param x_start: the initial data batch. 223 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 224 | :param noise: if specified, the split-out normal noise. 225 | :return: A noisy version of x_start. 226 | """ 227 | if noise is None: 228 | noise = th.randn_like(x_start) 229 | assert noise.shape == x_start.shape 230 | return ( 231 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 232 | + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise 233 | ) 234 | 235 | def q_posterior_mean_variance(self, x_start, x_t, t): 236 | """ 237 | Compute the mean and variance of the diffusion posterior: 238 | q(x_{t-1} | x_t, x_0) 239 | """ 240 | assert x_start.shape == x_t.shape 241 | posterior_mean = ( 242 | _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start 243 | + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t 244 | ) 245 | posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) 246 | posterior_log_variance_clipped = _extract_into_tensor( 247 | self.posterior_log_variance_clipped, t, x_t.shape 248 | ) 249 | assert ( 250 | posterior_mean.shape[0] 251 | == posterior_variance.shape[0] 252 | == posterior_log_variance_clipped.shape[0] 253 | == x_start.shape[0] 254 | ) 255 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 256 | 257 | def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None, 258 | mask=None, x_start=None, use_concat=False): 259 | """ 260 | Apply the model to get p(x_{t-1} | x_t), as well as a prediction of 261 | the initial x, x_0. 262 | :param model: the model, which takes a signal and a batch of timesteps 263 | as input. 264 | :param x: the [N x C x ...] tensor at time t. 265 | :param t: a 1-D Tensor of timesteps. 266 | :param clip_denoised: if True, clip the denoised signal into [-1, 1]. 267 | :param denoised_fn: if not None, a function which applies to the 268 | x_start prediction before it is used to sample. Applies before 269 | clip_denoised. 270 | :param model_kwargs: if not None, a dict of extra keyword arguments to 271 | pass to the model. This can be used for conditioning. 272 | :return: a dict with the following keys: 273 | - 'mean': the model mean output. 274 | - 'variance': the model variance output. 275 | - 'log_variance': the log of 'variance'. 276 | - 'pred_xstart': the prediction for x_0. 277 | """ 278 | if model_kwargs is None: 279 | model_kwargs = {} 280 | 281 | B, F, C = x.shape[:3] 282 | assert t.shape == (B,) 283 | if use_concat: 284 | model_output = model(th.concat([x, mask, x_start], dim=1), t, **model_kwargs) 285 | else: 286 | model_output = model(x, t, **model_kwargs) 287 | try: 288 | model_output = model_output.sample # for tav unet 289 | except: 290 | pass 291 | # model_output = model(x, t, **model_kwargs) 292 | if isinstance(model_output, tuple): 293 | model_output, extra = model_output 294 | else: 295 | extra = None 296 | 297 | if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: 298 | assert model_output.shape == (B, F, C * 2, *x.shape[3:]) 299 | model_output, model_var_values = th.split(model_output, C, dim=2) 300 | min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) 301 | max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) 302 | # The model_var_values is [-1, 1] for [min_var, max_var]. 303 | frac = (model_var_values + 1) / 2 304 | model_log_variance = frac * max_log + (1 - frac) * min_log 305 | model_variance = th.exp(model_log_variance) 306 | else: 307 | model_variance, model_log_variance = { 308 | # for fixedlarge, we set the initial (log-)variance like so 309 | # to get a better decoder log likelihood. 310 | ModelVarType.FIXED_LARGE: ( 311 | np.append(self.posterior_variance[1], self.betas[1:]), 312 | np.log(np.append(self.posterior_variance[1], self.betas[1:])), 313 | ), 314 | ModelVarType.FIXED_SMALL: ( 315 | self.posterior_variance, 316 | self.posterior_log_variance_clipped, 317 | ), 318 | }[self.model_var_type] 319 | model_variance = _extract_into_tensor(model_variance, t, x.shape) 320 | model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) 321 | 322 | def process_xstart(x): 323 | if denoised_fn is not None: 324 | x = denoised_fn(x) 325 | if clip_denoised: 326 | return x.clamp(-1, 1) 327 | return x 328 | 329 | if self.model_mean_type == ModelMeanType.START_X: 330 | pred_xstart = process_xstart(model_output) 331 | else: 332 | pred_xstart = process_xstart( 333 | self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) 334 | ) 335 | model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) 336 | 337 | assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape 338 | return { 339 | "mean": model_mean, 340 | "variance": model_variance, 341 | "log_variance": model_log_variance, 342 | "pred_xstart": pred_xstart, 343 | "extra": extra, 344 | } 345 | 346 | def _predict_xstart_from_eps(self, x_t, t, eps): 347 | assert x_t.shape == eps.shape 348 | return ( 349 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 350 | - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps 351 | ) 352 | 353 | def _predict_eps_from_xstart(self, x_t, t, pred_xstart): 354 | return ( 355 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart 356 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) 357 | 358 | def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): 359 | """ 360 | Compute the mean for the previous step, given a function cond_fn that 361 | computes the gradient of a conditional log probability with respect to 362 | x. In particular, cond_fn computes grad(log(p(y|x))), and we want to 363 | condition on y. 364 | This uses the conditioning strategy from Sohl-Dickstein et al. (2015). 365 | """ 366 | gradient = cond_fn(x, t, **model_kwargs) 367 | new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() 368 | return new_mean 369 | 370 | def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): 371 | """ 372 | Compute what the p_mean_variance output would have been, should the 373 | model's score function be conditioned by cond_fn. 374 | See condition_mean() for details on cond_fn. 375 | Unlike condition_mean(), this instead uses the conditioning strategy 376 | from Song et al (2020). 377 | """ 378 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) 379 | 380 | eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) 381 | eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs) 382 | 383 | out = p_mean_var.copy() 384 | out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) 385 | out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) 386 | return out 387 | 388 | def p_sample( 389 | self, 390 | model, 391 | x, 392 | t, 393 | clip_denoised=True, 394 | denoised_fn=None, 395 | cond_fn=None, 396 | model_kwargs=None, 397 | mask=None, 398 | x_start=None, 399 | use_concat=False 400 | ): 401 | """ 402 | Sample x_{t-1} from the model at the given timestep. 403 | :param model: the model to sample from. 404 | :param x: the current tensor at x_{t-1}. 405 | :param t: the value of t, starting at 0 for the first diffusion step. 406 | :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. 407 | :param denoised_fn: if not None, a function which applies to the 408 | x_start prediction before it is used to sample. 409 | :param cond_fn: if not None, this is a gradient function that acts 410 | similarly to the model. 411 | :param model_kwargs: if not None, a dict of extra keyword arguments to 412 | pass to the model. This can be used for conditioning. 413 | :return: a dict containing the following keys: 414 | - 'sample': a random sample from the model. 415 | - 'pred_xstart': a prediction of x_0. 416 | """ 417 | out = self.p_mean_variance( 418 | model, 419 | x, 420 | t, 421 | clip_denoised=clip_denoised, 422 | denoised_fn=denoised_fn, 423 | model_kwargs=model_kwargs, 424 | mask=mask, 425 | x_start=x_start, 426 | use_concat=use_concat 427 | ) 428 | noise = th.randn_like(x) 429 | nonzero_mask = ( 430 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) 431 | ) # no noise when t == 0 432 | if cond_fn is not None: 433 | out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) 434 | sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise 435 | return {"sample": sample, "pred_xstart": out["pred_xstart"]} 436 | 437 | def p_sample_loop( 438 | self, 439 | model, 440 | shape, 441 | noise=None, 442 | clip_denoised=True, 443 | denoised_fn=None, 444 | cond_fn=None, 445 | model_kwargs=None, 446 | device=None, 447 | progress=False, 448 | mask=None, 449 | x_start=None, 450 | use_concat=False, 451 | ): 452 | """ 453 | Generate samples from the model. 454 | :param model: the model module. 455 | :param shape: the shape of the samples, (N, C, H, W). 456 | :param noise: if specified, the noise from the encoder to sample. 457 | Should be of the same shape as `shape`. 458 | :param clip_denoised: if True, clip x_start predictions to [-1, 1]. 459 | :param denoised_fn: if not None, a function which applies to the 460 | x_start prediction before it is used to sample. 461 | :param cond_fn: if not None, this is a gradient function that acts 462 | similarly to the model. 463 | :param model_kwargs: if not None, a dict of extra keyword arguments to 464 | pass to the model. This can be used for conditioning. 465 | :param device: if specified, the device to create the samples on. 466 | If not specified, use a model parameter's device. 467 | :param progress: if True, show a tqdm progress bar. 468 | :return: a non-differentiable batch of samples. 469 | """ 470 | final = None 471 | for sample in self.p_sample_loop_progressive( 472 | model, 473 | shape, 474 | noise=noise, 475 | clip_denoised=clip_denoised, 476 | denoised_fn=denoised_fn, 477 | cond_fn=cond_fn, 478 | model_kwargs=model_kwargs, 479 | device=device, 480 | progress=progress, 481 | mask=mask, 482 | x_start=x_start, 483 | use_concat=use_concat 484 | ): 485 | final = sample 486 | return final["sample"] 487 | 488 | def p_sample_loop_progressive( 489 | self, 490 | model, 491 | shape, 492 | noise=None, 493 | clip_denoised=True, 494 | denoised_fn=None, 495 | cond_fn=None, 496 | model_kwargs=None, 497 | device=None, 498 | progress=False, 499 | mask=None, 500 | x_start=None, 501 | use_concat=False 502 | ): 503 | """ 504 | Generate samples from the model and yield intermediate samples from 505 | each timestep of diffusion. 506 | Arguments are the same as p_sample_loop(). 507 | Returns a generator over dicts, where each dict is the return value of 508 | p_sample(). 509 | """ 510 | if device is None: 511 | device = next(model.parameters()).device 512 | assert isinstance(shape, (tuple, list)) 513 | if noise is not None: 514 | img = noise 515 | else: 516 | img = th.randn(*shape, device=device) 517 | indices = list(range(self.num_timesteps))[::-1] 518 | 519 | if progress: 520 | # Lazy import so that we don't depend on tqdm. 521 | from tqdm.auto import tqdm 522 | 523 | indices = tqdm(indices) 524 | 525 | for i in indices: 526 | t = th.tensor([i] * shape[0], device=device) 527 | with th.no_grad(): 528 | out = self.p_sample( 529 | model, 530 | img, 531 | t, 532 | clip_denoised=clip_denoised, 533 | denoised_fn=denoised_fn, 534 | cond_fn=cond_fn, 535 | model_kwargs=model_kwargs, 536 | mask=mask, 537 | x_start=x_start, 538 | use_concat=use_concat 539 | ) 540 | yield out 541 | img = out["sample"] 542 | 543 | def ddim_sample( 544 | self, 545 | model, 546 | x, 547 | t, 548 | clip_denoised=True, 549 | denoised_fn=None, 550 | cond_fn=None, 551 | model_kwargs=None, 552 | eta=0.0, 553 | mask=None, 554 | x_start=None, 555 | use_concat=False 556 | ): 557 | """ 558 | Sample x_{t-1} from the model using DDIM. 559 | Same usage as p_sample(). 560 | """ 561 | out = self.p_mean_variance( 562 | model, 563 | x, 564 | t, 565 | clip_denoised=clip_denoised, 566 | denoised_fn=denoised_fn, 567 | model_kwargs=model_kwargs, 568 | mask=mask, 569 | x_start=x_start, 570 | use_concat=use_concat 571 | ) 572 | if cond_fn is not None: 573 | out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) 574 | 575 | # Usually our model outputs epsilon, but we re-derive it 576 | # in case we used x_start or x_prev prediction. 577 | eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) 578 | 579 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) 580 | alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) 581 | sigma = ( 582 | eta 583 | * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) 584 | * th.sqrt(1 - alpha_bar / alpha_bar_prev) 585 | ) 586 | # Equation 12. 587 | noise = th.randn_like(x) 588 | mean_pred = ( 589 | out["pred_xstart"] * th.sqrt(alpha_bar_prev) 590 | + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps 591 | ) 592 | nonzero_mask = ( 593 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) 594 | ) # no noise when t == 0 595 | sample = mean_pred + nonzero_mask * sigma * noise 596 | return {"sample": sample, "pred_xstart": out["pred_xstart"]} 597 | 598 | def ddim_reverse_sample( 599 | self, 600 | model, 601 | x, 602 | t, 603 | clip_denoised=True, 604 | denoised_fn=None, 605 | cond_fn=None, 606 | model_kwargs=None, 607 | eta=0.0, 608 | ): 609 | """ 610 | Sample x_{t+1} from the model using DDIM reverse ODE. 611 | """ 612 | assert eta == 0.0, "Reverse ODE only for deterministic path" 613 | out = self.p_mean_variance( 614 | model, 615 | x, 616 | t, 617 | clip_denoised=clip_denoised, 618 | denoised_fn=denoised_fn, 619 | model_kwargs=model_kwargs, 620 | ) 621 | if cond_fn is not None: 622 | out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) 623 | # Usually our model outputs epsilon, but we re-derive it 624 | # in case we used x_start or x_prev prediction. 625 | eps = ( 626 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x 627 | - out["pred_xstart"] 628 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) 629 | alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) 630 | 631 | # Equation 12. reversed 632 | mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps 633 | 634 | return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} 635 | 636 | def ddim_sample_loop( 637 | self, 638 | model, 639 | shape, 640 | noise=None, 641 | clip_denoised=True, 642 | denoised_fn=None, 643 | cond_fn=None, 644 | model_kwargs=None, 645 | device=None, 646 | progress=False, 647 | eta=0.0, 648 | mask=None, 649 | x_start=None, 650 | use_concat=False 651 | ): 652 | """ 653 | Generate samples from the model using DDIM. 654 | Same usage as p_sample_loop(). 655 | """ 656 | final = None 657 | for sample in self.ddim_sample_loop_progressive( 658 | model, 659 | shape, 660 | noise=noise, 661 | clip_denoised=clip_denoised, 662 | denoised_fn=denoised_fn, 663 | cond_fn=cond_fn, 664 | model_kwargs=model_kwargs, 665 | device=device, 666 | progress=progress, 667 | eta=eta, 668 | mask=mask, 669 | x_start=x_start, 670 | use_concat=use_concat 671 | ): 672 | final = sample 673 | return final["sample"] 674 | 675 | def ddim_sample_loop_progressive( 676 | self, 677 | model, 678 | shape, 679 | noise=None, 680 | clip_denoised=True, 681 | denoised_fn=None, 682 | cond_fn=None, 683 | model_kwargs=None, 684 | device=None, 685 | progress=False, 686 | eta=0.0, 687 | mask=None, 688 | x_start=None, 689 | use_concat=False 690 | ): 691 | """ 692 | Use DDIM to sample from the model and yield intermediate samples from 693 | each timestep of DDIM. 694 | Same usage as p_sample_loop_progressive(). 695 | """ 696 | if device is None: 697 | device = next(model.parameters()).device 698 | assert isinstance(shape, (tuple, list)) 699 | if noise is not None: 700 | img = noise 701 | else: 702 | img = th.randn(*shape, device=device) 703 | indices = list(range(self.num_timesteps))[::-1] 704 | 705 | if progress: 706 | # Lazy import so that we don't depend on tqdm. 707 | from tqdm.auto import tqdm 708 | 709 | indices = tqdm(indices) 710 | 711 | for i in indices: 712 | t = th.tensor([i] * shape[0], device=device) 713 | with th.no_grad(): 714 | out = self.ddim_sample( 715 | model, 716 | img, 717 | t, 718 | clip_denoised=clip_denoised, 719 | denoised_fn=denoised_fn, 720 | cond_fn=cond_fn, 721 | model_kwargs=model_kwargs, 722 | eta=eta, 723 | mask=mask, 724 | x_start=x_start, 725 | use_concat=use_concat 726 | ) 727 | yield out 728 | img = out["sample"] 729 | 730 | def _vb_terms_bpd( 731 | self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None 732 | ): 733 | """ 734 | Get a term for the variational lower-bound. 735 | The resulting units are bits (rather than nats, as one might expect). 736 | This allows for comparison to other papers. 737 | :return: a dict with the following keys: 738 | - 'output': a shape [N] tensor of NLLs or KLs. 739 | - 'pred_xstart': the x_0 predictions. 740 | """ 741 | true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( 742 | x_start=x_start, x_t=x_t, t=t 743 | ) 744 | out = self.p_mean_variance( 745 | model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs 746 | ) 747 | kl = normal_kl( 748 | true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] 749 | ) 750 | kl = mean_flat(kl) / np.log(2.0) 751 | 752 | decoder_nll = -discretized_gaussian_log_likelihood( 753 | x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] 754 | ) 755 | assert decoder_nll.shape == x_start.shape 756 | decoder_nll = mean_flat(decoder_nll) / np.log(2.0) 757 | 758 | # At the first timestep return the decoder NLL, 759 | # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) 760 | output = th.where((t == 0), decoder_nll, kl) 761 | return {"output": output, "pred_xstart": out["pred_xstart"]} 762 | 763 | def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, use_mask=False): 764 | """ 765 | Compute training losses for a single timestep. 766 | :param model: the model to evaluate loss on. 767 | :param x_start: the [N x C x ...] tensor of inputs. 768 | :param t: a batch of timestep indices. 769 | :param model_kwargs: if not None, a dict of extra keyword arguments to 770 | pass to the model. This can be used for conditioning. 771 | :param noise: if specified, the specific Gaussian noise to try to remove. 772 | :return: a dict with the key "loss" containing a tensor of shape [N]. 773 | Some mean or variance settings may also have other keys. 774 | """ 775 | if model_kwargs is None: 776 | model_kwargs = {} 777 | if noise is None: 778 | noise = th.randn_like(x_start) 779 | x_t = self.q_sample(x_start, t, noise=noise) 780 | if use_mask: 781 | x_t = th.cat([x_t[:, :4], x_start[:, 4:]], dim=1) 782 | terms = {} 783 | 784 | if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: 785 | terms["loss"] = self._vb_terms_bpd( 786 | model=model, 787 | x_start=x_start, 788 | x_t=x_t, 789 | t=t, 790 | clip_denoised=False, 791 | model_kwargs=model_kwargs, 792 | )["output"] 793 | if self.loss_type == LossType.RESCALED_KL: 794 | terms["loss"] *= self.num_timesteps 795 | elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: 796 | model_output = model(x_t, t, **model_kwargs) 797 | try: 798 | # model_output = model(x_t, t, **model_kwargs).sample 799 | model_output = model_output.sample # for tav unet 800 | except: 801 | pass 802 | # model_output = model(x_t, t, **model_kwargs) 803 | 804 | if self.model_var_type in [ 805 | ModelVarType.LEARNED, 806 | ModelVarType.LEARNED_RANGE, 807 | ]: 808 | B, F, C = x_t.shape[:3] 809 | assert model_output.shape == (B, F, C * 2, *x_t.shape[3:]) 810 | model_output, model_var_values = th.split(model_output, C, dim=2) 811 | # Learn the variance using the variational bound, but don't let 812 | # it affect our mean prediction. 813 | frozen_out = th.cat([model_output.detach(), model_var_values], dim=2) 814 | terms["vb"] = self._vb_terms_bpd( 815 | model=lambda *args, r=frozen_out: r, 816 | x_start=x_start, 817 | x_t=x_t, 818 | t=t, 819 | clip_denoised=False, 820 | )["output"] 821 | if self.loss_type == LossType.RESCALED_MSE: 822 | # Divide by 1000 for equivalence with initial implementation. 823 | # Without a factor of 1/1000, the VB term hurts the MSE term. 824 | terms["vb"] *= self.num_timesteps / 1000.0 825 | 826 | target = { 827 | ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( 828 | x_start=x_start, x_t=x_t, t=t 829 | )[0], 830 | ModelMeanType.START_X: x_start, 831 | ModelMeanType.EPSILON: noise, 832 | }[self.model_mean_type] 833 | # assert model_output.shape == target.shape == x_start.shape 834 | if use_mask: 835 | terms["mse"] = mean_flat((target[:,:4] - model_output) ** 2) 836 | else: 837 | terms["mse"] = mean_flat((target - model_output) ** 2) 838 | if "vb" in terms: 839 | terms["loss"] = terms["mse"] + terms["vb"] 840 | else: 841 | terms["loss"] = terms["mse"] 842 | else: 843 | raise NotImplementedError(self.loss_type) 844 | 845 | return terms 846 | 847 | def _prior_bpd(self, x_start): 848 | """ 849 | Get the prior KL term for the variational lower-bound, measured in 850 | bits-per-dim. 851 | This term can't be optimized, as it only depends on the encoder. 852 | :param x_start: the [N x C x ...] tensor of inputs. 853 | :return: a batch of [N] KL values (in bits), one per batch element. 854 | """ 855 | batch_size = x_start.shape[0] 856 | t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) 857 | qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) 858 | kl_prior = normal_kl( 859 | mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 860 | ) 861 | return mean_flat(kl_prior) / np.log(2.0) 862 | 863 | def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): 864 | """ 865 | Compute the entire variational lower-bound, measured in bits-per-dim, 866 | as well as other related quantities. 867 | :param model: the model to evaluate loss on. 868 | :param x_start: the [N x C x ...] tensor of inputs. 869 | :param clip_denoised: if True, clip denoised samples. 870 | :param model_kwargs: if not None, a dict of extra keyword arguments to 871 | pass to the model. This can be used for conditioning. 872 | :return: a dict containing the following keys: 873 | - total_bpd: the total variational lower-bound, per batch element. 874 | - prior_bpd: the prior term in the lower-bound. 875 | - vb: an [N x T] tensor of terms in the lower-bound. 876 | - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. 877 | - mse: an [N x T] tensor of epsilon MSEs for each timestep. 878 | """ 879 | device = x_start.device 880 | batch_size = x_start.shape[0] 881 | 882 | vb = [] 883 | xstart_mse = [] 884 | mse = [] 885 | for t in list(range(self.num_timesteps))[::-1]: 886 | t_batch = th.tensor([t] * batch_size, device=device) 887 | noise = th.randn_like(x_start) 888 | x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) 889 | # Calculate VLB term at the current timestep 890 | with th.no_grad(): 891 | out = self._vb_terms_bpd( 892 | model, 893 | x_start=x_start, 894 | x_t=x_t, 895 | t=t_batch, 896 | clip_denoised=clip_denoised, 897 | model_kwargs=model_kwargs, 898 | ) 899 | vb.append(out["output"]) 900 | xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) 901 | eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) 902 | mse.append(mean_flat((eps - noise) ** 2)) 903 | 904 | vb = th.stack(vb, dim=1) 905 | xstart_mse = th.stack(xstart_mse, dim=1) 906 | mse = th.stack(mse, dim=1) 907 | 908 | prior_bpd = self._prior_bpd(x_start) 909 | total_bpd = vb.sum(dim=1) + prior_bpd 910 | return { 911 | "total_bpd": total_bpd, 912 | "prior_bpd": prior_bpd, 913 | "vb": vb, 914 | "xstart_mse": xstart_mse, 915 | "mse": mse, 916 | } 917 | 918 | 919 | def _extract_into_tensor(arr, timesteps, broadcast_shape): 920 | """ 921 | Extract values from a 1-D numpy array for a batch of indices. 922 | :param arr: the 1-D numpy array. 923 | :param timesteps: a tensor of indices into the array to extract. 924 | :param broadcast_shape: a larger shape of K dimensions with the batch 925 | dimension equal to the length of timesteps. 926 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. 927 | """ 928 | res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() 929 | while len(res.shape) < len(broadcast_shape): 930 | res = res[..., None] 931 | return res + th.zeros(broadcast_shape, device=timesteps.device) 932 | -------------------------------------------------------------------------------- /diffusion/respace.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | import torch 6 | import numpy as np 7 | import torch as th 8 | 9 | from .gaussian_diffusion import GaussianDiffusion 10 | 11 | 12 | def space_timesteps(num_timesteps, section_counts): 13 | """ 14 | Create a list of timesteps to use from an original diffusion process, 15 | given the number of timesteps we want to take from equally-sized portions 16 | of the original process. 17 | For example, if there's 300 timesteps and the section counts are [10,15,20] 18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 19 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 20 | If the stride is a string starting with "ddim", then the fixed striding 21 | from the DDIM paper is used, and only one section is allowed. 22 | :param num_timesteps: the number of diffusion steps in the original 23 | process to divide up. 24 | :param section_counts: either a list of numbers, or a string containing 25 | comma-separated numbers, indicating the step count 26 | per section. As a special case, use "ddimN" where N 27 | is a number of steps to use the striding from the 28 | DDIM paper. 29 | :return: a set of diffusion steps from the original process to use. 30 | """ 31 | if isinstance(section_counts, str): 32 | if section_counts.startswith("ddim"): 33 | desired_count = int(section_counts[len("ddim") :]) 34 | for i in range(1, num_timesteps): 35 | if len(range(0, num_timesteps, i)) == desired_count: 36 | return set(range(0, num_timesteps, i)) 37 | raise ValueError( 38 | f"cannot create exactly {num_timesteps} steps with an integer stride" 39 | ) 40 | section_counts = [int(x) for x in section_counts.split(",")] 41 | size_per = num_timesteps // len(section_counts) 42 | extra = num_timesteps % len(section_counts) 43 | start_idx = 0 44 | all_steps = [] 45 | for i, section_count in enumerate(section_counts): 46 | size = size_per + (1 if i < extra else 0) 47 | if size < section_count: 48 | raise ValueError( 49 | f"cannot divide section of {size} steps into {section_count}" 50 | ) 51 | if section_count <= 1: 52 | frac_stride = 1 53 | else: 54 | frac_stride = (size - 1) / (section_count - 1) 55 | cur_idx = 0.0 56 | taken_steps = [] 57 | for _ in range(section_count): 58 | taken_steps.append(start_idx + round(cur_idx)) 59 | cur_idx += frac_stride 60 | all_steps += taken_steps 61 | start_idx += size 62 | return set(all_steps) 63 | 64 | 65 | class SpacedDiffusion(GaussianDiffusion): 66 | """ 67 | A diffusion process which can skip steps in a base diffusion process. 68 | :param use_timesteps: a collection (sequence or set) of timesteps from the 69 | original diffusion process to retain. 70 | :param kwargs: the kwargs to create the base diffusion process. 71 | """ 72 | 73 | def __init__(self, use_timesteps, **kwargs): 74 | self.use_timesteps = set(use_timesteps) 75 | self.timestep_map = [] 76 | self.original_num_steps = len(kwargs["betas"]) 77 | 78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 79 | last_alpha_cumprod = 1.0 80 | new_betas = [] 81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 82 | if i in self.use_timesteps: 83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 84 | last_alpha_cumprod = alpha_cumprod 85 | self.timestep_map.append(i) 86 | kwargs["betas"] = np.array(new_betas) 87 | super().__init__(**kwargs) 88 | 89 | def p_mean_variance( 90 | self, model, *args, **kwargs 91 | ): # pylint: disable=signature-differs 92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 93 | 94 | # @torch.compile 95 | def training_losses( 96 | self, model, *args, **kwargs 97 | ): # pylint: disable=signature-differs 98 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 99 | 100 | def condition_mean(self, cond_fn, *args, **kwargs): 101 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 102 | 103 | def condition_score(self, cond_fn, *args, **kwargs): 104 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 105 | 106 | def _wrap_model(self, model): 107 | if isinstance(model, _WrappedModel): 108 | return model 109 | return _WrappedModel( 110 | model, self.timestep_map, self.original_num_steps 111 | ) 112 | 113 | def _scale_timesteps(self, t): 114 | # Scaling is done by the wrapped model. 115 | return t 116 | 117 | 118 | class _WrappedModel: 119 | def __init__(self, model, timestep_map, original_num_steps): 120 | self.model = model 121 | self.timestep_map = timestep_map 122 | # self.rescale_timesteps = rescale_timesteps 123 | self.original_num_steps = original_num_steps 124 | 125 | def __call__(self, x, ts, **kwargs): 126 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 127 | new_ts = map_tensor[ts] 128 | # if self.rescale_timesteps: 129 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 130 | return self.model(x, new_ts, **kwargs) 131 | -------------------------------------------------------------------------------- /diffusion/timestep_sampler.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from abc import ABC, abstractmethod 7 | 8 | import numpy as np 9 | import torch as th 10 | import torch.distributed as dist 11 | 12 | 13 | def create_named_schedule_sampler(name, diffusion): 14 | """ 15 | Create a ScheduleSampler from a library of pre-defined samplers. 16 | :param name: the name of the sampler. 17 | :param diffusion: the diffusion object to sample for. 18 | """ 19 | if name == "uniform": 20 | return UniformSampler(diffusion) 21 | elif name == "loss-second-moment": 22 | return LossSecondMomentResampler(diffusion) 23 | else: 24 | raise NotImplementedError(f"unknown schedule sampler: {name}") 25 | 26 | 27 | class ScheduleSampler(ABC): 28 | """ 29 | A distribution over timesteps in the diffusion process, intended to reduce 30 | variance of the objective. 31 | By default, samplers perform unbiased importance sampling, in which the 32 | objective's mean is unchanged. 33 | However, subclasses may override sample() to change how the resampled 34 | terms are reweighted, allowing for actual changes in the objective. 35 | """ 36 | 37 | @abstractmethod 38 | def weights(self): 39 | """ 40 | Get a numpy array of weights, one per diffusion step. 41 | The weights needn't be normalized, but must be positive. 42 | """ 43 | 44 | def sample(self, batch_size, device): 45 | """ 46 | Importance-sample timesteps for a batch. 47 | :param batch_size: the number of timesteps. 48 | :param device: the torch device to save to. 49 | :return: a tuple (timesteps, weights): 50 | - timesteps: a tensor of timestep indices. 51 | - weights: a tensor of weights to scale the resulting losses. 52 | """ 53 | w = self.weights() 54 | p = w / np.sum(w) 55 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 56 | indices = th.from_numpy(indices_np).long().to(device) 57 | weights_np = 1 / (len(p) * p[indices_np]) 58 | weights = th.from_numpy(weights_np).float().to(device) 59 | return indices, weights 60 | 61 | 62 | class UniformSampler(ScheduleSampler): 63 | def __init__(self, diffusion): 64 | self.diffusion = diffusion 65 | self._weights = np.ones([diffusion.num_timesteps]) 66 | 67 | def weights(self): 68 | return self._weights 69 | 70 | 71 | class LossAwareSampler(ScheduleSampler): 72 | def update_with_local_losses(self, local_ts, local_losses): 73 | """ 74 | Update the reweighting using losses from a model. 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | :param local_ts: an integer Tensor of timesteps. 80 | :param local_losses: a 1D Tensor of losses. 81 | """ 82 | batch_sizes = [ 83 | th.tensor([0], dtype=th.int32, device=local_ts.device) 84 | for _ in range(dist.get_world_size()) 85 | ] 86 | dist.all_gather( 87 | batch_sizes, 88 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 89 | ) 90 | 91 | # Pad all_gather batches to be the maximum batch size. 92 | batch_sizes = [x.item() for x in batch_sizes] 93 | max_bs = max(batch_sizes) 94 | 95 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 96 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 97 | dist.all_gather(timestep_batches, local_ts) 98 | dist.all_gather(loss_batches, local_losses) 99 | timesteps = [ 100 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 101 | ] 102 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 103 | self.update_with_all_losses(timesteps, losses) 104 | 105 | @abstractmethod 106 | def update_with_all_losses(self, ts, losses): 107 | """ 108 | Update the reweighting using losses from a model. 109 | Sub-classes should override this method to update the reweighting 110 | using losses from the model. 111 | This method directly updates the reweighting without synchronizing 112 | between workers. It is called by update_with_local_losses from all 113 | ranks with identical arguments. Thus, it should have deterministic 114 | behavior to maintain state across workers. 115 | :param ts: a list of int timesteps. 116 | :param losses: a list of float losses, one per timestep. 117 | """ 118 | 119 | 120 | class LossSecondMomentResampler(LossAwareSampler): 121 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 122 | self.diffusion = diffusion 123 | self.history_per_term = history_per_term 124 | self.uniform_prob = uniform_prob 125 | self._loss_history = np.zeros( 126 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 127 | ) 128 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 129 | 130 | def weights(self): 131 | if not self._warmed_up(): 132 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 133 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 134 | weights /= np.sum(weights) 135 | weights *= 1 - self.uniform_prob 136 | weights += self.uniform_prob / len(weights) 137 | return weights 138 | 139 | def update_with_all_losses(self, ts, losses): 140 | for t, loss in zip(ts, losses): 141 | if self._loss_counts[t] == self.history_per_term: 142 | # Shift out the oldest loss term. 143 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 144 | self._loss_history[t, -1] = loss 145 | else: 146 | self._loss_history[t, self._loss_counts[t]] = loss 147 | self._loss_counts[t] += 1 148 | 149 | def _warmed_up(self): 150 | return (self._loss_counts == self.history_per_term).all() 151 | -------------------------------------------------------------------------------- /examples/Close-up essence is poured from bottleKodak Vision3 50,slow motion_0000_001.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vchitect/SEINE/bd8d519ef4e7ce761f0b1064f906bb96f4cd3ebe/examples/Close-up essence is poured from bottleKodak Vision3 50,slow motion_0000_001.gif -------------------------------------------------------------------------------- /examples/The picture shows the beauty of the sea and at the sam,slow motion_0000_11301.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vchitect/SEINE/bd8d519ef4e7ce761f0b1064f906bb96f4cd3ebe/examples/The picture shows the beauty of the sea and at the sam,slow motion_0000_11301.gif -------------------------------------------------------------------------------- /examples/The picture shows the beauty of the sea and at the sam,slow motion_0000_6600.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vchitect/SEINE/bd8d519ef4e7ce761f0b1064f906bb96f4cd3ebe/examples/The picture shows the beauty of the sea and at the sam,slow motion_0000_6600.gif -------------------------------------------------------------------------------- /examples/Travel from Earth's spring blossoms to the alien cherry blossom forestssmooth transition, slow motion_0000_003.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vchitect/SEINE/bd8d519ef4e7ce761f0b1064f906bb96f4cd3ebe/examples/Travel from Earth's spring blossoms to the alien cherry blossom forestssmooth transition, slow motion_0000_003.gif -------------------------------------------------------------------------------- /examples/flying through fantasy landscapes in the cloud, 4k, high resolution._0000_885.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vchitect/SEINE/bd8d519ef4e7ce761f0b1064f906bb96f4cd3ebe/examples/flying through fantasy landscapes in the cloud, 4k, high resolution._0000_885.gif -------------------------------------------------------------------------------- /examples/orange-flower.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vchitect/SEINE/bd8d519ef4e7ce761f0b1064f906bb96f4cd3ebe/examples/orange-flower.gif -------------------------------------------------------------------------------- /examples/spiderman-becomes-a-sand-sculpture.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vchitect/SEINE/bd8d519ef4e7ce761f0b1064f906bb96f4cd3ebe/examples/spiderman-becomes-a-sand-sculpture.gif -------------------------------------------------------------------------------- /input/i2v/Close-up_essence_is_poured_from_bottleKodak_Vision.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vchitect/SEINE/bd8d519ef4e7ce761f0b1064f906bb96f4cd3ebe/input/i2v/Close-up_essence_is_poured_from_bottleKodak_Vision.png -------------------------------------------------------------------------------- /input/i2v/The_picture_shows_the_beauty_of_the_sea.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vchitect/SEINE/bd8d519ef4e7ce761f0b1064f906bb96f4cd3ebe/input/i2v/The_picture_shows_the_beauty_of_the_sea.png -------------------------------------------------------------------------------- /input/i2v/The_picture_shows_the_beauty_of_the_sea_and_at_the_same.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vchitect/SEINE/bd8d519ef4e7ce761f0b1064f906bb96f4cd3ebe/input/i2v/The_picture_shows_the_beauty_of_the_sea_and_at_the_same.png -------------------------------------------------------------------------------- /input/transition/1/1-Close-up shot of a blooming cherry tree, realism-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vchitect/SEINE/bd8d519ef4e7ce761f0b1064f906bb96f4cd3ebe/input/transition/1/1-Close-up shot of a blooming cherry tree, realism-1.png -------------------------------------------------------------------------------- /input/transition/1/2-Wide angle shot of an alien planet with cherry blossom forest-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vchitect/SEINE/bd8d519ef4e7ce761f0b1064f906bb96f4cd3ebe/input/transition/1/2-Wide angle shot of an alien planet with cherry blossom forest-2.png -------------------------------------------------------------------------------- /input/transition/2/1-Overhead view of a bustling city street at night, realism-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vchitect/SEINE/bd8d519ef4e7ce761f0b1064f906bb96f4cd3ebe/input/transition/2/1-Overhead view of a bustling city street at night, realism-1.png -------------------------------------------------------------------------------- /input/transition/2/2-Aerial view of a futuristic city bathed in neon lights-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vchitect/SEINE/bd8d519ef4e7ce761f0b1064f906bb96f4cd3ebe/input/transition/2/2-Aerial view of a futuristic city bathed in neon lights-2.png -------------------------------------------------------------------------------- /input/transition/3/1-Close-up shot of a candle lit in the darkness, realism-1png.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vchitect/SEINE/bd8d519ef4e7ce761f0b1064f906bb96f4cd3ebe/input/transition/3/1-Close-up shot of a candle lit in the darkness, realism-1png.png -------------------------------------------------------------------------------- /input/transition/3/2-Wide shot of a mystical land illuminated by giant glowing flowers-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vchitect/SEINE/bd8d519ef4e7ce761f0b1064f906bb96f4cd3ebe/input/transition/3/2-Wide shot of a mystical land illuminated by giant glowing flowers-2.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.split(sys.path[0])[0]) 4 | 5 | from .unet import UNet3DConditionModel 6 | from torch.optim.lr_scheduler import LambdaLR 7 | 8 | def customized_lr_scheduler(optimizer, warmup_steps=5000): # 5000 from u-vit 9 | from torch.optim.lr_scheduler import LambdaLR 10 | def fn(step): 11 | if warmup_steps > 0: 12 | return min(step / warmup_steps, 1) 13 | else: 14 | return 1 15 | return LambdaLR(optimizer, fn) 16 | 17 | 18 | def get_lr_scheduler(optimizer, name, **kwargs): 19 | if name == 'warmup': 20 | return customized_lr_scheduler(optimizer, **kwargs) 21 | elif name == 'cosine': 22 | from torch.optim.lr_scheduler import CosineAnnealingLR 23 | return CosineAnnealingLR(optimizer, **kwargs) 24 | else: 25 | raise NotImplementedError(name) 26 | 27 | def get_models(args): 28 | if 'UNet' in args.model: 29 | pretrained_model_path = args.pretrained_model_path 30 | return UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", use_concat=args.use_mask) 31 | else: 32 | raise '{} Model Not Supported!'.format(args.model) 33 | -------------------------------------------------------------------------------- /models/clip.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import torch.nn as nn 3 | from transformers import CLIPTokenizer, CLIPTextModel 4 | 5 | import transformers 6 | transformers.logging.set_verbosity_error() 7 | 8 | """ 9 | Will encounter following warning: 10 | - This IS expected if you are initializing CLIPTextModel from the checkpoint of a model trained on another task 11 | or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). 12 | - This IS NOT expected if you are initializing CLIPTextModel from the checkpoint of a model 13 | that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). 14 | 15 | https://github.com/CompVis/stable-diffusion/issues/97 16 | according to this issue, this warning is safe. 17 | 18 | This is expected since the vision backbone of the CLIP model is not needed to run Stable Diffusion. 19 | You can safely ignore the warning, it is not an error. 20 | 21 | This clip usage is from U-ViT and same with Stable Diffusion. 22 | """ 23 | 24 | class AbstractEncoder(nn.Module): 25 | def __init__(self): 26 | super().__init__() 27 | 28 | def encode(self, *args, **kwargs): 29 | raise NotImplementedError 30 | 31 | 32 | class FrozenCLIPEmbedder(AbstractEncoder): 33 | """Uses the CLIP transformer encoder for text (from Hugging Face)""" 34 | # def __init__(self, version="openai/clip-vit-huge-patch14", device="cuda", max_length=77): 35 | def __init__(self, path, device="cuda", max_length=77): 36 | super().__init__() 37 | self.tokenizer = CLIPTokenizer.from_pretrained(path, subfolder="tokenizer") 38 | self.transformer = CLIPTextModel.from_pretrained(path, subfolder='text_encoder') 39 | self.device = device 40 | self.max_length = max_length 41 | self.freeze() 42 | 43 | def freeze(self): 44 | self.transformer = self.transformer.eval() 45 | for param in self.parameters(): 46 | param.requires_grad = False 47 | 48 | def forward(self, text): 49 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 50 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 51 | tokens = batch_encoding["input_ids"].to(self.device) 52 | outputs = self.transformer(input_ids=tokens) 53 | 54 | z = outputs.last_hidden_state 55 | return z 56 | 57 | def encode(self, text): 58 | return self(text) 59 | 60 | 61 | class TextEmbedder(nn.Module): 62 | """ 63 | Embeds text prompt into vector representations. Also handles text dropout for classifier-free guidance. 64 | """ 65 | def __init__(self, path, dropout_prob=0.1): 66 | super().__init__() 67 | self.text_encodder = FrozenCLIPEmbedder(path=path) 68 | self.dropout_prob = dropout_prob 69 | 70 | def token_drop(self, text_prompts, force_drop_ids=None): 71 | """ 72 | Drops text to enable classifier-free guidance. 73 | """ 74 | if force_drop_ids is None: 75 | drop_ids = numpy.random.uniform(0, 1, len(text_prompts)) < self.dropout_prob 76 | else: 77 | # TODO 78 | drop_ids = force_drop_ids == 1 79 | labels = list(numpy.where(drop_ids, "", text_prompts)) 80 | # print(labels) 81 | return labels 82 | 83 | def forward(self, text_prompts, train, force_drop_ids=None): 84 | use_dropout = self.dropout_prob > 0 85 | if (train and use_dropout) or (force_drop_ids is not None): 86 | text_prompts = self.token_drop(text_prompts, force_drop_ids) 87 | embeddings = self.text_encodder(text_prompts) 88 | return embeddings 89 | 90 | 91 | if __name__ == '__main__': 92 | 93 | r""" 94 | Returns: 95 | 96 | Examples from CLIPTextModel: 97 | 98 | ```python 99 | >>> from transformers import AutoTokenizer, CLIPTextModel 100 | 101 | >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") 102 | >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") 103 | 104 | >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") 105 | 106 | >>> outputs = model(**inputs) 107 | >>> last_hidden_state = outputs.last_hidden_state 108 | >>> pooled_output = outputs.pooler_output # pooled (EOS token) states 109 | ```""" 110 | 111 | import torch 112 | 113 | device = "cuda" if torch.cuda.is_available() else "cpu" 114 | 115 | text_encoder = TextEmbedder(path='/mnt/petrelfs/maxin/work/pretrained/stable-diffusion-2-1-base', 116 | dropout_prob=0.00001).to(device) 117 | 118 | text_prompt = [["a photo of a cat", "a photo of a cat"], ["a photo of a dog", "a photo of a cat"], ['a photo of a dog human', "a photo of a cat"]] 119 | # text_prompt = ('None', 'None', 'None') 120 | output = text_encoder(text_prompts=text_prompt, train=False) 121 | # print(output) 122 | print(output.shape) 123 | # print(output.shape) -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py 2 | import os 3 | import sys 4 | sys.path.append(os.path.split(sys.path[0])[0]) 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from einops import rearrange 11 | 12 | 13 | class InflatedConv3d(nn.Conv2d): 14 | def forward(self, x): 15 | video_length = x.shape[2] 16 | 17 | x = rearrange(x, "b c f h w -> (b f) c h w") 18 | x = super().forward(x) 19 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 20 | 21 | return x 22 | 23 | 24 | class Upsample3D(nn.Module): 25 | def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): 26 | super().__init__() 27 | self.channels = channels 28 | self.out_channels = out_channels or channels 29 | self.use_conv = use_conv 30 | self.use_conv_transpose = use_conv_transpose 31 | self.name = name 32 | 33 | conv = None 34 | if use_conv_transpose: 35 | raise NotImplementedError 36 | elif use_conv: 37 | conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) 38 | 39 | if name == "conv": 40 | self.conv = conv 41 | else: 42 | self.Conv2d_0 = conv 43 | 44 | def forward(self, hidden_states, output_size=None): 45 | assert hidden_states.shape[1] == self.channels 46 | 47 | if self.use_conv_transpose: 48 | raise NotImplementedError 49 | 50 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 51 | dtype = hidden_states.dtype 52 | if dtype == torch.bfloat16: 53 | hidden_states = hidden_states.to(torch.float32) 54 | 55 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 56 | if hidden_states.shape[0] >= 64: 57 | hidden_states = hidden_states.contiguous() 58 | 59 | # if `output_size` is passed we force the interpolation output 60 | # size and do not make use of `scale_factor=2` 61 | if output_size is None: 62 | hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") 63 | else: 64 | hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") 65 | 66 | # If the input is bfloat16, we cast back to bfloat16 67 | if dtype == torch.bfloat16: 68 | hidden_states = hidden_states.to(dtype) 69 | 70 | if self.use_conv: 71 | if self.name == "conv": 72 | hidden_states = self.conv(hidden_states) 73 | else: 74 | hidden_states = self.Conv2d_0(hidden_states) 75 | 76 | return hidden_states 77 | 78 | 79 | class Downsample3D(nn.Module): 80 | def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): 81 | super().__init__() 82 | self.channels = channels 83 | self.out_channels = out_channels or channels 84 | self.use_conv = use_conv 85 | self.padding = padding 86 | stride = 2 87 | self.name = name 88 | 89 | if use_conv: 90 | conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) 91 | else: 92 | raise NotImplementedError 93 | 94 | if name == "conv": 95 | self.Conv2d_0 = conv 96 | self.conv = conv 97 | elif name == "Conv2d_0": 98 | self.conv = conv 99 | else: 100 | self.conv = conv 101 | 102 | def forward(self, hidden_states): 103 | assert hidden_states.shape[1] == self.channels 104 | if self.use_conv and self.padding == 0: 105 | raise NotImplementedError 106 | 107 | assert hidden_states.shape[1] == self.channels 108 | hidden_states = self.conv(hidden_states) 109 | 110 | return hidden_states 111 | 112 | 113 | class ResnetBlock3D(nn.Module): 114 | def __init__( 115 | self, 116 | *, 117 | in_channels, 118 | out_channels=None, 119 | conv_shortcut=False, 120 | dropout=0.0, 121 | temb_channels=512, 122 | groups=32, 123 | groups_out=None, 124 | pre_norm=True, 125 | eps=1e-6, 126 | non_linearity="swish", 127 | time_embedding_norm="default", 128 | output_scale_factor=1.0, 129 | use_in_shortcut=None, 130 | ): 131 | super().__init__() 132 | self.pre_norm = pre_norm 133 | self.pre_norm = True 134 | self.in_channels = in_channels 135 | out_channels = in_channels if out_channels is None else out_channels 136 | self.out_channels = out_channels 137 | self.use_conv_shortcut = conv_shortcut 138 | self.time_embedding_norm = time_embedding_norm 139 | self.output_scale_factor = output_scale_factor 140 | 141 | if groups_out is None: 142 | groups_out = groups 143 | 144 | self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 145 | 146 | self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 147 | 148 | if temb_channels is not None: 149 | if self.time_embedding_norm == "default": 150 | time_emb_proj_out_channels = out_channels 151 | elif self.time_embedding_norm == "scale_shift": 152 | time_emb_proj_out_channels = out_channels * 2 153 | else: 154 | raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") 155 | 156 | self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) 157 | else: 158 | self.time_emb_proj = None 159 | 160 | self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) 161 | self.dropout = torch.nn.Dropout(dropout) 162 | self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 163 | 164 | if non_linearity == "swish": 165 | self.nonlinearity = lambda x: F.silu(x) 166 | elif non_linearity == "mish": 167 | self.nonlinearity = Mish() 168 | elif non_linearity == "silu": 169 | self.nonlinearity = nn.SiLU() 170 | 171 | self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut 172 | 173 | self.conv_shortcut = None 174 | if self.use_in_shortcut: 175 | self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 176 | 177 | def forward(self, input_tensor, temb): 178 | hidden_states = input_tensor 179 | 180 | hidden_states = self.norm1(hidden_states) 181 | hidden_states = self.nonlinearity(hidden_states) 182 | 183 | hidden_states = self.conv1(hidden_states) 184 | 185 | if temb is not None: 186 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] 187 | 188 | if temb is not None and self.time_embedding_norm == "default": 189 | hidden_states = hidden_states + temb 190 | 191 | hidden_states = self.norm2(hidden_states) 192 | 193 | if temb is not None and self.time_embedding_norm == "scale_shift": 194 | scale, shift = torch.chunk(temb, 2, dim=1) 195 | hidden_states = hidden_states * (1 + scale) + shift 196 | 197 | hidden_states = self.nonlinearity(hidden_states) 198 | 199 | hidden_states = self.dropout(hidden_states) 200 | hidden_states = self.conv2(hidden_states) 201 | 202 | if self.conv_shortcut is not None: 203 | input_tensor = self.conv_shortcut(input_tensor) 204 | 205 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor 206 | 207 | return output_tensor 208 | 209 | 210 | class Mish(torch.nn.Module): 211 | def forward(self, hidden_states): 212 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) -------------------------------------------------------------------------------- /models/unet.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py 2 | 3 | from dataclasses import dataclass 4 | from typing import List, Optional, Tuple, Union 5 | 6 | import os 7 | import sys 8 | sys.path.append(os.path.split(sys.path[0])[0]) 9 | 10 | import math 11 | import json 12 | import torch 13 | import einops 14 | import torch.nn as nn 15 | import torch.utils.checkpoint 16 | 17 | from diffusers.configuration_utils import ConfigMixin, register_to_config 18 | from diffusers.utils import BaseOutput, logging 19 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps 20 | 21 | try: 22 | from diffusers.models.modeling_utils import ModelMixin 23 | except: 24 | from diffusers.modeling_utils import ModelMixin # 0.11.1 25 | 26 | try: 27 | from .unet_blocks import ( 28 | CrossAttnDownBlock3D, 29 | CrossAttnUpBlock3D, 30 | DownBlock3D, 31 | UNetMidBlock3DCrossAttn, 32 | UpBlock3D, 33 | get_down_block, 34 | get_up_block, 35 | ) 36 | from .resnet import InflatedConv3d 37 | except: 38 | from unet_blocks import ( 39 | CrossAttnDownBlock3D, 40 | CrossAttnUpBlock3D, 41 | DownBlock3D, 42 | UNetMidBlock3DCrossAttn, 43 | UpBlock3D, 44 | get_down_block, 45 | get_up_block, 46 | ) 47 | from resnet import InflatedConv3d 48 | 49 | from rotary_embedding_torch import RotaryEmbedding 50 | 51 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 52 | 53 | class RelativePositionBias(nn.Module): 54 | def __init__( 55 | self, 56 | heads=8, 57 | num_buckets=32, 58 | max_distance=128, 59 | ): 60 | super().__init__() 61 | self.num_buckets = num_buckets 62 | self.max_distance = max_distance 63 | self.relative_attention_bias = nn.Embedding(num_buckets, heads) 64 | 65 | @staticmethod 66 | def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128): 67 | ret = 0 68 | n = -relative_position 69 | 70 | num_buckets //= 2 71 | ret += (n < 0).long() * num_buckets 72 | n = torch.abs(n) 73 | 74 | max_exact = num_buckets // 2 75 | is_small = n < max_exact 76 | 77 | val_if_large = max_exact + ( 78 | torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) 79 | ).long() 80 | val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) 81 | 82 | ret += torch.where(is_small, n, val_if_large) 83 | return ret 84 | 85 | def forward(self, n, device): 86 | q_pos = torch.arange(n, dtype = torch.long, device = device) 87 | k_pos = torch.arange(n, dtype = torch.long, device = device) 88 | rel_pos = einops.rearrange(k_pos, 'j -> 1 j') - einops.rearrange(q_pos, 'i -> i 1') 89 | rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance) 90 | values = self.relative_attention_bias(rp_bucket) 91 | return einops.rearrange(values, 'i j h -> h i j') # num_heads, num_frames, num_frames 92 | 93 | @dataclass 94 | class UNet3DConditionOutput(BaseOutput): 95 | sample: torch.FloatTensor 96 | 97 | 98 | class UNet3DConditionModel(ModelMixin, ConfigMixin): 99 | _supports_gradient_checkpointing = True 100 | 101 | @register_to_config 102 | def __init__( 103 | self, 104 | sample_size: Optional[int] = None, # 64 105 | in_channels: int = 4, 106 | out_channels: int = 4, 107 | center_input_sample: bool = False, 108 | flip_sin_to_cos: bool = True, 109 | freq_shift: int = 0, 110 | down_block_types: Tuple[str] = ( 111 | "CrossAttnDownBlock3D", 112 | "CrossAttnDownBlock3D", 113 | "CrossAttnDownBlock3D", 114 | "DownBlock3D", 115 | ), 116 | mid_block_type: str = "UNetMidBlock3DCrossAttn", 117 | up_block_types: Tuple[str] = ( 118 | "UpBlock3D", 119 | "CrossAttnUpBlock3D", 120 | "CrossAttnUpBlock3D", 121 | "CrossAttnUpBlock3D" 122 | ), 123 | only_cross_attention: Union[bool, Tuple[bool]] = False, 124 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280), 125 | layers_per_block: int = 2, 126 | downsample_padding: int = 1, 127 | mid_block_scale_factor: float = 1, 128 | act_fn: str = "silu", 129 | norm_num_groups: int = 32, 130 | norm_eps: float = 1e-5, 131 | cross_attention_dim: int = 1280, 132 | attention_head_dim: Union[int, Tuple[int]] = 8, 133 | dual_cross_attention: bool = False, 134 | use_linear_projection: bool = False, 135 | class_embed_type: Optional[str] = None, 136 | num_class_embeds: Optional[int] = None, 137 | upcast_attention: bool = False, 138 | resnet_time_scale_shift: str = "default", 139 | use_first_frame: bool = False, 140 | use_relative_position: bool = False, 141 | ): 142 | super().__init__() 143 | 144 | # print(use_first_frame) 145 | 146 | self.sample_size = sample_size 147 | time_embed_dim = block_out_channels[0] * 4 148 | 149 | # input 150 | self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) 151 | 152 | # time 153 | self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) 154 | timestep_input_dim = block_out_channels[0] 155 | 156 | self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 157 | 158 | # class embedding 159 | if class_embed_type is None and num_class_embeds is not None: 160 | self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) 161 | elif class_embed_type == "timestep": 162 | self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 163 | elif class_embed_type == "identity": 164 | self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) 165 | else: 166 | self.class_embedding = None 167 | 168 | self.down_blocks = nn.ModuleList([]) 169 | self.mid_block = None 170 | self.up_blocks = nn.ModuleList([]) 171 | 172 | # print(only_cross_attention) 173 | # print(type(only_cross_attention)) 174 | # exit() 175 | if isinstance(only_cross_attention, bool): 176 | only_cross_attention = [only_cross_attention] * len(down_block_types) 177 | # print(only_cross_attention) 178 | # exit() 179 | 180 | if isinstance(attention_head_dim, int): 181 | attention_head_dim = (attention_head_dim,) * len(down_block_types) 182 | # print(attention_head_dim) 183 | # exit() 184 | 185 | rotary_emb = RotaryEmbedding(32) 186 | 187 | # down 188 | output_channel = block_out_channels[0] 189 | for i, down_block_type in enumerate(down_block_types): 190 | input_channel = output_channel 191 | output_channel = block_out_channels[i] 192 | is_final_block = i == len(block_out_channels) - 1 193 | 194 | down_block = get_down_block( 195 | down_block_type, 196 | num_layers=layers_per_block, 197 | in_channels=input_channel, 198 | out_channels=output_channel, 199 | temb_channels=time_embed_dim, 200 | add_downsample=not is_final_block, 201 | resnet_eps=norm_eps, 202 | resnet_act_fn=act_fn, 203 | resnet_groups=norm_num_groups, 204 | cross_attention_dim=cross_attention_dim, 205 | attn_num_head_channels=attention_head_dim[i], 206 | downsample_padding=downsample_padding, 207 | dual_cross_attention=dual_cross_attention, 208 | use_linear_projection=use_linear_projection, 209 | only_cross_attention=only_cross_attention[i], 210 | upcast_attention=upcast_attention, 211 | resnet_time_scale_shift=resnet_time_scale_shift, 212 | use_first_frame=use_first_frame, 213 | use_relative_position=use_relative_position, 214 | rotary_emb=rotary_emb, 215 | ) 216 | self.down_blocks.append(down_block) 217 | 218 | # mid 219 | if mid_block_type == "UNetMidBlock3DCrossAttn": 220 | self.mid_block = UNetMidBlock3DCrossAttn( 221 | in_channels=block_out_channels[-1], 222 | temb_channels=time_embed_dim, 223 | resnet_eps=norm_eps, 224 | resnet_act_fn=act_fn, 225 | output_scale_factor=mid_block_scale_factor, 226 | resnet_time_scale_shift=resnet_time_scale_shift, 227 | cross_attention_dim=cross_attention_dim, 228 | attn_num_head_channels=attention_head_dim[-1], 229 | resnet_groups=norm_num_groups, 230 | dual_cross_attention=dual_cross_attention, 231 | use_linear_projection=use_linear_projection, 232 | upcast_attention=upcast_attention, 233 | use_first_frame=use_first_frame, 234 | use_relative_position=use_relative_position, 235 | rotary_emb=rotary_emb, 236 | ) 237 | else: 238 | raise ValueError(f"unknown mid_block_type : {mid_block_type}") 239 | 240 | # count how many layers upsample the videos 241 | self.num_upsamplers = 0 242 | 243 | # up 244 | reversed_block_out_channels = list(reversed(block_out_channels)) 245 | reversed_attention_head_dim = list(reversed(attention_head_dim)) 246 | only_cross_attention = list(reversed(only_cross_attention)) 247 | output_channel = reversed_block_out_channels[0] 248 | for i, up_block_type in enumerate(up_block_types): 249 | is_final_block = i == len(block_out_channels) - 1 250 | 251 | prev_output_channel = output_channel 252 | output_channel = reversed_block_out_channels[i] 253 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] 254 | 255 | # add upsample block for all BUT final layer 256 | if not is_final_block: 257 | add_upsample = True 258 | self.num_upsamplers += 1 259 | else: 260 | add_upsample = False 261 | 262 | up_block = get_up_block( 263 | up_block_type, 264 | num_layers=layers_per_block + 1, 265 | in_channels=input_channel, 266 | out_channels=output_channel, 267 | prev_output_channel=prev_output_channel, 268 | temb_channels=time_embed_dim, 269 | add_upsample=add_upsample, 270 | resnet_eps=norm_eps, 271 | resnet_act_fn=act_fn, 272 | resnet_groups=norm_num_groups, 273 | cross_attention_dim=cross_attention_dim, 274 | attn_num_head_channels=reversed_attention_head_dim[i], 275 | dual_cross_attention=dual_cross_attention, 276 | use_linear_projection=use_linear_projection, 277 | only_cross_attention=only_cross_attention[i], 278 | upcast_attention=upcast_attention, 279 | resnet_time_scale_shift=resnet_time_scale_shift, 280 | use_first_frame=use_first_frame, 281 | use_relative_position=use_relative_position, 282 | rotary_emb=rotary_emb, 283 | ) 284 | self.up_blocks.append(up_block) 285 | prev_output_channel = output_channel 286 | 287 | # out 288 | self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) 289 | self.conv_act = nn.SiLU() 290 | self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1) 291 | 292 | # relative time positional embeddings 293 | self.use_relative_position = use_relative_position 294 | if self.use_relative_position: 295 | self.time_rel_pos_bias = RelativePositionBias(heads=8, max_distance=32) # realistically will not be able to generate that many frames of video... yet 296 | 297 | def set_attention_slice(self, slice_size): 298 | r""" 299 | Enable sliced attention computation. 300 | 301 | When this option is enabled, the attention module will split the input tensor in slices, to compute attention 302 | in several steps. This is useful to save some memory in exchange for a small speed decrease. 303 | 304 | Args: 305 | slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): 306 | When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If 307 | `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is 308 | provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` 309 | must be a multiple of `slice_size`. 310 | """ 311 | sliceable_head_dims = [] 312 | 313 | def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): 314 | if hasattr(module, "set_attention_slice"): 315 | sliceable_head_dims.append(module.sliceable_head_dim) 316 | 317 | for child in module.children(): 318 | fn_recursive_retrieve_slicable_dims(child) 319 | 320 | # retrieve number of attention layers 321 | for module in self.children(): 322 | fn_recursive_retrieve_slicable_dims(module) 323 | 324 | num_slicable_layers = len(sliceable_head_dims) 325 | 326 | if slice_size == "auto": 327 | # half the attention head size is usually a good trade-off between 328 | # speed and memory 329 | slice_size = [dim // 2 for dim in sliceable_head_dims] 330 | elif slice_size == "max": 331 | # make smallest slice possible 332 | slice_size = num_slicable_layers * [1] 333 | 334 | slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size 335 | 336 | if len(slice_size) != len(sliceable_head_dims): 337 | raise ValueError( 338 | f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" 339 | f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." 340 | ) 341 | 342 | for i in range(len(slice_size)): 343 | size = slice_size[i] 344 | dim = sliceable_head_dims[i] 345 | if size is not None and size > dim: 346 | raise ValueError(f"size {size} has to be smaller or equal to {dim}.") 347 | 348 | # Recursively walk through all the children. 349 | # Any children which exposes the set_attention_slice method 350 | # gets the message 351 | def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): 352 | if hasattr(module, "set_attention_slice"): 353 | module.set_attention_slice(slice_size.pop()) 354 | 355 | for child in module.children(): 356 | fn_recursive_set_attention_slice(child, slice_size) 357 | 358 | reversed_slice_size = list(reversed(slice_size)) 359 | for module in self.children(): 360 | fn_recursive_set_attention_slice(module, reversed_slice_size) 361 | 362 | def _set_gradient_checkpointing(self, module, value=False): 363 | if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): 364 | module.gradient_checkpointing = value 365 | 366 | def forward( 367 | self, 368 | sample: torch.FloatTensor, 369 | timestep: Union[torch.Tensor, float, int], 370 | encoder_hidden_states: torch.Tensor = None, 371 | class_labels: Optional[torch.Tensor] = None, 372 | attention_mask: Optional[torch.Tensor] = None, 373 | use_image_num: int = 0, 374 | return_dict: bool = True, 375 | ) -> Union[UNet3DConditionOutput, Tuple]: 376 | r""" 377 | Args: 378 | sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor 379 | timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps 380 | encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states 381 | return_dict (`bool`, *optional*, defaults to `True`): 382 | Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. 383 | 384 | Returns: 385 | [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: 386 | [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When 387 | returning a tuple, the first element is the sample tensor. 388 | """ 389 | # By default samples have to be AT least a multiple of the overall upsampling factor. 390 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). 391 | # However, the upsampling interpolation output size can be forced to fit any upsampling size 392 | # on the fly if necessary. 393 | default_overall_up_factor = 2**self.num_upsamplers 394 | 395 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 396 | forward_upsample_size = False 397 | upsample_size = None 398 | 399 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): 400 | logger.info("Forward upsample size to force interpolation output size.") 401 | forward_upsample_size = True 402 | 403 | # prepare attention_mask 404 | if attention_mask is not None: 405 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 406 | attention_mask = attention_mask.unsqueeze(1) 407 | 408 | # center input if necessary 409 | if self.config.center_input_sample: 410 | sample = 2 * sample - 1.0 411 | 412 | # time 413 | timesteps = timestep 414 | if not torch.is_tensor(timesteps): 415 | # This would be a good case for the `match` statement (Python 3.10+) 416 | is_mps = sample.device.type == "mps" 417 | if isinstance(timestep, float): 418 | dtype = torch.float32 if is_mps else torch.float64 419 | else: 420 | dtype = torch.int32 if is_mps else torch.int64 421 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 422 | elif len(timesteps.shape) == 0: 423 | timesteps = timesteps[None].to(sample.device) 424 | 425 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 426 | timesteps = timesteps.expand(sample.shape[0]) 427 | 428 | t_emb = self.time_proj(timesteps) 429 | 430 | # timesteps does not contain any weights and will always return f32 tensors 431 | # but time_embedding might actually be running in fp16. so we need to cast here. 432 | # there might be better ways to encapsulate this. 433 | t_emb = t_emb.to(dtype=self.dtype) 434 | emb = self.time_embedding(t_emb) 435 | if self.class_embedding is not None: 436 | if class_labels is None: 437 | raise ValueError("class_labels should be provided when num_class_embeds > 0") 438 | 439 | if self.config.class_embed_type == "timestep": 440 | class_labels = self.time_proj(class_labels) 441 | 442 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) 443 | # print(emb.shape) # torch.Size([3, 1280]) 444 | # print(class_emb.shape) # torch.Size([3, 1280]) 445 | emb = emb + class_emb 446 | 447 | if self.use_relative_position: 448 | frame_rel_pos_bias = self.time_rel_pos_bias(sample.shape[2], device=sample.device) 449 | else: 450 | frame_rel_pos_bias = None 451 | 452 | # pre-process 453 | sample = self.conv_in(sample) 454 | 455 | # down 456 | down_block_res_samples = (sample,) 457 | for downsample_block in self.down_blocks: 458 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 459 | sample, res_samples = downsample_block( 460 | hidden_states=sample, 461 | temb=emb, 462 | encoder_hidden_states=encoder_hidden_states, 463 | attention_mask=attention_mask, 464 | use_image_num=use_image_num, 465 | ) 466 | else: 467 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 468 | 469 | down_block_res_samples += res_samples 470 | 471 | # mid 472 | sample = self.mid_block( 473 | sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, use_image_num=use_image_num, 474 | ) 475 | 476 | # up 477 | for i, upsample_block in enumerate(self.up_blocks): 478 | is_final_block = i == len(self.up_blocks) - 1 479 | 480 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 481 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 482 | 483 | # if we have not reached the final block and need to forward the 484 | # upsample size, we do it here 485 | if not is_final_block and forward_upsample_size: 486 | upsample_size = down_block_res_samples[-1].shape[2:] 487 | 488 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: 489 | sample = upsample_block( 490 | hidden_states=sample, 491 | temb=emb, 492 | res_hidden_states_tuple=res_samples, 493 | encoder_hidden_states=encoder_hidden_states, 494 | upsample_size=upsample_size, 495 | attention_mask=attention_mask, 496 | use_image_num=use_image_num, 497 | ) 498 | else: 499 | sample = upsample_block( 500 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size 501 | ) 502 | # post-process 503 | sample = self.conv_norm_out(sample) 504 | sample = self.conv_act(sample) 505 | sample = self.conv_out(sample) 506 | # print(sample.shape) 507 | 508 | if not return_dict: 509 | return (sample,) 510 | sample = UNet3DConditionOutput(sample=sample) 511 | return sample 512 | 513 | def forward_with_cfg(self, 514 | x, 515 | t, 516 | encoder_hidden_states = None, 517 | class_labels: Optional[torch.Tensor] = None, 518 | cfg_scale=4.0, 519 | use_fp16=False): 520 | """ 521 | Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance. 522 | """ 523 | # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb 524 | half = x[: len(x) // 2] 525 | combined = torch.cat([half, half], dim=0) 526 | if use_fp16: 527 | combined = combined.to(dtype=torch.float16) 528 | model_out = self.forward(combined, t, encoder_hidden_states, class_labels).sample 529 | # For exact reproducibility reasons, we apply classifier-free guidance on only 530 | # three channels by default. The standard approach to cfg applies it to all channels. 531 | # This can be done by uncommenting the following line and commenting-out the line following that. 532 | eps, rest = model_out[:, :4], model_out[:, 4:] 533 | # eps, rest = model_out[:, :3], model_out[:, 3:] # b c f h w 534 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) 535 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) 536 | eps = torch.cat([half_eps, half_eps], dim=0) 537 | return torch.cat([eps, rest], dim=1) 538 | 539 | @classmethod 540 | def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, use_concat=False): 541 | if subfolder is not None: 542 | pretrained_model_path = os.path.join(pretrained_model_path, subfolder) 543 | 544 | 545 | # the content of the config file 546 | # { 547 | # "_class_name": "UNet2DConditionModel", 548 | # "_diffusers_version": "0.2.2", 549 | # "act_fn": "silu", 550 | # "attention_head_dim": 8, 551 | # "block_out_channels": [ 552 | # 320, 553 | # 640, 554 | # 1280, 555 | # 1280 556 | # ], 557 | # "center_input_sample": false, 558 | # "cross_attention_dim": 768, 559 | # "down_block_types": [ 560 | # "CrossAttnDownBlock2D", 561 | # "CrossAttnDownBlock2D", 562 | # "CrossAttnDownBlock2D", 563 | # "DownBlock2D" 564 | # ], 565 | # "downsample_padding": 1, 566 | # "flip_sin_to_cos": true, 567 | # "freq_shift": 0, 568 | # "in_channels": 4, 569 | # "layers_per_block": 2, 570 | # "mid_block_scale_factor": 1, 571 | # "norm_eps": 1e-05, 572 | # "norm_num_groups": 32, 573 | # "out_channels": 4, 574 | # "sample_size": 64, 575 | # "up_block_types": [ 576 | # "UpBlock2D", 577 | # "CrossAttnUpBlock2D", 578 | # "CrossAttnUpBlock2D", 579 | # "CrossAttnUpBlock2D" 580 | # ] 581 | # } 582 | config_file = os.path.join(pretrained_model_path, 'config.json') 583 | if not os.path.isfile(config_file): 584 | raise RuntimeError(f"{config_file} does not exist") 585 | with open(config_file, "r") as f: 586 | config = json.load(f) 587 | config["_class_name"] = cls.__name__ 588 | config["down_block_types"] = [ 589 | "CrossAttnDownBlock3D", 590 | "CrossAttnDownBlock3D", 591 | "CrossAttnDownBlock3D", 592 | "DownBlock3D" 593 | ] 594 | config["up_block_types"] = [ 595 | "UpBlock3D", 596 | "CrossAttnUpBlock3D", 597 | "CrossAttnUpBlock3D", 598 | "CrossAttnUpBlock3D" 599 | ] 600 | 601 | # config["use_first_frame"] = True 602 | 603 | config["use_first_frame"] = False 604 | if use_concat: 605 | config["in_channels"] = 9 606 | # config["use_relative_position"] = True 607 | 608 | # # tmp 609 | # config["class_embed_type"] = "timestep" 610 | # config["num_class_embeds"] = 100 611 | 612 | from diffusers.utils import WEIGHTS_NAME # diffusion_pytorch_model.bin 613 | 614 | # {'_class_name': 'UNet3DConditionModel', 615 | # '_diffusers_version': '0.2.2', 616 | # 'act_fn': 'silu', 617 | # 'attention_head_dim': 8, 618 | # 'block_out_channels': [320, 640, 1280, 1280], 619 | # 'center_input_sample': False, 620 | # 'cross_attention_dim': 768, 621 | # 'down_block_types': 622 | # ['CrossAttnDownBlock3D', 623 | # 'CrossAttnDownBlock3D', 624 | # 'CrossAttnDownBlock3D', 625 | # 'DownBlock3D'], 626 | # 'downsample_padding': 1, 627 | # 'flip_sin_to_cos': True, 628 | # 'freq_shift': 0, 629 | # 'in_channels': 4, 630 | # 'layers_per_block': 2, 631 | # 'mid_block_scale_factor': 1, 632 | # 'norm_eps': 1e-05, 633 | # 'norm_num_groups': 32, 634 | # 'out_channels': 4, 635 | # 'sample_size': 64, 636 | # 'up_block_types': 637 | # ['UpBlock3D', 638 | # 'CrossAttnUpBlock3D', 639 | # 'CrossAttnUpBlock3D', 640 | # 'CrossAttnUpBlock3D']} 641 | 642 | model = cls.from_config(config) 643 | model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) 644 | if not os.path.isfile(model_file): 645 | raise RuntimeError(f"{model_file} does not exist") 646 | state_dict = torch.load(model_file, map_location="cpu") 647 | 648 | if use_concat: 649 | new_state_dict = {} 650 | conv_in_weight = state_dict["conv_in.weight"] 651 | new_conv_weight = torch.zeros((conv_in_weight.shape[0], 9, *conv_in_weight.shape[2:]), dtype=conv_in_weight.dtype) 652 | 653 | for i, j in zip([0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 6, 7, 8]): 654 | new_conv_weight[:, j] = conv_in_weight[:, i] 655 | new_state_dict["conv_in.weight"] = new_conv_weight 656 | new_state_dict["conv_in.bias"] = state_dict["conv_in.bias"] 657 | for k, v in model.state_dict().items(): 658 | # print(k) 659 | if '_temp.' in k: 660 | new_state_dict.update({k: v}) 661 | if 'attn_fcross' in k: # conpy parms of attn1 to attn_fcross 662 | k = k.replace('attn_fcross', 'attn1') 663 | state_dict.update({k: state_dict[k]}) 664 | if 'norm_fcross' in k: 665 | k = k.replace('norm_fcross', 'norm1') 666 | state_dict.update({k: state_dict[k]}) 667 | 668 | if 'conv_in' in k: 669 | continue 670 | else: 671 | new_state_dict[k] = v 672 | # # tmp 673 | # if 'class_embedding' in k: 674 | # state_dict.update({k: v}) 675 | # breakpoint() 676 | model.load_state_dict(new_state_dict) 677 | else: 678 | for k, v in model.state_dict().items(): 679 | # print(k) 680 | if '_temp' in k: 681 | state_dict.update({k: v}) 682 | if 'attn_fcross' in k: # conpy parms of attn1 to attn_fcross 683 | k = k.replace('attn_fcross', 'attn1') 684 | state_dict.update({k: state_dict[k]}) 685 | if 'norm_fcross' in k: 686 | k = k.replace('norm_fcross', 'norm1') 687 | state_dict.update({k: state_dict[k]}) 688 | 689 | model.load_state_dict(state_dict) 690 | 691 | return model 692 | -------------------------------------------------------------------------------- /models/unet_blocks.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py 2 | import os 3 | import sys 4 | sys.path.append(os.path.split(sys.path[0])[0]) 5 | 6 | import torch 7 | from torch import nn 8 | 9 | try: 10 | from .attention import Transformer3DModel 11 | from .resnet import Downsample3D, ResnetBlock3D, Upsample3D 12 | except: 13 | from attention import Transformer3DModel 14 | from resnet import Downsample3D, ResnetBlock3D, Upsample3D 15 | 16 | 17 | def get_down_block( 18 | down_block_type, 19 | num_layers, 20 | in_channels, 21 | out_channels, 22 | temb_channels, 23 | add_downsample, 24 | resnet_eps, 25 | resnet_act_fn, 26 | attn_num_head_channels, 27 | resnet_groups=None, 28 | cross_attention_dim=None, 29 | downsample_padding=None, 30 | dual_cross_attention=False, 31 | use_linear_projection=False, 32 | only_cross_attention=False, 33 | upcast_attention=False, 34 | resnet_time_scale_shift="default", 35 | use_first_frame=False, 36 | use_relative_position=False, 37 | rotary_emb=False, 38 | ): 39 | # print(down_block_type) 40 | # print(use_first_frame) 41 | down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type 42 | if down_block_type == "DownBlock3D": 43 | return DownBlock3D( 44 | num_layers=num_layers, 45 | in_channels=in_channels, 46 | out_channels=out_channels, 47 | temb_channels=temb_channels, 48 | add_downsample=add_downsample, 49 | resnet_eps=resnet_eps, 50 | resnet_act_fn=resnet_act_fn, 51 | resnet_groups=resnet_groups, 52 | downsample_padding=downsample_padding, 53 | resnet_time_scale_shift=resnet_time_scale_shift, 54 | ) 55 | elif down_block_type == "CrossAttnDownBlock3D": 56 | if cross_attention_dim is None: 57 | raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") 58 | return CrossAttnDownBlock3D( 59 | num_layers=num_layers, 60 | in_channels=in_channels, 61 | out_channels=out_channels, 62 | temb_channels=temb_channels, 63 | add_downsample=add_downsample, 64 | resnet_eps=resnet_eps, 65 | resnet_act_fn=resnet_act_fn, 66 | resnet_groups=resnet_groups, 67 | downsample_padding=downsample_padding, 68 | cross_attention_dim=cross_attention_dim, 69 | attn_num_head_channels=attn_num_head_channels, 70 | dual_cross_attention=dual_cross_attention, 71 | use_linear_projection=use_linear_projection, 72 | only_cross_attention=only_cross_attention, 73 | upcast_attention=upcast_attention, 74 | resnet_time_scale_shift=resnet_time_scale_shift, 75 | use_first_frame=use_first_frame, 76 | use_relative_position=use_relative_position, 77 | rotary_emb=rotary_emb, 78 | ) 79 | raise ValueError(f"{down_block_type} does not exist.") 80 | 81 | 82 | def get_up_block( 83 | up_block_type, 84 | num_layers, 85 | in_channels, 86 | out_channels, 87 | prev_output_channel, 88 | temb_channels, 89 | add_upsample, 90 | resnet_eps, 91 | resnet_act_fn, 92 | attn_num_head_channels, 93 | resnet_groups=None, 94 | cross_attention_dim=None, 95 | dual_cross_attention=False, 96 | use_linear_projection=False, 97 | only_cross_attention=False, 98 | upcast_attention=False, 99 | resnet_time_scale_shift="default", 100 | use_first_frame=False, 101 | use_relative_position=False, 102 | rotary_emb=False, 103 | ): 104 | up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type 105 | if up_block_type == "UpBlock3D": 106 | return UpBlock3D( 107 | num_layers=num_layers, 108 | in_channels=in_channels, 109 | out_channels=out_channels, 110 | prev_output_channel=prev_output_channel, 111 | temb_channels=temb_channels, 112 | add_upsample=add_upsample, 113 | resnet_eps=resnet_eps, 114 | resnet_act_fn=resnet_act_fn, 115 | resnet_groups=resnet_groups, 116 | resnet_time_scale_shift=resnet_time_scale_shift, 117 | ) 118 | elif up_block_type == "CrossAttnUpBlock3D": 119 | if cross_attention_dim is None: 120 | raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") 121 | return CrossAttnUpBlock3D( 122 | num_layers=num_layers, 123 | in_channels=in_channels, 124 | out_channels=out_channels, 125 | prev_output_channel=prev_output_channel, 126 | temb_channels=temb_channels, 127 | add_upsample=add_upsample, 128 | resnet_eps=resnet_eps, 129 | resnet_act_fn=resnet_act_fn, 130 | resnet_groups=resnet_groups, 131 | cross_attention_dim=cross_attention_dim, 132 | attn_num_head_channels=attn_num_head_channels, 133 | dual_cross_attention=dual_cross_attention, 134 | use_linear_projection=use_linear_projection, 135 | only_cross_attention=only_cross_attention, 136 | upcast_attention=upcast_attention, 137 | resnet_time_scale_shift=resnet_time_scale_shift, 138 | use_first_frame=use_first_frame, 139 | use_relative_position=use_relative_position, 140 | rotary_emb=rotary_emb, 141 | ) 142 | raise ValueError(f"{up_block_type} does not exist.") 143 | 144 | 145 | class UNetMidBlock3DCrossAttn(nn.Module): 146 | def __init__( 147 | self, 148 | in_channels: int, 149 | temb_channels: int, 150 | dropout: float = 0.0, 151 | num_layers: int = 1, 152 | resnet_eps: float = 1e-6, 153 | resnet_time_scale_shift: str = "default", 154 | resnet_act_fn: str = "swish", 155 | resnet_groups: int = 32, 156 | resnet_pre_norm: bool = True, 157 | attn_num_head_channels=1, 158 | output_scale_factor=1.0, 159 | cross_attention_dim=1280, 160 | dual_cross_attention=False, 161 | use_linear_projection=False, 162 | upcast_attention=False, 163 | use_first_frame=False, 164 | use_relative_position=False, 165 | rotary_emb=False, 166 | ): 167 | super().__init__() 168 | 169 | self.has_cross_attention = True 170 | self.attn_num_head_channels = attn_num_head_channels 171 | resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) 172 | 173 | # there is always at least one resnet 174 | resnets = [ 175 | ResnetBlock3D( 176 | in_channels=in_channels, 177 | out_channels=in_channels, 178 | temb_channels=temb_channels, 179 | eps=resnet_eps, 180 | groups=resnet_groups, 181 | dropout=dropout, 182 | time_embedding_norm=resnet_time_scale_shift, 183 | non_linearity=resnet_act_fn, 184 | output_scale_factor=output_scale_factor, 185 | pre_norm=resnet_pre_norm, 186 | ) 187 | ] 188 | attentions = [] 189 | 190 | for _ in range(num_layers): 191 | if dual_cross_attention: 192 | raise NotImplementedError 193 | attentions.append( 194 | Transformer3DModel( 195 | attn_num_head_channels, 196 | in_channels // attn_num_head_channels, 197 | in_channels=in_channels, 198 | num_layers=1, 199 | cross_attention_dim=cross_attention_dim, 200 | norm_num_groups=resnet_groups, 201 | use_linear_projection=use_linear_projection, 202 | upcast_attention=upcast_attention, 203 | use_first_frame=use_first_frame, 204 | use_relative_position=use_relative_position, 205 | rotary_emb=rotary_emb, 206 | ) 207 | ) 208 | resnets.append( 209 | ResnetBlock3D( 210 | in_channels=in_channels, 211 | out_channels=in_channels, 212 | temb_channels=temb_channels, 213 | eps=resnet_eps, 214 | groups=resnet_groups, 215 | dropout=dropout, 216 | time_embedding_norm=resnet_time_scale_shift, 217 | non_linearity=resnet_act_fn, 218 | output_scale_factor=output_scale_factor, 219 | pre_norm=resnet_pre_norm, 220 | ) 221 | ) 222 | 223 | self.attentions = nn.ModuleList(attentions) 224 | self.resnets = nn.ModuleList(resnets) 225 | 226 | def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, use_image_num=None): 227 | hidden_states = self.resnets[0](hidden_states, temb) 228 | for attn, resnet in zip(self.attentions, self.resnets[1:]): 229 | hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, use_image_num=use_image_num).sample 230 | hidden_states = resnet(hidden_states, temb) 231 | 232 | return hidden_states 233 | 234 | 235 | class CrossAttnDownBlock3D(nn.Module): 236 | def __init__( 237 | self, 238 | in_channels: int, 239 | out_channels: int, 240 | temb_channels: int, 241 | dropout: float = 0.0, 242 | num_layers: int = 1, 243 | resnet_eps: float = 1e-6, 244 | resnet_time_scale_shift: str = "default", 245 | resnet_act_fn: str = "swish", 246 | resnet_groups: int = 32, 247 | resnet_pre_norm: bool = True, 248 | attn_num_head_channels=1, 249 | cross_attention_dim=1280, 250 | output_scale_factor=1.0, 251 | downsample_padding=1, 252 | add_downsample=True, 253 | dual_cross_attention=False, 254 | use_linear_projection=False, 255 | only_cross_attention=False, 256 | upcast_attention=False, 257 | use_first_frame=False, 258 | use_relative_position=False, 259 | rotary_emb=False, 260 | ): 261 | super().__init__() 262 | resnets = [] 263 | attentions = [] 264 | 265 | # print(use_first_frame) 266 | 267 | self.has_cross_attention = True 268 | self.attn_num_head_channels = attn_num_head_channels 269 | 270 | for i in range(num_layers): 271 | in_channels = in_channels if i == 0 else out_channels 272 | resnets.append( 273 | ResnetBlock3D( 274 | in_channels=in_channels, 275 | out_channels=out_channels, 276 | temb_channels=temb_channels, 277 | eps=resnet_eps, 278 | groups=resnet_groups, 279 | dropout=dropout, 280 | time_embedding_norm=resnet_time_scale_shift, 281 | non_linearity=resnet_act_fn, 282 | output_scale_factor=output_scale_factor, 283 | pre_norm=resnet_pre_norm, 284 | ) 285 | ) 286 | if dual_cross_attention: 287 | raise NotImplementedError 288 | attentions.append( 289 | Transformer3DModel( 290 | attn_num_head_channels, 291 | out_channels // attn_num_head_channels, 292 | in_channels=out_channels, 293 | num_layers=1, 294 | cross_attention_dim=cross_attention_dim, 295 | norm_num_groups=resnet_groups, 296 | use_linear_projection=use_linear_projection, 297 | only_cross_attention=only_cross_attention, 298 | upcast_attention=upcast_attention, 299 | use_first_frame=use_first_frame, 300 | use_relative_position=use_relative_position, 301 | rotary_emb=rotary_emb, 302 | ) 303 | ) 304 | self.attentions = nn.ModuleList(attentions) 305 | self.resnets = nn.ModuleList(resnets) 306 | 307 | if add_downsample: 308 | self.downsamplers = nn.ModuleList( 309 | [ 310 | Downsample3D( 311 | out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" 312 | ) 313 | ] 314 | ) 315 | else: 316 | self.downsamplers = None 317 | 318 | self.gradient_checkpointing = False 319 | 320 | def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, use_image_num=None): 321 | output_states = () 322 | 323 | for resnet, attn in zip(self.resnets, self.attentions): 324 | if self.training and self.gradient_checkpointing: 325 | 326 | def create_custom_forward(module, return_dict=None): 327 | def custom_forward(*inputs): 328 | if return_dict is not None: 329 | return module(*inputs, return_dict=return_dict) 330 | else: 331 | return module(*inputs) 332 | 333 | return custom_forward 334 | 335 | def create_custom_forward_attn(module, return_dict=None, use_image_num=None): 336 | def custom_forward(*inputs): 337 | if return_dict is not None: 338 | return module(*inputs, return_dict=return_dict, use_image_num=use_image_num) 339 | else: 340 | return module(*inputs, use_image_num=use_image_num) 341 | 342 | return custom_forward 343 | 344 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) 345 | hidden_states = torch.utils.checkpoint.checkpoint( 346 | create_custom_forward_attn(attn, return_dict=False, use_image_num=use_image_num), 347 | hidden_states, 348 | encoder_hidden_states, 349 | )[0] 350 | else: 351 | hidden_states = resnet(hidden_states, temb) 352 | hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, use_image_num=use_image_num).sample 353 | 354 | output_states += (hidden_states,) 355 | 356 | if self.downsamplers is not None: 357 | for downsampler in self.downsamplers: 358 | hidden_states = downsampler(hidden_states) 359 | 360 | output_states += (hidden_states,) 361 | 362 | return hidden_states, output_states 363 | 364 | 365 | class DownBlock3D(nn.Module): 366 | def __init__( 367 | self, 368 | in_channels: int, 369 | out_channels: int, 370 | temb_channels: int, 371 | dropout: float = 0.0, 372 | num_layers: int = 1, 373 | resnet_eps: float = 1e-6, 374 | resnet_time_scale_shift: str = "default", 375 | resnet_act_fn: str = "swish", 376 | resnet_groups: int = 32, 377 | resnet_pre_norm: bool = True, 378 | output_scale_factor=1.0, 379 | add_downsample=True, 380 | downsample_padding=1, 381 | ): 382 | super().__init__() 383 | resnets = [] 384 | 385 | for i in range(num_layers): 386 | in_channels = in_channels if i == 0 else out_channels 387 | resnets.append( 388 | ResnetBlock3D( 389 | in_channels=in_channels, 390 | out_channels=out_channels, 391 | temb_channels=temb_channels, 392 | eps=resnet_eps, 393 | groups=resnet_groups, 394 | dropout=dropout, 395 | time_embedding_norm=resnet_time_scale_shift, 396 | non_linearity=resnet_act_fn, 397 | output_scale_factor=output_scale_factor, 398 | pre_norm=resnet_pre_norm, 399 | ) 400 | ) 401 | 402 | self.resnets = nn.ModuleList(resnets) 403 | 404 | if add_downsample: 405 | self.downsamplers = nn.ModuleList( 406 | [ 407 | Downsample3D( 408 | out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" 409 | ) 410 | ] 411 | ) 412 | else: 413 | self.downsamplers = None 414 | 415 | self.gradient_checkpointing = False 416 | 417 | def forward(self, hidden_states, temb=None): 418 | output_states = () 419 | 420 | for resnet in self.resnets: 421 | if self.training and self.gradient_checkpointing: 422 | 423 | def create_custom_forward(module): 424 | def custom_forward(*inputs): 425 | return module(*inputs) 426 | 427 | return custom_forward 428 | 429 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) 430 | else: 431 | hidden_states = resnet(hidden_states, temb) 432 | 433 | output_states += (hidden_states,) 434 | 435 | if self.downsamplers is not None: 436 | for downsampler in self.downsamplers: 437 | hidden_states = downsampler(hidden_states) 438 | 439 | output_states += (hidden_states,) 440 | 441 | return hidden_states, output_states 442 | 443 | 444 | class CrossAttnUpBlock3D(nn.Module): 445 | def __init__( 446 | self, 447 | in_channels: int, 448 | out_channels: int, 449 | prev_output_channel: int, 450 | temb_channels: int, 451 | dropout: float = 0.0, 452 | num_layers: int = 1, 453 | resnet_eps: float = 1e-6, 454 | resnet_time_scale_shift: str = "default", 455 | resnet_act_fn: str = "swish", 456 | resnet_groups: int = 32, 457 | resnet_pre_norm: bool = True, 458 | attn_num_head_channels=1, 459 | cross_attention_dim=1280, 460 | output_scale_factor=1.0, 461 | add_upsample=True, 462 | dual_cross_attention=False, 463 | use_linear_projection=False, 464 | only_cross_attention=False, 465 | upcast_attention=False, 466 | use_first_frame=False, 467 | use_relative_position=False, 468 | rotary_emb=False 469 | ): 470 | super().__init__() 471 | resnets = [] 472 | attentions = [] 473 | 474 | self.has_cross_attention = True 475 | self.attn_num_head_channels = attn_num_head_channels 476 | 477 | for i in range(num_layers): 478 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels 479 | resnet_in_channels = prev_output_channel if i == 0 else out_channels 480 | 481 | resnets.append( 482 | ResnetBlock3D( 483 | in_channels=resnet_in_channels + res_skip_channels, 484 | out_channels=out_channels, 485 | temb_channels=temb_channels, 486 | eps=resnet_eps, 487 | groups=resnet_groups, 488 | dropout=dropout, 489 | time_embedding_norm=resnet_time_scale_shift, 490 | non_linearity=resnet_act_fn, 491 | output_scale_factor=output_scale_factor, 492 | pre_norm=resnet_pre_norm, 493 | ) 494 | ) 495 | if dual_cross_attention: 496 | raise NotImplementedError 497 | attentions.append( 498 | Transformer3DModel( 499 | attn_num_head_channels, 500 | out_channels // attn_num_head_channels, 501 | in_channels=out_channels, 502 | num_layers=1, 503 | cross_attention_dim=cross_attention_dim, 504 | norm_num_groups=resnet_groups, 505 | use_linear_projection=use_linear_projection, 506 | only_cross_attention=only_cross_attention, 507 | upcast_attention=upcast_attention, 508 | use_first_frame=use_first_frame, 509 | use_relative_position=use_relative_position, 510 | rotary_emb=rotary_emb, 511 | ) 512 | ) 513 | 514 | self.attentions = nn.ModuleList(attentions) 515 | self.resnets = nn.ModuleList(resnets) 516 | 517 | if add_upsample: 518 | self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) 519 | else: 520 | self.upsamplers = None 521 | 522 | self.gradient_checkpointing = False 523 | 524 | def forward( 525 | self, 526 | hidden_states, 527 | res_hidden_states_tuple, 528 | temb=None, 529 | encoder_hidden_states=None, 530 | upsample_size=None, 531 | attention_mask=None, 532 | use_image_num=None, 533 | ): 534 | for resnet, attn in zip(self.resnets, self.attentions): 535 | # pop res hidden states 536 | res_hidden_states = res_hidden_states_tuple[-1] 537 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 538 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 539 | 540 | if self.training and self.gradient_checkpointing: 541 | 542 | def create_custom_forward(module, return_dict=None): 543 | def custom_forward(*inputs): 544 | if return_dict is not None: 545 | return module(*inputs, return_dict=return_dict) 546 | else: 547 | return module(*inputs) 548 | 549 | return custom_forward 550 | 551 | def create_custom_forward_attn(module, return_dict=None, use_image_num=None): 552 | def custom_forward(*inputs): 553 | if return_dict is not None: 554 | return module(*inputs, return_dict=return_dict, use_image_num=use_image_num) 555 | else: 556 | return module(*inputs, use_image_num=use_image_num) 557 | 558 | return custom_forward 559 | 560 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) 561 | hidden_states = torch.utils.checkpoint.checkpoint( 562 | create_custom_forward_attn(attn, return_dict=False, use_image_num=use_image_num), 563 | hidden_states, 564 | encoder_hidden_states, 565 | )[0] 566 | else: 567 | hidden_states = resnet(hidden_states, temb) 568 | hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, use_image_num=use_image_num).sample 569 | 570 | if self.upsamplers is not None: 571 | for upsampler in self.upsamplers: 572 | hidden_states = upsampler(hidden_states, upsample_size) 573 | 574 | return hidden_states 575 | 576 | 577 | class UpBlock3D(nn.Module): 578 | def __init__( 579 | self, 580 | in_channels: int, 581 | prev_output_channel: int, 582 | out_channels: int, 583 | temb_channels: int, 584 | dropout: float = 0.0, 585 | num_layers: int = 1, 586 | resnet_eps: float = 1e-6, 587 | resnet_time_scale_shift: str = "default", 588 | resnet_act_fn: str = "swish", 589 | resnet_groups: int = 32, 590 | resnet_pre_norm: bool = True, 591 | output_scale_factor=1.0, 592 | add_upsample=True, 593 | ): 594 | super().__init__() 595 | resnets = [] 596 | 597 | for i in range(num_layers): 598 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels 599 | resnet_in_channels = prev_output_channel if i == 0 else out_channels 600 | 601 | resnets.append( 602 | ResnetBlock3D( 603 | in_channels=resnet_in_channels + res_skip_channels, 604 | out_channels=out_channels, 605 | temb_channels=temb_channels, 606 | eps=resnet_eps, 607 | groups=resnet_groups, 608 | dropout=dropout, 609 | time_embedding_norm=resnet_time_scale_shift, 610 | non_linearity=resnet_act_fn, 611 | output_scale_factor=output_scale_factor, 612 | pre_norm=resnet_pre_norm, 613 | ) 614 | ) 615 | 616 | self.resnets = nn.ModuleList(resnets) 617 | 618 | if add_upsample: 619 | self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) 620 | else: 621 | self.upsamplers = None 622 | 623 | self.gradient_checkpointing = False 624 | 625 | def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): 626 | for resnet in self.resnets: 627 | # pop res hidden states 628 | res_hidden_states = res_hidden_states_tuple[-1] 629 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 630 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 631 | 632 | if self.training and self.gradient_checkpointing: 633 | 634 | def create_custom_forward(module): 635 | def custom_forward(*inputs): 636 | return module(*inputs) 637 | 638 | return custom_forward 639 | 640 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) 641 | else: 642 | hidden_states = resnet(hidden_states, temb) 643 | 644 | if self.upsamplers is not None: 645 | for upsampler in self.upsamplers: 646 | hidden_states = upsampler(hidden_states, upsample_size) 647 | 648 | return hidden_states 649 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import os 12 | import math 13 | import torch 14 | 15 | import numpy as np 16 | import torch.nn as nn 17 | 18 | from einops import repeat 19 | 20 | 21 | ################################################################################# 22 | # Unet Utils # 23 | ################################################################################# 24 | 25 | def checkpoint(func, inputs, params, flag): 26 | """ 27 | Evaluate a function without caching intermediate activations, allowing for 28 | reduced memory at the expense of extra compute in the backward pass. 29 | :param func: the function to evaluate. 30 | :param inputs: the argument sequence to pass to `func`. 31 | :param params: a sequence of parameters `func` depends on but does not 32 | explicitly take as arguments. 33 | :param flag: if False, disable gradient checkpointing. 34 | """ 35 | if flag: 36 | args = tuple(inputs) + tuple(params) 37 | return CheckpointFunction.apply(func, len(inputs), *args) 38 | else: 39 | return func(*inputs) 40 | 41 | 42 | class CheckpointFunction(torch.autograd.Function): 43 | @staticmethod 44 | def forward(ctx, run_function, length, *args): 45 | ctx.run_function = run_function 46 | ctx.input_tensors = list(args[:length]) 47 | ctx.input_params = list(args[length:]) 48 | 49 | with torch.no_grad(): 50 | output_tensors = ctx.run_function(*ctx.input_tensors) 51 | return output_tensors 52 | 53 | @staticmethod 54 | def backward(ctx, *output_grads): 55 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 56 | with torch.enable_grad(): 57 | # Fixes a bug where the first op in run_function modifies the 58 | # Tensor storage in place, which is not allowed for detach()'d 59 | # Tensors. 60 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 61 | output_tensors = ctx.run_function(*shallow_copies) 62 | input_grads = torch.autograd.grad( 63 | output_tensors, 64 | ctx.input_tensors + ctx.input_params, 65 | output_grads, 66 | allow_unused=True, 67 | ) 68 | del ctx.input_tensors 69 | del ctx.input_params 70 | del output_tensors 71 | return (None, None) + input_grads 72 | 73 | 74 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 75 | """ 76 | Create sinusoidal timestep embeddings. 77 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 78 | These may be fractional. 79 | :param dim: the dimension of the output. 80 | :param max_period: controls the minimum frequency of the embeddings. 81 | :return: an [N x dim] Tensor of positional embeddings. 82 | """ 83 | if not repeat_only: 84 | half = dim // 2 85 | freqs = torch.exp( 86 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 87 | ).to(device=timesteps.device) 88 | args = timesteps[:, None].float() * freqs[None] 89 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 90 | if dim % 2: 91 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 92 | else: 93 | embedding = repeat(timesteps, 'b -> b d', d=dim).contiguous() 94 | return embedding 95 | 96 | 97 | def zero_module(module): 98 | """ 99 | Zero out the parameters of a module and return it. 100 | """ 101 | for p in module.parameters(): 102 | p.detach().zero_() 103 | return module 104 | 105 | 106 | def scale_module(module, scale): 107 | """ 108 | Scale the parameters of a module and return it. 109 | """ 110 | for p in module.parameters(): 111 | p.detach().mul_(scale) 112 | return module 113 | 114 | 115 | def mean_flat(tensor): 116 | """ 117 | Take the mean over all non-batch dimensions. 118 | """ 119 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 120 | 121 | 122 | def normalization(channels): 123 | """ 124 | Make a standard normalization layer. 125 | :param channels: number of input channels. 126 | :return: an nn.Module for normalization. 127 | """ 128 | return GroupNorm32(32, channels) 129 | 130 | 131 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 132 | class SiLU(nn.Module): 133 | def forward(self, x): 134 | return x * torch.sigmoid(x) 135 | 136 | 137 | class GroupNorm32(nn.GroupNorm): 138 | def forward(self, x): 139 | return super().forward(x.float()).type(x.dtype) 140 | 141 | def conv_nd(dims, *args, **kwargs): 142 | """ 143 | Create a 1D, 2D, or 3D convolution module. 144 | """ 145 | if dims == 1: 146 | return nn.Conv1d(*args, **kwargs) 147 | elif dims == 2: 148 | return nn.Conv2d(*args, **kwargs) 149 | elif dims == 3: 150 | return nn.Conv3d(*args, **kwargs) 151 | raise ValueError(f"unsupported dimensions: {dims}") 152 | 153 | 154 | def linear(*args, **kwargs): 155 | """ 156 | Create a linear module. 157 | """ 158 | return nn.Linear(*args, **kwargs) 159 | 160 | 161 | def avg_pool_nd(dims, *args, **kwargs): 162 | """ 163 | Create a 1D, 2D, or 3D average pooling module. 164 | """ 165 | if dims == 1: 166 | return nn.AvgPool1d(*args, **kwargs) 167 | elif dims == 2: 168 | return nn.AvgPool2d(*args, **kwargs) 169 | elif dims == 3: 170 | return nn.AvgPool3d(*args, **kwargs) 171 | raise ValueError(f"unsupported dimensions: {dims}") 172 | 173 | 174 | # class HybridConditioner(nn.Module): 175 | 176 | # def __init__(self, c_concat_config, c_crossattn_config): 177 | # super().__init__() 178 | # self.concat_conditioner = instantiate_from_config(c_concat_config) 179 | # self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 180 | 181 | # def forward(self, c_concat, c_crossattn): 182 | # c_concat = self.concat_conditioner(c_concat) 183 | # c_crossattn = self.crossattn_conditioner(c_crossattn) 184 | # return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} 185 | 186 | 187 | def noise_like(shape, device, repeat=False): 188 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 189 | noise = lambda: torch.randn(shape, device=device) 190 | return repeat_noise() if repeat else noise() 191 | 192 | def count_flops_attn(model, _x, y): 193 | """ 194 | A counter for the `thop` package to count the operations in an 195 | attention operation. 196 | Meant to be used like: 197 | macs, params = thop.profile( 198 | model, 199 | inputs=(inputs, timestamps), 200 | custom_ops={QKVAttention: QKVAttention.count_flops}, 201 | ) 202 | """ 203 | b, c, *spatial = y[0].shape 204 | num_spatial = int(np.prod(spatial)) 205 | # We perform two matmuls with the same number of ops. 206 | # The first computes the weight matrix, the second computes 207 | # the combination of the value vectors. 208 | matmul_ops = 2 * b * (num_spatial ** 2) * c 209 | model.total_ops += torch.DoubleTensor([matmul_ops]) 210 | 211 | def count_params(model, verbose=False): 212 | total_params = sum(p.numel() for p in model.parameters()) 213 | if verbose: 214 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 215 | return total_params -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.1 2 | torchaudio==2.0.2 3 | torchvision==0.15.2 4 | decord==0.6.0 5 | diffusers==0.15.0 6 | imageio==2.29.0 7 | transformers==4.29.2 8 | xformers==0.0.20 9 | einops 10 | omegaconf 11 | tensorboard==2.15.1 12 | timm==0.9.10 13 | rotary-embedding-torch==0.3.5 14 | natsort==8.4.0 -------------------------------------------------------------------------------- /sample_scripts/with_mask_sample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Sample new images from a pre-trained DiT. 9 | """ 10 | import os 11 | import sys 12 | import math 13 | try: 14 | import utils 15 | from diffusion import create_diffusion 16 | except: 17 | # sys.path.append(os.getcwd()) 18 | sys.path.append(os.path.split(sys.path[0])[0]) 19 | # sys.path[0] 20 | # os.path.split(sys.path[0]) 21 | import utils 22 | 23 | from diffusion import create_diffusion 24 | 25 | import torch 26 | torch.backends.cuda.matmul.allow_tf32 = True 27 | torch.backends.cudnn.allow_tf32 = True 28 | import argparse 29 | import torchvision 30 | 31 | from einops import rearrange 32 | from models import get_models 33 | from torchvision.utils import save_image 34 | from diffusers.models import AutoencoderKL 35 | from models.clip import TextEmbedder 36 | from omegaconf import OmegaConf 37 | from PIL import Image 38 | import numpy as np 39 | from torchvision import transforms 40 | sys.path.append("..") 41 | from datasets import video_transforms 42 | from utils import mask_generation_before 43 | from natsort import natsorted 44 | from diffusers.utils.import_utils import is_xformers_available 45 | import pdb 46 | 47 | def get_input(args): 48 | input_path = args.input_path 49 | transform_video = transforms.Compose([ 50 | video_transforms.ToTensorVideo(), # TCHW 51 | video_transforms.ResizeVideo((args.image_h, args.image_w)), 52 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) 53 | ]) 54 | if input_path is not None: 55 | print(f'loading video from {input_path}') 56 | if os.path.isdir(input_path): 57 | file_list = os.listdir(input_path) 58 | video_frames = [] 59 | if args.mask_type.startswith('onelast'): 60 | num = int(args.mask_type.split('onelast')[-1]) 61 | # get first and last frame 62 | first_frame_path = os.path.join(input_path, natsorted(file_list)[0]) 63 | last_frame_path = os.path.join(input_path, natsorted(file_list)[-1]) 64 | first_frame = torch.as_tensor(np.array(Image.open(first_frame_path), dtype=np.uint8, copy=True)).unsqueeze(0) 65 | last_frame = torch.as_tensor(np.array(Image.open(last_frame_path), dtype=np.uint8, copy=True)).unsqueeze(0) 66 | for i in range(num): 67 | video_frames.append(first_frame) 68 | # add zeros to frames 69 | num_zeros = args.num_frames-2*num 70 | for i in range(num_zeros): 71 | zeros = torch.zeros_like(first_frame) 72 | video_frames.append(zeros) 73 | for i in range(num): 74 | video_frames.append(last_frame) 75 | n = 0 76 | video_frames = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2) # f,c,h,w 77 | video_frames = transform_video(video_frames) 78 | else: 79 | for file in file_list: 80 | if file.endswith('jpg') or file.endswith('png'): 81 | image = torch.as_tensor(np.array(Image.open(file), dtype=np.uint8, copy=True)).unsqueeze(0) 82 | video_frames.append(image) 83 | else: 84 | continue 85 | n = 0 86 | video_frames = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2) # f,c,h,w 87 | video_frames = transform_video(video_frames) 88 | return video_frames, n 89 | elif os.path.isfile(input_path): 90 | _, full_file_name = os.path.split(input_path) 91 | file_name, extension = os.path.splitext(full_file_name) 92 | if extension == '.jpg' or extension == '.png': 93 | print("loading the input image") 94 | video_frames = [] 95 | num = int(args.mask_type.split('first')[-1]) 96 | first_frame = torch.as_tensor(np.array(Image.open(input_path), dtype=np.uint8, copy=True)).unsqueeze(0) 97 | for i in range(num): 98 | video_frames.append(first_frame) 99 | num_zeros = args.num_frames-num 100 | for i in range(num_zeros): 101 | zeros = torch.zeros_like(first_frame) 102 | video_frames.append(zeros) 103 | n = 0 104 | video_frames = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2) # f,c,h,w 105 | video_frames = transform_video(video_frames) 106 | return video_frames, n 107 | else: 108 | raise TypeError(f'{extension} is not supported !!') 109 | else: 110 | raise ValueError('Please check your path input!!') 111 | else: 112 | raise ValueError('Need to give a video or some images') 113 | 114 | def auto_inpainting(args, video_input, masked_video, mask, prompt, vae, text_encoder, diffusion, model, device,): 115 | b,f,c,h,w=video_input.shape 116 | latent_h = args.image_size[0] // 8 117 | latent_w = args.image_size[1] // 8 118 | 119 | # prepare inputs 120 | if args.use_fp16: 121 | z = torch.randn(1, 4, args.num_frames, args.latent_h, args.latent_w, dtype=torch.float16, device=device) # b,c,f,h,w 122 | masked_video = masked_video.to(dtype=torch.float16) 123 | mask = mask.to(dtype=torch.float16) 124 | else: 125 | z = torch.randn(1, 4, args.num_frames, args.latent_h, args.latent_w, device=device) # b,c,f,h,w 126 | 127 | 128 | masked_video = rearrange(masked_video, 'b f c h w -> (b f) c h w').contiguous() 129 | masked_video = vae.encode(masked_video).latent_dist.sample().mul_(0.18215) 130 | masked_video = rearrange(masked_video, '(b f) c h w -> b c f h w', b=b).contiguous() 131 | mask = torch.nn.functional.interpolate(mask[:,:,0,:], size=(latent_h, latent_w)).unsqueeze(1) 132 | 133 | # classifier_free_guidance 134 | if args.do_classifier_free_guidance: 135 | masked_video = torch.cat([masked_video] * 2) 136 | mask = torch.cat([mask] * 2) 137 | z = torch.cat([z] * 2) 138 | prompt_all = [prompt] + [args.negative_prompt] 139 | 140 | else: 141 | masked_video = masked_video 142 | mask = mask 143 | z = z 144 | prompt_all = [prompt] 145 | 146 | text_prompt = text_encoder(text_prompts=prompt_all, train=False) 147 | model_kwargs = dict(encoder_hidden_states=text_prompt, 148 | class_labels=None, 149 | cfg_scale=args.cfg_scale, 150 | use_fp16=args.use_fp16,) # tav unet 151 | 152 | # Sample video: 153 | if args.sample_method == 'ddim': 154 | samples = diffusion.ddim_sample_loop( 155 | model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device, \ 156 | mask=mask, x_start=masked_video, use_concat=args.use_mask 157 | ) 158 | elif args.sample_method == 'ddpm': 159 | samples = diffusion.p_sample_loop( 160 | model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device, \ 161 | mask=mask, x_start=masked_video, use_concat=args.use_mask 162 | ) 163 | samples, _ = samples.chunk(2, dim=0) # [1, 4, 16, 32, 32] 164 | if args.use_fp16: 165 | samples = samples.to(dtype=torch.float16) 166 | 167 | video_clip = samples[0].permute(1, 0, 2, 3).contiguous() # [16, 4, 32, 32] 168 | video_clip = vae.decode(video_clip / 0.18215).sample # [16, 3, 256, 256] 169 | return video_clip 170 | 171 | def main(args): 172 | # Setup PyTorch: 173 | if args.seed: 174 | torch.manual_seed(args.seed) 175 | torch.set_grad_enabled(False) 176 | device = "cuda" if torch.cuda.is_available() else "cpu" 177 | # device = "cpu" 178 | 179 | if args.ckpt is None: 180 | raise ValueError("Please specify a checkpoint path using --ckpt ") 181 | 182 | # Load model: 183 | latent_h = args.image_size[0] // 8 184 | latent_w = args.image_size[1] // 8 185 | args.image_h = args.image_size[0] 186 | args.image_w = args.image_size[1] 187 | args.latent_h = latent_h 188 | args.latent_w = latent_w 189 | print('loading model') 190 | model = get_models(args).to(device) 191 | 192 | if args.enable_xformers_memory_efficient_attention: 193 | if is_xformers_available(): 194 | model.enable_xformers_memory_efficient_attention() 195 | else: 196 | raise ValueError("xformers is not available. Make sure it is installed correctly") 197 | 198 | # load model 199 | ckpt_path = args.ckpt 200 | state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage)['ema'] 201 | model.load_state_dict(state_dict) 202 | print('loading succeed') 203 | 204 | model.eval() 205 | pretrained_model_path = args.pretrained_model_path 206 | diffusion = create_diffusion(str(args.num_sampling_steps)) 207 | vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device) 208 | text_encoder = TextEmbedder(pretrained_model_path).to(device) 209 | if args.use_fp16: 210 | print('Warnning: using half percision for inferencing!') 211 | vae.to(dtype=torch.float16) 212 | model.to(dtype=torch.float16) 213 | text_encoder.to(dtype=torch.float16) 214 | 215 | # prompt: 216 | prompt = args.text_prompt 217 | if prompt ==[]: 218 | prompt = args.input_path.split('/')[-1].split('.')[0].replace('_', ' ') 219 | else: 220 | prompt = prompt[0] 221 | prompt_base = prompt.replace(' ','_') 222 | prompt = prompt + args.additional_prompt 223 | 224 | if not os.path.exists(os.path.join(args.save_path)): 225 | os.makedirs(os.path.join(args.save_path)) 226 | video_input, researve_frames = get_input(args) # f,c,h,w 227 | video_input = video_input.to(device).unsqueeze(0) # b,f,c,h,w 228 | mask = mask_generation_before(args.mask_type, video_input.shape, video_input.dtype, device) # b,f,c,h,w 229 | masked_video = video_input * (mask == 0) 230 | 231 | video_clip = auto_inpainting(args, video_input, masked_video, mask, prompt, vae, text_encoder, diffusion, model, device,) 232 | video_ = ((video_clip * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1) 233 | save_video_path = os.path.join(args.save_path, prompt_base+ '.mp4') 234 | torchvision.io.write_video(save_video_path, video_, fps=8) 235 | print(f'save in {save_video_path}') 236 | 237 | 238 | if __name__ == "__main__": 239 | parser = argparse.ArgumentParser() 240 | parser.add_argument("--config", type=str, default="./configs/sample_mask.yaml") 241 | args = parser.parse_args() 242 | omega_conf = OmegaConf.load(args.config) 243 | main(omega_conf) 244 | -------------------------------------------------------------------------------- /seine.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vchitect/SEINE/bd8d519ef4e7ce761f0b1064f906bb96f4cd3ebe/seine.gif -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import logging 5 | import subprocess 6 | import numpy as np 7 | import torch.distributed as dist 8 | 9 | # from torch._six import inf 10 | from torch import inf 11 | from PIL import Image 12 | from typing import Union, Iterable 13 | from collections import OrderedDict 14 | from torch.utils.tensorboard import SummaryWriter 15 | _tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]] 16 | 17 | ################################################################################# 18 | # Training Helper Functions # 19 | ################################################################################# 20 | def fetch_files_by_numbers(start_number, count, file_list): 21 | file_numbers = range(start_number, start_number + count) 22 | found_files = [] 23 | for file_number in file_numbers: 24 | file_number_padded = str(file_number).zfill(2) 25 | for file_name in file_list: 26 | if file_name.endswith(file_number_padded + '.csv'): 27 | found_files.append(file_name) 28 | break # Stop searching once a file is found for the current number 29 | return found_files 30 | 31 | ################################################################################# 32 | # Training Clip Gradients # 33 | ################################################################################# 34 | 35 | def get_grad_norm( 36 | parameters: _tensor_or_tensors, norm_type: float = 2.0) -> torch.Tensor: 37 | r""" 38 | Copy from torch.nn.utils.clip_grad_norm_ 39 | 40 | Clips gradient norm of an iterable of parameters. 41 | 42 | The norm is computed over all gradients together, as if they were 43 | concatenated into a single vector. Gradients are modified in-place. 44 | 45 | Args: 46 | parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a 47 | single Tensor that will have gradients normalized 48 | max_norm (float or int): max norm of the gradients 49 | norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for 50 | infinity norm. 51 | error_if_nonfinite (bool): if True, an error is thrown if the total 52 | norm of the gradients from :attr:`parameters` is ``nan``, 53 | ``inf``, or ``-inf``. Default: False (will switch to True in the future) 54 | 55 | Returns: 56 | Total norm of the parameter gradients (viewed as a single vector). 57 | """ 58 | if isinstance(parameters, torch.Tensor): 59 | parameters = [parameters] 60 | grads = [p.grad for p in parameters if p.grad is not None] 61 | norm_type = float(norm_type) 62 | if len(grads) == 0: 63 | return torch.tensor(0.) 64 | device = grads[0].device 65 | if norm_type == inf: 66 | norms = [g.detach().abs().max().to(device) for g in grads] 67 | total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms)) 68 | else: 69 | total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type) 70 | return total_norm 71 | 72 | def clip_grad_norm_( 73 | parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0, 74 | error_if_nonfinite: bool = False, clip_grad = True) -> torch.Tensor: 75 | r""" 76 | Copy from torch.nn.utils.clip_grad_norm_ 77 | 78 | Clips gradient norm of an iterable of parameters. 79 | 80 | The norm is computed over all gradients together, as if they were 81 | concatenated into a single vector. Gradients are modified in-place. 82 | 83 | Args: 84 | parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a 85 | single Tensor that will have gradients normalized 86 | max_norm (float or int): max norm of the gradients 87 | norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for 88 | infinity norm. 89 | error_if_nonfinite (bool): if True, an error is thrown if the total 90 | norm of the gradients from :attr:`parameters` is ``nan``, 91 | ``inf``, or ``-inf``. Default: False (will switch to True in the future) 92 | 93 | Returns: 94 | Total norm of the parameter gradients (viewed as a single vector). 95 | """ 96 | if isinstance(parameters, torch.Tensor): 97 | parameters = [parameters] 98 | grads = [p.grad for p in parameters if p.grad is not None] 99 | max_norm = float(max_norm) 100 | norm_type = float(norm_type) 101 | if len(grads) == 0: 102 | return torch.tensor(0.) 103 | device = grads[0].device 104 | if norm_type == inf: 105 | norms = [g.detach().abs().max().to(device) for g in grads] 106 | total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms)) 107 | else: 108 | total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type) 109 | # print(total_norm) 110 | 111 | if clip_grad: 112 | if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): 113 | raise RuntimeError( 114 | f'The total norm of order {norm_type} for gradients from ' 115 | '`parameters` is non-finite, so it cannot be clipped. To disable ' 116 | 'this error and scale the gradients by the non-finite norm anyway, ' 117 | 'set `error_if_nonfinite=False`') 118 | clip_coef = max_norm / (total_norm + 1e-6) 119 | # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so 120 | # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization 121 | # when the gradients do not reside in CPU memory. 122 | clip_coef_clamped = torch.clamp(clip_coef, max=1.0) 123 | for g in grads: 124 | g.detach().mul_(clip_coef_clamped.to(g.device)) 125 | # gradient_cliped = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type) 126 | # print(gradient_cliped) 127 | return total_norm 128 | 129 | def separation_content_motion(video_clip): 130 | """ 131 | separate coontent and motion in a given video 132 | Args: 133 | video_clip, a give video clip, [B F C H W] 134 | 135 | Return: 136 | base frame, [B, 1, C, H, W] 137 | motions, [B, F-1, C, H, W], 138 | the first is base frame, 139 | the second is motions based on base frame 140 | """ 141 | total_frames = video_clip.shape[1] 142 | base_frame = video_clip[0] 143 | motions = [video_clip[i] - base_frame for i in range(1, total_frames)] 144 | motions = torch.cat(motions, dim=1) 145 | return base_frame, motions 146 | 147 | def get_experiment_dir(root_dir, args): 148 | if args.use_compile: 149 | root_dir += '-Compile' # speedup by torch compile 150 | if args.fixed_spatial: 151 | root_dir += '-FixedSpa' 152 | if args.enable_xformers_memory_efficient_attention: 153 | root_dir += '-Xfor' 154 | if args.gradient_checkpointing: 155 | root_dir += '-Gc' 156 | if args.mixed_precision: 157 | root_dir += '-Amp' 158 | if args.image_size == 512: 159 | root_dir += '-512' 160 | return root_dir 161 | 162 | ################################################################################# 163 | # Training Logger # 164 | ################################################################################# 165 | 166 | def create_logger(logging_dir): 167 | """ 168 | Create a logger that writes to a log file and stdout. 169 | """ 170 | if dist.get_rank() == 0: # real logger 171 | logging.basicConfig( 172 | level=logging.INFO, 173 | # format='[\033[34m%(asctime)s\033[0m] %(message)s', 174 | format='[%(asctime)s] %(message)s', 175 | datefmt='%Y-%m-%d %H:%M:%S', 176 | handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] 177 | ) 178 | logger = logging.getLogger(__name__) 179 | 180 | else: # dummy logger (does nothing) 181 | logger = logging.getLogger(__name__) 182 | logger.addHandler(logging.NullHandler()) 183 | return logger 184 | 185 | def create_accelerate_logger(logging_dir, is_main_process=False): 186 | """ 187 | Create a logger that writes to a log file and stdout. 188 | """ 189 | if is_main_process: # real logger 190 | logging.basicConfig( 191 | level=logging.INFO, 192 | # format='[\033[34m%(asctime)s\033[0m] %(message)s', 193 | format='[%(asctime)s] %(message)s', 194 | datefmt='%Y-%m-%d %H:%M:%S', 195 | handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] 196 | ) 197 | logger = logging.getLogger(__name__) 198 | else: # dummy logger (does nothing) 199 | logger = logging.getLogger(__name__) 200 | logger.addHandler(logging.NullHandler()) 201 | return logger 202 | 203 | 204 | def create_tensorboard(tensorboard_dir): 205 | """ 206 | Create a tensorboard that saves losses. 207 | """ 208 | if dist.get_rank() == 0: # real tensorboard 209 | # tensorboard 210 | writer = SummaryWriter(tensorboard_dir) 211 | 212 | return writer 213 | 214 | def write_tensorboard(writer, *args): 215 | ''' 216 | write the loss information to a tensorboard file. 217 | Only for pytorch DDP mode. 218 | ''' 219 | if dist.get_rank() == 0: # real tensorboard 220 | writer.add_scalar(args[0], args[1], args[2]) 221 | 222 | ################################################################################# 223 | # EMA Update/ DDP Training Utils # 224 | ################################################################################# 225 | 226 | @torch.no_grad() 227 | def update_ema(ema_model, model, decay=0.9999): 228 | """ 229 | Step the EMA model towards the current model. 230 | """ 231 | ema_params = OrderedDict(ema_model.named_parameters()) 232 | model_params = OrderedDict(model.named_parameters()) 233 | 234 | for name, param in model_params.items(): 235 | # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed 236 | if param.requires_grad: 237 | ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) 238 | 239 | def requires_grad(model, flag=True): 240 | """ 241 | Set requires_grad flag for all parameters in a model. 242 | """ 243 | for p in model.parameters(): 244 | p.requires_grad = flag 245 | 246 | def cleanup(): 247 | """ 248 | End DDP training. 249 | """ 250 | dist.destroy_process_group() 251 | 252 | 253 | def setup_distributed(backend="nccl", port=None): 254 | """Initialize distributed training environment. 255 | support both slurm and torch.distributed.launch 256 | see torch.distributed.init_process_group() for more details 257 | """ 258 | num_gpus = torch.cuda.device_count() 259 | 260 | if "SLURM_JOB_ID" in os.environ: 261 | rank = int(os.environ["SLURM_PROCID"]) 262 | world_size = int(os.environ["SLURM_NTASKS"]) 263 | node_list = os.environ["SLURM_NODELIST"] 264 | addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1") 265 | # specify master port 266 | if port is not None: 267 | os.environ["MASTER_PORT"] = str(port) 268 | elif "MASTER_PORT" not in os.environ: 269 | # os.environ["MASTER_PORT"] = "29566" 270 | os.environ["MASTER_PORT"] = str(29566 + num_gpus) 271 | if "MASTER_ADDR" not in os.environ: 272 | os.environ["MASTER_ADDR"] = addr 273 | os.environ["WORLD_SIZE"] = str(world_size) 274 | os.environ["LOCAL_RANK"] = str(rank % num_gpus) 275 | os.environ["RANK"] = str(rank) 276 | else: 277 | rank = int(os.environ["RANK"]) 278 | world_size = int(os.environ["WORLD_SIZE"]) 279 | 280 | # torch.cuda.set_device(rank % num_gpus) 281 | 282 | dist.init_process_group( 283 | backend=backend, 284 | world_size=world_size, 285 | rank=rank, 286 | ) 287 | 288 | ################################################################################# 289 | # Testing Utils # 290 | ################################################################################# 291 | 292 | def save_video_grid(video, nrow=None): 293 | b, t, h, w, c = video.shape 294 | 295 | if nrow is None: 296 | nrow = math.ceil(math.sqrt(b)) 297 | ncol = math.ceil(b / nrow) 298 | padding = 1 299 | video_grid = torch.zeros((t, (padding + h) * nrow + padding, 300 | (padding + w) * ncol + padding, c), dtype=torch.uint8) 301 | 302 | print(video_grid.shape) 303 | for i in range(b): 304 | r = i // ncol 305 | c = i % ncol 306 | start_r = (padding + h) * r 307 | start_c = (padding + w) * c 308 | video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i] 309 | 310 | return video_grid 311 | 312 | def save_videos_grid_tav(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=8): 313 | from einops import rearrange 314 | import imageio 315 | import torchvision 316 | 317 | videos = rearrange(videos, "b c t h w -> t b c h w") 318 | outputs = [] 319 | for x in videos: 320 | x = torchvision.utils.make_grid(x, nrow=n_rows) 321 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) 322 | if rescale: 323 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 324 | x = (x * 255).numpy().astype(np.uint8) 325 | outputs.append(x) 326 | 327 | # os.makedirs(os.path.dirname(path), exist_ok=True) 328 | imageio.mimsave(path, outputs, fps=fps) 329 | 330 | 331 | ################################################################################# 332 | # MMCV Utils # 333 | ################################################################################# 334 | 335 | 336 | def collect_env(): 337 | # Copyright (c) OpenMMLab. All rights reserved. 338 | from mmcv.utils import collect_env as collect_base_env 339 | from mmcv.utils import get_git_hash 340 | """Collect the information of the running environments.""" 341 | 342 | env_info = collect_base_env() 343 | env_info['MMClassification'] = get_git_hash()[:7] 344 | 345 | for name, val in env_info.items(): 346 | print(f'{name}: {val}') 347 | 348 | print(torch.cuda.get_arch_list()) 349 | print(torch.version.cuda) 350 | 351 | 352 | ################################################################################# 353 | # Long video generation Utils # 354 | ################################################################################# 355 | 356 | def mask_generation_before(mask_type, shape, dtype, device, dropout_prob=0.0, use_image_num=0): 357 | b, f, c, h, w = shape 358 | if mask_type.startswith('first'): 359 | num = int(mask_type.split('first')[-1]) 360 | mask_f = torch.cat([torch.zeros(1, num, 1, 1, 1, dtype=dtype, device=device), 361 | torch.ones(1, f-num, 1, 1, 1, dtype=dtype, device=device)], dim=1) 362 | mask = mask_f.expand(b, -1, c, h, w) 363 | elif mask_type.startswith('all'): 364 | mask = torch.ones(b,f,c,h,w,dtype=dtype,device=device) 365 | elif mask_type.startswith('onelast'): 366 | num = int(mask_type.split('onelast')[-1]) 367 | mask_one = torch.zeros(1,1,1,1,1, dtype=dtype, device=device) 368 | mask_mid = torch.ones(1,f-2*num,1,1,1,dtype=dtype, device=device) 369 | mask_last = torch.zeros_like(mask_one) 370 | mask = torch.cat([mask_one]*num + [mask_mid] + [mask_last]*num, dim=1) 371 | mask = mask.expand(b, -1, c, h, w) 372 | else: 373 | raise ValueError(f"Invalid mask type: {mask_type}") 374 | return mask 375 | --------------------------------------------------------------------------------