├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── THIRD-PARTY ├── callbacks ├── common.py └── instruct_p2p_video.py ├── configs ├── instruct_v2v.yaml └── instruct_v2v_inference.yaml ├── data ├── airbrush-painting.mp4 ├── airplane-and-contrail.mp4 ├── audi-snow-trail.mp4 ├── car-turn.mp4 ├── cat-in-the-sun.mp4 ├── dirt-road-driving.mp4 ├── drift-turn.mp4 ├── earth-full-view.mp4 ├── eiffel-flyover.mp4 ├── ferris-wheel-timelapse.mp4 ├── gold-fish.mp4 ├── ice-hockey.mp4 ├── miami-surf.mp4 ├── raindrops.mp4 ├── red-roses-sunny-day.mp4 └── swans.mp4 ├── dataset ├── loveu_tgve_dataset.py ├── loveu_tgve_edit_prompt_dict.json ├── single_video_dataset.py └── videoP2P.py ├── figures ├── data_pipe.png ├── synthetic_sample │ ├── synthetic_video_106_0.gif │ ├── synthetic_video_116_0.gif │ ├── synthetic_video_141_0.gif │ ├── synthetic_video_18_0.gif │ ├── synthetic_video_192_0.gif │ ├── synthetic_video_197_0.gif │ ├── synthetic_video_1_0.gif │ ├── synthetic_video_24_0.gif │ ├── synthetic_video_81_0.gif │ └── synthetic_video_92_0.gif ├── teaser.png └── videos │ ├── TGVE_video_edit.mp4 │ ├── airbrush-painting_object.gif │ ├── airplane-and-contrail_background.gif │ ├── audi-snow-trail_background.gif │ ├── cat-in-the-sun_background.gif │ ├── cat-in-the-sun_object.gif │ ├── dirt-road-driving_style.gif │ ├── drift-turn_style.gif │ ├── earth-full-view_background.gif │ ├── eiffel-flyover_background.gif │ ├── ferris-wheel-timelapse_background.gif │ ├── gold-fish_style.gif │ ├── ice-hockey_object.gif │ ├── miami-surf_background.gif │ ├── raindrops_style.gif │ ├── red-roses-sunny-day_background.gif │ ├── red-roses-sunny-day_style.gif │ ├── swans_background.gif │ └── swans_object.gif ├── gradio_demo.py ├── insv2v_run_loveu_tgve.py ├── main.py ├── misc_utils ├── clip_similarity.py ├── flow_utils.py ├── image_utils.py ├── model_utils.py ├── ptp_utils.py ├── train_utils.py └── video_ptp_utils.py ├── modules ├── damo_text_to_video │ ├── configuration.json │ ├── text_model.py │ └── unet_sd.py ├── kl_autoencoder │ └── autoencoder.py ├── openclip │ └── modules.py ├── video_unet_temporal │ ├── attention.py │ ├── motion_module.py │ ├── resnet.py │ ├── unet.py │ └── unet_blocks.py └── vqvae │ ├── autoencoder.py │ └── model.py ├── pl_trainer ├── diffusion.py ├── inference │ ├── inference.py │ └── inference_damo.py └── instruct_p2p_video.py ├── requirements.txt ├── video_edit.ipynb └── video_prompt_to_prompt.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | dummy_dataset 131 | wandb 132 | experiments 133 | results 134 | .DS_Store 135 | pretrained_models 136 | debug.pkl 137 | ablation_results 138 | ablation_output* 139 | video_ptp 140 | *.pt 141 | *.pth 142 | webvid_edit_prompt 143 | evaluation_results 144 | tmp 145 | vptp_results 146 | repolint_result.txt -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT No Attribution 2 | 3 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so. 10 | 11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 12 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 13 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 14 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 15 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 16 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 17 | 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## This is the code release for the ICLR2024 paper [Consistent Video-to-Video Transfer Using Synthetic Dataset](https://arxiv.org/abs/2311.00213). 2 | 3 | ![teaser](figures/teaser.png) 4 | 5 | ## Quick Links 6 | * [Installation](#installation) 7 | * [Video Editing](#video-editing) 🔥 8 | * [Synthetic Video Prompt-to-Prompt Dataset](#synthetic-video-prompt-to-prompt-dataset) 9 | * [Training](#training) 10 | * [Create Synthetic Video Dataset](#create-synthetic-video-dataset) 11 | 12 | ## Updates 13 | * 2024/02/13: The official synthetic data and model will not be released due to Amazon policy, but we provide a third party reproduction of the synthetic data and model weights. Please refer to this [github repo](https://github.com/cplusx/INSV2V-3rd-pty-reprod) 14 | * 2023/11/29: We have updated paper with more comparison to recent baseline methods and updated the [comparison video](#visual-comparison-to-other-methods). Gradio demo code is uploaded. 15 | 16 | ## Installation 17 | ```bash 18 | git clone https://github.com/amazon-science/instruct-video-to-video.git 19 | pip install -r requirements.txt 20 | ``` 21 | NOTE: The code is tested on PyTorch 2.1.0+cu11.8 and corresponding xformers version. Any PyTorch version > 2.0 should work but please install the right corresponding xformers version. 22 | ## Video Editing 23 | We are undergoing the model release process. Please stay tuned. 24 | 25 | Download the [InsV2V model weights](https://github.com/cplusx/INSV2V-3rd-pty-reprod) and change the ckpt path in the following notebook. 26 | 27 | ✨🚀 This [notebook](video_edit.ipynb) provide a sample code to conduct text-based video editing. 28 | 29 | ### Download LOVEU Dataset for Testing 30 | Please follow the instructions in the [LOVEU Dataset](https://sites.google.com/view/loveucvpr23/track4) to download the dataset. Use the following [script](insv2v_run_loveu_tgve.py) to run editing on the LOVEU dataset: 31 | ```bash 32 | python insv2v_run_loveu_tgve.py \ 33 | --config configs/instruct_v2v.yaml \ 34 | --ckpt-path [PATH TO THE CHECKPOINT] \ 35 | --data-dir [PATH TO THE LOVEU DATASET] \ 36 | --with_optical_flow \ # use motion compensation 37 | --text-cfg 7.5 10 \ 38 | --video-cfg 1.2 1.5 \ 39 | --image-size 256 384 40 | ``` 41 | Note: you may need to try different combination of image resolution, video/text classifier free guidance scale to find the best editing results. 42 | 43 | Example results of editing LOVEU-TGVE Dataset: 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 |
Image DescriptionImage Description
Image DescriptionImage Description
Image DescriptionImage Description
Image DescriptionImage Description
Image DescriptionImage Description
68 | 69 | 70 | ## Synthetic Video Prompt-to-Prompt Dataset 71 | 72 | Generation pipeline of the synthetic video dataset: 73 | ![generation pipeline](figures/data_pipe.png) 74 | 75 | Examples of the synthetic video dataset: 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 |
Image DescriptionImage Description
Image DescriptionImage Description
Image DescriptionImage Description
Image DescriptionImage Description
Image DescriptionImage Description
98 | 99 | ## Training 100 | 101 | ### Download Foundational Models 102 | [Download](https://drive.google.com/file/d/1R9sWsnGZUa5P8IB5DDfD9eU-T9SQLsFw/view?usp=sharing) the foundational models and place them in the `pretrained_models` folder. 103 | 104 | ### Download Synthetic Video Dataset 105 | [See download link in the third party reproduction](https://github.com/cplusx/INSV2V-3rd-pty-reprod) 106 | 107 | ### Train the Model 108 | Put the synthetic video dataset in the `video_ptp` folder. 109 | 110 | Run the following command to train the model: 111 | ```bash 112 | python main.py --config configs/instruct_v2v.yaml -r # add -r to resume training if the training is interrupted 113 | ``` 114 | 115 | ## Create Synthetic Video Dataset 116 | If you want to create your own synthetic video dataset, please follow the instructions 117 | * Download the modelscope VAE, UNet and text encoder weights from [here](https://huggingface.co/damo-vilab/modelscope-damo-text-to-video-synthesis/tree/main) 118 | * Replace the model path in the [`video_prompt_to_prompt.py`](video_prompt_to_prompt.py) file 119 | ``` 120 | vae_ckpt = 'VAE_PATH' 121 | unet_ckpt = 'UNet_PATH' 122 | text_model_ckpt = 'Text_MODEL_PATH' 123 | ``` 124 | * Download the edit prompt files from [Instruct Pix2Pix](https://github.com/timothybrooks/instruct-pix2pix). The prompt file should be `gpt-generated-prompts.jsonl`, and change the file path in the `video_prompt_to_prompt.py` accordingly. Or download the WebVid prompt edit file proposed in our paper from [To be released](). 125 | * Run the command to generate the synthetic video dataset: 126 | ```bash 127 | python video_prompt_to_prompt.py 128 | --start [START INDEX] \ 129 | --end [END INDEX] \ 130 | --prompt_source [ip2p or webvid] \ 131 | --num_sample_each_prompt [NUM SAMPLES FOR EACH PROMPT] 132 | ``` 133 | 134 | ## Visual Comparison to Other Methods 135 | 136 | https://github.com/amazon-science/instruct-video-to-video/assets/20940184/d3619652-dd75-41a0-92b4-345bbf57de40 137 | 138 | 139 | Links to the baselines used in the video: 140 | 141 | | [Tune-A-Video](https://github.com/showlab/Tune-A-Video) | [Control Video](https://github.com/thu-ml/controlvideo) | [Vid2Vid Zero](https://github.com/baaivision/vid2vid-zero) | [Video P2P](https://github.com/ShaoTengLiu/Video-P2P) | 142 | 143 | | [TokenFlow](https://github.com/omerbt/TokenFlow) | [Render A Video](https://github.com/williamyang1991/Rerender_A_Video) | [Pix2Video](https://github.com/duyguceylan/pix2video) | 144 | 145 | ## Credit 146 | The code was implemented by [Jiaxin Cheng](https://github.com/cplusx) during his internship at the AWS Shanghai Lablet. 147 | ## References 148 | Part of the code and the foundational models are adapted from the following works: 149 | * [Instruct Pix2Pix](https://github.com/timothybrooks/instruct-pix2pix) 150 | * [AnimateDiff](https://github.com/guoyww/animatediff/) 151 | -------------------------------------------------------------------------------- /callbacks/common.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision 4 | import wandb 5 | import cv2 6 | 7 | def unnorm(x): 8 | '''convert from range [-1, 1] to [0, 1]''' 9 | return (x+1) / 2 10 | 11 | def clip_image(x, min=0., max=1.): 12 | return torch.clamp(x, min=min, max=max) 13 | 14 | def format_dtype_and_shape(x): 15 | if isinstance(x, torch.Tensor): 16 | if len(x.shape) == 3 and x.shape[0] == 3: 17 | x = x.permute(1, 2, 0) 18 | if len(x.shape) == 4 and x.shape[1] == 3: 19 | x = x.permute(0, 2, 3, 1) 20 | x = x.detach().cpu().numpy() 21 | return x 22 | 23 | def tensor2image(x): 24 | x = x.float() # handle bf16 25 | '''convert 4D (b, dim, h, w) pytorch tensor to wandb Image class''' 26 | grid_img = torchvision.utils.make_grid( 27 | x, nrow=4 28 | ).permute(1, 2, 0).detach().cpu().numpy() 29 | img = wandb.Image( 30 | grid_img 31 | ) 32 | return img 33 | 34 | def save_figure(image, save_path): 35 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 36 | if image.min() < 0: 37 | image = clip_image(unnorm(image)) 38 | image = format_dtype_and_shape(image) 39 | image = (image * 255).astype(np.uint8) 40 | cv2.imwrite(save_path, image[..., ::-1]) 41 | 42 | def save_sampling_history(image, save_path): 43 | if image.min() < 0: 44 | image = clip_image(unnorm(image)) 45 | grid_img = torchvision.utils.make_grid(image, nrow=4) 46 | save_figure(grid_img, save_path) -------------------------------------------------------------------------------- /callbacks/instruct_p2p_video.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.loggers import WandbLogger 2 | from pytorch_lightning.callbacks import Callback 3 | from .common import tensor2image, clip_image, unnorm 4 | from einops import rearrange 5 | 6 | def frame_dim_to_batch_dim(x): 7 | return rearrange(x, 'b f c h w -> (b f) c h w') 8 | 9 | class InstructP2PLogger(Callback): 10 | def __init__( 11 | self, 12 | wandb_logger: WandbLogger=None, 13 | max_num_images: int=16, 14 | ) -> None: 15 | super().__init__() 16 | self.wandb_logger = wandb_logger 17 | self.max_num_images = max_num_images 18 | 19 | def on_train_batch_end( 20 | self, trainer, pl_module, 21 | outputs, batch, batch_idx 22 | ): 23 | # record images in first batch 24 | if batch_idx == 0: 25 | input_image = tensor2image(frame_dim_to_batch_dim(clip_image(unnorm( 26 | batch['input_video'][:self.max_num_images] 27 | )))) 28 | edited_image = tensor2image(frame_dim_to_batch_dim(clip_image(unnorm( 29 | batch['edited_video'][:self.max_num_images] 30 | )))) 31 | pred_image = tensor2image(frame_dim_to_batch_dim(clip_image(unnorm( 32 | outputs['pred'][:self.max_num_images] 33 | )))) 34 | self.wandb_logger.experiment.log({ 35 | 'train/input_image': input_image, 36 | 'train/edited_image': edited_image, 37 | 'train/pred': pred_image, 38 | }) 39 | 40 | def on_validation_batch_end( 41 | self, trainer, pl_module, 42 | outputs, batch, batch_idx 43 | ): 44 | """Called when the validation batch ends.""" 45 | if batch_idx == 0: 46 | input_image = tensor2image(frame_dim_to_batch_dim(clip_image(unnorm( 47 | batch['input_video'][:self.max_num_images] 48 | )))) 49 | edited_image = tensor2image(frame_dim_to_batch_dim(clip_image(unnorm( 50 | batch['edited_video'][:self.max_num_images] 51 | )))) 52 | pred_image = tensor2image(frame_dim_to_batch_dim(clip_image(unnorm( 53 | outputs['pred'][:self.max_num_images] 54 | )))) 55 | self.wandb_logger.experiment.log({ 56 | 'val/input_image': input_image, 57 | 'val/edited_image': edited_image, 58 | 'val/pred': pred_image, 59 | }) 60 | -------------------------------------------------------------------------------- /configs/instruct_v2v.yaml: -------------------------------------------------------------------------------- 1 | expt_dir: experiments 2 | expt_name: instruct_v2v 3 | trainer_args: 4 | max_epochs: 10 5 | accelerator: "gpu" 6 | devices: [0,1,2,3] 7 | limit_train_batches: 2048 8 | limit_val_batches: 1 9 | # strategy: "ddp" 10 | strategy: "deepspeed_stage_2" 11 | accumulate_grad_batches: 256 12 | check_val_every_n_epoch: 5 13 | diffusion: 14 | target: pl_trainer.instruct_p2p_video.InstructP2PVideoTrainerTemporal 15 | params: 16 | beta_schedule_args: 17 | beta_schedule: scaled_linear 18 | num_train_timesteps: 1000 19 | beta_start: 0.00085 20 | beta_end: 0.012 21 | clip_sample: false 22 | thresholding: false 23 | prediction_type: epsilon 24 | loss_fn: l2 25 | optim_args: 26 | lr: 1e-5 27 | unet_init_weights: 28 | - pretrained_models/instruct_pix2pix/diffusion_pytorch_model.bin 29 | - pretrained_models/Motion_Module/mm_sd_v15.ckpt 30 | vae_init_weights: pretrained_models/instruct_pix2pix/vqvae.ckpt 31 | text_model_init_weights: pretrained_models/instruct_pix2pix/text.ckpt 32 | scale_factor: 0.18215 33 | guidance_scale: 5 # not used 34 | ddim_sampling_steps: 20 35 | text_cfg: 7.5 36 | img_cfg: 1.2 37 | cond_image_dropout: 0.1 38 | prompt_type: edit_prompt 39 | unet: 40 | target: modules.video_unet_temporal.unet.UNet3DConditionModel 41 | params: 42 | in_channels: 8 43 | out_channels: 4 44 | act_fn: silu 45 | attention_head_dim: 8 46 | block_out_channels: 47 | - 320 48 | - 640 49 | - 1280 50 | - 1280 51 | cross_attention_dim: 768 52 | down_block_types: 53 | - CrossAttnDownBlock3D 54 | - CrossAttnDownBlock3D 55 | - CrossAttnDownBlock3D 56 | - DownBlock3D 57 | up_block_types: 58 | - UpBlock3D 59 | - CrossAttnUpBlock3D 60 | - CrossAttnUpBlock3D 61 | - CrossAttnUpBlock3D 62 | downsample_padding: 1 63 | layers_per_block: 2 64 | mid_block_scale_factor: 1 65 | norm_eps: 1e-05 66 | norm_num_groups: 32 67 | sample_size: 64 68 | use_motion_module: true 69 | motion_module_resolutions: 70 | - 1 71 | - 2 72 | - 4 73 | - 8 74 | motion_module_mid_block: false 75 | motion_module_decoder_only: false 76 | motion_module_type: Vanilla 77 | motion_module_kwargs: 78 | num_attention_heads: 8 79 | num_transformer_block: 1 80 | attention_block_types: 81 | - Temporal_Self 82 | - Temporal_Self 83 | temporal_position_encoding: true 84 | temporal_position_encoding_max_len: 32 85 | temporal_attention_dim_div: 1 86 | vae: 87 | target: modules.kl_autoencoder.autoencoder.AutoencoderKL 88 | params: 89 | embed_dim: 4 90 | ddconfig: 91 | double_z: true 92 | z_channels: 4 93 | resolution: 256 94 | in_channels: 3 95 | out_ch: 3 96 | ch: 128 97 | ch_mult: 98 | - 1 99 | - 2 100 | - 4 101 | - 4 102 | num_res_blocks: 2 103 | attn_resolutions: [] 104 | dropout: 0.0 105 | lossconfig: 106 | target: torch.nn.Identity 107 | text_model: 108 | target: modules.openclip.modules.FrozenCLIPEmbedder 109 | params: 110 | freeze: true 111 | data: 112 | batch_size: 1 113 | val_batch_size: 1 114 | train: 115 | target: dataset.videoP2P.VideoPromptToPromptMotionAug 116 | params: 117 | root_dirs: 118 | - video_ptp/raw_generated 119 | - video_ptp/raw_generated_webvid 120 | num_frames: 16 121 | zoom_ratio: 0.2 122 | max_zoom: 1.25 123 | translation_ratio: 0.7 124 | translation_range: [0, 0.2] 125 | val: 126 | target: dataset.videoP2P.VideoPromptToPromptMotionAug 127 | params: 128 | root_dirs: 129 | - video_ptp/raw_generated 130 | num_frames: 16 131 | zoom_ratio: 0.2 132 | max_zoom: 1.25 133 | translation_ratio: 0.7 134 | translation_range: [0, 0.2] 135 | callbacks: 136 | - target: pytorch_lightning.callbacks.ModelCheckpoint 137 | params: 138 | dirpath: "${expt_dir}/${expt_name}" 139 | filename: "{epoch:04d}" 140 | monitor: epoch 141 | mode: max 142 | save_top_k: 5 143 | save_last: true 144 | - target: callbacks.instruct_p2p_video.InstructP2PLogger 145 | params: 146 | max_num_images: 1 147 | require_wandb: true -------------------------------------------------------------------------------- /configs/instruct_v2v_inference.yaml: -------------------------------------------------------------------------------- 1 | diffusion: 2 | target: pl_trainer.instruct_p2p_video.InstructP2PVideoTrainerTemporal 3 | params: 4 | beta_schedule_args: 5 | beta_schedule: scaled_linear 6 | num_train_timesteps: 1000 7 | beta_start: 0.00085 8 | beta_end: 0.012 9 | clip_sample: false 10 | thresholding: false 11 | prediction_type: epsilon 12 | loss_fn: l2 13 | optim_args: 14 | lr: 1e-5 15 | scale_factor: 0.18215 16 | guidance_scale: 5 # not used 17 | ddim_sampling_steps: 20 18 | text_cfg: 7.5 19 | img_cfg: 1.2 20 | cond_image_dropout: 0.1 21 | prompt_type: edit_prompt 22 | unet: 23 | target: modules.video_unet_temporal.unet.UNet3DConditionModel 24 | params: 25 | in_channels: 8 26 | out_channels: 4 27 | act_fn: silu 28 | attention_head_dim: 8 29 | block_out_channels: 30 | - 320 31 | - 640 32 | - 1280 33 | - 1280 34 | cross_attention_dim: 768 35 | down_block_types: 36 | - CrossAttnDownBlock3D 37 | - CrossAttnDownBlock3D 38 | - CrossAttnDownBlock3D 39 | - DownBlock3D 40 | up_block_types: 41 | - UpBlock3D 42 | - CrossAttnUpBlock3D 43 | - CrossAttnUpBlock3D 44 | - CrossAttnUpBlock3D 45 | downsample_padding: 1 46 | layers_per_block: 2 47 | mid_block_scale_factor: 1 48 | norm_eps: 1e-05 49 | norm_num_groups: 32 50 | sample_size: 64 51 | use_motion_module: true 52 | motion_module_resolutions: 53 | - 1 54 | - 2 55 | - 4 56 | - 8 57 | motion_module_mid_block: false 58 | motion_module_decoder_only: false 59 | motion_module_type: Vanilla 60 | motion_module_kwargs: 61 | num_attention_heads: 8 62 | num_transformer_block: 1 63 | attention_block_types: 64 | - Temporal_Self 65 | - Temporal_Self 66 | temporal_position_encoding: true 67 | temporal_position_encoding_max_len: 32 68 | temporal_attention_dim_div: 1 69 | vae: 70 | target: modules.kl_autoencoder.autoencoder.AutoencoderKL 71 | params: 72 | embed_dim: 4 73 | ddconfig: 74 | double_z: true 75 | z_channels: 4 76 | resolution: 256 77 | in_channels: 3 78 | out_ch: 3 79 | ch: 128 80 | ch_mult: 81 | - 1 82 | - 2 83 | - 4 84 | - 4 85 | num_res_blocks: 2 86 | attn_resolutions: [] 87 | dropout: 0.0 88 | lossconfig: 89 | target: torch.nn.Identity 90 | text_model: 91 | target: modules.openclip.modules.FrozenCLIPEmbedder 92 | params: 93 | freeze: true -------------------------------------------------------------------------------- /data/airbrush-painting.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/data/airbrush-painting.mp4 -------------------------------------------------------------------------------- /data/airplane-and-contrail.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/data/airplane-and-contrail.mp4 -------------------------------------------------------------------------------- /data/audi-snow-trail.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/data/audi-snow-trail.mp4 -------------------------------------------------------------------------------- /data/car-turn.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/data/car-turn.mp4 -------------------------------------------------------------------------------- /data/cat-in-the-sun.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/data/cat-in-the-sun.mp4 -------------------------------------------------------------------------------- /data/dirt-road-driving.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/data/dirt-road-driving.mp4 -------------------------------------------------------------------------------- /data/drift-turn.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/data/drift-turn.mp4 -------------------------------------------------------------------------------- /data/earth-full-view.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/data/earth-full-view.mp4 -------------------------------------------------------------------------------- /data/eiffel-flyover.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/data/eiffel-flyover.mp4 -------------------------------------------------------------------------------- /data/ferris-wheel-timelapse.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/data/ferris-wheel-timelapse.mp4 -------------------------------------------------------------------------------- /data/gold-fish.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/data/gold-fish.mp4 -------------------------------------------------------------------------------- /data/ice-hockey.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/data/ice-hockey.mp4 -------------------------------------------------------------------------------- /data/miami-surf.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/data/miami-surf.mp4 -------------------------------------------------------------------------------- /data/raindrops.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/data/raindrops.mp4 -------------------------------------------------------------------------------- /data/red-roses-sunny-day.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/data/red-roses-sunny-day.mp4 -------------------------------------------------------------------------------- /data/swans.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/data/swans.mp4 -------------------------------------------------------------------------------- /dataset/loveu_tgve_dataset.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import cv2 3 | import os 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import Dataset 7 | from torchvision import transforms 8 | 9 | class LoveuTgveVideoDataset(Dataset): 10 | def __init__(self, root_dir, image_size=(480, 480)): 11 | self.root_dir = root_dir 12 | self.image_size = image_size 13 | self.transform = transforms.Compose([ 14 | transforms.ToTensor(), 15 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # normalize to [-1,1] for each channel 16 | ]) 17 | 18 | csv_file = os.path.join(root_dir, 'LOVEU-TGVE-2023_Dataset.csv') 19 | self.data = {} 20 | with open(csv_file, 'r') as file: 21 | reader = csv.reader(file) 22 | next(reader, None) # skip the headers 23 | for row in reader: 24 | if len(row[0]) == 0: 25 | continue 26 | if row[0].endswith('Videos:'): 27 | dataset_type = row[0].split(' ')[0] 28 | if dataset_type == 'DAVIS': 29 | self.source_folder = dataset_type + '_480p/480p_videos' 30 | else: 31 | self.source_folder = dataset_type.lower() + '_480p/480p_videos' 32 | elif len(row) > 1: 33 | video_name = row[0] 34 | self.data[video_name] = { 35 | 'video_name': video_name, 36 | 'original': row[1], 37 | 'style': row[2], 38 | 'object': row[3], 39 | 'background': row[4], 40 | 'multiple': row[5], 41 | 'source_folder': self.source_folder, 42 | } 43 | 44 | def load_frames(self, video_name, source_folder): 45 | video_path = os.path.join(self.root_dir, source_folder, f'{video_name}.mp4') 46 | cap = cv2.VideoCapture(video_path) 47 | frames = [] 48 | while cap.isOpened(): 49 | ret, frame = cap.read() 50 | if ret: 51 | frame = cv2.resize(frame, self.image_size) 52 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 53 | frame = self.transform(frame) # convert to PyTorch tensor and normalize 54 | frames.append(frame) 55 | else: 56 | break 57 | cap.release() 58 | return torch.stack(frames, dim=0) 59 | 60 | def load_fps(self, video_name, source_folder): 61 | video_path = os.path.join(self.root_dir, source_folder, f'{video_name}.mp4') 62 | cap = cv2.VideoCapture(video_path) 63 | fps = cap.get(cv2.CAP_PROP_FPS) 64 | cap.release() 65 | return fps 66 | 67 | def __len__(self): 68 | return len(self.data) 69 | 70 | def __getitem__(self, idx): 71 | if isinstance(idx, str): 72 | video_name = idx 73 | else: 74 | video_name = list(self.data.keys())[idx] 75 | 76 | # Load frames and fps when the dataset is called 77 | source_folder = self.data[video_name]['source_folder'] 78 | frames = self.load_frames(video_name, source_folder) 79 | fps = self.load_fps(video_name, source_folder) 80 | item = self.data[video_name].copy() 81 | item['frames'] = frames 82 | item['fps'] = fps 83 | 84 | return item -------------------------------------------------------------------------------- /dataset/single_video_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from einops import rearrange 4 | import cv2 5 | import torch 6 | from torch.utils.data import Dataset 7 | from torchvision.transforms import functional as F 8 | import omegaconf 9 | 10 | class SingleVideoDataset(Dataset): 11 | def __init__( 12 | self, 13 | video_file, 14 | video_description, 15 | sampling_fps=24, 16 | frame_gap=0, 17 | num_frames=2, 18 | output_size=(512, 512), 19 | mode='train' 20 | ): 21 | self.video_file = video_file 22 | self.video_id = os.path.splitext(os.path.basename(video_file))[0] 23 | self.sampling_fps = sampling_fps 24 | self.frame_gap = frame_gap 25 | self.description = video_description 26 | self.output_size = output_size 27 | self.mode = mode 28 | self.num_frames = num_frames 29 | 30 | cap = cv2.VideoCapture(video_file) 31 | video_fps = round(cap.get(cv2.CAP_PROP_FPS)) 32 | num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 33 | cap.release() 34 | 35 | if self.sampling_fps is not None: 36 | if isinstance(self.sampling_fps, int): 37 | sampling_fps = self.sampling_fps 38 | elif isinstance(self.sampling_fps, list) or isinstance(self.sampling_fps, omegaconf.listconfig.ListConfig): 39 | sampling_fps = random.choice(self.sampling_fps) 40 | assert isinstance(sampling_fps, int) 41 | else: 42 | raise ValueError(f'sampling_fps should be int or list of int, got {self.sampling_fps}') 43 | sampling_fps = int(min(sampling_fps, video_fps)) 44 | frame_gap = max(0, int(video_fps / sampling_fps)) 45 | else: 46 | sampling_fps = video_fps // (1 + self.frame_gap) 47 | frame_gap = self.frame_gap 48 | 49 | self.frame_gap = frame_gap 50 | self.sampling_fps = sampling_fps 51 | 52 | cap.release() 53 | 54 | # print(num_frames, frame_gap, self.num_frames) 55 | self.num_frames = min(self.num_frames, num_frames // frame_gap) # this is number of sampling frames 56 | self.total_possible_starting_frames = max(0, num_frames - (frame_gap * (self.num_frames - 1))) 57 | 58 | def __len__(self): 59 | return self.total_possible_starting_frames 60 | 61 | def __getitem__(self, index): 62 | video_file = self.video_file 63 | 64 | cap = cv2.VideoCapture(video_file) 65 | video_fps = round(cap.get(cv2.CAP_PROP_FPS)) 66 | num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 67 | 68 | sampling_fps = self.sampling_fps 69 | frame_gap = self.frame_gap 70 | 71 | frames = [] 72 | if num_frames > 1 + frame_gap: 73 | first_frame_index = index 74 | for i in range(self.num_frames): 75 | cap.set(cv2.CAP_PROP_POS_FRAMES, first_frame_index + i * frame_gap) 76 | ret, frame = cap.read() 77 | 78 | # Convert BGR to RGB and resize frame 79 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) if frame is not None else None 80 | 81 | if frame is not None: 82 | height, width, _ = frame.shape 83 | aspect_ratio = width / height 84 | target_width = int(self.output_size[0] * aspect_ratio) 85 | frame = rearrange(torch.from_numpy(frame), 'h w c -> c h w') 86 | frame = F.resize(frame, (self.output_size[1], target_width), antialias=True) 87 | 88 | if target_width > self.output_size[1]: 89 | margin = (target_width - self.output_size[1]) // 2 90 | frame = F.crop(frame, 0, margin, self.output_size[1], self.output_size[0]) 91 | else: 92 | margin = (self.output_size[1] - target_width) // 2 93 | frame = F.pad(frame, (margin, 0), 0, 'constant') 94 | frame = (frame / 127.5) - 1.0 95 | 96 | frames.append(frame) 97 | else: 98 | for i in range(self.num_frames): 99 | frames.append(None) 100 | 101 | cap.release() 102 | 103 | # Stack frames 104 | frames = [frame for frame in frames if frame is not None] 105 | while len(frames) < self.num_frames: 106 | frames.append(frames[-1]) 107 | frames = frames[:self.num_frames] 108 | frames = torch.stack(frames, dim=0) 109 | 110 | caption = self.description 111 | 112 | res = { 113 | 'frames': frames, 114 | 'video_id': self.video_id, 115 | 'text': caption, 116 | 'fps': torch.tensor(sampling_fps), 117 | } 118 | return res -------------------------------------------------------------------------------- /dataset/videoP2P.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import jsonlines 4 | import cv2 5 | import torch 6 | from torch.utils.data import Dataset 7 | import numpy as np 8 | 9 | class VideoPromptToPrompt(Dataset): 10 | def __init__(self, root_dirs, num_frames=8): 11 | if isinstance(root_dirs, str): 12 | root_dirs = [root_dirs] 13 | self.root_dirs = root_dirs 14 | self.image_folders = [] 15 | for root_dir in self.root_dirs: 16 | image_folders = [f for f in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, f))] 17 | self.num_frames = num_frames 18 | 19 | for f in image_folders: 20 | if os.path.exists(os.path.join(root_dir, f, 'image')) and os.path.exists(os.path.join(root_dir, f, 'metadata.jsonl')) and os.path.exists(os.path.join(root_dir, f, 'prompt.json')): 21 | self.image_folders.append( 22 | os.path.join(root_dir, f) 23 | ) 24 | 25 | def __len__(self): 26 | return len(self.image_folders) 27 | 28 | def numpy_to_tensor(self, image): 29 | image = torch.from_numpy(image.transpose((0, 3, 1, 2))).to(torch.float32) 30 | return image * 2 - 1 31 | 32 | def __getitem__(self, idx): 33 | folder = self.image_folders[idx] 34 | with jsonlines.open(os.path.join(folder, 'metadata.jsonl')) as reader: 35 | seeds = [obj['seed'] for obj in reader if (obj['sim_dir'] > 0.2 and obj['sim_0'] > 0.2 and obj['sim_1'] > 0.2 and obj['sim_image'] > 0.5)] 36 | seed = np.random.choice(seeds) 37 | 38 | # Load prompt.json 39 | with open(os.path.join(folder, 'prompt.json'), 'r') as f: 40 | prompt = json.load(f) 41 | 42 | start_idx = np.random.randint(0, 16 - self.num_frames) 43 | end_idx = start_idx + self.num_frames 44 | 45 | input_images = np.array([self.load_image(os.path.join(folder, 'image', f'{seed}_0_{img_idx:04d}.jpg')) for img_idx in range(start_idx, end_idx)]) 46 | edited_images = np.array([self.load_image(os.path.join(folder, 'image', f'{seed}_1_{img_idx:04d}.jpg')) for img_idx in range(start_idx, end_idx)]) 47 | 48 | return { 49 | 'input_video': self.numpy_to_tensor(input_images), 50 | 'edited_video': self.numpy_to_tensor(edited_images), 51 | 'input_prompt': prompt['input'], 52 | 'output_prompt': prompt['output'], 53 | 'edit_prompt': prompt['edit'] 54 | } 55 | 56 | @staticmethod 57 | def load_image(image_path): 58 | img = cv2.imread(image_path) 59 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 60 | img = img / 255.0 # Normalization 61 | return img 62 | 63 | 64 | class VideoPromptToPromptMotionAug(VideoPromptToPrompt): 65 | def __init__(self, *args, zoom_ratio=0.2, max_zoom=1.2, translation_ratio=0.3, translation_range=(0, 0.2), **kwargs): 66 | super().__init__(*args, **kwargs) 67 | self.zoom_ratio = zoom_ratio 68 | self.max_zoom = max_zoom 69 | self.translation_ratio = translation_ratio 70 | self.translation_range = translation_range 71 | 72 | def translation_crop(self, delta_h, delta_w, images): 73 | def center_crop(img, center_x, center_y, h, w): 74 | x_start = int(center_x - w / 2) 75 | x_end = int(x_start + w) 76 | y_start = int(center_y - h / 2) 77 | y_end = int(y_start + h) 78 | img = img[y_start: y_end, x_start: x_end] 79 | return img 80 | 81 | H, W = images.shape[1:3] 82 | crop_H = H - abs(delta_h) 83 | crop_W = W - abs(delta_w) 84 | 85 | if delta_h > 0: 86 | h_start = (H - delta_h) // 2 87 | h_end = h_start + delta_h 88 | else: 89 | h_end = H - (H + delta_h) // 2 90 | h_start = h_end + delta_h 91 | 92 | if delta_w > 0: 93 | w_start = (W - delta_w) // 2 94 | w_end = w_start + delta_w 95 | else: 96 | w_end = W - (W + delta_w) // 2 97 | w_start = w_end + delta_w 98 | 99 | center_xs = np.linspace(w_start, w_end, self.num_frames) 100 | center_ys = np.linspace(h_start, h_end, self.num_frames) 101 | 102 | if delta_h < 0: 103 | center_ys = center_ys[::-1] 104 | if delta_w < 0: 105 | center_xs = center_xs[::-1] 106 | 107 | images = np.stack([center_crop(img, center_x, center_y, crop_H, crop_W) for img, center_x, center_y in zip(images, center_xs, center_ys)], axis=0) 108 | images = np.stack([cv2.resize(img, (W, H), interpolation=cv2.INTER_CUBIC) for img in images], axis=0) 109 | return images 110 | 111 | def zoom_aug(self, images, final_scale=1., zoom_in_or_out='in'): 112 | def zoom_in_with_scale(img, scale): 113 | H, W = img.shape[:2] 114 | img = cv2.resize(img, (int(W * scale), int(H * scale)), interpolation=cv2.INTER_CUBIC) 115 | # center crop to H, W 116 | img = img[(img.shape[0] - H) // 2: (img.shape[0] - H) // 2 + H, (img.shape[1] - W) // 2: (img.shape[1] - W) // 2 + W] 117 | return img 118 | if final_scale <= 1.02 : 119 | return images 120 | if zoom_in_or_out == 'in': 121 | scales = np.linspace(1., final_scale, self.num_frames) 122 | images = np.array([zoom_in_with_scale(img, scale) for img, scale in zip(images, scales)]) 123 | elif zoom_in_or_out == 'out': 124 | scales = np.linspace(final_scale, 1., self.num_frames) 125 | images = np.array([zoom_in_with_scale(img, scale) for img, scale in zip(images, scales)]) 126 | return images 127 | 128 | def motion_augmentation(self, input_images, edited_images): 129 | H, W = input_images.shape[1:3] 130 | 131 | # translation augmentation 132 | if np.random.random() < self.translation_ratio: 133 | delta_h = np.random.uniform(self.translation_range[0], self.translation_range[1]) * H * np.random.choice([-1, 1]) 134 | delta_w = np.random.uniform(self.translation_range[0], self.translation_range[1]) * W * np.random.choice([-1, 1]) 135 | # print(delta_h, delta_w) 136 | input_images = self.translation_crop(delta_h, delta_w, input_images) 137 | edited_images = self.translation_crop(delta_h, delta_w, edited_images) 138 | 139 | # zoom augmentation 140 | if np.random.random() < self.zoom_ratio: 141 | final_scale = np.random.uniform(1., self.max_zoom) 142 | zoom_in_or_out = np.random.choice(['in', 'out']) 143 | # print(final_scale, zoom_in_or_out) 144 | input_images = self.zoom_aug(input_images, final_scale, zoom_in_or_out) 145 | edited_images = self.zoom_aug(edited_images, final_scale, zoom_in_or_out) 146 | 147 | return input_images, edited_images 148 | 149 | def __getitem__(self, idx): 150 | folder = self.image_folders[idx] 151 | with jsonlines.open(os.path.join(folder, 'metadata.jsonl')) as reader: 152 | seeds = [obj['seed'] for obj in reader if (obj['sim_dir'] > 0.2 and obj['sim_0'] > 0.2 and obj['sim_1'] > 0.2 and obj['sim_image'] > 0.5)] 153 | seed = np.random.choice(seeds) 154 | 155 | # Load prompt.json 156 | with open(os.path.join(folder, 'prompt.json'), 'r') as f: 157 | prompt = json.load(f) 158 | 159 | start_idx = np.random.randint(0, 16 - self.num_frames + 1) 160 | end_idx = start_idx + self.num_frames 161 | 162 | input_images = np.array([self.load_image(os.path.join(folder, 'image', f'{seed}_0_{img_idx:04d}.jpg')) for img_idx in range(start_idx, end_idx)]) 163 | edited_images = np.array([self.load_image(os.path.join(folder, 'image', f'{seed}_1_{img_idx:04d}.jpg')) for img_idx in range(start_idx, end_idx)]) 164 | 165 | input_images, edited_images = self.motion_augmentation(input_images, edited_images) 166 | 167 | return { 168 | 'input_video': self.numpy_to_tensor(input_images), 169 | 'edited_video': self.numpy_to_tensor(edited_images), 170 | 'input_prompt': prompt['input'], 171 | 'output_prompt': prompt['output'], 172 | 'edit_prompt': prompt['edit'] 173 | } -------------------------------------------------------------------------------- /figures/data_pipe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/data_pipe.png -------------------------------------------------------------------------------- /figures/synthetic_sample/synthetic_video_106_0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/synthetic_sample/synthetic_video_106_0.gif -------------------------------------------------------------------------------- /figures/synthetic_sample/synthetic_video_116_0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/synthetic_sample/synthetic_video_116_0.gif -------------------------------------------------------------------------------- /figures/synthetic_sample/synthetic_video_141_0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/synthetic_sample/synthetic_video_141_0.gif -------------------------------------------------------------------------------- /figures/synthetic_sample/synthetic_video_18_0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/synthetic_sample/synthetic_video_18_0.gif -------------------------------------------------------------------------------- /figures/synthetic_sample/synthetic_video_192_0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/synthetic_sample/synthetic_video_192_0.gif -------------------------------------------------------------------------------- /figures/synthetic_sample/synthetic_video_197_0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/synthetic_sample/synthetic_video_197_0.gif -------------------------------------------------------------------------------- /figures/synthetic_sample/synthetic_video_1_0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/synthetic_sample/synthetic_video_1_0.gif -------------------------------------------------------------------------------- /figures/synthetic_sample/synthetic_video_24_0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/synthetic_sample/synthetic_video_24_0.gif -------------------------------------------------------------------------------- /figures/synthetic_sample/synthetic_video_81_0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/synthetic_sample/synthetic_video_81_0.gif -------------------------------------------------------------------------------- /figures/synthetic_sample/synthetic_video_92_0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/synthetic_sample/synthetic_video_92_0.gif -------------------------------------------------------------------------------- /figures/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/teaser.png -------------------------------------------------------------------------------- /figures/videos/TGVE_video_edit.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/videos/TGVE_video_edit.mp4 -------------------------------------------------------------------------------- /figures/videos/airbrush-painting_object.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/videos/airbrush-painting_object.gif -------------------------------------------------------------------------------- /figures/videos/airplane-and-contrail_background.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/videos/airplane-and-contrail_background.gif -------------------------------------------------------------------------------- /figures/videos/audi-snow-trail_background.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/videos/audi-snow-trail_background.gif -------------------------------------------------------------------------------- /figures/videos/cat-in-the-sun_background.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/videos/cat-in-the-sun_background.gif -------------------------------------------------------------------------------- /figures/videos/cat-in-the-sun_object.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/videos/cat-in-the-sun_object.gif -------------------------------------------------------------------------------- /figures/videos/dirt-road-driving_style.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/videos/dirt-road-driving_style.gif -------------------------------------------------------------------------------- /figures/videos/drift-turn_style.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/videos/drift-turn_style.gif -------------------------------------------------------------------------------- /figures/videos/earth-full-view_background.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/videos/earth-full-view_background.gif -------------------------------------------------------------------------------- /figures/videos/eiffel-flyover_background.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/videos/eiffel-flyover_background.gif -------------------------------------------------------------------------------- /figures/videos/ferris-wheel-timelapse_background.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/videos/ferris-wheel-timelapse_background.gif -------------------------------------------------------------------------------- /figures/videos/gold-fish_style.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/videos/gold-fish_style.gif -------------------------------------------------------------------------------- /figures/videos/ice-hockey_object.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/videos/ice-hockey_object.gif -------------------------------------------------------------------------------- /figures/videos/miami-surf_background.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/videos/miami-surf_background.gif -------------------------------------------------------------------------------- /figures/videos/raindrops_style.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/videos/raindrops_style.gif -------------------------------------------------------------------------------- /figures/videos/red-roses-sunny-day_background.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/videos/red-roses-sunny-day_background.gif -------------------------------------------------------------------------------- /figures/videos/red-roses-sunny-day_style.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/videos/red-roses-sunny-day_style.gif -------------------------------------------------------------------------------- /figures/videos/swans_background.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/videos/swans_background.gif -------------------------------------------------------------------------------- /figures/videos/swans_object.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/videos/swans_object.gif -------------------------------------------------------------------------------- /gradio_demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gradio as gr 3 | import numpy as np 4 | from misc_utils.image_utils import save_tensor_to_gif 5 | from misc_utils.train_utils import unit_test_create_model 6 | from pl_trainer.inference.inference import InferenceIP2PVideoOpticalFlow 7 | from dataset.single_video_dataset import SingleVideoDataset 8 | import torch 9 | 10 | NEGATIVE_PROMPT = 'worst quality, normal quality, low quality, low res, blurry, text, watermark, logo, banner, extra digits, cropped, jpeg artifacts, signature, username, error, sketch ,duplicate, ugly, monochrome, horror, geometry, mutation, disgusting' 11 | # The order is: [video_path, edit_prompt, negative_prompt, text_cfg, video_cfg, resolution, sample_rate, num_frames, start_frame] 12 | CARTURN = [ 'data/car-turn.mp4', 'Change the car to a red Porsche and make the background beach.', NEGATIVE_PROMPT, 10, 1.1, 512, 10, 28, 20, ] 13 | AIRPLANE = ['data/airplane-and-contrail.mp4', "add Taj Mahal in the image", NEGATIVE_PROMPT, 10, 1.2, 512, 30, 28, 0] 14 | AUDI = ['data/audi-snow-trail.mp4', "make the car drive in desert trail.", NEGATIVE_PROMPT, 10, 1.5, 512, 3, 28, 0] 15 | CATINSUN_BKG = ['data/cat-in-the-sun.mp4', "change the background to a beach.", NEGATIVE_PROMPT, 7.5, 1.3, 512, 6, 28, 0] 16 | DIRTROAD = ['data/dirt-road-driving.mp4', 'add dust cloud effect.', NEGATIVE_PROMPT, 10, 1.2, 512, 6, 28, 0] 17 | EARTH = ['data/earth-full-view.mp4', 'add a fireworks display in the background..', NEGATIVE_PROMPT, 7.5, 1.2, 512, 30, 28, 0] 18 | EIFFELTOWER = ['data/eiffel-flyover.mp4', 'add a large fireworks display.', NEGATIVE_PROMPT, 10, 1.2, 512, 6, 28, 0] 19 | FERRIS = ['data/ferris-wheel-timelapse.mp4', 'Add a sunset in the background.', NEGATIVE_PROMPT, 10, 1.2, 512, 6, 28, 0] 20 | GOLDFISH = ['data/gold-fish.mp4', 'make the style impressionist', NEGATIVE_PROMPT, 10, 1.2, 512, 6, 28, 0] 21 | ICEHOCKEY = ['data/ice-hockey.mp4', 'make the players to cartoon characters.', NEGATIVE_PROMPT, 10, 1.5, 512, 6, 28, 0] 22 | MIAMISURF = ['data/miami-surf.mp4', 'change the background to wave pool.', NEGATIVE_PROMPT, 10, 1.2, 512, 6, 28, 0] 23 | RAINDROP = ['data/raindrops.mp4', 'Make the style expressionism.', NEGATIVE_PROMPT, 10, 1.2, 512, 6, 28, 0] 24 | REDROSE_BKG = ['data/red-roses-sunny-day.mp4', 'make background to moonlight.', NEGATIVE_PROMPT, 10, 1.2, 512, 6, 28, 0] 25 | REDROSE_STY = ['data/red-roses-sunny-day.mp4', 'Make the style origami.', NEGATIVE_PROMPT, 10, 1.2, 512, 6, 28, 0] 26 | SWAN_OBJ = ['data/swans.mp4', 'change swans to pink flamingos.', NEGATIVE_PROMPT, 7.5, 1.2, 512, 6, 28, 0] 27 | 28 | class VideoTransfer: 29 | def __init__(self, config_path, model_ckpt, device='cuda'): 30 | self.config_path = config_path 31 | self.model_ckpt = model_ckpt 32 | self.device = device 33 | self.diffusion_model = None 34 | self.pipe = None 35 | 36 | def _init_pipe(self): 37 | diffusion_model = unit_test_create_model(self.config_path, device=self.device) 38 | ckpt = torch.load(self.model_ckpt, map_location='cpu') 39 | diffusion_model.load_state_dict(ckpt, strict=False) 40 | self.diffusion_model = diffusion_model 41 | self.pipe = InferenceIP2PVideoOpticalFlow( 42 | unet = diffusion_model.unet, 43 | num_ddim_steps=20, 44 | scheduler='ddpm' 45 | ) 46 | 47 | def get_batch(self, video_path, video_sample_rate, num_frames, image_size, start_frame=0): 48 | dataset = SingleVideoDataset( 49 | video_file=video_path, 50 | video_description='', 51 | sampling_fps=video_sample_rate, 52 | num_frames=num_frames, 53 | output_size=(image_size, image_size) 54 | ) 55 | batch = dataset[start_frame] 56 | batch = {k: v.to(self.device)[None] if isinstance(v, torch.Tensor) else v for k, v in batch.items()} 57 | return batch 58 | 59 | @staticmethod 60 | def split_batch(cond, frames_in_batch=16, num_ref_frames=4): 61 | frames_in_following_batch = frames_in_batch - num_ref_frames 62 | conds = [cond[:, :frames_in_batch]] 63 | frame_ptr = frames_in_batch 64 | num_ref_frames_each_batch = [] 65 | 66 | while frame_ptr < cond.shape[1]: 67 | remaining_frames = cond.shape[1] - frame_ptr 68 | if remaining_frames < frames_in_batch: 69 | frames_in_following_batch = remaining_frames 70 | else: 71 | frames_in_following_batch = frames_in_batch - num_ref_frames 72 | this_ref_frames = frames_in_batch - frames_in_following_batch 73 | conds.append(cond[:, frame_ptr:frame_ptr+frames_in_following_batch]) 74 | frame_ptr += frames_in_following_batch 75 | num_ref_frames_each_batch.append(this_ref_frames) 76 | 77 | return conds, num_ref_frames_each_batch 78 | 79 | def get_splitted_batch_and_conds(self, batch, edit_promt, negative_prompt): 80 | diffusion_model = self.diffusion_model 81 | cond = [diffusion_model.encode_image_to_latent(frames) / 0.18215 for frames in batch['frames'].chunk(16, dim=1)] # when encoding, chunk the frames to avoid oom in vae, you can reduce the 16 if you have a smaller gpu 82 | cond = torch.cat(cond, dim=1) 83 | text_cond = diffusion_model.encode_text([edit_promt]) 84 | text_uncond = diffusion_model.encode_text([negative_prompt]) 85 | conds, num_ref_frames_each_batch = self.split_batch(cond, frames_in_batch=16, num_ref_frames=4) 86 | splitted_frames, _ = self.split_batch(batch['frames'], frames_in_batch=16, num_ref_frames=4) 87 | return conds, num_ref_frames_each_batch, text_cond, text_uncond, splitted_frames 88 | 89 | def transfer_video(self, video_path, edit_prompt, negative_prompt, text_cfg, video_cfg, resolution, video_sample_rate, num_frames, start_frame): 90 | # TODO, support seed 91 | video_name = os.path.basename(video_path).split('.')[0] 92 | output_file_id = f'{video_name}_{text_cfg}_{video_cfg}_{resolution}_{video_sample_rate}_{num_frames}' 93 | if self.pipe is None: 94 | self._init_pipe() 95 | 96 | batch = self.get_batch(video_path, video_sample_rate, num_frames, resolution, start_frame) 97 | conds, num_ref_frames_each_batch, text_cond, text_uncond, splitted_frames = self.get_splitted_batch_and_conds(batch, edit_prompt, negative_prompt) 98 | with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.float16): 99 | # First video clip 100 | cond1 = conds[0] 101 | latent_pred_list = [] 102 | init_latent = torch.randn_like(cond1) 103 | latent_pred = self.pipe( 104 | latent = init_latent, 105 | text_cond = text_cond, 106 | text_uncond = text_uncond, 107 | img_cond = cond1, 108 | text_cfg = text_cfg, 109 | img_cfg = video_cfg, 110 | )['latent'] 111 | latent_pred_list.append(latent_pred) 112 | 113 | 114 | # Subsequent video clips 115 | for prev_cond, cond_, prev_frame, curr_frame, num_ref_frames_ in zip( 116 | conds[:-1], conds[1:], splitted_frames[:-1], splitted_frames[1:], num_ref_frames_each_batch 117 | ): 118 | init_latent = torch.cat([init_latent[:, -num_ref_frames_:], torch.randn_like(cond_)], dim=1) 119 | cond_ = torch.cat([prev_cond[:, -num_ref_frames_:], cond_], dim=1) 120 | 121 | # additional kwargs for using motion compensation 122 | ref_images = prev_frame[:, -num_ref_frames_:] 123 | query_images = curr_frame 124 | additional_kwargs = { 125 | 'ref_images': ref_images, 126 | 'query_images': query_images, 127 | } 128 | 129 | latent_pred = self.pipe.second_clip_forward( 130 | latent = init_latent, 131 | text_cond = text_cond, 132 | text_uncond = text_uncond, 133 | img_cond = cond_, 134 | latent_ref = latent_pred[:, -num_ref_frames_:], 135 | noise_correct_step = 0.6, 136 | text_cfg = text_cfg, 137 | img_cfg = video_cfg, 138 | **additional_kwargs, 139 | )['latent'] 140 | latent_pred_list.append(latent_pred[:, num_ref_frames_:]) 141 | 142 | # Save GIF 143 | original_images = batch['frames'].cpu() 144 | latent_pred = torch.cat(latent_pred_list, dim=1) 145 | image_pred = self.diffusion_model.decode_latent_to_image(latent_pred).clip(-1, 1) 146 | transferred_images = image_pred.float().cpu() 147 | save_tensor_to_gif(original_images, f'gradio_cache/{output_file_id}_original.gif', fps=5) 148 | save_tensor_to_gif(transferred_images, f'gradio_cache/{output_file_id}.gif', fps=5) 149 | return f'gradio_cache/{output_file_id}_original.gif', f'gradio_cache/{output_file_id}.gif' 150 | 151 | video_transfer = VideoTransfer( 152 | config_path = 'configs/instruct_v2v_inference.yaml', 153 | model_ckpt = 'insv2v.pth', 154 | device = 'cuda', 155 | ) 156 | 157 | def transfer_video(video_path, edit_prompt, negative_prompt, text_cfg, video_cfg, resolution, video_sample_rate, num_frames, start_frame): 158 | transferred_video_path = video_transfer.transfer_video( 159 | video_path = video_path, 160 | edit_prompt = edit_prompt, 161 | negative_prompt = negative_prompt, 162 | text_cfg = float(text_cfg), 163 | video_cfg = float(video_cfg), 164 | resolution = int(resolution), 165 | video_sample_rate = int(video_sample_rate), 166 | num_frames = int(num_frames), 167 | start_frame = int(start_frame), 168 | ) 169 | return transferred_video_path # a gif image 170 | 171 | with gr.Blocks() as demo: 172 | with gr.Row(): 173 | video_source = gr.Video(label="Upload Video", interactive=True, width=384) 174 | video_input = gr.Image(type='filepath', width=384, label='Original Video') 175 | video_output = gr.Image(type='filepath', width=384, label='Edited Video') 176 | 177 | with gr.Row(): 178 | with gr.Column(scale=3): 179 | edit_prompt = gr.Textbox(label="Edit Prompt") 180 | negative_prompt = gr.Textbox(label="Negative Prompt") 181 | with gr.Column(scale=1): 182 | submit_btn = gr.Button(label="Transfer") 183 | 184 | with gr.Row(): 185 | with gr.Row(): 186 | text_cfg = gr.Textbox(label="Text classifier-free guidance", value=7.5) 187 | video_cfg = gr.Textbox(label="Video classifier-free guidance", value=1.2) 188 | resolution = gr.Textbox(label="Resolution", value=384) 189 | sample_rate = gr.Textbox(label="Video Sample Rate", value=5) 190 | num_frames = gr.Textbox(label="Number of frames", value=28) 191 | start_frame = gr.Textbox(label="Start frame index", value=0) 192 | 193 | gr.Examples( 194 | examples=[ 195 | EARTH, 196 | AUDI, 197 | DIRTROAD, 198 | CATINSUN_BKG, 199 | EIFFELTOWER, 200 | FERRIS, 201 | GOLDFISH, 202 | CARTURN, 203 | ICEHOCKEY, 204 | MIAMISURF, 205 | RAINDROP, 206 | REDROSE_BKG, 207 | REDROSE_STY, 208 | SWAN_OBJ, 209 | AIRPLANE, 210 | ], 211 | inputs=[ 212 | video_source, 213 | edit_prompt, 214 | negative_prompt, 215 | text_cfg, 216 | video_cfg, 217 | resolution, 218 | sample_rate, 219 | num_frames, 220 | start_frame, 221 | ], 222 | ) 223 | 224 | submit_btn.click( 225 | transfer_video, 226 | inputs=[ 227 | video_source, 228 | edit_prompt, 229 | negative_prompt, 230 | text_cfg, 231 | video_cfg, 232 | resolution, 233 | sample_rate, 234 | num_frames, 235 | start_frame, 236 | ], 237 | outputs=[ 238 | video_input, 239 | video_output 240 | ], 241 | ) 242 | 243 | demo.queue(concurrency_count=1).launch(share=True) 244 | -------------------------------------------------------------------------------- /insv2v_run_loveu_tgve.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from dataset.loveu_tgve_dataset import LoveuTgveVideoDataset 4 | from matplotlib import pyplot as plt 5 | from misc_utils.train_utils import unit_test_create_model 6 | from itertools import product 7 | from pl_trainer.inference.inference import InferenceIP2PVideo, InferenceIP2PVideoOpticalFlow 8 | import json 9 | import argparse 10 | from misc_utils.image_utils import save_tensor_to_gif, save_tensor_to_images 11 | 12 | def split_batch(cond, frames_in_batch=16, num_ref_frames=4): 13 | frames_in_following_batch = frames_in_batch - num_ref_frames 14 | conds = [cond[:, :frames_in_batch]] 15 | frame_ptr = frames_in_batch 16 | num_ref_frames_each_batch = [] 17 | 18 | while frame_ptr < cond.shape[1]: 19 | remaining_frames = cond.shape[1] - frame_ptr 20 | if remaining_frames < frames_in_batch: 21 | frames_in_following_batch = remaining_frames 22 | else: 23 | frames_in_following_batch = frames_in_batch - num_ref_frames 24 | this_ref_frames = frames_in_batch - frames_in_following_batch 25 | conds.append(cond[:, frame_ptr:frame_ptr+frames_in_following_batch]) 26 | frame_ptr += frames_in_following_batch 27 | num_ref_frames_each_batch.append(this_ref_frames) 28 | 29 | return conds, num_ref_frames_each_batch 30 | 31 | parser = argparse.ArgumentParser(description='Your program description') 32 | 33 | # Add arguments 34 | parser.add_argument('--text-cfg', nargs='+', type=float, default=[7.5], help='Text configuration parameter') 35 | parser.add_argument('--video-cfg', nargs='+', type=float, default=[1.8], help='Image configuration parameter') 36 | parser.add_argument('--num-frames', nargs='+', type=int, default=[32], help='Number of frames') 37 | parser.add_argument('--image-size', nargs='+', type=int, default=[384], help='Image size') 38 | parser.add_argument('--prompt-source', type=str, default='edit', help='Prompt source') 39 | parser.add_argument('--ckpt-path', type=str, help='Path to checkpoint') 40 | parser.add_argument('--config-path', type=str, default='configs/instruct_v2v.yaml', help='Path to config file') 41 | parser.add_argument('--data-dir', type=str, default='loveu-tgve-2023', help='Path to LOVEU dataset') 42 | parser.add_argument('--with_optical_flow', action='store_true', help='Use motion compensation') 43 | 44 | # Parse arguments 45 | args = parser.parse_args() 46 | 47 | TEXT_CFGS = args.text_cfg 48 | VIDEO_CFGS = args.video_cfg 49 | NUM_FRAMES = args.num_frames 50 | IMAGE_SIZE = args.image_size 51 | 52 | PROMPT_SOURCE = args.prompt_source 53 | DATA_ROOT = args.data_dir 54 | 55 | config_path = args.config_path 56 | ckpt_path = args.ckpt_path 57 | 58 | diffusion_model = unit_test_create_model(config_path) 59 | 60 | ckpt = torch.load(ckpt_path, map_location='cpu') 61 | ckpt = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()} 62 | diffusion_model.load_state_dict(ckpt, strict=False) 63 | 64 | if args.with_optical_flow: 65 | inf_pipe = InferenceIP2PVideoOpticalFlow( 66 | unet = diffusion_model.unet, 67 | num_ddim_steps=20, 68 | scheduler='ddpm' 69 | ) 70 | else: 71 | inf_pipe = InferenceIP2PVideo( 72 | unet = diffusion_model.unet, 73 | num_ddim_steps=20, 74 | scheduler='ddpm' 75 | ) 76 | 77 | frames_in_batch = 16 78 | num_ref_frames = 4 79 | 80 | edit_prompt_file = 'dataset/loveu_tgve_edit_prompt_dict.json' 81 | edit_prompt_dict = json.load(open(edit_prompt_file, 'r')) 82 | 83 | for VIDEO_ID, text_cfg, video_cfg, num_frames, image_size in product(range(len(edit_prompt_dict)), TEXT_CFGS, VIDEO_CFGS, NUM_FRAMES, IMAGE_SIZE): 84 | dataset = LoveuTgveVideoDataset( 85 | root_dir=DATA_ROOT, 86 | image_size=(image_size, image_size), 87 | ) 88 | 89 | batch = dataset[VIDEO_ID] 90 | video_name = batch['video_name'] 91 | num_video_frames = len(batch['frames']) 92 | if num_video_frames > num_frames: 93 | frame_skip = num_video_frames // num_frames 94 | else: 95 | frame_skip = 1 96 | batch = {k: v[::frame_skip].cuda()[None] if isinstance(v, torch.Tensor) else v for k, v in batch.items()} 97 | 98 | cond = diffusion_model.encode_image_to_latent(batch['frames']) / 0.18215 99 | text_uncond = diffusion_model.encode_text(['']) 100 | 101 | for prompt_key in ['style', 'object', 'background', 'multiple']: 102 | final_prompt = batch[prompt_key] 103 | if PROMPT_SOURCE == 'edit': 104 | prompt = edit_prompt_dict[video_name]['edit_' + prompt_key] 105 | out_folder = f'v2v_results/edit_prompt/loveu_tgve_{image_size}/gif/VID_{VIDEO_ID}/VIDEO_CFG_{video_cfg}_TEXT_CFG_{text_cfg}' 106 | image_output_dir = f'v2v_results/edit_prompt/loveu_tgve_{image_size}/images_{num_frames}/VIDEO_CFG_{video_cfg}_TEXT_CFG_{text_cfg}/{video_name}/{prompt_key}' 107 | elif PROMPT_SOURCE == 'original': 108 | prompt = batch[prompt_key] 109 | out_folder = f'v2v_results/original_prompt/loveu_tgve_{image_size}/gif/VID_{VIDEO_ID}/VIDEO_CFG_{video_cfg}_TEXT_CFG_{text_cfg}' 110 | image_output_dir = f'v2v_results/original_prompt/loveu_tgve_{image_size}/images_{num_frames}/VIDEO_CFG_{video_cfg}_TEXT_CFG_{text_cfg}/{video_name}/{prompt_key}' 111 | 112 | text = '_'.join(final_prompt.split(' ')) 113 | output_path = f'{out_folder}/{prompt_key}_{num_frames}_{text}.gif' 114 | if os.path.exists(output_path): 115 | print(f'File {output_path} exists, skip') 116 | continue 117 | 118 | text_cond = diffusion_model.encode_text(prompt) 119 | conds, num_ref_frames_each_batch = split_batch(cond, frames_in_batch=frames_in_batch, num_ref_frames=num_ref_frames) 120 | splitted_frames, _ = split_batch(batch['frames'], frames_in_batch=frames_in_batch, num_ref_frames=num_ref_frames) 121 | 122 | # First video clip 123 | cond1 = conds[0] 124 | latent_pred_list = [] 125 | init_latent = torch.randn_like(cond1) 126 | latent_pred = inf_pipe( 127 | latent = init_latent, 128 | text_cond = text_cond, 129 | text_uncond = text_uncond, 130 | img_cond = cond1, 131 | text_cfg = text_cfg, 132 | img_cfg = video_cfg, 133 | )['latent'] 134 | latent_pred_list.append(latent_pred) 135 | 136 | 137 | # Subsequent video clips 138 | for prev_cond, cond_, prev_frame, curr_frame, num_ref_frames_ in zip(conds[:-1], conds[1:], splitted_frames[:-1], splitted_frames[1:], num_ref_frames_each_batch): 139 | init_latent = torch.cat([init_latent[:, -num_ref_frames_:], torch.randn_like(cond_)], dim=1) 140 | cond_ = torch.cat([prev_cond[:, -num_ref_frames_:], cond_], dim=1) 141 | if args.with_optical_flow: 142 | ref_images = prev_frame[:, -num_ref_frames_:] 143 | query_images = curr_frame 144 | additional_kwargs = { 145 | 'ref_images': ref_images, 146 | 'query_images': query_images, 147 | } 148 | else: 149 | additional_kwargs = {} 150 | latent_pred = inf_pipe.second_clip_forward( 151 | latent = init_latent, 152 | text_cond = text_cond, 153 | text_uncond = text_uncond, 154 | img_cond = cond_, 155 | latent_ref = latent_pred[:, -num_ref_frames_:], 156 | noise_correct_step = 0.5, 157 | text_cfg = text_cfg, 158 | img_cfg = video_cfg, 159 | **additional_kwargs, 160 | )['latent'] 161 | latent_pred_list.append(latent_pred[:, num_ref_frames_:]) 162 | 163 | # Save GIF 164 | latent_pred = torch.cat(latent_pred_list, dim=1) 165 | image_pred = diffusion_model.decode_latent_to_image(latent_pred).clip(-1, 1) 166 | 167 | original_images = batch['frames'].cpu() 168 | transferred_images = image_pred.float().cpu() 169 | concat_images = torch.cat([original_images, transferred_images], dim=4) 170 | 171 | save_tensor_to_gif(concat_images, output_path, fps=5) 172 | save_tensor_to_images(transferred_images, image_output_dir) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import argparse 4 | from omegaconf import OmegaConf 5 | from pytorch_lightning import Trainer, seed_everything 6 | from misc_utils.train_utils import get_models, get_DDPM, get_logger, get_callbacks, get_dataset 7 | 8 | if __name__ == '__main__': 9 | # seed_everything(42) 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument( 12 | '-c', '--config', type=str, 13 | default='config/train.json') 14 | parser.add_argument( 15 | '-r', '--resume', action="store_true" 16 | ) 17 | parser.add_argument( 18 | '-n', '--nnode', type=int, default=1 19 | ) 20 | parser.add_argument( 21 | '--ckpt', type=str, default=None 22 | ) 23 | parser.add_argument( 24 | '--manual_load', action="store_true" 25 | ) 26 | 27 | ''' parser configs ''' 28 | args_raw = parser.parse_args() 29 | args = OmegaConf.load(args_raw.config) 30 | args.update(vars(args_raw)) 31 | expt_path = os.path.join(args.expt_dir, args.expt_name) 32 | os.makedirs(expt_path, exist_ok=True) 33 | 34 | '''1. create denoising model''' 35 | models = get_models(args) 36 | 37 | diffusion_configs = args.diffusion 38 | ddpm_model = get_DDPM( 39 | diffusion_configs=diffusion_configs, 40 | log_args=args, 41 | **models 42 | ) 43 | 44 | '''2. dataset and dataloader''' 45 | train_loader, val_loader, train_set, val_set = get_dataset(args) 46 | 47 | '''3. create callbacks''' 48 | wandb_logger = get_logger(args) 49 | callbacks = get_callbacks(args, wandb_logger) 50 | 51 | '''4. trainer''' 52 | trainer_args = { 53 | "max_epochs": 100, 54 | "accelerator": "gpu", 55 | "devices": [0], 56 | "limit_val_batches": 1, 57 | "strategy": "ddp", 58 | "check_val_every_n_epoch": 1, 59 | "num_nodes": args.nnode 60 | # "benchmark" :True 61 | } 62 | config_trainer_args = args.trainer_args if args.get('trainer_args') is not None else {} 63 | trainer_args.update(config_trainer_args) 64 | print(f'Training args are {trainer_args}') 65 | trainer = Trainer( 66 | logger = wandb_logger, 67 | callbacks = callbacks, 68 | **trainer_args 69 | ) 70 | '''5. start training''' 71 | if args['resume']: 72 | print('INFO: Try to resume from checkpoint') 73 | if args['ckpt'] is not None: 74 | ckpt_path = args['ckpt'] 75 | else: 76 | ckpt_path = os.path.join(expt_path, 'last.ckpt') 77 | if os.path.exists(ckpt_path): 78 | print(f'INFO: Found checkpoint {ckpt_path}') 79 | if args['manual_load']: 80 | print('INFO: Manually load checkpoint') 81 | ckpt = torch.load(ckpt_path, map_location='cpu') 82 | ddpm_model.load_state_dict(ckpt['state_dict']) 83 | ckpt_path = None # do not need to load checkpoint in Trainer 84 | else: 85 | ckpt_path = None 86 | else: 87 | ckpt_path = None 88 | trainer.fit( 89 | ddpm_model, train_loader, val_loader, 90 | ckpt_path=ckpt_path 91 | ) 92 | -------------------------------------------------------------------------------- /misc_utils/clip_similarity.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/timothybrooks/instruct-pix2pix/blob/main/metrics/clip_similarity.py 2 | 3 | import clip 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from einops import rearrange 8 | 9 | 10 | class ClipSimilarity(nn.Module): 11 | def __init__(self, name: str = "ViT-L/14"): 12 | super().__init__() 13 | assert name in ("RN50", "RN101", "RN50x4", "RN50x16", "RN50x64", "ViT-B/32", "ViT-B/16", "ViT-L/14", "ViT-L/14@336px") # fmt: skip 14 | self.size = {"RN50x4": 288, "RN50x16": 384, "RN50x64": 448, "ViT-L/14@336px": 336}.get(name, 224) 15 | 16 | self.model, _ = clip.load(name, device="cpu", download_root="./") 17 | self.model.eval().requires_grad_(False) 18 | 19 | self.register_buffer("mean", torch.tensor((0.48145466, 0.4578275, 0.40821073))) 20 | self.register_buffer("std", torch.tensor((0.26862954, 0.26130258, 0.27577711))) 21 | 22 | def encode_text(self, text: list[str]) -> torch.Tensor: 23 | text = clip.tokenize(text, truncate=True).to(next(self.parameters()).device) 24 | text_features = self.model.encode_text(text) 25 | text_features = text_features / text_features.norm(dim=1, keepdim=True) 26 | return text_features 27 | 28 | def encode_image(self, image: torch.Tensor) -> torch.Tensor: # Input images in range [0, 1]. 29 | image = F.interpolate(image.float(), size=self.size, mode="bicubic", align_corners=False) 30 | image = image - rearrange(self.mean, "c -> 1 c 1 1") 31 | image = image / rearrange(self.std, "c -> 1 c 1 1") 32 | image_features = self.model.encode_image(image) 33 | image_features = image_features / image_features.norm(dim=1, keepdim=True) 34 | return image_features 35 | 36 | def forward( 37 | self, image_0: torch.Tensor, image_1: torch.Tensor, text_0: list[str], text_1: list[str] 38 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 39 | image_features_0 = self.encode_image(image_0) 40 | image_features_1 = self.encode_image(image_1) 41 | text_features_0 = self.encode_text(text_0) 42 | text_features_1 = self.encode_text(text_1) 43 | sim_0 = F.cosine_similarity(image_features_0, text_features_0) 44 | sim_1 = F.cosine_similarity(image_features_1, text_features_1) 45 | sim_direction = F.cosine_similarity(image_features_1 - image_features_0, text_features_1 - text_features_0) 46 | sim_image = F.cosine_similarity(image_features_0, image_features_1) 47 | return sim_0, sim_1, sim_direction, sim_image -------------------------------------------------------------------------------- /misc_utils/flow_utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Usage: 3 | 4 | from misc_utils.flow_utils import RAFTFlow, load_image_as_tensor, warp_image, MyRandomPerspective, generate_sample 5 | image = load_image_as_tensor('hamburger_pic.jpeg', image_size) 6 | flow_estimator = RAFTFlow() 7 | res = generate_sample( 8 | image, 9 | flow_estimator, 10 | distortion_scale=distortion_scale, 11 | ) 12 | f1 = res['input'][None] 13 | f2 = res['target'][None] 14 | flow = res['flow'][None] 15 | f1_warp = warp_image(f1, flow) 16 | show_image(f1_warp[0]) 17 | show_image(f2[0]) 18 | ''' 19 | import torch 20 | import torch.nn.functional as F 21 | import torchvision.transforms.functional as TF 22 | from torchvision.models.optical_flow import raft_large, Raft_Large_Weights 23 | import numpy as np 24 | 25 | def warp_image(image, flow, mode='bilinear'): 26 | """ Warp an image using optical flow. 27 | Args: 28 | image (torch.Tensor): Input image tensor with shape (N, C, H, W). 29 | flow (torch.Tensor): Optical flow tensor with shape (N, 2, H, W). 30 | Returns: 31 | warped_image (torch.Tensor): Warped image tensor with shape (N, C, H, W). 32 | """ 33 | # check shape 34 | if len(image.shape) == 3: 35 | image = image.unsqueeze(0) 36 | if len(flow.shape) == 3: 37 | flow = flow.unsqueeze(0) 38 | if image.device != flow.device: 39 | flow = flow.to(image.device) 40 | assert image.shape[0] == flow.shape[0], f'Batch size of image and flow must be the same. Got {image.shape[0]} and {flow.shape[0]}.' 41 | assert image.shape[2:] == flow.shape[2:], f'Height and width of image and flow must be the same. Got {image.shape[2:]} and {flow.shape[2:]}.' 42 | # Generate a grid of sampling points 43 | grid = torch.tensor( 44 | np.array(np.meshgrid(range(image.shape[3]), range(image.shape[2]), indexing='xy')), 45 | dtype=torch.float32, device=image.device 46 | )[None] 47 | grid = grid.permute(0, 2, 3, 1).repeat(image.shape[0], 1, 1, 1) # (N, H, W, 2) 48 | grid += flow.permute(0, 2, 3, 1) # add optical flow to grid 49 | 50 | # Normalize grid to [-1, 1] 51 | grid[:, :, :, 0] = 2 * (grid[:, :, :, 0] / (image.shape[3] - 1) - 0.5) 52 | grid[:, :, :, 1] = 2 * (grid[:, :, :, 1] / (image.shape[2] - 1) - 0.5) 53 | 54 | # Sample input image using the grid 55 | warped_image = F.grid_sample(image, grid, mode=mode, align_corners=True) 56 | 57 | return warped_image 58 | 59 | def resize_flow(flow, size): 60 | """ 61 | Resize optical flow tensor to a new size. 62 | 63 | Args: 64 | flow (torch.Tensor): Optical flow tensor with shape (B, 2, H, W). 65 | size (tuple[int, int]): Target size as a tuple (H, W). 66 | 67 | Returns: 68 | flow_resized (torch.Tensor): Resized optical flow tensor with shape (B, 2, H, W). 69 | """ 70 | # Unpack the target size 71 | H, W = size 72 | 73 | # Compute the scaling factors 74 | h, w = flow.shape[2:] 75 | scale_x = W / w 76 | scale_y = H / h 77 | 78 | # Scale the optical flow by the resizing factors 79 | flow_scaled = flow.clone() 80 | flow_scaled[:, 0] *= scale_x 81 | flow_scaled[:, 1] *= scale_y 82 | 83 | # Resize the optical flow to the new size (H, W) 84 | flow_resized = F.interpolate(flow_scaled, size=(H, W), mode='bilinear', align_corners=False) 85 | 86 | return flow_resized 87 | 88 | def check_consistency(flow1: torch.Tensor, flow2: torch.Tensor) -> torch.Tensor: 89 | """ 90 | Check the consistency of two optical flows. 91 | flow1: (B, 2, H, W) 92 | flow2: (B, 2, H, W) 93 | if want the output to be forward flow, then flow1 is the forward flow and flow2 is the backward flow 94 | return: (H, W) 95 | """ 96 | device = flow1.device 97 | height, width = flow1.shape[2:] 98 | 99 | kernel_x = torch.tensor([[0.5, 0, -0.5]]).unsqueeze(0).unsqueeze(0).to(device) 100 | kernel_y = torch.tensor([[0.5], [0], [-0.5]]).unsqueeze(0).unsqueeze(0).to(device) 101 | grad_x = torch.nn.functional.conv2d(flow1[:, :1], kernel_x, padding=(0, 1)) 102 | grad_y = torch.nn.functional.conv2d(flow1[:, 1:], kernel_y, padding=(1, 0)) 103 | 104 | motion_edge = (grad_x * grad_x + grad_y * grad_y).sum(dim=1).squeeze(0) 105 | 106 | ax, ay = torch.meshgrid(torch.arange(width, device=device), torch.arange(height, device=device), indexing='xy') 107 | bx, by = ax + flow1[:, 0], ay + flow1[:, 1] 108 | 109 | x1, y1 = torch.floor(bx).long(), torch.floor(by).long() 110 | x2, y2 = x1 + 1, y1 + 1 111 | x1 = torch.clamp(x1, 0, width - 1) 112 | x2 = torch.clamp(x2, 0, width - 1) 113 | y1 = torch.clamp(y1, 0, height - 1) 114 | y2 = torch.clamp(y2, 0, height - 1) 115 | 116 | alpha_x, alpha_y = bx - x1.float(), by - y1.float() 117 | 118 | a = (1.0 - alpha_x) * flow2[:, 0, y1, x1] + alpha_x * flow2[:, 0, y1, x2] 119 | b = (1.0 - alpha_x) * flow2[:, 0, y2, x1] + alpha_x * flow2[:, 0, y2, x2] 120 | u = (1.0 - alpha_y) * a + alpha_y * b 121 | 122 | a = (1.0 - alpha_x) * flow2[:, 1, y1, x1] + alpha_x * flow2[:, 1, y1, x2] 123 | b = (1.0 - alpha_x) * flow2[:, 1, y2, x1] + alpha_x * flow2[:, 1, y2, x2] 124 | v = (1.0 - alpha_y) * a + alpha_y * b 125 | 126 | cx, cy = bx + u, by + v 127 | u2, v2 = flow1[:, 0], flow1[:, 1] 128 | 129 | reliable = ((((cx - ax) ** 2 + (cy - ay) ** 2) < (0.01 * (u2 ** 2 + v2 ** 2 + u ** 2 + v ** 2) + 0.5)) & (motion_edge <= 0.01 * (u2 ** 2 + v2 ** 2) + 0.002)).float() 130 | 131 | return reliable # (B, 1, H, W) 132 | 133 | 134 | class RAFTFlow(torch.nn.Module): 135 | ''' 136 | # Instantiate the RAFTFlow class 137 | raft_flow = RAFTFlow(device='cuda') 138 | 139 | # Load a pair of image frames as PyTorch tensors 140 | img1 = torch.tensor(np.random.rand(3, 720, 1280), dtype=torch.float32) 141 | img2 = torch.tensor(np.random.rand(3, 720, 1280), dtype=torch.float32) 142 | 143 | # Compute optical flow between the two frames 144 | (optional) image_size = (256, 256) or None 145 | flow = raft_flow.compute_flow(img1, img2, image_size) # flow will be computed at the original image size if image_size is None 146 | # this flow can be used to warp the second image to the first image 147 | 148 | # Warp the second image using the flow 149 | warped_img = warp_image(img2, flow) 150 | ''' 151 | def __init__(self, *args): 152 | """ 153 | Args: 154 | device (str): Device to run the model on ("cpu" or "cuda"). 155 | """ 156 | super().__init__(*args) 157 | weights = Raft_Large_Weights.DEFAULT 158 | self.model = raft_large(weights=weights, progress=False) 159 | self.model_transform = weights.transforms() 160 | 161 | def forward(self, img1, img2, img_size=None): 162 | """ 163 | Compute optical flow between two frames using RAFT model. 164 | 165 | Args: 166 | img1 (torch.Tensor): First frame tensor with shape (B, C, H, W). 167 | img2 (torch.Tensor): Second frame tensor with shape (B, C, H, W). 168 | img_size (tuple): Size of the input images to be processed. 169 | 170 | Returns: 171 | flow (torch.Tensor): Optical flow tensor with shape (B, 2, H, W). 172 | """ 173 | original_size = img1.shape[2:] 174 | # Preprocess the input frames 175 | if img_size is not None: 176 | img1 = TF.resize(img1, size=img_size, antialias=False) 177 | img2 = TF.resize(img2, size=img_size, antialias=False) 178 | 179 | img1, img2 = self.model_transform(img1, img2) 180 | 181 | # Compute the optical flow using the RAFT model 182 | with torch.no_grad(): 183 | list_of_flows = self.model(img1, img2) 184 | flow = list_of_flows[-1] 185 | 186 | if img_size is not None: 187 | flow = resize_flow(flow, original_size) 188 | 189 | return flow 190 | -------------------------------------------------------------------------------- /misc_utils/image_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import matplotlib.pyplot as plt 3 | import torch 4 | import numpy as np 5 | import cv2 6 | import imageio 7 | from PIL import Image 8 | import textwrap 9 | 10 | def find_nearest_Nx(size, N=32): 11 | return int(np.ceil(size / N) * N) 12 | 13 | def load_image_as_tensor(image_path, image_size): 14 | if isinstance(image_size, int): 15 | image_size = (image_size, image_size) 16 | image = cv2.imread(image_path)[..., ::-1] 17 | try: 18 | image = cv2.resize(image, image_size) 19 | except Exception as e: 20 | print(e) 21 | print(image_path) 22 | 23 | image = torch.from_numpy(np.array(image).transpose(2, 0, 1)) / 255. 24 | return image 25 | 26 | def show_image(image): 27 | if len(image.shape) == 4: 28 | image = image[0] 29 | plt.imshow(image.permute(1, 2, 0).detach().cpu().numpy()) 30 | plt.show() 31 | 32 | def extract_video(video_path, save_dir, sampling_fps, skip_frames=0): 33 | os.makedirs(save_dir, exist_ok=True) 34 | cap = cv2.VideoCapture(video_path) 35 | frame_skip = int(cap.get(cv2.CAP_PROP_FPS) / sampling_fps) 36 | frame_count = 0 37 | save_count = 0 38 | while True: 39 | ret, frame = cap.read() 40 | if not ret: 41 | break 42 | if frame_count < skip_frames: # skip the first N frames 43 | frame_count += 1 44 | continue 45 | if (frame_count - skip_frames) % frame_skip == 0: 46 | # Save the frame as an image file if it doesn't already exist 47 | save_path = os.path.join(save_dir, f"frame{save_count:04d}.jpg") 48 | save_count += 1 49 | if not os.path.exists(save_path): 50 | cv2.imwrite(save_path, frame) 51 | frame_count += 1 52 | cap.release() 53 | cv2.destroyAllWindows() 54 | 55 | def concatenate_frames_to_video(frame_dir, video_path, fps): 56 | os.makedirs(os.path.dirname(video_path), exist_ok=True) 57 | # Get the list of frame file names in the directory 58 | frame_files = [f for f in os.listdir(frame_dir) if f.startswith("frame")] 59 | # Sort the frame file names in ascending order 60 | frame_files.sort() 61 | # Load the first frame to get the frame size 62 | frame = cv2.imread(os.path.join(frame_dir, frame_files[0])) 63 | height, width, _ = frame.shape 64 | # Initialize the video writer 65 | fourcc = cv2.VideoWriter_fourcc(*"mp4v") 66 | out = cv2.VideoWriter(video_path, fourcc, fps, (width, height)) 67 | # Loop through the frame files and add them to the video 68 | for frame_file in frame_files: 69 | frame_path = os.path.join(frame_dir, frame_file) 70 | frame = cv2.imread(frame_path) 71 | out.write(frame) 72 | # Release the video writer 73 | out.release() 74 | 75 | def cumulative_histogram(hist): 76 | cum_hist = hist.copy() 77 | for i in range(1, len(hist)): 78 | cum_hist[i] = cum_hist[i - 1] + hist[i] 79 | return cum_hist 80 | 81 | def histogram_matching(src_img, ref_img): 82 | src_img = (src_img * 255).astype(np.uint8) 83 | ref_img = (ref_img * 255).astype(np.uint8) 84 | src_img_yuv = cv2.cvtColor(src_img, cv2.COLOR_RGB2YUV) 85 | ref_img_yuv = cv2.cvtColor(ref_img, cv2.COLOR_RGB2YUV) 86 | 87 | matched_img = np.zeros_like(src_img_yuv) 88 | for channel in range(src_img_yuv.shape[2]): 89 | src_hist, _ = np.histogram(src_img_yuv[:, :, channel].ravel(), 256, (0, 256)) 90 | ref_hist, _ = np.histogram(ref_img_yuv[:, :, channel].ravel(), 256, (0, 256)) 91 | 92 | src_cum_hist = cumulative_histogram(src_hist) 93 | ref_cum_hist = cumulative_histogram(ref_hist) 94 | 95 | lut = np.zeros(256, dtype=np.uint8) 96 | j = 0 97 | for i in range(256): 98 | while ref_cum_hist[j] < src_cum_hist[i] and j < 255: 99 | j += 1 100 | lut[i] = j 101 | 102 | matched_img[:, :, channel] = cv2.LUT(src_img_yuv[:, :, channel], lut) 103 | 104 | matched_img = cv2.cvtColor(matched_img, cv2.COLOR_YUV2RGB) 105 | matched_img = matched_img.astype(np.float32) / 255 106 | return matched_img 107 | 108 | def canny_image_batch(image_batch, low_threshold=100, high_threshold=200): 109 | if isinstance(image_batch, torch.Tensor): 110 | # [-1, 1] tensor -> [0, 255] numpy array 111 | is_torch = True 112 | device = image_batch.device 113 | image_batch = (image_batch + 1) * 127.5 114 | image_batch = image_batch.permute(0, 2, 3, 1).detach().cpu().numpy() 115 | image_batch = image_batch.astype(np.uint8) 116 | image_batch = np.array([cv2.Canny(image, low_threshold, high_threshold) for image in image_batch]) 117 | image_batch = image_batch[:, :, :, None] 118 | image_batch = np.concatenate([image_batch, image_batch, image_batch], axis=3) 119 | 120 | if is_torch: 121 | # [0, 255] numpy array -> [-1, 1] tensor 122 | image_batch = torch.from_numpy(image_batch).permute(0, 3, 1, 2).float() / 255. 123 | image_batch = image_batch.to(device) 124 | return image_batch 125 | 126 | 127 | def images_to_gif(images, filename, fps): 128 | os.makedirs(os.path.dirname(filename), exist_ok=True) 129 | # Normalize to 0-255 and convert to uint8 130 | images = [(img * 255).astype(np.uint8) if img.dtype == np.float32 else img for img in images] 131 | images = [Image.fromarray(img) for img in images] 132 | imageio.mimsave(filename, images, duration=1 / fps) 133 | 134 | def load_gif(image_path): 135 | import imageio 136 | gif = imageio.get_reader(image_path) 137 | np_images = np.array([frame[..., :3] for frame in gif]) 138 | return np_images 139 | 140 | def add_text_to_frame(frame, text, font_scale=1, thickness=2, color=(0, 0, 0), bg_color=(255, 255, 255), max_width=30): 141 | """ 142 | Add text to a frame. 143 | """ 144 | # Make a copy of the frame 145 | frame_with_text = np.copy(frame) 146 | # Choose font 147 | font = cv2.FONT_HERSHEY_SIMPLEX 148 | # Split text into lines if it's too long 149 | lines = textwrap.wrap(text, width=max_width) 150 | # Get total text height 151 | total_text_height = len(lines) * (thickness * font_scale + 10) + 60 * font_scale 152 | # Create an image filled with the background color, having enough space for the text 153 | text_bg_img = np.full((int(total_text_height), frame.shape[1], 3), bg_color, dtype=np.uint8) 154 | # Put each line on the text background image 155 | y = 0 156 | for line in lines: 157 | text_size, _ = cv2.getTextSize(line, font, font_scale, thickness) 158 | text_x = (text_bg_img.shape[1] - text_size[0]) // 2 159 | y += text_size[1] + 10 160 | cv2.putText(text_bg_img, line, (text_x, y), font, font_scale, color, thickness) 161 | # Append the text background image to the frame 162 | frame_with_text = np.vstack((frame_with_text, text_bg_img)) 163 | 164 | return frame_with_text 165 | 166 | def add_text_to_gif(numpy_images, text, **kwargs): 167 | """ 168 | Add text to each frame of a gif. 169 | """ 170 | # Iterate over frames and add text to each frame 171 | frames_with_text = [] 172 | for frame in numpy_images: 173 | frame_with_text = add_text_to_frame(frame, text, **kwargs) 174 | frames_with_text.append(frame_with_text) 175 | 176 | # Convert the list of frames to a numpy array 177 | numpy_images_with_text = np.array(frames_with_text) 178 | 179 | return numpy_images_with_text 180 | 181 | def pad_images_to_same_height(images): 182 | """ 183 | Pad images to the same height. 184 | """ 185 | # Find the maximum height 186 | max_height = max(img.shape[0] for img in images) 187 | 188 | # Pad each image to the maximum height 189 | padded_images = [] 190 | for img in images: 191 | pad_height = max_height - img.shape[0] 192 | padded_img = cv2.copyMakeBorder(img, 0, pad_height, 0, 0, cv2.BORDER_CONSTANT, value=[255, 255, 255]) 193 | padded_images.append(padded_img) 194 | 195 | return padded_images 196 | 197 | def concatenate_gifs(gifs): 198 | """ 199 | Concatenate gifs. 200 | """ 201 | # Ensure that all gifs have the same number of frames 202 | min_num_frames = min(gif.shape[0] for gif in gifs) 203 | gifs = [gif[:min_num_frames] for gif in gifs] 204 | 205 | # Concatenate each frame 206 | concatenated_gifs = [] 207 | for i in range(min_num_frames): 208 | # Get the i-th frame from each gif 209 | frames = [gif[i] for gif in gifs] 210 | 211 | # Pad the frames to the same height 212 | padded_frames = pad_images_to_same_height(frames) 213 | 214 | # Concatenate the padded frames 215 | concatenated_frame = np.concatenate(padded_frames, axis=1) 216 | 217 | concatenated_gifs.append(concatenated_frame) 218 | 219 | return np.array(concatenated_gifs) 220 | 221 | def stack_gifs(gifs): 222 | '''vertically stack gifs''' 223 | min_num_frames = min(gif.shape[0] for gif in gifs) 224 | stacked_gifs = [] 225 | 226 | for i in range(min_num_frames): 227 | frames = [gif[i] for gif in gifs] 228 | stacked_frame = np.concatenate(frames, axis=0) 229 | stacked_gifs.append(stacked_frame) 230 | 231 | return np.array(stacked_gifs) 232 | 233 | def save_tensor_to_gif(images, filename, fps): 234 | images = images.squeeze(0).detach().cpu().numpy().transpose(0, 2, 3, 1) / 2 + 0.5 235 | images_to_gif(images, filename, fps) 236 | 237 | def save_tensor_to_images(images, output_dir): 238 | images = images.squeeze(0).detach().cpu().numpy().transpose(0, 2, 3, 1) / 2 + 0.5 239 | os.makedirs(output_dir, exist_ok=True) 240 | for i in range(images.shape[0]): 241 | plt.imsave(f'{output_dir}/{i:03d}.jpg', images[i]) -------------------------------------------------------------------------------- /misc_utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import torch 3 | import numpy as np 4 | from inspect import isfunction 5 | 6 | def instantiate_from_config(config): 7 | if not "target" in config: 8 | raise KeyError("Expected key `target` to instantiate.") 9 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 10 | 11 | 12 | def get_obj_from_str(string, reload=False): 13 | module, cls = string.rsplit(".", 1) 14 | if reload: 15 | module_imp = importlib.import_module(module) 16 | importlib.reload(module_imp) 17 | return getattr(importlib.import_module(module, package=None), cls) 18 | 19 | def exists(x): 20 | return x is not None 21 | 22 | def default(val, d): 23 | if exists(val): 24 | return val 25 | return d() if isfunction(d) else d 26 | 27 | def noise_like(shape, device, repeat=False): 28 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 29 | noise = lambda: torch.randn(shape, device=device) 30 | return repeat_noise() if repeat else noise() 31 | 32 | def extract_into_tensor(a, t, x_shape): 33 | b, *_ = t.shape 34 | out = a.gather(-1, t) 35 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 36 | 37 | def right_pad_dims_to(x, t): 38 | padding_dims = x.ndim - t.ndim 39 | if padding_dims <= 0: 40 | return t 41 | return t.view(*t.shape, *((1,) * padding_dims)) 42 | 43 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 44 | if schedule == "linear" or schedule == "scaled_linear": 45 | betas = ( 46 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 47 | ) 48 | 49 | elif schedule == "cosine": 50 | timesteps = ( 51 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 52 | ) 53 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 54 | alphas = torch.cos(alphas).pow(2) 55 | alphas = alphas / alphas[0] 56 | betas = 1 - alphas[1:] / alphas[:-1] 57 | betas = np.clip(betas, a_min=0, a_max=0.999) 58 | 59 | elif schedule == "sqrt_linear": 60 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 61 | elif schedule == "sqrt": 62 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 63 | else: 64 | raise ValueError(f"schedule '{schedule}' unknown.") 65 | return betas.numpy() 66 | 67 | 68 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 69 | if ddim_discr_method == 'uniform': 70 | c = num_ddpm_timesteps // num_ddim_timesteps 71 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 72 | elif ddim_discr_method == 'quad': 73 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 74 | else: 75 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 76 | 77 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 78 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 79 | steps_out = ddim_timesteps + 1 80 | if verbose: 81 | print(f'Selected timesteps for ddim sampler: {steps_out}') 82 | return steps_out 83 | 84 | 85 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 86 | # select alphas for computing the variance schedule 87 | alphas = alphacums[ddim_timesteps] 88 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 89 | 90 | # according the the formula provided in https://arxiv.org/abs/2010.02502 91 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 92 | if verbose: 93 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 94 | print(f'For the chosen value of eta, which is {eta}, ' 95 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 96 | return sigmas, alphas, alphas_prev 97 | 98 | 99 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 100 | """ 101 | Create a beta schedule that discretizes the given alpha_t_bar function, 102 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 103 | :param num_diffusion_timesteps: the number of betas to produce. 104 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 105 | produces the cumulative product of (1-beta) up to that 106 | part of the diffusion process. 107 | :param max_beta: the maximum beta to use; use values lower than 1 to 108 | prevent singularities. 109 | """ 110 | betas = [] 111 | for i in range(num_diffusion_timesteps): 112 | t1 = i / num_diffusion_timesteps 113 | t2 = (i + 1) / num_diffusion_timesteps 114 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 115 | return np.array(betas) -------------------------------------------------------------------------------- /misc_utils/ptp_utils.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import torch 3 | import numpy as np 4 | 5 | @dataclass 6 | class Edit: 7 | old: str 8 | new: str 9 | weight: float = 1.0 10 | 11 | 12 | @dataclass 13 | class Insert: 14 | text: str 15 | weight: float = 1.0 16 | 17 | @property 18 | def old(self): 19 | return "" 20 | 21 | @property 22 | def new(self): 23 | return self.text 24 | 25 | 26 | @dataclass 27 | class Delete: 28 | text: str 29 | weight: float = 1.0 30 | 31 | @property 32 | def old(self): 33 | return self.text 34 | 35 | @property 36 | def new(self): 37 | return "" 38 | 39 | 40 | @dataclass 41 | class Text: 42 | text: str 43 | weight: float = 1.0 44 | 45 | @property 46 | def old(self): 47 | return self.text 48 | 49 | @property 50 | def new(self): 51 | return self.text 52 | 53 | @torch.inference_mode() 54 | def get_text_embedding(prompt, tokenizer, text_encoder): 55 | text_input_ids = tokenizer( 56 | prompt, 57 | padding="max_length", 58 | truncation=True, 59 | max_length=tokenizer.model_max_length, 60 | return_tensors="pt", 61 | ).input_ids 62 | text_embeddings = text_encoder(text_input_ids.to(text_encoder.device))[0] 63 | return text_embeddings 64 | 65 | @torch.inference_mode() 66 | def encode_text(text_pieces, tokenizer, text_encoder): 67 | n_old_tokens = 0 68 | n_new_tokens = 0 69 | new_id_to_old_id = [] 70 | weights = [] 71 | for piece in text_pieces: 72 | old, new = piece.old, piece.new 73 | old_tokens = tokenizer.tokenize(old) 74 | new_tokens = tokenizer.tokenize(new) 75 | if len(old_tokens) == 0 and len(new_tokens) == 0: 76 | continue 77 | elif old == new: 78 | n_old_tokens += len(old_tokens) 79 | n_new_tokens += len(new_tokens) 80 | new_id_to_old_id.extend(range(n_old_tokens - len(old_tokens), n_old_tokens)) 81 | elif len(old_tokens) == 0: 82 | # insert 83 | new_id_to_old_id.extend([-1] * len(new_tokens)) 84 | n_new_tokens += len(new_tokens) 85 | elif len(new_tokens) == 0: 86 | # delete 87 | n_old_tokens += len(old_tokens) 88 | else: 89 | # replace 90 | n_old_tokens += len(old_tokens) 91 | n_new_tokens += len(new_tokens) 92 | start = n_old_tokens - len(old_tokens) 93 | end = n_old_tokens 94 | ids = np.linspace(start, end, len(new_tokens), endpoint=False).astype(int) 95 | new_id_to_old_id.extend(list(ids)) 96 | weights.extend([piece.weight] * len(new_tokens)) 97 | 98 | old_prompt = " ".join([piece.old for piece in text_pieces]) 99 | new_prompt = " ".join([piece.new for piece in text_pieces]) 100 | old_text_input_ids = tokenizer( 101 | old_prompt, 102 | padding="max_length", 103 | truncation=True, 104 | max_length=tokenizer.model_max_length, 105 | return_tensors="pt", 106 | ).input_ids 107 | new_text_input_ids = tokenizer( 108 | new_prompt, 109 | padding="max_length", 110 | truncation=True, 111 | max_length=tokenizer.model_max_length, 112 | return_tensors="pt", 113 | ).input_ids 114 | 115 | old_text_embeddings = text_encoder(old_text_input_ids.to(text_encoder.device))[0] 116 | new_text_embeddings = text_encoder(new_text_input_ids.to(text_encoder.device))[0] 117 | value = new_text_embeddings.clone() # batch (1), seq, dim 118 | key = new_text_embeddings.clone() 119 | 120 | for i, (j, weight) in enumerate(zip(new_id_to_old_id, weights)): 121 | if 0 <= j < old_text_embeddings.shape[1]: 122 | key[0, i] = old_text_embeddings[0, j] 123 | value[0, i] *= weight 124 | return key, value 125 | 126 | @torch.inference_mode() 127 | def get_text_embedding_openclip(prompt, text_encoder, device='cuda'): 128 | import open_clip 129 | text_input_ids = open_clip.tokenize(prompt) 130 | text_embeddings = text_encoder(text_input_ids.to(device)) 131 | return text_embeddings 132 | 133 | @torch.inference_mode() 134 | def encode_text_openclip(text_pieces, text_encoder, device='cuda'): 135 | import open_clip 136 | n_old_tokens = 0 137 | n_new_tokens = 0 138 | new_id_to_old_id = [] 139 | weights = [] 140 | for piece in text_pieces: 141 | old, new = piece.old, piece.new 142 | old_tokens = open_clip.tokenize(old) 143 | new_tokens = open_clip.tokenize(new) 144 | if len(old_tokens) == 0 and len(new_tokens) == 0: 145 | continue 146 | elif old == new: 147 | n_old_tokens += len(old_tokens) 148 | n_new_tokens += len(new_tokens) 149 | new_id_to_old_id.extend(range(n_old_tokens - len(old_tokens), n_old_tokens)) 150 | elif len(old_tokens) == 0: 151 | # insert 152 | new_id_to_old_id.extend([-1] * len(new_tokens)) 153 | n_new_tokens += len(new_tokens) 154 | elif len(new_tokens) == 0: 155 | # delete 156 | n_old_tokens += len(old_tokens) 157 | else: 158 | # replace 159 | n_old_tokens += len(old_tokens) 160 | n_new_tokens += len(new_tokens) 161 | start = n_old_tokens - len(old_tokens) 162 | end = n_old_tokens 163 | ids = np.linspace(start, end, len(new_tokens), endpoint=False).astype(int) 164 | new_id_to_old_id.extend(list(ids)) 165 | weights.extend([piece.weight] * len(new_tokens)) 166 | 167 | old_prompt = " ".join([piece.old for piece in text_pieces]) 168 | new_prompt = " ".join([piece.new for piece in text_pieces]) 169 | old_text_input_ids = open_clip.tokenize(old_prompt) 170 | new_text_input_ids = open_clip.tokenize(new_prompt) 171 | 172 | old_text_embeddings = text_encoder(old_text_input_ids.to(device)) 173 | new_text_embeddings = text_encoder(new_text_input_ids.to(device)) 174 | value = new_text_embeddings.clone() # batch (1), seq, dim 175 | key = new_text_embeddings.clone() 176 | 177 | for i, (j, weight) in enumerate(zip(new_id_to_old_id, weights)): 178 | if 0 <= j < old_text_embeddings.shape[1]: 179 | key[0, i] = old_text_embeddings[0, j] 180 | value[0, i] *= weight 181 | return key, value -------------------------------------------------------------------------------- /misc_utils/train_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from omegaconf import OmegaConf 3 | from pytorch_lightning.loggers import WandbLogger 4 | from misc_utils.model_utils import instantiate_from_config, get_obj_from_str 5 | 6 | def get_models(args): 7 | unet = instantiate_from_config(args.unet) 8 | model_dict = { 9 | 'unet': unet, 10 | } 11 | 12 | if args.get('vae'): 13 | vae = instantiate_from_config(args.vae) 14 | model_dict['vae'] = vae 15 | 16 | if args.get('text_model'): 17 | text_model = instantiate_from_config(args.text_model) 18 | model_dict['text_model'] = text_model 19 | 20 | if args.get('ctrlnet'): 21 | ctrlnet = instantiate_from_config(args.ctrlnet) 22 | model_dict['ctrlnet'] = ctrlnet 23 | 24 | return model_dict 25 | 26 | def get_DDPM(diffusion_configs, log_args={}, **models): 27 | diffusion_model_class = diffusion_configs['target'] 28 | diffusion_args = diffusion_configs['params'] 29 | DDPM_model = get_obj_from_str(diffusion_model_class) 30 | ddpm_model = DDPM_model( 31 | log_args=log_args, 32 | **models, 33 | **diffusion_args 34 | ) 35 | return ddpm_model 36 | 37 | 38 | def get_logger(args): 39 | wandb_logger = WandbLogger( 40 | project=args["expt_name"], 41 | ) 42 | return wandb_logger 43 | 44 | def get_callbacks(args, wandb_logger): 45 | callbacks = [] 46 | for callback in args['callbacks']: 47 | if callback.get('require_wandb', False): 48 | # we need to pass wandb logger to the callback 49 | callback_obj = get_obj_from_str(callback.target) 50 | callbacks.append( 51 | callback_obj(wandb_logger=wandb_logger, **callback.params) 52 | ) 53 | else: 54 | callbacks.append( 55 | instantiate_from_config(callback) 56 | ) 57 | return callbacks 58 | 59 | def get_dataset(args): 60 | from torch.utils.data import DataLoader 61 | data_args = args['data'] 62 | train_set = instantiate_from_config(data_args['train']) 63 | val_set = instantiate_from_config(data_args['val']) 64 | train_loader = DataLoader( 65 | train_set, batch_size=data_args['batch_size'], shuffle=True, 66 | num_workers=4*len(args['trainer_args']['devices']), pin_memory=True 67 | ) 68 | val_loader = DataLoader( 69 | val_set, batch_size=data_args['val_batch_size'], 70 | num_workers=len(args['trainer_args']['devices']), pin_memory=True 71 | ) 72 | return train_loader, val_loader, train_set, val_set 73 | 74 | def unit_test_create_model(config_path): 75 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 76 | conf = OmegaConf.load(config_path) 77 | models = get_models(conf) 78 | ddpm = get_DDPM(conf['diffusion'], log_args=conf, **models) 79 | ddpm = ddpm.to(device) 80 | return ddpm 81 | 82 | def unit_test_create_dataset(config_path, split='train'): 83 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 84 | conf = OmegaConf.load(config_path) 85 | train_loader, val_loader, train_set, val_set = get_dataset(conf) 86 | if split == 'train': 87 | batch = next(iter(train_loader)) 88 | else: 89 | batch = next(iter(val_loader)) 90 | for k, v in batch.items(): 91 | if isinstance(v, torch.Tensor): 92 | batch[k] = v.to(device) 93 | return batch 94 | 95 | def unit_test_training_step(config_path): 96 | ddpm = unit_test_create_model(config_path) 97 | batch = unit_test_create_dataset(config_path) 98 | res = ddpm.training_step(batch, 0) 99 | return res 100 | 101 | def unit_test_val_step(config_path): 102 | ddpm = unit_test_create_model(config_path) 103 | batch = unit_test_create_dataset(config_path, split='val') 104 | res = ddpm.validation_step(batch, 0) 105 | return res 106 | 107 | NEGATIVE_PROMPTS = "(((deformed))), blurry, bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), fused fingers, messy drawing, broken legs censor, censored, censor_bar, multiple breasts, (mutated hands and fingers:1.5), (long body :1.3), (mutation, poorly drawn :1.2), black-white, bad anatomy, liquid body, liquidtongue, disfigured, malformed, mutated, anatomical nonsense, text font ui, error, malformed hands, long neck, blurred, lowers, low res, bad anatomy, bad proportions, bad shadow, uncoordinated body, unnatural body, fused breasts, bad breasts, huge breasts, poorly drawn breasts, extra breasts, liquid breasts, heavy breasts, missingbreasts, huge haunch, huge thighs, huge calf, bad hands, fused hand, missing hand, disappearing arms, disappearing thigh, disappearing calf, disappearing legs, fusedears, bad ears, poorly drawn ears, extra ears, liquid ears, heavy ears, missing ears, old photo, low res, black and white, black and white filter, colorless" 108 | -------------------------------------------------------------------------------- /misc_utils/video_ptp_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from modules.damo_text_to_video.unet_sd import UNetSD 3 | from misc_utils.train_utils import instantiate_from_config 4 | from omegaconf import OmegaConf 5 | from modules.damo_text_to_video.text_model import FrozenOpenCLIPEmbedder 6 | from typing import List, Tuple, Union 7 | from diff_match_patch import diff_match_patch 8 | import difflib 9 | from misc_utils.ptp_utils import get_text_embedding_openclip, encode_text_openclip, Text, Edit, Insert, Delete 10 | 11 | 12 | def get_models_of_damo_model( 13 | unet_config: str, 14 | unet_ckpt: str, 15 | vae_config: str, 16 | vae_ckpt: str, 17 | text_model_ckpt: str, 18 | ): 19 | vae_conf = OmegaConf.load(vae_config) 20 | unet_conf = OmegaConf.load(unet_config) 21 | 22 | vae = instantiate_from_config(vae_conf.vae) 23 | vae.load_state_dict(torch.load(vae_ckpt, map_location='cpu')) 24 | vae = vae.half().cuda() 25 | 26 | unet = UNetSD(**unet_conf.model.model_cfg) 27 | unet.load_state_dict(torch.load(unet_ckpt, map_location='cpu')) 28 | unet = unet.half().cuda() 29 | 30 | text_model = FrozenOpenCLIPEmbedder(version=text_model_ckpt, layer='penultimate') 31 | text_model = text_model.half().cuda() 32 | 33 | return vae, unet, text_model 34 | 35 | def compute_diff_old(old_sentence: str, new_sentence: str) -> List[Tuple[Union[Text, Edit, Insert, Delete], str, str]]: 36 | dmp = diff_match_patch() 37 | diff = dmp.diff_main(old_sentence, new_sentence) 38 | dmp.diff_cleanupSemantic(diff) 39 | 40 | result = [] 41 | i = 0 42 | while i < len(diff): 43 | op, data = diff[i] 44 | if op == 0: # Equal 45 | # result.append((Text, data, data)) 46 | result.append(Text(text=data)) 47 | elif op == -1: # Delete 48 | if i + 1 < len(diff) and diff[i + 1][0] == 1: # If next operation is Insert 49 | result.append(Edit(old=data, new=diff[i + 1][1])) # Append as Edit operation 50 | i += 1 # Skip next operation because we've handled it here 51 | else: 52 | result.append(Delete(text=data)) 53 | elif op == 1: # Insert 54 | if i == 0 or diff[i - 1][0] != -1: # If previous operation wasn't Delete 55 | result.append(Insert(text=data)) 56 | i += 1 57 | 58 | return result 59 | 60 | def compute_diff(old_sentence: str, new_sentence: str) -> List[Union[Text, Edit, Insert, Delete]]: 61 | differ = difflib.Differ() 62 | diff = list(differ.compare(old_sentence.split(), new_sentence.split())) 63 | 64 | result = [] 65 | i = 0 66 | while i < len(diff): 67 | if diff[i][0] == ' ': # Equal 68 | equal_words = [diff[i][2:]] 69 | while i + 1 < len(diff) and diff[i + 1][0] == ' ': 70 | i += 1 71 | equal_words.append(diff[i][2:]) 72 | result.append(Text(text=' '.join(equal_words))) 73 | elif diff[i][0] == '-': # Delete 74 | deleted_words = [diff[i][2:]] 75 | while i + 1 < len(diff) and diff[i + 1][0] == '-': 76 | i += 1 77 | deleted_words.append(diff[i][2:]) 78 | result.append(Delete(text=' '.join(deleted_words))) 79 | elif diff[i][0] == '+': # Insert 80 | inserted_words = [diff[i][2:]] 81 | while i + 1 < len(diff) and diff[i + 1][0] == '+': 82 | i += 1 83 | inserted_words.append(diff[i][2:]) 84 | result.append(Insert(text=' '.join(inserted_words))) 85 | i += 1 86 | 87 | # Post-process to merge adjacent inserts and deletes into edits 88 | i = 0 89 | while i < len(result) - 1: 90 | if isinstance(result[i], Delete) and isinstance(result[i+1], Insert): 91 | result[i:i+2] = [Edit(old=result[i].text, new=result[i+1].text)] 92 | elif isinstance(result[i], Insert) and isinstance(result[i+1], Delete): 93 | result[i:i+2] = [Edit(old=result[i+1].text, new=result[i].text)] 94 | else: 95 | i += 1 96 | 97 | return result -------------------------------------------------------------------------------- /modules/damo_text_to_video/configuration.json: -------------------------------------------------------------------------------- 1 | { "framework": "pytorch", 2 | "task": "text-to-video-synthesis", 3 | "model": { 4 | "type": "latent-text-to-video-synthesis", 5 | "model_args": { 6 | "ckpt_clip": "open_clip_pytorch_model.bin", 7 | "ckpt_unet": "text2video_pytorch_model.pth", 8 | "ckpt_autoencoder": "VQGAN_autoencoder.pth", 9 | "max_frames": 16, 10 | "tiny_gpu": 1 11 | }, 12 | "model_cfg": { 13 | "in_dim": 4, 14 | "dim": 320, 15 | "y_dim": 768, 16 | "context_dim": 1024, 17 | "out_dim": 4, 18 | "dim_mult": [1, 2, 4, 4], 19 | "num_heads": 8, 20 | "head_dim": 64, 21 | "num_res_blocks": 2, 22 | "attn_scales": [1, 0.5, 0.25], 23 | "dropout": 0.1, 24 | "temporal_attention": "True", 25 | "use_checkpoint": "True" 26 | } 27 | }, 28 | "pipeline": { 29 | "type": "latent-text-to-video-synthesis" 30 | } 31 | } -------------------------------------------------------------------------------- /modules/damo_text_to_video/text_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import open_clip 3 | 4 | class FrozenOpenCLIPEmbedder(torch.nn.Module): 5 | """ 6 | Uses the OpenCLIP transformer encoder for text 7 | """ 8 | LAYERS = ['last', 'penultimate'] 9 | 10 | def __init__(self, 11 | arch='ViT-H-14', 12 | version='open_clip_pytorch_model.bin', 13 | device='cuda', 14 | max_length=77, 15 | freeze=True, 16 | layer='last'): 17 | super().__init__() 18 | assert layer in self.LAYERS 19 | model, _, _ = open_clip.create_model_and_transforms( 20 | arch, device=torch.device('cpu'), pretrained=version) 21 | del model.visual 22 | self.model = model 23 | 24 | self.device = device 25 | self.max_length = max_length 26 | if freeze: 27 | self.freeze() 28 | self.layer = layer 29 | if self.layer == 'last': 30 | self.layer_idx = 0 31 | elif self.layer == 'penultimate': 32 | self.layer_idx = 1 33 | else: 34 | raise NotImplementedError() 35 | 36 | def freeze(self): 37 | self.model = self.model.eval() 38 | for param in self.parameters(): 39 | param.requires_grad = False 40 | 41 | def forward(self, text): 42 | tokens = open_clip.tokenize(text) 43 | z = self.encode_with_transformer(tokens.to(self.device)) 44 | return z 45 | 46 | def encode_with_transformer(self, text): 47 | x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] 48 | x = x + self.model.positional_embedding 49 | x = x.permute(1, 0, 2) # NLD -> LND 50 | x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) 51 | x = x.permute(1, 0, 2) # LND -> NLD 52 | x = self.model.ln_final(x) 53 | return x 54 | 55 | def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): 56 | for i, r in enumerate(self.model.transformer.resblocks): 57 | if i == len(self.model.transformer.resblocks) - self.layer_idx: 58 | break 59 | x = r(x, attn_mask=attn_mask) 60 | return x 61 | 62 | def encode(self, text): 63 | return self(text) -------------------------------------------------------------------------------- /modules/kl_autoencoder/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from modules.vqvae.model import Encoder, Decoder 7 | 8 | from misc_utils.model_utils import instantiate_from_config 9 | 10 | class DiagonalGaussianDistribution(object): 11 | def __init__(self, parameters, deterministic=False): 12 | self.parameters = parameters 13 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 14 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 15 | self.deterministic = deterministic 16 | self.std = torch.exp(0.5 * self.logvar) 17 | self.var = torch.exp(self.logvar) 18 | if self.deterministic: 19 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 20 | 21 | def sample(self): 22 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 23 | return x 24 | 25 | def kl(self, other=None): 26 | if self.deterministic: 27 | return torch.Tensor([0.]) 28 | else: 29 | if other is None: 30 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 31 | + self.var - 1.0 - self.logvar, 32 | dim=[1, 2, 3]) 33 | else: 34 | return 0.5 * torch.sum( 35 | torch.pow(self.mean - other.mean, 2) / other.var 36 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 37 | dim=[1, 2, 3]) 38 | 39 | def nll(self, sample, dims=[1,2,3]): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | logtwopi = np.log(2.0 * np.pi) 43 | return 0.5 * torch.sum( 44 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 45 | dim=dims) 46 | 47 | def mode(self): 48 | return self.mean 49 | 50 | class AutoencoderKL(pl.LightningModule): 51 | def __init__(self, 52 | ddconfig, 53 | lossconfig, 54 | embed_dim, 55 | ckpt_path=None, 56 | ignore_keys=[], 57 | image_key="image", 58 | colorize_nlabels=None, 59 | monitor=None, 60 | ): 61 | super().__init__() 62 | self.image_key = image_key 63 | self.encoder = Encoder(**ddconfig) 64 | self.decoder = Decoder(**ddconfig) 65 | self.loss = instantiate_from_config(lossconfig) 66 | assert ddconfig["double_z"] 67 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) 68 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 69 | self.embed_dim = embed_dim 70 | if colorize_nlabels is not None: 71 | assert type(colorize_nlabels)==int 72 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 73 | if monitor is not None: 74 | self.monitor = monitor 75 | if ckpt_path is not None: 76 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 77 | 78 | def init_from_ckpt(self, path, ignore_keys=list()): 79 | sd = torch.load(path, map_location="cpu")["state_dict"] 80 | keys = list(sd.keys()) 81 | for k in keys: 82 | for ik in ignore_keys: 83 | if k.startswith(ik): 84 | print("Deleting key {} from state_dict.".format(k)) 85 | del sd[k] 86 | self.load_state_dict(sd, strict=False) 87 | print(f"Restored from {path}") 88 | 89 | def encode(self, x): 90 | h = self.encoder(x) 91 | moments = self.quant_conv(h) 92 | posterior = DiagonalGaussianDistribution(moments) 93 | # TODO check if need to put sample into DDIM_ldm class 94 | enc = posterior.sample() 95 | return enc #posterior 96 | 97 | def decode(self, z): 98 | z = self.post_quant_conv(z) 99 | dec = self.decoder(z) 100 | return dec 101 | 102 | def forward(self, input, sample_posterior=True): 103 | posterior = self.encode(input) 104 | if sample_posterior: 105 | z = posterior.sample() 106 | else: 107 | z = posterior.mode() 108 | dec = self.decode(z) 109 | return dec, posterior 110 | 111 | def get_input(self, batch, k): 112 | x = batch[k] 113 | if len(x.shape) == 3: 114 | x = x[..., None] 115 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() 116 | return x 117 | 118 | def training_step(self, batch, batch_idx, optimizer_idx): 119 | inputs = self.get_input(batch, self.image_key) 120 | reconstructions, posterior = self(inputs) 121 | 122 | if optimizer_idx == 0: 123 | # train encoder+decoder+logvar 124 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 125 | last_layer=self.get_last_layer(), split="train") 126 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 127 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) 128 | return aeloss 129 | 130 | if optimizer_idx == 1: 131 | # train the discriminator 132 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 133 | last_layer=self.get_last_layer(), split="train") 134 | 135 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 136 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) 137 | return discloss 138 | 139 | def validation_step(self, batch, batch_idx): 140 | inputs = self.get_input(batch, self.image_key) 141 | reconstructions, posterior = self(inputs) 142 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, 143 | last_layer=self.get_last_layer(), split="val") 144 | 145 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, 146 | last_layer=self.get_last_layer(), split="val") 147 | 148 | self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) 149 | self.log_dict(log_dict_ae) 150 | self.log_dict(log_dict_disc) 151 | return self.log_dict 152 | 153 | def configure_optimizers(self): 154 | lr = self.learning_rate 155 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ 156 | list(self.decoder.parameters())+ 157 | list(self.quant_conv.parameters())+ 158 | list(self.post_quant_conv.parameters()), 159 | lr=lr, betas=(0.5, 0.9)) 160 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), 161 | lr=lr, betas=(0.5, 0.9)) 162 | return [opt_ae, opt_disc], [] 163 | 164 | def get_last_layer(self): 165 | return self.decoder.conv_out.weight 166 | 167 | @torch.no_grad() 168 | def log_images(self, batch, only_inputs=False, **kwargs): 169 | log = dict() 170 | x = self.get_input(batch, self.image_key) 171 | x = x.to(self.device) 172 | if not only_inputs: 173 | xrec, posterior = self(x) 174 | if x.shape[1] > 3: 175 | # colorize with random projection 176 | assert xrec.shape[1] > 3 177 | x = self.to_rgb(x) 178 | xrec = self.to_rgb(xrec) 179 | log["samples"] = self.decode(torch.randn_like(posterior.sample())) 180 | log["reconstructions"] = xrec 181 | log["inputs"] = x 182 | return log 183 | 184 | def to_rgb(self, x): 185 | assert self.image_key == "segmentation" 186 | if not hasattr(self, "colorize"): 187 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 188 | x = F.conv2d(x, weight=self.colorize) 189 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1. 190 | return x -------------------------------------------------------------------------------- /modules/openclip/modules.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Mapping 2 | import torch 3 | import torch.nn as nn 4 | from torch.utils.checkpoint import checkpoint 5 | 6 | from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel 7 | 8 | import open_clip 9 | 10 | 11 | class AbstractEncoder(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | def encode(self, *args, **kwargs): 16 | raise NotImplementedError 17 | 18 | 19 | class IdentityEncoder(AbstractEncoder): 20 | 21 | def encode(self, x): 22 | return x 23 | 24 | 25 | class ClassEmbedder(nn.Module): 26 | def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1): 27 | super().__init__() 28 | self.key = key 29 | self.embedding = nn.Embedding(n_classes, embed_dim) 30 | self.n_classes = n_classes 31 | self.ucg_rate = ucg_rate 32 | 33 | def forward(self, batch, key=None, disable_dropout=False): 34 | if key is None: 35 | key = self.key 36 | # this is for use in crossattn 37 | c = batch[key][:, None] 38 | if self.ucg_rate > 0. and not disable_dropout: 39 | mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) 40 | c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1) 41 | c = c.long() 42 | c = self.embedding(c) 43 | return c 44 | 45 | def get_unconditional_conditioning(self, bs, device="cuda"): 46 | uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) 47 | uc = torch.ones((bs,), device=device) * uc_class 48 | uc = {self.key: uc} 49 | return uc 50 | 51 | 52 | def disabled_train(self, mode=True): 53 | """Overwrite model.train with this function to make sure train/eval mode 54 | does not change anymore.""" 55 | return self 56 | 57 | 58 | class FrozenT5Embedder(AbstractEncoder): 59 | """Uses the T5 transformer encoder for text""" 60 | def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl 61 | super().__init__() 62 | self.tokenizer = T5Tokenizer.from_pretrained(version) 63 | self.transformer = T5EncoderModel.from_pretrained(version) 64 | self.device = device 65 | self.max_length = max_length # TODO: typical value? 66 | if freeze: 67 | self.freeze() 68 | 69 | def freeze(self): 70 | self.transformer = self.transformer.eval() 71 | #self.train = disabled_train 72 | for param in self.parameters(): 73 | param.requires_grad = False 74 | 75 | def forward(self, text): 76 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 77 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 78 | tokens = batch_encoding["input_ids"].to(self.device) 79 | outputs = self.transformer(input_ids=tokens) 80 | 81 | z = outputs.last_hidden_state 82 | return z 83 | 84 | def encode(self, text): 85 | return self(text) 86 | 87 | 88 | class FrozenCLIPEmbedder(AbstractEncoder): 89 | """Uses the CLIP transformer encoder for text (from huggingface)""" 90 | LAYERS = [ 91 | "last", 92 | "pooled", 93 | "hidden" 94 | ] 95 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, 96 | freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 97 | super().__init__() 98 | assert layer in self.LAYERS 99 | self.tokenizer = CLIPTokenizer.from_pretrained(version) 100 | self.transformer = CLIPTextModel.from_pretrained(version) 101 | self.device = device 102 | self.max_length = max_length 103 | if freeze: 104 | self.freeze() 105 | self.layer = layer 106 | self.layer_idx = layer_idx 107 | if layer == "hidden": 108 | assert layer_idx is not None 109 | assert 0 <= abs(layer_idx) <= 12 110 | 111 | def freeze(self): 112 | self.transformer = self.transformer.eval() 113 | #self.train = disabled_train 114 | for param in self.parameters(): 115 | param.requires_grad = False 116 | 117 | def forward(self, text): 118 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 119 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 120 | tokens = batch_encoding["input_ids"].to(self.device) 121 | outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") 122 | if self.layer == "last": 123 | z = outputs.last_hidden_state 124 | elif self.layer == "pooled": 125 | z = outputs.pooler_output[:, None, :] 126 | else: 127 | z = outputs.hidden_states[self.layer_idx] 128 | return z 129 | 130 | def encode(self, text): 131 | return self(text) 132 | 133 | def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): 134 | state_dict.pop("transformer.text_model.embeddings.position_ids") # it seems that this is removed from the model in recent transformers versions 135 | return super().load_state_dict(state_dict, strict) 136 | 137 | 138 | class FrozenOpenCLIPEmbedder(AbstractEncoder): 139 | """ 140 | Uses the OpenCLIP transformer encoder for text 141 | """ 142 | LAYERS = [ 143 | #"pooled", 144 | "last", 145 | "penultimate" 146 | ] 147 | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, 148 | freeze=True, layer="last"): 149 | super().__init__() 150 | assert layer in self.LAYERS 151 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) 152 | del model.visual 153 | self.model = model 154 | 155 | self.device = device 156 | self.max_length = max_length 157 | if freeze: 158 | self.freeze() 159 | self.layer = layer 160 | if self.layer == "last": 161 | self.layer_idx = 0 162 | elif self.layer == "penultimate": 163 | self.layer_idx = 1 164 | else: 165 | raise NotImplementedError() 166 | 167 | def freeze(self): 168 | self.model = self.model.eval() 169 | for param in self.parameters(): 170 | param.requires_grad = False 171 | 172 | def forward(self, text): 173 | tokens = open_clip.tokenize(text) 174 | z = self.encode_with_transformer(tokens.to(self.device)) 175 | return z 176 | 177 | def encode_with_transformer(self, text): 178 | x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] 179 | x = x + self.model.positional_embedding 180 | x = x.permute(1, 0, 2) # NLD -> LND 181 | x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) 182 | x = x.permute(1, 0, 2) # LND -> NLD 183 | x = self.model.ln_final(x) 184 | return x 185 | 186 | def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): 187 | for i, r in enumerate(self.model.transformer.resblocks): 188 | if i == len(self.model.transformer.resblocks) - self.layer_idx: 189 | break 190 | if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): 191 | x = checkpoint(r, x, attn_mask) 192 | else: 193 | x = r(x, attn_mask=attn_mask) 194 | return x 195 | 196 | def encode(self, text): 197 | return self(text) 198 | 199 | 200 | class FrozenCLIPT5Encoder(AbstractEncoder): 201 | def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", 202 | clip_max_length=77, t5_max_length=77): 203 | super().__init__() 204 | self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) 205 | self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) 206 | # print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, " 207 | # f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.") 208 | 209 | def encode(self, text): 210 | return self(text) 211 | 212 | def forward(self, text): 213 | clip_z = self.clip_encoder.encode(text) 214 | t5_z = self.t5_encoder.encode(text) 215 | return [clip_z, t5_z] 216 | 217 | 218 | -------------------------------------------------------------------------------- /modules/video_unet_temporal/attention.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py 2 | 3 | from dataclasses import dataclass 4 | from typing import Optional 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import nn 9 | from torch import einsum 10 | from misc_utils.model_utils import default, exists 11 | 12 | from diffusers.configuration_utils import ConfigMixin, register_to_config 13 | from diffusers.models.modeling_utils import ModelMixin 14 | from diffusers.utils import BaseOutput 15 | from diffusers.utils.import_utils import is_xformers_available 16 | from diffusers.models.attention import Attention, FeedForward, AdaLayerNorm 17 | 18 | from einops import rearrange, repeat 19 | 20 | 21 | @dataclass 22 | class Transformer3DModelOutput(BaseOutput): 23 | sample: torch.FloatTensor 24 | 25 | 26 | if is_xformers_available(): 27 | import xformers 28 | import xformers.ops 29 | else: 30 | xformers = None 31 | 32 | 33 | class Transformer3DModel(ModelMixin, ConfigMixin): 34 | @register_to_config 35 | def __init__( 36 | self, 37 | num_attention_heads: int = 16, 38 | attention_head_dim: int = 88, 39 | in_channels: Optional[int] = None, 40 | num_layers: int = 1, 41 | dropout: float = 0.0, 42 | norm_num_groups: int = 32, 43 | cross_attention_dim: Optional[int] = None, 44 | attention_bias: bool = False, 45 | activation_fn: str = "geglu", 46 | num_embeds_ada_norm: Optional[int] = None, 47 | use_linear_projection: bool = False, 48 | only_cross_attention: bool = False, 49 | upcast_attention: bool = False, 50 | ): 51 | super().__init__() 52 | self.use_linear_projection = use_linear_projection 53 | self.num_attention_heads = num_attention_heads 54 | self.attention_head_dim = attention_head_dim 55 | inner_dim = num_attention_heads * attention_head_dim 56 | 57 | # Define input layers 58 | self.in_channels = in_channels 59 | 60 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 61 | if use_linear_projection: 62 | self.proj_in = nn.Linear(in_channels, inner_dim) 63 | else: 64 | self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) 65 | 66 | # Define transformers blocks 67 | self.transformer_blocks = nn.ModuleList( 68 | [ 69 | BasicTransformerBlock( 70 | inner_dim, 71 | num_attention_heads, 72 | attention_head_dim, 73 | dropout=dropout, 74 | cross_attention_dim=cross_attention_dim, 75 | activation_fn=activation_fn, 76 | num_embeds_ada_norm=num_embeds_ada_norm, 77 | attention_bias=attention_bias, 78 | only_cross_attention=only_cross_attention, 79 | upcast_attention=upcast_attention, 80 | ) 81 | for d in range(num_layers) 82 | ] 83 | ) 84 | 85 | # 4. Define output layers 86 | if use_linear_projection: 87 | self.proj_out = nn.Linear(in_channels, inner_dim) 88 | else: 89 | self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) 90 | 91 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): 92 | # Input 93 | assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 94 | video_length = hidden_states.shape[2] 95 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 96 | encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length) 97 | 98 | batch, channel, height, weight = hidden_states.shape 99 | residual = hidden_states 100 | 101 | hidden_states = self.norm(hidden_states) 102 | if not self.use_linear_projection: 103 | hidden_states = self.proj_in(hidden_states) 104 | inner_dim = hidden_states.shape[1] 105 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 106 | else: 107 | inner_dim = hidden_states.shape[1] 108 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 109 | hidden_states = self.proj_in(hidden_states) 110 | 111 | # Blocks 112 | for block in self.transformer_blocks: 113 | hidden_states = block( 114 | hidden_states, 115 | encoder_hidden_states=encoder_hidden_states, 116 | timestep=timestep, 117 | video_length=video_length 118 | ) 119 | 120 | # Output 121 | if not self.use_linear_projection: 122 | hidden_states = ( 123 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 124 | ) 125 | hidden_states = self.proj_out(hidden_states) 126 | else: 127 | hidden_states = self.proj_out(hidden_states) 128 | hidden_states = ( 129 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 130 | ) 131 | 132 | output = hidden_states + residual 133 | 134 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 135 | if not return_dict: 136 | return (output,) 137 | 138 | return Transformer3DModelOutput(sample=output) 139 | 140 | 141 | class BasicTransformerBlock(nn.Module): 142 | def __init__( 143 | self, 144 | dim: int, 145 | num_attention_heads: int, 146 | attention_head_dim: int, 147 | dropout=0.0, 148 | cross_attention_dim: Optional[int] = None, 149 | activation_fn: str = "geglu", 150 | num_embeds_ada_norm: Optional[int] = None, 151 | attention_bias: bool = False, 152 | only_cross_attention: bool = False, 153 | upcast_attention: bool = False, 154 | ): 155 | super().__init__() 156 | self.only_cross_attention = only_cross_attention 157 | self.use_ada_layer_norm = num_embeds_ada_norm is not None 158 | 159 | # SC-Attn 160 | self.attn1 = Attention( 161 | query_dim=dim, 162 | heads=num_attention_heads, 163 | dim_head=attention_head_dim, 164 | dropout=dropout, 165 | bias=attention_bias, 166 | upcast_attention=upcast_attention, 167 | ) 168 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) 169 | 170 | # Cross-Attn 171 | if cross_attention_dim is not None: 172 | self.attn2 = Attention( 173 | query_dim=dim, 174 | cross_attention_dim=cross_attention_dim, 175 | heads=num_attention_heads, 176 | dim_head=attention_head_dim, 177 | dropout=dropout, 178 | bias=attention_bias, 179 | upcast_attention=upcast_attention, 180 | ) 181 | else: 182 | self.attn2 = None 183 | 184 | if cross_attention_dim is not None: 185 | self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) 186 | else: 187 | self.norm2 = None 188 | 189 | # Feed-forward 190 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) 191 | self.norm3 = nn.LayerNorm(dim) 192 | 193 | # Temp-Attn 194 | # self.attn_temp = Attention( 195 | # query_dim=dim, 196 | # heads=num_attention_heads, 197 | # dim_head=attention_head_dim, 198 | # dropout=dropout, 199 | # bias=attention_bias, 200 | # upcast_attention=upcast_attention, 201 | # ) 202 | # nn.init.zeros_(self.attn_temp.to_out[0].weight.data) 203 | # self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) 204 | 205 | def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, attention_op): 206 | if not is_xformers_available(): 207 | print("Here is how to install it") 208 | raise ModuleNotFoundError( 209 | "Refer to https://github.com/facebookresearch/xformers for more information on how to install" 210 | " xformers", 211 | name="xformers", 212 | ) 213 | elif not torch.cuda.is_available(): 214 | raise ValueError( 215 | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" 216 | " available for GPU " 217 | ) 218 | else: 219 | try: 220 | # Make sure we can run the memory efficient attention 221 | _ = xformers.ops.memory_efficient_attention( 222 | torch.randn((1, 2, 40), device="cuda"), 223 | torch.randn((1, 2, 40), device="cuda"), 224 | torch.randn((1, 2, 40), device="cuda"), 225 | ) 226 | except Exception as e: 227 | raise e 228 | self.attn1.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers, attention_op) 229 | if self.attn2 is not None: 230 | self.attn2.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers, attention_op) 231 | # self.attn_temp.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers, attention_op) 232 | 233 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None): 234 | # SparseCausal-Attention 235 | norm_hidden_states = ( 236 | self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) 237 | ) 238 | 239 | if self.only_cross_attention: 240 | hidden_states = ( 241 | self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states 242 | ) 243 | else: 244 | hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states 245 | 246 | if self.attn2 is not None: 247 | # Cross-Attention 248 | norm_hidden_states = ( 249 | self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) 250 | ) 251 | hidden_states = ( 252 | self.attn2( 253 | norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask 254 | ) 255 | + hidden_states 256 | ) 257 | 258 | # Feed-forward 259 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states 260 | 261 | # Temporal-Attention 262 | # d = hidden_states.shape[1] 263 | # hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) 264 | # norm_hidden_states = ( 265 | # self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states) 266 | # ) 267 | # hidden_states = self.attn_temp(norm_hidden_states) + hidden_states 268 | # hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) 269 | 270 | return hidden_states 271 | -------------------------------------------------------------------------------- /modules/video_unet_temporal/resnet.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from einops import rearrange 8 | 9 | 10 | class InflatedConv3d(nn.Conv2d): 11 | def forward(self, x): 12 | video_length = x.shape[2] 13 | 14 | x = rearrange(x, "b c f h w -> (b f) c h w") 15 | x = super().forward(x) 16 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 17 | 18 | return x 19 | 20 | 21 | class Upsample3D(nn.Module): 22 | def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): 23 | super().__init__() 24 | self.channels = channels 25 | self.out_channels = out_channels or channels 26 | self.use_conv = use_conv 27 | self.use_conv_transpose = use_conv_transpose 28 | self.name = name 29 | 30 | conv = None 31 | if use_conv_transpose: 32 | raise NotImplementedError 33 | elif use_conv: 34 | conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) 35 | 36 | if name == "conv": 37 | self.conv = conv 38 | else: 39 | self.Conv2d_0 = conv 40 | 41 | def forward(self, hidden_states, output_size=None): 42 | assert hidden_states.shape[1] == self.channels 43 | 44 | if self.use_conv_transpose: 45 | raise NotImplementedError 46 | 47 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 48 | dtype = hidden_states.dtype 49 | if dtype == torch.bfloat16: 50 | hidden_states = hidden_states.to(torch.float32) 51 | 52 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 53 | if hidden_states.shape[0] >= 64: 54 | hidden_states = hidden_states.contiguous() 55 | 56 | # if `output_size` is passed we force the interpolation output 57 | # size and do not make use of `scale_factor=2` 58 | if output_size is None: 59 | hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") 60 | else: 61 | hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") 62 | 63 | # If the input is bfloat16, we cast back to bfloat16 64 | if dtype == torch.bfloat16: 65 | hidden_states = hidden_states.to(dtype) 66 | 67 | if self.use_conv: 68 | if self.name == "conv": 69 | hidden_states = self.conv(hidden_states) 70 | else: 71 | hidden_states = self.Conv2d_0(hidden_states) 72 | 73 | return hidden_states 74 | 75 | 76 | class Downsample3D(nn.Module): 77 | def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): 78 | super().__init__() 79 | self.channels = channels 80 | self.out_channels = out_channels or channels 81 | self.use_conv = use_conv 82 | self.padding = padding 83 | stride = 2 84 | self.name = name 85 | 86 | if use_conv: 87 | conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) 88 | else: 89 | raise NotImplementedError 90 | 91 | if name == "conv": 92 | self.Conv2d_0 = conv 93 | self.conv = conv 94 | elif name == "Conv2d_0": 95 | self.conv = conv 96 | else: 97 | self.conv = conv 98 | 99 | def forward(self, hidden_states): 100 | assert hidden_states.shape[1] == self.channels 101 | if self.use_conv and self.padding == 0: 102 | raise NotImplementedError 103 | 104 | assert hidden_states.shape[1] == self.channels 105 | hidden_states = self.conv(hidden_states) 106 | 107 | return hidden_states 108 | 109 | 110 | class ResnetBlock3D(nn.Module): 111 | def __init__( 112 | self, 113 | *, 114 | in_channels, 115 | out_channels=None, 116 | conv_shortcut=False, 117 | dropout=0.0, 118 | temb_channels=512, 119 | groups=32, 120 | groups_out=None, 121 | pre_norm=True, 122 | eps=1e-6, 123 | non_linearity="swish", 124 | time_embedding_norm="default", 125 | output_scale_factor=1.0, 126 | use_in_shortcut=None, 127 | ): 128 | super().__init__() 129 | self.pre_norm = pre_norm 130 | self.pre_norm = True 131 | self.in_channels = in_channels 132 | out_channels = in_channels if out_channels is None else out_channels 133 | self.out_channels = out_channels 134 | self.use_conv_shortcut = conv_shortcut 135 | self.time_embedding_norm = time_embedding_norm 136 | self.output_scale_factor = output_scale_factor 137 | 138 | if groups_out is None: 139 | groups_out = groups 140 | 141 | self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 142 | 143 | self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 144 | 145 | if temb_channels is not None: 146 | if self.time_embedding_norm == "default": 147 | time_emb_proj_out_channels = out_channels 148 | elif self.time_embedding_norm == "scale_shift": 149 | time_emb_proj_out_channels = out_channels * 2 150 | else: 151 | raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") 152 | 153 | self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) 154 | else: 155 | self.time_emb_proj = None 156 | 157 | self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) 158 | self.dropout = torch.nn.Dropout(dropout) 159 | self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 160 | 161 | if non_linearity == "swish": 162 | self.nonlinearity = lambda x: F.silu(x) 163 | elif non_linearity == "mish": 164 | self.nonlinearity = Mish() 165 | elif non_linearity == "silu": 166 | self.nonlinearity = nn.SiLU() 167 | 168 | self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut 169 | 170 | self.conv_shortcut = None 171 | if self.use_in_shortcut: 172 | self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 173 | 174 | def forward(self, input_tensor, temb): 175 | hidden_states = input_tensor 176 | 177 | hidden_states = self.norm1(hidden_states) 178 | hidden_states = self.nonlinearity(hidden_states) 179 | 180 | hidden_states = self.conv1(hidden_states) 181 | 182 | if temb is not None: 183 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] 184 | 185 | if temb is not None and self.time_embedding_norm == "default": 186 | hidden_states = hidden_states + temb 187 | 188 | hidden_states = self.norm2(hidden_states) 189 | 190 | if temb is not None and self.time_embedding_norm == "scale_shift": 191 | scale, shift = torch.chunk(temb, 2, dim=1) 192 | hidden_states = hidden_states * (1 + scale) + shift 193 | 194 | hidden_states = self.nonlinearity(hidden_states) 195 | 196 | hidden_states = self.dropout(hidden_states) 197 | hidden_states = self.conv2(hidden_states) 198 | 199 | if self.conv_shortcut is not None: 200 | input_tensor = self.conv_shortcut(input_tensor) 201 | 202 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor 203 | 204 | return output_tensor 205 | 206 | 207 | class Mish(torch.nn.Module): 208 | def forward(self, hidden_states): 209 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) -------------------------------------------------------------------------------- /modules/vqvae/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pytorch_lightning as pl 4 | import torch.nn.functional as F 5 | from contextlib import contextmanager 6 | 7 | from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer 8 | 9 | from .model import Encoder, Decoder 10 | 11 | from misc_utils.model_utils import instantiate_from_config 12 | 13 | 14 | class VQModel(pl.LightningModule): 15 | def __init__(self, 16 | ddconfig, 17 | lossconfig, 18 | n_embed, 19 | embed_dim, 20 | ckpt_path=None, 21 | ignore_keys=[], 22 | image_key="image", 23 | colorize_nlabels=None, 24 | monitor=None, 25 | batch_resize_range=None, 26 | scheduler_config=None, 27 | lr_g_factor=1.0, 28 | remap=None, 29 | sane_index_shape=False, # tell vector quantizer to return indices as bhw 30 | use_ema=False 31 | ): 32 | super().__init__() 33 | self.embed_dim = embed_dim 34 | self.n_embed = n_embed 35 | self.image_key = image_key 36 | self.encoder = Encoder(**ddconfig) 37 | self.decoder = Decoder(**ddconfig) 38 | self.loss = instantiate_from_config(lossconfig) 39 | self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, 40 | remap=remap, 41 | sane_index_shape=sane_index_shape) 42 | self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) 43 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 44 | if colorize_nlabels is not None: 45 | assert type(colorize_nlabels)==int 46 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 47 | if monitor is not None: 48 | self.monitor = monitor 49 | self.batch_resize_range = batch_resize_range 50 | if self.batch_resize_range is not None: 51 | print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.") 52 | 53 | self.use_ema = use_ema 54 | if self.use_ema: 55 | self.model_ema = LitEma(self) 56 | print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") 57 | 58 | if ckpt_path is not None: 59 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 60 | self.scheduler_config = scheduler_config 61 | self.lr_g_factor = lr_g_factor 62 | 63 | @contextmanager 64 | def ema_scope(self, context=None): 65 | if self.use_ema: 66 | self.model_ema.store(self.parameters()) 67 | self.model_ema.copy_to(self) 68 | if context is not None: 69 | print(f"{context}: Switched to EMA weights") 70 | try: 71 | yield None 72 | finally: 73 | if self.use_ema: 74 | self.model_ema.restore(self.parameters()) 75 | if context is not None: 76 | print(f"{context}: Restored training weights") 77 | 78 | def init_from_ckpt(self, path, ignore_keys=list()): 79 | sd = torch.load(path, map_location="cpu")["state_dict"] 80 | keys = list(sd.keys()) 81 | for k in keys: 82 | for ik in ignore_keys: 83 | if k.startswith(ik): 84 | print("Deleting key {} from state_dict.".format(k)) 85 | del sd[k] 86 | missing, unexpected = self.load_state_dict(sd, strict=False) 87 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") 88 | if len(missing) > 0: 89 | print(f"Missing Keys: {missing}") 90 | print(f"Unexpected Keys: {unexpected}") 91 | 92 | def on_train_batch_end(self, *args, **kwargs): 93 | if self.use_ema: 94 | self.model_ema(self) 95 | 96 | def encode(self, x): 97 | h = self.encoder(x) 98 | h = self.quant_conv(h) 99 | quant, emb_loss, info = self.quantize(h) 100 | return quant, emb_loss, info 101 | 102 | def encode_to_prequant(self, x): 103 | h = self.encoder(x) 104 | h = self.quant_conv(h) 105 | return h 106 | 107 | def decode(self, quant): 108 | quant = self.post_quant_conv(quant) 109 | dec = self.decoder(quant) 110 | return dec 111 | 112 | def decode_code(self, code_b): 113 | quant_b = self.quantize.embed_code(code_b) 114 | dec = self.decode(quant_b) 115 | return dec 116 | 117 | def forward(self, input, return_pred_indices=False): 118 | quant, diff, (_,_,ind) = self.encode(input) 119 | dec = self.decode(quant) 120 | if return_pred_indices: 121 | return dec, diff, ind 122 | return dec, diff 123 | 124 | def get_input(self, batch, k): 125 | x = batch[k] 126 | if len(x.shape) == 3: 127 | x = x[..., None] 128 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() 129 | if self.batch_resize_range is not None: 130 | lower_size = self.batch_resize_range[0] 131 | upper_size = self.batch_resize_range[1] 132 | if self.global_step <= 4: 133 | # do the first few batches with max size to avoid later oom 134 | new_resize = upper_size 135 | else: 136 | new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16)) 137 | if new_resize != x.shape[2]: 138 | x = F.interpolate(x, size=new_resize, mode="bicubic") 139 | x = x.detach() 140 | return x 141 | 142 | def training_step(self, batch, batch_idx, optimizer_idx): 143 | # https://github.com/pytorch/pytorch/issues/37142 144 | # try not to fool the heuristics 145 | x = self.get_input(batch, self.image_key) 146 | xrec, qloss, ind = self(x, return_pred_indices=True) 147 | 148 | if optimizer_idx == 0: 149 | # autoencode 150 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, 151 | last_layer=self.get_last_layer(), split="train", 152 | predicted_indices=ind) 153 | 154 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) 155 | return aeloss 156 | 157 | if optimizer_idx == 1: 158 | # discriminator 159 | discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, 160 | last_layer=self.get_last_layer(), split="train") 161 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) 162 | return discloss 163 | 164 | def validation_step(self, batch, batch_idx): 165 | log_dict = self._validation_step(batch, batch_idx) 166 | with self.ema_scope(): 167 | log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") 168 | return log_dict 169 | 170 | def _validation_step(self, batch, batch_idx, suffix=""): 171 | x = self.get_input(batch, self.image_key) 172 | xrec, qloss, ind = self(x, return_pred_indices=True) 173 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, 174 | self.global_step, 175 | last_layer=self.get_last_layer(), 176 | split="val"+suffix, 177 | predicted_indices=ind 178 | ) 179 | 180 | discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, 181 | self.global_step, 182 | last_layer=self.get_last_layer(), 183 | split="val"+suffix, 184 | predicted_indices=ind 185 | ) 186 | rec_loss = log_dict_ae[f"val{suffix}/rec_loss"] 187 | self.log(f"val{suffix}/rec_loss", rec_loss, 188 | prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) 189 | self.log(f"val{suffix}/aeloss", aeloss, 190 | prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) 191 | if version.parse(pl.__version__) >= version.parse('1.4.0'): 192 | del log_dict_ae[f"val{suffix}/rec_loss"] 193 | self.log_dict(log_dict_ae) 194 | self.log_dict(log_dict_disc) 195 | return self.log_dict 196 | 197 | def configure_optimizers(self): 198 | lr_d = self.learning_rate 199 | lr_g = self.lr_g_factor*self.learning_rate 200 | print("lr_d", lr_d) 201 | print("lr_g", lr_g) 202 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ 203 | list(self.decoder.parameters())+ 204 | list(self.quantize.parameters())+ 205 | list(self.quant_conv.parameters())+ 206 | list(self.post_quant_conv.parameters()), 207 | lr=lr_g, betas=(0.5, 0.9)) 208 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), 209 | lr=lr_d, betas=(0.5, 0.9)) 210 | 211 | if self.scheduler_config is not None: 212 | scheduler = instantiate_from_config(self.scheduler_config) 213 | 214 | print("Setting up LambdaLR scheduler...") 215 | scheduler = [ 216 | { 217 | 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule), 218 | 'interval': 'step', 219 | 'frequency': 1 220 | }, 221 | { 222 | 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule), 223 | 'interval': 'step', 224 | 'frequency': 1 225 | }, 226 | ] 227 | return [opt_ae, opt_disc], scheduler 228 | return [opt_ae, opt_disc], [] 229 | 230 | def get_last_layer(self): 231 | return self.decoder.conv_out.weight 232 | 233 | def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): 234 | log = dict() 235 | x = self.get_input(batch, self.image_key) 236 | x = x.to(self.device) 237 | if only_inputs: 238 | log["inputs"] = x 239 | return log 240 | xrec, _ = self(x) 241 | if x.shape[1] > 3: 242 | # colorize with random projection 243 | assert xrec.shape[1] > 3 244 | x = self.to_rgb(x) 245 | xrec = self.to_rgb(xrec) 246 | log["inputs"] = x 247 | log["reconstructions"] = xrec 248 | if plot_ema: 249 | with self.ema_scope(): 250 | xrec_ema, _ = self(x) 251 | if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) 252 | log["reconstructions_ema"] = xrec_ema 253 | return log 254 | 255 | def to_rgb(self, x): 256 | assert self.image_key == "segmentation" 257 | if not hasattr(self, "colorize"): 258 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 259 | x = F.conv2d(x, weight=self.colorize) 260 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1. 261 | return x 262 | 263 | 264 | class VQModelInterface(VQModel): 265 | def __init__(self, embed_dim, *args, **kwargs): 266 | super().__init__(embed_dim=embed_dim, *args, **kwargs) 267 | self.embed_dim = embed_dim 268 | 269 | def encode(self, x): 270 | h = self.encoder(x) 271 | h = self.quant_conv(h) 272 | return h 273 | 274 | def decode(self, h, force_not_quantize=False): 275 | # also go through quantization layer 276 | if not force_not_quantize: 277 | quant, emb_loss, info = self.quantize(h) 278 | else: 279 | quant = h 280 | quant = self.post_quant_conv(quant) 281 | dec = self.decoder(quant) 282 | return dec 283 | 284 | -------------------------------------------------------------------------------- /pl_trainer/diffusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import pytorch_lightning as pl 4 | from misc_utils.model_utils import default, instantiate_from_config 5 | from diffusers import DDPMScheduler 6 | 7 | def mean_flat(tensor): 8 | """ 9 | Take the mean over all non-batch dimensions. 10 | """ 11 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 12 | 13 | class DDPM(pl.LightningModule): 14 | def __init__( 15 | self, 16 | unet, 17 | beta_schedule_args={ 18 | 'beta_start': 0.00085, 19 | 'beta_end': 0.0012, 20 | 'num_train_timesteps': 1000, 21 | 'beta_schedule': 'scaled_linear', 22 | 'clip_sample': False, 23 | 'thresholding': False, 24 | }, 25 | prediction_type='epsilon', 26 | loss_fn='l2', 27 | optim_args={}, 28 | **kwargs 29 | ): 30 | ''' 31 | denoising_fn: a denoising model such as UNet 32 | beta_schedule_args: a dictionary which contains 33 | the configurations of the beta schedule 34 | ''' 35 | super().__init__(**kwargs) 36 | self.unet = unet 37 | self.prediction_type = prediction_type 38 | beta_schedule_args.update({'prediction_type': prediction_type}) 39 | self.set_beta_schedule(beta_schedule_args) 40 | self.num_timesteps = beta_schedule_args['num_train_timesteps'] 41 | self.optim_args = optim_args 42 | self.loss = loss_fn 43 | if loss_fn == 'l2' or loss_fn == 'mse': 44 | self.loss_fn = nn.MSELoss(reduction='none') 45 | elif loss_fn == 'l1' or loss_fn == 'mae': 46 | self.loss_fn = nn.L1Loss(reduction='none') 47 | elif isinstance(loss_fn, dict): 48 | self.loss_fn = instantiate_from_config(loss_fn) 49 | else: 50 | raise NotImplementedError 51 | 52 | def set_beta_schedule(self, beta_schedule_args): 53 | self.beta_schedule_args = beta_schedule_args 54 | self.scheduler = DDPMScheduler(**beta_schedule_args) 55 | 56 | @torch.no_grad() 57 | def add_noise(self, x, t, noise=None): 58 | noise = default(noise, torch.randn_like(x)) 59 | return self.scheduler.add_noise(x, noise, t) 60 | 61 | def predict_x_0_from_x_t(self, model_output: torch.Tensor, t: torch.LongTensor, x_t: torch.Tensor): 62 | ''' recover x_0 from predicted noise. Reverse of Eq(4) in DDPM paper 63 | \hat(x_0) = 1 / sqrt[\bar(a)]*x_t - sqrt[(1-\bar(a)) / \bar(a)]*noise''' 64 | # return self.scheduler.step(model_output, int(t), x_t).pred_original_sample 65 | if self.prediction_type == 'sample': 66 | return model_output 67 | # for training target == epsilon 68 | alphas_cumprod = self.scheduler.alphas_cumprod.to(device=x_t.device, dtype=x_t.dtype) 69 | sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod[t]).flatten() 70 | sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod[t] - 1.).flatten() 71 | while len(sqrt_recip_alphas_cumprod.shape) < len(x_t.shape): 72 | sqrt_recip_alphas_cumprod = sqrt_recip_alphas_cumprod.unsqueeze(-1) 73 | sqrt_recipm1_alphas_cumprod = sqrt_recipm1_alphas_cumprod.unsqueeze(-1) 74 | return sqrt_recip_alphas_cumprod * x_t - sqrt_recipm1_alphas_cumprod * model_output 75 | 76 | def predict_x_tm1_from_x_t(self, model_output, t, x_t): 77 | '''predict x_{t-1} from x_t and model_output''' 78 | return self.scheduler.step(model_output, t, x_t).prev_sample 79 | 80 | class DDPMTraining(DDPM): 81 | def __init__( 82 | self, 83 | unet, 84 | beta_schedule_args, 85 | prediction_type='epsilon', 86 | loss_fn='l2', 87 | optim_args={ 88 | 'lr': 1e-3, 89 | 'weight_decay': 5e-4 90 | }, 91 | log_args={}, # for record all arguments with self.save_hyperparameters 92 | ddim_sampling_steps=20, 93 | guidance_scale=5., 94 | **kwargs 95 | ): 96 | super().__init__( 97 | unet=unet, 98 | beta_schedule_args=beta_schedule_args, 99 | prediction_type=prediction_type, 100 | loss_fn=loss_fn, 101 | optim_args=optim_args, 102 | **kwargs) 103 | self.log_args = log_args 104 | self.call_save_hyperparameters() 105 | 106 | self.ddim_sampling_steps = ddim_sampling_steps 107 | self.guidance_scale = guidance_scale 108 | 109 | def call_save_hyperparameters(self): 110 | '''write in a separate function so that the inherit class can overwrite it''' 111 | self.save_hyperparameters(ignore=['unet']) 112 | 113 | def process_batch(self, x_0, mode): 114 | assert mode in ['train', 'val', 'test'] 115 | b, *_ = x_0.shape 116 | noise = torch.randn_like(x_0) 117 | if mode == 'train': 118 | t = torch.randint(0, self.num_timesteps, (b,), device=x_0.device).long() 119 | x_t = self.add_noise(x_0, t, noise=noise) 120 | else: 121 | t = torch.full((b,), self.num_timesteps-1, device=x_0.device, dtype=torch.long) 122 | x_t = self.add_noise(x_0, t, noise=noise) 123 | 124 | model_kwargs = {} 125 | '''the order of return is 126 | 1) model input, 127 | 2) model pred target, 128 | 3) model time condition 129 | 4) raw image before adding noise 130 | 5) model_kwargs 131 | ''' 132 | if self.prediction_type == 'epsilon': 133 | return { 134 | 'model_input': x_t, 135 | 'model_target': noise, 136 | 't': t, 137 | 'model_kwargs': model_kwargs 138 | } 139 | else: 140 | return { 141 | 'model_input': x_t, 142 | 'model_target': x_0, 143 | 't': t, 144 | 'model_kwargs': model_kwargs 145 | } 146 | 147 | def forward(self, x): 148 | return self.validation_step(x, 0) 149 | 150 | def get_loss(self, pred, target, t): 151 | loss_raw = self.loss_fn(pred, target) 152 | loss_flat = mean_flat(loss_raw) 153 | 154 | loss = loss_flat 155 | loss = loss.mean() 156 | 157 | return loss 158 | 159 | def training_step(self, batch, batch_idx): 160 | self.clip_denoised = False 161 | processed_batch = self.process_batch(batch, mode='train') 162 | x_t = processed_batch['model_input'] 163 | y = processed_batch['model_target'] 164 | t = processed_batch['t'] 165 | model_kwargs = processed_batch['model_kwargs'] 166 | pred = self.unet(x_t, t, **model_kwargs) 167 | loss = self.get_loss(pred, y, t) 168 | x_0_hat = self.predict_x_0_from_x_t(pred, t, x_t) 169 | 170 | self.log(f'train_loss', loss) 171 | return { 172 | 'loss': loss, 173 | 'model_input': x_t, 174 | 'model_output': pred, 175 | 'x_0_hat': x_0_hat 176 | } 177 | 178 | @torch.no_grad() 179 | def validation_step(self, batch, batch_idx): 180 | from diffusers import DDIMScheduler 181 | scheduler = DDIMScheduler(**self.beta_schedule_args) 182 | scheduler.set_timesteps(self.ddim_sampling_steps) 183 | processed_batch = self.process_batch(batch, mode='val') 184 | x_t = torch.randn_like(processed_batch['model_input']) 185 | x_hist = [] 186 | timesteps = scheduler.timesteps 187 | for i, t in enumerate(timesteps): 188 | t_ = torch.full((x_t.shape[0],), t, device=x_t.device, dtype=torch.long) 189 | model_output = self.unet(x_t, t_, **processed_batch['model_kwargs']) 190 | x_hist.append( 191 | self.predict_x_0_from_x_t(model_output, t_, x_t) 192 | ) 193 | x_t = scheduler.step(model_output, t, x_t).prev_sample 194 | 195 | return { 196 | 'x_pred': x_t, 197 | 'x_hist': torch.stack(x_hist, dim=1), 198 | } 199 | 200 | def test_step(self, batch, batch_idx): 201 | '''Test is usually not used in a sampling problem''' 202 | return self.validation_step(batch, batch_idx) 203 | 204 | 205 | def configure_optimizers(self): 206 | optimizer = torch.optim.Adam(self.parameters(), **self.optim_args) 207 | return optimizer 208 | 209 | class DDPMLDMTraining(DDPMTraining): 210 | def __init__( 211 | self, *args, 212 | vae, 213 | unet_init_weights=None, 214 | vae_init_weights=None, 215 | scale_factor=0.18215, 216 | **kwargs 217 | ): 218 | super().__init__(*args, **kwargs) 219 | self.vae = vae 220 | self.scale_factor = scale_factor 221 | self.initialize_unet(unet_init_weights) 222 | self.initialize_vqvae(vae_init_weights) 223 | 224 | def initialize_unet(self, unet_init_weights): 225 | if unet_init_weights is not None: 226 | print(f'INFO: initialize denoising UNet from {unet_init_weights}') 227 | sd = torch.load(unet_init_weights, map_location='cpu') 228 | self.unet.load_state_dict(sd) 229 | 230 | def initialize_vqvae(self, vqvae_init_weights): 231 | if vqvae_init_weights is not None: 232 | print(f'INFO: initialize VQVAE from {vqvae_init_weights}') 233 | sd = torch.load(vqvae_init_weights, map_location='cpu') 234 | self.vae.load_state_dict(sd) 235 | for param in self.vae.parameters(): 236 | param.requires_grad = False 237 | 238 | def call_save_hyperparameters(self): 239 | '''write in a separate function so that the inherit class can overwrite it''' 240 | self.save_hyperparameters(ignore=['unet', 'vae']) 241 | 242 | @torch.no_grad() 243 | def encode_image_to_latent(self, x): 244 | return self.vae.encode(x) * self.scale_factor 245 | 246 | @torch.no_grad() 247 | def decode_latent_to_image(self, x): 248 | x = x / self.scale_factor 249 | return self.vae.decode(x) 250 | 251 | def process_batch(self, x_0, mode): 252 | x_0 = self.encode_image_to_latent(x_0) 253 | res = super().process_batch(x_0, mode) 254 | return res 255 | 256 | def training_step(self, batch, batch_idx): 257 | res_dict = super().training_step(batch, batch_idx) 258 | res_dict['x_0_hat'] = self.decode_latent_to_image(res_dict['x_0_hat']) 259 | return res_dict 260 | 261 | class DDIMLDMTextTraining(DDPMLDMTraining): 262 | def __init__( 263 | self, *args, 264 | text_model, 265 | text_model_init_weights=None, 266 | **kwargs 267 | ): 268 | super().__init__( 269 | *args, **kwargs 270 | ) 271 | self.text_model = text_model 272 | self.initialize_text_model(text_model_init_weights) 273 | 274 | def initialize_text_model(self, text_model_init_weights): 275 | if text_model_init_weights is not None: 276 | print(f'INFO: initialize text model from {text_model_init_weights}') 277 | sd = torch.load(text_model_init_weights, map_location='cpu') 278 | self.text_model.load_state_dict(sd) 279 | for param in self.text_model.parameters(): 280 | param.requires_grad = False 281 | 282 | def call_save_hyperparameters(self): 283 | '''write in a separate function so that the inherit class can overwrite it''' 284 | self.save_hyperparameters(ignore=['unet', 'vae', 'text_model']) 285 | 286 | @torch.no_grad() 287 | def encode_text(self, x): 288 | if isinstance(x, tuple): 289 | x = list(x) 290 | return self.text_model.encode(x) 291 | 292 | def process_batch(self, batch, mode): 293 | x_0 = batch['image'] 294 | text = batch['text'] 295 | processed_batch = super().process_batch(x_0, mode) 296 | processed_batch['model_kwargs'].update({ 297 | 'context': {'text': self.encode_text([text])} 298 | }) 299 | return processed_batch 300 | 301 | def sampling(self, image_shape=(1, 4, 64, 64), text='', negative_text=None): 302 | ''' 303 | Usage: 304 | sampled = self.sampling(text='a cat on the tree', negative_text='') 305 | 306 | x = sampled['x_pred'][0].permute(1, 2, 0).detach().cpu().numpy() 307 | x = x / 2 + 0.5 308 | plt.imshow(x) 309 | 310 | y = sampled['x_hist'][0, 10].permute(1, 2, 0).detach().cpu().numpy() 311 | y = y / 2 + 0.5 312 | plt.imshow(y) 313 | ''' 314 | from diffusers import DDIMScheduler 315 | scheduler = DDIMScheduler(**self.beta_schedule_args) 316 | scheduler.set_timesteps(self.ddim_sampling_steps) 317 | x_t = torch.randn(*image_shape, device=self.device) 318 | 319 | do_cfg = self.guidance_scale > 1. and negative_text is not None 320 | 321 | if do_cfg: 322 | context = {'text': self.encode_text([text, negative_text])} 323 | else: 324 | context = {'text': self.encode_text([text])} 325 | x_hist = [] 326 | timesteps = scheduler.timesteps 327 | for i, t in enumerate(timesteps): 328 | if do_cfg: 329 | model_input = torch.cat([x_t]*2) 330 | else: 331 | model_input = x_t 332 | t_ = torch.full((model_input.shape[0],), t, device=x_t.device, dtype=torch.long) 333 | model_output = self.unet(model_input, t_, context) 334 | 335 | if do_cfg: 336 | model_output_positive, model_output_negative = model_output.chunk(2) 337 | model_output = model_output_negative + self.guidance_scale * (model_output_positive - model_output_negative) 338 | x_hist.append( 339 | self.decode_latent_to_image(self.predict_x_0_from_x_t(model_output, t_[:x_t.shape[0]], x_t)) 340 | ) 341 | x_t = scheduler.step(model_output, t, x_t).prev_sample 342 | 343 | return { 344 | 'x_pred': self.decode_latent_to_image(x_t), 345 | 'x_hist': torch.stack(x_hist, dim=1), 346 | } 347 | -------------------------------------------------------------------------------- /pl_trainer/instruct_p2p_video.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Use pretrained instruct pix2pix model but add additional channels for reference modification 3 | ''' 4 | 5 | import torch 6 | from .diffusion import DDIMLDMTextTraining 7 | from einops import rearrange 8 | 9 | class InstructP2PVideoTrainer(DDIMLDMTextTraining): 10 | def __init__( 11 | self, *args, 12 | cond_image_dropout=0.1, 13 | prompt_type='output_prompt', 14 | text_cfg=7.5, 15 | img_cfg=1.2, 16 | **kwargs 17 | ): 18 | super().__init__(*args, **kwargs) 19 | self.cond_image_dropout = cond_image_dropout 20 | 21 | assert prompt_type in ['output_prompt', 'edit_prompt', 'mixed_prompt'] 22 | self.prompt_type = prompt_type 23 | 24 | self.text_cfg = text_cfg 25 | self.img_cfg = img_cfg 26 | 27 | self.unet.enable_xformers_memory_efficient_attention() 28 | self.unet.enable_gradient_checkpointing() 29 | 30 | def encode_text(self, text): 31 | with torch.cuda.amp.autocast(dtype=torch.float16): 32 | encoded_text = super().encode_text(text) 33 | return encoded_text 34 | 35 | def encode_image_to_latent(self, image): 36 | with torch.cuda.amp.autocast(dtype=torch.float16): 37 | latent = super().encode_image_to_latent(image) 38 | return latent 39 | 40 | @torch.cuda.amp.autocast(dtype=torch.float16) 41 | @torch.no_grad() 42 | def get_prompt(self, batch, mode): 43 | if mode == 'train': 44 | if self.prompt_type == 'output_prompt': 45 | prompt = batch['output_prompt'] 46 | elif self.prompt_type == 'edit_prompt': 47 | prompt = batch['edit_prompt'] 48 | elif self.prompt_type == 'mixed_prompt': 49 | if int(torch.rand(1)) > 0.5: 50 | prompt = batch['output_prompt'] 51 | else: 52 | prompt = batch['edit_prompt'] 53 | else: 54 | prompt = batch['output_prompt'] 55 | return self.encode_text(prompt) 56 | 57 | @torch.cuda.amp.autocast(dtype=torch.float16) 58 | @torch.no_grad() 59 | def encode_image_to_latent(self, image): 60 | b, f, c, h, w = image.shape 61 | image = rearrange(image, 'b f c h w -> (b f) c h w') 62 | latent = super().encode_image_to_latent(image) 63 | latent = rearrange(latent, '(b f) c h w -> b f c h w', b=b) 64 | return latent 65 | 66 | @torch.cuda.amp.autocast(dtype=torch.float16) 67 | @torch.no_grad() 68 | def decode_latent_to_image(self, latent): 69 | b, f, c, h, w = latent.shape 70 | latent = rearrange(latent, 'b f c h w -> (b f) c h w') 71 | 72 | image = [] 73 | for latent_ in latent: 74 | image_ = super().decode_latent_to_image(latent_[None]) 75 | image.append(image_) 76 | image = torch.cat(image, dim=0) 77 | # image = super().decode_latent_to_image(latent) 78 | image = rearrange(image, '(b f) c h w -> b f c h w', b=b) 79 | return image 80 | 81 | @torch.no_grad() 82 | def get_cond_image(self, batch, mode): 83 | cond_image = batch['input_video'] 84 | 85 | # ip2p does not scale cond image, so we unscale the cond image 86 | cond_image = self.encode_image_to_latent(cond_image) / self.scale_factor 87 | if mode == 'train': 88 | if int(torch.rand(1)) < self.cond_image_dropout: 89 | cond_image = torch.zeros_like(cond_image) 90 | return cond_image 91 | 92 | @torch.no_grad() 93 | def get_diffused_image(self, batch, mode): 94 | x = batch['edited_video'] 95 | b, *_ = x.shape 96 | x = self.encode_image_to_latent(x) 97 | eps = torch.randn_like(x) 98 | 99 | if mode == 'train': 100 | t = torch.randint(0, self.num_timesteps, (b,), device=x.device).long() 101 | else: 102 | t = torch.full((b,), self.num_timesteps-1, device=x.device, dtype=torch.long) 103 | x_t = self.add_noise(x, t, eps) 104 | 105 | if self.prediction_type == 'epsilon': 106 | return x_t, eps, t 107 | else: 108 | return x_t, x, t 109 | 110 | @torch.no_grad() 111 | def process_batch(self, batch, mode): 112 | cond_image = self.get_cond_image(batch, mode) 113 | diffused_image, target, t = self.get_diffused_image(batch, mode) 114 | prompt = self.get_prompt(batch, mode) 115 | 116 | model_kwargs = { 117 | 'encoder_hidden_states': prompt 118 | } 119 | 120 | return { 121 | 'diffused_input': diffused_image, 122 | 'condition': cond_image, 123 | 'target': target, 124 | 't': t, 125 | 'model_kwargs': model_kwargs, 126 | } 127 | 128 | def training_step(self, batch, batch_idx): 129 | processed_batch = self.process_batch(batch, mode='train') 130 | diffused_input = processed_batch['diffused_input'] 131 | condition = processed_batch['condition'] 132 | target = processed_batch['target'] 133 | t = processed_batch['t'] 134 | model_kwargs = processed_batch['model_kwargs'] 135 | 136 | model_input = torch.cat([diffused_input, condition], dim=2) # b, f, c, h, w 137 | model_input = rearrange(model_input, 'b f c h w -> b c f h w') 138 | 139 | pred = self.unet(model_input, t, **model_kwargs).sample 140 | pred = rearrange(pred, 'b c f h w -> b f c h w') 141 | 142 | loss = self.get_loss(pred, target, t) 143 | self.log('train_loss', loss, sync_dist=True) 144 | 145 | latent_pred = self.predict_x_0_from_x_t(pred, t, diffused_input) 146 | image_pred = self.decode_latent_to_image(latent_pred) 147 | 148 | res_dict = { 149 | 'loss': loss, 150 | 'pred': image_pred, 151 | } 152 | return res_dict 153 | 154 | @torch.no_grad() 155 | @torch.cuda.amp.autocast(dtype=torch.bfloat16) 156 | def validation_step(self, batch, batch_idx): 157 | from .inference.inference import InferenceIP2PVideo 158 | inf_pipe = InferenceIP2PVideo( 159 | self.unet, 160 | beta_start=self.scheduler.config.beta_start, 161 | beta_end=self.scheduler.config.beta_end, 162 | beta_schedule=self.scheduler.config.beta_schedule, 163 | num_ddim_steps=20 164 | ) 165 | 166 | processed_batch = self.process_batch(batch, mode='val') 167 | diffused_input = torch.randn_like(processed_batch['diffused_input']) 168 | 169 | condition = processed_batch['condition'] 170 | img_cond = condition[:, :, :4] 171 | 172 | res = inf_pipe( 173 | latent = diffused_input, 174 | text_cond = processed_batch['model_kwargs']['encoder_hidden_states'], 175 | text_uncond = self.encode_text(['']), 176 | img_cond = img_cond, 177 | text_cfg = self.text_cfg, 178 | img_cfg = self.img_cfg, 179 | ) 180 | 181 | latent_pred = res['latent'] 182 | image_pred = self.decode_latent_to_image(latent_pred) 183 | res_dict = { 184 | 'pred': image_pred, 185 | } 186 | return res_dict 187 | 188 | def configure_optimizers(self): 189 | # optimizer = torch.optim.AdamW(self.unet.parameters(), lr=self.optim_args['lr']) 190 | import bitsandbytes as bnb 191 | params = [] 192 | for name, p in self.unet.named_parameters(): 193 | if ('transformer_in' in name) or ('temp_' in name): 194 | # p.requires_grad = True 195 | params.append(p) 196 | else: 197 | pass 198 | # p.requires_grad = False 199 | optimizer = bnb.optim.Adam8bit(params, lr=self.optim_args['lr'], betas=(0.9, 0.999)) 200 | return optimizer 201 | 202 | def initialize_unet(self, unet_init_weights): 203 | if unet_init_weights is not None: 204 | print(f'INFO: initialize denoising UNet from {unet_init_weights}') 205 | sd = torch.load(unet_init_weights, map_location='cpu') 206 | model_sd = self.unet.state_dict() 207 | # fit input conv size 208 | for k in model_sd.keys(): 209 | if k in sd.keys(): 210 | pass 211 | else: 212 | # handling temporal layers 213 | if (('temp_' in k) or ('transformer_in' in k)) and 'proj_out' in k: 214 | # print(f'INFO: initialize {k} from {model_sd[k].shape} to zeros') 215 | sd[k] = torch.zeros_like(model_sd[k]) 216 | else: 217 | # print(f'INFO: initialize {k} from {model_sd[k].shape} to random') 218 | sd[k] = model_sd[k] 219 | self.unet.load_state_dict(sd) 220 | 221 | class InstructP2PVideoTrainerTemporal(InstructP2PVideoTrainer): 222 | def initialize_unet(self, unet_init_weights): 223 | if unet_init_weights is not None: 224 | print(f'INFO: initialize denoising UNet from {unet_init_weights}') 225 | sd_init_weights, motion_module_init_weights = unet_init_weights 226 | sd = torch.load(sd_init_weights, map_location='cpu') 227 | motion_sd = torch.load(motion_module_init_weights, map_location='cpu') 228 | assert len(sd) + len(motion_sd) == len(self.unet.state_dict()), f'Improper state dict length, got {len(sd) + len(motion_sd)} expected {len(self.unet.state_dict())}' 229 | sd.update(motion_sd) 230 | for k, v in self.unet.state_dict().items(): 231 | if 'pos_encoder.pe' in k: 232 | sd[k] = v # the size of pe may change 233 | self.unet.load_state_dict(sd) 234 | 235 | def configure_optimizers(self): 236 | import bitsandbytes as bnb 237 | motion_params = [] 238 | remaining_params = [] 239 | for name, p in self.unet.named_parameters(): 240 | if ('motion' in name): 241 | motion_params.append(p) 242 | else: 243 | remaining_params.append(p) 244 | optimizer = bnb.optim.Adam8bit([ 245 | {'params': motion_params, 'lr': self.optim_args['lr']}, 246 | ], betas=(0.9, 0.999)) 247 | return optimizer -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | einops 3 | wandb 4 | transformers 5 | diffusers 6 | pytorch-lightning 7 | opencv-python 8 | opencv-contrib-python 9 | omegaconf 10 | open-clip-torch 11 | jsonlines 12 | albumentations 13 | diff-match-patch 14 | git+https://github.com/openai/CLIP.git 15 | deepspeed -------------------------------------------------------------------------------- /video_edit.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from misc_utils.train_utils import unit_test_create_model\n", 10 | "from misc_utils.image_utils import save_tensor_to_gif, save_tensor_to_images\n", 11 | "config_path = 'configs/instruct_v2v_inference.yaml'\n", 12 | "diffusion_model = unit_test_create_model(config_path)" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "import torch\n", 22 | "ckpt = torch.load('insv2v.pth', map_location='cpu')\n", 23 | "diffusion_model.load_state_dict(ckpt, strict=False)" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "# edit params\n", 33 | "EDIT_PROMPT = 'make the car red Porsche and drive alone beach'\n", 34 | "VIDEO_CFG = 1.2\n", 35 | "TEXT_CFG = 7.5\n", 36 | "LONG_VID_SAMPLING_CORRECTION_STEP = 0.5\n", 37 | "\n", 38 | "# video params\n", 39 | "VIDEO_PATH = 'data/car-turn.mp4'\n", 40 | "IMGSIZE = 384\n", 41 | "NUM_FRAMES = 32\n", 42 | "VIDEO_SAMPLE_RATE = 10\n", 43 | "\n", 44 | "# sampling params\n", 45 | "FRAMES_IN_BATCH = 16\n", 46 | "NUM_REF_FRAMES = 4\n", 47 | "USE_MOTION_COMPENSATION = True" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "from pl_trainer.inference.inference import InferenceIP2PVideo, InferenceIP2PVideoOpticalFlow\n", 57 | "if USE_MOTION_COMPENSATION:\n", 58 | " inf_pipe = InferenceIP2PVideoOpticalFlow(\n", 59 | " unet = diffusion_model.unet,\n", 60 | " num_ddim_steps=20,\n", 61 | " scheduler='ddpm'\n", 62 | " )\n", 63 | "else:\n", 64 | " inf_pipe = InferenceIP2PVideo(\n", 65 | " unet = diffusion_model.unet,\n", 66 | " num_ddim_steps=20,\n", 67 | " scheduler='ddpm'\n", 68 | " )" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "from dataset.single_video_dataset import SingleVideoDataset\n", 78 | "dataset = SingleVideoDataset(\n", 79 | " video_file=VIDEO_PATH,\n", 80 | " video_description='',\n", 81 | " sampling_fps=VIDEO_SAMPLE_RATE,\n", 82 | " num_frames=NUM_FRAMES,\n", 83 | " output_size=(IMGSIZE, IMGSIZE)\n", 84 | ")\n", 85 | "batch = dataset[20] # start from 20th frame\n", 86 | "batch = {k: v.cuda()[None] if isinstance(v, torch.Tensor) else v for k, v in batch.items()}" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "def split_batch(cond, frames_in_batch=16, num_ref_frames=4):\n", 96 | " frames_in_following_batch = frames_in_batch - num_ref_frames\n", 97 | " conds = [cond[:, :frames_in_batch]]\n", 98 | " frame_ptr = frames_in_batch\n", 99 | " num_ref_frames_each_batch = []\n", 100 | "\n", 101 | " while frame_ptr < cond.shape[1]:\n", 102 | " remaining_frames = cond.shape[1] - frame_ptr\n", 103 | " if remaining_frames < frames_in_batch:\n", 104 | " frames_in_following_batch = remaining_frames\n", 105 | " else:\n", 106 | " frames_in_following_batch = frames_in_batch - num_ref_frames\n", 107 | " this_ref_frames = frames_in_batch - frames_in_following_batch\n", 108 | " conds.append(cond[:, frame_ptr:frame_ptr+frames_in_following_batch])\n", 109 | " frame_ptr += frames_in_following_batch\n", 110 | " num_ref_frames_each_batch.append(this_ref_frames)\n", 111 | "\n", 112 | " return conds, num_ref_frames_each_batch" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "cond = [diffusion_model.encode_image_to_latent(frames) / 0.18215 for frames in batch['frames'].chunk(16, dim=1)] # when encoding, chunk the frames to avoid oom in vae, you can reduce the 16 if you have a smaller gpu\n", 122 | "cond = torch.cat(cond, dim=1)\n", 123 | "text_cond = diffusion_model.encode_text([EDIT_PROMPT])\n", 124 | "text_uncond = diffusion_model.encode_text([''])\n", 125 | "conds, num_ref_frames_each_batch = split_batch(cond, frames_in_batch=FRAMES_IN_BATCH, num_ref_frames=NUM_REF_FRAMES)\n", 126 | "splitted_frames, _ = split_batch(batch['frames'], frames_in_batch=FRAMES_IN_BATCH, num_ref_frames=NUM_REF_FRAMES)" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "# First video clip\n", 136 | "cond1 = conds[0]\n", 137 | "latent_pred_list = []\n", 138 | "init_latent = torch.randn_like(cond1)\n", 139 | "latent_pred = inf_pipe(\n", 140 | " latent = init_latent,\n", 141 | " text_cond = text_cond,\n", 142 | " text_uncond = text_uncond,\n", 143 | " img_cond = cond1,\n", 144 | " text_cfg = TEXT_CFG,\n", 145 | " img_cfg = VIDEO_CFG,\n", 146 | ")['latent']\n", 147 | "latent_pred_list.append(latent_pred)\n", 148 | "\n", 149 | "\n", 150 | "# Subsequent video clips\n", 151 | "for prev_cond, cond_, prev_frame, curr_frame, num_ref_frames_ in zip(\n", 152 | " conds[:-1], conds[1:], splitted_frames[:-1], splitted_frames[1:], num_ref_frames_each_batch\n", 153 | "):\n", 154 | " init_latent = torch.cat([init_latent[:, -num_ref_frames_:], torch.randn_like(cond_)], dim=1)\n", 155 | " cond_ = torch.cat([prev_cond[:, -num_ref_frames_:], cond_], dim=1)\n", 156 | " if USE_MOTION_COMPENSATION:\n", 157 | " ref_images = prev_frame[:, -num_ref_frames_:]\n", 158 | " query_images = curr_frame\n", 159 | " additional_kwargs = {\n", 160 | " 'ref_images': ref_images,\n", 161 | " 'query_images': query_images,\n", 162 | " }\n", 163 | " else:\n", 164 | " additional_kwargs = {}\n", 165 | " latent_pred = inf_pipe.second_clip_forward(\n", 166 | " latent = init_latent, \n", 167 | " text_cond = text_cond,\n", 168 | " text_uncond = text_uncond,\n", 169 | " img_cond = cond_,\n", 170 | " latent_ref = latent_pred[:, -num_ref_frames_:],\n", 171 | " noise_correct_step = LONG_VID_SAMPLING_CORRECTION_STEP,\n", 172 | " text_cfg = TEXT_CFG,\n", 173 | " img_cfg = VIDEO_CFG,\n", 174 | " **additional_kwargs,\n", 175 | " )['latent']\n", 176 | " latent_pred_list.append(latent_pred[:, num_ref_frames_:])\n", 177 | "\n", 178 | "# Save GIF\n", 179 | "latent_pred = torch.cat(latent_pred_list, dim=1)\n", 180 | "image_pred = diffusion_model.decode_latent_to_image(latent_pred).clip(-1, 1)" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": null, 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [ 189 | "original_images = batch['frames'].cpu()\n", 190 | "transferred_images = image_pred.float().cpu()\n", 191 | "concat_images = torch.cat([original_images, transferred_images], dim=4)\n", 192 | "\n", 193 | "save_tensor_to_gif(concat_images, 'results/video_edit.gif', fps=5)\n", 194 | "save_tensor_to_images(transferred_images, 'results/video_edit_images')" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "# visualize the gif\n", 204 | "from IPython.display import Image\n", 205 | "Image(filename='results/video_edit.gif')" 206 | ] 207 | } 208 | ], 209 | "metadata": { 210 | "kernelspec": { 211 | "display_name": "pytorch", 212 | "language": "python", 213 | "name": "python3" 214 | }, 215 | "language_info": { 216 | "codemirror_mode": { 217 | "name": "ipython", 218 | "version": 3 219 | }, 220 | "file_extension": ".py", 221 | "mimetype": "text/x-python", 222 | "name": "python", 223 | "nbconvert_exporter": "python", 224 | "pygments_lexer": "ipython3", 225 | "version": "3.10.10" 226 | } 227 | }, 228 | "nbformat": 4, 229 | "nbformat_minor": 2 230 | } 231 | -------------------------------------------------------------------------------- /video_prompt_to_prompt.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import os 3 | import jsonlines 4 | import json 5 | from misc_utils.video_ptp_utils import get_models_of_damo_model 6 | import torch 7 | import numpy as np 8 | from einops import rearrange 9 | from misc_utils.ptp_utils import get_text_embedding_openclip, encode_text_openclip, Text, Edit, Insert, Delete 10 | from misc_utils.video_ptp_utils import compute_diff 11 | from pl_trainer.inference.inference_damo import InferenceDAMO_PTP_v2 12 | import cv2 13 | from misc_utils.image_utils import images_to_gif 14 | from misc_utils.clip_similarity import ClipSimilarity 15 | 16 | def save_images_to_folder(source_images, target_images, folder, seed): 17 | os.makedirs(folder, exist_ok=True) 18 | for i, (src_image, tgt_image) in enumerate(zip(source_images, target_images)): 19 | src_img = cv2.cvtColor((src_image*255).astype(np.uint8), cv2.COLOR_RGB2BGR) 20 | tgt_img = cv2.cvtColor((tgt_image*255).astype(np.uint8), cv2.COLOR_RGB2BGR) 21 | cv2.imwrite(os.path.join(folder, f'{seed}_0_{i:04d}.jpg'), src_img) 22 | cv2.imwrite(os.path.join(folder, f'{seed}_1_{i:04d}.jpg'), tgt_img) 23 | 24 | def save_images_to_gif(source_images, target_images, folder, seed): 25 | os.makedirs(folder, exist_ok=True) 26 | # images_to_gif(source_images, os.path.join(folder, f'{seed}_0.gif'), fps=5) 27 | # images_to_gif(target_images, os.path.join(folder, f'{seed}_1.gif'), fps=5) 28 | concat_image = np.concatenate([source_images, target_images], axis=2) 29 | images_to_gif(concat_image, os.path.join(folder, f'{seed}_concat.gif'), fps=5) 30 | 31 | def append_dict_to_jsonl(file_path: str, dict_obj: dict) -> None: 32 | with open(file_path, 'a') as f: 33 | f.write(json.dumps(dict_obj)) 34 | f.write("\n") # Write a new line at the end to prepare for the next JSON object 35 | 36 | # %% 37 | def torch_to_numpy(x): 38 | return (x.float().squeeze().detach().cpu().numpy().transpose(0, 2, 3, 1) / 2 + 0.5).clip(0, 1) 39 | def str_to_prompt(prompt): 40 | input_text = prompt['input'].strip('.') 41 | output_text = prompt['output'].strip('.') 42 | return compute_diff(input_text, output_text) 43 | 44 | def get_ptp_prompt(prompt, edit_weight=1.): 45 | PROMPT = str_to_prompt(prompt) 46 | for p in PROMPT: 47 | if isinstance(p, (Edit, Insert)): 48 | p.weight = edit_weight 49 | print(PROMPT) 50 | source_prompt = ' '.join(x.old for x in PROMPT) 51 | target_prompt = ' '.join(x.new for x in PROMPT) 52 | context_uncond = get_text_embedding_openclip("", text_encoder, text_model.device) 53 | old_context = get_text_embedding_openclip(source_prompt, text_encoder, text_model.device) 54 | context = get_text_embedding_openclip(target_prompt, text_encoder, text_model.device) 55 | key, value = encode_text_openclip(PROMPT, text_encoder, text_model.device) 56 | return { 57 | 'source_prompt': source_prompt, 58 | 'target_prompt': target_prompt, 59 | 'context_uncond': context_uncond, 60 | 'old_context': old_context, 61 | 'context': context, 62 | 'key': key, 63 | 'value': value, 64 | } 65 | def process_one_sample(prompt, seed=None, guidance_scale=9, num_ddim_steps=30, scheduler='ddim', sa_end_time=0.4, ca_end_time=0.8, edit_weight=1.): 66 | prompt_dict = get_ptp_prompt(prompt, edit_weight=edit_weight) 67 | 68 | if seed is None: 69 | seed = np.random.randint(0, 1000000) 70 | torch.random.manual_seed(seed) 71 | latent = torch.randn(1, 4, 16, 32, 32).cuda() 72 | inf_pipe = InferenceDAMO_PTP_v2(unet=unet, guidance_scale=guidance_scale, num_ddim_steps=num_ddim_steps, scheduler=scheduler) 73 | 74 | with torch.cuda.amp.autocast(dtype=torch.float16): 75 | res = inf_pipe( 76 | latent=latent, 77 | context = prompt_dict['context'], 78 | old_context = prompt_dict['old_context'], 79 | old_to_new_context=[prompt_dict['key'], prompt_dict['value']], 80 | uncond_context = prompt_dict['context_uncond'], 81 | sa_end_time=sa_end_time, 82 | ca_end_time=ca_end_time, 83 | ) 84 | pred_latent = res['latent'] 85 | pred_latent_old = res['latent_old'] 86 | 87 | latents = rearrange(pred_latent, 'b d f h w -> f b d h w') 88 | latents_old = rearrange(pred_latent_old, 'b d f h w -> f b d h w') 89 | with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.float16): 90 | pred_image = [vae.decode(latent / 0.18215) for latent in latents] 91 | pred_image_old = [vae.decode(latent / 0.18215) for latent in latents_old] 92 | pred_images = torch.stack(pred_image, dim=0) 93 | pred_images_old = torch.stack(pred_image_old, dim=0) 94 | 95 | old_prompt_images_np = torch_to_numpy(pred_images_old) 96 | new_prompt_images_np = torch_to_numpy(pred_images) 97 | 98 | return old_prompt_images_np, new_prompt_images_np, seed, prompt_dict 99 | 100 | 101 | if __name__ == '__main__': 102 | import argparse 103 | parser = argparse.ArgumentParser() 104 | parser.add_argument('--start', type=int, default=0) 105 | parser.add_argument('--end', type=int, default=10) 106 | parser.add_argument('--prompt_source', type=str, default='ip2p') 107 | parser.add_argument('--num_sample_each_prompt', '-n', type=int, default=1) 108 | args = parser.parse_args() 109 | 110 | if args.prompt_source == 'ip2p': 111 | prompt_meta_file = 'data/gpt-generated-prompts.jsonl' 112 | with open(prompt_meta_file, 'r') as f: 113 | prompts = [(i, json.loads(line)) for i, line in enumerate(f.readlines())] 114 | output_dir = 'video_ptp/raw_generated' 115 | elif args.prompt_source == 'webvid': 116 | root_dir = 'webvid_edit_prompt' 117 | files = os.listdir(root_dir) 118 | files = sorted(files, key=lambda x: int(x.split('.')[0])) 119 | prompts = [] 120 | for file in files: 121 | with open(os.path.join(root_dir, file), 'r') as f: 122 | prompt_idx = int(file.split('.')[0]) 123 | prompts.append((prompt_idx, json.load(f))) 124 | output_dir = 'video_ptp/raw_generated_webvid' 125 | else: 126 | raise ValueError(f'Unknown prompt source: {args.prompt_source}') 127 | 128 | 129 | # %% 130 | vae_config = 'configs/instruct_v2v.yaml' 131 | unet_config = 'modules/damo_text_to_video/configuration.json' 132 | vae_ckpt = 'VAE_PATH' 133 | unet_ckpt = 'UNet_PATH' 134 | text_model_ckpt = 'Text_MODEL_PATH' 135 | 136 | vae, unet, text_model = get_models_of_damo_model( 137 | unet_config=unet_config, 138 | unet_ckpt=unet_ckpt, 139 | vae_config=vae_config, 140 | vae_ckpt=vae_ckpt, 141 | text_model_ckpt=text_model_ckpt, 142 | ) 143 | text_encoder = text_model.encode_with_transformer 144 | clip_sim = ClipSimilarity(name='ViT-L/14').cuda() 145 | # %% 146 | 147 | for prompt_idx, prompt in prompts: 148 | if prompt_idx < args.start: 149 | continue 150 | if prompt_idx >= args.end: 151 | break 152 | print(prompt_idx, prompt) 153 | output_folder_idx = f'{prompt_idx:07d}' 154 | 155 | prompt_json_file = os.path.join(output_dir, output_folder_idx, 'prompt.json') 156 | os.makedirs(os.path.dirname(prompt_json_file), exist_ok=True) 157 | with open(prompt_json_file, 'w') as f: 158 | json.dump(prompt, f) 159 | 160 | meta_file_path = os.path.join(output_dir, output_folder_idx, 'metadata.jsonl') 161 | if os.path.exists(meta_file_path): 162 | with jsonlines.open(meta_file_path, 'r') as f: 163 | used_seed = [int(line['seed']) for line in f] 164 | num_existing_samples = len(used_seed) 165 | else: 166 | num_existing_samples = 0 167 | used_seed = [] 168 | print(f'num_existing_samples: {num_existing_samples}, used_seed: {used_seed}') 169 | 170 | for _ in range(num_existing_samples, args.num_sample_each_prompt): 171 | # generate a random configure 172 | seed = np.random.randint(0, 1000000) 173 | while seed in used_seed: 174 | seed = np.random.randint(0, 1000000) 175 | used_seed.append(seed) 176 | 177 | rng = np.random.RandomState(seed=seed) 178 | guidance_scale = rng.randint(5, 13) 179 | sa_end_time = float('{:.2f}'.format(rng.choice(np.linspace(0.3, 0.45, 4)))) 180 | ca_end_time = float('{:.2f}'.format(rng.choice(np.linspace(0.6, 0.85, 6)))) 181 | edit_weight = rng.randint(1, 6) 182 | generate_config = { 183 | 'seed': seed, 184 | 'guidance_scale': guidance_scale, 185 | 'sa_end_time': sa_end_time, 186 | 'ca_end_time': ca_end_time, 187 | 'edit_weight': edit_weight, 188 | } 189 | print(generate_config) 190 | 191 | # %% 192 | source_prompt_images, target_prompt_images, seed, prompt_dict = process_one_sample( 193 | prompt, 194 | **generate_config, 195 | ) 196 | 197 | source_prompt_images_torch = torch.from_numpy(source_prompt_images.transpose(0, 3, 1, 2)).cuda() 198 | target_prompt_images_torch = torch.from_numpy(target_prompt_images.transpose(0, 3, 1, 2)).cuda() 199 | 200 | with torch.no_grad(): 201 | sim_0, sim_1, sim_dir, sim_image = clip_sim(source_prompt_images_torch, target_prompt_images_torch, [prompt_dict['source_prompt']], [prompt_dict['target_prompt']]) 202 | 203 | simi_dict = { 204 | 'sim_0': sim_0.mean().item(), 205 | 'sim_1': sim_1.mean().item(), 206 | 'sim_dir': sim_dir.mean().item(), 207 | 'sim_image': sim_image.mean().item(), 208 | } 209 | 210 | generate_config.update(simi_dict) 211 | 212 | output_folder_img = os.path.join(output_dir, output_folder_idx, 'image') 213 | output_folder_gif = os.path.join(output_dir, output_folder_idx, 'gif') 214 | 215 | if ( 216 | sim_0.mean().item() > 0.2 and sim_1.mean().item() > 0.2 and sim_dir.mean().item() > 0.2 and sim_image.mean().item() > 0.5 217 | ): 218 | save_images_to_folder(source_prompt_images, target_prompt_images, output_folder_img, seed) 219 | save_images_to_gif(source_prompt_images, target_prompt_images, output_folder_gif, seed) 220 | 221 | append_dict_to_jsonl(meta_file_path, generate_config) 222 | 223 | 224 | --------------------------------------------------------------------------------