├── .gitignore
├── LICENSE
├── README.md
├── artist
├── .DS_Store
├── __init__.py
├── __pycache__
│ └── __init__.cpython-38.pyc
├── data
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── samplers.cpython-38.pyc
│ │ ├── tokenizers.cpython-38.pyc
│ │ └── transforms.cpython-38.pyc
│ ├── bpe_simple_vocab_16e6.txt.gz
│ ├── samplers.py
│ ├── tokenizers.py
│ └── transforms.py
├── models
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── clip.cpython-38.pyc
│ │ └── midas.cpython-38.pyc
│ ├── clip.py
│ └── midas.py
├── ops
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── degration.cpython-38.pyc
│ │ ├── diffusion.cpython-38.pyc
│ │ ├── distributed.cpython-38.pyc
│ │ ├── dpm_solver.cpython-38.pyc
│ │ ├── losses.cpython-38.pyc
│ │ ├── random_mask.cpython-38.pyc
│ │ └── utils.cpython-38.pyc
│ ├── degration.py
│ ├── diffusion.py
│ ├── distributed.py
│ ├── dpm_solver.py
│ ├── losses.py
│ ├── random_mask.py
│ └── utils.py
└── optim
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-38.pyc
│ ├── adafactor.cpython-38.pyc
│ └── lr_scheduler.cpython-38.pyc
│ ├── adafactor.py
│ └── lr_scheduler.py
├── configs
├── base.yaml
├── exp01_vidcomposer_full.yaml
├── exp02_motion_transfer.yaml
├── exp02_motion_transfer_vs_style.yaml
├── exp03_sketch2video_style.yaml
├── exp04_sketch2video_wo_style.yaml
├── exp05_text_depths_wo_style.yaml
├── exp06_text_depths_vs_style.yaml
└── exp10_vidcomposer_no_watermark_full.yaml
├── demo_video
├── blackswan.mp4
├── captions_list.txt
├── landscape_painting.png
├── moon_on_water.jpg
├── motion_transfer.mp4
├── qibaishi_01.png
├── src_single_sketch.png
├── style
│ ├── Bingxueqiyuan.jpeg
│ ├── fangao_01.jpeg
│ ├── fangao_02.jpeg
│ └── fangao_03.jpeg
├── sunflower.png
├── sunflower_sketch.png
├── tall_buildings.png
├── tennis.mp4
├── video_10000178.mp4
├── video_5360763.mp4
├── video_8800.mp4
└── wash_painting.png
├── environment.yaml
├── gen_sketch.py
├── model_weights
└── readme.md
├── run_bash.sh
├── run_net.py
├── source
├── fig01.jpg
├── fig02_framwork.jpg
├── fig03_image-to-video.jpg
├── fig04_hand-crafted-motions.jpg
├── fig05_video-inpainting.jpg
├── fig06_sketch-to-video.jpg
└── results
│ ├── exp02_motion_transfer-S00009.gif
│ ├── exp02_motion_transfer-S09999-0.gif
│ ├── exp02_motion_transfer-S09999.gif
│ ├── exp03_sketch2video_style-S09999.gif
│ ├── exp04_sketch2video_wo_style-S00144-1.gif
│ ├── exp04_sketch2video_wo_style-S00144-2.gif
│ ├── exp04_sketch2video_wo_style-S00144.gif
│ ├── exp05_text_depths_wo_style-S09999-0.gif
│ ├── exp05_text_depths_wo_style-S09999-1.gif
│ ├── exp05_text_depths_wo_style-S09999-2.gif
│ ├── exp06_text_depths_vs_style-S09999-0.gif
│ ├── exp06_text_depths_vs_style-S09999-1.gif
│ ├── exp06_text_depths_vs_style-S09999-2.gif
│ └── exp06_text_depths_vs_style-S09999-3.gif
├── tools
├── .DS_Store
├── __init__.py
├── __pycache__
│ └── __init__.cpython-38.pyc
├── annotator
│ ├── .DS_Store
│ ├── __pycache__
│ │ └── util.cpython-38.pyc
│ ├── canny
│ │ ├── __init__.py
│ │ └── __pycache__
│ │ │ └── __init__.cpython-38.pyc
│ ├── histogram
│ │ ├── __init__.py
│ │ └── palette.py
│ ├── sketch
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── pidinet.cpython-38.pyc
│ │ │ └── sketch_simplification.cpython-38.pyc
│ │ ├── pidinet.py
│ │ └── sketch_simplification.py
│ └── util.py
└── videocomposer
│ ├── __pycache__
│ ├── autoencoder.cpython-38.pyc
│ ├── config.cpython-38.pyc
│ ├── datasets.cpython-38.pyc
│ ├── inference_multi.cpython-38.pyc
│ ├── inference_single.cpython-38.pyc
│ ├── mha_flash.cpython-38.pyc
│ └── unet_sd.cpython-38.pyc
│ ├── autoencoder.py
│ ├── config.py
│ ├── datasets.py
│ ├── inference_multi.py
│ ├── inference_single.py
│ ├── mha_flash.py
│ └── unet_sd.py
└── utils
├── __init__.py
├── __pycache__
├── __init__.cpython-38.pyc
├── config.cpython-38.pyc
├── distributed.cpython-38.pyc
└── logging.cpython-38.pyc
├── config.py
├── distributed.py
└── logging.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pkl
2 | *.pt
3 | *.mov
4 | *.pth
5 | *.json
6 | *.mov
7 | *.npz
8 | *.npy
9 | *.boj
10 | *.onnx
11 | *.tar
12 | *.bin
13 | cache*
14 | .DS_Store
15 | *DS_Store
16 | outputs/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 DAMO Vision Intelligence Lab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # VideoComposer
2 |
3 | Official repo for [VideoComposer: Compositional Video Synthesis with Motion Controllability](https://arxiv.org/pdf/2306.02018.pdf)
4 |
5 | Please see [Project Page](https://videocomposer.github.io/) for more examples.
6 |
7 | We are searching for talented, motivated, and imaginative researchers to join our team. If you are interested, please don't hesitate to send us your resume via email yingya.zyy@alibaba-inc.com
8 |
9 | 
10 |
11 |
12 | VideoComposer is a controllable video diffusion model, which allows users to flexibly control the spatial and temporal patterns simultaneously within a synthesized video in various forms, such as text description, sketch sequence, reference video, or even simply handcrafted motions and handrawings.
13 |
14 |
15 | ## 🔥News!!!
16 |
17 | - __[2023.10]__ We release a high-quality I2VGen-XL model, please refer to the [Webpage](https://i2vgen-xl.github.io)
18 | - __[2023.08]__ We release the Gradio UI on [ModelScope](https://modelscope.cn/studios/damo/VideoComposer-Demo/summary)
19 | - __[2023.07]__ We release the pretrained model without watermark, please refer to the [ModelCard](https://modelscope.cn/models/damo/VideoComposer/files)
20 |
21 |
22 |
23 | ## TODO
24 | - [x] Release our technical papers and webpage.
25 | - [x] Release code and pretrained model.
26 | - [x] Release Gradio UI on [ModelScope](https://modelscope.cn/studios/damo/VideoComposer-Demo/summary) and Hugging Face.
27 | - [x] Release pretrained model that can generate 8s videos without watermark on [ModelScope](https://modelscope.cn/models/damo/VideoComposer/files)
28 |
29 |
30 | ## Method
31 |
32 | 
33 |
34 |
35 | ## Running by Yourself
36 |
37 | ### 1. Installation
38 |
39 | Requirements:
40 | - Python==3.8
41 | - ffmpeg (for motion vector extraction)
42 | - torch==1.12.0+cu113
43 | - torchvision==0.13.0+cu113
44 | - open-clip-torch==2.0.2
45 | - transformers==4.18.0
46 | - flash-attn==0.2
47 | - xformers==0.0.13
48 | - motion-vector-extractor==1.0.6 (for motion vector extraction)
49 |
50 | You also can create the same environment as ours with the following command:
51 | ```
52 | conda env create -f environment.yaml
53 | ```
54 |
55 | ### 2. Download model weights
56 |
57 | Download all the [model weights](https://www.modelscope.cn/models/damo/VideoComposer/summary) via the following command:
58 |
59 | ```
60 | !pip install modelscope
61 | from modelscope.hub.snapshot_download import snapshot_download
62 | model_dir = snapshot_download('damo/VideoComposer', cache_dir='model_weights/', revision='v1.0.0')
63 | ```
64 |
65 | Next, place these models in the `model_weights` folder following the file structure shown below.
66 |
67 |
68 | ```
69 | |--model_weights/
70 | | |--non_ema_228000.pth
71 | | |--midas_v3_dpt_large.pth
72 | | |--open_clip_pytorch_model.bin
73 | | |--sketch_simplification_gan.pth
74 | | |--table5_pidinet.pth
75 | | |--v2-1_512-ema-pruned.ckpt
76 | ```
77 |
78 | You can also download some of them from their original project:
79 | - "midas_v3_dpt_large.pth" in [MiDaS](https://github.com/isl-org/MiDaS)
80 | - "open_clip_pytorch_model.bin" in [Open Clip](https://github.com/mlfoundations/open_clip)
81 | - "sketch_simplification_gan.pth" and "table5_pidinet.pth" in [Pidinet](https://github.com/zhuoinoulu/pidinet)
82 | - "v2-1_512-ema-pruned.ckpt" in [Stable Diffusion](https://huggingface.co/stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned.ckpt).
83 |
84 | For convenience, we provide a download link in this repo.
85 |
86 |
87 | ### 3. Running
88 |
89 | In this project, we provide two implementations that can help you better understand our method.
90 |
91 |
92 | #### 3.1 Inference with Customized Inputs
93 |
94 | You can run the code with the following command:
95 |
96 | ```
97 | python run_net.py\
98 | --cfg configs/exp02_motion_transfer.yaml\
99 | --seed 9999\
100 | --input_video "demo_video/motion_transfer.mp4"\
101 | --image_path "demo_video/moon_on_water.jpg"\
102 | --input_text_desc "A beautiful big moon on the water at night"
103 | ```
104 | The results are saved in the `outputs/exp02_motion_transfer-S09999` folder:
105 |
106 | 
107 | 
108 |
109 |
110 | In some cases, if you notice a significant change in color difference, you can use the style condition to adjust the color distribution with the following command. This can be helpful in certain cases.
111 |
112 |
113 | ```
114 | python run_net.py\
115 | --cfg configs/exp02_motion_transfer_vs_style.yaml\
116 | --seed 9999\
117 | --input_video "demo_video/motion_transfer.mp4"\
118 | --image_path "demo_video/moon_on_water.jpg"\
119 | --style_image "demo_video/moon_on_water.jpg"\
120 | --input_text_desc "A beautiful big moon on the water at night"
121 | ```
122 |
123 |
124 | ```
125 | python run_net.py\
126 | --cfg configs/exp03_sketch2video_style.yaml\
127 | --seed 8888\
128 | --sketch_path "demo_video/src_single_sketch.png"\
129 | --style_image "demo_video/style/qibaishi_01.png"\
130 | --input_text_desc "Red-backed Shrike lanius collurio"
131 | ```
132 | 
133 |
134 |
135 |
136 | ```
137 | python run_net.py\
138 | --cfg configs/exp04_sketch2video_wo_style.yaml\
139 | --seed 144\
140 | --sketch_path "demo_video/src_single_sketch.png"\
141 | --input_text_desc "A Red-backed Shrike lanius collurio is on the branch"
142 | ```
143 | 
144 | 
145 |
146 |
147 |
148 | ```
149 | python run_net.py\
150 | --cfg configs/exp05_text_depths_wo_style.yaml\
151 | --seed 9999\
152 | --input_video demo_video/video_8800.mp4\
153 | --input_text_desc "A glittering and translucent fish swimming in a small glass bowl with multicolored piece of stone, like a glass fish"
154 | ```
155 | 
156 | 
157 |
158 | ```
159 | python run_net.py\
160 | --cfg configs/exp06_text_depths_vs_style.yaml\
161 | --seed 9999\
162 | --input_video demo_video/video_8800.mp4\
163 | --style_image "demo_video/style/qibaishi_01.png"\
164 | --input_text_desc "A glittering and translucent fish swimming in a small glass bowl with multicolored piece of stone, like a glass fish"
165 | ```
166 |
167 | 
168 | 
169 |
170 |
171 | #### 3.2 Inference on a Video
172 |
173 | You can just run the code with the following command:
174 | ```
175 | python run_net.py \
176 | --cfg configs/exp01_vidcomposer_full.yaml \
177 | --input_video "demo_video/blackswan.mp4" \
178 | --input_text_desc "A black swan swam in the water" \
179 | --seed 9999
180 | ```
181 |
182 | This command will extract the different conditions, e.g., depth, sketch, and motion vectors, of the input video for the following video generation, which are saved in the `outputs` folder. The task list are predefined in inference_multi.py.
183 |
184 |
185 |
186 | In addition to the above use cases, you can explore further possibilities with this code and model. Please note that due to the diversity of generated samples by the diffusion model, you can explore different seeds to generate better results.
187 |
188 | We hope you enjoy using it! 😀
189 |
190 |
191 |
192 | ## BibTeX
193 |
194 | If this repo is useful to you, please cite our technical paper.
195 | ```bibtex
196 | @article{2023videocomposer,
197 | title={VideoComposer: Compositional Video Synthesis with Motion Controllability},
198 | author={Wang, Xiang* and Yuan, Hangjie* and Zhang, Shiwei* and Chen, Dayou* and Wang, Jiuniu, and Zhang, Yingya, and Shen, Yujun, and Zhao, Deli and Zhou, Jingren},
199 | booktitle={arXiv preprint arXiv:2306.02018},
200 | year={2023}
201 | }
202 | ```
203 |
204 |
205 | ## Acknowledgement
206 |
207 | We would like to express our gratitude for the contributions of several previous works to the development of VideoComposer. This includes, but is not limited to [Composer](https://arxiv.org/abs/2302.09778), [ModelScopeT2V](https://modelscope.cn/models/damo/text-to-video-synthesis/summary), [Stable Diffusion](https://github.com/Stability-AI/stablediffusion), [OpenCLIP](https://github.com/mlfoundations/open_clip), [WebVid-10M](https://m-bain.github.io/webvid-dataset/), [LAION-400M](https://laion.ai/blog/laion-400-open-dataset/), [Pidinet](https://github.com/zhuoinoulu/pidinet) and [MiDaS](https://github.com/isl-org/MiDaS). We are committed to building upon these foundations in a way that respects their original contributions.
208 |
209 |
210 | ## Disclaimer
211 |
212 | This open-source model is trained on the [WebVid-10M](https://m-bain.github.io/webvid-dataset/) and [LAION-400M](https://laion.ai/blog/laion-400-open-dataset/) datasets and is intended for RESEARCH/NON-COMMERCIAL USE ONLY. We have also trained more powerful models using internal video data, which can be used in the future.
213 |
--------------------------------------------------------------------------------
/artist/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/.DS_Store
--------------------------------------------------------------------------------
/artist/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import torch
4 | import torch.distributed as dist
5 | import oss2 as oss
6 |
7 | __all__ = ['DOWNLOAD_TO_CACHE']
8 |
9 |
10 | def DOWNLOAD_TO_CACHE(oss_key,
11 | file_or_dirname=None,
12 | cache_dir=osp.join('/'.join(osp.abspath(__file__).split('/')[:-2]), 'model_weights')):
13 | r"""Download OSS [file or folder] to the cache folder.
14 | Only the 0th process on each node will run the downloading.
15 | Barrier all processes until the downloading is completed.
16 | """
17 | # source and target paths
18 | base_path = osp.join(cache_dir, file_or_dirname or osp.basename(oss_key))
19 |
20 | return base_path
21 |
--------------------------------------------------------------------------------
/artist/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/artist/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .samplers import *
2 | from .tokenizers import *
3 | from .transforms import *
4 |
--------------------------------------------------------------------------------
/artist/data/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/data/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/artist/data/__pycache__/samplers.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/data/__pycache__/samplers.cpython-38.pyc
--------------------------------------------------------------------------------
/artist/data/__pycache__/tokenizers.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/data/__pycache__/tokenizers.cpython-38.pyc
--------------------------------------------------------------------------------
/artist/data/__pycache__/transforms.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/data/__pycache__/transforms.cpython-38.pyc
--------------------------------------------------------------------------------
/artist/data/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/data/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/artist/data/samplers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import json
3 | import os.path as osp
4 | from torch.utils.data.sampler import Sampler
5 |
6 | from ..ops.distributed import get_rank, get_world_size, shared_random_seed
7 | from ..ops.utils import ceil_divide, read
8 |
9 | __all__ = ['BatchSampler','GroupSampler','ImgGroupSampler']
10 |
11 | class BatchSampler(Sampler):
12 | r"""An infinite batch sampler.
13 | """
14 | def __init__(self, dataset_size, batch_size, num_replicas=None, rank=None, shuffle=False, seed=None):
15 | self.dataset_size = dataset_size
16 | self.batch_size = batch_size
17 | self.num_replicas = num_replicas or get_world_size()
18 | self.rank = rank or get_rank()
19 | self.shuffle = shuffle
20 | self.seed = seed or shared_random_seed()
21 | self.rng = np.random.default_rng(self.seed + self.rank)
22 | self.batches_per_rank = ceil_divide(dataset_size, self.num_replicas * self.batch_size)
23 | self.samples_per_rank = self.batches_per_rank * self.batch_size
24 |
25 | # rank indices
26 | indices = self.rng.permutation(self.samples_per_rank) if shuffle else np.arange(self.samples_per_rank)
27 | indices = indices * self.num_replicas + self.rank
28 | indices = indices[indices < dataset_size]
29 | self.indices = indices
30 |
31 | def __iter__(self):
32 | start = 0
33 | while True:
34 | batch = [self.indices[i % len(self.indices)] for i in range(start, start + self.batch_size)]
35 | if self.shuffle and (start + self.batch_size) > len(self.indices):
36 | self.rng.shuffle(self.indices)
37 | start = (start + self.batch_size) % len(self.indices)
38 | yield batch
39 |
40 | class GroupSampler(Sampler):
41 |
42 | def __init__(self, group_file, batch_size, alpha=0.7, update_interval=5000, seed=8888):
43 | self.group_file = group_file
44 | self.group_folder = osp.join(osp.dirname(group_file), 'groups')
45 | self.batch_size = batch_size
46 | self.alpha = alpha
47 | self.update_interval = update_interval
48 | self.seed = seed
49 | self.rng = np.random.default_rng(seed)
50 |
51 | def __iter__(self):
52 | while True:
53 | # keep groups up-to-date
54 | self.update_groups()
55 |
56 | # collect items
57 | items = self.sample()
58 | while len(items) < self.batch_size:
59 | items += self.sample()
60 |
61 | # sample a batch
62 | batch = self.rng.choice(items, self.batch_size, replace=False if len(items) >= self.batch_size else True)
63 | yield [u.strip().split(',') for u in batch]
64 |
65 | def update_groups(self):
66 | if not hasattr(self, '_step'):
67 | self._step = 0
68 | if self._step % self.update_interval == 0:
69 | self.groups = json.loads(read(self.group_file))
70 | self._step += 1
71 |
72 | def sample(self):
73 | scales = np.array([float(next(iter(u)).split(':')[-1]) for u in self.groups])
74 | p = scales ** self.alpha / (scales ** self.alpha).sum()
75 | group = self.rng.choice(self.groups, p=p)
76 | list_file = osp.join(self.group_folder, self.rng.choice(next(iter(group.values()))))
77 | return read(list_file).strip().split('\n')
78 |
79 | class ImgGroupSampler(Sampler):
80 |
81 | def __init__(self, group_file, batch_size, alpha=0.7, update_interval=5000, seed=8888):
82 | self.group_file = group_file
83 | self.group_folder = osp.join(osp.dirname(group_file), 'groups')
84 | self.batch_size = batch_size
85 | self.alpha = alpha
86 | self.update_interval = update_interval
87 | self.seed = seed
88 | self.rng = np.random.default_rng(seed)
89 |
90 | def __iter__(self):
91 | while True:
92 | # keep groups up-to-date
93 | self.update_groups()
94 |
95 | # collect items
96 | items = self.sample()
97 | while len(items) < self.batch_size:
98 | items += self.sample()
99 |
100 | # sample a batch
101 | batch = self.rng.choice(items, self.batch_size, replace=False if len(items) >= self.batch_size else True)
102 | yield [u.strip().split(',', 1) for u in batch]
103 |
104 | def update_groups(self):
105 | if not hasattr(self, '_step'):
106 | self._step = 0
107 | if self._step % self.update_interval == 0:
108 | self.groups = json.loads(read(self.group_file))
109 |
110 | self._step += 1
111 |
112 | def sample(self):
113 | scales = np.array([float(next(iter(u)).split(':')[-1]) for u in self.groups])
114 | p = scales ** self.alpha / (scales ** self.alpha).sum()
115 | group = self.rng.choice(self.groups, p=p)
116 | list_file = osp.join(self.group_folder, self.rng.choice(next(iter(group.values()))))
117 | return read(list_file).strip().split('\n')
--------------------------------------------------------------------------------
/artist/data/tokenizers.py:
--------------------------------------------------------------------------------
1 | import os
2 | import gzip
3 | import html
4 | import ftfy
5 | import regex as re
6 | import torch
7 | from tokenizers import CharBPETokenizer, BertWordPieceTokenizer
8 | from functools import lru_cache
9 |
10 | __all__ = ['CLIPTokenizer']
11 |
12 | #-------------------------------- CLIPTokenizer --------------------------------#
13 |
14 | @lru_cache()
15 | def default_bpe():
16 | root = os.path.realpath(__file__)
17 | root = '/'.join(root.split('/')[:-1])
18 | return os.path.join(root, 'bpe_simple_vocab_16e6.txt.gz')
19 |
20 | @lru_cache()
21 | def bytes_to_unicode():
22 | """
23 | Returns list of utf-8 byte and a corresponding list of unicode strings.
24 | The reversible bpe codes work on unicode strings.
25 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
26 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
27 | This is a signficant percentage of your normal, say, 32K bpe vocab.
28 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
29 | And avoids mapping to whitespace/control characters the bpe code barfs on.
30 | """
31 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
32 | cs = bs[:]
33 | n = 0
34 | for b in range(2**8):
35 | if b not in bs:
36 | bs.append(b)
37 | cs.append(2**8+n)
38 | n += 1
39 | cs = [chr(n) for n in cs]
40 | return dict(zip(bs, cs))
41 |
42 | def get_pairs(word):
43 | """Return set of symbol pairs in a word.
44 | Word is represented as tuple of symbols (symbols being variable-length strings).
45 | """
46 | pairs = set()
47 | prev_char = word[0]
48 | for char in word[1:]:
49 | pairs.add((prev_char, char))
50 | prev_char = char
51 | return pairs
52 |
53 | def basic_clean(text):
54 | text = ftfy.fix_text(text)
55 | text = html.unescape(html.unescape(text))
56 | return text.strip()
57 |
58 | def whitespace_clean(text):
59 | text = re.sub(r'\s+', ' ', text)
60 | text = text.strip()
61 | return text
62 |
63 | class SimpleTokenizer(object):
64 |
65 | def __init__(self, bpe_path: str = default_bpe()):
66 | self.byte_encoder = bytes_to_unicode()
67 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
68 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
69 | merges = merges[1:49152-256-2+1]
70 | merges = [tuple(merge.split()) for merge in merges]
71 | vocab = list(bytes_to_unicode().values())
72 | vocab = vocab + [v+'' for v in vocab]
73 | for merge in merges:
74 | vocab.append(''.join(merge))
75 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
76 | self.encoder = dict(zip(vocab, range(len(vocab))))
77 | self.decoder = {v: k for k, v in self.encoder.items()}
78 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
79 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
80 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
81 |
82 | def bpe(self, token):
83 | if token in self.cache:
84 | return self.cache[token]
85 | word = tuple(token[:-1]) + ( token[-1] + '',)
86 | pairs = get_pairs(word)
87 |
88 | if not pairs:
89 | return token+''
90 |
91 | while True:
92 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
93 | if bigram not in self.bpe_ranks:
94 | break
95 | first, second = bigram
96 | new_word = []
97 | i = 0
98 | while i < len(word):
99 | try:
100 | j = word.index(first, i)
101 | new_word.extend(word[i:j])
102 | i = j
103 | except:
104 | new_word.extend(word[i:])
105 | break
106 |
107 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
108 | new_word.append(first+second)
109 | i += 2
110 | else:
111 | new_word.append(word[i])
112 | i += 1
113 | new_word = tuple(new_word)
114 | word = new_word
115 | if len(word) == 1:
116 | break
117 | else:
118 | pairs = get_pairs(word)
119 | word = ' '.join(word)
120 | self.cache[token] = word
121 | return word
122 |
123 | def encode(self, text):
124 | bpe_tokens = []
125 | text = whitespace_clean(basic_clean(text)).lower()
126 | for token in re.findall(self.pat, text):
127 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
128 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
129 | return bpe_tokens
130 |
131 | def decode(self, tokens):
132 | text = ''.join([self.decoder[token] for token in tokens])
133 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
134 | return text
135 |
136 | class CLIPTokenizer(object):
137 |
138 | def __init__(self, length=77):
139 | self.length = length
140 |
141 | # init tokenizer
142 | self.tokenizer = SimpleTokenizer(bpe_path=default_bpe())
143 | self.sos_token = self.tokenizer.encoder['<|startoftext|>']
144 | self.eos_token = self.tokenizer.encoder['<|endoftext|>']
145 | self.vocab_size = len(self.tokenizer.encoder)
146 |
147 | def __call__(self, sequence):
148 | if isinstance(sequence, str):
149 | return torch.LongTensor(self._tokenizer(sequence))
150 | elif isinstance(sequence, list):
151 | return torch.LongTensor([self._tokenizer(u) for u in sequence])
152 | else:
153 | raise TypeError(f'Expected the "sequence" to be a string or a list, but got {type(sequence)}')
154 |
155 | def _tokenizer(self, text):
156 | tokens = self.tokenizer.encode(text)[:self.length - 2]
157 | tokens = [self.sos_token] + tokens + [self.eos_token]
158 | tokens = tokens + [0] * (self.length - len(tokens))
159 | return tokens
160 |
--------------------------------------------------------------------------------
/artist/data/transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.transforms.functional as F
3 | import random
4 | import math
5 | import numpy as np
6 | from PIL import Image, ImageFilter
7 |
8 | __all__ = ['Compose', 'Resize', 'Rescale', 'CenterCrop', 'CenterCropV2', 'RandomCrop', 'RandomCropV2', 'RandomHFlip',\
9 | 'GaussianBlur', 'ColorJitter', 'RandomGray', 'ToTensor', 'Normalize', "ResizeRandomCrop", "ExtractResizeRandomCrop", "ExtractResizeAssignCrop"]
10 |
11 | # class Compose(object):
12 |
13 | # def __init__(self, transforms):
14 | # self.transforms = transforms
15 |
16 | # def __call__(self, rgb):
17 | # for t in self.transforms:
18 | # rgb = t(rgb)
19 | # return rgb
20 | class Compose(object):
21 |
22 | def __init__(self, transforms):
23 | self.transforms = transforms
24 |
25 | def __getitem__(self, index):
26 | if isinstance(index, slice):
27 | return Compose(self.transforms[index])
28 | else:
29 | return self.transforms[index]
30 |
31 | def __len__(self):
32 | return len(self.transforms)
33 |
34 | def __call__(self, rgb):
35 | for t in self.transforms:
36 | rgb = t(rgb)
37 | return rgb
38 |
39 | class Resize(object):
40 |
41 | def __init__(self, size=256):
42 | if isinstance(size, int):
43 | size = (size, size)
44 | self.size = size
45 |
46 | def __call__(self, rgb):
47 |
48 | rgb = [u.resize(self.size, Image.BILINEAR) for u in rgb]
49 | return rgb
50 |
51 | class Rescale(object):
52 |
53 | def __init__(self, size=256, interpolation=Image.BILINEAR):
54 | self.size = size
55 | self.interpolation = interpolation
56 |
57 | def __call__(self, rgb):
58 | w, h = rgb[0].size
59 | scale = self.size / min(w, h)
60 | out_w, out_h = int(round(w * scale)), int(round(h * scale))
61 | rgb = [u.resize((out_w, out_h), self.interpolation) for u in rgb]
62 | return rgb
63 |
64 | class CenterCrop(object):
65 |
66 | def __init__(self, size=224):
67 | self.size = size
68 |
69 | def __call__(self, rgb):
70 | w, h = rgb[0].size
71 | assert min(w, h) >= self.size
72 | x1 = (w - self.size) // 2
73 | y1 = (h - self.size) // 2
74 | rgb = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in rgb]
75 | return rgb
76 |
77 | class ResizeRandomCrop(object):
78 |
79 | def __init__(self, size=256, size_short=292):
80 | self.size = size
81 | # self.min_area = min_area
82 | self.size_short = size_short
83 |
84 | def __call__(self, rgb):
85 |
86 | # consistent crop between rgb and m
87 | while min(rgb[0].size) >= 2 * self.size_short:
88 | rgb = [u.resize((u.width // 2, u.height // 2), resample=Image.BOX) for u in rgb]
89 | scale = self.size_short / min(rgb[0].size)
90 | rgb = [u.resize((round(scale * u.width), round(scale * u.height)), resample=Image.BICUBIC) for u in rgb]
91 | out_w = self.size
92 | out_h = self.size
93 | w, h = rgb[0].size # (518, 292)
94 | x1 = random.randint(0, w - out_w)
95 | y1 = random.randint(0, h - out_h)
96 |
97 | rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb]
98 | # rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb]
99 | # # center crop
100 | # x1 = (img[0].width - self.size) // 2
101 | # y1 = (img[0].height - self.size) // 2
102 | # img = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in img]
103 | return rgb
104 |
105 |
106 |
107 | class ExtractResizeRandomCrop(object):
108 |
109 | def __init__(self, size=256, size_short=292):
110 | self.size = size
111 | # self.min_area = min_area
112 | self.size_short = size_short
113 |
114 | def __call__(self, rgb):
115 |
116 | # consistent crop between rgb and m
117 | while min(rgb[0].size) >= 2 * self.size_short:
118 | rgb = [u.resize((u.width // 2, u.height // 2), resample=Image.BOX) for u in rgb]
119 | scale = self.size_short / min(rgb[0].size)
120 | rgb = [u.resize((round(scale * u.width), round(scale * u.height)), resample=Image.BICUBIC) for u in rgb]
121 | out_w = self.size
122 | out_h = self.size
123 | w, h = rgb[0].size # (518, 292)
124 | x1 = random.randint(0, w - out_w)
125 | y1 = random.randint(0, h - out_h)
126 |
127 | rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb]
128 | # rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb]
129 | # # center crop
130 | # x1 = (img[0].width - self.size) // 2
131 | # y1 = (img[0].height - self.size) // 2
132 | # img = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in img]
133 | wh = [x1, y1, x1 + out_w, y1 + out_h]
134 | return rgb, wh
135 |
136 |
137 |
138 |
139 | class ExtractResizeAssignCrop(object):
140 |
141 | def __init__(self, size=256, size_short=292):
142 | self.size = size
143 | # self.min_area = min_area
144 | self.size_short = size_short
145 |
146 | def __call__(self, rgb, wh):
147 |
148 | # consistent crop between rgb and m
149 | while min(rgb[0].size) >= 2 * self.size_short:
150 | rgb = [u.resize((u.width // 2, u.height // 2), resample=Image.BOX) for u in rgb]
151 | scale = self.size_short / min(rgb[0].size)
152 | rgb = [u.resize((round(scale * u.width), round(scale * u.height)), resample=Image.BICUBIC) for u in rgb]
153 | # out_w = self.size
154 | # out_h = self.size
155 | # w, h = rgb[0].size # (518, 292)
156 | # x1 = random.randint(0, w - out_w)
157 | # y1 = random.randint(0, h - out_h)
158 |
159 | rgb = [u.crop(wh) for u in rgb]
160 | rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb]
161 | # # center crop
162 | # x1 = (img[0].width - self.size) // 2
163 | # y1 = (img[0].height - self.size) // 2
164 | # img = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in img]
165 | # wh = [x1, y1, x1 + out_w, y1 + out_h]
166 | return rgb
167 |
168 | class CenterCropV2(object):
169 | def __init__(self, size):
170 | self.size = size
171 |
172 | def __call__(self, img):
173 | # fast resize
174 | while min(img[0].size) >= 2 * self.size:
175 | img = [u.resize((u.width // 2, u.height // 2), resample=Image.BOX) for u in img]
176 | scale = self.size / min(img[0].size)
177 | img = [u.resize((round(scale * u.width), round(scale * u.height)), resample=Image.BICUBIC) for u in img]
178 |
179 | # center crop
180 | x1 = (img[0].width - self.size) // 2
181 | y1 = (img[0].height - self.size) // 2
182 | img = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in img]
183 | return img
184 |
185 | class RandomCrop(object):
186 |
187 | def __init__(self, size=224, min_area=0.4):
188 | self.size = size
189 | self.min_area = min_area
190 |
191 | def __call__(self, rgb):
192 |
193 | # consistent crop between rgb and m
194 | w, h = rgb[0].size
195 | area = w * h
196 | out_w, out_h = float('inf'), float('inf')
197 | while out_w > w or out_h > h:
198 | target_area = random.uniform(self.min_area, 1.0) * area
199 | aspect_ratio = random.uniform(3. / 4., 4. / 3.)
200 | out_w = int(round(math.sqrt(target_area * aspect_ratio)))
201 | out_h = int(round(math.sqrt(target_area / aspect_ratio)))
202 | x1 = random.randint(0, w - out_w)
203 | y1 = random.randint(0, h - out_h)
204 |
205 | rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb]
206 | rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb]
207 |
208 | return rgb
209 |
210 | class RandomCropV2(object):
211 |
212 | def __init__(self, size=224, min_area=0.4, ratio=(3. / 4., 4. / 3.)):
213 | if isinstance(size, (tuple, list)):
214 | self.size = size
215 | else:
216 | self.size = (size, size)
217 | self.min_area = min_area
218 | self.ratio = ratio
219 |
220 | def _get_params(self, img):
221 | width, height = img.size
222 | area = height * width
223 |
224 | for _ in range(10):
225 | target_area = random.uniform(self.min_area, 1.0) * area
226 | log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1]))
227 | aspect_ratio = math.exp(random.uniform(*log_ratio))
228 |
229 | w = int(round(math.sqrt(target_area * aspect_ratio)))
230 | h = int(round(math.sqrt(target_area / aspect_ratio)))
231 |
232 | if 0 < w <= width and 0 < h <= height:
233 | i = random.randint(0, height - h)
234 | j = random.randint(0, width - w)
235 | return i, j, h, w
236 |
237 | # Fallback to central crop
238 | in_ratio = float(width) / float(height)
239 | if (in_ratio < min(self.ratio)):
240 | w = width
241 | h = int(round(w / min(self.ratio)))
242 | elif (in_ratio > max(self.ratio)):
243 | h = height
244 | w = int(round(h * max(self.ratio)))
245 | else: # whole image
246 | w = width
247 | h = height
248 | i = (height - h) // 2
249 | j = (width - w) // 2
250 | return i, j, h, w
251 |
252 | def __call__(self, rgb):
253 | i, j, h, w = self._get_params(rgb[0])
254 | rgb = [F.resized_crop(u, i, j, h, w, self.size) for u in rgb]
255 | return rgb
256 |
257 | class RandomHFlip(object):
258 |
259 | def __init__(self, p=0.5):
260 | self.p = p
261 |
262 | def __call__(self, rgb):
263 | if random.random() < self.p:
264 | rgb = [u.transpose(Image.FLIP_LEFT_RIGHT) for u in rgb]
265 | return rgb
266 |
267 | class GaussianBlur(object):
268 |
269 | def __init__(self, sigmas=[0.1, 2.0], p=0.5):
270 | self.sigmas = sigmas
271 | self.p = p
272 |
273 | def __call__(self, rgb):
274 | if random.random() < self.p:
275 | sigma = random.uniform(*self.sigmas)
276 | rgb = [u.filter(ImageFilter.GaussianBlur(radius=sigma)) for u in rgb]
277 | return rgb
278 |
279 | class ColorJitter(object):
280 |
281 | def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, p=0.5):
282 | self.brightness = brightness
283 | self.contrast = contrast
284 | self.saturation = saturation
285 | self.hue = hue
286 | self.p = p
287 |
288 | def __call__(self, rgb):
289 | if random.random() < self.p:
290 | brightness, contrast, saturation, hue = self._random_params()
291 | transforms = [
292 | lambda f: F.adjust_brightness(f, brightness),
293 | lambda f: F.adjust_contrast(f, contrast),
294 | lambda f: F.adjust_saturation(f, saturation),
295 | lambda f: F.adjust_hue(f, hue)]
296 | random.shuffle(transforms)
297 | for t in transforms:
298 | rgb = [t(u) for u in rgb]
299 |
300 | return rgb
301 |
302 | def _random_params(self):
303 | brightness = random.uniform(
304 | max(0, 1 - self.brightness), 1 + self.brightness)
305 | contrast = random.uniform(
306 | max(0, 1 - self.contrast), 1 + self.contrast)
307 | saturation = random.uniform(
308 | max(0, 1 - self.saturation), 1 + self.saturation)
309 | hue = random.uniform(-self.hue, self.hue)
310 | return brightness, contrast, saturation, hue
311 |
312 | class RandomGray(object):
313 |
314 | def __init__(self, p=0.2):
315 | self.p = p
316 |
317 | def __call__(self, rgb):
318 | if random.random() < self.p:
319 | rgb = [u.convert('L').convert('RGB') for u in rgb]
320 | return rgb
321 |
322 | class ToTensor(object):
323 |
324 | def __call__(self, rgb):
325 | rgb = torch.stack([F.to_tensor(u) for u in rgb], dim=0)
326 | return rgb
327 |
328 | class Normalize(object):
329 |
330 | def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
331 | self.mean = mean
332 | self.std = std
333 |
334 | def __call__(self, rgb):
335 | rgb = rgb.clone()
336 | rgb.clamp_(0, 1)
337 | if not isinstance(self.mean, torch.Tensor):
338 | self.mean = rgb.new_tensor(self.mean).view(-1)
339 | if not isinstance(self.std, torch.Tensor):
340 | self.std = rgb.new_tensor(self.std).view(-1)
341 | rgb.sub_(self.mean.view(1, -1, 1, 1)).div_(self.std.view(1, -1, 1, 1))
342 | return rgb
343 |
344 |
--------------------------------------------------------------------------------
/artist/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .clip import *
2 | from .midas import *
3 |
4 |
--------------------------------------------------------------------------------
/artist/models/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/models/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/artist/models/__pycache__/clip.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/models/__pycache__/clip.cpython-38.pyc
--------------------------------------------------------------------------------
/artist/models/__pycache__/midas.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/models/__pycache__/midas.cpython-38.pyc
--------------------------------------------------------------------------------
/artist/models/clip.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 |
6 | import artist.ops as ops # for using differentiable all_gather
7 | from artist import DOWNLOAD_TO_CACHE
8 |
9 | __all__ = ['CLIP', 'clip_vit_b_32', 'clip_vit_b_16', 'clip_vit_l_14', 'clip_vit_l_14_336px', 'clip_vit_h_16']
10 |
11 | def to_fp16(m):
12 | if isinstance(m, (nn.Linear, nn.Conv2d)):
13 | m.weight.data = m.weight.data.half()
14 | if m.bias is not None:
15 | m.bias.data = m.bias.data.half()
16 | elif hasattr(m, 'head'):
17 | p = getattr(m, 'head')
18 | p.data = p.data.half()
19 |
20 | class QuickGELU(nn.Module):
21 |
22 | def forward(self, x):
23 | return x * torch.sigmoid(1.702 * x)
24 |
25 | class LayerNorm(nn.LayerNorm):
26 | r"""Subclass of nn.LayerNorm to handle fp16.
27 | """
28 | def forward(self, x):
29 | return super(LayerNorm, self).forward(x.float()).type_as(x)
30 |
31 | class SelfAttention(nn.Module):
32 |
33 | def __init__(self, dim, num_heads, attn_dropout=0.0, proj_dropout=0.0):
34 | assert dim % num_heads == 0
35 | super(SelfAttention, self).__init__()
36 | self.dim = dim
37 | self.num_heads = num_heads
38 | self.head_dim = dim // num_heads
39 | self.scale = 1.0 / math.sqrt(self.head_dim)
40 |
41 | # layers
42 | self.to_qkv = nn.Linear(dim, dim * 3)
43 | self.attn_dropout = nn.Dropout(attn_dropout)
44 | self.proj = nn.Linear(dim, dim)
45 | self.proj_dropout = nn.Dropout(proj_dropout)
46 |
47 | def forward(self, x, mask=None):
48 | r"""x: [B, L, C].
49 | mask: [*, L, L].
50 | """
51 | b, l, c, n = *x.size(), self.num_heads
52 |
53 | # compute query, key, and value
54 | q, k, v = self.to_qkv(x.transpose(0, 1)).chunk(3, dim=-1)
55 | q = q.reshape(l, b * n, -1).transpose(0, 1)
56 | k = k.reshape(l, b * n, -1).transpose(0, 1)
57 | v = v.reshape(l, b * n, -1).transpose(0, 1)
58 |
59 | # compute attention
60 | attn = self.scale * torch.bmm(q, k.transpose(1, 2))
61 | if mask is not None:
62 | attn = attn.masked_fill(mask[:, :l, :l] == 0, float('-inf'))
63 | attn = F.softmax(attn.float(), dim=-1).type_as(attn)
64 | attn = self.attn_dropout(attn)
65 |
66 | # gather context
67 | x = torch.bmm(attn, v)
68 | x = x.view(b, n, l, -1).transpose(1, 2).reshape(b, l, -1)
69 |
70 | # output
71 | x = self.proj(x)
72 | x = self.proj_dropout(x)
73 | return x
74 |
75 | class AttentionBlock(nn.Module):
76 |
77 | def __init__(self, dim, num_heads, attn_dropout=0.0, proj_dropout=0.0):
78 | super(AttentionBlock, self).__init__()
79 | self.dim = dim
80 | self.num_heads = num_heads
81 |
82 | # layers
83 | self.norm1 = LayerNorm(dim)
84 | self.attn = SelfAttention(dim, num_heads, attn_dropout, proj_dropout)
85 | self.norm2 = LayerNorm(dim)
86 | self.mlp = nn.Sequential(
87 | nn.Linear(dim, dim * 4),
88 | QuickGELU(),
89 | nn.Linear(dim * 4, dim),
90 | nn.Dropout(proj_dropout))
91 |
92 | def forward(self, x, mask=None):
93 | x = x + self.attn(self.norm1(x), mask)
94 | x = x + self.mlp(self.norm2(x))
95 | return x
96 |
97 | class VisionTransformer(nn.Module):
98 |
99 | def __init__(self,
100 | image_size=224,
101 | patch_size=16,
102 | dim=768,
103 | out_dim=512,
104 | num_heads=12,
105 | num_layers=12,
106 | attn_dropout=0.0,
107 | proj_dropout=0.0,
108 | embedding_dropout=0.0):
109 | assert image_size % patch_size == 0
110 | super(VisionTransformer, self).__init__()
111 | self.image_size = image_size
112 | self.patch_size = patch_size
113 | self.dim = dim
114 | self.out_dim = out_dim
115 | self.num_heads = num_heads
116 | self.num_layers = num_layers
117 | self.num_patches = (image_size // patch_size) ** 2
118 |
119 | # embeddings
120 | gain = 1.0 / math.sqrt(dim)
121 | self.patch_embedding = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size, bias=False)
122 | self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
123 | self.pos_embedding = nn.Parameter(gain * torch.randn(1, self.num_patches + 1, dim))
124 | self.dropout = nn.Dropout(embedding_dropout)
125 |
126 | # transformer
127 | self.pre_norm = LayerNorm(dim)
128 | self.transformer = nn.Sequential(*[
129 | AttentionBlock(dim, num_heads, attn_dropout, proj_dropout)
130 | for _ in range(num_layers)])
131 | self.post_norm = LayerNorm(dim)
132 |
133 | # head
134 | self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
135 |
136 | def forward(self, x):
137 | b, dtype = x.size(0), self.head.dtype
138 | x = x.type(dtype)
139 |
140 | # patch-embedding
141 | x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) # [b, n, c]
142 | x = torch.cat([self.cls_embedding.repeat(b, 1, 1).type(dtype), x], dim=1)
143 | x = self.dropout(x + self.pos_embedding.type(dtype))
144 | x = self.pre_norm(x)
145 |
146 | # transformer
147 | x = self.transformer(x)
148 |
149 | # head
150 | x = self.post_norm(x)
151 | x = torch.mm(x[:, 0, :], self.head)
152 | return x
153 |
154 | def fp16(self):
155 | return self.apply(to_fp16)
156 |
157 | class TextTransformer(nn.Module):
158 |
159 | def __init__(self,
160 | vocab_size,
161 | text_len,
162 | dim=512,
163 | out_dim=512,
164 | num_heads=8,
165 | num_layers=12,
166 | attn_dropout=0.0,
167 | proj_dropout=0.0,
168 | embedding_dropout=0.0):
169 | super(TextTransformer, self).__init__()
170 | self.vocab_size = vocab_size
171 | self.text_len = text_len
172 | self.dim = dim
173 | self.out_dim = out_dim
174 | self.num_heads = num_heads
175 | self.num_layers = num_layers
176 |
177 | # embeddings
178 | self.token_embedding = nn.Embedding(vocab_size, dim)
179 | self.pos_embedding = nn.Parameter(0.01 * torch.randn(1, text_len, dim))
180 | self.dropout = nn.Dropout(embedding_dropout)
181 |
182 | # transformer
183 | self.transformer = nn.ModuleList([
184 | AttentionBlock(dim, num_heads, attn_dropout, proj_dropout)
185 | for _ in range(num_layers)])
186 | self.norm = LayerNorm(dim)
187 |
188 | # head
189 | gain = 1.0 / math.sqrt(dim)
190 | self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
191 |
192 | # causal attention mask
193 | self.register_buffer('attn_mask', torch.tril(torch.ones(1, text_len, text_len)))
194 |
195 | def forward(self, x):
196 | eot, dtype = x.argmax(dim=-1), self.head.dtype
197 |
198 | # embeddings
199 | x = self.dropout(self.token_embedding(x).type(dtype) + self.pos_embedding.type(dtype))
200 |
201 | # transformer
202 | for block in self.transformer:
203 | x = block(x, self.attn_mask)
204 |
205 | # head
206 | x = self.norm(x)
207 | x = torch.mm(x[torch.arange(x.size(0)), eot], self.head)
208 | return x
209 |
210 | def fp16(self):
211 | return self.apply(to_fp16)
212 |
213 | class CLIP(nn.Module):
214 |
215 | def __init__(self,
216 | embed_dim=512,
217 | image_size=224,
218 | patch_size=16,
219 | vision_dim=768,
220 | vision_heads=12,
221 | vision_layers=12,
222 | vocab_size=49408,
223 | text_len=77,
224 | text_dim=512,
225 | text_heads=8,
226 | text_layers=12,
227 | attn_dropout=0.0,
228 | proj_dropout=0.0,
229 | embedding_dropout=0.0):
230 | super(CLIP, self).__init__()
231 | self.embed_dim = embed_dim
232 | self.image_size = image_size
233 | self.patch_size = patch_size
234 | self.vision_dim = vision_dim
235 | self.vision_heads = vision_heads
236 | self.vision_layers = vision_layers
237 | self.vocab_size = vocab_size
238 | self.text_len = text_len
239 | self.text_dim = text_dim
240 | self.text_heads = text_heads
241 | self.text_layers = text_layers
242 |
243 | # models
244 | self.visual = VisionTransformer(
245 | image_size=image_size,
246 | patch_size=patch_size,
247 | dim=vision_dim,
248 | out_dim=embed_dim,
249 | num_heads=vision_heads,
250 | num_layers=vision_layers,
251 | attn_dropout=attn_dropout,
252 | proj_dropout=proj_dropout,
253 | embedding_dropout=embedding_dropout)
254 | self.textual = TextTransformer(
255 | vocab_size=vocab_size,
256 | text_len=text_len,
257 | dim=text_dim,
258 | out_dim=embed_dim,
259 | num_heads=text_heads,
260 | num_layers=text_layers,
261 | attn_dropout=attn_dropout,
262 | proj_dropout=proj_dropout,
263 | embedding_dropout=embedding_dropout)
264 | self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
265 |
266 | def forward(self, imgs, txt_tokens):
267 | r"""imgs: [B, C, H, W] of torch.float32.
268 | txt_tokens: [B, T] of torch.long.
269 | """
270 | xi = self.visual(imgs)
271 | xt = self.textual(txt_tokens)
272 |
273 | # normalize features
274 | xi = F.normalize(xi, p=2, dim=1)
275 | xt = F.normalize(xt, p=2, dim=1)
276 |
277 | # gather features from all ranks
278 | full_xi = ops.diff_all_gather(xi)
279 | full_xt = ops.diff_all_gather(xt)
280 |
281 | # logits
282 | scale = self.log_scale.exp()
283 | logits_i2t = scale * torch.mm(xi, full_xt.t())
284 | logits_t2i = scale * torch.mm(xt, full_xi.t())
285 |
286 | # labels
287 | labels = torch.arange(
288 | len(xi) * ops.get_rank(),
289 | len(xi) * (ops.get_rank() + 1),
290 | dtype=torch.long,
291 | device=xi.device)
292 | return logits_i2t, logits_t2i, labels
293 |
294 | def init_weights(self):
295 | # embeddings
296 | nn.init.normal_(self.textual.token_embedding.weight, std=0.02)
297 | nn.init.normal_(self.visual.patch_embedding.weight, tsd=0.1)
298 |
299 | # attentions
300 | for modality in ['visual', 'textual']:
301 | dim = self.vision_dim if modality == 'visual' else 'textual'
302 | transformer = getattr(self, modality).transformer
303 | proj_gain = (1.0 / math.sqrt(dim)) * (1.0 / math.sqrt(2 * transformer.num_layers))
304 | attn_gain = 1.0 / math.sqrt(dim)
305 | mlp_gain = 1.0 / math.sqrt(2.0 * dim)
306 | for block in transformer.layers:
307 | nn.init.normal_(block.attn.to_qkv.weight, std=attn_gain)
308 | nn.init.normal_(block.attn.proj.weight, std=proj_gain)
309 | nn.init.normal_(block.mlp[0].weight, std=mlp_gain)
310 | nn.init.normal_(block.mlp[2].weight, std=proj_gain)
311 |
312 | def param_groups(self):
313 | groups = [
314 | {'params': [p for n, p in self.named_parameters() if 'norm' in n or n.endswith('bias')], 'weight_decay': 0.0},
315 | {'params': [p for n, p in self.named_parameters() if not ('norm' in n or n.endswith('bias'))]}]
316 | return groups
317 |
318 | def fp16(self):
319 | return self.apply(to_fp16)
320 |
321 | def _clip(name, pretrained=False, **kwargs):
322 | model = CLIP(**kwargs)
323 | if pretrained:
324 | model.load_state_dict(torch.load(DOWNLOAD_TO_CACHE(f'models/clip/{name}.pth'), map_location='cpu'))
325 | return model
326 |
327 | def clip_vit_b_32(pretrained=False, **kwargs):
328 | cfg = dict(
329 | embed_dim=512,
330 | image_size=224,
331 | patch_size=32,
332 | vision_dim=768,
333 | vision_heads=12,
334 | vision_layers=12,
335 | vocab_size=49408,
336 | text_len=77,
337 | text_dim=512,
338 | text_heads=8,
339 | text_layers=12)
340 | cfg.update(**kwargs)
341 | return _clip('openai-clip-vit-base-32', pretrained, **cfg)
342 |
343 | def clip_vit_b_16(pretrained=False, **kwargs):
344 | cfg = dict(
345 | embed_dim=512,
346 | image_size=224,
347 | patch_size=32,
348 | vision_dim=768,
349 | vision_heads=12,
350 | vision_layers=12,
351 | vocab_size=49408,
352 | text_len=77,
353 | text_dim=512,
354 | text_heads=8,
355 | text_layers=12)
356 | cfg.update(**kwargs)
357 | return _clip('openai-clip-vit-base-16', pretrained, **cfg)
358 |
359 | def clip_vit_l_14(pretrained=False, **kwargs):
360 | cfg = dict(
361 | embed_dim=768,
362 | image_size=224,
363 | patch_size=14,
364 | vision_dim=1024,
365 | vision_heads=16,
366 | vision_layers=24,
367 | vocab_size=49408,
368 | text_len=77,
369 | text_dim=768,
370 | text_heads=12,
371 | text_layers=12)
372 | cfg.update(**kwargs)
373 | return _clip('openai-clip-vit-large-14', pretrained, **cfg)
374 |
375 | def clip_vit_l_14_336px(pretrained=False, **kwargs):
376 | cfg = dict(
377 | embed_dim=768,
378 | image_size=336,
379 | patch_size=14,
380 | vision_dim=1024,
381 | vision_heads=16,
382 | vision_layers=24,
383 | vocab_size=49408,
384 | text_len=77,
385 | text_dim=768,
386 | text_heads=12,
387 | text_layers=12)
388 | cfg.update(**kwargs)
389 | return _clip('openai-clip-vit-large-14-336px', pretrained, **cfg)
390 |
391 | def clip_vit_h_16(pretrained=False, **kwargs):
392 | assert not pretrained, 'pretrained model for openai-clip-vit-huge-16 is not available!'
393 | cfg = dict(
394 | embed_dim=1024,
395 | image_size=256,
396 | patch_size=16,
397 | vision_dim=1280,
398 | vision_heads=16,
399 | vision_layers=32,
400 | vocab_size=49408,
401 | text_len=77,
402 | text_dim=1024,
403 | text_heads=16,
404 | text_layers=24)
405 | cfg.update(**kwargs)
406 | return _clip('openai-clip-vit-huge-16', pretrained, **cfg)
407 |
--------------------------------------------------------------------------------
/artist/models/midas.py:
--------------------------------------------------------------------------------
1 | r"""A much cleaner re-implementation of ``https://github.com/isl-org/MiDaS''.
2 | Image augmentation: T.Compose([
3 | Resize(
4 | keep_aspect_ratio=True,
5 | ensure_multiple_of=32,
6 | interpolation=cv2.INTER_CUBIC),
7 | T.ToTensor(),
8 | T.Normalize(
9 | mean=[0.5, 0.5, 0.5],
10 | std=[0.5, 0.5, 0.5])]).
11 | Fast inference:
12 | model = model.to(memory_format=torch.channels_last).half()
13 | input = input.to(memory_format=torch.channels_last).half()
14 | output = model(input)
15 | """
16 | import torch
17 | import torch.nn as nn
18 | import torch.nn.functional as F
19 | import math
20 |
21 | from artist import DOWNLOAD_TO_CACHE
22 |
23 | __all__ = ['MiDaS', 'midas_v3']
24 |
25 | class SelfAttention(nn.Module):
26 |
27 | def __init__(self, dim, num_heads):
28 | assert dim % num_heads == 0
29 | super(SelfAttention, self).__init__()
30 | self.dim = dim
31 | self.num_heads = num_heads
32 | self.head_dim = dim // num_heads
33 | self.scale = 1.0 / math.sqrt(self.head_dim)
34 |
35 | # layers
36 | self.to_qkv = nn.Linear(dim, dim * 3)
37 | self.proj = nn.Linear(dim, dim)
38 |
39 | def forward(self, x):
40 | b, l, c, n, d = *x.size(), self.num_heads, self.head_dim
41 |
42 | # compute query, key, value
43 | q, k, v = self.to_qkv(x).view(b, l, n * 3, d).chunk(3, dim=2)
44 |
45 | # compute attention
46 | attn = self.scale * torch.einsum('binc,bjnc->bnij', q, k)
47 | attn = F.softmax(attn.float(), dim=-1).type_as(attn)
48 |
49 | # gather context
50 | x = torch.einsum('bnij,bjnc->binc', attn, v)
51 | x = x.reshape(b, l, c)
52 |
53 | # output
54 | x = self.proj(x)
55 | return x
56 |
57 | class AttentionBlock(nn.Module):
58 |
59 | def __init__(self, dim, num_heads):
60 | super(AttentionBlock, self).__init__()
61 | self.dim = dim
62 | self.num_heads = num_heads
63 |
64 | # layers
65 | self.norm1 = nn.LayerNorm(dim)
66 | self.attn = SelfAttention(dim, num_heads)
67 | self.norm2 = nn.LayerNorm(dim)
68 | self.mlp = nn.Sequential(
69 | nn.Linear(dim, dim * 4),
70 | nn.GELU(),
71 | nn.Linear(dim * 4, dim))
72 |
73 | def forward(self, x):
74 | x = x + self.attn(self.norm1(x))
75 | x = x + self.mlp(self.norm2(x))
76 | return x
77 |
78 | class VisionTransformer(nn.Module):
79 |
80 | def __init__(self,
81 | image_size=384,
82 | patch_size=16,
83 | dim=1024,
84 | out_dim=1000,
85 | num_heads=16,
86 | num_layers=24):
87 | assert image_size % patch_size == 0
88 | super(VisionTransformer, self).__init__()
89 | self.image_size = image_size
90 | self.patch_size = patch_size
91 | self.dim = dim
92 | self.out_dim = out_dim
93 | self.num_heads = num_heads
94 | self.num_layers = num_layers
95 | self.num_patches = (image_size // patch_size) ** 2
96 |
97 | # embeddings
98 | self.patch_embedding = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size)
99 | self.cls_embedding = nn.Parameter(torch.zeros(1, 1, dim))
100 | self.pos_embedding = nn.Parameter(torch.empty(1, self.num_patches + 1, dim).normal_(std=0.02))
101 |
102 | # blocks
103 | self.blocks = nn.Sequential(*[AttentionBlock(dim, num_heads) for _ in range(num_layers)])
104 | self.norm = nn.LayerNorm(dim)
105 |
106 | # head
107 | self.head = nn.Linear(dim, out_dim)
108 |
109 | def forward(self, x):
110 | b = x.size(0)
111 |
112 | # embeddings
113 | x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
114 | x = torch.cat([self.cls_embedding.repeat(b, 1, 1), x], dim=1)
115 | x = x + self.pos_embedding
116 |
117 | # blocks
118 | x = self.blocks(x)
119 | x = self.norm(x)
120 |
121 | # head
122 | x = self.head(x)
123 | return x
124 |
125 | class ResidualBlock(nn.Module):
126 |
127 | def __init__(self, dim):
128 | super(ResidualBlock, self).__init__()
129 | self.dim = dim
130 |
131 | # layers
132 | self.residual = nn.Sequential(
133 | nn.ReLU(inplace=False), # NOTE: avoid modifying the input
134 | nn.Conv2d(dim, dim, 3, padding=1),
135 | nn.ReLU(inplace=True),
136 | nn.Conv2d(dim, dim, 3, padding=1))
137 |
138 | def forward(self, x):
139 | return x + self.residual(x)
140 |
141 | class FusionBlock(nn.Module):
142 |
143 | def __init__(self, dim):
144 | super(FusionBlock, self).__init__()
145 | self.dim = dim
146 |
147 | # layers
148 | self.layer1 = ResidualBlock(dim)
149 | self.layer2 = ResidualBlock(dim)
150 | self.conv_out = nn.Conv2d(dim, dim, 1)
151 |
152 | def forward(self, *xs):
153 | assert len(xs) in (1, 2), 'invalid number of inputs'
154 | if len(xs) == 1:
155 | x = self.layer2(xs[0])
156 | else:
157 | x = self.layer2(xs[0] + self.layer1(xs[1]))
158 | x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
159 | x = self.conv_out(x)
160 | return x
161 |
162 | class MiDaS(nn.Module):
163 | r"""MiDaS v3.0 DPT-Large from ``https://github.com/isl-org/MiDaS''.
164 | Monocular depth estimation using dense prediction transformers.
165 | """
166 | def __init__(self,
167 | image_size=384,
168 | patch_size=16,
169 | dim=1024,
170 | neck_dims=[256, 512, 1024, 1024],
171 | fusion_dim=256,
172 | num_heads=16,
173 | num_layers=24):
174 | assert image_size % patch_size == 0
175 | assert num_layers % 4 == 0
176 | super(MiDaS, self).__init__()
177 | self.image_size = image_size
178 | self.patch_size = patch_size
179 | self.dim = dim
180 | self.neck_dims = neck_dims
181 | self.fusion_dim = fusion_dim
182 | self.num_heads = num_heads
183 | self.num_layers = num_layers
184 | self.num_patches = (image_size // patch_size) ** 2
185 |
186 | # embeddings
187 | self.patch_embedding = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size)
188 | self.cls_embedding = nn.Parameter(torch.zeros(1, 1, dim))
189 | self.pos_embedding = nn.Parameter(torch.empty(1, self.num_patches + 1, dim).normal_(std=0.02))
190 |
191 | # blocks
192 | stride = num_layers // 4
193 | self.blocks = nn.Sequential(*[AttentionBlock(dim, num_heads) for _ in range(num_layers)])
194 | self.slices = [slice(i * stride, (i + 1) * stride) for i in range(4)]
195 |
196 | # stage1 (4x)
197 | self.fc1 = nn.Sequential(
198 | nn.Linear(dim * 2, dim),
199 | nn.GELU())
200 | self.conv1 = nn.Sequential(
201 | nn.Conv2d(dim, neck_dims[0], 1),
202 | nn.ConvTranspose2d(neck_dims[0], neck_dims[0], 4, stride=4),
203 | nn.Conv2d(neck_dims[0], fusion_dim, 3, padding=1, bias=False))
204 | self.fusion1 = FusionBlock(fusion_dim)
205 |
206 | # stage2 (8x)
207 | self.fc2 = nn.Sequential(
208 | nn.Linear(dim * 2, dim),
209 | nn.GELU())
210 | self.conv2 = nn.Sequential(
211 | nn.Conv2d(dim, neck_dims[1], 1),
212 | nn.ConvTranspose2d(neck_dims[1], neck_dims[1], 2, stride=2),
213 | nn.Conv2d(neck_dims[1], fusion_dim, 3, padding=1, bias=False))
214 | self.fusion2 = FusionBlock(fusion_dim)
215 |
216 | # stage3 (16x)
217 | self.fc3 = nn.Sequential(
218 | nn.Linear(dim * 2, dim),
219 | nn.GELU())
220 | self.conv3 = nn.Sequential(
221 | nn.Conv2d(dim, neck_dims[2], 1),
222 | nn.Conv2d(neck_dims[2], fusion_dim, 3, padding=1, bias=False))
223 | self.fusion3 = FusionBlock(fusion_dim)
224 |
225 | # stage4 (32x)
226 | self.fc4 = nn.Sequential(
227 | nn.Linear(dim * 2, dim),
228 | nn.GELU())
229 | self.conv4 = nn.Sequential(
230 | nn.Conv2d(dim, neck_dims[3], 1),
231 | nn.Conv2d(neck_dims[3], neck_dims[3], 3, stride=2, padding=1),
232 | nn.Conv2d(neck_dims[3], fusion_dim, 3, padding=1, bias=False))
233 | self.fusion4 = FusionBlock(fusion_dim)
234 |
235 | # head
236 | self.head = nn.Sequential(
237 | nn.Conv2d(fusion_dim, fusion_dim // 2, 3, padding=1),
238 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
239 | nn.Conv2d(fusion_dim // 2, 32, 3, padding=1),
240 | nn.ReLU(inplace=True),
241 | nn.ConvTranspose2d(32, 1, 1),
242 | nn.ReLU(inplace=True))
243 |
244 | def forward(self, x):
245 | b, c, h, w, p = *x.size(), self.patch_size
246 | assert h % p == 0 and w % p == 0, f'Image size ({w}, {h}) is not divisible by patch size ({p}, {p})'
247 | hp, wp, grid = h // p, w // p, self.image_size // p
248 |
249 | # embeddings
250 | pos_embedding = torch.cat([
251 | self.pos_embedding[:, :1],
252 | F.interpolate(
253 | self.pos_embedding[:, 1:].reshape(1, grid, grid, -1).permute(0, 3, 1, 2),
254 | size=(hp, wp),
255 | mode='bilinear',
256 | align_corners=False).permute(0, 2, 3, 1).reshape(1, hp * wp, -1)], dim=1)
257 | x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
258 | x = torch.cat([self.cls_embedding.repeat(b, 1, 1), x], dim=1)
259 | x = x + pos_embedding
260 |
261 | # stage1
262 | x = self.blocks[self.slices[0]](x)
263 | x1 = torch.cat([x[:, 1:], x[:, :1].expand_as(x[:, 1:])], dim=-1)
264 | x1 = self.fc1(x1).permute(0, 2, 1).unflatten(2, (hp, wp))
265 | x1 = self.conv1(x1)
266 |
267 | # stage2
268 | x = self.blocks[self.slices[1]](x)
269 | x2 = torch.cat([x[:, 1:], x[:, :1].expand_as(x[:, 1:])], dim=-1)
270 | x2 = self.fc2(x2).permute(0, 2, 1).unflatten(2, (hp, wp))
271 | x2 = self.conv2(x2)
272 |
273 | # stage3
274 | x = self.blocks[self.slices[2]](x)
275 | x3 = torch.cat([x[:, 1:], x[:, :1].expand_as(x[:, 1:])], dim=-1)
276 | x3 = self.fc3(x3).permute(0, 2, 1).unflatten(2, (hp, wp))
277 | x3 = self.conv3(x3)
278 |
279 | # stage4
280 | x = self.blocks[self.slices[3]](x)
281 | x4 = torch.cat([x[:, 1:], x[:, :1].expand_as(x[:, 1:])], dim=-1)
282 | x4 = self.fc4(x4).permute(0, 2, 1).unflatten(2, (hp, wp))
283 | x4 = self.conv4(x4)
284 |
285 | # fusion
286 | x4 = self.fusion4(x4)
287 | x3 = self.fusion3(x4, x3)
288 | x2 = self.fusion2(x3, x2)
289 | x1 = self.fusion1(x2, x1)
290 |
291 | # head
292 | x = self.head(x1)
293 | return x
294 |
295 | def midas_v3(pretrained=False, **kwargs):
296 | cfg = dict(
297 | image_size=384,
298 | patch_size=16,
299 | dim=1024,
300 | neck_dims=[256, 512, 1024, 1024],
301 | fusion_dim=256,
302 | num_heads=16,
303 | num_layers=24)
304 | cfg.update(**kwargs)
305 | model = MiDaS(**cfg)
306 | if pretrained:
307 | # model.load_state_dict(torch.load(DOWNLOAD_TO_CACHE('experiments/models/midas/midas_v3_dpt_large.pth'), map_location='cpu'))
308 | model.load_state_dict(torch.load("./model_weights/midas_v3_dpt_large.pth", map_location='cpu'))
309 | return model
310 |
--------------------------------------------------------------------------------
/artist/ops/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import *
2 | from .distributed import *
3 | from .diffusion import *
4 | from .losses import *
5 | from .degration import *
6 | from .random_mask import *
--------------------------------------------------------------------------------
/artist/ops/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/ops/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/artist/ops/__pycache__/degration.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/ops/__pycache__/degration.cpython-38.pyc
--------------------------------------------------------------------------------
/artist/ops/__pycache__/diffusion.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/ops/__pycache__/diffusion.cpython-38.pyc
--------------------------------------------------------------------------------
/artist/ops/__pycache__/distributed.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/ops/__pycache__/distributed.cpython-38.pyc
--------------------------------------------------------------------------------
/artist/ops/__pycache__/dpm_solver.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/ops/__pycache__/dpm_solver.cpython-38.pyc
--------------------------------------------------------------------------------
/artist/ops/__pycache__/losses.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/ops/__pycache__/losses.cpython-38.pyc
--------------------------------------------------------------------------------
/artist/ops/__pycache__/random_mask.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/ops/__pycache__/random_mask.cpython-38.pyc
--------------------------------------------------------------------------------
/artist/ops/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/ops/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/artist/ops/distributed.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.distributed as dist
4 | import functools
5 | import pickle
6 | import numpy as np
7 | from collections import OrderedDict
8 | from torch.autograd import Function
9 |
10 | __all__ = ['is_dist_initialized',
11 | 'get_world_size',
12 | 'get_rank',
13 | 'new_group',
14 | 'destroy_process_group',
15 | 'barrier',
16 | 'broadcast',
17 | 'all_reduce',
18 | 'reduce',
19 | 'gather',
20 | 'all_gather',
21 | 'reduce_dict',
22 | 'get_global_gloo_group',
23 | 'generalized_all_gather',
24 | 'generalized_gather',
25 | 'scatter',
26 | 'reduce_scatter',
27 | 'send',
28 | 'recv',
29 | 'isend',
30 | 'irecv',
31 | 'shared_random_seed',
32 | 'diff_all_gather',
33 | 'diff_all_reduce',
34 | 'diff_scatter',
35 | 'diff_copy',
36 | 'spherical_kmeans',
37 | 'sinkhorn']
38 |
39 | #-------------------------------- Distributed operations --------------------------------#
40 |
41 | def is_dist_initialized():
42 | return dist.is_available() and dist.is_initialized()
43 |
44 | def get_world_size(group=None):
45 | return dist.get_world_size(group) if is_dist_initialized() else 1
46 |
47 | def get_rank(group=None):
48 | return dist.get_rank(group) if is_dist_initialized() else 0
49 |
50 | def new_group(ranks=None, **kwargs):
51 | if is_dist_initialized():
52 | return dist.new_group(ranks, **kwargs)
53 | return None
54 |
55 | def destroy_process_group():
56 | if is_dist_initialized():
57 | dist.destroy_process_group()
58 |
59 | def barrier(group=None, **kwargs):
60 | if get_world_size(group) > 1:
61 | dist.barrier(group, **kwargs)
62 |
63 | def broadcast(tensor, src, group=None, **kwargs):
64 | if get_world_size(group) > 1:
65 | return dist.broadcast(tensor, src, group, **kwargs)
66 |
67 | def all_reduce(tensor, op=dist.ReduceOp.SUM, group=None, **kwargs):
68 | if get_world_size(group) > 1:
69 | return dist.all_reduce(tensor, op, group, **kwargs)
70 |
71 | def reduce(tensor, dst, op=dist.ReduceOp.SUM, group=None, **kwargs):
72 | if get_world_size(group) > 1:
73 | return dist.reduce(tensor, dst, op, group, **kwargs)
74 |
75 | def gather(tensor, dst=0, group=None, **kwargs):
76 | rank = get_rank() # global rank
77 | world_size = get_world_size(group)
78 | if world_size == 1:
79 | return [tensor]
80 | tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] if rank == dst else None
81 | dist.gather(tensor, tensor_list, dst, group, **kwargs)
82 | return tensor_list
83 |
84 | def all_gather(tensor, uniform_size=True, group=None, **kwargs):
85 | world_size = get_world_size(group)
86 | if world_size == 1:
87 | return [tensor]
88 | assert tensor.is_contiguous(), 'ops.all_gather requires the tensor to be contiguous()'
89 |
90 | if uniform_size:
91 | tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
92 | dist.all_gather(tensor_list, tensor, group, **kwargs)
93 | return tensor_list
94 | else:
95 | # collect tensor shapes across GPUs
96 | shape = tuple(tensor.shape)
97 | shape_list = generalized_all_gather(shape, group)
98 |
99 | # flatten the tensor
100 | tensor = tensor.reshape(-1)
101 | size = int(np.prod(shape))
102 | size_list = [int(np.prod(u)) for u in shape_list]
103 | max_size = max(size_list)
104 |
105 | # pad to maximum size
106 | if size != max_size:
107 | padding = tensor.new_zeros(max_size - size)
108 | tensor = torch.cat([tensor, padding], dim=0)
109 |
110 | # all_gather
111 | tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
112 | dist.all_gather(tensor_list, tensor, group, **kwargs)
113 |
114 | # reshape tensors
115 | tensor_list = [t[:n].view(s) for t, n, s in zip(
116 | tensor_list, size_list, shape_list)]
117 | return tensor_list
118 |
119 | @torch.no_grad()
120 | def reduce_dict(input_dict, group=None, reduction='mean', **kwargs):
121 | assert reduction in ['mean', 'sum']
122 | world_size = get_world_size(group)
123 | if world_size == 1:
124 | return input_dict
125 |
126 | # ensure that the orders of keys are consistent across processes
127 | if isinstance(input_dict, OrderedDict):
128 | keys = list(input_dict.keys)
129 | else:
130 | keys = sorted(input_dict.keys())
131 | vals = [input_dict[key] for key in keys]
132 | vals = torch.stack(vals, dim=0)
133 | dist.reduce(vals, dst=0, group=group, **kwargs)
134 | if dist.get_rank(group) == 0 and reduction == 'mean':
135 | vals /= world_size
136 | dist.broadcast(vals, src=0, group=group, **kwargs)
137 | reduced_dict = type(input_dict)([
138 | (key, val) for key, val in zip(keys, vals)])
139 | return reduced_dict
140 |
141 | @functools.lru_cache()
142 | def get_global_gloo_group():
143 | backend = dist.get_backend()
144 | assert backend in ['gloo', 'nccl']
145 | if backend == 'nccl':
146 | return dist.new_group(backend='gloo')
147 | else:
148 | return dist.group.WORLD
149 |
150 | def _serialize_to_tensor(data, group):
151 | backend = dist.get_backend(group)
152 | assert backend in ['gloo', 'nccl']
153 | device = torch.device('cpu' if backend == 'gloo' else 'cuda')
154 |
155 | buffer = pickle.dumps(data)
156 | if len(buffer) > 1024 ** 3:
157 | logger = logging.getLogger(__name__)
158 | logger.warning(
159 | 'Rank {} trying to all-gather {:.2f} GB of data on device'
160 | '{}'.format(get_rank(), len(buffer) / (1024 ** 3), device))
161 | storage = torch.ByteStorage.from_buffer(buffer)
162 | tensor = torch.ByteTensor(storage).to(device=device)
163 | return tensor
164 |
165 | def _pad_to_largest_tensor(tensor, group):
166 | world_size = dist.get_world_size(group=group)
167 | assert world_size >= 1, \
168 | 'gather/all_gather must be called from ranks within' \
169 | 'the give group!'
170 | local_size = torch.tensor(
171 | [tensor.numel()], dtype=torch.int64, device=tensor.device)
172 | size_list = [torch.zeros(
173 | [1], dtype=torch.int64, device=tensor.device)
174 | for _ in range(world_size)]
175 |
176 | # gather tensors and compute the maximum size
177 | dist.all_gather(size_list, local_size, group=group)
178 | size_list = [int(size.item()) for size in size_list]
179 | max_size = max(size_list)
180 |
181 | # pad tensors to the same size
182 | if local_size != max_size:
183 | padding = torch.zeros(
184 | (max_size - local_size, ),
185 | dtype=torch.uint8, device=tensor.device)
186 | tensor = torch.cat((tensor, padding), dim=0)
187 | return size_list, tensor
188 |
189 | def generalized_all_gather(data, group=None):
190 | if get_world_size(group) == 1:
191 | return [data]
192 | if group is None:
193 | group = get_global_gloo_group()
194 |
195 | tensor = _serialize_to_tensor(data, group)
196 | size_list, tensor = _pad_to_largest_tensor(tensor, group)
197 | max_size = max(size_list)
198 |
199 | # receiving tensors from all ranks
200 | tensor_list = [torch.empty(
201 | (max_size, ), dtype=torch.uint8, device=tensor.device)
202 | for _ in size_list]
203 | dist.all_gather(tensor_list, tensor, group=group)
204 |
205 | data_list = []
206 | for size, tensor in zip(size_list, tensor_list):
207 | buffer = tensor.cpu().numpy().tobytes()[:size]
208 | data_list.append(pickle.loads(buffer))
209 | return data_list
210 |
211 | def generalized_gather(data, dst=0, group=None):
212 | world_size = get_world_size(group)
213 | if world_size == 1:
214 | return [data]
215 | if group is None:
216 | group = get_global_gloo_group()
217 | rank = dist.get_rank() # global rank
218 |
219 | tensor = _serialize_to_tensor(data, group)
220 | size_list, tensor = _pad_to_largest_tensor(tensor, group)
221 |
222 | # receiving tensors from all ranks to dst
223 | if rank == dst:
224 | max_size = max(size_list)
225 | tensor_list = [torch.empty(
226 | (max_size, ), dtype=torch.uint8, device=tensor.device)
227 | for _ in size_list]
228 | dist.gather(tensor, tensor_list, dst=dst, group=group)
229 |
230 | data_list = []
231 | for size, tensor in zip(size_list, tensor_list):
232 | buffer = tensor.cpu().numpy().tobytes()[:size]
233 | data_list.append(pickle.loads(buffer))
234 | return data_list
235 | else:
236 | dist.gather(tensor, [], dst=dst, group=group)
237 | return []
238 |
239 | def scatter(data, scatter_list=None, src=0, group=None, **kwargs):
240 | r"""NOTE: only supports CPU tensor communication.
241 | """
242 | if get_world_size(group) > 1:
243 | return dist.scatter(data, scatter_list, src, group, **kwargs)
244 |
245 | def reduce_scatter(output, input_list, op=dist.ReduceOp.SUM, group=None, **kwargs):
246 | if get_world_size(group) > 1:
247 | return dist.reduce_scatter(output, input_list, op, group, **kwargs)
248 |
249 | def send(tensor, dst, group=None, **kwargs):
250 | if get_world_size(group) > 1:
251 | assert tensor.is_contiguous(), 'ops.send requires the tensor to be contiguous()'
252 | return dist.send(tensor, dst, group, **kwargs)
253 |
254 | def recv(tensor, src=None, group=None, **kwargs):
255 | if get_world_size(group) > 1:
256 | assert tensor.is_contiguous(), 'ops.recv requires the tensor to be contiguous()'
257 | return dist.recv(tensor, src, group, **kwargs)
258 |
259 | def isend(tensor, dst, group=None, **kwargs):
260 | if get_world_size(group) > 1:
261 | assert tensor.is_contiguous(), 'ops.isend requires the tensor to be contiguous()'
262 | return dist.isend(tensor, dst, group, **kwargs)
263 |
264 | def irecv(tensor, src=None, group=None, **kwargs):
265 | if get_world_size(group) > 1:
266 | assert tensor.is_contiguous(), 'ops.irecv requires the tensor to be contiguous()'
267 | return dist.irecv(tensor, src, group, **kwargs)
268 |
269 | def shared_random_seed(group=None):
270 | seed = np.random.randint(2 ** 31)
271 | all_seeds = generalized_all_gather(seed, group)
272 | return all_seeds[0]
273 |
274 | #-------------------------------- Differentiable operations --------------------------------#
275 |
276 | def _all_gather(x):
277 | if not (dist.is_available() and dist.is_initialized()) or dist.get_world_size() == 1:
278 | return x
279 | rank = dist.get_rank()
280 | world_size = dist.get_world_size()
281 | tensors = [torch.empty_like(x) for _ in range(world_size)]
282 | tensors[rank] = x
283 | dist.all_gather(tensors, x)
284 | return torch.cat(tensors, dim=0).contiguous()
285 |
286 | def _all_reduce(x):
287 | if not (dist.is_available() and dist.is_initialized()) or dist.get_world_size() == 1:
288 | return x
289 | dist.all_reduce(x)
290 | return x
291 |
292 | def _split(x):
293 | if not (dist.is_available() and dist.is_initialized()) or dist.get_world_size() == 1:
294 | return x
295 | rank = dist.get_rank()
296 | world_size = dist.get_world_size()
297 | return x.chunk(world_size, dim=0)[rank].contiguous()
298 |
299 | class DiffAllGather(Function):
300 | r"""Differentiable all-gather.
301 | """
302 | @staticmethod
303 | def symbolic(graph, input):
304 | return _all_gather(input)
305 |
306 | @staticmethod
307 | def forward(ctx, input):
308 | return _all_gather(input)
309 |
310 | @staticmethod
311 | def backward(ctx, grad_output):
312 | return _split(grad_output)
313 |
314 | class DiffAllReduce(Function):
315 | r"""Differentiable all-reducd.
316 | """
317 | @staticmethod
318 | def symbolic(graph, input):
319 | return _all_reduce(input)
320 |
321 | @staticmethod
322 | def forward(ctx, input):
323 | return _all_reduce(input)
324 |
325 | @staticmethod
326 | def backward(ctx, grad_output):
327 | return grad_output
328 |
329 | class DiffScatter(Function):
330 | r"""Differentiable scatter.
331 | """
332 | @staticmethod
333 | def symbolic(graph, input):
334 | return _split(input)
335 |
336 | @staticmethod
337 | def symbolic(ctx, input):
338 | return _split(input)
339 |
340 | @staticmethod
341 | def backward(ctx, grad_output):
342 | return _all_gather(grad_output)
343 |
344 | class DiffCopy(Function):
345 | r"""Differentiable copy that reduces all gradients during backward.
346 | """
347 | @staticmethod
348 | def symbolic(graph, input):
349 | return input
350 |
351 | @staticmethod
352 | def forward(ctx, input):
353 | return input
354 |
355 | @staticmethod
356 | def backward(ctx, grad_output):
357 | return _all_reduce(grad_output)
358 |
359 | diff_all_gather = DiffAllGather.apply
360 | diff_all_reduce = DiffAllReduce.apply
361 | diff_scatter = DiffScatter.apply
362 | diff_copy = DiffCopy.apply
363 |
364 | #-------------------------------- Distributed algorithms --------------------------------#
365 |
366 | @torch.no_grad()
367 | def spherical_kmeans(feats, num_clusters, num_iters=10):
368 | k, n, c = num_clusters, *feats.size()
369 | ones = feats.new_ones(n, dtype=torch.long)
370 |
371 | # distributed settings
372 | rank = get_rank()
373 | world_size = get_world_size()
374 |
375 | # init clusters
376 | rand_inds = torch.randperm(n)[:int(np.ceil(k / world_size))]
377 | clusters = torch.cat(all_gather(feats[rand_inds]), dim=0)[:k]
378 |
379 | # variables
380 | new_clusters = feats.new_zeros(k, c)
381 | counts = feats.new_zeros(k, dtype=torch.long)
382 |
383 | # iterative Expectation-Maximization
384 | for step in range(num_iters + 1):
385 | # Expectation step
386 | simmat = torch.mm(feats, clusters.t())
387 | scores, assigns = simmat.max(dim=1)
388 | if step == num_iters:
389 | break
390 |
391 | # Maximization step
392 | new_clusters.zero_().scatter_add_(0, assigns.unsqueeze(1).repeat(1, c), feats)
393 | all_reduce(new_clusters)
394 |
395 | counts.zero_()
396 | counts.index_add_(0, assigns, ones)
397 | all_reduce(counts)
398 |
399 | mask = (counts > 0)
400 | clusters[mask] = new_clusters[mask] / counts[mask].view(-1, 1)
401 | clusters = F.normalize(clusters, p=2, dim=1)
402 | return clusters, assigns, scores
403 |
404 | @torch.no_grad()
405 | def sinkhorn(Q, eps=0.5, num_iters=3):
406 | # normalize Q
407 | Q = torch.exp(Q / eps).t()
408 | sum_Q = Q.sum()
409 | all_reduce(sum_Q)
410 | Q /= sum_Q
411 |
412 | # variables
413 | n, m = Q.size()
414 | u = Q.new_zeros(n)
415 | r = Q.new_ones(n) / n
416 | c = Q.new_ones(m) / (m * get_world_size())
417 |
418 | # iterative update
419 | cur_sum = Q.sum(dim=1)
420 | all_reduce(cur_sum)
421 | for i in range(num_iters):
422 | u = cur_sum
423 | Q *= (r / u).unsqueeze(1)
424 | Q *= (c / Q.sum(dim=0)).unsqueeze(0)
425 | cur_sum = Q.sum(dim=1)
426 | all_reduce(cur_sum)
427 | return (Q / Q.sum(dim=0, keepdim=True)).t().float()
428 |
--------------------------------------------------------------------------------
/artist/ops/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 |
4 | __all__ = ['kl_divergence', 'discretized_gaussian_log_likelihood']
5 |
6 | def kl_divergence(mu1, logvar1, mu2, logvar2):
7 | return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mu1 - mu2) ** 2) * torch.exp(-logvar2))
8 |
9 | def standard_normal_cdf(x):
10 | r"""A fast approximation of the cumulative distribution function of the standard normal.
11 | """
12 | return 0.5 * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
13 |
14 | def discretized_gaussian_log_likelihood(x0, mean, log_scale):
15 | assert x0.shape == mean.shape == log_scale.shape
16 | cx = x0 - mean
17 | inv_stdv = torch.exp(-log_scale)
18 | cdf_plus = standard_normal_cdf(inv_stdv * (cx + 1.0 / 255.0))
19 | cdf_min = standard_normal_cdf(inv_stdv * (cx - 1.0 / 255.0))
20 | log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12))
21 | log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12))
22 | cdf_delta = cdf_plus - cdf_min
23 | log_probs = torch.where(
24 | x0 < -0.999,
25 | log_cdf_plus,
26 | torch.where(x0 > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))))
27 | assert log_probs.shape == x0.shape
28 | return log_probs
29 |
--------------------------------------------------------------------------------
/artist/ops/random_mask.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 |
4 | __all__ = ['make_irregular_mask', 'make_rectangle_mask', 'make_uncrop']
5 |
6 | def make_irregular_mask(w, h, max_angle=4, max_length=200, max_width=100, min_strokes=1, max_strokes=5, mode='line'):
7 | # initialize mask
8 | assert mode in ['line', 'circle', 'square']
9 | mask = np.zeros((h, w), np.float32)
10 |
11 | # draw strokes
12 | num_strokes = np.random.randint(min_strokes, max_strokes + 1)
13 | for i in range(num_strokes):
14 | x1 = np.random.randint(w)
15 | y1 = np.random.randint(h)
16 | for j in range(1 + np.random.randint(5)):
17 | angle = 0.01 + np.random.randint(max_angle)
18 | if i % 2 == 0:
19 | angle = 2 * 3.1415926 - angle
20 | length = 10 + np.random.randint(max_length)
21 | radius = 5 + np.random.randint(max_width)
22 | x2 = np.clip((x1 + length * np.sin(angle)).astype(np.int32), 0, w)
23 | y2 = np.clip((y1 + length * np.cos(angle)).astype(np.int32), 0, h)
24 | if mode == 'line':
25 | cv2.line(mask, (x1, y1), (x2, y2), 1.0, radius)
26 | elif mode == 'circle':
27 | cv2.circle(mask, (x1, y1), radius=radius, color=1.0, thickness=-1)
28 | elif mode == 'square':
29 | radius = radius // 2
30 | mask[y1 - radius:y1 + radius, x1 - radius:x1 + radius] = 1
31 | x1, y1 = x2, y2
32 | return mask
33 |
34 | def make_rectangle_mask(w, h, margin=10, min_size=30, max_size=150, min_strokes=1, max_strokes=4):
35 | # initialize mask
36 | mask = np.zeros((h, w), np.float32)
37 |
38 | # draw rectangles
39 | num_strokes = np.random.randint(min_strokes, max_strokes + 1)
40 | for i in range(num_strokes):
41 | box_w = np.random.randint(min_size, max_size)
42 | box_h = np.random.randint(min_size, max_size)
43 | x1 = np.random.randint(margin, w - margin - box_w + 1)
44 | y1 = np.random.randint(margin, h - margin - box_h + 1)
45 | mask[y1:y1 + box_h, x1:x1 + box_w] = 1
46 | return mask
47 |
48 | def make_uncrop(w, h):
49 | # initialize mask
50 | mask = np.zeros((h, w), np.float32)
51 |
52 | # randomly halve the image
53 | side = np.random.choice([0, 1, 2, 3])
54 | if side == 0:
55 | mask[:h // 2, :] = 1
56 | elif side == 1:
57 | mask[h // 2:, :] = 1
58 | elif side == 2:
59 | mask[:, :w // 2] = 1
60 | elif side == 3:
61 | mask[:, w // 2:] = 1
62 | return mask
63 |
--------------------------------------------------------------------------------
/artist/optim/__init__.py:
--------------------------------------------------------------------------------
1 | from .lr_scheduler import *
2 | from .adafactor import *
--------------------------------------------------------------------------------
/artist/optim/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/optim/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/artist/optim/__pycache__/adafactor.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/optim/__pycache__/adafactor.cpython-38.pyc
--------------------------------------------------------------------------------
/artist/optim/__pycache__/lr_scheduler.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/artist/optim/__pycache__/lr_scheduler.cpython-38.pyc
--------------------------------------------------------------------------------
/artist/optim/adafactor.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.optim import Optimizer
4 | from torch.optim.lr_scheduler import LambdaLR
5 |
6 | __all__ = ['Adafactor']
7 |
8 | class Adafactor(Optimizer):
9 | """
10 | AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code:
11 | https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
12 | Paper: *Adafactor: Adaptive Learning Rates with Sublinear Memory Cost* https://arxiv.org/abs/1804.04235 Note that
13 | this optimizer internally adjusts the learning rate depending on the `scale_parameter`, `relative_step` and
14 | `warmup_init` options. To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
15 | `relative_step=False`.
16 | Arguments:
17 | params (`Iterable[nn.parameter.Parameter]`):
18 | Iterable of parameters to optimize or dictionaries defining parameter groups.
19 | lr (`float`, *optional*):
20 | The external learning rate.
21 | eps (`Tuple[float, float]`, *optional*, defaults to (1e-30, 1e-3)):
22 | Regularization constants for square gradient and parameter scale respectively
23 | clip_threshold (`float`, *optional*, defaults 1.0):
24 | Threshold of root mean square of final gradient update
25 | decay_rate (`float`, *optional*, defaults to -0.8):
26 | Coefficient used to compute running averages of square
27 | beta1 (`float`, *optional*):
28 | Coefficient used for computing running averages of gradient
29 | weight_decay (`float`, *optional*, defaults to 0):
30 | Weight decay (L2 penalty)
31 | scale_parameter (`bool`, *optional*, defaults to `True`):
32 | If True, learning rate is scaled by root mean square
33 | relative_step (`bool`, *optional*, defaults to `True`):
34 | If True, time-dependent learning rate is computed instead of external learning rate
35 | warmup_init (`bool`, *optional*, defaults to `False`):
36 | Time-dependent learning rate computation depends on whether warm-up initialization is being used
37 | This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested.
38 | Recommended T5 finetuning settings (https://discuss.huggingface.co/t/t5-finetuning-tips/684/3):
39 | - Training without LR warmup or clip_threshold is not recommended.
40 | - use scheduled LR warm-up to fixed LR
41 | - use clip_threshold=1.0 (https://arxiv.org/abs/1804.04235)
42 | - Disable relative updates
43 | - Use scale_parameter=False
44 | - Additional optimizer operations like gradient clipping should not be used alongside Adafactor
45 | Example:
46 | ```python
47 | Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3)
48 | ```
49 | Others reported the following combination to work well:
50 | ```python
51 | Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
52 | ```
53 | When using `lr=None` with [`Trainer`] you will most likely need to use [`~optimization.AdafactorSchedule`]
54 | scheduler as following:
55 | ```python
56 | from transformers.optimization import Adafactor, AdafactorSchedule
57 | optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
58 | lr_scheduler = AdafactorSchedule(optimizer)
59 | trainer = Trainer(..., optimizers=(optimizer, lr_scheduler))
60 | ```
61 | Usage:
62 | ```python
63 | # replace AdamW with Adafactor
64 | optimizer = Adafactor(
65 | model.parameters(),
66 | lr=1e-3,
67 | eps=(1e-30, 1e-3),
68 | clip_threshold=1.0,
69 | decay_rate=-0.8,
70 | beta1=None,
71 | weight_decay=0.0,
72 | relative_step=False,
73 | scale_parameter=False,
74 | warmup_init=False,
75 | )
76 | ```"""
77 |
78 | def __init__(
79 | self,
80 | params,
81 | lr=None,
82 | eps=(1e-30, 1e-3),
83 | clip_threshold=1.0,
84 | decay_rate=-0.8,
85 | beta1=None,
86 | weight_decay=0.0,
87 | scale_parameter=True,
88 | relative_step=True,
89 | warmup_init=False,
90 | ):
91 | r"""require_version("torch>=1.5.0") # add_ with alpha
92 | """
93 | if lr is not None and relative_step:
94 | raise ValueError("Cannot combine manual `lr` and `relative_step=True` options")
95 | if warmup_init and not relative_step:
96 | raise ValueError("`warmup_init=True` requires `relative_step=True`")
97 |
98 | defaults = dict(
99 | lr=lr,
100 | eps=eps,
101 | clip_threshold=clip_threshold,
102 | decay_rate=decay_rate,
103 | beta1=beta1,
104 | weight_decay=weight_decay,
105 | scale_parameter=scale_parameter,
106 | relative_step=relative_step,
107 | warmup_init=warmup_init,
108 | )
109 | super().__init__(params, defaults)
110 |
111 | @staticmethod
112 | def _get_lr(param_group, param_state):
113 | rel_step_sz = param_group["lr"]
114 | if param_group["relative_step"]:
115 | min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2
116 | rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"]))
117 | param_scale = 1.0
118 | if param_group["scale_parameter"]:
119 | param_scale = max(param_group["eps"][1], param_state["RMS"])
120 | return param_scale * rel_step_sz
121 |
122 | @staticmethod
123 | def _get_options(param_group, param_shape):
124 | factored = len(param_shape) >= 2
125 | use_first_moment = param_group["beta1"] is not None
126 | return factored, use_first_moment
127 |
128 | @staticmethod
129 | def _rms(tensor):
130 | return tensor.norm(2) / (tensor.numel() ** 0.5)
131 |
132 | @staticmethod
133 | def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
134 | # copy from fairseq's adafactor implementation:
135 | # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505
136 | r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
137 | c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
138 | return torch.mul(r_factor, c_factor)
139 |
140 | def step(self, closure=None):
141 | """
142 | Performs a single optimization step
143 | Arguments:
144 | closure (callable, optional): A closure that reevaluates the model
145 | and returns the loss.
146 | """
147 | loss = None
148 | if closure is not None:
149 | loss = closure()
150 |
151 | for group in self.param_groups:
152 | for p in group["params"]:
153 | if p.grad is None:
154 | continue
155 | grad = p.grad.data
156 | if grad.dtype in {torch.float16, torch.bfloat16}:
157 | grad = grad.float()
158 | if grad.is_sparse:
159 | raise RuntimeError("Adafactor does not support sparse gradients.")
160 |
161 | state = self.state[p]
162 | grad_shape = grad.shape
163 |
164 | factored, use_first_moment = self._get_options(group, grad_shape)
165 | # State Initialization
166 | if len(state) == 0:
167 | state["step"] = 0
168 |
169 | if use_first_moment:
170 | # Exponential moving average of gradient values
171 | state["exp_avg"] = torch.zeros_like(grad)
172 | if factored:
173 | state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
174 | state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
175 | else:
176 | state["exp_avg_sq"] = torch.zeros_like(grad)
177 |
178 | state["RMS"] = 0
179 | else:
180 | if use_first_moment:
181 | state["exp_avg"] = state["exp_avg"].to(grad)
182 | if factored:
183 | state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
184 | state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
185 | else:
186 | state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
187 |
188 | p_data_fp32 = p.data
189 | if p.data.dtype in {torch.float16, torch.bfloat16}:
190 | p_data_fp32 = p_data_fp32.float()
191 |
192 | state["step"] += 1
193 | state["RMS"] = self._rms(p_data_fp32)
194 | lr = self._get_lr(group, state)
195 |
196 | beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
197 | update = (grad**2) + group["eps"][0]
198 | if factored:
199 | exp_avg_sq_row = state["exp_avg_sq_row"]
200 | exp_avg_sq_col = state["exp_avg_sq_col"]
201 |
202 | exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
203 | exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
204 |
205 | # Approximation of exponential moving average of square of gradient
206 | update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
207 | update.mul_(grad)
208 | else:
209 | exp_avg_sq = state["exp_avg_sq"]
210 |
211 | exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
212 | update = exp_avg_sq.rsqrt().mul_(grad)
213 |
214 | update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
215 | update.mul_(lr)
216 |
217 | if use_first_moment:
218 | exp_avg = state["exp_avg"]
219 | exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
220 | update = exp_avg
221 |
222 | if group["weight_decay"] != 0:
223 | p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
224 |
225 | p_data_fp32.add_(-update)
226 |
227 | if p.data.dtype in {torch.float16, torch.bfloat16}:
228 | p.data.copy_(p_data_fp32)
229 |
230 | return loss
231 |
--------------------------------------------------------------------------------
/artist/optim/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import math
2 | from torch.optim.lr_scheduler import _LRScheduler
3 |
4 | __all__ = ['AnnealingLR']
5 |
6 | class AnnealingLR(_LRScheduler):
7 |
8 | def __init__(self, optimizer, base_lr, warmup_steps, total_steps, decay_mode='cosine', min_lr=0.0, last_step=-1):
9 | assert decay_mode in ['linear', 'cosine', 'none']
10 | self.optimizer = optimizer
11 | self.base_lr = base_lr
12 | self.warmup_steps = warmup_steps
13 | self.total_steps = total_steps
14 | self.decay_mode = decay_mode
15 | self.min_lr = min_lr
16 | self.current_step = last_step + 1
17 | self.step(self.current_step)
18 |
19 | def get_lr(self):
20 | if self.warmup_steps > 0 and self.current_step <= self.warmup_steps:
21 | return self.base_lr * self.current_step / self.warmup_steps
22 | else:
23 | ratio = (self.current_step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
24 | ratio = min(1.0, max(0.0, ratio))
25 | if self.decay_mode == 'linear':
26 | return self.base_lr * (1 - ratio)
27 | elif self.decay_mode == 'cosine':
28 | return self.base_lr * (math.cos(math.pi * ratio) + 1.0) / 2.0
29 | else:
30 | return self.base_lr
31 |
32 | def step(self, current_step=None):
33 | if current_step is None:
34 | current_step = self.current_step + 1
35 | self.current_step = current_step
36 | new_lr = max(self.min_lr, self.get_lr())
37 | if isinstance(self.optimizer, list):
38 | for o in self.optimizer:
39 | for group in o.param_groups:
40 | group['lr'] = new_lr
41 | else:
42 | for group in self.optimizer.param_groups:
43 | group['lr'] = new_lr
44 |
45 | def state_dict(self):
46 | return {
47 | 'base_lr': self.base_lr,
48 | 'warmup_steps': self.warmup_steps,
49 | 'total_steps': self.total_steps,
50 | 'decay_mode': self.decay_mode,
51 | 'current_step': self.current_step}
52 |
53 | def load_state_dict(self, state_dict):
54 | self.base_lr = state_dict['base_lr']
55 | self.warmup_steps = state_dict['warmup_steps']
56 | self.total_steps = state_dict['total_steps']
57 | self.decay_mode = state_dict['decay_mode']
58 | self.current_step = state_dict['current_step']
59 |
--------------------------------------------------------------------------------
/configs/base.yaml:
--------------------------------------------------------------------------------
1 | ENABLE: true
2 | DATASET: webvid10m
--------------------------------------------------------------------------------
/configs/exp01_vidcomposer_full.yaml:
--------------------------------------------------------------------------------
1 | TASK_TYPE: MULTI_TASK
2 | ENABLE: true
3 | DATASET: webvid10m
4 | video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch']
5 | batch_sizes: {
6 | "1": 1,
7 | "4": 1,
8 | "8": 1,
9 | "16": 1,
10 | }
11 | vit_image_size: 224
12 | network_name: UNetSD_temporal
13 | resume: true
14 | resume_step: 228000
15 | num_workers: 1
16 | mvs_visual: False
17 | chunk_size: 1
18 | resume_checkpoint: "model_weights/non_ema_228000.pth"
19 | log_dir: 'outputs'
20 | num_steps: 1
21 |
--------------------------------------------------------------------------------
/configs/exp02_motion_transfer.yaml:
--------------------------------------------------------------------------------
1 | TASK_TYPE: SINGLE_TASK
2 | read_image: True # You NEED Open It
3 | ENABLE: true
4 | DATASET: webvid10m
5 | video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch']
6 | guidances: ['y', 'local_image', 'motion'] # You NEED Open It
7 | batch_sizes: {
8 | "1": 1,
9 | "4": 1,
10 | "8": 1,
11 | "16": 1,
12 | }
13 | vit_image_size: 224
14 | network_name: UNetSD_temporal
15 | resume: true
16 | resume_step: 228000
17 | seed: 182
18 | num_workers: 0
19 | mvs_visual: False
20 | chunk_size: 1
21 | resume_checkpoint: "model_weights/non_ema_228000.pth"
22 | log_dir: 'outputs'
23 | num_steps: 1
24 |
--------------------------------------------------------------------------------
/configs/exp02_motion_transfer_vs_style.yaml:
--------------------------------------------------------------------------------
1 | TASK_TYPE: SINGLE_TASK
2 | read_image: True # You NEED Open It
3 | read_style: True
4 | ENABLE: true
5 | DATASET: webvid10m
6 | video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch']
7 | guidances: ['y', 'local_image', 'image', 'motion'] # You NEED Open It
8 | batch_sizes: {
9 | "1": 1,
10 | "4": 1,
11 | "8": 1,
12 | "16": 1,
13 | }
14 | vit_image_size: 224
15 | network_name: UNetSD_temporal
16 | resume: true
17 | resume_step: 228000
18 | seed: 182
19 | num_workers: 0
20 | mvs_visual: False
21 | chunk_size: 1
22 | resume_checkpoint: "model_weights/non_ema_228000.pth"
23 | log_dir: 'outputs'
24 | num_steps: 1
25 |
--------------------------------------------------------------------------------
/configs/exp03_sketch2video_style.yaml:
--------------------------------------------------------------------------------
1 | TASK_TYPE: SINGLE_TASK
2 | read_image: False # You NEED Open It
3 | read_style: True
4 | read_sketch: True
5 | save_origin_video: False
6 | ENABLE: true
7 | DATASET: webvid10m
8 | video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch']
9 | guidances: ['y', 'image', 'single_sketch'] # You NEED Open It
10 | batch_sizes: {
11 | "1": 1,
12 | "4": 1,
13 | "8": 1,
14 | "16": 1,
15 | }
16 | vit_image_size: 224
17 | network_name: UNetSD_temporal
18 | resume: true
19 | resume_step: 228000
20 | seed: 182
21 | num_workers: 0
22 | mvs_visual: False
23 | chunk_size: 1
24 | resume_checkpoint: "model_weights/non_ema_228000.pth"
25 | log_dir: 'outputs'
26 | num_steps: 1
27 |
--------------------------------------------------------------------------------
/configs/exp04_sketch2video_wo_style.yaml:
--------------------------------------------------------------------------------
1 | TASK_TYPE: SINGLE_TASK
2 | read_image: False # You NEED Open It
3 | read_style: False
4 | read_sketch: True
5 | save_origin_video: False
6 | ENABLE: true
7 | DATASET: webvid10m
8 | video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch']
9 | guidances: ['y', 'single_sketch'] # You NEED Open It
10 | batch_sizes: {
11 | "1": 1,
12 | "4": 1,
13 | "8": 1,
14 | "16": 1,
15 | }
16 | vit_image_size: 224
17 | network_name: UNetSD_temporal
18 | resume: true
19 | resume_step: 228000
20 | seed: 182
21 | num_workers: 0
22 | mvs_visual: False
23 | chunk_size: 1
24 | resume_checkpoint: "model_weights/non_ema_228000.pth"
25 | log_dir: 'outputs'
26 | num_steps: 1
27 |
--------------------------------------------------------------------------------
/configs/exp05_text_depths_wo_style.yaml:
--------------------------------------------------------------------------------
1 | TASK_TYPE: SINGLE_TASK
2 | read_image: False # You NEED Open It
3 | read_style: False
4 | read_sketch: False
5 | save_origin_video: True
6 | ENABLE: true
7 | DATASET: webvid10m
8 | video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch']
9 | guidances: ['y', 'depth'] # You NEED Open It
10 | batch_sizes: {
11 | "1": 1,
12 | "4": 1,
13 | "8": 1,
14 | "16": 1,
15 | }
16 | vit_image_size: 224
17 | network_name: UNetSD_temporal
18 | resume: true
19 | resume_step: 228000
20 | seed: 182
21 | num_workers: 0
22 | mvs_visual: False
23 | chunk_size: 1
24 | resume_checkpoint: "model_weights/non_ema_228000.pth"
25 | log_dir: 'outputs'
26 | num_steps: 1
27 |
28 |
--------------------------------------------------------------------------------
/configs/exp06_text_depths_vs_style.yaml:
--------------------------------------------------------------------------------
1 | TASK_TYPE: SINGLE_TASK
2 | read_image: False
3 | read_style: True
4 | read_sketch: False
5 | save_origin_video: True
6 | ENABLE: true
7 | DATASET: webvid10m
8 | video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch']
9 | guidances: ['y', 'image', 'depth'] # You NEED Open It
10 | batch_sizes: {
11 | "1": 1,
12 | "4": 1,
13 | "8": 1,
14 | "16": 1,
15 | }
16 | vit_image_size: 224
17 | network_name: UNetSD_temporal
18 | resume: true
19 | resume_step: 228000
20 | seed: 182
21 | num_workers: 0
22 | mvs_visual: False
23 | chunk_size: 1
24 | resume_checkpoint: "model_weights/non_ema_228000.pth"
25 | log_dir: 'outputs'
26 | num_steps: 1
27 |
--------------------------------------------------------------------------------
/configs/exp10_vidcomposer_no_watermark_full.yaml:
--------------------------------------------------------------------------------
1 | TASK_TYPE: VideoComposer_Inference
2 | ENABLE: true
3 | DATASET: webvid10m
4 | video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch']
5 | batch_sizes: {
6 | "1": 1,
7 | "4": 1,
8 | "8": 1,
9 | "16": 1,
10 | }
11 | vit_image_size: 224
12 | network_name: UNetSD_temporal
13 | resume: true
14 | resume_step: 141000
15 | seed: 14
16 | num_workers: 1
17 | mvs_visual: True
18 | chunk_size: 1
19 | resume_checkpoint: "model_weights/non_ema_141000_no_watermark.pth"
20 | log_dir: 'outputs'
21 | num_steps: 1
--------------------------------------------------------------------------------
/demo_video/blackswan.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/demo_video/blackswan.mp4
--------------------------------------------------------------------------------
/demo_video/captions_list.txt:
--------------------------------------------------------------------------------
1 | video_10000178, 一架客机正在迪拜机场的停机坪上滑行, an airliner is taxiing on the tarmac at Dubai Airport
2 | video_5360763, 鸽子坐在房子附近的街道上, Pigeon sitting on the street near the house
3 | video_8800, 暹罗斗鱼(Betta splendens)在小玻璃碗和五彩石头中游泳,微距视频|||Siamese Fighting Fish (Betta splendens) swimming in a small glass bowl with multicolored piece of stone, Macro Video
4 | # A colorful and beautiful fish swimming in a small glass bowl with multicolored piece of stone, Macro Video
5 | # A glittering and translucent fish swimming in a small glass bowl with multicolored piece of stone, like a glass fish
6 |
--------------------------------------------------------------------------------
/demo_video/landscape_painting.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/demo_video/landscape_painting.png
--------------------------------------------------------------------------------
/demo_video/moon_on_water.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/demo_video/moon_on_water.jpg
--------------------------------------------------------------------------------
/demo_video/motion_transfer.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/demo_video/motion_transfer.mp4
--------------------------------------------------------------------------------
/demo_video/qibaishi_01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/demo_video/qibaishi_01.png
--------------------------------------------------------------------------------
/demo_video/src_single_sketch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/demo_video/src_single_sketch.png
--------------------------------------------------------------------------------
/demo_video/style/Bingxueqiyuan.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/demo_video/style/Bingxueqiyuan.jpeg
--------------------------------------------------------------------------------
/demo_video/style/fangao_01.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/demo_video/style/fangao_01.jpeg
--------------------------------------------------------------------------------
/demo_video/style/fangao_02.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/demo_video/style/fangao_02.jpeg
--------------------------------------------------------------------------------
/demo_video/style/fangao_03.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/demo_video/style/fangao_03.jpeg
--------------------------------------------------------------------------------
/demo_video/sunflower.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/demo_video/sunflower.png
--------------------------------------------------------------------------------
/demo_video/sunflower_sketch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/demo_video/sunflower_sketch.png
--------------------------------------------------------------------------------
/demo_video/tall_buildings.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/demo_video/tall_buildings.png
--------------------------------------------------------------------------------
/demo_video/tennis.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/demo_video/tennis.mp4
--------------------------------------------------------------------------------
/demo_video/video_10000178.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/demo_video/video_10000178.mp4
--------------------------------------------------------------------------------
/demo_video/video_5360763.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/demo_video/video_5360763.mp4
--------------------------------------------------------------------------------
/demo_video/video_8800.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/demo_video/video_8800.mp4
--------------------------------------------------------------------------------
/demo_video/wash_painting.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/demo_video/wash_painting.png
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: VideoComposer
2 | channels:
3 | - defaults
4 | dependencies:
5 | - _libgcc_mutex=0.1=main
6 | - _openmp_mutex=5.1=1_gnu
7 | - ca-certificates=2023.01.10=h06a4308_0
8 | - certifi=2022.12.7=py38h06a4308_0
9 | - ld_impl_linux-64=2.38=h1181459_1
10 | - libffi=3.4.2=h6a678d5_6
11 | - libgcc-ng=11.2.0=h1234567_1
12 | - libgomp=11.2.0=h1234567_1
13 | - libstdcxx-ng=11.2.0=h1234567_1
14 | - ncurses=6.4=h6a678d5_0
15 | - openssl=1.1.1t=h7f8727e_0
16 | - pip=23.0.1=py38h06a4308_0
17 | - python=3.8.16=h7a1cb2a_3
18 | - readline=8.2=h5eee18b_0
19 | - setuptools=65.6.3=py38h06a4308_0
20 | - sqlite=3.41.1=h5eee18b_0
21 | - tk=8.6.12=h1ccaba5_0
22 | - wheel=0.38.4=py38h06a4308_0
23 | - xz=5.2.10=h5eee18b_1
24 | - zlib=1.2.13=h5eee18b_0
25 | - pip:
26 | - absl-py==1.4.0
27 | - aiohttp==3.8.4
28 | - aiosignal==1.3.1
29 | - aliyun-python-sdk-core==2.13.36
30 | - aliyun-python-sdk-kms==2.16.0
31 | - asttokens==2.2.1
32 | - async-timeout==4.0.2
33 | - attrs==22.2.0
34 | - backcall==0.2.0
35 | - cachetools==5.3.0
36 | - cffi==1.15.1
37 | - chardet==5.1.0
38 | - charset-normalizer==3.1.0
39 | - clean-fid==0.1.35
40 | - click==8.1.3
41 | - cmake==3.26.0
42 | - crcmod==1.7
43 | - cryptography==39.0.2
44 | - decorator==5.1.1
45 | - decord==0.6.0
46 | - easydict==1.10
47 | - einops==0.6.0
48 | - executing==1.2.0
49 | - fairscale==0.4.6
50 | - filelock==3.10.2
51 | - flash-attn==0.2.0
52 | - frozenlist==1.3.3
53 | - fsspec==2023.3.0
54 | - ftfy==6.1.1
55 | - future==0.18.3
56 | - google-auth==2.16.2
57 | - google-auth-oauthlib==0.4.6
58 | - grpcio==1.51.3
59 | - huggingface-hub==0.13.3
60 | - idna==3.4
61 | - imageio==2.15.0
62 | - importlib-metadata==6.1.0
63 | - ipdb==0.13.13
64 | - ipython==8.11.0
65 | - jedi==0.18.2
66 | - jmespath==0.10.0
67 | - joblib==1.2.0
68 | - lazy-loader==0.2
69 | - markdown==3.4.3
70 | - markupsafe==2.1.2
71 | - matplotlib-inline==0.1.6
72 | - motion-vector-extractor==1.0.6
73 | - multidict==6.0.4
74 | - mypy-extensions==1.0.0
75 | - networkx==3.1
76 | - numpy==1.24.2
77 | - oauthlib==3.2.2
78 | - open-clip-torch==2.0.2
79 | - openai-clip==1.0.1
80 | - opencv-python==4.5.5.64
81 | - opencv-python-headless==4.7.0.68
82 | - oss2==2.17.0
83 | - packaging==23.0
84 | - parso==0.8.3
85 | - pexpect==4.8.0
86 | - pickleshare==0.7.5
87 | - pillow==9.4.0
88 | - pkgconfig==1.5.5
89 | - prompt-toolkit==3.0.38
90 | - protobuf==4.22.1
91 | - ptyprocess==0.7.0
92 | - pure-eval==0.2.2
93 | - pyasn1==0.4.8
94 | - pyasn1-modules==0.2.8
95 | - pycparser==2.21
96 | - pycryptodome==3.17
97 | - pydeprecate==0.3.1
98 | - pygments==2.14.0
99 | - pynvml==11.5.0
100 | - pyre-extensions==0.0.23
101 | - pytorch-lightning==1.4.2
102 | - pywavelets==1.4.1
103 | - pyyaml==6.0
104 | - regex==2023.3.22
105 | - requests==2.28.2
106 | - requests-oauthlib==1.3.1
107 | - rotary-embedding-torch==0.2.1
108 | - rsa==4.9
109 | - sacremoses==0.0.53
110 | - scikit-image==0.20.0
111 | - scikit-learn==1.2.2
112 | - scikit-video==1.1.11
113 | - scipy==1.9.1
114 | - simplejson==3.18.4
115 | - six==1.16.0
116 | - sklearn==0.0.post4
117 | - stack-data==0.6.2
118 | - tensorboard==2.12.0
119 | - tensorboard-data-server==0.7.0
120 | - tensorboard-plugin-wit==1.8.1
121 | - threadpoolctl==3.1.0
122 | - tifffile==2023.4.12
123 | - tokenizers==0.12.1
124 | - tomli==2.0.1
125 | - torch==1.12.0+cu113
126 | - torchaudio==0.12.0+cu113
127 | - torchmetrics==0.6.0
128 | - torchvision==0.13.0+cu113
129 | - tqdm==4.65.0
130 | - traitlets==5.9.0
131 | - transformers==4.18.0
132 | - triton==2.0.0.dev20221120
133 | - typing-extensions==4.5.0
134 | - typing-inspect==0.8.0
135 | - urllib3==1.26.15
136 | - wcwidth==0.2.6
137 | - werkzeug==2.2.3
138 | - xformers==0.0.13
139 | - yarl==1.8.2
140 | - zipp==3.15.0
141 |
--------------------------------------------------------------------------------
/gen_sketch.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import torch
4 | import random
5 | from PIL import Image
6 | from io import BytesIO
7 | from functools import partial
8 | import torchvision.transforms as T
9 | import torchvision.transforms.functional as TF
10 | from torchvision.transforms.functional import InterpolationMode
11 |
12 | import artist.data as data
13 | import artist.ops as ops
14 | from tools.annotator.sketch import pidinet_bsd, sketch_simplification_gan
15 |
16 |
17 | def random_resize(img, size):
18 | img = [TF.resize(u, size, interpolation=random.choice([
19 | InterpolationMode.BILINEAR,
20 | InterpolationMode.BICUBIC,
21 | InterpolationMode.LANCZOS])) for u in img]
22 | return img
23 |
24 |
25 | def gen_sketch(image_path, gpu=0, misc_size=384):
26 |
27 | sketch_mean = [0.485, 0.456, 0.406]
28 | sketch_std = [0.229, 0.224, 0.225]
29 |
30 | pidinet = pidinet_bsd(pretrained=True, vanilla_cnn=True).eval().requires_grad_(False).to(gpu)
31 | cleaner = sketch_simplification_gan(pretrained=True).eval().requires_grad_(False).to(gpu)
32 | pidi_mean = torch.tensor(sketch_mean).view(1, -1, 1, 1).to(gpu)
33 | pidi_std = torch.tensor(sketch_std).view(1, -1, 1, 1).to(gpu)
34 |
35 | misc_transforms = data.Compose([
36 | T.Lambda(partial(random_resize, size=misc_size)),
37 | data.CenterCropV2(misc_size),
38 | data.ToTensor()])
39 |
40 | image = Image.open(open(image_path, mode='rb')).convert('RGB')
41 | image = misc_transforms([image]) #
42 | image = image.to(gpu)
43 |
44 | sketch = pidinet(image.sub(pidi_mean).div_(pidi_std))
45 | sketch = 1.0 - cleaner(1.0 - sketch)
46 | sketch = sketch.cpu()
47 |
48 | sketch = sketch[0][0]
49 | sketch = (sketch.numpy()*255).astype('uint8')
50 | file_name = os.path.basename(image_path)
51 | save_pth = 'source/inputs/' + file_name.replace('.', '_sketch.')
52 | cv2.imwrite(save_pth, sketch)
53 |
54 |
55 | gen_sketch(image_path='demo_video/sunflower.png')
56 |
--------------------------------------------------------------------------------
/model_weights/readme.md:
--------------------------------------------------------------------------------
1 | ### 1. Installation
2 |
3 | Please download the model and place them here.
4 |
5 | ```
6 | |--model_weights/
7 | | |--non_ema_228000.pth
8 | | |--midas_v3_dpt_large.pth
9 | | |--open_clip_pytorch_model.bin
10 | | |--sketch_simplification_gan.pth
11 | | |--table5_pidinet.pth
12 | | |--v2-1_512-ema-pruned.ckpt
13 | ```
--------------------------------------------------------------------------------
/run_bash.sh:
--------------------------------------------------------------------------------
1 | # Exp01, inference different conditions from a video
2 | python run_net.py\
3 | --cfg configs/exp01_vidcomposer_full.yaml\
4 | --seed 9999\
5 | --input_video "demo_video/blackswan.mp4"\
6 | --input_text_desc "A black swan swam in the water"
7 |
8 |
9 | # Exp02, Motion Transfer from a video to a Single Image
10 | python run_net.py\
11 | --cfg configs/exp02_motion_transfer.yaml\
12 | --seed 9999\
13 | --input_video "demo_video/motion_transfer.mp4"\
14 | --image_path "demo_video/sunflower.png"\
15 | --input_text_desc "A sunflower in a field of flowers"
16 |
17 |
18 | python run_net.py\
19 | --cfg configs/exp02_motion_transfer_vs_style.yaml\
20 | --seed 9999\
21 | --input_video "demo_video/motion_transfer.mp4"\
22 | --image_path "demo_video/moon_on_water.jpg"\
23 | --style_image "demo_video/moon_on_water.jpg"\
24 | --input_text_desc "A beautiful big silver moon on the water"
25 |
26 |
27 | python run_net.py\
28 | --cfg configs/exp02_motion_transfer_vs_style.yaml\
29 | --seed 888\
30 | --input_video "demo_video/motion_transfer.mp4"\
31 | --image_path "demo_video/style/fangao_01.jpeg"\
32 | --style_image "demo_video/style/fangao_01.jpeg"\
33 | --input_text_desc "Beneath Van Gogh's Starry Sky"
34 |
35 |
36 | # Exp03, Single Sketch to videos with style
37 | python run_net.py\
38 | --cfg configs/exp03_sketch2video_style.yaml\
39 | --seed 8888\
40 | --sketch_path "demo_video/src_single_sketch.png"\
41 | --style_image "demo_video/style/qibaishi_01.png"\
42 | --input_text_desc "Red-backed Shrike lanius collurio"
43 |
44 | # Exp04, Single Sketch to videos without style input
45 | python run_net.py\
46 | --cfg configs/exp04_sketch2video_wo_style.yaml\
47 | --seed 144\
48 | --sketch_path "demo_video/src_single_sketch.png"\
49 | --input_text_desc "A little bird is standing on a branch"
50 |
51 |
52 | # Exp05, Depth to video without style
53 | python run_net.py\
54 | --cfg configs/exp05_text_depths_wo_style.yaml\
55 | --seed 9999\
56 | --input_video demo_video/tennis.mp4\
57 | --input_text_desc "Ironman is fighting against the enemy, big fire in the background, photorealistic"
58 |
59 |
60 | # Exp06, Depth to video with style
61 | python run_net.py\
62 | --cfg configs/exp06_text_depths_vs_style.yaml\
63 | --seed 9999\
64 | --input_video demo_video/tennis.mp4\
65 | --style_image "demo_video/style/fangao_01.jpeg"\
66 | --input_text_desc "Van Gogh played tennis under the stars"
67 |
68 |
69 | # Exp07, Depth to video without style
70 | python run_net.py\
71 | --cfg configs/exp07_text_image_wo_style.yaml\
72 | --seed 9999\
73 | --input_video demo_video/blackswan.mp4\
74 | --input_text_desc "Van Gogh played tennis under the stars"
75 |
76 |
77 | # If you want to , use CUDA_VISIBLE_DEVICES=0
78 |
--------------------------------------------------------------------------------
/run_net.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import os.path as osp
4 | # sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-2]))
5 | import logging
6 | import numpy as np
7 | import copy
8 | import random
9 | import json
10 | import math
11 | import itertools
12 |
13 | import logging
14 | # logger = logging.get_logger(__name__)
15 |
16 | from utils.config import Config
17 | from tools.videocomposer.inference_multi import inference_multi
18 | from tools.videocomposer.inference_single import inference_single
19 |
20 |
21 | def main():
22 | """
23 | Main function to spawn the train and test process.
24 | """
25 | cfg = Config(load=True)
26 | if hasattr(cfg, "TASK_TYPE") and cfg.TASK_TYPE == "MULTI_TASK":
27 | logging.info("TASK TYPE: %s " % cfg.TASK_TYPE)
28 | inference_multi(cfg.cfg_dict)
29 | elif hasattr(cfg, "TASK_TYPE") and cfg.TASK_TYPE == "SINGLE_TASK":
30 | logging.info("TASK TYPE: %s " % cfg.TASK_TYPE)
31 | inference_single(cfg.cfg_dict)
32 | else:
33 | logging.info('Not suport task %s' % (cfg.TASK_TYPE))
34 |
35 | if __name__ == "__main__":
36 | main()
37 |
--------------------------------------------------------------------------------
/source/fig01.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/source/fig01.jpg
--------------------------------------------------------------------------------
/source/fig02_framwork.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/source/fig02_framwork.jpg
--------------------------------------------------------------------------------
/source/fig03_image-to-video.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/source/fig03_image-to-video.jpg
--------------------------------------------------------------------------------
/source/fig04_hand-crafted-motions.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/source/fig04_hand-crafted-motions.jpg
--------------------------------------------------------------------------------
/source/fig05_video-inpainting.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/source/fig05_video-inpainting.jpg
--------------------------------------------------------------------------------
/source/fig06_sketch-to-video.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/source/fig06_sketch-to-video.jpg
--------------------------------------------------------------------------------
/source/results/exp02_motion_transfer-S00009.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/source/results/exp02_motion_transfer-S00009.gif
--------------------------------------------------------------------------------
/source/results/exp02_motion_transfer-S09999-0.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/source/results/exp02_motion_transfer-S09999-0.gif
--------------------------------------------------------------------------------
/source/results/exp02_motion_transfer-S09999.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/source/results/exp02_motion_transfer-S09999.gif
--------------------------------------------------------------------------------
/source/results/exp03_sketch2video_style-S09999.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/source/results/exp03_sketch2video_style-S09999.gif
--------------------------------------------------------------------------------
/source/results/exp04_sketch2video_wo_style-S00144-1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/source/results/exp04_sketch2video_wo_style-S00144-1.gif
--------------------------------------------------------------------------------
/source/results/exp04_sketch2video_wo_style-S00144-2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/source/results/exp04_sketch2video_wo_style-S00144-2.gif
--------------------------------------------------------------------------------
/source/results/exp04_sketch2video_wo_style-S00144.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/source/results/exp04_sketch2video_wo_style-S00144.gif
--------------------------------------------------------------------------------
/source/results/exp05_text_depths_wo_style-S09999-0.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/source/results/exp05_text_depths_wo_style-S09999-0.gif
--------------------------------------------------------------------------------
/source/results/exp05_text_depths_wo_style-S09999-1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/source/results/exp05_text_depths_wo_style-S09999-1.gif
--------------------------------------------------------------------------------
/source/results/exp05_text_depths_wo_style-S09999-2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/source/results/exp05_text_depths_wo_style-S09999-2.gif
--------------------------------------------------------------------------------
/source/results/exp06_text_depths_vs_style-S09999-0.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/source/results/exp06_text_depths_vs_style-S09999-0.gif
--------------------------------------------------------------------------------
/source/results/exp06_text_depths_vs_style-S09999-1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/source/results/exp06_text_depths_vs_style-S09999-1.gif
--------------------------------------------------------------------------------
/source/results/exp06_text_depths_vs_style-S09999-2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/source/results/exp06_text_depths_vs_style-S09999-2.gif
--------------------------------------------------------------------------------
/source/results/exp06_text_depths_vs_style-S09999-3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/source/results/exp06_text_depths_vs_style-S09999-3.gif
--------------------------------------------------------------------------------
/tools/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/tools/.DS_Store
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/tools/__init__.py
--------------------------------------------------------------------------------
/tools/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/tools/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/annotator/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/tools/annotator/.DS_Store
--------------------------------------------------------------------------------
/tools/annotator/__pycache__/util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/tools/annotator/__pycache__/util.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/annotator/canny/__init__.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import torch
3 | import numpy as np
4 | from tools.annotator.util import HWC3
5 | # import gradio as gr
6 |
7 | class CannyDetector:
8 | def __call__(self, img, low_threshold = None, high_threshold = None, random_threshold = True):
9 |
10 | ### Convert to numpy
11 | if isinstance(img, torch.Tensor): # (h, w, c)
12 | img = img.cpu().numpy()
13 | img_np = cv2.convertScaleAbs((img * 255.))
14 | elif isinstance(img, np.ndarray): # (h, w, c)
15 | img_np = img # we assume values are in the range from 0 to 255.
16 | else:
17 | assert False
18 |
19 | ### Select the threshold
20 | if (low_threshold is None) and (high_threshold is None):
21 | median_intensity = np.median(img_np)
22 | if random_threshold is False:
23 | low_threshold = int(max(0, (1 - 0.33) * median_intensity))
24 | high_threshold = int(min(255, (1 + 0.33) * median_intensity))
25 | else:
26 | random_canny = np.random.uniform(0.1, 0.4)
27 | # Might try other values
28 | low_threshold = int(max(0, (1 - random_canny) * median_intensity))
29 | high_threshold = 2 * low_threshold
30 |
31 | ### Detect canny edge
32 | canny_edge = cv2.Canny(img_np, low_threshold, high_threshold)
33 | ### Convert to 3 channels
34 | # canny_edge = HWC3(canny_edge)
35 |
36 | canny_condition = torch.from_numpy(canny_edge.copy()).unsqueeze(dim = -1).float().cuda() / 255.0
37 | # canny_condition = torch.stack([canny_condition for _ in range(num_samples)], dim=0)
38 | # canny_condition = einops.rearrange(canny_condition, 'h w c -> b c h w').clone()
39 | # return cv2.Canny(img, low_threshold, high_threshold)
40 | return canny_condition
--------------------------------------------------------------------------------
/tools/annotator/canny/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/tools/annotator/canny/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/annotator/histogram/__init__.py:
--------------------------------------------------------------------------------
1 | from .palette import *
--------------------------------------------------------------------------------
/tools/annotator/histogram/palette.py:
--------------------------------------------------------------------------------
1 | r"""Modified from ``https://github.com/sergeyk/rayleigh''.
2 | """
3 | import os
4 | import os.path as osp
5 | import numpy as np
6 | from skimage.color import hsv2rgb, rgb2lab, lab2rgb
7 | from skimage.io import imsave
8 | from sklearn.metrics import euclidean_distances
9 |
10 | __all__ = ['Palette']
11 |
12 | def rgb2hex(rgb):
13 | return '#%02x%02x%02x' % tuple([int(round(255.0 * u)) for u in rgb])
14 |
15 | def hex2rgb(hex):
16 | rgb = hex.strip('#')
17 | fn = lambda u: round(int(u, 16) / 255.0, 5)
18 | return fn(rgb[:2]), fn(rgb[2:4]), fn(rgb[4:6])
19 |
20 | class Palette(object):
21 | r"""Create a color palette (codebook) in the form of a 2D grid of colors.
22 | Further, the rightmost column has num_hues gradations from black to white.
23 |
24 | Parameters:
25 | num_hues: number of colors with full lightness and saturation, in the middle.
26 | num_sat: number of rows above middle row that show the same hues with decreasing saturation.
27 | """
28 | def __init__(self, num_hues=11, num_sat=5, num_light=4):
29 | n = num_sat + 2 * num_light
30 |
31 | # hues
32 | if num_hues == 8:
33 | hues = np.tile(np.array([0., 0.10, 0.15, 0.28, 0.51, 0.58, 0.77, 0.85]), (n, 1))
34 | elif num_hues == 9:
35 | hues = np.tile(np.array([0., 0.10, 0.15, 0.28, 0.49, 0.54, 0.60, 0.7, 0.87]), (n, 1))
36 | elif num_hues == 10:
37 | hues = np.tile(np.array([0., 0.10, 0.15, 0.28, 0.49, 0.54, 0.60, 0.66, 0.76, 0.87]), (n, 1))
38 | elif num_hues == 11:
39 | hues = np.tile(np.array([0.0, 0.0833, 0.166, 0.25, 0.333, 0.5, 0.56333, 0.666, 0.73, 0.803, 0.916]), (n, 1))
40 | else:
41 | hues = np.tile(np.linspace(0, 1, num_hues + 1)[:-1], (n, 1))
42 |
43 | # saturations
44 | sats = np.hstack((
45 | np.linspace(0, 1, num_sat + 2)[1:-1],
46 | 1,
47 | [1] * num_light,
48 | [0.4] * (num_light - 1)))
49 | sats = np.tile(np.atleast_2d(sats).T, (1, num_hues))
50 |
51 | # lights
52 | lights = np.hstack((
53 | [1] * num_sat,
54 | 1,
55 | np.linspace(1, 0.2, num_light + 2)[1:-1],
56 | np.linspace(1, 0.2, num_light + 2)[1:-2]))
57 | lights = np.tile(np.atleast_2d(lights).T, (1, num_hues))
58 |
59 | # colors
60 | rgb = hsv2rgb(np.dstack([hues, sats, lights]))
61 | gray = np.tile(np.linspace(1, 0, n)[:, np.newaxis, np.newaxis], (1, 1, 3))
62 | self.thumbnail = np.hstack([rgb, gray])
63 |
64 | # flatten
65 | rgb = rgb.T.reshape(3, -1).T
66 | gray = gray.T.reshape(3, -1).T
67 | self.rgb = np.vstack((rgb, gray))
68 | self.lab = rgb2lab(self.rgb[np.newaxis, :, :]).squeeze()
69 | self.hex = [rgb2hex(u) for u in self.rgb]
70 | self.lab_dists = euclidean_distances(self.lab, squared=True)
71 |
72 | def histogram(self, rgb_img, sigma=20):
73 | # compute histogram
74 | lab = rgb2lab(rgb_img).reshape((-1, 3))
75 | min_ind = np.argmin(euclidean_distances(lab, self.lab, squared=True), axis=1)
76 | hist = 1.0 * np.bincount(min_ind, minlength=self.lab.shape[0]) / lab.shape[0]
77 |
78 | # smooth histogram
79 | if sigma > 0:
80 | weight = np.exp(-self.lab_dists / (2.0 * sigma ** 2))
81 | weight = weight / weight.sum(1)[:, np.newaxis]
82 | hist = (weight * hist).sum(1)
83 | hist[hist < 1e-5] = 0
84 | return hist
85 |
86 | def get_palette_image(self, hist, percentile=90, width=200, height=50):
87 | # curate histogram
88 | ind = np.argsort(-hist)
89 | ind = ind[hist[ind] > np.percentile(hist, percentile)]
90 | hist = hist[ind] / hist[ind].sum()
91 |
92 | # draw palette
93 | nums = np.array(hist * width, dtype=int)
94 | array = np.vstack([np.tile(np.array(u), (v, 1)) for u, v in zip(self.rgb[ind], nums)])
95 | array = np.tile(array[np.newaxis, :, :], (height, 1, 1))
96 | if array.shape[1] < width:
97 | array = np.concatenate([array, np.zeros((height, width - array.shape[1], 3))], axis=1)
98 | return array
99 |
100 | def quantize_image(self, rgb_img):
101 | lab = rgb2lab(rgb_img).reshape((-1, 3))
102 | min_ind = np.argmin(euclidean_distances(lab, self.lab, squared=True), axis=1)
103 | quantized_lab = self.lab[min_ind]
104 | img = lab2rgb(quantized_lab.reshape(rgb_img.shape))
105 | return img
106 |
107 | def export(self, dirname):
108 | if not osp.exists(dirname):
109 | os.makedirs(dirname)
110 |
111 | # save thumbnail
112 | imsave(osp.join(dirname, 'palette.png'), self.thumbnail)
113 |
114 | # save html
115 | with open(osp.join(dirname, 'palette.html'), 'w') as f:
116 | html = '''
117 |
126 | '''
127 | for row in self.thumbnail:
128 | for col in row:
129 | html += '\n'.format(rgb2hex(col))
130 | html += '
\n'
131 | f.write(html)
132 |
--------------------------------------------------------------------------------
/tools/annotator/sketch/__init__.py:
--------------------------------------------------------------------------------
1 | from .pidinet import *
2 | from .sketch_simplification import *
--------------------------------------------------------------------------------
/tools/annotator/sketch/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/tools/annotator/sketch/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/annotator/sketch/__pycache__/pidinet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/tools/annotator/sketch/__pycache__/pidinet.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/annotator/sketch/__pycache__/sketch_simplification.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/tools/annotator/sketch/__pycache__/sketch_simplification.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/annotator/sketch/sketch_simplification.py:
--------------------------------------------------------------------------------
1 | r"""PyTorch re-implementation adapted from the Lua code in ``https://github.com/bobbens/sketch_simplification''.
2 | """
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import math
7 |
8 | # from canvas import DOWNLOAD_TO_CACHE
9 | from artist import DOWNLOAD_TO_CACHE
10 |
11 | __all__ = ['SketchSimplification', 'sketch_simplification_gan', 'sketch_simplification_mse',
12 | 'sketch_to_pencil_v1', 'sketch_to_pencil_v2']
13 |
14 | class SketchSimplification(nn.Module):
15 | r"""NOTE:
16 | 1. Input image should has only one gray channel.
17 | 2. Input image size should be divisible by 8.
18 | 3. Sketch in the input/output image is in dark color while background in light color.
19 | """
20 | def __init__(self, mean, std):
21 | assert isinstance(mean, float) and isinstance(std, float)
22 | super(SketchSimplification, self).__init__()
23 | self.mean = mean
24 | self.std = std
25 |
26 | # layers
27 | self.layers = nn.Sequential(
28 | nn.Conv2d(1, 48, 5, 2, 2),
29 | nn.ReLU(inplace=True),
30 | nn.Conv2d(48, 128, 3, 1, 1),
31 | nn.ReLU(inplace=True),
32 | nn.Conv2d(128, 128, 3, 1, 1),
33 | nn.ReLU(inplace=True),
34 | nn.Conv2d(128, 128, 3, 2, 1),
35 | nn.ReLU(inplace=True),
36 | nn.Conv2d(128, 256, 3, 1, 1),
37 | nn.ReLU(inplace=True),
38 | nn.Conv2d(256, 256, 3, 1, 1),
39 | nn.ReLU(inplace=True),
40 | nn.Conv2d(256, 256, 3, 2, 1),
41 | nn.ReLU(inplace=True),
42 | nn.Conv2d(256, 512, 3, 1, 1),
43 | nn.ReLU(inplace=True),
44 | nn.Conv2d(512, 1024, 3, 1, 1),
45 | nn.ReLU(inplace=True),
46 | nn.Conv2d(1024, 1024, 3, 1, 1),
47 | nn.ReLU(inplace=True),
48 | nn.Conv2d(1024, 1024, 3, 1, 1),
49 | nn.ReLU(inplace=True),
50 | nn.Conv2d(1024, 1024, 3, 1, 1),
51 | nn.ReLU(inplace=True),
52 | nn.Conv2d(1024, 512, 3, 1, 1),
53 | nn.ReLU(inplace=True),
54 | nn.Conv2d(512, 256, 3, 1, 1),
55 | nn.ReLU(inplace=True),
56 | nn.ConvTranspose2d(256, 256, 4, 2, 1),
57 | nn.ReLU(inplace=True),
58 | nn.Conv2d(256, 256, 3, 1, 1),
59 | nn.ReLU(inplace=True),
60 | nn.Conv2d(256, 128, 3, 1, 1),
61 | nn.ReLU(inplace=True),
62 | nn.ConvTranspose2d(128, 128, 4, 2, 1),
63 | nn.ReLU(inplace=True),
64 | nn.Conv2d(128, 128, 3, 1, 1),
65 | nn.ReLU(inplace=True),
66 | nn.Conv2d(128, 48, 3, 1, 1),
67 | nn.ReLU(inplace=True),
68 | nn.ConvTranspose2d(48, 48, 4, 2, 1),
69 | nn.ReLU(inplace=True),
70 | nn.Conv2d(48, 24, 3, 1, 1),
71 | nn.ReLU(inplace=True),
72 | nn.Conv2d(24, 1, 3, 1, 1),
73 | nn.Sigmoid())
74 |
75 | def forward(self, x):
76 | r"""x: [B, 1, H, W] within range [0, 1]. Sketch pixels in dark color.
77 | """
78 | x = (x - self.mean) / self.std
79 | return self.layers(x)
80 |
81 | def sketch_simplification_gan(pretrained=False):
82 | model = SketchSimplification(mean=0.9664114577640158, std=0.0858381272736797)
83 | if pretrained:
84 | model.load_state_dict(torch.load(
85 | DOWNLOAD_TO_CACHE('./model_weights/sketch_simplification_gan.pth'),
86 | map_location='cpu'))
87 | return model
88 |
89 | def sketch_simplification_mse(pretrained=False):
90 | model = SketchSimplification(mean=0.9664423107454593, std=0.08583666033640507)
91 | if pretrained:
92 | model.load_state_dict(torch.load(
93 | DOWNLOAD_TO_CACHE('models/sketch_simplification/sketch_simplification_mse.pth'),
94 | map_location='cpu'))
95 | return model
96 |
97 | def sketch_to_pencil_v1(pretrained=False):
98 | model = SketchSimplification(mean=0.9817833515894078, std=0.0925009022585048)
99 | if pretrained:
100 | model.load_state_dict(torch.load(
101 | DOWNLOAD_TO_CACHE('models/sketch_simplification/sketch_to_pencil_v1.pth'),
102 | map_location='cpu'))
103 | return model
104 |
105 | def sketch_to_pencil_v2(pretrained=False):
106 | model = SketchSimplification(mean=0.9851298627337799, std=0.07418377454883571)
107 | if pretrained:
108 | model.load_state_dict(torch.load(
109 | DOWNLOAD_TO_CACHE('models/sketch_simplification/sketch_to_pencil_v2.pth'),
110 | map_location='cpu'))
111 | return model
112 |
--------------------------------------------------------------------------------
/tools/annotator/util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import os
4 |
5 | annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts')
6 |
7 | def HWC3(x):
8 | assert x.dtype == np.uint8
9 | if x.ndim == 2:
10 | x = x[:, :, None]
11 | assert x.ndim == 3
12 | H, W, C = x.shape
13 | assert C == 1 or C == 3 or C == 4
14 | if C == 3:
15 | return x
16 | if C == 1:
17 | return np.concatenate([x, x, x], axis=2)
18 | if C == 4:
19 | color = x[:, :, 0:3].astype(np.float32)
20 | alpha = x[:, :, 3:4].astype(np.float32) / 255.0
21 | y = color * alpha + 255.0 * (1.0 - alpha)
22 | y = y.clip(0, 255).astype(np.uint8)
23 | return y
24 |
25 |
26 | def resize_image(input_image, resolution):
27 | H, W, C = input_image.shape
28 | H = float(H)
29 | W = float(W)
30 | k = float(resolution) / min(H, W)
31 | H *= k
32 | W *= k
33 | H = int(np.round(H / 64.0)) * 64
34 | W = int(np.round(W / 64.0)) * 64
35 | img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
36 | return img
--------------------------------------------------------------------------------
/tools/videocomposer/__pycache__/autoencoder.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/tools/videocomposer/__pycache__/autoencoder.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/videocomposer/__pycache__/config.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/tools/videocomposer/__pycache__/config.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/videocomposer/__pycache__/datasets.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/tools/videocomposer/__pycache__/datasets.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/videocomposer/__pycache__/inference_multi.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/tools/videocomposer/__pycache__/inference_multi.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/videocomposer/__pycache__/inference_single.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/tools/videocomposer/__pycache__/inference_single.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/videocomposer/__pycache__/mha_flash.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/tools/videocomposer/__pycache__/mha_flash.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/videocomposer/__pycache__/unet_sd.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/tools/videocomposer/__pycache__/unet_sd.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/videocomposer/autoencoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | __all__ = ['AutoencoderKL']
7 |
8 | def nonlinearity(x):
9 | # swish
10 | return x*torch.sigmoid(x)
11 |
12 | def Normalize(in_channels, num_groups=32):
13 | return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
14 |
15 | class DiagonalGaussianDistribution(object):
16 | def __init__(self, parameters, deterministic=False):
17 | self.parameters = parameters
18 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
19 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
20 | self.deterministic = deterministic
21 | self.std = torch.exp(0.5 * self.logvar)
22 | self.var = torch.exp(self.logvar)
23 | if self.deterministic:
24 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
25 |
26 | def sample(self):
27 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
28 | return x
29 |
30 | def kl(self, other=None):
31 | if self.deterministic:
32 | return torch.Tensor([0.])
33 | else:
34 | if other is None:
35 | return 0.5 * torch.sum(torch.pow(self.mean, 2)
36 | + self.var - 1.0 - self.logvar,
37 | dim=[1, 2, 3])
38 | else:
39 | return 0.5 * torch.sum(
40 | torch.pow(self.mean - other.mean, 2) / other.var
41 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
42 | dim=[1, 2, 3])
43 |
44 | def nll(self, sample, dims=[1,2,3]):
45 | if self.deterministic:
46 | return torch.Tensor([0.])
47 | logtwopi = np.log(2.0 * np.pi)
48 | return 0.5 * torch.sum(
49 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
50 | dim=dims)
51 |
52 | def mode(self):
53 | return self.mean
54 |
55 | class Downsample(nn.Module):
56 | def __init__(self, in_channels, with_conv):
57 | super().__init__()
58 | self.with_conv = with_conv
59 | if self.with_conv:
60 | # no asymmetric padding in torch conv, must do it ourselves
61 | self.conv = torch.nn.Conv2d(in_channels,
62 | in_channels,
63 | kernel_size=3,
64 | stride=2,
65 | padding=0)
66 |
67 | def forward(self, x):
68 | if self.with_conv:
69 | pad = (0,1,0,1)
70 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
71 | x = self.conv(x)
72 | else:
73 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
74 | return x
75 |
76 | class ResnetBlock(nn.Module):
77 | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
78 | dropout, temb_channels=512):
79 | super().__init__()
80 | self.in_channels = in_channels
81 | out_channels = in_channels if out_channels is None else out_channels
82 | self.out_channels = out_channels
83 | self.use_conv_shortcut = conv_shortcut
84 |
85 | self.norm1 = Normalize(in_channels)
86 | self.conv1 = torch.nn.Conv2d(in_channels,
87 | out_channels,
88 | kernel_size=3,
89 | stride=1,
90 | padding=1)
91 | if temb_channels > 0:
92 | self.temb_proj = torch.nn.Linear(temb_channels,
93 | out_channels)
94 | self.norm2 = Normalize(out_channels)
95 | self.dropout = torch.nn.Dropout(dropout)
96 | self.conv2 = torch.nn.Conv2d(out_channels,
97 | out_channels,
98 | kernel_size=3,
99 | stride=1,
100 | padding=1)
101 | if self.in_channels != self.out_channels:
102 | if self.use_conv_shortcut:
103 | self.conv_shortcut = torch.nn.Conv2d(in_channels,
104 | out_channels,
105 | kernel_size=3,
106 | stride=1,
107 | padding=1)
108 | else:
109 | self.nin_shortcut = torch.nn.Conv2d(in_channels,
110 | out_channels,
111 | kernel_size=1,
112 | stride=1,
113 | padding=0)
114 |
115 | def forward(self, x, temb):
116 | h = x
117 | h = self.norm1(h)
118 | h = nonlinearity(h)
119 | h = self.conv1(h)
120 |
121 | if temb is not None:
122 | h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
123 |
124 | h = self.norm2(h)
125 | h = nonlinearity(h)
126 | h = self.dropout(h)
127 | h = self.conv2(h)
128 |
129 | if self.in_channels != self.out_channels:
130 | if self.use_conv_shortcut:
131 | x = self.conv_shortcut(x)
132 | else:
133 | x = self.nin_shortcut(x)
134 |
135 | return x+h
136 |
137 |
138 | class AttnBlock(nn.Module):
139 | def __init__(self, in_channels):
140 | super().__init__()
141 | self.in_channels = in_channels
142 |
143 | self.norm = Normalize(in_channels)
144 | self.q = torch.nn.Conv2d(in_channels,
145 | in_channels,
146 | kernel_size=1,
147 | stride=1,
148 | padding=0)
149 | self.k = torch.nn.Conv2d(in_channels,
150 | in_channels,
151 | kernel_size=1,
152 | stride=1,
153 | padding=0)
154 | self.v = torch.nn.Conv2d(in_channels,
155 | in_channels,
156 | kernel_size=1,
157 | stride=1,
158 | padding=0)
159 | self.proj_out = torch.nn.Conv2d(in_channels,
160 | in_channels,
161 | kernel_size=1,
162 | stride=1,
163 | padding=0)
164 |
165 | def forward(self, x):
166 | h_ = x
167 | h_ = self.norm(h_)
168 | q = self.q(h_)
169 | k = self.k(h_)
170 | v = self.v(h_)
171 |
172 | # compute attention
173 | b,c,h,w = q.shape
174 | q = q.reshape(b,c,h*w)
175 | q = q.permute(0,2,1) # b,hw,c
176 | k = k.reshape(b,c,h*w) # b,c,hw
177 | w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
178 | w_ = w_ * (int(c)**(-0.5))
179 | w_ = torch.nn.functional.softmax(w_, dim=2)
180 |
181 | # attend to values
182 | v = v.reshape(b,c,h*w)
183 | w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
184 | h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
185 | h_ = h_.reshape(b,c,h,w)
186 |
187 | h_ = self.proj_out(h_)
188 |
189 | return x+h_
190 |
191 | class AttnBlock(nn.Module):
192 | def __init__(self, in_channels):
193 | super().__init__()
194 | self.in_channels = in_channels
195 |
196 | self.norm = Normalize(in_channels)
197 | self.q = torch.nn.Conv2d(in_channels,
198 | in_channels,
199 | kernel_size=1,
200 | stride=1,
201 | padding=0)
202 | self.k = torch.nn.Conv2d(in_channels,
203 | in_channels,
204 | kernel_size=1,
205 | stride=1,
206 | padding=0)
207 | self.v = torch.nn.Conv2d(in_channels,
208 | in_channels,
209 | kernel_size=1,
210 | stride=1,
211 | padding=0)
212 | self.proj_out = torch.nn.Conv2d(in_channels,
213 | in_channels,
214 | kernel_size=1,
215 | stride=1,
216 | padding=0)
217 |
218 | def forward(self, x):
219 | h_ = x
220 | h_ = self.norm(h_)
221 | q = self.q(h_)
222 | k = self.k(h_)
223 | v = self.v(h_)
224 |
225 | # compute attention
226 | b,c,h,w = q.shape
227 | q = q.reshape(b,c,h*w)
228 | q = q.permute(0,2,1) # b,hw,c
229 | k = k.reshape(b,c,h*w) # b,c,hw
230 | w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
231 | w_ = w_ * (int(c)**(-0.5))
232 | w_ = torch.nn.functional.softmax(w_, dim=2)
233 |
234 | # attend to values
235 | v = v.reshape(b,c,h*w)
236 | w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
237 | h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
238 | h_ = h_.reshape(b,c,h,w)
239 |
240 | h_ = self.proj_out(h_)
241 |
242 | return x+h_
243 |
244 | class Upsample(nn.Module):
245 | def __init__(self, in_channels, with_conv):
246 | super().__init__()
247 | self.with_conv = with_conv
248 | if self.with_conv:
249 | self.conv = torch.nn.Conv2d(in_channels,
250 | in_channels,
251 | kernel_size=3,
252 | stride=1,
253 | padding=1)
254 |
255 | def forward(self, x):
256 | x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
257 | if self.with_conv:
258 | x = self.conv(x)
259 | return x
260 |
261 |
262 | class Downsample(nn.Module):
263 | def __init__(self, in_channels, with_conv):
264 | super().__init__()
265 | self.with_conv = with_conv
266 | if self.with_conv:
267 | # no asymmetric padding in torch conv, must do it ourselves
268 | self.conv = torch.nn.Conv2d(in_channels,
269 | in_channels,
270 | kernel_size=3,
271 | stride=2,
272 | padding=0)
273 |
274 | def forward(self, x):
275 | if self.with_conv:
276 | pad = (0,1,0,1)
277 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
278 | x = self.conv(x)
279 | else:
280 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
281 | return x
282 |
283 | class Encoder(nn.Module):
284 | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
285 | attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
286 | resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
287 | **ignore_kwargs):
288 | super().__init__()
289 | if use_linear_attn: attn_type = "linear"
290 | self.ch = ch
291 | self.temb_ch = 0
292 | self.num_resolutions = len(ch_mult)
293 | self.num_res_blocks = num_res_blocks
294 | self.resolution = resolution
295 | self.in_channels = in_channels
296 |
297 | # downsampling
298 | self.conv_in = torch.nn.Conv2d(in_channels,
299 | self.ch,
300 | kernel_size=3,
301 | stride=1,
302 | padding=1)
303 |
304 | curr_res = resolution
305 | in_ch_mult = (1,)+tuple(ch_mult)
306 | self.in_ch_mult = in_ch_mult
307 | self.down = nn.ModuleList()
308 | for i_level in range(self.num_resolutions):
309 | block = nn.ModuleList()
310 | attn = nn.ModuleList()
311 | block_in = ch*in_ch_mult[i_level]
312 | block_out = ch*ch_mult[i_level]
313 | for i_block in range(self.num_res_blocks):
314 | block.append(ResnetBlock(in_channels=block_in,
315 | out_channels=block_out,
316 | temb_channels=self.temb_ch,
317 | dropout=dropout))
318 | block_in = block_out
319 | if curr_res in attn_resolutions:
320 | attn.append(AttnBlock(block_in))
321 | down = nn.Module()
322 | down.block = block
323 | down.attn = attn
324 | if i_level != self.num_resolutions-1:
325 | down.downsample = Downsample(block_in, resamp_with_conv)
326 | curr_res = curr_res // 2
327 | self.down.append(down)
328 |
329 | # middle
330 | self.mid = nn.Module()
331 | self.mid.block_1 = ResnetBlock(in_channels=block_in,
332 | out_channels=block_in,
333 | temb_channels=self.temb_ch,
334 | dropout=dropout)
335 | self.mid.attn_1 = AttnBlock(block_in)
336 | self.mid.block_2 = ResnetBlock(in_channels=block_in,
337 | out_channels=block_in,
338 | temb_channels=self.temb_ch,
339 | dropout=dropout)
340 |
341 | # end
342 | self.norm_out = Normalize(block_in)
343 | self.conv_out = torch.nn.Conv2d(block_in,
344 | 2*z_channels if double_z else z_channels,
345 | kernel_size=3,
346 | stride=1,
347 | padding=1)
348 |
349 | def forward(self, x):
350 | # timestep embedding
351 | temb = None
352 |
353 | # downsampling
354 | hs = [self.conv_in(x)]
355 | for i_level in range(self.num_resolutions):
356 | for i_block in range(self.num_res_blocks):
357 | h = self.down[i_level].block[i_block](hs[-1], temb)
358 | if len(self.down[i_level].attn) > 0:
359 | h = self.down[i_level].attn[i_block](h)
360 | hs.append(h)
361 | if i_level != self.num_resolutions-1:
362 | hs.append(self.down[i_level].downsample(hs[-1]))
363 |
364 | # middle
365 | h = hs[-1]
366 | h = self.mid.block_1(h, temb)
367 | h = self.mid.attn_1(h)
368 | h = self.mid.block_2(h, temb)
369 |
370 | # end
371 | h = self.norm_out(h)
372 | h = nonlinearity(h)
373 | h = self.conv_out(h)
374 | return h
375 |
376 |
377 | class Decoder(nn.Module):
378 | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
379 | attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
380 | resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
381 | attn_type="vanilla", **ignorekwargs):
382 | super().__init__()
383 | if use_linear_attn: attn_type = "linear"
384 | self.ch = ch
385 | self.temb_ch = 0
386 | self.num_resolutions = len(ch_mult)
387 | self.num_res_blocks = num_res_blocks
388 | self.resolution = resolution
389 | self.in_channels = in_channels
390 | self.give_pre_end = give_pre_end
391 | self.tanh_out = tanh_out
392 |
393 | # compute in_ch_mult, block_in and curr_res at lowest res
394 | in_ch_mult = (1,)+tuple(ch_mult)
395 | block_in = ch*ch_mult[self.num_resolutions-1]
396 | curr_res = resolution // 2**(self.num_resolutions-1)
397 | self.z_shape = (1,z_channels,curr_res,curr_res)
398 | print("Working with z of shape {} = {} dimensions.".format(
399 | self.z_shape, np.prod(self.z_shape)))
400 |
401 | # z to block_in
402 | self.conv_in = torch.nn.Conv2d(z_channels,
403 | block_in,
404 | kernel_size=3,
405 | stride=1,
406 | padding=1)
407 |
408 | # middle
409 | self.mid = nn.Module()
410 | self.mid.block_1 = ResnetBlock(in_channels=block_in,
411 | out_channels=block_in,
412 | temb_channels=self.temb_ch,
413 | dropout=dropout)
414 | self.mid.attn_1 = AttnBlock(block_in)
415 | self.mid.block_2 = ResnetBlock(in_channels=block_in,
416 | out_channels=block_in,
417 | temb_channels=self.temb_ch,
418 | dropout=dropout)
419 |
420 | # upsampling
421 | self.up = nn.ModuleList()
422 | for i_level in reversed(range(self.num_resolutions)):
423 | block = nn.ModuleList()
424 | attn = nn.ModuleList()
425 | block_out = ch*ch_mult[i_level]
426 | for i_block in range(self.num_res_blocks+1):
427 | block.append(ResnetBlock(in_channels=block_in,
428 | out_channels=block_out,
429 | temb_channels=self.temb_ch,
430 | dropout=dropout))
431 | block_in = block_out
432 | if curr_res in attn_resolutions:
433 | attn.append(AttnBlock(block_in))
434 | up = nn.Module()
435 | up.block = block
436 | up.attn = attn
437 | if i_level != 0:
438 | up.upsample = Upsample(block_in, resamp_with_conv)
439 | curr_res = curr_res * 2
440 | self.up.insert(0, up) # prepend to get consistent order
441 |
442 | # end
443 | self.norm_out = Normalize(block_in)
444 | self.conv_out = torch.nn.Conv2d(block_in,
445 | out_ch,
446 | kernel_size=3,
447 | stride=1,
448 | padding=1)
449 |
450 | def forward(self, z):
451 | #assert z.shape[1:] == self.z_shape[1:]
452 | self.last_z_shape = z.shape
453 |
454 | # timestep embedding
455 | temb = None
456 |
457 | # z to block_in
458 | h = self.conv_in(z)
459 |
460 | # middle
461 | h = self.mid.block_1(h, temb)
462 | h = self.mid.attn_1(h)
463 | h = self.mid.block_2(h, temb)
464 |
465 | # upsampling
466 | for i_level in reversed(range(self.num_resolutions)):
467 | for i_block in range(self.num_res_blocks+1):
468 | h = self.up[i_level].block[i_block](h, temb)
469 | if len(self.up[i_level].attn) > 0:
470 | h = self.up[i_level].attn[i_block](h)
471 | if i_level != 0:
472 | h = self.up[i_level].upsample(h)
473 |
474 | # end
475 | if self.give_pre_end:
476 | return h
477 |
478 | h = self.norm_out(h)
479 | h = nonlinearity(h)
480 | h = self.conv_out(h)
481 | if self.tanh_out:
482 | h = torch.tanh(h)
483 | return h
484 |
485 |
486 | class AutoencoderKL(nn.Module):
487 | def __init__(self,
488 | ddconfig,
489 | embed_dim,
490 | ckpt_path=None,
491 | ignore_keys=[],
492 | image_key="image",
493 | colorize_nlabels=None,
494 | monitor=None,
495 | ema_decay=None,
496 | learn_logvar=False
497 | ):
498 | super().__init__()
499 | self.learn_logvar = learn_logvar
500 | self.image_key = image_key
501 | self.encoder = Encoder(**ddconfig)
502 | self.decoder = Decoder(**ddconfig)
503 | assert ddconfig["double_z"]
504 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
505 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
506 | self.embed_dim = embed_dim
507 | if colorize_nlabels is not None:
508 | assert type(colorize_nlabels)==int
509 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
510 | if monitor is not None:
511 | self.monitor = monitor
512 | self.use_ema = ema_decay is not None
513 | if ckpt_path is not None:
514 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
515 |
516 | def init_from_ckpt(self, path, ignore_keys=list()):
517 | sd = torch.load(path, map_location="cpu")["state_dict"]
518 | keys = list(sd.keys())
519 | for key in keys:
520 | print(key, sd[key].shape)
521 | import collections
522 | sd_new = collections.OrderedDict()
523 | for k in keys:
524 | if k.find('first_stage_model') >= 0:
525 | k_new = k.split('first_stage_model.')[-1]
526 | sd_new[k_new] = sd[k]
527 | self.load_state_dict(sd_new, strict=True)
528 | print(f"Restored from {path}")
529 |
530 | def init_from_ckpt2(self, path, ignore_keys=list()):
531 | sd = torch.load(path, map_location="cpu")["state_dict"]
532 | keys = list(sd.keys())
533 |
534 | first_stage_model
535 | for k in keys:
536 | for ik in ignore_keys:
537 | if k.startswith(ik):
538 | print("Deleting key {} from state_dict.".format(k))
539 | del sd[k]
540 | self.load_state_dict(sd, strict=False)
541 | print(f"Restored from {path}")
542 |
543 | def on_train_batch_end(self, *args, **kwargs):
544 | if self.use_ema:
545 | self.model_ema(self)
546 |
547 | def encode(self, x):
548 | h = self.encoder(x)
549 | moments = self.quant_conv(h)
550 | posterior = DiagonalGaussianDistribution(moments)
551 | return posterior
552 |
553 | def decode(self, z):
554 | z = self.post_quant_conv(z)
555 | dec = self.decoder(z)
556 | return dec
557 |
558 | def forward(self, input, sample_posterior=True):
559 | posterior = self.encode(input)
560 | if sample_posterior:
561 | z = posterior.sample()
562 | else:
563 | z = posterior.mode()
564 | dec = self.decode(z)
565 | return dec, posterior
566 |
567 | def get_input(self, batch, k):
568 | x = batch[k]
569 | if len(x.shape) == 3:
570 | x = x[..., None]
571 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
572 | return x
573 |
574 | def get_last_layer(self):
575 | return self.decoder.conv_out.weight
576 |
577 | @torch.no_grad()
578 | def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
579 | log = dict()
580 | x = self.get_input(batch, self.image_key)
581 | x = x.to(self.device)
582 | if not only_inputs:
583 | xrec, posterior = self(x)
584 | if x.shape[1] > 3:
585 | # colorize with random projection
586 | assert xrec.shape[1] > 3
587 | x = self.to_rgb(x)
588 | xrec = self.to_rgb(xrec)
589 | log["samples"] = self.decode(torch.randn_like(posterior.sample()))
590 | log["reconstructions"] = xrec
591 | if log_ema or self.use_ema:
592 | with self.ema_scope():
593 | xrec_ema, posterior_ema = self(x)
594 | if x.shape[1] > 3:
595 | # colorize with random projection
596 | assert xrec_ema.shape[1] > 3
597 | xrec_ema = self.to_rgb(xrec_ema)
598 | log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
599 | log["reconstructions_ema"] = xrec_ema
600 | log["inputs"] = x
601 | return log
602 |
603 | def to_rgb(self, x):
604 | assert self.image_key == "segmentation"
605 | if not hasattr(self, "colorize"):
606 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
607 | x = F.conv2d(x, weight=self.colorize)
608 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
609 | return x
610 |
611 |
612 | class IdentityFirstStage(torch.nn.Module):
613 | def __init__(self, *args, vq_interface=False, **kwargs):
614 | self.vq_interface = vq_interface
615 | super().__init__()
616 |
617 | def encode(self, x, *args, **kwargs):
618 | return x
619 |
620 | def decode(self, x, *args, **kwargs):
621 | return x
622 |
623 | def quantize(self, x, *args, **kwargs):
624 | if self.vq_interface:
625 | return x, None, [None, None, None]
626 | return x
627 |
628 | def forward(self, x, *args, **kwargs):
629 | return x
630 |
631 |
--------------------------------------------------------------------------------
/tools/videocomposer/config.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import logging
3 | import os.path as osp
4 | from datetime import datetime
5 | from easydict import EasyDict
6 | import os
7 |
8 | cfg = EasyDict(__name__='Config: VideoComposer')
9 |
10 | pmi_world_size = int(os.getenv('WORLD_SIZE', 1))
11 | gpus_per_machine = torch.cuda.device_count()
12 | world_size = pmi_world_size * gpus_per_machine
13 |
14 | cfg.video_compositions = ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch']
15 |
16 | # dataset
17 | cfg.root_dir = 'webvid10m/'
18 |
19 | cfg.alpha = 0.7
20 |
21 | cfg.misc_size = 384
22 | cfg.depth_std = 20.0
23 | cfg.depth_clamp = 10.0
24 | cfg.hist_sigma = 10.0
25 |
26 | #
27 | cfg.use_image_dataset = False
28 | cfg.alpha_img = 0.7
29 |
30 | cfg.resolution = 256
31 | cfg.mean = [0.5, 0.5, 0.5]
32 | cfg.std = [0.5, 0.5, 0.5]
33 |
34 | # sketch
35 | cfg.sketch_mean = [0.485, 0.456, 0.406]
36 | cfg.sketch_std = [0.229, 0.224, 0.225]
37 |
38 | # dataloader
39 | cfg.max_words = 1000
40 |
41 | cfg.frame_lens = [
42 | 16,
43 | 16,
44 | 16,
45 | 16,
46 | ]
47 | cfg.feature_framerates = [
48 | 4,
49 | ]
50 | cfg.feature_framerate = 4
51 | cfg.batch_sizes = {
52 | str(1):1,
53 | str(4):1,
54 | str(8):1,
55 | str(16):1,
56 | }
57 |
58 | cfg.chunk_size=64
59 | cfg.num_workers = 8
60 | cfg.prefetch_factor = 2
61 | cfg.seed = 8888
62 |
63 | # diffusion
64 | cfg.num_timesteps = 1000
65 | cfg.mean_type = 'eps'
66 | cfg.var_type = 'fixed_small' # NOTE: to stabilize training and avoid NaN
67 | cfg.loss_type = 'mse'
68 | cfg.ddim_timesteps = 50 # official: 250
69 | cfg.ddim_eta = 0.0
70 | cfg.clamp = 1.0
71 | cfg.share_noise = False
72 | cfg.use_div_loss = False
73 |
74 | # classifier-free guidance
75 | cfg.p_zero = 0.9
76 | cfg.guide_scale = 6.0
77 |
78 | # stabel diffusion
79 | cfg.sd_checkpoint = 'v2-1_512-ema-pruned.ckpt'
80 |
81 |
82 | # clip vision encoder
83 | cfg.vit_image_size = 336
84 | cfg.vit_patch_size = 14
85 | cfg.vit_dim = 1024
86 | cfg.vit_out_dim = 768
87 | cfg.vit_heads = 16
88 | cfg.vit_layers = 24
89 | cfg.vit_mean = [0.48145466, 0.4578275, 0.40821073]
90 | cfg.vit_std = [0.26862954, 0.26130258, 0.27577711]
91 | cfg.clip_checkpoint = 'open_clip_pytorch_model.bin'
92 | cfg.mvs_visual= False
93 | # unet
94 | cfg.unet_in_dim = 4
95 | cfg.unet_concat_dim = 8
96 | cfg.unet_y_dim = cfg.vit_out_dim
97 | cfg.unet_context_dim = 1024
98 | cfg.unet_out_dim = 8 if cfg.var_type.startswith('learned') else 4
99 | cfg.unet_dim = 320
100 | #cfg.unet_dim_mult = [1, 2, 3, 5]
101 | cfg.unet_dim_mult = [1, 2, 4, 4]
102 | cfg.unet_res_blocks = 2
103 | cfg.unet_num_heads = 8
104 | cfg.unet_head_dim = 64
105 | cfg.unet_attn_scales = [1 / 1, 1 / 2, 1 / 4]
106 | cfg.unet_dropout = 0.1
107 | cfg.misc_dropout = 0.5
108 | cfg.p_all_zero = 0.1
109 | cfg.p_all_keep = 0.1
110 | cfg.temporal_conv = False
111 | cfg.temporal_attn_times = 1
112 | cfg.temporal_attention = True
113 |
114 | cfg.use_fps_condition = False
115 | cfg.use_sim_mask = False
116 |
117 | ## Default: load 2d pretrain
118 | cfg.pretrained = False
119 | cfg.fix_weight = False
120 |
121 | ## Default resume
122 | #
123 | cfg.resume = True
124 | cfg.resume_step = 148000
125 | cfg.resume_check_dir = '.'
126 | cfg.resume_checkpoint = os.path.join(cfg.resume_check_dir,f'step_{cfg.resume_step}/non_ema_{cfg.resume_step}.pth')
127 | #
128 | cfg.resume_optimizer = False
129 | if cfg.resume_optimizer:
130 | cfg.resume_optimizer = os.path.join(cfg.resume_check_dir,f'optimizer_step_{cfg.resume_step}.pt')
131 |
132 |
133 | # acceleration
134 | cfg.use_ema = True
135 | # for debug, no ema
136 | if world_size<2:
137 | cfg.use_ema = False
138 | cfg.load_from = None
139 |
140 | cfg.use_checkpoint = True
141 | cfg.use_sharded_ddp = False
142 | cfg.use_fsdp = False
143 | cfg.use_fp16 = True
144 |
145 | # training
146 | cfg.ema_decay = 0.9999
147 | cfg.viz_interval = 1000
148 | cfg.save_ckp_interval = 1000
149 |
150 | # logging
151 | cfg.log_interval = 100
152 | composition_strings = '_'.join(cfg.video_compositions)
153 | ### Default log_dir
154 | cfg.log_dir = f'outputs/'
155 |
--------------------------------------------------------------------------------
/tools/videocomposer/datasets.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import torch
3 | import time
4 | import torch.nn.functional as F
5 | from torch.utils.data import Dataset
6 | import torchvision.transforms as T
7 | from collections import defaultdict
8 | import re
9 | import pickle
10 | import json
11 | import random
12 | import numpy as np
13 | from io import BytesIO
14 | from PIL import Image
15 | import artist.ops as ops
16 | import cv2
17 | from skimage.color import rgb2lab, lab2rgb
18 | import datetime
19 | # ADD
20 | import os
21 | from mvextractor.videocap import VideoCap
22 | import subprocess
23 | import binascii
24 | from ipdb import set_trace
25 | import imageio
26 |
27 | import utils.logging as logging
28 | logger = logging.get_logger(__name__)
29 |
30 | def pre_caption(caption, max_words):
31 | caption = re.sub(
32 | r"([,.'!?\"()*#:;~])",
33 | '',
34 | caption.lower(),
35 | ).replace('-', ' ').replace('/', ' ').replace('', 'person')
36 |
37 | caption = re.sub(
38 | r"\s{2,}",
39 | ' ',
40 | caption,
41 | )
42 | caption = caption.rstrip('\n')
43 | caption = caption.strip(' ')
44 |
45 | #truncate caption
46 | caption_words = caption.split(' ')
47 | if len(caption_words) > max_words:
48 | caption = ' '.join(caption_words[:max_words])
49 |
50 | return caption
51 |
52 | def random_resize(img, size):
53 | return TF.resize(img, size, interpolation=random.choice([
54 | InterpolationMode.BILINEAR,
55 | InterpolationMode.BICUBIC,
56 | InterpolationMode.LANCZOS]))
57 |
58 | def rand_name(length=16, suffix=''):
59 | name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
60 | if suffix:
61 | if not suffix.startswith('.'):
62 | suffix = '.' + suffix
63 | name += suffix
64 | return name
65 |
66 | def draw_motion_vectors(frame, motion_vectors):
67 | if len(motion_vectors) > 0:
68 | num_mvs = np.shape(motion_vectors)[0]
69 | for mv in np.split(motion_vectors, num_mvs):
70 | start_pt = (mv[0, 3], mv[0, 4])
71 | end_pt = (mv[0, 5], mv[0, 6])
72 | cv2.arrowedLine(frame, start_pt, end_pt, (0, 0, 255), 1, cv2.LINE_AA, 0, 0.1)
73 | # cv2.arrowedLine(frame, start_pt, end_pt, (0, 0, 255), 2, cv2.LINE_AA, 0, 0.2)
74 | return frame
75 |
76 | def extract_motion_vectors(input_video,fps=4, dump=False, verbose=False, visual_mv=False):
77 |
78 | if dump:
79 | now = datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
80 | for child in ["frames", "motion_vectors"]:
81 | os.makedirs(os.path.join(f"out-{now}", child), exist_ok=True)
82 | temp = rand_name()
83 | # tmp_video = f'{temp}_{input_video}'
84 | tmp_video = os.path.join(input_video.split("/")[0], f'{temp}' +input_video.split("/")[-1])
85 | videocapture = cv2.VideoCapture(input_video)
86 | frames_num = videocapture.get(cv2.CAP_PROP_FRAME_COUNT)
87 | fps_video =videocapture.get(cv2.CAP_PROP_FPS)
88 | # check if enough frames
89 | if frames_num/fps_video*fps>16: #
90 | fps = max(fps, 1)
91 | else:
92 | fps = int(16/(frames_num/fps_video)) + 1
93 | ffmpeg_cmd = f'ffmpeg -threads 8 -loglevel error -i {input_video} -filter:v fps={fps} -c:v mpeg4 -f rawvideo {tmp_video}'
94 |
95 | if os.path.exists(tmp_video):
96 | os.remove(tmp_video)
97 |
98 | subprocess.run(args=ffmpeg_cmd,shell=True,timeout=120)
99 |
100 | cap = VideoCap()
101 | # open the video file
102 | ret = cap.open(tmp_video)
103 | if not ret:
104 | raise RuntimeError(f"Could not open {tmp_video}")
105 |
106 | step = 0
107 | times = []
108 |
109 | frame_types = []
110 | frames = []
111 | mvs = []
112 | mvs_visual = []
113 | # continuously read and display video frames and motion vectors
114 | while True:
115 | if verbose:
116 | print("Frame: ", step, end=" ")
117 |
118 | tstart = time.perf_counter()
119 |
120 | # read next video frame and corresponding motion vectors
121 | ret, frame, motion_vectors, frame_type, timestamp = cap.read()
122 |
123 | tend = time.perf_counter()
124 | telapsed = tend - tstart
125 | times.append(telapsed)
126 |
127 | # if there is an error reading the frame
128 | if not ret:
129 | if verbose:
130 | print("No frame read. Stopping.")
131 | break
132 |
133 | frame_save = np.zeros(frame.copy().shape, dtype=np.uint8) # *255
134 | if visual_mv:
135 | frame_save = draw_motion_vectors(frame_save, motion_vectors)
136 |
137 | # store motion vectors, frames, etc. in output directory
138 | dump = False
139 | if frame.shape[1] >= frame.shape[0]:
140 | w_half = (frame.shape[1] - frame.shape[0])//2
141 | if dump:
142 | cv2.imwrite(os.path.join(f"./mv_visual/", f"frame-{step}.jpg"), frame_save[:,w_half:-w_half])
143 | mvs_visual.append(frame_save[:,w_half:-w_half])
144 | else:
145 | h_half = (frame.shape[0] - frame.shape[1])//2
146 | if dump:
147 | cv2.imwrite(os.path.join(f"./mv_visual/", f"frame-{step}.jpg"), frame_save[h_half:-h_half,:])
148 | mvs_visual.append(frame_save[h_half:-h_half,:])
149 |
150 | h,w = frame.shape[:2]
151 | mv = np.zeros((h,w,2))
152 | position = motion_vectors[:,5:7].clip((0,0),(w-1,h-1))
153 | mv[position[:,1],position[:,0]]=motion_vectors[:,0:1]*motion_vectors[:,7:9]/motion_vectors[:, 9:]
154 |
155 | step += 1
156 | frame_types.append(frame_type)
157 | frames.append(frame)
158 | mvs.append(mv)
159 | # mvs_visual.append(frame_save)
160 | if verbose:
161 | print("average dt: ", np.mean(times))
162 | cap.release()
163 |
164 | if os.path.exists(tmp_video):
165 | os.remove(tmp_video)
166 |
167 | return frame_types,frames,mvs, mvs_visual
168 |
169 |
170 | class VideoDataset(Dataset):
171 | def __init__(self,
172 | cfg,
173 | tokenizer=None,
174 | max_words=30,
175 | feature_framerate=1,
176 | max_frames=16,
177 | image_resolution=224,
178 | transforms=None,
179 | mv_transforms = None,
180 | misc_transforms = None,
181 | vit_transforms=None,
182 | vit_image_size = 336,
183 | misc_size = 384):
184 |
185 | self.cfg = cfg
186 |
187 | self.tokenizer = tokenizer
188 | self.max_words = max_words
189 | self.feature_framerate = feature_framerate
190 | self.max_frames = max_frames
191 | self.image_resolution = image_resolution
192 | self.transforms = transforms
193 | self.vit_transforms = vit_transforms
194 | self.vit_image_size = vit_image_size
195 | self.misc_transforms = misc_transforms
196 | self.misc_size = misc_size
197 |
198 | self.mv_transforms = mv_transforms
199 |
200 | self.video_cap_pairs = [[self.cfg.input_video, self.cfg.input_text_desc]]
201 | self.Vit_image_random_resize = T.Resize((vit_image_size, vit_image_size))
202 |
203 | self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>",
204 | "MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"}
205 |
206 | def __len__(self):
207 | return len(self.video_cap_pairs)
208 |
209 | def __getitem__(self, index):
210 |
211 | video_key, cap_txt = self.video_cap_pairs[index]
212 |
213 | total_frames = None
214 |
215 | feature_framerate = self.feature_framerate
216 | if os.path.exists(video_key):
217 | try:
218 | ref_frame, vit_image, video_data, misc_data, mv_data = self._get_video_traindata(video_key, feature_framerate, total_frames, self.cfg.mvs_visual)
219 | except Exception as e:
220 | print('{} get frames failed... with error: {}'.format(video_key, e), flush=True)
221 |
222 | ref_frame = torch.zeros(3, self.vit_image_size, self.vit_image_size)
223 | vit_image = torch.zeros(3,self.vit_image_size,self.vit_image_size)
224 | video_data = torch.zeros(self.max_frames, 3, self.image_resolution, self.image_resolution)
225 | misc_data = torch.zeros(self.max_frames, 3, self.misc_size, self.misc_size)
226 |
227 | mv_data = torch.zeros(self.max_frames, 2, self.image_resolution, self.image_resolution)
228 | else:
229 | print("The video path does not exist or no video dir provided!")
230 | ref_frame = torch.zeros(3, self.vit_image_size, self.vit_image_size)
231 | vit_image = torch.zeros(3,self.vit_image_size,self.vit_image_size)
232 | video_data = torch.zeros(self.max_frames, 3, self.image_resolution, self.image_resolution)
233 | misc_data = torch.zeros(self.max_frames, 3, self.misc_size, self.misc_size)
234 |
235 | mv_data = torch.zeros(self.max_frames, 2, self.image_resolution, self.image_resolution)
236 |
237 |
238 | # inpainting mask
239 | p = random.random()
240 | if p < 0.7:
241 | mask = ops.make_irregular_mask(512, 512)
242 | elif p < 0.9:
243 | mask = ops.make_rectangle_mask(512, 512)
244 | else:
245 | mask = ops.make_uncrop(512, 512)
246 | mask = torch.from_numpy(cv2.resize(mask, (self.misc_size,self.misc_size), interpolation=cv2.INTER_NEAREST)).unsqueeze(0).float()
247 |
248 | mask = mask.unsqueeze(0).repeat_interleave(repeats=self.max_frames,dim=0)
249 |
250 |
251 | return ref_frame, cap_txt, video_data, misc_data, feature_framerate, mask, mv_data
252 |
253 | def _get_video_traindata(self, video_key, feature_framerate, total_frames, visual_mv):
254 |
255 | # folder_name = "cache_temp/"
256 | # filename = folder_name + osp.basename(video_key)
257 | # if not os.path.exists(folder_name):
258 | # os.makedirs(folder_name, exist_ok=True)
259 | # oss_path = osp.join(self.root_dir, video_key)
260 | # bucket, oss_key = ops.parse_oss_url(oss_path)
261 | # ops.get_object_to_file(bucket, oss_key, filename)
262 | filename = video_key
263 | for _ in range(5):
264 | try:
265 | frame_types,frames,mvs, mvs_visual = extract_motion_vectors(input_video=filename,fps=feature_framerate, visual_mv=visual_mv)
266 | # os.remove(filename)
267 | break
268 | except Exception as e:
269 | print('{} read video frames and motion vectors failed with error: {}'.format(video_key, e), flush=True)
270 |
271 | total_frames = len(frame_types)
272 | start_indexs = np.where((np.array(frame_types)=='I') & (total_frames - np.arange(total_frames) >= self.max_frames))[0]
273 | start_index = np.random.choice(start_indexs)
274 | indices = np.arange(start_index, start_index+self.max_frames)
275 |
276 | # note frames are in BGR mode, need to trans to RGB mode
277 | frames = [Image.fromarray(frames[i][:, :, ::-1]) for i in indices]
278 | mvs = [torch.from_numpy(mvs[i].transpose((2,0,1))) for i in indices]
279 | mvs = torch.stack(mvs)
280 | # set_trace()
281 | # if mvs_visual != None:
282 | if visual_mv:
283 | # images = [(mvs_visual[i][:,:,::-1]*255).astype('uint8') for i in indices]
284 | images = [(mvs_visual[i][:,:,::-1]).astype('uint8') for i in indices]
285 | # images = [mvs_visual[i] for i in indices]
286 | # images = [(image.numpy()*255).astype('uint8') for image in images]
287 | path = self.cfg.log_dir + "/visual_mv/" + video_key.split("/")[-1] + ".gif"
288 | if not os.path.exists(self.cfg.log_dir + "/visual_mv/"):
289 | os.makedirs(self.cfg.log_dir + "/visual_mv/", exist_ok=True)
290 | print("save motion vectors visualization to :", path)
291 | imageio.mimwrite(path, images, fps=8)
292 |
293 | # mvs_visual = [torch.from_numpy(mvs_visual[i].transpose((2,0,1))) for i in indices]
294 | # mvs_visual = torch.stack(mvs_visual)
295 | # mvs_visual = self.mv_transforms(mvs_visual)
296 |
297 | have_frames = len(frames)>0
298 | middle_indix = int(len(frames)/2)
299 | if have_frames:
300 | ref_frame = frames[middle_indix]
301 | vit_image = self.vit_transforms(ref_frame)
302 | misc_imgs_np = self.misc_transforms[:2](frames)
303 | misc_imgs = self.misc_transforms[2:](misc_imgs_np)
304 | frames = self.transforms(frames)
305 | mvs = self.mv_transforms(mvs)
306 | else:
307 | # ref_frame = Image.fromarray(np.zeros((3, self.image_resolution, self.image_resolution)))
308 | vit_image = torch.zeros(3,self.vit_image_size,self.vit_image_size)
309 |
310 | video_data = torch.zeros(self.max_frames, 3, self.image_resolution, self.image_resolution)
311 | mv_data = torch.zeros(self.max_frames, 2, self.image_resolution, self.image_resolution)
312 | misc_data = torch.zeros(self.max_frames, 3, self.misc_size, self.misc_size)
313 | if have_frames:
314 | video_data[:len(frames), ...] = frames # [[XX...],[...], ..., [0,0...], [], ...]
315 | misc_data[:len(frames), ...] = misc_imgs
316 | mv_data[:len(frames), ...] = mvs
317 |
318 |
319 | ref_frame = vit_image
320 |
321 | del frames
322 | del misc_imgs
323 | del mvs
324 |
325 | return ref_frame, vit_image, video_data, misc_data, mv_data
326 |
327 |
--------------------------------------------------------------------------------
/tools/videocomposer/mha_flash.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.cuda.amp as amp
4 | import torch.nn.functional as F
5 | import math
6 | import os
7 | import time
8 | import numpy as np
9 | import random
10 |
11 | from flash_attn.flash_attention import FlashAttention
12 |
13 | class FlashAttentionBlock(nn.Module):
14 |
15 | def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None, batch_size=4):
16 | # consider head_dim first, then num_heads
17 | num_heads = dim // head_dim if head_dim else num_heads
18 | head_dim = dim // num_heads
19 | assert num_heads * head_dim == dim
20 | super(FlashAttentionBlock, self).__init__()
21 | self.dim = dim
22 | self.context_dim = context_dim
23 | self.num_heads = num_heads
24 | self.head_dim = head_dim
25 | self.scale = math.pow(head_dim, -0.25)
26 |
27 | # layers
28 | self.norm = nn.GroupNorm(32, dim)
29 | self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
30 | if context_dim is not None:
31 | self.context_kv = nn.Linear(context_dim, dim * 2)
32 | self.proj = nn.Conv2d(dim, dim, 1)
33 |
34 | if self.head_dim <= 128 and (self.head_dim % 8) == 0:
35 | new_scale = math.pow(head_dim, -0.5)
36 | self.flash_attn = FlashAttention(softmax_scale=None, attention_dropout=0.0)
37 |
38 | # zero out the last layer params
39 | nn.init.zeros_(self.proj.weight)
40 | # self.apply(self._init_weight)
41 |
42 |
43 | def _init_weight(self, module):
44 | if isinstance(module, nn.Linear):
45 | module.weight.data.normal_(mean=0.0, std=0.15)
46 | if module.bias is not None:
47 | module.bias.data.zero_()
48 | elif isinstance(module, nn.Conv2d):
49 | module.weight.data.normal_(mean=0.0, std=0.15)
50 | if module.bias is not None:
51 | module.bias.data.zero_()
52 |
53 | def forward(self, x, context=None):
54 | r"""x: [B, C, H, W].
55 | context: [B, L, C] or None.
56 | """
57 | identity = x
58 | b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim
59 |
60 | # compute query, key, value
61 | x = self.norm(x)
62 | q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1)
63 | if context is not None:
64 | ck, cv = self.context_kv(context).reshape(b, -1, n * 2, d).permute(0, 2, 3, 1).chunk(2, dim=1)
65 | k = torch.cat([ck, k], dim=-1)
66 | v = torch.cat([cv, v], dim=-1)
67 | cq = torch.zeros([b, n, d, 4], dtype=q.dtype, device=q.device)
68 | q = torch.cat([q, cq], dim=-1)
69 |
70 | qkv = torch.cat([q,k,v], dim=1)
71 | origin_dtype = qkv.dtype
72 | qkv = qkv.permute(0, 3, 1, 2).reshape(b, -1, 3, n, d).half().contiguous()
73 | out, _ = self.flash_attn(qkv)
74 | out.to(origin_dtype)
75 |
76 | if context is not None:
77 | out = out[:, :-4, :, :]
78 | out = out.permute(0, 2, 3, 1).reshape(b, c, h, w)
79 |
80 | # output
81 | x = self.proj(out)
82 | return x + identity
83 |
84 | if __name__ == '__main__':
85 | batch_size = 8
86 | flash_net = FlashAttentionBlock(dim=1280, context_dim=512, num_heads=None, head_dim=64, batch_size=batch_size).cuda()
87 |
88 | x = torch.randn([batch_size, 1280, 32, 32], dtype=torch.float32).cuda()
89 | context = torch.randn([batch_size, 4, 512], dtype=torch.float32).cuda()
90 | # context = None
91 | flash_net.eval()
92 |
93 | with amp.autocast(enabled=True):
94 | # warm up
95 | for i in range(5):
96 | y = flash_net(x, context)
97 | torch.cuda.synchronize()
98 | s1 = time.time()
99 | for i in range(10):
100 | y = flash_net(x, context)
101 | torch.cuda.synchronize()
102 | s2 = time.time()
103 |
104 | print(f'Average cost time {(s2-s1)*1000/10} ms')
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/utils/__init__.py
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/utils/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/config.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/utils/__pycache__/config.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/distributed.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/utils/__pycache__/distributed.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/logging.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/videocomposer/490ed21b5c41ecbb12fd024de4308ce68b1e0f64/utils/__pycache__/logging.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import yaml
3 | import json
4 | import copy
5 | import argparse
6 |
7 | import utils.logging as logging
8 | logger = logging.get_logger(__name__)
9 |
10 | class Config(object):
11 | def __init__(self, load=True, cfg_dict=None, cfg_level=None):
12 | self._level = "cfg" + ("." + cfg_level if cfg_level is not None else "")
13 | if load:
14 | self.args = self._parse_args()
15 | logger.info("Loading config from {}.".format(self.args.cfg_file))
16 | self.need_initialization = True
17 | cfg_base = self._initialize_cfg()
18 | cfg_dict = self._load_yaml(self.args)
19 | cfg_dict = self._merge_cfg_from_base(cfg_base, cfg_dict)
20 | cfg_dict = self._update_from_args(cfg_dict)
21 | self.cfg_dict = cfg_dict
22 | self._update_dict(cfg_dict)
23 |
24 | def _parse_args(self):
25 | parser = argparse.ArgumentParser(
26 | description="Argparser for configuring [code base name to think of] codebase"
27 | )
28 | parser.add_argument(
29 | "--cfg",
30 | dest="cfg_file",
31 | help="Path to the configuration file",
32 | default='configs/exp01_vidcomposer_full.yaml'
33 | )
34 | parser.add_argument(
35 | "--init_method",
36 | help="Initialization method, includes TCP or shared file-system",
37 | default="tcp://localhost:9999",
38 | type=str,
39 | )
40 | parser.add_argument(
41 | "--seed",
42 | type=int,
43 | default=8888,
44 | help="Need to explore for different videos"
45 | )
46 | parser.add_argument(
47 | '--debug',
48 | action='store_true',
49 | default=False,
50 | help='Into debug information'
51 | )
52 | parser.add_argument(
53 | '--input_video',
54 | default='demo_video/video_8800.mp4',
55 | help='input video for full task, or motion vector of input videos',
56 | type=str,
57 | ),
58 | parser.add_argument(
59 | '--image_path',
60 | default='',
61 | help='Single Image Input',
62 | type=str
63 | )
64 | parser.add_argument(
65 | '--sketch_path',
66 | default='',
67 | help='Single Sketch Input',
68 | type=str
69 | )
70 | parser.add_argument(
71 | '--style_image',
72 | help='Single Sketch Input',
73 | type=str
74 | )
75 | parser.add_argument(
76 | '--input_text_desc',
77 | default='A colorful and beautiful fish swimming in a small glass bowl with multicolored piece of stone, Macro Video',
78 | type=str,
79 | ),
80 | parser.add_argument(
81 | "opts",
82 | help="other configurations",
83 | default=None,
84 | nargs=argparse.REMAINDER
85 | )
86 | return parser.parse_args()
87 |
88 | def _path_join(self, path_list):
89 | path = ""
90 | for p in path_list:
91 | path+= p + '/'
92 | return path[:-1]
93 |
94 | def _update_from_args(self, cfg_dict):
95 | args = self.args
96 | for var in vars(args):
97 | cfg_dict[var] = getattr(args, var)
98 | return cfg_dict
99 |
100 | def _initialize_cfg(self):
101 | if self.need_initialization:
102 | self.need_initialization = False
103 | if os.path.exists('./configs/base.yaml'):
104 | with open("./configs/base.yaml", 'r') as f:
105 | cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
106 | else:
107 | with open(os.path.realpath(__file__).split('/')[-3] + "/configs/base.yaml", 'r') as f:
108 | cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
109 | return cfg
110 |
111 | def _load_yaml(self, args, file_name=""):
112 | assert args.cfg_file is not None
113 | if not file_name == "": # reading from base file
114 | with open(file_name, 'r') as f:
115 | cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
116 | else:
117 | if os.getcwd().split("/")[-1] == args.cfg_file.split("/")[0]:
118 | args.cfg_file = args.cfg_file.replace(os.getcwd().split("/")[-1], "./")
119 | try:
120 | with open(args.cfg_file, 'r') as f:
121 | cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
122 | file_name = args.cfg_file
123 | except:
124 | args.cfg_file = os.path.realpath(__file__).split('/')[-3] + "/" + args.cfg_file
125 | with open(args.cfg_file, 'r') as f:
126 | cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
127 | file_name = args.cfg_file
128 |
129 | if "_BASE_RUN" not in cfg.keys() and "_BASE_MODEL" not in cfg.keys() and "_BASE" not in cfg.keys():
130 | # return cfg if the base file is being accessed
131 | return cfg
132 |
133 | if "_BASE" in cfg.keys():
134 | if cfg["_BASE"][1] == '.':
135 | prev_count = cfg["_BASE"].count('..')
136 | cfg_base_file = self._path_join(file_name.split('/')[:(-1-cfg["_BASE"].count('..'))] + cfg["_BASE"].split('/')[prev_count:])
137 | else:
138 | cfg_base_file = cfg["_BASE"].replace(
139 | "./",
140 | args.cfg_file.replace(args.cfg_file.split('/')[-1], "")
141 | )
142 | cfg_base = self._load_yaml(args, cfg_base_file)
143 | cfg = self._merge_cfg_from_base(cfg_base, cfg)
144 | else:
145 | if "_BASE_RUN" in cfg.keys():
146 | if cfg["_BASE_RUN"][1] == '.':
147 | prev_count = cfg["_BASE_RUN"].count('..')
148 | cfg_base_file = self._path_join(file_name.split('/')[:(-1-prev_count)] + cfg["_BASE_RUN"].split('/')[prev_count:])
149 | else:
150 | cfg_base_file = cfg["_BASE_RUN"].replace(
151 | "./",
152 | args.cfg_file.replace(args.cfg_file.split('/')[-1], "")
153 | )
154 | cfg_base = self._load_yaml(args, cfg_base_file)
155 | cfg = self._merge_cfg_from_base(cfg_base, cfg, preserve_base=True)
156 | if "_BASE_MODEL" in cfg.keys():
157 | if cfg["_BASE_MODEL"][1] == '.':
158 | prev_count = cfg["_BASE_MODEL"].count('..')
159 | cfg_base_file = self._path_join(file_name.split('/')[:(-1-cfg["_BASE_MODEL"].count('..'))] + cfg["_BASE_MODEL"].split('/')[prev_count:])
160 | else:
161 | cfg_base_file = cfg["_BASE_MODEL"].replace(
162 | "./",
163 | args.cfg_file.replace(args.cfg_file.split('/')[-1], "")
164 | )
165 | cfg_base = self._load_yaml(args, cfg_base_file)
166 | cfg = self._merge_cfg_from_base(cfg_base, cfg)
167 | cfg = self._merge_cfg_from_command(args, cfg)
168 | return cfg
169 |
170 | def _merge_cfg_from_base(self, cfg_base, cfg_new, preserve_base=False):
171 | for k,v in cfg_new.items():
172 | if k in cfg_base.keys():
173 | if isinstance(v, dict):
174 | self._merge_cfg_from_base(cfg_base[k], v)
175 | else:
176 | cfg_base[k] = v
177 | else:
178 | if "BASE" not in k or preserve_base:
179 | cfg_base[k] = v
180 | return cfg_base
181 |
182 | def _merge_cfg_from_command(self, args, cfg):
183 | assert len(args.opts) % 2 == 0, 'Override list {} has odd length: {}.'.format(
184 | args.opts, len(args.opts)
185 | )
186 | keys = args.opts[0::2]
187 | vals = args.opts[1::2]
188 |
189 | # maximum supported depth 3
190 | for idx, key in enumerate(keys):
191 | key_split = key.split('.')
192 | assert len(key_split) <= 4, 'Key depth error. \nMaximum depth: 3\n Get depth: {}'.format(
193 | len(key_split)
194 | )
195 | assert key_split[0] in cfg.keys(), 'Non-existant key: {}.'.format(
196 | key_split[0]
197 | )
198 | if len(key_split) == 2:
199 | assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format(
200 | key
201 | )
202 | elif len(key_split) == 3:
203 | assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format(
204 | key
205 | )
206 | assert key_split[2] in cfg[key_split[0]][key_split[1]].keys(), 'Non-existant key: {}.'.format(
207 | key
208 | )
209 | elif len(key_split) == 4:
210 | assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format(
211 | key
212 | )
213 | assert key_split[2] in cfg[key_split[0]][key_split[1]].keys(), 'Non-existant key: {}.'.format(
214 | key
215 | )
216 | assert key_split[3] in cfg[key_split[0]][key_split[1]][key_split[2]].keys(), 'Non-existant key: {}.'.format(
217 | key
218 | )
219 | if len(key_split) == 1:
220 | cfg[key_split[0]] = vals[idx]
221 | elif len(key_split) == 2:
222 | cfg[key_split[0]][key_split[1]] = vals[idx]
223 | elif len(key_split) == 3:
224 | cfg[key_split[0]][key_split[1]][key_split[2]] = vals[idx]
225 | elif len(key_split) == 4:
226 | cfg[key_split[0]][key_split[1]][key_split[2]][key_split[3]] = vals[idx]
227 | return cfg
228 |
229 | def _update_dict(self, cfg_dict):
230 | def recur(key, elem):
231 | if type(elem) is dict:
232 | return key, Config(load=False, cfg_dict=elem, cfg_level=key)
233 | else:
234 | if type(elem) is str and elem[1:3]=="e-":
235 | elem = float(elem)
236 | return key, elem
237 | dic = dict(recur(k, v) for k, v in cfg_dict.items())
238 | self.__dict__.update(dic)
239 |
240 | def get_args(self):
241 | return self.args
242 |
243 | def __repr__(self):
244 | return "{}\n".format(self.dump())
245 |
246 | def dump(self):
247 | return json.dumps(self.cfg_dict, indent=2)
248 |
249 | def deep_copy(self):
250 | return copy.deepcopy(self)
251 |
252 | if __name__ == '__main__':
253 | # debug
254 | cfg = Config(load=True)
255 | print(cfg.DATA)
256 |
--------------------------------------------------------------------------------
/utils/distributed.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
3 |
4 | """Distributed helpers."""
5 |
6 | import functools
7 | import logging
8 | import pickle
9 | import torch
10 | import torch.distributed as dist
11 |
12 | _LOCAL_PROCESS_GROUP = None
13 |
14 |
15 | def all_gather(tensors):
16 | """
17 | All gathers the provided tensors from all processes across machines.
18 | Args:
19 | tensors (list): tensors to perform all gather across all processes in
20 | all machines.
21 | """
22 |
23 | gather_list = []
24 | output_tensor = []
25 | world_size = dist.get_world_size()
26 | for tensor in tensors:
27 | tensor_placeholder = [
28 | torch.ones_like(tensor) for _ in range(world_size)
29 | ]
30 | dist.all_gather(tensor_placeholder, tensor, async_op=False)
31 | gather_list.append(tensor_placeholder)
32 | for gathered_tensor in gather_list:
33 | output_tensor.append(torch.cat(gathered_tensor, dim=0))
34 | return output_tensor
35 |
36 |
37 | def all_reduce(tensors, average=True):
38 | """
39 | All reduce the provided tensors from all processes across machines.
40 | Args:
41 | tensors (list): tensors to perform all reduce across all processes in
42 | all machines.
43 | average (bool): scales the reduced tensor by the number of overall
44 | processes across all machines.
45 | """
46 |
47 | for tensor in tensors:
48 | dist.all_reduce(tensor, async_op=False)
49 | if average:
50 | world_size = dist.get_world_size()
51 | for tensor in tensors:
52 | tensor.mul_(1.0 / world_size)
53 | return tensors
54 |
55 |
56 | def init_process_group(
57 | local_rank,
58 | local_world_size,
59 | shard_id,
60 | num_shards,
61 | init_method,
62 | dist_backend="nccl",
63 | ):
64 | """
65 | Initializes the default process group.
66 | Args:
67 | local_rank (int): the rank on the current local machine.
68 | local_world_size (int): the world size (number of processes running) on
69 | the current local machine.
70 | shard_id (int): the shard index (machine rank) of the current machine.
71 | num_shards (int): number of shards for distributed training.
72 | init_method (string): supporting three different methods for
73 | initializing process groups:
74 | "file": use shared file system to initialize the groups across
75 | different processes.
76 | "tcp": use tcp address to initialize the groups across different
77 | dist_backend (string): backend to use for distributed training. Options
78 | includes gloo, mpi and nccl, the details can be found here:
79 | https://pytorch.org/docs/stable/distributed.html
80 | """
81 | # Sets the GPU to use.
82 | torch.cuda.set_device(local_rank)
83 | # Initialize the process group.
84 | proc_rank = local_rank + shard_id * local_world_size
85 | world_size = local_world_size * num_shards
86 | dist.init_process_group(
87 | backend=dist_backend,
88 | init_method=init_method,
89 | world_size=world_size,
90 | rank=proc_rank,
91 | )
92 |
93 |
94 | def is_master_proc(num_gpus=8):
95 | """
96 | Determines if the current process is the master process.
97 | """
98 | if torch.distributed.is_initialized():
99 | return dist.get_rank() % num_gpus == 0
100 | else:
101 | return True
102 |
103 |
104 | def get_world_size():
105 | """
106 | Get the size of the world.
107 | """
108 | if not dist.is_available():
109 | return 1
110 | if not dist.is_initialized():
111 | return 1
112 | return dist.get_world_size()
113 |
114 |
115 | def get_rank():
116 | """
117 | Get the rank of the current process.
118 | """
119 | if not dist.is_available():
120 | return 0
121 | if not dist.is_initialized():
122 | return 0
123 | return dist.get_rank()
124 |
125 |
126 | def synchronize():
127 | """
128 | Helper function to synchronize (barrier) among all processes when
129 | using distributed training
130 | """
131 | if not dist.is_available():
132 | return
133 | if not dist.is_initialized():
134 | return
135 | world_size = dist.get_world_size()
136 | if world_size == 1:
137 | return
138 | dist.barrier()
139 |
140 |
141 | @functools.lru_cache()
142 | def _get_global_gloo_group():
143 | """
144 | Return a process group based on gloo backend, containing all the ranks
145 | The result is cached.
146 | Returns:
147 | (group): pytorch dist group.
148 | """
149 | if dist.get_backend() == "nccl":
150 | return dist.new_group(backend="gloo")
151 | else:
152 | return dist.group.WORLD
153 |
154 |
155 | def _serialize_to_tensor(data, group):
156 | """
157 | Seriialize the tensor to ByteTensor. Note that only `gloo` and `nccl`
158 | backend is supported.
159 | Args:
160 | data (data): data to be serialized.
161 | group (group): pytorch dist group.
162 | Returns:
163 | tensor (ByteTensor): tensor that serialized.
164 | """
165 |
166 | backend = dist.get_backend(group)
167 | assert backend in ["gloo", "nccl"]
168 | device = torch.device("cpu" if backend == "gloo" else "cuda")
169 |
170 | buffer = pickle.dumps(data)
171 | if len(buffer) > 1024 ** 3:
172 | logger = logging.getLogger(__name__)
173 | logger.warning(
174 | "Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
175 | get_rank(), len(buffer) / (1024 ** 3), device
176 | )
177 | )
178 | storage = torch.ByteStorage.from_buffer(buffer)
179 | tensor = torch.ByteTensor(storage).to(device=device)
180 | return tensor
181 |
182 |
183 | def _pad_to_largest_tensor(tensor, group):
184 | """
185 | Padding all the tensors from different GPUs to the largest ones.
186 | Args:
187 | tensor (tensor): tensor to pad.
188 | group (group): pytorch dist group.
189 | Returns:
190 | list[int]: size of the tensor, on each rank
191 | Tensor: padded tensor that has the max size
192 | """
193 | world_size = dist.get_world_size(group=group)
194 | assert (
195 | world_size >= 1
196 | ), "comm.gather/all_gather must be called from ranks within the given group!"
197 | local_size = torch.tensor(
198 | [tensor.numel()], dtype=torch.int64, device=tensor.device
199 | )
200 | size_list = [
201 | torch.zeros([1], dtype=torch.int64, device=tensor.device)
202 | for _ in range(world_size)
203 | ]
204 | dist.all_gather(size_list, local_size, group=group)
205 | size_list = [int(size.item()) for size in size_list]
206 |
207 | max_size = max(size_list)
208 |
209 | # we pad the tensor because torch all_gather does not support
210 | # gathering tensors of different shapes
211 | if local_size != max_size:
212 | padding = torch.zeros(
213 | (max_size - local_size,), dtype=torch.uint8, device=tensor.device
214 | )
215 | tensor = torch.cat((tensor, padding), dim=0)
216 | return size_list, tensor
217 |
218 |
219 | def all_gather_unaligned(data, group=None):
220 | """
221 | Run all_gather on arbitrary picklable data (not necessarily tensors).
222 |
223 | Args:
224 | data: any picklable object
225 | group: a torch process group. By default, will use a group which
226 | contains all ranks on gloo backend.
227 |
228 | Returns:
229 | list[data]: list of data gathered from each rank
230 | """
231 | if get_world_size() == 1:
232 | return [data]
233 | if group is None:
234 | group = _get_global_gloo_group()
235 | if dist.get_world_size(group) == 1:
236 | return [data]
237 |
238 | tensor = _serialize_to_tensor(data, group)
239 |
240 | size_list, tensor = _pad_to_largest_tensor(tensor, group)
241 | max_size = max(size_list)
242 |
243 | # receiving Tensor from all ranks
244 | tensor_list = [
245 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device)
246 | for _ in size_list
247 | ]
248 | dist.all_gather(tensor_list, tensor, group=group)
249 |
250 | data_list = []
251 | for size, tensor in zip(size_list, tensor_list):
252 | buffer = tensor.cpu().numpy().tobytes()[:size]
253 | data_list.append(pickle.loads(buffer))
254 |
255 | return data_list
256 |
257 |
258 | def init_distributed_training(cfg):
259 | """
260 | Initialize variables needed for distributed training.
261 | """
262 | if cfg.NUM_GPUS <= 1:
263 | return
264 | num_gpus_per_machine = cfg.NUM_GPUS
265 | num_machines = dist.get_world_size() // num_gpus_per_machine
266 | for i in range(num_machines):
267 | ranks_on_i = list(
268 | range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine)
269 | )
270 | pg = dist.new_group(ranks_on_i)
271 | if i == cfg.SHARD_ID:
272 | global _LOCAL_PROCESS_GROUP
273 | _LOCAL_PROCESS_GROUP = pg
274 |
275 |
276 | def get_local_size() -> int:
277 | """
278 | Returns:
279 | The size of the per-machine process group,
280 | i.e. the number of processes per machine.
281 | """
282 | if not dist.is_available():
283 | return 1
284 | if not dist.is_initialized():
285 | return 1
286 | return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
287 |
288 |
289 | def get_local_rank() -> int:
290 | """
291 | Returns:
292 | The rank of the current process within the local (per-machine) process group.
293 | """
294 | if not dist.is_available():
295 | return 0
296 | if not dist.is_initialized():
297 | return 0
298 | assert _LOCAL_PROCESS_GROUP is not None
299 | return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
300 |
--------------------------------------------------------------------------------
/utils/logging.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
3 |
4 | """Logging."""
5 |
6 | import builtins
7 | import decimal
8 | import functools
9 | import logging
10 | import os
11 | import sys
12 | import simplejson
13 | # from fvcore.common.file_io import PathManager
14 |
15 | import utils.distributed as du
16 |
17 |
18 | def _suppress_print():
19 | """
20 | Suppresses printing from the current process.
21 | """
22 |
23 | def print_pass(*objects, sep=" ", end="\n", file=sys.stdout, flush=False):
24 | pass
25 |
26 | builtins.print = print_pass
27 |
28 |
29 | # @functools.lru_cache(maxsize=None)
30 | # def _cached_log_stream(filename):
31 | # return PathManager.open(filename, "a")
32 |
33 |
34 | def setup_logging(cfg, log_file):
35 | """
36 | Sets up the logging for multiple processes. Only enable the logging for the
37 | master process, and suppress logging for the non-master processes.
38 | """
39 | if du.is_master_proc():
40 | # Enable logging for the master process.
41 | logging.root.handlers = []
42 | else:
43 | # Suppress logging for non-master processes.
44 | _suppress_print()
45 |
46 | logger = logging.getLogger()
47 | logger.setLevel(logging.INFO)
48 | logger.propagate = False
49 | plain_formatter = logging.Formatter(
50 | "[%(asctime)s][%(levelname)s] %(name)s: %(lineno)4d: %(message)s",
51 | datefmt="%m/%d %H:%M:%S",
52 | )
53 |
54 | if du.is_master_proc():
55 | ch = logging.StreamHandler(stream=sys.stdout)
56 | ch.setLevel(logging.DEBUG)
57 | ch.setFormatter(plain_formatter)
58 | logger.addHandler(ch)
59 |
60 | if log_file is not None and du.is_master_proc(du.get_world_size()):
61 | filename = os.path.join(cfg.OUTPUT_DIR, log_file)
62 | fh = logging.FileHandler(filename)
63 | fh.setLevel(logging.DEBUG)
64 | fh.setFormatter(plain_formatter)
65 | logger.addHandler(fh)
66 |
67 |
68 | def get_logger(name):
69 | """
70 | Retrieve the logger with the specified name or, if name is None, return a
71 | logger which is the root logger of the hierarchy.
72 | Args:
73 | name (string): name of the logger.
74 | """
75 | return logging.getLogger(name)
76 |
77 |
78 | def log_json_stats(stats):
79 | """
80 | Logs json stats.
81 | Args:
82 | stats (dict): a dictionary of statistical information to log.
83 | """
84 | stats = {
85 | k: decimal.Decimal("{:.6f}".format(v)) if isinstance(v, float) else v
86 | for k, v in stats.items()
87 | }
88 | json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True)
89 | logger = get_logger(__name__)
90 | logger.info("{:s}".format(json_stats))
91 |
--------------------------------------------------------------------------------