├── .gitignore ├── Dockerfile ├── LICENCE ├── README.md ├── configs └── inference_256_v1.1.yaml ├── dataset_prep └── llama_prompts.py ├── download_dataset.sh ├── download_weights.sh ├── lvdm ├── basics.py ├── common.py ├── distributions.py ├── ema.py ├── models │ ├── autoencoder.py │ ├── ddpm3d.py │ ├── samplers │ │ ├── ddim.py │ │ └── ddim_multiplecond.py │ └── utils_diffusion.py ├── modules │ ├── attention.py │ ├── encoders │ │ ├── condition.py │ │ └── resampler.py │ ├── networks │ │ ├── ae_modules.py │ │ └── openaimodel3d.py │ └── x_transformer.py └── utils.py ├── predict.py ├── test_data ├── img01.jpg └── prompt_file.txt ├── train.py └── video_dataset.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | weights 3 | output 4 | .ipynb_checkpoints 5 | 6 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:2.5.1-cuda12.4-cudnn9-devel 2 | 3 | RUN pip install --no-cache-dir \ 4 | einops==0.8.0 \ 5 | omegaconf==2.3.0 \ 6 | pillow \ 7 | transformers==4.46.3 \ 8 | open_clip_torch==2.22.0 \ 9 | kornia==0.7.4 \ 10 | tqdm \ 11 | timm==1.0.11 12 | -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Tomáš Souček 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ShowHowTo: Generating Scene-Conditioned Step-by-Step Visual Instructions 2 | 3 | ### [[Project Website :dart:]](https://soczech.github.io/showhowto/)   [[Paper :page_with_curl:]](https://arxiv.org/abs/2412.01987)   [Code :octocat:] 4 | 5 | This repository contains code for the CVPR'25 paper [ShowHowTo: Generating Scene-Conditioned Step-by-Step Visual Instructions](https://arxiv.org/abs/2412.01987). 6 | 7 | 8 | ## Run the model on your images and prompts 9 | 1. **Environment setup** 10 | - Use provided `Dockerfile` to build the environment or install the [packages](https://github.com/soCzech/ShowHowTo/blob/main/Dockerfile) manually. 11 | ``` 12 | docker build -t showhowto . 13 | docker run -it --rm -v $(pwd):$(pwd) -w $(pwd) --gpus=1 showhowto:latest bash 14 | ``` 15 | - The code, as written, requires a GPU. 16 | 17 | 2. **Download ShowHowTo model weights** 18 | - Use `download_weights.sh` script or download the [ShowHowTo weights](https://data.ciirc.cvut.cz/public/projects/2024ShowHowTo/weights/) manually. 19 | 20 | 3. **Get predictions** 21 | - Run the following command to get example predictions. 22 | ``` 23 | python predict.py --ckpt_path ./weights/showhowto_2to8steps.pt 24 | --prompt_file ./test_data/prompt_file.txt 25 | --unconditional_guidance_scale 7.5 26 | ``` 27 | - To run the model on your images and prompts, replace `./test_data/prompt_file.txt` with your prompt file. 28 | 29 | 30 | ## Training 31 | 1. **Environment setup** 32 | - Use the same environment as for the prediction (see above). 33 | 34 | 2. **Download DynamiCrafter model weights** 35 | - Use `download_weights.sh` script or download the [DynamiCrafter weights](https://huggingface.co/Doubiiu/DynamiCrafter/blob/main/model.ckpt) manually. 36 | 37 | 3. **Get the dataset** 38 | - To replicate our experiments on the ShowHowTo dataset, see below, or use your own dataset. 39 | - The dataset must have the following directory structure. 40 | ``` 41 | dataset_root 42 | ├── prompts.json 43 | └── imgseqs 44 | ├── .jpg 45 | │ ... 46 | └── ... 47 | ``` 48 | There can be multiple directories with names starting with `imgseqs`. 49 | - The `promts.json` file must have the following structure. 50 | ``` 51 | { 52 | "": ["prompt for the 1st frame", "prompt for the 2nd frame", ...], 53 | ... 54 | } 55 | ``` 56 | - The sequence image `.jpg` must be of width `N*W` (`W` is width of each image in the sequence) and arbitrary height `H`. 57 | The number of images in the sequence `N` must match the length of the prompt list in the `prompts.json` file. 58 | 4. **Train** 59 | - Run the training code. 60 | ``` 61 | python train.py --local_batch_size 2 62 | --dataset_root /path/to/ShowHowToTrain 63 | --ckpt_path weights/dynamicrafter_256_v1.ckpt 64 | ``` 65 | - We trained on a single node with 8 GPUs with the batch size of 2 videos per GPU. Be advised, that more than 40 GB of VRAM per GPU may be required to train with batch size larger than 1. 66 | 67 | 68 | ## Dataset 69 | You can download the ShowHowTo dataset using the `download_dataset.sh` script. To also download the image sequences from our servers, you need username and password. 70 | You can obtain it by sending an email to *tomas.soucek at cvut dot cz* specifying your name and affiliation. Please use your institutional email (i.e., not gmail, etc.). 71 | 72 | You can also extract the dataset from the raw original videos with the following steps. 73 | 74 | 1. **Download the HowTo100M videos and the ShowHowTo prompts** 75 | - The list of all video ids for both the train set and test set can be found [here](https://data.ciirc.cvut.cz/public/projects/2024ShowHowTo/dataset/). 76 | - For each video, the `keyframes.json` file contains information on which video frames are part of the dataset. 77 | - You can find there also the prompts for each video in `prompts.json` file. 78 | 2. **Extract the video frames of the ShowHowTo dataset** 79 | - To extract the frames from the videos, we used ffmpeg v7.0.1 with the following function. 80 | ```python 81 | def extract_frame(video, start_sec, frame_idx, width, height): 82 | ffmpeg_args = ['ffmpeg', '-i', video, '-f', 'rawvideo', '-pix_fmt', 'rgb24', 83 | '-vf', f'fps=5,select=gte(t\\,{start_sec}),select=eq(n\\,{frame_idx})', 84 | '-s', f'{width}x{height}', '-vframes', '1', 'pipe:'] 85 | video_stream = subprocess.Popen(ffmpeg_args, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL) 86 | 87 | in_bytes = video_stream.stdout.read(width * height * 3) 88 | return np.frombuffer(in_bytes, np.uint8).reshape([height, width, 3]) 89 | ``` 90 | The function arguments are: `video` is the path to the video, `start_sec` and `frame_idx` are the values from the `keyframes.json` and `width` and `height` specify the output image size (we used the native video resolution here). 91 | 3. **Prepare the image sequences** 92 | - Concatenate all frames from a video in the horizontal dimension and place the resulting concatenated image into `dataset_root/imgseqs/.jpg`. The `` is the YouTube video id. 93 | 94 | 95 | 96 | 97 | ## Citation 98 | ```bibtex 99 | @article{soucek2025showhowto, 100 | title={ShowHowTo: Generating Scene-Conditioned Step-by-Step Visual Instructions}, 101 | author={Sou\v{c}ek, Tom\'{a}\v{s} and Gatti, Prajwal and Wray, Michael and Laptev, Ivan and Damen, Dima and Sivic, Josef}, 102 | booktitle = {Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 103 | month = {June}, 104 | year = {2025} 105 | } 106 | ``` 107 | 108 | ## Acknowledgements 109 | 110 | The code has been adapted from the ECCV 2024 paper [DynamiCrafter: Animating Open-domain Images with Video Diffusion Priors](https://arxiv.org/abs/2310.12190) available on [GitHub](https://github.com/Doubiiu/DynamiCrafter). Please refer to its license before use. 111 | 112 | -------------------------------------------------------------------------------- /configs/inference_256_v1.1.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: lvdm.models.ddpm3d.LatentVisualDiffusion 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.012 6 | num_timesteps_cond: 1 7 | timesteps: 1000 8 | first_stage_key: video 9 | cond_stage_key: caption 10 | cond_stage_trainable: False 11 | conditioning_key: hybrid 12 | image_size: [32, 32] 13 | channels: 4 14 | scale_by_std: False 15 | scale_factor: 0.18215 16 | use_ema: False 17 | perframe_ae: True 18 | uncond_type: 'empty_seq' 19 | unet_config: 20 | target: lvdm.modules.networks.openaimodel3d.UNetModel 21 | params: 22 | in_channels: 8 23 | out_channels: 4 24 | model_channels: 320 25 | attention_resolutions: 26 | - 4 27 | - 2 28 | - 1 29 | num_res_blocks: 2 30 | channel_mult: 31 | - 1 32 | - 2 33 | - 4 34 | - 4 35 | dropout: 0.1 36 | num_head_channels: 64 37 | transformer_depth: 1 38 | context_dim: 1024 39 | use_linear: true 40 | use_checkpoint: False 41 | temporal_conv: True 42 | temporal_attention: True 43 | temporal_selfatt_only: true 44 | use_relative_position: false 45 | use_causal_attention: False 46 | temporal_length: 16 47 | addition_attention: true 48 | image_cross_attention: true 49 | image_cross_attention_scale_learnable: true 50 | default_fs: 3 51 | fs_condition: true 52 | 53 | first_stage_config: 54 | target: lvdm.models.autoencoder.AutoencoderKL 55 | params: 56 | embed_dim: 4 57 | monitor: val/rec_loss 58 | ddconfig: 59 | double_z: True 60 | z_channels: 4 61 | resolution: 256 62 | in_channels: 3 63 | out_ch: 3 64 | ch: 128 65 | ch_mult: 66 | - 1 67 | - 2 68 | - 4 69 | - 4 70 | num_res_blocks: 2 71 | attn_resolutions: [] 72 | dropout: 0.0 73 | lossconfig: 74 | target: torch.nn.Identity 75 | 76 | cond_stage_config: 77 | target: lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder 78 | params: 79 | freeze: true 80 | layer: "penultimate" 81 | 82 | img_cond_stage_config: 83 | target: lvdm.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2 84 | params: 85 | freeze: true 86 | 87 | image_proj_stage_config: 88 | target: lvdm.modules.encoders.resampler.Resampler 89 | params: 90 | dim: 1024 91 | depth: 4 92 | dim_head: 64 93 | heads: 12 94 | num_queries: 16 95 | embedding_dim: 1280 96 | output_dim: 1024 97 | ff_mult: 4 98 | video_length: 16 99 | 100 | -------------------------------------------------------------------------------- /dataset_prep/llama_prompts.py: -------------------------------------------------------------------------------- 1 | def prompt_to_filter_noninstructional_videos(video_title, video_transcript, max_transcript_length=1500): 2 | """ 3 | Generate a prompt to determine if a video is instructional based on its title and transcript. 4 | 5 | Args: 6 | video_title: The title of the video 7 | video_transcript: List of transcript segments with start, end, and text fields 8 | max_transcript_length: Maximum length of transcript to include in prompt 9 | 10 | Returns: 11 | str: Formatted prompt for video classification 12 | """ 13 | # Extract and join transcript text 14 | formatted_transcript = " ".join(item['text'] for item in video_transcript) 15 | 16 | # Increase length limit for longer videos 17 | if video_transcript[-1]['end'] > 400: 18 | max_transcript_length += 2000 19 | 20 | # Truncate transcript if needed 21 | if len(formatted_transcript) > max_transcript_length: 22 | formatted_transcript = formatted_transcript[:max_transcript_length] + "..." 23 | 24 | prompt = f""" 25 | Based on the following video title and partial transcript segment, determine if the video is instructional in nature, where "instructional" means it involves actively demonstrating or teaching how to perform a specific task or activity with physical steps (e.g., cooking a recipe, repairing something, crafting, etc.). 26 | Respond with 'Yes' if the video is actively demonstrating or teaching how to perform a specific task, or 'No' if it is not. Then provide a single sentence explanation. 27 | 28 | Examples of instructional videos: 29 | - How to Bake a Chocolate Cake 30 | - Reparing a Leaky Faucet 31 | - Learn to Knit a Scarf 32 | 33 | Examples of non-instructional videos: 34 | - Discussing Fashion Trends 35 | - Product Reviews and Opinions 36 | - A Vlog of My Daily Life 37 | 38 | Example 1: 39 | Video Title: Red Dead Redemption 2 - Herbert Moon and Strange man Eastereggs In Armadillo [SPOILERS] 40 | Partial Video Transcript Segment: 41 | "oh you're back I feared the worst it's all here waiting for you who's that I don't know it's just a little portrait somebody gave me once I always quite liked it why no reason just seem familiar anyway this area is closed to the public if you want to shop here you better act right move you long streak of piss who do you think you are for God's sake get out you degenerate you blew it get out of my store if you don't leave there will be problems okay okay stay calm oh you'll..." 42 | 43 | Is this video actively demonstrating or teaching how to perform a specific task? No 44 | Explanation: The video is not actively demonstrating or teaching how to perform a specific task; it appears to be showcasing or discussing Easter eggs in the game Red Dead Redemption 2. 45 | 46 | Example 2: 47 | Video Title: Fantastic VEGAN Cupcakes with Raspberry Frosting 48 | Partial Video Transcript Segment: 49 | "hey there I'm chef Annie and tomorrow is Valentine's Day so we are making some extra special cupcakes for this occasion can you believe that we have not made cupcakes on this channel it's about time so today I'm going to show you how to present these cupcakes so they look impressive and absolutely beautiful so enough copy let's cook it so we're going to start by mixing together our wet ingredients..." 50 | 51 | Is this video actively demonstrating or teaching how to perform a specific task? Yes 52 | Explanation: The video actively demonstrates and teaches how to make vegan cupcakes with raspberry frosting, as indicated by the detailed steps and instructions given by the chef. 53 | 54 | Example 3: 55 | Video Title: How To: Piston Ring Install 56 | Partial Video Transcript Segment: 57 | "hey it's Matt from how to motorcycle repair comm just got done doing a top end on a YZF 250 or yz250 F and I thought I'd do a quick video on how to install a piston ring the easy way now I've done this in the past too but most people will take the ends here and spread it and put it on but you can potentially damage the ring so an easier way to do that is just to take this right here incident in the groove that you need then you bend one up..." 58 | 59 | Is this video actively demonstrating or teaching how to perform a specific task? Yes 60 | Explanation: The video is actively demonstrating or teaching how to install a piston ring, which is a specific task. 61 | 62 | Example 4: 63 | Video Title: Best gas weed eater reviews Husqvarna 128DJ with 28cc Cycle Gas Powered String Trimmer 64 | Partial Video Transcript Segment: 65 | "Video Transcript: guys i'm shanley today i'm going to tell you about this straight shaft gas-powered trimmer from husqvarna this trimmer runs on a 28 CC two cycle engine it features 1.1 horsepower and a three-piece crankshaft it also has a smart start system as well as an auto return to stop switch and this trimmer is air purge design for easier starting it has a 17 inch cutting path..." 66 | 67 | Is this video actively demonstrating or teaching how to perform a specific task? No 68 | Explanation: This video is reviewing the features of a gas-powered trimmer rather than actively demonstrating or teaching how to use it. 69 | 70 | Now, determine if the following video is instructional in nature: 71 | 72 | Video Title: {video_title} 73 | 74 | Partial Video Transcript Segment: 75 | {formatted_transcript} 76 | 77 | Is this video actively demonstrating or teaching how to perform a specific task? 78 | """ 79 | return prompt 80 | 81 | 82 | def prompt_to_generate_keysteps(video_title, video_transcript): 83 | """ 84 | Generate a prompt to extract key steps from an instructional video transcript. 85 | 86 | Args: 87 | video_title: The title of the video 88 | video_transcript: List of transcript segments with start, end, and text fields 89 | 90 | Returns: 91 | str: Formatted prompt for step extraction 92 | """ 93 | 94 | formatted_transcript = '' 95 | for item in video_transcript: 96 | formatted_transcript += f"{item['start']:.2f} - {item['end']:.2f}: \"{item['text']}\"\n" 97 | 98 | prompt = f""" 99 | Below are transcripts from YouTube instructional videos and their corresponding extracted steps in a clear, third-person, step-by-step format like WikiHow. Each step is concise, actionable, and temporally ordered as they occur in the video. The steps include start and end timestamps indicating when the steps are carried out in the video. Follow this format to extract and summarize the key steps from the provided transcript. 100 | 101 | Example 1: 102 | YouTube Video Title: 103 | "BÁNH TÁO MINI - How To Make Apple Turnovers | Episode 11 | Taste From Home" 104 | 105 | YouTube Video Transcript: 106 | [start of transcript] 107 | [start of transcript] 108 | 0.87 - 7.79: "Hey little muffins, today we will make together a super easy, quick and delicious apple turnovers." 109 | 7.79 - 9.35: "40 minutes for all the process." 110 | 9.35 - 11.95: "Seriously, can someone deny them?" 111 | 11.95 - 13.63: "Ok, let's begin." 112 | 13.63 - 18.82: "First of all, combine the apple cubes, lemon juice, cinnamon and sugar in a bowl." 113 | 26.69 - 29.59: " Mixing, mixing, mixing." 114 | 29.59 - 32.62: "Apple and cinnamon always go perfectly together." 115 | 32.62 - 43.52: "Now using a round cutter or glass like me, cut 15 rounds from the pastry sheet." 116 | 57.86 - 64.99: " Here comes the fun part." 117 | 64.99 - 69.97: "Spoon about 2 teaspoons apple mixture in the center of one round." 118 | 69.97 - 74.41: "Using your fingers, gently fold the pastry over to enclose filling." 119 | 88.47 - 104.48: " After that, use a fork and press around the edges to seal and make your apple turnovers look more beautiful." 120 | 104.48 - 105.84: "This is how it looks like." 121 | 109.99 - 113.53: " I will show you one more time to make sure that you understand the technique." 122 | 113.53 - 117.20: "And if you still find my apple turnovers too ugly, I'm really sorry." 123 | 117.20 - 121.66: "Anyways, just have fun making them with your family and friends." 124 | 121.66 - 124.43: "Then it doesn't matter if they are beautiful or not." 125 | 124.43 - 125.45: "At least they are tasty." 126 | 127.35 - 133.35: " I finally finished my 15 turnovers." 127 | 151.62 - 157.46: " Now we will lightly beat one egg in a small bowl and then egg wash our apple turnovers." 128 | 157.46 - 164.10: "The egg wash will give them a gorgeous light brown color after baking in the oven." 129 | 164.10 - 174.87: "We will bake the apple turnovers at 180°C in about 18-20 minutes until golden." 130 | 174.87 - 177.35: "Your kitchen should smell amazing by now." 131 | 178.17 - 181.33: " Okay, taking photos like this is like a torture." 132 | 181.33 - 185.55: "I just can't resist grabbing one of them and enjoying immediately." 133 | 185.55 - 186.17: "Yay!" 134 | 186.17 - 189.00: "So finally I can try my apple turnovers." 135 | 189.00 - 201.34: "I'm really really excited." 136 | 201.34 - 203.22: "Oh my god!" 137 | 203.22 - 204.46: "This is really really good." 138 | 204.46 - 207.26: "You guys should definitely try out this recipe." 139 | 208.69 - 215.49: " So I hope you enjoyed this video and I will see you in my next video for a new recipe." 140 | 215.49 - 216.03: "Bye guys!" 141 | [end of transcript] 142 | 143 | Extracted Steps: 144 | [ 145 | {{ "WikiHow Title": "How to Make Apple Turnovers" }}, 146 | {{ "steps": [ 147 | {{ "step": 1, "instruction": "Combine apple cubes, lemon juice, cinnamon, and sugar in a bowl.", "start_timestamp": 13.63, "end_timestamp": 18.82 }}, 148 | {{ "step": 2, "instruction": "Mix the ingredients thoroughly.", "start_timestamp": 26.69, "end_timestamp": 29.59 }}, 149 | {{ "step": 3, "instruction": "Cut 15 rounds from the pastry sheet using a round cutter or a glass.", "start_timestamp": 32.62, "end_timestamp": 43.52 }}, 150 | {{ "step": 4, "instruction": "Spoon about 2 teaspoons of the apple mixture into the center of one round.", "start_timestamp": 64.99, "end_timestamp": 69.97 }}, 151 | {{ "step": 5, "instruction": "Gently fold the pastry over to enclose the filling using your fingers.", "start_timestamp": 69.97, "end_timestamp": 74.41 }}, 152 | {{ "step": 6, "instruction": "Press around the edges with a fork to seal and beautify the turnovers.", "start_timestamp": 88.47, "end_timestamp": 104.48 }}, 153 | {{ "step": 7, "instruction": "Repeat the technique until all turnovers are formed.", "start_timestamp": 109.99, "end_timestamp": 113.53 }}, 154 | {{ "step": 8, "instruction": "Lightly beat one egg in a small bowl.", "start_timestamp": 151.62, "end_timestamp": 157.46 }}, 155 | {{ "step": 9, "instruction": "Egg wash the apple turnovers to give them a gorgeous light brown color after baking.", "start_timestamp": 157.46, "end_timestamp": 164.10 }}, 156 | {{ "step": 10, "instruction": "Bake the apple turnovers at 180°C for 18-20 minutes until golden.", "start_timestamp": 164.10, "end_timestamp": 174.87 }}, 157 | {{ "step": 11, "instruction": "Enjoy the freshly baked apple turnovers.", "start_timestamp": 178.17, "end_timestamp": 185.55 }} 158 | ] 159 | }} 160 | ] 161 | 162 | Example 2: 163 | YouTube Video Title: 164 | "How to Clean Your Car Tires & Whitewalls" 165 | 166 | YouTube Video Transcript: 167 | [start of transcript] 168 | 0.17 - 5.42: " Cleaning your tires properly requires more than just using car wash soap and then rinsing it down." 169 | 5.42 - 11.84: "Today I want to show you the proper way to clean your tires to give it that nice new natural darkened look." 170 | 17.69 - 22.31: " Now, although you may be sitting there thinking to yourself, do I really need a specific tire cleaner?" 171 | 22.31 - 25.17: "Well, if you're an auto detailer, we highly recommend it." 172 | 25.17 - 30.74: "But even if you're a car enthusiast, there's a few reasons why this type of product can come in handy." 173 | 30.74 - 39.78: "Now, our tire and white wall cleaner is super concentrated, making it aggressive enough to really break down any of that gunk that gets stuck to the rubber of your tire." 174 | 39.78 - 45.30: "Not to mention, can safely and really help you clean any of those white walls or white letters." 175 | 45.80 - 55.29: " One ingredient that really makes this product unique is that it actually contains darkening agents that really help to bring out that dark, rich color of the rubber." 176 | 55.29 - 59.41: "This can also help with assisting when doing your tire dressing." 177 | 59.41 - 65.24: "It allows the tire dressing to set up properly, giving you a nice shine, and helps it to last even longer." 178 | 65.62 - 75.37: " Now, the first thing we want to do here is not just rinse down the tire, but as well as some of the surrounding areas like the wheel, wheel well, as well as the body of the vehicle here." 179 | 95.34 - 103.47: " Now that we've thoroughly rinsed down the surrounding areas as well as the tire, we can go ahead and take our tire and whitewall and spray an even mist all on the tire." 180 | 109.80 - 113.99: " Now, what we want to do is we want to allow the product to set up for about 30 seconds." 181 | 113.99 - 120.05: "This is really going to help to break down any of that gunk or dirt that is just stuck onto the rubber of the tire here." 182 | 120.05 - 125.30: "I mean, as you can already see, just by applying it on within a few seconds here, it's already turning brown." 183 | 125.30 - 129.98: "So it's showing that the product is activating, pulling off any of that dirt." 184 | 129.98 - 132.70: "So now we've allowed the product to set up for about 30 seconds or so." 185 | 132.70 - 137.59: "We can go ahead, take our brush, dip it in a bucket of water, and then begin with the scrubbing process." 186 | 156.52 - 163.69: " Seeing all these brown suds just like this is exactly what you want to see when using this product." 187 | 163.69 - 173.68: "Again that's showing that it's activated and it's really breaking down everything that is just on the rubber of this tire." 188 | 173.68 - 176.08: "Now we've scrubbed it we can go ahead and rinse it right down." 189 | 196.36 - 208.82: " So as you guys can see, it's honestly just as simple and as easy as that, with just a little bit of product, a little bit of scrubbing, and then a simple rinse, we were able to really bring back that dark, rich color to the rubber of this tire once again." 190 | 208.82 - 222.94: "So if this video was helpful, make sure you definitely give it a thumbs up, subscribe to our channel for more product and how-to videos, and to order your very own Tire and White Wall Cleaner, make sure you visit us at detailking.com, where we have everything you need to keep your car clean like a detail king." 191 | 222.94 - 223.84: "We'll see you guys next time." 192 | [end of transcript] 193 | 194 | Extracted Steps: 195 | [ 196 | {{ "WikiHow Title": "How to Clean Your Car Tires & Whitewalls"}}, 197 | {{ "steps": [ 198 | {{ "step": 1, "instruction": "Rinse down the tire and surrounding areas such as the wheel, wheel well, and the body of the vehicle.", "start_timestamp": 65.62, "end_timestamp": 75.37 }}, 199 | {{ "step": 2, "instruction": "Spray an even mist of tire and whitewall cleaner all over the tire.", "start_timestamp": 95.34, "end_timestamp": 103.47 }}, 200 | {{ "step": 3, "instruction": "Allow the cleaner to set up for about 30 seconds to break down dirt and gunk on the tire.", "start_timestamp": 109.80, "end_timestamp": 120.05 }}, 201 | {{ "step": 4, "instruction": "Observe the cleaner turning brown, indicating it is working.", "start_timestamp": 120.05, "end_timestamp": 125.30 }}, 202 | {{ "step": 6, "instruction": "Dip the brush in a bucket of water and begin scrubbing the tire.", "start_timestamp": 132.70, "end_timestamp": 137.59 }}, 203 | {{ "step": 7, "instruction": "Check for brown suds while scrubbing, which shows the cleaner is breaking down dirt.", "start_timestamp": 156.52, "end_timestamp": 163.69 }}, 204 | {{ "step": 8, "instruction": "Rinse the tire thoroughly to remove all cleaner and dirt.", "start_timestamp": 173.68, "end_timestamp": 176.08 }}, 205 | {{ "step": 9, "instruction": "Enjoy the clean, dark, and rich color of your tire's rubber.", "start_timestamp": 196.36, "end_timestamp": 208.82 }} 206 | ] 207 | }} 208 | ] 209 | 210 | Now, extract the steps from the following transcript: 211 | 212 | YouTube Video Title: 213 | {video_title} 214 | 215 | YouTube Video Transcript: 216 | [start of transcript] 217 | {formatted_transcript} 218 | [end of transcript] 219 | 220 | Extracted Steps: 221 | ```json 222 | """ 223 | return prompt 224 | 225 | 226 | if __name__ == "__main__": 227 | # An example video and its transcript 228 | example_title = "How to Baked a Potato in the pressure cooker." 229 | example_transcript = [ 230 | { 231 | "start": 0.504, 232 | "end": 2.605, 233 | "text": " Hi, it's Matthew in his pressure cooker again." 234 | }, 235 | { 236 | "start": 2.605, 237 | "end": 6.086, 238 | "text": "And today I'm going to make baked potatoes in the pressure cooker." 239 | }, 240 | { 241 | "start": 6.086, 242 | "end": 11.708, 243 | "text": "Obviously they're not baked potatoes, just an easy way to make something like a baked potato." 244 | }, 245 | { 246 | "start": 11.708, 247 | "end": 13.289, 248 | "text": "It's basically going to be steaming them." 249 | }, 250 | { 251 | "start": 13.289, 252 | "end": 18.051, 253 | "text": "So I use my aluminum foil trick." 254 | }, 255 | { 256 | "start": 18.051, 257 | "end": 19.532, 258 | "text": "Put some aluminum foil in there." 259 | }, 260 | { 261 | "start": 19.532, 262 | "end": 27.835, 263 | "text": "Now a lot of pressure cookers do come with standoffs so that you can put things in where they're sitting above the water." 264 | }, 265 | { 266 | "start": 28.857, 267 | "end": 29.698, 268 | "text": " I'm going to make two little rings." 269 | }, 270 | { 271 | "start": 29.698, 272 | "end": 38.503, 273 | "text": "I'm going to take my potatoes, put them in so that they're going to be steamed nicely." 274 | }, 275 | { 276 | "start": 38.503, 277 | "end": 40.424, 278 | "text": "Because you don't want to actually boil them." 279 | }, 280 | { 281 | "start": 40.424, 282 | "end": 42.285, 283 | "text": "You want to steam them." 284 | }, 285 | { 286 | "start": 42.285, 287 | "end": 44.987, 288 | "text": "And I've got my cup and a half of water." 289 | }, 290 | { 291 | "start": 44.987, 292 | "end": 50.23, 293 | "text": "I'll just dump that in." 294 | }, 295 | { 296 | "start": 50.23, 297 | "end": 52.272, 298 | "text": "See what I've got in here?" 299 | }, 300 | { 301 | "start": 52.272, 302 | "end": 53.472, 303 | "text": "So they're just sitting there." 304 | }, 305 | { 306 | "start": 53.472, 307 | "end": 54.313, 308 | "text": "They're above the water." 309 | }, 310 | { 311 | "start": 54.313, 312 | "end": 57.495, 313 | "text": "Now, the amount of time varies quite a bit." 314 | }, 315 | { 316 | "start": 59.853, 317 | "end": 62.054, 318 | "text": " For these ones I'm going to put them in for about 15 minutes." 319 | }, 320 | { 321 | "start": 62.054, 322 | "end": 63.755, 323 | "text": "Smaller potatoes maybe a little bit less." 324 | }, 325 | { 326 | "start": 63.755, 327 | "end": 65.016, 328 | "text": "Larger potatoes a little bit more." 329 | }, 330 | { 331 | "start": 65.016, 332 | "end": 68.518, 333 | "text": "The nice thing about this is it doesn't tie up anything else." 334 | }, 335 | { 336 | "start": 68.518, 337 | "end": 69.898, 338 | "text": "It doesn't heat up your kitchen." 339 | }, 340 | { 341 | "start": 69.898, 342 | "end": 73.881, 343 | "text": "It's really good for in the summer when you want to have something like a baked potato." 344 | }, 345 | { 346 | "start": 73.881, 347 | "end": 75.481, 348 | "text": "It's just going to be steamed." 349 | }, 350 | { 351 | "start": 75.481, 352 | "end": 76.082, 353 | "text": "So there it is." 354 | }, 355 | { 356 | "start": 76.082, 357 | "end": 76.942, 358 | "text": "15 minutes." 359 | }, 360 | { 361 | "start": 76.942, 362 | "end": 80.484, 363 | "text": "When they're done you just pop them out and eat them like a normal baked potato." 364 | }, 365 | { 366 | "start": 80.484, 367 | "end": 81.465, 368 | "text": "I hope you find this useful." 369 | }, 370 | { 371 | "start": 81.465, 372 | "end": 87.928, 373 | "text": "If you want to hear more ideas or have any questions leave a comment, send me an email and I'll see what I can do for you." 374 | }, 375 | { 376 | "start": 89.452, 377 | "end": 92.674, 378 | "text": " I hope you're enjoying your pressure cooker as much as I'm enjoying mine." 379 | }, 380 | { 381 | "start": 92.674, 382 | "end": 93.915, 383 | "text": "Bye." 384 | }, 385 | { 386 | "start": 93.915, 387 | "end": 96.837, 388 | "text": "And here we go." 389 | }, 390 | { 391 | "start": 96.837, 392 | "end": 101.699, 393 | "text": "I'll pull it out." 394 | }, 395 | { 396 | "start": 101.699, 397 | "end": 104.601, 398 | "text": "There is the baked potato or steamed potato actually." 399 | }, 400 | { 401 | "start": 104.601, 402 | "end": 107.523, 403 | "text": "It's nice and it's pretty dry on the outside." 404 | }, 405 | { 406 | "start": 108.55, 407 | "end": 110.371, 408 | "text": " Just cut it open." 409 | }, 410 | { 411 | "start": 110.371, 412 | "end": 112.431, 413 | "text": "It's nice and soft." 414 | }, 415 | { 416 | "start": 112.431, 417 | "end": 115.072, 418 | "text": "Well cooked on the inside." 419 | }, 420 | { 421 | "start": 115.072, 422 | "end": 117.412, 423 | "text": "Great to cook it up however you're going to make a meal." 424 | }, 425 | { 426 | "start": 117.412, 427 | "end": 119.353, 428 | "text": "If you're going to eat it like a traditional baked potato." 429 | }, 430 | { 431 | "start": 119.353, 432 | "end": 121.633, 433 | "text": "If you're going to use it for potato salad or whatever." 434 | }, 435 | { 436 | "start": 121.633, 437 | "end": 123.794, 438 | "text": "I hope you enjoyed it." 439 | }, 440 | { 441 | "start": 123.794, 442 | "end": 126.755, 443 | "text": "If you want to see any other ideas, check my channel." 444 | }, 445 | { 446 | "start": 126.755, 447 | "end": 128.355, 448 | "text": "See what other things I've got posted." 449 | }, 450 | { 451 | "start": 128.355, 452 | "end": 134.157, 453 | "text": "If you've got ideas that you don't know how to do, send me an email or leave a comment and I'll see what I can do." 454 | }, 455 | { 456 | "start": 134.157, 457 | "end": 134.697, 458 | "text": "Hope you enjoyed it." 459 | }, 460 | { 461 | "start": 134.697, 462 | "end": 135.037, 463 | "text": "Bye." 464 | } 465 | ] 466 | 467 | print("=== Llama 3 prompt to filter non-instructional videos ===") 468 | print(prompt_to_filter_noninstructional_videos(example_title, example_transcript)) 469 | 470 | print("\n" + "=" * 80 + "\n") 471 | 472 | print("=== Llama 3 prompt to generate keysteps ===") 473 | print(prompt_to_generate_keysteps(example_title, example_transcript)) 474 | -------------------------------------------------------------------------------- /download_dataset.sh: -------------------------------------------------------------------------------- 1 | echo "Downloading ShowHowTo dataset to './data'." 2 | 3 | echo -n "Downloading ShowHowTo test set ... " 4 | mkdir -p data/ShowHowToTest 5 | cd data/ShowHowToTest || exit 1 6 | wget https://data.ciirc.cvut.cz/public/projects/2024ShowHowTo/dataset/ShowHowToTest.prompts.tar.gz --no-check-certificate -q 7 | wget https://data.ciirc.cvut.cz/public/projects/2024ShowHowTo/dataset/ShowHowToTest.keyframes.tar.gz --no-check-certificate -q 8 | 9 | tar xf ShowHowToTest.prompts.tar.gz && rm ShowHowToTest.prompts.tar.gz && \ 10 | tar xf ShowHowToTest.keyframes.tar.gz && rm ShowHowToTest.keyframes.tar.gz && \ 11 | echo "OK ✓" || echo "ERROR ✗" 12 | 13 | cd ../.. 14 | 15 | echo -n "Downloading ShowHowTo train set ... " 16 | mkdir -p data/ShowHowToTrain 17 | cd data/ShowHowToTrain || exit 1 18 | wget https://data.ciirc.cvut.cz/public/projects/2024ShowHowTo/dataset/ShowHowToTrain.prompts.tar.gz --no-check-certificate -q 19 | wget https://data.ciirc.cvut.cz/public/projects/2024ShowHowTo/dataset/ShowHowToTrain.keyframes.tar.gz --no-check-certificate -q 20 | 21 | tar xf ShowHowToTrain.prompts.tar.gz && rm ShowHowToTrain.prompts.tar.gz && \ 22 | tar xf ShowHowToTrain.keyframes.tar.gz && rm ShowHowToTrain.keyframes.tar.gz && \ 23 | echo "OK ✓" || echo "ERROR ✗" 24 | 25 | cd ../.. 26 | cd data 27 | 28 | echo -n 'Do you want to also download image sequences? Approximately 200GB will be downloaded, password is required. (y/n) ' 29 | read -r answer 30 | if [[ "${answer}" != "${answer#[Yy]}" ]];then 31 | echo -n 'Username: ' 32 | read -r username 33 | echo -n 'Password: ' 34 | read -r password 35 | 36 | echo -n "Downloading ShowHowTo test set image sequences ... " 37 | wget --user="${username}" --password="${password}" https://data.ciirc.cvut.cz/public/projects/2024ShowHowTo/dataset_sequences/ShowHowToTest.images.tar --no-check-certificate -q 38 | tar xf ShowHowToTest.images.tar && rm ShowHowToTest.images.tar && echo "OK ✓" || echo "ERROR ✗" 39 | 40 | echo -n "Downloading ShowHowTo train set image sequences ... " 41 | wget --user="${username}" --password="${password}" https://data.ciirc.cvut.cz/public/projects/2024ShowHowTo/dataset_sequences/ShowHowToTrain.images.tar --no-check-certificate -q 42 | tar xf ShowHowToTrain.images.tar && rm ShowHowToTrain.images.tar && echo "OK ✓" || echo "ERROR ✗" 43 | fi 44 | -------------------------------------------------------------------------------- /download_weights.sh: -------------------------------------------------------------------------------- 1 | mkdir -p weights 2 | 3 | echo -n 'Do you want to use our trained model (y/n)? ' 4 | read -r answer 5 | if [[ "${answer}" != "${answer#[Yy]}" ]];then 6 | echo -n "Downloading ShowHowTo weights ... " 7 | wget https://data.ciirc.cvut.cz/public/projects/2024ShowHowTo/weights/showhowto_2to8steps.pt -O weights/showhowto_2to8steps.pt --no-check-certificate 8 | SHA256SUM=$(sha256sum weights/showhowto_2to8steps.pt | cut -d' ' -f1) 9 | if [[ ${SHA256SUM} == "5759609fde82dc394a3e9872f145c50bed229d9d22d24dd682065e4e724ac47c" ]]; then 10 | echo "OK ✓" 11 | else 12 | echo "ERROR ✗" 13 | exit 1 14 | fi 15 | fi 16 | 17 | echo -n 'Do you want to train your own model (y/n)? ' 18 | read -r answer 19 | if [[ "${answer}" != "${answer#[Yy]}" ]];then 20 | echo -n "Downloading DynamiCrafter weights ... " 21 | wget https://huggingface.co/Doubiiu/DynamiCrafter/resolve/main/model.ckpt -q -O ./weights/dynamicrafter_256_v1.ckpt 22 | SHA256SUM=$(sha256sum weights/dynamicrafter_256_v1.ckpt | cut -d' ' -f1) 23 | if [[ ${SHA256SUM} == "328d23963f1fe5af1324793117dfa80c8f5d3d31a2a7d5a6089a1c8aa72fb2da" ]]; then 24 | echo "OK ✓" 25 | else 26 | echo "ERROR ✗" 27 | exit 1 28 | fi 29 | fi 30 | -------------------------------------------------------------------------------- /lvdm/basics.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | import torch.nn as nn 11 | from lvdm.utils import instantiate_from_config 12 | 13 | 14 | def disabled_train(self, mode=True): 15 | """Overwrite model.train with this function to make sure train/eval mode 16 | does not change anymore.""" 17 | return self 18 | 19 | def zero_module(module): 20 | """ 21 | Zero out the parameters of a module and return it. 22 | """ 23 | for p in module.parameters(): 24 | p.detach().zero_() 25 | return module 26 | 27 | def scale_module(module, scale): 28 | """ 29 | Scale the parameters of a module and return it. 30 | """ 31 | for p in module.parameters(): 32 | p.detach().mul_(scale) 33 | return module 34 | 35 | 36 | def conv_nd(dims, *args, **kwargs): 37 | """ 38 | Create a 1D, 2D, or 3D convolution module. 39 | """ 40 | if dims == 1: 41 | return nn.Conv1d(*args, **kwargs) 42 | elif dims == 2: 43 | return nn.Conv2d(*args, **kwargs) 44 | elif dims == 3: 45 | return nn.Conv3d(*args, **kwargs) 46 | raise ValueError(f"unsupported dimensions: {dims}") 47 | 48 | 49 | def linear(*args, **kwargs): 50 | """ 51 | Create a linear module. 52 | """ 53 | return nn.Linear(*args, **kwargs) 54 | 55 | 56 | def avg_pool_nd(dims, *args, **kwargs): 57 | """ 58 | Create a 1D, 2D, or 3D average pooling module. 59 | """ 60 | if dims == 1: 61 | return nn.AvgPool1d(*args, **kwargs) 62 | elif dims == 2: 63 | return nn.AvgPool2d(*args, **kwargs) 64 | elif dims == 3: 65 | return nn.AvgPool3d(*args, **kwargs) 66 | raise ValueError(f"unsupported dimensions: {dims}") 67 | 68 | 69 | def nonlinearity(type='silu'): 70 | if type == 'silu': 71 | return nn.SiLU() 72 | elif type == 'leaky_relu': 73 | return nn.LeakyReLU() 74 | 75 | 76 | class GroupNormSpecific(nn.GroupNorm): 77 | def forward(self, x): 78 | return super().forward(x.float()).type(x.dtype) 79 | 80 | 81 | def normalization(channels, num_groups=32): 82 | """ 83 | Make a standard normalization layer. 84 | :param channels: number of input channels. 85 | :return: an nn.Module for normalization. 86 | """ 87 | return GroupNormSpecific(num_groups, channels) 88 | 89 | 90 | class HybridConditioner(nn.Module): 91 | 92 | def __init__(self, c_concat_config, c_crossattn_config): 93 | super().__init__() 94 | self.concat_conditioner = instantiate_from_config(c_concat_config) 95 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 96 | 97 | def forward(self, c_concat, c_crossattn): 98 | c_concat = self.concat_conditioner(c_concat) 99 | c_crossattn = self.crossattn_conditioner(c_crossattn) 100 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} -------------------------------------------------------------------------------- /lvdm/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | from inspect import isfunction 3 | import torch 4 | from torch import nn 5 | import torch.distributed as dist 6 | 7 | 8 | def gather_data(data, return_np=True): 9 | ''' gather data from multiple processes to one list ''' 10 | data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size())] 11 | dist.all_gather(data_list, data) # gather not supported with NCCL 12 | if return_np: 13 | data_list = [data.cpu().numpy() for data in data_list] 14 | return data_list 15 | 16 | def autocast(f): 17 | def do_autocast(*args, **kwargs): 18 | with torch.cuda.amp.autocast(enabled=True, 19 | dtype=torch.get_autocast_gpu_dtype(), 20 | cache_enabled=torch.is_autocast_cache_enabled()): 21 | return f(*args, **kwargs) 22 | return do_autocast 23 | 24 | 25 | def extract_into_tensor(a, t, x_shape): 26 | b, *_ = t.shape 27 | out = a.gather(-1, t) 28 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 29 | 30 | 31 | def noise_like(shape, device, repeat=False): 32 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 33 | noise = lambda: torch.randn(shape, device=device) 34 | return repeat_noise() if repeat else noise() 35 | 36 | 37 | def default(val, d): 38 | if exists(val): 39 | return val 40 | return d() if isfunction(d) else d 41 | 42 | def exists(val): 43 | return val is not None 44 | 45 | def identity(*args, **kwargs): 46 | return nn.Identity() 47 | 48 | def uniq(arr): 49 | return{el: True for el in arr}.keys() 50 | 51 | def mean_flat(tensor): 52 | """ 53 | Take the mean over all non-batch dimensions. 54 | """ 55 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 56 | 57 | def ismap(x): 58 | if not isinstance(x, torch.Tensor): 59 | return False 60 | return (len(x.shape) == 4) and (x.shape[1] > 3) 61 | 62 | def isimage(x): 63 | if not isinstance(x,torch.Tensor): 64 | return False 65 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 66 | 67 | def max_neg_value(t): 68 | return -torch.finfo(t.dtype).max 69 | 70 | def shape_to_str(x): 71 | shape_str = "x".join([str(x) for x in x.shape]) 72 | return shape_str 73 | 74 | def init_(tensor): 75 | dim = tensor.shape[-1] 76 | std = 1 / math.sqrt(dim) 77 | tensor.uniform_(-std, std) 78 | return tensor 79 | 80 | 81 | def checkpoint(func, inputs, params, flag): 82 | from torch.utils import checkpoint 83 | """ 84 | Evaluate a function without caching intermediate activations, allowing for 85 | reduced memory at the expense of extra compute in the backward pass. 86 | :param func: the function to evaluate. 87 | :param inputs: the argument sequence to pass to `func`. 88 | :param params: a sequence of parameters `func` depends on but does not 89 | explicitly take as arguments. 90 | :param flag: if False, disable gradient checkpointing. 91 | """ 92 | if flag: 93 | return checkpoint.checkpoint(func, *inputs, use_reentrant=False) 94 | else: 95 | return func(*inputs) -------------------------------------------------------------------------------- /lvdm/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self, noise=None): 36 | if noise is None: 37 | noise = torch.randn(self.mean.shape) 38 | 39 | x = self.mean + self.std * noise.to(device=self.parameters.device) 40 | return x 41 | 42 | def kl(self, other=None): 43 | if self.deterministic: 44 | return torch.Tensor([0.]) 45 | else: 46 | if other is None: 47 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 48 | + self.var - 1.0 - self.logvar, 49 | dim=[1, 2, 3]) 50 | else: 51 | return 0.5 * torch.sum( 52 | torch.pow(self.mean - other.mean, 2) / other.var 53 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 54 | dim=[1, 2, 3]) 55 | 56 | def nll(self, sample, dims=[1,2,3]): 57 | if self.deterministic: 58 | return torch.Tensor([0.]) 59 | logtwopi = np.log(2.0 * np.pi) 60 | return 0.5 * torch.sum( 61 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 62 | dim=dims) 63 | 64 | def mode(self): 65 | return self.mean 66 | 67 | 68 | def normal_kl(mean1, logvar1, mean2, logvar2): 69 | """ 70 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 71 | Compute the KL divergence between two gaussians. 72 | Shapes are automatically broadcasted, so batches can be compared to 73 | scalars, among other use cases. 74 | """ 75 | tensor = None 76 | for obj in (mean1, logvar1, mean2, logvar2): 77 | if isinstance(obj, torch.Tensor): 78 | tensor = obj 79 | break 80 | assert tensor is not None, "at least one argument must be a Tensor" 81 | 82 | # Force variances to be Tensors. Broadcasting helps convert scalars to 83 | # Tensors, but it does not work for torch.exp(). 84 | logvar1, logvar2 = [ 85 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 86 | for x in (logvar1, logvar2) 87 | ] 88 | 89 | return 0.5 * ( 90 | -1.0 91 | + logvar2 92 | - logvar1 93 | + torch.exp(logvar1 - logvar2) 94 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 95 | ) -------------------------------------------------------------------------------- /lvdm/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) -------------------------------------------------------------------------------- /lvdm/models/autoencoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from contextlib import contextmanager 3 | import torch 4 | import numpy as np 5 | from einops import rearrange 6 | import torch.nn.functional as F 7 | from lvdm.modules.networks.ae_modules import Encoder, Decoder 8 | from lvdm.distributions import DiagonalGaussianDistribution 9 | from lvdm.utils import instantiate_from_config 10 | 11 | 12 | class AutoencoderKL(torch.nn.Module): 13 | def __init__(self, 14 | ddconfig, 15 | lossconfig, 16 | embed_dim, 17 | ckpt_path=None, 18 | ignore_keys=[], 19 | image_key="image", 20 | colorize_nlabels=None, 21 | monitor=None, 22 | test=False, 23 | logdir=None, 24 | input_dim=4, 25 | test_args=None, 26 | ): 27 | super().__init__() 28 | self.image_key = image_key 29 | self.encoder = Encoder(**ddconfig) 30 | self.decoder = Decoder(**ddconfig) 31 | self.loss = instantiate_from_config(lossconfig) 32 | assert ddconfig["double_z"] 33 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) 34 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 35 | self.embed_dim = embed_dim 36 | self.input_dim = input_dim 37 | self.test = test 38 | self.test_args = test_args 39 | self.logdir = logdir 40 | if colorize_nlabels is not None: 41 | assert type(colorize_nlabels)==int 42 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 43 | if monitor is not None: 44 | self.monitor = monitor 45 | if ckpt_path is not None: 46 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 47 | if self.test: 48 | self.init_test() 49 | 50 | def init_test(self,): 51 | self.test = True 52 | save_dir = os.path.join(self.logdir, "test") 53 | if 'ckpt' in self.test_args: 54 | ckpt_name = os.path.basename(self.test_args.ckpt).split('.ckpt')[0] + f'_epoch{self._cur_epoch}' 55 | self.root = os.path.join(save_dir, ckpt_name) 56 | else: 57 | self.root = save_dir 58 | if 'test_subdir' in self.test_args: 59 | self.root = os.path.join(save_dir, self.test_args.test_subdir) 60 | 61 | self.root_zs = os.path.join(self.root, "zs") 62 | self.root_dec = os.path.join(self.root, "reconstructions") 63 | self.root_inputs = os.path.join(self.root, "inputs") 64 | os.makedirs(self.root, exist_ok=True) 65 | 66 | if self.test_args.save_z: 67 | os.makedirs(self.root_zs, exist_ok=True) 68 | if self.test_args.save_reconstruction: 69 | os.makedirs(self.root_dec, exist_ok=True) 70 | if self.test_args.save_input: 71 | os.makedirs(self.root_inputs, exist_ok=True) 72 | assert(self.test_args is not None) 73 | self.test_maximum = getattr(self.test_args, 'test_maximum', None) 74 | self.count = 0 75 | self.eval_metrics = {} 76 | self.decodes = [] 77 | self.save_decode_samples = 2048 78 | 79 | def init_from_ckpt(self, path, ignore_keys=list()): 80 | sd = torch.load(path, map_location="cpu") 81 | try: 82 | self._cur_epoch = sd['epoch'] 83 | sd = sd["state_dict"] 84 | except: 85 | self._cur_epoch = 'null' 86 | keys = list(sd.keys()) 87 | for k in keys: 88 | for ik in ignore_keys: 89 | if k.startswith(ik): 90 | print("Deleting key {} from state_dict.".format(k)) 91 | del sd[k] 92 | self.load_state_dict(sd, strict=False) 93 | # self.load_state_dict(sd, strict=True) 94 | print(f"Restored from {path}") 95 | 96 | def encode(self, x, **kwargs): 97 | 98 | h = self.encoder(x) 99 | moments = self.quant_conv(h) 100 | posterior = DiagonalGaussianDistribution(moments) 101 | return posterior 102 | 103 | def decode(self, z, **kwargs): 104 | z = self.post_quant_conv(z) 105 | dec = self.decoder(z) 106 | return dec 107 | 108 | def forward(self, input, sample_posterior=True): 109 | posterior = self.encode(input) 110 | if sample_posterior: 111 | z = posterior.sample() 112 | else: 113 | z = posterior.mode() 114 | dec = self.decode(z) 115 | return dec, posterior 116 | 117 | def get_input(self, batch, k): 118 | x = batch[k] 119 | if x.dim() == 5 and self.input_dim == 4: 120 | b,c,t,h,w = x.shape 121 | self.b = b 122 | self.t = t 123 | x = rearrange(x, 'b c t h w -> (b t) c h w') 124 | 125 | return x 126 | 127 | def training_step(self, batch, batch_idx, optimizer_idx): 128 | inputs = self.get_input(batch, self.image_key) 129 | reconstructions, posterior = self(inputs) 130 | 131 | if optimizer_idx == 0: 132 | # train encoder+decoder+logvar 133 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 134 | last_layer=self.get_last_layer(), split="train") 135 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 136 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) 137 | return aeloss 138 | 139 | if optimizer_idx == 1: 140 | # train the discriminator 141 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 142 | last_layer=self.get_last_layer(), split="train") 143 | 144 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 145 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) 146 | return discloss 147 | 148 | def validation_step(self, batch, batch_idx): 149 | inputs = self.get_input(batch, self.image_key) 150 | reconstructions, posterior = self(inputs) 151 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, 152 | last_layer=self.get_last_layer(), split="val") 153 | 154 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, 155 | last_layer=self.get_last_layer(), split="val") 156 | 157 | self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) 158 | self.log_dict(log_dict_ae) 159 | self.log_dict(log_dict_disc) 160 | return self.log_dict 161 | 162 | def configure_optimizers(self): 163 | lr = self.learning_rate 164 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ 165 | list(self.decoder.parameters())+ 166 | list(self.quant_conv.parameters())+ 167 | list(self.post_quant_conv.parameters()), 168 | lr=lr, betas=(0.5, 0.9)) 169 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), 170 | lr=lr, betas=(0.5, 0.9)) 171 | return [opt_ae, opt_disc], [] 172 | 173 | def get_last_layer(self): 174 | return self.decoder.conv_out.weight 175 | 176 | @torch.no_grad() 177 | def log_images(self, batch, only_inputs=False, **kwargs): 178 | log = dict() 179 | x = self.get_input(batch, self.image_key) 180 | x = x.to(self.device) 181 | if not only_inputs: 182 | xrec, posterior = self(x) 183 | if x.shape[1] > 3: 184 | # colorize with random projection 185 | assert xrec.shape[1] > 3 186 | x = self.to_rgb(x) 187 | xrec = self.to_rgb(xrec) 188 | log["samples"] = self.decode(torch.randn_like(posterior.sample())) 189 | log["reconstructions"] = xrec 190 | log["inputs"] = x 191 | return log 192 | 193 | def to_rgb(self, x): 194 | assert self.image_key == "segmentation" 195 | if not hasattr(self, "colorize"): 196 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 197 | x = F.conv2d(x, weight=self.colorize) 198 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1. 199 | return x 200 | 201 | class IdentityFirstStage(torch.nn.Module): 202 | def __init__(self, *args, vq_interface=False, **kwargs): 203 | self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff 204 | super().__init__() 205 | 206 | def encode(self, x, *args, **kwargs): 207 | return x 208 | 209 | def decode(self, x, *args, **kwargs): 210 | return x 211 | 212 | def quantize(self, x, *args, **kwargs): 213 | if self.vq_interface: 214 | return x, None, [None, None, None] 215 | return x 216 | 217 | def forward(self, x, *args, **kwargs): 218 | return x -------------------------------------------------------------------------------- /lvdm/models/samplers/ddim.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | import torch 4 | from lvdm.models.utils_diffusion import make_ddim_sampling_parameters, make_ddim_timesteps, rescale_noise_cfg 5 | from lvdm.common import noise_like 6 | from lvdm.common import extract_into_tensor 7 | import copy 8 | 9 | 10 | class DDIMSampler(object): 11 | def __init__(self, model, schedule="linear", **kwargs): 12 | super().__init__() 13 | self.model = model 14 | self.ddpm_num_timesteps = model.num_timesteps 15 | self.schedule = schedule 16 | self.counter = 0 17 | 18 | def register_buffer(self, name, attr): 19 | if type(attr) == torch.Tensor: 20 | if attr.device != torch.device("cuda"): 21 | attr = attr.to(torch.device("cuda")) 22 | setattr(self, name, attr) 23 | 24 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 25 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 26 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 27 | alphas_cumprod = self.model.alphas_cumprod 28 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 29 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 30 | 31 | if self.model.use_dynamic_rescale: 32 | self.ddim_scale_arr = self.model.scale_arr[self.ddim_timesteps] 33 | self.ddim_scale_arr_prev = torch.cat([self.ddim_scale_arr[0:1], self.ddim_scale_arr[:-1]]) 34 | 35 | self.register_buffer('betas', to_torch(self.model.betas)) 36 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 37 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 38 | 39 | # calculations for diffusion q(x_t | x_{t-1}) and others 40 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 41 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 42 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 43 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 44 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 45 | 46 | # ddim sampling parameters 47 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 48 | ddim_timesteps=self.ddim_timesteps, 49 | eta=ddim_eta,verbose=verbose) 50 | self.register_buffer('ddim_sigmas', ddim_sigmas) 51 | self.register_buffer('ddim_alphas', ddim_alphas) 52 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 53 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 54 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 55 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 56 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 57 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 58 | 59 | @torch.no_grad() 60 | def sample(self, 61 | S, 62 | batch_size, 63 | shape, 64 | conditioning=None, 65 | callback=None, 66 | normals_sequence=None, 67 | img_callback=None, 68 | quantize_x0=False, 69 | eta=0., 70 | mask=None, 71 | x0=None, 72 | temperature=1., 73 | noise_dropout=0., 74 | score_corrector=None, 75 | corrector_kwargs=None, 76 | verbose=True, 77 | schedule_verbose=False, 78 | x_T=None, 79 | log_every_t=100, 80 | unconditional_guidance_scale=1., 81 | unconditional_conditioning=None, 82 | precision=None, 83 | fs=None, 84 | timestep_spacing='uniform', #uniform_trailing for starting from last timestep 85 | guidance_rescale=0.0, 86 | **kwargs 87 | ): 88 | 89 | # check condition bs 90 | if conditioning is not None: 91 | if isinstance(conditioning, dict): 92 | try: 93 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 94 | except: 95 | cbs = conditioning[list(conditioning.keys())[0]][0].shape[0] 96 | 97 | if cbs != batch_size: 98 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 99 | else: 100 | if conditioning.shape[0] != batch_size: 101 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 102 | 103 | self.make_schedule(ddim_num_steps=S, ddim_discretize=timestep_spacing, ddim_eta=eta, verbose=schedule_verbose) 104 | 105 | # make shape 106 | if len(shape) == 3: 107 | C, H, W = shape 108 | size = (batch_size, C, H, W) 109 | elif len(shape) == 4: 110 | C, T, H, W = shape 111 | size = (batch_size, C, T, H, W) 112 | 113 | samples, intermediates = self.ddim_sampling(conditioning, size, 114 | callback=callback, 115 | img_callback=img_callback, 116 | quantize_denoised=quantize_x0, 117 | mask=mask, x0=x0, 118 | ddim_use_original_steps=False, 119 | noise_dropout=noise_dropout, 120 | temperature=temperature, 121 | score_corrector=score_corrector, 122 | corrector_kwargs=corrector_kwargs, 123 | x_T=x_T, 124 | log_every_t=log_every_t, 125 | unconditional_guidance_scale=unconditional_guidance_scale, 126 | unconditional_conditioning=unconditional_conditioning, 127 | verbose=verbose, 128 | precision=precision, 129 | fs=fs, 130 | guidance_rescale=guidance_rescale, 131 | **kwargs) 132 | return samples, intermediates 133 | 134 | @torch.no_grad() 135 | def ddim_sampling(self, cond, shape, 136 | x_T=None, ddim_use_original_steps=False, 137 | callback=None, timesteps=None, quantize_denoised=False, 138 | mask=None, x0=None, img_callback=None, log_every_t=100, 139 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 140 | unconditional_guidance_scale=1., unconditional_conditioning=None, verbose=True,precision=None,fs=None,guidance_rescale=0.0, 141 | **kwargs): 142 | device = self.model.betas.device 143 | b = shape[0] 144 | if x_T is None: 145 | img = torch.randn(shape, device=device) 146 | else: 147 | img = x_T 148 | if precision is not None: 149 | if precision == 16: 150 | img = img.to(dtype=torch.float16) 151 | 152 | if timesteps is None: 153 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 154 | elif timesteps is not None and not ddim_use_original_steps: 155 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 156 | timesteps = self.ddim_timesteps[:subset_end] 157 | 158 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 159 | time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) 160 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 161 | if verbose: 162 | iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) 163 | else: 164 | iterator = time_range 165 | 166 | clean_cond = kwargs.pop("clean_cond", False) 167 | 168 | # cond_copy, unconditional_conditioning_copy = copy.deepcopy(cond), copy.deepcopy(unconditional_conditioning) 169 | for i, step in enumerate(iterator): 170 | index = total_steps - i - 1 171 | ts = torch.full((b,), step, device=device, dtype=torch.long) 172 | 173 | ## use mask to blend noised original latent (img_orig) & new sampled latent (img) 174 | if mask is not None: 175 | assert x0 is not None 176 | if clean_cond: 177 | img_orig = x0 178 | else: 179 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 180 | img = img_orig * mask + (1. - mask) * img # keep original & modify use img 181 | 182 | 183 | 184 | 185 | outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 186 | quantize_denoised=quantize_denoised, temperature=temperature, 187 | noise_dropout=noise_dropout, score_corrector=score_corrector, 188 | corrector_kwargs=corrector_kwargs, 189 | unconditional_guidance_scale=unconditional_guidance_scale, 190 | unconditional_conditioning=unconditional_conditioning, 191 | mask=mask,x0=x0,fs=fs,guidance_rescale=guidance_rescale, 192 | **kwargs) 193 | 194 | 195 | img, pred_x0 = outs 196 | if callback: callback(i) 197 | if img_callback: img_callback(pred_x0, i) 198 | 199 | if index % log_every_t == 0 or index == total_steps - 1: 200 | intermediates['x_inter'].append(img) 201 | intermediates['pred_x0'].append(pred_x0) 202 | 203 | return img, intermediates 204 | 205 | @torch.no_grad() 206 | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 207 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 208 | unconditional_guidance_scale=1., unconditional_conditioning=None, 209 | uc_type=None, conditional_guidance_scale_temporal=None,mask=None,x0=None,guidance_rescale=0.0,**kwargs): 210 | b, *_, device = *x.shape, x.device 211 | if x.dim() == 5: 212 | is_video = True 213 | else: 214 | is_video = False 215 | 216 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 217 | model_output = self.model.apply_model(x, t, c, **kwargs) # unet denoiser 218 | else: 219 | ### do_classifier_free_guidance 220 | if isinstance(c, torch.Tensor) or isinstance(c, dict): 221 | e_t_cond = self.model.apply_model(x, t, c, **kwargs) 222 | e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning, **kwargs) 223 | else: 224 | raise NotImplementedError 225 | 226 | model_output = e_t_uncond + unconditional_guidance_scale * (e_t_cond - e_t_uncond) 227 | 228 | if guidance_rescale > 0.0: 229 | model_output = rescale_noise_cfg(model_output, e_t_cond, guidance_rescale=guidance_rescale) 230 | 231 | if self.model.parameterization == "v": 232 | e_t = self.model.predict_eps_from_z_and_v(x, t, model_output) 233 | else: 234 | e_t = model_output 235 | 236 | if score_corrector is not None: 237 | assert self.model.parameterization == "eps", 'not implemented' 238 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 239 | 240 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 241 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 242 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 243 | # sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 244 | sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 245 | # select parameters corresponding to the currently considered timestep 246 | 247 | if is_video: 248 | size = (b, 1, 1, 1, 1) 249 | else: 250 | size = (b, 1, 1, 1) 251 | a_t = torch.full(size, alphas[index], device=device) 252 | a_prev = torch.full(size, alphas_prev[index], device=device) 253 | sigma_t = torch.full(size, sigmas[index], device=device) 254 | sqrt_one_minus_at = torch.full(size, sqrt_one_minus_alphas[index],device=device) 255 | 256 | # current prediction for x_0 257 | if self.model.parameterization != "v": 258 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 259 | else: 260 | pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) 261 | 262 | if self.model.use_dynamic_rescale: 263 | scale_t = torch.full(size, self.ddim_scale_arr[index], device=device) 264 | prev_scale_t = torch.full(size, self.ddim_scale_arr_prev[index], device=device) 265 | rescale = (prev_scale_t / scale_t) 266 | pred_x0 *= rescale 267 | 268 | if quantize_denoised: 269 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 270 | # direction pointing to x_t 271 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 272 | 273 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 274 | if noise_dropout > 0.: 275 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 276 | 277 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 278 | 279 | return x_prev, pred_x0 280 | 281 | @torch.no_grad() 282 | def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, 283 | use_original_steps=False, callback=None): 284 | 285 | timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps 286 | timesteps = timesteps[:t_start] 287 | 288 | time_range = np.flip(timesteps) 289 | total_steps = timesteps.shape[0] 290 | print(f"Running DDIM Sampling with {total_steps} timesteps") 291 | 292 | iterator = tqdm(time_range, desc='Decoding image', total=total_steps) 293 | x_dec = x_latent 294 | for i, step in enumerate(iterator): 295 | index = total_steps - i - 1 296 | ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) 297 | x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, 298 | unconditional_guidance_scale=unconditional_guidance_scale, 299 | unconditional_conditioning=unconditional_conditioning) 300 | if callback: callback(i) 301 | return x_dec 302 | 303 | @torch.no_grad() 304 | def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): 305 | # fast, but does not allow for exact reconstruction 306 | # t serves as an index to gather the correct alphas 307 | if use_original_steps: 308 | sqrt_alphas_cumprod = self.sqrt_alphas_cumprod 309 | sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod 310 | else: 311 | sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) 312 | sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas 313 | 314 | if noise is None: 315 | noise = torch.randn_like(x0) 316 | return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + 317 | extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) 318 | -------------------------------------------------------------------------------- /lvdm/models/samplers/ddim_multiplecond.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | import torch 4 | from lvdm.models.utils_diffusion import make_ddim_sampling_parameters, make_ddim_timesteps, rescale_noise_cfg 5 | from lvdm.common import noise_like 6 | from lvdm.common import extract_into_tensor 7 | import copy 8 | 9 | 10 | class DDIMSampler(object): 11 | def __init__(self, model, schedule="linear", **kwargs): 12 | super().__init__() 13 | self.model = model 14 | self.ddpm_num_timesteps = model.num_timesteps 15 | self.schedule = schedule 16 | self.counter = 0 17 | 18 | def register_buffer(self, name, attr): 19 | if type(attr) == torch.Tensor: 20 | if attr.device != torch.device("cuda"): 21 | attr = attr.to(torch.device("cuda")) 22 | setattr(self, name, attr) 23 | 24 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 25 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 26 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 27 | alphas_cumprod = self.model.alphas_cumprod 28 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 29 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 30 | 31 | if self.model.use_dynamic_rescale: 32 | self.ddim_scale_arr = self.model.scale_arr[self.ddim_timesteps] 33 | self.ddim_scale_arr_prev = torch.cat([self.ddim_scale_arr[0:1], self.ddim_scale_arr[:-1]]) 34 | 35 | self.register_buffer('betas', to_torch(self.model.betas)) 36 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 37 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 38 | 39 | # calculations for diffusion q(x_t | x_{t-1}) and others 40 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 41 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 42 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 43 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 44 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 45 | 46 | # ddim sampling parameters 47 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 48 | ddim_timesteps=self.ddim_timesteps, 49 | eta=ddim_eta,verbose=verbose) 50 | self.register_buffer('ddim_sigmas', ddim_sigmas) 51 | self.register_buffer('ddim_alphas', ddim_alphas) 52 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 53 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 54 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 55 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 56 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 57 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 58 | 59 | @torch.no_grad() 60 | def sample(self, 61 | S, 62 | batch_size, 63 | shape, 64 | conditioning=None, 65 | callback=None, 66 | normals_sequence=None, 67 | img_callback=None, 68 | quantize_x0=False, 69 | eta=0., 70 | mask=None, 71 | x0=None, 72 | temperature=1., 73 | noise_dropout=0., 74 | score_corrector=None, 75 | corrector_kwargs=None, 76 | verbose=True, 77 | schedule_verbose=False, 78 | x_T=None, 79 | log_every_t=100, 80 | unconditional_guidance_scale=1., 81 | unconditional_conditioning=None, 82 | precision=None, 83 | fs=None, 84 | timestep_spacing='uniform', #uniform_trailing for starting from last timestep 85 | guidance_rescale=0.0, 86 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 87 | **kwargs 88 | ): 89 | 90 | # check condition bs 91 | if conditioning is not None: 92 | if isinstance(conditioning, dict): 93 | try: 94 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 95 | except: 96 | cbs = conditioning[list(conditioning.keys())[0]][0].shape[0] 97 | 98 | if cbs != batch_size: 99 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 100 | else: 101 | if conditioning.shape[0] != batch_size: 102 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 103 | 104 | # print('==> timestep_spacing: ', timestep_spacing, guidance_rescale) 105 | self.make_schedule(ddim_num_steps=S, ddim_discretize=timestep_spacing, ddim_eta=eta, verbose=schedule_verbose) 106 | 107 | # make shape 108 | if len(shape) == 3: 109 | C, H, W = shape 110 | size = (batch_size, C, H, W) 111 | elif len(shape) == 4: 112 | C, T, H, W = shape 113 | size = (batch_size, C, T, H, W) 114 | # print(f'Data shape for DDIM sampling is {size}, eta {eta}') 115 | 116 | samples, intermediates = self.ddim_sampling(conditioning, size, 117 | callback=callback, 118 | img_callback=img_callback, 119 | quantize_denoised=quantize_x0, 120 | mask=mask, x0=x0, 121 | ddim_use_original_steps=False, 122 | noise_dropout=noise_dropout, 123 | temperature=temperature, 124 | score_corrector=score_corrector, 125 | corrector_kwargs=corrector_kwargs, 126 | x_T=x_T, 127 | log_every_t=log_every_t, 128 | unconditional_guidance_scale=unconditional_guidance_scale, 129 | unconditional_conditioning=unconditional_conditioning, 130 | verbose=verbose, 131 | precision=precision, 132 | fs=fs, 133 | guidance_rescale=guidance_rescale, 134 | **kwargs) 135 | return samples, intermediates 136 | 137 | @torch.no_grad() 138 | def ddim_sampling(self, cond, shape, 139 | x_T=None, ddim_use_original_steps=False, 140 | callback=None, timesteps=None, quantize_denoised=False, 141 | mask=None, x0=None, img_callback=None, log_every_t=100, 142 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 143 | unconditional_guidance_scale=1., unconditional_conditioning=None, verbose=True,precision=None,fs=None,guidance_rescale=0.0, 144 | **kwargs): 145 | device = self.model.betas.device 146 | b = shape[0] 147 | if x_T is None: 148 | img = torch.randn(shape, device=device) 149 | else: 150 | img = x_T 151 | if precision is not None: 152 | if precision == 16: 153 | img = img.to(dtype=torch.float16) 154 | 155 | 156 | if timesteps is None: 157 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 158 | elif timesteps is not None and not ddim_use_original_steps: 159 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 160 | timesteps = self.ddim_timesteps[:subset_end] 161 | 162 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 163 | time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) 164 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 165 | if verbose: 166 | iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) 167 | else: 168 | iterator = time_range 169 | 170 | clean_cond = kwargs.pop("clean_cond", False) 171 | 172 | # cond_copy, unconditional_conditioning_copy = copy.deepcopy(cond), copy.deepcopy(unconditional_conditioning) 173 | for i, step in enumerate(iterator): 174 | index = total_steps - i - 1 175 | ts = torch.full((b,), step, device=device, dtype=torch.long) 176 | 177 | ## use mask to blend noised original latent (img_orig) & new sampled latent (img) 178 | if mask is not None: 179 | assert x0 is not None 180 | if clean_cond: 181 | img_orig = x0 182 | else: 183 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 184 | img = img_orig * mask + (1. - mask) * img # keep original & modify use img 185 | 186 | 187 | 188 | 189 | outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 190 | quantize_denoised=quantize_denoised, temperature=temperature, 191 | noise_dropout=noise_dropout, score_corrector=score_corrector, 192 | corrector_kwargs=corrector_kwargs, 193 | unconditional_guidance_scale=unconditional_guidance_scale, 194 | unconditional_conditioning=unconditional_conditioning, 195 | mask=mask,x0=x0,fs=fs,guidance_rescale=guidance_rescale, 196 | **kwargs) 197 | 198 | 199 | 200 | img, pred_x0 = outs 201 | if callback: callback(i) 202 | if img_callback: img_callback(pred_x0, i) 203 | 204 | if index % log_every_t == 0 or index == total_steps - 1: 205 | intermediates['x_inter'].append(img) 206 | intermediates['pred_x0'].append(pred_x0) 207 | 208 | return img, intermediates 209 | 210 | @torch.no_grad() 211 | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 212 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 213 | unconditional_guidance_scale=1., unconditional_conditioning=None, 214 | uc_type=None, cfg_img=None,mask=None,x0=None,guidance_rescale=0.0, **kwargs): 215 | b, *_, device = *x.shape, x.device 216 | if x.dim() == 5: 217 | is_video = True 218 | else: 219 | is_video = False 220 | if cfg_img is None: 221 | cfg_img = unconditional_guidance_scale 222 | 223 | unconditional_conditioning_img_nonetext = kwargs['unconditional_conditioning_img_nonetext'] 224 | 225 | 226 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 227 | model_output = self.model.apply_model(x, t, c, **kwargs) # unet denoiser 228 | else: 229 | ### with unconditional condition 230 | e_t_cond = self.model.apply_model(x, t, c, **kwargs) 231 | e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning, **kwargs) 232 | e_t_uncond_img = self.model.apply_model(x, t, unconditional_conditioning_img_nonetext, **kwargs) 233 | # text cfg 234 | model_output = e_t_uncond + cfg_img * (e_t_uncond_img - e_t_uncond) + unconditional_guidance_scale * (e_t_cond - e_t_uncond_img) 235 | if guidance_rescale > 0.0: 236 | model_output = rescale_noise_cfg(model_output, e_t_cond, guidance_rescale=guidance_rescale) 237 | 238 | if self.model.parameterization == "v": 239 | e_t = self.model.predict_eps_from_z_and_v(x, t, model_output) 240 | else: 241 | e_t = model_output 242 | 243 | if score_corrector is not None: 244 | assert self.model.parameterization == "eps", 'not implemented' 245 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 246 | 247 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 248 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 249 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 250 | sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 251 | # select parameters corresponding to the currently considered timestep 252 | 253 | if is_video: 254 | size = (b, 1, 1, 1, 1) 255 | else: 256 | size = (b, 1, 1, 1) 257 | a_t = torch.full(size, alphas[index], device=device) 258 | a_prev = torch.full(size, alphas_prev[index], device=device) 259 | sigma_t = torch.full(size, sigmas[index], device=device) 260 | sqrt_one_minus_at = torch.full(size, sqrt_one_minus_alphas[index],device=device) 261 | 262 | # current prediction for x_0 263 | if self.model.parameterization != "v": 264 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 265 | else: 266 | pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) 267 | 268 | if self.model.use_dynamic_rescale: 269 | scale_t = torch.full(size, self.ddim_scale_arr[index], device=device) 270 | prev_scale_t = torch.full(size, self.ddim_scale_arr_prev[index], device=device) 271 | rescale = (prev_scale_t / scale_t) 272 | pred_x0 *= rescale 273 | 274 | if quantize_denoised: 275 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 276 | # direction pointing to x_t 277 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 278 | 279 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 280 | if noise_dropout > 0.: 281 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 282 | 283 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 284 | 285 | return x_prev, pred_x0 286 | 287 | @torch.no_grad() 288 | def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, 289 | use_original_steps=False, callback=None): 290 | 291 | timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps 292 | timesteps = timesteps[:t_start] 293 | 294 | time_range = np.flip(timesteps) 295 | total_steps = timesteps.shape[0] 296 | print(f"Running DDIM Sampling with {total_steps} timesteps") 297 | 298 | iterator = tqdm(time_range, desc='Decoding image', total=total_steps) 299 | x_dec = x_latent 300 | for i, step in enumerate(iterator): 301 | index = total_steps - i - 1 302 | ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) 303 | x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, 304 | unconditional_guidance_scale=unconditional_guidance_scale, 305 | unconditional_conditioning=unconditional_conditioning) 306 | if callback: callback(i) 307 | return x_dec 308 | 309 | @torch.no_grad() 310 | def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): 311 | # fast, but does not allow for exact reconstruction 312 | # t serves as an index to gather the correct alphas 313 | if use_original_steps: 314 | sqrt_alphas_cumprod = self.sqrt_alphas_cumprod 315 | sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod 316 | else: 317 | sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) 318 | sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas 319 | 320 | if noise is None: 321 | noise = torch.randn_like(x0) 322 | return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + 323 | extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) -------------------------------------------------------------------------------- /lvdm/models/utils_diffusion.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from einops import repeat 6 | 7 | 8 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 9 | """ 10 | Create sinusoidal timestep embeddings. 11 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 12 | These may be fractional. 13 | :param dim: the dimension of the output. 14 | :param max_period: controls the minimum frequency of the embeddings. 15 | :return: an [N x dim] Tensor of positional embeddings. 16 | """ 17 | if not repeat_only: 18 | half = dim // 2 19 | freqs = torch.exp( 20 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 21 | ).to(device=timesteps.device) 22 | args = timesteps[:, None].float() * freqs[None] 23 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 24 | if dim % 2: 25 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 26 | else: 27 | embedding = repeat(timesteps, 'b -> b d', d=dim) 28 | return embedding 29 | 30 | 31 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 32 | if schedule == "linear": 33 | betas = ( 34 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 35 | ) 36 | 37 | elif schedule == "cosine": 38 | timesteps = ( 39 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 40 | ) 41 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 42 | alphas = torch.cos(alphas).pow(2) 43 | alphas = alphas / alphas[0] 44 | betas = 1 - alphas[1:] / alphas[:-1] 45 | betas = np.clip(betas, a_min=0, a_max=0.999) 46 | 47 | elif schedule == "sqrt_linear": 48 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 49 | elif schedule == "sqrt": 50 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 51 | else: 52 | raise ValueError(f"schedule '{schedule}' unknown.") 53 | return betas.numpy() 54 | 55 | 56 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 57 | if ddim_discr_method == 'uniform': 58 | c = num_ddpm_timesteps // num_ddim_timesteps 59 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 60 | steps_out = ddim_timesteps + 1 61 | elif ddim_discr_method == 'uniform_trailing': 62 | c = num_ddpm_timesteps / num_ddim_timesteps 63 | ddim_timesteps = np.flip(np.round(np.arange(num_ddpm_timesteps, 0, -c))).astype(np.int64) 64 | steps_out = ddim_timesteps - 1 65 | elif ddim_discr_method == 'quad': 66 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 67 | steps_out = ddim_timesteps + 1 68 | else: 69 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 70 | 71 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 72 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 73 | # steps_out = ddim_timesteps + 1 74 | if verbose: 75 | print(f'Selected timesteps for ddim sampler: {steps_out}') 76 | return steps_out 77 | 78 | 79 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 80 | # select alphas for computing the variance schedule 81 | # print(f'ddim_timesteps={ddim_timesteps}, len_alphacums={len(alphacums)}') 82 | alphas = alphacums[ddim_timesteps] 83 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 84 | 85 | # according the the formula provided in https://arxiv.org/abs/2010.02502 86 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 87 | if verbose: 88 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 89 | print(f'For the chosen value of eta, which is {eta}, ' 90 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 91 | return sigmas, alphas, alphas_prev 92 | 93 | 94 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 95 | """ 96 | Create a beta schedule that discretizes the given alpha_t_bar function, 97 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 98 | :param num_diffusion_timesteps: the number of betas to produce. 99 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 100 | produces the cumulative product of (1-beta) up to that 101 | part of the diffusion process. 102 | :param max_beta: the maximum beta to use; use values lower than 1 to 103 | prevent singularities. 104 | """ 105 | betas = [] 106 | for i in range(num_diffusion_timesteps): 107 | t1 = i / num_diffusion_timesteps 108 | t2 = (i + 1) / num_diffusion_timesteps 109 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 110 | return np.array(betas) 111 | 112 | def rescale_zero_terminal_snr(betas): 113 | """ 114 | Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) 115 | 116 | Args: 117 | betas (`numpy.ndarray`): 118 | the betas that the scheduler is being initialized with. 119 | 120 | Returns: 121 | `numpy.ndarray`: rescaled betas with zero terminal SNR 122 | """ 123 | # Convert betas to alphas_bar_sqrt 124 | alphas = 1.0 - betas 125 | alphas_cumprod = np.cumprod(alphas, axis=0) 126 | alphas_bar_sqrt = np.sqrt(alphas_cumprod) 127 | 128 | # Store old values. 129 | alphas_bar_sqrt_0 = alphas_bar_sqrt[0].copy() 130 | alphas_bar_sqrt_T = alphas_bar_sqrt[-1].copy() 131 | 132 | # Shift so the last timestep is zero. 133 | alphas_bar_sqrt -= alphas_bar_sqrt_T 134 | 135 | # Scale so the first timestep is back to the old value. 136 | alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) 137 | 138 | # Convert alphas_bar_sqrt to betas 139 | alphas_bar = alphas_bar_sqrt**2 # Revert sqrt 140 | alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod 141 | alphas = np.concatenate([alphas_bar[0:1], alphas]) 142 | betas = 1 - alphas 143 | 144 | return betas 145 | 146 | 147 | def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): 148 | """ 149 | Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and 150 | Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 151 | """ 152 | std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) 153 | std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) 154 | # rescale the results from guidance (fixes overexposure) 155 | noise_pred_rescaled = noise_cfg * (std_text / std_cfg) 156 | # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images 157 | noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg 158 | return noise_cfg -------------------------------------------------------------------------------- /lvdm/modules/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | from einops import rearrange, repeat 5 | from functools import partial 6 | try: 7 | import xformers 8 | import xformers.ops 9 | XFORMERS_IS_AVAILBLE = True 10 | except: 11 | XFORMERS_IS_AVAILBLE = False 12 | from lvdm.common import ( 13 | checkpoint, 14 | exists, 15 | default, 16 | ) 17 | from lvdm.basics import zero_module 18 | 19 | 20 | class RelativePosition(nn.Module): 21 | """ https://github.com/evelinehong/Transformer_Relative_Position_PyTorch/blob/master/relative_position.py """ 22 | 23 | def __init__(self, num_units, max_relative_position): 24 | super().__init__() 25 | self.num_units = num_units 26 | self.max_relative_position = max_relative_position 27 | self.embeddings_table = nn.Parameter(torch.Tensor(max_relative_position * 2 + 1, num_units)) 28 | nn.init.xavier_uniform_(self.embeddings_table) 29 | 30 | def forward(self, length_q, length_k): 31 | device = self.embeddings_table.device 32 | range_vec_q = torch.arange(length_q, device=device) 33 | range_vec_k = torch.arange(length_k, device=device) 34 | distance_mat = range_vec_k[None, :] - range_vec_q[:, None] 35 | distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position) 36 | final_mat = distance_mat_clipped + self.max_relative_position 37 | final_mat = final_mat.long() 38 | embeddings = self.embeddings_table[final_mat] 39 | return embeddings 40 | 41 | 42 | class CrossAttention(nn.Module): 43 | 44 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., 45 | relative_position=False, temporal_length=None, video_length=None, image_cross_attention=False, image_cross_attention_scale=1.0, image_cross_attention_scale_learnable=False, text_context_len=77): 46 | super().__init__() 47 | inner_dim = dim_head * heads 48 | context_dim = default(context_dim, query_dim) 49 | 50 | self.scale = dim_head**-0.5 51 | self.heads = heads 52 | self.dim_head = dim_head 53 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 54 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 55 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 56 | 57 | self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) 58 | 59 | self.relative_position = relative_position 60 | if self.relative_position: 61 | assert(temporal_length is not None) 62 | self.relative_position_k = RelativePosition(num_units=dim_head, max_relative_position=temporal_length) 63 | self.relative_position_v = RelativePosition(num_units=dim_head, max_relative_position=temporal_length) 64 | else: 65 | ## only used for spatial attention, while NOT for temporal attention 66 | if XFORMERS_IS_AVAILBLE and temporal_length is None: 67 | self.forward = self.efficient_forward 68 | 69 | self.video_length = video_length 70 | self.image_cross_attention = image_cross_attention 71 | self.image_cross_attention_scale = image_cross_attention_scale 72 | self.text_context_len = text_context_len 73 | self.image_cross_attention_scale_learnable = image_cross_attention_scale_learnable 74 | if self.image_cross_attention: 75 | self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False) 76 | self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False) 77 | if image_cross_attention_scale_learnable: 78 | self.register_parameter('alpha', nn.Parameter(torch.tensor(0.)) ) 79 | 80 | 81 | def forward(self, x, context=None, mask=None): 82 | spatial_self_attn = (context is None) 83 | k_ip, v_ip, out_ip = None, None, None 84 | 85 | h = self.heads 86 | q = self.to_q(x) 87 | context = default(context, x) 88 | 89 | if self.image_cross_attention and not spatial_self_attn: 90 | context, context_image = context[:,:self.text_context_len,:], context[:,self.text_context_len:,:] 91 | k = self.to_k(context) 92 | v = self.to_v(context) 93 | k_ip = self.to_k_ip(context_image) 94 | v_ip = self.to_v_ip(context_image) 95 | else: 96 | if not spatial_self_attn: 97 | context = context[:,:self.text_context_len,:] 98 | k = self.to_k(context) 99 | v = self.to_v(context) 100 | 101 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 102 | 103 | sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale 104 | if self.relative_position: 105 | len_q, len_k, len_v = q.shape[1], k.shape[1], v.shape[1] 106 | k2 = self.relative_position_k(len_q, len_k) 107 | sim2 = einsum('b t d, t s d -> b t s', q, k2) * self.scale # TODO check 108 | sim += sim2 109 | del k 110 | 111 | if exists(mask): 112 | ## feasible for causal attention mask only 113 | max_neg_value = -torch.finfo(sim.dtype).max 114 | mask = repeat(mask, 'b i j -> (b h) i j', h=h) 115 | sim.masked_fill_(~(mask>0.5), max_neg_value) 116 | 117 | # attention, what we cannot get enough of 118 | sim = sim.softmax(dim=-1) 119 | 120 | out = torch.einsum('b i j, b j d -> b i d', sim, v) 121 | if self.relative_position: 122 | v2 = self.relative_position_v(len_q, len_v) 123 | out2 = einsum('b t s, t s d -> b t d', sim, v2) # TODO check 124 | out += out2 125 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 126 | 127 | 128 | ## for image cross-attention 129 | if k_ip is not None: 130 | k_ip, v_ip = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (k_ip, v_ip)) 131 | sim_ip = torch.einsum('b i d, b j d -> b i j', q, k_ip) * self.scale 132 | del k_ip 133 | sim_ip = sim_ip.softmax(dim=-1) 134 | out_ip = torch.einsum('b i j, b j d -> b i d', sim_ip, v_ip) 135 | out_ip = rearrange(out_ip, '(b h) n d -> b n (h d)', h=h) 136 | 137 | 138 | if out_ip is not None: 139 | if self.image_cross_attention_scale_learnable: 140 | out = out + self.image_cross_attention_scale * out_ip * (torch.tanh(self.alpha)+1) 141 | else: 142 | out = out + self.image_cross_attention_scale * out_ip 143 | 144 | return self.to_out(out) 145 | 146 | def efficient_forward(self, x, context=None, mask=None): 147 | spatial_self_attn = (context is None) 148 | k_ip, v_ip, out_ip = None, None, None 149 | 150 | q = self.to_q(x) 151 | context = default(context, x) 152 | 153 | if self.image_cross_attention and not spatial_self_attn: 154 | context, context_image = context[:,:self.text_context_len,:], context[:,self.text_context_len:,:] 155 | k = self.to_k(context) 156 | v = self.to_v(context) 157 | k_ip = self.to_k_ip(context_image) 158 | v_ip = self.to_v_ip(context_image) 159 | else: 160 | if not spatial_self_attn: 161 | context = context[:,:self.text_context_len,:] 162 | k = self.to_k(context) 163 | v = self.to_v(context) 164 | 165 | b, _, _ = q.shape 166 | q, k, v = map( 167 | lambda t: t.unsqueeze(3) 168 | .reshape(b, t.shape[1], self.heads, self.dim_head) 169 | .permute(0, 2, 1, 3) 170 | .reshape(b * self.heads, t.shape[1], self.dim_head) 171 | .contiguous(), 172 | (q, k, v), 173 | ) 174 | # actually compute the attention, what we cannot get enough of 175 | out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=None) 176 | 177 | ## for image cross-attention 178 | if k_ip is not None: 179 | k_ip, v_ip = map( 180 | lambda t: t.unsqueeze(3) 181 | .reshape(b, t.shape[1], self.heads, self.dim_head) 182 | .permute(0, 2, 1, 3) 183 | .reshape(b * self.heads, t.shape[1], self.dim_head) 184 | .contiguous(), 185 | (k_ip, v_ip), 186 | ) 187 | out_ip = xformers.ops.memory_efficient_attention(q, k_ip, v_ip, attn_bias=None, op=None) 188 | out_ip = ( 189 | out_ip.unsqueeze(0) 190 | .reshape(b, self.heads, out.shape[1], self.dim_head) 191 | .permute(0, 2, 1, 3) 192 | .reshape(b, out.shape[1], self.heads * self.dim_head) 193 | ) 194 | 195 | if exists(mask): 196 | raise NotImplementedError 197 | out = ( 198 | out.unsqueeze(0) 199 | .reshape(b, self.heads, out.shape[1], self.dim_head) 200 | .permute(0, 2, 1, 3) 201 | .reshape(b, out.shape[1], self.heads * self.dim_head) 202 | ) 203 | if out_ip is not None: 204 | if self.image_cross_attention_scale_learnable: 205 | out = out + self.image_cross_attention_scale * out_ip * (torch.tanh(self.alpha)+1) 206 | else: 207 | out = out + self.image_cross_attention_scale * out_ip 208 | 209 | return self.to_out(out) 210 | 211 | 212 | class BasicTransformerBlock(nn.Module): 213 | 214 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, 215 | disable_self_attn=False, attention_cls=None, video_length=None, image_cross_attention=False, image_cross_attention_scale=1.0, image_cross_attention_scale_learnable=False, text_context_len=77): 216 | super().__init__() 217 | attn_cls = CrossAttention if attention_cls is None else attention_cls 218 | self.disable_self_attn = disable_self_attn 219 | self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, 220 | context_dim=context_dim if self.disable_self_attn else None) 221 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 222 | self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout, video_length=video_length, image_cross_attention=image_cross_attention, image_cross_attention_scale=image_cross_attention_scale, image_cross_attention_scale_learnable=image_cross_attention_scale_learnable,text_context_len=text_context_len) 223 | self.image_cross_attention = image_cross_attention 224 | 225 | self.norm1 = nn.LayerNorm(dim) 226 | self.norm2 = nn.LayerNorm(dim) 227 | self.norm3 = nn.LayerNorm(dim) 228 | self.checkpoint = checkpoint 229 | 230 | 231 | def forward(self, x, context=None, mask=None, **kwargs): 232 | ## implementation tricks: because checkpointing doesn't support non-tensor (e.g. None or scalar) arguments 233 | input_tuple = (x,) ## should not be (x), otherwise *input_tuple will decouple x into multiple arguments 234 | if context is not None: 235 | input_tuple = (x, context) 236 | if mask is not None: 237 | forward_mask = partial(self._forward, mask=mask) 238 | return checkpoint(forward_mask, (x,), self.parameters(), self.checkpoint) 239 | return checkpoint(self._forward, input_tuple, self.parameters(), self.checkpoint) 240 | 241 | 242 | def _forward(self, x, context=None, mask=None): 243 | x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None, mask=mask) + x 244 | x = self.attn2(self.norm2(x), context=context, mask=mask) + x 245 | x = self.ff(self.norm3(x)) + x 246 | return x 247 | 248 | 249 | class SpatialTransformer(nn.Module): 250 | """ 251 | Transformer block for image-like data in spatial axis. 252 | First, project the input (aka embedding) 253 | and reshape to b, t, d. 254 | Then apply standard transformer action. 255 | Finally, reshape to image 256 | NEW: use_linear for more efficiency instead of the 1x1 convs 257 | """ 258 | 259 | def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None, 260 | use_checkpoint=True, disable_self_attn=False, use_linear=False, video_length=None, 261 | image_cross_attention=False, image_cross_attention_scale_learnable=False): 262 | super().__init__() 263 | self.in_channels = in_channels 264 | inner_dim = n_heads * d_head 265 | self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 266 | if not use_linear: 267 | self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) 268 | else: 269 | self.proj_in = nn.Linear(in_channels, inner_dim) 270 | 271 | attention_cls = None 272 | self.transformer_blocks = nn.ModuleList([ 273 | BasicTransformerBlock( 274 | inner_dim, 275 | n_heads, 276 | d_head, 277 | dropout=dropout, 278 | context_dim=context_dim, 279 | disable_self_attn=disable_self_attn, 280 | checkpoint=use_checkpoint, 281 | attention_cls=attention_cls, 282 | video_length=video_length, 283 | image_cross_attention=image_cross_attention, 284 | image_cross_attention_scale_learnable=image_cross_attention_scale_learnable, 285 | ) for d in range(depth) 286 | ]) 287 | if not use_linear: 288 | self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) 289 | else: 290 | self.proj_out = zero_module(nn.Linear(inner_dim, in_channels)) 291 | self.use_linear = use_linear 292 | 293 | 294 | def forward(self, x, context=None, **kwargs): 295 | b, c, h, w = x.shape 296 | x_in = x 297 | x = self.norm(x) 298 | if not self.use_linear: 299 | x = self.proj_in(x) 300 | x = rearrange(x, 'b c h w -> b (h w) c').contiguous() 301 | if self.use_linear: 302 | x = self.proj_in(x) 303 | for i, block in enumerate(self.transformer_blocks): 304 | x = block(x, context=context, **kwargs) 305 | if self.use_linear: 306 | x = self.proj_out(x) 307 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() 308 | if not self.use_linear: 309 | x = self.proj_out(x) 310 | return x + x_in 311 | 312 | 313 | class TemporalTransformer(nn.Module): 314 | """ 315 | Transformer block for image-like data in temporal axis. 316 | First, reshape to b, t, d. 317 | Then apply standard transformer action. 318 | Finally, reshape to image 319 | """ 320 | def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None, 321 | use_checkpoint=True, use_linear=False, only_self_att=True, causal_attention=False, causal_block_size=1, 322 | relative_position=False, temporal_length=None): 323 | super().__init__() 324 | self.only_self_att = only_self_att 325 | self.relative_position = relative_position 326 | self.causal_attention = causal_attention 327 | self.causal_block_size = causal_block_size 328 | 329 | self.in_channels = in_channels 330 | inner_dim = n_heads * d_head 331 | self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 332 | self.proj_in = nn.Conv1d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) 333 | if not use_linear: 334 | self.proj_in = nn.Conv1d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) 335 | else: 336 | self.proj_in = nn.Linear(in_channels, inner_dim) 337 | 338 | if relative_position: 339 | assert(temporal_length is not None) 340 | attention_cls = partial(CrossAttention, relative_position=True, temporal_length=temporal_length) 341 | else: 342 | attention_cls = partial(CrossAttention, temporal_length=temporal_length) 343 | if self.causal_attention: 344 | assert(temporal_length is not None) 345 | self.mask = torch.tril(torch.ones([1, temporal_length, temporal_length])) 346 | 347 | if self.only_self_att: 348 | context_dim = None 349 | self.transformer_blocks = nn.ModuleList([ 350 | BasicTransformerBlock( 351 | inner_dim, 352 | n_heads, 353 | d_head, 354 | dropout=dropout, 355 | context_dim=context_dim, 356 | attention_cls=attention_cls, 357 | checkpoint=use_checkpoint) for d in range(depth) 358 | ]) 359 | if not use_linear: 360 | self.proj_out = zero_module(nn.Conv1d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) 361 | else: 362 | self.proj_out = zero_module(nn.Linear(inner_dim, in_channels)) 363 | self.use_linear = use_linear 364 | 365 | def forward(self, x, context=None): 366 | b, c, t, h, w = x.shape 367 | x_in = x 368 | x = self.norm(x) 369 | x = rearrange(x, 'b c t h w -> (b h w) c t').contiguous() 370 | if not self.use_linear: 371 | x = self.proj_in(x) 372 | x = rearrange(x, 'bhw c t -> bhw t c').contiguous() 373 | if self.use_linear: 374 | x = self.proj_in(x) 375 | 376 | temp_mask = None 377 | if self.causal_attention: 378 | # slice the from mask map 379 | temp_mask = self.mask[:,:t,:t].to(x.device) 380 | 381 | if temp_mask is not None: 382 | mask = temp_mask.to(x.device) 383 | mask = repeat(mask, 'l i j -> (l bhw) i j', bhw=b*h*w) 384 | else: 385 | mask = None 386 | 387 | if self.only_self_att: 388 | ## note: if no context is given, cross-attention defaults to self-attention 389 | for i, block in enumerate(self.transformer_blocks): 390 | x = block(x, mask=mask) 391 | x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous() 392 | else: 393 | x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous() 394 | context = rearrange(context, '(b t) l con -> b t l con', t=t).contiguous() 395 | for i, block in enumerate(self.transformer_blocks): 396 | # calculate each batch one by one (since number in shape could not greater then 65,535 for some package) 397 | for j in range(b): 398 | context_j = repeat( 399 | context[j], 400 | 't l con -> (t r) l con', r=(h * w) // t, t=t).contiguous() 401 | ## note: causal mask will not applied in cross-attention case 402 | x[j] = block(x[j], context=context_j) 403 | 404 | if self.use_linear: 405 | x = self.proj_out(x) 406 | x = rearrange(x, 'b (h w) t c -> b c t h w', h=h, w=w).contiguous() 407 | if not self.use_linear: 408 | x = rearrange(x, 'b hw t c -> (b hw) c t').contiguous() 409 | x = self.proj_out(x) 410 | x = rearrange(x, '(b h w) c t -> b c t h w', b=b, h=h, w=w).contiguous() 411 | 412 | return x + x_in 413 | 414 | 415 | class GEGLU(nn.Module): 416 | def __init__(self, dim_in, dim_out): 417 | super().__init__() 418 | self.proj = nn.Linear(dim_in, dim_out * 2) 419 | 420 | def forward(self, x): 421 | x, gate = self.proj(x).chunk(2, dim=-1) 422 | return x * F.gelu(gate) 423 | 424 | 425 | class FeedForward(nn.Module): 426 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 427 | super().__init__() 428 | inner_dim = int(dim * mult) 429 | dim_out = default(dim_out, dim) 430 | project_in = nn.Sequential( 431 | nn.Linear(dim, inner_dim), 432 | nn.GELU() 433 | ) if not glu else GEGLU(dim, inner_dim) 434 | 435 | self.net = nn.Sequential( 436 | project_in, 437 | nn.Dropout(dropout), 438 | nn.Linear(inner_dim, dim_out) 439 | ) 440 | 441 | def forward(self, x): 442 | return self.net(x) 443 | 444 | 445 | class LinearAttention(nn.Module): 446 | def __init__(self, dim, heads=4, dim_head=32): 447 | super().__init__() 448 | self.heads = heads 449 | hidden_dim = dim_head * heads 450 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 451 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 452 | 453 | def forward(self, x): 454 | b, c, h, w = x.shape 455 | qkv = self.to_qkv(x) 456 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) 457 | k = k.softmax(dim=-1) 458 | context = torch.einsum('bhdn,bhen->bhde', k, v) 459 | out = torch.einsum('bhde,bhdn->bhen', context, q) 460 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) 461 | return self.to_out(out) 462 | 463 | 464 | class SpatialSelfAttention(nn.Module): 465 | def __init__(self, in_channels): 466 | super().__init__() 467 | self.in_channels = in_channels 468 | 469 | self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 470 | self.q = torch.nn.Conv2d(in_channels, 471 | in_channels, 472 | kernel_size=1, 473 | stride=1, 474 | padding=0) 475 | self.k = torch.nn.Conv2d(in_channels, 476 | in_channels, 477 | kernel_size=1, 478 | stride=1, 479 | padding=0) 480 | self.v = torch.nn.Conv2d(in_channels, 481 | in_channels, 482 | kernel_size=1, 483 | stride=1, 484 | padding=0) 485 | self.proj_out = torch.nn.Conv2d(in_channels, 486 | in_channels, 487 | kernel_size=1, 488 | stride=1, 489 | padding=0) 490 | 491 | def forward(self, x): 492 | h_ = x 493 | h_ = self.norm(h_) 494 | q = self.q(h_) 495 | k = self.k(h_) 496 | v = self.v(h_) 497 | 498 | # compute attention 499 | b,c,h,w = q.shape 500 | q = rearrange(q, 'b c h w -> b (h w) c') 501 | k = rearrange(k, 'b c h w -> b c (h w)') 502 | w_ = torch.einsum('bij,bjk->bik', q, k) 503 | 504 | w_ = w_ * (int(c)**(-0.5)) 505 | w_ = torch.nn.functional.softmax(w_, dim=2) 506 | 507 | # attend to values 508 | v = rearrange(v, 'b c h w -> b c (h w)') 509 | w_ = rearrange(w_, 'b i j -> b j i') 510 | h_ = torch.einsum('bij,bjk->bik', v, w_) 511 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) 512 | h_ = self.proj_out(h_) 513 | 514 | return x+h_ 515 | -------------------------------------------------------------------------------- /lvdm/modules/encoders/condition.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import kornia 4 | import open_clip 5 | from torch.utils.checkpoint import checkpoint 6 | from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel 7 | from lvdm.common import autocast 8 | 9 | 10 | class AbstractEncoder(nn.Module): 11 | def __init__(self): 12 | super().__init__() 13 | 14 | def encode(self, *args, **kwargs): 15 | raise NotImplementedError 16 | 17 | 18 | class IdentityEncoder(AbstractEncoder): 19 | def encode(self, x): 20 | return x 21 | 22 | 23 | class ClassEmbedder(nn.Module): 24 | def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1): 25 | super().__init__() 26 | self.key = key 27 | self.embedding = nn.Embedding(n_classes, embed_dim) 28 | self.n_classes = n_classes 29 | self.ucg_rate = ucg_rate 30 | 31 | def forward(self, batch, key=None, disable_dropout=False): 32 | if key is None: 33 | key = self.key 34 | # this is for use in crossattn 35 | c = batch[key][:, None] 36 | if self.ucg_rate > 0. and not disable_dropout: 37 | mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) 38 | c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1) 39 | c = c.long() 40 | c = self.embedding(c) 41 | return c 42 | 43 | def get_unconditional_conditioning(self, bs, device="cuda"): 44 | uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) 45 | uc = torch.ones((bs,), device=device) * uc_class 46 | uc = {self.key: uc} 47 | return uc 48 | 49 | 50 | def disabled_train(self, mode=True): 51 | """Overwrite model.train with this function to make sure train/eval mode 52 | does not change anymore.""" 53 | return self 54 | 55 | 56 | class FrozenT5Embedder(AbstractEncoder): 57 | """Uses the T5 transformer encoder for text""" 58 | 59 | def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, 60 | freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl 61 | super().__init__() 62 | self.tokenizer = T5Tokenizer.from_pretrained(version) 63 | self.transformer = T5EncoderModel.from_pretrained(version) 64 | self.device = device 65 | self.max_length = max_length # TODO: typical value? 66 | if freeze: 67 | self.freeze() 68 | 69 | def freeze(self): 70 | self.transformer = self.transformer.eval() 71 | # self.train = disabled_train 72 | for param in self.parameters(): 73 | param.requires_grad = False 74 | 75 | def forward(self, text): 76 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 77 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 78 | tokens = batch_encoding["input_ids"].to(self.device) 79 | outputs = self.transformer(input_ids=tokens) 80 | 81 | z = outputs.last_hidden_state 82 | return z 83 | 84 | def encode(self, text): 85 | return self(text) 86 | 87 | 88 | class FrozenCLIPEmbedder(AbstractEncoder): 89 | """Uses the CLIP transformer encoder for text (from huggingface)""" 90 | LAYERS = [ 91 | "last", 92 | "pooled", 93 | "hidden" 94 | ] 95 | 96 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, 97 | freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 98 | super().__init__() 99 | assert layer in self.LAYERS 100 | self.tokenizer = CLIPTokenizer.from_pretrained(version) 101 | self.transformer = CLIPTextModel.from_pretrained(version) 102 | self.device = device 103 | self.max_length = max_length 104 | if freeze: 105 | self.freeze() 106 | self.layer = layer 107 | self.layer_idx = layer_idx 108 | if layer == "hidden": 109 | assert layer_idx is not None 110 | assert 0 <= abs(layer_idx) <= 12 111 | 112 | def freeze(self): 113 | self.transformer = self.transformer.eval() 114 | # self.train = disabled_train 115 | for param in self.parameters(): 116 | param.requires_grad = False 117 | 118 | def forward(self, text): 119 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 120 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 121 | tokens = batch_encoding["input_ids"].to(self.device) 122 | outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden") 123 | if self.layer == "last": 124 | z = outputs.last_hidden_state 125 | elif self.layer == "pooled": 126 | z = outputs.pooler_output[:, None, :] 127 | else: 128 | z = outputs.hidden_states[self.layer_idx] 129 | return z 130 | 131 | def encode(self, text): 132 | return self(text) 133 | 134 | 135 | class ClipImageEmbedder(nn.Module): 136 | def __init__( 137 | self, 138 | model, 139 | jit=False, 140 | device='cuda' if torch.cuda.is_available() else 'cpu', 141 | antialias=True, 142 | ucg_rate=0. 143 | ): 144 | super().__init__() 145 | from clip import load as load_clip 146 | self.model, _ = load_clip(name=model, device=device, jit=jit) 147 | 148 | self.antialias = antialias 149 | 150 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) 151 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 152 | self.ucg_rate = ucg_rate 153 | 154 | def preprocess(self, x): 155 | # normalize to [0,1] 156 | x = kornia.geometry.resize(x, (224, 224), 157 | interpolation='bicubic', align_corners=True, 158 | antialias=self.antialias) 159 | x = (x + 1.) / 2. 160 | # re-normalize according to clip 161 | x = kornia.enhance.normalize(x, self.mean, self.std) 162 | return x 163 | 164 | def forward(self, x, no_dropout=False): 165 | # x is assumed to be in range [-1,1] 166 | out = self.model.encode_image(self.preprocess(x)) 167 | out = out.to(x.dtype) 168 | if self.ucg_rate > 0. and not no_dropout: 169 | out = torch.bernoulli((1. - self.ucg_rate) * torch.ones(out.shape[0], device=out.device))[:, None] * out 170 | return out 171 | 172 | 173 | class FrozenOpenCLIPEmbedder(AbstractEncoder): 174 | """ 175 | Uses the OpenCLIP transformer encoder for text 176 | """ 177 | LAYERS = [ 178 | # "pooled", 179 | "last", 180 | "penultimate" 181 | ] 182 | 183 | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, 184 | freeze=True, layer="last"): 185 | super().__init__() 186 | assert layer in self.LAYERS 187 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) 188 | del model.visual 189 | self.model = model 190 | 191 | self.device = device 192 | self.max_length = max_length 193 | if freeze: 194 | self.freeze() 195 | self.layer = layer 196 | if self.layer == "last": 197 | self.layer_idx = 0 198 | elif self.layer == "penultimate": 199 | self.layer_idx = 1 200 | else: 201 | raise NotImplementedError() 202 | 203 | def freeze(self): 204 | self.model = self.model.eval() 205 | for param in self.parameters(): 206 | param.requires_grad = False 207 | 208 | def forward(self, text): 209 | tokens = open_clip.tokenize(text) ## all clip models use 77 as context length 210 | z = self.encode_with_transformer(tokens.to(self.device)) 211 | return z 212 | 213 | def encode_with_transformer(self, text): 214 | x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] 215 | x = x + self.model.positional_embedding 216 | x = x.permute(1, 0, 2) # NLD -> LND 217 | x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) 218 | x = x.permute(1, 0, 2) # LND -> NLD 219 | x = self.model.ln_final(x) 220 | return x 221 | 222 | def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): 223 | for i, r in enumerate(self.model.transformer.resblocks): 224 | if i == len(self.model.transformer.resblocks) - self.layer_idx: 225 | break 226 | if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): 227 | x = checkpoint(r, x, attn_mask) 228 | else: 229 | x = r(x, attn_mask=attn_mask) 230 | return x 231 | 232 | def encode(self, text): 233 | return self(text) 234 | 235 | 236 | class FrozenOpenCLIPImageEmbedder(AbstractEncoder): 237 | """ 238 | Uses the OpenCLIP vision transformer encoder for images 239 | """ 240 | 241 | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, 242 | freeze=True, layer="pooled", antialias=True, ucg_rate=0.): 243 | super().__init__() 244 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), 245 | pretrained=version, ) 246 | del model.transformer 247 | self.model = model 248 | # self.mapper = torch.nn.Linear(1280, 1024) 249 | self.device = device 250 | self.max_length = max_length 251 | if freeze: 252 | self.freeze() 253 | self.layer = layer 254 | if self.layer == "penultimate": 255 | raise NotImplementedError() 256 | self.layer_idx = 1 257 | 258 | self.antialias = antialias 259 | 260 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) 261 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 262 | self.ucg_rate = ucg_rate 263 | 264 | def preprocess(self, x): 265 | # normalize to [0,1] 266 | x = kornia.geometry.resize(x, (224, 224), 267 | interpolation='bicubic', align_corners=True, 268 | antialias=self.antialias) 269 | x = (x + 1.) / 2. 270 | # renormalize according to clip 271 | x = kornia.enhance.normalize(x, self.mean, self.std) 272 | return x 273 | 274 | def freeze(self): 275 | self.model = self.model.eval() 276 | for param in self.model.parameters(): 277 | param.requires_grad = False 278 | 279 | @autocast 280 | def forward(self, image, no_dropout=False): 281 | z = self.encode_with_vision_transformer(image) 282 | if self.ucg_rate > 0. and not no_dropout: 283 | z = torch.bernoulli((1. - self.ucg_rate) * torch.ones(z.shape[0], device=z.device))[:, None] * z 284 | return z 285 | 286 | def encode_with_vision_transformer(self, img): 287 | img = self.preprocess(img) 288 | x = self.model.visual(img) 289 | return x 290 | 291 | def encode(self, text): 292 | return self(text) 293 | 294 | class FrozenOpenCLIPImageEmbedderV2(AbstractEncoder): 295 | """ 296 | Uses the OpenCLIP vision transformer encoder for images 297 | """ 298 | 299 | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", 300 | freeze=True, layer="pooled", antialias=True): 301 | super().__init__() 302 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), 303 | pretrained=version, ) 304 | del model.transformer 305 | self.model = model 306 | self.device = device 307 | 308 | if freeze: 309 | self.freeze() 310 | self.layer = layer 311 | if self.layer == "penultimate": 312 | raise NotImplementedError() 313 | self.layer_idx = 1 314 | 315 | self.antialias = antialias 316 | 317 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) 318 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 319 | 320 | 321 | def preprocess(self, x): 322 | # normalize to [0,1] 323 | x = kornia.geometry.resize(x, (224, 224), 324 | interpolation='bicubic', align_corners=True, 325 | antialias=self.antialias) 326 | x = (x + 1.) / 2. 327 | # renormalize according to clip 328 | x = kornia.enhance.normalize(x, self.mean, self.std) 329 | return x 330 | 331 | def freeze(self): 332 | self.model = self.model.eval() 333 | for param in self.model.parameters(): 334 | param.requires_grad = False 335 | 336 | def forward(self, image, no_dropout=False): 337 | ## image: b c h w 338 | z = self.encode_with_vision_transformer(image) 339 | return z 340 | 341 | def encode_with_vision_transformer(self, x): 342 | x = self.preprocess(x) 343 | 344 | # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1 345 | if self.model.visual.input_patchnorm: 346 | # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)') 347 | x = x.reshape(x.shape[0], x.shape[1], self.model.visual.grid_size[0], self.model.visual.patch_size[0], self.model.visual.grid_size[1], self.model.visual.patch_size[1]) 348 | x = x.permute(0, 2, 4, 1, 3, 5) 349 | x = x.reshape(x.shape[0], self.model.visual.grid_size[0] * self.model.visual.grid_size[1], -1) 350 | x = self.model.visual.patchnorm_pre_ln(x) 351 | x = self.model.visual.conv1(x) 352 | else: 353 | x = self.model.visual.conv1(x) # shape = [*, width, grid, grid] 354 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 355 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 356 | 357 | # class embeddings and positional embeddings 358 | x = torch.cat( 359 | [self.model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), 360 | x], dim=1) # shape = [*, grid ** 2 + 1, width] 361 | x = x + self.model.visual.positional_embedding.to(x.dtype) 362 | 363 | # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in 364 | x = self.model.visual.patch_dropout(x) 365 | x = self.model.visual.ln_pre(x) 366 | 367 | x = x.permute(1, 0, 2) # NLD -> LND 368 | x = self.model.visual.transformer(x) 369 | x = x.permute(1, 0, 2) # LND -> NLD 370 | 371 | return x 372 | 373 | class FrozenCLIPT5Encoder(AbstractEncoder): 374 | def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", 375 | clip_max_length=77, t5_max_length=77): 376 | super().__init__() 377 | self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) 378 | self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) 379 | # print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, " 380 | # f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params.") 381 | 382 | def encode(self, text): 383 | return self(text) 384 | 385 | def forward(self, text): 386 | clip_z = self.clip_encoder.encode(text) 387 | t5_z = self.t5_encoder.encode(text) 388 | return [clip_z, t5_z] 389 | -------------------------------------------------------------------------------- /lvdm/modules/encoders/resampler.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py 2 | # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py 3 | # and https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/resampler.py 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class ImageProjModel(nn.Module): 10 | """Projection Model""" 11 | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): 12 | super().__init__() 13 | self.cross_attention_dim = cross_attention_dim 14 | self.clip_extra_context_tokens = clip_extra_context_tokens 15 | self.proj = nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) 16 | self.norm = nn.LayerNorm(cross_attention_dim) 17 | 18 | def forward(self, image_embeds): 19 | #embeds = image_embeds 20 | embeds = image_embeds.type(list(self.proj.parameters())[0].dtype) 21 | clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim) 22 | clip_extra_context_tokens = self.norm(clip_extra_context_tokens) 23 | return clip_extra_context_tokens 24 | 25 | 26 | # FFN 27 | def FeedForward(dim, mult=4): 28 | inner_dim = int(dim * mult) 29 | return nn.Sequential( 30 | nn.LayerNorm(dim), 31 | nn.Linear(dim, inner_dim, bias=False), 32 | nn.GELU(), 33 | nn.Linear(inner_dim, dim, bias=False), 34 | ) 35 | 36 | 37 | def reshape_tensor(x, heads): 38 | bs, length, width = x.shape 39 | #(bs, length, width) --> (bs, length, n_heads, dim_per_head) 40 | x = x.view(bs, length, heads, -1) 41 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) 42 | x = x.transpose(1, 2) 43 | # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) 44 | x = x.reshape(bs, heads, length, -1) 45 | return x 46 | 47 | 48 | class PerceiverAttention(nn.Module): 49 | def __init__(self, *, dim, dim_head=64, heads=8): 50 | super().__init__() 51 | self.scale = dim_head**-0.5 52 | self.dim_head = dim_head 53 | self.heads = heads 54 | inner_dim = dim_head * heads 55 | 56 | self.norm1 = nn.LayerNorm(dim) 57 | self.norm2 = nn.LayerNorm(dim) 58 | 59 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 60 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 61 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 62 | 63 | 64 | def forward(self, x, latents): 65 | """ 66 | Args: 67 | x (torch.Tensor): image features 68 | shape (b, n1, D) 69 | latent (torch.Tensor): latent features 70 | shape (b, n2, D) 71 | """ 72 | x = self.norm1(x) 73 | latents = self.norm2(latents) 74 | 75 | b, l, _ = latents.shape 76 | 77 | q = self.to_q(latents) 78 | kv_input = torch.cat((x, latents), dim=-2) 79 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 80 | 81 | q = reshape_tensor(q, self.heads) 82 | k = reshape_tensor(k, self.heads) 83 | v = reshape_tensor(v, self.heads) 84 | 85 | # attention 86 | scale = 1 / math.sqrt(math.sqrt(self.dim_head)) 87 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards 88 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 89 | out = weight @ v 90 | 91 | out = out.permute(0, 2, 1, 3).reshape(b, l, -1) 92 | 93 | return self.to_out(out) 94 | 95 | 96 | class Resampler(nn.Module): 97 | def __init__( 98 | self, 99 | dim=1024, 100 | depth=8, 101 | dim_head=64, 102 | heads=16, 103 | num_queries=8, 104 | embedding_dim=768, 105 | output_dim=1024, 106 | ff_mult=4, 107 | video_length=None, # using frame-wise version or not 108 | ): 109 | super().__init__() 110 | ## queries for a single frame / image 111 | self.num_queries = num_queries 112 | self.video_length = video_length 113 | 114 | ## queries for each frame 115 | if video_length is not None: 116 | num_queries = num_queries * video_length 117 | 118 | self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) 119 | self.proj_in = nn.Linear(embedding_dim, dim) 120 | self.proj_out = nn.Linear(dim, output_dim) 121 | self.norm_out = nn.LayerNorm(output_dim) 122 | 123 | self.layers = nn.ModuleList([]) 124 | for _ in range(depth): 125 | self.layers.append( 126 | nn.ModuleList( 127 | [ 128 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 129 | FeedForward(dim=dim, mult=ff_mult), 130 | ] 131 | ) 132 | ) 133 | 134 | def forward(self, x): 135 | latents = self.latents.repeat(x.size(0), 1, 1) ## B (T L) C 136 | x = self.proj_in(x) 137 | 138 | for attn, ff in self.layers: 139 | latents = attn(x, latents) + latents 140 | latents = ff(latents) + latents 141 | 142 | latents = self.proj_out(latents) 143 | latents = self.norm_out(latents) # B L C or B (T L) C 144 | 145 | return latents -------------------------------------------------------------------------------- /lvdm/modules/networks/openaimodel3d.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from abc import abstractmethod 3 | import torch 4 | import torch.nn as nn 5 | from einops import rearrange 6 | import torch.nn.functional as F 7 | from lvdm.models.utils_diffusion import timestep_embedding 8 | from lvdm.common import checkpoint 9 | from lvdm.basics import ( 10 | zero_module, 11 | conv_nd, 12 | linear, 13 | avg_pool_nd, 14 | normalization 15 | ) 16 | from lvdm.modules.attention import SpatialTransformer, TemporalTransformer 17 | 18 | 19 | class TimestepBlock(nn.Module): 20 | """ 21 | Any module where forward() takes timestep embeddings as a second argument. 22 | """ 23 | @abstractmethod 24 | def forward(self, x, emb): 25 | """ 26 | Apply the module to `x` given `emb` timestep embeddings. 27 | """ 28 | 29 | 30 | class TimestepEmbedSequential(nn.Sequential, TimestepBlock): 31 | """ 32 | A sequential module that passes timestep embeddings to the children that 33 | support it as an extra input. 34 | """ 35 | 36 | def forward(self, x, emb, context=None, batch_size=None): 37 | for layer in self: 38 | if isinstance(layer, TimestepBlock): 39 | x = layer(x, emb, batch_size=batch_size) 40 | elif isinstance(layer, SpatialTransformer): 41 | x = layer(x, context) 42 | elif isinstance(layer, TemporalTransformer): 43 | x = rearrange(x, '(b f) c h w -> b c f h w', b=batch_size) 44 | x = layer(x, context) 45 | x = rearrange(x, 'b c f h w -> (b f) c h w') 46 | else: 47 | x = layer(x) 48 | return x 49 | 50 | 51 | class Downsample(nn.Module): 52 | """ 53 | A downsampling layer with an optional convolution. 54 | :param channels: channels in the inputs and outputs. 55 | :param use_conv: a bool determining if a convolution is applied. 56 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 57 | downsampling occurs in the inner-two dimensions. 58 | """ 59 | 60 | def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): 61 | super().__init__() 62 | self.channels = channels 63 | self.out_channels = out_channels or channels 64 | self.use_conv = use_conv 65 | self.dims = dims 66 | stride = 2 if dims != 3 else (1, 2, 2) 67 | if use_conv: 68 | self.op = conv_nd( 69 | dims, self.channels, self.out_channels, 3, stride=stride, padding=padding 70 | ) 71 | else: 72 | assert self.channels == self.out_channels 73 | self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) 74 | 75 | def forward(self, x): 76 | assert x.shape[1] == self.channels 77 | return self.op(x) 78 | 79 | 80 | class Upsample(nn.Module): 81 | """ 82 | An upsampling layer with an optional convolution. 83 | :param channels: channels in the inputs and outputs. 84 | :param use_conv: a bool determining if a convolution is applied. 85 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 86 | upsampling occurs in the inner-two dimensions. 87 | """ 88 | 89 | def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): 90 | super().__init__() 91 | self.channels = channels 92 | self.out_channels = out_channels or channels 93 | self.use_conv = use_conv 94 | self.dims = dims 95 | if use_conv: 96 | self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) 97 | 98 | def forward(self, x): 99 | assert x.shape[1] == self.channels 100 | if self.dims == 3: 101 | x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode='nearest') 102 | else: 103 | x = F.interpolate(x, scale_factor=2, mode='nearest') 104 | if self.use_conv: 105 | x = self.conv(x) 106 | return x 107 | 108 | 109 | class ResBlock(TimestepBlock): 110 | """ 111 | A residual block that can optionally change the number of channels. 112 | :param channels: the number of input channels. 113 | :param emb_channels: the number of timestep embedding channels. 114 | :param dropout: the rate of dropout. 115 | :param out_channels: if specified, the number of out channels. 116 | :param use_conv: if True and out_channels is specified, use a spatial 117 | convolution instead of a smaller 1x1 convolution to change the 118 | channels in the skip connection. 119 | :param dims: determines if the signal is 1D, 2D, or 3D. 120 | :param up: if True, use this block for upsampling. 121 | :param down: if True, use this block for downsampling. 122 | :param use_temporal_conv: if True, use the temporal convolution. 123 | :param use_image_dataset: if True, the temporal parameters will not be optimized. 124 | """ 125 | 126 | def __init__( 127 | self, 128 | channels, 129 | emb_channels, 130 | dropout, 131 | out_channels=None, 132 | use_scale_shift_norm=False, 133 | dims=2, 134 | use_checkpoint=False, 135 | use_conv=False, 136 | up=False, 137 | down=False, 138 | use_temporal_conv=False, 139 | tempspatial_aware=False 140 | ): 141 | super().__init__() 142 | self.channels = channels 143 | self.emb_channels = emb_channels 144 | self.dropout = dropout 145 | self.out_channels = out_channels or channels 146 | self.use_conv = use_conv 147 | self.use_checkpoint = use_checkpoint 148 | self.use_scale_shift_norm = use_scale_shift_norm 149 | self.use_temporal_conv = use_temporal_conv 150 | 151 | self.in_layers = nn.Sequential( 152 | normalization(channels), 153 | nn.SiLU(), 154 | conv_nd(dims, channels, self.out_channels, 3, padding=1), 155 | ) 156 | 157 | self.updown = up or down 158 | 159 | if up: 160 | self.h_upd = Upsample(channels, False, dims) 161 | self.x_upd = Upsample(channels, False, dims) 162 | elif down: 163 | self.h_upd = Downsample(channels, False, dims) 164 | self.x_upd = Downsample(channels, False, dims) 165 | else: 166 | self.h_upd = self.x_upd = nn.Identity() 167 | 168 | self.emb_layers = nn.Sequential( 169 | nn.SiLU(), 170 | nn.Linear( 171 | emb_channels, 172 | 2 * self.out_channels if use_scale_shift_norm else self.out_channels, 173 | ), 174 | ) 175 | self.out_layers = nn.Sequential( 176 | normalization(self.out_channels), 177 | nn.SiLU(), 178 | nn.Dropout(p=dropout), 179 | zero_module(nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)), 180 | ) 181 | 182 | if self.out_channels == channels: 183 | self.skip_connection = nn.Identity() 184 | elif use_conv: 185 | self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1) 186 | else: 187 | self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) 188 | 189 | if self.use_temporal_conv: 190 | self.temopral_conv = TemporalConvBlock( 191 | self.out_channels, 192 | self.out_channels, 193 | dropout=0.1, 194 | spatial_aware=tempspatial_aware 195 | ) 196 | 197 | def forward(self, x, emb, batch_size=None): 198 | """ 199 | Apply the block to a Tensor, conditioned on a timestep embedding. 200 | :param x: an [N x C x ...] Tensor of features. 201 | :param emb: an [N x emb_channels] Tensor of timestep embeddings. 202 | :return: an [N x C x ...] Tensor of outputs. 203 | """ 204 | input_tuple = (x, emb) 205 | if batch_size: 206 | forward_batchsize = partial(self._forward, batch_size=batch_size) 207 | return checkpoint(forward_batchsize, input_tuple, self.parameters(), self.use_checkpoint) 208 | return checkpoint(self._forward, input_tuple, self.parameters(), self.use_checkpoint) 209 | 210 | def _forward(self, x, emb, batch_size=None): 211 | if self.updown: 212 | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] 213 | h = in_rest(x) 214 | h = self.h_upd(h) 215 | x = self.x_upd(x) 216 | h = in_conv(h) 217 | else: 218 | h = self.in_layers(x) 219 | emb_out = self.emb_layers(emb).type(h.dtype) 220 | while len(emb_out.shape) < len(h.shape): 221 | emb_out = emb_out[..., None] 222 | if self.use_scale_shift_norm: 223 | out_norm, out_rest = self.out_layers[0], self.out_layers[1:] 224 | scale, shift = torch.chunk(emb_out, 2, dim=1) 225 | h = out_norm(h) * (1 + scale) + shift 226 | h = out_rest(h) 227 | else: 228 | h = h + emb_out 229 | h = self.out_layers(h) 230 | h = self.skip_connection(x) + h 231 | 232 | if self.use_temporal_conv and batch_size: 233 | h = rearrange(h, '(b t) c h w -> b c t h w', b=batch_size) 234 | h = self.temopral_conv(h) 235 | h = rearrange(h, 'b c t h w -> (b t) c h w') 236 | return h 237 | 238 | 239 | class TemporalConvBlock(nn.Module): 240 | """ 241 | Adapted from modelscope: https://github.com/modelscope/modelscope/blob/master/modelscope/models/multi_modal/video_synthesis/unet_sd.py 242 | """ 243 | def __init__(self, in_channels, out_channels=None, dropout=0.0, spatial_aware=False): 244 | super(TemporalConvBlock, self).__init__() 245 | if out_channels is None: 246 | out_channels = in_channels 247 | self.in_channels = in_channels 248 | self.out_channels = out_channels 249 | th_kernel_shape = (3, 1, 1) if not spatial_aware else (3, 3, 1) 250 | th_padding_shape = (1, 0, 0) if not spatial_aware else (1, 1, 0) 251 | tw_kernel_shape = (3, 1, 1) if not spatial_aware else (3, 1, 3) 252 | tw_padding_shape = (1, 0, 0) if not spatial_aware else (1, 0, 1) 253 | 254 | # conv layers 255 | self.conv1 = nn.Sequential( 256 | nn.GroupNorm(32, in_channels), nn.SiLU(), 257 | nn.Conv3d(in_channels, out_channels, th_kernel_shape, padding=th_padding_shape)) 258 | self.conv2 = nn.Sequential( 259 | nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout), 260 | nn.Conv3d(out_channels, in_channels, tw_kernel_shape, padding=tw_padding_shape)) 261 | self.conv3 = nn.Sequential( 262 | nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout), 263 | nn.Conv3d(out_channels, in_channels, th_kernel_shape, padding=th_padding_shape)) 264 | self.conv4 = nn.Sequential( 265 | nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout), 266 | nn.Conv3d(out_channels, in_channels, tw_kernel_shape, padding=tw_padding_shape)) 267 | 268 | # zero out the last layer params,so the conv block is identity 269 | nn.init.zeros_(self.conv4[-1].weight) 270 | nn.init.zeros_(self.conv4[-1].bias) 271 | 272 | def forward(self, x): 273 | identity = x 274 | x = self.conv1(x) 275 | x = self.conv2(x) 276 | x = self.conv3(x) 277 | x = self.conv4(x) 278 | 279 | return identity + x 280 | 281 | class UNetModel(nn.Module): 282 | """ 283 | The full UNet model with attention and timestep embedding. 284 | :param in_channels: in_channels in the input Tensor. 285 | :param model_channels: base channel count for the model. 286 | :param out_channels: channels in the output Tensor. 287 | :param num_res_blocks: number of residual blocks per downsample. 288 | :param attention_resolutions: a collection of downsample rates at which 289 | attention will take place. May be a set, list, or tuple. 290 | For example, if this contains 4, then at 4x downsampling, attention 291 | will be used. 292 | :param dropout: the dropout probability. 293 | :param channel_mult: channel multiplier for each level of the UNet. 294 | :param conv_resample: if True, use learned convolutions for upsampling and 295 | downsampling. 296 | :param dims: determines if the signal is 1D, 2D, or 3D. 297 | :param num_classes: if specified (as an int), then this model will be 298 | class-conditional with `num_classes` classes. 299 | :param use_checkpoint: use gradient checkpointing to reduce memory usage. 300 | :param num_heads: the number of attention heads in each attention layer. 301 | :param num_heads_channels: if specified, ignore num_heads and instead use 302 | a fixed channel width per attention head. 303 | :param num_heads_upsample: works with num_heads to set a different number 304 | of heads for upsampling. Deprecated. 305 | :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. 306 | :param resblock_updown: use residual blocks for up/downsampling. 307 | :param use_new_attention_order: use a different attention pattern for potentially 308 | increased efficiency. 309 | """ 310 | 311 | def __init__(self, 312 | in_channels, 313 | model_channels, 314 | out_channels, 315 | num_res_blocks, 316 | attention_resolutions, 317 | dropout=0.0, 318 | channel_mult=(1, 2, 4, 8), 319 | conv_resample=True, 320 | dims=2, 321 | context_dim=None, 322 | use_scale_shift_norm=False, 323 | resblock_updown=False, 324 | num_heads=-1, 325 | num_head_channels=-1, 326 | transformer_depth=1, 327 | use_linear=False, 328 | use_checkpoint=False, 329 | temporal_conv=False, 330 | tempspatial_aware=False, 331 | temporal_attention=True, 332 | use_relative_position=True, 333 | use_causal_attention=False, 334 | temporal_length=None, 335 | use_fp16=False, 336 | addition_attention=False, 337 | temporal_selfatt_only=True, 338 | image_cross_attention=False, 339 | image_cross_attention_scale_learnable=False, 340 | default_fs=4, 341 | fs_condition=False, 342 | ): 343 | super(UNetModel, self).__init__() 344 | if num_heads == -1: 345 | assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' 346 | if num_head_channels == -1: 347 | assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' 348 | 349 | self.in_channels = in_channels 350 | self.model_channels = model_channels 351 | self.out_channels = out_channels 352 | self.num_res_blocks = num_res_blocks 353 | self.attention_resolutions = attention_resolutions 354 | self.dropout = dropout 355 | self.channel_mult = channel_mult 356 | self.conv_resample = conv_resample 357 | self.temporal_attention = temporal_attention 358 | time_embed_dim = model_channels * 4 359 | self.use_checkpoint = use_checkpoint 360 | self.dtype = torch.float16 if use_fp16 else torch.float32 361 | temporal_self_att_only = True 362 | self.addition_attention = addition_attention 363 | self.temporal_length = temporal_length 364 | self.image_cross_attention = image_cross_attention 365 | self.image_cross_attention_scale_learnable = image_cross_attention_scale_learnable 366 | self.default_fs = default_fs 367 | self.fs_condition = fs_condition 368 | 369 | ## Time embedding blocks 370 | self.time_embed = nn.Sequential( 371 | linear(model_channels, time_embed_dim), 372 | nn.SiLU(), 373 | linear(time_embed_dim, time_embed_dim), 374 | ) 375 | if fs_condition: 376 | self.fps_embedding = nn.Sequential( 377 | linear(model_channels, time_embed_dim), 378 | nn.SiLU(), 379 | linear(time_embed_dim, time_embed_dim), 380 | ) 381 | nn.init.zeros_(self.fps_embedding[-1].weight) 382 | nn.init.zeros_(self.fps_embedding[-1].bias) 383 | ## Input Block 384 | self.input_blocks = nn.ModuleList( 385 | [ 386 | TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1)) 387 | ] 388 | ) 389 | if self.addition_attention: 390 | self.init_attn=TimestepEmbedSequential( 391 | TemporalTransformer( 392 | model_channels, 393 | n_heads=8, 394 | d_head=num_head_channels, 395 | depth=transformer_depth, 396 | context_dim=context_dim, 397 | use_checkpoint=use_checkpoint, only_self_att=temporal_selfatt_only, 398 | causal_attention=False, relative_position=use_relative_position, 399 | temporal_length=temporal_length)) 400 | 401 | input_block_chans = [model_channels] 402 | ch = model_channels 403 | ds = 1 404 | for level, mult in enumerate(channel_mult): 405 | for _ in range(num_res_blocks): 406 | layers = [ 407 | ResBlock(ch, time_embed_dim, dropout, 408 | out_channels=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint, 409 | use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware, 410 | use_temporal_conv=temporal_conv 411 | ) 412 | ] 413 | ch = mult * model_channels 414 | if ds in attention_resolutions: 415 | if num_head_channels == -1: 416 | dim_head = ch // num_heads 417 | else: 418 | num_heads = ch // num_head_channels 419 | dim_head = num_head_channels 420 | layers.append( 421 | SpatialTransformer(ch, num_heads, dim_head, 422 | depth=transformer_depth, context_dim=context_dim, use_linear=use_linear, 423 | use_checkpoint=use_checkpoint, disable_self_attn=False, 424 | video_length=temporal_length, image_cross_attention=self.image_cross_attention, 425 | image_cross_attention_scale_learnable=self.image_cross_attention_scale_learnable, 426 | ) 427 | ) 428 | if self.temporal_attention: 429 | layers.append( 430 | TemporalTransformer(ch, num_heads, dim_head, 431 | depth=transformer_depth, context_dim=context_dim, use_linear=use_linear, 432 | use_checkpoint=use_checkpoint, only_self_att=temporal_self_att_only, 433 | causal_attention=use_causal_attention, relative_position=use_relative_position, 434 | temporal_length=temporal_length 435 | ) 436 | ) 437 | self.input_blocks.append(TimestepEmbedSequential(*layers)) 438 | input_block_chans.append(ch) 439 | if level != len(channel_mult) - 1: 440 | out_ch = ch 441 | self.input_blocks.append( 442 | TimestepEmbedSequential( 443 | ResBlock(ch, time_embed_dim, dropout, 444 | out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint, 445 | use_scale_shift_norm=use_scale_shift_norm, 446 | down=True 447 | ) 448 | if resblock_updown 449 | else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch) 450 | ) 451 | ) 452 | ch = out_ch 453 | input_block_chans.append(ch) 454 | ds *= 2 455 | 456 | if num_head_channels == -1: 457 | dim_head = ch // num_heads 458 | else: 459 | num_heads = ch // num_head_channels 460 | dim_head = num_head_channels 461 | layers = [ 462 | ResBlock(ch, time_embed_dim, dropout, 463 | dims=dims, use_checkpoint=use_checkpoint, 464 | use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware, 465 | use_temporal_conv=temporal_conv 466 | ), 467 | SpatialTransformer(ch, num_heads, dim_head, 468 | depth=transformer_depth, context_dim=context_dim, use_linear=use_linear, 469 | use_checkpoint=use_checkpoint, disable_self_attn=False, video_length=temporal_length, 470 | image_cross_attention=self.image_cross_attention,image_cross_attention_scale_learnable=self.image_cross_attention_scale_learnable 471 | ) 472 | ] 473 | if self.temporal_attention: 474 | layers.append( 475 | TemporalTransformer(ch, num_heads, dim_head, 476 | depth=transformer_depth, context_dim=context_dim, use_linear=use_linear, 477 | use_checkpoint=use_checkpoint, only_self_att=temporal_self_att_only, 478 | causal_attention=use_causal_attention, relative_position=use_relative_position, 479 | temporal_length=temporal_length 480 | ) 481 | ) 482 | layers.append( 483 | ResBlock(ch, time_embed_dim, dropout, 484 | dims=dims, use_checkpoint=use_checkpoint, 485 | use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware, 486 | use_temporal_conv=temporal_conv 487 | ) 488 | ) 489 | 490 | ## Middle Block 491 | self.middle_block = TimestepEmbedSequential(*layers) 492 | 493 | ## Output Block 494 | self.output_blocks = nn.ModuleList([]) 495 | for level, mult in list(enumerate(channel_mult))[::-1]: 496 | for i in range(num_res_blocks + 1): 497 | ich = input_block_chans.pop() 498 | layers = [ 499 | ResBlock(ch + ich, time_embed_dim, dropout, 500 | out_channels=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint, 501 | use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware, 502 | use_temporal_conv=temporal_conv 503 | ) 504 | ] 505 | ch = model_channels * mult 506 | if ds in attention_resolutions: 507 | if num_head_channels == -1: 508 | dim_head = ch // num_heads 509 | else: 510 | num_heads = ch // num_head_channels 511 | dim_head = num_head_channels 512 | layers.append( 513 | SpatialTransformer(ch, num_heads, dim_head, 514 | depth=transformer_depth, context_dim=context_dim, use_linear=use_linear, 515 | use_checkpoint=use_checkpoint, disable_self_attn=False, video_length=temporal_length, 516 | image_cross_attention=self.image_cross_attention,image_cross_attention_scale_learnable=self.image_cross_attention_scale_learnable 517 | ) 518 | ) 519 | if self.temporal_attention: 520 | layers.append( 521 | TemporalTransformer(ch, num_heads, dim_head, 522 | depth=transformer_depth, context_dim=context_dim, use_linear=use_linear, 523 | use_checkpoint=use_checkpoint, only_self_att=temporal_self_att_only, 524 | causal_attention=use_causal_attention, relative_position=use_relative_position, 525 | temporal_length=temporal_length 526 | ) 527 | ) 528 | if level and i == num_res_blocks: 529 | out_ch = ch 530 | layers.append( 531 | ResBlock(ch, time_embed_dim, dropout, 532 | out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint, 533 | use_scale_shift_norm=use_scale_shift_norm, 534 | up=True 535 | ) 536 | if resblock_updown 537 | else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) 538 | ) 539 | ds //= 2 540 | self.output_blocks.append(TimestepEmbedSequential(*layers)) 541 | 542 | self.out = nn.Sequential( 543 | normalization(ch), 544 | nn.SiLU(), 545 | zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), 546 | ) 547 | 548 | def forward(self, x, timesteps, context=None, features_adapter=None, fs=None, **kwargs): 549 | b,_,t,_,_ = x.shape 550 | t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).type(x.dtype) 551 | emb = self.time_embed(t_emb) 552 | 553 | ## repeat t times for context [(b t) 77 768] & time embedding 554 | ## check if we use per-frame image conditioning 555 | _, l_context, _ = context.shape 556 | # if l_context == 77 + t*16: ## !!! HARD CODE here 557 | # context_text, context_img = context[:,:77,:], context[:,77:,:] 558 | # context_text = context_text.repeat_interleave(repeats=t, dim=0) 559 | # context_img = rearrange(context_img, 'b (t l) c -> (b t) l c', t=t) 560 | # context = torch.cat([context_text, context_img], dim=1) 561 | # else: 562 | # print('context shape:', context.shape, "x.shape:", x.shape, flush=True) 563 | if len(context) == b * t: 564 | pass 565 | else: 566 | context = context.repeat_interleave(repeats=t, dim=0) 567 | emb = emb.repeat_interleave(repeats=t, dim=0) 568 | 569 | ## always in shape (b t) c h w, except for temporal layer 570 | x = rearrange(x, 'b c t h w -> (b t) c h w') 571 | 572 | ## combine emb 573 | if self.fs_condition: 574 | if fs is None: 575 | fs = torch.tensor( 576 | [self.default_fs] * b, dtype=torch.long, device=x.device) 577 | fs_emb = timestep_embedding(fs, self.model_channels, repeat_only=False).type(x.dtype) 578 | 579 | fs_embed = self.fps_embedding(fs_emb) 580 | fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0) 581 | emb = emb + fs_embed 582 | 583 | h = x.type(self.dtype) 584 | adapter_idx = 0 585 | hs = [] 586 | for id, module in enumerate(self.input_blocks): 587 | h = module(h, emb, context=context, batch_size=b) 588 | if id ==0 and self.addition_attention: 589 | h = self.init_attn(h, emb, context=context, batch_size=b) 590 | ## plug-in adapter features 591 | if ((id+1)%3 == 0) and features_adapter is not None: 592 | h = h + features_adapter[adapter_idx] 593 | adapter_idx += 1 594 | hs.append(h) 595 | if features_adapter is not None: 596 | assert len(features_adapter)==adapter_idx, 'Wrong features_adapter' 597 | 598 | h = self.middle_block(h, emb, context=context, batch_size=b) 599 | for module in self.output_blocks: 600 | h = torch.cat([h, hs.pop()], dim=1) 601 | h = module(h, emb, context=context, batch_size=b) 602 | h = h.type(x.dtype) 603 | y = self.out(h) 604 | 605 | # reshape back to (b c t h w) 606 | y = rearrange(y, '(b t) c h w -> b c t h w', b=b) 607 | return y -------------------------------------------------------------------------------- /lvdm/modules/x_transformer.py: -------------------------------------------------------------------------------- 1 | """shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers""" 2 | from functools import partial 3 | from inspect import isfunction 4 | from collections import namedtuple 5 | from einops import rearrange, repeat 6 | import torch 7 | from torch import nn, einsum 8 | import torch.nn.functional as F 9 | 10 | # constants 11 | DEFAULT_DIM_HEAD = 64 12 | 13 | Intermediates = namedtuple('Intermediates', [ 14 | 'pre_softmax_attn', 15 | 'post_softmax_attn' 16 | ]) 17 | 18 | LayerIntermediates = namedtuple('Intermediates', [ 19 | 'hiddens', 20 | 'attn_intermediates' 21 | ]) 22 | 23 | 24 | class AbsolutePositionalEmbedding(nn.Module): 25 | def __init__(self, dim, max_seq_len): 26 | super().__init__() 27 | self.emb = nn.Embedding(max_seq_len, dim) 28 | self.init_() 29 | 30 | def init_(self): 31 | nn.init.normal_(self.emb.weight, std=0.02) 32 | 33 | def forward(self, x): 34 | n = torch.arange(x.shape[1], device=x.device) 35 | return self.emb(n)[None, :, :] 36 | 37 | 38 | class FixedPositionalEmbedding(nn.Module): 39 | def __init__(self, dim): 40 | super().__init__() 41 | inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) 42 | self.register_buffer('inv_freq', inv_freq) 43 | 44 | def forward(self, x, seq_dim=1, offset=0): 45 | t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset 46 | sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) 47 | emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) 48 | return emb[None, :, :] 49 | 50 | 51 | # helpers 52 | 53 | def exists(val): 54 | return val is not None 55 | 56 | 57 | def default(val, d): 58 | if exists(val): 59 | return val 60 | return d() if isfunction(d) else d 61 | 62 | 63 | def always(val): 64 | def inner(*args, **kwargs): 65 | return val 66 | return inner 67 | 68 | 69 | def not_equals(val): 70 | def inner(x): 71 | return x != val 72 | return inner 73 | 74 | 75 | def equals(val): 76 | def inner(x): 77 | return x == val 78 | return inner 79 | 80 | 81 | def max_neg_value(tensor): 82 | return -torch.finfo(tensor.dtype).max 83 | 84 | 85 | # keyword argument helpers 86 | 87 | def pick_and_pop(keys, d): 88 | values = list(map(lambda key: d.pop(key), keys)) 89 | return dict(zip(keys, values)) 90 | 91 | 92 | def group_dict_by_key(cond, d): 93 | return_val = [dict(), dict()] 94 | for key in d.keys(): 95 | match = bool(cond(key)) 96 | ind = int(not match) 97 | return_val[ind][key] = d[key] 98 | return (*return_val,) 99 | 100 | 101 | def string_begins_with(prefix, str): 102 | return str.startswith(prefix) 103 | 104 | 105 | def group_by_key_prefix(prefix, d): 106 | return group_dict_by_key(partial(string_begins_with, prefix), d) 107 | 108 | 109 | def groupby_prefix_and_trim(prefix, d): 110 | kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) 111 | kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) 112 | return kwargs_without_prefix, kwargs 113 | 114 | 115 | # classes 116 | class Scale(nn.Module): 117 | def __init__(self, value, fn): 118 | super().__init__() 119 | self.value = value 120 | self.fn = fn 121 | 122 | def forward(self, x, **kwargs): 123 | x, *rest = self.fn(x, **kwargs) 124 | return (x * self.value, *rest) 125 | 126 | 127 | class Rezero(nn.Module): 128 | def __init__(self, fn): 129 | super().__init__() 130 | self.fn = fn 131 | self.g = nn.Parameter(torch.zeros(1)) 132 | 133 | def forward(self, x, **kwargs): 134 | x, *rest = self.fn(x, **kwargs) 135 | return (x * self.g, *rest) 136 | 137 | 138 | class ScaleNorm(nn.Module): 139 | def __init__(self, dim, eps=1e-5): 140 | super().__init__() 141 | self.scale = dim ** -0.5 142 | self.eps = eps 143 | self.g = nn.Parameter(torch.ones(1)) 144 | 145 | def forward(self, x): 146 | norm = torch.norm(x, dim=-1, keepdim=True) * self.scale 147 | return x / norm.clamp(min=self.eps) * self.g 148 | 149 | 150 | class RMSNorm(nn.Module): 151 | def __init__(self, dim, eps=1e-8): 152 | super().__init__() 153 | self.scale = dim ** -0.5 154 | self.eps = eps 155 | self.g = nn.Parameter(torch.ones(dim)) 156 | 157 | def forward(self, x): 158 | norm = torch.norm(x, dim=-1, keepdim=True) * self.scale 159 | return x / norm.clamp(min=self.eps) * self.g 160 | 161 | 162 | class Residual(nn.Module): 163 | def forward(self, x, residual): 164 | return x + residual 165 | 166 | 167 | class GRUGating(nn.Module): 168 | def __init__(self, dim): 169 | super().__init__() 170 | self.gru = nn.GRUCell(dim, dim) 171 | 172 | def forward(self, x, residual): 173 | gated_output = self.gru( 174 | rearrange(x, 'b n d -> (b n) d'), 175 | rearrange(residual, 'b n d -> (b n) d') 176 | ) 177 | 178 | return gated_output.reshape_as(x) 179 | 180 | 181 | # feedforward 182 | 183 | class GEGLU(nn.Module): 184 | def __init__(self, dim_in, dim_out): 185 | super().__init__() 186 | self.proj = nn.Linear(dim_in, dim_out * 2) 187 | 188 | def forward(self, x): 189 | x, gate = self.proj(x).chunk(2, dim=-1) 190 | return x * F.gelu(gate) 191 | 192 | 193 | class FeedForward(nn.Module): 194 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 195 | super().__init__() 196 | inner_dim = int(dim * mult) 197 | dim_out = default(dim_out, dim) 198 | project_in = nn.Sequential( 199 | nn.Linear(dim, inner_dim), 200 | nn.GELU() 201 | ) if not glu else GEGLU(dim, inner_dim) 202 | 203 | self.net = nn.Sequential( 204 | project_in, 205 | nn.Dropout(dropout), 206 | nn.Linear(inner_dim, dim_out) 207 | ) 208 | 209 | def forward(self, x): 210 | return self.net(x) 211 | 212 | 213 | # attention. 214 | class Attention(nn.Module): 215 | def __init__( 216 | self, 217 | dim, 218 | dim_head=DEFAULT_DIM_HEAD, 219 | heads=8, 220 | causal=False, 221 | mask=None, 222 | talking_heads=False, 223 | sparse_topk=None, 224 | use_entmax15=False, 225 | num_mem_kv=0, 226 | dropout=0., 227 | on_attn=False 228 | ): 229 | super().__init__() 230 | if use_entmax15: 231 | raise NotImplementedError("Check out entmax activation instead of softmax activation!") 232 | self.scale = dim_head ** -0.5 233 | self.heads = heads 234 | self.causal = causal 235 | self.mask = mask 236 | 237 | inner_dim = dim_head * heads 238 | 239 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 240 | self.to_k = nn.Linear(dim, inner_dim, bias=False) 241 | self.to_v = nn.Linear(dim, inner_dim, bias=False) 242 | self.dropout = nn.Dropout(dropout) 243 | 244 | # talking heads 245 | self.talking_heads = talking_heads 246 | if talking_heads: 247 | self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads)) 248 | self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads)) 249 | 250 | # explicit topk sparse attention 251 | self.sparse_topk = sparse_topk 252 | 253 | # entmax 254 | #self.attn_fn = entmax15 if use_entmax15 else F.softmax 255 | self.attn_fn = F.softmax 256 | 257 | # add memory key / values 258 | self.num_mem_kv = num_mem_kv 259 | if num_mem_kv > 0: 260 | self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) 261 | self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) 262 | 263 | # attention on attention 264 | self.attn_on_attn = on_attn 265 | self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim) 266 | 267 | def forward( 268 | self, 269 | x, 270 | context=None, 271 | mask=None, 272 | context_mask=None, 273 | rel_pos=None, 274 | sinusoidal_emb=None, 275 | prev_attn=None, 276 | mem=None 277 | ): 278 | b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device 279 | kv_input = default(context, x) 280 | 281 | q_input = x 282 | k_input = kv_input 283 | v_input = kv_input 284 | 285 | if exists(mem): 286 | k_input = torch.cat((mem, k_input), dim=-2) 287 | v_input = torch.cat((mem, v_input), dim=-2) 288 | 289 | if exists(sinusoidal_emb): 290 | # in shortformer, the query would start at a position offset depending on the past cached memory 291 | offset = k_input.shape[-2] - q_input.shape[-2] 292 | q_input = q_input + sinusoidal_emb(q_input, offset=offset) 293 | k_input = k_input + sinusoidal_emb(k_input) 294 | 295 | q = self.to_q(q_input) 296 | k = self.to_k(k_input) 297 | v = self.to_v(v_input) 298 | 299 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) 300 | 301 | input_mask = None 302 | if any(map(exists, (mask, context_mask))): 303 | q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) 304 | k_mask = q_mask if not exists(context) else context_mask 305 | k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) 306 | q_mask = rearrange(q_mask, 'b i -> b () i ()') 307 | k_mask = rearrange(k_mask, 'b j -> b () () j') 308 | input_mask = q_mask * k_mask 309 | 310 | if self.num_mem_kv > 0: 311 | mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) 312 | k = torch.cat((mem_k, k), dim=-2) 313 | v = torch.cat((mem_v, v), dim=-2) 314 | if exists(input_mask): 315 | input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) 316 | 317 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 318 | mask_value = max_neg_value(dots) 319 | 320 | if exists(prev_attn): 321 | dots = dots + prev_attn 322 | 323 | pre_softmax_attn = dots 324 | 325 | if talking_heads: 326 | dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous() 327 | 328 | if exists(rel_pos): 329 | dots = rel_pos(dots) 330 | 331 | if exists(input_mask): 332 | dots.masked_fill_(~input_mask, mask_value) 333 | del input_mask 334 | 335 | if self.causal: 336 | i, j = dots.shape[-2:] 337 | r = torch.arange(i, device=device) 338 | mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j') 339 | mask = F.pad(mask, (j - i, 0), value=False) 340 | dots.masked_fill_(mask, mask_value) 341 | del mask 342 | 343 | if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: 344 | top, _ = dots.topk(self.sparse_topk, dim=-1) 345 | vk = top[..., -1].unsqueeze(-1).expand_as(dots) 346 | mask = dots < vk 347 | dots.masked_fill_(mask, mask_value) 348 | del mask 349 | 350 | attn = self.attn_fn(dots, dim=-1) 351 | post_softmax_attn = attn 352 | 353 | attn = self.dropout(attn) 354 | 355 | if talking_heads: 356 | attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous() 357 | 358 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 359 | out = rearrange(out, 'b h n d -> b n (h d)') 360 | 361 | intermediates = Intermediates( 362 | pre_softmax_attn=pre_softmax_attn, 363 | post_softmax_attn=post_softmax_attn 364 | ) 365 | 366 | return self.to_out(out), intermediates 367 | 368 | 369 | class AttentionLayers(nn.Module): 370 | def __init__( 371 | self, 372 | dim, 373 | depth, 374 | heads=8, 375 | causal=False, 376 | cross_attend=False, 377 | only_cross=False, 378 | use_scalenorm=False, 379 | use_rmsnorm=False, 380 | use_rezero=False, 381 | rel_pos_num_buckets=32, 382 | rel_pos_max_distance=128, 383 | position_infused_attn=False, 384 | custom_layers=None, 385 | sandwich_coef=None, 386 | par_ratio=None, 387 | residual_attn=False, 388 | cross_residual_attn=False, 389 | macaron=False, 390 | pre_norm=True, 391 | gate_residual=False, 392 | **kwargs 393 | ): 394 | super().__init__() 395 | ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) 396 | attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs) 397 | 398 | dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) 399 | 400 | self.dim = dim 401 | self.depth = depth 402 | self.layers = nn.ModuleList([]) 403 | 404 | self.has_pos_emb = position_infused_attn 405 | self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None 406 | self.rotary_pos_emb = always(None) 407 | 408 | assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' 409 | self.rel_pos = None 410 | 411 | self.pre_norm = pre_norm 412 | 413 | self.residual_attn = residual_attn 414 | self.cross_residual_attn = cross_residual_attn 415 | 416 | norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm 417 | norm_class = RMSNorm if use_rmsnorm else norm_class 418 | norm_fn = partial(norm_class, dim) 419 | 420 | norm_fn = nn.Identity if use_rezero else norm_fn 421 | branch_fn = Rezero if use_rezero else None 422 | 423 | if cross_attend and not only_cross: 424 | default_block = ('a', 'c', 'f') 425 | elif cross_attend and only_cross: 426 | default_block = ('c', 'f') 427 | else: 428 | default_block = ('a', 'f') 429 | 430 | if macaron: 431 | default_block = ('f',) + default_block 432 | 433 | if exists(custom_layers): 434 | layer_types = custom_layers 435 | elif exists(par_ratio): 436 | par_depth = depth * len(default_block) 437 | assert 1 < par_ratio <= par_depth, 'par ratio out of range' 438 | default_block = tuple(filter(not_equals('f'), default_block)) 439 | par_attn = par_depth // par_ratio 440 | depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper 441 | par_width = (depth_cut + depth_cut // par_attn) // par_attn 442 | assert len(default_block) <= par_width, 'default block is too large for par_ratio' 443 | par_block = default_block + ('f',) * (par_width - len(default_block)) 444 | par_head = par_block * par_attn 445 | layer_types = par_head + ('f',) * (par_depth - len(par_head)) 446 | elif exists(sandwich_coef): 447 | assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' 448 | layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef 449 | else: 450 | layer_types = default_block * depth 451 | 452 | self.layer_types = layer_types 453 | self.num_attn_layers = len(list(filter(equals('a'), layer_types))) 454 | 455 | for layer_type in self.layer_types: 456 | if layer_type == 'a': 457 | layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) 458 | elif layer_type == 'c': 459 | layer = Attention(dim, heads=heads, **attn_kwargs) 460 | elif layer_type == 'f': 461 | layer = FeedForward(dim, **ff_kwargs) 462 | layer = layer if not macaron else Scale(0.5, layer) 463 | else: 464 | raise Exception(f'invalid layer type {layer_type}') 465 | 466 | if isinstance(layer, Attention) and exists(branch_fn): 467 | layer = branch_fn(layer) 468 | 469 | if gate_residual: 470 | residual_fn = GRUGating(dim) 471 | else: 472 | residual_fn = Residual() 473 | 474 | self.layers.append(nn.ModuleList([ 475 | norm_fn(), 476 | layer, 477 | residual_fn 478 | ])) 479 | 480 | def forward( 481 | self, 482 | x, 483 | context=None, 484 | mask=None, 485 | context_mask=None, 486 | mems=None, 487 | return_hiddens=False 488 | ): 489 | hiddens = [] 490 | intermediates = [] 491 | prev_attn = None 492 | prev_cross_attn = None 493 | 494 | mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers 495 | 496 | for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): 497 | is_last = ind == (len(self.layers) - 1) 498 | 499 | if layer_type == 'a': 500 | hiddens.append(x) 501 | layer_mem = mems.pop(0) 502 | 503 | residual = x 504 | 505 | if self.pre_norm: 506 | x = norm(x) 507 | 508 | if layer_type == 'a': 509 | out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos, 510 | prev_attn=prev_attn, mem=layer_mem) 511 | elif layer_type == 'c': 512 | out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn) 513 | elif layer_type == 'f': 514 | out = block(x) 515 | 516 | x = residual_fn(out, residual) 517 | 518 | if layer_type in ('a', 'c'): 519 | intermediates.append(inter) 520 | 521 | if layer_type == 'a' and self.residual_attn: 522 | prev_attn = inter.pre_softmax_attn 523 | elif layer_type == 'c' and self.cross_residual_attn: 524 | prev_cross_attn = inter.pre_softmax_attn 525 | 526 | if not self.pre_norm and not is_last: 527 | x = norm(x) 528 | 529 | if return_hiddens: 530 | intermediates = LayerIntermediates( 531 | hiddens=hiddens, 532 | attn_intermediates=intermediates 533 | ) 534 | 535 | return x, intermediates 536 | 537 | return x 538 | 539 | 540 | class Encoder(AttentionLayers): 541 | def __init__(self, **kwargs): 542 | assert 'causal' not in kwargs, 'cannot set causality on encoder' 543 | super().__init__(causal=False, **kwargs) 544 | 545 | 546 | 547 | class TransformerWrapper(nn.Module): 548 | def __init__( 549 | self, 550 | *, 551 | num_tokens, 552 | max_seq_len, 553 | attn_layers, 554 | emb_dim=None, 555 | max_mem_len=0., 556 | emb_dropout=0., 557 | num_memory_tokens=None, 558 | tie_embedding=False, 559 | use_pos_emb=True 560 | ): 561 | super().__init__() 562 | assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' 563 | 564 | dim = attn_layers.dim 565 | emb_dim = default(emb_dim, dim) 566 | 567 | self.max_seq_len = max_seq_len 568 | self.max_mem_len = max_mem_len 569 | self.num_tokens = num_tokens 570 | 571 | self.token_emb = nn.Embedding(num_tokens, emb_dim) 572 | self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( 573 | use_pos_emb and not attn_layers.has_pos_emb) else always(0) 574 | self.emb_dropout = nn.Dropout(emb_dropout) 575 | 576 | self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() 577 | self.attn_layers = attn_layers 578 | self.norm = nn.LayerNorm(dim) 579 | 580 | self.init_() 581 | 582 | self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() 583 | 584 | # memory tokens (like [cls]) from Memory Transformers paper 585 | num_memory_tokens = default(num_memory_tokens, 0) 586 | self.num_memory_tokens = num_memory_tokens 587 | if num_memory_tokens > 0: 588 | self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) 589 | 590 | # let funnel encoder know number of memory tokens, if specified 591 | if hasattr(attn_layers, 'num_memory_tokens'): 592 | attn_layers.num_memory_tokens = num_memory_tokens 593 | 594 | def init_(self): 595 | nn.init.normal_(self.token_emb.weight, std=0.02) 596 | 597 | def forward( 598 | self, 599 | x, 600 | return_embeddings=False, 601 | mask=None, 602 | return_mems=False, 603 | return_attn=False, 604 | mems=None, 605 | **kwargs 606 | ): 607 | b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens 608 | x = self.token_emb(x) 609 | x += self.pos_emb(x) 610 | x = self.emb_dropout(x) 611 | 612 | x = self.project_emb(x) 613 | 614 | if num_mem > 0: 615 | mem = repeat(self.memory_tokens, 'n d -> b n d', b=b) 616 | x = torch.cat((mem, x), dim=1) 617 | 618 | # auto-handle masking after appending memory tokens 619 | if exists(mask): 620 | mask = F.pad(mask, (num_mem, 0), value=True) 621 | 622 | x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) 623 | x = self.norm(x) 624 | 625 | mem, x = x[:, :num_mem], x[:, num_mem:] 626 | 627 | out = self.to_logits(x) if not return_embeddings else x 628 | 629 | if return_mems: 630 | hiddens = intermediates.hiddens 631 | new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens 632 | new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems)) 633 | return out, new_mems 634 | 635 | if return_attn: 636 | attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) 637 | return out, attn_maps 638 | 639 | return out -------------------------------------------------------------------------------- /lvdm/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import importlib 5 | 6 | 7 | def instantiate_from_config(config): 8 | if not "target" in config: 9 | if config == '__is_first_stage__': 10 | return None 11 | elif config == "__is_unconditional__": 12 | return None 13 | raise KeyError("Expected key `target` to instantiate.") 14 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 15 | 16 | 17 | def get_obj_from_str(string, reload=False): 18 | module, cls = string.rsplit(".", 1) 19 | if reload: 20 | module_imp = importlib.import_module(module) 21 | importlib.reload(module_imp) 22 | return getattr(importlib.import_module(module, package=None), cls) 23 | 24 | 25 | def load_model_checkpoint(model, ckpt): 26 | assert os.path.exists(ckpt), f"Error: checkpoint file '{ckpt}' not found!" 27 | 28 | state_dict = torch.load(ckpt, map_location="cpu", weights_only=True) 29 | if "state_dict" in list(state_dict.keys()): 30 | state_dict = state_dict["state_dict"] 31 | for k in list(state_dict.keys()): 32 | if "framestride_embed" in k: 33 | new_key = k.replace("framestride_embed", "fps_embedding") 34 | state_dict[new_key] = state_dict[k] 35 | del state_dict[k] 36 | model.load_state_dict(state_dict, strict=True) 37 | print(f">>> model checkpoint '{ckpt}' loaded.", flush=True) 38 | return model 39 | 40 | 41 | def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1): 42 | 43 | def lr_lambda(current_step): 44 | if current_step < num_warmup_steps: 45 | return float(current_step) / float(max(1, num_warmup_steps)) 46 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 47 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) 48 | 49 | return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch) 50 | 51 | 52 | class AverageMeter: 53 | 54 | def __init__(self): 55 | self.reset() 56 | 57 | def reset(self): 58 | self._sum = 0 59 | self._count = 0 60 | 61 | def update(self, val, n=1): 62 | self._sum += val * n 63 | self._count += n 64 | 65 | @property 66 | def value(self): 67 | if self._count == 0: 68 | return 0 69 | return self._sum / self._count 70 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import numpy as np 5 | 6 | from PIL import Image 7 | from omegaconf import OmegaConf 8 | from einops import rearrange, repeat 9 | 10 | from lvdm.models.samplers.ddim import DDIMSampler 11 | from lvdm.utils import instantiate_from_config 12 | 13 | 14 | def get_latent_z(model, videos): 15 | b, c, t, h, w = videos.shape 16 | x = rearrange(videos, 'b c t h w -> (b t) c h w') 17 | z = model.encode_first_stage(x) 18 | z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t) 19 | return z 20 | 21 | 22 | @torch.no_grad() 23 | def image_guided_synthesis(model, prompts, videos, noise_shape, n_samples=1, ddim_steps=50, ddim_eta=1., 24 | unconditional_guidance_scale=1.0, **kwargs): 25 | ddim_sampler = DDIMSampler(model) 26 | batch_size = noise_shape[0] 27 | fs = torch.tensor([1.] * batch_size, dtype=torch.long, device=model.device) 28 | 29 | img = videos[:, :, 0] # bchw 30 | img_emb = model.embedder(img) # blc 31 | img_emb = model.image_proj_model(img_emb) 32 | 33 | cond_emb = model.get_learned_conditioning(prompts) 34 | (_B, _, _T, _, _), tB = videos.shape, cond_emb.shape[0] 35 | if tB != _B: # in case we have multiple prompts for a single video 36 | assert _B * _T == tB, f"{_B} * {_T} != {tB}" 37 | img_emb = img_emb.repeat_interleave(repeats=_T, dim=0) 38 | cond = {"c_crossattn": [torch.cat([cond_emb, img_emb], dim=1)]} 39 | 40 | z = get_latent_z(model, videos) # b c t h w 41 | img_cat_cond = z[:, :, :1, :, :] 42 | img_cat_cond = repeat(img_cat_cond, 'b c t h w -> b c (repeat t) h w', repeat=z.shape[2]) 43 | cond["c_concat"] = [img_cat_cond] # b c 1 h w 44 | 45 | if unconditional_guidance_scale != 1.0: 46 | uc_emb = model.get_learned_conditioning(batch_size * [""]) 47 | 48 | uc_img_emb = model.embedder(torch.zeros_like(img)) ## b l c 49 | uc_img_emb = model.image_proj_model(uc_img_emb) 50 | uc = { 51 | "c_crossattn": [torch.cat([uc_emb,uc_img_emb], dim=1)], 52 | "c_concat": [img_cat_cond] 53 | } 54 | else: 55 | uc = None 56 | 57 | z0 = None 58 | cond_mask = None 59 | x_T = None 60 | timesteps = None 61 | 62 | batch_variants = [] 63 | for _ in range(n_samples): 64 | if z0 is not None: 65 | cond_z0 = z0.clone() 66 | kwargs.update({"clean_cond": True}) 67 | else: 68 | cond_z0 = None 69 | if ddim_sampler is not None: 70 | samples, _ = ddim_sampler.sample(S=ddim_steps, 71 | conditioning=cond, 72 | batch_size=batch_size, 73 | shape=noise_shape[1:], 74 | verbose=True, 75 | unconditional_guidance_scale=unconditional_guidance_scale, 76 | unconditional_conditioning=uc, 77 | eta=ddim_eta, 78 | mask=cond_mask, 79 | x0=cond_z0, 80 | fs=fs, 81 | x_T=x_T, 82 | timesteps=timesteps, 83 | **kwargs) 84 | # reconstruct from latent to pixel space 85 | batch_images = model.decode_first_stage(samples) 86 | batch_variants.append(batch_images) 87 | 88 | # variants, batch, c, t, h, w 89 | batch_variants = torch.stack(batch_variants) 90 | return batch_variants.permute(1, 0, 2, 3, 4, 5) 91 | 92 | 93 | def main(args): 94 | config = OmegaConf.load("./configs/inference_256_v1.1.yaml")["model"] 95 | model = instantiate_from_config(config) 96 | model = model.cuda() 97 | 98 | state_dict = torch.load(args.ckpt_path, map_location="cpu", weights_only=False) 99 | model.load_state_dict(state_dict, strict=True) 100 | model.eval() 101 | 102 | with open(args.prompt_file, "r") as f: 103 | dataset = [] 104 | for line in f.readlines(): 105 | if line.strip() == "" or line.startswith("#"): 106 | continue 107 | image_path, prompts = line.split(args.delimiter, 1) 108 | prompts = [p.strip().strip('"') for p in prompts.split(args.delimiter)] 109 | 110 | dataset.append((image_path.strip(), prompts)) 111 | 112 | for idx, (image_path, prompts) in enumerate(dataset): 113 | print(f"Processing {idx + 1}/{len(dataset)}: {image_path}", flush=True) 114 | assert os.path.exists(image_path), f"Image not found: {image_path}" 115 | 116 | img = Image.open(image_path).convert("RGB") 117 | w, h = img.size 118 | if w > h: 119 | img = img.crop(((w - h) // 2, 0, h + (w - h) // 2, h)) 120 | elif h > w: 121 | img = img.crop((0, (h - w) // 2, w, w + (h - w) // 2)) 122 | img = img.resize((256, 256)) 123 | 124 | n_frames = len(prompts) 125 | noise_shape = [1, 4, n_frames, 32, 32] # B, C, T, H, W 126 | 127 | torch_img = torch.from_numpy(np.array(img).copy()).permute(2, 0, 1).float().div_(255 / 2).sub_(1) 128 | torch_img = torch_img.unsqueeze(1) # add temp dimension: 3, 1, 256, 256 129 | torch_img = torch_img.unsqueeze(0) # add batch size: 1, 3, 1, 256, 256 130 | torch_img = repeat(torch_img, 'b c t h w -> b c (repeat t) h w', repeat=n_frames) 131 | 132 | torch_img = torch_img.cuda() 133 | samples = image_guided_synthesis( 134 | model, prompts, torch_img, noise_shape, ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta, 135 | unconditional_guidance_scale=args.unconditional_guidance_scale) 136 | 137 | # n_samples=1, B=1, C, T, h, w 138 | samples = samples.squeeze(0).squeeze(0) 139 | samples = samples.clamp_(-1, 1).add_(1.).mul_(255 / 2) 140 | samples = samples.to(torch.uint8).permute(1, 2, 3, 0) 141 | samples = samples.cpu().numpy() 142 | 143 | output_image = Image.fromarray(np.concatenate(samples, axis=1)) 144 | os.makedirs(args.output_dir, exist_ok=True) 145 | output_image.save(os.path.join(args.output_dir, f"{idx:05d}.jpg")) 146 | 147 | 148 | if __name__ == '__main__': 149 | parser = argparse.ArgumentParser() 150 | parser.add_argument("--ckpt_path", type=str, required=True) 151 | parser.add_argument("--prompt_file", type=str, required=True, help="text file with image paths and prompts") 152 | parser.add_argument("--delimiter", type=str, required=True, help="delimiter for image paths and prompts") 153 | parser.add_argument("--output_dir", type=str, default="output", help="output directory for generated images") 154 | 155 | parser.add_argument("--ddim_steps", type=int, default=50, help="steps of ddim if positive, otherwise use DDPM") 156 | parser.add_argument("--ddim_eta", type=float, default=1.0, help="eta for ddim sampling (0.0 yields deterministic sampling)") 157 | parser.add_argument("--unconditional_guidance_scale", type=float, default=7.5, help="prompt classifier-free guidance") 158 | args = parser.parse_args() 159 | 160 | main(args) 161 | -------------------------------------------------------------------------------- /test_data/img01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soCzech/ShowHowTo/29700e2633ed7d3f7958c411a82a76585cd64536/test_data/img01.jpg -------------------------------------------------------------------------------- /test_data/prompt_file.txt: -------------------------------------------------------------------------------- 1 | # path/to/image | prompt for the input frame | prompt for the first genrated frame | ... 2 | test_data/img01.jpg|Boil a pot of water to skin the tomato and potato.|Cut a thin slice on the tomato skin and make a cross over on the other side to remove the skin.|Remove the skin from the potato by cutting a thin slice and making a cross over on the other side.|Add the vegetables to the pot of boiling water and let them cook for about 1 minute.|Remove the tomato and potato from the pot and peel off the skin.|Cut the other ingredients into smaller pieces.|Throw all the ingredients into a pot of boiling water.|Add salt and bring the heat to medium, then let the soup boil for about 3 to 4 hours. 3 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tqdm 3 | import torch 4 | import random 5 | import socket 6 | import argparse 7 | import torch.distributed as dist 8 | import torch.multiprocessing as mp 9 | 10 | from datetime import datetime 11 | from omegaconf import OmegaConf 12 | from einops import rearrange, repeat 13 | 14 | from lvdm.utils import instantiate_from_config, load_model_checkpoint, get_cosine_schedule_with_warmup, AverageMeter 15 | from video_dataset import sequence_collate, RepeatedDataset, ShowHowToDataset 16 | 17 | 18 | def get_latent_z(model, videos): 19 | b, c, t, h, w = videos.shape 20 | x = rearrange(videos, 'b c t h w -> (b t) c h w') 21 | z = model.encode_first_stage(x) 22 | z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t) 23 | return z 24 | 25 | 26 | def main(args): 27 | ngpus_per_node = torch.cuda.device_count() 28 | node_count = int(os.environ.get("SLURM_NPROCS", "1")) 29 | node_rank = int(os.environ.get("SLURM_PROCID", "0")) 30 | job_id = os.environ.get("SLURM_JOBID", "".join([str(random.randint(0, 9)) for _ in range(5)])) 31 | 32 | dist_url = "file://{}.{}".format(os.path.realpath("distfile"), job_id) 33 | print(f"Hi from node {socket.gethostname()} ({node_rank}/{node_count} with {ngpus_per_node} GPUs)!", flush=True) 34 | 35 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=({ 36 | "ngpus_per_node": ngpus_per_node, 37 | "node_count": node_count, 38 | "node_rank": node_rank, 39 | "dist_url": dist_url, 40 | "job_id": job_id 41 | }, args)) 42 | 43 | 44 | def main_worker(local_rank, cluster_args, args): 45 | world_size = cluster_args["node_count"] * cluster_args["ngpus_per_node"] 46 | global_rank = cluster_args["node_rank"] * cluster_args["ngpus_per_node"] + local_rank 47 | dist.init_process_group( 48 | backend="nccl", 49 | init_method=cluster_args["dist_url"], 50 | world_size=world_size, 51 | rank=global_rank, 52 | ) 53 | 54 | if global_rank == 0: 55 | store_dir = "logs/" + datetime.strftime(datetime.now(), "%Y-%m-%d_%H%M%S") 56 | for k, v in sorted(vars(args).items(), key=lambda x: x[0]): 57 | print(f"# {k}: {v}") 58 | print(f"# effective_batch_size: {world_size * args.local_batch_size}", flush=True) 59 | 60 | ############### 61 | # DATASET 62 | ############### 63 | n_epochs = 200 64 | save_every_n_epochs = 1 65 | 66 | train_ds = [] 67 | for i in range(2, args.max_seq_len + 1): 68 | train_ds.append(RepeatedDataset( 69 | ShowHowToDataset(args.dataset_root, video_length=i), epoch_len=16000)) 70 | 71 | train_samplers = [None for _ in train_ds] 72 | if world_size > 1: 73 | train_samplers = [torch.utils.data.distributed.DistributedSampler(ds, shuffle=True, drop_last=True) for ds in train_ds] 74 | 75 | train_ds_iters = [torch.utils.data.DataLoader( 76 | ds, batch_size=args.local_batch_size, shuffle=world_size == 1, drop_last=True, num_workers=1, 77 | pin_memory=True, sampler=train_sampler, collate_fn=sequence_collate) for ds, train_sampler in zip(train_ds, train_samplers)] 78 | 79 | ############### 80 | # MODEL 81 | ############### 82 | learning_rate = 2e-5 83 | 84 | config = OmegaConf.load("./configs/inference_256_v1.1.yaml")["model"] 85 | model = instantiate_from_config(config) 86 | model = load_model_checkpoint(model, args.ckpt_path) 87 | model.image_proj_model.requires_grad_(False) # torch sometimes thinks there are missing gradients w.r.t Resampler 88 | 89 | torch.cuda.set_device(local_rank) 90 | model.cuda(local_rank) 91 | model_parallel = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=False) 92 | print(f"Model distributed to gpu {global_rank}!", flush=True) 93 | 94 | ############### 95 | # OPTIMIZER 96 | ############### 97 | parameters2train = model_parallel.module.model.parameters() 98 | 99 | optim = torch.optim.AdamW(parameters2train, lr=learning_rate) 100 | scheduler = get_cosine_schedule_with_warmup(optim, len(train_ds_iters[0]), len(train_ds_iters[0]) * n_epochs) 101 | loss_metric = AverageMeter() 102 | 103 | for epoch in range(1, n_epochs + 1): 104 | if world_size > 1: 105 | for train_sampler in train_samplers: 106 | train_sampler.set_epoch(epoch) 107 | 108 | iterator = tqdm.tqdm(train_ds_iters[-1]) if global_rank == 0 else train_ds_iters[-1] 109 | other_iterators = [iter(ds) for ds in train_ds_iters[:-1]] 110 | for video_frames, prompts in iterator: 111 | # gather data for all lengths 112 | iterator_data = [(video_frames, prompts)] 113 | for other_iterator in other_iterators: 114 | iterator_data.append(next(other_iterator)) 115 | 116 | for video_frames, prompts in iterator_data: 117 | B, C, T, H, W = video_frames.shape 118 | frame_stride = torch.ones((B,), dtype=torch.long, device=model.device) 119 | 120 | with torch.no_grad(): 121 | img_emb = model.image_proj_model(model.embedder(video_frames[:, :, 0].to(model.device))) 122 | text_emb = model.get_learned_conditioning(prompts) 123 | z = get_latent_z(model, video_frames.to(model.device)) 124 | 125 | tB, tL, tC = text_emb.shape 126 | if tB != B: # in case we have multiple prompts for a single video 127 | assert B * T == tB, f"{B} * {T} != {tB}" 128 | img_emb = img_emb.repeat_interleave(repeats=T, dim=0) 129 | cond = { 130 | "c_crossattn": [torch.cat([text_emb, img_emb], dim=1)], 131 | "c_concat": [repeat(z[:, :, :1], 'b c t h w -> b c (repeat t) h w', repeat=T)] 132 | } 133 | 134 | t = torch.randint(0, model.num_timesteps, (z.shape[0],), device=model.device).long() 135 | noise = torch.randn_like(z) 136 | x_noisy = model.q_sample(x_start=z, t=t, noise=noise) 137 | 138 | model_output = model_parallel(x_noisy, t, cond, fs=frame_stride) 139 | 140 | loss = torch.nn.functional.mse_loss(noise, model_output, reduction='none') 141 | loss = loss.mean([1, 2, 3, 4]) 142 | loss = loss.mean() 143 | 144 | optim.zero_grad() 145 | loss.backward() # DistributedDataParallel does gradient averaging, i.e. loss is x-times smaller when trained on more GPUs 146 | optim.step() 147 | loss_metric.update(loss.item()) 148 | 149 | scheduler.step() 150 | 151 | if global_rank == 0: 152 | print(f"Epoch {epoch:03d} | Loss: {loss_metric.value:.4f}", flush=True) 153 | 154 | os.makedirs(store_dir, exist_ok=True) 155 | with open(os.path.join(store_dir, "losses.txt"), "a") as f: 156 | f.write(f"{epoch:03d}{loss_metric.value:12.4f}\n") 157 | loss_metric.reset() 158 | 159 | if epoch % save_every_n_epochs == 0: 160 | torch.save(model_parallel.module.state_dict(), os.path.join(store_dir, f"model_{epoch:03d}.pt")) 161 | 162 | 163 | if __name__ == '__main__': 164 | parser = argparse.ArgumentParser() 165 | parser.add_argument("--local_batch_size", type=int, default=1) 166 | parser.add_argument("--ckpt_path", type=str, default="./weights/dynamicrafter_256_v1.ckpt") 167 | parser.add_argument("--dataset_root", type=str, required=True) 168 | parser.add_argument("--max_seq_len", type=int, default=8) 169 | 170 | main(parser.parse_args()) 171 | -------------------------------------------------------------------------------- /video_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | import torch 5 | import random 6 | import numpy as np 7 | from PIL import Image 8 | from torch.utils.data import Dataset 9 | 10 | 11 | class ShowHowToDataset(Dataset): 12 | def __init__(self, root_path, video_length=None): 13 | self.root_path = root_path 14 | self.video_length = video_length 15 | 16 | with open(os.path.join(root_path, "prompts.json"), "r") as f: 17 | self.prompts = json.load(f) 18 | 19 | data = sorted(glob.glob(os.path.join(root_path, "imgseqs*", "*.jpg"))) 20 | self.data = [x for x in data if os.path.basename(x).replace(".jpg", "") in self.prompts] 21 | if self.video_length is not None: 22 | self.data = [x for x in self.data if len(self.prompts[os.path.basename(x).replace(".jpg", "")]) >= self.video_length] 23 | print(f"Found {len(self.data)} images with valid prompts and length >= {video_length}, other {len(data) - len(self.data)} removed") 24 | else: 25 | print(f"Found {len(self.data)} images with valid prompts, other {len(data) - len(self.data)} removed") 26 | 27 | def __len__(self): 28 | return len(self.data) 29 | 30 | def __getitem__(self, idx): 31 | img_fn = self.data[idx] 32 | vid_id = os.path.basename(img_fn).replace(".jpg", "") 33 | assert vid_id in self.prompts, f"prompt file not found for {img_fn} in prompts file!" 34 | prompts = self.prompts[vid_id] 35 | 36 | img = np.array(Image.open(img_fn)) 37 | h, w = img.shape[:2] 38 | w = w // len(prompts) 39 | 40 | imgs = [img[:, i * w:(i + 1) * w] for i in range(len(prompts))] 41 | imgs = np.stack(imgs, axis=0) 42 | if w < h: 43 | print(f"{img_fn} has width {w} and height {h}! Skipping...", flush=True) 44 | return self.__getitem__(random.randint(0, len(self.data) - 1)) 45 | else: 46 | imgs = imgs[:, :, (w - h) // 2:][:, :, :h] 47 | 48 | indices = np.arange(len(prompts)) 49 | if self.video_length is not None: 50 | indices = indices[np.random.randint(0, len(prompts) - self.video_length + 1):][:self.video_length] 51 | 52 | selected_prompts = [prompts[i] for i in indices] 53 | selected_imgs = imgs[indices] 54 | 55 | selected_imgs = np.stack([np.array(Image.fromarray(fr).resize((256, 256))) for fr in selected_imgs], axis=0) 56 | video_frames = torch.from_numpy(selected_imgs.copy()).float().div_(255 / 2).sub_(1).permute(3, 0, 1, 2) 57 | 58 | return video_frames, selected_prompts 59 | 60 | def __repr__(self): 61 | string = f"ShowHowToDataset(n_samples: {self.__len__()})" 62 | return string 63 | 64 | 65 | class RepeatedDataset(Dataset): 66 | 67 | def __init__(self, ds, epoch_len): 68 | self.ds = ds 69 | self.epoch_len = epoch_len 70 | 71 | def __getitem__(self, idx): 72 | return self.ds[random.randint(0, len(self.ds) - 1)] 73 | 74 | def __len__(self): 75 | return self.epoch_len 76 | 77 | def __repr__(self): 78 | string = f"RepeatedDataset(ds: {self.ds}, epoch_len: {self.epoch_len})" 79 | return string 80 | 81 | 82 | def sequence_collate(batch): 83 | video_frames, prompt = zip(*batch) 84 | video_frames = torch.stack(video_frames) 85 | if isinstance(prompt[0], list): 86 | prompt = [] 87 | for i in range(len(batch)): 88 | prompt.extend(batch[i][1]) 89 | return video_frames, prompt 90 | --------------------------------------------------------------------------------