├── .gitignore
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── THIRD-PARTY
├── callbacks
├── common.py
└── instruct_p2p_video.py
├── configs
├── instruct_v2v.yaml
└── instruct_v2v_inference.yaml
├── data
├── airbrush-painting.mp4
├── airplane-and-contrail.mp4
├── audi-snow-trail.mp4
├── car-turn.mp4
├── cat-in-the-sun.mp4
├── dirt-road-driving.mp4
├── drift-turn.mp4
├── earth-full-view.mp4
├── eiffel-flyover.mp4
├── ferris-wheel-timelapse.mp4
├── gold-fish.mp4
├── ice-hockey.mp4
├── miami-surf.mp4
├── raindrops.mp4
├── red-roses-sunny-day.mp4
└── swans.mp4
├── dataset
├── loveu_tgve_dataset.py
├── loveu_tgve_edit_prompt_dict.json
├── single_video_dataset.py
└── videoP2P.py
├── figures
├── data_pipe.png
├── synthetic_sample
│ ├── synthetic_video_106_0.gif
│ ├── synthetic_video_116_0.gif
│ ├── synthetic_video_141_0.gif
│ ├── synthetic_video_18_0.gif
│ ├── synthetic_video_192_0.gif
│ ├── synthetic_video_197_0.gif
│ ├── synthetic_video_1_0.gif
│ ├── synthetic_video_24_0.gif
│ ├── synthetic_video_81_0.gif
│ └── synthetic_video_92_0.gif
├── teaser.png
└── videos
│ ├── TGVE_video_edit.mp4
│ ├── airbrush-painting_object.gif
│ ├── airplane-and-contrail_background.gif
│ ├── audi-snow-trail_background.gif
│ ├── cat-in-the-sun_background.gif
│ ├── cat-in-the-sun_object.gif
│ ├── dirt-road-driving_style.gif
│ ├── drift-turn_style.gif
│ ├── earth-full-view_background.gif
│ ├── eiffel-flyover_background.gif
│ ├── ferris-wheel-timelapse_background.gif
│ ├── gold-fish_style.gif
│ ├── ice-hockey_object.gif
│ ├── miami-surf_background.gif
│ ├── raindrops_style.gif
│ ├── red-roses-sunny-day_background.gif
│ ├── red-roses-sunny-day_style.gif
│ ├── swans_background.gif
│ └── swans_object.gif
├── gradio_demo.py
├── insv2v_run_loveu_tgve.py
├── main.py
├── misc_utils
├── clip_similarity.py
├── flow_utils.py
├── image_utils.py
├── model_utils.py
├── ptp_utils.py
├── train_utils.py
└── video_ptp_utils.py
├── modules
├── damo_text_to_video
│ ├── configuration.json
│ ├── text_model.py
│ └── unet_sd.py
├── kl_autoencoder
│ └── autoencoder.py
├── openclip
│ └── modules.py
├── video_unet_temporal
│ ├── attention.py
│ ├── motion_module.py
│ ├── resnet.py
│ ├── unet.py
│ └── unet_blocks.py
└── vqvae
│ ├── autoencoder.py
│ └── model.py
├── pl_trainer
├── diffusion.py
├── inference
│ ├── inference.py
│ └── inference_damo.py
└── instruct_p2p_video.py
├── requirements.txt
├── video_edit.ipynb
└── video_prompt_to_prompt.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 | dummy_dataset
131 | wandb
132 | experiments
133 | results
134 | .DS_Store
135 | pretrained_models
136 | debug.pkl
137 | ablation_results
138 | ablation_output*
139 | video_ptp
140 | *.pt
141 | *.pth
142 | webvid_edit_prompt
143 | evaluation_results
144 | tmp
145 | vptp_results
146 | repolint_result.txt
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | ## Code of Conduct
2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
4 | opensource-codeofconduct@amazon.com with any additional questions or comments.
5 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing Guidelines
2 |
3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional
4 | documentation, we greatly value feedback and contributions from our community.
5 |
6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary
7 | information to effectively respond to your bug report or contribution.
8 |
9 |
10 | ## Reporting Bugs/Feature Requests
11 |
12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features.
13 |
14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already
15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful:
16 |
17 | * A reproducible test case or series of steps
18 | * The version of our code being used
19 | * Any modifications you've made relevant to the bug
20 | * Anything unusual about your environment or deployment
21 |
22 |
23 | ## Contributing via Pull Requests
24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that:
25 |
26 | 1. You are working against the latest source on the *main* branch.
27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already.
28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted.
29 |
30 | To send us a pull request, please:
31 |
32 | 1. Fork the repository.
33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change.
34 | 3. Ensure local tests pass.
35 | 4. Commit to your fork using clear commit messages.
36 | 5. Send us a pull request, answering any default questions in the pull request interface.
37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation.
38 |
39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and
40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/).
41 |
42 |
43 | ## Finding contributions to work on
44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start.
45 |
46 |
47 | ## Code of Conduct
48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
50 | opensource-codeofconduct@amazon.com with any additional questions or comments.
51 |
52 |
53 | ## Security issue notifications
54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue.
55 |
56 |
57 | ## Licensing
58 |
59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution.
60 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT No Attribution
2 |
3 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy of
6 | this software and associated documentation files (the "Software"), to deal in
7 | the Software without restriction, including without limitation the rights to
8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
9 | the Software, and to permit persons to whom the Software is furnished to do so.
10 |
11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
12 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
13 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
14 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
15 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
16 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
17 |
18 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## This is the code release for the ICLR2024 paper [Consistent Video-to-Video Transfer Using Synthetic Dataset](https://arxiv.org/abs/2311.00213).
2 |
3 | 
4 |
5 | ## Quick Links
6 | * [Installation](#installation)
7 | * [Video Editing](#video-editing) 🔥
8 | * [Synthetic Video Prompt-to-Prompt Dataset](#synthetic-video-prompt-to-prompt-dataset)
9 | * [Training](#training)
10 | * [Create Synthetic Video Dataset](#create-synthetic-video-dataset)
11 |
12 | ## Updates
13 | * 2024/02/13: The official synthetic data and model will not be released due to Amazon policy, but we provide a third party reproduction of the synthetic data and model weights. Please refer to this [github repo](https://github.com/cplusx/INSV2V-3rd-pty-reprod)
14 | * 2023/11/29: We have updated paper with more comparison to recent baseline methods and updated the [comparison video](#visual-comparison-to-other-methods). Gradio demo code is uploaded.
15 |
16 | ## Installation
17 | ```bash
18 | git clone https://github.com/amazon-science/instruct-video-to-video.git
19 | pip install -r requirements.txt
20 | ```
21 | NOTE: The code is tested on PyTorch 2.1.0+cu11.8 and corresponding xformers version. Any PyTorch version > 2.0 should work but please install the right corresponding xformers version.
22 | ## Video Editing
23 | We are undergoing the model release process. Please stay tuned.
24 |
25 | Download the [InsV2V model weights](https://github.com/cplusx/INSV2V-3rd-pty-reprod) and change the ckpt path in the following notebook.
26 |
27 | ✨🚀 This [notebook](video_edit.ipynb) provide a sample code to conduct text-based video editing.
28 |
29 | ### Download LOVEU Dataset for Testing
30 | Please follow the instructions in the [LOVEU Dataset](https://sites.google.com/view/loveucvpr23/track4) to download the dataset. Use the following [script](insv2v_run_loveu_tgve.py) to run editing on the LOVEU dataset:
31 | ```bash
32 | python insv2v_run_loveu_tgve.py \
33 | --config configs/instruct_v2v.yaml \
34 | --ckpt-path [PATH TO THE CHECKPOINT] \
35 | --data-dir [PATH TO THE LOVEU DATASET] \
36 | --with_optical_flow \ # use motion compensation
37 | --text-cfg 7.5 10 \
38 | --video-cfg 1.2 1.5 \
39 | --image-size 256 384
40 | ```
41 | Note: you may need to try different combination of image resolution, video/text classifier free guidance scale to find the best editing results.
42 |
43 | Example results of editing LOVEU-TGVE Dataset:
44 |
45 |
46 |
47 |
48 |  |
49 |  |
50 |
51 |
52 |  |
53 |  |
54 |
55 |
56 |  |
57 |  |
58 |
59 |
60 |  |
61 |  |
62 |
63 |
64 |  |
65 |  |
66 |
67 |
68 |
69 |
70 | ## Synthetic Video Prompt-to-Prompt Dataset
71 |
72 | Generation pipeline of the synthetic video dataset:
73 | 
74 |
75 | Examples of the synthetic video dataset:
76 |
77 |
78 |  |
79 |  |
80 |
81 |
82 |  |
83 |  |
84 |
85 |
86 |  |
87 |  |
88 |
89 |
90 |  |
91 |  |
92 |
93 |
94 |  |
95 |  |
96 |
97 |
98 |
99 | ## Training
100 |
101 | ### Download Foundational Models
102 | [Download](https://drive.google.com/file/d/1R9sWsnGZUa5P8IB5DDfD9eU-T9SQLsFw/view?usp=sharing) the foundational models and place them in the `pretrained_models` folder.
103 |
104 | ### Download Synthetic Video Dataset
105 | [See download link in the third party reproduction](https://github.com/cplusx/INSV2V-3rd-pty-reprod)
106 |
107 | ### Train the Model
108 | Put the synthetic video dataset in the `video_ptp` folder.
109 |
110 | Run the following command to train the model:
111 | ```bash
112 | python main.py --config configs/instruct_v2v.yaml -r # add -r to resume training if the training is interrupted
113 | ```
114 |
115 | ## Create Synthetic Video Dataset
116 | If you want to create your own synthetic video dataset, please follow the instructions
117 | * Download the modelscope VAE, UNet and text encoder weights from [here](https://huggingface.co/damo-vilab/modelscope-damo-text-to-video-synthesis/tree/main)
118 | * Replace the model path in the [`video_prompt_to_prompt.py`](video_prompt_to_prompt.py) file
119 | ```
120 | vae_ckpt = 'VAE_PATH'
121 | unet_ckpt = 'UNet_PATH'
122 | text_model_ckpt = 'Text_MODEL_PATH'
123 | ```
124 | * Download the edit prompt files from [Instruct Pix2Pix](https://github.com/timothybrooks/instruct-pix2pix). The prompt file should be `gpt-generated-prompts.jsonl`, and change the file path in the `video_prompt_to_prompt.py` accordingly. Or download the WebVid prompt edit file proposed in our paper from [To be released]().
125 | * Run the command to generate the synthetic video dataset:
126 | ```bash
127 | python video_prompt_to_prompt.py
128 | --start [START INDEX] \
129 | --end [END INDEX] \
130 | --prompt_source [ip2p or webvid] \
131 | --num_sample_each_prompt [NUM SAMPLES FOR EACH PROMPT]
132 | ```
133 |
134 | ## Visual Comparison to Other Methods
135 |
136 | https://github.com/amazon-science/instruct-video-to-video/assets/20940184/d3619652-dd75-41a0-92b4-345bbf57de40
137 |
138 |
139 | Links to the baselines used in the video:
140 |
141 | | [Tune-A-Video](https://github.com/showlab/Tune-A-Video) | [Control Video](https://github.com/thu-ml/controlvideo) | [Vid2Vid Zero](https://github.com/baaivision/vid2vid-zero) | [Video P2P](https://github.com/ShaoTengLiu/Video-P2P) |
142 |
143 | | [TokenFlow](https://github.com/omerbt/TokenFlow) | [Render A Video](https://github.com/williamyang1991/Rerender_A_Video) | [Pix2Video](https://github.com/duyguceylan/pix2video) |
144 |
145 | ## Credit
146 | The code was implemented by [Jiaxin Cheng](https://github.com/cplusx) during his internship at the AWS Shanghai Lablet.
147 | ## References
148 | Part of the code and the foundational models are adapted from the following works:
149 | * [Instruct Pix2Pix](https://github.com/timothybrooks/instruct-pix2pix)
150 | * [AnimateDiff](https://github.com/guoyww/animatediff/)
151 |
--------------------------------------------------------------------------------
/callbacks/common.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchvision
4 | import wandb
5 | import cv2
6 |
7 | def unnorm(x):
8 | '''convert from range [-1, 1] to [0, 1]'''
9 | return (x+1) / 2
10 |
11 | def clip_image(x, min=0., max=1.):
12 | return torch.clamp(x, min=min, max=max)
13 |
14 | def format_dtype_and_shape(x):
15 | if isinstance(x, torch.Tensor):
16 | if len(x.shape) == 3 and x.shape[0] == 3:
17 | x = x.permute(1, 2, 0)
18 | if len(x.shape) == 4 and x.shape[1] == 3:
19 | x = x.permute(0, 2, 3, 1)
20 | x = x.detach().cpu().numpy()
21 | return x
22 |
23 | def tensor2image(x):
24 | x = x.float() # handle bf16
25 | '''convert 4D (b, dim, h, w) pytorch tensor to wandb Image class'''
26 | grid_img = torchvision.utils.make_grid(
27 | x, nrow=4
28 | ).permute(1, 2, 0).detach().cpu().numpy()
29 | img = wandb.Image(
30 | grid_img
31 | )
32 | return img
33 |
34 | def save_figure(image, save_path):
35 | os.makedirs(os.path.dirname(save_path), exist_ok=True)
36 | if image.min() < 0:
37 | image = clip_image(unnorm(image))
38 | image = format_dtype_and_shape(image)
39 | image = (image * 255).astype(np.uint8)
40 | cv2.imwrite(save_path, image[..., ::-1])
41 |
42 | def save_sampling_history(image, save_path):
43 | if image.min() < 0:
44 | image = clip_image(unnorm(image))
45 | grid_img = torchvision.utils.make_grid(image, nrow=4)
46 | save_figure(grid_img, save_path)
--------------------------------------------------------------------------------
/callbacks/instruct_p2p_video.py:
--------------------------------------------------------------------------------
1 | from pytorch_lightning.loggers import WandbLogger
2 | from pytorch_lightning.callbacks import Callback
3 | from .common import tensor2image, clip_image, unnorm
4 | from einops import rearrange
5 |
6 | def frame_dim_to_batch_dim(x):
7 | return rearrange(x, 'b f c h w -> (b f) c h w')
8 |
9 | class InstructP2PLogger(Callback):
10 | def __init__(
11 | self,
12 | wandb_logger: WandbLogger=None,
13 | max_num_images: int=16,
14 | ) -> None:
15 | super().__init__()
16 | self.wandb_logger = wandb_logger
17 | self.max_num_images = max_num_images
18 |
19 | def on_train_batch_end(
20 | self, trainer, pl_module,
21 | outputs, batch, batch_idx
22 | ):
23 | # record images in first batch
24 | if batch_idx == 0:
25 | input_image = tensor2image(frame_dim_to_batch_dim(clip_image(unnorm(
26 | batch['input_video'][:self.max_num_images]
27 | ))))
28 | edited_image = tensor2image(frame_dim_to_batch_dim(clip_image(unnorm(
29 | batch['edited_video'][:self.max_num_images]
30 | ))))
31 | pred_image = tensor2image(frame_dim_to_batch_dim(clip_image(unnorm(
32 | outputs['pred'][:self.max_num_images]
33 | ))))
34 | self.wandb_logger.experiment.log({
35 | 'train/input_image': input_image,
36 | 'train/edited_image': edited_image,
37 | 'train/pred': pred_image,
38 | })
39 |
40 | def on_validation_batch_end(
41 | self, trainer, pl_module,
42 | outputs, batch, batch_idx
43 | ):
44 | """Called when the validation batch ends."""
45 | if batch_idx == 0:
46 | input_image = tensor2image(frame_dim_to_batch_dim(clip_image(unnorm(
47 | batch['input_video'][:self.max_num_images]
48 | ))))
49 | edited_image = tensor2image(frame_dim_to_batch_dim(clip_image(unnorm(
50 | batch['edited_video'][:self.max_num_images]
51 | ))))
52 | pred_image = tensor2image(frame_dim_to_batch_dim(clip_image(unnorm(
53 | outputs['pred'][:self.max_num_images]
54 | ))))
55 | self.wandb_logger.experiment.log({
56 | 'val/input_image': input_image,
57 | 'val/edited_image': edited_image,
58 | 'val/pred': pred_image,
59 | })
60 |
--------------------------------------------------------------------------------
/configs/instruct_v2v.yaml:
--------------------------------------------------------------------------------
1 | expt_dir: experiments
2 | expt_name: instruct_v2v
3 | trainer_args:
4 | max_epochs: 10
5 | accelerator: "gpu"
6 | devices: [0,1,2,3]
7 | limit_train_batches: 2048
8 | limit_val_batches: 1
9 | # strategy: "ddp"
10 | strategy: "deepspeed_stage_2"
11 | accumulate_grad_batches: 256
12 | check_val_every_n_epoch: 5
13 | diffusion:
14 | target: pl_trainer.instruct_p2p_video.InstructP2PVideoTrainerTemporal
15 | params:
16 | beta_schedule_args:
17 | beta_schedule: scaled_linear
18 | num_train_timesteps: 1000
19 | beta_start: 0.00085
20 | beta_end: 0.012
21 | clip_sample: false
22 | thresholding: false
23 | prediction_type: epsilon
24 | loss_fn: l2
25 | optim_args:
26 | lr: 1e-5
27 | unet_init_weights:
28 | - pretrained_models/instruct_pix2pix/diffusion_pytorch_model.bin
29 | - pretrained_models/Motion_Module/mm_sd_v15.ckpt
30 | vae_init_weights: pretrained_models/instruct_pix2pix/vqvae.ckpt
31 | text_model_init_weights: pretrained_models/instruct_pix2pix/text.ckpt
32 | scale_factor: 0.18215
33 | guidance_scale: 5 # not used
34 | ddim_sampling_steps: 20
35 | text_cfg: 7.5
36 | img_cfg: 1.2
37 | cond_image_dropout: 0.1
38 | prompt_type: edit_prompt
39 | unet:
40 | target: modules.video_unet_temporal.unet.UNet3DConditionModel
41 | params:
42 | in_channels: 8
43 | out_channels: 4
44 | act_fn: silu
45 | attention_head_dim: 8
46 | block_out_channels:
47 | - 320
48 | - 640
49 | - 1280
50 | - 1280
51 | cross_attention_dim: 768
52 | down_block_types:
53 | - CrossAttnDownBlock3D
54 | - CrossAttnDownBlock3D
55 | - CrossAttnDownBlock3D
56 | - DownBlock3D
57 | up_block_types:
58 | - UpBlock3D
59 | - CrossAttnUpBlock3D
60 | - CrossAttnUpBlock3D
61 | - CrossAttnUpBlock3D
62 | downsample_padding: 1
63 | layers_per_block: 2
64 | mid_block_scale_factor: 1
65 | norm_eps: 1e-05
66 | norm_num_groups: 32
67 | sample_size: 64
68 | use_motion_module: true
69 | motion_module_resolutions:
70 | - 1
71 | - 2
72 | - 4
73 | - 8
74 | motion_module_mid_block: false
75 | motion_module_decoder_only: false
76 | motion_module_type: Vanilla
77 | motion_module_kwargs:
78 | num_attention_heads: 8
79 | num_transformer_block: 1
80 | attention_block_types:
81 | - Temporal_Self
82 | - Temporal_Self
83 | temporal_position_encoding: true
84 | temporal_position_encoding_max_len: 32
85 | temporal_attention_dim_div: 1
86 | vae:
87 | target: modules.kl_autoencoder.autoencoder.AutoencoderKL
88 | params:
89 | embed_dim: 4
90 | ddconfig:
91 | double_z: true
92 | z_channels: 4
93 | resolution: 256
94 | in_channels: 3
95 | out_ch: 3
96 | ch: 128
97 | ch_mult:
98 | - 1
99 | - 2
100 | - 4
101 | - 4
102 | num_res_blocks: 2
103 | attn_resolutions: []
104 | dropout: 0.0
105 | lossconfig:
106 | target: torch.nn.Identity
107 | text_model:
108 | target: modules.openclip.modules.FrozenCLIPEmbedder
109 | params:
110 | freeze: true
111 | data:
112 | batch_size: 1
113 | val_batch_size: 1
114 | train:
115 | target: dataset.videoP2P.VideoPromptToPromptMotionAug
116 | params:
117 | root_dirs:
118 | - video_ptp/raw_generated
119 | - video_ptp/raw_generated_webvid
120 | num_frames: 16
121 | zoom_ratio: 0.2
122 | max_zoom: 1.25
123 | translation_ratio: 0.7
124 | translation_range: [0, 0.2]
125 | val:
126 | target: dataset.videoP2P.VideoPromptToPromptMotionAug
127 | params:
128 | root_dirs:
129 | - video_ptp/raw_generated
130 | num_frames: 16
131 | zoom_ratio: 0.2
132 | max_zoom: 1.25
133 | translation_ratio: 0.7
134 | translation_range: [0, 0.2]
135 | callbacks:
136 | - target: pytorch_lightning.callbacks.ModelCheckpoint
137 | params:
138 | dirpath: "${expt_dir}/${expt_name}"
139 | filename: "{epoch:04d}"
140 | monitor: epoch
141 | mode: max
142 | save_top_k: 5
143 | save_last: true
144 | - target: callbacks.instruct_p2p_video.InstructP2PLogger
145 | params:
146 | max_num_images: 1
147 | require_wandb: true
--------------------------------------------------------------------------------
/configs/instruct_v2v_inference.yaml:
--------------------------------------------------------------------------------
1 | diffusion:
2 | target: pl_trainer.instruct_p2p_video.InstructP2PVideoTrainerTemporal
3 | params:
4 | beta_schedule_args:
5 | beta_schedule: scaled_linear
6 | num_train_timesteps: 1000
7 | beta_start: 0.00085
8 | beta_end: 0.012
9 | clip_sample: false
10 | thresholding: false
11 | prediction_type: epsilon
12 | loss_fn: l2
13 | optim_args:
14 | lr: 1e-5
15 | scale_factor: 0.18215
16 | guidance_scale: 5 # not used
17 | ddim_sampling_steps: 20
18 | text_cfg: 7.5
19 | img_cfg: 1.2
20 | cond_image_dropout: 0.1
21 | prompt_type: edit_prompt
22 | unet:
23 | target: modules.video_unet_temporal.unet.UNet3DConditionModel
24 | params:
25 | in_channels: 8
26 | out_channels: 4
27 | act_fn: silu
28 | attention_head_dim: 8
29 | block_out_channels:
30 | - 320
31 | - 640
32 | - 1280
33 | - 1280
34 | cross_attention_dim: 768
35 | down_block_types:
36 | - CrossAttnDownBlock3D
37 | - CrossAttnDownBlock3D
38 | - CrossAttnDownBlock3D
39 | - DownBlock3D
40 | up_block_types:
41 | - UpBlock3D
42 | - CrossAttnUpBlock3D
43 | - CrossAttnUpBlock3D
44 | - CrossAttnUpBlock3D
45 | downsample_padding: 1
46 | layers_per_block: 2
47 | mid_block_scale_factor: 1
48 | norm_eps: 1e-05
49 | norm_num_groups: 32
50 | sample_size: 64
51 | use_motion_module: true
52 | motion_module_resolutions:
53 | - 1
54 | - 2
55 | - 4
56 | - 8
57 | motion_module_mid_block: false
58 | motion_module_decoder_only: false
59 | motion_module_type: Vanilla
60 | motion_module_kwargs:
61 | num_attention_heads: 8
62 | num_transformer_block: 1
63 | attention_block_types:
64 | - Temporal_Self
65 | - Temporal_Self
66 | temporal_position_encoding: true
67 | temporal_position_encoding_max_len: 32
68 | temporal_attention_dim_div: 1
69 | vae:
70 | target: modules.kl_autoencoder.autoencoder.AutoencoderKL
71 | params:
72 | embed_dim: 4
73 | ddconfig:
74 | double_z: true
75 | z_channels: 4
76 | resolution: 256
77 | in_channels: 3
78 | out_ch: 3
79 | ch: 128
80 | ch_mult:
81 | - 1
82 | - 2
83 | - 4
84 | - 4
85 | num_res_blocks: 2
86 | attn_resolutions: []
87 | dropout: 0.0
88 | lossconfig:
89 | target: torch.nn.Identity
90 | text_model:
91 | target: modules.openclip.modules.FrozenCLIPEmbedder
92 | params:
93 | freeze: true
--------------------------------------------------------------------------------
/data/airbrush-painting.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/data/airbrush-painting.mp4
--------------------------------------------------------------------------------
/data/airplane-and-contrail.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/data/airplane-and-contrail.mp4
--------------------------------------------------------------------------------
/data/audi-snow-trail.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/data/audi-snow-trail.mp4
--------------------------------------------------------------------------------
/data/car-turn.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/data/car-turn.mp4
--------------------------------------------------------------------------------
/data/cat-in-the-sun.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/data/cat-in-the-sun.mp4
--------------------------------------------------------------------------------
/data/dirt-road-driving.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/data/dirt-road-driving.mp4
--------------------------------------------------------------------------------
/data/drift-turn.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/data/drift-turn.mp4
--------------------------------------------------------------------------------
/data/earth-full-view.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/data/earth-full-view.mp4
--------------------------------------------------------------------------------
/data/eiffel-flyover.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/data/eiffel-flyover.mp4
--------------------------------------------------------------------------------
/data/ferris-wheel-timelapse.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/data/ferris-wheel-timelapse.mp4
--------------------------------------------------------------------------------
/data/gold-fish.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/data/gold-fish.mp4
--------------------------------------------------------------------------------
/data/ice-hockey.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/data/ice-hockey.mp4
--------------------------------------------------------------------------------
/data/miami-surf.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/data/miami-surf.mp4
--------------------------------------------------------------------------------
/data/raindrops.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/data/raindrops.mp4
--------------------------------------------------------------------------------
/data/red-roses-sunny-day.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/data/red-roses-sunny-day.mp4
--------------------------------------------------------------------------------
/data/swans.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/data/swans.mp4
--------------------------------------------------------------------------------
/dataset/loveu_tgve_dataset.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import cv2
3 | import os
4 | import numpy as np
5 | import torch
6 | from torch.utils.data import Dataset
7 | from torchvision import transforms
8 |
9 | class LoveuTgveVideoDataset(Dataset):
10 | def __init__(self, root_dir, image_size=(480, 480)):
11 | self.root_dir = root_dir
12 | self.image_size = image_size
13 | self.transform = transforms.Compose([
14 | transforms.ToTensor(),
15 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # normalize to [-1,1] for each channel
16 | ])
17 |
18 | csv_file = os.path.join(root_dir, 'LOVEU-TGVE-2023_Dataset.csv')
19 | self.data = {}
20 | with open(csv_file, 'r') as file:
21 | reader = csv.reader(file)
22 | next(reader, None) # skip the headers
23 | for row in reader:
24 | if len(row[0]) == 0:
25 | continue
26 | if row[0].endswith('Videos:'):
27 | dataset_type = row[0].split(' ')[0]
28 | if dataset_type == 'DAVIS':
29 | self.source_folder = dataset_type + '_480p/480p_videos'
30 | else:
31 | self.source_folder = dataset_type.lower() + '_480p/480p_videos'
32 | elif len(row) > 1:
33 | video_name = row[0]
34 | self.data[video_name] = {
35 | 'video_name': video_name,
36 | 'original': row[1],
37 | 'style': row[2],
38 | 'object': row[3],
39 | 'background': row[4],
40 | 'multiple': row[5],
41 | 'source_folder': self.source_folder,
42 | }
43 |
44 | def load_frames(self, video_name, source_folder):
45 | video_path = os.path.join(self.root_dir, source_folder, f'{video_name}.mp4')
46 | cap = cv2.VideoCapture(video_path)
47 | frames = []
48 | while cap.isOpened():
49 | ret, frame = cap.read()
50 | if ret:
51 | frame = cv2.resize(frame, self.image_size)
52 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
53 | frame = self.transform(frame) # convert to PyTorch tensor and normalize
54 | frames.append(frame)
55 | else:
56 | break
57 | cap.release()
58 | return torch.stack(frames, dim=0)
59 |
60 | def load_fps(self, video_name, source_folder):
61 | video_path = os.path.join(self.root_dir, source_folder, f'{video_name}.mp4')
62 | cap = cv2.VideoCapture(video_path)
63 | fps = cap.get(cv2.CAP_PROP_FPS)
64 | cap.release()
65 | return fps
66 |
67 | def __len__(self):
68 | return len(self.data)
69 |
70 | def __getitem__(self, idx):
71 | if isinstance(idx, str):
72 | video_name = idx
73 | else:
74 | video_name = list(self.data.keys())[idx]
75 |
76 | # Load frames and fps when the dataset is called
77 | source_folder = self.data[video_name]['source_folder']
78 | frames = self.load_frames(video_name, source_folder)
79 | fps = self.load_fps(video_name, source_folder)
80 | item = self.data[video_name].copy()
81 | item['frames'] = frames
82 | item['fps'] = fps
83 |
84 | return item
--------------------------------------------------------------------------------
/dataset/single_video_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | from einops import rearrange
4 | import cv2
5 | import torch
6 | from torch.utils.data import Dataset
7 | from torchvision.transforms import functional as F
8 | import omegaconf
9 |
10 | class SingleVideoDataset(Dataset):
11 | def __init__(
12 | self,
13 | video_file,
14 | video_description,
15 | sampling_fps=24,
16 | frame_gap=0,
17 | num_frames=2,
18 | output_size=(512, 512),
19 | mode='train'
20 | ):
21 | self.video_file = video_file
22 | self.video_id = os.path.splitext(os.path.basename(video_file))[0]
23 | self.sampling_fps = sampling_fps
24 | self.frame_gap = frame_gap
25 | self.description = video_description
26 | self.output_size = output_size
27 | self.mode = mode
28 | self.num_frames = num_frames
29 |
30 | cap = cv2.VideoCapture(video_file)
31 | video_fps = round(cap.get(cv2.CAP_PROP_FPS))
32 | num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
33 | cap.release()
34 |
35 | if self.sampling_fps is not None:
36 | if isinstance(self.sampling_fps, int):
37 | sampling_fps = self.sampling_fps
38 | elif isinstance(self.sampling_fps, list) or isinstance(self.sampling_fps, omegaconf.listconfig.ListConfig):
39 | sampling_fps = random.choice(self.sampling_fps)
40 | assert isinstance(sampling_fps, int)
41 | else:
42 | raise ValueError(f'sampling_fps should be int or list of int, got {self.sampling_fps}')
43 | sampling_fps = int(min(sampling_fps, video_fps))
44 | frame_gap = max(0, int(video_fps / sampling_fps))
45 | else:
46 | sampling_fps = video_fps // (1 + self.frame_gap)
47 | frame_gap = self.frame_gap
48 |
49 | self.frame_gap = frame_gap
50 | self.sampling_fps = sampling_fps
51 |
52 | cap.release()
53 |
54 | # print(num_frames, frame_gap, self.num_frames)
55 | self.num_frames = min(self.num_frames, num_frames // frame_gap) # this is number of sampling frames
56 | self.total_possible_starting_frames = max(0, num_frames - (frame_gap * (self.num_frames - 1)))
57 |
58 | def __len__(self):
59 | return self.total_possible_starting_frames
60 |
61 | def __getitem__(self, index):
62 | video_file = self.video_file
63 |
64 | cap = cv2.VideoCapture(video_file)
65 | video_fps = round(cap.get(cv2.CAP_PROP_FPS))
66 | num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
67 |
68 | sampling_fps = self.sampling_fps
69 | frame_gap = self.frame_gap
70 |
71 | frames = []
72 | if num_frames > 1 + frame_gap:
73 | first_frame_index = index
74 | for i in range(self.num_frames):
75 | cap.set(cv2.CAP_PROP_POS_FRAMES, first_frame_index + i * frame_gap)
76 | ret, frame = cap.read()
77 |
78 | # Convert BGR to RGB and resize frame
79 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) if frame is not None else None
80 |
81 | if frame is not None:
82 | height, width, _ = frame.shape
83 | aspect_ratio = width / height
84 | target_width = int(self.output_size[0] * aspect_ratio)
85 | frame = rearrange(torch.from_numpy(frame), 'h w c -> c h w')
86 | frame = F.resize(frame, (self.output_size[1], target_width), antialias=True)
87 |
88 | if target_width > self.output_size[1]:
89 | margin = (target_width - self.output_size[1]) // 2
90 | frame = F.crop(frame, 0, margin, self.output_size[1], self.output_size[0])
91 | else:
92 | margin = (self.output_size[1] - target_width) // 2
93 | frame = F.pad(frame, (margin, 0), 0, 'constant')
94 | frame = (frame / 127.5) - 1.0
95 |
96 | frames.append(frame)
97 | else:
98 | for i in range(self.num_frames):
99 | frames.append(None)
100 |
101 | cap.release()
102 |
103 | # Stack frames
104 | frames = [frame for frame in frames if frame is not None]
105 | while len(frames) < self.num_frames:
106 | frames.append(frames[-1])
107 | frames = frames[:self.num_frames]
108 | frames = torch.stack(frames, dim=0)
109 |
110 | caption = self.description
111 |
112 | res = {
113 | 'frames': frames,
114 | 'video_id': self.video_id,
115 | 'text': caption,
116 | 'fps': torch.tensor(sampling_fps),
117 | }
118 | return res
--------------------------------------------------------------------------------
/dataset/videoP2P.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import jsonlines
4 | import cv2
5 | import torch
6 | from torch.utils.data import Dataset
7 | import numpy as np
8 |
9 | class VideoPromptToPrompt(Dataset):
10 | def __init__(self, root_dirs, num_frames=8):
11 | if isinstance(root_dirs, str):
12 | root_dirs = [root_dirs]
13 | self.root_dirs = root_dirs
14 | self.image_folders = []
15 | for root_dir in self.root_dirs:
16 | image_folders = [f for f in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, f))]
17 | self.num_frames = num_frames
18 |
19 | for f in image_folders:
20 | if os.path.exists(os.path.join(root_dir, f, 'image')) and os.path.exists(os.path.join(root_dir, f, 'metadata.jsonl')) and os.path.exists(os.path.join(root_dir, f, 'prompt.json')):
21 | self.image_folders.append(
22 | os.path.join(root_dir, f)
23 | )
24 |
25 | def __len__(self):
26 | return len(self.image_folders)
27 |
28 | def numpy_to_tensor(self, image):
29 | image = torch.from_numpy(image.transpose((0, 3, 1, 2))).to(torch.float32)
30 | return image * 2 - 1
31 |
32 | def __getitem__(self, idx):
33 | folder = self.image_folders[idx]
34 | with jsonlines.open(os.path.join(folder, 'metadata.jsonl')) as reader:
35 | seeds = [obj['seed'] for obj in reader if (obj['sim_dir'] > 0.2 and obj['sim_0'] > 0.2 and obj['sim_1'] > 0.2 and obj['sim_image'] > 0.5)]
36 | seed = np.random.choice(seeds)
37 |
38 | # Load prompt.json
39 | with open(os.path.join(folder, 'prompt.json'), 'r') as f:
40 | prompt = json.load(f)
41 |
42 | start_idx = np.random.randint(0, 16 - self.num_frames)
43 | end_idx = start_idx + self.num_frames
44 |
45 | input_images = np.array([self.load_image(os.path.join(folder, 'image', f'{seed}_0_{img_idx:04d}.jpg')) for img_idx in range(start_idx, end_idx)])
46 | edited_images = np.array([self.load_image(os.path.join(folder, 'image', f'{seed}_1_{img_idx:04d}.jpg')) for img_idx in range(start_idx, end_idx)])
47 |
48 | return {
49 | 'input_video': self.numpy_to_tensor(input_images),
50 | 'edited_video': self.numpy_to_tensor(edited_images),
51 | 'input_prompt': prompt['input'],
52 | 'output_prompt': prompt['output'],
53 | 'edit_prompt': prompt['edit']
54 | }
55 |
56 | @staticmethod
57 | def load_image(image_path):
58 | img = cv2.imread(image_path)
59 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
60 | img = img / 255.0 # Normalization
61 | return img
62 |
63 |
64 | class VideoPromptToPromptMotionAug(VideoPromptToPrompt):
65 | def __init__(self, *args, zoom_ratio=0.2, max_zoom=1.2, translation_ratio=0.3, translation_range=(0, 0.2), **kwargs):
66 | super().__init__(*args, **kwargs)
67 | self.zoom_ratio = zoom_ratio
68 | self.max_zoom = max_zoom
69 | self.translation_ratio = translation_ratio
70 | self.translation_range = translation_range
71 |
72 | def translation_crop(self, delta_h, delta_w, images):
73 | def center_crop(img, center_x, center_y, h, w):
74 | x_start = int(center_x - w / 2)
75 | x_end = int(x_start + w)
76 | y_start = int(center_y - h / 2)
77 | y_end = int(y_start + h)
78 | img = img[y_start: y_end, x_start: x_end]
79 | return img
80 |
81 | H, W = images.shape[1:3]
82 | crop_H = H - abs(delta_h)
83 | crop_W = W - abs(delta_w)
84 |
85 | if delta_h > 0:
86 | h_start = (H - delta_h) // 2
87 | h_end = h_start + delta_h
88 | else:
89 | h_end = H - (H + delta_h) // 2
90 | h_start = h_end + delta_h
91 |
92 | if delta_w > 0:
93 | w_start = (W - delta_w) // 2
94 | w_end = w_start + delta_w
95 | else:
96 | w_end = W - (W + delta_w) // 2
97 | w_start = w_end + delta_w
98 |
99 | center_xs = np.linspace(w_start, w_end, self.num_frames)
100 | center_ys = np.linspace(h_start, h_end, self.num_frames)
101 |
102 | if delta_h < 0:
103 | center_ys = center_ys[::-1]
104 | if delta_w < 0:
105 | center_xs = center_xs[::-1]
106 |
107 | images = np.stack([center_crop(img, center_x, center_y, crop_H, crop_W) for img, center_x, center_y in zip(images, center_xs, center_ys)], axis=0)
108 | images = np.stack([cv2.resize(img, (W, H), interpolation=cv2.INTER_CUBIC) for img in images], axis=0)
109 | return images
110 |
111 | def zoom_aug(self, images, final_scale=1., zoom_in_or_out='in'):
112 | def zoom_in_with_scale(img, scale):
113 | H, W = img.shape[:2]
114 | img = cv2.resize(img, (int(W * scale), int(H * scale)), interpolation=cv2.INTER_CUBIC)
115 | # center crop to H, W
116 | img = img[(img.shape[0] - H) // 2: (img.shape[0] - H) // 2 + H, (img.shape[1] - W) // 2: (img.shape[1] - W) // 2 + W]
117 | return img
118 | if final_scale <= 1.02 :
119 | return images
120 | if zoom_in_or_out == 'in':
121 | scales = np.linspace(1., final_scale, self.num_frames)
122 | images = np.array([zoom_in_with_scale(img, scale) for img, scale in zip(images, scales)])
123 | elif zoom_in_or_out == 'out':
124 | scales = np.linspace(final_scale, 1., self.num_frames)
125 | images = np.array([zoom_in_with_scale(img, scale) for img, scale in zip(images, scales)])
126 | return images
127 |
128 | def motion_augmentation(self, input_images, edited_images):
129 | H, W = input_images.shape[1:3]
130 |
131 | # translation augmentation
132 | if np.random.random() < self.translation_ratio:
133 | delta_h = np.random.uniform(self.translation_range[0], self.translation_range[1]) * H * np.random.choice([-1, 1])
134 | delta_w = np.random.uniform(self.translation_range[0], self.translation_range[1]) * W * np.random.choice([-1, 1])
135 | # print(delta_h, delta_w)
136 | input_images = self.translation_crop(delta_h, delta_w, input_images)
137 | edited_images = self.translation_crop(delta_h, delta_w, edited_images)
138 |
139 | # zoom augmentation
140 | if np.random.random() < self.zoom_ratio:
141 | final_scale = np.random.uniform(1., self.max_zoom)
142 | zoom_in_or_out = np.random.choice(['in', 'out'])
143 | # print(final_scale, zoom_in_or_out)
144 | input_images = self.zoom_aug(input_images, final_scale, zoom_in_or_out)
145 | edited_images = self.zoom_aug(edited_images, final_scale, zoom_in_or_out)
146 |
147 | return input_images, edited_images
148 |
149 | def __getitem__(self, idx):
150 | folder = self.image_folders[idx]
151 | with jsonlines.open(os.path.join(folder, 'metadata.jsonl')) as reader:
152 | seeds = [obj['seed'] for obj in reader if (obj['sim_dir'] > 0.2 and obj['sim_0'] > 0.2 and obj['sim_1'] > 0.2 and obj['sim_image'] > 0.5)]
153 | seed = np.random.choice(seeds)
154 |
155 | # Load prompt.json
156 | with open(os.path.join(folder, 'prompt.json'), 'r') as f:
157 | prompt = json.load(f)
158 |
159 | start_idx = np.random.randint(0, 16 - self.num_frames + 1)
160 | end_idx = start_idx + self.num_frames
161 |
162 | input_images = np.array([self.load_image(os.path.join(folder, 'image', f'{seed}_0_{img_idx:04d}.jpg')) for img_idx in range(start_idx, end_idx)])
163 | edited_images = np.array([self.load_image(os.path.join(folder, 'image', f'{seed}_1_{img_idx:04d}.jpg')) for img_idx in range(start_idx, end_idx)])
164 |
165 | input_images, edited_images = self.motion_augmentation(input_images, edited_images)
166 |
167 | return {
168 | 'input_video': self.numpy_to_tensor(input_images),
169 | 'edited_video': self.numpy_to_tensor(edited_images),
170 | 'input_prompt': prompt['input'],
171 | 'output_prompt': prompt['output'],
172 | 'edit_prompt': prompt['edit']
173 | }
--------------------------------------------------------------------------------
/figures/data_pipe.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/data_pipe.png
--------------------------------------------------------------------------------
/figures/synthetic_sample/synthetic_video_106_0.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/synthetic_sample/synthetic_video_106_0.gif
--------------------------------------------------------------------------------
/figures/synthetic_sample/synthetic_video_116_0.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/synthetic_sample/synthetic_video_116_0.gif
--------------------------------------------------------------------------------
/figures/synthetic_sample/synthetic_video_141_0.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/synthetic_sample/synthetic_video_141_0.gif
--------------------------------------------------------------------------------
/figures/synthetic_sample/synthetic_video_18_0.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/synthetic_sample/synthetic_video_18_0.gif
--------------------------------------------------------------------------------
/figures/synthetic_sample/synthetic_video_192_0.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/synthetic_sample/synthetic_video_192_0.gif
--------------------------------------------------------------------------------
/figures/synthetic_sample/synthetic_video_197_0.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/synthetic_sample/synthetic_video_197_0.gif
--------------------------------------------------------------------------------
/figures/synthetic_sample/synthetic_video_1_0.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/synthetic_sample/synthetic_video_1_0.gif
--------------------------------------------------------------------------------
/figures/synthetic_sample/synthetic_video_24_0.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/synthetic_sample/synthetic_video_24_0.gif
--------------------------------------------------------------------------------
/figures/synthetic_sample/synthetic_video_81_0.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/synthetic_sample/synthetic_video_81_0.gif
--------------------------------------------------------------------------------
/figures/synthetic_sample/synthetic_video_92_0.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/synthetic_sample/synthetic_video_92_0.gif
--------------------------------------------------------------------------------
/figures/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/teaser.png
--------------------------------------------------------------------------------
/figures/videos/TGVE_video_edit.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/videos/TGVE_video_edit.mp4
--------------------------------------------------------------------------------
/figures/videos/airbrush-painting_object.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/videos/airbrush-painting_object.gif
--------------------------------------------------------------------------------
/figures/videos/airplane-and-contrail_background.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/videos/airplane-and-contrail_background.gif
--------------------------------------------------------------------------------
/figures/videos/audi-snow-trail_background.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/videos/audi-snow-trail_background.gif
--------------------------------------------------------------------------------
/figures/videos/cat-in-the-sun_background.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/videos/cat-in-the-sun_background.gif
--------------------------------------------------------------------------------
/figures/videos/cat-in-the-sun_object.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/videos/cat-in-the-sun_object.gif
--------------------------------------------------------------------------------
/figures/videos/dirt-road-driving_style.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/videos/dirt-road-driving_style.gif
--------------------------------------------------------------------------------
/figures/videos/drift-turn_style.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/videos/drift-turn_style.gif
--------------------------------------------------------------------------------
/figures/videos/earth-full-view_background.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/videos/earth-full-view_background.gif
--------------------------------------------------------------------------------
/figures/videos/eiffel-flyover_background.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/videos/eiffel-flyover_background.gif
--------------------------------------------------------------------------------
/figures/videos/ferris-wheel-timelapse_background.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/videos/ferris-wheel-timelapse_background.gif
--------------------------------------------------------------------------------
/figures/videos/gold-fish_style.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/videos/gold-fish_style.gif
--------------------------------------------------------------------------------
/figures/videos/ice-hockey_object.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/videos/ice-hockey_object.gif
--------------------------------------------------------------------------------
/figures/videos/miami-surf_background.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/videos/miami-surf_background.gif
--------------------------------------------------------------------------------
/figures/videos/raindrops_style.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/videos/raindrops_style.gif
--------------------------------------------------------------------------------
/figures/videos/red-roses-sunny-day_background.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/videos/red-roses-sunny-day_background.gif
--------------------------------------------------------------------------------
/figures/videos/red-roses-sunny-day_style.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/videos/red-roses-sunny-day_style.gif
--------------------------------------------------------------------------------
/figures/videos/swans_background.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/videos/swans_background.gif
--------------------------------------------------------------------------------
/figures/videos/swans_object.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/instruct-video-to-video/6a51b4865d74e41797c74fa216017a272b13a524/figures/videos/swans_object.gif
--------------------------------------------------------------------------------
/gradio_demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import gradio as gr
3 | import numpy as np
4 | from misc_utils.image_utils import save_tensor_to_gif
5 | from misc_utils.train_utils import unit_test_create_model
6 | from pl_trainer.inference.inference import InferenceIP2PVideoOpticalFlow
7 | from dataset.single_video_dataset import SingleVideoDataset
8 | import torch
9 |
10 | NEGATIVE_PROMPT = 'worst quality, normal quality, low quality, low res, blurry, text, watermark, logo, banner, extra digits, cropped, jpeg artifacts, signature, username, error, sketch ,duplicate, ugly, monochrome, horror, geometry, mutation, disgusting'
11 | # The order is: [video_path, edit_prompt, negative_prompt, text_cfg, video_cfg, resolution, sample_rate, num_frames, start_frame]
12 | CARTURN = [ 'data/car-turn.mp4', 'Change the car to a red Porsche and make the background beach.', NEGATIVE_PROMPT, 10, 1.1, 512, 10, 28, 20, ]
13 | AIRPLANE = ['data/airplane-and-contrail.mp4', "add Taj Mahal in the image", NEGATIVE_PROMPT, 10, 1.2, 512, 30, 28, 0]
14 | AUDI = ['data/audi-snow-trail.mp4', "make the car drive in desert trail.", NEGATIVE_PROMPT, 10, 1.5, 512, 3, 28, 0]
15 | CATINSUN_BKG = ['data/cat-in-the-sun.mp4', "change the background to a beach.", NEGATIVE_PROMPT, 7.5, 1.3, 512, 6, 28, 0]
16 | DIRTROAD = ['data/dirt-road-driving.mp4', 'add dust cloud effect.', NEGATIVE_PROMPT, 10, 1.2, 512, 6, 28, 0]
17 | EARTH = ['data/earth-full-view.mp4', 'add a fireworks display in the background..', NEGATIVE_PROMPT, 7.5, 1.2, 512, 30, 28, 0]
18 | EIFFELTOWER = ['data/eiffel-flyover.mp4', 'add a large fireworks display.', NEGATIVE_PROMPT, 10, 1.2, 512, 6, 28, 0]
19 | FERRIS = ['data/ferris-wheel-timelapse.mp4', 'Add a sunset in the background.', NEGATIVE_PROMPT, 10, 1.2, 512, 6, 28, 0]
20 | GOLDFISH = ['data/gold-fish.mp4', 'make the style impressionist', NEGATIVE_PROMPT, 10, 1.2, 512, 6, 28, 0]
21 | ICEHOCKEY = ['data/ice-hockey.mp4', 'make the players to cartoon characters.', NEGATIVE_PROMPT, 10, 1.5, 512, 6, 28, 0]
22 | MIAMISURF = ['data/miami-surf.mp4', 'change the background to wave pool.', NEGATIVE_PROMPT, 10, 1.2, 512, 6, 28, 0]
23 | RAINDROP = ['data/raindrops.mp4', 'Make the style expressionism.', NEGATIVE_PROMPT, 10, 1.2, 512, 6, 28, 0]
24 | REDROSE_BKG = ['data/red-roses-sunny-day.mp4', 'make background to moonlight.', NEGATIVE_PROMPT, 10, 1.2, 512, 6, 28, 0]
25 | REDROSE_STY = ['data/red-roses-sunny-day.mp4', 'Make the style origami.', NEGATIVE_PROMPT, 10, 1.2, 512, 6, 28, 0]
26 | SWAN_OBJ = ['data/swans.mp4', 'change swans to pink flamingos.', NEGATIVE_PROMPT, 7.5, 1.2, 512, 6, 28, 0]
27 |
28 | class VideoTransfer:
29 | def __init__(self, config_path, model_ckpt, device='cuda'):
30 | self.config_path = config_path
31 | self.model_ckpt = model_ckpt
32 | self.device = device
33 | self.diffusion_model = None
34 | self.pipe = None
35 |
36 | def _init_pipe(self):
37 | diffusion_model = unit_test_create_model(self.config_path, device=self.device)
38 | ckpt = torch.load(self.model_ckpt, map_location='cpu')
39 | diffusion_model.load_state_dict(ckpt, strict=False)
40 | self.diffusion_model = diffusion_model
41 | self.pipe = InferenceIP2PVideoOpticalFlow(
42 | unet = diffusion_model.unet,
43 | num_ddim_steps=20,
44 | scheduler='ddpm'
45 | )
46 |
47 | def get_batch(self, video_path, video_sample_rate, num_frames, image_size, start_frame=0):
48 | dataset = SingleVideoDataset(
49 | video_file=video_path,
50 | video_description='',
51 | sampling_fps=video_sample_rate,
52 | num_frames=num_frames,
53 | output_size=(image_size, image_size)
54 | )
55 | batch = dataset[start_frame]
56 | batch = {k: v.to(self.device)[None] if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
57 | return batch
58 |
59 | @staticmethod
60 | def split_batch(cond, frames_in_batch=16, num_ref_frames=4):
61 | frames_in_following_batch = frames_in_batch - num_ref_frames
62 | conds = [cond[:, :frames_in_batch]]
63 | frame_ptr = frames_in_batch
64 | num_ref_frames_each_batch = []
65 |
66 | while frame_ptr < cond.shape[1]:
67 | remaining_frames = cond.shape[1] - frame_ptr
68 | if remaining_frames < frames_in_batch:
69 | frames_in_following_batch = remaining_frames
70 | else:
71 | frames_in_following_batch = frames_in_batch - num_ref_frames
72 | this_ref_frames = frames_in_batch - frames_in_following_batch
73 | conds.append(cond[:, frame_ptr:frame_ptr+frames_in_following_batch])
74 | frame_ptr += frames_in_following_batch
75 | num_ref_frames_each_batch.append(this_ref_frames)
76 |
77 | return conds, num_ref_frames_each_batch
78 |
79 | def get_splitted_batch_and_conds(self, batch, edit_promt, negative_prompt):
80 | diffusion_model = self.diffusion_model
81 | cond = [diffusion_model.encode_image_to_latent(frames) / 0.18215 for frames in batch['frames'].chunk(16, dim=1)] # when encoding, chunk the frames to avoid oom in vae, you can reduce the 16 if you have a smaller gpu
82 | cond = torch.cat(cond, dim=1)
83 | text_cond = diffusion_model.encode_text([edit_promt])
84 | text_uncond = diffusion_model.encode_text([negative_prompt])
85 | conds, num_ref_frames_each_batch = self.split_batch(cond, frames_in_batch=16, num_ref_frames=4)
86 | splitted_frames, _ = self.split_batch(batch['frames'], frames_in_batch=16, num_ref_frames=4)
87 | return conds, num_ref_frames_each_batch, text_cond, text_uncond, splitted_frames
88 |
89 | def transfer_video(self, video_path, edit_prompt, negative_prompt, text_cfg, video_cfg, resolution, video_sample_rate, num_frames, start_frame):
90 | # TODO, support seed
91 | video_name = os.path.basename(video_path).split('.')[0]
92 | output_file_id = f'{video_name}_{text_cfg}_{video_cfg}_{resolution}_{video_sample_rate}_{num_frames}'
93 | if self.pipe is None:
94 | self._init_pipe()
95 |
96 | batch = self.get_batch(video_path, video_sample_rate, num_frames, resolution, start_frame)
97 | conds, num_ref_frames_each_batch, text_cond, text_uncond, splitted_frames = self.get_splitted_batch_and_conds(batch, edit_prompt, negative_prompt)
98 | with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.float16):
99 | # First video clip
100 | cond1 = conds[0]
101 | latent_pred_list = []
102 | init_latent = torch.randn_like(cond1)
103 | latent_pred = self.pipe(
104 | latent = init_latent,
105 | text_cond = text_cond,
106 | text_uncond = text_uncond,
107 | img_cond = cond1,
108 | text_cfg = text_cfg,
109 | img_cfg = video_cfg,
110 | )['latent']
111 | latent_pred_list.append(latent_pred)
112 |
113 |
114 | # Subsequent video clips
115 | for prev_cond, cond_, prev_frame, curr_frame, num_ref_frames_ in zip(
116 | conds[:-1], conds[1:], splitted_frames[:-1], splitted_frames[1:], num_ref_frames_each_batch
117 | ):
118 | init_latent = torch.cat([init_latent[:, -num_ref_frames_:], torch.randn_like(cond_)], dim=1)
119 | cond_ = torch.cat([prev_cond[:, -num_ref_frames_:], cond_], dim=1)
120 |
121 | # additional kwargs for using motion compensation
122 | ref_images = prev_frame[:, -num_ref_frames_:]
123 | query_images = curr_frame
124 | additional_kwargs = {
125 | 'ref_images': ref_images,
126 | 'query_images': query_images,
127 | }
128 |
129 | latent_pred = self.pipe.second_clip_forward(
130 | latent = init_latent,
131 | text_cond = text_cond,
132 | text_uncond = text_uncond,
133 | img_cond = cond_,
134 | latent_ref = latent_pred[:, -num_ref_frames_:],
135 | noise_correct_step = 0.6,
136 | text_cfg = text_cfg,
137 | img_cfg = video_cfg,
138 | **additional_kwargs,
139 | )['latent']
140 | latent_pred_list.append(latent_pred[:, num_ref_frames_:])
141 |
142 | # Save GIF
143 | original_images = batch['frames'].cpu()
144 | latent_pred = torch.cat(latent_pred_list, dim=1)
145 | image_pred = self.diffusion_model.decode_latent_to_image(latent_pred).clip(-1, 1)
146 | transferred_images = image_pred.float().cpu()
147 | save_tensor_to_gif(original_images, f'gradio_cache/{output_file_id}_original.gif', fps=5)
148 | save_tensor_to_gif(transferred_images, f'gradio_cache/{output_file_id}.gif', fps=5)
149 | return f'gradio_cache/{output_file_id}_original.gif', f'gradio_cache/{output_file_id}.gif'
150 |
151 | video_transfer = VideoTransfer(
152 | config_path = 'configs/instruct_v2v_inference.yaml',
153 | model_ckpt = 'insv2v.pth',
154 | device = 'cuda',
155 | )
156 |
157 | def transfer_video(video_path, edit_prompt, negative_prompt, text_cfg, video_cfg, resolution, video_sample_rate, num_frames, start_frame):
158 | transferred_video_path = video_transfer.transfer_video(
159 | video_path = video_path,
160 | edit_prompt = edit_prompt,
161 | negative_prompt = negative_prompt,
162 | text_cfg = float(text_cfg),
163 | video_cfg = float(video_cfg),
164 | resolution = int(resolution),
165 | video_sample_rate = int(video_sample_rate),
166 | num_frames = int(num_frames),
167 | start_frame = int(start_frame),
168 | )
169 | return transferred_video_path # a gif image
170 |
171 | with gr.Blocks() as demo:
172 | with gr.Row():
173 | video_source = gr.Video(label="Upload Video", interactive=True, width=384)
174 | video_input = gr.Image(type='filepath', width=384, label='Original Video')
175 | video_output = gr.Image(type='filepath', width=384, label='Edited Video')
176 |
177 | with gr.Row():
178 | with gr.Column(scale=3):
179 | edit_prompt = gr.Textbox(label="Edit Prompt")
180 | negative_prompt = gr.Textbox(label="Negative Prompt")
181 | with gr.Column(scale=1):
182 | submit_btn = gr.Button(label="Transfer")
183 |
184 | with gr.Row():
185 | with gr.Row():
186 | text_cfg = gr.Textbox(label="Text classifier-free guidance", value=7.5)
187 | video_cfg = gr.Textbox(label="Video classifier-free guidance", value=1.2)
188 | resolution = gr.Textbox(label="Resolution", value=384)
189 | sample_rate = gr.Textbox(label="Video Sample Rate", value=5)
190 | num_frames = gr.Textbox(label="Number of frames", value=28)
191 | start_frame = gr.Textbox(label="Start frame index", value=0)
192 |
193 | gr.Examples(
194 | examples=[
195 | EARTH,
196 | AUDI,
197 | DIRTROAD,
198 | CATINSUN_BKG,
199 | EIFFELTOWER,
200 | FERRIS,
201 | GOLDFISH,
202 | CARTURN,
203 | ICEHOCKEY,
204 | MIAMISURF,
205 | RAINDROP,
206 | REDROSE_BKG,
207 | REDROSE_STY,
208 | SWAN_OBJ,
209 | AIRPLANE,
210 | ],
211 | inputs=[
212 | video_source,
213 | edit_prompt,
214 | negative_prompt,
215 | text_cfg,
216 | video_cfg,
217 | resolution,
218 | sample_rate,
219 | num_frames,
220 | start_frame,
221 | ],
222 | )
223 |
224 | submit_btn.click(
225 | transfer_video,
226 | inputs=[
227 | video_source,
228 | edit_prompt,
229 | negative_prompt,
230 | text_cfg,
231 | video_cfg,
232 | resolution,
233 | sample_rate,
234 | num_frames,
235 | start_frame,
236 | ],
237 | outputs=[
238 | video_input,
239 | video_output
240 | ],
241 | )
242 |
243 | demo.queue(concurrency_count=1).launch(share=True)
244 |
--------------------------------------------------------------------------------
/insv2v_run_loveu_tgve.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from dataset.loveu_tgve_dataset import LoveuTgveVideoDataset
4 | from matplotlib import pyplot as plt
5 | from misc_utils.train_utils import unit_test_create_model
6 | from itertools import product
7 | from pl_trainer.inference.inference import InferenceIP2PVideo, InferenceIP2PVideoOpticalFlow
8 | import json
9 | import argparse
10 | from misc_utils.image_utils import save_tensor_to_gif, save_tensor_to_images
11 |
12 | def split_batch(cond, frames_in_batch=16, num_ref_frames=4):
13 | frames_in_following_batch = frames_in_batch - num_ref_frames
14 | conds = [cond[:, :frames_in_batch]]
15 | frame_ptr = frames_in_batch
16 | num_ref_frames_each_batch = []
17 |
18 | while frame_ptr < cond.shape[1]:
19 | remaining_frames = cond.shape[1] - frame_ptr
20 | if remaining_frames < frames_in_batch:
21 | frames_in_following_batch = remaining_frames
22 | else:
23 | frames_in_following_batch = frames_in_batch - num_ref_frames
24 | this_ref_frames = frames_in_batch - frames_in_following_batch
25 | conds.append(cond[:, frame_ptr:frame_ptr+frames_in_following_batch])
26 | frame_ptr += frames_in_following_batch
27 | num_ref_frames_each_batch.append(this_ref_frames)
28 |
29 | return conds, num_ref_frames_each_batch
30 |
31 | parser = argparse.ArgumentParser(description='Your program description')
32 |
33 | # Add arguments
34 | parser.add_argument('--text-cfg', nargs='+', type=float, default=[7.5], help='Text configuration parameter')
35 | parser.add_argument('--video-cfg', nargs='+', type=float, default=[1.8], help='Image configuration parameter')
36 | parser.add_argument('--num-frames', nargs='+', type=int, default=[32], help='Number of frames')
37 | parser.add_argument('--image-size', nargs='+', type=int, default=[384], help='Image size')
38 | parser.add_argument('--prompt-source', type=str, default='edit', help='Prompt source')
39 | parser.add_argument('--ckpt-path', type=str, help='Path to checkpoint')
40 | parser.add_argument('--config-path', type=str, default='configs/instruct_v2v.yaml', help='Path to config file')
41 | parser.add_argument('--data-dir', type=str, default='loveu-tgve-2023', help='Path to LOVEU dataset')
42 | parser.add_argument('--with_optical_flow', action='store_true', help='Use motion compensation')
43 |
44 | # Parse arguments
45 | args = parser.parse_args()
46 |
47 | TEXT_CFGS = args.text_cfg
48 | VIDEO_CFGS = args.video_cfg
49 | NUM_FRAMES = args.num_frames
50 | IMAGE_SIZE = args.image_size
51 |
52 | PROMPT_SOURCE = args.prompt_source
53 | DATA_ROOT = args.data_dir
54 |
55 | config_path = args.config_path
56 | ckpt_path = args.ckpt_path
57 |
58 | diffusion_model = unit_test_create_model(config_path)
59 |
60 | ckpt = torch.load(ckpt_path, map_location='cpu')
61 | ckpt = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
62 | diffusion_model.load_state_dict(ckpt, strict=False)
63 |
64 | if args.with_optical_flow:
65 | inf_pipe = InferenceIP2PVideoOpticalFlow(
66 | unet = diffusion_model.unet,
67 | num_ddim_steps=20,
68 | scheduler='ddpm'
69 | )
70 | else:
71 | inf_pipe = InferenceIP2PVideo(
72 | unet = diffusion_model.unet,
73 | num_ddim_steps=20,
74 | scheduler='ddpm'
75 | )
76 |
77 | frames_in_batch = 16
78 | num_ref_frames = 4
79 |
80 | edit_prompt_file = 'dataset/loveu_tgve_edit_prompt_dict.json'
81 | edit_prompt_dict = json.load(open(edit_prompt_file, 'r'))
82 |
83 | for VIDEO_ID, text_cfg, video_cfg, num_frames, image_size in product(range(len(edit_prompt_dict)), TEXT_CFGS, VIDEO_CFGS, NUM_FRAMES, IMAGE_SIZE):
84 | dataset = LoveuTgveVideoDataset(
85 | root_dir=DATA_ROOT,
86 | image_size=(image_size, image_size),
87 | )
88 |
89 | batch = dataset[VIDEO_ID]
90 | video_name = batch['video_name']
91 | num_video_frames = len(batch['frames'])
92 | if num_video_frames > num_frames:
93 | frame_skip = num_video_frames // num_frames
94 | else:
95 | frame_skip = 1
96 | batch = {k: v[::frame_skip].cuda()[None] if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
97 |
98 | cond = diffusion_model.encode_image_to_latent(batch['frames']) / 0.18215
99 | text_uncond = diffusion_model.encode_text([''])
100 |
101 | for prompt_key in ['style', 'object', 'background', 'multiple']:
102 | final_prompt = batch[prompt_key]
103 | if PROMPT_SOURCE == 'edit':
104 | prompt = edit_prompt_dict[video_name]['edit_' + prompt_key]
105 | out_folder = f'v2v_results/edit_prompt/loveu_tgve_{image_size}/gif/VID_{VIDEO_ID}/VIDEO_CFG_{video_cfg}_TEXT_CFG_{text_cfg}'
106 | image_output_dir = f'v2v_results/edit_prompt/loveu_tgve_{image_size}/images_{num_frames}/VIDEO_CFG_{video_cfg}_TEXT_CFG_{text_cfg}/{video_name}/{prompt_key}'
107 | elif PROMPT_SOURCE == 'original':
108 | prompt = batch[prompt_key]
109 | out_folder = f'v2v_results/original_prompt/loveu_tgve_{image_size}/gif/VID_{VIDEO_ID}/VIDEO_CFG_{video_cfg}_TEXT_CFG_{text_cfg}'
110 | image_output_dir = f'v2v_results/original_prompt/loveu_tgve_{image_size}/images_{num_frames}/VIDEO_CFG_{video_cfg}_TEXT_CFG_{text_cfg}/{video_name}/{prompt_key}'
111 |
112 | text = '_'.join(final_prompt.split(' '))
113 | output_path = f'{out_folder}/{prompt_key}_{num_frames}_{text}.gif'
114 | if os.path.exists(output_path):
115 | print(f'File {output_path} exists, skip')
116 | continue
117 |
118 | text_cond = diffusion_model.encode_text(prompt)
119 | conds, num_ref_frames_each_batch = split_batch(cond, frames_in_batch=frames_in_batch, num_ref_frames=num_ref_frames)
120 | splitted_frames, _ = split_batch(batch['frames'], frames_in_batch=frames_in_batch, num_ref_frames=num_ref_frames)
121 |
122 | # First video clip
123 | cond1 = conds[0]
124 | latent_pred_list = []
125 | init_latent = torch.randn_like(cond1)
126 | latent_pred = inf_pipe(
127 | latent = init_latent,
128 | text_cond = text_cond,
129 | text_uncond = text_uncond,
130 | img_cond = cond1,
131 | text_cfg = text_cfg,
132 | img_cfg = video_cfg,
133 | )['latent']
134 | latent_pred_list.append(latent_pred)
135 |
136 |
137 | # Subsequent video clips
138 | for prev_cond, cond_, prev_frame, curr_frame, num_ref_frames_ in zip(conds[:-1], conds[1:], splitted_frames[:-1], splitted_frames[1:], num_ref_frames_each_batch):
139 | init_latent = torch.cat([init_latent[:, -num_ref_frames_:], torch.randn_like(cond_)], dim=1)
140 | cond_ = torch.cat([prev_cond[:, -num_ref_frames_:], cond_], dim=1)
141 | if args.with_optical_flow:
142 | ref_images = prev_frame[:, -num_ref_frames_:]
143 | query_images = curr_frame
144 | additional_kwargs = {
145 | 'ref_images': ref_images,
146 | 'query_images': query_images,
147 | }
148 | else:
149 | additional_kwargs = {}
150 | latent_pred = inf_pipe.second_clip_forward(
151 | latent = init_latent,
152 | text_cond = text_cond,
153 | text_uncond = text_uncond,
154 | img_cond = cond_,
155 | latent_ref = latent_pred[:, -num_ref_frames_:],
156 | noise_correct_step = 0.5,
157 | text_cfg = text_cfg,
158 | img_cfg = video_cfg,
159 | **additional_kwargs,
160 | )['latent']
161 | latent_pred_list.append(latent_pred[:, num_ref_frames_:])
162 |
163 | # Save GIF
164 | latent_pred = torch.cat(latent_pred_list, dim=1)
165 | image_pred = diffusion_model.decode_latent_to_image(latent_pred).clip(-1, 1)
166 |
167 | original_images = batch['frames'].cpu()
168 | transferred_images = image_pred.float().cpu()
169 | concat_images = torch.cat([original_images, transferred_images], dim=4)
170 |
171 | save_tensor_to_gif(concat_images, output_path, fps=5)
172 | save_tensor_to_images(transferred_images, image_output_dir)
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import argparse
4 | from omegaconf import OmegaConf
5 | from pytorch_lightning import Trainer, seed_everything
6 | from misc_utils.train_utils import get_models, get_DDPM, get_logger, get_callbacks, get_dataset
7 |
8 | if __name__ == '__main__':
9 | # seed_everything(42)
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument(
12 | '-c', '--config', type=str,
13 | default='config/train.json')
14 | parser.add_argument(
15 | '-r', '--resume', action="store_true"
16 | )
17 | parser.add_argument(
18 | '-n', '--nnode', type=int, default=1
19 | )
20 | parser.add_argument(
21 | '--ckpt', type=str, default=None
22 | )
23 | parser.add_argument(
24 | '--manual_load', action="store_true"
25 | )
26 |
27 | ''' parser configs '''
28 | args_raw = parser.parse_args()
29 | args = OmegaConf.load(args_raw.config)
30 | args.update(vars(args_raw))
31 | expt_path = os.path.join(args.expt_dir, args.expt_name)
32 | os.makedirs(expt_path, exist_ok=True)
33 |
34 | '''1. create denoising model'''
35 | models = get_models(args)
36 |
37 | diffusion_configs = args.diffusion
38 | ddpm_model = get_DDPM(
39 | diffusion_configs=diffusion_configs,
40 | log_args=args,
41 | **models
42 | )
43 |
44 | '''2. dataset and dataloader'''
45 | train_loader, val_loader, train_set, val_set = get_dataset(args)
46 |
47 | '''3. create callbacks'''
48 | wandb_logger = get_logger(args)
49 | callbacks = get_callbacks(args, wandb_logger)
50 |
51 | '''4. trainer'''
52 | trainer_args = {
53 | "max_epochs": 100,
54 | "accelerator": "gpu",
55 | "devices": [0],
56 | "limit_val_batches": 1,
57 | "strategy": "ddp",
58 | "check_val_every_n_epoch": 1,
59 | "num_nodes": args.nnode
60 | # "benchmark" :True
61 | }
62 | config_trainer_args = args.trainer_args if args.get('trainer_args') is not None else {}
63 | trainer_args.update(config_trainer_args)
64 | print(f'Training args are {trainer_args}')
65 | trainer = Trainer(
66 | logger = wandb_logger,
67 | callbacks = callbacks,
68 | **trainer_args
69 | )
70 | '''5. start training'''
71 | if args['resume']:
72 | print('INFO: Try to resume from checkpoint')
73 | if args['ckpt'] is not None:
74 | ckpt_path = args['ckpt']
75 | else:
76 | ckpt_path = os.path.join(expt_path, 'last.ckpt')
77 | if os.path.exists(ckpt_path):
78 | print(f'INFO: Found checkpoint {ckpt_path}')
79 | if args['manual_load']:
80 | print('INFO: Manually load checkpoint')
81 | ckpt = torch.load(ckpt_path, map_location='cpu')
82 | ddpm_model.load_state_dict(ckpt['state_dict'])
83 | ckpt_path = None # do not need to load checkpoint in Trainer
84 | else:
85 | ckpt_path = None
86 | else:
87 | ckpt_path = None
88 | trainer.fit(
89 | ddpm_model, train_loader, val_loader,
90 | ckpt_path=ckpt_path
91 | )
92 |
--------------------------------------------------------------------------------
/misc_utils/clip_similarity.py:
--------------------------------------------------------------------------------
1 | # from https://github.com/timothybrooks/instruct-pix2pix/blob/main/metrics/clip_similarity.py
2 |
3 | import clip
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from einops import rearrange
8 |
9 |
10 | class ClipSimilarity(nn.Module):
11 | def __init__(self, name: str = "ViT-L/14"):
12 | super().__init__()
13 | assert name in ("RN50", "RN101", "RN50x4", "RN50x16", "RN50x64", "ViT-B/32", "ViT-B/16", "ViT-L/14", "ViT-L/14@336px") # fmt: skip
14 | self.size = {"RN50x4": 288, "RN50x16": 384, "RN50x64": 448, "ViT-L/14@336px": 336}.get(name, 224)
15 |
16 | self.model, _ = clip.load(name, device="cpu", download_root="./")
17 | self.model.eval().requires_grad_(False)
18 |
19 | self.register_buffer("mean", torch.tensor((0.48145466, 0.4578275, 0.40821073)))
20 | self.register_buffer("std", torch.tensor((0.26862954, 0.26130258, 0.27577711)))
21 |
22 | def encode_text(self, text: list[str]) -> torch.Tensor:
23 | text = clip.tokenize(text, truncate=True).to(next(self.parameters()).device)
24 | text_features = self.model.encode_text(text)
25 | text_features = text_features / text_features.norm(dim=1, keepdim=True)
26 | return text_features
27 |
28 | def encode_image(self, image: torch.Tensor) -> torch.Tensor: # Input images in range [0, 1].
29 | image = F.interpolate(image.float(), size=self.size, mode="bicubic", align_corners=False)
30 | image = image - rearrange(self.mean, "c -> 1 c 1 1")
31 | image = image / rearrange(self.std, "c -> 1 c 1 1")
32 | image_features = self.model.encode_image(image)
33 | image_features = image_features / image_features.norm(dim=1, keepdim=True)
34 | return image_features
35 |
36 | def forward(
37 | self, image_0: torch.Tensor, image_1: torch.Tensor, text_0: list[str], text_1: list[str]
38 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
39 | image_features_0 = self.encode_image(image_0)
40 | image_features_1 = self.encode_image(image_1)
41 | text_features_0 = self.encode_text(text_0)
42 | text_features_1 = self.encode_text(text_1)
43 | sim_0 = F.cosine_similarity(image_features_0, text_features_0)
44 | sim_1 = F.cosine_similarity(image_features_1, text_features_1)
45 | sim_direction = F.cosine_similarity(image_features_1 - image_features_0, text_features_1 - text_features_0)
46 | sim_image = F.cosine_similarity(image_features_0, image_features_1)
47 | return sim_0, sim_1, sim_direction, sim_image
--------------------------------------------------------------------------------
/misc_utils/flow_utils.py:
--------------------------------------------------------------------------------
1 | '''
2 | Usage:
3 |
4 | from misc_utils.flow_utils import RAFTFlow, load_image_as_tensor, warp_image, MyRandomPerspective, generate_sample
5 | image = load_image_as_tensor('hamburger_pic.jpeg', image_size)
6 | flow_estimator = RAFTFlow()
7 | res = generate_sample(
8 | image,
9 | flow_estimator,
10 | distortion_scale=distortion_scale,
11 | )
12 | f1 = res['input'][None]
13 | f2 = res['target'][None]
14 | flow = res['flow'][None]
15 | f1_warp = warp_image(f1, flow)
16 | show_image(f1_warp[0])
17 | show_image(f2[0])
18 | '''
19 | import torch
20 | import torch.nn.functional as F
21 | import torchvision.transforms.functional as TF
22 | from torchvision.models.optical_flow import raft_large, Raft_Large_Weights
23 | import numpy as np
24 |
25 | def warp_image(image, flow, mode='bilinear'):
26 | """ Warp an image using optical flow.
27 | Args:
28 | image (torch.Tensor): Input image tensor with shape (N, C, H, W).
29 | flow (torch.Tensor): Optical flow tensor with shape (N, 2, H, W).
30 | Returns:
31 | warped_image (torch.Tensor): Warped image tensor with shape (N, C, H, W).
32 | """
33 | # check shape
34 | if len(image.shape) == 3:
35 | image = image.unsqueeze(0)
36 | if len(flow.shape) == 3:
37 | flow = flow.unsqueeze(0)
38 | if image.device != flow.device:
39 | flow = flow.to(image.device)
40 | assert image.shape[0] == flow.shape[0], f'Batch size of image and flow must be the same. Got {image.shape[0]} and {flow.shape[0]}.'
41 | assert image.shape[2:] == flow.shape[2:], f'Height and width of image and flow must be the same. Got {image.shape[2:]} and {flow.shape[2:]}.'
42 | # Generate a grid of sampling points
43 | grid = torch.tensor(
44 | np.array(np.meshgrid(range(image.shape[3]), range(image.shape[2]), indexing='xy')),
45 | dtype=torch.float32, device=image.device
46 | )[None]
47 | grid = grid.permute(0, 2, 3, 1).repeat(image.shape[0], 1, 1, 1) # (N, H, W, 2)
48 | grid += flow.permute(0, 2, 3, 1) # add optical flow to grid
49 |
50 | # Normalize grid to [-1, 1]
51 | grid[:, :, :, 0] = 2 * (grid[:, :, :, 0] / (image.shape[3] - 1) - 0.5)
52 | grid[:, :, :, 1] = 2 * (grid[:, :, :, 1] / (image.shape[2] - 1) - 0.5)
53 |
54 | # Sample input image using the grid
55 | warped_image = F.grid_sample(image, grid, mode=mode, align_corners=True)
56 |
57 | return warped_image
58 |
59 | def resize_flow(flow, size):
60 | """
61 | Resize optical flow tensor to a new size.
62 |
63 | Args:
64 | flow (torch.Tensor): Optical flow tensor with shape (B, 2, H, W).
65 | size (tuple[int, int]): Target size as a tuple (H, W).
66 |
67 | Returns:
68 | flow_resized (torch.Tensor): Resized optical flow tensor with shape (B, 2, H, W).
69 | """
70 | # Unpack the target size
71 | H, W = size
72 |
73 | # Compute the scaling factors
74 | h, w = flow.shape[2:]
75 | scale_x = W / w
76 | scale_y = H / h
77 |
78 | # Scale the optical flow by the resizing factors
79 | flow_scaled = flow.clone()
80 | flow_scaled[:, 0] *= scale_x
81 | flow_scaled[:, 1] *= scale_y
82 |
83 | # Resize the optical flow to the new size (H, W)
84 | flow_resized = F.interpolate(flow_scaled, size=(H, W), mode='bilinear', align_corners=False)
85 |
86 | return flow_resized
87 |
88 | def check_consistency(flow1: torch.Tensor, flow2: torch.Tensor) -> torch.Tensor:
89 | """
90 | Check the consistency of two optical flows.
91 | flow1: (B, 2, H, W)
92 | flow2: (B, 2, H, W)
93 | if want the output to be forward flow, then flow1 is the forward flow and flow2 is the backward flow
94 | return: (H, W)
95 | """
96 | device = flow1.device
97 | height, width = flow1.shape[2:]
98 |
99 | kernel_x = torch.tensor([[0.5, 0, -0.5]]).unsqueeze(0).unsqueeze(0).to(device)
100 | kernel_y = torch.tensor([[0.5], [0], [-0.5]]).unsqueeze(0).unsqueeze(0).to(device)
101 | grad_x = torch.nn.functional.conv2d(flow1[:, :1], kernel_x, padding=(0, 1))
102 | grad_y = torch.nn.functional.conv2d(flow1[:, 1:], kernel_y, padding=(1, 0))
103 |
104 | motion_edge = (grad_x * grad_x + grad_y * grad_y).sum(dim=1).squeeze(0)
105 |
106 | ax, ay = torch.meshgrid(torch.arange(width, device=device), torch.arange(height, device=device), indexing='xy')
107 | bx, by = ax + flow1[:, 0], ay + flow1[:, 1]
108 |
109 | x1, y1 = torch.floor(bx).long(), torch.floor(by).long()
110 | x2, y2 = x1 + 1, y1 + 1
111 | x1 = torch.clamp(x1, 0, width - 1)
112 | x2 = torch.clamp(x2, 0, width - 1)
113 | y1 = torch.clamp(y1, 0, height - 1)
114 | y2 = torch.clamp(y2, 0, height - 1)
115 |
116 | alpha_x, alpha_y = bx - x1.float(), by - y1.float()
117 |
118 | a = (1.0 - alpha_x) * flow2[:, 0, y1, x1] + alpha_x * flow2[:, 0, y1, x2]
119 | b = (1.0 - alpha_x) * flow2[:, 0, y2, x1] + alpha_x * flow2[:, 0, y2, x2]
120 | u = (1.0 - alpha_y) * a + alpha_y * b
121 |
122 | a = (1.0 - alpha_x) * flow2[:, 1, y1, x1] + alpha_x * flow2[:, 1, y1, x2]
123 | b = (1.0 - alpha_x) * flow2[:, 1, y2, x1] + alpha_x * flow2[:, 1, y2, x2]
124 | v = (1.0 - alpha_y) * a + alpha_y * b
125 |
126 | cx, cy = bx + u, by + v
127 | u2, v2 = flow1[:, 0], flow1[:, 1]
128 |
129 | reliable = ((((cx - ax) ** 2 + (cy - ay) ** 2) < (0.01 * (u2 ** 2 + v2 ** 2 + u ** 2 + v ** 2) + 0.5)) & (motion_edge <= 0.01 * (u2 ** 2 + v2 ** 2) + 0.002)).float()
130 |
131 | return reliable # (B, 1, H, W)
132 |
133 |
134 | class RAFTFlow(torch.nn.Module):
135 | '''
136 | # Instantiate the RAFTFlow class
137 | raft_flow = RAFTFlow(device='cuda')
138 |
139 | # Load a pair of image frames as PyTorch tensors
140 | img1 = torch.tensor(np.random.rand(3, 720, 1280), dtype=torch.float32)
141 | img2 = torch.tensor(np.random.rand(3, 720, 1280), dtype=torch.float32)
142 |
143 | # Compute optical flow between the two frames
144 | (optional) image_size = (256, 256) or None
145 | flow = raft_flow.compute_flow(img1, img2, image_size) # flow will be computed at the original image size if image_size is None
146 | # this flow can be used to warp the second image to the first image
147 |
148 | # Warp the second image using the flow
149 | warped_img = warp_image(img2, flow)
150 | '''
151 | def __init__(self, *args):
152 | """
153 | Args:
154 | device (str): Device to run the model on ("cpu" or "cuda").
155 | """
156 | super().__init__(*args)
157 | weights = Raft_Large_Weights.DEFAULT
158 | self.model = raft_large(weights=weights, progress=False)
159 | self.model_transform = weights.transforms()
160 |
161 | def forward(self, img1, img2, img_size=None):
162 | """
163 | Compute optical flow between two frames using RAFT model.
164 |
165 | Args:
166 | img1 (torch.Tensor): First frame tensor with shape (B, C, H, W).
167 | img2 (torch.Tensor): Second frame tensor with shape (B, C, H, W).
168 | img_size (tuple): Size of the input images to be processed.
169 |
170 | Returns:
171 | flow (torch.Tensor): Optical flow tensor with shape (B, 2, H, W).
172 | """
173 | original_size = img1.shape[2:]
174 | # Preprocess the input frames
175 | if img_size is not None:
176 | img1 = TF.resize(img1, size=img_size, antialias=False)
177 | img2 = TF.resize(img2, size=img_size, antialias=False)
178 |
179 | img1, img2 = self.model_transform(img1, img2)
180 |
181 | # Compute the optical flow using the RAFT model
182 | with torch.no_grad():
183 | list_of_flows = self.model(img1, img2)
184 | flow = list_of_flows[-1]
185 |
186 | if img_size is not None:
187 | flow = resize_flow(flow, original_size)
188 |
189 | return flow
190 |
--------------------------------------------------------------------------------
/misc_utils/image_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import matplotlib.pyplot as plt
3 | import torch
4 | import numpy as np
5 | import cv2
6 | import imageio
7 | from PIL import Image
8 | import textwrap
9 |
10 | def find_nearest_Nx(size, N=32):
11 | return int(np.ceil(size / N) * N)
12 |
13 | def load_image_as_tensor(image_path, image_size):
14 | if isinstance(image_size, int):
15 | image_size = (image_size, image_size)
16 | image = cv2.imread(image_path)[..., ::-1]
17 | try:
18 | image = cv2.resize(image, image_size)
19 | except Exception as e:
20 | print(e)
21 | print(image_path)
22 |
23 | image = torch.from_numpy(np.array(image).transpose(2, 0, 1)) / 255.
24 | return image
25 |
26 | def show_image(image):
27 | if len(image.shape) == 4:
28 | image = image[0]
29 | plt.imshow(image.permute(1, 2, 0).detach().cpu().numpy())
30 | plt.show()
31 |
32 | def extract_video(video_path, save_dir, sampling_fps, skip_frames=0):
33 | os.makedirs(save_dir, exist_ok=True)
34 | cap = cv2.VideoCapture(video_path)
35 | frame_skip = int(cap.get(cv2.CAP_PROP_FPS) / sampling_fps)
36 | frame_count = 0
37 | save_count = 0
38 | while True:
39 | ret, frame = cap.read()
40 | if not ret:
41 | break
42 | if frame_count < skip_frames: # skip the first N frames
43 | frame_count += 1
44 | continue
45 | if (frame_count - skip_frames) % frame_skip == 0:
46 | # Save the frame as an image file if it doesn't already exist
47 | save_path = os.path.join(save_dir, f"frame{save_count:04d}.jpg")
48 | save_count += 1
49 | if not os.path.exists(save_path):
50 | cv2.imwrite(save_path, frame)
51 | frame_count += 1
52 | cap.release()
53 | cv2.destroyAllWindows()
54 |
55 | def concatenate_frames_to_video(frame_dir, video_path, fps):
56 | os.makedirs(os.path.dirname(video_path), exist_ok=True)
57 | # Get the list of frame file names in the directory
58 | frame_files = [f for f in os.listdir(frame_dir) if f.startswith("frame")]
59 | # Sort the frame file names in ascending order
60 | frame_files.sort()
61 | # Load the first frame to get the frame size
62 | frame = cv2.imread(os.path.join(frame_dir, frame_files[0]))
63 | height, width, _ = frame.shape
64 | # Initialize the video writer
65 | fourcc = cv2.VideoWriter_fourcc(*"mp4v")
66 | out = cv2.VideoWriter(video_path, fourcc, fps, (width, height))
67 | # Loop through the frame files and add them to the video
68 | for frame_file in frame_files:
69 | frame_path = os.path.join(frame_dir, frame_file)
70 | frame = cv2.imread(frame_path)
71 | out.write(frame)
72 | # Release the video writer
73 | out.release()
74 |
75 | def cumulative_histogram(hist):
76 | cum_hist = hist.copy()
77 | for i in range(1, len(hist)):
78 | cum_hist[i] = cum_hist[i - 1] + hist[i]
79 | return cum_hist
80 |
81 | def histogram_matching(src_img, ref_img):
82 | src_img = (src_img * 255).astype(np.uint8)
83 | ref_img = (ref_img * 255).astype(np.uint8)
84 | src_img_yuv = cv2.cvtColor(src_img, cv2.COLOR_RGB2YUV)
85 | ref_img_yuv = cv2.cvtColor(ref_img, cv2.COLOR_RGB2YUV)
86 |
87 | matched_img = np.zeros_like(src_img_yuv)
88 | for channel in range(src_img_yuv.shape[2]):
89 | src_hist, _ = np.histogram(src_img_yuv[:, :, channel].ravel(), 256, (0, 256))
90 | ref_hist, _ = np.histogram(ref_img_yuv[:, :, channel].ravel(), 256, (0, 256))
91 |
92 | src_cum_hist = cumulative_histogram(src_hist)
93 | ref_cum_hist = cumulative_histogram(ref_hist)
94 |
95 | lut = np.zeros(256, dtype=np.uint8)
96 | j = 0
97 | for i in range(256):
98 | while ref_cum_hist[j] < src_cum_hist[i] and j < 255:
99 | j += 1
100 | lut[i] = j
101 |
102 | matched_img[:, :, channel] = cv2.LUT(src_img_yuv[:, :, channel], lut)
103 |
104 | matched_img = cv2.cvtColor(matched_img, cv2.COLOR_YUV2RGB)
105 | matched_img = matched_img.astype(np.float32) / 255
106 | return matched_img
107 |
108 | def canny_image_batch(image_batch, low_threshold=100, high_threshold=200):
109 | if isinstance(image_batch, torch.Tensor):
110 | # [-1, 1] tensor -> [0, 255] numpy array
111 | is_torch = True
112 | device = image_batch.device
113 | image_batch = (image_batch + 1) * 127.5
114 | image_batch = image_batch.permute(0, 2, 3, 1).detach().cpu().numpy()
115 | image_batch = image_batch.astype(np.uint8)
116 | image_batch = np.array([cv2.Canny(image, low_threshold, high_threshold) for image in image_batch])
117 | image_batch = image_batch[:, :, :, None]
118 | image_batch = np.concatenate([image_batch, image_batch, image_batch], axis=3)
119 |
120 | if is_torch:
121 | # [0, 255] numpy array -> [-1, 1] tensor
122 | image_batch = torch.from_numpy(image_batch).permute(0, 3, 1, 2).float() / 255.
123 | image_batch = image_batch.to(device)
124 | return image_batch
125 |
126 |
127 | def images_to_gif(images, filename, fps):
128 | os.makedirs(os.path.dirname(filename), exist_ok=True)
129 | # Normalize to 0-255 and convert to uint8
130 | images = [(img * 255).astype(np.uint8) if img.dtype == np.float32 else img for img in images]
131 | images = [Image.fromarray(img) for img in images]
132 | imageio.mimsave(filename, images, duration=1 / fps)
133 |
134 | def load_gif(image_path):
135 | import imageio
136 | gif = imageio.get_reader(image_path)
137 | np_images = np.array([frame[..., :3] for frame in gif])
138 | return np_images
139 |
140 | def add_text_to_frame(frame, text, font_scale=1, thickness=2, color=(0, 0, 0), bg_color=(255, 255, 255), max_width=30):
141 | """
142 | Add text to a frame.
143 | """
144 | # Make a copy of the frame
145 | frame_with_text = np.copy(frame)
146 | # Choose font
147 | font = cv2.FONT_HERSHEY_SIMPLEX
148 | # Split text into lines if it's too long
149 | lines = textwrap.wrap(text, width=max_width)
150 | # Get total text height
151 | total_text_height = len(lines) * (thickness * font_scale + 10) + 60 * font_scale
152 | # Create an image filled with the background color, having enough space for the text
153 | text_bg_img = np.full((int(total_text_height), frame.shape[1], 3), bg_color, dtype=np.uint8)
154 | # Put each line on the text background image
155 | y = 0
156 | for line in lines:
157 | text_size, _ = cv2.getTextSize(line, font, font_scale, thickness)
158 | text_x = (text_bg_img.shape[1] - text_size[0]) // 2
159 | y += text_size[1] + 10
160 | cv2.putText(text_bg_img, line, (text_x, y), font, font_scale, color, thickness)
161 | # Append the text background image to the frame
162 | frame_with_text = np.vstack((frame_with_text, text_bg_img))
163 |
164 | return frame_with_text
165 |
166 | def add_text_to_gif(numpy_images, text, **kwargs):
167 | """
168 | Add text to each frame of a gif.
169 | """
170 | # Iterate over frames and add text to each frame
171 | frames_with_text = []
172 | for frame in numpy_images:
173 | frame_with_text = add_text_to_frame(frame, text, **kwargs)
174 | frames_with_text.append(frame_with_text)
175 |
176 | # Convert the list of frames to a numpy array
177 | numpy_images_with_text = np.array(frames_with_text)
178 |
179 | return numpy_images_with_text
180 |
181 | def pad_images_to_same_height(images):
182 | """
183 | Pad images to the same height.
184 | """
185 | # Find the maximum height
186 | max_height = max(img.shape[0] for img in images)
187 |
188 | # Pad each image to the maximum height
189 | padded_images = []
190 | for img in images:
191 | pad_height = max_height - img.shape[0]
192 | padded_img = cv2.copyMakeBorder(img, 0, pad_height, 0, 0, cv2.BORDER_CONSTANT, value=[255, 255, 255])
193 | padded_images.append(padded_img)
194 |
195 | return padded_images
196 |
197 | def concatenate_gifs(gifs):
198 | """
199 | Concatenate gifs.
200 | """
201 | # Ensure that all gifs have the same number of frames
202 | min_num_frames = min(gif.shape[0] for gif in gifs)
203 | gifs = [gif[:min_num_frames] for gif in gifs]
204 |
205 | # Concatenate each frame
206 | concatenated_gifs = []
207 | for i in range(min_num_frames):
208 | # Get the i-th frame from each gif
209 | frames = [gif[i] for gif in gifs]
210 |
211 | # Pad the frames to the same height
212 | padded_frames = pad_images_to_same_height(frames)
213 |
214 | # Concatenate the padded frames
215 | concatenated_frame = np.concatenate(padded_frames, axis=1)
216 |
217 | concatenated_gifs.append(concatenated_frame)
218 |
219 | return np.array(concatenated_gifs)
220 |
221 | def stack_gifs(gifs):
222 | '''vertically stack gifs'''
223 | min_num_frames = min(gif.shape[0] for gif in gifs)
224 | stacked_gifs = []
225 |
226 | for i in range(min_num_frames):
227 | frames = [gif[i] for gif in gifs]
228 | stacked_frame = np.concatenate(frames, axis=0)
229 | stacked_gifs.append(stacked_frame)
230 |
231 | return np.array(stacked_gifs)
232 |
233 | def save_tensor_to_gif(images, filename, fps):
234 | images = images.squeeze(0).detach().cpu().numpy().transpose(0, 2, 3, 1) / 2 + 0.5
235 | images_to_gif(images, filename, fps)
236 |
237 | def save_tensor_to_images(images, output_dir):
238 | images = images.squeeze(0).detach().cpu().numpy().transpose(0, 2, 3, 1) / 2 + 0.5
239 | os.makedirs(output_dir, exist_ok=True)
240 | for i in range(images.shape[0]):
241 | plt.imsave(f'{output_dir}/{i:03d}.jpg', images[i])
--------------------------------------------------------------------------------
/misc_utils/model_utils.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import torch
3 | import numpy as np
4 | from inspect import isfunction
5 |
6 | def instantiate_from_config(config):
7 | if not "target" in config:
8 | raise KeyError("Expected key `target` to instantiate.")
9 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
10 |
11 |
12 | def get_obj_from_str(string, reload=False):
13 | module, cls = string.rsplit(".", 1)
14 | if reload:
15 | module_imp = importlib.import_module(module)
16 | importlib.reload(module_imp)
17 | return getattr(importlib.import_module(module, package=None), cls)
18 |
19 | def exists(x):
20 | return x is not None
21 |
22 | def default(val, d):
23 | if exists(val):
24 | return val
25 | return d() if isfunction(d) else d
26 |
27 | def noise_like(shape, device, repeat=False):
28 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
29 | noise = lambda: torch.randn(shape, device=device)
30 | return repeat_noise() if repeat else noise()
31 |
32 | def extract_into_tensor(a, t, x_shape):
33 | b, *_ = t.shape
34 | out = a.gather(-1, t)
35 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
36 |
37 | def right_pad_dims_to(x, t):
38 | padding_dims = x.ndim - t.ndim
39 | if padding_dims <= 0:
40 | return t
41 | return t.view(*t.shape, *((1,) * padding_dims))
42 |
43 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
44 | if schedule == "linear" or schedule == "scaled_linear":
45 | betas = (
46 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
47 | )
48 |
49 | elif schedule == "cosine":
50 | timesteps = (
51 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
52 | )
53 | alphas = timesteps / (1 + cosine_s) * np.pi / 2
54 | alphas = torch.cos(alphas).pow(2)
55 | alphas = alphas / alphas[0]
56 | betas = 1 - alphas[1:] / alphas[:-1]
57 | betas = np.clip(betas, a_min=0, a_max=0.999)
58 |
59 | elif schedule == "sqrt_linear":
60 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
61 | elif schedule == "sqrt":
62 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
63 | else:
64 | raise ValueError(f"schedule '{schedule}' unknown.")
65 | return betas.numpy()
66 |
67 |
68 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
69 | if ddim_discr_method == 'uniform':
70 | c = num_ddpm_timesteps // num_ddim_timesteps
71 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
72 | elif ddim_discr_method == 'quad':
73 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
74 | else:
75 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
76 |
77 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps
78 | # add one to get the final alpha values right (the ones from first scale to data during sampling)
79 | steps_out = ddim_timesteps + 1
80 | if verbose:
81 | print(f'Selected timesteps for ddim sampler: {steps_out}')
82 | return steps_out
83 |
84 |
85 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
86 | # select alphas for computing the variance schedule
87 | alphas = alphacums[ddim_timesteps]
88 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
89 |
90 | # according the the formula provided in https://arxiv.org/abs/2010.02502
91 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
92 | if verbose:
93 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
94 | print(f'For the chosen value of eta, which is {eta}, '
95 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
96 | return sigmas, alphas, alphas_prev
97 |
98 |
99 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
100 | """
101 | Create a beta schedule that discretizes the given alpha_t_bar function,
102 | which defines the cumulative product of (1-beta) over time from t = [0,1].
103 | :param num_diffusion_timesteps: the number of betas to produce.
104 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
105 | produces the cumulative product of (1-beta) up to that
106 | part of the diffusion process.
107 | :param max_beta: the maximum beta to use; use values lower than 1 to
108 | prevent singularities.
109 | """
110 | betas = []
111 | for i in range(num_diffusion_timesteps):
112 | t1 = i / num_diffusion_timesteps
113 | t2 = (i + 1) / num_diffusion_timesteps
114 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
115 | return np.array(betas)
--------------------------------------------------------------------------------
/misc_utils/ptp_utils.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | import torch
3 | import numpy as np
4 |
5 | @dataclass
6 | class Edit:
7 | old: str
8 | new: str
9 | weight: float = 1.0
10 |
11 |
12 | @dataclass
13 | class Insert:
14 | text: str
15 | weight: float = 1.0
16 |
17 | @property
18 | def old(self):
19 | return ""
20 |
21 | @property
22 | def new(self):
23 | return self.text
24 |
25 |
26 | @dataclass
27 | class Delete:
28 | text: str
29 | weight: float = 1.0
30 |
31 | @property
32 | def old(self):
33 | return self.text
34 |
35 | @property
36 | def new(self):
37 | return ""
38 |
39 |
40 | @dataclass
41 | class Text:
42 | text: str
43 | weight: float = 1.0
44 |
45 | @property
46 | def old(self):
47 | return self.text
48 |
49 | @property
50 | def new(self):
51 | return self.text
52 |
53 | @torch.inference_mode()
54 | def get_text_embedding(prompt, tokenizer, text_encoder):
55 | text_input_ids = tokenizer(
56 | prompt,
57 | padding="max_length",
58 | truncation=True,
59 | max_length=tokenizer.model_max_length,
60 | return_tensors="pt",
61 | ).input_ids
62 | text_embeddings = text_encoder(text_input_ids.to(text_encoder.device))[0]
63 | return text_embeddings
64 |
65 | @torch.inference_mode()
66 | def encode_text(text_pieces, tokenizer, text_encoder):
67 | n_old_tokens = 0
68 | n_new_tokens = 0
69 | new_id_to_old_id = []
70 | weights = []
71 | for piece in text_pieces:
72 | old, new = piece.old, piece.new
73 | old_tokens = tokenizer.tokenize(old)
74 | new_tokens = tokenizer.tokenize(new)
75 | if len(old_tokens) == 0 and len(new_tokens) == 0:
76 | continue
77 | elif old == new:
78 | n_old_tokens += len(old_tokens)
79 | n_new_tokens += len(new_tokens)
80 | new_id_to_old_id.extend(range(n_old_tokens - len(old_tokens), n_old_tokens))
81 | elif len(old_tokens) == 0:
82 | # insert
83 | new_id_to_old_id.extend([-1] * len(new_tokens))
84 | n_new_tokens += len(new_tokens)
85 | elif len(new_tokens) == 0:
86 | # delete
87 | n_old_tokens += len(old_tokens)
88 | else:
89 | # replace
90 | n_old_tokens += len(old_tokens)
91 | n_new_tokens += len(new_tokens)
92 | start = n_old_tokens - len(old_tokens)
93 | end = n_old_tokens
94 | ids = np.linspace(start, end, len(new_tokens), endpoint=False).astype(int)
95 | new_id_to_old_id.extend(list(ids))
96 | weights.extend([piece.weight] * len(new_tokens))
97 |
98 | old_prompt = " ".join([piece.old for piece in text_pieces])
99 | new_prompt = " ".join([piece.new for piece in text_pieces])
100 | old_text_input_ids = tokenizer(
101 | old_prompt,
102 | padding="max_length",
103 | truncation=True,
104 | max_length=tokenizer.model_max_length,
105 | return_tensors="pt",
106 | ).input_ids
107 | new_text_input_ids = tokenizer(
108 | new_prompt,
109 | padding="max_length",
110 | truncation=True,
111 | max_length=tokenizer.model_max_length,
112 | return_tensors="pt",
113 | ).input_ids
114 |
115 | old_text_embeddings = text_encoder(old_text_input_ids.to(text_encoder.device))[0]
116 | new_text_embeddings = text_encoder(new_text_input_ids.to(text_encoder.device))[0]
117 | value = new_text_embeddings.clone() # batch (1), seq, dim
118 | key = new_text_embeddings.clone()
119 |
120 | for i, (j, weight) in enumerate(zip(new_id_to_old_id, weights)):
121 | if 0 <= j < old_text_embeddings.shape[1]:
122 | key[0, i] = old_text_embeddings[0, j]
123 | value[0, i] *= weight
124 | return key, value
125 |
126 | @torch.inference_mode()
127 | def get_text_embedding_openclip(prompt, text_encoder, device='cuda'):
128 | import open_clip
129 | text_input_ids = open_clip.tokenize(prompt)
130 | text_embeddings = text_encoder(text_input_ids.to(device))
131 | return text_embeddings
132 |
133 | @torch.inference_mode()
134 | def encode_text_openclip(text_pieces, text_encoder, device='cuda'):
135 | import open_clip
136 | n_old_tokens = 0
137 | n_new_tokens = 0
138 | new_id_to_old_id = []
139 | weights = []
140 | for piece in text_pieces:
141 | old, new = piece.old, piece.new
142 | old_tokens = open_clip.tokenize(old)
143 | new_tokens = open_clip.tokenize(new)
144 | if len(old_tokens) == 0 and len(new_tokens) == 0:
145 | continue
146 | elif old == new:
147 | n_old_tokens += len(old_tokens)
148 | n_new_tokens += len(new_tokens)
149 | new_id_to_old_id.extend(range(n_old_tokens - len(old_tokens), n_old_tokens))
150 | elif len(old_tokens) == 0:
151 | # insert
152 | new_id_to_old_id.extend([-1] * len(new_tokens))
153 | n_new_tokens += len(new_tokens)
154 | elif len(new_tokens) == 0:
155 | # delete
156 | n_old_tokens += len(old_tokens)
157 | else:
158 | # replace
159 | n_old_tokens += len(old_tokens)
160 | n_new_tokens += len(new_tokens)
161 | start = n_old_tokens - len(old_tokens)
162 | end = n_old_tokens
163 | ids = np.linspace(start, end, len(new_tokens), endpoint=False).astype(int)
164 | new_id_to_old_id.extend(list(ids))
165 | weights.extend([piece.weight] * len(new_tokens))
166 |
167 | old_prompt = " ".join([piece.old for piece in text_pieces])
168 | new_prompt = " ".join([piece.new for piece in text_pieces])
169 | old_text_input_ids = open_clip.tokenize(old_prompt)
170 | new_text_input_ids = open_clip.tokenize(new_prompt)
171 |
172 | old_text_embeddings = text_encoder(old_text_input_ids.to(device))
173 | new_text_embeddings = text_encoder(new_text_input_ids.to(device))
174 | value = new_text_embeddings.clone() # batch (1), seq, dim
175 | key = new_text_embeddings.clone()
176 |
177 | for i, (j, weight) in enumerate(zip(new_id_to_old_id, weights)):
178 | if 0 <= j < old_text_embeddings.shape[1]:
179 | key[0, i] = old_text_embeddings[0, j]
180 | value[0, i] *= weight
181 | return key, value
--------------------------------------------------------------------------------
/misc_utils/train_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from omegaconf import OmegaConf
3 | from pytorch_lightning.loggers import WandbLogger
4 | from misc_utils.model_utils import instantiate_from_config, get_obj_from_str
5 |
6 | def get_models(args):
7 | unet = instantiate_from_config(args.unet)
8 | model_dict = {
9 | 'unet': unet,
10 | }
11 |
12 | if args.get('vae'):
13 | vae = instantiate_from_config(args.vae)
14 | model_dict['vae'] = vae
15 |
16 | if args.get('text_model'):
17 | text_model = instantiate_from_config(args.text_model)
18 | model_dict['text_model'] = text_model
19 |
20 | if args.get('ctrlnet'):
21 | ctrlnet = instantiate_from_config(args.ctrlnet)
22 | model_dict['ctrlnet'] = ctrlnet
23 |
24 | return model_dict
25 |
26 | def get_DDPM(diffusion_configs, log_args={}, **models):
27 | diffusion_model_class = diffusion_configs['target']
28 | diffusion_args = diffusion_configs['params']
29 | DDPM_model = get_obj_from_str(diffusion_model_class)
30 | ddpm_model = DDPM_model(
31 | log_args=log_args,
32 | **models,
33 | **diffusion_args
34 | )
35 | return ddpm_model
36 |
37 |
38 | def get_logger(args):
39 | wandb_logger = WandbLogger(
40 | project=args["expt_name"],
41 | )
42 | return wandb_logger
43 |
44 | def get_callbacks(args, wandb_logger):
45 | callbacks = []
46 | for callback in args['callbacks']:
47 | if callback.get('require_wandb', False):
48 | # we need to pass wandb logger to the callback
49 | callback_obj = get_obj_from_str(callback.target)
50 | callbacks.append(
51 | callback_obj(wandb_logger=wandb_logger, **callback.params)
52 | )
53 | else:
54 | callbacks.append(
55 | instantiate_from_config(callback)
56 | )
57 | return callbacks
58 |
59 | def get_dataset(args):
60 | from torch.utils.data import DataLoader
61 | data_args = args['data']
62 | train_set = instantiate_from_config(data_args['train'])
63 | val_set = instantiate_from_config(data_args['val'])
64 | train_loader = DataLoader(
65 | train_set, batch_size=data_args['batch_size'], shuffle=True,
66 | num_workers=4*len(args['trainer_args']['devices']), pin_memory=True
67 | )
68 | val_loader = DataLoader(
69 | val_set, batch_size=data_args['val_batch_size'],
70 | num_workers=len(args['trainer_args']['devices']), pin_memory=True
71 | )
72 | return train_loader, val_loader, train_set, val_set
73 |
74 | def unit_test_create_model(config_path):
75 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
76 | conf = OmegaConf.load(config_path)
77 | models = get_models(conf)
78 | ddpm = get_DDPM(conf['diffusion'], log_args=conf, **models)
79 | ddpm = ddpm.to(device)
80 | return ddpm
81 |
82 | def unit_test_create_dataset(config_path, split='train'):
83 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
84 | conf = OmegaConf.load(config_path)
85 | train_loader, val_loader, train_set, val_set = get_dataset(conf)
86 | if split == 'train':
87 | batch = next(iter(train_loader))
88 | else:
89 | batch = next(iter(val_loader))
90 | for k, v in batch.items():
91 | if isinstance(v, torch.Tensor):
92 | batch[k] = v.to(device)
93 | return batch
94 |
95 | def unit_test_training_step(config_path):
96 | ddpm = unit_test_create_model(config_path)
97 | batch = unit_test_create_dataset(config_path)
98 | res = ddpm.training_step(batch, 0)
99 | return res
100 |
101 | def unit_test_val_step(config_path):
102 | ddpm = unit_test_create_model(config_path)
103 | batch = unit_test_create_dataset(config_path, split='val')
104 | res = ddpm.validation_step(batch, 0)
105 | return res
106 |
107 | NEGATIVE_PROMPTS = "(((deformed))), blurry, bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), fused fingers, messy drawing, broken legs censor, censored, censor_bar, multiple breasts, (mutated hands and fingers:1.5), (long body :1.3), (mutation, poorly drawn :1.2), black-white, bad anatomy, liquid body, liquidtongue, disfigured, malformed, mutated, anatomical nonsense, text font ui, error, malformed hands, long neck, blurred, lowers, low res, bad anatomy, bad proportions, bad shadow, uncoordinated body, unnatural body, fused breasts, bad breasts, huge breasts, poorly drawn breasts, extra breasts, liquid breasts, heavy breasts, missingbreasts, huge haunch, huge thighs, huge calf, bad hands, fused hand, missing hand, disappearing arms, disappearing thigh, disappearing calf, disappearing legs, fusedears, bad ears, poorly drawn ears, extra ears, liquid ears, heavy ears, missing ears, old photo, low res, black and white, black and white filter, colorless"
108 |
--------------------------------------------------------------------------------
/misc_utils/video_ptp_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from modules.damo_text_to_video.unet_sd import UNetSD
3 | from misc_utils.train_utils import instantiate_from_config
4 | from omegaconf import OmegaConf
5 | from modules.damo_text_to_video.text_model import FrozenOpenCLIPEmbedder
6 | from typing import List, Tuple, Union
7 | from diff_match_patch import diff_match_patch
8 | import difflib
9 | from misc_utils.ptp_utils import get_text_embedding_openclip, encode_text_openclip, Text, Edit, Insert, Delete
10 |
11 |
12 | def get_models_of_damo_model(
13 | unet_config: str,
14 | unet_ckpt: str,
15 | vae_config: str,
16 | vae_ckpt: str,
17 | text_model_ckpt: str,
18 | ):
19 | vae_conf = OmegaConf.load(vae_config)
20 | unet_conf = OmegaConf.load(unet_config)
21 |
22 | vae = instantiate_from_config(vae_conf.vae)
23 | vae.load_state_dict(torch.load(vae_ckpt, map_location='cpu'))
24 | vae = vae.half().cuda()
25 |
26 | unet = UNetSD(**unet_conf.model.model_cfg)
27 | unet.load_state_dict(torch.load(unet_ckpt, map_location='cpu'))
28 | unet = unet.half().cuda()
29 |
30 | text_model = FrozenOpenCLIPEmbedder(version=text_model_ckpt, layer='penultimate')
31 | text_model = text_model.half().cuda()
32 |
33 | return vae, unet, text_model
34 |
35 | def compute_diff_old(old_sentence: str, new_sentence: str) -> List[Tuple[Union[Text, Edit, Insert, Delete], str, str]]:
36 | dmp = diff_match_patch()
37 | diff = dmp.diff_main(old_sentence, new_sentence)
38 | dmp.diff_cleanupSemantic(diff)
39 |
40 | result = []
41 | i = 0
42 | while i < len(diff):
43 | op, data = diff[i]
44 | if op == 0: # Equal
45 | # result.append((Text, data, data))
46 | result.append(Text(text=data))
47 | elif op == -1: # Delete
48 | if i + 1 < len(diff) and diff[i + 1][0] == 1: # If next operation is Insert
49 | result.append(Edit(old=data, new=diff[i + 1][1])) # Append as Edit operation
50 | i += 1 # Skip next operation because we've handled it here
51 | else:
52 | result.append(Delete(text=data))
53 | elif op == 1: # Insert
54 | if i == 0 or diff[i - 1][0] != -1: # If previous operation wasn't Delete
55 | result.append(Insert(text=data))
56 | i += 1
57 |
58 | return result
59 |
60 | def compute_diff(old_sentence: str, new_sentence: str) -> List[Union[Text, Edit, Insert, Delete]]:
61 | differ = difflib.Differ()
62 | diff = list(differ.compare(old_sentence.split(), new_sentence.split()))
63 |
64 | result = []
65 | i = 0
66 | while i < len(diff):
67 | if diff[i][0] == ' ': # Equal
68 | equal_words = [diff[i][2:]]
69 | while i + 1 < len(diff) and diff[i + 1][0] == ' ':
70 | i += 1
71 | equal_words.append(diff[i][2:])
72 | result.append(Text(text=' '.join(equal_words)))
73 | elif diff[i][0] == '-': # Delete
74 | deleted_words = [diff[i][2:]]
75 | while i + 1 < len(diff) and diff[i + 1][0] == '-':
76 | i += 1
77 | deleted_words.append(diff[i][2:])
78 | result.append(Delete(text=' '.join(deleted_words)))
79 | elif diff[i][0] == '+': # Insert
80 | inserted_words = [diff[i][2:]]
81 | while i + 1 < len(diff) and diff[i + 1][0] == '+':
82 | i += 1
83 | inserted_words.append(diff[i][2:])
84 | result.append(Insert(text=' '.join(inserted_words)))
85 | i += 1
86 |
87 | # Post-process to merge adjacent inserts and deletes into edits
88 | i = 0
89 | while i < len(result) - 1:
90 | if isinstance(result[i], Delete) and isinstance(result[i+1], Insert):
91 | result[i:i+2] = [Edit(old=result[i].text, new=result[i+1].text)]
92 | elif isinstance(result[i], Insert) and isinstance(result[i+1], Delete):
93 | result[i:i+2] = [Edit(old=result[i+1].text, new=result[i].text)]
94 | else:
95 | i += 1
96 |
97 | return result
--------------------------------------------------------------------------------
/modules/damo_text_to_video/configuration.json:
--------------------------------------------------------------------------------
1 | { "framework": "pytorch",
2 | "task": "text-to-video-synthesis",
3 | "model": {
4 | "type": "latent-text-to-video-synthesis",
5 | "model_args": {
6 | "ckpt_clip": "open_clip_pytorch_model.bin",
7 | "ckpt_unet": "text2video_pytorch_model.pth",
8 | "ckpt_autoencoder": "VQGAN_autoencoder.pth",
9 | "max_frames": 16,
10 | "tiny_gpu": 1
11 | },
12 | "model_cfg": {
13 | "in_dim": 4,
14 | "dim": 320,
15 | "y_dim": 768,
16 | "context_dim": 1024,
17 | "out_dim": 4,
18 | "dim_mult": [1, 2, 4, 4],
19 | "num_heads": 8,
20 | "head_dim": 64,
21 | "num_res_blocks": 2,
22 | "attn_scales": [1, 0.5, 0.25],
23 | "dropout": 0.1,
24 | "temporal_attention": "True",
25 | "use_checkpoint": "True"
26 | }
27 | },
28 | "pipeline": {
29 | "type": "latent-text-to-video-synthesis"
30 | }
31 | }
--------------------------------------------------------------------------------
/modules/damo_text_to_video/text_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import open_clip
3 |
4 | class FrozenOpenCLIPEmbedder(torch.nn.Module):
5 | """
6 | Uses the OpenCLIP transformer encoder for text
7 | """
8 | LAYERS = ['last', 'penultimate']
9 |
10 | def __init__(self,
11 | arch='ViT-H-14',
12 | version='open_clip_pytorch_model.bin',
13 | device='cuda',
14 | max_length=77,
15 | freeze=True,
16 | layer='last'):
17 | super().__init__()
18 | assert layer in self.LAYERS
19 | model, _, _ = open_clip.create_model_and_transforms(
20 | arch, device=torch.device('cpu'), pretrained=version)
21 | del model.visual
22 | self.model = model
23 |
24 | self.device = device
25 | self.max_length = max_length
26 | if freeze:
27 | self.freeze()
28 | self.layer = layer
29 | if self.layer == 'last':
30 | self.layer_idx = 0
31 | elif self.layer == 'penultimate':
32 | self.layer_idx = 1
33 | else:
34 | raise NotImplementedError()
35 |
36 | def freeze(self):
37 | self.model = self.model.eval()
38 | for param in self.parameters():
39 | param.requires_grad = False
40 |
41 | def forward(self, text):
42 | tokens = open_clip.tokenize(text)
43 | z = self.encode_with_transformer(tokens.to(self.device))
44 | return z
45 |
46 | def encode_with_transformer(self, text):
47 | x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
48 | x = x + self.model.positional_embedding
49 | x = x.permute(1, 0, 2) # NLD -> LND
50 | x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
51 | x = x.permute(1, 0, 2) # LND -> NLD
52 | x = self.model.ln_final(x)
53 | return x
54 |
55 | def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
56 | for i, r in enumerate(self.model.transformer.resblocks):
57 | if i == len(self.model.transformer.resblocks) - self.layer_idx:
58 | break
59 | x = r(x, attn_mask=attn_mask)
60 | return x
61 |
62 | def encode(self, text):
63 | return self(text)
--------------------------------------------------------------------------------
/modules/kl_autoencoder/autoencoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pytorch_lightning as pl
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | from modules.vqvae.model import Encoder, Decoder
7 |
8 | from misc_utils.model_utils import instantiate_from_config
9 |
10 | class DiagonalGaussianDistribution(object):
11 | def __init__(self, parameters, deterministic=False):
12 | self.parameters = parameters
13 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
14 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
15 | self.deterministic = deterministic
16 | self.std = torch.exp(0.5 * self.logvar)
17 | self.var = torch.exp(self.logvar)
18 | if self.deterministic:
19 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
20 |
21 | def sample(self):
22 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
23 | return x
24 |
25 | def kl(self, other=None):
26 | if self.deterministic:
27 | return torch.Tensor([0.])
28 | else:
29 | if other is None:
30 | return 0.5 * torch.sum(torch.pow(self.mean, 2)
31 | + self.var - 1.0 - self.logvar,
32 | dim=[1, 2, 3])
33 | else:
34 | return 0.5 * torch.sum(
35 | torch.pow(self.mean - other.mean, 2) / other.var
36 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
37 | dim=[1, 2, 3])
38 |
39 | def nll(self, sample, dims=[1,2,3]):
40 | if self.deterministic:
41 | return torch.Tensor([0.])
42 | logtwopi = np.log(2.0 * np.pi)
43 | return 0.5 * torch.sum(
44 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
45 | dim=dims)
46 |
47 | def mode(self):
48 | return self.mean
49 |
50 | class AutoencoderKL(pl.LightningModule):
51 | def __init__(self,
52 | ddconfig,
53 | lossconfig,
54 | embed_dim,
55 | ckpt_path=None,
56 | ignore_keys=[],
57 | image_key="image",
58 | colorize_nlabels=None,
59 | monitor=None,
60 | ):
61 | super().__init__()
62 | self.image_key = image_key
63 | self.encoder = Encoder(**ddconfig)
64 | self.decoder = Decoder(**ddconfig)
65 | self.loss = instantiate_from_config(lossconfig)
66 | assert ddconfig["double_z"]
67 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
68 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
69 | self.embed_dim = embed_dim
70 | if colorize_nlabels is not None:
71 | assert type(colorize_nlabels)==int
72 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
73 | if monitor is not None:
74 | self.monitor = monitor
75 | if ckpt_path is not None:
76 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
77 |
78 | def init_from_ckpt(self, path, ignore_keys=list()):
79 | sd = torch.load(path, map_location="cpu")["state_dict"]
80 | keys = list(sd.keys())
81 | for k in keys:
82 | for ik in ignore_keys:
83 | if k.startswith(ik):
84 | print("Deleting key {} from state_dict.".format(k))
85 | del sd[k]
86 | self.load_state_dict(sd, strict=False)
87 | print(f"Restored from {path}")
88 |
89 | def encode(self, x):
90 | h = self.encoder(x)
91 | moments = self.quant_conv(h)
92 | posterior = DiagonalGaussianDistribution(moments)
93 | # TODO check if need to put sample into DDIM_ldm class
94 | enc = posterior.sample()
95 | return enc #posterior
96 |
97 | def decode(self, z):
98 | z = self.post_quant_conv(z)
99 | dec = self.decoder(z)
100 | return dec
101 |
102 | def forward(self, input, sample_posterior=True):
103 | posterior = self.encode(input)
104 | if sample_posterior:
105 | z = posterior.sample()
106 | else:
107 | z = posterior.mode()
108 | dec = self.decode(z)
109 | return dec, posterior
110 |
111 | def get_input(self, batch, k):
112 | x = batch[k]
113 | if len(x.shape) == 3:
114 | x = x[..., None]
115 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
116 | return x
117 |
118 | def training_step(self, batch, batch_idx, optimizer_idx):
119 | inputs = self.get_input(batch, self.image_key)
120 | reconstructions, posterior = self(inputs)
121 |
122 | if optimizer_idx == 0:
123 | # train encoder+decoder+logvar
124 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
125 | last_layer=self.get_last_layer(), split="train")
126 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
127 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
128 | return aeloss
129 |
130 | if optimizer_idx == 1:
131 | # train the discriminator
132 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
133 | last_layer=self.get_last_layer(), split="train")
134 |
135 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
136 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
137 | return discloss
138 |
139 | def validation_step(self, batch, batch_idx):
140 | inputs = self.get_input(batch, self.image_key)
141 | reconstructions, posterior = self(inputs)
142 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
143 | last_layer=self.get_last_layer(), split="val")
144 |
145 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
146 | last_layer=self.get_last_layer(), split="val")
147 |
148 | self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
149 | self.log_dict(log_dict_ae)
150 | self.log_dict(log_dict_disc)
151 | return self.log_dict
152 |
153 | def configure_optimizers(self):
154 | lr = self.learning_rate
155 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
156 | list(self.decoder.parameters())+
157 | list(self.quant_conv.parameters())+
158 | list(self.post_quant_conv.parameters()),
159 | lr=lr, betas=(0.5, 0.9))
160 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
161 | lr=lr, betas=(0.5, 0.9))
162 | return [opt_ae, opt_disc], []
163 |
164 | def get_last_layer(self):
165 | return self.decoder.conv_out.weight
166 |
167 | @torch.no_grad()
168 | def log_images(self, batch, only_inputs=False, **kwargs):
169 | log = dict()
170 | x = self.get_input(batch, self.image_key)
171 | x = x.to(self.device)
172 | if not only_inputs:
173 | xrec, posterior = self(x)
174 | if x.shape[1] > 3:
175 | # colorize with random projection
176 | assert xrec.shape[1] > 3
177 | x = self.to_rgb(x)
178 | xrec = self.to_rgb(xrec)
179 | log["samples"] = self.decode(torch.randn_like(posterior.sample()))
180 | log["reconstructions"] = xrec
181 | log["inputs"] = x
182 | return log
183 |
184 | def to_rgb(self, x):
185 | assert self.image_key == "segmentation"
186 | if not hasattr(self, "colorize"):
187 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
188 | x = F.conv2d(x, weight=self.colorize)
189 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
190 | return x
--------------------------------------------------------------------------------
/modules/openclip/modules.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Mapping
2 | import torch
3 | import torch.nn as nn
4 | from torch.utils.checkpoint import checkpoint
5 |
6 | from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
7 |
8 | import open_clip
9 |
10 |
11 | class AbstractEncoder(nn.Module):
12 | def __init__(self):
13 | super().__init__()
14 |
15 | def encode(self, *args, **kwargs):
16 | raise NotImplementedError
17 |
18 |
19 | class IdentityEncoder(AbstractEncoder):
20 |
21 | def encode(self, x):
22 | return x
23 |
24 |
25 | class ClassEmbedder(nn.Module):
26 | def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
27 | super().__init__()
28 | self.key = key
29 | self.embedding = nn.Embedding(n_classes, embed_dim)
30 | self.n_classes = n_classes
31 | self.ucg_rate = ucg_rate
32 |
33 | def forward(self, batch, key=None, disable_dropout=False):
34 | if key is None:
35 | key = self.key
36 | # this is for use in crossattn
37 | c = batch[key][:, None]
38 | if self.ucg_rate > 0. and not disable_dropout:
39 | mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
40 | c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1)
41 | c = c.long()
42 | c = self.embedding(c)
43 | return c
44 |
45 | def get_unconditional_conditioning(self, bs, device="cuda"):
46 | uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
47 | uc = torch.ones((bs,), device=device) * uc_class
48 | uc = {self.key: uc}
49 | return uc
50 |
51 |
52 | def disabled_train(self, mode=True):
53 | """Overwrite model.train with this function to make sure train/eval mode
54 | does not change anymore."""
55 | return self
56 |
57 |
58 | class FrozenT5Embedder(AbstractEncoder):
59 | """Uses the T5 transformer encoder for text"""
60 | def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
61 | super().__init__()
62 | self.tokenizer = T5Tokenizer.from_pretrained(version)
63 | self.transformer = T5EncoderModel.from_pretrained(version)
64 | self.device = device
65 | self.max_length = max_length # TODO: typical value?
66 | if freeze:
67 | self.freeze()
68 |
69 | def freeze(self):
70 | self.transformer = self.transformer.eval()
71 | #self.train = disabled_train
72 | for param in self.parameters():
73 | param.requires_grad = False
74 |
75 | def forward(self, text):
76 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
77 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
78 | tokens = batch_encoding["input_ids"].to(self.device)
79 | outputs = self.transformer(input_ids=tokens)
80 |
81 | z = outputs.last_hidden_state
82 | return z
83 |
84 | def encode(self, text):
85 | return self(text)
86 |
87 |
88 | class FrozenCLIPEmbedder(AbstractEncoder):
89 | """Uses the CLIP transformer encoder for text (from huggingface)"""
90 | LAYERS = [
91 | "last",
92 | "pooled",
93 | "hidden"
94 | ]
95 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
96 | freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
97 | super().__init__()
98 | assert layer in self.LAYERS
99 | self.tokenizer = CLIPTokenizer.from_pretrained(version)
100 | self.transformer = CLIPTextModel.from_pretrained(version)
101 | self.device = device
102 | self.max_length = max_length
103 | if freeze:
104 | self.freeze()
105 | self.layer = layer
106 | self.layer_idx = layer_idx
107 | if layer == "hidden":
108 | assert layer_idx is not None
109 | assert 0 <= abs(layer_idx) <= 12
110 |
111 | def freeze(self):
112 | self.transformer = self.transformer.eval()
113 | #self.train = disabled_train
114 | for param in self.parameters():
115 | param.requires_grad = False
116 |
117 | def forward(self, text):
118 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
119 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
120 | tokens = batch_encoding["input_ids"].to(self.device)
121 | outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
122 | if self.layer == "last":
123 | z = outputs.last_hidden_state
124 | elif self.layer == "pooled":
125 | z = outputs.pooler_output[:, None, :]
126 | else:
127 | z = outputs.hidden_states[self.layer_idx]
128 | return z
129 |
130 | def encode(self, text):
131 | return self(text)
132 |
133 | def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
134 | state_dict.pop("transformer.text_model.embeddings.position_ids") # it seems that this is removed from the model in recent transformers versions
135 | return super().load_state_dict(state_dict, strict)
136 |
137 |
138 | class FrozenOpenCLIPEmbedder(AbstractEncoder):
139 | """
140 | Uses the OpenCLIP transformer encoder for text
141 | """
142 | LAYERS = [
143 | #"pooled",
144 | "last",
145 | "penultimate"
146 | ]
147 | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
148 | freeze=True, layer="last"):
149 | super().__init__()
150 | assert layer in self.LAYERS
151 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
152 | del model.visual
153 | self.model = model
154 |
155 | self.device = device
156 | self.max_length = max_length
157 | if freeze:
158 | self.freeze()
159 | self.layer = layer
160 | if self.layer == "last":
161 | self.layer_idx = 0
162 | elif self.layer == "penultimate":
163 | self.layer_idx = 1
164 | else:
165 | raise NotImplementedError()
166 |
167 | def freeze(self):
168 | self.model = self.model.eval()
169 | for param in self.parameters():
170 | param.requires_grad = False
171 |
172 | def forward(self, text):
173 | tokens = open_clip.tokenize(text)
174 | z = self.encode_with_transformer(tokens.to(self.device))
175 | return z
176 |
177 | def encode_with_transformer(self, text):
178 | x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
179 | x = x + self.model.positional_embedding
180 | x = x.permute(1, 0, 2) # NLD -> LND
181 | x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
182 | x = x.permute(1, 0, 2) # LND -> NLD
183 | x = self.model.ln_final(x)
184 | return x
185 |
186 | def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
187 | for i, r in enumerate(self.model.transformer.resblocks):
188 | if i == len(self.model.transformer.resblocks) - self.layer_idx:
189 | break
190 | if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
191 | x = checkpoint(r, x, attn_mask)
192 | else:
193 | x = r(x, attn_mask=attn_mask)
194 | return x
195 |
196 | def encode(self, text):
197 | return self(text)
198 |
199 |
200 | class FrozenCLIPT5Encoder(AbstractEncoder):
201 | def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
202 | clip_max_length=77, t5_max_length=77):
203 | super().__init__()
204 | self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
205 | self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
206 | # print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
207 | # f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.")
208 |
209 | def encode(self, text):
210 | return self(text)
211 |
212 | def forward(self, text):
213 | clip_z = self.clip_encoder.encode(text)
214 | t5_z = self.t5_encoder.encode(text)
215 | return [clip_z, t5_z]
216 |
217 |
218 |
--------------------------------------------------------------------------------
/modules/video_unet_temporal/attention.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
2 |
3 | from dataclasses import dataclass
4 | from typing import Optional
5 |
6 | import torch
7 | import torch.nn.functional as F
8 | from torch import nn
9 | from torch import einsum
10 | from misc_utils.model_utils import default, exists
11 |
12 | from diffusers.configuration_utils import ConfigMixin, register_to_config
13 | from diffusers.models.modeling_utils import ModelMixin
14 | from diffusers.utils import BaseOutput
15 | from diffusers.utils.import_utils import is_xformers_available
16 | from diffusers.models.attention import Attention, FeedForward, AdaLayerNorm
17 |
18 | from einops import rearrange, repeat
19 |
20 |
21 | @dataclass
22 | class Transformer3DModelOutput(BaseOutput):
23 | sample: torch.FloatTensor
24 |
25 |
26 | if is_xformers_available():
27 | import xformers
28 | import xformers.ops
29 | else:
30 | xformers = None
31 |
32 |
33 | class Transformer3DModel(ModelMixin, ConfigMixin):
34 | @register_to_config
35 | def __init__(
36 | self,
37 | num_attention_heads: int = 16,
38 | attention_head_dim: int = 88,
39 | in_channels: Optional[int] = None,
40 | num_layers: int = 1,
41 | dropout: float = 0.0,
42 | norm_num_groups: int = 32,
43 | cross_attention_dim: Optional[int] = None,
44 | attention_bias: bool = False,
45 | activation_fn: str = "geglu",
46 | num_embeds_ada_norm: Optional[int] = None,
47 | use_linear_projection: bool = False,
48 | only_cross_attention: bool = False,
49 | upcast_attention: bool = False,
50 | ):
51 | super().__init__()
52 | self.use_linear_projection = use_linear_projection
53 | self.num_attention_heads = num_attention_heads
54 | self.attention_head_dim = attention_head_dim
55 | inner_dim = num_attention_heads * attention_head_dim
56 |
57 | # Define input layers
58 | self.in_channels = in_channels
59 |
60 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
61 | if use_linear_projection:
62 | self.proj_in = nn.Linear(in_channels, inner_dim)
63 | else:
64 | self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
65 |
66 | # Define transformers blocks
67 | self.transformer_blocks = nn.ModuleList(
68 | [
69 | BasicTransformerBlock(
70 | inner_dim,
71 | num_attention_heads,
72 | attention_head_dim,
73 | dropout=dropout,
74 | cross_attention_dim=cross_attention_dim,
75 | activation_fn=activation_fn,
76 | num_embeds_ada_norm=num_embeds_ada_norm,
77 | attention_bias=attention_bias,
78 | only_cross_attention=only_cross_attention,
79 | upcast_attention=upcast_attention,
80 | )
81 | for d in range(num_layers)
82 | ]
83 | )
84 |
85 | # 4. Define output layers
86 | if use_linear_projection:
87 | self.proj_out = nn.Linear(in_channels, inner_dim)
88 | else:
89 | self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
90 |
91 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
92 | # Input
93 | assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
94 | video_length = hidden_states.shape[2]
95 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
96 | encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
97 |
98 | batch, channel, height, weight = hidden_states.shape
99 | residual = hidden_states
100 |
101 | hidden_states = self.norm(hidden_states)
102 | if not self.use_linear_projection:
103 | hidden_states = self.proj_in(hidden_states)
104 | inner_dim = hidden_states.shape[1]
105 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
106 | else:
107 | inner_dim = hidden_states.shape[1]
108 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
109 | hidden_states = self.proj_in(hidden_states)
110 |
111 | # Blocks
112 | for block in self.transformer_blocks:
113 | hidden_states = block(
114 | hidden_states,
115 | encoder_hidden_states=encoder_hidden_states,
116 | timestep=timestep,
117 | video_length=video_length
118 | )
119 |
120 | # Output
121 | if not self.use_linear_projection:
122 | hidden_states = (
123 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
124 | )
125 | hidden_states = self.proj_out(hidden_states)
126 | else:
127 | hidden_states = self.proj_out(hidden_states)
128 | hidden_states = (
129 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
130 | )
131 |
132 | output = hidden_states + residual
133 |
134 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
135 | if not return_dict:
136 | return (output,)
137 |
138 | return Transformer3DModelOutput(sample=output)
139 |
140 |
141 | class BasicTransformerBlock(nn.Module):
142 | def __init__(
143 | self,
144 | dim: int,
145 | num_attention_heads: int,
146 | attention_head_dim: int,
147 | dropout=0.0,
148 | cross_attention_dim: Optional[int] = None,
149 | activation_fn: str = "geglu",
150 | num_embeds_ada_norm: Optional[int] = None,
151 | attention_bias: bool = False,
152 | only_cross_attention: bool = False,
153 | upcast_attention: bool = False,
154 | ):
155 | super().__init__()
156 | self.only_cross_attention = only_cross_attention
157 | self.use_ada_layer_norm = num_embeds_ada_norm is not None
158 |
159 | # SC-Attn
160 | self.attn1 = Attention(
161 | query_dim=dim,
162 | heads=num_attention_heads,
163 | dim_head=attention_head_dim,
164 | dropout=dropout,
165 | bias=attention_bias,
166 | upcast_attention=upcast_attention,
167 | )
168 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
169 |
170 | # Cross-Attn
171 | if cross_attention_dim is not None:
172 | self.attn2 = Attention(
173 | query_dim=dim,
174 | cross_attention_dim=cross_attention_dim,
175 | heads=num_attention_heads,
176 | dim_head=attention_head_dim,
177 | dropout=dropout,
178 | bias=attention_bias,
179 | upcast_attention=upcast_attention,
180 | )
181 | else:
182 | self.attn2 = None
183 |
184 | if cross_attention_dim is not None:
185 | self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
186 | else:
187 | self.norm2 = None
188 |
189 | # Feed-forward
190 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
191 | self.norm3 = nn.LayerNorm(dim)
192 |
193 | # Temp-Attn
194 | # self.attn_temp = Attention(
195 | # query_dim=dim,
196 | # heads=num_attention_heads,
197 | # dim_head=attention_head_dim,
198 | # dropout=dropout,
199 | # bias=attention_bias,
200 | # upcast_attention=upcast_attention,
201 | # )
202 | # nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
203 | # self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
204 |
205 | def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, attention_op):
206 | if not is_xformers_available():
207 | print("Here is how to install it")
208 | raise ModuleNotFoundError(
209 | "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
210 | " xformers",
211 | name="xformers",
212 | )
213 | elif not torch.cuda.is_available():
214 | raise ValueError(
215 | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
216 | " available for GPU "
217 | )
218 | else:
219 | try:
220 | # Make sure we can run the memory efficient attention
221 | _ = xformers.ops.memory_efficient_attention(
222 | torch.randn((1, 2, 40), device="cuda"),
223 | torch.randn((1, 2, 40), device="cuda"),
224 | torch.randn((1, 2, 40), device="cuda"),
225 | )
226 | except Exception as e:
227 | raise e
228 | self.attn1.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers, attention_op)
229 | if self.attn2 is not None:
230 | self.attn2.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers, attention_op)
231 | # self.attn_temp.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers, attention_op)
232 |
233 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):
234 | # SparseCausal-Attention
235 | norm_hidden_states = (
236 | self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
237 | )
238 |
239 | if self.only_cross_attention:
240 | hidden_states = (
241 | self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
242 | )
243 | else:
244 | hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
245 |
246 | if self.attn2 is not None:
247 | # Cross-Attention
248 | norm_hidden_states = (
249 | self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
250 | )
251 | hidden_states = (
252 | self.attn2(
253 | norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
254 | )
255 | + hidden_states
256 | )
257 |
258 | # Feed-forward
259 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
260 |
261 | # Temporal-Attention
262 | # d = hidden_states.shape[1]
263 | # hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
264 | # norm_hidden_states = (
265 | # self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
266 | # )
267 | # hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
268 | # hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
269 |
270 | return hidden_states
271 |
--------------------------------------------------------------------------------
/modules/video_unet_temporal/resnet.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from einops import rearrange
8 |
9 |
10 | class InflatedConv3d(nn.Conv2d):
11 | def forward(self, x):
12 | video_length = x.shape[2]
13 |
14 | x = rearrange(x, "b c f h w -> (b f) c h w")
15 | x = super().forward(x)
16 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
17 |
18 | return x
19 |
20 |
21 | class Upsample3D(nn.Module):
22 | def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
23 | super().__init__()
24 | self.channels = channels
25 | self.out_channels = out_channels or channels
26 | self.use_conv = use_conv
27 | self.use_conv_transpose = use_conv_transpose
28 | self.name = name
29 |
30 | conv = None
31 | if use_conv_transpose:
32 | raise NotImplementedError
33 | elif use_conv:
34 | conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
35 |
36 | if name == "conv":
37 | self.conv = conv
38 | else:
39 | self.Conv2d_0 = conv
40 |
41 | def forward(self, hidden_states, output_size=None):
42 | assert hidden_states.shape[1] == self.channels
43 |
44 | if self.use_conv_transpose:
45 | raise NotImplementedError
46 |
47 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
48 | dtype = hidden_states.dtype
49 | if dtype == torch.bfloat16:
50 | hidden_states = hidden_states.to(torch.float32)
51 |
52 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
53 | if hidden_states.shape[0] >= 64:
54 | hidden_states = hidden_states.contiguous()
55 |
56 | # if `output_size` is passed we force the interpolation output
57 | # size and do not make use of `scale_factor=2`
58 | if output_size is None:
59 | hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
60 | else:
61 | hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
62 |
63 | # If the input is bfloat16, we cast back to bfloat16
64 | if dtype == torch.bfloat16:
65 | hidden_states = hidden_states.to(dtype)
66 |
67 | if self.use_conv:
68 | if self.name == "conv":
69 | hidden_states = self.conv(hidden_states)
70 | else:
71 | hidden_states = self.Conv2d_0(hidden_states)
72 |
73 | return hidden_states
74 |
75 |
76 | class Downsample3D(nn.Module):
77 | def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
78 | super().__init__()
79 | self.channels = channels
80 | self.out_channels = out_channels or channels
81 | self.use_conv = use_conv
82 | self.padding = padding
83 | stride = 2
84 | self.name = name
85 |
86 | if use_conv:
87 | conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
88 | else:
89 | raise NotImplementedError
90 |
91 | if name == "conv":
92 | self.Conv2d_0 = conv
93 | self.conv = conv
94 | elif name == "Conv2d_0":
95 | self.conv = conv
96 | else:
97 | self.conv = conv
98 |
99 | def forward(self, hidden_states):
100 | assert hidden_states.shape[1] == self.channels
101 | if self.use_conv and self.padding == 0:
102 | raise NotImplementedError
103 |
104 | assert hidden_states.shape[1] == self.channels
105 | hidden_states = self.conv(hidden_states)
106 |
107 | return hidden_states
108 |
109 |
110 | class ResnetBlock3D(nn.Module):
111 | def __init__(
112 | self,
113 | *,
114 | in_channels,
115 | out_channels=None,
116 | conv_shortcut=False,
117 | dropout=0.0,
118 | temb_channels=512,
119 | groups=32,
120 | groups_out=None,
121 | pre_norm=True,
122 | eps=1e-6,
123 | non_linearity="swish",
124 | time_embedding_norm="default",
125 | output_scale_factor=1.0,
126 | use_in_shortcut=None,
127 | ):
128 | super().__init__()
129 | self.pre_norm = pre_norm
130 | self.pre_norm = True
131 | self.in_channels = in_channels
132 | out_channels = in_channels if out_channels is None else out_channels
133 | self.out_channels = out_channels
134 | self.use_conv_shortcut = conv_shortcut
135 | self.time_embedding_norm = time_embedding_norm
136 | self.output_scale_factor = output_scale_factor
137 |
138 | if groups_out is None:
139 | groups_out = groups
140 |
141 | self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
142 |
143 | self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
144 |
145 | if temb_channels is not None:
146 | if self.time_embedding_norm == "default":
147 | time_emb_proj_out_channels = out_channels
148 | elif self.time_embedding_norm == "scale_shift":
149 | time_emb_proj_out_channels = out_channels * 2
150 | else:
151 | raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
152 |
153 | self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
154 | else:
155 | self.time_emb_proj = None
156 |
157 | self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
158 | self.dropout = torch.nn.Dropout(dropout)
159 | self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
160 |
161 | if non_linearity == "swish":
162 | self.nonlinearity = lambda x: F.silu(x)
163 | elif non_linearity == "mish":
164 | self.nonlinearity = Mish()
165 | elif non_linearity == "silu":
166 | self.nonlinearity = nn.SiLU()
167 |
168 | self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
169 |
170 | self.conv_shortcut = None
171 | if self.use_in_shortcut:
172 | self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
173 |
174 | def forward(self, input_tensor, temb):
175 | hidden_states = input_tensor
176 |
177 | hidden_states = self.norm1(hidden_states)
178 | hidden_states = self.nonlinearity(hidden_states)
179 |
180 | hidden_states = self.conv1(hidden_states)
181 |
182 | if temb is not None:
183 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
184 |
185 | if temb is not None and self.time_embedding_norm == "default":
186 | hidden_states = hidden_states + temb
187 |
188 | hidden_states = self.norm2(hidden_states)
189 |
190 | if temb is not None and self.time_embedding_norm == "scale_shift":
191 | scale, shift = torch.chunk(temb, 2, dim=1)
192 | hidden_states = hidden_states * (1 + scale) + shift
193 |
194 | hidden_states = self.nonlinearity(hidden_states)
195 |
196 | hidden_states = self.dropout(hidden_states)
197 | hidden_states = self.conv2(hidden_states)
198 |
199 | if self.conv_shortcut is not None:
200 | input_tensor = self.conv_shortcut(input_tensor)
201 |
202 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
203 |
204 | return output_tensor
205 |
206 |
207 | class Mish(torch.nn.Module):
208 | def forward(self, hidden_states):
209 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
--------------------------------------------------------------------------------
/modules/vqvae/autoencoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import pytorch_lightning as pl
4 | import torch.nn.functional as F
5 | from contextlib import contextmanager
6 |
7 | from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
8 |
9 | from .model import Encoder, Decoder
10 |
11 | from misc_utils.model_utils import instantiate_from_config
12 |
13 |
14 | class VQModel(pl.LightningModule):
15 | def __init__(self,
16 | ddconfig,
17 | lossconfig,
18 | n_embed,
19 | embed_dim,
20 | ckpt_path=None,
21 | ignore_keys=[],
22 | image_key="image",
23 | colorize_nlabels=None,
24 | monitor=None,
25 | batch_resize_range=None,
26 | scheduler_config=None,
27 | lr_g_factor=1.0,
28 | remap=None,
29 | sane_index_shape=False, # tell vector quantizer to return indices as bhw
30 | use_ema=False
31 | ):
32 | super().__init__()
33 | self.embed_dim = embed_dim
34 | self.n_embed = n_embed
35 | self.image_key = image_key
36 | self.encoder = Encoder(**ddconfig)
37 | self.decoder = Decoder(**ddconfig)
38 | self.loss = instantiate_from_config(lossconfig)
39 | self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
40 | remap=remap,
41 | sane_index_shape=sane_index_shape)
42 | self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
43 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
44 | if colorize_nlabels is not None:
45 | assert type(colorize_nlabels)==int
46 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
47 | if monitor is not None:
48 | self.monitor = monitor
49 | self.batch_resize_range = batch_resize_range
50 | if self.batch_resize_range is not None:
51 | print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
52 |
53 | self.use_ema = use_ema
54 | if self.use_ema:
55 | self.model_ema = LitEma(self)
56 | print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
57 |
58 | if ckpt_path is not None:
59 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
60 | self.scheduler_config = scheduler_config
61 | self.lr_g_factor = lr_g_factor
62 |
63 | @contextmanager
64 | def ema_scope(self, context=None):
65 | if self.use_ema:
66 | self.model_ema.store(self.parameters())
67 | self.model_ema.copy_to(self)
68 | if context is not None:
69 | print(f"{context}: Switched to EMA weights")
70 | try:
71 | yield None
72 | finally:
73 | if self.use_ema:
74 | self.model_ema.restore(self.parameters())
75 | if context is not None:
76 | print(f"{context}: Restored training weights")
77 |
78 | def init_from_ckpt(self, path, ignore_keys=list()):
79 | sd = torch.load(path, map_location="cpu")["state_dict"]
80 | keys = list(sd.keys())
81 | for k in keys:
82 | for ik in ignore_keys:
83 | if k.startswith(ik):
84 | print("Deleting key {} from state_dict.".format(k))
85 | del sd[k]
86 | missing, unexpected = self.load_state_dict(sd, strict=False)
87 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
88 | if len(missing) > 0:
89 | print(f"Missing Keys: {missing}")
90 | print(f"Unexpected Keys: {unexpected}")
91 |
92 | def on_train_batch_end(self, *args, **kwargs):
93 | if self.use_ema:
94 | self.model_ema(self)
95 |
96 | def encode(self, x):
97 | h = self.encoder(x)
98 | h = self.quant_conv(h)
99 | quant, emb_loss, info = self.quantize(h)
100 | return quant, emb_loss, info
101 |
102 | def encode_to_prequant(self, x):
103 | h = self.encoder(x)
104 | h = self.quant_conv(h)
105 | return h
106 |
107 | def decode(self, quant):
108 | quant = self.post_quant_conv(quant)
109 | dec = self.decoder(quant)
110 | return dec
111 |
112 | def decode_code(self, code_b):
113 | quant_b = self.quantize.embed_code(code_b)
114 | dec = self.decode(quant_b)
115 | return dec
116 |
117 | def forward(self, input, return_pred_indices=False):
118 | quant, diff, (_,_,ind) = self.encode(input)
119 | dec = self.decode(quant)
120 | if return_pred_indices:
121 | return dec, diff, ind
122 | return dec, diff
123 |
124 | def get_input(self, batch, k):
125 | x = batch[k]
126 | if len(x.shape) == 3:
127 | x = x[..., None]
128 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
129 | if self.batch_resize_range is not None:
130 | lower_size = self.batch_resize_range[0]
131 | upper_size = self.batch_resize_range[1]
132 | if self.global_step <= 4:
133 | # do the first few batches with max size to avoid later oom
134 | new_resize = upper_size
135 | else:
136 | new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
137 | if new_resize != x.shape[2]:
138 | x = F.interpolate(x, size=new_resize, mode="bicubic")
139 | x = x.detach()
140 | return x
141 |
142 | def training_step(self, batch, batch_idx, optimizer_idx):
143 | # https://github.com/pytorch/pytorch/issues/37142
144 | # try not to fool the heuristics
145 | x = self.get_input(batch, self.image_key)
146 | xrec, qloss, ind = self(x, return_pred_indices=True)
147 |
148 | if optimizer_idx == 0:
149 | # autoencode
150 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
151 | last_layer=self.get_last_layer(), split="train",
152 | predicted_indices=ind)
153 |
154 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
155 | return aeloss
156 |
157 | if optimizer_idx == 1:
158 | # discriminator
159 | discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
160 | last_layer=self.get_last_layer(), split="train")
161 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
162 | return discloss
163 |
164 | def validation_step(self, batch, batch_idx):
165 | log_dict = self._validation_step(batch, batch_idx)
166 | with self.ema_scope():
167 | log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
168 | return log_dict
169 |
170 | def _validation_step(self, batch, batch_idx, suffix=""):
171 | x = self.get_input(batch, self.image_key)
172 | xrec, qloss, ind = self(x, return_pred_indices=True)
173 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
174 | self.global_step,
175 | last_layer=self.get_last_layer(),
176 | split="val"+suffix,
177 | predicted_indices=ind
178 | )
179 |
180 | discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
181 | self.global_step,
182 | last_layer=self.get_last_layer(),
183 | split="val"+suffix,
184 | predicted_indices=ind
185 | )
186 | rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
187 | self.log(f"val{suffix}/rec_loss", rec_loss,
188 | prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
189 | self.log(f"val{suffix}/aeloss", aeloss,
190 | prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
191 | if version.parse(pl.__version__) >= version.parse('1.4.0'):
192 | del log_dict_ae[f"val{suffix}/rec_loss"]
193 | self.log_dict(log_dict_ae)
194 | self.log_dict(log_dict_disc)
195 | return self.log_dict
196 |
197 | def configure_optimizers(self):
198 | lr_d = self.learning_rate
199 | lr_g = self.lr_g_factor*self.learning_rate
200 | print("lr_d", lr_d)
201 | print("lr_g", lr_g)
202 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
203 | list(self.decoder.parameters())+
204 | list(self.quantize.parameters())+
205 | list(self.quant_conv.parameters())+
206 | list(self.post_quant_conv.parameters()),
207 | lr=lr_g, betas=(0.5, 0.9))
208 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
209 | lr=lr_d, betas=(0.5, 0.9))
210 |
211 | if self.scheduler_config is not None:
212 | scheduler = instantiate_from_config(self.scheduler_config)
213 |
214 | print("Setting up LambdaLR scheduler...")
215 | scheduler = [
216 | {
217 | 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
218 | 'interval': 'step',
219 | 'frequency': 1
220 | },
221 | {
222 | 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
223 | 'interval': 'step',
224 | 'frequency': 1
225 | },
226 | ]
227 | return [opt_ae, opt_disc], scheduler
228 | return [opt_ae, opt_disc], []
229 |
230 | def get_last_layer(self):
231 | return self.decoder.conv_out.weight
232 |
233 | def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
234 | log = dict()
235 | x = self.get_input(batch, self.image_key)
236 | x = x.to(self.device)
237 | if only_inputs:
238 | log["inputs"] = x
239 | return log
240 | xrec, _ = self(x)
241 | if x.shape[1] > 3:
242 | # colorize with random projection
243 | assert xrec.shape[1] > 3
244 | x = self.to_rgb(x)
245 | xrec = self.to_rgb(xrec)
246 | log["inputs"] = x
247 | log["reconstructions"] = xrec
248 | if plot_ema:
249 | with self.ema_scope():
250 | xrec_ema, _ = self(x)
251 | if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
252 | log["reconstructions_ema"] = xrec_ema
253 | return log
254 |
255 | def to_rgb(self, x):
256 | assert self.image_key == "segmentation"
257 | if not hasattr(self, "colorize"):
258 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
259 | x = F.conv2d(x, weight=self.colorize)
260 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
261 | return x
262 |
263 |
264 | class VQModelInterface(VQModel):
265 | def __init__(self, embed_dim, *args, **kwargs):
266 | super().__init__(embed_dim=embed_dim, *args, **kwargs)
267 | self.embed_dim = embed_dim
268 |
269 | def encode(self, x):
270 | h = self.encoder(x)
271 | h = self.quant_conv(h)
272 | return h
273 |
274 | def decode(self, h, force_not_quantize=False):
275 | # also go through quantization layer
276 | if not force_not_quantize:
277 | quant, emb_loss, info = self.quantize(h)
278 | else:
279 | quant = h
280 | quant = self.post_quant_conv(quant)
281 | dec = self.decoder(quant)
282 | return dec
283 |
284 |
--------------------------------------------------------------------------------
/pl_trainer/diffusion.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import pytorch_lightning as pl
4 | from misc_utils.model_utils import default, instantiate_from_config
5 | from diffusers import DDPMScheduler
6 |
7 | def mean_flat(tensor):
8 | """
9 | Take the mean over all non-batch dimensions.
10 | """
11 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
12 |
13 | class DDPM(pl.LightningModule):
14 | def __init__(
15 | self,
16 | unet,
17 | beta_schedule_args={
18 | 'beta_start': 0.00085,
19 | 'beta_end': 0.0012,
20 | 'num_train_timesteps': 1000,
21 | 'beta_schedule': 'scaled_linear',
22 | 'clip_sample': False,
23 | 'thresholding': False,
24 | },
25 | prediction_type='epsilon',
26 | loss_fn='l2',
27 | optim_args={},
28 | **kwargs
29 | ):
30 | '''
31 | denoising_fn: a denoising model such as UNet
32 | beta_schedule_args: a dictionary which contains
33 | the configurations of the beta schedule
34 | '''
35 | super().__init__(**kwargs)
36 | self.unet = unet
37 | self.prediction_type = prediction_type
38 | beta_schedule_args.update({'prediction_type': prediction_type})
39 | self.set_beta_schedule(beta_schedule_args)
40 | self.num_timesteps = beta_schedule_args['num_train_timesteps']
41 | self.optim_args = optim_args
42 | self.loss = loss_fn
43 | if loss_fn == 'l2' or loss_fn == 'mse':
44 | self.loss_fn = nn.MSELoss(reduction='none')
45 | elif loss_fn == 'l1' or loss_fn == 'mae':
46 | self.loss_fn = nn.L1Loss(reduction='none')
47 | elif isinstance(loss_fn, dict):
48 | self.loss_fn = instantiate_from_config(loss_fn)
49 | else:
50 | raise NotImplementedError
51 |
52 | def set_beta_schedule(self, beta_schedule_args):
53 | self.beta_schedule_args = beta_schedule_args
54 | self.scheduler = DDPMScheduler(**beta_schedule_args)
55 |
56 | @torch.no_grad()
57 | def add_noise(self, x, t, noise=None):
58 | noise = default(noise, torch.randn_like(x))
59 | return self.scheduler.add_noise(x, noise, t)
60 |
61 | def predict_x_0_from_x_t(self, model_output: torch.Tensor, t: torch.LongTensor, x_t: torch.Tensor):
62 | ''' recover x_0 from predicted noise. Reverse of Eq(4) in DDPM paper
63 | \hat(x_0) = 1 / sqrt[\bar(a)]*x_t - sqrt[(1-\bar(a)) / \bar(a)]*noise'''
64 | # return self.scheduler.step(model_output, int(t), x_t).pred_original_sample
65 | if self.prediction_type == 'sample':
66 | return model_output
67 | # for training target == epsilon
68 | alphas_cumprod = self.scheduler.alphas_cumprod.to(device=x_t.device, dtype=x_t.dtype)
69 | sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod[t]).flatten()
70 | sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod[t] - 1.).flatten()
71 | while len(sqrt_recip_alphas_cumprod.shape) < len(x_t.shape):
72 | sqrt_recip_alphas_cumprod = sqrt_recip_alphas_cumprod.unsqueeze(-1)
73 | sqrt_recipm1_alphas_cumprod = sqrt_recipm1_alphas_cumprod.unsqueeze(-1)
74 | return sqrt_recip_alphas_cumprod * x_t - sqrt_recipm1_alphas_cumprod * model_output
75 |
76 | def predict_x_tm1_from_x_t(self, model_output, t, x_t):
77 | '''predict x_{t-1} from x_t and model_output'''
78 | return self.scheduler.step(model_output, t, x_t).prev_sample
79 |
80 | class DDPMTraining(DDPM):
81 | def __init__(
82 | self,
83 | unet,
84 | beta_schedule_args,
85 | prediction_type='epsilon',
86 | loss_fn='l2',
87 | optim_args={
88 | 'lr': 1e-3,
89 | 'weight_decay': 5e-4
90 | },
91 | log_args={}, # for record all arguments with self.save_hyperparameters
92 | ddim_sampling_steps=20,
93 | guidance_scale=5.,
94 | **kwargs
95 | ):
96 | super().__init__(
97 | unet=unet,
98 | beta_schedule_args=beta_schedule_args,
99 | prediction_type=prediction_type,
100 | loss_fn=loss_fn,
101 | optim_args=optim_args,
102 | **kwargs)
103 | self.log_args = log_args
104 | self.call_save_hyperparameters()
105 |
106 | self.ddim_sampling_steps = ddim_sampling_steps
107 | self.guidance_scale = guidance_scale
108 |
109 | def call_save_hyperparameters(self):
110 | '''write in a separate function so that the inherit class can overwrite it'''
111 | self.save_hyperparameters(ignore=['unet'])
112 |
113 | def process_batch(self, x_0, mode):
114 | assert mode in ['train', 'val', 'test']
115 | b, *_ = x_0.shape
116 | noise = torch.randn_like(x_0)
117 | if mode == 'train':
118 | t = torch.randint(0, self.num_timesteps, (b,), device=x_0.device).long()
119 | x_t = self.add_noise(x_0, t, noise=noise)
120 | else:
121 | t = torch.full((b,), self.num_timesteps-1, device=x_0.device, dtype=torch.long)
122 | x_t = self.add_noise(x_0, t, noise=noise)
123 |
124 | model_kwargs = {}
125 | '''the order of return is
126 | 1) model input,
127 | 2) model pred target,
128 | 3) model time condition
129 | 4) raw image before adding noise
130 | 5) model_kwargs
131 | '''
132 | if self.prediction_type == 'epsilon':
133 | return {
134 | 'model_input': x_t,
135 | 'model_target': noise,
136 | 't': t,
137 | 'model_kwargs': model_kwargs
138 | }
139 | else:
140 | return {
141 | 'model_input': x_t,
142 | 'model_target': x_0,
143 | 't': t,
144 | 'model_kwargs': model_kwargs
145 | }
146 |
147 | def forward(self, x):
148 | return self.validation_step(x, 0)
149 |
150 | def get_loss(self, pred, target, t):
151 | loss_raw = self.loss_fn(pred, target)
152 | loss_flat = mean_flat(loss_raw)
153 |
154 | loss = loss_flat
155 | loss = loss.mean()
156 |
157 | return loss
158 |
159 | def training_step(self, batch, batch_idx):
160 | self.clip_denoised = False
161 | processed_batch = self.process_batch(batch, mode='train')
162 | x_t = processed_batch['model_input']
163 | y = processed_batch['model_target']
164 | t = processed_batch['t']
165 | model_kwargs = processed_batch['model_kwargs']
166 | pred = self.unet(x_t, t, **model_kwargs)
167 | loss = self.get_loss(pred, y, t)
168 | x_0_hat = self.predict_x_0_from_x_t(pred, t, x_t)
169 |
170 | self.log(f'train_loss', loss)
171 | return {
172 | 'loss': loss,
173 | 'model_input': x_t,
174 | 'model_output': pred,
175 | 'x_0_hat': x_0_hat
176 | }
177 |
178 | @torch.no_grad()
179 | def validation_step(self, batch, batch_idx):
180 | from diffusers import DDIMScheduler
181 | scheduler = DDIMScheduler(**self.beta_schedule_args)
182 | scheduler.set_timesteps(self.ddim_sampling_steps)
183 | processed_batch = self.process_batch(batch, mode='val')
184 | x_t = torch.randn_like(processed_batch['model_input'])
185 | x_hist = []
186 | timesteps = scheduler.timesteps
187 | for i, t in enumerate(timesteps):
188 | t_ = torch.full((x_t.shape[0],), t, device=x_t.device, dtype=torch.long)
189 | model_output = self.unet(x_t, t_, **processed_batch['model_kwargs'])
190 | x_hist.append(
191 | self.predict_x_0_from_x_t(model_output, t_, x_t)
192 | )
193 | x_t = scheduler.step(model_output, t, x_t).prev_sample
194 |
195 | return {
196 | 'x_pred': x_t,
197 | 'x_hist': torch.stack(x_hist, dim=1),
198 | }
199 |
200 | def test_step(self, batch, batch_idx):
201 | '''Test is usually not used in a sampling problem'''
202 | return self.validation_step(batch, batch_idx)
203 |
204 |
205 | def configure_optimizers(self):
206 | optimizer = torch.optim.Adam(self.parameters(), **self.optim_args)
207 | return optimizer
208 |
209 | class DDPMLDMTraining(DDPMTraining):
210 | def __init__(
211 | self, *args,
212 | vae,
213 | unet_init_weights=None,
214 | vae_init_weights=None,
215 | scale_factor=0.18215,
216 | **kwargs
217 | ):
218 | super().__init__(*args, **kwargs)
219 | self.vae = vae
220 | self.scale_factor = scale_factor
221 | self.initialize_unet(unet_init_weights)
222 | self.initialize_vqvae(vae_init_weights)
223 |
224 | def initialize_unet(self, unet_init_weights):
225 | if unet_init_weights is not None:
226 | print(f'INFO: initialize denoising UNet from {unet_init_weights}')
227 | sd = torch.load(unet_init_weights, map_location='cpu')
228 | self.unet.load_state_dict(sd)
229 |
230 | def initialize_vqvae(self, vqvae_init_weights):
231 | if vqvae_init_weights is not None:
232 | print(f'INFO: initialize VQVAE from {vqvae_init_weights}')
233 | sd = torch.load(vqvae_init_weights, map_location='cpu')
234 | self.vae.load_state_dict(sd)
235 | for param in self.vae.parameters():
236 | param.requires_grad = False
237 |
238 | def call_save_hyperparameters(self):
239 | '''write in a separate function so that the inherit class can overwrite it'''
240 | self.save_hyperparameters(ignore=['unet', 'vae'])
241 |
242 | @torch.no_grad()
243 | def encode_image_to_latent(self, x):
244 | return self.vae.encode(x) * self.scale_factor
245 |
246 | @torch.no_grad()
247 | def decode_latent_to_image(self, x):
248 | x = x / self.scale_factor
249 | return self.vae.decode(x)
250 |
251 | def process_batch(self, x_0, mode):
252 | x_0 = self.encode_image_to_latent(x_0)
253 | res = super().process_batch(x_0, mode)
254 | return res
255 |
256 | def training_step(self, batch, batch_idx):
257 | res_dict = super().training_step(batch, batch_idx)
258 | res_dict['x_0_hat'] = self.decode_latent_to_image(res_dict['x_0_hat'])
259 | return res_dict
260 |
261 | class DDIMLDMTextTraining(DDPMLDMTraining):
262 | def __init__(
263 | self, *args,
264 | text_model,
265 | text_model_init_weights=None,
266 | **kwargs
267 | ):
268 | super().__init__(
269 | *args, **kwargs
270 | )
271 | self.text_model = text_model
272 | self.initialize_text_model(text_model_init_weights)
273 |
274 | def initialize_text_model(self, text_model_init_weights):
275 | if text_model_init_weights is not None:
276 | print(f'INFO: initialize text model from {text_model_init_weights}')
277 | sd = torch.load(text_model_init_weights, map_location='cpu')
278 | self.text_model.load_state_dict(sd)
279 | for param in self.text_model.parameters():
280 | param.requires_grad = False
281 |
282 | def call_save_hyperparameters(self):
283 | '''write in a separate function so that the inherit class can overwrite it'''
284 | self.save_hyperparameters(ignore=['unet', 'vae', 'text_model'])
285 |
286 | @torch.no_grad()
287 | def encode_text(self, x):
288 | if isinstance(x, tuple):
289 | x = list(x)
290 | return self.text_model.encode(x)
291 |
292 | def process_batch(self, batch, mode):
293 | x_0 = batch['image']
294 | text = batch['text']
295 | processed_batch = super().process_batch(x_0, mode)
296 | processed_batch['model_kwargs'].update({
297 | 'context': {'text': self.encode_text([text])}
298 | })
299 | return processed_batch
300 |
301 | def sampling(self, image_shape=(1, 4, 64, 64), text='', negative_text=None):
302 | '''
303 | Usage:
304 | sampled = self.sampling(text='a cat on the tree', negative_text='')
305 |
306 | x = sampled['x_pred'][0].permute(1, 2, 0).detach().cpu().numpy()
307 | x = x / 2 + 0.5
308 | plt.imshow(x)
309 |
310 | y = sampled['x_hist'][0, 10].permute(1, 2, 0).detach().cpu().numpy()
311 | y = y / 2 + 0.5
312 | plt.imshow(y)
313 | '''
314 | from diffusers import DDIMScheduler
315 | scheduler = DDIMScheduler(**self.beta_schedule_args)
316 | scheduler.set_timesteps(self.ddim_sampling_steps)
317 | x_t = torch.randn(*image_shape, device=self.device)
318 |
319 | do_cfg = self.guidance_scale > 1. and negative_text is not None
320 |
321 | if do_cfg:
322 | context = {'text': self.encode_text([text, negative_text])}
323 | else:
324 | context = {'text': self.encode_text([text])}
325 | x_hist = []
326 | timesteps = scheduler.timesteps
327 | for i, t in enumerate(timesteps):
328 | if do_cfg:
329 | model_input = torch.cat([x_t]*2)
330 | else:
331 | model_input = x_t
332 | t_ = torch.full((model_input.shape[0],), t, device=x_t.device, dtype=torch.long)
333 | model_output = self.unet(model_input, t_, context)
334 |
335 | if do_cfg:
336 | model_output_positive, model_output_negative = model_output.chunk(2)
337 | model_output = model_output_negative + self.guidance_scale * (model_output_positive - model_output_negative)
338 | x_hist.append(
339 | self.decode_latent_to_image(self.predict_x_0_from_x_t(model_output, t_[:x_t.shape[0]], x_t))
340 | )
341 | x_t = scheduler.step(model_output, t, x_t).prev_sample
342 |
343 | return {
344 | 'x_pred': self.decode_latent_to_image(x_t),
345 | 'x_hist': torch.stack(x_hist, dim=1),
346 | }
347 |
--------------------------------------------------------------------------------
/pl_trainer/instruct_p2p_video.py:
--------------------------------------------------------------------------------
1 | '''
2 | Use pretrained instruct pix2pix model but add additional channels for reference modification
3 | '''
4 |
5 | import torch
6 | from .diffusion import DDIMLDMTextTraining
7 | from einops import rearrange
8 |
9 | class InstructP2PVideoTrainer(DDIMLDMTextTraining):
10 | def __init__(
11 | self, *args,
12 | cond_image_dropout=0.1,
13 | prompt_type='output_prompt',
14 | text_cfg=7.5,
15 | img_cfg=1.2,
16 | **kwargs
17 | ):
18 | super().__init__(*args, **kwargs)
19 | self.cond_image_dropout = cond_image_dropout
20 |
21 | assert prompt_type in ['output_prompt', 'edit_prompt', 'mixed_prompt']
22 | self.prompt_type = prompt_type
23 |
24 | self.text_cfg = text_cfg
25 | self.img_cfg = img_cfg
26 |
27 | self.unet.enable_xformers_memory_efficient_attention()
28 | self.unet.enable_gradient_checkpointing()
29 |
30 | def encode_text(self, text):
31 | with torch.cuda.amp.autocast(dtype=torch.float16):
32 | encoded_text = super().encode_text(text)
33 | return encoded_text
34 |
35 | def encode_image_to_latent(self, image):
36 | with torch.cuda.amp.autocast(dtype=torch.float16):
37 | latent = super().encode_image_to_latent(image)
38 | return latent
39 |
40 | @torch.cuda.amp.autocast(dtype=torch.float16)
41 | @torch.no_grad()
42 | def get_prompt(self, batch, mode):
43 | if mode == 'train':
44 | if self.prompt_type == 'output_prompt':
45 | prompt = batch['output_prompt']
46 | elif self.prompt_type == 'edit_prompt':
47 | prompt = batch['edit_prompt']
48 | elif self.prompt_type == 'mixed_prompt':
49 | if int(torch.rand(1)) > 0.5:
50 | prompt = batch['output_prompt']
51 | else:
52 | prompt = batch['edit_prompt']
53 | else:
54 | prompt = batch['output_prompt']
55 | return self.encode_text(prompt)
56 |
57 | @torch.cuda.amp.autocast(dtype=torch.float16)
58 | @torch.no_grad()
59 | def encode_image_to_latent(self, image):
60 | b, f, c, h, w = image.shape
61 | image = rearrange(image, 'b f c h w -> (b f) c h w')
62 | latent = super().encode_image_to_latent(image)
63 | latent = rearrange(latent, '(b f) c h w -> b f c h w', b=b)
64 | return latent
65 |
66 | @torch.cuda.amp.autocast(dtype=torch.float16)
67 | @torch.no_grad()
68 | def decode_latent_to_image(self, latent):
69 | b, f, c, h, w = latent.shape
70 | latent = rearrange(latent, 'b f c h w -> (b f) c h w')
71 |
72 | image = []
73 | for latent_ in latent:
74 | image_ = super().decode_latent_to_image(latent_[None])
75 | image.append(image_)
76 | image = torch.cat(image, dim=0)
77 | # image = super().decode_latent_to_image(latent)
78 | image = rearrange(image, '(b f) c h w -> b f c h w', b=b)
79 | return image
80 |
81 | @torch.no_grad()
82 | def get_cond_image(self, batch, mode):
83 | cond_image = batch['input_video']
84 |
85 | # ip2p does not scale cond image, so we unscale the cond image
86 | cond_image = self.encode_image_to_latent(cond_image) / self.scale_factor
87 | if mode == 'train':
88 | if int(torch.rand(1)) < self.cond_image_dropout:
89 | cond_image = torch.zeros_like(cond_image)
90 | return cond_image
91 |
92 | @torch.no_grad()
93 | def get_diffused_image(self, batch, mode):
94 | x = batch['edited_video']
95 | b, *_ = x.shape
96 | x = self.encode_image_to_latent(x)
97 | eps = torch.randn_like(x)
98 |
99 | if mode == 'train':
100 | t = torch.randint(0, self.num_timesteps, (b,), device=x.device).long()
101 | else:
102 | t = torch.full((b,), self.num_timesteps-1, device=x.device, dtype=torch.long)
103 | x_t = self.add_noise(x, t, eps)
104 |
105 | if self.prediction_type == 'epsilon':
106 | return x_t, eps, t
107 | else:
108 | return x_t, x, t
109 |
110 | @torch.no_grad()
111 | def process_batch(self, batch, mode):
112 | cond_image = self.get_cond_image(batch, mode)
113 | diffused_image, target, t = self.get_diffused_image(batch, mode)
114 | prompt = self.get_prompt(batch, mode)
115 |
116 | model_kwargs = {
117 | 'encoder_hidden_states': prompt
118 | }
119 |
120 | return {
121 | 'diffused_input': diffused_image,
122 | 'condition': cond_image,
123 | 'target': target,
124 | 't': t,
125 | 'model_kwargs': model_kwargs,
126 | }
127 |
128 | def training_step(self, batch, batch_idx):
129 | processed_batch = self.process_batch(batch, mode='train')
130 | diffused_input = processed_batch['diffused_input']
131 | condition = processed_batch['condition']
132 | target = processed_batch['target']
133 | t = processed_batch['t']
134 | model_kwargs = processed_batch['model_kwargs']
135 |
136 | model_input = torch.cat([diffused_input, condition], dim=2) # b, f, c, h, w
137 | model_input = rearrange(model_input, 'b f c h w -> b c f h w')
138 |
139 | pred = self.unet(model_input, t, **model_kwargs).sample
140 | pred = rearrange(pred, 'b c f h w -> b f c h w')
141 |
142 | loss = self.get_loss(pred, target, t)
143 | self.log('train_loss', loss, sync_dist=True)
144 |
145 | latent_pred = self.predict_x_0_from_x_t(pred, t, diffused_input)
146 | image_pred = self.decode_latent_to_image(latent_pred)
147 |
148 | res_dict = {
149 | 'loss': loss,
150 | 'pred': image_pred,
151 | }
152 | return res_dict
153 |
154 | @torch.no_grad()
155 | @torch.cuda.amp.autocast(dtype=torch.bfloat16)
156 | def validation_step(self, batch, batch_idx):
157 | from .inference.inference import InferenceIP2PVideo
158 | inf_pipe = InferenceIP2PVideo(
159 | self.unet,
160 | beta_start=self.scheduler.config.beta_start,
161 | beta_end=self.scheduler.config.beta_end,
162 | beta_schedule=self.scheduler.config.beta_schedule,
163 | num_ddim_steps=20
164 | )
165 |
166 | processed_batch = self.process_batch(batch, mode='val')
167 | diffused_input = torch.randn_like(processed_batch['diffused_input'])
168 |
169 | condition = processed_batch['condition']
170 | img_cond = condition[:, :, :4]
171 |
172 | res = inf_pipe(
173 | latent = diffused_input,
174 | text_cond = processed_batch['model_kwargs']['encoder_hidden_states'],
175 | text_uncond = self.encode_text(['']),
176 | img_cond = img_cond,
177 | text_cfg = self.text_cfg,
178 | img_cfg = self.img_cfg,
179 | )
180 |
181 | latent_pred = res['latent']
182 | image_pred = self.decode_latent_to_image(latent_pred)
183 | res_dict = {
184 | 'pred': image_pred,
185 | }
186 | return res_dict
187 |
188 | def configure_optimizers(self):
189 | # optimizer = torch.optim.AdamW(self.unet.parameters(), lr=self.optim_args['lr'])
190 | import bitsandbytes as bnb
191 | params = []
192 | for name, p in self.unet.named_parameters():
193 | if ('transformer_in' in name) or ('temp_' in name):
194 | # p.requires_grad = True
195 | params.append(p)
196 | else:
197 | pass
198 | # p.requires_grad = False
199 | optimizer = bnb.optim.Adam8bit(params, lr=self.optim_args['lr'], betas=(0.9, 0.999))
200 | return optimizer
201 |
202 | def initialize_unet(self, unet_init_weights):
203 | if unet_init_weights is not None:
204 | print(f'INFO: initialize denoising UNet from {unet_init_weights}')
205 | sd = torch.load(unet_init_weights, map_location='cpu')
206 | model_sd = self.unet.state_dict()
207 | # fit input conv size
208 | for k in model_sd.keys():
209 | if k in sd.keys():
210 | pass
211 | else:
212 | # handling temporal layers
213 | if (('temp_' in k) or ('transformer_in' in k)) and 'proj_out' in k:
214 | # print(f'INFO: initialize {k} from {model_sd[k].shape} to zeros')
215 | sd[k] = torch.zeros_like(model_sd[k])
216 | else:
217 | # print(f'INFO: initialize {k} from {model_sd[k].shape} to random')
218 | sd[k] = model_sd[k]
219 | self.unet.load_state_dict(sd)
220 |
221 | class InstructP2PVideoTrainerTemporal(InstructP2PVideoTrainer):
222 | def initialize_unet(self, unet_init_weights):
223 | if unet_init_weights is not None:
224 | print(f'INFO: initialize denoising UNet from {unet_init_weights}')
225 | sd_init_weights, motion_module_init_weights = unet_init_weights
226 | sd = torch.load(sd_init_weights, map_location='cpu')
227 | motion_sd = torch.load(motion_module_init_weights, map_location='cpu')
228 | assert len(sd) + len(motion_sd) == len(self.unet.state_dict()), f'Improper state dict length, got {len(sd) + len(motion_sd)} expected {len(self.unet.state_dict())}'
229 | sd.update(motion_sd)
230 | for k, v in self.unet.state_dict().items():
231 | if 'pos_encoder.pe' in k:
232 | sd[k] = v # the size of pe may change
233 | self.unet.load_state_dict(sd)
234 |
235 | def configure_optimizers(self):
236 | import bitsandbytes as bnb
237 | motion_params = []
238 | remaining_params = []
239 | for name, p in self.unet.named_parameters():
240 | if ('motion' in name):
241 | motion_params.append(p)
242 | else:
243 | remaining_params.append(p)
244 | optimizer = bnb.optim.Adam8bit([
245 | {'params': motion_params, 'lr': self.optim_args['lr']},
246 | ], betas=(0.9, 0.999))
247 | return optimizer
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate
2 | einops
3 | wandb
4 | transformers
5 | diffusers
6 | pytorch-lightning
7 | opencv-python
8 | opencv-contrib-python
9 | omegaconf
10 | open-clip-torch
11 | jsonlines
12 | albumentations
13 | diff-match-patch
14 | git+https://github.com/openai/CLIP.git
15 | deepspeed
--------------------------------------------------------------------------------
/video_edit.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from misc_utils.train_utils import unit_test_create_model\n",
10 | "from misc_utils.image_utils import save_tensor_to_gif, save_tensor_to_images\n",
11 | "config_path = 'configs/instruct_v2v_inference.yaml'\n",
12 | "diffusion_model = unit_test_create_model(config_path)"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": null,
18 | "metadata": {},
19 | "outputs": [],
20 | "source": [
21 | "import torch\n",
22 | "ckpt = torch.load('insv2v.pth', map_location='cpu')\n",
23 | "diffusion_model.load_state_dict(ckpt, strict=False)"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": null,
29 | "metadata": {},
30 | "outputs": [],
31 | "source": [
32 | "# edit params\n",
33 | "EDIT_PROMPT = 'make the car red Porsche and drive alone beach'\n",
34 | "VIDEO_CFG = 1.2\n",
35 | "TEXT_CFG = 7.5\n",
36 | "LONG_VID_SAMPLING_CORRECTION_STEP = 0.5\n",
37 | "\n",
38 | "# video params\n",
39 | "VIDEO_PATH = 'data/car-turn.mp4'\n",
40 | "IMGSIZE = 384\n",
41 | "NUM_FRAMES = 32\n",
42 | "VIDEO_SAMPLE_RATE = 10\n",
43 | "\n",
44 | "# sampling params\n",
45 | "FRAMES_IN_BATCH = 16\n",
46 | "NUM_REF_FRAMES = 4\n",
47 | "USE_MOTION_COMPENSATION = True"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": null,
53 | "metadata": {},
54 | "outputs": [],
55 | "source": [
56 | "from pl_trainer.inference.inference import InferenceIP2PVideo, InferenceIP2PVideoOpticalFlow\n",
57 | "if USE_MOTION_COMPENSATION:\n",
58 | " inf_pipe = InferenceIP2PVideoOpticalFlow(\n",
59 | " unet = diffusion_model.unet,\n",
60 | " num_ddim_steps=20,\n",
61 | " scheduler='ddpm'\n",
62 | " )\n",
63 | "else:\n",
64 | " inf_pipe = InferenceIP2PVideo(\n",
65 | " unet = diffusion_model.unet,\n",
66 | " num_ddim_steps=20,\n",
67 | " scheduler='ddpm'\n",
68 | " )"
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "execution_count": null,
74 | "metadata": {},
75 | "outputs": [],
76 | "source": [
77 | "from dataset.single_video_dataset import SingleVideoDataset\n",
78 | "dataset = SingleVideoDataset(\n",
79 | " video_file=VIDEO_PATH,\n",
80 | " video_description='',\n",
81 | " sampling_fps=VIDEO_SAMPLE_RATE,\n",
82 | " num_frames=NUM_FRAMES,\n",
83 | " output_size=(IMGSIZE, IMGSIZE)\n",
84 | ")\n",
85 | "batch = dataset[20] # start from 20th frame\n",
86 | "batch = {k: v.cuda()[None] if isinstance(v, torch.Tensor) else v for k, v in batch.items()}"
87 | ]
88 | },
89 | {
90 | "cell_type": "code",
91 | "execution_count": null,
92 | "metadata": {},
93 | "outputs": [],
94 | "source": [
95 | "def split_batch(cond, frames_in_batch=16, num_ref_frames=4):\n",
96 | " frames_in_following_batch = frames_in_batch - num_ref_frames\n",
97 | " conds = [cond[:, :frames_in_batch]]\n",
98 | " frame_ptr = frames_in_batch\n",
99 | " num_ref_frames_each_batch = []\n",
100 | "\n",
101 | " while frame_ptr < cond.shape[1]:\n",
102 | " remaining_frames = cond.shape[1] - frame_ptr\n",
103 | " if remaining_frames < frames_in_batch:\n",
104 | " frames_in_following_batch = remaining_frames\n",
105 | " else:\n",
106 | " frames_in_following_batch = frames_in_batch - num_ref_frames\n",
107 | " this_ref_frames = frames_in_batch - frames_in_following_batch\n",
108 | " conds.append(cond[:, frame_ptr:frame_ptr+frames_in_following_batch])\n",
109 | " frame_ptr += frames_in_following_batch\n",
110 | " num_ref_frames_each_batch.append(this_ref_frames)\n",
111 | "\n",
112 | " return conds, num_ref_frames_each_batch"
113 | ]
114 | },
115 | {
116 | "cell_type": "code",
117 | "execution_count": null,
118 | "metadata": {},
119 | "outputs": [],
120 | "source": [
121 | "cond = [diffusion_model.encode_image_to_latent(frames) / 0.18215 for frames in batch['frames'].chunk(16, dim=1)] # when encoding, chunk the frames to avoid oom in vae, you can reduce the 16 if you have a smaller gpu\n",
122 | "cond = torch.cat(cond, dim=1)\n",
123 | "text_cond = diffusion_model.encode_text([EDIT_PROMPT])\n",
124 | "text_uncond = diffusion_model.encode_text([''])\n",
125 | "conds, num_ref_frames_each_batch = split_batch(cond, frames_in_batch=FRAMES_IN_BATCH, num_ref_frames=NUM_REF_FRAMES)\n",
126 | "splitted_frames, _ = split_batch(batch['frames'], frames_in_batch=FRAMES_IN_BATCH, num_ref_frames=NUM_REF_FRAMES)"
127 | ]
128 | },
129 | {
130 | "cell_type": "code",
131 | "execution_count": null,
132 | "metadata": {},
133 | "outputs": [],
134 | "source": [
135 | "# First video clip\n",
136 | "cond1 = conds[0]\n",
137 | "latent_pred_list = []\n",
138 | "init_latent = torch.randn_like(cond1)\n",
139 | "latent_pred = inf_pipe(\n",
140 | " latent = init_latent,\n",
141 | " text_cond = text_cond,\n",
142 | " text_uncond = text_uncond,\n",
143 | " img_cond = cond1,\n",
144 | " text_cfg = TEXT_CFG,\n",
145 | " img_cfg = VIDEO_CFG,\n",
146 | ")['latent']\n",
147 | "latent_pred_list.append(latent_pred)\n",
148 | "\n",
149 | "\n",
150 | "# Subsequent video clips\n",
151 | "for prev_cond, cond_, prev_frame, curr_frame, num_ref_frames_ in zip(\n",
152 | " conds[:-1], conds[1:], splitted_frames[:-1], splitted_frames[1:], num_ref_frames_each_batch\n",
153 | "):\n",
154 | " init_latent = torch.cat([init_latent[:, -num_ref_frames_:], torch.randn_like(cond_)], dim=1)\n",
155 | " cond_ = torch.cat([prev_cond[:, -num_ref_frames_:], cond_], dim=1)\n",
156 | " if USE_MOTION_COMPENSATION:\n",
157 | " ref_images = prev_frame[:, -num_ref_frames_:]\n",
158 | " query_images = curr_frame\n",
159 | " additional_kwargs = {\n",
160 | " 'ref_images': ref_images,\n",
161 | " 'query_images': query_images,\n",
162 | " }\n",
163 | " else:\n",
164 | " additional_kwargs = {}\n",
165 | " latent_pred = inf_pipe.second_clip_forward(\n",
166 | " latent = init_latent, \n",
167 | " text_cond = text_cond,\n",
168 | " text_uncond = text_uncond,\n",
169 | " img_cond = cond_,\n",
170 | " latent_ref = latent_pred[:, -num_ref_frames_:],\n",
171 | " noise_correct_step = LONG_VID_SAMPLING_CORRECTION_STEP,\n",
172 | " text_cfg = TEXT_CFG,\n",
173 | " img_cfg = VIDEO_CFG,\n",
174 | " **additional_kwargs,\n",
175 | " )['latent']\n",
176 | " latent_pred_list.append(latent_pred[:, num_ref_frames_:])\n",
177 | "\n",
178 | "# Save GIF\n",
179 | "latent_pred = torch.cat(latent_pred_list, dim=1)\n",
180 | "image_pred = diffusion_model.decode_latent_to_image(latent_pred).clip(-1, 1)"
181 | ]
182 | },
183 | {
184 | "cell_type": "code",
185 | "execution_count": null,
186 | "metadata": {},
187 | "outputs": [],
188 | "source": [
189 | "original_images = batch['frames'].cpu()\n",
190 | "transferred_images = image_pred.float().cpu()\n",
191 | "concat_images = torch.cat([original_images, transferred_images], dim=4)\n",
192 | "\n",
193 | "save_tensor_to_gif(concat_images, 'results/video_edit.gif', fps=5)\n",
194 | "save_tensor_to_images(transferred_images, 'results/video_edit_images')"
195 | ]
196 | },
197 | {
198 | "cell_type": "code",
199 | "execution_count": null,
200 | "metadata": {},
201 | "outputs": [],
202 | "source": [
203 | "# visualize the gif\n",
204 | "from IPython.display import Image\n",
205 | "Image(filename='results/video_edit.gif')"
206 | ]
207 | }
208 | ],
209 | "metadata": {
210 | "kernelspec": {
211 | "display_name": "pytorch",
212 | "language": "python",
213 | "name": "python3"
214 | },
215 | "language_info": {
216 | "codemirror_mode": {
217 | "name": "ipython",
218 | "version": 3
219 | },
220 | "file_extension": ".py",
221 | "mimetype": "text/x-python",
222 | "name": "python",
223 | "nbconvert_exporter": "python",
224 | "pygments_lexer": "ipython3",
225 | "version": "3.10.10"
226 | }
227 | },
228 | "nbformat": 4,
229 | "nbformat_minor": 2
230 | }
231 |
--------------------------------------------------------------------------------
/video_prompt_to_prompt.py:
--------------------------------------------------------------------------------
1 | # %%
2 | import os
3 | import jsonlines
4 | import json
5 | from misc_utils.video_ptp_utils import get_models_of_damo_model
6 | import torch
7 | import numpy as np
8 | from einops import rearrange
9 | from misc_utils.ptp_utils import get_text_embedding_openclip, encode_text_openclip, Text, Edit, Insert, Delete
10 | from misc_utils.video_ptp_utils import compute_diff
11 | from pl_trainer.inference.inference_damo import InferenceDAMO_PTP_v2
12 | import cv2
13 | from misc_utils.image_utils import images_to_gif
14 | from misc_utils.clip_similarity import ClipSimilarity
15 |
16 | def save_images_to_folder(source_images, target_images, folder, seed):
17 | os.makedirs(folder, exist_ok=True)
18 | for i, (src_image, tgt_image) in enumerate(zip(source_images, target_images)):
19 | src_img = cv2.cvtColor((src_image*255).astype(np.uint8), cv2.COLOR_RGB2BGR)
20 | tgt_img = cv2.cvtColor((tgt_image*255).astype(np.uint8), cv2.COLOR_RGB2BGR)
21 | cv2.imwrite(os.path.join(folder, f'{seed}_0_{i:04d}.jpg'), src_img)
22 | cv2.imwrite(os.path.join(folder, f'{seed}_1_{i:04d}.jpg'), tgt_img)
23 |
24 | def save_images_to_gif(source_images, target_images, folder, seed):
25 | os.makedirs(folder, exist_ok=True)
26 | # images_to_gif(source_images, os.path.join(folder, f'{seed}_0.gif'), fps=5)
27 | # images_to_gif(target_images, os.path.join(folder, f'{seed}_1.gif'), fps=5)
28 | concat_image = np.concatenate([source_images, target_images], axis=2)
29 | images_to_gif(concat_image, os.path.join(folder, f'{seed}_concat.gif'), fps=5)
30 |
31 | def append_dict_to_jsonl(file_path: str, dict_obj: dict) -> None:
32 | with open(file_path, 'a') as f:
33 | f.write(json.dumps(dict_obj))
34 | f.write("\n") # Write a new line at the end to prepare for the next JSON object
35 |
36 | # %%
37 | def torch_to_numpy(x):
38 | return (x.float().squeeze().detach().cpu().numpy().transpose(0, 2, 3, 1) / 2 + 0.5).clip(0, 1)
39 | def str_to_prompt(prompt):
40 | input_text = prompt['input'].strip('.')
41 | output_text = prompt['output'].strip('.')
42 | return compute_diff(input_text, output_text)
43 |
44 | def get_ptp_prompt(prompt, edit_weight=1.):
45 | PROMPT = str_to_prompt(prompt)
46 | for p in PROMPT:
47 | if isinstance(p, (Edit, Insert)):
48 | p.weight = edit_weight
49 | print(PROMPT)
50 | source_prompt = ' '.join(x.old for x in PROMPT)
51 | target_prompt = ' '.join(x.new for x in PROMPT)
52 | context_uncond = get_text_embedding_openclip("", text_encoder, text_model.device)
53 | old_context = get_text_embedding_openclip(source_prompt, text_encoder, text_model.device)
54 | context = get_text_embedding_openclip(target_prompt, text_encoder, text_model.device)
55 | key, value = encode_text_openclip(PROMPT, text_encoder, text_model.device)
56 | return {
57 | 'source_prompt': source_prompt,
58 | 'target_prompt': target_prompt,
59 | 'context_uncond': context_uncond,
60 | 'old_context': old_context,
61 | 'context': context,
62 | 'key': key,
63 | 'value': value,
64 | }
65 | def process_one_sample(prompt, seed=None, guidance_scale=9, num_ddim_steps=30, scheduler='ddim', sa_end_time=0.4, ca_end_time=0.8, edit_weight=1.):
66 | prompt_dict = get_ptp_prompt(prompt, edit_weight=edit_weight)
67 |
68 | if seed is None:
69 | seed = np.random.randint(0, 1000000)
70 | torch.random.manual_seed(seed)
71 | latent = torch.randn(1, 4, 16, 32, 32).cuda()
72 | inf_pipe = InferenceDAMO_PTP_v2(unet=unet, guidance_scale=guidance_scale, num_ddim_steps=num_ddim_steps, scheduler=scheduler)
73 |
74 | with torch.cuda.amp.autocast(dtype=torch.float16):
75 | res = inf_pipe(
76 | latent=latent,
77 | context = prompt_dict['context'],
78 | old_context = prompt_dict['old_context'],
79 | old_to_new_context=[prompt_dict['key'], prompt_dict['value']],
80 | uncond_context = prompt_dict['context_uncond'],
81 | sa_end_time=sa_end_time,
82 | ca_end_time=ca_end_time,
83 | )
84 | pred_latent = res['latent']
85 | pred_latent_old = res['latent_old']
86 |
87 | latents = rearrange(pred_latent, 'b d f h w -> f b d h w')
88 | latents_old = rearrange(pred_latent_old, 'b d f h w -> f b d h w')
89 | with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.float16):
90 | pred_image = [vae.decode(latent / 0.18215) for latent in latents]
91 | pred_image_old = [vae.decode(latent / 0.18215) for latent in latents_old]
92 | pred_images = torch.stack(pred_image, dim=0)
93 | pred_images_old = torch.stack(pred_image_old, dim=0)
94 |
95 | old_prompt_images_np = torch_to_numpy(pred_images_old)
96 | new_prompt_images_np = torch_to_numpy(pred_images)
97 |
98 | return old_prompt_images_np, new_prompt_images_np, seed, prompt_dict
99 |
100 |
101 | if __name__ == '__main__':
102 | import argparse
103 | parser = argparse.ArgumentParser()
104 | parser.add_argument('--start', type=int, default=0)
105 | parser.add_argument('--end', type=int, default=10)
106 | parser.add_argument('--prompt_source', type=str, default='ip2p')
107 | parser.add_argument('--num_sample_each_prompt', '-n', type=int, default=1)
108 | args = parser.parse_args()
109 |
110 | if args.prompt_source == 'ip2p':
111 | prompt_meta_file = 'data/gpt-generated-prompts.jsonl'
112 | with open(prompt_meta_file, 'r') as f:
113 | prompts = [(i, json.loads(line)) for i, line in enumerate(f.readlines())]
114 | output_dir = 'video_ptp/raw_generated'
115 | elif args.prompt_source == 'webvid':
116 | root_dir = 'webvid_edit_prompt'
117 | files = os.listdir(root_dir)
118 | files = sorted(files, key=lambda x: int(x.split('.')[0]))
119 | prompts = []
120 | for file in files:
121 | with open(os.path.join(root_dir, file), 'r') as f:
122 | prompt_idx = int(file.split('.')[0])
123 | prompts.append((prompt_idx, json.load(f)))
124 | output_dir = 'video_ptp/raw_generated_webvid'
125 | else:
126 | raise ValueError(f'Unknown prompt source: {args.prompt_source}')
127 |
128 |
129 | # %%
130 | vae_config = 'configs/instruct_v2v.yaml'
131 | unet_config = 'modules/damo_text_to_video/configuration.json'
132 | vae_ckpt = 'VAE_PATH'
133 | unet_ckpt = 'UNet_PATH'
134 | text_model_ckpt = 'Text_MODEL_PATH'
135 |
136 | vae, unet, text_model = get_models_of_damo_model(
137 | unet_config=unet_config,
138 | unet_ckpt=unet_ckpt,
139 | vae_config=vae_config,
140 | vae_ckpt=vae_ckpt,
141 | text_model_ckpt=text_model_ckpt,
142 | )
143 | text_encoder = text_model.encode_with_transformer
144 | clip_sim = ClipSimilarity(name='ViT-L/14').cuda()
145 | # %%
146 |
147 | for prompt_idx, prompt in prompts:
148 | if prompt_idx < args.start:
149 | continue
150 | if prompt_idx >= args.end:
151 | break
152 | print(prompt_idx, prompt)
153 | output_folder_idx = f'{prompt_idx:07d}'
154 |
155 | prompt_json_file = os.path.join(output_dir, output_folder_idx, 'prompt.json')
156 | os.makedirs(os.path.dirname(prompt_json_file), exist_ok=True)
157 | with open(prompt_json_file, 'w') as f:
158 | json.dump(prompt, f)
159 |
160 | meta_file_path = os.path.join(output_dir, output_folder_idx, 'metadata.jsonl')
161 | if os.path.exists(meta_file_path):
162 | with jsonlines.open(meta_file_path, 'r') as f:
163 | used_seed = [int(line['seed']) for line in f]
164 | num_existing_samples = len(used_seed)
165 | else:
166 | num_existing_samples = 0
167 | used_seed = []
168 | print(f'num_existing_samples: {num_existing_samples}, used_seed: {used_seed}')
169 |
170 | for _ in range(num_existing_samples, args.num_sample_each_prompt):
171 | # generate a random configure
172 | seed = np.random.randint(0, 1000000)
173 | while seed in used_seed:
174 | seed = np.random.randint(0, 1000000)
175 | used_seed.append(seed)
176 |
177 | rng = np.random.RandomState(seed=seed)
178 | guidance_scale = rng.randint(5, 13)
179 | sa_end_time = float('{:.2f}'.format(rng.choice(np.linspace(0.3, 0.45, 4))))
180 | ca_end_time = float('{:.2f}'.format(rng.choice(np.linspace(0.6, 0.85, 6))))
181 | edit_weight = rng.randint(1, 6)
182 | generate_config = {
183 | 'seed': seed,
184 | 'guidance_scale': guidance_scale,
185 | 'sa_end_time': sa_end_time,
186 | 'ca_end_time': ca_end_time,
187 | 'edit_weight': edit_weight,
188 | }
189 | print(generate_config)
190 |
191 | # %%
192 | source_prompt_images, target_prompt_images, seed, prompt_dict = process_one_sample(
193 | prompt,
194 | **generate_config,
195 | )
196 |
197 | source_prompt_images_torch = torch.from_numpy(source_prompt_images.transpose(0, 3, 1, 2)).cuda()
198 | target_prompt_images_torch = torch.from_numpy(target_prompt_images.transpose(0, 3, 1, 2)).cuda()
199 |
200 | with torch.no_grad():
201 | sim_0, sim_1, sim_dir, sim_image = clip_sim(source_prompt_images_torch, target_prompt_images_torch, [prompt_dict['source_prompt']], [prompt_dict['target_prompt']])
202 |
203 | simi_dict = {
204 | 'sim_0': sim_0.mean().item(),
205 | 'sim_1': sim_1.mean().item(),
206 | 'sim_dir': sim_dir.mean().item(),
207 | 'sim_image': sim_image.mean().item(),
208 | }
209 |
210 | generate_config.update(simi_dict)
211 |
212 | output_folder_img = os.path.join(output_dir, output_folder_idx, 'image')
213 | output_folder_gif = os.path.join(output_dir, output_folder_idx, 'gif')
214 |
215 | if (
216 | sim_0.mean().item() > 0.2 and sim_1.mean().item() > 0.2 and sim_dir.mean().item() > 0.2 and sim_image.mean().item() > 0.5
217 | ):
218 | save_images_to_folder(source_prompt_images, target_prompt_images, output_folder_img, seed)
219 | save_images_to_gif(source_prompt_images, target_prompt_images, output_folder_gif, seed)
220 |
221 | append_dict_to_jsonl(meta_file_path, generate_config)
222 |
223 |
224 |
--------------------------------------------------------------------------------