├── requirements.txt ├── error.mp4 ├── .gitattributes ├── .github ├── ISSUE_TEMPLATE │ ├── config.yml │ ├── feature_request.yml │ └── bug_report.yml ├── workflows │ └── issue_checker.yaml ├── FUNDING.yml └── scripts │ └── issue_checker.py ├── install.py ├── javascript └── t2v_progressbar.js ├── scripts ├── videocrafter │ ├── sample_text2video.sh │ ├── lvdm │ │ ├── utils │ │ │ ├── dist_utils.py │ │ │ ├── common_utils.py │ │ │ └── saving_utils.py │ │ ├── models │ │ │ ├── modules │ │ │ │ ├── condition_modules.py │ │ │ │ ├── distributions.py │ │ │ │ ├── adapter.py │ │ │ │ └── util.py │ │ │ └── autoencoder.py │ │ ├── data │ │ │ └── webvid.py │ │ └── samplers │ │ │ └── ddim.py │ ├── ddp_wrapper.py │ ├── base_t2v │ │ └── model_config.yaml │ ├── sample_utils.py │ ├── process_videocrafter.py │ ├── sample_text2video_adapter.py │ └── sample_text2video.py ├── t2v_helpers │ ├── general_utils.py │ ├── extensions_utils.py │ ├── render.py │ ├── key_frames.py │ └── video_audio_utils.py ├── samplers │ ├── uni_pc │ │ └── sampler.py │ ├── samplers_common.py │ └── ddim │ │ └── gaussian_sampler.py ├── text2vid.py ├── api_t2v.py ├── stable_lora │ ├── scripts │ │ └── lora_webui.py │ └── stable_utils │ │ └── lora_processor.py └── modelscope │ └── process_modelscope.py ├── style.css ├── .gitignore └── README.md /requirements.txt: -------------------------------------------------------------------------------- 1 | imageio_ffmpeg 2 | av 3 | moviepy 4 | numexpr 5 | mutagen 6 | -------------------------------------------------------------------------------- /error.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kabachuha/sd-webui-text2video/HEAD/error.mp4 -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: Github discussions 4 | url: https://github.com/deforum-art/sd-webui-modelscope-text2video/discussions 5 | about: Please ask and answer questions here as well as share your art. But if you want to complain about something, don't try to circumvent issue filling by starting a discussion here 🙃 6 | -------------------------------------------------------------------------------- /install.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023 by Artem Khrapov (kabachuha) 2 | # Read LICENSE for usage terms. 3 | 4 | import launch 5 | 6 | import os 7 | 8 | req_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "requirements.txt") 9 | 10 | with open(req_file) as file: 11 | 12 | for lib in file: 13 | 14 | lib = lib.strip() 15 | 16 | if not launch.is_installed(lib): 17 | 18 | launch.run_pip(f"install {lib}", f"text2video requirement: {lib}") 19 | -------------------------------------------------------------------------------- /javascript/t2v_progressbar.js: -------------------------------------------------------------------------------- 1 | function submit_txt2vid(){ 2 | // rememberGallerySelection('txt2img_gallery') 3 | showSubmitButtons('text2vid', false) 4 | 5 | var id = randomId() 6 | // Using progressbar without the gallery 7 | requestProgress(id, gradioApp().getElementById('text2vid_results_panel'), null, function(){ 8 | showSubmitButtons('text2vid', true) 9 | }) 10 | 11 | var res = create_submit_args(arguments) 12 | 13 | res[0] = id 14 | 15 | return res 16 | } 17 | -------------------------------------------------------------------------------- /scripts/videocrafter/sample_text2video.sh: -------------------------------------------------------------------------------- 1 | 2 | PROMPT="astronaut riding a horse" # OR: PROMPT="input/prompts.txt" for sampling multiple prompts 3 | OUTDIR="results/" 4 | 5 | BASE_PATH="models/base_t2v/model.ckpt" 6 | CONFIG_PATH="models/base_t2v/model_config.yaml" 7 | 8 | python scripts/sample_text2video.py \ 9 | --ckpt_path $BASE_PATH \ 10 | --config_path $CONFIG_PATH \ 11 | --prompt "$PROMPT" \ 12 | --save_dir $OUTDIR \ 13 | --n_samples 1 \ 14 | --batch_size 1 \ 15 | --seed 1000 \ 16 | --show_denoising_progress 17 | -------------------------------------------------------------------------------- /.github/workflows/issue_checker.yaml: -------------------------------------------------------------------------------- 1 | name: Issue Checker 2 | 3 | on: 4 | issues: 5 | types: [opened, reopened, edited] 6 | 7 | jobs: 8 | check_issue: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - name: Checkout repository 12 | uses: actions/checkout@v3 13 | - name: Set up Python 14 | uses: actions/setup-python@v3 15 | with: 16 | python-version: '3.x' 17 | - name: Install dependencies 18 | run: pip install PyGithub 19 | - name: Check issue 20 | env: 21 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 22 | ISSUE_NUMBER: ${{ github.event.number }} 23 | run: python .github/scripts/issue_checker.py 24 | -------------------------------------------------------------------------------- /scripts/videocrafter/lvdm/utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | 4 | def setup_dist(local_rank): 5 | if dist.is_initialized(): 6 | return 7 | torch.cuda.set_device(local_rank) 8 | torch.distributed.init_process_group( 9 | 'nccl', 10 | init_method='env://' 11 | ) 12 | 13 | def gather_data(data, return_np=True): 14 | ''' gather data from multiple processes to one list ''' 15 | data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size())] 16 | dist.all_gather(data_list, data) # gather not supported with NCCL 17 | if return_np: 18 | data_list = [data.cpu().numpy() for data in data_list] 19 | return data_list 20 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 13 | custom: [https://etherscan.io/address/0x4c92637b8d3587383d50812f64a0dbd2a5426e81] 14 | -------------------------------------------------------------------------------- /style.css: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (C) 2023 by Artem Khrapov (kabachuha) 3 | Read LICENSE for usage terms. 4 | */ 5 | 6 | #vid_to_vid_chosen_file .w-full, #inpainting_chosen_file .w-full, #metadata_chosen_file .w-full { 7 | display: flex !important; 8 | align-items: flex-start !important; 9 | justify-content: center !important; 10 | height: 85px !important; 11 | } 12 | 13 | #vid_to_vid_chosen_file, #inpainting_chosen_file { 14 | height: 85px !important; 15 | } 16 | 17 | .generate-box{ 18 | position: relative; 19 | } 20 | .gradio-button.generate-box-skip, .gradio-button.generate-box-interrupt{ 21 | position: absolute; 22 | width: 50%; 23 | height: 100%; 24 | display: none; 25 | background: #b4c0cc; 26 | } 27 | .gradio-button.generate-box-skip:hover, .gradio-button.generate-box-interrupt:hover{ 28 | background: #c2cfdb; 29 | } 30 | .gradio-button.generate-box-interrupt{ 31 | left: 0; 32 | border-radius: 0.5rem 0 0 0.5rem; 33 | } 34 | .gradio-button.generate-box-skip{ 35 | right: 0; 36 | border-radius: 0 0.5rem 0.5rem 0; 37 | } 38 | -------------------------------------------------------------------------------- /scripts/t2v_helpers/general_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023 by Artem Khrapov (kabachuha) 2 | # Read LICENSE for usage terms. 3 | from modules.prompt_parser import reconstruct_cond_batch 4 | import os 5 | import modules.paths as ph 6 | 7 | def get_t2v_version(): 8 | from modules import extensions as mext 9 | try: 10 | for ext in mext.extensions: 11 | if (ext.name in ["sd-webui-modelscope-text2video"] or ext.name in ["sd-webui-text2video"]) and ext.enabled: 12 | return ext.version 13 | return "Unknown" 14 | except: 15 | return "Unknown" 16 | 17 | def get_model_location(model_name): 18 | assert model_name is not None 19 | 20 | if model_name == "": 21 | return os.path.join(ph.models_path, 'ModelScope/t2v') 22 | elif model_name == "": 23 | return os.path.join(ph.models_path, 'VideoCrafter') 24 | else: 25 | return os.path.join(ph.models_path, 'text2video/', model_name) 26 | 27 | def reconstruct_conds(cond, uncond, step): 28 | c = reconstruct_cond_batch(cond, step) 29 | uc = reconstruct_cond_batch(uncond, step) 30 | return c, uc 31 | -------------------------------------------------------------------------------- /scripts/videocrafter/lvdm/models/modules/condition_modules.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from transformers import logging 3 | from transformers import CLIPTokenizer, CLIPTextModel 4 | logging.set_verbosity_error() 5 | 6 | 7 | class AbstractEncoder(nn.Module): 8 | def __init__(self): 9 | super().__init__() 10 | 11 | def encode(self, *args, **kwargs): 12 | raise NotImplementedError 13 | 14 | 15 | class FrozenCLIPEmbedder(AbstractEncoder): 16 | """Uses the CLIP transformer encoder for text (from huggingface)""" 17 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): 18 | super().__init__() 19 | self.tokenizer = CLIPTokenizer.from_pretrained(version) 20 | self.transformer = CLIPTextModel.from_pretrained(version) 21 | self.device = device 22 | self.max_length = max_length 23 | self.freeze() 24 | 25 | def freeze(self): 26 | self.transformer = self.transformer.eval() 27 | for param in self.parameters(): 28 | param.requires_grad = False 29 | 30 | def forward(self, text): 31 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 32 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 33 | tokens = batch_encoding["input_ids"].to(self.device) 34 | outputs = self.transformer(input_ids=tokens) 35 | 36 | z = outputs.last_hidden_state 37 | return z 38 | 39 | def encode(self, text): 40 | return self(text) 41 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yml: -------------------------------------------------------------------------------- 1 | name: Feature request 2 | description: Suggest an idea for the ModelScope text2video extension 3 | title: "[Feature Request]: " 4 | labels: ["enhancement"] 5 | 6 | body: 7 | - type: checkboxes 8 | attributes: 9 | label: Is there an existing issue for this? 10 | description: Please search to see if an issue already exists for the feature you want, and that it's not implemented in a recent build/commit. 11 | options: 12 | - label: I have searched the existing issues and checked the recent builds/commits 13 | required: true 14 | - type: markdown 15 | attributes: 16 | value: | 17 | *Please fill this form with as much information as possible, provide screenshots and/or illustrations of the feature if possible* 18 | - type: textarea 19 | id: feature 20 | attributes: 21 | label: What would your feature do ? 22 | description: Tell us about your feature in a very clear and simple way, and what problem it would solve 23 | validations: 24 | required: true 25 | - type: textarea 26 | id: workflow 27 | attributes: 28 | label: Proposed workflow 29 | description: Please provide us with step by step information on how you'd like the feature to be accessed and used 30 | value: | 31 | 1. Go to .... 32 | 2. Press .... 33 | 3. ... 34 | validations: 35 | required: true 36 | - type: textarea 37 | id: misc 38 | attributes: 39 | label: Additional information 40 | description: Add any other context or screenshots about the feature request here. 41 | 42 | -------------------------------------------------------------------------------- /scripts/videocrafter/ddp_wrapper.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import argparse, importlib 3 | from pytorch_lightning import seed_everything 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | 9 | def setup_dist(local_rank): 10 | if dist.is_initialized(): 11 | return 12 | torch.cuda.set_device(local_rank) 13 | torch.distributed.init_process_group('nccl', init_method='env://') 14 | 15 | 16 | def get_dist_info(): 17 | if dist.is_available(): 18 | initialized = dist.is_initialized() 19 | else: 20 | initialized = False 21 | if initialized: 22 | rank = dist.get_rank() 23 | world_size = dist.get_world_size() 24 | else: 25 | rank = 0 26 | world_size = 1 27 | return rank, world_size 28 | 29 | 30 | if __name__ == '__main__': 31 | now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument("--module", type=str, help="module name", default="inference") 34 | parser.add_argument("--local_rank", type=int, nargs="?", help="for ddp", default=0) 35 | args, unknown = parser.parse_known_args() 36 | inference_api = importlib.import_module(args.module, package=None) 37 | 38 | inference_parser = inference_api.get_parser() 39 | inference_args, unknown = inference_parser.parse_known_args() 40 | 41 | seed_everything(inference_args.seed) 42 | setup_dist(args.local_rank) 43 | torch.backends.cudnn.benchmark = True 44 | rank, gpu_num = get_dist_info() 45 | 46 | print("@CoVideoGen Inference [rank%d]: %s"%(rank, now)) 47 | inference_api.run_inference(inference_args, rank) -------------------------------------------------------------------------------- /scripts/videocrafter/base_t2v/model_config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: lvdm.models.ddpm3d.LatentDiffusion 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.012 6 | num_timesteps_cond: 1 7 | log_every_t: 200 8 | timesteps: 1000 9 | first_stage_key: video 10 | cond_stage_key: caption 11 | image_size: 12 | - 32 13 | - 32 14 | video_length: 16 15 | channels: 4 16 | cond_stage_trainable: false 17 | conditioning_key: crossattn 18 | scale_by_std: false 19 | scale_factor: 0.18215 20 | 21 | unet_config: 22 | target: lvdm.models.modules.openaimodel3d.UNetModel 23 | params: 24 | image_size: 32 25 | in_channels: 4 26 | out_channels: 4 27 | model_channels: 320 28 | attention_resolutions: 29 | - 4 30 | - 2 31 | - 1 32 | num_res_blocks: 2 33 | channel_mult: 34 | - 1 35 | - 2 36 | - 4 37 | - 4 38 | num_heads: 8 39 | transformer_depth: 1 40 | context_dim: 768 41 | use_checkpoint: true 42 | legacy: false 43 | kernel_size_t: 1 44 | padding_t: 0 45 | temporal_length: 16 46 | use_relative_position: true 47 | 48 | first_stage_config: 49 | target: lvdm.models.autoencoder.AutoencoderKL 50 | params: 51 | embed_dim: 4 52 | monitor: val/rec_loss 53 | ddconfig: 54 | double_z: true 55 | z_channels: 4 56 | resolution: 256 57 | in_channels: 3 58 | out_ch: 3 59 | ch: 128 60 | ch_mult: 61 | - 1 62 | - 2 63 | - 4 64 | - 4 65 | num_res_blocks: 2 66 | attn_resolutions: [] 67 | dropout: 0.0 68 | lossconfig: 69 | target: torch.nn.Identity 70 | 71 | cond_stage_config: 72 | target: lvdm.models.modules.condition_modules.FrozenCLIPEmbedder -------------------------------------------------------------------------------- /scripts/t2v_helpers/extensions_utils.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | 3 | class Text2VideoExtension(object): 4 | """ 5 | A simple base class that sets a definitive way to process extensions 6 | """ 7 | 8 | def __init__(self, extension_name: str = '', extension_title: str = ''): 9 | 10 | self.extension_name = extension_name 11 | self.extension_title = extension_title 12 | self.return_args_delimiter = f"extension_{extension_name}" 13 | 14 | def return_ui_inputs(self, return_args: list = [] ): 15 | """ 16 | All extensions should use this method to return Gradio inputs. 17 | This allows for tracking the inputs using a delimiter. 18 | Arguments are automatically processed and returned. 19 | 20 | Output: + [arg1, arg2, arg3] + 21 | """ 22 | 23 | delimiter = gr.State(self.return_args_delimiter) 24 | return [delimiter] + return_args + [delimiter] 25 | 26 | def process_extension_args(self, all_args: list = []): 27 | """ 28 | Processes all extension arguments and appends them into a list. 29 | The filtered arguments are piped into the extension's process method. 30 | """ 31 | 32 | can_append = False 33 | extension_args = [] 34 | 35 | for value in all_args: 36 | if value == self.return_args_delimiter and not can_append: 37 | can_append = True 38 | continue 39 | 40 | if can_append: 41 | if value == self.return_args_delimiter: 42 | break 43 | else: 44 | extension_args.append(value) 45 | 46 | return extension_args 47 | 48 | def log(self, message: str = '', *args): 49 | """ 50 | Choose to print a log specific to the extension. 51 | """ 52 | OKGREEN = '\033[92m' 53 | ENDC = '\033[0m' 54 | 55 | title = self.extension_title 56 | message = f"Extension {title}: {message} " + ', '.join(args) 57 | print(OKGREEN + message + ENDC) 58 | -------------------------------------------------------------------------------- /scripts/t2v_helpers/render.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023 by Artem Khrapov (kabachuha) 2 | # Read LICENSE for usage terms. 3 | 4 | import traceback 5 | from modelscope.process_modelscope import process_modelscope 6 | import modelscope.process_modelscope as pm 7 | from videocrafter.process_videocrafter import process_videocrafter 8 | from modules.shared import opts 9 | from .error_hardcode import get_error 10 | from modules import lowvram, devices, sd_hijack 11 | import logging 12 | import gc 13 | import t2v_helpers.args as t2v_helpers_args 14 | 15 | def run(*args): 16 | dataurl = get_error() 17 | vids_pack = [dataurl] 18 | component_names = t2v_helpers_args.get_component_names() 19 | # api check 20 | num_components = len(component_names) 21 | affected_args = args[2:] if len(args) > num_components else args 22 | # TODO: change to i+2 when we will add the progress bar 23 | args_dict = {component_names[i]: affected_args[i] for i in range(0, num_components)} 24 | model_type = args_dict['model_type'] 25 | t2v_helpers_args.i1_store_t2v = f'

text2video extension for auto1111 — version 1.3b

' 26 | keep_pipe_in_vram = opts.data.get("modelscope_deforum_keep_model_in_vram") if opts.data is not None and opts.data.get("modelscope_deforum_keep_model_in_vram") is not None else 'None' 27 | try: 28 | print(f'text2video — The model selected is: {args_dict["model"]} ({args_dict["model_type"]}-like)') 29 | if model_type == 'ModelScope': 30 | vids_pack = process_modelscope(args_dict, args) 31 | elif model_type == 'VideoCrafter (WIP)': 32 | vids_pack = process_videocrafter(args_dict) 33 | else: 34 | raise NotImplementedError(f"Unknown model type: {model_type}") 35 | except Exception as e: 36 | traceback.print_exc() 37 | print('Exception occurred:', e) 38 | finally: 39 | #optionally store pipe in global between runs, if not, remove it 40 | if keep_pipe_in_vram == 'None': 41 | pm.pipe = None 42 | devices.torch_gc() 43 | gc.collect() 44 | return vids_pack 45 | 46 | 47 | -------------------------------------------------------------------------------- /scripts/videocrafter/lvdm/models/modules/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class DiagonalGaussianDistribution(object): 6 | def __init__(self, parameters, deterministic=False): 7 | self.parameters = parameters 8 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 9 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 10 | self.deterministic = deterministic 11 | self.std = torch.exp(0.5 * self.logvar) 12 | self.var = torch.exp(self.logvar) 13 | if self.deterministic: 14 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 15 | 16 | def sample(self, noise=None): 17 | if noise is None: 18 | noise = torch.randn(self.mean.shape) 19 | 20 | x = self.mean + self.std * noise.to(device=self.parameters.device) 21 | return x 22 | 23 | def kl(self, other=None): 24 | if self.deterministic: 25 | return torch.Tensor([0.]) 26 | else: 27 | if other is None: 28 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 29 | + self.var - 1.0 - self.logvar, 30 | dim=[1, 2, 3]) 31 | else: 32 | return 0.5 * torch.sum( 33 | torch.pow(self.mean - other.mean, 2) / other.var 34 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 35 | dim=[1, 2, 3]) 36 | 37 | def nll(self, sample, dims=[1,2,3]): 38 | if self.deterministic: 39 | return torch.Tensor([0.]) 40 | logtwopi = np.log(2.0 * np.pi) 41 | return 0.5 * torch.sum( 42 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 43 | dim=dims) 44 | 45 | def mode(self): 46 | return self.mean 47 | 48 | 49 | def normal_kl(mean1, logvar1, mean2, logvar2): 50 | """ 51 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 52 | Compute the KL divergence between two gaussians. 53 | Shapes are automatically broadcasted, so batches can be compared to 54 | scalars, among other use cases. 55 | """ 56 | tensor = None 57 | for obj in (mean1, logvar1, mean2, logvar2): 58 | if isinstance(obj, torch.Tensor): 59 | tensor = obj 60 | break 61 | assert tensor is not None, "at least one argument must be a Tensor" 62 | 63 | # Force variances to be Tensors. Broadcasting helps convert scalars to 64 | # Tensors, but it does not work for torch.exp(). 65 | logvar1, logvar2 = [ 66 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 67 | for x in (logvar1, logvar2) 68 | ] 69 | 70 | return 0.5 * ( 71 | -1.0 72 | + logvar2 73 | - logvar1 74 | + torch.exp(logvar1 - logvar2) 75 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 76 | ) 77 | -------------------------------------------------------------------------------- /scripts/samplers/uni_pc/sampler.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | 5 | from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC 6 | 7 | class UniPCSampler(object): 8 | def __init__(self, model, **kwargs): 9 | super().__init__() 10 | self.model = model 11 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) 12 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) 13 | 14 | def register_buffer(self, name, attr): 15 | if type(attr) == torch.Tensor: 16 | if attr.device != torch.device("cuda"): 17 | attr = attr.to(torch.device("cuda")) 18 | setattr(self, name, attr) 19 | 20 | def unipc_encode(self, latent, device, strength, steps, noise=None): 21 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) 22 | uni_pc = UniPC(None, ns, predict_x0=True, thresholding=False, variant='bh1') 23 | t_0 = 1. / ns.total_N 24 | 25 | timesteps = uni_pc.get_time_steps("time_uniform", strength, t_0, steps, device) 26 | timesteps = timesteps[0].expand((latent.shape[0])) 27 | 28 | noisy_latent = uni_pc.unipc_encode(latent, timesteps, noise=noise) 29 | return noisy_latent 30 | 31 | @torch.no_grad() 32 | def sample(self, 33 | S, 34 | batch_size, 35 | shape, 36 | conditioning=None, 37 | callback=None, 38 | normals_sequence=None, 39 | img_callback=None, 40 | quantize_x0=False, 41 | strength=None, 42 | eta=0., 43 | mask=None, 44 | x0=None, 45 | temperature=1., 46 | score_corrector=None, 47 | corrector_kwargs=None, 48 | verbose=True, 49 | x_T=None, 50 | log_every_t=100, 51 | unconditional_guidance_scale=1., 52 | unconditional_conditioning=None, 53 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 54 | **kwargs 55 | ): 56 | 57 | # sampling 58 | B, C, F, H, W = shape 59 | size = (B, C, F, H, W) 60 | 61 | if x_T is None: 62 | img = torch.randn(size, device=self.model.device) 63 | else: 64 | img = x_T 65 | 66 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) 67 | model_fn = model_wrapper( 68 | lambda x, t, c: self.model(x, t, c), 69 | ns, 70 | model_type="noise", 71 | guidance_type="classifier-free", 72 | condition=conditioning, 73 | unconditional_condition=unconditional_conditioning, 74 | guidance_scale=unconditional_guidance_scale, 75 | ) 76 | 77 | uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant='bh1') 78 | x = uni_pc.sample( 79 | img, 80 | steps=S, 81 | t_start=strength, 82 | skip_type="time_uniform", 83 | method="multistep", 84 | order=3, 85 | lower_order_final=True, 86 | initial_corrector=True, 87 | callback=callback 88 | ) 89 | 90 | return x.to(self.model.device) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 105 | __pypackages__/ 106 | 107 | # Celery stuff 108 | celerybeat-schedule 109 | celerybeat.pid 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | .env 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | 141 | # pytype static type analyzer 142 | .pytype/ 143 | 144 | # Cython debug symbols 145 | cython_debug/ 146 | 147 | # PyCharm 148 | # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can 149 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 150 | # and can be added to the global gitignore or merged into this file. For a more nuclear 151 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 152 | #.idea/ 153 | -------------------------------------------------------------------------------- /scripts/t2v_helpers/key_frames.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023 by Artem Khrapov (kabachuha) 2 | # Read LICENSE for usage terms. 3 | 4 | import re 5 | import numpy as np 6 | import numexpr 7 | import pandas as pd 8 | 9 | class T2VAnimKeys(): 10 | def __init__(self, anim_args, seed=-1, max_i_frames=1): 11 | self.fi = FrameInterpolater(anim_args.max_frames, seed, max_i_frames) 12 | self.inpainting_weights_series = self.fi.get_inbetweens(self.fi.parse_key_frames(anim_args.inpainting_weights)) 13 | 14 | def check_is_number(value): 15 | float_pattern = r'^(?=.)([+-]?([0-9]*)(\.([0-9]+))?)$' 16 | return re.match(float_pattern, value) 17 | 18 | class FrameInterpolater(): 19 | def __init__(self, max_frames=0, seed=-1, max_i_frames=1) -> None: 20 | self.max_frames = max_frames 21 | self.seed = seed 22 | self.max_i_frames = max_i_frames 23 | 24 | def sanitize_value(self, value): 25 | return value.replace("'","").replace('"',"").replace('(',"").replace(')',"") 26 | 27 | def get_inbetweens(self, key_frames, integer=False, interp_method='Linear', is_single_string = False): 28 | key_frame_series = pd.Series([np.nan for a in range(self.max_frames)]) 29 | # get our ui variables set for numexpr.evaluate 30 | max_f = self.max_frames -1 31 | max_i_f = self.max_i_frames - 1 32 | s = self.seed 33 | for i in range(0, self.max_frames): 34 | if i in key_frames: 35 | value = key_frames[i] 36 | value_is_number = check_is_number(self.sanitize_value(value)) 37 | if value_is_number: # if it's only a number, leave the rest for the default interpolation 38 | key_frame_series[i] = self.sanitize_value(value) 39 | if not value_is_number: 40 | t = i 41 | # workaround for values formatted like 0:("I am test") //used for sampler schedules 42 | key_frame_series[i] = numexpr.evaluate(value) if not is_single_string else self.sanitize_value(value) 43 | elif is_single_string:# take previous string value and replicate it 44 | key_frame_series[i] = key_frame_series[i-1] 45 | key_frame_series = key_frame_series.astype(float) if not is_single_string else key_frame_series # as string 46 | 47 | if interp_method == 'Cubic' and len(key_frames.items()) <= 3: 48 | interp_method = 'Quadratic' 49 | if interp_method == 'Quadratic' and len(key_frames.items()) <= 2: 50 | interp_method = 'Linear' 51 | 52 | key_frame_series[0] = key_frame_series[key_frame_series.first_valid_index()] 53 | key_frame_series[self.max_frames-1] = key_frame_series[key_frame_series.last_valid_index()] 54 | key_frame_series = key_frame_series.interpolate(method=interp_method.lower(), limit_direction='both') 55 | if integer: 56 | return key_frame_series.astype(int) 57 | return key_frame_series 58 | 59 | def parse_key_frames(self, string): 60 | # because math functions (i.e. sin(t)) can utilize brackets 61 | # it extracts the value in form of some stuff 62 | # which has previously been enclosed with brackets and 63 | # with a comma or end of line existing after the closing one 64 | frames = dict() 65 | for match_object in string.split(","): 66 | frameParam = match_object.split(":") 67 | max_f = self.max_frames - 1 68 | max_i_f = self.max_i_frames - 1 69 | s = self.seed 70 | frame = int(self.sanitize_value(frameParam[0])) if check_is_number(self.sanitize_value(frameParam[0].strip())) else int(numexpr.evaluate(frameParam[0].strip().replace("'","",1).replace('"',"",1)[::-1].replace("'","",1).replace('"',"",1)[::-1])) 71 | frames[frame] = frameParam[1].strip() 72 | if frames == {} and len(string) != 0: 73 | raise RuntimeError('Key Frame string not correctly formatted') 74 | return frames 75 | -------------------------------------------------------------------------------- /scripts/videocrafter/lvdm/models/modules/adapter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from collections import OrderedDict 4 | from videocrafter.lvdm.models.modules.util import ( 5 | zero_module, 6 | conv_nd, 7 | avg_pool_nd 8 | ) 9 | 10 | class Downsample(nn.Module): 11 | """ 12 | A downsampling layer with an optional convolution. 13 | :param channels: channels in the inputs and outputs. 14 | :param use_conv: a bool determining if a convolution is applied. 15 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 16 | downsampling occurs in the inner-two dimensions. 17 | """ 18 | 19 | def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): 20 | super().__init__() 21 | self.channels = channels 22 | self.out_channels = out_channels or channels 23 | self.use_conv = use_conv 24 | self.dims = dims 25 | stride = 2 if dims != 3 else (1, 2, 2) 26 | if use_conv: 27 | self.op = conv_nd( 28 | dims, self.channels, self.out_channels, 3, stride=stride, padding=padding 29 | ) 30 | else: 31 | assert self.channels == self.out_channels 32 | self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) 33 | 34 | def forward(self, x): 35 | assert x.shape[1] == self.channels 36 | return self.op(x) 37 | 38 | 39 | class ResnetBlock(nn.Module): 40 | def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True): 41 | super().__init__() 42 | ps = ksize // 2 43 | if in_c != out_c or sk == False: 44 | self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps) 45 | else: 46 | # print('n_in') 47 | self.in_conv = None 48 | self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1) 49 | self.act = nn.ReLU() 50 | self.block2 = nn.Conv2d(out_c, out_c, ksize, 1, ps) 51 | if sk == False: 52 | self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps) 53 | else: 54 | self.skep = None 55 | 56 | self.down = down 57 | if self.down == True: 58 | self.down_opt = Downsample(in_c, use_conv=use_conv) 59 | 60 | def forward(self, x): 61 | if self.down == True: 62 | x = self.down_opt(x) 63 | if self.in_conv is not None: # edit 64 | x = self.in_conv(x) 65 | 66 | h = self.block1(x) 67 | h = self.act(h) 68 | h = self.block2(h) 69 | if self.skep is not None: 70 | return h + self.skep(x) 71 | else: 72 | return h + x 73 | 74 | 75 | class Adapter(nn.Module): 76 | def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64, ksize=3, sk=False, use_conv=True): 77 | super(Adapter, self).__init__() 78 | self.unshuffle = nn.PixelUnshuffle(8) 79 | self.channels = channels 80 | self.nums_rb = nums_rb 81 | self.body = [] 82 | for i in range(len(channels)): 83 | for j in range(nums_rb): 84 | if (i != 0) and (j == 0): 85 | self.body.append( 86 | ResnetBlock(channels[i - 1], channels[i], down=True, ksize=ksize, sk=sk, use_conv=use_conv)) 87 | else: 88 | self.body.append( 89 | ResnetBlock(channels[i], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv)) 90 | self.body = nn.ModuleList(self.body) 91 | self.conv_in = nn.Conv2d(cin, channels[0], 3, 1, 1) 92 | 93 | def forward(self, x): 94 | # unshuffle 95 | x = self.unshuffle(x) 96 | # extract features 97 | features = [] 98 | x = self.conv_in(x) 99 | for i in range(len(self.channels)): 100 | for j in range(self.nums_rb): 101 | idx = i * self.nums_rb + j 102 | x = self.body[idx](x) 103 | features.append(x) 104 | 105 | return features -------------------------------------------------------------------------------- /scripts/videocrafter/lvdm/utils/common_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import importlib 3 | 4 | import torch 5 | import numpy as np 6 | 7 | from inspect import isfunction 8 | from PIL import Image, ImageDraw, ImageFont 9 | 10 | 11 | def str2bool(v): 12 | if isinstance(v, bool): 13 | return v 14 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 15 | return True 16 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 17 | return False 18 | else: 19 | raise ValueError('Boolean value expected.') 20 | 21 | 22 | def instantiate_from_config(config): 23 | if not "target" in config: 24 | if config == '__is_first_stage__': 25 | return None 26 | elif config == "__is_unconditional__": 27 | return None 28 | raise KeyError("Expected key `target` to instantiate.") 29 | 30 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 31 | 32 | def get_obj_from_str(string, reload=False): 33 | module, cls = string.rsplit(".", 1) 34 | if reload: 35 | module_imp = importlib.import_module('videocrafter.'+module if not 'torch' in module else module) 36 | importlib.reload(module_imp) 37 | return getattr(importlib.import_module('videocrafter.'+module if not 'torch' in module else module, package=None), cls) 38 | 39 | def log_txt_as_img(wh, xc, size=10): 40 | # wh a tuple of (width, height) 41 | # xc a list of captions to plot 42 | b = len(xc) 43 | txts = list() 44 | for bi in range(b): 45 | txt = Image.new("RGB", wh, color="white") 46 | draw = ImageDraw.Draw(txt) 47 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 48 | nc = int(40 * (wh[0] / 256)) 49 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 50 | 51 | try: 52 | draw.text((0, 0), lines, fill="black", font=font) 53 | except UnicodeEncodeError: 54 | print("Cant encode string for logging. Skipping.") 55 | 56 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 57 | txts.append(txt) 58 | txts = np.stack(txts) 59 | txts = torch.tensor(txts) 60 | return txts 61 | 62 | 63 | def ismap(x): 64 | if not isinstance(x, torch.Tensor): 65 | return False 66 | return (len(x.shape) == 4) and (x.shape[1] > 3) 67 | 68 | 69 | def isimage(x): 70 | if not isinstance(x,torch.Tensor): 71 | return False 72 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 73 | 74 | 75 | def exists(x): 76 | return x is not None 77 | 78 | 79 | def default(val, d): 80 | if exists(val): 81 | return val 82 | return d() if isfunction(d) else d 83 | 84 | 85 | def mean_flat(tensor): 86 | """ 87 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 88 | Take the mean over all non-batch dimensions. 89 | """ 90 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 91 | 92 | 93 | def count_params(model, verbose=False): 94 | total_params = sum(p.numel() for p in model.parameters()) 95 | if verbose: 96 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") 97 | return total_params 98 | 99 | 100 | def instantiate_from_config(config): 101 | if not "target" in config: 102 | if config == '__is_first_stage__': 103 | return None 104 | elif config == "__is_unconditional__": 105 | return None 106 | raise KeyError("Expected key `target` to instantiate.") 107 | 108 | if "instantiate_with_dict" in config and config["instantiate_with_dict"]: 109 | # input parameter is one dict 110 | return get_obj_from_str(config["target"])(config.get("params", dict()), **kwargs) 111 | else: 112 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 113 | 114 | 115 | def get_obj_from_str(string, reload=False): 116 | module, cls = string.rsplit(".", 1) 117 | if reload: 118 | module_imp = importlib.import_module('videocrafter.'+module if not 'torch' in module else module) 119 | importlib.reload(module_imp) 120 | return getattr(importlib.import_module('videocrafter.'+module if not 'torch' in module else module, package=None), cls) 121 | 122 | 123 | def check_istarget(name, para_list): 124 | """ 125 | name: full name of source para 126 | para_list: partial name of target para 127 | """ 128 | istarget=False 129 | for para in para_list: 130 | if para in name: 131 | return True 132 | return istarget -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yml: -------------------------------------------------------------------------------- 1 | name: Bug Report 2 | description: Create a bug report for the ModelScope text2video extension 3 | title: "[Bug]: " 4 | labels: ["bug"] 5 | 6 | body: 7 | - type: checkboxes 8 | attributes: 9 | label: Is there an existing issue for this? 10 | description: Please search to see if an issue already exists for the bug you encountered (including the closed issues). 11 | options: 12 | - label: I have searched the existing issues and checked the recent builds/commits of both this extension and the webui 13 | required: true 14 | - type: checkboxes 15 | attributes: 16 | label: Are you using the latest version of the extension? 17 | description: Please, check if your text2video setup is based on the latest repo commit (git log) or update it through the 'Extensions' tab and check if the issue still persist. Otherwise, check this box. 18 | options: 19 | - label: I have the modelscope text2video extension updated to the lastest version and I still have the issue. 20 | required: true 21 | - type: markdown 22 | attributes: 23 | value: | 24 | *Please fill this form with as much information as possible, don't forget to fill "What OS..." and "What browsers" and *provide screenshots if possible** 25 | - type: textarea 26 | id: what-did 27 | attributes: 28 | label: What happened? 29 | description: Tell us what happened in a very clear and simple way 30 | validations: 31 | required: true 32 | - type: textarea 33 | id: steps 34 | attributes: 35 | label: Steps to reproduce the problem 36 | description: Please provide us with precise step by step information on how to reproduce the bug 37 | value: | 38 | 1. Go to .... 39 | 2. Press .... 40 | 3. ... 41 | validations: 42 | required: true 43 | - type: textarea 44 | id: what-should 45 | attributes: 46 | label: What should have happened? 47 | description: Tell what you think the normal behavior should be 48 | - type: textarea 49 | id: commits 50 | attributes: 51 | label: WebUI and Deforum extension Commit IDs 52 | description: Which commit of the webui/text2video extension are you running on? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or if you can't launch the webui at all, enter your cmd/terminal, CD into the main webui folder to get the webui commit id, and cd into the extensions/sd-webui-modelscope-text2video folder to get the text2video commit id, both using the command 'git rev-parse HEAD'.) 53 | value: | 54 | webui commit id - 55 | txt2vid commit id - 56 | validations: 57 | required: true 58 | - type: textarea 59 | id: what-torch 60 | attributes: 61 | label: Torch version 62 | description: Which Torch version your WebUI is working with 63 | validations: 64 | required: true 65 | - type: textarea 66 | id: what-gpu 67 | attributes: 68 | label: What GPU were you using for launching? 69 | description: The model and the amount of available VRAM 70 | validations: 71 | required: true 72 | - type: dropdown 73 | id: where 74 | validations: 75 | required: true 76 | attributes: 77 | label: On which platform are you launching the webui backend with the extension? 78 | multiple: true 79 | options: 80 | - Local PC setup (Windows) 81 | - Local PC setup (Linux) 82 | - Local PC setup (Mac) 83 | - Google Colab (The Last Ben's) 84 | - Google Colab (Other) 85 | - Cloud server (Linux) 86 | - Other (please specify in "additional information") 87 | - type: textarea 88 | id: deforumsettings 89 | attributes: 90 | label: Settings 91 | description: Send here a link to your used settings (since the repo is new, a screenshot is enough) 92 | validations: 93 | required: true 94 | - type: textarea 95 | id: logs 96 | attributes: 97 | label: Console logs 98 | description: Please provide **FULL cmd/terminal logs FROM THE MOMENT YOU STARTED UI to the end of it**, after your bug happened. If it's very long, provide a link to GitHub gists or similar service. 99 | render: Shell 100 | validations: 101 | required: true 102 | - type: textarea 103 | id: misc 104 | attributes: 105 | label: Additional information 106 | description: Please provide us with any relevant additional info or context. 107 | -------------------------------------------------------------------------------- /scripts/videocrafter/sample_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from PIL import Image 4 | 5 | from videocrafter.lvdm.models.modules.lora import net_load_lora 6 | from videocrafter.lvdm.utils.common_utils import instantiate_from_config 7 | 8 | 9 | # ------------------------------------------------------------------------------------------ 10 | def load_model(config, ckpt_path, gpu_id=None, inject_lora=False, lora_scale=1.0, lora_path=''): 11 | print(f"Loading model from {ckpt_path}") 12 | 13 | # load sd 14 | pl_sd = torch.load(ckpt_path, map_location="cpu") 15 | try: 16 | global_step = pl_sd["global_step"] 17 | epoch = pl_sd["epoch"] 18 | except: 19 | global_step = -1 20 | epoch = -1 21 | 22 | # load sd to model 23 | try: 24 | sd = pl_sd["state_dict"] 25 | except: 26 | sd = pl_sd 27 | model = instantiate_from_config(config.model) 28 | model.load_state_dict(sd, strict=True) 29 | 30 | if inject_lora: 31 | net_load_lora(model, lora_path, alpha=lora_scale) 32 | 33 | # move to device & eval 34 | if gpu_id is not None: 35 | model.to(f"cuda:{gpu_id}") 36 | else: 37 | model.cuda() 38 | model.eval() 39 | 40 | return model, global_step, epoch 41 | 42 | 43 | # ------------------------------------------------------------------------------------------ 44 | @torch.no_grad() 45 | def get_conditions(prompts, model, batch_size, cond_fps=None,): 46 | 47 | if isinstance(prompts, str) or isinstance(prompts, int): 48 | prompts = [prompts] 49 | if isinstance(prompts, list): 50 | if len(prompts) == 1: 51 | prompts = prompts * batch_size 52 | elif len(prompts) == batch_size: 53 | pass 54 | else: 55 | raise ValueError(f"invalid prompts length: {len(prompts)}") 56 | else: 57 | raise ValueError(f"invalid prompts: {prompts}") 58 | assert(len(prompts) == batch_size) 59 | 60 | # content condition: text / class label 61 | c = model.get_learned_conditioning(prompts) 62 | key = 'c_concat' if model.conditioning_key == 'concat' else 'c_crossattn' 63 | c = {key: [c]} 64 | 65 | # temporal condition: fps 66 | if getattr(model, 'cond_stage2_config', None) is not None: 67 | if model.cond_stage2_key == "temporal_context": 68 | assert(cond_fps is not None) 69 | batch = {'fps': torch.tensor([cond_fps] * batch_size).long().to(model.device)} 70 | fps_embd = model.cond_stage2_model(batch) 71 | c[model.cond_stage2_key] = fps_embd 72 | 73 | return c 74 | 75 | 76 | # ------------------------------------------------------------------------------------------ 77 | def make_model_input_shape(model, batch_size, T=None): 78 | image_size = [model.image_size, model.image_size] if isinstance(model.image_size, int) else model.image_size 79 | C = model.model.diffusion_model.in_channels 80 | if T is None: 81 | T = model.model.diffusion_model.temporal_length 82 | shape = [batch_size, C, T, *image_size] 83 | return shape 84 | 85 | 86 | # ------------------------------------------------------------------------------------------ 87 | def custom_to_pil(x): 88 | x = x.detach().cpu() 89 | x = torch.clamp(x, -1., 1.) 90 | x = (x + 1.) / 2. 91 | x = x.permute(1, 2, 0).numpy() 92 | x = (255 * x).astype(np.uint8) 93 | x = Image.fromarray(x) 94 | if not x.mode == "RGB": 95 | x = x.convert("RGB") 96 | return x 97 | 98 | def torch_to_np(x): 99 | # saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py 100 | sample = x.detach().cpu() 101 | sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8) 102 | if sample.dim() == 5: 103 | sample = sample.permute(0, 2, 3, 4, 1) 104 | else: 105 | sample = sample.permute(0, 2, 3, 1) 106 | sample = sample.contiguous() 107 | return sample 108 | 109 | def make_sample_dir(opt, global_step=None, epoch=None): 110 | if not getattr(opt, 'not_automatic_logdir', False): 111 | gs_str = f"globalstep{global_step:09}" if global_step is not None else "None" 112 | e_str = f"epoch{epoch:06}" if epoch is not None else "None" 113 | ckpt_dir = os.path.join(opt.logdir, f"{gs_str}_{e_str}") 114 | 115 | # subdir name 116 | if opt.prompt_file is not None: 117 | subdir = f"prompts_{os.path.splitext(os.path.basename(opt.prompt_file))[0]}" 118 | else: 119 | subdir = f"prompt_{opt.prompt[:10]}" 120 | subdir += "_DDPM" if opt.vanilla_sample else f"_DDIM{opt.custom_steps}steps" 121 | subdir += f"_CfgScale{opt.scale}" 122 | if opt.cond_fps is not None: 123 | subdir += f"_fps{opt.cond_fps}" 124 | if opt.seed is not None: 125 | subdir += f"_seed{opt.seed}" 126 | 127 | return os.path.join(ckpt_dir, subdir) 128 | else: 129 | return opt.logdir 130 | -------------------------------------------------------------------------------- /.github/scripts/issue_checker.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from github import Github 4 | 5 | # Get GitHub token from environment variables 6 | token = os.environ['GITHUB_TOKEN'] 7 | g = Github(token) 8 | 9 | # Get the current repository 10 | print(f"Repo is {os.environ['GITHUB_REPOSITORY']}") 11 | repo = g.get_repo(os.environ['GITHUB_REPOSITORY']) 12 | 13 | # Get the issue number from the event payload 14 | #issue_number = int(os.environ['ISSUE_NUMBER']) 15 | 16 | detox = True 17 | 18 | for issue in repo.get_issues(): 19 | print(f"Processing issue №{issue.number}") 20 | if issue.pull_request: 21 | continue 22 | 23 | if detox: 24 | continue 25 | 26 | # Get the issue object 27 | #issue = repo.get_issue(issue_number) 28 | 29 | # Define the keywords to search for in the issue 30 | keywords = ['Python', 'Commit hash', 'Launching Web UI with arguments', 'text2video'] 31 | 32 | # Check if ALL of the keywords are present in the issue 33 | def check_keywords(issue_body, keywords): 34 | for keyword in keywords: 35 | if not re.search(r'\b' + re.escape(keyword) + r'\b', issue_body, re.IGNORECASE): 36 | return False 37 | return True 38 | 39 | # Check if the issue title has at least a specified number of words 40 | def check_title_word_count(issue_title, min_word_count): 41 | words = issue_title.replace("/", " ").replace("\\\\", " ").split() 42 | return len(words) >= min_word_count 43 | 44 | # Check if the issue title is concise 45 | def check_title_concise(issue_title, max_word_count): 46 | words = issue_title.replace("/", " ").replace("\\\\", " ").split() 47 | return len(words) <= max_word_count 48 | 49 | # Check if the commit ID is in the correct hash form 50 | def check_commit_id_format(issue_body): 51 | match = re.search(r'webui commit id - ([a-fA-F0-9]+|\[[a-fA-F0-9]+\])', issue_body) 52 | if not match: 53 | return False 54 | webui_commit_id = match.group(1) 55 | webui_commit_id = webui_commit_id.replace("[", "").replace("]", "") 56 | if not (7 <= len(webui_commit_id) <= 40): 57 | return False 58 | match = re.search(r'txt2vid commit id - ([a-fA-F0-9]+|\[[a-fA-F0-9]+\])', issue_body) 59 | if match: 60 | return False 61 | t2v_commit_id = match.group(1) 62 | t2v_commit_id = t2v_commit_id.replace("[", "").replace("]", "") 63 | if not (7 <= len(t2v_commit_id) <= 40): 64 | return False 65 | return True 66 | 67 | # Only if a bug report 68 | if '[Bug]' in issue.title and not '[Feature Request]' in issue.title and not 'Repos for Training and Finetuning' in issue.title: 69 | print('The issue is eligible') 70 | # Initialize an empty list to store error messages 71 | error_messages = [] 72 | 73 | # Check for each condition and add the corresponding error message if the condition is not met 74 | if not check_keywords(issue.body, keywords): 75 | error_messages.append("Include **THE FULL LOG FROM THE START OF THE WEBUI** in the issue description.") 76 | 77 | if not check_title_word_count(issue.title, 3): 78 | error_messages.append("Make sure the issue title has at least 3 words.") 79 | 80 | if not check_title_concise(issue.title, 13): 81 | error_messages.append("The issue title should be concise and contain no more than 13 words.") 82 | 83 | # if not check_commit_id_format(issue.body): 84 | # error_messages.append("Provide a valid commit ID in the format 'commit id - [commit_hash]' **both** for the WebUI and the Extension.") 85 | 86 | # If there are any error messages, close the issue and send a comment with the error messages 87 | if error_messages: 88 | print('Invalid issue, closing') 89 | # Add the "not planned" label to the issue 90 | not_planned_label = repo.get_label("wrong format") 91 | issue.add_to_labels(not_planned_label) 92 | 93 | # Close the issue 94 | issue.edit(state='closed') 95 | 96 | # Generate the comment by concatenating the error messages 97 | comment = "This issue has been closed due to incorrect formatting. Please address the following mistakes and reopen the issue:\n\n" 98 | comment += "\n".join(f"- {error_message}" for error_message in error_messages) 99 | 100 | # Add the comment to the issue 101 | issue.create_comment(comment) 102 | elif repo.get_label("wrong format") in issue.labels: 103 | print('Issue is fine') 104 | issue.edit(state='open') 105 | issue.delete_labels() 106 | bug_label = repo.get_label("bug") 107 | issue.add_to_labels(bug_label) 108 | comment = "Thanks for addressing your formatting mistakes. The issue has been reopened now." 109 | issue.create_comment(comment) 110 | -------------------------------------------------------------------------------- /scripts/videocrafter/process_videocrafter.py: -------------------------------------------------------------------------------- 1 | from base64 import b64encode 2 | from tqdm import tqdm 3 | from omegaconf import OmegaConf 4 | import time, os 5 | from t2v_helpers.general_utils import get_t2v_version 6 | from t2v_helpers.args import get_outdir, process_args 7 | import modules.paths as ph 8 | import t2v_helpers.args as t2v_helpers_args 9 | from modules.shared import state 10 | 11 | # VideoCrafter support is heavy WIP and sketchy, needs help and more devs! 12 | def process_videocrafter(args_dict): 13 | args, video_args = process_args(args_dict) 14 | print(f"\033[4;33m text2video extension for auto1111 webui\033[0m") 15 | print(f"Git commit: {get_t2v_version()}") 16 | init_timestring = time.strftime('%Y%m%d%H%M%S') 17 | outdir_current = os.path.join(get_outdir(), f"{init_timestring}") 18 | 19 | os.makedirs(outdir_current, exist_ok=True) 20 | 21 | # load & merge config 22 | 23 | config_path = os.path.join(ph.models_path, "models/VideoCrafter/model_config.yaml") 24 | if not os.path.exists(config_path): 25 | config_path = os.path.join(os.getcwd(), "extensions/sd-webui-modelscope-text2video/scripts/videocrafter/base_t2v/model_config.yaml") 26 | if not os.path.exists(config_path): 27 | config_path = os.path.join(os.getcwd(), "extensions/sd-webui-text2video/scripts/videocrafter/base_t2v/model_config.yaml") 28 | if not os.path.exists(config_path): 29 | raise FileNotFoundError(f'Could not find config file at {os.path.join(ph.models_path, "models/VideoCrafter/model_config.yaml")}, nor at {os.path.join(os.getcwd(), "extensions/sd-webui-modelscope-text2video/scripts/videocrafter/base_t2v/model_config.yaml")}, nor at {os.path.join(os.getcwd(), "extensions/sd-webui-text2video/scripts/videocrafter/base_t2v/model_config.yaml")}') 30 | 31 | config = OmegaConf.load(config_path) 32 | print("VideoCrafter config: \n", config) 33 | 34 | from videocrafter.lvdm.samplers.ddim import DDIMSampler 35 | from videocrafter.sample_utils import load_model, get_conditions, make_model_input_shape, torch_to_np 36 | from videocrafter.sample_text2video import sample_text2video 37 | from videocrafter.lvdm.utils.saving_utils import npz_to_video_grid 38 | from t2v_helpers.video_audio_utils import add_soundtrack 39 | 40 | # get model & sampler 41 | model, _, _ = load_model(config, ph.models_path+'/VideoCrafter/model.ckpt', #TODO: support safetensors and stuff 42 | inject_lora=False, # TODO 43 | lora_scale=1, # TODO 44 | lora_path=ph.models_path+'/VideoCrafter/LoRA/LoRA.ckpt', #TODO: support LoRA and stuff 45 | ) 46 | ddim_sampler = DDIMSampler(model)# if opt.sample_type == "ddim" else None 47 | 48 | # if opt.inject_lora: 49 | # assert(opt.lora_trigger_word != '') 50 | # prompts = [p + opt.lora_trigger_word for p in prompts] 51 | 52 | # go 53 | start = time.time() 54 | 55 | pbar = tqdm(range(args.batch_count), leave=False) 56 | if args.batch_count == 1: 57 | pbar.disable=True 58 | 59 | state.job_count = args.batch_count 60 | 61 | for batch in pbar: 62 | state.job_no = batch + 1 63 | if state.skipped: 64 | state.skipped = False 65 | 66 | if state.interrupted: 67 | break 68 | 69 | state.job = f"Batch {batch+1} out of {args.batch_count}" 70 | ddim_sampler.noise_gen.manual_seed(args.seed + batch if args.seed != -1 else -1) 71 | # sample 72 | samples = sample_text2video(model, args.prompt, args.n_prompt, 1, 1,# todo:add batch size support 73 | sample_type='ddim', sampler=ddim_sampler, 74 | ddim_steps=args.steps, eta=args.eta, 75 | cfg_scale=args.cfg_scale, 76 | decode_frame_bs=1, 77 | ddp=False, show_denoising_progress=False, 78 | num_frames=args.frames 79 | ) 80 | # save 81 | if batch > 0: 82 | outdir_current = os.path.join(get_outdir(), f"{init_timestring}_{batch}") 83 | print(f'text2video finished, saving frames to {outdir_current}') 84 | 85 | npz_to_video_grid(samples[0:1,...], # TODO: is this the reason only 1 second is saved? 86 | os.path.join(outdir_current, f"vid.mp4"), 87 | fps=video_args.fps) 88 | if video_args.add_soundtrack != 'None': 89 | add_soundtrack(video_args.ffmpeg_location, video_args.fps, os.path.join(outdir_current, f"vid.mp4"), 0, -1, None, video_args.add_soundtrack, video_args.soundtrack_path, video_args.ffmpeg_crf, video_args.ffmpeg_preset) 90 | print(f't2v complete, result saved at {outdir_current}') 91 | 92 | mp4 = open(outdir_current + os.path.sep + f"vid.mp4", 'rb').read() 93 | dataurl = "data:video/mp4;base64," + b64encode(mp4).decode() 94 | t2v_helpers_args.i1_store_t2v = f'

text2video extension for auto1111 — version 1.1b

' 95 | print("Finish sampling!") 96 | print(f"Run time = {(time.time() - start):.2f} seconds") 97 | pbar.close() 98 | return [dataurl] 99 | -------------------------------------------------------------------------------- /scripts/text2vid.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023 by Artem Khrapov (kabachuha) 2 | # Read LICENSE for usage terms. 3 | 4 | import sys, os 5 | 6 | basedirs = [os.getcwd()] 7 | if 'google.colab' in sys.modules: 8 | basedirs.append('/content/gdrive/MyDrive/sd/stable-diffusion-webui') # hardcode as TheLastBen's colab seems to be the primal source 9 | 10 | for basedir in basedirs: 11 | deforum_paths_to_ensure = [basedir + '/extensions/sd-webui-text2video/scripts', basedir + '/extensions/sd-webui-modelscope-text2video/scripts', basedir] 12 | 13 | for deforum_scripts_path_fix in deforum_paths_to_ensure: 14 | if not deforum_scripts_path_fix in sys.path: 15 | sys.path.extend([deforum_scripts_path_fix]) 16 | 17 | current_directory = os.path.dirname(os.path.abspath(__file__)) 18 | if current_directory not in sys.path: 19 | sys.path.append(current_directory) 20 | 21 | import gradio as gr 22 | from modules import script_callbacks, shared 23 | from modules.shared import cmd_opts, opts 24 | from t2v_helpers.render import run 25 | import t2v_helpers.args as args 26 | from t2v_helpers.args import setup_text2video_settings_dictionary 27 | from modules.call_queue import wrap_gradio_gpu_call 28 | from stable_lora.scripts.lora_webui import StableLoraScriptInstance 29 | StableLoraScript = StableLoraScriptInstance 30 | 31 | def process(*argss): 32 | # weird PATH stuff 33 | for basedir in basedirs: 34 | sys.path.extend([ 35 | basedir + '/scripts', 36 | basedir + '/extensions/sd-webui-text2video/scripts', 37 | basedir + '/extensions/sd-webui-modelscope-text2video/scripts', 38 | ]) 39 | if current_directory not in sys.path: 40 | sys.path.append(current_directory) 41 | 42 | run(*argss) 43 | return [args.i1_store_t2v] 44 | 45 | def on_ui_tabs(): 46 | with gr.Blocks(analytics_enabled=False) as deforum_interface: 47 | components = {} 48 | with gr.Row(elem_id='t2v-core').style(equal_height=False, variant='compact'): 49 | with gr.Column(scale=1, variant='panel'): 50 | components = setup_text2video_settings_dictionary() 51 | stable_lora_ui = StableLoraScript.ui() 52 | with gr.Column(scale=1, variant='compact'): 53 | with gr.Row(elem_id=f"text2vid_generate_box", variant='compact', elem_classes="generate-box"): 54 | interrupt = gr.Button('Interrupt', elem_id=f"text2vid_interrupt", elem_classes="generate-box-interrupt") 55 | skip = gr.Button('Skip', elem_id=f"text2vid_skip", elem_classes="generate-box-skip") 56 | run_button = gr.Button('Generate', elem_id=f"text2vid_generate", variant='primary') 57 | 58 | skip.click( 59 | fn=lambda: shared.state.skip(), 60 | inputs=[], 61 | outputs=[], 62 | ) 63 | 64 | interrupt.click( 65 | fn=lambda: shared.state.interrupt(), 66 | inputs=[], 67 | outputs=[], 68 | ) 69 | with gr.Row(variant='compact'): 70 | i1 = gr.HTML(args.i1_store_t2v, elem_id='deforum_header') 71 | with gr.Row(visible=False): 72 | dummy_component1 = gr.Label("") 73 | dummy_component2 = gr.Label("") 74 | with gr.Row(variant='compact', elem_id='text2vid_results_panel'): 75 | ... 76 | # gr.Label("", visible=False) 77 | with gr.Row(variant='compact'): 78 | i1 = gr.HTML(args.i1_store_t2v, elem_id='deforum_header') 79 | 80 | run_button.click( 81 | # , extra_outputs=[None, '', '']), 82 | fn=wrap_gradio_gpu_call(process), 83 | _js="submit_txt2vid", 84 | inputs=[dummy_component1, dummy_component2] + [components[name] for name in args.get_component_names()] + stable_lora_ui, 85 | outputs=[ 86 | i1 87 | ], 88 | ) 89 | return [(deforum_interface, "txt2video", "t2v_interface")] 90 | 91 | def on_ui_settings(): 92 | section = ('modelscope_deforum', "Text2Video") 93 | shared.opts.add_option("modelscope_deforum_keep_model_in_vram", shared.OptionInfo( 94 | 'None', "Keep model in VRAM between runs", gr.Radio, 95 | {"interactive": True, "choices": ['None', 'Main Model Only', 'All'], "visible": True if not (cmd_opts.lowvram or cmd_opts.medvram) else False}, section=section)) 96 | shared.opts.add_option("modelscope_deforum_vae_settings", shared.OptionInfo( 97 | "GPU (half precision)", "VAE Mode:", gr.Radio, {"interactive": True, "choices": ['GPU (half precision)', 'GPU', 'CPU (Low VRAM)']}, section=section)) 98 | shared.opts.add_option("modelscope_deforum_show_n_videos", shared.OptionInfo( 99 | -1, "How many videos to show on the right panel on completion (-1 = show all)", gr.Number, {"interactive": True, "visible": True}, section=section)) 100 | shared.opts.add_option("modelscope_save_info_to_file", shared.OptionInfo( 101 | False, "Save generation params to a text file near the video", gr.Checkbox, {'interactive':True, 'visible':True}, section=section)) 102 | shared.opts.add_option("modelscope_save_metadata", shared.OptionInfo( 103 | True, "Save generation params as video metadata", gr.Checkbox, {'interactive':True, 'visible':True}, section=section)) 104 | 105 | script_callbacks.on_ui_tabs(on_ui_tabs) 106 | script_callbacks.on_ui_settings(on_ui_settings) 107 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # text2video Extension for AUTOMATIC1111's StableDiffusion WebUI 2 | 3 | **~~Warning: as of 2023-11-21 this extension is not maintained. If you'd like to continue devving/remaking it, please contact me on Discord @kabachuha (you can also find me on [camenduru's server's text2video channel](https://discord.gg/TYk6rfT9)) and we'll figure it out~~** 4 | 5 | **~~Maintained starting on 2023-11-21 by [Deforum-art](https://github.com/deforum-art)~~** 6 | 7 | **Maintained by me again** 8 | 9 | Auto1111 extension implementing various text2video models, such as ModelScope and VideoCrafter, using only Auto1111 webui dependencies and downloadable models (so no logins required anywhere) 10 | 11 | ## Requirements 12 | 13 | ### ModelScope 14 | 15 | 6 GBs vram should be enough to run on GPU with low vram vae on at 256x256 (and we are already getting reports of people launching 192x192 videos [with 4gbs of vram](https://github.com/deforum-art/sd-webui-modelscope-text2video/discussions/27)). 24 frames long 256x256 video definitely fits into 12gbs of NVIDIA GeForce RTX 2080 Ti, or if you have a Torch2 attention optimization supported videocard, you can fit the whopping 125 frames (8 seconds) long video into the same 12 GBs of VRAM! 250 frames (16 seconds) in the same conditions take 20 gbs. 16 | 17 | Prompt: `best quality, anime girl dancing` 18 | 19 | https://user-images.githubusercontent.com/14872007/232229730-82df36cc-ac8b-46b3-949d-0e1dfc10a975.mp4 20 | 21 | 22 | We will appreciate *any* help with this extension, *especially* pull-requests. 23 | 24 | ### LoRA Support 25 | 26 | Currently, there is support for trained LoRAs using this finetune repository. Please follow instructions there on how to train them. 27 | https://github.com/ExponentialML/Text-To-Video-Finetuning#updates 28 | 29 | After training, simply place them into your default LoRA directory defined by your webui installation. 30 | 31 | ### VideoCrafter (WIP, needs more devs to maintain properly as well) 32 | 33 | VideoCrafter runs with around 9.2 GBs of VRAM with the settings set on Default. 34 | 35 | ## Major changes between versions 36 | 37 | Update 2023-03-27: VAE settings and "Keep model in VRAM" moved to general webui setting under 'ModelScopeTxt2Vid' section. 38 | 39 | Update 2023-03-26: prompt weights **implemented**! (ModelScope only yet, as of 2023-04-05) 40 | 41 | Update 2023-04-05: added VideoCrafter support, renamed the extension to plainly 'sd-webui-text2video' 42 | 43 | Update 2023-04-13: in-framing/in-painting support: allows to 'animate' an existing pic or even seamlessly loop the videos! 44 | 45 | Update 2023-04-15: **MEGA-UPDATE**: Torch2/xformers optimizations, possible to make 125 frames long video on 12 gbs of VRAM. CPU offloading doesn't happen now if keep_pipe_in_vram is checked. 46 | 47 | Update 2023-04-16: WebAPI is available! 48 | 49 | Update 2023-07-02: Alternate samplers, model hotswitch. 50 | 51 | ## Test examples: 52 | 53 | ### ModelScope 54 | 55 | Prompt: `cinematic explosion by greg rutkowski` 56 | 57 | https://user-images.githubusercontent.com/14872007/226345611-a1f0601f-db32-41bd-b983-80d363eca4d5.mp4 58 | 59 | Prompt: `really attractive anime girl skating, by makoto shinkai, cinematic lighting` 60 | 61 | https://user-images.githubusercontent.com/14872007/226468406-ce43fa0c-35f2-4625-a892-9fb3411d96bb.mp4 62 | 63 | **'Continuing' an existing image** 64 | 65 | Prompt: `best quality, astronaut dog` 66 | 67 | https://user-images.githubusercontent.com/14872007/232073361-bdb87a47-85ec-44d8-9dc4-40dab0bd0555.mp4 68 | 69 | Prompt: `explosion` 70 | 71 | https://user-images.githubusercontent.com/14872007/232073687-b7e78b06-182b-4ce6-b565-d6738c4890d1.mp4 72 | 73 | **In-painting and looping back the videos** 74 | 75 | Prompt: `nuclear explosion` 76 | 77 | https://user-images.githubusercontent.com/14872007/232073842-84860a3e-fa82-43a6-a411-5cfc509b5355.mp4 78 | 79 | Prompt: `best quality, lots of cheese` 80 | 81 | https://user-images.githubusercontent.com/14872007/232073876-16895cae-0f26-41bc-a575-0c811219cf88.mp4 82 | 83 | ### VideoCrafter 84 | 85 | Prompt: `anime 1girl reimu touhou` 86 | 87 | https://user-images.githubusercontent.com/14872007/230231253-2fd9b7af-3f05-41c8-8c92-51042b269116.mp4 88 | 89 | ## Where to get the weights 90 | 91 | ### ModelScope 92 | 93 | Download the following files from the [original HuggingFace repository](https://huggingface.co/damo-vilab/modelscope-damo-text-to-video-synthesis/tree/main). Alternatively, [download half-precision fp16 pruned weights (they are smaller and use less vram on loading)](https://huggingface.co/kabachuha/modelscope-damo-text2video-pruned-weights/tree/main): 94 | - VQGAN_autoencoder.pth 95 | - configuration.json 96 | - open_clip_pytorch_model.bin 97 | - text2video_pytorch_model.pth 98 | 99 | And put them in `stable-diffusion-webui/models/ModelScope/t2v`. Create those 2 folders if they are missing. 100 | 101 | ### VideoCrafter 102 | 103 | Download pretrained T2V models either via [this link](https://drive.google.com/file/d/13ZZTXyAKM3x0tObRQOQWdtnrI2ARWYf_/view?usp=share_link) or download [the pruned half precision weights](https://huggingface.co/kabachuha/videocrafter-pruned-weights/tree/main), and put the `model.ckpt` in `models/VideoCrafter/model.ckpt`. 104 | 105 | ## Fine-tunes and how to use them 106 | 107 | Thanks to https://github.com/ExponentialML/Text-To-Video-Finetuning you can fine-tune your models! 108 | 109 | To utilize a fine-tuned model here, use [this script](https://github.com/ExponentialML/Text-To-Video-Finetuning/pull/52) which will convert the Diffusers-formatted model that repo outputs into the original weights format. 110 | 111 | ### Prominent Fine-tunes 112 | 113 | **ZeroScope v2** 114 | 115 | Trained by @cerspense on high quality YouTube videos. Download the files from the folder named `zs2_XL` at [cerspense/zeroscope_v2_XL](https://huggingface.co/cerspense/zeroscope_v2_XL/tree/main/zs2_XL) and then add the missing `VQGAN_autoencoder.pth` and `configuration.json` from [any other ModelScope model](https://huggingface.co/kabachuha/modelscope-damo-text2video-pruned-weights/tree/main). 116 | 117 | https://github.com/kabachuha/sd-webui-text2video/assets/14872007/6fa39221-3608-415e-b8ce-04a2bad11d30 118 | 119 | **Potat1** 120 | 121 | [Potat1](https://huggingface.co/camenduru/potat1) is a ModelScope-based model trained by @camenduru on 2197 clips with the resolution of 1024x576 which makes it the first open source hi-res text2video model. 122 | 123 | https://github.com/kabachuha/sd-webui-text2video/assets/14872007/ff01c6cb-0000-40a2-ac7e-ec3edc5f9713 124 | 125 | To download the plug-and-play weights for the extension use this link https://huggingface.co/kabachuha/potat1-with-text-encoder-original-format. 126 | 127 | **Animov-0.1** 128 | 129 | [Animov-0.1 by strangeman3107](https://huggingface.co/datasets/strangeman3107/animov-0.1). The converted weights for this model reside [here](https://huggingface.co/kabachuha/animov-0.1-modelscope-original-format). 130 | 131 | https://user-images.githubusercontent.com/14872007/232611542-600cec38-d944-4530-bc5c-3595a115c2be.mp4 132 | 133 | ## Screenshots 134 | 135 | txt2vid with img2vid 136 | 137 | ![Screenshot 2023-04-15 at 17-53-36 Stable Diffusion](https://user-images.githubusercontent.com/14872007/232232319-c3a443ee-1a8a-4504-a114-d9da2ae916c2.png) 138 | 139 | vid2vid 140 | 141 | ![Screenshot 2023-04-15 at 17-33-32 Stable Diffusion](https://user-images.githubusercontent.com/14872007/232232338-a2aa4b78-35d0-4c9b-850b-15edc90c0c9f.png) 142 | 143 | ## Dev resources 144 | 145 | ### ModelScope 146 | 147 | HuggingFace space: 148 | 149 | https://huggingface.co/spaces/damo-vilab/modelscope-text-to-video-synthesis 150 | 151 | The model PyTorch implementation from ModelScope: 152 | 153 | https://github.com/modelscope/modelscope/tree/master/modelscope/models/multi_modal/video_synthesis 154 | 155 | Google Colab from the devs: 156 | 157 | https://colab.research.google.com/drive/1uW1ZqswkQ9Z9bp5Nbo5z59cAn7I0hE6R?usp=sharing 158 | 159 | ### VideoCrafter 160 | 161 | Github: 162 | 163 | https://github.com/VideoCrafter/VideoCrafter 164 | -------------------------------------------------------------------------------- /scripts/samplers/samplers_common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from samplers.ddim.sampler import DDIMSampler 3 | from samplers.ddim.gaussian_sampler import GaussianDiffusion 4 | from samplers.uni_pc.sampler import UniPCSampler 5 | from tqdm import tqdm 6 | from modules.shared import state 7 | from modules.sd_samplers_common import InterruptedException 8 | 9 | def get_height_width(h, w, divisor): 10 | return h // divisor, w // divisor 11 | 12 | def get_tensor_shape(batch_size, channels, frames, h, w, latents=None): 13 | if latents is None: 14 | return (batch_size, channels, frames, h, w) 15 | return latents.shape 16 | 17 | def inpaint_masking(xt, step, steps, mask, add_noise_cb, noise_cb_args): 18 | if mask is not None and step < steps - 1: 19 | 20 | #convert mask to 0,1 valued based on step 21 | v = (steps - step - 1) / steps 22 | binary_mask = torch.where(mask <= v, torch.zeros_like(mask), torch.ones_like(mask)) 23 | 24 | noise_to_add = add_noise_cb(**noise_cb_args) 25 | to_inpaint = noise_to_add 26 | xt = to_inpaint * (1 - binary_mask) + xt * binary_mask 27 | 28 | class SamplerStepCallback(object): 29 | def __init__(self, sampler_name: str, total_steps: int): 30 | self.sampler_name = sampler_name 31 | self.total_steps = total_steps 32 | self.current_step = 0 33 | self.progress_bar = tqdm(desc=self.progress_msg(sampler_name, total_steps), total=total_steps) 34 | 35 | def progress_msg(self, name, total_steps=None): 36 | total_steps = total_steps if total_steps is not None else self.total_steps 37 | state.sampling_steps = total_steps 38 | return f"Sampling using {name} for {total_steps} steps." 39 | 40 | def set_webui_step(self, step): 41 | state.sampling_step = step 42 | 43 | def is_finished(self, step): 44 | if step >= self.total_steps: 45 | self.progress_bar.close() 46 | self.current_step = 0 47 | 48 | def interrupt(self): 49 | return state.interrupted or state.skipped 50 | 51 | def cancel(self): 52 | raise InterruptedException 53 | 54 | def update(self, step): 55 | self.set_webui_step(step) 56 | 57 | if self.interrupt(): 58 | self.cancel() 59 | 60 | self.progress_bar.set_description(self.progress_msg(self.sampler_name)) 61 | self.progress_bar.update(1) 62 | 63 | self.is_finished(step) 64 | 65 | def __call__(self,*args, **kwargs): 66 | self.current_step += 1 67 | step = self.current_step 68 | 69 | self.update(step) 70 | 71 | class SamplerBase(object): 72 | def __init__(self, name: str, Sampler, frame_inpaint_support=False): 73 | self.name = name 74 | self.Sampler = Sampler 75 | self.frame_inpaint_support = frame_inpaint_support 76 | 77 | def register_buffers_to_model(self, sd_model, betas, device): 78 | self.alphas = 1. - betas 79 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) 80 | 81 | setattr(sd_model, 'device', device) 82 | setattr(sd_model, 'betas', betas) 83 | setattr(sd_model, 'alphas_cumprod', self.alphas_cumprod) 84 | 85 | def init_sampler(self, sd_model, betas, device, **kwargs): 86 | self.register_buffers_to_model(sd_model, betas, device) 87 | return self.Sampler(sd_model, betas=betas, **kwargs) 88 | 89 | available_samplers = [ 90 | SamplerBase("DDIM_Gaussian", GaussianDiffusion, True), 91 | SamplerBase("DDIM", DDIMSampler), 92 | SamplerBase("UniPC", UniPCSampler), 93 | ] 94 | 95 | class Txt2VideoSampler(object): 96 | def __init__(self, sd_model, device, betas=None, sampler_name="UniPC"): 97 | self.sd_model = sd_model 98 | self.device = device 99 | self.noise_gen = torch.Generator(device='cpu') 100 | self.sampler_name = sampler_name 101 | self.betas = betas 102 | self.sampler = self.get_sampler(sampler_name, betas=self.betas) 103 | 104 | def get_noise(self, num_sample, channels, frames, height, width, latents=None, seed=1): 105 | if latents is not None: 106 | latents.to(self.device) 107 | 108 | print(f"Using input latents. Shape: {latents.shape}, Mean: {torch.mean(latents)}, Std: {torch.std(latents)}") 109 | else: 110 | print("Sampling random noise.") 111 | 112 | num_sample = 1 113 | max_frames = frames 114 | 115 | latent_h, latent_w = get_height_width(height, width, 8) 116 | shape = get_tensor_shape(num_sample, channels, max_frames, latent_h, latent_w, latents) 117 | 118 | self.noise_gen.manual_seed(seed) 119 | noise = torch.randn(shape, generator=self.noise_gen).to(self.device) 120 | 121 | return latents, noise, shape 122 | 123 | def encode_latent(self, latent, noise, strength, steps): 124 | encoded_latent = None 125 | denoise_steps = None 126 | 127 | if hasattr(self.sampler, 'unipc_encode'): 128 | encoded_latent = self.sampler.unipc_encode(latent, self.device, strength, steps, noise=noise) 129 | 130 | if hasattr(self.sampler, 'stochastic_encode'): 131 | denoise_steps = int(strength * steps) 132 | timestep = torch.tensor([denoise_steps] * int(latent.shape[0])).to(self.device) 133 | self.sampler.make_schedule(steps) 134 | encoded_latent = self.sampler.stochastic_encode(latent, timestep, noise=noise).to(dtype=latent.dtype) 135 | self.sampler.sample = self.sampler.decode 136 | 137 | if hasattr(self.sampler, 'add_noise'): 138 | denoise_steps = int(strength * steps) 139 | timestep = self.sampler.get_time_steps(denoise_steps, latent.shape[0]) 140 | encoded_latent = self.sampler.add_noise(latent, noise, timestep[0].cpu()) 141 | 142 | if encoded_latent is None: 143 | assert "Could not find the appropriate function to encode the input latents" 144 | 145 | return encoded_latent, denoise_steps 146 | 147 | def get_sampler(self, sampler_name: str, betas=None, return_sampler=True): 148 | betas = betas if betas is not None else self.betas 149 | 150 | for Sampler in available_samplers: 151 | if sampler_name == Sampler.name: 152 | sampler = Sampler.init_sampler(self.sd_model, betas=betas, device=self.device) 153 | 154 | if Sampler.frame_inpaint_support: 155 | setattr(sampler, 'inpaint_masking', inpaint_masking) 156 | 157 | if return_sampler: 158 | return sampler 159 | else: 160 | self.sampler = sampler 161 | return 162 | 163 | raise ValueError(f"Sample {sampler_name} does not exist.") 164 | 165 | def sample_loop( 166 | self, 167 | steps, 168 | strength, 169 | conditioning, 170 | unconditional_conditioning, 171 | batch_size, 172 | latents=None, 173 | shape=None, 174 | noise=None, 175 | is_vid2vid=False, 176 | guidance_scale=1, 177 | eta=0, 178 | mask=None, 179 | sampler_name="DDIM" 180 | ): 181 | denoise_steps = None 182 | # Assume that we are adding noise to existing latents (Image, Video, etc.) 183 | if latents is not None and is_vid2vid: 184 | latents, denoise_steps = self.encode_latent(latents, noise, strength, steps) 185 | 186 | # Create a callback that handles counting each step 187 | sampler_callback = SamplerStepCallback(sampler_name, steps) 188 | 189 | # Predict the noise sample 190 | x0 = self.sampler.sample( 191 | S=steps, 192 | conditioning=conditioning, 193 | strength=strength, 194 | unconditional_conditioning=unconditional_conditioning, 195 | batch_size=batch_size, 196 | x_T=latents if latents is not None else noise, 197 | x_latent=latents, 198 | t_start=denoise_steps, 199 | unconditional_guidance_scale=guidance_scale, 200 | shape=shape, 201 | callback=sampler_callback, 202 | cond=conditioning, 203 | eta=eta, 204 | mask=mask 205 | ) 206 | 207 | return x0 -------------------------------------------------------------------------------- /scripts/videocrafter/lvdm/data/webvid.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import bisect 4 | 5 | import pandas as pd 6 | 7 | import omegaconf 8 | import torch 9 | from torch.utils.data import Dataset 10 | from torchvision import transforms 11 | from decord import VideoReader, cpu 12 | import torchvision.transforms._transforms_video as transforms_video 13 | 14 | class WebVid(Dataset): 15 | """ 16 | WebVid Dataset. 17 | Assumes webvid data is structured as follows. 18 | Webvid/ 19 | videos/ 20 | 000001_000050/ ($page_dir) 21 | 1.mp4 (videoid.mp4) 22 | ... 23 | 5000.mp4 24 | ... 25 | """ 26 | def __init__(self, 27 | meta_path, 28 | data_dir, 29 | subsample=None, 30 | video_length=16, 31 | resolution=[256, 512], 32 | frame_stride=1, 33 | spatial_transform=None, 34 | crop_resolution=None, 35 | fps_max=None, 36 | load_raw_resolution=False, 37 | fps_schedule=None, 38 | fs_probs=None, 39 | bs_per_gpu=None, 40 | trigger_word='', 41 | dataname='', 42 | ): 43 | self.meta_path = meta_path 44 | self.data_dir = data_dir 45 | self.subsample = subsample 46 | self.video_length = video_length 47 | self.resolution = [resolution, resolution] if isinstance(resolution, int) else resolution 48 | self.frame_stride = frame_stride 49 | self.fps_max = fps_max 50 | self.load_raw_resolution = load_raw_resolution 51 | self.fs_probs = fs_probs 52 | self.trigger_word = trigger_word 53 | self.dataname = dataname 54 | 55 | self._load_metadata() 56 | if spatial_transform is not None: 57 | if spatial_transform == "random_crop": 58 | self.spatial_transform = transforms_video.RandomCropVideo(crop_resolution) 59 | elif spatial_transform == "resize_center_crop": 60 | assert(self.resolution[0] == self.resolution[1]) 61 | self.spatial_transform = transforms.Compose([ 62 | transforms.Resize(resolution), 63 | transforms_video.CenterCropVideo(resolution), 64 | ]) 65 | else: 66 | raise NotImplementedError 67 | else: 68 | self.spatial_transform = None 69 | 70 | self.fps_schedule = fps_schedule 71 | self.bs_per_gpu = bs_per_gpu 72 | if self.fps_schedule is not None: 73 | assert(self.bs_per_gpu is not None) 74 | self.counter = 0 75 | self.stage_idx = 0 76 | 77 | def _load_metadata(self): 78 | metadata = pd.read_csv(self.meta_path) 79 | if self.subsample is not None: 80 | metadata = metadata.sample(self.subsample, random_state=0) 81 | metadata['caption'] = metadata['name'] 82 | del metadata['name'] 83 | self.metadata = metadata 84 | self.metadata.dropna(inplace=True) 85 | # self.metadata['caption'] = self.metadata['caption'].str[:350] 86 | 87 | def _get_video_path(self, sample): 88 | if self.dataname == "loradata": 89 | rel_video_fp = str(sample['videoid']) + '.mp4' 90 | full_video_fp = os.path.join(self.data_dir, rel_video_fp) 91 | else: 92 | rel_video_fp = os.path.join(sample['page_dir'], str(sample['videoid']) + '.mp4') 93 | full_video_fp = os.path.join(self.data_dir, 'videos', rel_video_fp) 94 | return full_video_fp, rel_video_fp 95 | 96 | def get_fs_based_on_schedule(self, frame_strides, schedule): 97 | assert(len(frame_strides) == len(schedule) + 1) # nstage=len_fps_schedule + 1 98 | global_step = self.counter // self.bs_per_gpu # TODO: support resume. 99 | stage_idx = bisect.bisect(schedule, global_step) 100 | frame_stride = frame_strides[stage_idx] 101 | # log stage change 102 | if stage_idx != self.stage_idx: 103 | print(f'fps stage: {stage_idx} start ... new frame stride = {frame_stride}') 104 | self.stage_idx = stage_idx 105 | return frame_stride 106 | 107 | def get_fs_based_on_probs(self, frame_strides, probs): 108 | assert(len(frame_strides) == len(probs)) 109 | return random.choices(frame_strides, weights=probs)[0] 110 | 111 | def get_fs_randomly(self, frame_strides): 112 | return random.choice(frame_strides) 113 | 114 | def __getitem__(self, index): 115 | 116 | if isinstance(self.frame_stride, list) or isinstance(self.frame_stride, omegaconf.listconfig.ListConfig): 117 | if self.fps_schedule is not None: 118 | frame_stride = self.get_fs_based_on_schedule(self.frame_stride, self.fps_schedule) 119 | elif self.fs_probs is not None: 120 | frame_stride = self.get_fs_based_on_probs(self.frame_stride, self.fs_probs) 121 | else: 122 | frame_stride = self.get_fs_randomly(self.frame_stride) 123 | else: 124 | frame_stride = self.frame_stride 125 | assert(isinstance(frame_stride, int)), type(frame_stride) 126 | 127 | while True: 128 | index = index % len(self.metadata) 129 | sample = self.metadata.iloc[index] 130 | video_path, rel_fp = self._get_video_path(sample) 131 | caption = sample['caption']+self.trigger_word 132 | 133 | # make reader 134 | try: 135 | if self.load_raw_resolution: 136 | video_reader = VideoReader(video_path, ctx=cpu(0)) 137 | else: 138 | video_reader = VideoReader(video_path, ctx=cpu(0), width=self.resolution[1], height=self.resolution[0]) 139 | if len(video_reader) < self.video_length: 140 | print(f"video length ({len(video_reader)}) is smaller than target length({self.video_length})") 141 | index += 1 142 | continue 143 | else: 144 | pass 145 | except: 146 | index += 1 147 | print(f"Load video failed! path = {video_path}") 148 | continue 149 | 150 | # sample strided frames 151 | all_frames = list(range(0, len(video_reader), frame_stride)) 152 | if len(all_frames) < self.video_length: # recal a max fs 153 | frame_stride = len(video_reader) // self.video_length 154 | assert(frame_stride != 0) 155 | all_frames = list(range(0, len(video_reader), frame_stride)) 156 | 157 | # select a random clip 158 | rand_idx = random.randint(0, len(all_frames) - self.video_length) 159 | frame_indices = all_frames[rand_idx:rand_idx+self.video_length] 160 | try: 161 | frames = video_reader.get_batch(frame_indices) 162 | break 163 | except: 164 | print(f"Get frames failed! path = {video_path}") 165 | index += 1 166 | continue 167 | 168 | assert(frames.shape[0] == self.video_length),f'{len(frames)}, self.video_length={self.video_length}' 169 | frames = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float() # [t,h,w,c] -> [c,t,h,w] 170 | if self.spatial_transform is not None: 171 | frames = self.spatial_transform(frames) 172 | if self.resolution is not None: 173 | assert(frames.shape[2] == self.resolution[0] and frames.shape[3] == self.resolution[1]), f'frames={frames.shape}, self.resolution={self.resolution}' 174 | frames = (frames / 255 - 0.5) * 2 175 | 176 | fps_ori = video_reader.get_avg_fps() 177 | fps_clip = fps_ori // frame_stride 178 | if self.fps_max is not None and fps_clip > self.fps_max: 179 | fps_clip = self.fps_max 180 | 181 | data = {'video': frames, 'caption': caption, 'path': video_path, 'fps': fps_clip, 'frame_stride': frame_stride} 182 | 183 | if self.fps_schedule is not None: 184 | self.counter += 1 185 | return data 186 | 187 | def __len__(self): 188 | return len(self.metadata) 189 | -------------------------------------------------------------------------------- /scripts/api_t2v.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023 by Artem Khrapov (kabachuha) 2 | # Read LICENSE for usage terms. 3 | 4 | import sys, os 5 | basedirs = [os.getcwd()] 6 | if 'google.colab' in sys.modules: 7 | basedirs.append('/content/gdrive/MyDrive/sd/stable-diffusion-webui') #hardcode as TheLastBen's colab seems to be the primal source 8 | 9 | for basedir in basedirs: 10 | deforum_paths_to_ensure = [basedir + '/extensions/sd-webui-text2video/scripts', basedir + '/extensions/sd-webui-modelscope-text2video/scripts', basedir] 11 | 12 | for deforum_scripts_path_fix in deforum_paths_to_ensure: 13 | if not deforum_scripts_path_fix in sys.path: 14 | sys.path.extend([deforum_scripts_path_fix]) 15 | 16 | current_directory = os.path.dirname(os.path.abspath(__file__)) 17 | if current_directory not in sys.path: 18 | sys.path.append(current_directory) 19 | 20 | import hashlib 21 | import io 22 | import json 23 | import logging 24 | import os 25 | import sys 26 | import tempfile 27 | from PIL import Image 28 | import urllib 29 | from typing import Union 30 | import traceback 31 | from types import SimpleNamespace 32 | 33 | from fastapi import FastAPI, Query, Request, UploadFile 34 | from fastapi.encoders import jsonable_encoder 35 | from fastapi.exceptions import RequestValidationError 36 | from fastapi.responses import JSONResponse 37 | from t2v_helpers.video_audio_utils import find_ffmpeg_binary 38 | from t2v_helpers.general_utils import get_t2v_version 39 | from t2v_helpers.args import T2VArgs_sanity_check, T2VArgs, T2VOutputArgs 40 | from t2v_helpers.render import run 41 | import uuid 42 | 43 | logger = logging.getLogger(__name__) 44 | 45 | current_directory = os.path.dirname(os.path.abspath(__file__)) 46 | if current_directory not in sys.path: 47 | sys.path.append(current_directory) 48 | 49 | def t2v_api(_, app: FastAPI): 50 | logger.debug(f"text2video extension for auto1111 webui") 51 | logger.debug(f"Git commit: {get_t2v_version()}") 52 | logger.debug("Loading text2video API endpoints") 53 | 54 | @app.exception_handler(RequestValidationError) 55 | async def validation_exception_handler(request: Request, exc: RequestValidationError): 56 | return JSONResponse( 57 | status_code=422, 58 | content=jsonable_encoder({"detail": exc.errors(), "body": exc.body}), 59 | ) 60 | 61 | @app.get("/t2v/api_version") 62 | async def t2v_api_version(): 63 | return JSONResponse(content={"version": '1.0'}) 64 | 65 | @app.get("/t2v/version") 66 | async def t2v_version(): 67 | return JSONResponse(content={"version": get_t2v_version()}) 68 | 69 | @app.post("/t2v/run") 70 | async def t2v_run(prompt: str, n_prompt: Union[str, None] = None, model: Union[str, None] = None, sampler: Union[str, None] = None, steps: Union[int, None] = None, frames: Union[int, None] = None, seed: Union[int, None] = None, \ 71 | cfg_scale: Union[int, None] = None, width: Union[int, None] = None, height: Union[int, None] = None, eta: Union[float, None] = None, batch_count: Union[int, None] = None, \ 72 | do_vid2vid:bool = False, vid2vid_input: Union[UploadFile, None] = None,strength: Union[float, None] = None,vid2vid_startFrame: Union[int, None] = None, \ 73 | inpainting_image: Union[UploadFile, None] = None, inpainting_frames: Union[int, None] = None, inpainting_weights: Union[str, None] = None, \ 74 | fps: Union[int, None] = None, add_soundtrack: Union[str, None] = None, soundtrack_path: Union[str, None] = None, ): 75 | for basedir in basedirs: 76 | sys.path.extend([ 77 | basedir + '/scripts', 78 | basedir + '/extensions/sd-webui-text2video/scripts', 79 | basedir + '/extensions/sd-webui-modelscope-text2video/scripts', 80 | ]) 81 | 82 | locals_args_dict = locals() 83 | args_dict = T2VArgs() 84 | video_args_dict = T2VOutputArgs() 85 | for k, v in locals_args_dict.items(): 86 | if v is not None: 87 | if k in args_dict: 88 | args_dict[k] = locals_args_dict[k] 89 | elif k in video_args_dict: 90 | video_args_dict[k] = locals_args_dict[k] 91 | 92 | """ 93 | Run t2v over api 94 | @return: 95 | """ 96 | d = SimpleNamespace(**args_dict) 97 | dv = SimpleNamespace(**video_args_dict) 98 | 99 | tmp_inpainting = None 100 | tmp_inpainting_name = f'outputs/t2v_temp/{str(uuid.uuid4())}.png' 101 | tmp_vid2vid = None 102 | temp_vid2vid_name = f'outputs/t2v_temp/{str(uuid.uuid4())}.mp4' 103 | os.makedirs('outputs/t2v_temp', exist_ok=True) 104 | 105 | # Wrap the process call in a try-except block to handle potential errors 106 | try: 107 | T2VArgs_sanity_check(d) 108 | 109 | if d.inpainting_frames > 0 and inpainting_image: 110 | img_content = await inpainting_image.read() 111 | img = Image.open(io.BytesIO(img_content)) 112 | img.save(tmp_inpainting_name) 113 | tmp_inpainting = open(tmp_inpainting_name, "r") 114 | 115 | if do_vid2vid and vid2vid_input: 116 | vid2vid_input_content = await vid2vid_input.read() 117 | tmp_vid2vid = open(temp_vid2vid_name, "wb") 118 | tmp_vid2vid.write(io.BytesIO(vid2vid_input_content).getbuffer()) 119 | tmp_vid2vid.close() 120 | tmp_vid2vid = open(temp_vid2vid_name, "r") 121 | 122 | videodat = run( 123 | # ffmpeg params 124 | dv.skip_video_creation, #skip_video_creation 125 | find_ffmpeg_binary(), #ffmpeg_location 126 | dv.ffmpeg_crf, #ffmpeg_crf 127 | dv.ffmpeg_preset,#ffmpeg_preset 128 | dv.fps,#fps 129 | dv.add_soundtrack,#add_soundtrack 130 | dv.soundtrack_path,#soundtrack_paths 131 | 132 | d.prompt,#prompt 133 | d.n_prompt,#n_prompt 134 | d.sampler,#sampler 135 | d.steps,#steps 136 | d.frames,#frames 137 | d.seed,#seed 138 | d.cfg_scale,#cfg_scale 139 | d.width,#width 140 | d.height,#height 141 | d.eta,#eta 142 | d.batch_count,#batch_count 143 | 144 | # The same, but for vid2vid. Will deduplicate later 145 | d.prompt,#prompt 146 | d.n_prompt,#n_prompt 147 | d.sampler,#sampler 148 | d.steps,#steps 149 | d.frames,#frames 150 | d.seed,#seed 151 | d.cfg_scale,#cfg_scale 152 | d.width,#width 153 | d.height,#height 154 | d.eta,#eta 155 | d.batch_count,#batch_count_v 156 | 157 | do_vid2vid,#do_vid2vid 158 | tmp_vid2vid,#vid2vid_frames 159 | "",#vid2vid_frames_path 160 | d.strength,#strength 161 | d.vid2vid_startFrame,#vid2vid_startFrame 162 | tmp_inpainting,#inpainting_image 163 | d.inpainting_frames,#inpainting_frames 164 | d.inpainting_weights,#inpainting_weights 165 | "ModelScope",#model_type. Only one has stable support at this moment 166 | d.model, 167 | ) 168 | 169 | return JSONResponse(content={"mp4s": videodat}) 170 | except Exception as e: 171 | # Log the error and return a JSON response with an appropriate status code and error message 172 | logger.error(f"Error processing the video: {e}") 173 | traceback.print_exc() 174 | return JSONResponse( 175 | status_code=500, 176 | content={"detail": "An error occurred while processing the video."}, 177 | ) 178 | finally: 179 | if tmp_inpainting is not None: 180 | tmp_inpainting.close() 181 | # delete temporary file 182 | try: 183 | os.remove(tmp_inpainting_name) 184 | except Exception as e: 185 | ... 186 | except Exception as e: 187 | ... 188 | if tmp_vid2vid is not None: 189 | tmp_vid2vid.close() 190 | try: 191 | os.remove(temp_vid2vid_name) 192 | except Exception as e: 193 | ... 194 | 195 | 196 | try: 197 | import modules.script_callbacks as script_callbacks 198 | 199 | script_callbacks.on_app_started(t2v_api) 200 | logger.debug("SD-Webui text2video API layer loaded") 201 | except ImportError: 202 | logger.debug("Unable to import script callbacks.XXX") 203 | -------------------------------------------------------------------------------- /scripts/videocrafter/lvdm/models/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | import torch.nn.functional as F 4 | import os 5 | from einops import rearrange 6 | 7 | from videocrafter.lvdm.models.modules.autoencoder_modules import Encoder, Decoder 8 | from videocrafter.lvdm.models.modules.distributions import DiagonalGaussianDistribution 9 | from videocrafter.lvdm.utils.common_utils import instantiate_from_config 10 | 11 | class AutoencoderKL(pl.LightningModule): 12 | def __init__(self, 13 | ddconfig, 14 | lossconfig, 15 | embed_dim, 16 | ckpt_path=None, 17 | ignore_keys=[], 18 | image_key="image", 19 | colorize_nlabels=None, 20 | monitor=None, 21 | test=False, 22 | logdir=None, 23 | input_dim=4, 24 | test_args=None, 25 | ): 26 | super().__init__() 27 | self.image_key = image_key 28 | self.encoder = Encoder(**ddconfig) 29 | self.decoder = Decoder(**ddconfig) 30 | self.loss = instantiate_from_config(lossconfig) 31 | assert ddconfig["double_z"] 32 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) 33 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 34 | self.embed_dim = embed_dim 35 | self.input_dim = input_dim 36 | self.test = test 37 | self.test_args = test_args 38 | self.logdir = logdir 39 | if colorize_nlabels is not None: 40 | assert type(colorize_nlabels)==int 41 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 42 | if monitor is not None: 43 | self.monitor = monitor 44 | if ckpt_path is not None: 45 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 46 | if self.test: 47 | self.init_test() 48 | 49 | def init_test(self,): 50 | self.test = True 51 | save_dir = os.path.join(self.logdir, "test") 52 | if 'ckpt' in self.test_args: 53 | ckpt_name = os.path.basename(self.test_args.ckpt).split('.ckpt')[0] + f'_epoch{self._cur_epoch}' 54 | self.root = os.path.join(save_dir, ckpt_name) 55 | else: 56 | self.root = save_dir 57 | if 'test_subdir' in self.test_args: 58 | self.root = os.path.join(save_dir, self.test_args.test_subdir) 59 | 60 | self.root_zs = os.path.join(self.root, "zs") 61 | self.root_dec = os.path.join(self.root, "reconstructions") 62 | self.root_inputs = os.path.join(self.root, "inputs") 63 | os.makedirs(self.root, exist_ok=True) 64 | 65 | if self.test_args.save_z: 66 | os.makedirs(self.root_zs, exist_ok=True) 67 | if self.test_args.save_reconstruction: 68 | os.makedirs(self.root_dec, exist_ok=True) 69 | if self.test_args.save_input: 70 | os.makedirs(self.root_inputs, exist_ok=True) 71 | assert(self.test_args is not None) 72 | self.test_maximum = getattr(self.test_args, 'test_maximum', None) #1500 # 12000/8 73 | self.count = 0 74 | self.eval_metrics = {} 75 | self.decodes = [] 76 | self.save_decode_samples = 2048 77 | 78 | def init_from_ckpt(self, path, ignore_keys=list()): 79 | sd = torch.load(path, map_location="cpu") 80 | try: 81 | self._cur_epoch = sd['epoch'] 82 | sd = sd["state_dict"] 83 | except: 84 | self._cur_epoch = 'null' 85 | keys = list(sd.keys()) 86 | for k in keys: 87 | for ik in ignore_keys: 88 | if k.startswith(ik): 89 | print("Deleting key {} from state_dict.".format(k)) 90 | del sd[k] 91 | self.load_state_dict(sd, strict=False) 92 | # self.load_state_dict(sd, strict=True) 93 | print(f"Restored from {path}") 94 | 95 | def encode(self, x, **kwargs): 96 | 97 | h = self.encoder(x) 98 | moments = self.quant_conv(h) 99 | posterior = DiagonalGaussianDistribution(moments) 100 | return posterior 101 | 102 | def decode(self, z, **kwargs): 103 | z = self.post_quant_conv(z) 104 | dec = self.decoder(z) 105 | return dec 106 | 107 | def forward(self, input, sample_posterior=True): 108 | posterior = self.encode(input) 109 | if sample_posterior: 110 | z = posterior.sample() 111 | else: 112 | z = posterior.mode() 113 | dec = self.decode(z) 114 | return dec, posterior 115 | 116 | def get_input(self, batch, k): 117 | x = batch[k] 118 | # if len(x.shape) == 3: 119 | # x = x[..., None] 120 | # if x.dim() == 4: 121 | # x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() 122 | if x.dim() == 5 and self.input_dim == 4: 123 | b,c,t,h,w = x.shape 124 | self.b = b 125 | self.t = t 126 | x = rearrange(x, 'b c t h w -> (b t) c h w') 127 | 128 | return x 129 | 130 | def training_step(self, batch, batch_idx, optimizer_idx): 131 | inputs = self.get_input(batch, self.image_key) 132 | reconstructions, posterior = self(inputs) 133 | 134 | if optimizer_idx == 0: 135 | # train encoder+decoder+logvar 136 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 137 | last_layer=self.get_last_layer(), split="train") 138 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 139 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) 140 | return aeloss 141 | 142 | if optimizer_idx == 1: 143 | # train the discriminator 144 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 145 | last_layer=self.get_last_layer(), split="train") 146 | 147 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 148 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) 149 | return discloss 150 | 151 | def validation_step(self, batch, batch_idx): 152 | inputs = self.get_input(batch, self.image_key) 153 | reconstructions, posterior = self(inputs) 154 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, 155 | last_layer=self.get_last_layer(), split="val") 156 | 157 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, 158 | last_layer=self.get_last_layer(), split="val") 159 | 160 | self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) 161 | self.log_dict(log_dict_ae) 162 | self.log_dict(log_dict_disc) 163 | return self.log_dict 164 | 165 | def configure_optimizers(self): 166 | lr = self.learning_rate 167 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ 168 | list(self.decoder.parameters())+ 169 | list(self.quant_conv.parameters())+ 170 | list(self.post_quant_conv.parameters()), 171 | lr=lr, betas=(0.5, 0.9)) 172 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), 173 | lr=lr, betas=(0.5, 0.9)) 174 | return [opt_ae, opt_disc], [] 175 | 176 | def get_last_layer(self): 177 | return self.decoder.conv_out.weight 178 | 179 | @torch.no_grad() 180 | def log_images(self, batch, only_inputs=False, **kwargs): 181 | log = dict() 182 | x = self.get_input(batch, self.image_key) 183 | x = x.to(self.device) 184 | if not only_inputs: 185 | xrec, posterior = self(x) 186 | if x.shape[1] > 3: 187 | # colorize with random projection 188 | assert xrec.shape[1] > 3 189 | x = self.to_rgb(x) 190 | xrec = self.to_rgb(xrec) 191 | log["samples"] = self.decode(torch.randn_like(posterior.sample())) 192 | log["reconstructions"] = xrec 193 | log["inputs"] = x 194 | return log 195 | 196 | def to_rgb(self, x): 197 | assert self.image_key == "segmentation" 198 | if not hasattr(self, "colorize"): 199 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 200 | x = F.conv2d(x, weight=self.colorize) 201 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1. 202 | return x 203 | -------------------------------------------------------------------------------- /scripts/stable_lora/scripts/lora_webui.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import glob 3 | 4 | from safetensors.torch import load_file 5 | from types import SimpleNamespace 6 | from safetensors import safe_open 7 | from einops import rearrange 8 | import gradio as gr 9 | import os 10 | import json 11 | 12 | from modules import images, script_callbacks 13 | from modules.shared import opts, state, cmd_opts 14 | from stable_lora.stable_utils.lora_processor import StableLoraProcessor 15 | from t2v_helpers.extensions_utils import Text2VideoExtension 16 | 17 | EXTENSION_TITLE = "Stable LoRA" 18 | EXTENSION_NAME = EXTENSION_TITLE.replace(' ', '_').lower() 19 | 20 | gr_inputs_list = [ 21 | "lora_file_selection", 22 | "lora_alpha", 23 | "refresh_button", 24 | "use_bias", 25 | "use_linear", 26 | "use_conv", 27 | "use_emb", 28 | "use_time", 29 | "use_multiplier" 30 | ] 31 | 32 | gr_inputs_dict = {v: v for v in gr_inputs_list} 33 | GradioInputsIds = SimpleNamespace(**gr_inputs_dict) 34 | 35 | class StableLoraScript(Text2VideoExtension, StableLoraProcessor): 36 | 37 | def __init__(self): 38 | StableLoraProcessor.__init__(self) 39 | Text2VideoExtension.__init__(self, EXTENSION_NAME, EXTENSION_TITLE) 40 | self.device = 'cuda' 41 | self.dtype = torch.float16 42 | 43 | def title(self): 44 | return EXTENSION_TITLE 45 | 46 | def refresh_models(self, *args): 47 | paths_with_metadata, lora_names = self.get_lora_files() 48 | self.lora_files = paths_with_metadata.copy() 49 | 50 | return gr.Dropdown.update(value=[], choices=lora_names) 51 | 52 | def ui(self): 53 | paths_with_metadata, lora_names = self.get_lora_files() 54 | self.lora_files = paths_with_metadata.copy() 55 | REPOSITORY_LINK = "https://github.com/ExponentialML/Text-To-Video-Finetuning" 56 | 57 | with gr.Accordion(label=EXTENSION_TITLE, open=False) as stable_lora_section: 58 | with gr.Blocks(analytics_enabled=False): 59 | with gr.Row(): 60 | with gr.Column(): 61 | gr.HTML("

Load a Trained LoRA File.

") 62 | gr.HTML( 63 | """ 64 |

65 | Only Stable LoRA files are supported. 66 |

67 | """ 68 | ) 69 | gr.HTML(f""" 70 | 71 | To train a Stable LoRA file, use the finetune repository by clicking here. 72 | """ 73 | ) 74 | gr.HTML(f" Place your LoRA files in {cmd_opts.lora_dir}") 75 | lora_files_selection = gr.Dropdown( 76 | label="Available Models", 77 | elem_id=GradioInputsIds.lora_file_selection, 78 | choices=lora_names, 79 | value=[], 80 | multiselect=True, 81 | ) 82 | lora_alpha = gr.Slider( 83 | minimum=0, 84 | maximum=1, 85 | value=1, 86 | step=0.05, 87 | elem_id=GradioInputsIds.lora_alpha, 88 | label="LoRA Weight" 89 | ) 90 | refresh_button = gr.Button( 91 | value="Refresh Models", 92 | elem_id=GradioInputsIds.refresh_button 93 | ) 94 | refresh_button.click( 95 | self.refresh_models, 96 | lora_files_selection, 97 | lora_files_selection 98 | ) 99 | with gr.Accordion(label="Advanced Settings", open=False, visible=False): 100 | with gr.Column(): 101 | use_bias = gr.Checkbox(label="Enable Bias", elem_id=GradioInputsIds.use_bias, value=lambda: True) 102 | use_linear = gr.Checkbox(label="Enable Linears", elem_id=GradioInputsIds.use_linear, value=lambda: True) 103 | use_conv = gr.Checkbox(label="Enable Convolutions", elem_id=GradioInputsIds.use_conv, value=lambda: True) 104 | use_emb = gr.Checkbox(label="Enable Embeddings", elem_id=GradioInputsIds.use_emb, value=lambda: True) 105 | use_time = gr.Checkbox(label="Enable Time", elem_id=GradioInputsIds.use_time, value=lambda: True) 106 | with gr.Column(): 107 | use_multiplier = gr.Number( 108 | label="Alpha Multiplier", 109 | elem_id=GradioInputsIds.use_multiplier, 110 | value=1, 111 | ) 112 | 113 | 114 | return self.return_ui_inputs( 115 | return_args=[ 116 | lora_files_selection, 117 | lora_alpha, 118 | use_bias, 119 | use_linear, 120 | use_conv, 121 | use_emb, 122 | use_multiplier, 123 | use_time 124 | ] 125 | ) 126 | 127 | @torch.no_grad() 128 | def process( 129 | self, 130 | p, 131 | lora_files_selection, 132 | lora_alpha, 133 | use_bias, 134 | use_linear, 135 | use_conv, 136 | use_emb, 137 | use_multiplier, 138 | use_time 139 | ): 140 | 141 | # Get the list of LoRA files based off of filepath. 142 | lora_file_names = [x for x in lora_files_selection if x != "None"] 143 | 144 | if len(self.lora_files) <= 0: 145 | paths_with_metadata, lora_names = self.get_lora_files() 146 | self.lora_files = paths_with_metadata.copy() 147 | 148 | lora_files = self.get_loras_to_process(lora_file_names) 149 | 150 | # Load multiple LoRAs 151 | lora_files_list = [] 152 | 153 | # Load our advanced options in a list 154 | advanced_options = [ 155 | use_bias, 156 | use_linear, 157 | use_conv, 158 | use_emb, 159 | use_multiplier, 160 | use_time 161 | ] 162 | 163 | # Save the previous alpha value so we can re-run the LoRA with new values. 164 | alpha_changed = self.handle_alpha_change(lora_alpha, p.sd_model) 165 | 166 | # If an advanced option changes, re-run with new options 167 | options_changed = self.handle_options_change(advanced_options, p.sd_model) 168 | 169 | # Check if we changed our LoRA models we are loading 170 | lora_changed = self.previous_lora_file_names != lora_file_names 171 | 172 | first_lora_init = not self.is_lora_loaded(p.sd_model) 173 | 174 | # If the LoRA is still loaded, unload it. 175 | unload_args = [p.sd_model, None, use_bias, use_time, use_conv, use_emb, use_linear, None] 176 | self.handle_lora_start(lora_files, p.sd_model, unload_args) 177 | 178 | can_use_lora = self.can_use_lora(p.sd_model) 179 | 180 | lora_params_changed = any([alpha_changed, lora_changed, options_changed]) 181 | 182 | # Process LoRA 183 | if can_use_lora or lora_params_changed: 184 | 185 | if len(lora_files) == 0: return 186 | 187 | for i, model in enumerate([p.sd_model, p.clip_encoder.model.transformer]): 188 | lora_alpha = (lora_alpha * use_multiplier) / len(lora_files) 189 | 190 | lora_files_list = self.load_loras_from_list(lora_files) 191 | 192 | args = [model, lora_files_list, use_bias, use_time, use_conv, use_emb, use_linear, lora_alpha] 193 | 194 | if lora_params_changed and not first_lora_init : 195 | if i == 0: 196 | self.log("Resetting weights to reflect changed options.") 197 | 198 | undo_args = args.copy() 199 | undo_args[1], undo_args[-1] = self.undo_merge_preprocess() 200 | 201 | self.process_lora(*undo_args, undo_merge=True) 202 | 203 | self.process_lora(*args, undo_merge=False) 204 | 205 | self.handle_after_lora_load( 206 | p.sd_model, 207 | lora_files, 208 | lora_file_names, 209 | advanced_options, 210 | lora_alpha 211 | ) 212 | 213 | if len(lora_files) > 0 and not all([can_use_lora, lora_params_changed]): 214 | self.log(f"Using loaded LoRAs: {', '.join(lora_file_names)}") 215 | 216 | StableLoraScriptInstance = StableLoraScript() 217 | -------------------------------------------------------------------------------- /scripts/stable_lora/stable_utils/lora_processor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import torch 4 | 5 | from safetensors.torch import load_file 6 | from safetensors import safe_open 7 | from modules.shared import opts, cmd_opts, state 8 | 9 | class StableLoraProcessor: 10 | def __init__(self): 11 | self.lora_loaded = 'lora_loaded' 12 | self.previous_lora_alpha = 1 13 | self.current_sd_checkpoint = "" 14 | self.previous_lora_file_names = [] 15 | self.previous_advanced_options = [] 16 | self.lora_files = [] 17 | 18 | def get_lora_files(self): 19 | paths_with_metadata = [] 20 | paths = glob.glob(os.path.join(cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) 21 | lora_names = [] 22 | 23 | for lora_path in paths: 24 | with safe_open(lora_path, 'pt') as lora_file: 25 | metadata = lora_file.metadata() 26 | if metadata is not None and 'stable_lora_text_to_video' in metadata.keys(): 27 | metadata['path'] = lora_path 28 | metadata['lora_name'] = os.path.splitext(os.path.basename(lora_path))[0] 29 | paths_with_metadata.append(metadata) 30 | 31 | if len(paths_with_metadata) > 0: 32 | lora_names = [x['lora_name'] for x in paths_with_metadata] 33 | 34 | return paths_with_metadata, lora_names 35 | 36 | def key_name_match(self, value, key, name): 37 | return value in key and name == key.split(f".{value}")[0] 38 | 39 | def is_lora_match(self, key, name): 40 | return self.key_name_match('lora_A', key, name) 41 | 42 | def is_bias_match(self, key, name): 43 | return self.key_name_match("bias", key, name) 44 | 45 | def lora_rank(self, weight): return min(weight.shape) 46 | 47 | def get_lora_alpha(self, alpha): 48 | return alpha 49 | 50 | def process_lora_weight(self, weight, lora_weight, alpha, undo_merge=False): 51 | new_weight = weight.detach().clone() 52 | 53 | if not undo_merge: 54 | new_weight += lora_weight.to(weight.device, weight.dtype) * alpha 55 | else: 56 | new_weight -= lora_weight.to(weight.device, weight.dtype) * alpha 57 | 58 | return torch.nn.Parameter(new_weight.to(weight.device, weight.dtype)) 59 | 60 | def lora_linear_forward( 61 | self, 62 | weight, 63 | lora_A, 64 | lora_B, 65 | alpha, 66 | undo_merge=False, 67 | *args 68 | ): 69 | l_alpha = self.get_lora_alpha(alpha) 70 | lora_weight = (lora_B @ lora_A) 71 | 72 | return self.process_lora_weight(weight, lora_weight, l_alpha, undo_merge=undo_merge) 73 | 74 | def lora_conv_forward( 75 | self, 76 | weight, 77 | lora_A, 78 | lora_B, 79 | alpha, 80 | undo_merge=False, 81 | is_temporal=False, 82 | *args 83 | ): 84 | l_alpha = self.get_lora_alpha(alpha) 85 | view_shape = weight.shape 86 | 87 | if is_temporal: 88 | i, o, k = weight.shape[:3] 89 | view_shape = (i, o, k, k, 1) 90 | 91 | lora_weight = (lora_B @ lora_A).view(view_shape) 92 | 93 | if is_temporal: 94 | lora_weight = torch.mean(lora_weight, dim=-2, keepdim=True) 95 | 96 | return self.process_lora_weight(weight, lora_weight, l_alpha, undo_merge=undo_merge) 97 | 98 | def lora_emb_forward(self, lora_A, lora_B, alpha, undo_merge=False, *args): 99 | l_alpha = self.get_lora_alpha(alpha) 100 | 101 | return (lora_B @ lora_A).transpose(0, 1) * l_alpha 102 | 103 | def is_lora_loaded(self, sd_model): 104 | return hasattr(sd_model, self.lora_loaded) 105 | 106 | def get_loras_to_process(self, lora_files): 107 | lora_files_to_load = [] 108 | 109 | for file_name in lora_files: 110 | if len(self.lora_files) > 0: 111 | for f in self.lora_files: 112 | if file_name == f['lora_name']: 113 | lora_files_to_load.append(f['path']) 114 | 115 | return lora_files_to_load 116 | 117 | def handle_lora_load( 118 | self, 119 | sd_model, 120 | lora_files_list, 121 | set_lora_loaded=False, 122 | unload_args=[] 123 | ): 124 | if not hasattr(sd_model, self.lora_loaded) and set_lora_loaded: 125 | setattr(sd_model, self.lora_loaded, True) 126 | 127 | if self.is_lora_loaded(sd_model) and not set_lora_loaded: 128 | unload_args[1], unload_args[-1] = self.undo_merge_preprocess() 129 | self.process_lora(*unload_args, undo_merge=True) 130 | delattr(sd_model, self.lora_loaded) 131 | 132 | def handle_alpha_change(self, lora_alpha, model): 133 | return (lora_alpha != self.previous_lora_alpha) \ 134 | and self.is_lora_loaded(model) 135 | 136 | def handle_options_change(self, options, model): 137 | return (options != self.previous_advanced_options) \ 138 | and self.is_lora_loaded(model) 139 | 140 | def handle_lora_start(self, lora_files, model, unload_args): 141 | if len(lora_files) == 0 and self.is_lora_loaded(model): 142 | self.handle_lora_load( 143 | model, 144 | lora_files, 145 | set_lora_loaded=False, 146 | unload_args=unload_args 147 | ) 148 | 149 | self.log(f"Unloaded previously loaded LoRA files") 150 | return 151 | 152 | def can_use_lora(self, model): 153 | return not self.is_lora_loaded(model) 154 | 155 | def load_loras_from_list(self, lora_files): 156 | lora_files_list = [] 157 | 158 | for lora_file in lora_files: 159 | LORA_FILE = lora_file.split('/')[-1] 160 | LORA_DIR = cmd_opts.lora_dir 161 | LORA_PATH = f"{LORA_DIR}/{LORA_FILE}" 162 | 163 | lora_model_text_path = f"{LORA_DIR}/text_{LORA_FILE}" 164 | lora_text_exists = os.path.exists(lora_model_text_path) 165 | 166 | is_safetensors = LORA_PATH.endswith('.safetensors') 167 | load_method = load_file if is_safetensors else torch.load 168 | 169 | lora_model = load_method(LORA_PATH) 170 | 171 | lora_files_list.append(lora_model) 172 | 173 | return lora_files_list 174 | 175 | def handle_after_lora_load( 176 | self, 177 | model, 178 | lora_files, 179 | lora_file_names, 180 | advanced_options, 181 | lora_alpha 182 | ): 183 | lora_summary = [] 184 | self.handle_lora_load(model, lora_files, set_lora_loaded=True) 185 | self.previous_lora_file_names = lora_file_names 186 | self.previous_advanced_options = advanced_options 187 | self.previous_lora_alpha = lora_alpha 188 | 189 | for lora_file_name in lora_file_names: 190 | if self.is_lora_loaded(model): 191 | lora_summary.append(f"{lora_file_name.split('.')[0]}") 192 | 193 | if len(lora_summary) > 0: 194 | self.log(f"Using {model.__class__.__name__} LoRA(s):", *lora_summary) 195 | 196 | def undo_merge_preprocess(self): 197 | previous_lora_files_list = self.get_loras_to_process(self.previous_lora_file_names) 198 | previous_lora_files = self.load_loras_from_list(previous_lora_files_list) 199 | 200 | return previous_lora_files, self.previous_lora_alpha 201 | 202 | @torch.autocast('cuda') 203 | def process_lora( 204 | self, 205 | model, 206 | lora_files_list, 207 | use_bias, 208 | use_time, 209 | use_conv, 210 | use_emb, 211 | use_linear, 212 | lora_alpha, 213 | undo_merge=False 214 | ): 215 | for n, m in model.named_modules(): 216 | for lora_model in lora_files_list: 217 | for k, v in lora_model.items(): 218 | 219 | # If there is bias in the LoRA, add it here. 220 | if self.is_bias_match(k, n) and use_bias: 221 | if m.bias is None: 222 | m.bias = torch.nn.Parameter(v.to(self.device, dtype=self.dtype)) 223 | else: 224 | m.bias.weight = v.to(self.device, dtype=self.dtype) 225 | 226 | if self.is_lora_match(k, n): 227 | lora_A = lora_model[f"{n}.lora_A"].to(self.device, dtype=self.dtype) 228 | lora_B = lora_model[f"{n}.lora_B"].to(self.device, dtype=self.dtype) 229 | 230 | forward_args = [m.weight, lora_A, lora_B, lora_alpha] 231 | 232 | if isinstance(m, torch.nn.Linear) and use_linear: 233 | if 'proj' in n: 234 | forward_args[1], forward_args[2] = map(lambda l: l.squeeze(-1), (lora_A, lora_B)) 235 | 236 | m.weight = self.lora_linear_forward(*forward_args, undo_merge=undo_merge) 237 | 238 | if isinstance(m, torch.nn.Conv2d) and use_conv: 239 | m.weight = self.lora_conv_forward(*forward_args, undo_merge=undo_merge, is_temporal=False) 240 | 241 | if isinstance(m, torch.nn.Conv3d) and use_conv and use_time: 242 | m.weight = self.lora_conv_forward(*forward_args, undo_merge=undo_merge, is_temporal=True) 243 | 244 | if isinstance(m, torch.nn.Embedding) and use_emb: 245 | embedding_weight = self.lora_emb_forward(lora_A, lora_B, lora_alpha, undo_merge=undo_merge) 246 | new_embedding_weight = torch.nn.Embedding.from_pretrained(embedding_weight) 247 | -------------------------------------------------------------------------------- /scripts/videocrafter/sample_text2video_adapter.py: -------------------------------------------------------------------------------- 1 | import argparse, os, sys, glob 2 | import datetime, time 3 | from omegaconf import OmegaConf 4 | 5 | import torch 6 | from decord import VideoReader, cpu 7 | import torchvision 8 | from pytorch_lightning import seed_everything 9 | 10 | from lvdm.samplers.ddim import DDIMSampler 11 | from lvdm.utils.common_utils import instantiate_from_config 12 | from lvdm.utils.saving_utils import tensor_to_mp4 13 | 14 | 15 | def get_filelist(data_dir, ext='*'): 16 | file_list = glob.glob(os.path.join(data_dir, '*.%s'%ext)) 17 | file_list.sort() 18 | return file_list 19 | 20 | def load_model_checkpoint(model, ckpt, adapter_ckpt=None): 21 | print('>>> Loading checkpoints ...') 22 | if adapter_ckpt: 23 | ## main model 24 | state_dict = torch.load(ckpt, map_location="cpu") 25 | if "state_dict" in list(state_dict.keys()): 26 | state_dict = state_dict["state_dict"] 27 | model.load_state_dict(state_dict, strict=False) 28 | print('@model checkpoint loaded.') 29 | ## adapter 30 | state_dict = torch.load(adapter_ckpt, map_location="cpu") 31 | if "state_dict" in list(state_dict.keys()): 32 | state_dict = state_dict["state_dict"] 33 | model.adapter.load_state_dict(state_dict, strict=True) 34 | print('@adapter checkpoint loaded.') 35 | else: 36 | state_dict = torch.load(ckpt, map_location="cpu") 37 | if "state_dict" in list(state_dict.keys()): 38 | state_dict = state_dict["state_dict"] 39 | model.load_state_dict(state_dict, strict=True) 40 | print('@model checkpoint loaded.') 41 | return model 42 | 43 | def load_prompts(prompt_file): 44 | f = open(prompt_file, 'r') 45 | prompt_list = [] 46 | for idx, line in enumerate(f.readlines()): 47 | l = line.strip() 48 | if len(l) != 0: 49 | prompt_list.append(l) 50 | f.close() 51 | return prompt_list 52 | 53 | def load_video(filepath, frame_stride, video_size=(256,256), video_frames=16): 54 | vidreader = VideoReader(filepath, ctx=cpu(0), width=video_size[1], height=video_size[0]) 55 | max_frames = len(vidreader) 56 | temp_stride = max_frames // video_frames if frame_stride == -1 else frame_stride 57 | if temp_stride * (video_frames-1) >= max_frames: 58 | print(f'Warning: default frame stride is used because the input video clip {max_frames} is not long enough.') 59 | temp_stride = max_frames // video_frames 60 | frame_indices = [temp_stride*i for i in range(video_frames)] 61 | frames = vidreader.get_batch(frame_indices) 62 | 63 | ## [t,h,w,c] -> [c,t,h,w] 64 | frame_tensor = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float() 65 | frame_tensor = (frame_tensor / 255. - 0.5) * 2 66 | return frame_tensor 67 | 68 | 69 | def save_results(prompt, samples, inputs, filename, realdir, fakedir, fps=10): 70 | ## save prompt 71 | prompt = prompt[0] if isinstance(prompt, list) else prompt 72 | path = os.path.join(realdir, "%s.txt"%filename) 73 | with open(path, 'w') as f: 74 | f.write(f'{prompt}') 75 | f.close() 76 | 77 | ## save video 78 | videos = [inputs, samples] 79 | savedirs = [realdir, fakedir] 80 | for idx, video in enumerate(videos): 81 | if video is None: 82 | continue 83 | # b,c,t,h,w 84 | video = video.detach().cpu() 85 | video = torch.clamp(video.float(), -1., 1.) 86 | n = video.shape[0] 87 | video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w 88 | frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n)) for framesheet in video] #[3, 1*h, n*w] 89 | grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w] 90 | grid = (grid + 1.0) / 2.0 91 | grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) 92 | path = os.path.join(savedirs[idx], "%s.mp4"%filename) 93 | torchvision.io.write_video(path, grid, fps=fps, video_codec='h264', options={'crf': '10'}) 94 | 95 | 96 | def adapter_guided_synthesis(model, prompts, videos, noise_shape, n_samples=1, ddim_steps=50, ddim_eta=1., \ 97 | unconditional_guidance_scale=1.0, unconditional_guidance_scale_temporal=None, **kwargs): 98 | ddim_sampler = DDIMSampler(model) 99 | 100 | batch_size = noise_shape[0] 101 | ## get condition embeddings (support single prompt only) 102 | if isinstance(prompts, str): 103 | prompts = [prompts] 104 | cond = model.get_learned_conditioning(prompts) 105 | if unconditional_guidance_scale != 1.0: 106 | prompts = batch_size * [""] 107 | uc = model.get_learned_conditioning(prompts) 108 | else: 109 | uc = None 110 | 111 | ## adapter features: process in 2D manner 112 | b, c, t, h, w = videos.shape 113 | extra_cond = model.get_batch_depth(videos, (h,w)) 114 | features_adapter = model.get_adapter_features(extra_cond) 115 | 116 | batch_variants = [] 117 | for _ in range(n_samples): 118 | if ddim_sampler is not None: 119 | samples, _ = ddim_sampler.sample(S=ddim_steps, 120 | conditioning=cond, 121 | batch_size=noise_shape[0], 122 | shape=noise_shape[1:], 123 | verbose=False, 124 | unconditional_guidance_scale=unconditional_guidance_scale, 125 | unconditional_conditioning=uc, 126 | eta=ddim_eta, 127 | temporal_length=noise_shape[2], 128 | conditional_guidance_scale_temporal=unconditional_guidance_scale_temporal, 129 | features_adapter=features_adapter, 130 | **kwargs 131 | ) 132 | ## reconstruct from latent to pixel space 133 | batch_images = model.decode_first_stage(samples, decode_bs=1, return_cpu=False) 134 | batch_variants.append(batch_images) 135 | ## variants, batch, c, t, h, w 136 | batch_variants = torch.stack(batch_variants) 137 | return batch_variants.permute(1, 0, 2, 3, 4, 5), extra_cond 138 | 139 | 140 | def run_inference(args, gpu_idx): 141 | ## model config 142 | config = OmegaConf.load(args.base) 143 | model_config = config.pop("model", OmegaConf.create()) 144 | model = instantiate_from_config(model_config) 145 | model = model.cuda(gpu_idx) 146 | assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!" 147 | model = load_model_checkpoint(model, args.ckpt_path, args.adapter_ckpt) 148 | model.eval() 149 | 150 | ## run over data 151 | assert (args.height % 16 == 0) and (args.width % 16 == 0), "Error: image size [h,w] should be multiples of 16!" 152 | ## latent noise shape 153 | h, w = args.height // 8, args.width // 8 154 | channels = model.channels 155 | frames = model.temporal_length 156 | noise_shape = [args.bs, channels, args.num_frames, h, w] 157 | 158 | ## inference 159 | start = time.time() 160 | prompt = args.prompt 161 | video = load_video(args.video, args.frame_stride, video_size=(args.height, args.width), video_frames=args.num_frames) 162 | video = video.unsqueeze(0).to("cuda") 163 | with torch.no_grad(): 164 | batch_samples, batch_conds = adapter_guided_synthesis(model, prompt, video, noise_shape, args.n_samples, args.ddim_steps, args.ddim_eta, \ 165 | args.unconditional_guidance_scale, args.unconditional_guidance_scale_temporal) 166 | batch_samples = batch_samples[0] 167 | os.makedirs(args.savedir, exist_ok=True) 168 | filename = f"{args.prompt}_seed{args.seed}" 169 | filename = filename.replace("/", "_slash_") if "/" in filename else filename 170 | filename = filename.replace(" ", "_") if " " in filename else filename 171 | tensor_to_mp4(video=batch_conds.detach().cpu(), savepath=os.path.join(args.savedir, f'{filename}_depth.mp4'), fps=10) 172 | tensor_to_mp4(video=batch_samples.detach().cpu(), savepath=os.path.join(args.savedir, f'{filename}_sample.mp4'), fps=10) 173 | 174 | print(f"Saved in {args.savedir}. Time used: {(time.time() - start):.2f} seconds") 175 | 176 | 177 | def get_parser(): 178 | parser = argparse.ArgumentParser() 179 | parser.add_argument("--savedir", type=str, default=None, help="results saving path") 180 | parser.add_argument("--ckpt_path", type=str, default=None, help="checkpoint path") 181 | parser.add_argument("--adapter_ckpt", type=str, default=None, help="adapter checkpoint path") 182 | parser.add_argument("--base", type=str, help="config (yaml) path") 183 | parser.add_argument("--prompt", type=str, default=None, help="prompt string") 184 | parser.add_argument("--video", type=str, default=None, help="video path") 185 | parser.add_argument("--n_samples", type=int, default=1, help="num of samples per prompt",) 186 | parser.add_argument("--ddim_steps", type=int, default=50, help="steps of ddim if positive, otherwise use DDPM",) 187 | parser.add_argument("--ddim_eta", type=float, default=1.0, help="eta for ddim sampling (0.0 yields deterministic sampling)",) 188 | parser.add_argument("--bs", type=int, default=1, help="batch size for inference") 189 | parser.add_argument("--height", type=int, default=512, help="image height, in pixel space") 190 | parser.add_argument("--width", type=int, default=512, help="image width, in pixel space") 191 | parser.add_argument("--frame_stride", type=int, default=-1, help="frame extracting from input video") 192 | parser.add_argument("--unconditional_guidance_scale", type=float, default=1.0, help="prompt classifier-free guidance") 193 | parser.add_argument("--unconditional_guidance_scale_temporal", type=float, default=None, help="temporal consistency guidance") 194 | parser.add_argument("--seed", type=int, default=2023, help="seed for seed_everything") 195 | parser.add_argument("--num_frames", type=int, default=16, help="number of input frames") 196 | return parser 197 | 198 | 199 | if __name__ == '__main__': 200 | now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") 201 | print("@CoVideoGen cond-Inference: %s"%now) 202 | parser = get_parser() 203 | args = parser.parse_args() 204 | 205 | seed_everything(args.seed) 206 | rank = 0 207 | run_inference(args, rank) -------------------------------------------------------------------------------- /scripts/videocrafter/lvdm/utils/saving_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | import time 5 | import imageio 6 | from tqdm import tqdm 7 | from PIL import Image 8 | import os 9 | import sys 10 | sys.path.insert(1, os.path.join(sys.path[0], '..')) 11 | import torch 12 | import torchvision 13 | from torchvision.utils import make_grid 14 | from torch import Tensor 15 | from torchvision.transforms.functional import to_tensor 16 | 17 | 18 | def tensor_to_mp4(video, savepath, fps, rescale=True, nrow=None): 19 | """ 20 | video: torch.Tensor, b,c,t,h,w, 0-1 21 | if -1~1, enable rescale=True 22 | """ 23 | n = video.shape[0] 24 | video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w 25 | nrow = int(np.sqrt(n)) if nrow is None else nrow 26 | frame_grids = [torchvision.utils.make_grid(framesheet, nrow=nrow) for framesheet in video] # [3, grid_h, grid_w] 27 | grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [T, 3, grid_h, grid_w] 28 | grid = torch.clamp(grid.float(), -1., 1.) 29 | if rescale: 30 | grid = (grid + 1.0) / 2.0 31 | grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) # [T, 3, grid_h, grid_w] -> [T, grid_h, grid_w, 3] 32 | #print(f'Save video to {savepath}') 33 | torchvision.io.write_video(savepath, grid, fps=fps, video_codec='h264', options={'crf': '10'}) 34 | 35 | # ---------------------------------------------------------------------------------------------- 36 | def savenp2sheet(imgs, savepath, nrow=None): 37 | """ save multiple imgs (in numpy array type) to a img sheet. 38 | img sheet is one row. 39 | 40 | imgs: 41 | np array of size [N, H, W, 3] or List[array] with array size = [H,W,3] 42 | """ 43 | if imgs.ndim == 4: 44 | img_list = [imgs[i] for i in range(imgs.shape[0])] 45 | imgs = img_list 46 | 47 | imgs_new = [] 48 | for i, img in enumerate(imgs): 49 | if img.ndim == 3 and img.shape[0] == 3: 50 | img = np.transpose(img,(1,2,0)) 51 | 52 | assert(img.ndim == 3 and img.shape[-1] == 3), img.shape # h,w,3 53 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 54 | imgs_new.append(img) 55 | n = len(imgs) 56 | if nrow is not None: 57 | n_cols = nrow 58 | else: 59 | n_cols=int(n**0.5) 60 | n_rows=int(np.ceil(n/n_cols)) 61 | print(n_cols) 62 | print(n_rows) 63 | 64 | imgsheet = cv2.vconcat([cv2.hconcat(imgs_new[i*n_cols:(i+1)*n_cols]) for i in range(n_rows)]) 65 | cv2.imwrite(savepath, imgsheet) 66 | print(f'saved in {savepath}') 67 | 68 | # ---------------------------------------------------------------------------------------------- 69 | def save_np_to_img(img, path, norm=True): 70 | if norm: 71 | img = (img + 1) / 2 * 255 72 | img = img.astype(np.uint8) 73 | image = Image.fromarray(img) 74 | image.save(path, q=95) 75 | 76 | # ---------------------------------------------------------------------------------------------- 77 | def npz_to_imgsheet_5d(data_path, res_dir, nrow=None,): 78 | if isinstance(data_path, str): 79 | imgs = np.load(data_path)['arr_0'] # NTHWC 80 | elif isinstance(data_path, np.ndarray): 81 | imgs = data_path 82 | else: 83 | raise Exception 84 | 85 | if os.path.isdir(res_dir): 86 | res_path = os.path.join(res_dir, f'samples.jpg') 87 | else: 88 | assert(res_dir.endswith('.jpg')) 89 | res_path = res_dir 90 | imgs = np.concatenate([imgs[i] for i in range(imgs.shape[0])], axis=0) 91 | savenp2sheet(imgs, res_path, nrow=nrow) 92 | 93 | # ---------------------------------------------------------------------------------------------- 94 | def npz_to_imgsheet_4d(data_path, res_path, nrow=None,): 95 | if isinstance(data_path, str): 96 | imgs = np.load(data_path)['arr_0'] # NHWC 97 | elif isinstance(data_path, np.ndarray): 98 | imgs = data_path 99 | else: 100 | raise Exception 101 | print(imgs.shape) 102 | savenp2sheet(imgs, res_path, nrow=nrow) 103 | 104 | 105 | # ---------------------------------------------------------------------------------------------- 106 | def tensor_to_imgsheet(tensor, save_path): 107 | """ 108 | save a batch of videos in one image sheet with shape of [batch_size * num_frames]. 109 | data: [b,c,t,h,w] 110 | """ 111 | assert(tensor.dim() == 5) 112 | b,c,t,h,w = tensor.shape 113 | imgs = [tensor[bi,:,ti, :, :] for bi in range(b) for ti in range(t)] 114 | torchvision.utils.save_image(imgs, save_path, normalize=True, nrow=t) 115 | 116 | 117 | # ---------------------------------------------------------------------------------------------- 118 | def npz_to_frames(data_path, res_dir, norm, num_frames=None, num_samples=None): 119 | start = time.time() 120 | arr = np.load(data_path) 121 | imgs = arr['arr_0'] # [N, T, H, W, 3] 122 | print('original data shape: ', imgs.shape) 123 | 124 | if num_samples is not None: 125 | imgs = imgs[:num_samples, :, :, :, :] 126 | print('after sample selection: ', imgs.shape) 127 | 128 | if num_frames is not None: 129 | imgs = imgs[:, :num_frames, :, :, :] 130 | print('after frame selection: ', imgs.shape) 131 | 132 | for vid in tqdm(range(imgs.shape[0]), desc='Video'): 133 | video_dir = os.path.join(res_dir, f'video{vid:04d}') 134 | os.makedirs(video_dir, exist_ok=True) 135 | for fid in range(imgs.shape[1]): 136 | frame = imgs[vid, fid, :, :, :] #HW3 137 | save_np_to_img(frame, os.path.join(video_dir, f'frame{fid:04d}.jpg'), norm=norm) 138 | print('Finish') 139 | print(f'Total time = {time.time()- start}') 140 | 141 | # ---------------------------------------------------------------------------------------------- 142 | def npz_to_gifs(data_path, res_dir, duration=0.2, start_idx=0, num_videos=None, mode='gif'): 143 | os.makedirs(res_dir, exist_ok=True) 144 | if isinstance(data_path, str): 145 | imgs = np.load(data_path)['arr_0'] # NTHWC 146 | elif isinstance(data_path, np.ndarray): 147 | imgs = data_path 148 | else: 149 | raise Exception 150 | 151 | for i in range(imgs.shape[0]): 152 | frames = [imgs[i,j,:,:,:] for j in range(imgs[i].shape[0])] # [(h,w,3)] 153 | if mode == 'gif': 154 | imageio.mimwrite(os.path.join(res_dir, f'samples_{start_idx+i}.gif'), frames, format='GIF', duration=duration) 155 | elif mode == 'mp4': 156 | frames = [torch.from_numpy(frame) for frame in frames] 157 | frames = torch.stack(frames, dim=0).to(torch.uint8) # [T, H, W, C] 158 | torchvision.io.write_video(os.path.join(res_dir, f'samples_{start_idx+i}.mp4'), 159 | frames, fps=0.5, video_codec='h264', options={'crf': '10'}) 160 | if i+ 1 == num_videos: 161 | break 162 | 163 | # ---------------------------------------------------------------------------------------------- 164 | def fill_with_black_squares(video, desired_len: int) -> Tensor: 165 | if len(video) >= desired_len: 166 | return video 167 | 168 | return torch.cat([ 169 | video, 170 | torch.zeros_like(video[0]).unsqueeze(0).repeat(desired_len - len(video), 1, 1, 1), 171 | ], dim=0) 172 | 173 | # ---------------------------------------------------------------------------------------------- 174 | def load_num_videos(data_path, num_videos): 175 | # data_path can be either data_path of np array 176 | if isinstance(data_path, str): 177 | videos = np.load(data_path)['arr_0'] # NTHWC 178 | elif isinstance(data_path, np.ndarray): 179 | videos = data_path 180 | else: 181 | raise Exception 182 | 183 | if num_videos is not None: 184 | videos = videos[:num_videos, :, :, :, :] 185 | return videos 186 | 187 | # ---------------------------------------------------------------------------------------------- 188 | def npz_to_video_grid(data_path, out_path, num_frames=None, fps=8, num_videos=None, nrow=None, verbose=True): 189 | if isinstance(data_path, str): 190 | videos = load_num_videos(data_path, num_videos) 191 | elif isinstance(data_path, np.ndarray): 192 | videos = data_path 193 | else: 194 | raise Exception 195 | n,t,h,w,c = videos.shape 196 | 197 | videos_th = [] 198 | for i in range(n): 199 | video = videos[i, :,:,:,:] 200 | images = [video[j, :,:,:] for j in range(t)] 201 | images = [to_tensor(img) for img in images] 202 | video = torch.stack(images) 203 | videos_th.append(video) 204 | 205 | if num_frames is None: 206 | num_frames = videos.shape[1] 207 | if verbose: 208 | videos = [fill_with_black_squares(v, num_frames) for v in tqdm(videos_th, desc='Adding empty frames')] # NTCHW 209 | else: 210 | videos = [fill_with_black_squares(v, num_frames) for v in videos_th] # NTCHW 211 | 212 | frame_grids = torch.stack(videos).permute(1, 0, 2, 3, 4) # [T, N, C, H, W] 213 | if nrow is None: 214 | nrow = int(np.ceil(np.sqrt(n))) 215 | if verbose: 216 | frame_grids = [make_grid(fs, nrow=nrow) for fs in tqdm(frame_grids, desc='Making grids')] 217 | else: 218 | frame_grids = [make_grid(fs, nrow=nrow) for fs in frame_grids] 219 | 220 | if os.path.dirname(out_path) != "": 221 | os.makedirs(os.path.dirname(out_path), exist_ok=True) 222 | frame_grids = (torch.stack(frame_grids) * 255).to(torch.uint8).permute(0, 2, 3, 1) # [T, H, W, C] 223 | torchvision.io.write_video(out_path, frame_grids, fps=fps, video_codec='h264', options={'crf': '10'}) 224 | 225 | # ---------------------------------------------------------------------------------------------- 226 | def npz_to_gif_grid(data_path, out_path, n_cols=None, num_videos=20): 227 | arr = np.load(data_path) 228 | imgs = arr['arr_0'] # [N, T, H, W, 3] 229 | imgs = imgs[:num_videos] 230 | n, t, h, w, c = imgs.shape 231 | assert(n == num_videos) 232 | n_cols = n_cols if n_cols else imgs.shape[0] 233 | n_rows = np.ceil(imgs.shape[0] / n_cols).astype(np.int8) 234 | H, W = h * n_rows, w * n_cols 235 | grid = np.zeros((t, H, W, c), dtype=np.uint8) 236 | 237 | for i in range(n_rows): 238 | for j in range(n_cols): 239 | if i*n_cols+j < imgs.shape[0]: 240 | grid[:, i*h:(i+1)*h, j*w:(j+1)*w, :] = imgs[i*n_cols+j, :, :, :, :] 241 | 242 | videos = [grid[i] for i in range(grid.shape[0])] # grid: TH'W'C 243 | imageio.mimwrite(out_path, videos, format='GIF', duration=0.5,palettesize=256) 244 | 245 | 246 | # ---------------------------------------------------------------------------------------------- 247 | def torch_to_video_grid(videos, out_path, num_frames, fps, num_videos=None, nrow=None, verbose=True): 248 | """ 249 | videos: -1 ~ 1, torch.Tensor, BCTHW 250 | """ 251 | n,t,h,w,c = videos.shape 252 | videos_th = [videos[i, ...] for i in range(n)] 253 | if verbose: 254 | videos = [fill_with_black_squares(v, num_frames) for v in tqdm(videos_th, desc='Adding empty frames')] # NTCHW 255 | else: 256 | videos = [fill_with_black_squares(v, num_frames) for v in videos_th] # NTCHW 257 | 258 | frame_grids = torch.stack(videos).permute(1, 0, 2, 3, 4) # [T, N, C, H, W] 259 | if nrow is None: 260 | nrow = int(np.ceil(np.sqrt(n))) 261 | if verbose: 262 | frame_grids = [make_grid(fs, nrow=nrow) for fs in tqdm(frame_grids, desc='Making grids')] 263 | else: 264 | frame_grids = [make_grid(fs, nrow=nrow) for fs in frame_grids] 265 | 266 | if os.path.dirname(out_path) != "": 267 | os.makedirs(os.path.dirname(out_path), exist_ok=True) 268 | frame_grids = ((torch.stack(frame_grids) + 1) / 2 * 255).to(torch.uint8).permute(0, 2, 3, 1) # [T, H, W, C] 269 | torchvision.io.write_video(out_path, frame_grids, fps=fps, video_codec='h264', options={'crf': '10'}) 270 | -------------------------------------------------------------------------------- /scripts/samplers/ddim/gaussian_sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from modelscope.t2v_model import _i 3 | from t2v_helpers.general_utils import reconstruct_conds 4 | 5 | class GaussianDiffusion(object): 6 | r""" Diffusion Model for DDIM. 7 | "Denoising diffusion implicit models." by Song, Jiaming, Chenlin Meng, and Stefano Ermon. 8 | See https://arxiv.org/abs/2010.02502 9 | """ 10 | 11 | def __init__(self, 12 | model, 13 | betas, 14 | mean_type='eps', 15 | var_type='learned_range', 16 | loss_type='mse', 17 | epsilon=1e-12, 18 | rescale_timesteps=False, 19 | **kwargs): 20 | 21 | # check input 22 | self.check_input_vars(betas, mean_type, var_type, loss_type) 23 | 24 | self.model = model 25 | self.betas = betas 26 | self.num_timesteps = len(betas) 27 | self.mean_type = mean_type 28 | self.var_type = var_type 29 | self.loss_type = loss_type 30 | self.epsilon = epsilon 31 | self.rescale_timesteps = rescale_timesteps 32 | 33 | # alphas 34 | alphas = 1 - self.betas 35 | self.alphas_cumprod = torch.cumprod(alphas, dim=0) 36 | self.alphas_cumprod_prev = torch.cat([alphas.new_ones([1]), self.alphas_cumprod[:-1]]) 37 | self.alphas_cumprod_next = torch.cat([self.alphas_cumprod[1:],alphas.new_zeros([1])]) 38 | 39 | # q(x_t | x_{t-1}) 40 | self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) 41 | self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod) 42 | self.log_one_minus_alphas_cumprod = torch.log(1.0 - self.alphas_cumprod) 43 | self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod) 44 | self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod - 1) 45 | 46 | # q(x_{t-1} | x_t, x_0) 47 | self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 48 | self.posterior_log_variance_clipped = torch.log(self.posterior_variance.clamp(1e-20)) 49 | self.posterior_mean_coef1 = betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 50 | self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - self.alphas_cumprod) 51 | 52 | def check_input_vars(self, betas, mean_type, var_type, loss_type): 53 | mean_types = ['x0', 'x_{t-1}', 'eps'] 54 | var_types = ['learned', 'learned_range', 'fixed_large', 'fixed_small'] 55 | loss_types = ['mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1','charbonnier'] 56 | 57 | if not isinstance(betas, torch.DoubleTensor): 58 | betas = torch.tensor(betas, dtype=torch.float64) 59 | 60 | assert min(betas) > 0 and max(betas) <= 1 61 | assert mean_type in mean_types 62 | assert var_type in var_types 63 | assert loss_type in loss_types 64 | 65 | def validate_model_kwargs(self, model_kwargs): 66 | """ 67 | Use the original implementation of passing model kwargs to the model. 68 | eg: model_kwargs=[{'y':c_i}, {'y':uc_i,}] 69 | """ 70 | if len(model_kwargs) > 0: 71 | assert isinstance(model_kwargs, list) and len(model_kwargs) == 2 72 | 73 | def get_time_steps(self, ddim_timesteps, batch_size=1, step=None): 74 | b = batch_size 75 | 76 | # Get thhe full timestep range 77 | arange_steps = (1 + torch.arange(0, self.num_timesteps, ddim_timesteps)) 78 | steps = arange_steps.clamp(0, self.num_timesteps - 1) 79 | timesteps = steps.flip(0).to(self.model.device) 80 | 81 | if step is not None: 82 | # Get the current timestep during a sample loop 83 | timesteps = torch.full((b, ), timesteps[step], dtype=torch.long, device=self.model.device) 84 | 85 | return timesteps 86 | 87 | def add_noise(self, xt, noise, t): 88 | noisy_sample = self.sqrt_alphas_cumprod[t.cpu()].to(self.model.device) * \ 89 | xt + noise * self.sqrt_one_minus_alphas_cumprod[t.cpu()].to(self.model.device) 90 | 91 | return noisy_sample 92 | 93 | def get_dim(self, y_out): 94 | is_fixed = self.var_type.startswith('fixed') 95 | return y_out.size(1) if is_fixed else y_out.size(1) // 2 96 | 97 | def fixed_small_variance(self, xt, t): 98 | var = _i(self.posterior_variance, t, xt) 99 | log_var = _i(self.posterior_log_variance_clipped, t, xt) 100 | 101 | return var, log_var 102 | 103 | def mean_x0(self, xt, t, x_out): 104 | x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i( 105 | self.sqrt_recipm1_alphas_cumprod, t, xt) * x_out 106 | mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) 107 | 108 | return x0, mu 109 | 110 | def restrict_range_x0(self, percentile, x0, clamp=False): 111 | if not clamp: 112 | assert percentile > 0 and percentile <= 1 # e.g., 0.995 113 | s = torch.quantile(x0.flatten(1).abs(), percentile,dim=1) 114 | s.clamp_(1.0).view(-1, 1, 1, 1) 115 | 116 | x0 = torch.min(s, torch.max(-s, x0)) / s 117 | else: 118 | x0 = x0.clamp(-clamp, clamp) 119 | 120 | return x0 121 | 122 | def is_unconditional(self, guide_scale): 123 | return guide_scale is None or guide_scale == 1 124 | 125 | def do_classifier_guidance(self, y_out, u_out, guidance_scale): 126 | """ 127 | y_out: Condition 128 | u_out: Unconditional 129 | """ 130 | dim = self.get_dim(y_out) 131 | a = u_out[:, :dim] 132 | b = guidance_scale * (y_out[:, :dim] - u_out[:, :dim]) 133 | c = y_out[:, dim:] 134 | out = torch.cat([a + b, c], dim=1) 135 | 136 | return out 137 | 138 | def p_mean_variance(self, 139 | xt, 140 | t, 141 | model_kwargs={}, 142 | clamp=None, 143 | percentile=None, 144 | guide_scale=None, 145 | conditioning=None, 146 | unconditional_conditioning=None, 147 | only_x0=True, 148 | **kwargs): 149 | r"""Distribution of p(x_{t-1} | x_t).""" 150 | 151 | # predict distribution 152 | if self.is_unconditional(guide_scale): 153 | out = self.model(xt, self._scale_timesteps(t), conditioning) 154 | else: 155 | # classifier-free guidance 156 | if model_kwargs != {}: 157 | self.validate_model_kwargs(model_kwargs) 158 | conditioning = model_kwargs[0] 159 | unconditional_conditioning = model_kwargs[1] 160 | 161 | y_out = self.model(xt, self._scale_timesteps(t), conditioning) 162 | u_out = self.model(xt, self._scale_timesteps(t), unconditional_conditioning) 163 | 164 | out = self.do_classifier_guidance(y_out, u_out, guide_scale) 165 | 166 | # compute variance 167 | if self.var_type == 'fixed_small': 168 | var, log_var = self.fixed_small_variance(xt, t) 169 | 170 | # compute mean and x0 171 | if self.mean_type == 'eps': 172 | x0, mu = self.mean_x0(xt, t, out) 173 | 174 | # restrict the range of x0 175 | if percentile is not None: 176 | x0 = self.restrict_range_x0(percentile, x0) 177 | elif clamp is not None: 178 | x0 = self.restrict_range_x0(percentile, x0, clamp=True) 179 | 180 | if only_x0: 181 | return x0 182 | else: 183 | return mu, var, log_var, x0 184 | 185 | def q_posterior_mean_variance(self, x0, xt, t): 186 | r"""Distribution of q(x_{t-1} | x_t, x_0). 187 | """ 188 | mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i( 189 | self.posterior_mean_coef2, t, xt) * xt 190 | var = _i(self.posterior_variance, t, xt) 191 | log_var = _i(self.posterior_log_variance_clipped, t, xt) 192 | return mu, var, log_var 193 | 194 | def _scale_timesteps(self, t): 195 | if self.rescale_timesteps: 196 | return t.float() * 1000.0 / self.num_timesteps 197 | return t 198 | 199 | def get_eps(self, xt, x0, t, alpha, condition_fn, model_kwargs={}): 200 | # x0 -> eps 201 | eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( 202 | self.sqrt_recipm1_alphas_cumprod, t, xt) 203 | 204 | if condition_fn is not None: 205 | eps = eps - (1 - alpha).sqrt() * condition_fn( 206 | xt, self._scale_timesteps(t), **model_kwargs) 207 | # eps -> x0 208 | x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i( 209 | self.sqrt_recipm1_alphas_cumprod, t, xt) * eps 210 | 211 | return eps, x0 212 | 213 | @torch.no_grad() 214 | def sample(self, 215 | x_T=None, 216 | S=5, 217 | shape=None, 218 | conditioning=None, 219 | unconditional_conditioning=None, 220 | model_kwargs={}, 221 | clamp=None, 222 | percentile=None, 223 | condition_fn=None, 224 | unconditional_guidance_scale=None, 225 | eta=0.0, 226 | callback=None, 227 | mask=None, 228 | **kwargs): 229 | r"""Sample from p(x_{t-1} | x_t) using DDIM. 230 | - condition_fn: for classifier-based guidance (guided-diffusion). 231 | - guide_scale: for classifier-free guidance (glide/dalle-2). 232 | """ 233 | 234 | # Shape must exist to sample 235 | if shape is None and x_T is None: 236 | assert "Shape must exists to sample from noise" 237 | 238 | # Assign variables for sampling 239 | steps = S 240 | stride = self.num_timesteps // steps 241 | guide_scale = unconditional_guidance_scale 242 | original_latents = None 243 | 244 | if x_T is None: 245 | xt = torch.randn(shape, device=self.model.device) 246 | else: 247 | xt = x_T.clone() 248 | original_latents = xt 249 | 250 | timesteps = self.get_time_steps(stride, xt.shape[0]) 251 | 252 | for step in range(0, steps): 253 | c, uc = reconstruct_conds(conditioning, unconditional_conditioning, step) 254 | t = self.get_time_steps(stride, xt.shape[0], step=step) 255 | 256 | # predict distribution of p(x_{t-1} | x_t) 257 | x0 = self.p_mean_variance( 258 | xt, 259 | t, 260 | model_kwargs, 261 | clamp, 262 | percentile, 263 | guide_scale, 264 | conditioning=c, 265 | unconditional_conditioning=uc, 266 | **kwargs 267 | ) 268 | 269 | alphas = _i(self.alphas_cumprod, t, xt) 270 | alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) 271 | 272 | eps, x0 = self.get_eps(xt, x0, t, alphas, condition_fn) 273 | 274 | a = (1 - alphas_prev) / (1 - alphas) 275 | b = (1 - alphas / alphas_prev) 276 | sigmas = eta * torch.sqrt(a * b) 277 | 278 | # random sample 279 | noise = torch.randn_like(xt) 280 | direction = torch.sqrt(1 - alphas_prev - sigmas**2) * eps 281 | mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) 282 | xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise 283 | xt = xt_1 284 | 285 | if hasattr(self, 'inpaint_masking') and mask is not None: 286 | add_noise_args = { 287 | "xt":xt, 288 | "noise": torch.randn_like(xt), 289 | "t": timesteps[(step - 1) + 1] 290 | } 291 | self.inpaint_masking(xt, step, steps, mask, self.add_noise, add_noise_args) 292 | 293 | if callback is not None: 294 | callback(step) 295 | 296 | return xt 297 | 298 | 299 | 300 | -------------------------------------------------------------------------------- /scripts/videocrafter/sample_text2video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import yaml, math 5 | from tqdm import trange 6 | import torch 7 | import numpy as np 8 | from omegaconf import OmegaConf 9 | import torch.distributed as dist 10 | from pytorch_lightning import seed_everything 11 | 12 | from videocrafter.lvdm.samplers.ddim import DDIMSampler 13 | from videocrafter.lvdm.utils.common_utils import str2bool 14 | from videocrafter.lvdm.utils.dist_utils import setup_dist, gather_data 15 | from videocrafter.lvdm.utils.saving_utils import npz_to_video_grid, npz_to_imgsheet_5d 16 | from videocrafter.sample_utils import load_model, get_conditions, make_model_input_shape, torch_to_np 17 | 18 | 19 | # ------------------------------------------------------------------------------------------ 20 | def get_parser(): 21 | parser = argparse.ArgumentParser() 22 | # basic args 23 | parser.add_argument("--ckpt_path", type=str, help="model checkpoint path") 24 | parser.add_argument("--config_path", type=str, help="model config path (a yaml file)") 25 | parser.add_argument("--prompt", type=str, help="input text prompts for text2video (a sentence OR a txt file).") 26 | parser.add_argument("--save_dir", type=str, help="results saving dir", default="results/") 27 | # device args 28 | parser.add_argument("--ddp", action='store_true', help="whether use pytorch ddp mode for parallel sampling (recommend for multi-gpu case)", default=False) 29 | parser.add_argument("--local_rank", type=int, help="is used for pytorch ddp mode", default=0) 30 | parser.add_argument("--gpu_id", type=int, help="choose a specific gpu", default=0) 31 | # sampling args 32 | parser.add_argument("--n_samples", type=int, help="how many samples for each text prompt", default=2) 33 | parser.add_argument("--batch_size", type=int, help="video batch size for sampling", default=1) 34 | parser.add_argument("--decode_frame_bs", type=int, help="frame batch size for framewise decoding", default=1) 35 | parser.add_argument("--sample_type", type=str, help="ddpm or ddim", default="ddim", choices=["ddpm", "ddim"]) 36 | parser.add_argument("--ddim_steps", type=int, help="ddim sampling -- number of ddim denoising timesteps", default=50) 37 | parser.add_argument("--eta", type=float, help="ddim sampling -- eta (0.0 yields deterministic sampling, 1.0 yields random sampling)", default=1.0) 38 | parser.add_argument("--cfg_scale", type=float, default=15.0, help="classifier-free guidance scale") 39 | parser.add_argument("--seed", type=int, default=None, help="fix a seed for randomness (If you want to reproduce the sample results)") 40 | parser.add_argument("--show_denoising_progress", action='store_true', default=False, help="whether show denoising progress during sampling one batch",) 41 | parser.add_argument("--num_frames", type=int, default=16, help="number of input frames") 42 | # lora args 43 | parser.add_argument("--lora_path", type=str, help="lora checkpoint path") 44 | parser.add_argument("--inject_lora", action='store_true', default=False, help="",) 45 | parser.add_argument("--lora_scale", type=float, default=None, help="scale for lora weight") 46 | parser.add_argument("--lora_trigger_word", type=str, default="", help="",) 47 | # saving args 48 | parser.add_argument("--save_mp4", type=str2bool, default=True, help="whether save samples in separate mp4 files", choices=["True", "true", "False", "false"]) 49 | parser.add_argument("--save_mp4_sheet", action='store_true', default=False, help="whether save samples in mp4 file",) 50 | parser.add_argument("--save_npz", action='store_true', default=False, help="whether save samples in npz file",) 51 | parser.add_argument("--save_jpg", action='store_true', default=False, help="whether save samples in jpg file",) 52 | parser.add_argument("--save_fps", type=int, default=8, help="fps of saved mp4 videos",) 53 | return parser 54 | 55 | # ------------------------------------------------------------------------------------------ 56 | def sample_denoising_batch(model, noise_shape, condition, *args, 57 | sample_type="ddim", sampler=None, 58 | ddim_steps=None, eta=None, 59 | unconditional_guidance_scale=1.0, uc=None, 60 | denoising_progress=False, 61 | **kwargs, 62 | ): 63 | 64 | if sample_type == "ddpm": 65 | samples = model.p_sample_loop(cond=condition, shape=noise_shape, 66 | return_intermediates=False, 67 | verbose=denoising_progress, 68 | **kwargs, 69 | ) 70 | elif sample_type == "ddim": 71 | assert(sampler is not None) 72 | assert(ddim_steps is not None) 73 | assert(eta is not None) 74 | ddim_sampler = sampler 75 | samples, _ = ddim_sampler.sample(S=ddim_steps, 76 | conditioning=condition, 77 | batch_size=noise_shape[0], 78 | shape=noise_shape[1:], 79 | verbose=denoising_progress, 80 | unconditional_guidance_scale=unconditional_guidance_scale, 81 | unconditional_conditioning=uc, 82 | eta=eta, 83 | **kwargs, 84 | ) 85 | else: 86 | raise ValueError 87 | return samples 88 | 89 | 90 | # ------------------------------------------------------------------------------------------ 91 | @torch.no_grad() 92 | def sample_text2video(model, prompt, n_prompt, n_samples, batch_size, 93 | sample_type="ddim", sampler=None, 94 | ddim_steps=50, eta=1.0, cfg_scale=7.5, 95 | decode_frame_bs=1, 96 | ddp=False, all_gather=True, 97 | batch_progress=True, show_denoising_progress=False, 98 | num_frames=None, 99 | ): 100 | # get cond vector 101 | assert(model.cond_stage_model is not None) 102 | cond_embd = get_conditions(prompt, model, batch_size) 103 | uncond_embd = get_conditions(n_prompt, model, batch_size) if cfg_scale != 1.0 else None 104 | 105 | # sample batches 106 | all_videos = [] 107 | n_iter = math.ceil(n_samples / batch_size) 108 | iterator = trange(n_iter, desc="Sampling Batches (text-to-video)") if batch_progress else range(n_iter) 109 | for _ in iterator: 110 | noise_shape = make_model_input_shape(model, batch_size, T=num_frames) 111 | samples_latent = sample_denoising_batch(model, noise_shape, cond_embd, 112 | sample_type=sample_type, 113 | sampler=sampler, 114 | ddim_steps=ddim_steps, 115 | eta=eta, 116 | unconditional_guidance_scale=cfg_scale, 117 | uc=uncond_embd, 118 | denoising_progress=show_denoising_progress, 119 | ) 120 | samples = model.decode_first_stage(samples_latent, decode_bs=decode_frame_bs, return_cpu=False) 121 | 122 | # gather samples from multiple gpus 123 | if ddp and all_gather: 124 | data_list = gather_data(samples, return_np=False) 125 | all_videos.extend([torch_to_np(data) for data in data_list]) 126 | else: 127 | all_videos.append(torch_to_np(samples)) 128 | 129 | all_videos = np.concatenate(all_videos, axis=0) 130 | assert(all_videos.shape[0] >= n_samples) 131 | return all_videos 132 | 133 | 134 | # ------------------------------------------------------------------------------------------ 135 | def save_results(videos, save_dir, 136 | save_name="results", save_fps=8, save_mp4=True, 137 | save_npz=False, save_mp4_sheet=False, save_jpg=False 138 | ): 139 | if save_mp4: 140 | save_subdir = os.path.join(save_dir, "videos") 141 | os.makedirs(save_subdir, exist_ok=True) 142 | for i in range(videos.shape[0]): 143 | npz_to_video_grid(videos[i:i+1,...], 144 | os.path.join(save_subdir, f"{save_name}_{i:03d}.mp4"), 145 | fps=save_fps) 146 | print(f'Successfully saved videos in {save_subdir}') 147 | 148 | if save_npz: 149 | save_path = os.path.join(save_dir, f"{save_name}.npz") 150 | np.savez(save_path, videos) 151 | print(f'Successfully saved npz in {save_path}') 152 | 153 | if save_mp4_sheet: 154 | save_path = os.path.join(save_dir, f"{save_name}.mp4") 155 | npz_to_video_grid(videos, save_path, fps=save_fps) 156 | print(f'Successfully saved mp4 sheet in {save_path}') 157 | 158 | if save_jpg: 159 | save_path = os.path.join(save_dir, f"{save_name}.jpg") 160 | npz_to_imgsheet_5d(videos, save_path, nrow=videos.shape[1]) 161 | print(f'Successfully saved jpg sheet in {save_path}') 162 | 163 | 164 | # ------------------------------------------------------------------------------------------ 165 | def main(): 166 | """ 167 | text-to-video generation 168 | """ 169 | parser = get_parser() 170 | opt, unknown = parser.parse_known_args() 171 | os.makedirs(opt.save_dir, exist_ok=True) 172 | 173 | # set device 174 | if opt.ddp: 175 | setup_dist(opt.local_rank) 176 | opt.n_samples = math.ceil(opt.n_samples / dist.get_world_size()) 177 | gpu_id = None 178 | else: 179 | gpu_id = opt.gpu_id 180 | os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}" 181 | 182 | # set random seed 183 | if opt.seed is not None: 184 | if opt.ddp: 185 | seed = opt.local_rank + opt.seed 186 | else: 187 | seed = opt.seed 188 | seed_everything(seed) 189 | 190 | # dump args 191 | fpath = os.path.join(opt.save_dir, "sampling_args.yaml") 192 | with open(fpath, 'w') as f: 193 | yaml.dump(vars(opt), f, default_flow_style=False) 194 | 195 | # load & merge config 196 | config = OmegaConf.load(opt.config_path) 197 | cli = OmegaConf.from_dotlist(unknown) 198 | config = OmegaConf.merge(config, cli) 199 | print("config: \n", config) 200 | 201 | # get model & sampler 202 | model, _, _ = load_model(config, opt.ckpt_path, 203 | inject_lora=opt.inject_lora, 204 | lora_scale=opt.lora_scale, 205 | lora_path=opt.lora_path 206 | ) 207 | ddim_sampler = DDIMSampler(model) if opt.sample_type == "ddim" else None 208 | 209 | # prepare prompt 210 | if opt.prompt.endswith(".txt"): 211 | opt.prompt_file = opt.prompt 212 | opt.prompt = None 213 | else: 214 | opt.prompt_file = None 215 | 216 | if opt.prompt_file is not None: 217 | f = open(opt.prompt_file, 'r') 218 | prompts, line_idx = [], [] 219 | for idx, line in enumerate(f.readlines()): 220 | l = line.strip() 221 | if len(l) != 0: 222 | prompts.append(l) 223 | line_idx.append(idx) 224 | f.close() 225 | cmd = f"cp {opt.prompt_file} {opt.save_dir}" 226 | os.system(cmd) 227 | else: 228 | prompts = [opt.prompt] 229 | line_idx = [None] 230 | 231 | if opt.inject_lora: 232 | assert(opt.lora_trigger_word != '') 233 | prompts = [p + opt.lora_trigger_word for p in prompts] 234 | 235 | # go 236 | start = time.time() 237 | for prompt in prompts: 238 | # sample 239 | samples = sample_text2video(model, prompt, opt.n_samples, opt.batch_size, 240 | sample_type=opt.sample_type, sampler=ddim_sampler, 241 | ddim_steps=opt.ddim_steps, eta=opt.eta, 242 | cfg_scale=opt.cfg_scale, 243 | decode_frame_bs=opt.decode_frame_bs, 244 | ddp=opt.ddp, show_denoising_progress=opt.show_denoising_progress, 245 | num_frames=opt.num_frames, 246 | ) 247 | # save 248 | if (opt.ddp and dist.get_rank() == 0) or (not opt.ddp): 249 | prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt 250 | save_name = prompt_str.replace(" ", "_") if " " in prompt else prompt_str 251 | if opt.seed is not None: 252 | save_name = save_name + f"_seed{seed:05d}" 253 | save_results(samples, opt.save_dir, save_name=save_name, save_fps=opt.save_fps) 254 | print("Finish sampling!") 255 | print(f"Run time = {(time.time() - start):.2f} seconds") 256 | 257 | if opt.ddp: 258 | dist.destroy_process_group() 259 | 260 | 261 | # ------------------------------------------------------------------------------------------ 262 | if __name__ == "__main__": 263 | main() -------------------------------------------------------------------------------- /scripts/videocrafter/lvdm/models/modules/util.py: -------------------------------------------------------------------------------- 1 | import math 2 | from inspect import isfunction 3 | 4 | import torch 5 | import numpy as np 6 | import torch.nn as nn 7 | from einops import repeat 8 | import torch.nn.functional as F 9 | 10 | from videocrafter.lvdm.utils.common_utils import instantiate_from_config 11 | 12 | 13 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 14 | if schedule == "linear": 15 | betas = ( 16 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 17 | ) 18 | elif schedule == "cosine": 19 | timesteps = ( 20 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 21 | ) 22 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 23 | alphas = torch.cos(alphas).pow(2) 24 | alphas = alphas / alphas[0] 25 | betas = 1 - alphas[1:] / alphas[:-1] 26 | betas = np.clip(betas, a_min=0, a_max=0.999) 27 | elif schedule == "sqrt_linear": 28 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 29 | elif schedule == "sqrt": 30 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 31 | else: 32 | raise ValueError(f"schedule '{schedule}' unknown.") 33 | return betas.numpy() 34 | 35 | 36 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 37 | if ddim_discr_method == 'uniform': 38 | c = num_ddpm_timesteps // num_ddim_timesteps 39 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 40 | elif ddim_discr_method == 'quad': 41 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 42 | else: 43 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 44 | 45 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 46 | steps_out = ddim_timesteps + 1 47 | if verbose: 48 | print(f'Selected timesteps for ddim sampler: {steps_out}') 49 | return steps_out 50 | 51 | 52 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 53 | # select alphas for computing the variance schedule 54 | alphas = alphacums[ddim_timesteps] 55 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 56 | 57 | # according the the formula provided in https://arxiv.org/abs/2010.02502 58 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 59 | if verbose: 60 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 61 | print(f'For the chosen value of eta, which is {eta}, ' 62 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 63 | return sigmas, alphas, alphas_prev 64 | 65 | 66 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 67 | """ 68 | Create a beta schedule that discretizes the given alpha_t_bar function, 69 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 70 | :param num_diffusion_timesteps: the number of betas to produce. 71 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 72 | produces the cumulative product of (1-beta) up to that 73 | part of the diffusion process. 74 | :param max_beta: the maximum beta to use; use values lower than 1 to 75 | prevent singularities. 76 | """ 77 | betas = [] 78 | for i in range(num_diffusion_timesteps): 79 | t1 = i / num_diffusion_timesteps 80 | t2 = (i + 1) / num_diffusion_timesteps 81 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 82 | return np.array(betas) 83 | 84 | 85 | def extract_into_tensor(a, t, x_shape): 86 | b, *_ = t.shape 87 | out = a.gather(-1, t) 88 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 89 | 90 | 91 | def checkpoint(func, inputs, params, flag): 92 | """ 93 | Evaluate a function without caching intermediate activations, allowing for 94 | reduced memory at the expense of extra compute in the backward pass. 95 | :param func: the function to evaluate. 96 | :param inputs: the argument sequence to pass to `func`. 97 | :param params: a sequence of parameters `func` depends on but does not 98 | explicitly take as arguments. 99 | :param flag: if False, disable gradient checkpointing. 100 | """ 101 | if flag: 102 | args = tuple(inputs) + tuple(params) 103 | return CheckpointFunction.apply(func, len(inputs), *args) 104 | else: 105 | return func(*inputs) 106 | 107 | 108 | class CheckpointFunction(torch.autograd.Function): 109 | @staticmethod 110 | @torch.cuda.amp.custom_fwd 111 | def forward(ctx, run_function, length, *args): 112 | ctx.run_function = run_function 113 | ctx.input_tensors = list(args[:length]) 114 | ctx.input_params = list(args[length:]) 115 | 116 | with torch.no_grad(): 117 | output_tensors = ctx.run_function(*ctx.input_tensors) 118 | return output_tensors 119 | 120 | @staticmethod 121 | @torch.cuda.amp.custom_bwd 122 | def backward(ctx, *output_grads): 123 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 124 | with torch.enable_grad(): 125 | # Fixes a bug where the first op in run_function modifies the 126 | # Tensor storage in place, which is not allowed for detach()'d 127 | # Tensors. 128 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 129 | output_tensors = ctx.run_function(*shallow_copies) 130 | input_grads = torch.autograd.grad( 131 | output_tensors, 132 | ctx.input_tensors + ctx.input_params, 133 | output_grads, 134 | allow_unused=True, 135 | ) 136 | del ctx.input_tensors 137 | del ctx.input_params 138 | del output_tensors 139 | return (None, None) + input_grads 140 | 141 | 142 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 143 | """ 144 | Create sinusoidal timestep embeddings. 145 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 146 | These may be fractional. 147 | :param dim: the dimension of the output. 148 | :param max_period: controls the minimum frequency of the embeddings. 149 | :return: an [N x dim] Tensor of positional embeddings. 150 | """ 151 | if not repeat_only: 152 | half = dim // 2 153 | freqs = torch.exp( 154 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 155 | ).to(device=timesteps.device) 156 | args = timesteps[:, None].float() * freqs[None] 157 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 158 | if dim % 2: 159 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 160 | else: 161 | embedding = repeat(timesteps, 'b -> b d', d=dim) 162 | return embedding 163 | 164 | 165 | def zero_module(module): 166 | """ 167 | Zero out the parameters of a module and return it. 168 | """ 169 | for p in module.parameters(): 170 | p.detach().zero_() 171 | return module 172 | 173 | 174 | def scale_module(module, scale): 175 | """ 176 | Scale the parameters of a module and return it. 177 | """ 178 | for p in module.parameters(): 179 | p.detach().mul_(scale) 180 | return module 181 | 182 | 183 | def mean_flat(tensor): 184 | """ 185 | Take the mean over all non-batch dimensions. 186 | """ 187 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 188 | 189 | 190 | def normalization(channels): 191 | """ 192 | Make a standard normalization layer. 193 | :param channels: number of input channels. 194 | :return: an nn.Module for normalization. 195 | """ 196 | return GroupNorm32(32, channels) 197 | 198 | def Normalize(in_channels): 199 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 200 | 201 | def identity(*args, **kwargs): 202 | return nn.Identity() 203 | 204 | class Normalization(nn.Module): 205 | def __init__(self, output_size, eps=1e-5, norm_type='gn'): 206 | super(Normalization, self).__init__() 207 | # epsilon to avoid dividing by 0 208 | self.eps = eps 209 | self.norm_type = norm_type 210 | 211 | if self.norm_type in ['bn', 'in']: 212 | self.register_buffer('stored_mean', torch.zeros(output_size)) 213 | self.register_buffer('stored_var', torch.ones(output_size)) 214 | 215 | def forward(self, x): 216 | if self.norm_type == 'bn': 217 | out = F.batch_norm(x, self.stored_mean, self.stored_var, None, 218 | None, 219 | self.training, 0.1, self.eps) 220 | elif self.norm_type == 'in': 221 | out = F.instance_norm(x, self.stored_mean, self.stored_var, 222 | None, None, 223 | self.training, 0.1, self.eps) 224 | elif self.norm_type == 'gn': 225 | out = F.group_norm(x, 32) 226 | elif self.norm_type == 'nonorm': 227 | out = x 228 | return out 229 | 230 | 231 | class CCNormalization(nn.Module): 232 | def __init__(self, embed_dim, feature_dim, *args, **kwargs): 233 | super(CCNormalization, self).__init__() 234 | 235 | self.embed_dim = embed_dim 236 | self.feature_dim = feature_dim 237 | self.norm = Normalization(feature_dim, *args, **kwargs) 238 | 239 | self.gain = nn.Linear(self.embed_dim, self.feature_dim) 240 | self.bias = nn.Linear(self.embed_dim, self.feature_dim) 241 | 242 | def forward(self, x, y): 243 | shape = [1] * (x.dim() - 2) 244 | gain = (1 + self.gain(y)).view(y.size(0), -1, *shape) 245 | bias = self.bias(y).view(y.size(0), -1, *shape) 246 | return self.norm(x) * gain + bias 247 | 248 | 249 | def nonlinearity(type='silu'): 250 | if type == 'silu': 251 | return nn.SiLU() 252 | elif type == 'leaky_relu': 253 | return nn.LeakyReLU() 254 | 255 | 256 | class GEGLU(nn.Module): 257 | def __init__(self, dim_in, dim_out): 258 | super().__init__() 259 | self.proj = nn.Linear(dim_in, dim_out * 2) 260 | 261 | def forward(self, x): 262 | x, gate = self.proj(x).chunk(2, dim=-1) 263 | return x * F.gelu(gate) 264 | 265 | 266 | class SiLU(nn.Module): 267 | def forward(self, x): 268 | return x * torch.sigmoid(x) 269 | 270 | 271 | class GroupNorm32(nn.GroupNorm): 272 | def forward(self, x): 273 | return super().forward(x.float()).type(x.dtype) 274 | 275 | 276 | def conv_nd(dims, *args, **kwargs): 277 | """ 278 | Create a 1D, 2D, or 3D convolution module. 279 | """ 280 | if dims == 1: 281 | return nn.Conv1d(*args, **kwargs) 282 | elif dims == 2: 283 | return nn.Conv2d(*args, **kwargs) 284 | elif dims == 3: 285 | return nn.Conv3d(*args, **kwargs) 286 | raise ValueError(f"unsupported dimensions: {dims}") 287 | 288 | 289 | def linear(*args, **kwargs): 290 | """ 291 | Create a linear module. 292 | """ 293 | return nn.Linear(*args, **kwargs) 294 | 295 | 296 | def avg_pool_nd(dims, *args, **kwargs): 297 | """ 298 | Create a 1D, 2D, or 3D average pooling module. 299 | """ 300 | if dims == 1: 301 | return nn.AvgPool1d(*args, **kwargs) 302 | elif dims == 2: 303 | return nn.AvgPool2d(*args, **kwargs) 304 | elif dims == 3: 305 | return nn.AvgPool3d(*args, **kwargs) 306 | raise ValueError(f"unsupported dimensions: {dims}") 307 | 308 | 309 | class HybridConditioner(nn.Module): 310 | 311 | def __init__(self, c_concat_config, c_crossattn_config): 312 | super().__init__() 313 | self.concat_conditioner = instantiate_from_config(c_concat_config) 314 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 315 | 316 | def forward(self, c_concat, c_crossattn): 317 | c_concat = self.concat_conditioner(c_concat) 318 | c_crossattn = self.crossattn_conditioner(c_crossattn) 319 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} 320 | 321 | def noise_like(shape, device, repeat=False, noise_gen=None): 322 | assert noise_gen is not None 323 | repeat_noise = lambda: torch.randn((1, *shape[1:]), generator=noise_gen).repeat(shape[0], *((1,) * (len(shape) - 1))).to(device) 324 | noise = lambda: torch.randn(shape, generator=noise_gen).to(device) 325 | return repeat_noise() if repeat else noise() 326 | 327 | def init_(tensor): 328 | dim = tensor.shape[-1] 329 | std = 1 / math.sqrt(dim) 330 | tensor.uniform_(-std, std) 331 | return tensor 332 | 333 | 334 | def exists(val): 335 | return val is not None 336 | 337 | 338 | def uniq(arr): 339 | return{el: True for el in arr}.keys() 340 | 341 | 342 | def default(val, d): 343 | if exists(val): 344 | return val 345 | return d() if isfunction(d) else d 346 | 347 | 348 | -------------------------------------------------------------------------------- /scripts/modelscope/process_modelscope.py: -------------------------------------------------------------------------------- 1 | # Function calls referenced from https://github.com/modelscope/modelscope/tree/master/modelscope/pipelines/multi_modal 2 | 3 | # Copyright (C) 2023 by Artem Khrapov (kabachuha) 4 | # Read LICENSE for usage terms. 5 | 6 | from base64 import b64encode 7 | from tqdm import tqdm 8 | from PIL import Image 9 | from modelscope.t2v_pipeline import TextToVideoSynthesis, tensor2vid 10 | from t2v_helpers.key_frames import T2VAnimKeys # TODO: move to deforum_tools 11 | from pathlib import Path 12 | import numpy as np 13 | import torch 14 | import cv2 15 | import gc 16 | import modules.paths as ph 17 | from types import SimpleNamespace 18 | from t2v_helpers.general_utils import get_t2v_version, get_model_location 19 | import time, math 20 | from t2v_helpers.video_audio_utils import ffmpeg_stitch_video, get_quick_vid_info, vid2frames, duplicate_pngs_from_folder, clean_folder_name 21 | from t2v_helpers.args import get_outdir, process_args 22 | import t2v_helpers.args as t2v_helpers_args 23 | from modules import shared, sd_hijack, lowvram 24 | from modules.shared import opts, state 25 | from modules import devices 26 | from stable_lora.scripts.lora_webui import gr_inputs_list, StableLoraScriptInstance 27 | import os 28 | 29 | pipe = None 30 | 31 | def setup_pipeline(model_name): 32 | return TextToVideoSynthesis(get_model_location(model_name)) 33 | 34 | def process_modelscope(args_dict, extra_args=None): 35 | args, video_args = process_args(args_dict) 36 | 37 | global pipe 38 | print(f"\033[4;33m text2video extension for auto1111 webui\033[0m") 39 | print(f"Git commit: {get_t2v_version()}") 40 | init_timestring = time.strftime('%Y%m%d%H%M%S') 41 | outdir_current = os.path.join(get_outdir(), f"{init_timestring}") 42 | 43 | max_vids_to_pack = opts.data.get("modelscope_deforum_show_n_videos") if opts.data is not None and opts.data.get("modelscope_deforum_show_n_videos") is not None else -1 44 | cpu_vae = opts.data.get("modelscope_deforum_vae_settings") if opts.data is not None and opts.data.get("modelscope_deforum_vae_settings") is not None else 'GPU (half precision)' 45 | if shared.sd_model is not None: 46 | sd_hijack.model_hijack.undo_hijack(shared.sd_model) 47 | try: 48 | lowvram.send_everything_to_cpu() 49 | except Exception as e: 50 | pass 51 | # the following command actually frees the GPU vram from the sd.model, no need to do del shared.sd_model 22-05-23 52 | shared.sd_model = None 53 | gc.collect() 54 | devices.torch_gc() 55 | 56 | print('Starting text2video') 57 | print('Pipeline setup') 58 | 59 | # optionally store pipe in global between runs 60 | # also refresh the model if the user selected a newer one 61 | # if args.model is none (e.g. an API call, the model stays as the previous one) 62 | if pipe is None and args.model is None: # one more API call hack, falling back to if never used TODO: figure out how to permastore the model name the best way 63 | args.model = "" 64 | print(f"WARNING: received an API call with an empty model name, defaulting to {args.model} at {get_model_location(args.model)}") 65 | if pipe is None or pipe is not None and args.model is not None and get_model_location(args.model) != pipe.model_dir: 66 | pipe = setup_pipeline(args.model) 67 | 68 | #TODO Wrap this in a list so that we can process this for future extensions. 69 | stable_lora_processor = StableLoraScriptInstance 70 | stable_lora_args = stable_lora_processor.process_extension_args(all_args=extra_args) 71 | stable_lora_processor.process(pipe, *stable_lora_args) 72 | 73 | pipe.keep_in_vram = opts.data.get("modelscope_deforum_keep_model_in_vram") if opts.data is not None and opts.data.get("modelscope_deforum_keep_model_in_vram") is not None else 'None' 74 | 75 | device = devices.get_optimal_device() 76 | print('device', device) 77 | 78 | mask = None 79 | 80 | if args.do_vid2vid: 81 | if args.vid2vid_frames is None and args.vid2vid_frames_path == "": 82 | raise FileNotFoundError("Please upload a video :()") 83 | 84 | # Overrides 85 | if args.vid2vid_frames is not None: 86 | vid2vid_frames_path = args.vid2vid_frames.name 87 | 88 | print("got a request to *vid2vid* an existing video.") 89 | 90 | in_vid_fps, _, _ = get_quick_vid_info(vid2vid_frames_path) 91 | folder_name = clean_folder_name(Path(vid2vid_frames_path).stem) 92 | outdir_no_tmp = os.path.join(os.getcwd(), 'outputs', 'frame-vid2vid', folder_name) 93 | i = 1 94 | while os.path.exists(outdir_no_tmp): 95 | outdir_no_tmp = os.path.join(os.getcwd(), 'outputs', 'frame-vid2vid', folder_name + '_' + str(i)) 96 | i += 1 97 | 98 | outdir_v2v = os.path.join(outdir_no_tmp, 'tmp_input_frames') 99 | os.makedirs(outdir_v2v, exist_ok=True) 100 | 101 | vid2frames(video_path=vid2vid_frames_path, video_in_frame_path=outdir_v2v, overwrite=True, extract_from_frame=args.vid2vid_startFrame, extract_to_frame=args.vid2vid_startFrame + args.frames, 102 | numeric_files_output=True, out_img_format='png') 103 | 104 | temp_convert_raw_png_path = os.path.join(outdir_v2v, "tmp_vid2vid_folder") 105 | duplicate_pngs_from_folder(outdir_v2v, temp_convert_raw_png_path, None, folder_name) 106 | 107 | videogen = [] 108 | for f in os.listdir(temp_convert_raw_png_path): 109 | # double check for old _depth_ files, not really needed probably but keeping it for now 110 | if '_depth_' not in f: 111 | videogen.append(f) 112 | 113 | videogen.sort(key=lambda x: int(x.split('.')[0])) 114 | 115 | images = [] 116 | for file in tqdm(videogen, desc="Loading frames"): 117 | image = Image.open(os.path.join(temp_convert_raw_png_path, file)) 118 | image = image.resize((args.width, args.height), Image.ANTIALIAS) 119 | array = np.array(image) 120 | images += [array] 121 | 122 | # print(images) 123 | 124 | images = np.stack(images) # f h w c 125 | batches = 1 126 | n_images = np.tile(images[np.newaxis, ...], (batches, 1, 1, 1, 1)) # n f h w c 127 | bcfhw = n_images.transpose(0, 4, 1, 2, 3) 128 | # convert to 0-1 float 129 | bcfhw = bcfhw.astype(np.float32) / 255 130 | bfchw = bcfhw.transpose(0, 2, 1, 3, 4) # b c f h w 131 | 132 | print(f"Converted the frames to tensor {bfchw.shape}") 133 | 134 | vd_out = torch.from_numpy(bcfhw).to("cuda") 135 | 136 | # should be -1,1, not 0,1 137 | vd_out = 2 * vd_out - 1 138 | 139 | # latents should have shape num_sample, 4, max_frames, latent_h,latent_w 140 | print("Computing latents") 141 | latents = pipe.compute_latents(vd_out).to(device) 142 | 143 | skip_steps = int(math.floor(args.steps * max(0, min(1 - args.strength, 1)))) 144 | else: 145 | latents = None 146 | args.strength = 1 147 | skip_steps = 0 148 | 149 | print('Working in txt2vid mode' if not args.do_vid2vid else 'Working in vid2vid mode') 150 | 151 | # Start the batch count loop 152 | pbar = tqdm(range(args.batch_count), leave=False) 153 | if args.batch_count == 1: 154 | pbar.disable = True 155 | 156 | vids_to_pack = [] 157 | 158 | state.job_count = args.batch_count 159 | 160 | for batch in pbar: 161 | state.job_no = batch 162 | if state.skipped: 163 | state.skipped = False 164 | 165 | if state.interrupted: 166 | break 167 | 168 | shared.state.job = f"Batch {batch + 1} out of {args.batch_count}" 169 | # TODO: move to a separate function 170 | if args.inpainting_frames > 0 and hasattr(args.inpainting_image, "name"): 171 | keys = T2VAnimKeys(SimpleNamespace(**{'max_frames': args.frames, 'inpainting_weights': args.inpainting_weights}), args.seed, args.inpainting_frames) 172 | images = [] 173 | print("Received an image for inpainting", args.inpainting_image.name) 174 | for i in range(args.frames): 175 | image = Image.open(args.inpainting_image.name).convert("RGB") 176 | image = image.resize((args.width, args.height), Image.ANTIALIAS) 177 | array = np.array(image) 178 | images += [array] 179 | 180 | images = np.stack(images) # f h w c 181 | batches = 1 182 | n_images = np.tile(images[np.newaxis, ...], (batches, 1, 1, 1, 1)) # n f h w c 183 | bcfhw = n_images.transpose(0, 4, 1, 2, 3) 184 | # convert to 0-1 float 185 | bcfhw = bcfhw.astype(np.float32) / 255 186 | bfchw = bcfhw.transpose(0, 2, 1, 3, 4) # b c f h w 187 | 188 | print(f"Converted the frames to tensor {bfchw.shape}") 189 | 190 | vd_out = torch.from_numpy(bcfhw).to("cuda") 191 | 192 | # should be -1,1, not 0,1 193 | vd_out = 2 * vd_out - 1 194 | 195 | # latents should have shape num_sample, 4, max_frames, latent_h,latent_w 196 | # but right now they have shape num_sample=1,4, 1 (only used 1 img), latent_h, latent_w 197 | print("Computing latents") 198 | image_latents = pipe.compute_latents(vd_out).numpy() 199 | # padding_width = [(0, 0), (0, 0), (0, frames-inpainting_frames), (0, 0), (0, 0)] 200 | # padded_latents = np.pad(image_latents, pad_width=padding_width, mode='constant', constant_values=0) 201 | 202 | latent_h = args.height // 8 203 | latent_w = args.width // 8 204 | latent_noise = np.random.normal(size=(1, 4, args.frames, latent_h, latent_w)) 205 | mask = np.ones(shape=(1, 4, args.frames, latent_h, latent_w)) 206 | 207 | mask_weights = [keys.inpainting_weights_series[frame_idx] for frame_idx in range(args.frames)] 208 | 209 | for i in range(args.frames): 210 | v = mask_weights[i] 211 | mask[:, :, i, :, :] = v 212 | 213 | masked_latents = image_latents * (1 - mask) + latent_noise * mask 214 | 215 | latents = torch.tensor(masked_latents).to(device) 216 | 217 | mask = torch.tensor(mask).to(device) 218 | 219 | args.strength = 1 220 | 221 | samples, _, infotext = pipe.infer(args.prompt, args.n_prompt, args.steps, args.frames, args.seed + batch if args.seed != -1 else -1, args.cfg_scale, 222 | args.width, args.height, args.eta, cpu_vae, device, latents, strength=args.strength, skip_steps=skip_steps, mask=mask, is_vid2vid=args.do_vid2vid, sampler=args.sampler) 223 | 224 | 225 | if batch > 0: 226 | outdir_current = os.path.join(get_outdir(), f"{init_timestring}_{batch}") 227 | print(f'text2video finished, saving frames to {outdir_current}') 228 | 229 | # just deleted the folder so we need to make it again 230 | os.makedirs(outdir_current, exist_ok=True) 231 | for i in range(len(samples)): 232 | cv2.imwrite(outdir_current + os.path.sep + 233 | f"{i:06}.png", samples[i]) 234 | 235 | # save settings to a file 236 | if opts.data is not None and opts.data.get("modelscope_save_info_to_file"): 237 | 238 | args_file = os.path.join(outdir_current,'args.txt') 239 | with open(args_file, 'w', encoding='utf-8') as f: 240 | print(f'saving args to {args_file}') 241 | f.write(infotext) 242 | 243 | # TODO: add params to the GUI 244 | if not video_args.skip_video_creation: 245 | metadata = None 246 | if opts.data is not None and opts.data.get("modelscope_save_metadata") is not None: 247 | if opts.data.get("modelscope_save_metadata"): 248 | metadata = infotext 249 | else: 250 | metadata = infotext 251 | ffmpeg_stitch_video(ffmpeg_location=video_args.ffmpeg_location, fps=video_args.fps, outmp4_path=outdir_current + os.path.sep + f"vid.mp4", imgs_path=os.path.join(outdir_current, 252 | "%06d.png"), 253 | stitch_from_frame=0, stitch_to_frame=-1, add_soundtrack=video_args.add_soundtrack, 254 | audio_path=vid2vid_frames_path if video_args.add_soundtrack == 'Init Video' else video_args.soundtrack_path, crf=video_args.ffmpeg_crf, preset=video_args.ffmpeg_preset, metadata=metadata) 255 | print(f't2v complete, result saved at {outdir_current}') 256 | 257 | mp4 = open(outdir_current + os.path.sep + f"vid.mp4", 'rb').read() 258 | dataurl = "data:video/mp4;base64," + b64encode(mp4).decode() 259 | 260 | if max_vids_to_pack == -1 or len(vids_to_pack) < max_vids_to_pack: 261 | vids_to_pack.append((dataurl, infotext)) 262 | t2v_helpers_args.i1_store_t2v = f'

text2video extension for auto1111 — version 1.2b

' 263 | for dataurl, infotext in vids_to_pack: 264 | t2v_helpers_args.i1_store_t2v += f'
{infotext}
' 265 | pbar.close() 266 | return [v for v, _ in vids_to_pack] 267 | -------------------------------------------------------------------------------- /scripts/t2v_helpers/video_audio_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023 by Artem Khrapov (kabachuha) 2 | # Read LICENSE for usage terms. 3 | 4 | import time, math 5 | import subprocess 6 | import os, shutil 7 | import cv2 8 | from modules.shared import state 9 | from pkg_resources import resource_filename 10 | import requests 11 | from mutagen.mp4 import MP4 12 | 13 | def get_frame_name(path): 14 | name = os.path.basename(path) 15 | name = os.path.splitext(name)[0] 16 | return name 17 | 18 | def vid2frames(video_path, video_in_frame_path, n=1, overwrite=True, extract_from_frame=0, extract_to_frame=-1, out_img_format='jpg', numeric_files_output = False): 19 | if (extract_to_frame <= extract_from_frame) and extract_to_frame != -1: 20 | raise RuntimeError('Error: extract_to_frame can not be higher than extract_from_frame') 21 | 22 | if n < 1: n = 1 #HACK Gradio interface does not currently allow min/max in gr.Number(...) 23 | 24 | # check vid path using a function and only enter if we get True 25 | if is_vid_path_valid(video_path): 26 | 27 | name = get_frame_name(video_path) 28 | 29 | vidcap = cv2.VideoCapture(video_path) 30 | video_fps = vidcap.get(cv2.CAP_PROP_FPS) 31 | 32 | input_content = [] 33 | if os.path.exists(video_in_frame_path) : 34 | input_content = os.listdir(video_in_frame_path) 35 | 36 | # check if existing frame is the same video, if not we need to erase it and repopulate 37 | if len(input_content) > 0: 38 | #get the name of the existing frame 39 | content_name = get_frame_name(input_content[0]) 40 | if not content_name.startswith(name): 41 | overwrite = True 42 | 43 | # grab the frame count to check against existing directory len 44 | frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) 45 | 46 | # raise error if the user wants to skip more frames than exist 47 | if n >= frame_count : 48 | raise RuntimeError('Skipping more frames than input video contains. extract_nth_frames larger than input frames') 49 | 50 | expected_frame_count = math.ceil(frame_count / n) 51 | # Check to see if the frame count is matches the number of files in path 52 | if overwrite or expected_frame_count != len(input_content): 53 | shutil.rmtree(video_in_frame_path) 54 | os.makedirs(video_in_frame_path, exist_ok=True) # just deleted the folder so we need to make it again 55 | input_content = os.listdir(video_in_frame_path) 56 | 57 | print(f"Trying to extract frames from video with input FPS of {video_fps}. Please wait patiently.") 58 | if len(input_content) == 0: 59 | vidcap.set(cv2.CAP_PROP_POS_FRAMES, extract_from_frame) # Set the starting frame 60 | success,image = vidcap.read() 61 | count = extract_from_frame 62 | t=1 63 | success = True 64 | while success: 65 | if state.interrupted: 66 | return 67 | if (count <= extract_to_frame or extract_to_frame == -1) and count % n == 0: 68 | if numeric_files_output == True: 69 | cv2.imwrite(video_in_frame_path + os.path.sep + f"{t:09}.{out_img_format}" , image) # save frame as file 70 | else: 71 | cv2.imwrite(video_in_frame_path + os.path.sep + name + f"{t:09}.{out_img_format}" , image) # save frame as file 72 | t += 1 73 | success,image = vidcap.read() 74 | count += 1 75 | print(f"Successfully extracted {count} frames from video.") 76 | else: 77 | print("Frames already unpacked") 78 | vidcap.release() 79 | return video_fps 80 | 81 | def is_vid_path_valid(video_path): 82 | # make sure file format is supported! 83 | file_formats = ["mov", "mpeg", "mp4", "m4v", "avi", "mpg", "webm"] 84 | extension = video_path.rsplit('.', 1)[-1].lower() 85 | # vid path is actually a URL, check it 86 | if video_path.startswith('http://') or video_path.startswith('https://'): 87 | response = requests.head(video_path, allow_redirects=True) 88 | if response.status_code == 404: 89 | raise ConnectionError("Video URL is not valid. Response status code: {}".format(response.status_code)) 90 | elif response.status_code == 302: 91 | response = requests.head(response.headers['location'], allow_redirects=True) 92 | if response.status_code != 200: 93 | raise ConnectionError("Video URL is not valid. Response status code: {}".format(response.status_code)) 94 | if extension not in file_formats: 95 | raise ValueError("Video file format '{}' not supported. Supported formats are: {}".format(extension, file_formats)) 96 | else: 97 | if not os.path.exists(video_path): 98 | raise RuntimeError("Video path does not exist.") 99 | if extension not in file_formats: 100 | raise ValueError("Video file format '{}' not supported. Supported formats are: {}".format(extension, file_formats)) 101 | return True 102 | 103 | 104 | def clean_folder_name(string): 105 | illegal_chars = "/\\<>:\"|?*.,\" " 106 | translation_table = str.maketrans(illegal_chars, "_"*len(illegal_chars)) 107 | return string.translate(translation_table) 108 | 109 | def find_ffmpeg_binary(): 110 | try: 111 | import google.colab 112 | return 'ffmpeg' 113 | except: 114 | pass 115 | for package in ['imageio_ffmpeg', 'imageio-ffmpeg']: 116 | try: 117 | package_path = resource_filename(package, 'binaries') 118 | files = [os.path.join(package_path, f) for f in os.listdir( 119 | package_path) if f.startswith("ffmpeg-")] 120 | files.sort(key=lambda x: os.path.getmtime(x), reverse=True) 121 | return files[0] if files else 'ffmpeg' 122 | except: 123 | return 'ffmpeg' 124 | 125 | # Stitch images to a h264 mp4 video using ffmpeg 126 | def ffmpeg_stitch_video(ffmpeg_location=None, fps=None, outmp4_path=None, stitch_from_frame=0, stitch_to_frame=None, imgs_path=None, add_soundtrack=None, audio_path=None, crf=17, preset='veryslow', metadata=None): 127 | start_time = time.time() 128 | 129 | print(f"Got a request to stitch frames to video using FFmpeg.\nFrames:\n{imgs_path}\nTo Video:\n{outmp4_path}") 130 | msg_to_print = f"Stitching *video*..." 131 | print(msg_to_print) 132 | if stitch_to_frame == -1: 133 | stitch_to_frame = 999999999 134 | try: 135 | cmd = [ 136 | ffmpeg_location, 137 | '-y', 138 | '-vcodec', 'png', 139 | '-r', str(float(fps)), 140 | '-start_number', str(stitch_from_frame), 141 | '-i', imgs_path, 142 | '-frames:v', str(stitch_to_frame), 143 | '-c:v', 'libx264', 144 | '-vf', 145 | f'fps={float(fps)}', 146 | '-pix_fmt', 'yuv420p', 147 | '-crf', str(crf), 148 | '-preset', preset, 149 | '-pattern_type', 'sequence', 150 | ] 151 | 152 | cmd.append(outmp4_path) 153 | 154 | process = subprocess.Popen( 155 | cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 156 | stdout, stderr = process.communicate() 157 | except FileNotFoundError: 158 | print("\r" + " " * len(msg_to_print), end="", flush=True) 159 | print(f"\r{msg_to_print}", flush=True) 160 | raise FileNotFoundError( 161 | "FFmpeg not found. Please make sure you have a working ffmpeg path under 'ffmpeg_location' parameter.") 162 | except Exception as e: 163 | print("\r" + " " * len(msg_to_print), end="", flush=True) 164 | print(f"\r{msg_to_print}", flush=True) 165 | raise Exception( 166 | f'Error stitching frames to video. Actual runtime error:{e}') 167 | 168 | if add_soundtrack != 'None': 169 | audio_add_start_time = time.time() 170 | try: 171 | cmd = [ 172 | ffmpeg_location, 173 | '-i', 174 | outmp4_path, 175 | '-i', 176 | audio_path, 177 | '-map', '0:v', 178 | '-map', '1:a', 179 | '-c:v', 'copy', 180 | '-shortest', 181 | ] 182 | 183 | cmd.append(outmp4_path+'.temp.mp4') 184 | 185 | process = subprocess.Popen( 186 | cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 187 | stdout, stderr = process.communicate() 188 | if process.returncode != 0: 189 | print("\r" + " " * len(msg_to_print), end="", flush=True) 190 | print(f"\r{msg_to_print}", flush=True) 191 | raise RuntimeError(stderr) 192 | os.replace(outmp4_path+'.temp.mp4', outmp4_path) 193 | print("\r" + " " * len(msg_to_print), end="", flush=True) 194 | print(f"\r{msg_to_print}", flush=True) 195 | print(f"\rFFmpeg Video+Audio stitching \033[0;32mdone\033[0m in {time.time() - start_time:.2f} seconds!", flush=True) 196 | except Exception as e: 197 | print("\r" + " " * len(msg_to_print), end="", flush=True) 198 | print(f"\r{msg_to_print}", flush=True) 199 | print(f'\rError adding audio to video. Actual error: {e}', flush=True) 200 | print(f"FFMPEG Video (sorry, no audio) stitching \033[33mdone\033[0m in {time.time() - start_time:.2f} seconds!", flush=True) 201 | else: 202 | print("\r" + " " * len(msg_to_print), end="", flush=True) 203 | print(f"\r{msg_to_print}", flush=True) 204 | 205 | # adding metadata 206 | if metadata is not None: 207 | print('Writing metadata') 208 | video = MP4(outmp4_path) 209 | video["\xa9cmt"] = metadata 210 | video.save() 211 | 212 | print(f"\rVideo stitching \033[0;32mdone\033[0m in {time.time() - start_time:.2f} seconds!", flush=True) 213 | 214 | # quick-retreive frame count, FPS and H/W dimensions of a video (local or URL-based) 215 | def get_quick_vid_info(vid_path): 216 | vidcap = cv2.VideoCapture(vid_path) 217 | video_fps = vidcap.get(cv2.CAP_PROP_FPS) 218 | video_frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) 219 | video_width = int(vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)) 220 | video_height = int(vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 221 | vidcap.release() 222 | if video_fps.is_integer(): 223 | video_fps = int(video_fps) 224 | 225 | return video_fps, video_frame_count, (video_width, video_height) 226 | 227 | # This function usually gets a filename, and converts it to a legal linux/windows *folder* name 228 | def clean_folder_name(string): 229 | illegal_chars = "/\\<>:\"|?*.,\" " 230 | translation_table = str.maketrans(illegal_chars, "_"*len(illegal_chars)) 231 | return string.translate(translation_table) 232 | 233 | # used in src/rife/inference_video.py and more, soon 234 | def duplicate_pngs_from_folder(from_folder, to_folder, img_batch_id, orig_vid_name): 235 | import cv2 236 | #TODO: don't copy-paste at all if the input is a video (now it copy-pastes, and if input is deforum run is also converts to make sure no errors rise cuz of 24-32 bit depth differences) 237 | temp_convert_raw_png_path = os.path.join(from_folder, to_folder) 238 | if not os.path.exists(temp_convert_raw_png_path): 239 | os.makedirs(temp_convert_raw_png_path) 240 | 241 | frames_handled = 0 242 | for f in os.listdir(from_folder): 243 | if ('png' in f or 'jpg' in f) and '-' not in f and '_depth_' not in f and ((img_batch_id is not None and f.startswith(img_batch_id) or img_batch_id is None)): 244 | frames_handled +=1 245 | original_img_path = os.path.join(from_folder, f) 246 | if orig_vid_name is not None: 247 | shutil.copy(original_img_path, temp_convert_raw_png_path) 248 | else: 249 | image = cv2.imread(original_img_path) 250 | new_path = os.path.join(temp_convert_raw_png_path, f) 251 | cv2.imwrite(new_path, image, [cv2.IMWRITE_PNG_COMPRESSION, 0]) 252 | return frames_handled 253 | 254 | def add_soundtrack(ffmpeg_location=None, fps=None, outmp4_path=None, stitch_from_frame=0, stitch_to_frame=None, imgs_path=None, add_soundtrack=None, audio_path=None, crf=17, preset='veryslow', metadata=None): 255 | if add_soundtrack is None: 256 | return 257 | msg_to_print = f"Adding soundtrack to *video*..." 258 | start_time = time.time() 259 | try: 260 | cmd = [ 261 | ffmpeg_location, 262 | '-i', 263 | outmp4_path, 264 | '-i', 265 | audio_path, 266 | '-map', '0:v', 267 | '-map', '1:a', 268 | '-c:v', 'copy', 269 | '-shortest', 270 | outmp4_path+'.temp.mp4' 271 | ] 272 | process = subprocess.Popen( 273 | cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 274 | stdout, stderr = process.communicate() 275 | if process.returncode != 0: 276 | print("\r" + " " * len(msg_to_print), end="", flush=True) 277 | print(f"\r{msg_to_print}", flush=True) 278 | raise RuntimeError(stderr) 279 | os.replace(outmp4_path+'.temp.mp4', outmp4_path) 280 | print("\r" + " " * len(msg_to_print), end="", flush=True) 281 | print(f"\r{msg_to_print}", flush=True) 282 | print(f"\rFFmpeg Audio stitching \033[0;32mdone\033[0m in {time.time() - start_time:.2f} seconds!", flush=True) 283 | except Exception as e: 284 | print("\r" + " " * len(msg_to_print), end="", flush=True) 285 | print(f"\r{msg_to_print}", flush=True) 286 | print(f'\rError adding audio to video. Actual error: {e}', flush=True) 287 | print(f"FFMPEG Video (sorry, no audio) stitching \033[33mdone\033[0m in {time.time() - start_time:.2f} seconds!", flush=True) 288 | -------------------------------------------------------------------------------- /scripts/videocrafter/lvdm/samplers/ddim.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | from modules.shared import state 8 | from modules.sd_samplers_common import InterruptedException 9 | 10 | from videocrafter.lvdm.models.modules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like 11 | 12 | 13 | class DDIMSampler(object): 14 | def __init__(self, model, schedule="linear", **kwargs): 15 | super().__init__() 16 | self.model = model 17 | self.ddpm_num_timesteps = model.num_timesteps 18 | self.schedule = schedule 19 | self.counter = 0 20 | self.noise_gen = torch.Generator(device='cpu') 21 | 22 | def register_buffer(self, name, attr): 23 | if type(attr) == torch.Tensor: 24 | if attr.device != torch.device("cuda"): 25 | attr = attr.to(torch.device("cuda")) 26 | setattr(self, name, attr) 27 | 28 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 29 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 30 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 31 | alphas_cumprod = self.model.alphas_cumprod 32 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 33 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 34 | 35 | self.register_buffer('betas', to_torch(self.model.betas)) 36 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 37 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 38 | 39 | # calculations for diffusion q(x_t | x_{t-1}) and others 40 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 41 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 42 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 43 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 44 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 45 | 46 | # ddim sampling parameters 47 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 48 | ddim_timesteps=self.ddim_timesteps, 49 | eta=ddim_eta,verbose=verbose) 50 | self.register_buffer('ddim_sigmas', ddim_sigmas) 51 | self.register_buffer('ddim_alphas', ddim_alphas) 52 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 53 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 54 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 55 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 56 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 57 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 58 | 59 | @torch.no_grad() 60 | def sample(self, 61 | S, 62 | batch_size, 63 | shape, 64 | conditioning=None, 65 | callback=None, 66 | img_callback=None, 67 | quantize_x0=False, 68 | eta=0., 69 | mask=None, 70 | x0=None, 71 | temperature=1., 72 | noise_dropout=0., 73 | score_corrector=None, 74 | corrector_kwargs=None, 75 | verbose=True, 76 | schedule_verbose=False, 77 | x_T=None, 78 | log_every_t=100, 79 | unconditional_guidance_scale=1., 80 | unconditional_conditioning=None, 81 | postprocess_fn=None, 82 | sample_noise=None, 83 | cond_fn=None, 84 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 85 | **kwargs 86 | ): 87 | 88 | # check condition bs 89 | if conditioning is not None: 90 | if isinstance(conditioning, dict): 91 | try: 92 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 93 | if cbs != batch_size: 94 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 95 | except: 96 | # cbs = conditioning[list(conditioning.keys())[0]][0].shape[0] 97 | pass 98 | else: 99 | if conditioning.shape[0] != batch_size: 100 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 101 | 102 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=schedule_verbose) 103 | 104 | # make shape 105 | if len(shape) == 3: 106 | C, H, W = shape 107 | size = (batch_size, C, H, W) 108 | elif len(shape) == 4: 109 | C, T, H, W = shape 110 | size = (batch_size, C, T, H, W) 111 | 112 | samples, intermediates = self.ddim_sampling(conditioning, size, 113 | callback=callback, 114 | img_callback=img_callback, 115 | quantize_denoised=quantize_x0, 116 | mask=mask, x0=x0, 117 | ddim_use_original_steps=False, 118 | noise_dropout=noise_dropout, 119 | temperature=temperature, 120 | score_corrector=score_corrector, 121 | corrector_kwargs=corrector_kwargs, 122 | x_T=x_T, 123 | log_every_t=log_every_t, 124 | unconditional_guidance_scale=unconditional_guidance_scale, 125 | unconditional_conditioning=unconditional_conditioning, 126 | postprocess_fn=postprocess_fn, 127 | sample_noise=sample_noise, 128 | cond_fn=cond_fn, 129 | verbose=verbose, 130 | **kwargs 131 | ) 132 | return samples, intermediates 133 | 134 | @torch.no_grad() 135 | def ddim_sampling(self, cond, shape, 136 | x_T=None, ddim_use_original_steps=False, 137 | callback=None, timesteps=None, quantize_denoised=False, 138 | mask=None, x0=None, img_callback=None, log_every_t=100, 139 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 140 | unconditional_guidance_scale=1., unconditional_conditioning=None, 141 | postprocess_fn=None,sample_noise=None,cond_fn=None, 142 | uc_type=None, verbose=True, **kwargs, 143 | ): 144 | 145 | device = self.model.betas.device 146 | 147 | b = shape[0] 148 | if x_T is None: 149 | img = torch.randn(shape, device=device) 150 | else: 151 | img = x_T 152 | 153 | if timesteps is None: 154 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 155 | elif timesteps is not None and not ddim_use_original_steps: 156 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 157 | timesteps = self.ddim_timesteps[:subset_end] 158 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 159 | time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) 160 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 161 | if verbose: 162 | iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) 163 | else: 164 | iterator = time_range 165 | 166 | state.sampling_steps = total_steps 167 | 168 | for i, step in enumerate(iterator): 169 | state.sampling_step = i 170 | if state.interrupted: 171 | raise InterruptedException 172 | 173 | index = total_steps - i - 1 174 | ts = torch.full((b,), step, device=device, dtype=torch.long) 175 | 176 | if postprocess_fn is not None: 177 | img = postprocess_fn(img, ts) 178 | 179 | outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 180 | quantize_denoised=quantize_denoised, temperature=temperature, 181 | noise_dropout=noise_dropout, score_corrector=score_corrector, 182 | corrector_kwargs=corrector_kwargs, 183 | unconditional_guidance_scale=unconditional_guidance_scale, 184 | unconditional_conditioning=unconditional_conditioning, 185 | sample_noise=sample_noise,cond_fn=cond_fn,uc_type=uc_type, **kwargs,) 186 | img, pred_x0 = outs 187 | 188 | if mask is not None: 189 | # use mask to blend x_known_t-1 & x_sample_t-1 190 | assert x0 is not None 191 | x0 = x0.to(img.device) 192 | mask = mask.to(img.device) 193 | t = torch.tensor([step-1]*x0.shape[0], dtype=torch.long, device=img.device) 194 | img_known = self.model.q_sample(x0, t) 195 | img = img_known * mask + (1. - mask) * img 196 | 197 | if callback: callback(i) 198 | if img_callback: img_callback(pred_x0, i) 199 | 200 | if index % log_every_t == 0 or index == total_steps - 1: 201 | intermediates['x_inter'].append(img) 202 | intermediates['pred_x0'].append(pred_x0) 203 | if state.skipped: 204 | break 205 | 206 | return img, intermediates 207 | 208 | @torch.no_grad() 209 | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 210 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 211 | unconditional_guidance_scale=1., unconditional_conditioning=None, sample_noise=None, 212 | cond_fn=None, uc_type=None, 213 | **kwargs, 214 | ): 215 | b, *_, device = *x.shape, x.device 216 | if x.dim() == 5: 217 | is_video = True 218 | else: 219 | is_video = False 220 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 221 | e_t = self.model.apply_model(x, t, c, **kwargs) # unet denoiser 222 | else: 223 | # with unconditional condition 224 | if isinstance(c, torch.Tensor): 225 | e_t = self.model.apply_model(x, t, c, **kwargs) 226 | e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning, **kwargs) 227 | elif isinstance(c, dict): 228 | e_t = self.model.apply_model(x, t, c, **kwargs) 229 | e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning, **kwargs) 230 | else: 231 | raise NotImplementedError 232 | # text cfg 233 | if uc_type is None: 234 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) 235 | else: 236 | if uc_type == 'cfg_original': 237 | e_t = e_t + unconditional_guidance_scale * (e_t - e_t_uncond) 238 | elif uc_type == 'cfg_ours': 239 | e_t = e_t + unconditional_guidance_scale * (e_t_uncond - e_t) 240 | else: 241 | raise NotImplementedError 242 | 243 | if score_corrector is not None: 244 | assert self.model.parameterization == "eps" 245 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 246 | 247 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 248 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 249 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 250 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 251 | # select parameters corresponding to the currently considered timestep 252 | 253 | if is_video: 254 | size = (b, 1, 1, 1, 1) 255 | else: 256 | size = (b, 1, 1, 1) 257 | a_t = torch.full(size, alphas[index], device=device) 258 | a_prev = torch.full(size, alphas_prev[index], device=device) 259 | sigma_t = torch.full(size, sigmas[index], device=device) 260 | sqrt_one_minus_at = torch.full(size, sqrt_one_minus_alphas[index],device=device) 261 | 262 | # current prediction for x_0 263 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 264 | # print(f't={t}, pred_x0, min={torch.min(pred_x0)}, max={torch.max(pred_x0)}',file=f) 265 | if quantize_denoised: 266 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 267 | # direction pointing to x_t 268 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 269 | 270 | if sample_noise is None: 271 | noise = sigma_t * noise_like(x.shape, device, repeat_noise, self.noise_gen) * temperature 272 | if noise_dropout > 0.: 273 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 274 | else: 275 | noise = sigma_t * sample_noise * temperature 276 | 277 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 278 | 279 | return x_prev, pred_x0 280 | --------------------------------------------------------------------------------