├── .gitignore ├── LICENSE ├── README.md ├── artist ├── .DS_Store ├── __init__.py ├── __pycache__ │ └── __init__.cpython-38.pyc ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── samplers.cpython-38.pyc │ │ ├── tokenizers.cpython-38.pyc │ │ └── transforms.cpython-38.pyc │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── samplers.py │ ├── tokenizers.py │ └── transforms.py ├── models │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── clip.cpython-38.pyc │ │ └── midas.cpython-38.pyc │ ├── clip.py │ └── midas.py ├── ops │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── degration.cpython-38.pyc │ │ ├── diffusion.cpython-38.pyc │ │ ├── distributed.cpython-38.pyc │ │ ├── dpm_solver.cpython-38.pyc │ │ ├── losses.cpython-38.pyc │ │ ├── random_mask.cpython-38.pyc │ │ └── utils.cpython-38.pyc │ ├── degration.py │ ├── diffusion.py │ ├── distributed.py │ ├── dpm_solver.py │ ├── losses.py │ ├── random_mask.py │ └── utils.py └── optim │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── adafactor.cpython-38.pyc │ └── lr_scheduler.cpython-38.pyc │ ├── adafactor.py │ └── lr_scheduler.py ├── configs ├── base.yaml ├── exp01_vidcomposer_full.yaml ├── exp02_motion_transfer.yaml ├── exp02_motion_transfer_vs_style.yaml ├── exp03_sketch2video_style.yaml ├── exp04_sketch2video_wo_style.yaml ├── exp05_text_depths_wo_style.yaml ├── exp06_text_depths_vs_style.yaml └── exp10_vidcomposer_no_watermark_full.yaml ├── demo_video ├── blackswan.mp4 ├── captions_list.txt ├── landscape_painting.png ├── moon_on_water.jpg ├── motion_transfer.mp4 ├── qibaishi_01.png ├── src_single_sketch.png ├── style │ ├── Bingxueqiyuan.jpeg │ ├── fangao_01.jpeg │ ├── fangao_02.jpeg │ └── fangao_03.jpeg ├── sunflower.png ├── sunflower_sketch.png ├── tall_buildings.png ├── tennis.mp4 ├── video_10000178.mp4 ├── video_5360763.mp4 ├── video_8800.mp4 └── wash_painting.png ├── environment.yaml ├── gen_sketch.py ├── model_weights └── readme.md ├── run_bash.sh ├── run_net.py ├── source ├── fig01.jpg ├── fig02_framwork.jpg ├── fig03_image-to-video.jpg ├── fig04_hand-crafted-motions.jpg ├── fig05_video-inpainting.jpg ├── fig06_sketch-to-video.jpg └── results │ ├── exp02_motion_transfer-S00009.gif │ ├── exp02_motion_transfer-S09999-0.gif │ ├── exp02_motion_transfer-S09999.gif │ ├── exp03_sketch2video_style-S09999.gif │ ├── exp04_sketch2video_wo_style-S00144-1.gif │ ├── exp04_sketch2video_wo_style-S00144-2.gif │ ├── exp04_sketch2video_wo_style-S00144.gif │ ├── exp05_text_depths_wo_style-S09999-0.gif │ ├── exp05_text_depths_wo_style-S09999-1.gif │ ├── exp05_text_depths_wo_style-S09999-2.gif │ ├── exp06_text_depths_vs_style-S09999-0.gif │ ├── exp06_text_depths_vs_style-S09999-1.gif │ ├── exp06_text_depths_vs_style-S09999-2.gif │ └── exp06_text_depths_vs_style-S09999-3.gif ├── tools ├── .DS_Store ├── __init__.py ├── __pycache__ │ └── __init__.cpython-38.pyc ├── annotator │ ├── .DS_Store │ ├── __pycache__ │ │ └── util.cpython-38.pyc │ ├── canny │ │ ├── __init__.py │ │ └── __pycache__ │ │ │ └── __init__.cpython-38.pyc │ ├── histogram │ │ ├── __init__.py │ │ └── palette.py │ ├── sketch │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── pidinet.cpython-38.pyc │ │ │ └── sketch_simplification.cpython-38.pyc │ │ ├── pidinet.py │ │ └── sketch_simplification.py │ └── util.py └── videocomposer │ ├── __pycache__ │ ├── autoencoder.cpython-38.pyc │ ├── config.cpython-38.pyc │ ├── datasets.cpython-38.pyc │ ├── inference_multi.cpython-38.pyc │ ├── inference_single.cpython-38.pyc │ ├── mha_flash.cpython-38.pyc │ └── unet_sd.cpython-38.pyc │ ├── autoencoder.py │ ├── config.py │ ├── datasets.py │ ├── inference_multi.py │ ├── inference_single.py │ ├── mha_flash.py │ └── unet_sd.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-38.pyc ├── config.cpython-38.pyc ├── distributed.cpython-38.pyc └── logging.cpython-38.pyc ├── config.py ├── distributed.py └── logging.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pkl 2 | *.pt 3 | *.mov 4 | *.pth 5 | *.json 6 | *.mov 7 | *.npz 8 | *.npy 9 | *.boj 10 | *.onnx 11 | *.tar 12 | *.bin 13 | cache* 14 | .DS_Store 15 | *DS_Store 16 | outputs/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 DAMO Vision Intelligence Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VideoComposer 2 | 3 | Official repo for [VideoComposer: Compositional Video Synthesis with Motion Controllability](https://arxiv.org/pdf/2306.02018.pdf) 4 | 5 | Please see [Project Page](https://videocomposer.github.io/) for more examples. 6 | 7 | We are searching for talented, motivated, and imaginative researchers to join our team. If you are interested, please don't hesitate to send us your resume via email yingya.zyy@alibaba-inc.com 8 | 9 | ![figure1](source/fig01.jpg "figure1") 10 | 11 | 12 | VideoComposer is a controllable video diffusion model, which allows users to flexibly control the spatial and temporal patterns simultaneously within a synthesized video in various forms, such as text description, sketch sequence, reference video, or even simply handcrafted motions and handrawings. 13 | 14 | 15 | ## 🔥News!!! 16 | 17 | - __[2023.10]__ We release a high-quality I2VGen-XL model, please refer to the [Webpage](https://i2vgen-xl.github.io) 18 | - __[2023.08]__ We release the Gradio UI on [ModelScope](https://modelscope.cn/studios/damo/VideoComposer-Demo/summary) 19 | - __[2023.07]__ We release the pretrained model without watermark, please refer to the [ModelCard](https://modelscope.cn/models/damo/VideoComposer/files) 20 | 21 | 22 | 23 | ## TODO 24 | - [x] Release our technical papers and webpage. 25 | - [x] Release code and pretrained model. 26 | - [x] Release Gradio UI on [ModelScope](https://modelscope.cn/studios/damo/VideoComposer-Demo/summary) and Hugging Face. 27 | - [x] Release pretrained model that can generate 8s videos without watermark on [ModelScope](https://modelscope.cn/models/damo/VideoComposer/files) 28 | 29 | 30 | ## Method 31 | 32 | ![method](source/fig02_framwork.jpg "method") 33 | 34 | 35 | ## Running by Yourself 36 | 37 | ### 1. Installation 38 | 39 | Requirements: 40 | - Python==3.8 41 | - ffmpeg (for motion vector extraction) 42 | - torch==1.12.0+cu113 43 | - torchvision==0.13.0+cu113 44 | - open-clip-torch==2.0.2 45 | - transformers==4.18.0 46 | - flash-attn==0.2 47 | - xformers==0.0.13 48 | - motion-vector-extractor==1.0.6 (for motion vector extraction) 49 | 50 | You also can create the same environment as ours with the following command: 51 | ``` 52 | conda env create -f environment.yaml 53 | ``` 54 | 55 | ### 2. Download model weights 56 | 57 | Download all the [model weights](https://www.modelscope.cn/models/damo/VideoComposer/summary) via the following command: 58 | 59 | ``` 60 | !pip install modelscope 61 | from modelscope.hub.snapshot_download import snapshot_download 62 | model_dir = snapshot_download('damo/VideoComposer', cache_dir='model_weights/', revision='v1.0.0') 63 | ``` 64 | 65 | Next, place these models in the `model_weights` folder following the file structure shown below. 66 | 67 | 68 | ``` 69 | |--model_weights/ 70 | | |--non_ema_228000.pth 71 | | |--midas_v3_dpt_large.pth 72 | | |--open_clip_pytorch_model.bin 73 | | |--sketch_simplification_gan.pth 74 | | |--table5_pidinet.pth 75 | | |--v2-1_512-ema-pruned.ckpt 76 | ``` 77 | 78 | You can also download some of them from their original project: 79 | - "midas_v3_dpt_large.pth" in [MiDaS](https://github.com/isl-org/MiDaS) 80 | - "open_clip_pytorch_model.bin" in [Open Clip](https://github.com/mlfoundations/open_clip) 81 | - "sketch_simplification_gan.pth" and "table5_pidinet.pth" in [Pidinet](https://github.com/zhuoinoulu/pidinet) 82 | - "v2-1_512-ema-pruned.ckpt" in [Stable Diffusion](https://huggingface.co/stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned.ckpt). 83 | 84 | For convenience, we provide a download link in this repo. 85 | 86 | 87 | ### 3. Running 88 | 89 | In this project, we provide two implementations that can help you better understand our method. 90 | 91 | 92 | #### 3.1 Inference with Customized Inputs 93 | 94 | You can run the code with the following command: 95 | 96 | ``` 97 | python run_net.py\ 98 | --cfg configs/exp02_motion_transfer.yaml\ 99 | --seed 9999\ 100 | --input_video "demo_video/motion_transfer.mp4"\ 101 | --image_path "demo_video/moon_on_water.jpg"\ 102 | --input_text_desc "A beautiful big moon on the water at night" 103 | ``` 104 | The results are saved in the `outputs/exp02_motion_transfer-S09999` folder: 105 | 106 | ![case1](source/results/exp02_motion_transfer-S00009.gif "case2") 107 | ![case2](source/results/exp02_motion_transfer-S09999.gif "case2") 108 | 109 | 110 | In some cases, if you notice a significant change in color difference, you can use the style condition to adjust the color distribution with the following command. This can be helpful in certain cases. 111 | 112 | 113 | ``` 114 | python run_net.py\ 115 | --cfg configs/exp02_motion_transfer_vs_style.yaml\ 116 | --seed 9999\ 117 | --input_video "demo_video/motion_transfer.mp4"\ 118 | --image_path "demo_video/moon_on_water.jpg"\ 119 | --style_image "demo_video/moon_on_water.jpg"\ 120 | --input_text_desc "A beautiful big moon on the water at night" 121 | ``` 122 | 123 | 124 | ``` 125 | python run_net.py\ 126 | --cfg configs/exp03_sketch2video_style.yaml\ 127 | --seed 8888\ 128 | --sketch_path "demo_video/src_single_sketch.png"\ 129 | --style_image "demo_video/style/qibaishi_01.png"\ 130 | --input_text_desc "Red-backed Shrike lanius collurio" 131 | ``` 132 | ![case2](source/results/exp03_sketch2video_style-S09999.gif "case2") 133 | 134 | 135 | 136 | ``` 137 | python run_net.py\ 138 | --cfg configs/exp04_sketch2video_wo_style.yaml\ 139 | --seed 144\ 140 | --sketch_path "demo_video/src_single_sketch.png"\ 141 | --input_text_desc "A Red-backed Shrike lanius collurio is on the branch" 142 | ``` 143 | ![case2](source/results/exp04_sketch2video_wo_style-S00144.gif "case2") 144 | ![case2](source/results/exp04_sketch2video_wo_style-S00144-1.gif "case2") 145 | 146 | 147 | 148 | ``` 149 | python run_net.py\ 150 | --cfg configs/exp05_text_depths_wo_style.yaml\ 151 | --seed 9999\ 152 | --input_video demo_video/video_8800.mp4\ 153 | --input_text_desc "A glittering and translucent fish swimming in a small glass bowl with multicolored piece of stone, like a glass fish" 154 | ``` 155 | ![case2](source/results/exp05_text_depths_wo_style-S09999-0.gif "case2") 156 | ![case2](source/results/exp05_text_depths_wo_style-S09999-2.gif "case2") 157 | 158 | ``` 159 | python run_net.py\ 160 | --cfg configs/exp06_text_depths_vs_style.yaml\ 161 | --seed 9999\ 162 | --input_video demo_video/video_8800.mp4\ 163 | --style_image "demo_video/style/qibaishi_01.png"\ 164 | --input_text_desc "A glittering and translucent fish swimming in a small glass bowl with multicolored piece of stone, like a glass fish" 165 | ``` 166 | 167 | ![case2](source/results/exp06_text_depths_vs_style-S09999-0.gif "case2") 168 | ![case2](source/results/exp06_text_depths_vs_style-S09999-1.gif "case2") 169 | 170 | 171 | #### 3.2 Inference on a Video 172 | 173 | You can just run the code with the following command: 174 | ``` 175 | python run_net.py \ 176 | --cfg configs/exp01_vidcomposer_full.yaml \ 177 | --input_video "demo_video/blackswan.mp4" \ 178 | --input_text_desc "A black swan swam in the water" \ 179 | --seed 9999 180 | ``` 181 | 182 | This command will extract the different conditions, e.g., depth, sketch, and motion vectors, of the input video for the following video generation, which are saved in the `outputs` folder. The task list are predefined in inference_multi.py. 183 | 184 | 185 | 186 | In addition to the above use cases, you can explore further possibilities with this code and model. Please note that due to the diversity of generated samples by the diffusion model, you can explore different seeds to generate better results. 187 | 188 | We hope you enjoy using it! 😀 189 | 190 | 191 | 192 | ## BibTeX 193 | 194 | If this repo is useful to you, please cite our technical paper. 195 | ```bibtex 196 | @article{2023videocomposer, 197 | title={VideoComposer: Compositional Video Synthesis with Motion Controllability}, 198 | author={Wang, Xiang* and Yuan, Hangjie* and Zhang, Shiwei* and Chen, Dayou* and Wang, Jiuniu, and Zhang, Yingya, and Shen, Yujun, and Zhao, Deli and Zhou, Jingren}, 199 | booktitle={arXiv preprint arXiv:2306.02018}, 200 | year={2023} 201 | } 202 | ``` 203 | 204 | 205 | ## Acknowledgement 206 | 207 | We would like to express our gratitude for the contributions of several previous works to the development of VideoComposer. This includes, but is not limited to [Composer](https://arxiv.org/abs/2302.09778), [ModelScopeT2V](https://modelscope.cn/models/damo/text-to-video-synthesis/summary), [Stable Diffusion](https://github.com/Stability-AI/stablediffusion), [OpenCLIP](https://github.com/mlfoundations/open_clip), [WebVid-10M](https://m-bain.github.io/webvid-dataset/), [LAION-400M](https://laion.ai/blog/laion-400-open-dataset/), [Pidinet](https://github.com/zhuoinoulu/pidinet) and [MiDaS](https://github.com/isl-org/MiDaS). We are committed to building upon these foundations in a way that respects their original contributions. 208 | 209 | 210 | ## Disclaimer 211 | 212 | This open-source model is trained on the [WebVid-10M](https://m-bain.github.io/webvid-dataset/) and [LAION-400M](https://laion.ai/blog/laion-400-open-dataset/) datasets and is intended for RESEARCH/NON-COMMERCIAL USE ONLY. We have also trained more powerful models using internal video data, which can be used in the future. 213 | -------------------------------------------------------------------------------- /artist/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/.DS_Store -------------------------------------------------------------------------------- /artist/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import torch 4 | import torch.distributed as dist 5 | import oss2 as oss 6 | 7 | __all__ = ['DOWNLOAD_TO_CACHE'] 8 | 9 | 10 | def DOWNLOAD_TO_CACHE(oss_key, 11 | file_or_dirname=None, 12 | cache_dir=osp.join('/'.join(osp.abspath(__file__).split('/')[:-2]), 'model_weights')): 13 | r"""Download OSS [file or folder] to the cache folder. 14 | Only the 0th process on each node will run the downloading. 15 | Barrier all processes until the downloading is completed. 16 | """ 17 | # source and target paths 18 | base_path = osp.join(cache_dir, file_or_dirname or osp.basename(oss_key)) 19 | 20 | return base_path 21 | -------------------------------------------------------------------------------- /artist/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /artist/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .samplers import * 2 | from .tokenizers import * 3 | from .transforms import * 4 | -------------------------------------------------------------------------------- /artist/data/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/data/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /artist/data/__pycache__/samplers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/data/__pycache__/samplers.cpython-38.pyc -------------------------------------------------------------------------------- /artist/data/__pycache__/tokenizers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/data/__pycache__/tokenizers.cpython-38.pyc -------------------------------------------------------------------------------- /artist/data/__pycache__/transforms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/data/__pycache__/transforms.cpython-38.pyc -------------------------------------------------------------------------------- /artist/data/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/data/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /artist/data/samplers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | import os.path as osp 4 | from torch.utils.data.sampler import Sampler 5 | 6 | from ..ops.distributed import get_rank, get_world_size, shared_random_seed 7 | from ..ops.utils import ceil_divide, read 8 | 9 | __all__ = ['BatchSampler','GroupSampler','ImgGroupSampler'] 10 | 11 | class BatchSampler(Sampler): 12 | r"""An infinite batch sampler. 13 | """ 14 | def __init__(self, dataset_size, batch_size, num_replicas=None, rank=None, shuffle=False, seed=None): 15 | self.dataset_size = dataset_size 16 | self.batch_size = batch_size 17 | self.num_replicas = num_replicas or get_world_size() 18 | self.rank = rank or get_rank() 19 | self.shuffle = shuffle 20 | self.seed = seed or shared_random_seed() 21 | self.rng = np.random.default_rng(self.seed + self.rank) 22 | self.batches_per_rank = ceil_divide(dataset_size, self.num_replicas * self.batch_size) 23 | self.samples_per_rank = self.batches_per_rank * self.batch_size 24 | 25 | # rank indices 26 | indices = self.rng.permutation(self.samples_per_rank) if shuffle else np.arange(self.samples_per_rank) 27 | indices = indices * self.num_replicas + self.rank 28 | indices = indices[indices < dataset_size] 29 | self.indices = indices 30 | 31 | def __iter__(self): 32 | start = 0 33 | while True: 34 | batch = [self.indices[i % len(self.indices)] for i in range(start, start + self.batch_size)] 35 | if self.shuffle and (start + self.batch_size) > len(self.indices): 36 | self.rng.shuffle(self.indices) 37 | start = (start + self.batch_size) % len(self.indices) 38 | yield batch 39 | 40 | class GroupSampler(Sampler): 41 | 42 | def __init__(self, group_file, batch_size, alpha=0.7, update_interval=5000, seed=8888): 43 | self.group_file = group_file 44 | self.group_folder = osp.join(osp.dirname(group_file), 'groups') 45 | self.batch_size = batch_size 46 | self.alpha = alpha 47 | self.update_interval = update_interval 48 | self.seed = seed 49 | self.rng = np.random.default_rng(seed) 50 | 51 | def __iter__(self): 52 | while True: 53 | # keep groups up-to-date 54 | self.update_groups() 55 | 56 | # collect items 57 | items = self.sample() 58 | while len(items) < self.batch_size: 59 | items += self.sample() 60 | 61 | # sample a batch 62 | batch = self.rng.choice(items, self.batch_size, replace=False if len(items) >= self.batch_size else True) 63 | yield [u.strip().split(',') for u in batch] 64 | 65 | def update_groups(self): 66 | if not hasattr(self, '_step'): 67 | self._step = 0 68 | if self._step % self.update_interval == 0: 69 | self.groups = json.loads(read(self.group_file)) 70 | self._step += 1 71 | 72 | def sample(self): 73 | scales = np.array([float(next(iter(u)).split(':')[-1]) for u in self.groups]) 74 | p = scales ** self.alpha / (scales ** self.alpha).sum() 75 | group = self.rng.choice(self.groups, p=p) 76 | list_file = osp.join(self.group_folder, self.rng.choice(next(iter(group.values())))) 77 | return read(list_file).strip().split('\n') 78 | 79 | class ImgGroupSampler(Sampler): 80 | 81 | def __init__(self, group_file, batch_size, alpha=0.7, update_interval=5000, seed=8888): 82 | self.group_file = group_file 83 | self.group_folder = osp.join(osp.dirname(group_file), 'groups') 84 | self.batch_size = batch_size 85 | self.alpha = alpha 86 | self.update_interval = update_interval 87 | self.seed = seed 88 | self.rng = np.random.default_rng(seed) 89 | 90 | def __iter__(self): 91 | while True: 92 | # keep groups up-to-date 93 | self.update_groups() 94 | 95 | # collect items 96 | items = self.sample() 97 | while len(items) < self.batch_size: 98 | items += self.sample() 99 | 100 | # sample a batch 101 | batch = self.rng.choice(items, self.batch_size, replace=False if len(items) >= self.batch_size else True) 102 | yield [u.strip().split(',', 1) for u in batch] 103 | 104 | def update_groups(self): 105 | if not hasattr(self, '_step'): 106 | self._step = 0 107 | if self._step % self.update_interval == 0: 108 | self.groups = json.loads(read(self.group_file)) 109 | 110 | self._step += 1 111 | 112 | def sample(self): 113 | scales = np.array([float(next(iter(u)).split(':')[-1]) for u in self.groups]) 114 | p = scales ** self.alpha / (scales ** self.alpha).sum() 115 | group = self.rng.choice(self.groups, p=p) 116 | list_file = osp.join(self.group_folder, self.rng.choice(next(iter(group.values())))) 117 | return read(list_file).strip().split('\n') -------------------------------------------------------------------------------- /artist/data/tokenizers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gzip 3 | import html 4 | import ftfy 5 | import regex as re 6 | import torch 7 | from tokenizers import CharBPETokenizer, BertWordPieceTokenizer 8 | from functools import lru_cache 9 | 10 | __all__ = ['CLIPTokenizer'] 11 | 12 | #-------------------------------- CLIPTokenizer --------------------------------# 13 | 14 | @lru_cache() 15 | def default_bpe(): 16 | root = os.path.realpath(__file__) 17 | root = '/'.join(root.split('/')[:-1]) 18 | return os.path.join(root, 'bpe_simple_vocab_16e6.txt.gz') 19 | 20 | @lru_cache() 21 | def bytes_to_unicode(): 22 | """ 23 | Returns list of utf-8 byte and a corresponding list of unicode strings. 24 | The reversible bpe codes work on unicode strings. 25 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 26 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 27 | This is a signficant percentage of your normal, say, 32K bpe vocab. 28 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 29 | And avoids mapping to whitespace/control characters the bpe code barfs on. 30 | """ 31 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 32 | cs = bs[:] 33 | n = 0 34 | for b in range(2**8): 35 | if b not in bs: 36 | bs.append(b) 37 | cs.append(2**8+n) 38 | n += 1 39 | cs = [chr(n) for n in cs] 40 | return dict(zip(bs, cs)) 41 | 42 | def get_pairs(word): 43 | """Return set of symbol pairs in a word. 44 | Word is represented as tuple of symbols (symbols being variable-length strings). 45 | """ 46 | pairs = set() 47 | prev_char = word[0] 48 | for char in word[1:]: 49 | pairs.add((prev_char, char)) 50 | prev_char = char 51 | return pairs 52 | 53 | def basic_clean(text): 54 | text = ftfy.fix_text(text) 55 | text = html.unescape(html.unescape(text)) 56 | return text.strip() 57 | 58 | def whitespace_clean(text): 59 | text = re.sub(r'\s+', ' ', text) 60 | text = text.strip() 61 | return text 62 | 63 | class SimpleTokenizer(object): 64 | 65 | def __init__(self, bpe_path: str = default_bpe()): 66 | self.byte_encoder = bytes_to_unicode() 67 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 68 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 69 | merges = merges[1:49152-256-2+1] 70 | merges = [tuple(merge.split()) for merge in merges] 71 | vocab = list(bytes_to_unicode().values()) 72 | vocab = vocab + [v+'' for v in vocab] 73 | for merge in merges: 74 | vocab.append(''.join(merge)) 75 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 76 | self.encoder = dict(zip(vocab, range(len(vocab)))) 77 | self.decoder = {v: k for k, v in self.encoder.items()} 78 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 79 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 80 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 81 | 82 | def bpe(self, token): 83 | if token in self.cache: 84 | return self.cache[token] 85 | word = tuple(token[:-1]) + ( token[-1] + '',) 86 | pairs = get_pairs(word) 87 | 88 | if not pairs: 89 | return token+'' 90 | 91 | while True: 92 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 93 | if bigram not in self.bpe_ranks: 94 | break 95 | first, second = bigram 96 | new_word = [] 97 | i = 0 98 | while i < len(word): 99 | try: 100 | j = word.index(first, i) 101 | new_word.extend(word[i:j]) 102 | i = j 103 | except: 104 | new_word.extend(word[i:]) 105 | break 106 | 107 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 108 | new_word.append(first+second) 109 | i += 2 110 | else: 111 | new_word.append(word[i]) 112 | i += 1 113 | new_word = tuple(new_word) 114 | word = new_word 115 | if len(word) == 1: 116 | break 117 | else: 118 | pairs = get_pairs(word) 119 | word = ' '.join(word) 120 | self.cache[token] = word 121 | return word 122 | 123 | def encode(self, text): 124 | bpe_tokens = [] 125 | text = whitespace_clean(basic_clean(text)).lower() 126 | for token in re.findall(self.pat, text): 127 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 128 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 129 | return bpe_tokens 130 | 131 | def decode(self, tokens): 132 | text = ''.join([self.decoder[token] for token in tokens]) 133 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 134 | return text 135 | 136 | class CLIPTokenizer(object): 137 | 138 | def __init__(self, length=77): 139 | self.length = length 140 | 141 | # init tokenizer 142 | self.tokenizer = SimpleTokenizer(bpe_path=default_bpe()) 143 | self.sos_token = self.tokenizer.encoder['<|startoftext|>'] 144 | self.eos_token = self.tokenizer.encoder['<|endoftext|>'] 145 | self.vocab_size = len(self.tokenizer.encoder) 146 | 147 | def __call__(self, sequence): 148 | if isinstance(sequence, str): 149 | return torch.LongTensor(self._tokenizer(sequence)) 150 | elif isinstance(sequence, list): 151 | return torch.LongTensor([self._tokenizer(u) for u in sequence]) 152 | else: 153 | raise TypeError(f'Expected the "sequence" to be a string or a list, but got {type(sequence)}') 154 | 155 | def _tokenizer(self, text): 156 | tokens = self.tokenizer.encode(text)[:self.length - 2] 157 | tokens = [self.sos_token] + tokens + [self.eos_token] 158 | tokens = tokens + [0] * (self.length - len(tokens)) 159 | return tokens 160 | -------------------------------------------------------------------------------- /artist/data/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms.functional as F 3 | import random 4 | import math 5 | import numpy as np 6 | from PIL import Image, ImageFilter 7 | 8 | __all__ = ['Compose', 'Resize', 'Rescale', 'CenterCrop', 'CenterCropV2', 'RandomCrop', 'RandomCropV2', 'RandomHFlip',\ 9 | 'GaussianBlur', 'ColorJitter', 'RandomGray', 'ToTensor', 'Normalize', "ResizeRandomCrop", "ExtractResizeRandomCrop", "ExtractResizeAssignCrop"] 10 | 11 | # class Compose(object): 12 | 13 | # def __init__(self, transforms): 14 | # self.transforms = transforms 15 | 16 | # def __call__(self, rgb): 17 | # for t in self.transforms: 18 | # rgb = t(rgb) 19 | # return rgb 20 | class Compose(object): 21 | 22 | def __init__(self, transforms): 23 | self.transforms = transforms 24 | 25 | def __getitem__(self, index): 26 | if isinstance(index, slice): 27 | return Compose(self.transforms[index]) 28 | else: 29 | return self.transforms[index] 30 | 31 | def __len__(self): 32 | return len(self.transforms) 33 | 34 | def __call__(self, rgb): 35 | for t in self.transforms: 36 | rgb = t(rgb) 37 | return rgb 38 | 39 | class Resize(object): 40 | 41 | def __init__(self, size=256): 42 | if isinstance(size, int): 43 | size = (size, size) 44 | self.size = size 45 | 46 | def __call__(self, rgb): 47 | 48 | rgb = [u.resize(self.size, Image.BILINEAR) for u in rgb] 49 | return rgb 50 | 51 | class Rescale(object): 52 | 53 | def __init__(self, size=256, interpolation=Image.BILINEAR): 54 | self.size = size 55 | self.interpolation = interpolation 56 | 57 | def __call__(self, rgb): 58 | w, h = rgb[0].size 59 | scale = self.size / min(w, h) 60 | out_w, out_h = int(round(w * scale)), int(round(h * scale)) 61 | rgb = [u.resize((out_w, out_h), self.interpolation) for u in rgb] 62 | return rgb 63 | 64 | class CenterCrop(object): 65 | 66 | def __init__(self, size=224): 67 | self.size = size 68 | 69 | def __call__(self, rgb): 70 | w, h = rgb[0].size 71 | assert min(w, h) >= self.size 72 | x1 = (w - self.size) // 2 73 | y1 = (h - self.size) // 2 74 | rgb = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in rgb] 75 | return rgb 76 | 77 | class ResizeRandomCrop(object): 78 | 79 | def __init__(self, size=256, size_short=292): 80 | self.size = size 81 | # self.min_area = min_area 82 | self.size_short = size_short 83 | 84 | def __call__(self, rgb): 85 | 86 | # consistent crop between rgb and m 87 | while min(rgb[0].size) >= 2 * self.size_short: 88 | rgb = [u.resize((u.width // 2, u.height // 2), resample=Image.BOX) for u in rgb] 89 | scale = self.size_short / min(rgb[0].size) 90 | rgb = [u.resize((round(scale * u.width), round(scale * u.height)), resample=Image.BICUBIC) for u in rgb] 91 | out_w = self.size 92 | out_h = self.size 93 | w, h = rgb[0].size # (518, 292) 94 | x1 = random.randint(0, w - out_w) 95 | y1 = random.randint(0, h - out_h) 96 | 97 | rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb] 98 | # rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb] 99 | # # center crop 100 | # x1 = (img[0].width - self.size) // 2 101 | # y1 = (img[0].height - self.size) // 2 102 | # img = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in img] 103 | return rgb 104 | 105 | 106 | 107 | class ExtractResizeRandomCrop(object): 108 | 109 | def __init__(self, size=256, size_short=292): 110 | self.size = size 111 | # self.min_area = min_area 112 | self.size_short = size_short 113 | 114 | def __call__(self, rgb): 115 | 116 | # consistent crop between rgb and m 117 | while min(rgb[0].size) >= 2 * self.size_short: 118 | rgb = [u.resize((u.width // 2, u.height // 2), resample=Image.BOX) for u in rgb] 119 | scale = self.size_short / min(rgb[0].size) 120 | rgb = [u.resize((round(scale * u.width), round(scale * u.height)), resample=Image.BICUBIC) for u in rgb] 121 | out_w = self.size 122 | out_h = self.size 123 | w, h = rgb[0].size # (518, 292) 124 | x1 = random.randint(0, w - out_w) 125 | y1 = random.randint(0, h - out_h) 126 | 127 | rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb] 128 | # rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb] 129 | # # center crop 130 | # x1 = (img[0].width - self.size) // 2 131 | # y1 = (img[0].height - self.size) // 2 132 | # img = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in img] 133 | wh = [x1, y1, x1 + out_w, y1 + out_h] 134 | return rgb, wh 135 | 136 | 137 | 138 | 139 | class ExtractResizeAssignCrop(object): 140 | 141 | def __init__(self, size=256, size_short=292): 142 | self.size = size 143 | # self.min_area = min_area 144 | self.size_short = size_short 145 | 146 | def __call__(self, rgb, wh): 147 | 148 | # consistent crop between rgb and m 149 | while min(rgb[0].size) >= 2 * self.size_short: 150 | rgb = [u.resize((u.width // 2, u.height // 2), resample=Image.BOX) for u in rgb] 151 | scale = self.size_short / min(rgb[0].size) 152 | rgb = [u.resize((round(scale * u.width), round(scale * u.height)), resample=Image.BICUBIC) for u in rgb] 153 | # out_w = self.size 154 | # out_h = self.size 155 | # w, h = rgb[0].size # (518, 292) 156 | # x1 = random.randint(0, w - out_w) 157 | # y1 = random.randint(0, h - out_h) 158 | 159 | rgb = [u.crop(wh) for u in rgb] 160 | rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb] 161 | # # center crop 162 | # x1 = (img[0].width - self.size) // 2 163 | # y1 = (img[0].height - self.size) // 2 164 | # img = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in img] 165 | # wh = [x1, y1, x1 + out_w, y1 + out_h] 166 | return rgb 167 | 168 | class CenterCropV2(object): 169 | def __init__(self, size): 170 | self.size = size 171 | 172 | def __call__(self, img): 173 | # fast resize 174 | while min(img[0].size) >= 2 * self.size: 175 | img = [u.resize((u.width // 2, u.height // 2), resample=Image.BOX) for u in img] 176 | scale = self.size / min(img[0].size) 177 | img = [u.resize((round(scale * u.width), round(scale * u.height)), resample=Image.BICUBIC) for u in img] 178 | 179 | # center crop 180 | x1 = (img[0].width - self.size) // 2 181 | y1 = (img[0].height - self.size) // 2 182 | img = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in img] 183 | return img 184 | 185 | class RandomCrop(object): 186 | 187 | def __init__(self, size=224, min_area=0.4): 188 | self.size = size 189 | self.min_area = min_area 190 | 191 | def __call__(self, rgb): 192 | 193 | # consistent crop between rgb and m 194 | w, h = rgb[0].size 195 | area = w * h 196 | out_w, out_h = float('inf'), float('inf') 197 | while out_w > w or out_h > h: 198 | target_area = random.uniform(self.min_area, 1.0) * area 199 | aspect_ratio = random.uniform(3. / 4., 4. / 3.) 200 | out_w = int(round(math.sqrt(target_area * aspect_ratio))) 201 | out_h = int(round(math.sqrt(target_area / aspect_ratio))) 202 | x1 = random.randint(0, w - out_w) 203 | y1 = random.randint(0, h - out_h) 204 | 205 | rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb] 206 | rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb] 207 | 208 | return rgb 209 | 210 | class RandomCropV2(object): 211 | 212 | def __init__(self, size=224, min_area=0.4, ratio=(3. / 4., 4. / 3.)): 213 | if isinstance(size, (tuple, list)): 214 | self.size = size 215 | else: 216 | self.size = (size, size) 217 | self.min_area = min_area 218 | self.ratio = ratio 219 | 220 | def _get_params(self, img): 221 | width, height = img.size 222 | area = height * width 223 | 224 | for _ in range(10): 225 | target_area = random.uniform(self.min_area, 1.0) * area 226 | log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1])) 227 | aspect_ratio = math.exp(random.uniform(*log_ratio)) 228 | 229 | w = int(round(math.sqrt(target_area * aspect_ratio))) 230 | h = int(round(math.sqrt(target_area / aspect_ratio))) 231 | 232 | if 0 < w <= width and 0 < h <= height: 233 | i = random.randint(0, height - h) 234 | j = random.randint(0, width - w) 235 | return i, j, h, w 236 | 237 | # Fallback to central crop 238 | in_ratio = float(width) / float(height) 239 | if (in_ratio < min(self.ratio)): 240 | w = width 241 | h = int(round(w / min(self.ratio))) 242 | elif (in_ratio > max(self.ratio)): 243 | h = height 244 | w = int(round(h * max(self.ratio))) 245 | else: # whole image 246 | w = width 247 | h = height 248 | i = (height - h) // 2 249 | j = (width - w) // 2 250 | return i, j, h, w 251 | 252 | def __call__(self, rgb): 253 | i, j, h, w = self._get_params(rgb[0]) 254 | rgb = [F.resized_crop(u, i, j, h, w, self.size) for u in rgb] 255 | return rgb 256 | 257 | class RandomHFlip(object): 258 | 259 | def __init__(self, p=0.5): 260 | self.p = p 261 | 262 | def __call__(self, rgb): 263 | if random.random() < self.p: 264 | rgb = [u.transpose(Image.FLIP_LEFT_RIGHT) for u in rgb] 265 | return rgb 266 | 267 | class GaussianBlur(object): 268 | 269 | def __init__(self, sigmas=[0.1, 2.0], p=0.5): 270 | self.sigmas = sigmas 271 | self.p = p 272 | 273 | def __call__(self, rgb): 274 | if random.random() < self.p: 275 | sigma = random.uniform(*self.sigmas) 276 | rgb = [u.filter(ImageFilter.GaussianBlur(radius=sigma)) for u in rgb] 277 | return rgb 278 | 279 | class ColorJitter(object): 280 | 281 | def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, p=0.5): 282 | self.brightness = brightness 283 | self.contrast = contrast 284 | self.saturation = saturation 285 | self.hue = hue 286 | self.p = p 287 | 288 | def __call__(self, rgb): 289 | if random.random() < self.p: 290 | brightness, contrast, saturation, hue = self._random_params() 291 | transforms = [ 292 | lambda f: F.adjust_brightness(f, brightness), 293 | lambda f: F.adjust_contrast(f, contrast), 294 | lambda f: F.adjust_saturation(f, saturation), 295 | lambda f: F.adjust_hue(f, hue)] 296 | random.shuffle(transforms) 297 | for t in transforms: 298 | rgb = [t(u) for u in rgb] 299 | 300 | return rgb 301 | 302 | def _random_params(self): 303 | brightness = random.uniform( 304 | max(0, 1 - self.brightness), 1 + self.brightness) 305 | contrast = random.uniform( 306 | max(0, 1 - self.contrast), 1 + self.contrast) 307 | saturation = random.uniform( 308 | max(0, 1 - self.saturation), 1 + self.saturation) 309 | hue = random.uniform(-self.hue, self.hue) 310 | return brightness, contrast, saturation, hue 311 | 312 | class RandomGray(object): 313 | 314 | def __init__(self, p=0.2): 315 | self.p = p 316 | 317 | def __call__(self, rgb): 318 | if random.random() < self.p: 319 | rgb = [u.convert('L').convert('RGB') for u in rgb] 320 | return rgb 321 | 322 | class ToTensor(object): 323 | 324 | def __call__(self, rgb): 325 | rgb = torch.stack([F.to_tensor(u) for u in rgb], dim=0) 326 | return rgb 327 | 328 | class Normalize(object): 329 | 330 | def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): 331 | self.mean = mean 332 | self.std = std 333 | 334 | def __call__(self, rgb): 335 | rgb = rgb.clone() 336 | rgb.clamp_(0, 1) 337 | if not isinstance(self.mean, torch.Tensor): 338 | self.mean = rgb.new_tensor(self.mean).view(-1) 339 | if not isinstance(self.std, torch.Tensor): 340 | self.std = rgb.new_tensor(self.std).view(-1) 341 | rgb.sub_(self.mean.view(1, -1, 1, 1)).div_(self.std.view(1, -1, 1, 1)) 342 | return rgb 343 | 344 | -------------------------------------------------------------------------------- /artist/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | from .midas import * 3 | 4 | -------------------------------------------------------------------------------- /artist/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /artist/models/__pycache__/clip.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/models/__pycache__/clip.cpython-38.pyc -------------------------------------------------------------------------------- /artist/models/__pycache__/midas.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/models/__pycache__/midas.cpython-38.pyc -------------------------------------------------------------------------------- /artist/models/clip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | import artist.ops as ops # for using differentiable all_gather 7 | from artist import DOWNLOAD_TO_CACHE 8 | 9 | __all__ = ['CLIP', 'clip_vit_b_32', 'clip_vit_b_16', 'clip_vit_l_14', 'clip_vit_l_14_336px', 'clip_vit_h_16'] 10 | 11 | def to_fp16(m): 12 | if isinstance(m, (nn.Linear, nn.Conv2d)): 13 | m.weight.data = m.weight.data.half() 14 | if m.bias is not None: 15 | m.bias.data = m.bias.data.half() 16 | elif hasattr(m, 'head'): 17 | p = getattr(m, 'head') 18 | p.data = p.data.half() 19 | 20 | class QuickGELU(nn.Module): 21 | 22 | def forward(self, x): 23 | return x * torch.sigmoid(1.702 * x) 24 | 25 | class LayerNorm(nn.LayerNorm): 26 | r"""Subclass of nn.LayerNorm to handle fp16. 27 | """ 28 | def forward(self, x): 29 | return super(LayerNorm, self).forward(x.float()).type_as(x) 30 | 31 | class SelfAttention(nn.Module): 32 | 33 | def __init__(self, dim, num_heads, attn_dropout=0.0, proj_dropout=0.0): 34 | assert dim % num_heads == 0 35 | super(SelfAttention, self).__init__() 36 | self.dim = dim 37 | self.num_heads = num_heads 38 | self.head_dim = dim // num_heads 39 | self.scale = 1.0 / math.sqrt(self.head_dim) 40 | 41 | # layers 42 | self.to_qkv = nn.Linear(dim, dim * 3) 43 | self.attn_dropout = nn.Dropout(attn_dropout) 44 | self.proj = nn.Linear(dim, dim) 45 | self.proj_dropout = nn.Dropout(proj_dropout) 46 | 47 | def forward(self, x, mask=None): 48 | r"""x: [B, L, C]. 49 | mask: [*, L, L]. 50 | """ 51 | b, l, c, n = *x.size(), self.num_heads 52 | 53 | # compute query, key, and value 54 | q, k, v = self.to_qkv(x.transpose(0, 1)).chunk(3, dim=-1) 55 | q = q.reshape(l, b * n, -1).transpose(0, 1) 56 | k = k.reshape(l, b * n, -1).transpose(0, 1) 57 | v = v.reshape(l, b * n, -1).transpose(0, 1) 58 | 59 | # compute attention 60 | attn = self.scale * torch.bmm(q, k.transpose(1, 2)) 61 | if mask is not None: 62 | attn = attn.masked_fill(mask[:, :l, :l] == 0, float('-inf')) 63 | attn = F.softmax(attn.float(), dim=-1).type_as(attn) 64 | attn = self.attn_dropout(attn) 65 | 66 | # gather context 67 | x = torch.bmm(attn, v) 68 | x = x.view(b, n, l, -1).transpose(1, 2).reshape(b, l, -1) 69 | 70 | # output 71 | x = self.proj(x) 72 | x = self.proj_dropout(x) 73 | return x 74 | 75 | class AttentionBlock(nn.Module): 76 | 77 | def __init__(self, dim, num_heads, attn_dropout=0.0, proj_dropout=0.0): 78 | super(AttentionBlock, self).__init__() 79 | self.dim = dim 80 | self.num_heads = num_heads 81 | 82 | # layers 83 | self.norm1 = LayerNorm(dim) 84 | self.attn = SelfAttention(dim, num_heads, attn_dropout, proj_dropout) 85 | self.norm2 = LayerNorm(dim) 86 | self.mlp = nn.Sequential( 87 | nn.Linear(dim, dim * 4), 88 | QuickGELU(), 89 | nn.Linear(dim * 4, dim), 90 | nn.Dropout(proj_dropout)) 91 | 92 | def forward(self, x, mask=None): 93 | x = x + self.attn(self.norm1(x), mask) 94 | x = x + self.mlp(self.norm2(x)) 95 | return x 96 | 97 | class VisionTransformer(nn.Module): 98 | 99 | def __init__(self, 100 | image_size=224, 101 | patch_size=16, 102 | dim=768, 103 | out_dim=512, 104 | num_heads=12, 105 | num_layers=12, 106 | attn_dropout=0.0, 107 | proj_dropout=0.0, 108 | embedding_dropout=0.0): 109 | assert image_size % patch_size == 0 110 | super(VisionTransformer, self).__init__() 111 | self.image_size = image_size 112 | self.patch_size = patch_size 113 | self.dim = dim 114 | self.out_dim = out_dim 115 | self.num_heads = num_heads 116 | self.num_layers = num_layers 117 | self.num_patches = (image_size // patch_size) ** 2 118 | 119 | # embeddings 120 | gain = 1.0 / math.sqrt(dim) 121 | self.patch_embedding = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size, bias=False) 122 | self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) 123 | self.pos_embedding = nn.Parameter(gain * torch.randn(1, self.num_patches + 1, dim)) 124 | self.dropout = nn.Dropout(embedding_dropout) 125 | 126 | # transformer 127 | self.pre_norm = LayerNorm(dim) 128 | self.transformer = nn.Sequential(*[ 129 | AttentionBlock(dim, num_heads, attn_dropout, proj_dropout) 130 | for _ in range(num_layers)]) 131 | self.post_norm = LayerNorm(dim) 132 | 133 | # head 134 | self.head = nn.Parameter(gain * torch.randn(dim, out_dim)) 135 | 136 | def forward(self, x): 137 | b, dtype = x.size(0), self.head.dtype 138 | x = x.type(dtype) 139 | 140 | # patch-embedding 141 | x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) # [b, n, c] 142 | x = torch.cat([self.cls_embedding.repeat(b, 1, 1).type(dtype), x], dim=1) 143 | x = self.dropout(x + self.pos_embedding.type(dtype)) 144 | x = self.pre_norm(x) 145 | 146 | # transformer 147 | x = self.transformer(x) 148 | 149 | # head 150 | x = self.post_norm(x) 151 | x = torch.mm(x[:, 0, :], self.head) 152 | return x 153 | 154 | def fp16(self): 155 | return self.apply(to_fp16) 156 | 157 | class TextTransformer(nn.Module): 158 | 159 | def __init__(self, 160 | vocab_size, 161 | text_len, 162 | dim=512, 163 | out_dim=512, 164 | num_heads=8, 165 | num_layers=12, 166 | attn_dropout=0.0, 167 | proj_dropout=0.0, 168 | embedding_dropout=0.0): 169 | super(TextTransformer, self).__init__() 170 | self.vocab_size = vocab_size 171 | self.text_len = text_len 172 | self.dim = dim 173 | self.out_dim = out_dim 174 | self.num_heads = num_heads 175 | self.num_layers = num_layers 176 | 177 | # embeddings 178 | self.token_embedding = nn.Embedding(vocab_size, dim) 179 | self.pos_embedding = nn.Parameter(0.01 * torch.randn(1, text_len, dim)) 180 | self.dropout = nn.Dropout(embedding_dropout) 181 | 182 | # transformer 183 | self.transformer = nn.ModuleList([ 184 | AttentionBlock(dim, num_heads, attn_dropout, proj_dropout) 185 | for _ in range(num_layers)]) 186 | self.norm = LayerNorm(dim) 187 | 188 | # head 189 | gain = 1.0 / math.sqrt(dim) 190 | self.head = nn.Parameter(gain * torch.randn(dim, out_dim)) 191 | 192 | # causal attention mask 193 | self.register_buffer('attn_mask', torch.tril(torch.ones(1, text_len, text_len))) 194 | 195 | def forward(self, x): 196 | eot, dtype = x.argmax(dim=-1), self.head.dtype 197 | 198 | # embeddings 199 | x = self.dropout(self.token_embedding(x).type(dtype) + self.pos_embedding.type(dtype)) 200 | 201 | # transformer 202 | for block in self.transformer: 203 | x = block(x, self.attn_mask) 204 | 205 | # head 206 | x = self.norm(x) 207 | x = torch.mm(x[torch.arange(x.size(0)), eot], self.head) 208 | return x 209 | 210 | def fp16(self): 211 | return self.apply(to_fp16) 212 | 213 | class CLIP(nn.Module): 214 | 215 | def __init__(self, 216 | embed_dim=512, 217 | image_size=224, 218 | patch_size=16, 219 | vision_dim=768, 220 | vision_heads=12, 221 | vision_layers=12, 222 | vocab_size=49408, 223 | text_len=77, 224 | text_dim=512, 225 | text_heads=8, 226 | text_layers=12, 227 | attn_dropout=0.0, 228 | proj_dropout=0.0, 229 | embedding_dropout=0.0): 230 | super(CLIP, self).__init__() 231 | self.embed_dim = embed_dim 232 | self.image_size = image_size 233 | self.patch_size = patch_size 234 | self.vision_dim = vision_dim 235 | self.vision_heads = vision_heads 236 | self.vision_layers = vision_layers 237 | self.vocab_size = vocab_size 238 | self.text_len = text_len 239 | self.text_dim = text_dim 240 | self.text_heads = text_heads 241 | self.text_layers = text_layers 242 | 243 | # models 244 | self.visual = VisionTransformer( 245 | image_size=image_size, 246 | patch_size=patch_size, 247 | dim=vision_dim, 248 | out_dim=embed_dim, 249 | num_heads=vision_heads, 250 | num_layers=vision_layers, 251 | attn_dropout=attn_dropout, 252 | proj_dropout=proj_dropout, 253 | embedding_dropout=embedding_dropout) 254 | self.textual = TextTransformer( 255 | vocab_size=vocab_size, 256 | text_len=text_len, 257 | dim=text_dim, 258 | out_dim=embed_dim, 259 | num_heads=text_heads, 260 | num_layers=text_layers, 261 | attn_dropout=attn_dropout, 262 | proj_dropout=proj_dropout, 263 | embedding_dropout=embedding_dropout) 264 | self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([])) 265 | 266 | def forward(self, imgs, txt_tokens): 267 | r"""imgs: [B, C, H, W] of torch.float32. 268 | txt_tokens: [B, T] of torch.long. 269 | """ 270 | xi = self.visual(imgs) 271 | xt = self.textual(txt_tokens) 272 | 273 | # normalize features 274 | xi = F.normalize(xi, p=2, dim=1) 275 | xt = F.normalize(xt, p=2, dim=1) 276 | 277 | # gather features from all ranks 278 | full_xi = ops.diff_all_gather(xi) 279 | full_xt = ops.diff_all_gather(xt) 280 | 281 | # logits 282 | scale = self.log_scale.exp() 283 | logits_i2t = scale * torch.mm(xi, full_xt.t()) 284 | logits_t2i = scale * torch.mm(xt, full_xi.t()) 285 | 286 | # labels 287 | labels = torch.arange( 288 | len(xi) * ops.get_rank(), 289 | len(xi) * (ops.get_rank() + 1), 290 | dtype=torch.long, 291 | device=xi.device) 292 | return logits_i2t, logits_t2i, labels 293 | 294 | def init_weights(self): 295 | # embeddings 296 | nn.init.normal_(self.textual.token_embedding.weight, std=0.02) 297 | nn.init.normal_(self.visual.patch_embedding.weight, tsd=0.1) 298 | 299 | # attentions 300 | for modality in ['visual', 'textual']: 301 | dim = self.vision_dim if modality == 'visual' else 'textual' 302 | transformer = getattr(self, modality).transformer 303 | proj_gain = (1.0 / math.sqrt(dim)) * (1.0 / math.sqrt(2 * transformer.num_layers)) 304 | attn_gain = 1.0 / math.sqrt(dim) 305 | mlp_gain = 1.0 / math.sqrt(2.0 * dim) 306 | for block in transformer.layers: 307 | nn.init.normal_(block.attn.to_qkv.weight, std=attn_gain) 308 | nn.init.normal_(block.attn.proj.weight, std=proj_gain) 309 | nn.init.normal_(block.mlp[0].weight, std=mlp_gain) 310 | nn.init.normal_(block.mlp[2].weight, std=proj_gain) 311 | 312 | def param_groups(self): 313 | groups = [ 314 | {'params': [p for n, p in self.named_parameters() if 'norm' in n or n.endswith('bias')], 'weight_decay': 0.0}, 315 | {'params': [p for n, p in self.named_parameters() if not ('norm' in n or n.endswith('bias'))]}] 316 | return groups 317 | 318 | def fp16(self): 319 | return self.apply(to_fp16) 320 | 321 | def _clip(name, pretrained=False, **kwargs): 322 | model = CLIP(**kwargs) 323 | if pretrained: 324 | model.load_state_dict(torch.load(DOWNLOAD_TO_CACHE(f'models/clip/{name}.pth'), map_location='cpu')) 325 | return model 326 | 327 | def clip_vit_b_32(pretrained=False, **kwargs): 328 | cfg = dict( 329 | embed_dim=512, 330 | image_size=224, 331 | patch_size=32, 332 | vision_dim=768, 333 | vision_heads=12, 334 | vision_layers=12, 335 | vocab_size=49408, 336 | text_len=77, 337 | text_dim=512, 338 | text_heads=8, 339 | text_layers=12) 340 | cfg.update(**kwargs) 341 | return _clip('openai-clip-vit-base-32', pretrained, **cfg) 342 | 343 | def clip_vit_b_16(pretrained=False, **kwargs): 344 | cfg = dict( 345 | embed_dim=512, 346 | image_size=224, 347 | patch_size=32, 348 | vision_dim=768, 349 | vision_heads=12, 350 | vision_layers=12, 351 | vocab_size=49408, 352 | text_len=77, 353 | text_dim=512, 354 | text_heads=8, 355 | text_layers=12) 356 | cfg.update(**kwargs) 357 | return _clip('openai-clip-vit-base-16', pretrained, **cfg) 358 | 359 | def clip_vit_l_14(pretrained=False, **kwargs): 360 | cfg = dict( 361 | embed_dim=768, 362 | image_size=224, 363 | patch_size=14, 364 | vision_dim=1024, 365 | vision_heads=16, 366 | vision_layers=24, 367 | vocab_size=49408, 368 | text_len=77, 369 | text_dim=768, 370 | text_heads=12, 371 | text_layers=12) 372 | cfg.update(**kwargs) 373 | return _clip('openai-clip-vit-large-14', pretrained, **cfg) 374 | 375 | def clip_vit_l_14_336px(pretrained=False, **kwargs): 376 | cfg = dict( 377 | embed_dim=768, 378 | image_size=336, 379 | patch_size=14, 380 | vision_dim=1024, 381 | vision_heads=16, 382 | vision_layers=24, 383 | vocab_size=49408, 384 | text_len=77, 385 | text_dim=768, 386 | text_heads=12, 387 | text_layers=12) 388 | cfg.update(**kwargs) 389 | return _clip('openai-clip-vit-large-14-336px', pretrained, **cfg) 390 | 391 | def clip_vit_h_16(pretrained=False, **kwargs): 392 | assert not pretrained, 'pretrained model for openai-clip-vit-huge-16 is not available!' 393 | cfg = dict( 394 | embed_dim=1024, 395 | image_size=256, 396 | patch_size=16, 397 | vision_dim=1280, 398 | vision_heads=16, 399 | vision_layers=32, 400 | vocab_size=49408, 401 | text_len=77, 402 | text_dim=1024, 403 | text_heads=16, 404 | text_layers=24) 405 | cfg.update(**kwargs) 406 | return _clip('openai-clip-vit-huge-16', pretrained, **cfg) 407 | -------------------------------------------------------------------------------- /artist/models/midas.py: -------------------------------------------------------------------------------- 1 | r"""A much cleaner re-implementation of ``https://github.com/isl-org/MiDaS''. 2 | Image augmentation: T.Compose([ 3 | Resize( 4 | keep_aspect_ratio=True, 5 | ensure_multiple_of=32, 6 | interpolation=cv2.INTER_CUBIC), 7 | T.ToTensor(), 8 | T.Normalize( 9 | mean=[0.5, 0.5, 0.5], 10 | std=[0.5, 0.5, 0.5])]). 11 | Fast inference: 12 | model = model.to(memory_format=torch.channels_last).half() 13 | input = input.to(memory_format=torch.channels_last).half() 14 | output = model(input) 15 | """ 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | import math 20 | 21 | from artist import DOWNLOAD_TO_CACHE 22 | 23 | __all__ = ['MiDaS', 'midas_v3'] 24 | 25 | class SelfAttention(nn.Module): 26 | 27 | def __init__(self, dim, num_heads): 28 | assert dim % num_heads == 0 29 | super(SelfAttention, self).__init__() 30 | self.dim = dim 31 | self.num_heads = num_heads 32 | self.head_dim = dim // num_heads 33 | self.scale = 1.0 / math.sqrt(self.head_dim) 34 | 35 | # layers 36 | self.to_qkv = nn.Linear(dim, dim * 3) 37 | self.proj = nn.Linear(dim, dim) 38 | 39 | def forward(self, x): 40 | b, l, c, n, d = *x.size(), self.num_heads, self.head_dim 41 | 42 | # compute query, key, value 43 | q, k, v = self.to_qkv(x).view(b, l, n * 3, d).chunk(3, dim=2) 44 | 45 | # compute attention 46 | attn = self.scale * torch.einsum('binc,bjnc->bnij', q, k) 47 | attn = F.softmax(attn.float(), dim=-1).type_as(attn) 48 | 49 | # gather context 50 | x = torch.einsum('bnij,bjnc->binc', attn, v) 51 | x = x.reshape(b, l, c) 52 | 53 | # output 54 | x = self.proj(x) 55 | return x 56 | 57 | class AttentionBlock(nn.Module): 58 | 59 | def __init__(self, dim, num_heads): 60 | super(AttentionBlock, self).__init__() 61 | self.dim = dim 62 | self.num_heads = num_heads 63 | 64 | # layers 65 | self.norm1 = nn.LayerNorm(dim) 66 | self.attn = SelfAttention(dim, num_heads) 67 | self.norm2 = nn.LayerNorm(dim) 68 | self.mlp = nn.Sequential( 69 | nn.Linear(dim, dim * 4), 70 | nn.GELU(), 71 | nn.Linear(dim * 4, dim)) 72 | 73 | def forward(self, x): 74 | x = x + self.attn(self.norm1(x)) 75 | x = x + self.mlp(self.norm2(x)) 76 | return x 77 | 78 | class VisionTransformer(nn.Module): 79 | 80 | def __init__(self, 81 | image_size=384, 82 | patch_size=16, 83 | dim=1024, 84 | out_dim=1000, 85 | num_heads=16, 86 | num_layers=24): 87 | assert image_size % patch_size == 0 88 | super(VisionTransformer, self).__init__() 89 | self.image_size = image_size 90 | self.patch_size = patch_size 91 | self.dim = dim 92 | self.out_dim = out_dim 93 | self.num_heads = num_heads 94 | self.num_layers = num_layers 95 | self.num_patches = (image_size // patch_size) ** 2 96 | 97 | # embeddings 98 | self.patch_embedding = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size) 99 | self.cls_embedding = nn.Parameter(torch.zeros(1, 1, dim)) 100 | self.pos_embedding = nn.Parameter(torch.empty(1, self.num_patches + 1, dim).normal_(std=0.02)) 101 | 102 | # blocks 103 | self.blocks = nn.Sequential(*[AttentionBlock(dim, num_heads) for _ in range(num_layers)]) 104 | self.norm = nn.LayerNorm(dim) 105 | 106 | # head 107 | self.head = nn.Linear(dim, out_dim) 108 | 109 | def forward(self, x): 110 | b = x.size(0) 111 | 112 | # embeddings 113 | x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) 114 | x = torch.cat([self.cls_embedding.repeat(b, 1, 1), x], dim=1) 115 | x = x + self.pos_embedding 116 | 117 | # blocks 118 | x = self.blocks(x) 119 | x = self.norm(x) 120 | 121 | # head 122 | x = self.head(x) 123 | return x 124 | 125 | class ResidualBlock(nn.Module): 126 | 127 | def __init__(self, dim): 128 | super(ResidualBlock, self).__init__() 129 | self.dim = dim 130 | 131 | # layers 132 | self.residual = nn.Sequential( 133 | nn.ReLU(inplace=False), # NOTE: avoid modifying the input 134 | nn.Conv2d(dim, dim, 3, padding=1), 135 | nn.ReLU(inplace=True), 136 | nn.Conv2d(dim, dim, 3, padding=1)) 137 | 138 | def forward(self, x): 139 | return x + self.residual(x) 140 | 141 | class FusionBlock(nn.Module): 142 | 143 | def __init__(self, dim): 144 | super(FusionBlock, self).__init__() 145 | self.dim = dim 146 | 147 | # layers 148 | self.layer1 = ResidualBlock(dim) 149 | self.layer2 = ResidualBlock(dim) 150 | self.conv_out = nn.Conv2d(dim, dim, 1) 151 | 152 | def forward(self, *xs): 153 | assert len(xs) in (1, 2), 'invalid number of inputs' 154 | if len(xs) == 1: 155 | x = self.layer2(xs[0]) 156 | else: 157 | x = self.layer2(xs[0] + self.layer1(xs[1])) 158 | x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) 159 | x = self.conv_out(x) 160 | return x 161 | 162 | class MiDaS(nn.Module): 163 | r"""MiDaS v3.0 DPT-Large from ``https://github.com/isl-org/MiDaS''. 164 | Monocular depth estimation using dense prediction transformers. 165 | """ 166 | def __init__(self, 167 | image_size=384, 168 | patch_size=16, 169 | dim=1024, 170 | neck_dims=[256, 512, 1024, 1024], 171 | fusion_dim=256, 172 | num_heads=16, 173 | num_layers=24): 174 | assert image_size % patch_size == 0 175 | assert num_layers % 4 == 0 176 | super(MiDaS, self).__init__() 177 | self.image_size = image_size 178 | self.patch_size = patch_size 179 | self.dim = dim 180 | self.neck_dims = neck_dims 181 | self.fusion_dim = fusion_dim 182 | self.num_heads = num_heads 183 | self.num_layers = num_layers 184 | self.num_patches = (image_size // patch_size) ** 2 185 | 186 | # embeddings 187 | self.patch_embedding = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size) 188 | self.cls_embedding = nn.Parameter(torch.zeros(1, 1, dim)) 189 | self.pos_embedding = nn.Parameter(torch.empty(1, self.num_patches + 1, dim).normal_(std=0.02)) 190 | 191 | # blocks 192 | stride = num_layers // 4 193 | self.blocks = nn.Sequential(*[AttentionBlock(dim, num_heads) for _ in range(num_layers)]) 194 | self.slices = [slice(i * stride, (i + 1) * stride) for i in range(4)] 195 | 196 | # stage1 (4x) 197 | self.fc1 = nn.Sequential( 198 | nn.Linear(dim * 2, dim), 199 | nn.GELU()) 200 | self.conv1 = nn.Sequential( 201 | nn.Conv2d(dim, neck_dims[0], 1), 202 | nn.ConvTranspose2d(neck_dims[0], neck_dims[0], 4, stride=4), 203 | nn.Conv2d(neck_dims[0], fusion_dim, 3, padding=1, bias=False)) 204 | self.fusion1 = FusionBlock(fusion_dim) 205 | 206 | # stage2 (8x) 207 | self.fc2 = nn.Sequential( 208 | nn.Linear(dim * 2, dim), 209 | nn.GELU()) 210 | self.conv2 = nn.Sequential( 211 | nn.Conv2d(dim, neck_dims[1], 1), 212 | nn.ConvTranspose2d(neck_dims[1], neck_dims[1], 2, stride=2), 213 | nn.Conv2d(neck_dims[1], fusion_dim, 3, padding=1, bias=False)) 214 | self.fusion2 = FusionBlock(fusion_dim) 215 | 216 | # stage3 (16x) 217 | self.fc3 = nn.Sequential( 218 | nn.Linear(dim * 2, dim), 219 | nn.GELU()) 220 | self.conv3 = nn.Sequential( 221 | nn.Conv2d(dim, neck_dims[2], 1), 222 | nn.Conv2d(neck_dims[2], fusion_dim, 3, padding=1, bias=False)) 223 | self.fusion3 = FusionBlock(fusion_dim) 224 | 225 | # stage4 (32x) 226 | self.fc4 = nn.Sequential( 227 | nn.Linear(dim * 2, dim), 228 | nn.GELU()) 229 | self.conv4 = nn.Sequential( 230 | nn.Conv2d(dim, neck_dims[3], 1), 231 | nn.Conv2d(neck_dims[3], neck_dims[3], 3, stride=2, padding=1), 232 | nn.Conv2d(neck_dims[3], fusion_dim, 3, padding=1, bias=False)) 233 | self.fusion4 = FusionBlock(fusion_dim) 234 | 235 | # head 236 | self.head = nn.Sequential( 237 | nn.Conv2d(fusion_dim, fusion_dim // 2, 3, padding=1), 238 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 239 | nn.Conv2d(fusion_dim // 2, 32, 3, padding=1), 240 | nn.ReLU(inplace=True), 241 | nn.ConvTranspose2d(32, 1, 1), 242 | nn.ReLU(inplace=True)) 243 | 244 | def forward(self, x): 245 | b, c, h, w, p = *x.size(), self.patch_size 246 | assert h % p == 0 and w % p == 0, f'Image size ({w}, {h}) is not divisible by patch size ({p}, {p})' 247 | hp, wp, grid = h // p, w // p, self.image_size // p 248 | 249 | # embeddings 250 | pos_embedding = torch.cat([ 251 | self.pos_embedding[:, :1], 252 | F.interpolate( 253 | self.pos_embedding[:, 1:].reshape(1, grid, grid, -1).permute(0, 3, 1, 2), 254 | size=(hp, wp), 255 | mode='bilinear', 256 | align_corners=False).permute(0, 2, 3, 1).reshape(1, hp * wp, -1)], dim=1) 257 | x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) 258 | x = torch.cat([self.cls_embedding.repeat(b, 1, 1), x], dim=1) 259 | x = x + pos_embedding 260 | 261 | # stage1 262 | x = self.blocks[self.slices[0]](x) 263 | x1 = torch.cat([x[:, 1:], x[:, :1].expand_as(x[:, 1:])], dim=-1) 264 | x1 = self.fc1(x1).permute(0, 2, 1).unflatten(2, (hp, wp)) 265 | x1 = self.conv1(x1) 266 | 267 | # stage2 268 | x = self.blocks[self.slices[1]](x) 269 | x2 = torch.cat([x[:, 1:], x[:, :1].expand_as(x[:, 1:])], dim=-1) 270 | x2 = self.fc2(x2).permute(0, 2, 1).unflatten(2, (hp, wp)) 271 | x2 = self.conv2(x2) 272 | 273 | # stage3 274 | x = self.blocks[self.slices[2]](x) 275 | x3 = torch.cat([x[:, 1:], x[:, :1].expand_as(x[:, 1:])], dim=-1) 276 | x3 = self.fc3(x3).permute(0, 2, 1).unflatten(2, (hp, wp)) 277 | x3 = self.conv3(x3) 278 | 279 | # stage4 280 | x = self.blocks[self.slices[3]](x) 281 | x4 = torch.cat([x[:, 1:], x[:, :1].expand_as(x[:, 1:])], dim=-1) 282 | x4 = self.fc4(x4).permute(0, 2, 1).unflatten(2, (hp, wp)) 283 | x4 = self.conv4(x4) 284 | 285 | # fusion 286 | x4 = self.fusion4(x4) 287 | x3 = self.fusion3(x4, x3) 288 | x2 = self.fusion2(x3, x2) 289 | x1 = self.fusion1(x2, x1) 290 | 291 | # head 292 | x = self.head(x1) 293 | return x 294 | 295 | def midas_v3(pretrained=False, **kwargs): 296 | cfg = dict( 297 | image_size=384, 298 | patch_size=16, 299 | dim=1024, 300 | neck_dims=[256, 512, 1024, 1024], 301 | fusion_dim=256, 302 | num_heads=16, 303 | num_layers=24) 304 | cfg.update(**kwargs) 305 | model = MiDaS(**cfg) 306 | if pretrained: 307 | # model.load_state_dict(torch.load(DOWNLOAD_TO_CACHE('experiments/models/midas/midas_v3_dpt_large.pth'), map_location='cpu')) 308 | model.load_state_dict(torch.load("./model_weights/midas_v3_dpt_large.pth", map_location='cpu')) 309 | return model 310 | -------------------------------------------------------------------------------- /artist/ops/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | from .distributed import * 3 | from .diffusion import * 4 | from .losses import * 5 | from .degration import * 6 | from .random_mask import * -------------------------------------------------------------------------------- /artist/ops/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/ops/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /artist/ops/__pycache__/degration.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/ops/__pycache__/degration.cpython-38.pyc -------------------------------------------------------------------------------- /artist/ops/__pycache__/diffusion.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/ops/__pycache__/diffusion.cpython-38.pyc -------------------------------------------------------------------------------- /artist/ops/__pycache__/distributed.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/ops/__pycache__/distributed.cpython-38.pyc -------------------------------------------------------------------------------- /artist/ops/__pycache__/dpm_solver.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/ops/__pycache__/dpm_solver.cpython-38.pyc -------------------------------------------------------------------------------- /artist/ops/__pycache__/losses.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/ops/__pycache__/losses.cpython-38.pyc -------------------------------------------------------------------------------- /artist/ops/__pycache__/random_mask.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/ops/__pycache__/random_mask.cpython-38.pyc -------------------------------------------------------------------------------- /artist/ops/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/ops/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /artist/ops/distributed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.distributed as dist 4 | import functools 5 | import pickle 6 | import numpy as np 7 | from collections import OrderedDict 8 | from torch.autograd import Function 9 | 10 | __all__ = ['is_dist_initialized', 11 | 'get_world_size', 12 | 'get_rank', 13 | 'new_group', 14 | 'destroy_process_group', 15 | 'barrier', 16 | 'broadcast', 17 | 'all_reduce', 18 | 'reduce', 19 | 'gather', 20 | 'all_gather', 21 | 'reduce_dict', 22 | 'get_global_gloo_group', 23 | 'generalized_all_gather', 24 | 'generalized_gather', 25 | 'scatter', 26 | 'reduce_scatter', 27 | 'send', 28 | 'recv', 29 | 'isend', 30 | 'irecv', 31 | 'shared_random_seed', 32 | 'diff_all_gather', 33 | 'diff_all_reduce', 34 | 'diff_scatter', 35 | 'diff_copy', 36 | 'spherical_kmeans', 37 | 'sinkhorn'] 38 | 39 | #-------------------------------- Distributed operations --------------------------------# 40 | 41 | def is_dist_initialized(): 42 | return dist.is_available() and dist.is_initialized() 43 | 44 | def get_world_size(group=None): 45 | return dist.get_world_size(group) if is_dist_initialized() else 1 46 | 47 | def get_rank(group=None): 48 | return dist.get_rank(group) if is_dist_initialized() else 0 49 | 50 | def new_group(ranks=None, **kwargs): 51 | if is_dist_initialized(): 52 | return dist.new_group(ranks, **kwargs) 53 | return None 54 | 55 | def destroy_process_group(): 56 | if is_dist_initialized(): 57 | dist.destroy_process_group() 58 | 59 | def barrier(group=None, **kwargs): 60 | if get_world_size(group) > 1: 61 | dist.barrier(group, **kwargs) 62 | 63 | def broadcast(tensor, src, group=None, **kwargs): 64 | if get_world_size(group) > 1: 65 | return dist.broadcast(tensor, src, group, **kwargs) 66 | 67 | def all_reduce(tensor, op=dist.ReduceOp.SUM, group=None, **kwargs): 68 | if get_world_size(group) > 1: 69 | return dist.all_reduce(tensor, op, group, **kwargs) 70 | 71 | def reduce(tensor, dst, op=dist.ReduceOp.SUM, group=None, **kwargs): 72 | if get_world_size(group) > 1: 73 | return dist.reduce(tensor, dst, op, group, **kwargs) 74 | 75 | def gather(tensor, dst=0, group=None, **kwargs): 76 | rank = get_rank() # global rank 77 | world_size = get_world_size(group) 78 | if world_size == 1: 79 | return [tensor] 80 | tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] if rank == dst else None 81 | dist.gather(tensor, tensor_list, dst, group, **kwargs) 82 | return tensor_list 83 | 84 | def all_gather(tensor, uniform_size=True, group=None, **kwargs): 85 | world_size = get_world_size(group) 86 | if world_size == 1: 87 | return [tensor] 88 | assert tensor.is_contiguous(), 'ops.all_gather requires the tensor to be contiguous()' 89 | 90 | if uniform_size: 91 | tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] 92 | dist.all_gather(tensor_list, tensor, group, **kwargs) 93 | return tensor_list 94 | else: 95 | # collect tensor shapes across GPUs 96 | shape = tuple(tensor.shape) 97 | shape_list = generalized_all_gather(shape, group) 98 | 99 | # flatten the tensor 100 | tensor = tensor.reshape(-1) 101 | size = int(np.prod(shape)) 102 | size_list = [int(np.prod(u)) for u in shape_list] 103 | max_size = max(size_list) 104 | 105 | # pad to maximum size 106 | if size != max_size: 107 | padding = tensor.new_zeros(max_size - size) 108 | tensor = torch.cat([tensor, padding], dim=0) 109 | 110 | # all_gather 111 | tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] 112 | dist.all_gather(tensor_list, tensor, group, **kwargs) 113 | 114 | # reshape tensors 115 | tensor_list = [t[:n].view(s) for t, n, s in zip( 116 | tensor_list, size_list, shape_list)] 117 | return tensor_list 118 | 119 | @torch.no_grad() 120 | def reduce_dict(input_dict, group=None, reduction='mean', **kwargs): 121 | assert reduction in ['mean', 'sum'] 122 | world_size = get_world_size(group) 123 | if world_size == 1: 124 | return input_dict 125 | 126 | # ensure that the orders of keys are consistent across processes 127 | if isinstance(input_dict, OrderedDict): 128 | keys = list(input_dict.keys) 129 | else: 130 | keys = sorted(input_dict.keys()) 131 | vals = [input_dict[key] for key in keys] 132 | vals = torch.stack(vals, dim=0) 133 | dist.reduce(vals, dst=0, group=group, **kwargs) 134 | if dist.get_rank(group) == 0 and reduction == 'mean': 135 | vals /= world_size 136 | dist.broadcast(vals, src=0, group=group, **kwargs) 137 | reduced_dict = type(input_dict)([ 138 | (key, val) for key, val in zip(keys, vals)]) 139 | return reduced_dict 140 | 141 | @functools.lru_cache() 142 | def get_global_gloo_group(): 143 | backend = dist.get_backend() 144 | assert backend in ['gloo', 'nccl'] 145 | if backend == 'nccl': 146 | return dist.new_group(backend='gloo') 147 | else: 148 | return dist.group.WORLD 149 | 150 | def _serialize_to_tensor(data, group): 151 | backend = dist.get_backend(group) 152 | assert backend in ['gloo', 'nccl'] 153 | device = torch.device('cpu' if backend == 'gloo' else 'cuda') 154 | 155 | buffer = pickle.dumps(data) 156 | if len(buffer) > 1024 ** 3: 157 | logger = logging.getLogger(__name__) 158 | logger.warning( 159 | 'Rank {} trying to all-gather {:.2f} GB of data on device' 160 | '{}'.format(get_rank(), len(buffer) / (1024 ** 3), device)) 161 | storage = torch.ByteStorage.from_buffer(buffer) 162 | tensor = torch.ByteTensor(storage).to(device=device) 163 | return tensor 164 | 165 | def _pad_to_largest_tensor(tensor, group): 166 | world_size = dist.get_world_size(group=group) 167 | assert world_size >= 1, \ 168 | 'gather/all_gather must be called from ranks within' \ 169 | 'the give group!' 170 | local_size = torch.tensor( 171 | [tensor.numel()], dtype=torch.int64, device=tensor.device) 172 | size_list = [torch.zeros( 173 | [1], dtype=torch.int64, device=tensor.device) 174 | for _ in range(world_size)] 175 | 176 | # gather tensors and compute the maximum size 177 | dist.all_gather(size_list, local_size, group=group) 178 | size_list = [int(size.item()) for size in size_list] 179 | max_size = max(size_list) 180 | 181 | # pad tensors to the same size 182 | if local_size != max_size: 183 | padding = torch.zeros( 184 | (max_size - local_size, ), 185 | dtype=torch.uint8, device=tensor.device) 186 | tensor = torch.cat((tensor, padding), dim=0) 187 | return size_list, tensor 188 | 189 | def generalized_all_gather(data, group=None): 190 | if get_world_size(group) == 1: 191 | return [data] 192 | if group is None: 193 | group = get_global_gloo_group() 194 | 195 | tensor = _serialize_to_tensor(data, group) 196 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 197 | max_size = max(size_list) 198 | 199 | # receiving tensors from all ranks 200 | tensor_list = [torch.empty( 201 | (max_size, ), dtype=torch.uint8, device=tensor.device) 202 | for _ in size_list] 203 | dist.all_gather(tensor_list, tensor, group=group) 204 | 205 | data_list = [] 206 | for size, tensor in zip(size_list, tensor_list): 207 | buffer = tensor.cpu().numpy().tobytes()[:size] 208 | data_list.append(pickle.loads(buffer)) 209 | return data_list 210 | 211 | def generalized_gather(data, dst=0, group=None): 212 | world_size = get_world_size(group) 213 | if world_size == 1: 214 | return [data] 215 | if group is None: 216 | group = get_global_gloo_group() 217 | rank = dist.get_rank() # global rank 218 | 219 | tensor = _serialize_to_tensor(data, group) 220 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 221 | 222 | # receiving tensors from all ranks to dst 223 | if rank == dst: 224 | max_size = max(size_list) 225 | tensor_list = [torch.empty( 226 | (max_size, ), dtype=torch.uint8, device=tensor.device) 227 | for _ in size_list] 228 | dist.gather(tensor, tensor_list, dst=dst, group=group) 229 | 230 | data_list = [] 231 | for size, tensor in zip(size_list, tensor_list): 232 | buffer = tensor.cpu().numpy().tobytes()[:size] 233 | data_list.append(pickle.loads(buffer)) 234 | return data_list 235 | else: 236 | dist.gather(tensor, [], dst=dst, group=group) 237 | return [] 238 | 239 | def scatter(data, scatter_list=None, src=0, group=None, **kwargs): 240 | r"""NOTE: only supports CPU tensor communication. 241 | """ 242 | if get_world_size(group) > 1: 243 | return dist.scatter(data, scatter_list, src, group, **kwargs) 244 | 245 | def reduce_scatter(output, input_list, op=dist.ReduceOp.SUM, group=None, **kwargs): 246 | if get_world_size(group) > 1: 247 | return dist.reduce_scatter(output, input_list, op, group, **kwargs) 248 | 249 | def send(tensor, dst, group=None, **kwargs): 250 | if get_world_size(group) > 1: 251 | assert tensor.is_contiguous(), 'ops.send requires the tensor to be contiguous()' 252 | return dist.send(tensor, dst, group, **kwargs) 253 | 254 | def recv(tensor, src=None, group=None, **kwargs): 255 | if get_world_size(group) > 1: 256 | assert tensor.is_contiguous(), 'ops.recv requires the tensor to be contiguous()' 257 | return dist.recv(tensor, src, group, **kwargs) 258 | 259 | def isend(tensor, dst, group=None, **kwargs): 260 | if get_world_size(group) > 1: 261 | assert tensor.is_contiguous(), 'ops.isend requires the tensor to be contiguous()' 262 | return dist.isend(tensor, dst, group, **kwargs) 263 | 264 | def irecv(tensor, src=None, group=None, **kwargs): 265 | if get_world_size(group) > 1: 266 | assert tensor.is_contiguous(), 'ops.irecv requires the tensor to be contiguous()' 267 | return dist.irecv(tensor, src, group, **kwargs) 268 | 269 | def shared_random_seed(group=None): 270 | seed = np.random.randint(2 ** 31) 271 | all_seeds = generalized_all_gather(seed, group) 272 | return all_seeds[0] 273 | 274 | #-------------------------------- Differentiable operations --------------------------------# 275 | 276 | def _all_gather(x): 277 | if not (dist.is_available() and dist.is_initialized()) or dist.get_world_size() == 1: 278 | return x 279 | rank = dist.get_rank() 280 | world_size = dist.get_world_size() 281 | tensors = [torch.empty_like(x) for _ in range(world_size)] 282 | tensors[rank] = x 283 | dist.all_gather(tensors, x) 284 | return torch.cat(tensors, dim=0).contiguous() 285 | 286 | def _all_reduce(x): 287 | if not (dist.is_available() and dist.is_initialized()) or dist.get_world_size() == 1: 288 | return x 289 | dist.all_reduce(x) 290 | return x 291 | 292 | def _split(x): 293 | if not (dist.is_available() and dist.is_initialized()) or dist.get_world_size() == 1: 294 | return x 295 | rank = dist.get_rank() 296 | world_size = dist.get_world_size() 297 | return x.chunk(world_size, dim=0)[rank].contiguous() 298 | 299 | class DiffAllGather(Function): 300 | r"""Differentiable all-gather. 301 | """ 302 | @staticmethod 303 | def symbolic(graph, input): 304 | return _all_gather(input) 305 | 306 | @staticmethod 307 | def forward(ctx, input): 308 | return _all_gather(input) 309 | 310 | @staticmethod 311 | def backward(ctx, grad_output): 312 | return _split(grad_output) 313 | 314 | class DiffAllReduce(Function): 315 | r"""Differentiable all-reducd. 316 | """ 317 | @staticmethod 318 | def symbolic(graph, input): 319 | return _all_reduce(input) 320 | 321 | @staticmethod 322 | def forward(ctx, input): 323 | return _all_reduce(input) 324 | 325 | @staticmethod 326 | def backward(ctx, grad_output): 327 | return grad_output 328 | 329 | class DiffScatter(Function): 330 | r"""Differentiable scatter. 331 | """ 332 | @staticmethod 333 | def symbolic(graph, input): 334 | return _split(input) 335 | 336 | @staticmethod 337 | def symbolic(ctx, input): 338 | return _split(input) 339 | 340 | @staticmethod 341 | def backward(ctx, grad_output): 342 | return _all_gather(grad_output) 343 | 344 | class DiffCopy(Function): 345 | r"""Differentiable copy that reduces all gradients during backward. 346 | """ 347 | @staticmethod 348 | def symbolic(graph, input): 349 | return input 350 | 351 | @staticmethod 352 | def forward(ctx, input): 353 | return input 354 | 355 | @staticmethod 356 | def backward(ctx, grad_output): 357 | return _all_reduce(grad_output) 358 | 359 | diff_all_gather = DiffAllGather.apply 360 | diff_all_reduce = DiffAllReduce.apply 361 | diff_scatter = DiffScatter.apply 362 | diff_copy = DiffCopy.apply 363 | 364 | #-------------------------------- Distributed algorithms --------------------------------# 365 | 366 | @torch.no_grad() 367 | def spherical_kmeans(feats, num_clusters, num_iters=10): 368 | k, n, c = num_clusters, *feats.size() 369 | ones = feats.new_ones(n, dtype=torch.long) 370 | 371 | # distributed settings 372 | rank = get_rank() 373 | world_size = get_world_size() 374 | 375 | # init clusters 376 | rand_inds = torch.randperm(n)[:int(np.ceil(k / world_size))] 377 | clusters = torch.cat(all_gather(feats[rand_inds]), dim=0)[:k] 378 | 379 | # variables 380 | new_clusters = feats.new_zeros(k, c) 381 | counts = feats.new_zeros(k, dtype=torch.long) 382 | 383 | # iterative Expectation-Maximization 384 | for step in range(num_iters + 1): 385 | # Expectation step 386 | simmat = torch.mm(feats, clusters.t()) 387 | scores, assigns = simmat.max(dim=1) 388 | if step == num_iters: 389 | break 390 | 391 | # Maximization step 392 | new_clusters.zero_().scatter_add_(0, assigns.unsqueeze(1).repeat(1, c), feats) 393 | all_reduce(new_clusters) 394 | 395 | counts.zero_() 396 | counts.index_add_(0, assigns, ones) 397 | all_reduce(counts) 398 | 399 | mask = (counts > 0) 400 | clusters[mask] = new_clusters[mask] / counts[mask].view(-1, 1) 401 | clusters = F.normalize(clusters, p=2, dim=1) 402 | return clusters, assigns, scores 403 | 404 | @torch.no_grad() 405 | def sinkhorn(Q, eps=0.5, num_iters=3): 406 | # normalize Q 407 | Q = torch.exp(Q / eps).t() 408 | sum_Q = Q.sum() 409 | all_reduce(sum_Q) 410 | Q /= sum_Q 411 | 412 | # variables 413 | n, m = Q.size() 414 | u = Q.new_zeros(n) 415 | r = Q.new_ones(n) / n 416 | c = Q.new_ones(m) / (m * get_world_size()) 417 | 418 | # iterative update 419 | cur_sum = Q.sum(dim=1) 420 | all_reduce(cur_sum) 421 | for i in range(num_iters): 422 | u = cur_sum 423 | Q *= (r / u).unsqueeze(1) 424 | Q *= (c / Q.sum(dim=0)).unsqueeze(0) 425 | cur_sum = Q.sum(dim=1) 426 | all_reduce(cur_sum) 427 | return (Q / Q.sum(dim=0, keepdim=True)).t().float() 428 | -------------------------------------------------------------------------------- /artist/ops/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | __all__ = ['kl_divergence', 'discretized_gaussian_log_likelihood'] 5 | 6 | def kl_divergence(mu1, logvar1, mu2, logvar2): 7 | return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mu1 - mu2) ** 2) * torch.exp(-logvar2)) 8 | 9 | def standard_normal_cdf(x): 10 | r"""A fast approximation of the cumulative distribution function of the standard normal. 11 | """ 12 | return 0.5 * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 13 | 14 | def discretized_gaussian_log_likelihood(x0, mean, log_scale): 15 | assert x0.shape == mean.shape == log_scale.shape 16 | cx = x0 - mean 17 | inv_stdv = torch.exp(-log_scale) 18 | cdf_plus = standard_normal_cdf(inv_stdv * (cx + 1.0 / 255.0)) 19 | cdf_min = standard_normal_cdf(inv_stdv * (cx - 1.0 / 255.0)) 20 | log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) 21 | log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) 22 | cdf_delta = cdf_plus - cdf_min 23 | log_probs = torch.where( 24 | x0 < -0.999, 25 | log_cdf_plus, 26 | torch.where(x0 > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12)))) 27 | assert log_probs.shape == x0.shape 28 | return log_probs 29 | -------------------------------------------------------------------------------- /artist/ops/random_mask.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | __all__ = ['make_irregular_mask', 'make_rectangle_mask', 'make_uncrop'] 5 | 6 | def make_irregular_mask(w, h, max_angle=4, max_length=200, max_width=100, min_strokes=1, max_strokes=5, mode='line'): 7 | # initialize mask 8 | assert mode in ['line', 'circle', 'square'] 9 | mask = np.zeros((h, w), np.float32) 10 | 11 | # draw strokes 12 | num_strokes = np.random.randint(min_strokes, max_strokes + 1) 13 | for i in range(num_strokes): 14 | x1 = np.random.randint(w) 15 | y1 = np.random.randint(h) 16 | for j in range(1 + np.random.randint(5)): 17 | angle = 0.01 + np.random.randint(max_angle) 18 | if i % 2 == 0: 19 | angle = 2 * 3.1415926 - angle 20 | length = 10 + np.random.randint(max_length) 21 | radius = 5 + np.random.randint(max_width) 22 | x2 = np.clip((x1 + length * np.sin(angle)).astype(np.int32), 0, w) 23 | y2 = np.clip((y1 + length * np.cos(angle)).astype(np.int32), 0, h) 24 | if mode == 'line': 25 | cv2.line(mask, (x1, y1), (x2, y2), 1.0, radius) 26 | elif mode == 'circle': 27 | cv2.circle(mask, (x1, y1), radius=radius, color=1.0, thickness=-1) 28 | elif mode == 'square': 29 | radius = radius // 2 30 | mask[y1 - radius:y1 + radius, x1 - radius:x1 + radius] = 1 31 | x1, y1 = x2, y2 32 | return mask 33 | 34 | def make_rectangle_mask(w, h, margin=10, min_size=30, max_size=150, min_strokes=1, max_strokes=4): 35 | # initialize mask 36 | mask = np.zeros((h, w), np.float32) 37 | 38 | # draw rectangles 39 | num_strokes = np.random.randint(min_strokes, max_strokes + 1) 40 | for i in range(num_strokes): 41 | box_w = np.random.randint(min_size, max_size) 42 | box_h = np.random.randint(min_size, max_size) 43 | x1 = np.random.randint(margin, w - margin - box_w + 1) 44 | y1 = np.random.randint(margin, h - margin - box_h + 1) 45 | mask[y1:y1 + box_h, x1:x1 + box_w] = 1 46 | return mask 47 | 48 | def make_uncrop(w, h): 49 | # initialize mask 50 | mask = np.zeros((h, w), np.float32) 51 | 52 | # randomly halve the image 53 | side = np.random.choice([0, 1, 2, 3]) 54 | if side == 0: 55 | mask[:h // 2, :] = 1 56 | elif side == 1: 57 | mask[h // 2:, :] = 1 58 | elif side == 2: 59 | mask[:, :w // 2] = 1 60 | elif side == 3: 61 | mask[:, w // 2:] = 1 62 | return mask 63 | -------------------------------------------------------------------------------- /artist/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .lr_scheduler import * 2 | from .adafactor import * -------------------------------------------------------------------------------- /artist/optim/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/optim/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /artist/optim/__pycache__/adafactor.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/optim/__pycache__/adafactor.cpython-38.pyc -------------------------------------------------------------------------------- /artist/optim/__pycache__/lr_scheduler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/optim/__pycache__/lr_scheduler.cpython-38.pyc -------------------------------------------------------------------------------- /artist/optim/adafactor.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim import Optimizer 4 | from torch.optim.lr_scheduler import LambdaLR 5 | 6 | __all__ = ['Adafactor'] 7 | 8 | class Adafactor(Optimizer): 9 | """ 10 | AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code: 11 | https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py 12 | Paper: *Adafactor: Adaptive Learning Rates with Sublinear Memory Cost* https://arxiv.org/abs/1804.04235 Note that 13 | this optimizer internally adjusts the learning rate depending on the `scale_parameter`, `relative_step` and 14 | `warmup_init` options. To use a manual (external) learning rate schedule you should set `scale_parameter=False` and 15 | `relative_step=False`. 16 | Arguments: 17 | params (`Iterable[nn.parameter.Parameter]`): 18 | Iterable of parameters to optimize or dictionaries defining parameter groups. 19 | lr (`float`, *optional*): 20 | The external learning rate. 21 | eps (`Tuple[float, float]`, *optional*, defaults to (1e-30, 1e-3)): 22 | Regularization constants for square gradient and parameter scale respectively 23 | clip_threshold (`float`, *optional*, defaults 1.0): 24 | Threshold of root mean square of final gradient update 25 | decay_rate (`float`, *optional*, defaults to -0.8): 26 | Coefficient used to compute running averages of square 27 | beta1 (`float`, *optional*): 28 | Coefficient used for computing running averages of gradient 29 | weight_decay (`float`, *optional*, defaults to 0): 30 | Weight decay (L2 penalty) 31 | scale_parameter (`bool`, *optional*, defaults to `True`): 32 | If True, learning rate is scaled by root mean square 33 | relative_step (`bool`, *optional*, defaults to `True`): 34 | If True, time-dependent learning rate is computed instead of external learning rate 35 | warmup_init (`bool`, *optional*, defaults to `False`): 36 | Time-dependent learning rate computation depends on whether warm-up initialization is being used 37 | This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested. 38 | Recommended T5 finetuning settings (https://discuss.huggingface.co/t/t5-finetuning-tips/684/3): 39 | - Training without LR warmup or clip_threshold is not recommended. 40 | - use scheduled LR warm-up to fixed LR 41 | - use clip_threshold=1.0 (https://arxiv.org/abs/1804.04235) 42 | - Disable relative updates 43 | - Use scale_parameter=False 44 | - Additional optimizer operations like gradient clipping should not be used alongside Adafactor 45 | Example: 46 | ```python 47 | Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3) 48 | ``` 49 | Others reported the following combination to work well: 50 | ```python 51 | Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None) 52 | ``` 53 | When using `lr=None` with [`Trainer`] you will most likely need to use [`~optimization.AdafactorSchedule`] 54 | scheduler as following: 55 | ```python 56 | from transformers.optimization import Adafactor, AdafactorSchedule 57 | optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None) 58 | lr_scheduler = AdafactorSchedule(optimizer) 59 | trainer = Trainer(..., optimizers=(optimizer, lr_scheduler)) 60 | ``` 61 | Usage: 62 | ```python 63 | # replace AdamW with Adafactor 64 | optimizer = Adafactor( 65 | model.parameters(), 66 | lr=1e-3, 67 | eps=(1e-30, 1e-3), 68 | clip_threshold=1.0, 69 | decay_rate=-0.8, 70 | beta1=None, 71 | weight_decay=0.0, 72 | relative_step=False, 73 | scale_parameter=False, 74 | warmup_init=False, 75 | ) 76 | ```""" 77 | 78 | def __init__( 79 | self, 80 | params, 81 | lr=None, 82 | eps=(1e-30, 1e-3), 83 | clip_threshold=1.0, 84 | decay_rate=-0.8, 85 | beta1=None, 86 | weight_decay=0.0, 87 | scale_parameter=True, 88 | relative_step=True, 89 | warmup_init=False, 90 | ): 91 | r"""require_version("torch>=1.5.0") # add_ with alpha 92 | """ 93 | if lr is not None and relative_step: 94 | raise ValueError("Cannot combine manual `lr` and `relative_step=True` options") 95 | if warmup_init and not relative_step: 96 | raise ValueError("`warmup_init=True` requires `relative_step=True`") 97 | 98 | defaults = dict( 99 | lr=lr, 100 | eps=eps, 101 | clip_threshold=clip_threshold, 102 | decay_rate=decay_rate, 103 | beta1=beta1, 104 | weight_decay=weight_decay, 105 | scale_parameter=scale_parameter, 106 | relative_step=relative_step, 107 | warmup_init=warmup_init, 108 | ) 109 | super().__init__(params, defaults) 110 | 111 | @staticmethod 112 | def _get_lr(param_group, param_state): 113 | rel_step_sz = param_group["lr"] 114 | if param_group["relative_step"]: 115 | min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2 116 | rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) 117 | param_scale = 1.0 118 | if param_group["scale_parameter"]: 119 | param_scale = max(param_group["eps"][1], param_state["RMS"]) 120 | return param_scale * rel_step_sz 121 | 122 | @staticmethod 123 | def _get_options(param_group, param_shape): 124 | factored = len(param_shape) >= 2 125 | use_first_moment = param_group["beta1"] is not None 126 | return factored, use_first_moment 127 | 128 | @staticmethod 129 | def _rms(tensor): 130 | return tensor.norm(2) / (tensor.numel() ** 0.5) 131 | 132 | @staticmethod 133 | def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): 134 | # copy from fairseq's adafactor implementation: 135 | # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505 136 | r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) 137 | c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() 138 | return torch.mul(r_factor, c_factor) 139 | 140 | def step(self, closure=None): 141 | """ 142 | Performs a single optimization step 143 | Arguments: 144 | closure (callable, optional): A closure that reevaluates the model 145 | and returns the loss. 146 | """ 147 | loss = None 148 | if closure is not None: 149 | loss = closure() 150 | 151 | for group in self.param_groups: 152 | for p in group["params"]: 153 | if p.grad is None: 154 | continue 155 | grad = p.grad.data 156 | if grad.dtype in {torch.float16, torch.bfloat16}: 157 | grad = grad.float() 158 | if grad.is_sparse: 159 | raise RuntimeError("Adafactor does not support sparse gradients.") 160 | 161 | state = self.state[p] 162 | grad_shape = grad.shape 163 | 164 | factored, use_first_moment = self._get_options(group, grad_shape) 165 | # State Initialization 166 | if len(state) == 0: 167 | state["step"] = 0 168 | 169 | if use_first_moment: 170 | # Exponential moving average of gradient values 171 | state["exp_avg"] = torch.zeros_like(grad) 172 | if factored: 173 | state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad) 174 | state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad) 175 | else: 176 | state["exp_avg_sq"] = torch.zeros_like(grad) 177 | 178 | state["RMS"] = 0 179 | else: 180 | if use_first_moment: 181 | state["exp_avg"] = state["exp_avg"].to(grad) 182 | if factored: 183 | state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) 184 | state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) 185 | else: 186 | state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) 187 | 188 | p_data_fp32 = p.data 189 | if p.data.dtype in {torch.float16, torch.bfloat16}: 190 | p_data_fp32 = p_data_fp32.float() 191 | 192 | state["step"] += 1 193 | state["RMS"] = self._rms(p_data_fp32) 194 | lr = self._get_lr(group, state) 195 | 196 | beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) 197 | update = (grad**2) + group["eps"][0] 198 | if factored: 199 | exp_avg_sq_row = state["exp_avg_sq_row"] 200 | exp_avg_sq_col = state["exp_avg_sq_col"] 201 | 202 | exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t)) 203 | exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t)) 204 | 205 | # Approximation of exponential moving average of square of gradient 206 | update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) 207 | update.mul_(grad) 208 | else: 209 | exp_avg_sq = state["exp_avg_sq"] 210 | 211 | exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) 212 | update = exp_avg_sq.rsqrt().mul_(grad) 213 | 214 | update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) 215 | update.mul_(lr) 216 | 217 | if use_first_moment: 218 | exp_avg = state["exp_avg"] 219 | exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"])) 220 | update = exp_avg 221 | 222 | if group["weight_decay"] != 0: 223 | p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr)) 224 | 225 | p_data_fp32.add_(-update) 226 | 227 | if p.data.dtype in {torch.float16, torch.bfloat16}: 228 | p.data.copy_(p_data_fp32) 229 | 230 | return loss 231 | -------------------------------------------------------------------------------- /artist/optim/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch.optim.lr_scheduler import _LRScheduler 3 | 4 | __all__ = ['AnnealingLR'] 5 | 6 | class AnnealingLR(_LRScheduler): 7 | 8 | def __init__(self, optimizer, base_lr, warmup_steps, total_steps, decay_mode='cosine', min_lr=0.0, last_step=-1): 9 | assert decay_mode in ['linear', 'cosine', 'none'] 10 | self.optimizer = optimizer 11 | self.base_lr = base_lr 12 | self.warmup_steps = warmup_steps 13 | self.total_steps = total_steps 14 | self.decay_mode = decay_mode 15 | self.min_lr = min_lr 16 | self.current_step = last_step + 1 17 | self.step(self.current_step) 18 | 19 | def get_lr(self): 20 | if self.warmup_steps > 0 and self.current_step <= self.warmup_steps: 21 | return self.base_lr * self.current_step / self.warmup_steps 22 | else: 23 | ratio = (self.current_step - self.warmup_steps) / (self.total_steps - self.warmup_steps) 24 | ratio = min(1.0, max(0.0, ratio)) 25 | if self.decay_mode == 'linear': 26 | return self.base_lr * (1 - ratio) 27 | elif self.decay_mode == 'cosine': 28 | return self.base_lr * (math.cos(math.pi * ratio) + 1.0) / 2.0 29 | else: 30 | return self.base_lr 31 | 32 | def step(self, current_step=None): 33 | if current_step is None: 34 | current_step = self.current_step + 1 35 | self.current_step = current_step 36 | new_lr = max(self.min_lr, self.get_lr()) 37 | if isinstance(self.optimizer, list): 38 | for o in self.optimizer: 39 | for group in o.param_groups: 40 | group['lr'] = new_lr 41 | else: 42 | for group in self.optimizer.param_groups: 43 | group['lr'] = new_lr 44 | 45 | def state_dict(self): 46 | return { 47 | 'base_lr': self.base_lr, 48 | 'warmup_steps': self.warmup_steps, 49 | 'total_steps': self.total_steps, 50 | 'decay_mode': self.decay_mode, 51 | 'current_step': self.current_step} 52 | 53 | def load_state_dict(self, state_dict): 54 | self.base_lr = state_dict['base_lr'] 55 | self.warmup_steps = state_dict['warmup_steps'] 56 | self.total_steps = state_dict['total_steps'] 57 | self.decay_mode = state_dict['decay_mode'] 58 | self.current_step = state_dict['current_step'] 59 | -------------------------------------------------------------------------------- /configs/base.yaml: -------------------------------------------------------------------------------- 1 | ENABLE: true 2 | DATASET: webvid10m -------------------------------------------------------------------------------- /configs/exp01_vidcomposer_full.yaml: -------------------------------------------------------------------------------- 1 | TASK_TYPE: MULTI_TASK 2 | ENABLE: true 3 | DATASET: webvid10m 4 | video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch'] 5 | batch_sizes: { 6 | "1": 1, 7 | "4": 1, 8 | "8": 1, 9 | "16": 1, 10 | } 11 | vit_image_size: 224 12 | network_name: UNetSD_temporal 13 | resume: true 14 | resume_step: 228000 15 | num_workers: 1 16 | mvs_visual: False 17 | chunk_size: 1 18 | resume_checkpoint: "model_weights/non_ema_228000.pth" 19 | log_dir: 'outputs' 20 | num_steps: 1 21 | -------------------------------------------------------------------------------- /configs/exp02_motion_transfer.yaml: -------------------------------------------------------------------------------- 1 | TASK_TYPE: SINGLE_TASK 2 | read_image: True # You NEED Open It 3 | ENABLE: true 4 | DATASET: webvid10m 5 | video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch'] 6 | guidances: ['y', 'local_image', 'motion'] # You NEED Open It 7 | batch_sizes: { 8 | "1": 1, 9 | "4": 1, 10 | "8": 1, 11 | "16": 1, 12 | } 13 | vit_image_size: 224 14 | network_name: UNetSD_temporal 15 | resume: true 16 | resume_step: 228000 17 | seed: 182 18 | num_workers: 0 19 | mvs_visual: False 20 | chunk_size: 1 21 | resume_checkpoint: "model_weights/non_ema_228000.pth" 22 | log_dir: 'outputs' 23 | num_steps: 1 24 | -------------------------------------------------------------------------------- /configs/exp02_motion_transfer_vs_style.yaml: -------------------------------------------------------------------------------- 1 | TASK_TYPE: SINGLE_TASK 2 | read_image: True # You NEED Open It 3 | read_style: True 4 | ENABLE: true 5 | DATASET: webvid10m 6 | video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch'] 7 | guidances: ['y', 'local_image', 'image', 'motion'] # You NEED Open It 8 | batch_sizes: { 9 | "1": 1, 10 | "4": 1, 11 | "8": 1, 12 | "16": 1, 13 | } 14 | vit_image_size: 224 15 | network_name: UNetSD_temporal 16 | resume: true 17 | resume_step: 228000 18 | seed: 182 19 | num_workers: 0 20 | mvs_visual: False 21 | chunk_size: 1 22 | resume_checkpoint: "model_weights/non_ema_228000.pth" 23 | log_dir: 'outputs' 24 | num_steps: 1 25 | -------------------------------------------------------------------------------- /configs/exp03_sketch2video_style.yaml: -------------------------------------------------------------------------------- 1 | TASK_TYPE: SINGLE_TASK 2 | read_image: False # You NEED Open It 3 | read_style: True 4 | read_sketch: True 5 | save_origin_video: False 6 | ENABLE: true 7 | DATASET: webvid10m 8 | video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch'] 9 | guidances: ['y', 'image', 'single_sketch'] # You NEED Open It 10 | batch_sizes: { 11 | "1": 1, 12 | "4": 1, 13 | "8": 1, 14 | "16": 1, 15 | } 16 | vit_image_size: 224 17 | network_name: UNetSD_temporal 18 | resume: true 19 | resume_step: 228000 20 | seed: 182 21 | num_workers: 0 22 | mvs_visual: False 23 | chunk_size: 1 24 | resume_checkpoint: "model_weights/non_ema_228000.pth" 25 | log_dir: 'outputs' 26 | num_steps: 1 27 | -------------------------------------------------------------------------------- /configs/exp04_sketch2video_wo_style.yaml: -------------------------------------------------------------------------------- 1 | TASK_TYPE: SINGLE_TASK 2 | read_image: False # You NEED Open It 3 | read_style: False 4 | read_sketch: True 5 | save_origin_video: False 6 | ENABLE: true 7 | DATASET: webvid10m 8 | video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch'] 9 | guidances: ['y', 'single_sketch'] # You NEED Open It 10 | batch_sizes: { 11 | "1": 1, 12 | "4": 1, 13 | "8": 1, 14 | "16": 1, 15 | } 16 | vit_image_size: 224 17 | network_name: UNetSD_temporal 18 | resume: true 19 | resume_step: 228000 20 | seed: 182 21 | num_workers: 0 22 | mvs_visual: False 23 | chunk_size: 1 24 | resume_checkpoint: "model_weights/non_ema_228000.pth" 25 | log_dir: 'outputs' 26 | num_steps: 1 27 | -------------------------------------------------------------------------------- /configs/exp05_text_depths_wo_style.yaml: -------------------------------------------------------------------------------- 1 | TASK_TYPE: SINGLE_TASK 2 | read_image: False # You NEED Open It 3 | read_style: False 4 | read_sketch: False 5 | save_origin_video: True 6 | ENABLE: true 7 | DATASET: webvid10m 8 | video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch'] 9 | guidances: ['y', 'depth'] # You NEED Open It 10 | batch_sizes: { 11 | "1": 1, 12 | "4": 1, 13 | "8": 1, 14 | "16": 1, 15 | } 16 | vit_image_size: 224 17 | network_name: UNetSD_temporal 18 | resume: true 19 | resume_step: 228000 20 | seed: 182 21 | num_workers: 0 22 | mvs_visual: False 23 | chunk_size: 1 24 | resume_checkpoint: "model_weights/non_ema_228000.pth" 25 | log_dir: 'outputs' 26 | num_steps: 1 27 | 28 | -------------------------------------------------------------------------------- /configs/exp06_text_depths_vs_style.yaml: -------------------------------------------------------------------------------- 1 | TASK_TYPE: SINGLE_TASK 2 | read_image: False 3 | read_style: True 4 | read_sketch: False 5 | save_origin_video: True 6 | ENABLE: true 7 | DATASET: webvid10m 8 | video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch'] 9 | guidances: ['y', 'image', 'depth'] # You NEED Open It 10 | batch_sizes: { 11 | "1": 1, 12 | "4": 1, 13 | "8": 1, 14 | "16": 1, 15 | } 16 | vit_image_size: 224 17 | network_name: UNetSD_temporal 18 | resume: true 19 | resume_step: 228000 20 | seed: 182 21 | num_workers: 0 22 | mvs_visual: False 23 | chunk_size: 1 24 | resume_checkpoint: "model_weights/non_ema_228000.pth" 25 | log_dir: 'outputs' 26 | num_steps: 1 27 | -------------------------------------------------------------------------------- /configs/exp10_vidcomposer_no_watermark_full.yaml: -------------------------------------------------------------------------------- 1 | TASK_TYPE: VideoComposer_Inference 2 | ENABLE: true 3 | DATASET: webvid10m 4 | video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch'] 5 | batch_sizes: { 6 | "1": 1, 7 | "4": 1, 8 | "8": 1, 9 | "16": 1, 10 | } 11 | vit_image_size: 224 12 | network_name: UNetSD_temporal 13 | resume: true 14 | resume_step: 141000 15 | seed: 14 16 | num_workers: 1 17 | mvs_visual: True 18 | chunk_size: 1 19 | resume_checkpoint: "model_weights/non_ema_141000_no_watermark.pth" 20 | log_dir: 'outputs' 21 | num_steps: 1 -------------------------------------------------------------------------------- /demo_video/blackswan.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/demo_video/blackswan.mp4 -------------------------------------------------------------------------------- /demo_video/captions_list.txt: -------------------------------------------------------------------------------- 1 | video_10000178, 一架客机正在迪拜机场的停机坪上滑行, an airliner is taxiing on the tarmac at Dubai Airport 2 | video_5360763, 鸽子坐在房子附近的街道上, Pigeon sitting on the street near the house 3 | video_8800, 暹罗斗鱼(Betta splendens)在小玻璃碗和五彩石头中游泳,微距视频|||Siamese Fighting Fish (Betta splendens) swimming in a small glass bowl with multicolored piece of stone, Macro Video 4 | # A colorful and beautiful fish swimming in a small glass bowl with multicolored piece of stone, Macro Video 5 | # A glittering and translucent fish swimming in a small glass bowl with multicolored piece of stone, like a glass fish 6 | -------------------------------------------------------------------------------- /demo_video/landscape_painting.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/demo_video/landscape_painting.png -------------------------------------------------------------------------------- /demo_video/moon_on_water.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/demo_video/moon_on_water.jpg -------------------------------------------------------------------------------- /demo_video/motion_transfer.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/demo_video/motion_transfer.mp4 -------------------------------------------------------------------------------- /demo_video/qibaishi_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/demo_video/qibaishi_01.png -------------------------------------------------------------------------------- /demo_video/src_single_sketch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/demo_video/src_single_sketch.png -------------------------------------------------------------------------------- /demo_video/style/Bingxueqiyuan.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/demo_video/style/Bingxueqiyuan.jpeg -------------------------------------------------------------------------------- /demo_video/style/fangao_01.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/demo_video/style/fangao_01.jpeg -------------------------------------------------------------------------------- /demo_video/style/fangao_02.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/demo_video/style/fangao_02.jpeg -------------------------------------------------------------------------------- /demo_video/style/fangao_03.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/demo_video/style/fangao_03.jpeg -------------------------------------------------------------------------------- /demo_video/sunflower.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/demo_video/sunflower.png -------------------------------------------------------------------------------- /demo_video/sunflower_sketch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/demo_video/sunflower_sketch.png -------------------------------------------------------------------------------- /demo_video/tall_buildings.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/demo_video/tall_buildings.png -------------------------------------------------------------------------------- /demo_video/tennis.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/demo_video/tennis.mp4 -------------------------------------------------------------------------------- /demo_video/video_10000178.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/demo_video/video_10000178.mp4 -------------------------------------------------------------------------------- /demo_video/video_5360763.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/demo_video/video_5360763.mp4 -------------------------------------------------------------------------------- /demo_video/video_8800.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/demo_video/video_8800.mp4 -------------------------------------------------------------------------------- /demo_video/wash_painting.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/demo_video/wash_painting.png -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: VideoComposer 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - ca-certificates=2023.01.10=h06a4308_0 8 | - certifi=2022.12.7=py38h06a4308_0 9 | - ld_impl_linux-64=2.38=h1181459_1 10 | - libffi=3.4.2=h6a678d5_6 11 | - libgcc-ng=11.2.0=h1234567_1 12 | - libgomp=11.2.0=h1234567_1 13 | - libstdcxx-ng=11.2.0=h1234567_1 14 | - ncurses=6.4=h6a678d5_0 15 | - openssl=1.1.1t=h7f8727e_0 16 | - pip=23.0.1=py38h06a4308_0 17 | - python=3.8.16=h7a1cb2a_3 18 | - readline=8.2=h5eee18b_0 19 | - setuptools=65.6.3=py38h06a4308_0 20 | - sqlite=3.41.1=h5eee18b_0 21 | - tk=8.6.12=h1ccaba5_0 22 | - wheel=0.38.4=py38h06a4308_0 23 | - xz=5.2.10=h5eee18b_1 24 | - zlib=1.2.13=h5eee18b_0 25 | - pip: 26 | - absl-py==1.4.0 27 | - aiohttp==3.8.4 28 | - aiosignal==1.3.1 29 | - aliyun-python-sdk-core==2.13.36 30 | - aliyun-python-sdk-kms==2.16.0 31 | - asttokens==2.2.1 32 | - async-timeout==4.0.2 33 | - attrs==22.2.0 34 | - backcall==0.2.0 35 | - cachetools==5.3.0 36 | - cffi==1.15.1 37 | - chardet==5.1.0 38 | - charset-normalizer==3.1.0 39 | - clean-fid==0.1.35 40 | - click==8.1.3 41 | - cmake==3.26.0 42 | - crcmod==1.7 43 | - cryptography==39.0.2 44 | - decorator==5.1.1 45 | - decord==0.6.0 46 | - easydict==1.10 47 | - einops==0.6.0 48 | - executing==1.2.0 49 | - fairscale==0.4.6 50 | - filelock==3.10.2 51 | - flash-attn==0.2.0 52 | - frozenlist==1.3.3 53 | - fsspec==2023.3.0 54 | - ftfy==6.1.1 55 | - future==0.18.3 56 | - google-auth==2.16.2 57 | - google-auth-oauthlib==0.4.6 58 | - grpcio==1.51.3 59 | - huggingface-hub==0.13.3 60 | - idna==3.4 61 | - imageio==2.15.0 62 | - importlib-metadata==6.1.0 63 | - ipdb==0.13.13 64 | - ipython==8.11.0 65 | - jedi==0.18.2 66 | - jmespath==0.10.0 67 | - joblib==1.2.0 68 | - lazy-loader==0.2 69 | - markdown==3.4.3 70 | - markupsafe==2.1.2 71 | - matplotlib-inline==0.1.6 72 | - motion-vector-extractor==1.0.6 73 | - multidict==6.0.4 74 | - mypy-extensions==1.0.0 75 | - networkx==3.1 76 | - numpy==1.24.2 77 | - oauthlib==3.2.2 78 | - open-clip-torch==2.0.2 79 | - openai-clip==1.0.1 80 | - opencv-python==4.5.5.64 81 | - opencv-python-headless==4.7.0.68 82 | - oss2==2.17.0 83 | - packaging==23.0 84 | - parso==0.8.3 85 | - pexpect==4.8.0 86 | - pickleshare==0.7.5 87 | - pillow==9.4.0 88 | - pkgconfig==1.5.5 89 | - prompt-toolkit==3.0.38 90 | - protobuf==4.22.1 91 | - ptyprocess==0.7.0 92 | - pure-eval==0.2.2 93 | - pyasn1==0.4.8 94 | - pyasn1-modules==0.2.8 95 | - pycparser==2.21 96 | - pycryptodome==3.17 97 | - pydeprecate==0.3.1 98 | - pygments==2.14.0 99 | - pynvml==11.5.0 100 | - pyre-extensions==0.0.23 101 | - pytorch-lightning==1.4.2 102 | - pywavelets==1.4.1 103 | - pyyaml==6.0 104 | - regex==2023.3.22 105 | - requests==2.28.2 106 | - requests-oauthlib==1.3.1 107 | - rotary-embedding-torch==0.2.1 108 | - rsa==4.9 109 | - sacremoses==0.0.53 110 | - scikit-image==0.20.0 111 | - scikit-learn==1.2.2 112 | - scikit-video==1.1.11 113 | - scipy==1.9.1 114 | - simplejson==3.18.4 115 | - six==1.16.0 116 | - sklearn==0.0.post4 117 | - stack-data==0.6.2 118 | - tensorboard==2.12.0 119 | - tensorboard-data-server==0.7.0 120 | - tensorboard-plugin-wit==1.8.1 121 | - threadpoolctl==3.1.0 122 | - tifffile==2023.4.12 123 | - tokenizers==0.12.1 124 | - tomli==2.0.1 125 | - torch==1.12.0+cu113 126 | - torchaudio==0.12.0+cu113 127 | - torchmetrics==0.6.0 128 | - torchvision==0.13.0+cu113 129 | - tqdm==4.65.0 130 | - traitlets==5.9.0 131 | - transformers==4.18.0 132 | - triton==2.0.0.dev20221120 133 | - typing-extensions==4.5.0 134 | - typing-inspect==0.8.0 135 | - urllib3==1.26.15 136 | - wcwidth==0.2.6 137 | - werkzeug==2.2.3 138 | - xformers==0.0.13 139 | - yarl==1.8.2 140 | - zipp==3.15.0 141 | -------------------------------------------------------------------------------- /gen_sketch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import random 5 | from PIL import Image 6 | from io import BytesIO 7 | from functools import partial 8 | import torchvision.transforms as T 9 | import torchvision.transforms.functional as TF 10 | from torchvision.transforms.functional import InterpolationMode 11 | 12 | import artist.data as data 13 | import artist.ops as ops 14 | from tools.annotator.sketch import pidinet_bsd, sketch_simplification_gan 15 | 16 | 17 | def random_resize(img, size): 18 | img = [TF.resize(u, size, interpolation=random.choice([ 19 | InterpolationMode.BILINEAR, 20 | InterpolationMode.BICUBIC, 21 | InterpolationMode.LANCZOS])) for u in img] 22 | return img 23 | 24 | 25 | def gen_sketch(image_path, gpu=0, misc_size=384): 26 | 27 | sketch_mean = [0.485, 0.456, 0.406] 28 | sketch_std = [0.229, 0.224, 0.225] 29 | 30 | pidinet = pidinet_bsd(pretrained=True, vanilla_cnn=True).eval().requires_grad_(False).to(gpu) 31 | cleaner = sketch_simplification_gan(pretrained=True).eval().requires_grad_(False).to(gpu) 32 | pidi_mean = torch.tensor(sketch_mean).view(1, -1, 1, 1).to(gpu) 33 | pidi_std = torch.tensor(sketch_std).view(1, -1, 1, 1).to(gpu) 34 | 35 | misc_transforms = data.Compose([ 36 | T.Lambda(partial(random_resize, size=misc_size)), 37 | data.CenterCropV2(misc_size), 38 | data.ToTensor()]) 39 | 40 | image = Image.open(open(image_path, mode='rb')).convert('RGB') 41 | image = misc_transforms([image]) # 42 | image = image.to(gpu) 43 | 44 | sketch = pidinet(image.sub(pidi_mean).div_(pidi_std)) 45 | sketch = 1.0 - cleaner(1.0 - sketch) 46 | sketch = sketch.cpu() 47 | 48 | sketch = sketch[0][0] 49 | sketch = (sketch.numpy()*255).astype('uint8') 50 | file_name = os.path.basename(image_path) 51 | save_pth = 'source/inputs/' + file_name.replace('.', '_sketch.') 52 | cv2.imwrite(save_pth, sketch) 53 | 54 | 55 | gen_sketch(image_path='demo_video/sunflower.png') 56 | -------------------------------------------------------------------------------- /model_weights/readme.md: -------------------------------------------------------------------------------- 1 | ### 1. Installation 2 | 3 | Please download the model and place them here. 4 | 5 | ``` 6 | |--model_weights/ 7 | | |--non_ema_228000.pth 8 | | |--midas_v3_dpt_large.pth 9 | | |--open_clip_pytorch_model.bin 10 | | |--sketch_simplification_gan.pth 11 | | |--table5_pidinet.pth 12 | | |--v2-1_512-ema-pruned.ckpt 13 | ``` -------------------------------------------------------------------------------- /run_bash.sh: -------------------------------------------------------------------------------- 1 | # Exp01, inference different conditions from a video 2 | python run_net.py\ 3 | --cfg configs/exp01_vidcomposer_full.yaml\ 4 | --seed 9999\ 5 | --input_video "demo_video/blackswan.mp4"\ 6 | --input_text_desc "A black swan swam in the water" 7 | 8 | 9 | # Exp02, Motion Transfer from a video to a Single Image 10 | python run_net.py\ 11 | --cfg configs/exp02_motion_transfer.yaml\ 12 | --seed 9999\ 13 | --input_video "demo_video/motion_transfer.mp4"\ 14 | --image_path "demo_video/sunflower.png"\ 15 | --input_text_desc "A sunflower in a field of flowers" 16 | 17 | 18 | python run_net.py\ 19 | --cfg configs/exp02_motion_transfer_vs_style.yaml\ 20 | --seed 9999\ 21 | --input_video "demo_video/motion_transfer.mp4"\ 22 | --image_path "demo_video/moon_on_water.jpg"\ 23 | --style_image "demo_video/moon_on_water.jpg"\ 24 | --input_text_desc "A beautiful big silver moon on the water" 25 | 26 | 27 | python run_net.py\ 28 | --cfg configs/exp02_motion_transfer_vs_style.yaml\ 29 | --seed 888\ 30 | --input_video "demo_video/motion_transfer.mp4"\ 31 | --image_path "demo_video/style/fangao_01.jpeg"\ 32 | --style_image "demo_video/style/fangao_01.jpeg"\ 33 | --input_text_desc "Beneath Van Gogh's Starry Sky" 34 | 35 | 36 | # Exp03, Single Sketch to videos with style 37 | python run_net.py\ 38 | --cfg configs/exp03_sketch2video_style.yaml\ 39 | --seed 8888\ 40 | --sketch_path "demo_video/src_single_sketch.png"\ 41 | --style_image "demo_video/style/qibaishi_01.png"\ 42 | --input_text_desc "Red-backed Shrike lanius collurio" 43 | 44 | # Exp04, Single Sketch to videos without style input 45 | python run_net.py\ 46 | --cfg configs/exp04_sketch2video_wo_style.yaml\ 47 | --seed 144\ 48 | --sketch_path "demo_video/src_single_sketch.png"\ 49 | --input_text_desc "A little bird is standing on a branch" 50 | 51 | 52 | # Exp05, Depth to video without style 53 | python run_net.py\ 54 | --cfg configs/exp05_text_depths_wo_style.yaml\ 55 | --seed 9999\ 56 | --input_video demo_video/tennis.mp4\ 57 | --input_text_desc "Ironman is fighting against the enemy, big fire in the background, photorealistic" 58 | 59 | 60 | # Exp06, Depth to video with style 61 | python run_net.py\ 62 | --cfg configs/exp06_text_depths_vs_style.yaml\ 63 | --seed 9999\ 64 | --input_video demo_video/tennis.mp4\ 65 | --style_image "demo_video/style/fangao_01.jpeg"\ 66 | --input_text_desc "Van Gogh played tennis under the stars" 67 | 68 | 69 | # Exp07, Depth to video without style 70 | python run_net.py\ 71 | --cfg configs/exp07_text_image_wo_style.yaml\ 72 | --seed 9999\ 73 | --input_video demo_video/blackswan.mp4\ 74 | --input_text_desc "Van Gogh played tennis under the stars" 75 | 76 | 77 | # If you want to , use CUDA_VISIBLE_DEVICES=0 78 | -------------------------------------------------------------------------------- /run_net.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import os.path as osp 4 | # sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-2])) 5 | import logging 6 | import numpy as np 7 | import copy 8 | import random 9 | import json 10 | import math 11 | import itertools 12 | 13 | import logging 14 | # logger = logging.get_logger(__name__) 15 | 16 | from utils.config import Config 17 | from tools.videocomposer.inference_multi import inference_multi 18 | from tools.videocomposer.inference_single import inference_single 19 | 20 | 21 | def main(): 22 | """ 23 | Main function to spawn the train and test process. 24 | """ 25 | cfg = Config(load=True) 26 | if hasattr(cfg, "TASK_TYPE") and cfg.TASK_TYPE == "MULTI_TASK": 27 | logging.info("TASK TYPE: %s " % cfg.TASK_TYPE) 28 | inference_multi(cfg.cfg_dict) 29 | elif hasattr(cfg, "TASK_TYPE") and cfg.TASK_TYPE == "SINGLE_TASK": 30 | logging.info("TASK TYPE: %s " % cfg.TASK_TYPE) 31 | inference_single(cfg.cfg_dict) 32 | else: 33 | logging.info('Not suport task %s' % (cfg.TASK_TYPE)) 34 | 35 | if __name__ == "__main__": 36 | main() 37 | -------------------------------------------------------------------------------- /source/fig01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/source/fig01.jpg -------------------------------------------------------------------------------- /source/fig02_framwork.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/source/fig02_framwork.jpg -------------------------------------------------------------------------------- /source/fig03_image-to-video.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/source/fig03_image-to-video.jpg -------------------------------------------------------------------------------- /source/fig04_hand-crafted-motions.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/source/fig04_hand-crafted-motions.jpg -------------------------------------------------------------------------------- /source/fig05_video-inpainting.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/source/fig05_video-inpainting.jpg -------------------------------------------------------------------------------- /source/fig06_sketch-to-video.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/source/fig06_sketch-to-video.jpg -------------------------------------------------------------------------------- /source/results/exp02_motion_transfer-S00009.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/source/results/exp02_motion_transfer-S00009.gif -------------------------------------------------------------------------------- /source/results/exp02_motion_transfer-S09999-0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/source/results/exp02_motion_transfer-S09999-0.gif -------------------------------------------------------------------------------- /source/results/exp02_motion_transfer-S09999.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/source/results/exp02_motion_transfer-S09999.gif -------------------------------------------------------------------------------- /source/results/exp03_sketch2video_style-S09999.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/source/results/exp03_sketch2video_style-S09999.gif -------------------------------------------------------------------------------- /source/results/exp04_sketch2video_wo_style-S00144-1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/source/results/exp04_sketch2video_wo_style-S00144-1.gif -------------------------------------------------------------------------------- /source/results/exp04_sketch2video_wo_style-S00144-2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/source/results/exp04_sketch2video_wo_style-S00144-2.gif -------------------------------------------------------------------------------- /source/results/exp04_sketch2video_wo_style-S00144.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/source/results/exp04_sketch2video_wo_style-S00144.gif -------------------------------------------------------------------------------- /source/results/exp05_text_depths_wo_style-S09999-0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/source/results/exp05_text_depths_wo_style-S09999-0.gif -------------------------------------------------------------------------------- /source/results/exp05_text_depths_wo_style-S09999-1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/source/results/exp05_text_depths_wo_style-S09999-1.gif -------------------------------------------------------------------------------- /source/results/exp05_text_depths_wo_style-S09999-2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/source/results/exp05_text_depths_wo_style-S09999-2.gif -------------------------------------------------------------------------------- /source/results/exp06_text_depths_vs_style-S09999-0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/source/results/exp06_text_depths_vs_style-S09999-0.gif -------------------------------------------------------------------------------- /source/results/exp06_text_depths_vs_style-S09999-1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/source/results/exp06_text_depths_vs_style-S09999-1.gif -------------------------------------------------------------------------------- /source/results/exp06_text_depths_vs_style-S09999-2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/source/results/exp06_text_depths_vs_style-S09999-2.gif -------------------------------------------------------------------------------- /source/results/exp06_text_depths_vs_style-S09999-3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/source/results/exp06_text_depths_vs_style-S09999-3.gif -------------------------------------------------------------------------------- /tools/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/tools/.DS_Store -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/tools/__init__.py -------------------------------------------------------------------------------- /tools/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/tools/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /tools/annotator/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/tools/annotator/.DS_Store -------------------------------------------------------------------------------- /tools/annotator/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/tools/annotator/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /tools/annotator/canny/__init__.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import numpy as np 4 | from tools.annotator.util import HWC3 5 | # import gradio as gr 6 | 7 | class CannyDetector: 8 | def __call__(self, img, low_threshold = None, high_threshold = None, random_threshold = True): 9 | 10 | ### Convert to numpy 11 | if isinstance(img, torch.Tensor): # (h, w, c) 12 | img = img.cpu().numpy() 13 | img_np = cv2.convertScaleAbs((img * 255.)) 14 | elif isinstance(img, np.ndarray): # (h, w, c) 15 | img_np = img # we assume values are in the range from 0 to 255. 16 | else: 17 | assert False 18 | 19 | ### Select the threshold 20 | if (low_threshold is None) and (high_threshold is None): 21 | median_intensity = np.median(img_np) 22 | if random_threshold is False: 23 | low_threshold = int(max(0, (1 - 0.33) * median_intensity)) 24 | high_threshold = int(min(255, (1 + 0.33) * median_intensity)) 25 | else: 26 | random_canny = np.random.uniform(0.1, 0.4) 27 | # Might try other values 28 | low_threshold = int(max(0, (1 - random_canny) * median_intensity)) 29 | high_threshold = 2 * low_threshold 30 | 31 | ### Detect canny edge 32 | canny_edge = cv2.Canny(img_np, low_threshold, high_threshold) 33 | ### Convert to 3 channels 34 | # canny_edge = HWC3(canny_edge) 35 | 36 | canny_condition = torch.from_numpy(canny_edge.copy()).unsqueeze(dim = -1).float().cuda() / 255.0 37 | # canny_condition = torch.stack([canny_condition for _ in range(num_samples)], dim=0) 38 | # canny_condition = einops.rearrange(canny_condition, 'h w c -> b c h w').clone() 39 | # return cv2.Canny(img, low_threshold, high_threshold) 40 | return canny_condition -------------------------------------------------------------------------------- /tools/annotator/canny/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/tools/annotator/canny/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /tools/annotator/histogram/__init__.py: -------------------------------------------------------------------------------- 1 | from .palette import * -------------------------------------------------------------------------------- /tools/annotator/histogram/palette.py: -------------------------------------------------------------------------------- 1 | r"""Modified from ``https://github.com/sergeyk/rayleigh''. 2 | """ 3 | import os 4 | import os.path as osp 5 | import numpy as np 6 | from skimage.color import hsv2rgb, rgb2lab, lab2rgb 7 | from skimage.io import imsave 8 | from sklearn.metrics import euclidean_distances 9 | 10 | __all__ = ['Palette'] 11 | 12 | def rgb2hex(rgb): 13 | return '#%02x%02x%02x' % tuple([int(round(255.0 * u)) for u in rgb]) 14 | 15 | def hex2rgb(hex): 16 | rgb = hex.strip('#') 17 | fn = lambda u: round(int(u, 16) / 255.0, 5) 18 | return fn(rgb[:2]), fn(rgb[2:4]), fn(rgb[4:6]) 19 | 20 | class Palette(object): 21 | r"""Create a color palette (codebook) in the form of a 2D grid of colors. 22 | Further, the rightmost column has num_hues gradations from black to white. 23 | 24 | Parameters: 25 | num_hues: number of colors with full lightness and saturation, in the middle. 26 | num_sat: number of rows above middle row that show the same hues with decreasing saturation. 27 | """ 28 | def __init__(self, num_hues=11, num_sat=5, num_light=4): 29 | n = num_sat + 2 * num_light 30 | 31 | # hues 32 | if num_hues == 8: 33 | hues = np.tile(np.array([0., 0.10, 0.15, 0.28, 0.51, 0.58, 0.77, 0.85]), (n, 1)) 34 | elif num_hues == 9: 35 | hues = np.tile(np.array([0., 0.10, 0.15, 0.28, 0.49, 0.54, 0.60, 0.7, 0.87]), (n, 1)) 36 | elif num_hues == 10: 37 | hues = np.tile(np.array([0., 0.10, 0.15, 0.28, 0.49, 0.54, 0.60, 0.66, 0.76, 0.87]), (n, 1)) 38 | elif num_hues == 11: 39 | hues = np.tile(np.array([0.0, 0.0833, 0.166, 0.25, 0.333, 0.5, 0.56333, 0.666, 0.73, 0.803, 0.916]), (n, 1)) 40 | else: 41 | hues = np.tile(np.linspace(0, 1, num_hues + 1)[:-1], (n, 1)) 42 | 43 | # saturations 44 | sats = np.hstack(( 45 | np.linspace(0, 1, num_sat + 2)[1:-1], 46 | 1, 47 | [1] * num_light, 48 | [0.4] * (num_light - 1))) 49 | sats = np.tile(np.atleast_2d(sats).T, (1, num_hues)) 50 | 51 | # lights 52 | lights = np.hstack(( 53 | [1] * num_sat, 54 | 1, 55 | np.linspace(1, 0.2, num_light + 2)[1:-1], 56 | np.linspace(1, 0.2, num_light + 2)[1:-2])) 57 | lights = np.tile(np.atleast_2d(lights).T, (1, num_hues)) 58 | 59 | # colors 60 | rgb = hsv2rgb(np.dstack([hues, sats, lights])) 61 | gray = np.tile(np.linspace(1, 0, n)[:, np.newaxis, np.newaxis], (1, 1, 3)) 62 | self.thumbnail = np.hstack([rgb, gray]) 63 | 64 | # flatten 65 | rgb = rgb.T.reshape(3, -1).T 66 | gray = gray.T.reshape(3, -1).T 67 | self.rgb = np.vstack((rgb, gray)) 68 | self.lab = rgb2lab(self.rgb[np.newaxis, :, :]).squeeze() 69 | self.hex = [rgb2hex(u) for u in self.rgb] 70 | self.lab_dists = euclidean_distances(self.lab, squared=True) 71 | 72 | def histogram(self, rgb_img, sigma=20): 73 | # compute histogram 74 | lab = rgb2lab(rgb_img).reshape((-1, 3)) 75 | min_ind = np.argmin(euclidean_distances(lab, self.lab, squared=True), axis=1) 76 | hist = 1.0 * np.bincount(min_ind, minlength=self.lab.shape[0]) / lab.shape[0] 77 | 78 | # smooth histogram 79 | if sigma > 0: 80 | weight = np.exp(-self.lab_dists / (2.0 * sigma ** 2)) 81 | weight = weight / weight.sum(1)[:, np.newaxis] 82 | hist = (weight * hist).sum(1) 83 | hist[hist < 1e-5] = 0 84 | return hist 85 | 86 | def get_palette_image(self, hist, percentile=90, width=200, height=50): 87 | # curate histogram 88 | ind = np.argsort(-hist) 89 | ind = ind[hist[ind] > np.percentile(hist, percentile)] 90 | hist = hist[ind] / hist[ind].sum() 91 | 92 | # draw palette 93 | nums = np.array(hist * width, dtype=int) 94 | array = np.vstack([np.tile(np.array(u), (v, 1)) for u, v in zip(self.rgb[ind], nums)]) 95 | array = np.tile(array[np.newaxis, :, :], (height, 1, 1)) 96 | if array.shape[1] < width: 97 | array = np.concatenate([array, np.zeros((height, width - array.shape[1], 3))], axis=1) 98 | return array 99 | 100 | def quantize_image(self, rgb_img): 101 | lab = rgb2lab(rgb_img).reshape((-1, 3)) 102 | min_ind = np.argmin(euclidean_distances(lab, self.lab, squared=True), axis=1) 103 | quantized_lab = self.lab[min_ind] 104 | img = lab2rgb(quantized_lab.reshape(rgb_img.shape)) 105 | return img 106 | 107 | def export(self, dirname): 108 | if not osp.exists(dirname): 109 | os.makedirs(dirname) 110 | 111 | # save thumbnail 112 | imsave(osp.join(dirname, 'palette.png'), self.thumbnail) 113 | 114 | # save html 115 | with open(osp.join(dirname, 'palette.html'), 'w') as f: 116 | html = ''' 117 | 126 | ''' 127 | for row in self.thumbnail: 128 | for col in row: 129 | html += '\n'.format(rgb2hex(col)) 130 | html += '
\n' 131 | f.write(html) 132 | -------------------------------------------------------------------------------- /tools/annotator/sketch/__init__.py: -------------------------------------------------------------------------------- 1 | from .pidinet import * 2 | from .sketch_simplification import * -------------------------------------------------------------------------------- /tools/annotator/sketch/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/tools/annotator/sketch/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /tools/annotator/sketch/__pycache__/pidinet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/tools/annotator/sketch/__pycache__/pidinet.cpython-38.pyc -------------------------------------------------------------------------------- /tools/annotator/sketch/__pycache__/sketch_simplification.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/tools/annotator/sketch/__pycache__/sketch_simplification.cpython-38.pyc -------------------------------------------------------------------------------- /tools/annotator/sketch/sketch_simplification.py: -------------------------------------------------------------------------------- 1 | r"""PyTorch re-implementation adapted from the Lua code in ``https://github.com/bobbens/sketch_simplification''. 2 | """ 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import math 7 | 8 | # from canvas import DOWNLOAD_TO_CACHE 9 | from artist import DOWNLOAD_TO_CACHE 10 | 11 | __all__ = ['SketchSimplification', 'sketch_simplification_gan', 'sketch_simplification_mse', 12 | 'sketch_to_pencil_v1', 'sketch_to_pencil_v2'] 13 | 14 | class SketchSimplification(nn.Module): 15 | r"""NOTE: 16 | 1. Input image should has only one gray channel. 17 | 2. Input image size should be divisible by 8. 18 | 3. Sketch in the input/output image is in dark color while background in light color. 19 | """ 20 | def __init__(self, mean, std): 21 | assert isinstance(mean, float) and isinstance(std, float) 22 | super(SketchSimplification, self).__init__() 23 | self.mean = mean 24 | self.std = std 25 | 26 | # layers 27 | self.layers = nn.Sequential( 28 | nn.Conv2d(1, 48, 5, 2, 2), 29 | nn.ReLU(inplace=True), 30 | nn.Conv2d(48, 128, 3, 1, 1), 31 | nn.ReLU(inplace=True), 32 | nn.Conv2d(128, 128, 3, 1, 1), 33 | nn.ReLU(inplace=True), 34 | nn.Conv2d(128, 128, 3, 2, 1), 35 | nn.ReLU(inplace=True), 36 | nn.Conv2d(128, 256, 3, 1, 1), 37 | nn.ReLU(inplace=True), 38 | nn.Conv2d(256, 256, 3, 1, 1), 39 | nn.ReLU(inplace=True), 40 | nn.Conv2d(256, 256, 3, 2, 1), 41 | nn.ReLU(inplace=True), 42 | nn.Conv2d(256, 512, 3, 1, 1), 43 | nn.ReLU(inplace=True), 44 | nn.Conv2d(512, 1024, 3, 1, 1), 45 | nn.ReLU(inplace=True), 46 | nn.Conv2d(1024, 1024, 3, 1, 1), 47 | nn.ReLU(inplace=True), 48 | nn.Conv2d(1024, 1024, 3, 1, 1), 49 | nn.ReLU(inplace=True), 50 | nn.Conv2d(1024, 1024, 3, 1, 1), 51 | nn.ReLU(inplace=True), 52 | nn.Conv2d(1024, 512, 3, 1, 1), 53 | nn.ReLU(inplace=True), 54 | nn.Conv2d(512, 256, 3, 1, 1), 55 | nn.ReLU(inplace=True), 56 | nn.ConvTranspose2d(256, 256, 4, 2, 1), 57 | nn.ReLU(inplace=True), 58 | nn.Conv2d(256, 256, 3, 1, 1), 59 | nn.ReLU(inplace=True), 60 | nn.Conv2d(256, 128, 3, 1, 1), 61 | nn.ReLU(inplace=True), 62 | nn.ConvTranspose2d(128, 128, 4, 2, 1), 63 | nn.ReLU(inplace=True), 64 | nn.Conv2d(128, 128, 3, 1, 1), 65 | nn.ReLU(inplace=True), 66 | nn.Conv2d(128, 48, 3, 1, 1), 67 | nn.ReLU(inplace=True), 68 | nn.ConvTranspose2d(48, 48, 4, 2, 1), 69 | nn.ReLU(inplace=True), 70 | nn.Conv2d(48, 24, 3, 1, 1), 71 | nn.ReLU(inplace=True), 72 | nn.Conv2d(24, 1, 3, 1, 1), 73 | nn.Sigmoid()) 74 | 75 | def forward(self, x): 76 | r"""x: [B, 1, H, W] within range [0, 1]. Sketch pixels in dark color. 77 | """ 78 | x = (x - self.mean) / self.std 79 | return self.layers(x) 80 | 81 | def sketch_simplification_gan(pretrained=False): 82 | model = SketchSimplification(mean=0.9664114577640158, std=0.0858381272736797) 83 | if pretrained: 84 | model.load_state_dict(torch.load( 85 | DOWNLOAD_TO_CACHE('./model_weights/sketch_simplification_gan.pth'), 86 | map_location='cpu')) 87 | return model 88 | 89 | def sketch_simplification_mse(pretrained=False): 90 | model = SketchSimplification(mean=0.9664423107454593, std=0.08583666033640507) 91 | if pretrained: 92 | model.load_state_dict(torch.load( 93 | DOWNLOAD_TO_CACHE('models/sketch_simplification/sketch_simplification_mse.pth'), 94 | map_location='cpu')) 95 | return model 96 | 97 | def sketch_to_pencil_v1(pretrained=False): 98 | model = SketchSimplification(mean=0.9817833515894078, std=0.0925009022585048) 99 | if pretrained: 100 | model.load_state_dict(torch.load( 101 | DOWNLOAD_TO_CACHE('models/sketch_simplification/sketch_to_pencil_v1.pth'), 102 | map_location='cpu')) 103 | return model 104 | 105 | def sketch_to_pencil_v2(pretrained=False): 106 | model = SketchSimplification(mean=0.9851298627337799, std=0.07418377454883571) 107 | if pretrained: 108 | model.load_state_dict(torch.load( 109 | DOWNLOAD_TO_CACHE('models/sketch_simplification/sketch_to_pencil_v2.pth'), 110 | map_location='cpu')) 111 | return model 112 | -------------------------------------------------------------------------------- /tools/annotator/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | 5 | annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts') 6 | 7 | def HWC3(x): 8 | assert x.dtype == np.uint8 9 | if x.ndim == 2: 10 | x = x[:, :, None] 11 | assert x.ndim == 3 12 | H, W, C = x.shape 13 | assert C == 1 or C == 3 or C == 4 14 | if C == 3: 15 | return x 16 | if C == 1: 17 | return np.concatenate([x, x, x], axis=2) 18 | if C == 4: 19 | color = x[:, :, 0:3].astype(np.float32) 20 | alpha = x[:, :, 3:4].astype(np.float32) / 255.0 21 | y = color * alpha + 255.0 * (1.0 - alpha) 22 | y = y.clip(0, 255).astype(np.uint8) 23 | return y 24 | 25 | 26 | def resize_image(input_image, resolution): 27 | H, W, C = input_image.shape 28 | H = float(H) 29 | W = float(W) 30 | k = float(resolution) / min(H, W) 31 | H *= k 32 | W *= k 33 | H = int(np.round(H / 64.0)) * 64 34 | W = int(np.round(W / 64.0)) * 64 35 | img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) 36 | return img -------------------------------------------------------------------------------- /tools/videocomposer/__pycache__/autoencoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/tools/videocomposer/__pycache__/autoencoder.cpython-38.pyc -------------------------------------------------------------------------------- /tools/videocomposer/__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/tools/videocomposer/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /tools/videocomposer/__pycache__/datasets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/tools/videocomposer/__pycache__/datasets.cpython-38.pyc -------------------------------------------------------------------------------- /tools/videocomposer/__pycache__/inference_multi.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/tools/videocomposer/__pycache__/inference_multi.cpython-38.pyc -------------------------------------------------------------------------------- /tools/videocomposer/__pycache__/inference_single.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/tools/videocomposer/__pycache__/inference_single.cpython-38.pyc -------------------------------------------------------------------------------- /tools/videocomposer/__pycache__/mha_flash.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/tools/videocomposer/__pycache__/mha_flash.cpython-38.pyc -------------------------------------------------------------------------------- /tools/videocomposer/__pycache__/unet_sd.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/tools/videocomposer/__pycache__/unet_sd.cpython-38.pyc -------------------------------------------------------------------------------- /tools/videocomposer/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | __all__ = ['AutoencoderKL'] 7 | 8 | def nonlinearity(x): 9 | # swish 10 | return x*torch.sigmoid(x) 11 | 12 | def Normalize(in_channels, num_groups=32): 13 | return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) 14 | 15 | class DiagonalGaussianDistribution(object): 16 | def __init__(self, parameters, deterministic=False): 17 | self.parameters = parameters 18 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 19 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 20 | self.deterministic = deterministic 21 | self.std = torch.exp(0.5 * self.logvar) 22 | self.var = torch.exp(self.logvar) 23 | if self.deterministic: 24 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 25 | 26 | def sample(self): 27 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 28 | return x 29 | 30 | def kl(self, other=None): 31 | if self.deterministic: 32 | return torch.Tensor([0.]) 33 | else: 34 | if other is None: 35 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 36 | + self.var - 1.0 - self.logvar, 37 | dim=[1, 2, 3]) 38 | else: 39 | return 0.5 * torch.sum( 40 | torch.pow(self.mean - other.mean, 2) / other.var 41 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 42 | dim=[1, 2, 3]) 43 | 44 | def nll(self, sample, dims=[1,2,3]): 45 | if self.deterministic: 46 | return torch.Tensor([0.]) 47 | logtwopi = np.log(2.0 * np.pi) 48 | return 0.5 * torch.sum( 49 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 50 | dim=dims) 51 | 52 | def mode(self): 53 | return self.mean 54 | 55 | class Downsample(nn.Module): 56 | def __init__(self, in_channels, with_conv): 57 | super().__init__() 58 | self.with_conv = with_conv 59 | if self.with_conv: 60 | # no asymmetric padding in torch conv, must do it ourselves 61 | self.conv = torch.nn.Conv2d(in_channels, 62 | in_channels, 63 | kernel_size=3, 64 | stride=2, 65 | padding=0) 66 | 67 | def forward(self, x): 68 | if self.with_conv: 69 | pad = (0,1,0,1) 70 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0) 71 | x = self.conv(x) 72 | else: 73 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) 74 | return x 75 | 76 | class ResnetBlock(nn.Module): 77 | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, 78 | dropout, temb_channels=512): 79 | super().__init__() 80 | self.in_channels = in_channels 81 | out_channels = in_channels if out_channels is None else out_channels 82 | self.out_channels = out_channels 83 | self.use_conv_shortcut = conv_shortcut 84 | 85 | self.norm1 = Normalize(in_channels) 86 | self.conv1 = torch.nn.Conv2d(in_channels, 87 | out_channels, 88 | kernel_size=3, 89 | stride=1, 90 | padding=1) 91 | if temb_channels > 0: 92 | self.temb_proj = torch.nn.Linear(temb_channels, 93 | out_channels) 94 | self.norm2 = Normalize(out_channels) 95 | self.dropout = torch.nn.Dropout(dropout) 96 | self.conv2 = torch.nn.Conv2d(out_channels, 97 | out_channels, 98 | kernel_size=3, 99 | stride=1, 100 | padding=1) 101 | if self.in_channels != self.out_channels: 102 | if self.use_conv_shortcut: 103 | self.conv_shortcut = torch.nn.Conv2d(in_channels, 104 | out_channels, 105 | kernel_size=3, 106 | stride=1, 107 | padding=1) 108 | else: 109 | self.nin_shortcut = torch.nn.Conv2d(in_channels, 110 | out_channels, 111 | kernel_size=1, 112 | stride=1, 113 | padding=0) 114 | 115 | def forward(self, x, temb): 116 | h = x 117 | h = self.norm1(h) 118 | h = nonlinearity(h) 119 | h = self.conv1(h) 120 | 121 | if temb is not None: 122 | h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] 123 | 124 | h = self.norm2(h) 125 | h = nonlinearity(h) 126 | h = self.dropout(h) 127 | h = self.conv2(h) 128 | 129 | if self.in_channels != self.out_channels: 130 | if self.use_conv_shortcut: 131 | x = self.conv_shortcut(x) 132 | else: 133 | x = self.nin_shortcut(x) 134 | 135 | return x+h 136 | 137 | 138 | class AttnBlock(nn.Module): 139 | def __init__(self, in_channels): 140 | super().__init__() 141 | self.in_channels = in_channels 142 | 143 | self.norm = Normalize(in_channels) 144 | self.q = torch.nn.Conv2d(in_channels, 145 | in_channels, 146 | kernel_size=1, 147 | stride=1, 148 | padding=0) 149 | self.k = torch.nn.Conv2d(in_channels, 150 | in_channels, 151 | kernel_size=1, 152 | stride=1, 153 | padding=0) 154 | self.v = torch.nn.Conv2d(in_channels, 155 | in_channels, 156 | kernel_size=1, 157 | stride=1, 158 | padding=0) 159 | self.proj_out = torch.nn.Conv2d(in_channels, 160 | in_channels, 161 | kernel_size=1, 162 | stride=1, 163 | padding=0) 164 | 165 | def forward(self, x): 166 | h_ = x 167 | h_ = self.norm(h_) 168 | q = self.q(h_) 169 | k = self.k(h_) 170 | v = self.v(h_) 171 | 172 | # compute attention 173 | b,c,h,w = q.shape 174 | q = q.reshape(b,c,h*w) 175 | q = q.permute(0,2,1) # b,hw,c 176 | k = k.reshape(b,c,h*w) # b,c,hw 177 | w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] 178 | w_ = w_ * (int(c)**(-0.5)) 179 | w_ = torch.nn.functional.softmax(w_, dim=2) 180 | 181 | # attend to values 182 | v = v.reshape(b,c,h*w) 183 | w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) 184 | h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] 185 | h_ = h_.reshape(b,c,h,w) 186 | 187 | h_ = self.proj_out(h_) 188 | 189 | return x+h_ 190 | 191 | class AttnBlock(nn.Module): 192 | def __init__(self, in_channels): 193 | super().__init__() 194 | self.in_channels = in_channels 195 | 196 | self.norm = Normalize(in_channels) 197 | self.q = torch.nn.Conv2d(in_channels, 198 | in_channels, 199 | kernel_size=1, 200 | stride=1, 201 | padding=0) 202 | self.k = torch.nn.Conv2d(in_channels, 203 | in_channels, 204 | kernel_size=1, 205 | stride=1, 206 | padding=0) 207 | self.v = torch.nn.Conv2d(in_channels, 208 | in_channels, 209 | kernel_size=1, 210 | stride=1, 211 | padding=0) 212 | self.proj_out = torch.nn.Conv2d(in_channels, 213 | in_channels, 214 | kernel_size=1, 215 | stride=1, 216 | padding=0) 217 | 218 | def forward(self, x): 219 | h_ = x 220 | h_ = self.norm(h_) 221 | q = self.q(h_) 222 | k = self.k(h_) 223 | v = self.v(h_) 224 | 225 | # compute attention 226 | b,c,h,w = q.shape 227 | q = q.reshape(b,c,h*w) 228 | q = q.permute(0,2,1) # b,hw,c 229 | k = k.reshape(b,c,h*w) # b,c,hw 230 | w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] 231 | w_ = w_ * (int(c)**(-0.5)) 232 | w_ = torch.nn.functional.softmax(w_, dim=2) 233 | 234 | # attend to values 235 | v = v.reshape(b,c,h*w) 236 | w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) 237 | h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] 238 | h_ = h_.reshape(b,c,h,w) 239 | 240 | h_ = self.proj_out(h_) 241 | 242 | return x+h_ 243 | 244 | class Upsample(nn.Module): 245 | def __init__(self, in_channels, with_conv): 246 | super().__init__() 247 | self.with_conv = with_conv 248 | if self.with_conv: 249 | self.conv = torch.nn.Conv2d(in_channels, 250 | in_channels, 251 | kernel_size=3, 252 | stride=1, 253 | padding=1) 254 | 255 | def forward(self, x): 256 | x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") 257 | if self.with_conv: 258 | x = self.conv(x) 259 | return x 260 | 261 | 262 | class Downsample(nn.Module): 263 | def __init__(self, in_channels, with_conv): 264 | super().__init__() 265 | self.with_conv = with_conv 266 | if self.with_conv: 267 | # no asymmetric padding in torch conv, must do it ourselves 268 | self.conv = torch.nn.Conv2d(in_channels, 269 | in_channels, 270 | kernel_size=3, 271 | stride=2, 272 | padding=0) 273 | 274 | def forward(self, x): 275 | if self.with_conv: 276 | pad = (0,1,0,1) 277 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0) 278 | x = self.conv(x) 279 | else: 280 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) 281 | return x 282 | 283 | class Encoder(nn.Module): 284 | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, 285 | attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, 286 | resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", 287 | **ignore_kwargs): 288 | super().__init__() 289 | if use_linear_attn: attn_type = "linear" 290 | self.ch = ch 291 | self.temb_ch = 0 292 | self.num_resolutions = len(ch_mult) 293 | self.num_res_blocks = num_res_blocks 294 | self.resolution = resolution 295 | self.in_channels = in_channels 296 | 297 | # downsampling 298 | self.conv_in = torch.nn.Conv2d(in_channels, 299 | self.ch, 300 | kernel_size=3, 301 | stride=1, 302 | padding=1) 303 | 304 | curr_res = resolution 305 | in_ch_mult = (1,)+tuple(ch_mult) 306 | self.in_ch_mult = in_ch_mult 307 | self.down = nn.ModuleList() 308 | for i_level in range(self.num_resolutions): 309 | block = nn.ModuleList() 310 | attn = nn.ModuleList() 311 | block_in = ch*in_ch_mult[i_level] 312 | block_out = ch*ch_mult[i_level] 313 | for i_block in range(self.num_res_blocks): 314 | block.append(ResnetBlock(in_channels=block_in, 315 | out_channels=block_out, 316 | temb_channels=self.temb_ch, 317 | dropout=dropout)) 318 | block_in = block_out 319 | if curr_res in attn_resolutions: 320 | attn.append(AttnBlock(block_in)) 321 | down = nn.Module() 322 | down.block = block 323 | down.attn = attn 324 | if i_level != self.num_resolutions-1: 325 | down.downsample = Downsample(block_in, resamp_with_conv) 326 | curr_res = curr_res // 2 327 | self.down.append(down) 328 | 329 | # middle 330 | self.mid = nn.Module() 331 | self.mid.block_1 = ResnetBlock(in_channels=block_in, 332 | out_channels=block_in, 333 | temb_channels=self.temb_ch, 334 | dropout=dropout) 335 | self.mid.attn_1 = AttnBlock(block_in) 336 | self.mid.block_2 = ResnetBlock(in_channels=block_in, 337 | out_channels=block_in, 338 | temb_channels=self.temb_ch, 339 | dropout=dropout) 340 | 341 | # end 342 | self.norm_out = Normalize(block_in) 343 | self.conv_out = torch.nn.Conv2d(block_in, 344 | 2*z_channels if double_z else z_channels, 345 | kernel_size=3, 346 | stride=1, 347 | padding=1) 348 | 349 | def forward(self, x): 350 | # timestep embedding 351 | temb = None 352 | 353 | # downsampling 354 | hs = [self.conv_in(x)] 355 | for i_level in range(self.num_resolutions): 356 | for i_block in range(self.num_res_blocks): 357 | h = self.down[i_level].block[i_block](hs[-1], temb) 358 | if len(self.down[i_level].attn) > 0: 359 | h = self.down[i_level].attn[i_block](h) 360 | hs.append(h) 361 | if i_level != self.num_resolutions-1: 362 | hs.append(self.down[i_level].downsample(hs[-1])) 363 | 364 | # middle 365 | h = hs[-1] 366 | h = self.mid.block_1(h, temb) 367 | h = self.mid.attn_1(h) 368 | h = self.mid.block_2(h, temb) 369 | 370 | # end 371 | h = self.norm_out(h) 372 | h = nonlinearity(h) 373 | h = self.conv_out(h) 374 | return h 375 | 376 | 377 | class Decoder(nn.Module): 378 | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, 379 | attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, 380 | resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, 381 | attn_type="vanilla", **ignorekwargs): 382 | super().__init__() 383 | if use_linear_attn: attn_type = "linear" 384 | self.ch = ch 385 | self.temb_ch = 0 386 | self.num_resolutions = len(ch_mult) 387 | self.num_res_blocks = num_res_blocks 388 | self.resolution = resolution 389 | self.in_channels = in_channels 390 | self.give_pre_end = give_pre_end 391 | self.tanh_out = tanh_out 392 | 393 | # compute in_ch_mult, block_in and curr_res at lowest res 394 | in_ch_mult = (1,)+tuple(ch_mult) 395 | block_in = ch*ch_mult[self.num_resolutions-1] 396 | curr_res = resolution // 2**(self.num_resolutions-1) 397 | self.z_shape = (1,z_channels,curr_res,curr_res) 398 | print("Working with z of shape {} = {} dimensions.".format( 399 | self.z_shape, np.prod(self.z_shape))) 400 | 401 | # z to block_in 402 | self.conv_in = torch.nn.Conv2d(z_channels, 403 | block_in, 404 | kernel_size=3, 405 | stride=1, 406 | padding=1) 407 | 408 | # middle 409 | self.mid = nn.Module() 410 | self.mid.block_1 = ResnetBlock(in_channels=block_in, 411 | out_channels=block_in, 412 | temb_channels=self.temb_ch, 413 | dropout=dropout) 414 | self.mid.attn_1 = AttnBlock(block_in) 415 | self.mid.block_2 = ResnetBlock(in_channels=block_in, 416 | out_channels=block_in, 417 | temb_channels=self.temb_ch, 418 | dropout=dropout) 419 | 420 | # upsampling 421 | self.up = nn.ModuleList() 422 | for i_level in reversed(range(self.num_resolutions)): 423 | block = nn.ModuleList() 424 | attn = nn.ModuleList() 425 | block_out = ch*ch_mult[i_level] 426 | for i_block in range(self.num_res_blocks+1): 427 | block.append(ResnetBlock(in_channels=block_in, 428 | out_channels=block_out, 429 | temb_channels=self.temb_ch, 430 | dropout=dropout)) 431 | block_in = block_out 432 | if curr_res in attn_resolutions: 433 | attn.append(AttnBlock(block_in)) 434 | up = nn.Module() 435 | up.block = block 436 | up.attn = attn 437 | if i_level != 0: 438 | up.upsample = Upsample(block_in, resamp_with_conv) 439 | curr_res = curr_res * 2 440 | self.up.insert(0, up) # prepend to get consistent order 441 | 442 | # end 443 | self.norm_out = Normalize(block_in) 444 | self.conv_out = torch.nn.Conv2d(block_in, 445 | out_ch, 446 | kernel_size=3, 447 | stride=1, 448 | padding=1) 449 | 450 | def forward(self, z): 451 | #assert z.shape[1:] == self.z_shape[1:] 452 | self.last_z_shape = z.shape 453 | 454 | # timestep embedding 455 | temb = None 456 | 457 | # z to block_in 458 | h = self.conv_in(z) 459 | 460 | # middle 461 | h = self.mid.block_1(h, temb) 462 | h = self.mid.attn_1(h) 463 | h = self.mid.block_2(h, temb) 464 | 465 | # upsampling 466 | for i_level in reversed(range(self.num_resolutions)): 467 | for i_block in range(self.num_res_blocks+1): 468 | h = self.up[i_level].block[i_block](h, temb) 469 | if len(self.up[i_level].attn) > 0: 470 | h = self.up[i_level].attn[i_block](h) 471 | if i_level != 0: 472 | h = self.up[i_level].upsample(h) 473 | 474 | # end 475 | if self.give_pre_end: 476 | return h 477 | 478 | h = self.norm_out(h) 479 | h = nonlinearity(h) 480 | h = self.conv_out(h) 481 | if self.tanh_out: 482 | h = torch.tanh(h) 483 | return h 484 | 485 | 486 | class AutoencoderKL(nn.Module): 487 | def __init__(self, 488 | ddconfig, 489 | embed_dim, 490 | ckpt_path=None, 491 | ignore_keys=[], 492 | image_key="image", 493 | colorize_nlabels=None, 494 | monitor=None, 495 | ema_decay=None, 496 | learn_logvar=False 497 | ): 498 | super().__init__() 499 | self.learn_logvar = learn_logvar 500 | self.image_key = image_key 501 | self.encoder = Encoder(**ddconfig) 502 | self.decoder = Decoder(**ddconfig) 503 | assert ddconfig["double_z"] 504 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) 505 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 506 | self.embed_dim = embed_dim 507 | if colorize_nlabels is not None: 508 | assert type(colorize_nlabels)==int 509 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 510 | if monitor is not None: 511 | self.monitor = monitor 512 | self.use_ema = ema_decay is not None 513 | if ckpt_path is not None: 514 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 515 | 516 | def init_from_ckpt(self, path, ignore_keys=list()): 517 | sd = torch.load(path, map_location="cpu")["state_dict"] 518 | keys = list(sd.keys()) 519 | for key in keys: 520 | print(key, sd[key].shape) 521 | import collections 522 | sd_new = collections.OrderedDict() 523 | for k in keys: 524 | if k.find('first_stage_model') >= 0: 525 | k_new = k.split('first_stage_model.')[-1] 526 | sd_new[k_new] = sd[k] 527 | self.load_state_dict(sd_new, strict=True) 528 | print(f"Restored from {path}") 529 | 530 | def init_from_ckpt2(self, path, ignore_keys=list()): 531 | sd = torch.load(path, map_location="cpu")["state_dict"] 532 | keys = list(sd.keys()) 533 | 534 | first_stage_model 535 | for k in keys: 536 | for ik in ignore_keys: 537 | if k.startswith(ik): 538 | print("Deleting key {} from state_dict.".format(k)) 539 | del sd[k] 540 | self.load_state_dict(sd, strict=False) 541 | print(f"Restored from {path}") 542 | 543 | def on_train_batch_end(self, *args, **kwargs): 544 | if self.use_ema: 545 | self.model_ema(self) 546 | 547 | def encode(self, x): 548 | h = self.encoder(x) 549 | moments = self.quant_conv(h) 550 | posterior = DiagonalGaussianDistribution(moments) 551 | return posterior 552 | 553 | def decode(self, z): 554 | z = self.post_quant_conv(z) 555 | dec = self.decoder(z) 556 | return dec 557 | 558 | def forward(self, input, sample_posterior=True): 559 | posterior = self.encode(input) 560 | if sample_posterior: 561 | z = posterior.sample() 562 | else: 563 | z = posterior.mode() 564 | dec = self.decode(z) 565 | return dec, posterior 566 | 567 | def get_input(self, batch, k): 568 | x = batch[k] 569 | if len(x.shape) == 3: 570 | x = x[..., None] 571 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() 572 | return x 573 | 574 | def get_last_layer(self): 575 | return self.decoder.conv_out.weight 576 | 577 | @torch.no_grad() 578 | def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): 579 | log = dict() 580 | x = self.get_input(batch, self.image_key) 581 | x = x.to(self.device) 582 | if not only_inputs: 583 | xrec, posterior = self(x) 584 | if x.shape[1] > 3: 585 | # colorize with random projection 586 | assert xrec.shape[1] > 3 587 | x = self.to_rgb(x) 588 | xrec = self.to_rgb(xrec) 589 | log["samples"] = self.decode(torch.randn_like(posterior.sample())) 590 | log["reconstructions"] = xrec 591 | if log_ema or self.use_ema: 592 | with self.ema_scope(): 593 | xrec_ema, posterior_ema = self(x) 594 | if x.shape[1] > 3: 595 | # colorize with random projection 596 | assert xrec_ema.shape[1] > 3 597 | xrec_ema = self.to_rgb(xrec_ema) 598 | log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample())) 599 | log["reconstructions_ema"] = xrec_ema 600 | log["inputs"] = x 601 | return log 602 | 603 | def to_rgb(self, x): 604 | assert self.image_key == "segmentation" 605 | if not hasattr(self, "colorize"): 606 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 607 | x = F.conv2d(x, weight=self.colorize) 608 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1. 609 | return x 610 | 611 | 612 | class IdentityFirstStage(torch.nn.Module): 613 | def __init__(self, *args, vq_interface=False, **kwargs): 614 | self.vq_interface = vq_interface 615 | super().__init__() 616 | 617 | def encode(self, x, *args, **kwargs): 618 | return x 619 | 620 | def decode(self, x, *args, **kwargs): 621 | return x 622 | 623 | def quantize(self, x, *args, **kwargs): 624 | if self.vq_interface: 625 | return x, None, [None, None, None] 626 | return x 627 | 628 | def forward(self, x, *args, **kwargs): 629 | return x 630 | 631 | -------------------------------------------------------------------------------- /tools/videocomposer/config.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | import os.path as osp 4 | from datetime import datetime 5 | from easydict import EasyDict 6 | import os 7 | 8 | cfg = EasyDict(__name__='Config: VideoComposer') 9 | 10 | pmi_world_size = int(os.getenv('WORLD_SIZE', 1)) 11 | gpus_per_machine = torch.cuda.device_count() 12 | world_size = pmi_world_size * gpus_per_machine 13 | 14 | cfg.video_compositions = ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch'] 15 | 16 | # dataset 17 | cfg.root_dir = 'webvid10m/' 18 | 19 | cfg.alpha = 0.7 20 | 21 | cfg.misc_size = 384 22 | cfg.depth_std = 20.0 23 | cfg.depth_clamp = 10.0 24 | cfg.hist_sigma = 10.0 25 | 26 | # 27 | cfg.use_image_dataset = False 28 | cfg.alpha_img = 0.7 29 | 30 | cfg.resolution = 256 31 | cfg.mean = [0.5, 0.5, 0.5] 32 | cfg.std = [0.5, 0.5, 0.5] 33 | 34 | # sketch 35 | cfg.sketch_mean = [0.485, 0.456, 0.406] 36 | cfg.sketch_std = [0.229, 0.224, 0.225] 37 | 38 | # dataloader 39 | cfg.max_words = 1000 40 | 41 | cfg.frame_lens = [ 42 | 16, 43 | 16, 44 | 16, 45 | 16, 46 | ] 47 | cfg.feature_framerates = [ 48 | 4, 49 | ] 50 | cfg.feature_framerate = 4 51 | cfg.batch_sizes = { 52 | str(1):1, 53 | str(4):1, 54 | str(8):1, 55 | str(16):1, 56 | } 57 | 58 | cfg.chunk_size=64 59 | cfg.num_workers = 8 60 | cfg.prefetch_factor = 2 61 | cfg.seed = 8888 62 | 63 | # diffusion 64 | cfg.num_timesteps = 1000 65 | cfg.mean_type = 'eps' 66 | cfg.var_type = 'fixed_small' # NOTE: to stabilize training and avoid NaN 67 | cfg.loss_type = 'mse' 68 | cfg.ddim_timesteps = 50 # official: 250 69 | cfg.ddim_eta = 0.0 70 | cfg.clamp = 1.0 71 | cfg.share_noise = False 72 | cfg.use_div_loss = False 73 | 74 | # classifier-free guidance 75 | cfg.p_zero = 0.9 76 | cfg.guide_scale = 6.0 77 | 78 | # stabel diffusion 79 | cfg.sd_checkpoint = 'v2-1_512-ema-pruned.ckpt' 80 | 81 | 82 | # clip vision encoder 83 | cfg.vit_image_size = 336 84 | cfg.vit_patch_size = 14 85 | cfg.vit_dim = 1024 86 | cfg.vit_out_dim = 768 87 | cfg.vit_heads = 16 88 | cfg.vit_layers = 24 89 | cfg.vit_mean = [0.48145466, 0.4578275, 0.40821073] 90 | cfg.vit_std = [0.26862954, 0.26130258, 0.27577711] 91 | cfg.clip_checkpoint = 'open_clip_pytorch_model.bin' 92 | cfg.mvs_visual= False 93 | # unet 94 | cfg.unet_in_dim = 4 95 | cfg.unet_concat_dim = 8 96 | cfg.unet_y_dim = cfg.vit_out_dim 97 | cfg.unet_context_dim = 1024 98 | cfg.unet_out_dim = 8 if cfg.var_type.startswith('learned') else 4 99 | cfg.unet_dim = 320 100 | #cfg.unet_dim_mult = [1, 2, 3, 5] 101 | cfg.unet_dim_mult = [1, 2, 4, 4] 102 | cfg.unet_res_blocks = 2 103 | cfg.unet_num_heads = 8 104 | cfg.unet_head_dim = 64 105 | cfg.unet_attn_scales = [1 / 1, 1 / 2, 1 / 4] 106 | cfg.unet_dropout = 0.1 107 | cfg.misc_dropout = 0.5 108 | cfg.p_all_zero = 0.1 109 | cfg.p_all_keep = 0.1 110 | cfg.temporal_conv = False 111 | cfg.temporal_attn_times = 1 112 | cfg.temporal_attention = True 113 | 114 | cfg.use_fps_condition = False 115 | cfg.use_sim_mask = False 116 | 117 | ## Default: load 2d pretrain 118 | cfg.pretrained = False 119 | cfg.fix_weight = False 120 | 121 | ## Default resume 122 | # 123 | cfg.resume = True 124 | cfg.resume_step = 148000 125 | cfg.resume_check_dir = '.' 126 | cfg.resume_checkpoint = os.path.join(cfg.resume_check_dir,f'step_{cfg.resume_step}/non_ema_{cfg.resume_step}.pth') 127 | # 128 | cfg.resume_optimizer = False 129 | if cfg.resume_optimizer: 130 | cfg.resume_optimizer = os.path.join(cfg.resume_check_dir,f'optimizer_step_{cfg.resume_step}.pt') 131 | 132 | 133 | # acceleration 134 | cfg.use_ema = True 135 | # for debug, no ema 136 | if world_size<2: 137 | cfg.use_ema = False 138 | cfg.load_from = None 139 | 140 | cfg.use_checkpoint = True 141 | cfg.use_sharded_ddp = False 142 | cfg.use_fsdp = False 143 | cfg.use_fp16 = True 144 | 145 | # training 146 | cfg.ema_decay = 0.9999 147 | cfg.viz_interval = 1000 148 | cfg.save_ckp_interval = 1000 149 | 150 | # logging 151 | cfg.log_interval = 100 152 | composition_strings = '_'.join(cfg.video_compositions) 153 | ### Default log_dir 154 | cfg.log_dir = f'outputs/' 155 | -------------------------------------------------------------------------------- /tools/videocomposer/datasets.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import torch 3 | import time 4 | import torch.nn.functional as F 5 | from torch.utils.data import Dataset 6 | import torchvision.transforms as T 7 | from collections import defaultdict 8 | import re 9 | import pickle 10 | import json 11 | import random 12 | import numpy as np 13 | from io import BytesIO 14 | from PIL import Image 15 | import artist.ops as ops 16 | import cv2 17 | from skimage.color import rgb2lab, lab2rgb 18 | import datetime 19 | # ADD 20 | import os 21 | from mvextractor.videocap import VideoCap 22 | import subprocess 23 | import binascii 24 | from ipdb import set_trace 25 | import imageio 26 | 27 | import utils.logging as logging 28 | logger = logging.get_logger(__name__) 29 | 30 | def pre_caption(caption, max_words): 31 | caption = re.sub( 32 | r"([,.'!?\"()*#:;~])", 33 | '', 34 | caption.lower(), 35 | ).replace('-', ' ').replace('/', ' ').replace('', 'person') 36 | 37 | caption = re.sub( 38 | r"\s{2,}", 39 | ' ', 40 | caption, 41 | ) 42 | caption = caption.rstrip('\n') 43 | caption = caption.strip(' ') 44 | 45 | #truncate caption 46 | caption_words = caption.split(' ') 47 | if len(caption_words) > max_words: 48 | caption = ' '.join(caption_words[:max_words]) 49 | 50 | return caption 51 | 52 | def random_resize(img, size): 53 | return TF.resize(img, size, interpolation=random.choice([ 54 | InterpolationMode.BILINEAR, 55 | InterpolationMode.BICUBIC, 56 | InterpolationMode.LANCZOS])) 57 | 58 | def rand_name(length=16, suffix=''): 59 | name = binascii.b2a_hex(os.urandom(length)).decode('utf-8') 60 | if suffix: 61 | if not suffix.startswith('.'): 62 | suffix = '.' + suffix 63 | name += suffix 64 | return name 65 | 66 | def draw_motion_vectors(frame, motion_vectors): 67 | if len(motion_vectors) > 0: 68 | num_mvs = np.shape(motion_vectors)[0] 69 | for mv in np.split(motion_vectors, num_mvs): 70 | start_pt = (mv[0, 3], mv[0, 4]) 71 | end_pt = (mv[0, 5], mv[0, 6]) 72 | cv2.arrowedLine(frame, start_pt, end_pt, (0, 0, 255), 1, cv2.LINE_AA, 0, 0.1) 73 | # cv2.arrowedLine(frame, start_pt, end_pt, (0, 0, 255), 2, cv2.LINE_AA, 0, 0.2) 74 | return frame 75 | 76 | def extract_motion_vectors(input_video,fps=4, dump=False, verbose=False, visual_mv=False): 77 | 78 | if dump: 79 | now = datetime.now().strftime("%Y-%m-%dT%H:%M:%S") 80 | for child in ["frames", "motion_vectors"]: 81 | os.makedirs(os.path.join(f"out-{now}", child), exist_ok=True) 82 | temp = rand_name() 83 | # tmp_video = f'{temp}_{input_video}' 84 | tmp_video = os.path.join(input_video.split("/")[0], f'{temp}' +input_video.split("/")[-1]) 85 | videocapture = cv2.VideoCapture(input_video) 86 | frames_num = videocapture.get(cv2.CAP_PROP_FRAME_COUNT) 87 | fps_video =videocapture.get(cv2.CAP_PROP_FPS) 88 | # check if enough frames 89 | if frames_num/fps_video*fps>16: # 90 | fps = max(fps, 1) 91 | else: 92 | fps = int(16/(frames_num/fps_video)) + 1 93 | ffmpeg_cmd = f'ffmpeg -threads 8 -loglevel error -i {input_video} -filter:v fps={fps} -c:v mpeg4 -f rawvideo {tmp_video}' 94 | 95 | if os.path.exists(tmp_video): 96 | os.remove(tmp_video) 97 | 98 | subprocess.run(args=ffmpeg_cmd,shell=True,timeout=120) 99 | 100 | cap = VideoCap() 101 | # open the video file 102 | ret = cap.open(tmp_video) 103 | if not ret: 104 | raise RuntimeError(f"Could not open {tmp_video}") 105 | 106 | step = 0 107 | times = [] 108 | 109 | frame_types = [] 110 | frames = [] 111 | mvs = [] 112 | mvs_visual = [] 113 | # continuously read and display video frames and motion vectors 114 | while True: 115 | if verbose: 116 | print("Frame: ", step, end=" ") 117 | 118 | tstart = time.perf_counter() 119 | 120 | # read next video frame and corresponding motion vectors 121 | ret, frame, motion_vectors, frame_type, timestamp = cap.read() 122 | 123 | tend = time.perf_counter() 124 | telapsed = tend - tstart 125 | times.append(telapsed) 126 | 127 | # if there is an error reading the frame 128 | if not ret: 129 | if verbose: 130 | print("No frame read. Stopping.") 131 | break 132 | 133 | frame_save = np.zeros(frame.copy().shape, dtype=np.uint8) # *255 134 | if visual_mv: 135 | frame_save = draw_motion_vectors(frame_save, motion_vectors) 136 | 137 | # store motion vectors, frames, etc. in output directory 138 | dump = False 139 | if frame.shape[1] >= frame.shape[0]: 140 | w_half = (frame.shape[1] - frame.shape[0])//2 141 | if dump: 142 | cv2.imwrite(os.path.join(f"./mv_visual/", f"frame-{step}.jpg"), frame_save[:,w_half:-w_half]) 143 | mvs_visual.append(frame_save[:,w_half:-w_half]) 144 | else: 145 | h_half = (frame.shape[0] - frame.shape[1])//2 146 | if dump: 147 | cv2.imwrite(os.path.join(f"./mv_visual/", f"frame-{step}.jpg"), frame_save[h_half:-h_half,:]) 148 | mvs_visual.append(frame_save[h_half:-h_half,:]) 149 | 150 | h,w = frame.shape[:2] 151 | mv = np.zeros((h,w,2)) 152 | position = motion_vectors[:,5:7].clip((0,0),(w-1,h-1)) 153 | mv[position[:,1],position[:,0]]=motion_vectors[:,0:1]*motion_vectors[:,7:9]/motion_vectors[:, 9:] 154 | 155 | step += 1 156 | frame_types.append(frame_type) 157 | frames.append(frame) 158 | mvs.append(mv) 159 | # mvs_visual.append(frame_save) 160 | if verbose: 161 | print("average dt: ", np.mean(times)) 162 | cap.release() 163 | 164 | if os.path.exists(tmp_video): 165 | os.remove(tmp_video) 166 | 167 | return frame_types,frames,mvs, mvs_visual 168 | 169 | 170 | class VideoDataset(Dataset): 171 | def __init__(self, 172 | cfg, 173 | tokenizer=None, 174 | max_words=30, 175 | feature_framerate=1, 176 | max_frames=16, 177 | image_resolution=224, 178 | transforms=None, 179 | mv_transforms = None, 180 | misc_transforms = None, 181 | vit_transforms=None, 182 | vit_image_size = 336, 183 | misc_size = 384): 184 | 185 | self.cfg = cfg 186 | 187 | self.tokenizer = tokenizer 188 | self.max_words = max_words 189 | self.feature_framerate = feature_framerate 190 | self.max_frames = max_frames 191 | self.image_resolution = image_resolution 192 | self.transforms = transforms 193 | self.vit_transforms = vit_transforms 194 | self.vit_image_size = vit_image_size 195 | self.misc_transforms = misc_transforms 196 | self.misc_size = misc_size 197 | 198 | self.mv_transforms = mv_transforms 199 | 200 | self.video_cap_pairs = [[self.cfg.input_video, self.cfg.input_text_desc]] 201 | self.Vit_image_random_resize = T.Resize((vit_image_size, vit_image_size)) 202 | 203 | self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>", 204 | "MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"} 205 | 206 | def __len__(self): 207 | return len(self.video_cap_pairs) 208 | 209 | def __getitem__(self, index): 210 | 211 | video_key, cap_txt = self.video_cap_pairs[index] 212 | 213 | total_frames = None 214 | 215 | feature_framerate = self.feature_framerate 216 | if os.path.exists(video_key): 217 | try: 218 | ref_frame, vit_image, video_data, misc_data, mv_data = self._get_video_traindata(video_key, feature_framerate, total_frames, self.cfg.mvs_visual) 219 | except Exception as e: 220 | print('{} get frames failed... with error: {}'.format(video_key, e), flush=True) 221 | 222 | ref_frame = torch.zeros(3, self.vit_image_size, self.vit_image_size) 223 | vit_image = torch.zeros(3,self.vit_image_size,self.vit_image_size) 224 | video_data = torch.zeros(self.max_frames, 3, self.image_resolution, self.image_resolution) 225 | misc_data = torch.zeros(self.max_frames, 3, self.misc_size, self.misc_size) 226 | 227 | mv_data = torch.zeros(self.max_frames, 2, self.image_resolution, self.image_resolution) 228 | else: 229 | print("The video path does not exist or no video dir provided!") 230 | ref_frame = torch.zeros(3, self.vit_image_size, self.vit_image_size) 231 | vit_image = torch.zeros(3,self.vit_image_size,self.vit_image_size) 232 | video_data = torch.zeros(self.max_frames, 3, self.image_resolution, self.image_resolution) 233 | misc_data = torch.zeros(self.max_frames, 3, self.misc_size, self.misc_size) 234 | 235 | mv_data = torch.zeros(self.max_frames, 2, self.image_resolution, self.image_resolution) 236 | 237 | 238 | # inpainting mask 239 | p = random.random() 240 | if p < 0.7: 241 | mask = ops.make_irregular_mask(512, 512) 242 | elif p < 0.9: 243 | mask = ops.make_rectangle_mask(512, 512) 244 | else: 245 | mask = ops.make_uncrop(512, 512) 246 | mask = torch.from_numpy(cv2.resize(mask, (self.misc_size,self.misc_size), interpolation=cv2.INTER_NEAREST)).unsqueeze(0).float() 247 | 248 | mask = mask.unsqueeze(0).repeat_interleave(repeats=self.max_frames,dim=0) 249 | 250 | 251 | return ref_frame, cap_txt, video_data, misc_data, feature_framerate, mask, mv_data 252 | 253 | def _get_video_traindata(self, video_key, feature_framerate, total_frames, visual_mv): 254 | 255 | # folder_name = "cache_temp/" 256 | # filename = folder_name + osp.basename(video_key) 257 | # if not os.path.exists(folder_name): 258 | # os.makedirs(folder_name, exist_ok=True) 259 | # oss_path = osp.join(self.root_dir, video_key) 260 | # bucket, oss_key = ops.parse_oss_url(oss_path) 261 | # ops.get_object_to_file(bucket, oss_key, filename) 262 | filename = video_key 263 | for _ in range(5): 264 | try: 265 | frame_types,frames,mvs, mvs_visual = extract_motion_vectors(input_video=filename,fps=feature_framerate, visual_mv=visual_mv) 266 | # os.remove(filename) 267 | break 268 | except Exception as e: 269 | print('{} read video frames and motion vectors failed with error: {}'.format(video_key, e), flush=True) 270 | 271 | total_frames = len(frame_types) 272 | start_indexs = np.where((np.array(frame_types)=='I') & (total_frames - np.arange(total_frames) >= self.max_frames))[0] 273 | start_index = np.random.choice(start_indexs) 274 | indices = np.arange(start_index, start_index+self.max_frames) 275 | 276 | # note frames are in BGR mode, need to trans to RGB mode 277 | frames = [Image.fromarray(frames[i][:, :, ::-1]) for i in indices] 278 | mvs = [torch.from_numpy(mvs[i].transpose((2,0,1))) for i in indices] 279 | mvs = torch.stack(mvs) 280 | # set_trace() 281 | # if mvs_visual != None: 282 | if visual_mv: 283 | # images = [(mvs_visual[i][:,:,::-1]*255).astype('uint8') for i in indices] 284 | images = [(mvs_visual[i][:,:,::-1]).astype('uint8') for i in indices] 285 | # images = [mvs_visual[i] for i in indices] 286 | # images = [(image.numpy()*255).astype('uint8') for image in images] 287 | path = self.cfg.log_dir + "/visual_mv/" + video_key.split("/")[-1] + ".gif" 288 | if not os.path.exists(self.cfg.log_dir + "/visual_mv/"): 289 | os.makedirs(self.cfg.log_dir + "/visual_mv/", exist_ok=True) 290 | print("save motion vectors visualization to :", path) 291 | imageio.mimwrite(path, images, fps=8) 292 | 293 | # mvs_visual = [torch.from_numpy(mvs_visual[i].transpose((2,0,1))) for i in indices] 294 | # mvs_visual = torch.stack(mvs_visual) 295 | # mvs_visual = self.mv_transforms(mvs_visual) 296 | 297 | have_frames = len(frames)>0 298 | middle_indix = int(len(frames)/2) 299 | if have_frames: 300 | ref_frame = frames[middle_indix] 301 | vit_image = self.vit_transforms(ref_frame) 302 | misc_imgs_np = self.misc_transforms[:2](frames) 303 | misc_imgs = self.misc_transforms[2:](misc_imgs_np) 304 | frames = self.transforms(frames) 305 | mvs = self.mv_transforms(mvs) 306 | else: 307 | # ref_frame = Image.fromarray(np.zeros((3, self.image_resolution, self.image_resolution))) 308 | vit_image = torch.zeros(3,self.vit_image_size,self.vit_image_size) 309 | 310 | video_data = torch.zeros(self.max_frames, 3, self.image_resolution, self.image_resolution) 311 | mv_data = torch.zeros(self.max_frames, 2, self.image_resolution, self.image_resolution) 312 | misc_data = torch.zeros(self.max_frames, 3, self.misc_size, self.misc_size) 313 | if have_frames: 314 | video_data[:len(frames), ...] = frames # [[XX...],[...], ..., [0,0...], [], ...] 315 | misc_data[:len(frames), ...] = misc_imgs 316 | mv_data[:len(frames), ...] = mvs 317 | 318 | 319 | ref_frame = vit_image 320 | 321 | del frames 322 | del misc_imgs 323 | del mvs 324 | 325 | return ref_frame, vit_image, video_data, misc_data, mv_data 326 | 327 | -------------------------------------------------------------------------------- /tools/videocomposer/mha_flash.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.cuda.amp as amp 4 | import torch.nn.functional as F 5 | import math 6 | import os 7 | import time 8 | import numpy as np 9 | import random 10 | 11 | from flash_attn.flash_attention import FlashAttention 12 | 13 | class FlashAttentionBlock(nn.Module): 14 | 15 | def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None, batch_size=4): 16 | # consider head_dim first, then num_heads 17 | num_heads = dim // head_dim if head_dim else num_heads 18 | head_dim = dim // num_heads 19 | assert num_heads * head_dim == dim 20 | super(FlashAttentionBlock, self).__init__() 21 | self.dim = dim 22 | self.context_dim = context_dim 23 | self.num_heads = num_heads 24 | self.head_dim = head_dim 25 | self.scale = math.pow(head_dim, -0.25) 26 | 27 | # layers 28 | self.norm = nn.GroupNorm(32, dim) 29 | self.to_qkv = nn.Conv2d(dim, dim * 3, 1) 30 | if context_dim is not None: 31 | self.context_kv = nn.Linear(context_dim, dim * 2) 32 | self.proj = nn.Conv2d(dim, dim, 1) 33 | 34 | if self.head_dim <= 128 and (self.head_dim % 8) == 0: 35 | new_scale = math.pow(head_dim, -0.5) 36 | self.flash_attn = FlashAttention(softmax_scale=None, attention_dropout=0.0) 37 | 38 | # zero out the last layer params 39 | nn.init.zeros_(self.proj.weight) 40 | # self.apply(self._init_weight) 41 | 42 | 43 | def _init_weight(self, module): 44 | if isinstance(module, nn.Linear): 45 | module.weight.data.normal_(mean=0.0, std=0.15) 46 | if module.bias is not None: 47 | module.bias.data.zero_() 48 | elif isinstance(module, nn.Conv2d): 49 | module.weight.data.normal_(mean=0.0, std=0.15) 50 | if module.bias is not None: 51 | module.bias.data.zero_() 52 | 53 | def forward(self, x, context=None): 54 | r"""x: [B, C, H, W]. 55 | context: [B, L, C] or None. 56 | """ 57 | identity = x 58 | b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim 59 | 60 | # compute query, key, value 61 | x = self.norm(x) 62 | q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1) 63 | if context is not None: 64 | ck, cv = self.context_kv(context).reshape(b, -1, n * 2, d).permute(0, 2, 3, 1).chunk(2, dim=1) 65 | k = torch.cat([ck, k], dim=-1) 66 | v = torch.cat([cv, v], dim=-1) 67 | cq = torch.zeros([b, n, d, 4], dtype=q.dtype, device=q.device) 68 | q = torch.cat([q, cq], dim=-1) 69 | 70 | qkv = torch.cat([q,k,v], dim=1) 71 | origin_dtype = qkv.dtype 72 | qkv = qkv.permute(0, 3, 1, 2).reshape(b, -1, 3, n, d).half().contiguous() 73 | out, _ = self.flash_attn(qkv) 74 | out.to(origin_dtype) 75 | 76 | if context is not None: 77 | out = out[:, :-4, :, :] 78 | out = out.permute(0, 2, 3, 1).reshape(b, c, h, w) 79 | 80 | # output 81 | x = self.proj(out) 82 | return x + identity 83 | 84 | if __name__ == '__main__': 85 | batch_size = 8 86 | flash_net = FlashAttentionBlock(dim=1280, context_dim=512, num_heads=None, head_dim=64, batch_size=batch_size).cuda() 87 | 88 | x = torch.randn([batch_size, 1280, 32, 32], dtype=torch.float32).cuda() 89 | context = torch.randn([batch_size, 4, 512], dtype=torch.float32).cuda() 90 | # context = None 91 | flash_net.eval() 92 | 93 | with amp.autocast(enabled=True): 94 | # warm up 95 | for i in range(5): 96 | y = flash_net(x, context) 97 | torch.cuda.synchronize() 98 | s1 = time.time() 99 | for i in range(10): 100 | y = flash_net(x, context) 101 | torch.cuda.synchronize() 102 | s2 = time.time() 103 | 104 | print(f'Average cost time {(s2-s1)*1000/10} ms') -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/utils/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/distributed.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/utils/__pycache__/distributed.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logging.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/utils/__pycache__/logging.cpython-38.pyc -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import json 4 | import copy 5 | import argparse 6 | 7 | import utils.logging as logging 8 | logger = logging.get_logger(__name__) 9 | 10 | class Config(object): 11 | def __init__(self, load=True, cfg_dict=None, cfg_level=None): 12 | self._level = "cfg" + ("." + cfg_level if cfg_level is not None else "") 13 | if load: 14 | self.args = self._parse_args() 15 | logger.info("Loading config from {}.".format(self.args.cfg_file)) 16 | self.need_initialization = True 17 | cfg_base = self._initialize_cfg() 18 | cfg_dict = self._load_yaml(self.args) 19 | cfg_dict = self._merge_cfg_from_base(cfg_base, cfg_dict) 20 | cfg_dict = self._update_from_args(cfg_dict) 21 | self.cfg_dict = cfg_dict 22 | self._update_dict(cfg_dict) 23 | 24 | def _parse_args(self): 25 | parser = argparse.ArgumentParser( 26 | description="Argparser for configuring [code base name to think of] codebase" 27 | ) 28 | parser.add_argument( 29 | "--cfg", 30 | dest="cfg_file", 31 | help="Path to the configuration file", 32 | default='configs/exp01_vidcomposer_full.yaml' 33 | ) 34 | parser.add_argument( 35 | "--init_method", 36 | help="Initialization method, includes TCP or shared file-system", 37 | default="tcp://localhost:9999", 38 | type=str, 39 | ) 40 | parser.add_argument( 41 | "--seed", 42 | type=int, 43 | default=8888, 44 | help="Need to explore for different videos" 45 | ) 46 | parser.add_argument( 47 | '--debug', 48 | action='store_true', 49 | default=False, 50 | help='Into debug information' 51 | ) 52 | parser.add_argument( 53 | '--input_video', 54 | default='demo_video/video_8800.mp4', 55 | help='input video for full task, or motion vector of input videos', 56 | type=str, 57 | ), 58 | parser.add_argument( 59 | '--image_path', 60 | default='', 61 | help='Single Image Input', 62 | type=str 63 | ) 64 | parser.add_argument( 65 | '--sketch_path', 66 | default='', 67 | help='Single Sketch Input', 68 | type=str 69 | ) 70 | parser.add_argument( 71 | '--style_image', 72 | help='Single Sketch Input', 73 | type=str 74 | ) 75 | parser.add_argument( 76 | '--input_text_desc', 77 | default='A colorful and beautiful fish swimming in a small glass bowl with multicolored piece of stone, Macro Video', 78 | type=str, 79 | ), 80 | parser.add_argument( 81 | "opts", 82 | help="other configurations", 83 | default=None, 84 | nargs=argparse.REMAINDER 85 | ) 86 | return parser.parse_args() 87 | 88 | def _path_join(self, path_list): 89 | path = "" 90 | for p in path_list: 91 | path+= p + '/' 92 | return path[:-1] 93 | 94 | def _update_from_args(self, cfg_dict): 95 | args = self.args 96 | for var in vars(args): 97 | cfg_dict[var] = getattr(args, var) 98 | return cfg_dict 99 | 100 | def _initialize_cfg(self): 101 | if self.need_initialization: 102 | self.need_initialization = False 103 | if os.path.exists('./configs/base.yaml'): 104 | with open("./configs/base.yaml", 'r') as f: 105 | cfg = yaml.load(f.read(), Loader=yaml.SafeLoader) 106 | else: 107 | with open(os.path.realpath(__file__).split('/')[-3] + "/configs/base.yaml", 'r') as f: 108 | cfg = yaml.load(f.read(), Loader=yaml.SafeLoader) 109 | return cfg 110 | 111 | def _load_yaml(self, args, file_name=""): 112 | assert args.cfg_file is not None 113 | if not file_name == "": # reading from base file 114 | with open(file_name, 'r') as f: 115 | cfg = yaml.load(f.read(), Loader=yaml.SafeLoader) 116 | else: 117 | if os.getcwd().split("/")[-1] == args.cfg_file.split("/")[0]: 118 | args.cfg_file = args.cfg_file.replace(os.getcwd().split("/")[-1], "./") 119 | try: 120 | with open(args.cfg_file, 'r') as f: 121 | cfg = yaml.load(f.read(), Loader=yaml.SafeLoader) 122 | file_name = args.cfg_file 123 | except: 124 | args.cfg_file = os.path.realpath(__file__).split('/')[-3] + "/" + args.cfg_file 125 | with open(args.cfg_file, 'r') as f: 126 | cfg = yaml.load(f.read(), Loader=yaml.SafeLoader) 127 | file_name = args.cfg_file 128 | 129 | if "_BASE_RUN" not in cfg.keys() and "_BASE_MODEL" not in cfg.keys() and "_BASE" not in cfg.keys(): 130 | # return cfg if the base file is being accessed 131 | return cfg 132 | 133 | if "_BASE" in cfg.keys(): 134 | if cfg["_BASE"][1] == '.': 135 | prev_count = cfg["_BASE"].count('..') 136 | cfg_base_file = self._path_join(file_name.split('/')[:(-1-cfg["_BASE"].count('..'))] + cfg["_BASE"].split('/')[prev_count:]) 137 | else: 138 | cfg_base_file = cfg["_BASE"].replace( 139 | "./", 140 | args.cfg_file.replace(args.cfg_file.split('/')[-1], "") 141 | ) 142 | cfg_base = self._load_yaml(args, cfg_base_file) 143 | cfg = self._merge_cfg_from_base(cfg_base, cfg) 144 | else: 145 | if "_BASE_RUN" in cfg.keys(): 146 | if cfg["_BASE_RUN"][1] == '.': 147 | prev_count = cfg["_BASE_RUN"].count('..') 148 | cfg_base_file = self._path_join(file_name.split('/')[:(-1-prev_count)] + cfg["_BASE_RUN"].split('/')[prev_count:]) 149 | else: 150 | cfg_base_file = cfg["_BASE_RUN"].replace( 151 | "./", 152 | args.cfg_file.replace(args.cfg_file.split('/')[-1], "") 153 | ) 154 | cfg_base = self._load_yaml(args, cfg_base_file) 155 | cfg = self._merge_cfg_from_base(cfg_base, cfg, preserve_base=True) 156 | if "_BASE_MODEL" in cfg.keys(): 157 | if cfg["_BASE_MODEL"][1] == '.': 158 | prev_count = cfg["_BASE_MODEL"].count('..') 159 | cfg_base_file = self._path_join(file_name.split('/')[:(-1-cfg["_BASE_MODEL"].count('..'))] + cfg["_BASE_MODEL"].split('/')[prev_count:]) 160 | else: 161 | cfg_base_file = cfg["_BASE_MODEL"].replace( 162 | "./", 163 | args.cfg_file.replace(args.cfg_file.split('/')[-1], "") 164 | ) 165 | cfg_base = self._load_yaml(args, cfg_base_file) 166 | cfg = self._merge_cfg_from_base(cfg_base, cfg) 167 | cfg = self._merge_cfg_from_command(args, cfg) 168 | return cfg 169 | 170 | def _merge_cfg_from_base(self, cfg_base, cfg_new, preserve_base=False): 171 | for k,v in cfg_new.items(): 172 | if k in cfg_base.keys(): 173 | if isinstance(v, dict): 174 | self._merge_cfg_from_base(cfg_base[k], v) 175 | else: 176 | cfg_base[k] = v 177 | else: 178 | if "BASE" not in k or preserve_base: 179 | cfg_base[k] = v 180 | return cfg_base 181 | 182 | def _merge_cfg_from_command(self, args, cfg): 183 | assert len(args.opts) % 2 == 0, 'Override list {} has odd length: {}.'.format( 184 | args.opts, len(args.opts) 185 | ) 186 | keys = args.opts[0::2] 187 | vals = args.opts[1::2] 188 | 189 | # maximum supported depth 3 190 | for idx, key in enumerate(keys): 191 | key_split = key.split('.') 192 | assert len(key_split) <= 4, 'Key depth error. \nMaximum depth: 3\n Get depth: {}'.format( 193 | len(key_split) 194 | ) 195 | assert key_split[0] in cfg.keys(), 'Non-existant key: {}.'.format( 196 | key_split[0] 197 | ) 198 | if len(key_split) == 2: 199 | assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format( 200 | key 201 | ) 202 | elif len(key_split) == 3: 203 | assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format( 204 | key 205 | ) 206 | assert key_split[2] in cfg[key_split[0]][key_split[1]].keys(), 'Non-existant key: {}.'.format( 207 | key 208 | ) 209 | elif len(key_split) == 4: 210 | assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format( 211 | key 212 | ) 213 | assert key_split[2] in cfg[key_split[0]][key_split[1]].keys(), 'Non-existant key: {}.'.format( 214 | key 215 | ) 216 | assert key_split[3] in cfg[key_split[0]][key_split[1]][key_split[2]].keys(), 'Non-existant key: {}.'.format( 217 | key 218 | ) 219 | if len(key_split) == 1: 220 | cfg[key_split[0]] = vals[idx] 221 | elif len(key_split) == 2: 222 | cfg[key_split[0]][key_split[1]] = vals[idx] 223 | elif len(key_split) == 3: 224 | cfg[key_split[0]][key_split[1]][key_split[2]] = vals[idx] 225 | elif len(key_split) == 4: 226 | cfg[key_split[0]][key_split[1]][key_split[2]][key_split[3]] = vals[idx] 227 | return cfg 228 | 229 | def _update_dict(self, cfg_dict): 230 | def recur(key, elem): 231 | if type(elem) is dict: 232 | return key, Config(load=False, cfg_dict=elem, cfg_level=key) 233 | else: 234 | if type(elem) is str and elem[1:3]=="e-": 235 | elem = float(elem) 236 | return key, elem 237 | dic = dict(recur(k, v) for k, v in cfg_dict.items()) 238 | self.__dict__.update(dic) 239 | 240 | def get_args(self): 241 | return self.args 242 | 243 | def __repr__(self): 244 | return "{}\n".format(self.dump()) 245 | 246 | def dump(self): 247 | return json.dumps(self.cfg_dict, indent=2) 248 | 249 | def deep_copy(self): 250 | return copy.deepcopy(self) 251 | 252 | if __name__ == '__main__': 253 | # debug 254 | cfg = Config(load=True) 255 | print(cfg.DATA) 256 | -------------------------------------------------------------------------------- /utils/distributed.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Distributed helpers.""" 5 | 6 | import functools 7 | import logging 8 | import pickle 9 | import torch 10 | import torch.distributed as dist 11 | 12 | _LOCAL_PROCESS_GROUP = None 13 | 14 | 15 | def all_gather(tensors): 16 | """ 17 | All gathers the provided tensors from all processes across machines. 18 | Args: 19 | tensors (list): tensors to perform all gather across all processes in 20 | all machines. 21 | """ 22 | 23 | gather_list = [] 24 | output_tensor = [] 25 | world_size = dist.get_world_size() 26 | for tensor in tensors: 27 | tensor_placeholder = [ 28 | torch.ones_like(tensor) for _ in range(world_size) 29 | ] 30 | dist.all_gather(tensor_placeholder, tensor, async_op=False) 31 | gather_list.append(tensor_placeholder) 32 | for gathered_tensor in gather_list: 33 | output_tensor.append(torch.cat(gathered_tensor, dim=0)) 34 | return output_tensor 35 | 36 | 37 | def all_reduce(tensors, average=True): 38 | """ 39 | All reduce the provided tensors from all processes across machines. 40 | Args: 41 | tensors (list): tensors to perform all reduce across all processes in 42 | all machines. 43 | average (bool): scales the reduced tensor by the number of overall 44 | processes across all machines. 45 | """ 46 | 47 | for tensor in tensors: 48 | dist.all_reduce(tensor, async_op=False) 49 | if average: 50 | world_size = dist.get_world_size() 51 | for tensor in tensors: 52 | tensor.mul_(1.0 / world_size) 53 | return tensors 54 | 55 | 56 | def init_process_group( 57 | local_rank, 58 | local_world_size, 59 | shard_id, 60 | num_shards, 61 | init_method, 62 | dist_backend="nccl", 63 | ): 64 | """ 65 | Initializes the default process group. 66 | Args: 67 | local_rank (int): the rank on the current local machine. 68 | local_world_size (int): the world size (number of processes running) on 69 | the current local machine. 70 | shard_id (int): the shard index (machine rank) of the current machine. 71 | num_shards (int): number of shards for distributed training. 72 | init_method (string): supporting three different methods for 73 | initializing process groups: 74 | "file": use shared file system to initialize the groups across 75 | different processes. 76 | "tcp": use tcp address to initialize the groups across different 77 | dist_backend (string): backend to use for distributed training. Options 78 | includes gloo, mpi and nccl, the details can be found here: 79 | https://pytorch.org/docs/stable/distributed.html 80 | """ 81 | # Sets the GPU to use. 82 | torch.cuda.set_device(local_rank) 83 | # Initialize the process group. 84 | proc_rank = local_rank + shard_id * local_world_size 85 | world_size = local_world_size * num_shards 86 | dist.init_process_group( 87 | backend=dist_backend, 88 | init_method=init_method, 89 | world_size=world_size, 90 | rank=proc_rank, 91 | ) 92 | 93 | 94 | def is_master_proc(num_gpus=8): 95 | """ 96 | Determines if the current process is the master process. 97 | """ 98 | if torch.distributed.is_initialized(): 99 | return dist.get_rank() % num_gpus == 0 100 | else: 101 | return True 102 | 103 | 104 | def get_world_size(): 105 | """ 106 | Get the size of the world. 107 | """ 108 | if not dist.is_available(): 109 | return 1 110 | if not dist.is_initialized(): 111 | return 1 112 | return dist.get_world_size() 113 | 114 | 115 | def get_rank(): 116 | """ 117 | Get the rank of the current process. 118 | """ 119 | if not dist.is_available(): 120 | return 0 121 | if not dist.is_initialized(): 122 | return 0 123 | return dist.get_rank() 124 | 125 | 126 | def synchronize(): 127 | """ 128 | Helper function to synchronize (barrier) among all processes when 129 | using distributed training 130 | """ 131 | if not dist.is_available(): 132 | return 133 | if not dist.is_initialized(): 134 | return 135 | world_size = dist.get_world_size() 136 | if world_size == 1: 137 | return 138 | dist.barrier() 139 | 140 | 141 | @functools.lru_cache() 142 | def _get_global_gloo_group(): 143 | """ 144 | Return a process group based on gloo backend, containing all the ranks 145 | The result is cached. 146 | Returns: 147 | (group): pytorch dist group. 148 | """ 149 | if dist.get_backend() == "nccl": 150 | return dist.new_group(backend="gloo") 151 | else: 152 | return dist.group.WORLD 153 | 154 | 155 | def _serialize_to_tensor(data, group): 156 | """ 157 | Seriialize the tensor to ByteTensor. Note that only `gloo` and `nccl` 158 | backend is supported. 159 | Args: 160 | data (data): data to be serialized. 161 | group (group): pytorch dist group. 162 | Returns: 163 | tensor (ByteTensor): tensor that serialized. 164 | """ 165 | 166 | backend = dist.get_backend(group) 167 | assert backend in ["gloo", "nccl"] 168 | device = torch.device("cpu" if backend == "gloo" else "cuda") 169 | 170 | buffer = pickle.dumps(data) 171 | if len(buffer) > 1024 ** 3: 172 | logger = logging.getLogger(__name__) 173 | logger.warning( 174 | "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( 175 | get_rank(), len(buffer) / (1024 ** 3), device 176 | ) 177 | ) 178 | storage = torch.ByteStorage.from_buffer(buffer) 179 | tensor = torch.ByteTensor(storage).to(device=device) 180 | return tensor 181 | 182 | 183 | def _pad_to_largest_tensor(tensor, group): 184 | """ 185 | Padding all the tensors from different GPUs to the largest ones. 186 | Args: 187 | tensor (tensor): tensor to pad. 188 | group (group): pytorch dist group. 189 | Returns: 190 | list[int]: size of the tensor, on each rank 191 | Tensor: padded tensor that has the max size 192 | """ 193 | world_size = dist.get_world_size(group=group) 194 | assert ( 195 | world_size >= 1 196 | ), "comm.gather/all_gather must be called from ranks within the given group!" 197 | local_size = torch.tensor( 198 | [tensor.numel()], dtype=torch.int64, device=tensor.device 199 | ) 200 | size_list = [ 201 | torch.zeros([1], dtype=torch.int64, device=tensor.device) 202 | for _ in range(world_size) 203 | ] 204 | dist.all_gather(size_list, local_size, group=group) 205 | size_list = [int(size.item()) for size in size_list] 206 | 207 | max_size = max(size_list) 208 | 209 | # we pad the tensor because torch all_gather does not support 210 | # gathering tensors of different shapes 211 | if local_size != max_size: 212 | padding = torch.zeros( 213 | (max_size - local_size,), dtype=torch.uint8, device=tensor.device 214 | ) 215 | tensor = torch.cat((tensor, padding), dim=0) 216 | return size_list, tensor 217 | 218 | 219 | def all_gather_unaligned(data, group=None): 220 | """ 221 | Run all_gather on arbitrary picklable data (not necessarily tensors). 222 | 223 | Args: 224 | data: any picklable object 225 | group: a torch process group. By default, will use a group which 226 | contains all ranks on gloo backend. 227 | 228 | Returns: 229 | list[data]: list of data gathered from each rank 230 | """ 231 | if get_world_size() == 1: 232 | return [data] 233 | if group is None: 234 | group = _get_global_gloo_group() 235 | if dist.get_world_size(group) == 1: 236 | return [data] 237 | 238 | tensor = _serialize_to_tensor(data, group) 239 | 240 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 241 | max_size = max(size_list) 242 | 243 | # receiving Tensor from all ranks 244 | tensor_list = [ 245 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) 246 | for _ in size_list 247 | ] 248 | dist.all_gather(tensor_list, tensor, group=group) 249 | 250 | data_list = [] 251 | for size, tensor in zip(size_list, tensor_list): 252 | buffer = tensor.cpu().numpy().tobytes()[:size] 253 | data_list.append(pickle.loads(buffer)) 254 | 255 | return data_list 256 | 257 | 258 | def init_distributed_training(cfg): 259 | """ 260 | Initialize variables needed for distributed training. 261 | """ 262 | if cfg.NUM_GPUS <= 1: 263 | return 264 | num_gpus_per_machine = cfg.NUM_GPUS 265 | num_machines = dist.get_world_size() // num_gpus_per_machine 266 | for i in range(num_machines): 267 | ranks_on_i = list( 268 | range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine) 269 | ) 270 | pg = dist.new_group(ranks_on_i) 271 | if i == cfg.SHARD_ID: 272 | global _LOCAL_PROCESS_GROUP 273 | _LOCAL_PROCESS_GROUP = pg 274 | 275 | 276 | def get_local_size() -> int: 277 | """ 278 | Returns: 279 | The size of the per-machine process group, 280 | i.e. the number of processes per machine. 281 | """ 282 | if not dist.is_available(): 283 | return 1 284 | if not dist.is_initialized(): 285 | return 1 286 | return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) 287 | 288 | 289 | def get_local_rank() -> int: 290 | """ 291 | Returns: 292 | The rank of the current process within the local (per-machine) process group. 293 | """ 294 | if not dist.is_available(): 295 | return 0 296 | if not dist.is_initialized(): 297 | return 0 298 | assert _LOCAL_PROCESS_GROUP is not None 299 | return dist.get_rank(group=_LOCAL_PROCESS_GROUP) 300 | -------------------------------------------------------------------------------- /utils/logging.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Logging.""" 5 | 6 | import builtins 7 | import decimal 8 | import functools 9 | import logging 10 | import os 11 | import sys 12 | import simplejson 13 | # from fvcore.common.file_io import PathManager 14 | 15 | import utils.distributed as du 16 | 17 | 18 | def _suppress_print(): 19 | """ 20 | Suppresses printing from the current process. 21 | """ 22 | 23 | def print_pass(*objects, sep=" ", end="\n", file=sys.stdout, flush=False): 24 | pass 25 | 26 | builtins.print = print_pass 27 | 28 | 29 | # @functools.lru_cache(maxsize=None) 30 | # def _cached_log_stream(filename): 31 | # return PathManager.open(filename, "a") 32 | 33 | 34 | def setup_logging(cfg, log_file): 35 | """ 36 | Sets up the logging for multiple processes. Only enable the logging for the 37 | master process, and suppress logging for the non-master processes. 38 | """ 39 | if du.is_master_proc(): 40 | # Enable logging for the master process. 41 | logging.root.handlers = [] 42 | else: 43 | # Suppress logging for non-master processes. 44 | _suppress_print() 45 | 46 | logger = logging.getLogger() 47 | logger.setLevel(logging.INFO) 48 | logger.propagate = False 49 | plain_formatter = logging.Formatter( 50 | "[%(asctime)s][%(levelname)s] %(name)s: %(lineno)4d: %(message)s", 51 | datefmt="%m/%d %H:%M:%S", 52 | ) 53 | 54 | if du.is_master_proc(): 55 | ch = logging.StreamHandler(stream=sys.stdout) 56 | ch.setLevel(logging.DEBUG) 57 | ch.setFormatter(plain_formatter) 58 | logger.addHandler(ch) 59 | 60 | if log_file is not None and du.is_master_proc(du.get_world_size()): 61 | filename = os.path.join(cfg.OUTPUT_DIR, log_file) 62 | fh = logging.FileHandler(filename) 63 | fh.setLevel(logging.DEBUG) 64 | fh.setFormatter(plain_formatter) 65 | logger.addHandler(fh) 66 | 67 | 68 | def get_logger(name): 69 | """ 70 | Retrieve the logger with the specified name or, if name is None, return a 71 | logger which is the root logger of the hierarchy. 72 | Args: 73 | name (string): name of the logger. 74 | """ 75 | return logging.getLogger(name) 76 | 77 | 78 | def log_json_stats(stats): 79 | """ 80 | Logs json stats. 81 | Args: 82 | stats (dict): a dictionary of statistical information to log. 83 | """ 84 | stats = { 85 | k: decimal.Decimal("{:.6f}".format(v)) if isinstance(v, float) else v 86 | for k, v in stats.items() 87 | } 88 | json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True) 89 | logger = get_logger(__name__) 90 | logger.info("{:s}".format(json_stats)) 91 | --------------------------------------------------------------------------------