├── LICENSE
├── README.md
├── assets
└── f.png
├── configs
├── base.yaml
├── i2vgen_xl_infer.yaml
├── i2vgen_xl_train.yaml
├── t2v_infer.yaml
├── t2v_train.yaml
└── t2v_train_laion.yaml
├── core
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ ├── attention.cpython-38.pyc
│ ├── gs.cpython-38.pyc
│ ├── models.cpython-38.pyc
│ ├── options.cpython-38.pyc
│ ├── unet.cpython-38.pyc
│ └── utils.cpython-38.pyc
├── attention.py
├── gs.py
├── models.py
├── options.py
├── provider_objaverse.py
├── unet.py
└── utils.py
├── data
├── images
│ ├── demo1.png
│ └── demo2.png
├── lvis_thres_28.json
├── stable_diffusion_image_key_temporal_attention_x1.json
├── test_images.txt
├── test_prompts.txt
├── text_captions_cap3d.json
└── valid_paths_v4_cap_filter_thres_28_catfilter19w.json
├── inference.py
├── install.sh
├── requirements.txt
├── tools
├── __init__.py
├── __pycache__
│ └── __init__.cpython-38.pyc
├── annotator
│ ├── canny
│ │ └── __init__.py
│ ├── histogram
│ │ ├── __init__.py
│ │ └── palette.py
│ ├── sketch
│ │ ├── __init__.py
│ │ ├── pidinet.py
│ │ └── sketch_simplification.py
│ └── util.py
├── basic_funcs
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-310.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── pretrain_functions.cpython-310.pyc
│ │ └── pretrain_functions.cpython-38.pyc
│ └── pretrain_functions.py
├── datasets
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-310.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── image_dataset.cpython-310.pyc
│ │ ├── image_dataset.cpython-38.pyc
│ │ ├── video_dataset.cpython-310.pyc
│ │ ├── video_dataset.cpython-38.pyc
│ │ ├── video_i2v_dataset.cpython-310.pyc
│ │ └── video_i2v_dataset.cpython-38.pyc
│ ├── image_dataset.py
│ ├── laion_dataset.py
│ ├── video_dataset.py
│ └── video_i2v_dataset.py
├── hooks
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-310.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── visual_train_it2v_video.cpython-310.pyc
│ │ └── visual_train_it2v_video.cpython-38.pyc
│ ├── visual_train_it2v_video.py
│ └── visual_train_t2v.py
├── inferences
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-310.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── inference_i2vgen_entrance.cpython-310.pyc
│ │ ├── inference_i2vgen_entrance.cpython-38.pyc
│ │ ├── inference_text2video_entrance.cpython-310.pyc
│ │ └── inference_text2video_entrance.cpython-38.pyc
│ ├── inference_i2vgen_entrance.py
│ └── inference_text2video_entrance.py
├── modules
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-310.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── autoencoder.cpython-310.pyc
│ │ ├── autoencoder.cpython-38.pyc
│ │ ├── clip_embedder.cpython-310.pyc
│ │ ├── clip_embedder.cpython-38.pyc
│ │ ├── config.cpython-310.pyc
│ │ └── config.cpython-38.pyc
│ ├── autoencoder.py
│ ├── clip_embedder.py
│ ├── config.py
│ ├── diffusions
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-310.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── diffusion_ddim.cpython-310.pyc
│ │ │ ├── diffusion_ddim.cpython-38.pyc
│ │ │ ├── losses.cpython-310.pyc
│ │ │ ├── losses.cpython-38.pyc
│ │ │ ├── schedules.cpython-310.pyc
│ │ │ └── schedules.cpython-38.pyc
│ │ ├── diffusion_ddim.py
│ │ ├── losses.py
│ │ └── schedules.py
│ └── unet
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ ├── __init__.cpython-310.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── depthwise_attn.cpython-310.pyc
│ │ ├── depthwise_attn.cpython-38.pyc
│ │ ├── depthwise_net.cpython-310.pyc
│ │ ├── depthwise_net.cpython-38.pyc
│ │ ├── depthwise_utils.cpython-310.pyc
│ │ ├── depthwise_utils.cpython-38.pyc
│ │ ├── unet_i2vgen.cpython-310.pyc
│ │ ├── unet_i2vgen.cpython-38.pyc
│ │ ├── unet_t2v.cpython-310.pyc
│ │ ├── unet_t2v.cpython-38.pyc
│ │ ├── util.cpython-310.pyc
│ │ └── util.cpython-38.pyc
│ │ ├── mha_flash.py
│ │ ├── unet_i2vgen.py
│ │ ├── unet_t2v.py
│ │ └── util.py
└── train
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-310.pyc
│ ├── __init__.cpython-38.pyc
│ ├── train_i2v_enterance.cpython-310.pyc
│ ├── train_i2v_enterance.cpython-38.pyc
│ ├── train_t2v_enterance.cpython-310.pyc
│ └── train_t2v_enterance.cpython-38.pyc
│ ├── prev_t2v.py
│ ├── train_i2v_enterance.py
│ └── train_t2v_enterance.py
├── train_net.py
└── utils
├── __init__.py
├── __pycache__
├── __init__.cpython-310.pyc
├── __init__.cpython-38.pyc
├── assign_cfg.cpython-310.pyc
├── assign_cfg.cpython-38.pyc
├── camera_utils.cpython-310.pyc
├── camera_utils.cpython-38.pyc
├── config.cpython-310.pyc
├── config.cpython-38.pyc
├── distributed.cpython-310.pyc
├── distributed.cpython-38.pyc
├── logging.cpython-310.pyc
├── logging.cpython-38.pyc
├── multi_port.cpython-310.pyc
├── multi_port.cpython-38.pyc
├── registry.cpython-310.pyc
├── registry.cpython-38.pyc
├── registry_class.cpython-310.pyc
├── registry_class.cpython-38.pyc
├── seed.cpython-310.pyc
├── seed.cpython-38.pyc
├── transforms.cpython-310.pyc
├── transforms.cpython-38.pyc
├── util.cpython-310.pyc
├── util.cpython-38.pyc
├── video_op.cpython-310.pyc
└── video_op.cpython-38.pyc
├── assign_cfg.py
├── camera_utils.py
├── config.py
├── distributed.py
├── logging.py
├── multi_port.py
├── optim
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-310.pyc
│ ├── __init__.cpython-38.pyc
│ ├── adafactor.cpython-310.pyc
│ ├── adafactor.cpython-38.pyc
│ ├── lr_scheduler.cpython-310.pyc
│ └── lr_scheduler.cpython-38.pyc
├── adafactor.py
└── lr_scheduler.py
├── recenter_i2v.py
├── registry.py
├── registry_class.py
├── seed.py
├── transforms.py
├── util.py
└── video_op.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Alibaba
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## VideoMV: Consistent Multi-View Generation Based on Large Video Generative Model.
2 |
3 | [Qi Zuo\*](https://scholar.google.com/citations?view_op=list_works&hl=en&user=UDnHe2IAAAAJ),
4 | [Xiaodong Gu\*](https://scholar.google.com.hk/citations?user=aJPO514AAAAJ&hl=zh-CN&oi=ao),
5 | [Lingteng Qiu](https://lingtengqiu.github.io/),
6 | [Yuan Dong](dy283090@alibaba-inc.com),
7 | [Zhengyi Zhao](bushe.zzy@alibaba-inc.com),
8 | [Weihao Yuan](https://weihao-yuan.com/),
9 | [Rui Peng](https://prstrive.github.io/),
10 | [Siyu Zhu](https://sites.google.com/site/zhusiyucs/home/),
11 | [Zilong Dong](https://scholar.google.com/citations?user=GHOQKCwAAAAJ&hl=zh-CN&oi=ao),
12 | [Liefeng Bo](https://research.cs.washington.edu/istc/lfb/),
13 | [Qixing Huang](https://www.cs.utexas.edu/~huangqx/)
14 |
15 | https://github.com/alibaba/VideoMV/assets/58206232/3a78e28d-bda4-4d4c-a2ae-994d0320a301
16 |
17 | ## [Project page](https://aigc3d.github.io/VideoMV) | [Paper](https://arxiv.org/abs/2403.12010) | [YouTube](https://www.youtube.com/watch?v=zxjX5p0p0Ks) | [3D Rendering Dataset](https://aigc3d.github.io/gobjaverse)
18 |
19 | ## TODO :triangular_flag_on_post:
20 | - [ ] Release GS、Neus、NeRF reconstruction code.
21 | - [x] News: Release text-to-mv (G-Objaverse + Laion) training code and pretrained model(2024.04.22). Check the Inference&&Training Guidelines.
22 |
23 | Generated Multi-View Images using prompts from DreamFusion420:
24 |
25 | https://github.com/alibaba/VideoMV/assets/58206232/3a4e84e9-a4b2-4ecc-a3e8-7a898e6c3f1a
26 |
27 |
28 | - [x] Release the training code.
29 | - [x] Release multi-view inference code and pretrained weight(G-Objaverse).
30 |
31 | ## Architecture
32 |
33 | 
34 |
35 | ## Install
36 |
37 | - System requirement: Ubuntu20.04
38 | - Tested GPUs: A100
39 |
40 | Install requirements using following scripts.
41 |
42 | ```bash
43 | git clone https://github.com/alibaba/VideoMV.git
44 | conda create -n VideoMV python=3.8
45 | conda activate VideoMV
46 | cd VideoMV && bash install.sh
47 | ```
48 |
49 | ## Inference
50 |
51 | ```bash
52 | # Download our pretrained models
53 | wget https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/aigc3d/pretrained_models.zip
54 | unzip pretrained_models.zip
55 | # text-to-mv sampling
56 | CUDA_VISIBLE_DEVICES=0 python inference.py --cfg ./configs/t2v_infer.yaml
57 | # text-to-mv sampling using pretrained model trained on laion+Gobjaverse
58 | wget oss://virutalbuy-public/share/aigc3d/videomv_laion/non_ema_00365000.pth
59 | # modify the [test_model] as the location of [non_ema_00365000.pth]
60 | CUDA_VISIBLE_DEVICES=0 python inference.py --cfg ./configs/t2v_infer.yaml
61 |
62 |
63 | # image-to-mv sampling
64 | CUDA_VISIBLE_DEVICES=0 python inference.py --cfg ./configs/i2vgen_xl_infer.yaml
65 |
66 | # To test raw prompts: type the prompts in ./data/test_prompts.txt
67 |
68 | # To test raw images: use Background-Remover(https://www.remove.bg/) to get the foreground of images
69 | # place the images all in /path/to/your_dir
70 | # Then run
71 | python -m utils.recenter_i2v /path/to/your_dir
72 | # The recenter results will be saved in ./data/images
73 | # add test image paths in ./data/test_images.txt
74 | # Then run
75 | CUDA_VISIBLE_DEVICES=0 python inference.py --cfg ./configs/i2vgen_xl_infer.yaml
76 | ```
77 |
78 | ## Training
79 |
80 | ```bash
81 | # Download our dataset(G-Objaverse) following the instructions at
82 | # https://github.com/modelscope/richdreamer/tree/main/dataset/gobjaverse
83 | # Modify the vid_dataset.data_dir_list as your download data_root
84 | # in ./configs/t2v_train.yaml and ./configs/i2vgen_xl_train.yaml
85 |
86 | # Text-to-mv finetuning
87 | CUDA_VISIBLE_DEVICES=0 python train_net.py --cfg ./configs/t2v_train.yaml
88 | # Text-to-mv fintuning using both Laion and Gobjaverse.
89 | # (Note we use 24 A100 for training both datasets. If your computation resource is not sufficient, do not try it!)
90 | CUDA_VISIBLE_DEVICES=0 python train_net.py --cfg ./configs/t2v_train_laion.yaml
91 |
92 | # Text-to-mv Feed-forward reconstruction finetuning.
93 | # Modify the UNet.use_lgm_refine as 'True' in ./configs/t2v_train.yaml. Then
94 | CUDA_VISIBLE_DEVICES=0 python train_net.py --cfg ./configs/t2v_train.yaml
95 |
96 |
97 | # Image-to-mv finetuning
98 | CUDA_VISIBLE_DEVICES=0 python train_net.py --cfg ./configs/i2vgen_xl_train.yaml
99 | # Image-to-mv Feed-forward reconstruction finetuning.
100 | # Modify the UNet.use_lgm_refine as 'True' in ./configs/i2vgen_xl_train.yaml. Then
101 | CUDA_VISIBLE_DEVICES=0 python train_net.py --cfg ./configs/i2vgen_xl_train.yaml
102 | ```
103 |
104 | ## Tips
105 |
106 | - You will observe a sudden convergence in Text-to-MV finetuning(~5min).
107 |
108 | - You will not observe a sudden convergence in Image-to-MV finetuning. Usually it takes half a day for a initial convergence.
109 |
110 | - Remove the background of test image use [Background-Remover](https://www.remove.bg/) instead of rembg to get a better result. The artifacts of segmentation mask will influence the quality of multi-view generation results.
111 |
112 | ## Future Works
113 |
114 | - Dense View Large Reconstruction Model.
115 |
116 | - More general and high-quality Text-to-MV using better Video Diffusion Model(like HiGen) and novel finetuning techniques.
117 |
118 | ## Acknowledgement
119 |
120 | This work is built on many amazing research works and open-source projects:
121 |
122 | - [VGen](https://github.com/ali-vilab/VGen)
123 | - [LGM](https://github.com/3DTopia/LGM)
124 | - [SyncDreamer](https://github.com/liuyuan-pal/SyncDreamer)
125 | - [GaussianSplatting](https://github.com/graphdeco-inria/gaussian-splatting)
126 |
127 | Thanks for their excellent work and great contribution to 3D generation area.
128 |
129 | We would like to express our special gratitude to [Jiaxiang Tang](https://github.com/ashawkey), [Yuan Liu](https://github.com/liuyuan-pal) for the valuable discussion in LGM and SyncDreamer.
130 |
131 |
132 | ## Citation
133 |
134 | ```
135 | @misc{zuo2024videomv,
136 | title={VideoMV: Consistent Multi-View Generation Based on Large Video Generative Model},
137 | author={Qi Zuo and Xiaodong Gu and Lingteng Qiu and Yuan Dong and Zhengyi Zhao and Weihao Yuan and Rui Peng and Siyu Zhu and Zilong Dong and Liefeng Bo and Qixing Huang},
138 | year={2024},
139 | eprint={2403.12010},
140 | archivePrefix={arXiv},
141 | primaryClass={cs.CV}
142 | }
143 | ```
144 |
--------------------------------------------------------------------------------
/assets/f.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/assets/f.png
--------------------------------------------------------------------------------
/configs/base.yaml:
--------------------------------------------------------------------------------
1 | ENABLE: true
2 | DATASET: webvid10m
--------------------------------------------------------------------------------
/configs/i2vgen_xl_infer.yaml:
--------------------------------------------------------------------------------
1 | TASK_TYPE: inference_i2vgen_entrance
2 | use_fp16: True
3 | guide_scale: 6.0
4 | use_fp16: True
5 | chunk_size: 2
6 | decoder_bs: 2
7 | max_frames: 24
8 | target_fps: 8 # FPS Conditions, not the encoding fps
9 | scale: 8
10 | seed: 9999
11 | round: 4
12 | batch_size: 1
13 | use_zero_infer: True
14 |
15 | # For important input
16 | vldm_cfg: configs/i2vgen_xl_train.yaml
17 | test_list_path: data/test_images.txt
18 | test_model: ./pretrained_models/i2v_00882000.pth
19 | log_dir: "workspace/visualization/i2v"
20 |
21 | UNet: {
22 | 'type': 'UNetSD_I2VGen',
23 | 'in_dim': 4,
24 | 'y_dim': 1024,
25 | 'upper_len': 128,
26 | 'context_dim': 1024,
27 | 'concat_dim': 4,
28 | 'out_dim': 4,
29 | 'dim_mult': [1, 2, 4, 4],
30 | 'num_heads': 8,
31 | 'default_fps': 8,
32 | 'head_dim': 64,
33 | 'num_res_blocks': 2,
34 | 'dropout': 0.1,
35 | 'temporal_attention': True,
36 | 'temporal_attn_times': 1,
37 | 'use_checkpoint': True,
38 | 'use_fps_condition': False,
39 | 'use_camera_condition': True,
40 | 'use_lgm_refine': True, # Turn off this if you want to simply fintune a naive i2vgen-xl
41 | 'use_sim_mask': False
42 | }
--------------------------------------------------------------------------------
/configs/i2vgen_xl_train.yaml:
--------------------------------------------------------------------------------
1 | TASK_TYPE: train_i2v_entrance
2 | ENABLE: true
3 | use_ema: true
4 | num_workers: 6
5 | frame_lens: [24]
6 | sample_fps: [8]
7 | resolution: [256, 256]
8 | vit_resolution: [224, 224]
9 |
10 | lgm_pretrain: './pretrained_models/model.safetensors'
11 |
12 | vid_dataset: {
13 | 'type': 'Video_I2V_Dataset',
14 | 'data_list': ['./data/valid_paths_v4_cap_filter_thres_28_catfilter19w.json', ],
15 | 'data_dir_list': ['/mnt/objaverse/dataset/raw/0', ],
16 | 'caption_dir': './data/text_captions_cap3d.json',
17 | 'vit_resolution': [224, 224],
18 | 'resolution': [256, 256],
19 | 'get_first_frame': True,
20 | 'max_words': 1000,
21 | 'prepare_lgm': True,
22 | }
23 |
24 | img_dataset: {
25 | 'type': 'ImageDataset',
26 | 'data_list': ['data/img_list.txt', ],
27 | 'data_dir_list': ['data/images', ],
28 | 'vit_resolution': [224, 224],
29 | 'resolution': [256, 256],
30 | 'max_words': 1000
31 | }
32 |
33 | embedder: {
34 | 'type': 'FrozenOpenCLIPTtxtVisualEmbedder',
35 | 'layer': 'penultimate',
36 | 'vit_resolution': [224, 224],
37 | 'pretrained': './pretrained_models/modelscope_i2v/I2VGen-XL/open_clip_pytorch_model.bin'
38 | }
39 |
40 | UNet: {
41 | 'type': 'UNetSD_I2VGen',
42 | 'in_dim': 4,
43 | 'y_dim': 1024,
44 | 'upper_len': 128,
45 | 'context_dim': 1024,
46 | 'concat_dim': 4,
47 | 'out_dim': 4,
48 | 'dim_mult': [1, 2, 4, 4],
49 | 'num_heads': 8,
50 | 'default_fps': 8,
51 | 'head_dim': 64,
52 | 'num_res_blocks': 2,
53 | 'dropout': 0.1,
54 | 'temporal_attention': True,
55 | 'temporal_attn_times': 1,
56 | 'use_checkpoint': True,
57 | 'use_fps_condition': False,
58 | 'use_camera_condition': True,
59 | 'use_lgm_refine': False, # Turn off this if you want to simply fintune a naive i2vgen-xl
60 | 'use_sim_mask': False
61 | }
62 |
63 | Diffusion: {
64 | 'type': 'DiffusionDDIM',
65 | 'schedule': 'cosine', # cosine
66 | 'schedule_param': {
67 | 'num_timesteps': 1000,
68 | 'cosine_s': 0.008,
69 | 'zero_terminal_snr': True,
70 | },
71 | 'mean_type': 'v',
72 | 'loss_type': 'mse',
73 | 'var_type': 'fixed_small',
74 | 'rescale_timesteps': False,
75 | 'noise_strength': 0.1
76 | }
77 |
78 | batch_sizes: {
79 | "24": 8,
80 | }
81 |
82 | visual_train: {
83 | 'type': 'VisualTrainTextImageToVideo',
84 | 'partial_keys': [
85 | ['y', 'image', 'local_image', 'fps', 'camera_data', 'gs_data']
86 | ],
87 | 'use_offset_noise': True,
88 | 'guide_scale': 6.0,
89 | }
90 |
91 | Pretrain: {
92 | 'type': pretrain_specific_strategies,
93 | 'fix_weight': False,
94 | 'grad_scale': 0.5,
95 | 'resume_checkpoint': './pretrained_models/modelscope_i2v/I2VGen-XL/i2vgen_xl_00854500.pth',
96 | 'sd_keys_path': './pretrained_models/modelscope_i2v/I2VGen-XL/stable_diffusion_image_key_temporal_attention_x1.json',
97 | }
98 |
99 | chunk_size: 4
100 | decoder_bs: 4
101 | lr: 0.00003
102 |
103 | noise_strength: 0.1
104 | # classifier-free guidance
105 | p_zero: 0.0
106 | guide_scale: 3.0
107 | num_steps: 1000000
108 |
109 | use_zero_infer: True
110 | viz_interval: 200 # 200
111 | save_ckp_interval: 500 # 500
112 |
113 | # Log
114 | log_dir: "workspace/experiments_i2v"
115 | log_interval: 1
116 | seed: 6666
117 |
--------------------------------------------------------------------------------
/configs/t2v_infer.yaml:
--------------------------------------------------------------------------------
1 | TASK_TYPE: inference_text2video_entrance
2 | use_fp16: False
3 | guide_scale: 9.0
4 | chunk_size: 4
5 | decoder_bs: 4
6 | max_frames: 24
7 | target_fps: 8 # FPS Conditions, not encoding fps
8 | scale: 8
9 | batch_size: 1
10 | use_zero_infer: True
11 |
12 | round: 2
13 | seed: 11
14 |
15 | test_list_path: ./data/test_prompts.txt
16 | vldm_cfg: configs/t2v_train.yaml
17 | test_model: ./pretrained_modesl/t2v_00333000.pth
18 | log_dir: ./workspace/visualization/t2v
19 |
20 | UNet: {
21 | 'type': 'UNetSD_T2VBase',
22 | 'in_dim': 4,
23 | 'y_dim': 1024,
24 | 'upper_len': 128,
25 | 'context_dim': 1024,
26 | 'out_dim': 4,
27 | 'dim_mult': [1, 2, 4, 4],
28 | 'num_heads': 8,
29 | 'default_fps': 8,
30 | 'head_dim': 64,
31 | 'num_res_blocks': 2,
32 | 'dropout': 0.1,
33 | 'misc_dropout': 0.4,
34 | 'temporal_attention': True,
35 | 'temporal_attn_times': 1,
36 | 'use_checkpoint': True,
37 | 'use_fps_condition': False,
38 | 'use_camera_condition': True, # Turn off this if you are trained on multi-view images with fixed poses.
39 | 'use_lgm_refine': True,
40 | 'use_sim_mask': False
41 | }
--------------------------------------------------------------------------------
/configs/t2v_train.yaml:
--------------------------------------------------------------------------------
1 | TASK_TYPE: train_t2v_entrance
2 | ENABLE: true
3 | use_ema: false
4 | num_workers: 10
5 | frame_lens: [24]
6 | sample_fps: [8]
7 | resolution: [256, 256]
8 | vit_resolution: [224, 224]
9 | lgm_pretrain: './pretrained_models/model.safetensors'
10 |
11 | vid_dataset: {
12 | 'type': 'VideoDataset',
13 | 'data_list': ['./data/lvis_thres_28.json', ],
14 | 'data_dir_list': ['/mnt/objaverse/dataset/raw/0', ],
15 | 'caption_dir': './data/text_captions_cap3d.json',
16 | 'vit_resolution': [224, 224],
17 | 'resolution': [256, 256],
18 | 'get_first_frame': True,
19 | 'max_words': 1000,
20 | 'prepare_lgm': True,
21 | }
22 |
23 | img_dataset: {
24 | 'type': 'ImageDataset',
25 | 'data_list': ['data/img_list.txt', ],
26 | 'data_dir_list': ['data/images', ],
27 | 'vit_resolution': [224, 224],
28 | 'resolution': [256, 256],
29 | 'max_words': 1000
30 | }
31 | embedder: {
32 | 'type': 'FrozenOpenCLIPTtxtVisualEmbedder',
33 | 'layer': 'penultimate',
34 | 'vit_resolution': [224, 224],
35 | 'pretrained': './pretrained_models/modelscope_t2v/open_clip_pytorch_model.bin'
36 | }
37 |
38 | UNet: {
39 | 'type': 'UNetSD_T2VBase',
40 | 'in_dim': 4,
41 | 'y_dim': 1024,
42 | 'upper_len': 128,
43 | 'context_dim': 1024,
44 | 'out_dim': 4,
45 | 'dim_mult': [1, 2, 4, 4],
46 | 'num_heads': 8,
47 | 'default_fps': 8,
48 | 'head_dim': 64,
49 | 'num_res_blocks': 2,
50 | 'dropout': 0.1,
51 | 'misc_dropout': 0.4,
52 | 'temporal_attention': True,
53 | 'temporal_attn_times': 1,
54 | 'use_checkpoint': True,
55 | 'use_fps_condition': False,
56 | 'use_camera_condition': True, # Turn off this if you are trained on multi-view images with fixed poses.
57 | 'use_lgm_refine': False,
58 | 'use_sim_mask': False
59 | }
60 |
61 | Diffusion: {
62 | 'type': 'DiffusionDDIM',
63 | 'schedule': 'linear_sd', # cosine
64 | 'schedule_param': {
65 | 'num_timesteps': 1000,
66 | 'init_beta': 0.00085,
67 | 'last_beta': 0.0120,
68 | 'zero_terminal_snr': False,
69 | },
70 | 'mean_type': 'eps', # eps for baseline with no lgm reg
71 | 'loss_type': 'mse',
72 | 'var_type': 'fixed_small',
73 | 'rescale_timesteps': False,
74 | 'noise_strength': 0.0
75 | }
76 |
77 | batch_sizes: {
78 | "1": 32,
79 | "24": 8,
80 | }
81 |
82 | visual_train: {
83 | 'type': 'VisualTrainTextImageToVideo',
84 | 'partial_keys': [
85 | ['y', 'fps', 'camera_data', 'gs_data'],
86 | ],
87 | 'use_offset_noise': False,
88 | 'guide_scale': 9.0,
89 | }
90 |
91 | Pretrain: {
92 | 'type': pretrain_specific_strategies,
93 | 'fix_weight': False,
94 | 'grad_scale': 0.5,
95 | 'resume_checkpoint': './pretrained_models/modelscope_t2v/model_scope_0267000.pth',
96 | 'sd_keys_path': 'data/stable_diffusion_image_key_temporal_attention_x1.json',
97 | }
98 |
99 | chunk_size: 4
100 | decoder_bs: 4
101 | lr: 0.00003 # 0.00003
102 |
103 | noise_strength: 0.0 # no noise
104 | # classifier-free guidance
105 | p_zero: 0.1
106 | guide_scale: 3.0
107 | num_steps: 1000000
108 |
109 | use_zero_infer: True
110 | viz_interval: 50 # 200
111 | save_ckp_interval: 500 # 500
112 |
113 | # Log
114 | log_dir: "workspace/experiment_t2v"
115 | log_interval: 1
116 | seed: 0
--------------------------------------------------------------------------------
/configs/t2v_train_laion.yaml:
--------------------------------------------------------------------------------
1 | TASK_TYPE: train_t2v_entrance
2 | ENABLE: true
3 | use_ema: false
4 | num_workers: 4
5 | frame_lens: [1, 24, 24, 24, 24, 24, 24, 24]
6 | sample_fps: [1, 8, 8, 8, 8, 8, 8, 8]
7 | resolution: [256, 256]
8 | vit_resolution: [224, 224]
9 | lgm_pretrain: './pretrained_models/model.safetensors'
10 |
11 | vid_dataset: {
12 | 'type': 'VideoDataset',
13 | # 'data_list': ['/mnt/cap/muyuan/code/StableVideoDiffusion/StableVideoDiffusion/valid_paths_v4_cap_filter_thres_28.json', ],
14 | 'data_list': ['./data/lvis_thres_28.json', ],
15 | 'data_dir_list': ['/mnt/objaverse/dataset/raw/0', ],
16 | 'caption_dir': './data/text_captions_cap3d.json',
17 | 'vit_resolution': [224, 224],
18 | 'resolution': [256, 256],
19 | 'get_first_frame': True,
20 | 'max_words': 1000,
21 | 'prepare_lgm': False,
22 | }
23 |
24 | img_dataset: {
25 | 'type': 'LAIONImageDataset',
26 | 'data_list': ['{00000..60580}.tar', ],
27 | 'data_dir_list': ['/mnt/laion/dataset/laion2b-en-ath5/improved_aesthetics_5plus/laion-2ben-5_0/', ],
28 | 'vit_resolution': [224, 224],
29 | 'resolution': [256, 256],
30 | 'max_words': 1000,
31 | }
32 |
33 | embedder: {
34 | 'type': 'FrozenOpenCLIPTtxtVisualEmbedder',
35 | 'layer': 'penultimate',
36 | 'vit_resolution': [224, 224],
37 | 'pretrained': './pretrained_models/modelscope_t2v/open_clip_pytorch_model.bin'
38 | }
39 |
40 | UNet: {
41 | 'type': 'UNetSD_T2VBase',
42 | 'in_dim': 4,
43 | 'y_dim': 1024,
44 | 'upper_len': 128,
45 | 'context_dim': 1024,
46 | 'out_dim': 4,
47 | 'dim_mult': [1, 2, 4, 4],
48 | 'num_heads': 8,
49 | 'default_fps': 8,
50 | 'head_dim': 64,
51 | 'num_res_blocks': 2,
52 | 'dropout': 0.1,
53 | 'misc_dropout': 0.4,
54 | 'temporal_attention': True,
55 | 'temporal_attn_times': 1,
56 | 'use_checkpoint': True,
57 | 'use_fps_condition': False,
58 | 'use_camera_condition': True, # Turn off this if you are trained on multi-view images with fixed poses.
59 | 'use_sync_attention': False, # Turn off this if you do not wish to use SyncAttention.
60 | 'use_flexicube_reg': False, # Turn off this if you do not wish to use a 3D reguralization.
61 | 'use_lgm_reg': False, # Turn off this if you do not wish to use a lgm reguralization.
62 | 'use_lgm_refine': False,
63 | 'use_sim_mask': False
64 | }
65 | # Diffusion: {
66 | # 'type': 'DiffusionDDIM',
67 | # 'schedule': 'cosine', # cosine
68 | # 'schedule_param': {
69 | # 'num_timesteps': 1000,
70 | # 'cosine_s': 0.008,
71 | # 'zero_terminal_snr': True,
72 | # },
73 | # 'mean_type': 'v',
74 | # 'loss_type': 'mse',
75 | # 'var_type': 'fixed_small',
76 | # 'rescale_timesteps': False,
77 | # 'noise_strength': 0.1
78 | # }
79 |
80 | Diffusion: {
81 | 'type': 'DiffusionDDIM',
82 | 'schedule': 'linear_sd', # cosine
83 | 'schedule_param': {
84 | 'num_timesteps': 1000,
85 | 'init_beta': 0.00085,
86 | 'last_beta': 0.0120,
87 | 'zero_terminal_snr': False,
88 | },
89 | 'mean_type': 'eps', # eps for baseline with no lgm reg
90 | 'loss_type': 'mse',
91 | 'var_type': 'fixed_small',
92 | 'rescale_timesteps': False,
93 | 'noise_strength': 0.0
94 | }
95 |
96 | batch_sizes: {
97 | "1": 196,
98 | "24": 24,
99 | }
100 |
101 | visual_train: {
102 | 'type': 'VisualTrainTextImageToVideo',
103 | 'partial_keys': [
104 | ['y', 'fps', 'camera_data', 'gs_data'],
105 | ],
106 | 'use_offset_noise': False,
107 | 'guide_scale': 9.0,
108 | }
109 |
110 | Pretrain: {
111 | 'type': pretrain_specific_strategies,
112 | 'fix_weight': False,
113 | 'grad_scale': 0.5,
114 | 'resume_checkpoint': './pretrained_modesl/t2v_00333000.pth',
115 | 'sd_keys_path': 'data/stable_diffusion_image_key_temporal_attention_x1.json',
116 | }
117 |
118 | chunk_size: 4
119 | decoder_bs: 4
120 | lr: 0.00003 # 0.00003
121 |
122 | noise_strength: 0.0 # no noise
123 | # classifier-free guidance
124 | p_zero: 0.1
125 | guide_scale: 3.0
126 | num_steps: 1000000
127 |
128 | use_zero_infer: True
129 | viz_interval: 200 # 200
130 | save_ckp_interval: 500 # 500
131 |
132 | # Log
133 | log_dir: "workspace/experiments_laion"
134 | log_interval: 1
135 | seed: 0
--------------------------------------------------------------------------------
/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/core/__init__.py
--------------------------------------------------------------------------------
/core/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/core/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/core/__pycache__/attention.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/core/__pycache__/attention.cpython-38.pyc
--------------------------------------------------------------------------------
/core/__pycache__/gs.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/core/__pycache__/gs.cpython-38.pyc
--------------------------------------------------------------------------------
/core/__pycache__/models.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/core/__pycache__/models.cpython-38.pyc
--------------------------------------------------------------------------------
/core/__pycache__/options.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/core/__pycache__/options.cpython-38.pyc
--------------------------------------------------------------------------------
/core/__pycache__/unet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/core/__pycache__/unet.cpython-38.pyc
--------------------------------------------------------------------------------
/core/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/core/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/core/attention.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0
4 | # found in the LICENSE file in the root directory of this source tree.
5 |
6 | # References:
7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9 |
10 | import os
11 | import warnings
12 |
13 | from torch import Tensor
14 | from torch import nn
15 |
16 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
17 | try:
18 | if XFORMERS_ENABLED:
19 | from xformers.ops import memory_efficient_attention, unbind
20 |
21 | XFORMERS_AVAILABLE = True
22 | warnings.warn("xFormers is available (Attention)")
23 | else:
24 | warnings.warn("xFormers is disabled (Attention)")
25 | raise ImportError
26 | except ImportError:
27 | XFORMERS_AVAILABLE = False
28 | warnings.warn("xFormers is not available (Attention)")
29 |
30 |
31 | class Attention(nn.Module):
32 | def __init__(
33 | self,
34 | dim: int,
35 | num_heads: int = 8,
36 | qkv_bias: bool = False,
37 | proj_bias: bool = True,
38 | attn_drop: float = 0.0,
39 | proj_drop: float = 0.0,
40 | ) -> None:
41 | super().__init__()
42 | self.num_heads = num_heads
43 | head_dim = dim // num_heads
44 | self.scale = head_dim**-0.5
45 |
46 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
47 | self.attn_drop = nn.Dropout(attn_drop)
48 | self.proj = nn.Linear(dim, dim, bias=proj_bias)
49 | self.proj_drop = nn.Dropout(proj_drop)
50 |
51 | def forward(self, x: Tensor) -> Tensor:
52 | B, N, C = x.shape
53 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
54 |
55 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
56 | attn = q @ k.transpose(-2, -1)
57 |
58 | attn = attn.softmax(dim=-1)
59 | attn = self.attn_drop(attn)
60 |
61 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
62 | x = self.proj(x)
63 | x = self.proj_drop(x)
64 | return x
65 |
66 |
67 | class MemEffAttention(Attention):
68 | def forward(self, x: Tensor, attn_bias=None) -> Tensor:
69 | if not XFORMERS_AVAILABLE:
70 | if attn_bias is not None:
71 | raise AssertionError("xFormers is required for using nested tensors")
72 | return super().forward(x)
73 |
74 | B, N, C = x.shape
75 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
76 |
77 | q, k, v = unbind(qkv, 2)
78 |
79 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
80 | x = x.reshape([B, N, C])
81 |
82 | x = self.proj(x)
83 | x = self.proj_drop(x)
84 | return x
85 |
86 |
87 | class CrossAttention(nn.Module):
88 | def __init__(
89 | self,
90 | dim: int,
91 | dim_q: int,
92 | dim_k: int,
93 | dim_v: int,
94 | num_heads: int = 8,
95 | qkv_bias: bool = False,
96 | proj_bias: bool = True,
97 | attn_drop: float = 0.0,
98 | proj_drop: float = 0.0,
99 | ) -> None:
100 | super().__init__()
101 | self.dim = dim
102 | self.num_heads = num_heads
103 | head_dim = dim // num_heads
104 | self.scale = head_dim**-0.5
105 |
106 | self.to_q = nn.Linear(dim_q, dim, bias=qkv_bias)
107 | self.to_k = nn.Linear(dim_k, dim, bias=qkv_bias)
108 | self.to_v = nn.Linear(dim_v, dim, bias=qkv_bias)
109 | self.attn_drop = nn.Dropout(attn_drop)
110 | self.proj = nn.Linear(dim, dim, bias=proj_bias)
111 | self.proj_drop = nn.Dropout(proj_drop)
112 |
113 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
114 | # q: [B, N, Cq]
115 | # k: [B, M, Ck]
116 | # v: [B, M, Cv]
117 | # return: [B, N, C]
118 |
119 | B, N, _ = q.shape
120 | M = k.shape[1]
121 |
122 | q = self.scale * self.to_q(q).reshape(B, N, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, N, C/nh]
123 | k = self.to_k(k).reshape(B, M, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, M, C/nh]
124 | v = self.to_v(v).reshape(B, M, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, M, C/nh]
125 |
126 | attn = q @ k.transpose(-2, -1) # [B, nh, N, M]
127 |
128 | attn = attn.softmax(dim=-1) # [B, nh, N, M]
129 | attn = self.attn_drop(attn)
130 |
131 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1) # [B, nh, N, M] @ [B, nh, M, C/nh] --> [B, nh, N, C/nh] --> [B, N, nh, C/nh] --> [B, N, C]
132 | x = self.proj(x)
133 | x = self.proj_drop(x)
134 | return x
135 |
136 |
137 | class MemEffCrossAttention(CrossAttention):
138 | def forward(self, q: Tensor, k: Tensor, v: Tensor, attn_bias=None) -> Tensor:
139 | if not XFORMERS_AVAILABLE:
140 | if attn_bias is not None:
141 | raise AssertionError("xFormers is required for using nested tensors")
142 | return super().forward(x)
143 |
144 | B, N, _ = q.shape
145 | M = k.shape[1]
146 |
147 | q = self.scale * self.to_q(q).reshape(B, N, self.num_heads, self.dim // self.num_heads) # [B, N, nh, C/nh]
148 | k = self.to_k(k).reshape(B, M, self.num_heads, self.dim // self.num_heads) # [B, M, nh, C/nh]
149 | v = self.to_v(v).reshape(B, M, self.num_heads, self.dim // self.num_heads) # [B, M, nh, C/nh]
150 |
151 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
152 | x = x.reshape(B, N, -1)
153 |
154 | x = self.proj(x)
155 | x = self.proj_drop(x)
156 | return x
157 |
--------------------------------------------------------------------------------
/core/gs.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from diff_gaussian_rasterization import (
8 | GaussianRasterizationSettings,
9 | GaussianRasterizer,
10 | )
11 |
12 | from core.options import Options
13 |
14 | import kiui
15 |
16 | class GaussianRenderer:
17 | def __init__(self, opt: Options):
18 |
19 | self.opt = opt
20 | self.bg_color = torch.tensor([1,1,1], dtype=torch.float32, device="cuda")
21 |
22 | # intrinsics
23 | self.tan_half_fov = np.tan(0.5 * np.deg2rad(self.opt.fovy))
24 | self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32)
25 | self.proj_matrix[0, 0] = 1 / self.tan_half_fov
26 | self.proj_matrix[1, 1] = 1 / self.tan_half_fov
27 | self.proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
28 | self.proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)
29 | self.proj_matrix[2, 3] = 1
30 |
31 | def render(self, gaussians, cam_view, cam_view_proj, cam_pos, bg_color=None, scale_modifier=1):
32 | # gaussians: [B, N, 14]
33 | # cam_view, cam_view_proj: [B, V, 4, 4]
34 | # cam_pos: [B, V, 3]
35 |
36 | device = gaussians.device
37 | B, V = cam_view.shape[:2]
38 |
39 | # loop of loop...
40 | images = []
41 | alphas = []
42 | for b in range(B):
43 | # pos, opacity, scale, rotation, shs
44 | means3D = gaussians[b, :, 0:3].contiguous().float()
45 | opacity = gaussians[b, :, 3:4].contiguous().float()
46 | scales = gaussians[b, :, 4:7].contiguous().float()
47 | rotations = gaussians[b, :, 7:11].contiguous().float()
48 | rgbs = gaussians[b, :, 11:].contiguous().float() # [N, 4]
49 |
50 | for v in range(V):
51 |
52 | # render novel views
53 | view_matrix = cam_view[b, v].float()
54 | view_proj_matrix = cam_view_proj[b, v].float()
55 | campos = cam_pos[b, v].float()
56 |
57 | raster_settings = GaussianRasterizationSettings(
58 | image_height=self.opt.output_size,
59 | image_width=self.opt.output_size,
60 | tanfovx=self.tan_half_fov,
61 | tanfovy=self.tan_half_fov,
62 | bg=self.bg_color if bg_color is None else bg_color,
63 | scale_modifier=scale_modifier,
64 | viewmatrix=view_matrix,
65 | projmatrix=view_proj_matrix,
66 | sh_degree=0,
67 | campos=campos,
68 | prefiltered=False,
69 | debug=False,
70 | )
71 |
72 | rasterizer = GaussianRasterizer(raster_settings=raster_settings)
73 | # Rasterize visible Gaussians to image, obtain their radii (on screen).
74 | rendered_image, radii, rendered_depth, rendered_alpha = rasterizer(
75 | means3D=means3D,
76 | means2D=torch.zeros_like(means3D, dtype=torch.float32, device=device),
77 | shs=None,
78 | colors_precomp=rgbs,
79 | opacities=opacity,
80 | scales=scales,
81 | rotations=rotations,
82 | cov3D_precomp=None,
83 | )
84 | rendered_image = rendered_image.clamp(0, 1)
85 | images.append(rendered_image)
86 | alphas.append(rendered_alpha)
87 |
88 | images = torch.stack(images, dim=0).view(B, V, 3, self.opt.output_size, self.opt.output_size) # we use 4 for latent
89 | alphas = torch.stack(alphas, dim=0).view(B, V, 1, self.opt.output_size, self.opt.output_size)
90 |
91 | return {
92 | "image": images, # [B, V, 4, H, W]
93 | "alpha": alphas, # [B, V, 1, H, W]
94 | }
95 |
96 |
97 | def save_ply(self, gaussians, path, compatible=True):
98 | # gaussians: [B, N, 14]
99 | # compatible: save pre-activated gaussians as in the original paper
100 |
101 | assert gaussians.shape[0] == 1, 'only support batch size 1'
102 |
103 | from plyfile import PlyData, PlyElement
104 |
105 | means3D = gaussians[0, :, 0:3].contiguous().float()
106 | opacity = gaussians[0, :, 3:4].contiguous().float()
107 | scales = gaussians[0, :, 4:7].contiguous().float()
108 | rotations = gaussians[0, :, 7:11].contiguous().float()
109 | shs = gaussians[0, :, 11:].unsqueeze(1).contiguous().float() # [N, 1, 3]
110 |
111 | # prune by opacity
112 | mask = opacity.squeeze(-1) >= 0.005
113 | means3D = means3D[mask]
114 | opacity = opacity[mask]
115 | scales = scales[mask]
116 | rotations = rotations[mask]
117 | shs = shs[mask]
118 |
119 | # invert activation to make it compatible with the original ply format
120 | if compatible:
121 | opacity = kiui.op.inverse_sigmoid(opacity)
122 | scales = torch.log(scales + 1e-8)
123 | shs = (shs - 0.5) / 0.28209479177387814
124 |
125 | xyzs = means3D.detach().cpu().numpy()
126 | f_dc = shs.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
127 | opacities = opacity.detach().cpu().numpy()
128 | scales = scales.detach().cpu().numpy()
129 | rotations = rotations.detach().cpu().numpy()
130 |
131 | l = ['x', 'y', 'z']
132 | # All channels except the 3 DC
133 | for i in range(f_dc.shape[1]):
134 | l.append('f_dc_{}'.format(i))
135 | l.append('opacity')
136 | for i in range(scales.shape[1]):
137 | l.append('scale_{}'.format(i))
138 | for i in range(rotations.shape[1]):
139 | l.append('rot_{}'.format(i))
140 |
141 | dtype_full = [(attribute, 'f4') for attribute in l]
142 |
143 | elements = np.empty(xyzs.shape[0], dtype=dtype_full)
144 | attributes = np.concatenate((xyzs, f_dc, opacities, scales, rotations), axis=1)
145 | elements[:] = list(map(tuple, attributes))
146 | el = PlyElement.describe(elements, 'vertex')
147 |
148 | PlyData([el]).write(path)
149 |
150 | def load_ply(self, path, compatible=True):
151 |
152 | from plyfile import PlyData, PlyElement
153 |
154 | plydata = PlyData.read(path)
155 |
156 | xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
157 | np.asarray(plydata.elements[0]["y"]),
158 | np.asarray(plydata.elements[0]["z"])), axis=1)
159 | print("Number of points at loading : ", xyz.shape[0])
160 |
161 | opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
162 |
163 | shs = np.zeros((xyz.shape[0], 3))
164 | shs[:, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
165 | shs[:, 1] = np.asarray(plydata.elements[0]["f_dc_1"])
166 | shs[:, 2] = np.asarray(plydata.elements[0]["f_dc_2"])
167 |
168 | scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
169 | scales = np.zeros((xyz.shape[0], len(scale_names)))
170 | for idx, attr_name in enumerate(scale_names):
171 | scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
172 |
173 | rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot_")]
174 | rots = np.zeros((xyz.shape[0], len(rot_names)))
175 | for idx, attr_name in enumerate(rot_names):
176 | rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
177 |
178 | gaussians = np.concatenate([xyz, opacities, scales, rots, shs], axis=1)
179 | gaussians = torch.from_numpy(gaussians).float() # cpu
180 |
181 | if compatible:
182 | gaussians[..., 3:4] = torch.sigmoid(gaussians[..., 3:4])
183 | gaussians[..., 4:7] = torch.exp(gaussians[..., 4:7])
184 | gaussians[..., 11:] = 0.28209479177387814 * gaussians[..., 11:] + 0.5
185 |
186 | return gaussians
--------------------------------------------------------------------------------
/core/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | import kiui
7 | from kiui.lpips import LPIPS
8 |
9 | from core.unet import UNet
10 | from core.options import Options
11 | from core.gs import GaussianRenderer
12 |
13 |
14 | class LGM(nn.Module):
15 | def __init__(
16 | self,
17 | opt: Options,
18 | ):
19 | super().__init__()
20 |
21 | self.opt = opt
22 |
23 | # unet
24 | self.unet = UNet(
25 | 9, 14,
26 | down_channels=self.opt.down_channels,
27 | down_attention=self.opt.down_attention,
28 | mid_attention=self.opt.mid_attention,
29 | up_channels=self.opt.up_channels,
30 | up_attention=self.opt.up_attention,
31 | )
32 | # x = F.interpolate(x, scale_factor=2.0, mode='nearest')
33 | # last conv
34 | self.conv = nn.Conv2d(14, 14, kernel_size=1) # NOTE: maybe remove it if train again
35 | # Gaussian Renderer
36 | self.gs = GaussianRenderer(opt)
37 |
38 | # activations...
39 | self.pos_act = lambda x: x.clamp(-1, 1)
40 | self.scale_act = lambda x: 0.1 * F.softplus(x)
41 | self.opacity_act = lambda x: torch.sigmoid(x)
42 | self.rot_act = F.normalize
43 | self.rgb_act = lambda x: 0.5 * torch.tanh(x) + 0.5 # NOTE: may use sigmoid if train again
44 |
45 | # LPIPS loss
46 | if self.opt.lambda_lpips > 0:
47 | self.lpips_loss = LPIPS(net='vgg')
48 | self.lpips_loss.requires_grad_(False)
49 |
50 |
51 | def state_dict(self, **kwargs):
52 | # remove lpips_loss
53 | state_dict = super().state_dict(**kwargs)
54 | for k in list(state_dict.keys()):
55 | if 'lpips_loss' in k:
56 | del state_dict[k]
57 | return state_dict
58 |
59 |
60 | def prepare_default_rays(self, device, elevation=0):
61 |
62 | from kiui.cam import orbit_camera
63 | from core.utils import get_rays
64 |
65 | cam_poses = np.stack([
66 | orbit_camera(elevation, 0, radius=self.opt.cam_radius),
67 | orbit_camera(elevation, 90, radius=self.opt.cam_radius),
68 | orbit_camera(elevation, 180, radius=self.opt.cam_radius),
69 | orbit_camera(elevation, 270, radius=self.opt.cam_radius),
70 | ], axis=0) # [4, 4, 4]
71 | cam_poses = torch.from_numpy(cam_poses)
72 | # print("default_rays:", cam_poses)
73 | rays_embeddings = []
74 | for i in range(cam_poses.shape[0]):
75 | rays_o, rays_d = get_rays(cam_poses[i], self.opt.input_size, self.opt.input_size, self.opt.fovy) # [h, w, 3]
76 | rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6]
77 | rays_embeddings.append(rays_plucker)
78 |
79 | ## visualize rays for plotting figure
80 | # kiui.vis.plot_image(rays_d * 0.5 + 0.5, save=True)
81 |
82 | rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous().to(device) # [V, 6, h, w]
83 |
84 | return rays_embeddings
85 |
86 |
87 | def forward_gaussians(self, images):
88 | # images: [B, 4, 9, H, W]
89 | # return: Gaussians: [B, dim_t]
90 |
91 | B, V, C, H, W = images.shape
92 | images = images.view(B*V, C, H, W)
93 |
94 | x = self.unet(images) # [B*24, 14, h, w]
95 | x = self.conv(x) # [B*24, 14, h, w]
96 |
97 | x = x.reshape(B, self.opt.num_input_views, 14, self.opt.splat_size, self.opt.splat_size) # hard code: 24??
98 |
99 | ## visualize multi-view gaussian features for plotting figure
100 | # tmp_alpha = self.opacity_act(x[0, :, 3:4])
101 | # tmp_img_rgb = self.rgb_act(x[0, :, 11:]) * tmp_alpha + (1 - tmp_alpha)
102 | # tmp_img_pos = self.pos_act(x[0, :, 0:3]) * 0.5 + 0.5
103 | # kiui.vis.plot_image(tmp_img_rgb, save=True)
104 | # kiui.vis.plot_image(tmp_img_pos, save=True)
105 |
106 | x = x.permute(0, 1, 3, 4, 2).reshape(B, -1, 14)
107 |
108 | pos = self.pos_act(x[..., 0:3]) # [B, N, 3]
109 | opacity = self.opacity_act(x[..., 3:4])
110 | scale = self.scale_act(x[..., 4:7])
111 | rotation = self.rot_act(x[..., 7:11])
112 | rgbs = self.rgb_act(x[..., 11:])
113 |
114 | gaussians = torch.cat([pos, opacity, scale, rotation, rgbs], dim=-1) # [B, N, 14]
115 |
116 | return gaussians
117 |
118 | def infer(self, data, step_ratio=1, bg_color_factor=0.5):
119 | results = {}
120 |
121 | images = data['input'] # [B, 4, 9, h, W], input features
122 |
123 | # use the first view to predict gaussians
124 | gaussians = self.forward_gaussians(images) # [B, N, 14]
125 |
126 | bg_color = torch.ones(3, dtype=torch.float32, device=gaussians.device)*bg_color_factor
127 |
128 | # use the other views for rendering and supervision
129 | results = self.gs.render(gaussians, data['cam_view'], data['cam_view_proj'], data['cam_pos'], bg_color=bg_color)
130 | pred_images = results['image'] # [B, V, C, output_size, output_size]
131 |
132 | results['images_pred'] = pred_images
133 |
134 | return results
135 |
136 | def forward(self, data, step_ratio=1):
137 | # data: output of the dataloader
138 | # return: loss
139 |
140 | results = {}
141 | loss = 0
142 |
143 | images = data['input'] # [B, 4, 9, h, W], input features
144 |
145 | # use the first view to predict gaussians
146 | gaussians = self.forward_gaussians(images) # [B, N, 14]
147 |
148 | results['gaussians'] = gaussians
149 |
150 | # random bg for training
151 | if self.training:
152 | bg_color = torch.rand(3, dtype=torch.float32, device=gaussians.device)
153 | else:
154 | bg_color = torch.ones(3, dtype=torch.float32, device=gaussians.device)
155 |
156 | # use the other views for rendering and supervision
157 | results = self.gs.render(gaussians, data['cam_view'], data['cam_view_proj'], data['cam_pos'], bg_color=bg_color)
158 | pred_images = results['image'] # [B, V, C, output_size, output_size]
159 | pred_alphas = results['alpha'] # [B, V, 1, output_size, output_size]
160 |
161 | results['images_pred'] = pred_images
162 | results['alphas_pred'] = pred_alphas
163 |
164 | gt_images = data['images_output'] # [B, V, 3, output_size, output_size], ground-truth novel views
165 | gt_masks = data['masks_output'] # [B, V, 1, output_size, output_size], ground-truth masks
166 |
167 | gt_images = gt_images * gt_masks + bg_color.view(1, 1, 3, 1, 1) * (1 - gt_masks)
168 |
169 | loss_mse = F.mse_loss(pred_images.half(), gt_images.half()) + F.mse_loss(pred_alphas.half(), gt_masks.half())
170 | loss = loss + loss_mse
171 |
172 | if self.opt.lambda_lpips > 0:
173 | loss_lpips = self.lpips_loss(
174 | # gt_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1,
175 | # pred_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1,
176 | # downsampled to at most 256 to reduce memory cost
177 | F.interpolate(gt_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False).half(),
178 | F.interpolate(pred_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False).half(),
179 | ).mean()
180 | results['loss_lpips'] = loss_lpips
181 | loss = loss + self.opt.lambda_lpips * loss_lpips
182 |
183 | results['loss'] = loss
184 |
185 | # metric
186 | with torch.no_grad():
187 | psnr = -10 * torch.log10(torch.mean((pred_images.detach() - gt_images) ** 2))
188 | results['psnr'] = psnr
189 |
190 |
191 |
192 | return results
--------------------------------------------------------------------------------
/core/options.py:
--------------------------------------------------------------------------------
1 | import tyro
2 | from dataclasses import dataclass
3 | from typing import Tuple, Literal, Dict, Optional
4 |
5 |
6 | @dataclass
7 | class Options:
8 | ### model
9 | # Unet image input size
10 | input_size: int = 256
11 | # Unet definition
12 | down_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024, 1024)
13 | down_attention: Tuple[bool, ...] = (False, False, False, True, True, True)
14 | mid_attention: bool = True
15 | up_channels: Tuple[int, ...] = (1024, 1024, 512, 256)
16 | up_attention: Tuple[bool, ...] = (True, True, True, False)
17 | # Unet output size, dependent on the input_size and U-Net structure!
18 | splat_size: int = 64
19 | # gaussian render size
20 | output_size: int = 256
21 |
22 | ### dataset
23 | # data mode (only support s3 now)
24 | data_mode: Literal['s3'] = 's3'
25 | # fovy of the dataset
26 | fovy: float = 39.6 # 39.6 # 49.1
27 | # camera near plane
28 | znear: float = 0.5 # 0.1 # 0.5
29 | # camera far plane
30 | zfar: float = 2.5 # 1000 # 2.5
31 | # number of all views (input + output)
32 | num_views: int = 8
33 | # number of views
34 | num_input_views: int = 4
35 | # camera radius
36 | cam_radius: float = 1.5 # to better use [-1, 1]^3 space
37 | # num workers
38 | num_workers: int = 8
39 |
40 | ### training
41 | # workspace
42 | workspace: str = './workspace'
43 | # resume
44 | resume: Optional[str] = "/mnt/cap/muyuan/code/StableVideoDiffusion/StableVideoDiffusion/i2vgen-xl/LGM/pretrained/model_fp16.safetensors"
45 | # batch size (per-GPU)
46 | batch_size: int = 8
47 | # gradient accumulation
48 | gradient_accumulation_steps: int = 1
49 | # training epochs
50 | num_epochs: int = 30
51 | # lpips loss weight
52 | lambda_lpips: float = 1.0
53 | # gradient clip
54 | gradient_clip: float = 1.0
55 | # mixed precision
56 | mixed_precision: str = 'bf16'
57 | # learning rate
58 | lr: float = 1e-4
59 | # augmentation prob for grid distortion
60 | prob_grid_distortion: float = 0.5
61 | # augmentation prob for camera jitter
62 | prob_cam_jitter: float = 0.5
63 |
64 | ### testing
65 | # test image path
66 | test_path: Optional[str] = None
67 |
68 | ### misc
69 | # nvdiffrast backend setting
70 | force_cuda_rast: bool = False
71 | # render fancy video with gaussian scaling effect
72 | fancy_video: bool = False
73 |
74 |
75 | # all the default settings
76 | config_defaults: Dict[str, Options] = {}
77 | config_doc: Dict[str, str] = {}
78 |
79 | config_doc['lrm'] = 'the default settings for LGM'
80 | config_defaults['lrm'] = Options()
81 |
82 | config_doc['small'] = 'small model with lower resolution Gaussians'
83 | config_defaults['small'] = Options(
84 | input_size=256,
85 | splat_size=64,
86 | output_size=256,
87 | batch_size=4,
88 | gradient_accumulation_steps=1,
89 | mixed_precision='bf16',
90 | )
91 |
92 | config_doc['big'] = 'big model with higher resolution Gaussians'
93 | config_defaults['big'] = Options(
94 | input_size=256,
95 | up_channels=(1024, 1024, 512, 256, 128), # one more decoder
96 | up_attention=(True, True, True, False, False),
97 | splat_size=128,
98 | output_size=512, # render & supervise Gaussians at a higher resolution.
99 | batch_size=8,
100 | num_views=8,
101 | gradient_accumulation_steps=1,
102 | mixed_precision='bf16',
103 | )
104 |
105 | config_doc['tiny'] = 'tiny model for ablation'
106 | config_defaults['tiny'] = Options(
107 | input_size=256,
108 | down_channels=(32, 64, 128, 256),
109 | down_attention=(False, False, False, True),
110 | up_channels=(256, 128, 64),
111 | up_attention=(True, False, False),
112 | splat_size=128,
113 | output_size=256,
114 | batch_size=8,
115 | num_views=8,
116 | gradient_accumulation_steps=1,
117 | mixed_precision='bf16',
118 | )
119 |
120 | AllConfigs = tyro.extras.subcommand_type_from_defaults(config_defaults, config_doc)
121 |
--------------------------------------------------------------------------------
/core/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | import roma
8 | from kiui.op import safe_normalize
9 |
10 | def get_rays(pose, h, w, fovy, opengl=True):
11 |
12 | x, y = torch.meshgrid(
13 | torch.arange(w, device=pose.device),
14 | torch.arange(h, device=pose.device),
15 | indexing="xy",
16 | )
17 | x = x.flatten()
18 | y = y.flatten()
19 |
20 | cx = w * 0.5
21 | cy = h * 0.5
22 |
23 | focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy))
24 |
25 | camera_dirs = F.pad(
26 | torch.stack(
27 | [
28 | (x - cx + 0.5) / focal,
29 | (y - cy + 0.5) / focal * (-1.0 if opengl else 1.0),
30 | ],
31 | dim=-1,
32 | ),
33 | (0, 1),
34 | value=(-1.0 if opengl else 1.0),
35 | ) # [hw, 3]
36 |
37 | rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3]
38 | rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3]
39 |
40 | rays_o = rays_o.view(h, w, 3)
41 | rays_d = safe_normalize(rays_d).view(h, w, 3)
42 |
43 | return rays_o, rays_d
44 |
45 | def orbit_camera_jitter(poses, strength=0.1):
46 | # poses: [B, 4, 4], assume orbit camera in opengl format
47 | # random orbital rotate
48 |
49 | B = poses.shape[0]
50 | rotvec_x = poses[:, :3, 1] * strength * np.pi * (torch.rand(B, 1, device=poses.device) * 2 - 1)
51 | rotvec_y = poses[:, :3, 0] * strength * np.pi / 2 * (torch.rand(B, 1, device=poses.device) * 2 - 1)
52 |
53 | rot = roma.rotvec_to_rotmat(rotvec_x) @ roma.rotvec_to_rotmat(rotvec_y)
54 | R = rot @ poses[:, :3, :3]
55 | T = rot @ poses[:, :3, 3:]
56 |
57 | new_poses = poses.clone()
58 | new_poses[:, :3, :3] = R
59 | new_poses[:, :3, 3:] = T
60 |
61 | return new_poses
62 |
63 | def grid_distortion(images, strength=0.5):
64 | # images: [B, C, H, W]
65 | # num_steps: int, grid resolution for distortion
66 | # strength: float in [0, 1], strength of distortion
67 |
68 | B, C, H, W = images.shape
69 |
70 | num_steps = np.random.randint(8, 17)
71 | grid_steps = torch.linspace(-1, 1, num_steps)
72 |
73 | # have to loop batch...
74 | grids = []
75 | for b in range(B):
76 | # construct displacement
77 | x_steps = torch.linspace(0, 1, num_steps) # [num_steps], inclusive
78 | x_steps = (x_steps + strength * (torch.rand_like(x_steps) - 0.5) / (num_steps - 1)).clamp(0, 1) # perturb
79 | x_steps = (x_steps * W).long() # [num_steps]
80 | x_steps[0] = 0
81 | x_steps[-1] = W
82 | xs = []
83 | for i in range(num_steps - 1):
84 | xs.append(torch.linspace(grid_steps[i], grid_steps[i + 1], x_steps[i + 1] - x_steps[i]))
85 | xs = torch.cat(xs, dim=0) # [W]
86 |
87 | y_steps = torch.linspace(0, 1, num_steps) # [num_steps], inclusive
88 | y_steps = (y_steps + strength * (torch.rand_like(y_steps) - 0.5) / (num_steps - 1)).clamp(0, 1) # perturb
89 | y_steps = (y_steps * H).long() # [num_steps]
90 | y_steps[0] = 0
91 | y_steps[-1] = H
92 | ys = []
93 | for i in range(num_steps - 1):
94 | ys.append(torch.linspace(grid_steps[i], grid_steps[i + 1], y_steps[i + 1] - y_steps[i]))
95 | ys = torch.cat(ys, dim=0) # [H]
96 |
97 | # construct grid
98 | grid_x, grid_y = torch.meshgrid(xs, ys, indexing='xy') # [H, W]
99 | grid = torch.stack([grid_x, grid_y], dim=-1) # [H, W, 2]
100 |
101 | grids.append(grid)
102 |
103 | grids = torch.stack(grids, dim=0).to(images.device) # [B, H, W, 2]
104 |
105 | # grid sample
106 | images = F.grid_sample(images, grids, align_corners=False)
107 |
108 | return images
109 |
110 |
--------------------------------------------------------------------------------
/data/images/demo1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/data/images/demo1.png
--------------------------------------------------------------------------------
/data/images/demo2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/data/images/demo2.png
--------------------------------------------------------------------------------
/data/test_images.txt:
--------------------------------------------------------------------------------
1 | ./data/images/demo1.png
2 | ./data/images/demo2.png
--------------------------------------------------------------------------------
/data/test_prompts.txt:
--------------------------------------------------------------------------------
1 | Futuristic space helmet
2 | dragon armor
3 | A medieval shield with a cross and wooden handle
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import copy
4 | import json
5 | import math
6 | import random
7 | import logging
8 | import itertools
9 | import numpy as np
10 |
11 | from utils.config import Config
12 | from utils.registry_class import INFER_ENGINE
13 |
14 | from tools import *
15 |
16 | if __name__ == '__main__':
17 | cfg_update = Config(load=True)
18 | INFER_ENGINE.build(dict(type=cfg_update.TASK_TYPE), cfg_update=cfg_update.cfg_dict)
19 |
--------------------------------------------------------------------------------
/install.sh:
--------------------------------------------------------------------------------
1 | pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113
2 | pip install -r requirements.txt
3 | pip install ninja
4 | git clone --recursive https://github.com/ashawkey/diff-gaussian-rasterization
5 | pip install ./diff-gaussian-rasterization
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | easydict==1.10
2 | tokenizers==0.12.1
3 | numpy>=1.19.2
4 | ftfy==6.1.1
5 | transformers==4.18.0
6 | imageio==2.15.0
7 | fairscale==0.4.6
8 | ipdb
9 | open-clip-torch==2.0.2
10 | xformers==0.0.13
11 | chardet==5.1.0
12 | torchdiffeq==0.2.3
13 | opencv-python==4.4.0.46
14 | opencv-python-headless==4.7.0.68
15 | torchsde==0.2.6
16 | simplejson==3.18.4
17 | motion-vector-extractor==1.0.6
18 | scikit-learn
19 | scikit-image
20 | rotary-embedding-torch==0.2.1
21 | pynvml==11.5.0
22 | triton==2.0.0.dev20221120
23 | pytorch-lightning==1.4.2
24 | torchmetrics==0.6.0
25 | gradio==3.39.0
26 | imageio-ffmpeg
27 | kornia
28 | tyro
29 | dearpygui
30 | einops
31 | lpips
32 | matplotlib
33 | packaging
34 | Pillow
35 | pygltflib
36 | rembg[gpu,cli]
37 | rich
38 | safetensors
39 | scipy
40 | tqdm
41 | trimesh
42 | kiui >= 0.2.3
43 | roma
44 | plyfile
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
1 | from .annotator import *
2 | from .datasets import *
3 | from .modules import *
4 | from .train import *
5 | from .hooks import *
6 | from .inferences import *
7 | # from .prior import *
8 | from .basic_funcs import *
9 |
--------------------------------------------------------------------------------
/tools/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/annotator/canny/__init__.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import torch
3 | import numpy as np
4 | from tools.annotator.util import HWC3
5 | # import gradio as gr
6 |
7 | class CannyDetector:
8 | def __call__(self, img, low_threshold = None, high_threshold = None, random_threshold = True):
9 |
10 | ### GPT-4 suggestions
11 | # In the cv2.Canny() function, the low threshold and high threshold are used to determine the edges based on the gradient values in the image.
12 | # There isn't a one-size-fits-all solution for these threshold values, as the optimal values depend on the specific image and the application.
13 | # However, there are some general guidelines and empirical values you can use as a starting point:
14 | # 1. Ratio: A common recommendation is to use a ratio of 1:2 or 1:3 between the low threshold and the high threshold.
15 | # This means if your low threshold is 50, the high threshold should be around 100 or 150.
16 | # 2. Empirical values: As a starting point, you can use low threshold values in the range of 50-100 and high threshold values in the range of 100-200.
17 | # You may need to fine-tune these values based on the specific image and desired edge detection results.
18 | # 3. Automatic threshold calculation: To automatically calculate the threshold values, you can use the median or mean value of the image's pixel intensities as the low threshold,
19 | # and the high threshold can be set as twice or three times the low threshold.
20 |
21 | ### Convert to numpy
22 | if isinstance(img, torch.Tensor): # (h, w, c)
23 | img = img.cpu().numpy()
24 | img_np = cv2.convertScaleAbs((img * 255.))
25 | elif isinstance(img, np.ndarray): # (h, w, c)
26 | img_np = img # we assume values are in the range from 0 to 255.
27 | else:
28 | assert False
29 |
30 | ### Select the threshold
31 | if (low_threshold is None) and (high_threshold is None):
32 | median_intensity = np.median(img_np)
33 | if random_threshold is False:
34 | low_threshold = int(max(0, (1 - 0.33) * median_intensity))
35 | high_threshold = int(min(255, (1 + 0.33) * median_intensity))
36 | else:
37 | random_canny = np.random.uniform(0.1, 0.4)
38 | # Might try other values
39 | low_threshold = int(max(0, (1 - random_canny) * median_intensity))
40 | high_threshold = 2 * low_threshold
41 |
42 | ### Detect canny edge
43 | canny_edge = cv2.Canny(img_np, low_threshold, high_threshold)
44 | ### Convert to 3 channels
45 | # canny_edge = HWC3(canny_edge)
46 |
47 | canny_condition = torch.from_numpy(canny_edge.copy()).unsqueeze(dim = -1).float().cuda() / 255.0
48 | # canny_condition = torch.stack([canny_condition for _ in range(num_samples)], dim=0)
49 | # canny_condition = einops.rearrange(canny_condition, 'h w c -> b c h w').clone()
50 | # return cv2.Canny(img, low_threshold, high_threshold)
51 | return canny_condition
--------------------------------------------------------------------------------
/tools/annotator/histogram/__init__.py:
--------------------------------------------------------------------------------
1 | from .palette import *
--------------------------------------------------------------------------------
/tools/annotator/histogram/palette.py:
--------------------------------------------------------------------------------
1 | r"""Modified from ``https://github.com/sergeyk/rayleigh''.
2 | """
3 | import os
4 | import os.path as osp
5 | import numpy as np
6 | from skimage.color import hsv2rgb, rgb2lab, lab2rgb
7 | from skimage.io import imsave
8 | from sklearn.metrics import euclidean_distances
9 |
10 | __all__ = ['Palette']
11 |
12 | def rgb2hex(rgb):
13 | return '#%02x%02x%02x' % tuple([int(round(255.0 * u)) for u in rgb])
14 |
15 | def hex2rgb(hex):
16 | rgb = hex.strip('#')
17 | fn = lambda u: round(int(u, 16) / 255.0, 5)
18 | return fn(rgb[:2]), fn(rgb[2:4]), fn(rgb[4:6])
19 |
20 | class Palette(object):
21 | r"""Create a color palette (codebook) in the form of a 2D grid of colors.
22 | Further, the rightmost column has num_hues gradations from black to white.
23 |
24 | Parameters:
25 | num_hues: number of colors with full lightness and saturation, in the middle.
26 | num_sat: number of rows above middle row that show the same hues with decreasing saturation.
27 | """
28 | def __init__(self, num_hues=11, num_sat=5, num_light=4):
29 | n = num_sat + 2 * num_light
30 |
31 | # hues
32 | if num_hues == 8:
33 | hues = np.tile(np.array([0., 0.10, 0.15, 0.28, 0.51, 0.58, 0.77, 0.85]), (n, 1))
34 | elif num_hues == 9:
35 | hues = np.tile(np.array([0., 0.10, 0.15, 0.28, 0.49, 0.54, 0.60, 0.7, 0.87]), (n, 1))
36 | elif num_hues == 10:
37 | hues = np.tile(np.array([0., 0.10, 0.15, 0.28, 0.49, 0.54, 0.60, 0.66, 0.76, 0.87]), (n, 1))
38 | elif num_hues == 11:
39 | hues = np.tile(np.array([0.0, 0.0833, 0.166, 0.25, 0.333, 0.5, 0.56333, 0.666, 0.73, 0.803, 0.916]), (n, 1))
40 | else:
41 | hues = np.tile(np.linspace(0, 1, num_hues + 1)[:-1], (n, 1))
42 |
43 | # saturations
44 | sats = np.hstack((
45 | np.linspace(0, 1, num_sat + 2)[1:-1],
46 | 1,
47 | [1] * num_light,
48 | [0.4] * (num_light - 1)))
49 | sats = np.tile(np.atleast_2d(sats).T, (1, num_hues))
50 |
51 | # lights
52 | lights = np.hstack((
53 | [1] * num_sat,
54 | 1,
55 | np.linspace(1, 0.2, num_light + 2)[1:-1],
56 | np.linspace(1, 0.2, num_light + 2)[1:-2]))
57 | lights = np.tile(np.atleast_2d(lights).T, (1, num_hues))
58 |
59 | # colors
60 | rgb = hsv2rgb(np.dstack([hues, sats, lights]))
61 | gray = np.tile(np.linspace(1, 0, n)[:, np.newaxis, np.newaxis], (1, 1, 3))
62 | self.thumbnail = np.hstack([rgb, gray])
63 |
64 | # flatten
65 | rgb = rgb.T.reshape(3, -1).T
66 | gray = gray.T.reshape(3, -1).T
67 | self.rgb = np.vstack((rgb, gray))
68 | self.lab = rgb2lab(self.rgb[np.newaxis, :, :]).squeeze()
69 | self.hex = [rgb2hex(u) for u in self.rgb]
70 | self.lab_dists = euclidean_distances(self.lab, squared=True)
71 |
72 | def histogram(self, rgb_img, sigma=20):
73 | # compute histogram
74 | lab = rgb2lab(rgb_img).reshape((-1, 3))
75 | min_ind = np.argmin(euclidean_distances(lab, self.lab, squared=True), axis=1)
76 | hist = 1.0 * np.bincount(min_ind, minlength=self.lab.shape[0]) / lab.shape[0]
77 |
78 | # smooth histogram
79 | if sigma > 0:
80 | weight = np.exp(-self.lab_dists / (2.0 * sigma ** 2))
81 | weight = weight / weight.sum(1)[:, np.newaxis]
82 | hist = (weight * hist).sum(1)
83 | hist[hist < 1e-5] = 0
84 | return hist
85 |
86 | def get_palette_image(self, hist, percentile=90, width=200, height=50):
87 | # curate histogram
88 | ind = np.argsort(-hist)
89 | ind = ind[hist[ind] > np.percentile(hist, percentile)]
90 | hist = hist[ind] / hist[ind].sum()
91 |
92 | # draw palette
93 | nums = np.array(hist * width, dtype=int)
94 | array = np.vstack([np.tile(np.array(u), (v, 1)) for u, v in zip(self.rgb[ind], nums)])
95 | array = np.tile(array[np.newaxis, :, :], (height, 1, 1))
96 | if array.shape[1] < width:
97 | array = np.concatenate([array, np.zeros((height, width - array.shape[1], 3))], axis=1)
98 | return array
99 |
100 | def quantize_image(self, rgb_img):
101 | lab = rgb2lab(rgb_img).reshape((-1, 3))
102 | min_ind = np.argmin(euclidean_distances(lab, self.lab, squared=True), axis=1)
103 | quantized_lab = self.lab[min_ind]
104 | img = lab2rgb(quantized_lab.reshape(rgb_img.shape))
105 | return img
106 |
107 | def export(self, dirname):
108 | if not osp.exists(dirname):
109 | os.makedirs(dirname)
110 |
111 | # save thumbnail
112 | imsave(osp.join(dirname, 'palette.png'), self.thumbnail)
113 |
114 | # save html
115 | with open(osp.join(dirname, 'palette.html'), 'w') as f:
116 | html = '''
117 |
126 | '''
127 | for row in self.thumbnail:
128 | for col in row:
129 | html += '\n'.format(rgb2hex(col))
130 | html += '
\n'
131 | f.write(html)
132 |
--------------------------------------------------------------------------------
/tools/annotator/sketch/__init__.py:
--------------------------------------------------------------------------------
1 | from .pidinet import *
2 | from .sketch_simplification import *
--------------------------------------------------------------------------------
/tools/annotator/sketch/sketch_simplification.py:
--------------------------------------------------------------------------------
1 | r"""PyTorch re-implementation adapted from the Lua code in ``https://github.com/bobbens/sketch_simplification''.
2 | """
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import math
7 |
8 | # from canvas import DOWNLOAD_TO_CACHE
9 | from artist import DOWNLOAD_TO_CACHE
10 |
11 | __all__ = ['SketchSimplification', 'sketch_simplification_gan', 'sketch_simplification_mse',
12 | 'sketch_to_pencil_v1', 'sketch_to_pencil_v2']
13 |
14 | class SketchSimplification(nn.Module):
15 | r"""NOTE:
16 | 1. Input image should has only one gray channel.
17 | 2. Input image size should be divisible by 8.
18 | 3. Sketch in the input/output image is in dark color while background in light color.
19 | """
20 | def __init__(self, mean, std):
21 | assert isinstance(mean, float) and isinstance(std, float)
22 | super(SketchSimplification, self).__init__()
23 | self.mean = mean
24 | self.std = std
25 |
26 | # layers
27 | self.layers = nn.Sequential(
28 | nn.Conv2d(1, 48, 5, 2, 2),
29 | nn.ReLU(inplace=True),
30 | nn.Conv2d(48, 128, 3, 1, 1),
31 | nn.ReLU(inplace=True),
32 | nn.Conv2d(128, 128, 3, 1, 1),
33 | nn.ReLU(inplace=True),
34 | nn.Conv2d(128, 128, 3, 2, 1),
35 | nn.ReLU(inplace=True),
36 | nn.Conv2d(128, 256, 3, 1, 1),
37 | nn.ReLU(inplace=True),
38 | nn.Conv2d(256, 256, 3, 1, 1),
39 | nn.ReLU(inplace=True),
40 | nn.Conv2d(256, 256, 3, 2, 1),
41 | nn.ReLU(inplace=True),
42 | nn.Conv2d(256, 512, 3, 1, 1),
43 | nn.ReLU(inplace=True),
44 | nn.Conv2d(512, 1024, 3, 1, 1),
45 | nn.ReLU(inplace=True),
46 | nn.Conv2d(1024, 1024, 3, 1, 1),
47 | nn.ReLU(inplace=True),
48 | nn.Conv2d(1024, 1024, 3, 1, 1),
49 | nn.ReLU(inplace=True),
50 | nn.Conv2d(1024, 1024, 3, 1, 1),
51 | nn.ReLU(inplace=True),
52 | nn.Conv2d(1024, 512, 3, 1, 1),
53 | nn.ReLU(inplace=True),
54 | nn.Conv2d(512, 256, 3, 1, 1),
55 | nn.ReLU(inplace=True),
56 | nn.ConvTranspose2d(256, 256, 4, 2, 1),
57 | nn.ReLU(inplace=True),
58 | nn.Conv2d(256, 256, 3, 1, 1),
59 | nn.ReLU(inplace=True),
60 | nn.Conv2d(256, 128, 3, 1, 1),
61 | nn.ReLU(inplace=True),
62 | nn.ConvTranspose2d(128, 128, 4, 2, 1),
63 | nn.ReLU(inplace=True),
64 | nn.Conv2d(128, 128, 3, 1, 1),
65 | nn.ReLU(inplace=True),
66 | nn.Conv2d(128, 48, 3, 1, 1),
67 | nn.ReLU(inplace=True),
68 | nn.ConvTranspose2d(48, 48, 4, 2, 1),
69 | nn.ReLU(inplace=True),
70 | nn.Conv2d(48, 24, 3, 1, 1),
71 | nn.ReLU(inplace=True),
72 | nn.Conv2d(24, 1, 3, 1, 1),
73 | nn.Sigmoid())
74 |
75 | def forward(self, x):
76 | r"""x: [B, 1, H, W] within range [0, 1]. Sketch pixels in dark color.
77 | """
78 | x = (x - self.mean) / self.std
79 | return self.layers(x)
80 |
81 | def sketch_simplification_gan(pretrained=False):
82 | model = SketchSimplification(mean=0.9664114577640158, std=0.0858381272736797)
83 | if pretrained:
84 | # model.load_state_dict(torch.load(
85 | # DOWNLOAD_TO_CACHE('models/sketch_simplification/sketch_simplification_gan.pth'),
86 | # map_location='cpu'))
87 | model.load_state_dict(torch.load(
88 | DOWNLOAD_TO_CACHE('VideoComposer/Hangjie/models/sketch_simplification/sketch_simplification_gan.pth'),
89 | map_location='cpu'))
90 | return model
91 |
92 | def sketch_simplification_mse(pretrained=False):
93 | model = SketchSimplification(mean=0.9664423107454593, std=0.08583666033640507)
94 | if pretrained:
95 | model.load_state_dict(torch.load(
96 | DOWNLOAD_TO_CACHE('models/sketch_simplification/sketch_simplification_mse.pth'),
97 | map_location='cpu'))
98 | return model
99 |
100 | def sketch_to_pencil_v1(pretrained=False):
101 | model = SketchSimplification(mean=0.9817833515894078, std=0.0925009022585048)
102 | if pretrained:
103 | model.load_state_dict(torch.load(
104 | DOWNLOAD_TO_CACHE('models/sketch_simplification/sketch_to_pencil_v1.pth'),
105 | map_location='cpu'))
106 | return model
107 |
108 | def sketch_to_pencil_v2(pretrained=False):
109 | model = SketchSimplification(mean=0.9851298627337799, std=0.07418377454883571)
110 | if pretrained:
111 | model.load_state_dict(torch.load(
112 | DOWNLOAD_TO_CACHE('models/sketch_simplification/sketch_to_pencil_v2.pth'),
113 | map_location='cpu'))
114 | return model
115 |
--------------------------------------------------------------------------------
/tools/annotator/util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import os
4 |
5 | annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts')
6 |
7 | def HWC3(x):
8 | assert x.dtype == np.uint8
9 | if x.ndim == 2:
10 | x = x[:, :, None]
11 | assert x.ndim == 3
12 | H, W, C = x.shape
13 | assert C == 1 or C == 3 or C == 4
14 | if C == 3:
15 | return x
16 | if C == 1:
17 | return np.concatenate([x, x, x], axis=2)
18 | if C == 4:
19 | color = x[:, :, 0:3].astype(np.float32)
20 | alpha = x[:, :, 3:4].astype(np.float32) / 255.0
21 | y = color * alpha + 255.0 * (1.0 - alpha)
22 | y = y.clip(0, 255).astype(np.uint8)
23 | return y
24 |
25 |
26 | def resize_image(input_image, resolution):
27 | H, W, C = input_image.shape
28 | H = float(H)
29 | W = float(W)
30 | k = float(resolution) / min(H, W)
31 | H *= k
32 | W *= k
33 | H = int(np.round(H / 64.0)) * 64
34 | W = int(np.round(W / 64.0)) * 64
35 | img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
36 | return img
--------------------------------------------------------------------------------
/tools/basic_funcs/__init__.py:
--------------------------------------------------------------------------------
1 | from .pretrain_functions import *
2 |
--------------------------------------------------------------------------------
/tools/basic_funcs/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/basic_funcs/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/tools/basic_funcs/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/basic_funcs/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/basic_funcs/__pycache__/pretrain_functions.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/basic_funcs/__pycache__/pretrain_functions.cpython-310.pyc
--------------------------------------------------------------------------------
/tools/basic_funcs/__pycache__/pretrain_functions.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/basic_funcs/__pycache__/pretrain_functions.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/basic_funcs/pretrain_functions.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | import logging
5 | import collections
6 |
7 | from utils.registry_class import PRETRAIN
8 |
9 | @PRETRAIN.register_function()
10 | def pretrain_specific_strategies(
11 | model,
12 | resume_checkpoint,
13 | sd_keys_path=None,
14 | grad_scale=1,
15 | fix_weight=False,
16 | **kwargs
17 | ):
18 |
19 | state_dict = torch.load(resume_checkpoint, map_location='cpu')
20 | if 'state_dict' in state_dict:
21 | state_dict = state_dict['state_dict']
22 |
23 | # [1] load model
24 | try:
25 | ret = model.load_state_dict(state_dict, strict=False)
26 | logging.info(f'load a fixed model with {ret}')
27 | except:
28 | model_dict = model.state_dict()
29 | key_list = list(state_dict.keys())
30 | for skey, item in state_dict.items():
31 | if skey not in model_dict:
32 | logging.info(f'Skip {skey}')
33 | continue
34 | if item.shape != model_dict[skey].shape:
35 | logging.info(f'Skip {skey} with different shape {item.shape} {model_dict[skey].shape}')
36 | continue
37 | model_dict[skey].copy_(item)
38 | model.load_state_dict(model_dict)
39 |
40 | # [2] assign strategies
41 | total_size = 0
42 | state_dict = {} if sd_keys_path is None else json.load(open(sd_keys_path))
43 | for k, p in model.named_parameters():
44 | if k in state_dict:
45 | total_size += p.numel()
46 | if fix_weight:
47 | p.requires_grad=False
48 | else:
49 | p.register_hook(lambda grad: grad_scale * grad)
50 |
51 | resume_step = int(os.path.basename(resume_checkpoint).split('_')[-1].split('.')[0])
52 | logging.info(f'Successfully load step {resume_step} model from {resume_checkpoint}')
53 | logging.info(f'load a fixed model with {int(total_size / (1024 ** 2))}M parameters')
54 | return model, resume_step
55 |
56 |
57 |
58 | @PRETRAIN.register_function()
59 | def pretrain_from_sd():
60 | pass
61 |
62 |
63 | @PRETRAIN.register_function()
64 | def pretrain_ema_model():
65 | pass
66 |
--------------------------------------------------------------------------------
/tools/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .image_dataset import *
2 | from .video_dataset import *
3 | from .video_i2v_dataset import *
--------------------------------------------------------------------------------
/tools/datasets/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/datasets/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/tools/datasets/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/datasets/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/datasets/__pycache__/image_dataset.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/datasets/__pycache__/image_dataset.cpython-310.pyc
--------------------------------------------------------------------------------
/tools/datasets/__pycache__/image_dataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/datasets/__pycache__/image_dataset.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/datasets/__pycache__/video_dataset.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/datasets/__pycache__/video_dataset.cpython-310.pyc
--------------------------------------------------------------------------------
/tools/datasets/__pycache__/video_dataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/datasets/__pycache__/video_dataset.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/datasets/__pycache__/video_i2v_dataset.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/datasets/__pycache__/video_i2v_dataset.cpython-310.pyc
--------------------------------------------------------------------------------
/tools/datasets/__pycache__/video_i2v_dataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/datasets/__pycache__/video_i2v_dataset.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/datasets/image_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import torch
4 | import random
5 | import logging
6 | import tempfile
7 | import numpy as np
8 | from copy import copy
9 | from PIL import Image
10 | from io import BytesIO
11 | from torch.utils.data import Dataset
12 | from utils.registry_class import DATASETS
13 |
14 | @DATASETS.register_class()
15 | class ImageDataset(Dataset):
16 | def __init__(self,
17 | data_list,
18 | data_dir_list,
19 | max_words=1000,
20 | vit_resolution=[224, 224],
21 | resolution=(384, 256),
22 | max_frames=1,
23 | transforms=None,
24 | vit_transforms=None,
25 | **kwargs):
26 |
27 | self.max_frames = max_frames
28 | self.resolution = resolution
29 | self.transforms = transforms
30 | self.vit_resolution = vit_resolution
31 | self.vit_transforms = vit_transforms
32 |
33 | image_list = []
34 | for item_path, data_dir in zip(data_list, data_dir_list):
35 | lines = open(item_path, 'r').readlines()
36 | lines = [[data_dir, item.strip()] for item in lines]
37 | image_list.extend(lines)
38 | self.image_list = image_list
39 |
40 | def __len__(self):
41 | return len(self.image_list)
42 |
43 | def __getitem__(self, index):
44 | data_dir, file_path = self.image_list[index]
45 | img_key = file_path.split('|||')[0]
46 | try:
47 | ref_frame, vit_frame, video_data, caption = self._get_image_data(data_dir, file_path)
48 | except Exception as e:
49 | logging.info('{} get frames failed... with error: {}'.format(img_key, e))
50 | caption = ''
51 | img_key = ''
52 | ref_frame = torch.zeros(3, self.resolution[1], self.resolution[0])
53 | vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0])
54 | video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0])
55 | return ref_frame, vit_frame, video_data, caption, img_key
56 |
57 | def _get_image_data(self, data_dir, file_path):
58 | frame_list = []
59 | img_key, caption = file_path.split('|||')
60 | file_path = os.path.join(data_dir, img_key)
61 | for _ in range(5):
62 | try:
63 | image = Image.open(file_path)
64 | if image.mode != 'RGB':
65 | image = image.convert('RGB')
66 | frame_list.append(image)
67 | break
68 | except Exception as e:
69 | logging.info('{} read video frame failed with error: {}'.format(img_key, e))
70 | continue
71 |
72 | video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0])
73 | try:
74 | if len(frame_list) > 0:
75 | mid_frame = frame_list[0]
76 | vit_frame = self.vit_transforms(mid_frame)
77 | frame_tensor = self.transforms(frame_list)
78 | video_data[:len(frame_list), ...] = frame_tensor
79 | else:
80 | vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0])
81 | except:
82 | vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0])
83 | ref_frame = copy(video_data[0])
84 |
85 | return ref_frame, vit_frame, video_data, caption
86 |
87 |
--------------------------------------------------------------------------------
/tools/datasets/laion_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import cv2
4 | import torch
5 | import random
6 | import logging
7 | import tempfile
8 | import numpy as np
9 | from functools import partial
10 | from copy import copy
11 | from PIL import Image
12 | from io import BytesIO
13 | from torch.utils.data import Dataset
14 | import torchvision.transforms.functional as TF
15 | import albumentations
16 | import PIL
17 | from PIL import Image, ImageFile
18 | ImageFile.LOAD_TRUNCATED_IMAGES = True
19 | import webdataset as wds
20 |
21 | try:
22 | from utils.registry_class import DATASETS
23 | except Exception as ex:
24 | print("#" * 20)
25 | print("import error, try fixed by appending path")
26 | import sys
27 | sys.path.append("./")
28 | from utils.registry_class import DATASETS
29 |
30 |
31 |
32 | def HWC3(x):
33 | assert x.dtype == np.uint8
34 | if x.ndim == 2:
35 | x = x[:, :, None]
36 | assert x.ndim == 3
37 | H, W, C = x.shape
38 | assert C == 1 or C == 3 or C == 4
39 | if C == 3:
40 | return x
41 | if C == 1:
42 | return np.concatenate([x, x, x], axis=2)
43 | if C == 4:
44 | color = x[:, :, 0:3].astype(np.float32)
45 | alpha = x[:, :, 3:4].astype(np.float32) / 255.0
46 | y = color * alpha + 255.0 * (1.0 - alpha)
47 | y = y.clip(0, 255).astype(np.uint8)
48 | return y
49 |
50 |
51 | def resize_image(input_image, resolution):
52 | H, W, C = input_image.shape
53 | H = float(H)
54 | W = float(W)
55 | k = float(resolution) / min(H, W)
56 | H *= k
57 | W *= k
58 | H = int(np.round(H / 64.0)) * 64
59 | W = int(np.round(W / 64.0)) * 64
60 | img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
61 | return img
62 |
63 |
64 | def my_decoder(key, value):
65 | # solve the issue: https://github.com/webdataset/webdataset/issues/206
66 |
67 | if key.endswith('.jpg'):
68 | # return Image.open(BytesIO(value))
69 | return np.asarray(Image.open(BytesIO(value)).convert('RGB'))
70 |
71 | return None
72 |
73 |
74 | class filter_fake:
75 |
76 | def __init__(self, punsafety=0.2, aest=4.5):
77 | self.punsafety = punsafety
78 | self.aest = aest
79 |
80 | def __call__(self, src):
81 | for sample in src:
82 | img, prompt, json = sample
83 | # watermark filter
84 | if json['pwatermark'] is not None:
85 | if json['pwatermark'] > 0.3:
86 | continue
87 |
88 | # watermark
89 | if json['punsafe'] is not None:
90 | if json['punsafe'] > self.punsafety:
91 | continue
92 |
93 | # watermark
94 | if json['AESTHETIC_SCORE'] is not None:
95 | if json['AESTHETIC_SCORE'] < self.aest:
96 | continue
97 |
98 | # ratio filter
99 | w, h = json['width'], json['height']
100 | if max(w / h, h / w) > 3:
101 | continue
102 |
103 | yield img, prompt, json['AESTHETIC_SCORE'], json['key']
104 |
105 |
106 | class Laion2b_Process(object):
107 |
108 | def __init__(self,
109 | size=None,
110 | degradation=None,
111 | downscale_f=4,
112 | min_crop_f=0.8,
113 | max_crop_f=1.,
114 | random_crop=True,
115 | debug: bool = False):
116 | """
117 | Imagenet Superresolution Dataloader
118 | Performs following ops in order:
119 | 1. crops a crop of size s from image either as random or center crop
120 | 2. resizes crop to size with cv2.area_interpolation
121 | 3. degrades resized crop with degradation_fn
122 |
123 | :param size: resizing to size after cropping
124 | :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light
125 | :param downscale_f: Low Resolution Downsample factor
126 | :param min_crop_f: determines crop size s,
127 | where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)
128 | :param max_crop_f: ""
129 | :param data_root:
130 | :param random_crop:
131 | """
132 | # downsacle_f = 0
133 |
134 | assert size
135 | assert (size / downscale_f).is_integer()
136 | self.size = size
137 | self.LR_size = int(size / downscale_f)
138 | self.min_crop_f = min_crop_f
139 | self.max_crop_f = max_crop_f
140 | assert (max_crop_f <= 1.)
141 | self.center_crop = not random_crop
142 |
143 | self.image_rescaler = albumentations.SmallestMaxSize(
144 | max_size=size, interpolation=cv2.INTER_AREA)
145 |
146 |
147 | def __call__(self, samples):
148 | example = {}
149 | image, caption, aesthetics, key = samples
150 |
151 | image = np.array(image).astype(np.uint8)
152 |
153 | min_side_len = min(image.shape[:2])
154 | crop_side_len = min_side_len * np.random.uniform(
155 | self.min_crop_f, self.max_crop_f, size=None)
156 | crop_side_len = int(crop_side_len)
157 |
158 | if self.center_crop:
159 | self.cropper = albumentations.CenterCrop(
160 | height=crop_side_len, width=crop_side_len)
161 | else:
162 | self.cropper = albumentations.RandomCrop(
163 | height=crop_side_len, width=crop_side_len)
164 |
165 | image = self.cropper(image=image)['image']
166 | image = self.image_rescaler(image=image)['image']
167 |
168 | # -1, 1
169 | ref_image = (image / 127.5 - 1.0).astype(np.float32)
170 | ref_image = ref_image.transpose(2, 0, 1)
171 | vit_image = ref_image
172 | video_data = ref_image[np.newaxis, :, :, :]
173 |
174 |
175 | # example['image'] = image
176 | # # depth prior is set to 384
177 | # example['prior'] = resize_image(HWC3(image), 384)
178 | # example['caption'] = caption
179 | # example['aesthetics'] = aesthetics
180 | # example['key'] = key
181 |
182 | return ref_image, vit_image, video_data, caption, key
183 |
184 |
185 | @DATASETS.register_class()
186 | class LAIONImageDataset():
187 | def __init__(self,
188 | data_list,
189 | data_dir_list,
190 | max_words=1000,
191 | vit_resolution=[224, 224],
192 | resolution=(256, 256),
193 | max_frames=1,
194 | transforms=None,
195 | vit_transforms=None,
196 | **kwargs):
197 |
198 | aest = kwargs.get("aest", 4.0)
199 | punsafety = kwargs.get("punsafety", 0.2)
200 | min_crop_f = kwargs.get("min_crop_f", 1.0)
201 | self.num_samples = kwargs.get("num_samples", 60580*2000)
202 |
203 | assert resolution[0] == resolution[1]
204 | assert len(data_dir_list) == 1
205 | assert len(data_list) == 1
206 |
207 | self.web_dataset = wds.WebDataset(os.path.join(data_dir_list[0], data_list[0]), resampled=True).decode(
208 | my_decoder, 'rgb8').shuffle(1000).to_tuple(
209 | 'jpg', 'txt', 'json').compose(
210 | filter_fake(aest=aest, punsafety=punsafety)).map(
211 | Laion2b_Process(
212 | size=resolution[0],
213 | min_crop_f=min_crop_f)
214 | )
215 |
216 | def create_dataloader(self, batch_size, world_size, workers):
217 | num_samples = self.num_samples
218 | self.dataset = self.web_dataset.batched(batch_size, partial=False)
219 | round_fn = math.ceil
220 | global_batch_size = batch_size * world_size
221 | num_batches = round_fn(num_samples / global_batch_size)
222 | num_workers = max(1, workers)
223 | num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker
224 | num_batches = num_worker_batches * num_workers
225 | num_samples = num_batches * global_batch_size
226 | dataset = self.dataset.with_epoch(num_worker_batches) # each worker is iterating over this
227 |
228 | self.dataloader = wds.WebLoader(
229 | dataset,
230 | batch_size=None,
231 | shuffle=False,
232 | num_workers=workers,
233 | persistent_workers=workers > 0,
234 | )
235 |
236 | self.dataloader.num_batches = num_batches
237 | self.dataloader.num_samples = num_samples
238 |
239 | print("#"*50)
240 | print(f"dataloder, num_batches:{num_batches}, num_samples:{num_samples}")
241 | print("#"*50)
242 | return self.dataloader
243 |
244 |
245 |
246 | if __name__ == "__main__":
247 | dataset = LAIONImageDataset(
248 | data_list=['{00000..00001}.tar'],
249 | data_dir_list=['/home/gxd/projects/Normal-Depth-Diffusion-Model/tools/download_dataset/laion-2ben-5_aes/'],
250 | max_words=1000,
251 | resolution=(256, 256),
252 | vit_resolution=(224, 224),
253 | max_frames=24,
254 | sample_fps=1,
255 | transforms=None,
256 | vit_transforms=None,
257 | get_first_frame=True,
258 | num_samples=1000,
259 | debug=True)
260 |
261 | batch_size = 20
262 | world_size = 1
263 | workers = 10
264 |
265 | dataloader = dataset.create_dataloader(batch_size, world_size, workers)
266 |
267 | import tqdm
268 | key_list = []
269 | for data in tqdm.tqdm(dataloader):
270 | pass
271 | print(data[0].shape, data[1].shape, data[2].shape)
272 | key_list.extend(data[4])
273 | print(len(key_list), len(set(key_list)))
274 |
275 |
276 |
--------------------------------------------------------------------------------
/tools/datasets/video_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import json
4 | import torch
5 | import random
6 | import logging
7 | import tempfile
8 | import numpy as np
9 | from copy import copy
10 | from PIL import Image
11 | import torch.nn.functional as F
12 | from torch.utils.data import Dataset
13 | from utils.registry_class import DATASETS
14 | from core.utils import get_rays, grid_distortion, orbit_camera_jitter
15 |
16 | def read_camera_matrix_single(json_file):
17 | with open(json_file, 'r', encoding='utf8') as reader:
18 | json_content = json.load(reader)
19 |
20 | cond_camera_matrix = np.eye(4)
21 | cond_camera_matrix[:3, 0] = np.array(json_content['x'])
22 | cond_camera_matrix[:3, 1] = -np.array(json_content['y'])
23 | cond_camera_matrix[:3, 2] = -np.array(json_content['z'])
24 | cond_camera_matrix[:3, 3] = np.array(json_content['origin'])
25 |
26 |
27 | camera_matrix = np.eye(4)
28 | camera_matrix[:3, 0] = np.array(json_content['x'])
29 | camera_matrix[:3, 1] = np.array(json_content['y'])
30 | camera_matrix[:3, 2] = np.array(json_content['z'])
31 | camera_matrix[:3, 3] = np.array(json_content['origin'])
32 |
33 | return camera_matrix, cond_camera_matrix
34 |
35 | @DATASETS.register_class()
36 | class VideoDataset(Dataset):
37 | def __init__(self,
38 | data_list,
39 | data_dir_list,
40 | caption_dir,
41 | max_words=1000,
42 | resolution=(384, 256),
43 | vit_resolution=(224, 224),
44 | max_frames=16,
45 | sample_fps=8,
46 | transforms=None,
47 | vit_transforms=None,
48 | get_first_frame=True,
49 | prepare_lgm=False,
50 | **kwargs):
51 | self.prepare_lgm = prepare_lgm
52 | self.max_words = max_words
53 | self.max_frames = max_frames
54 | self.resolution = resolution
55 | self.vit_resolution = vit_resolution
56 | self.sample_fps = sample_fps
57 | self.transforms = transforms
58 | self.vit_transforms = vit_transforms
59 | self.get_first_frame = get_first_frame
60 |
61 | # @NOTE instead we read json
62 | image_list = []
63 | self.captions = json.load(open(caption_dir))
64 | for item_path, data_dir in zip(data_list, data_dir_list):
65 | lines = json.load(open(item_path))
66 | lines = [[data_dir, item] for item in lines]
67 | image_list.extend(lines)
68 | self.image_list = image_list
69 | self.replica = 1000
70 |
71 | if self.prepare_lgm:
72 | from core.options import config_defaults
73 | self.opt = config_defaults['big']
74 | # default camera intrinsics
75 | self.tan_half_fov = np.tan(0.5 * np.deg2rad(self.opt.fovy))
76 | self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32)
77 | self.proj_matrix[0, 0] = 1 / self.tan_half_fov
78 | self.proj_matrix[1, 1] = 1 / self.tan_half_fov
79 | self.proj_matrix[2, 2] = (self.opt.zfar + self.opt.znear) / (self.opt.zfar - self.opt.znear)
80 | self.proj_matrix[3, 2] = - (self.opt.zfar * self.opt.znear) / (self.opt.zfar - self.opt.znear)
81 | self.proj_matrix[2, 3] = 1
82 |
83 | def __getitem__(self, index):
84 | index = index % len(self.image_list)
85 | data_dir, file_path = self.image_list[index]
86 | video_key = file_path
87 | caption = self.captions[file_path] + ", 3d asset"
88 |
89 | try:
90 | ref_frame, vit_frame, video_data, fullreso_video_data, camera_data, mask_data, fullreso_mask_data = self._get_video_data(data_dir, file_path)
91 | if self.prepare_lgm:
92 | results = self.prepare_gs(camera_data.clone(), fullreso_mask_data.clone(), fullreso_video_data.clone())
93 | results['images_output'] = fullreso_video_data # GT renderings of [512, 512] resolution in the range [0,1]
94 | except Exception as e:
95 | print(e)
96 | return self.__getitem__((index+1)%len(self)) # next available data
97 |
98 | if self.prepare_lgm:
99 | return results, ref_frame, vit_frame, video_data, camera_data, mask_data, caption, video_key
100 | else:
101 | return ref_frame, vit_frame, video_data, camera_data, mask_data, caption, video_key
102 |
103 | def prepare_gs(self, camera_data, mask_data, video_data): # mask_data [24,512,512,1]
104 |
105 | results = {}
106 |
107 | mask_data = mask_data.permute(0,3,1,2)
108 | results['masks_output'] = mask_data/255.0 # TODO normalize to [0, 1]
109 |
110 | T = camera_data.shape[0]
111 | camera_data = camera_data.view(T,4,4).contiguous()
112 |
113 | camera_data[:,1] *= -1
114 | camera_data[:,[1, 2]] = camera_data[:,[2, 1]]
115 | cam_dis = np.sqrt(camera_data[0,0,3]**2 + camera_data[0,1,3]**2 + camera_data[0,2,3]**2)
116 |
117 | # normalized camera feats as in paper (transform the first pose to a fixed position)
118 | transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, cam_dis], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(camera_data[0])
119 | cam_poses = transform.unsqueeze(0) @ camera_data # [V, 4, 4]
120 |
121 | cam_poses_input = cam_poses.clone()
122 |
123 | rays_embeddings = []
124 | for i in range(T):
125 | rays_o, rays_d = get_rays(cam_poses_input[i], 256, 256, self.opt.fovy) # [h, w, 3]
126 | rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6]
127 | rays_embeddings.append(rays_plucker)
128 |
129 | rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous() # [V=24, 6, h, w]
130 | results['input'] = rays_embeddings
131 |
132 | # opengl to colmap camera for gs renderer
133 | cam_poses_input[:,:3,1:3] *= -1
134 |
135 | # cameras needed by gaussian rasterizer
136 | cam_view = torch.inverse(cam_poses_input).transpose(1, 2) # [V, 4, 4]
137 | cam_view_proj = cam_view @ self.proj_matrix # [V, 4, 4]
138 | cam_pos = - cam_poses_input[:, :3, 3] # [V, 3]
139 |
140 | results['cam_view'] = cam_view
141 | results['cam_view_proj'] = cam_view_proj
142 | results['cam_pos'] = cam_pos
143 |
144 | return results
145 |
146 | def _get_video_data(self, data_dir, file_path):
147 | prefix = os.path.join(data_dir, file_path, 'campos_512_v4')
148 |
149 | frames_path = [os.path.join(prefix, "{:05d}/{:05d}.png".format(frame_idx, frame_idx)) for frame_idx in range(24)]
150 | camera_path = [os.path.join(prefix, "{:05d}/{:05d}.json".format(frame_idx, frame_idx)) for frame_idx in range(24)]
151 |
152 | frame_list = []
153 | fullreso_frame_list = []
154 | camera_list = []
155 | mask_list = []
156 | fullreso_mask_list = []
157 | for frame_idx, frame_path in enumerate(frames_path):
158 | img = Image.open(frame_path).convert('RGBA')
159 | mask = torch.from_numpy(np.array(img.resize((self.resolution[1], self.resolution[0])))[:,:,-1]).unsqueeze(-1)
160 | mask_list.append(mask)
161 | fullreso_mask = torch.from_numpy(np.array(img)[:,:,-1]).unsqueeze(-1)
162 | fullreso_mask_list.append(fullreso_mask)
163 |
164 | width = img.width
165 | height = img.height
166 | # grey_scale = random.randint(128, 130) # random gray color
167 | grey_scale = 128
168 | image = Image.new('RGB', size=(width, height), color=(grey_scale,grey_scale,grey_scale))
169 | image.paste(img,(0,0),mask=img)
170 |
171 | fullreso_frame_list.append(torch.from_numpy(np.array(image)/255.0).float()) # for LGM rendering NOTE notice the data range [0,1]
172 | frame_list.append(image.resize((self.resolution[1], self.resolution[0])))
173 |
174 | _, camera_embedding = read_camera_matrix_single(camera_path[frame_idx])
175 | camera_list.append(torch.from_numpy(camera_embedding.flatten().astype(np.float32)))
176 |
177 | camera_data = torch.stack(camera_list, dim=0) # [24,16]
178 | mask_data = torch.stack(mask_list, dim=0)
179 | fullreso_mask_data = torch.stack(fullreso_mask_list, dim=0)
180 |
181 | video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0])
182 | fullreso_video_data = torch.zeros(self.max_frames, 3, 512, 512)
183 | if self.get_first_frame:
184 | ref_idx = 0
185 | else:
186 | ref_idx = int(len(frame_list)/2)
187 |
188 | mid_frame = copy(frame_list[ref_idx])
189 | vit_frame = self.vit_transforms(mid_frame)
190 | frames = self.transforms(frame_list)
191 | video_data[:len(frame_list), ...] = frames
192 |
193 | fullreso_video_data[:len(fullreso_frame_list), ...] = torch.stack(fullreso_frame_list, dim=0).permute(0,3,1,2)
194 |
195 | ref_frame = copy(frames[ref_idx])
196 |
197 | return ref_frame, vit_frame, video_data, fullreso_video_data, camera_data, mask_data, fullreso_mask_data
198 |
199 | def __len__(self):
200 | return len(self.image_list)*self.replica
--------------------------------------------------------------------------------
/tools/datasets/video_i2v_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import json
4 | import torch
5 | import random
6 | import logging
7 | import tempfile
8 | import numpy as np
9 | from copy import copy
10 | from PIL import Image
11 | import torch.nn.functional as F
12 | from torch.utils.data import Dataset
13 | from utils.registry_class import DATASETS
14 | from core.utils import get_rays, grid_distortion, orbit_camera_jitter
15 |
16 | def read_camera_matrix_single(json_file):
17 | with open(json_file, 'r', encoding='utf8') as reader:
18 | json_content = json.load(reader)
19 |
20 | cond_camera_matrix = np.eye(4)
21 | cond_camera_matrix[:3, 0] = np.array(json_content['x'])
22 | cond_camera_matrix[:3, 1] = -np.array(json_content['y'])
23 | cond_camera_matrix[:3, 2] = -np.array(json_content['z'])
24 | cond_camera_matrix[:3, 3] = np.array(json_content['origin'])
25 |
26 |
27 | camera_matrix = np.eye(4)
28 | camera_matrix[:3, 0] = np.array(json_content['x'])
29 | camera_matrix[:3, 1] = np.array(json_content['y'])
30 | camera_matrix[:3, 2] = np.array(json_content['z'])
31 | camera_matrix[:3, 3] = np.array(json_content['origin'])
32 |
33 | return camera_matrix, cond_camera_matrix
34 |
35 | @DATASETS.register_class()
36 | class Video_I2V_Dataset(Dataset):
37 | def __init__(self,
38 | data_list,
39 | data_dir_list,
40 | caption_dir,
41 | max_words=1000,
42 | resolution=(384, 256),
43 | vit_resolution=(224, 224),
44 | max_frames=16,
45 | sample_fps=8,
46 | transforms=None,
47 | vit_transforms=None,
48 | get_first_frame=True,
49 | prepare_lgm=False,
50 | **kwargs):
51 |
52 | self.prepare_lgm = prepare_lgm
53 | self.max_words = max_words
54 | self.max_frames = max_frames
55 | self.resolution = resolution
56 | self.vit_resolution = vit_resolution
57 | self.sample_fps = sample_fps
58 | self.transforms = transforms
59 | self.vit_transforms = vit_transforms
60 | self.get_first_frame = get_first_frame
61 |
62 | # @NOTE instead we read json
63 | image_list = []
64 | # self.captions = json.load(open(caption_dir))
65 | self.captions = None
66 | for item_path, data_dir in zip(data_list, data_dir_list):
67 | lines = json.load(open(item_path))
68 | lines = [[data_dir, item] for item in lines]
69 | image_list.extend(lines)
70 | self.image_list = image_list
71 | self.replica = 1000
72 |
73 | if self.prepare_lgm:
74 | from core.options import config_defaults
75 | self.opt = config_defaults['big']
76 | # default camera intrinsics
77 | self.tan_half_fov = np.tan(0.5 * np.deg2rad(self.opt.fovy))
78 | self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32)
79 | self.proj_matrix[0, 0] = 1 / self.tan_half_fov
80 | self.proj_matrix[1, 1] = 1 / self.tan_half_fov
81 | self.proj_matrix[2, 2] = (self.opt.zfar + self.opt.znear) / (self.opt.zfar - self.opt.znear)
82 | self.proj_matrix[3, 2] = - (self.opt.zfar * self.opt.znear) / (self.opt.zfar - self.opt.znear)
83 | self.proj_matrix[2, 3] = 1
84 |
85 | def __getitem__(self, index):
86 | index = index % len(self.image_list)
87 | data_dir, file_path = self.image_list[index]
88 | video_key = file_path
89 | caption = ""
90 |
91 | try:
92 | ref_frame, vit_frame, video_data, fullreso_video_data, camera_data, mask_data, fullreso_mask_data = self._get_video_data(data_dir, file_path)
93 | if self.prepare_lgm:
94 | results = self.prepare_gs(camera_data.clone(), fullreso_mask_data.clone(), fullreso_video_data.clone())
95 | results['images_output'] = fullreso_video_data # GT renderings of [512, 512] resolution in the range [0,1]
96 | except Exception as e:
97 | print(e)
98 | return self.__getitem__((index+1)%len(self)) # next available data
99 |
100 | if self.prepare_lgm:
101 | return results, ref_frame, vit_frame, video_data, camera_data, mask_data, caption, video_key
102 | else:
103 | return ref_frame, vit_frame, video_data, camera_data, mask_data, caption, video_key
104 |
105 | def prepare_gs(self, camera_data, mask_data, video_data): # mask_data [24,512,512,1]
106 |
107 | results = {}
108 |
109 | mask_data = mask_data.permute(0,3,1,2)
110 | results['masks_output'] = mask_data/255.0 # TODO normalize to [0, 1]
111 |
112 | T = camera_data.shape[0]
113 | camera_data = camera_data.view(T,4,4).contiguous()
114 |
115 | camera_data[:,1] *= -1
116 | camera_data[:,[1, 2]] = camera_data[:,[2, 1]]
117 | cam_dis = np.sqrt(camera_data[0,0,3]**2 + camera_data[0,1,3]**2 + camera_data[0,2,3]**2)
118 |
119 | # normalized camera feats as in paper (transform the first pose to a fixed position)
120 | transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, cam_dis], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(camera_data[0])
121 | cam_poses = transform.unsqueeze(0) @ camera_data # [V, 4, 4]
122 |
123 | cam_poses_input = cam_poses.clone()
124 |
125 | rays_embeddings = []
126 | for i in range(T):
127 | rays_o, rays_d = get_rays(cam_poses_input[i], 256, 256, self.opt.fovy) # [h, w, 3]
128 | rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6]
129 | rays_embeddings.append(rays_plucker)
130 |
131 | rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous() # [V=24, 6, h, w]
132 | results['input'] = rays_embeddings
133 |
134 | # opengl to colmap camera for gs renderer
135 | cam_poses_input[:,:3,1:3] *= -1
136 |
137 | # cameras needed by gaussian rasterizer
138 | cam_view = torch.inverse(cam_poses_input).transpose(1, 2) # [V, 4, 4]
139 | cam_view_proj = cam_view @ self.proj_matrix # [V, 4, 4]
140 | cam_pos = - cam_poses_input[:, :3, 3] # [V, 3]
141 |
142 | results['cam_view'] = cam_view
143 | results['cam_view_proj'] = cam_view_proj
144 | results['cam_pos'] = cam_pos
145 |
146 | return results
147 |
148 | def _get_video_data(self, data_dir, file_path):
149 | prefix = os.path.join(data_dir, file_path, 'campos_512_v4')
150 |
151 | frames_path = [os.path.join(prefix, "{:05d}/{:05d}.png".format(frame_idx, frame_idx)) for frame_idx in range(24)]
152 | camera_path = [os.path.join(prefix, "{:05d}/{:05d}.json".format(frame_idx, frame_idx)) for frame_idx in range(24)]
153 |
154 | frame_list = []
155 | fullreso_frame_list = []
156 | camera_list = []
157 | mask_list = []
158 | fullreso_mask_list = []
159 | for frame_idx, frame_path in enumerate(frames_path):
160 | img = Image.open(frame_path).convert('RGBA')
161 | mask = torch.from_numpy(np.array(img.resize((self.resolution[1], self.resolution[0])))[:,:,-1]).unsqueeze(-1)
162 | mask_list.append(mask)
163 | fullreso_mask = torch.from_numpy(np.array(img)[:,:,-1]).unsqueeze(-1)
164 | fullreso_mask_list.append(fullreso_mask)
165 |
166 | width = img.width
167 | height = img.height
168 | grey_scale = 255
169 | image = Image.new('RGB', size=(width, height), color=(grey_scale,grey_scale,grey_scale))
170 | image.paste(img,(0,0),mask=img)
171 |
172 | fullreso_frame_list.append(torch.from_numpy(np.array(image)/255.0).float()) # for LGM rendering NOTE notice the data range [0,1]
173 | frame_list.append(image.resize((self.resolution[1], self.resolution[0])))
174 |
175 | _, camera_embedding = read_camera_matrix_single(camera_path[frame_idx])
176 | camera_list.append(torch.from_numpy(camera_embedding.flatten().astype(np.float32)))
177 |
178 | camera_data = torch.stack(camera_list, dim=0) # [24,16]
179 | mask_data = torch.stack(mask_list, dim=0)
180 | fullreso_mask_data = torch.stack(fullreso_mask_list, dim=0)
181 |
182 | video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0])
183 |
184 | fullreso_video_data = torch.zeros(self.max_frames, 3, 512, 512)
185 |
186 | if self.get_first_frame:
187 | ref_idx = 0
188 | else:
189 | ref_idx = int(len(frame_list)/2)
190 |
191 | mid_frame = copy(frame_list[ref_idx])
192 | vit_frame = self.vit_transforms(mid_frame)
193 | frames = self.transforms(frame_list)
194 | video_data[:len(frame_list), ...] = frames
195 |
196 | if True: # random augmentation
197 | split_idx = np.random.randint(0, len(frame_list))
198 | video_data = torch.cat([video_data[split_idx:], video_data[:split_idx]], dim=0)
199 |
200 | fullreso_video_data[:len(fullreso_frame_list), ...] = torch.stack(fullreso_frame_list, dim=0).permute(0,3,1,2)
201 |
202 | ref_frame = copy(frames[ref_idx])
203 |
204 | return ref_frame, vit_frame, video_data, fullreso_video_data, camera_data, mask_data, fullreso_mask_data
205 |
206 | def __len__(self):
207 | return len(self.image_list)*self.replica
208 |
209 |
--------------------------------------------------------------------------------
/tools/hooks/__init__.py:
--------------------------------------------------------------------------------
1 | from .visual_train_it2v_video import *
2 |
--------------------------------------------------------------------------------
/tools/hooks/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/hooks/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/tools/hooks/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/hooks/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/hooks/__pycache__/visual_train_it2v_video.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/hooks/__pycache__/visual_train_it2v_video.cpython-310.pyc
--------------------------------------------------------------------------------
/tools/hooks/__pycache__/visual_train_it2v_video.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/hooks/__pycache__/visual_train_it2v_video.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/hooks/visual_train_it2v_video.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import pynvml
4 | import logging
5 | from einops import rearrange
6 | import torch.cuda.amp as amp
7 |
8 | from utils.video_op import save_video_refimg_and_text
9 | from utils.registry_class import VISUAL
10 |
11 | from PIL import Image
12 | import numpy as np
13 |
14 |
15 | @VISUAL.register_class()
16 | class VisualTrainTextImageToVideo(object):
17 | def __init__(self, cfg_global, autoencoder, diffusion, viz_num, partial_keys=[], guide_scale=9.0, use_offset_noise=None, **kwargs):
18 | super(VisualTrainTextImageToVideo, self).__init__(**kwargs)
19 | self.cfg = cfg_global
20 | self.viz_num = viz_num
21 | self.diffusion = diffusion
22 | self.autoencoder = autoencoder
23 | self.guide_scale = guide_scale
24 | self.partial_keys_list = partial_keys
25 | self.use_offset_noise = use_offset_noise
26 |
27 | def prepare_model_kwargs(self, partial_keys, full_model_kwargs):
28 | """
29 | """
30 | model_kwargs = [{}, {}]
31 | for partial_key in partial_keys:
32 | model_kwargs[0][partial_key] = full_model_kwargs[0][partial_key]
33 | model_kwargs[1][partial_key] = full_model_kwargs[1][partial_key]
34 | return model_kwargs
35 |
36 | @torch.no_grad()
37 | def run(self,
38 | model,
39 | video_data,
40 | captions,
41 | step=0,
42 | ref_frame=None,
43 | visual_kwards=[],
44 | **kwargs):
45 |
46 | cfg = self.cfg
47 | viz_num = min(self.viz_num, video_data.size(0))
48 |
49 | # save latent video_data first shape:[B,C,F,H,W]
50 | save_vid_data = video_data.clone().detach()
51 | for idx in range(save_vid_data.shape[0]):
52 | save_vid = save_vid_data[idx].permute(1,0,2,3)
53 | save_vid = torch.cat(save_vid.chunk(24),dim=-1).squeeze(0)
54 | save_vid = torch.cat(save_vid.chunk(4),dim=-2).squeeze(0)
55 | max_value = save_vid.max()
56 | min_value = save_vid.min()
57 |
58 | file_name = f'rank{cfg.rank:02d}_index{idx:02d}.png'
59 | local_path = os.path.join(cfg.log_dir, f'sample_{step:06d}/{file_name}')
60 | os.makedirs(os.path.join(cfg.log_dir, f'sample_{step:06d}'), exist_ok=True)
61 | save_vid = (save_vid - min_value)/(max_value - min_value)
62 | save_vid = Image.fromarray((save_vid.cpu().numpy()*255).astype(np.uint8)).save(local_path)
63 |
64 | noise = torch.randn_like(video_data[:viz_num])
65 | if self.use_offset_noise:
66 | noise_strength = getattr(cfg, 'noise_strength', 0)
67 | b, c, f, *_ = video_data[:viz_num].shape
68 | noise = noise + noise_strength * torch.randn(b, c, f, 1, 1, device=video_data.device)
69 |
70 | # import ipdb; ipdb.set_trace()
71 | # print memory
72 | pynvml.nvmlInit()
73 | handle=pynvml.nvmlDeviceGetHandleByIndex(0)
74 | meminfo=pynvml.nvmlDeviceGetMemoryInfo(handle)
75 | logging.info(f'GPU Memory used {meminfo.used / (1024 ** 3):.2f} GB')
76 |
77 | for keys in self.partial_keys_list:
78 | model_kwargs = self.prepare_model_kwargs(keys, visual_kwards)
79 | pre_name = '_'.join(keys)
80 | with amp.autocast(enabled=cfg.use_fp16):
81 | video_data = self.diffusion.ddim_sample_loop(
82 | noise=noise.clone(),
83 | model=model.eval(),
84 | model_kwargs=model_kwargs,
85 | guide_scale=self.guide_scale,
86 | ddim_timesteps=cfg.ddim_timesteps,
87 | eta=0.0)
88 |
89 | # save latent video_data pred shape:[B,C,F,H,W]
90 | save_vid_data_pred = video_data.clone().detach()
91 | for idx in range(save_vid_data_pred.shape[0]):
92 | save_vid = save_vid_data_pred[idx].permute(1,0,2,3)
93 | save_vid = torch.cat(save_vid.chunk(24),dim=-1).squeeze(0)
94 | save_vid = torch.cat(save_vid.chunk(4),dim=-2).squeeze(0)
95 | max_value = save_vid.max()
96 | min_value = save_vid.min()
97 |
98 | file_name = f'rank{cfg.rank:02d}_index{idx:02d}_pred.png'
99 | local_path = os.path.join(cfg.log_dir, f'sample_{step:06d}/{file_name}')
100 | os.makedirs(os.path.join(cfg.log_dir, f'sample_{step:06d}'), exist_ok=True)
101 | save_vid = (save_vid - min_value)/(max_value - min_value)
102 | save_vid = Image.fromarray((save_vid.cpu().numpy()*255).astype(np.uint8)).save(local_path)
103 |
104 | video_data = 1. / cfg.scale_factor * video_data # [64, 4, 32, 48]
105 | video_data = rearrange(video_data, 'b c f h w -> (b f) c h w')
106 | chunk_size = min(cfg.decoder_bs, video_data.shape[0])
107 | video_data_list = torch.chunk(video_data, video_data.shape[0]//chunk_size,dim=0)
108 | decode_data = []
109 | for vd_data in video_data_list:
110 | gen_frames = self.autoencoder.decode(vd_data)
111 | decode_data.append(gen_frames)
112 | video_data = torch.cat(decode_data, dim=0)
113 | video_data = rearrange(video_data, '(b f) c h w -> b c f h w', b = viz_num)
114 |
115 | text_size = cfg.resolution[-1]
116 | ref_frame = ref_frame[:viz_num]
117 | file_name = f'rank_{cfg.world_size:02d}_{cfg.rank:02d}_{cfg.sample_fps:02d}_{pre_name}'
118 | local_path = os.path.join(cfg.log_dir, f'sample_{step:06d}/{file_name}')
119 | os.makedirs(os.path.dirname(local_path), exist_ok=True)
120 | try:
121 | save_video_refimg_and_text(local_path, ref_frame.cpu(), video_data.cpu(), captions, cfg.mean, cfg.std, text_size)
122 | except Exception as e:
123 | logging.info(f'Step: {step} save text or video error with {e}')
--------------------------------------------------------------------------------
/tools/hooks/visual_train_t2v.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import pynvml
4 | import logging
5 | from einops import rearrange
6 | import torch.cuda.amp as amp
7 |
8 | from utils.video_op import save_video_refimg_and_text
9 | from utils.registry_class import VISUAL
10 |
11 | from PIL import Image
12 | import numpy as np
13 |
14 | @VISUAL.register_class()
15 | class VisualTrainTextToVideo(object):
16 | def __init__(self, cfg_global, autoencoder, diffusion, viz_num, partial_keys=[], guide_scale=9.0, use_offset_noise=None, **kwargs):
17 | super(VisualTrainTextToVideo, self).__init__(**kwargs)
18 | self.cfg = cfg_global
19 | self.viz_num = viz_num
20 | self.diffusion = diffusion
21 | self.autoencoder = autoencoder
22 | self.guide_scale = guide_scale
23 | self.partial_keys_list = partial_keys
24 | self.use_offset_noise = use_offset_noise
25 |
26 | def prepare_model_kwargs(self, partial_keys, full_model_kwargs):
27 | """
28 | """
29 | model_kwargs = [{}, {}]
30 | for partial_key in partial_keys:
31 | model_kwargs[0][partial_key] = full_model_kwargs[0][partial_key]
32 | model_kwargs[1][partial_key] = full_model_kwargs[1][partial_key]
33 | return model_kwargs
34 |
35 | @torch.no_grad()
36 | def run(self,
37 | model,
38 | video_data,
39 | captions,
40 | step=0,
41 | ref_frame=None,
42 | visual_kwards=[],
43 | **kwargs):
44 | cfg = self.cfg
45 | viz_num = self.viz_num
46 |
47 |
48 |
49 | noise = torch.randn_like(video_data[:viz_num]) # viz_num: 8
50 | if self.use_offset_noise:
51 | noise_strength = getattr(cfg, 'noise_strength', 0)
52 | b, c, f, *_ = video_data[:viz_num].shape
53 | noise = noise + noise_strength * torch.randn(b, c, f, 1, 1, device=video_data.device)
54 |
55 | # print memory
56 | pynvml.nvmlInit()
57 | handle=pynvml.nvmlDeviceGetHandleByIndex(0)
58 | meminfo=pynvml.nvmlDeviceGetMemoryInfo(handle)
59 | logging.info(f'GPU Memory used {meminfo.used / (1024 ** 3):.2f} GB')
60 |
61 | for keys in self.partial_keys_list:
62 | model_kwargs = self.prepare_model_kwargs(keys, visual_kwards)
63 | pre_name = '_'.join(keys)
64 | with amp.autocast(enabled=cfg.use_fp16):
65 | video_data = self.diffusion.ddim_sample_loop(
66 | noise=noise.clone(),
67 | model=model.eval(),
68 | model_kwargs=model_kwargs,
69 | guide_scale=self.guide_scale,
70 | ddim_timesteps=cfg.ddim_timesteps,
71 | eta=0.0)
72 |
73 | video_data = 1. / cfg.scale_factor * video_data # [64, 4, 32, 48]
74 | video_data = rearrange(video_data, 'b c f h w -> (b f) c h w')
75 | chunk_size = min(cfg.decoder_bs, video_data.shape[0])
76 | video_data_list = torch.chunk(video_data, video_data.shape[0]//chunk_size,dim=0)
77 | decode_data = []
78 | for vd_data in video_data_list:
79 | gen_frames = self.autoencoder.decode(vd_data)
80 | decode_data.append(gen_frames)
81 | video_data = torch.cat(decode_data, dim=0)
82 | video_data = rearrange(video_data, '(b f) c h w -> b c f h w', b = viz_num)
83 |
84 | text_size = cfg.resolution[-1]
85 | ref_frame = ref_frame[:viz_num]
86 | file_name = f'rank_{cfg.world_size:02d}_{cfg.rank:02d}_{cfg.sample_fps:02d}_{pre_name}'
87 | local_path = os.path.join(cfg.log_dir, f'sample_{step:06d}/{file_name}')
88 | os.makedirs(os.path.dirname(local_path), exist_ok=True)
89 | try:
90 | save_video_refimg_and_text(local_path, ref_frame.cpu(), video_data.cpu(), captions, cfg.mean, cfg.std, text_size)
91 | except Exception as e:
92 | logging.info(f'Step: {step} save text or video error with {e}')
93 |
94 |
95 |
--------------------------------------------------------------------------------
/tools/inferences/__init__.py:
--------------------------------------------------------------------------------
1 | from .inference_i2vgen_entrance import *
2 | from .inference_text2video_entrance import *
3 |
4 |
--------------------------------------------------------------------------------
/tools/inferences/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/inferences/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/tools/inferences/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/inferences/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/inferences/__pycache__/inference_i2vgen_entrance.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/inferences/__pycache__/inference_i2vgen_entrance.cpython-310.pyc
--------------------------------------------------------------------------------
/tools/inferences/__pycache__/inference_i2vgen_entrance.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/inferences/__pycache__/inference_i2vgen_entrance.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/inferences/__pycache__/inference_text2video_entrance.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/inferences/__pycache__/inference_text2video_entrance.cpython-310.pyc
--------------------------------------------------------------------------------
/tools/inferences/__pycache__/inference_text2video_entrance.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/inferences/__pycache__/inference_text2video_entrance.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .clip_embedder import FrozenOpenCLIPEmbedder
2 | from .autoencoder import DiagonalGaussianDistribution, AutoencoderKL
3 | from .clip_embedder import *
4 | from .autoencoder import *
5 | from .unet import *
6 | from .diffusions import *
--------------------------------------------------------------------------------
/tools/modules/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/tools/modules/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/modules/__pycache__/autoencoder.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/__pycache__/autoencoder.cpython-310.pyc
--------------------------------------------------------------------------------
/tools/modules/__pycache__/autoencoder.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/__pycache__/autoencoder.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/modules/__pycache__/clip_embedder.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/__pycache__/clip_embedder.cpython-310.pyc
--------------------------------------------------------------------------------
/tools/modules/__pycache__/clip_embedder.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/__pycache__/clip_embedder.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/modules/__pycache__/config.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/__pycache__/config.cpython-310.pyc
--------------------------------------------------------------------------------
/tools/modules/__pycache__/config.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/__pycache__/config.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/modules/clip_embedder.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import logging
4 | import open_clip
5 | import numpy as np
6 | import torch.nn as nn
7 | import torchvision.transforms as T
8 |
9 | from utils.registry_class import EMBEDDER
10 |
11 |
12 | @EMBEDDER.register_class()
13 | class FrozenOpenCLIPEmbedder(nn.Module):
14 | """
15 | Uses the OpenCLIP transformer encoder for text
16 | """
17 | LAYERS = [
18 | #"pooled",
19 | "last",
20 | "penultimate"
21 | ]
22 | def __init__(self, pretrained, arch="ViT-H-14", device="cuda", max_length=77,
23 | freeze=True, layer="last"):
24 | super().__init__()
25 | assert layer in self.LAYERS
26 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=pretrained)
27 | del model.visual
28 | self.model = model
29 |
30 | self.device = device
31 | self.max_length = max_length
32 | if freeze:
33 | self.freeze()
34 | self.layer = layer
35 | if self.layer == "last":
36 | self.layer_idx = 0
37 | elif self.layer == "penultimate":
38 | self.layer_idx = 1
39 | else:
40 | raise NotImplementedError()
41 |
42 | def freeze(self):
43 | self.model = self.model.eval()
44 | for param in self.parameters():
45 | param.requires_grad = False
46 |
47 | def forward(self, text):
48 | tokens = open_clip.tokenize(text)
49 | z = self.encode_with_transformer(tokens.to(self.device))
50 | return z
51 |
52 | def encode_with_transformer(self, text):
53 | x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
54 | x = x + self.model.positional_embedding
55 | x = x.permute(1, 0, 2) # NLD -> LND
56 | x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
57 | x = x.permute(1, 0, 2) # LND -> NLD
58 | x = self.model.ln_final(x)
59 | return x
60 |
61 | def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
62 | for i, r in enumerate(self.model.transformer.resblocks):
63 | if i == len(self.model.transformer.resblocks) - self.layer_idx:
64 | break
65 | if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
66 | x = checkpoint(r, x, attn_mask)
67 | else:
68 | x = r(x, attn_mask=attn_mask)
69 | return x
70 |
71 | def encode(self, text):
72 | return self(text)
73 |
74 |
75 | @EMBEDDER.register_class()
76 | class FrozenOpenCLIPVisualEmbedder(nn.Module):
77 | """
78 | Uses the OpenCLIP transformer encoder for text
79 | """
80 | LAYERS = [
81 | #"pooled",
82 | "last",
83 | "penultimate"
84 | ]
85 | def __init__(self, pretrained, vit_resolution=(224, 224), arch="ViT-H-14", device="cuda", max_length=77,
86 | freeze=True, layer="last"):
87 | super().__init__()
88 | assert layer in self.LAYERS
89 | model, _, preprocess = open_clip.create_model_and_transforms(
90 | arch, device=torch.device('cpu'), pretrained=pretrained)
91 | # Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)
92 | del model.transformer
93 | self.model = model
94 | data_white = np.ones((vit_resolution[0], vit_resolution[1], 3), dtype=np.uint8)*255
95 | self.white_image = preprocess(T.ToPILImage()(data_white)).unsqueeze(0)
96 |
97 | self.device = device
98 | self.max_length = max_length # 77
99 | if freeze:
100 | self.freeze()
101 | self.layer = layer # 'penultimate'
102 | if self.layer == "last":
103 | self.layer_idx = 0
104 | elif self.layer == "penultimate":
105 | self.layer_idx = 1
106 | else:
107 | raise NotImplementedError()
108 |
109 | def freeze(self): # model.encode_image(torch.randn(2,3,224,224))
110 | self.model = self.model.eval()
111 | for param in self.parameters():
112 | param.requires_grad = False
113 |
114 | def forward(self, image):
115 | # tokens = open_clip.tokenize(text)
116 | z = self.model.encode_image(image.to(self.device))
117 | return z
118 |
119 | def encode_with_transformer(self, text):
120 | x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
121 | x = x + self.model.positional_embedding
122 | x = x.permute(1, 0, 2) # NLD -> LND
123 | x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
124 | x = x.permute(1, 0, 2) # LND -> NLD
125 | x = self.model.ln_final(x)
126 |
127 | return x
128 |
129 | def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
130 | for i, r in enumerate(self.model.transformer.resblocks):
131 | if i == len(self.model.transformer.resblocks) - self.layer_idx:
132 | break
133 | if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
134 | x = checkpoint(r, x, attn_mask)
135 | else:
136 | x = r(x, attn_mask=attn_mask)
137 | return x
138 |
139 | def encode(self, text):
140 | return self(text)
141 |
142 |
143 |
144 | @EMBEDDER.register_class()
145 | class FrozenOpenCLIPTtxtVisualEmbedder(nn.Module):
146 | """
147 | Uses the OpenCLIP transformer encoder for text
148 | """
149 | LAYERS = [
150 | #"pooled",
151 | "last",
152 | "penultimate"
153 | ]
154 | def __init__(self, pretrained, arch="ViT-H-14", device="cuda", max_length=77,
155 | freeze=True, layer="last", **kwargs):
156 | super().__init__()
157 | assert layer in self.LAYERS
158 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=pretrained)
159 | self.model = model
160 |
161 | self.device = device
162 | self.max_length = max_length
163 | if freeze:
164 | self.freeze()
165 | self.layer = layer
166 | if self.layer == "last":
167 | self.layer_idx = 0
168 | elif self.layer == "penultimate":
169 | self.layer_idx = 1
170 | else:
171 | raise NotImplementedError()
172 |
173 | def freeze(self):
174 | self.model = self.model.eval()
175 | for param in self.parameters():
176 | param.requires_grad = False
177 |
178 | # def forward(self, text):
179 | # tokens = open_clip.tokenize(text)
180 | # z = self.encode_with_transformer(tokens.to(self.device))
181 | # return z
182 |
183 | def forward(self, image=None, text=None):
184 | # xi = self.encode_image(image) if image is not None else None
185 | xi = self.model.encode_image(image.to(self.device)) if image is not None else None
186 | # tokens = open_clip.tokenize(text, truncate=True)
187 | tokens = open_clip.tokenize(text)
188 | xt, x = self.encode_with_transformer(tokens.to(self.device))
189 | return xi, xt, x
190 |
191 | def encode_with_transformer(self, text):
192 | x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
193 | x = x + self.model.positional_embedding
194 | x = x.permute(1, 0, 2) # NLD -> LND
195 | x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
196 | x = x.permute(1, 0, 2) # LND -> NLD
197 | x = self.model.ln_final(x)
198 | xt = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.model.text_projection
199 | return xt, x
200 |
201 | # def encode_with_transformer(self, text):
202 | # x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
203 | # x = x + self.model.positional_embedding
204 | # x = x.permute(1, 0, 2) # NLD -> LND
205 | # x = self.model.transformer(x)
206 | # x = x.permute(1, 0, 2) # LND -> NLD
207 | # x = self.model.ln_final(x)
208 | # xt = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.model.text_projection
209 | # # text embedding, token embedding
210 | # return xt, x
211 |
212 | def encode_image(self, image):
213 | return self.model.visual(image)
214 |
215 | def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
216 | for i, r in enumerate(self.model.transformer.resblocks):
217 | if i == len(self.model.transformer.resblocks) - self.layer_idx:
218 | break
219 | if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
220 | x = checkpoint(r, x, attn_mask)
221 | else:
222 | x = r(x, attn_mask=attn_mask)
223 | return x
224 |
225 | def encode(self, text):
226 |
227 | return self(text)
228 |
229 |
--------------------------------------------------------------------------------
/tools/modules/config.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import logging
3 | import os.path as osp
4 | from datetime import datetime
5 | from easydict import EasyDict
6 | import os
7 |
8 | cfg = EasyDict(__name__='Config: VideoLDM Decoder')
9 |
10 | # -------------------------------distributed training--------------------------
11 | pmi_world_size = int(os.getenv('WORLD_SIZE', 1))
12 | gpus_per_machine = torch.cuda.device_count()
13 | world_size = pmi_world_size * gpus_per_machine
14 | # -----------------------------------------------------------------------------
15 |
16 |
17 | # ---------------------------Dataset Parameter---------------------------------
18 | cfg.mean = [0.5, 0.5, 0.5]
19 | cfg.std = [0.5, 0.5, 0.5]
20 | cfg.max_words = 1000
21 | cfg.num_workers = 8
22 | cfg.prefetch_factor = 2
23 |
24 | # PlaceHolder
25 | cfg.resolution = [448, 256]
26 | cfg.vit_out_dim = 1024
27 | cfg.vit_resolution = 336
28 | cfg.depth_clamp = 10.0
29 | cfg.misc_size = 384
30 | cfg.depth_std = 20.0
31 |
32 | cfg.frame_lens = [32, 32, 32, 1]
33 | cfg.sample_fps = [4, ]
34 | cfg.vid_dataset = {
35 | 'type': 'VideoBaseDataset',
36 | 'data_list': [],
37 | 'max_words': cfg.max_words,
38 | 'resolution': cfg.resolution}
39 | cfg.img_dataset = {
40 | 'type': 'ImageBaseDataset',
41 | 'data_list': ['laion_400m',],
42 | 'max_words': cfg.max_words,
43 | 'resolution': cfg.resolution}
44 |
45 | cfg.batch_sizes = {
46 | str(1):256,
47 | str(4):4,
48 | str(8):4,
49 | str(16):4}
50 | # -----------------------------------------------------------------------------
51 |
52 |
53 | # ---------------------------Mode Parameters-----------------------------------
54 | # Diffusion
55 | cfg.Diffusion = {
56 | 'type': 'DiffusionDDIM',
57 | 'schedule': 'cosine', # cosine
58 | 'schedule_param': {
59 | 'num_timesteps': 1000,
60 | 'cosine_s': 0.008,
61 | 'zero_terminal_snr': True,
62 | },
63 | 'mean_type': 'v', # [v, eps]
64 | 'loss_type': 'mse',
65 | 'var_type': 'fixed_small',
66 | 'rescale_timesteps': False,
67 | 'noise_strength': 0.1,
68 | 'ddim_timesteps': 50
69 | }
70 | cfg.ddim_timesteps = 50 # official: 250
71 | cfg.use_div_loss = False
72 | # classifier-free guidance
73 | cfg.p_zero = 0.9
74 | cfg.guide_scale = 3.0
75 |
76 | # clip vision encoder
77 | cfg.vit_mean = [0.48145466, 0.4578275, 0.40821073]
78 | cfg.vit_std = [0.26862954, 0.26130258, 0.27577711]
79 |
80 | # Model
81 | cfg.scale_factor = 0.18215
82 | cfg.use_checkpoint = True
83 | cfg.use_sharded_ddp = False
84 | cfg.use_fsdp = False
85 | cfg.use_fp16 = True
86 | cfg.temporal_attention = True
87 |
88 | cfg.UNet = {
89 | 'type': 'UNetSD',
90 | 'in_dim': 4,
91 | 'dim': 320,
92 | 'y_dim': cfg.vit_out_dim,
93 | 'context_dim': 1024,
94 | 'out_dim': 8,
95 | 'dim_mult': [1, 2, 4, 4],
96 | 'num_heads': 8,
97 | 'head_dim': 64,
98 | 'num_res_blocks': 2,
99 | 'attn_scales': [1 / 1, 1 / 2, 1 / 4],
100 | 'dropout': 0.1,
101 | 'temporal_attention': cfg.temporal_attention,
102 | 'temporal_attn_times': 1,
103 | 'use_checkpoint': cfg.use_checkpoint,
104 | 'use_fps_condition': False,
105 | 'use_sim_mask': False
106 | }
107 |
108 | # auotoencoder from stabel diffusion
109 | cfg.guidances = []
110 | cfg.auto_encoder = {
111 | 'type': 'AutoencoderKL',
112 | 'ddconfig': {
113 | 'double_z': True,
114 | 'z_channels': 4,
115 | 'resolution': 256,
116 | 'in_channels': 3,
117 | 'out_ch': 3,
118 | 'ch': 128,
119 | 'ch_mult': [1, 2, 4, 4],
120 | 'num_res_blocks': 2,
121 | 'attn_resolutions': [],
122 | 'dropout': 0.0,
123 | 'video_kernel_size': [3, 1, 1]
124 | },
125 | 'embed_dim': 4,
126 | 'pretrained': './pretrained_models/modelscope_t2v/VQGAN_autoencoder.pth'
127 | }
128 | # clip embedder
129 | cfg.embedder = {
130 | 'type': 'FrozenOpenCLIPEmbedder',
131 | 'layer': 'penultimate',
132 | 'pretrained': 'modelscope_t2v/open_clip_pytorch_model.bin'
133 | }
134 | # -----------------------------------------------------------------------------
135 |
136 | # ---------------------------Training Settings---------------------------------
137 | # training and optimizer
138 | cfg.ema_decay = 0.9999
139 | cfg.num_steps = 600000
140 | cfg.lr = 5e-5
141 | cfg.weight_decay = 0.0
142 | cfg.betas = (0.9, 0.999)
143 | cfg.eps = 1.0e-8
144 | cfg.chunk_size = 16
145 | cfg.decoder_bs = 8
146 | cfg.alpha = 0.7
147 | cfg.save_ckp_interval = 1000
148 |
149 | # scheduler
150 | cfg.warmup_steps = 10
151 | cfg.decay_mode = 'cosine'
152 |
153 | # acceleration
154 | cfg.use_ema = True
155 | if world_size<2:
156 | cfg.use_ema = False
157 | cfg.load_from = None
158 | # -----------------------------------------------------------------------------
159 |
160 |
161 | # ----------------------------Pretrain Settings---------------------------------
162 | cfg.Pretrain = {
163 | 'type': 'pretrain_specific_strategies',
164 | 'fix_weight': False,
165 | 'grad_scale': 0.2,
166 | 'resume_checkpoint': 'models/jiuniu_0267000.pth',
167 | 'sd_keys_path': 'models/stable_diffusion_image_key_temporal_attention_x1.json',
168 | }
169 | # -----------------------------------------------------------------------------
170 |
171 |
172 | # -----------------------------Visual-------------------------------------------
173 | # Visual videos
174 | cfg.viz_interval = 1000
175 | cfg.visual_train = {
176 | 'type': 'VisualTrainTextImageToVideo',
177 | }
178 | cfg.visual_inference = {
179 | 'type': 'VisualGeneratedVideos',
180 | }
181 | cfg.inference_list_path = ''
182 |
183 | # logging
184 | cfg.log_interval = 100
185 |
186 | ### Default log_dir
187 | cfg.log_dir = 'workspace/temp_dir'
188 | # -----------------------------------------------------------------------------
189 |
190 |
191 | # ---------------------------Others--------------------------------------------
192 | # seed
193 | cfg.seed = 8888
194 | # motionless static
195 | cfg.negative_prompt = 'Distorted, discontinuous, Ugly, blurry, low resolution, disfigured, disconnected limbs, Ugly faces, incomplete arms'
196 | # -----------------------------------------------------------------------------
197 |
198 |
--------------------------------------------------------------------------------
/tools/modules/diffusions/__init__.py:
--------------------------------------------------------------------------------
1 | from .diffusion_ddim import *
2 |
--------------------------------------------------------------------------------
/tools/modules/diffusions/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/diffusions/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/tools/modules/diffusions/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/diffusions/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/modules/diffusions/__pycache__/diffusion_ddim.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/diffusions/__pycache__/diffusion_ddim.cpython-310.pyc
--------------------------------------------------------------------------------
/tools/modules/diffusions/__pycache__/diffusion_ddim.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/diffusions/__pycache__/diffusion_ddim.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/modules/diffusions/__pycache__/losses.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/diffusions/__pycache__/losses.cpython-310.pyc
--------------------------------------------------------------------------------
/tools/modules/diffusions/__pycache__/losses.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/diffusions/__pycache__/losses.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/modules/diffusions/__pycache__/schedules.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/diffusions/__pycache__/schedules.cpython-310.pyc
--------------------------------------------------------------------------------
/tools/modules/diffusions/__pycache__/schedules.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/diffusions/__pycache__/schedules.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/modules/diffusions/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 |
4 | __all__ = ['kl_divergence', 'discretized_gaussian_log_likelihood']
5 |
6 | def kl_divergence(mu1, logvar1, mu2, logvar2):
7 | return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mu1 - mu2) ** 2) * torch.exp(-logvar2))
8 |
9 | def standard_normal_cdf(x):
10 | r"""A fast approximation of the cumulative distribution function of the standard normal.
11 | """
12 | return 0.5 * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
13 |
14 | def discretized_gaussian_log_likelihood(x0, mean, log_scale):
15 | assert x0.shape == mean.shape == log_scale.shape
16 | cx = x0 - mean
17 | inv_stdv = torch.exp(-log_scale)
18 | cdf_plus = standard_normal_cdf(inv_stdv * (cx + 1.0 / 255.0))
19 | cdf_min = standard_normal_cdf(inv_stdv * (cx - 1.0 / 255.0))
20 | log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12))
21 | log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12))
22 | cdf_delta = cdf_plus - cdf_min
23 | log_probs = torch.where(
24 | x0 < -0.999,
25 | log_cdf_plus,
26 | torch.where(x0 > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))))
27 | assert log_probs.shape == x0.shape
28 | return log_probs
29 |
--------------------------------------------------------------------------------
/tools/modules/diffusions/schedules.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 |
4 |
5 | def beta_schedule(schedule='cosine',
6 | num_timesteps=1000,
7 | zero_terminal_snr=False,
8 | **kwargs):
9 | # compute betas
10 | betas = {
11 | 'logsnr_cosine_interp': logsnr_cosine_interp_schedule,
12 | 'linear': linear_schedule,
13 | 'linear_sd': linear_sd_schedule,
14 | 'quadratic': quadratic_schedule,
15 | 'cosine': cosine_schedule
16 | }[schedule](num_timesteps, **kwargs)
17 |
18 | if zero_terminal_snr and betas.max() != 1.0:
19 | betas = rescale_zero_terminal_snr(betas)
20 |
21 | return betas
22 |
23 |
24 | def linear_schedule(num_timesteps, init_beta, last_beta, **kwargs):
25 | scale = 1000.0 / num_timesteps
26 | init_beta = init_beta or scale * 0.0001
27 | ast_beta = last_beta or scale * 0.02
28 | return torch.linspace(init_beta, last_beta, num_timesteps, dtype=torch.float64)
29 |
30 | def logsnr_cosine_interp_schedule(
31 | num_timesteps,
32 | scale_min=2,
33 | scale_max=4,
34 | logsnr_min=-15,
35 | logsnr_max=15,
36 | **kwargs):
37 | return logsnrs_to_sigmas(
38 | _logsnr_cosine_interp(num_timesteps, logsnr_min, logsnr_max, scale_min, scale_max))
39 |
40 | def linear_sd_schedule(num_timesteps, init_beta, last_beta, **kwargs):
41 | return torch.linspace(init_beta ** 0.5, last_beta ** 0.5, num_timesteps, dtype=torch.float64) ** 2
42 |
43 |
44 | def quadratic_schedule(num_timesteps, init_beta, last_beta, **kwargs):
45 | init_beta = init_beta or 0.0015
46 | last_beta = last_beta or 0.0195
47 | return torch.linspace(init_beta ** 0.5, last_beta ** 0.5, num_timesteps, dtype=torch.float64) ** 2
48 |
49 |
50 | def cosine_schedule(num_timesteps, cosine_s=0.008, **kwargs):
51 | betas = []
52 | for step in range(num_timesteps):
53 | t1 = step / num_timesteps
54 | t2 = (step + 1) / num_timesteps
55 | fn = lambda u: math.cos((u + cosine_s) / (1 + cosine_s) * math.pi / 2) ** 2
56 | betas.append(min(1.0 - fn(t2) / fn(t1), 0.999))
57 | return torch.tensor(betas, dtype=torch.float64)
58 |
59 |
60 | # def cosine_schedule(n, cosine_s=0.008, **kwargs):
61 | # ramp = torch.linspace(0, 1, n + 1)
62 | # square_alphas = torch.cos((ramp + cosine_s) / (1 + cosine_s) * torch.pi / 2) ** 2
63 | # betas = (1 - square_alphas[1:] / square_alphas[:-1]).clamp(max=0.999)
64 | # return betas_to_sigmas(betas)
65 |
66 |
67 | def betas_to_sigmas(betas):
68 | return torch.sqrt(1 - torch.cumprod(1 - betas, dim=0))
69 |
70 |
71 | def sigmas_to_betas(sigmas):
72 | square_alphas = 1 - sigmas**2
73 | betas = 1 - torch.cat(
74 | [square_alphas[:1], square_alphas[1:] / square_alphas[:-1]])
75 | return betas
76 |
77 |
78 |
79 | def sigmas_to_logsnrs(sigmas):
80 | square_sigmas = sigmas**2
81 | return torch.log(square_sigmas / (1 - square_sigmas))
82 |
83 |
84 | def _logsnr_cosine(n, logsnr_min=-15, logsnr_max=15):
85 | t_min = math.atan(math.exp(-0.5 * logsnr_min))
86 | t_max = math.atan(math.exp(-0.5 * logsnr_max))
87 | t = torch.linspace(1, 0, n)
88 | logsnrs = -2 * torch.log(torch.tan(t_min + t * (t_max - t_min)))
89 | return logsnrs
90 |
91 |
92 | def _logsnr_cosine_shifted(n, logsnr_min=-15, logsnr_max=15, scale=2):
93 | logsnrs = _logsnr_cosine(n, logsnr_min, logsnr_max)
94 | logsnrs += 2 * math.log(1 / scale)
95 | return logsnrs
96 |
97 | def karras_schedule(n, sigma_min=0.002, sigma_max=80.0, rho=7.0):
98 | ramp = torch.linspace(1, 0, n)
99 | min_inv_rho = sigma_min**(1 / rho)
100 | max_inv_rho = sigma_max**(1 / rho)
101 | sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho))**rho
102 | sigmas = torch.sqrt(sigmas**2 / (1 + sigmas**2))
103 | return sigmas
104 |
105 | def _logsnr_cosine_interp(n,
106 | logsnr_min=-15,
107 | logsnr_max=15,
108 | scale_min=2,
109 | scale_max=4):
110 | t = torch.linspace(1, 0, n)
111 | logsnrs_min = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_min)
112 | logsnrs_max = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_max)
113 | logsnrs = t * logsnrs_min + (1 - t) * logsnrs_max
114 | return logsnrs
115 |
116 |
117 | def logsnrs_to_sigmas(logsnrs):
118 | return torch.sqrt(torch.sigmoid(-logsnrs))
119 |
120 |
121 | def rescale_zero_terminal_snr(betas):
122 | """
123 | Rescale Schedule to Zero Terminal SNR
124 | """
125 | # Convert betas to alphas_bar_sqrt
126 | alphas = 1 - betas
127 | alphas_bar = alphas.cumprod(0)
128 | alphas_bar_sqrt = alphas_bar.sqrt()
129 |
130 | # Store old values. 8 alphas_bar_sqrt_0 = a
131 | alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
132 | alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
133 | # Shift so last timestep is zero.
134 | alphas_bar_sqrt -= alphas_bar_sqrt_T
135 | # Scale so first timestep is back to old value.
136 | alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
137 |
138 | # Convert alphas_bar_sqrt to betas
139 | alphas_bar = alphas_bar_sqrt ** 2
140 | alphas = alphas_bar[1:] / alphas_bar[:-1]
141 | alphas = torch.cat([alphas_bar[0:1], alphas])
142 | betas = 1 - alphas
143 | return betas
144 |
145 |
--------------------------------------------------------------------------------
/tools/modules/unet/__init__.py:
--------------------------------------------------------------------------------
1 | from .unet_i2vgen import *
2 | from .unet_t2v import *
--------------------------------------------------------------------------------
/tools/modules/unet/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/unet/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/tools/modules/unet/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/unet/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/modules/unet/__pycache__/depthwise_attn.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/unet/__pycache__/depthwise_attn.cpython-310.pyc
--------------------------------------------------------------------------------
/tools/modules/unet/__pycache__/depthwise_attn.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/unet/__pycache__/depthwise_attn.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/modules/unet/__pycache__/depthwise_net.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/unet/__pycache__/depthwise_net.cpython-310.pyc
--------------------------------------------------------------------------------
/tools/modules/unet/__pycache__/depthwise_net.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/unet/__pycache__/depthwise_net.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/modules/unet/__pycache__/depthwise_utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/unet/__pycache__/depthwise_utils.cpython-310.pyc
--------------------------------------------------------------------------------
/tools/modules/unet/__pycache__/depthwise_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/unet/__pycache__/depthwise_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/modules/unet/__pycache__/unet_i2vgen.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/unet/__pycache__/unet_i2vgen.cpython-310.pyc
--------------------------------------------------------------------------------
/tools/modules/unet/__pycache__/unet_i2vgen.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/unet/__pycache__/unet_i2vgen.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/modules/unet/__pycache__/unet_t2v.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/unet/__pycache__/unet_t2v.cpython-310.pyc
--------------------------------------------------------------------------------
/tools/modules/unet/__pycache__/unet_t2v.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/unet/__pycache__/unet_t2v.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/modules/unet/__pycache__/util.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/unet/__pycache__/util.cpython-310.pyc
--------------------------------------------------------------------------------
/tools/modules/unet/__pycache__/util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/unet/__pycache__/util.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/modules/unet/mha_flash.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.cuda.amp as amp
4 | import torch.nn.functional as F
5 | import math
6 | import os
7 | import time
8 | import numpy as np
9 | import random
10 |
11 | # from flash_attn.flash_attention import FlashAttention
12 |
13 | class FlashAttentionBlock(nn.Module):
14 |
15 | def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None, batch_size=4):
16 | # consider head_dim first, then num_heads
17 | num_heads = dim // head_dim if head_dim else num_heads
18 | head_dim = dim // num_heads
19 | assert num_heads * head_dim == dim
20 | super(FlashAttentionBlock, self).__init__()
21 | self.dim = dim
22 | self.context_dim = context_dim
23 | self.num_heads = num_heads
24 | self.head_dim = head_dim
25 | self.scale = math.pow(head_dim, -0.25)
26 |
27 | # layers
28 | self.norm = nn.GroupNorm(32, dim)
29 | self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
30 | if context_dim is not None:
31 | self.context_kv = nn.Linear(context_dim, dim * 2)
32 | self.proj = nn.Conv2d(dim, dim, 1)
33 |
34 | if self.head_dim <= 128 and (self.head_dim % 8) == 0:
35 | new_scale = math.pow(head_dim, -0.5)
36 | self.flash_attn = FlashAttention(softmax_scale=None, attention_dropout=0.0)
37 |
38 | # zero out the last layer params
39 | nn.init.zeros_(self.proj.weight)
40 | # self.apply(self._init_weight)
41 |
42 |
43 | def _init_weight(self, module):
44 | if isinstance(module, nn.Linear):
45 | module.weight.data.normal_(mean=0.0, std=0.15)
46 | if module.bias is not None:
47 | module.bias.data.zero_()
48 | elif isinstance(module, nn.Conv2d):
49 | module.weight.data.normal_(mean=0.0, std=0.15)
50 | if module.bias is not None:
51 | module.bias.data.zero_()
52 |
53 | def forward(self, x, context=None):
54 | r"""x: [B, C, H, W].
55 | context: [B, L, C] or None.
56 | """
57 | identity = x
58 | b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim
59 |
60 | # compute query, key, value
61 | x = self.norm(x)
62 | q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1)
63 | if context is not None:
64 | ck, cv = self.context_kv(context).reshape(b, -1, n * 2, d).permute(0, 2, 3, 1).chunk(2, dim=1)
65 | k = torch.cat([ck, k], dim=-1)
66 | v = torch.cat([cv, v], dim=-1)
67 | cq = torch.zeros([b, n, d, 4], dtype=q.dtype, device=q.device)
68 | q = torch.cat([q, cq], dim=-1)
69 |
70 | qkv = torch.cat([q,k,v], dim=1)
71 | origin_dtype = qkv.dtype
72 | qkv = qkv.permute(0, 3, 1, 2).reshape(b, -1, 3, n, d).half().contiguous()
73 | out, _ = self.flash_attn(qkv)
74 | out.to(origin_dtype)
75 |
76 | if context is not None:
77 | out = out[:, :-4, :, :]
78 | out = out.permute(0, 2, 3, 1).reshape(b, c, h, w)
79 |
80 | # output
81 | x = self.proj(out)
82 | return x + identity
83 |
84 | if __name__ == '__main__':
85 | batch_size = 8
86 | flash_net = FlashAttentionBlock(dim=1280, context_dim=512, num_heads=None, head_dim=64, batch_size=batch_size).cuda()
87 |
88 | x = torch.randn([batch_size, 1280, 32, 32], dtype=torch.float32).cuda()
89 | context = torch.randn([batch_size, 4, 512], dtype=torch.float32).cuda()
90 | # context = None
91 | flash_net.eval()
92 |
93 | with amp.autocast(enabled=True):
94 | # warm up
95 | for i in range(5):
96 | y = flash_net(x, context)
97 | torch.cuda.synchronize()
98 | s1 = time.time()
99 | for i in range(10):
100 | y = flash_net(x, context)
101 | torch.cuda.synchronize()
102 | s2 = time.time()
103 |
104 | print(f'Average cost time {(s2-s1)*1000/10} ms')
--------------------------------------------------------------------------------
/tools/train/__init__.py:
--------------------------------------------------------------------------------
1 | from .train_t2v_enterance import *
2 | from .train_i2v_enterance import *
--------------------------------------------------------------------------------
/tools/train/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/train/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/tools/train/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/train/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/train/__pycache__/train_i2v_enterance.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/train/__pycache__/train_i2v_enterance.cpython-310.pyc
--------------------------------------------------------------------------------
/tools/train/__pycache__/train_i2v_enterance.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/train/__pycache__/train_i2v_enterance.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/train/__pycache__/train_t2v_enterance.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/train/__pycache__/train_t2v_enterance.cpython-310.pyc
--------------------------------------------------------------------------------
/tools/train/__pycache__/train_t2v_enterance.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/train/__pycache__/train_t2v_enterance.cpython-38.pyc
--------------------------------------------------------------------------------
/train_net.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import copy
4 | import json
5 | import math
6 | import random
7 | import logging
8 | import itertools
9 | import numpy as np
10 |
11 | from utils.config import Config
12 | from utils.registry_class import ENGINE
13 |
14 | from tools import *
15 |
16 | if __name__ == '__main__':
17 | cfg_update = Config(load=True)
18 | ENGINE.build(dict(type=cfg_update.TASK_TYPE), cfg_update=cfg_update.cfg_dict)
19 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__init__.py
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/assign_cfg.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/assign_cfg.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/assign_cfg.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/assign_cfg.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/camera_utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/camera_utils.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/camera_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/camera_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/config.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/config.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/config.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/config.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/distributed.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/distributed.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/distributed.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/distributed.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/logging.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/logging.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/logging.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/logging.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/multi_port.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/multi_port.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/multi_port.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/multi_port.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/registry.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/registry.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/registry.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/registry.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/registry_class.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/registry_class.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/registry_class.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/registry_class.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/seed.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/seed.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/seed.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/seed.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/transforms.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/transforms.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/transforms.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/transforms.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/util.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/util.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/util.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/video_op.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/video_op.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/video_op.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/video_op.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/assign_cfg.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | from copy import deepcopy, copy
3 |
4 |
5 | # def get prior and ldm config
6 | def assign_prior_mudule_cfg(cfg):
7 | '''
8 | '''
9 | #
10 | prior_cfg = deepcopy(cfg)
11 | vldm_cfg = deepcopy(cfg)
12 |
13 | with open(cfg.prior_cfg, 'r') as f:
14 | _cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader)
15 | # _cfg_update = _cfg_update.cfg_dict
16 | for k, v in _cfg_update.items():
17 | if isinstance(v, dict) and k in cfg:
18 | prior_cfg[k].update(v)
19 | else:
20 | prior_cfg[k] = v
21 |
22 | with open(cfg.vldm_cfg, 'r') as f:
23 | _cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader)
24 | # _cfg_update = _cfg_update.cfg_dict
25 | for k, v in _cfg_update.items():
26 | if isinstance(v, dict) and k in cfg:
27 | vldm_cfg[k].update(v)
28 | else:
29 | vldm_cfg[k] = v
30 |
31 | return prior_cfg, vldm_cfg
32 |
33 |
34 | # def get prior and ldm config
35 | def assign_vldm_vsr_mudule_cfg(cfg):
36 | '''
37 | '''
38 | #
39 | vldm_cfg = deepcopy(cfg)
40 | vsr_cfg = deepcopy(cfg)
41 |
42 | with open(cfg.vldm_cfg, 'r') as f:
43 | _cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader)
44 | # _cfg_update = _cfg_update.cfg_dict
45 | for k, v in _cfg_update.items():
46 | if isinstance(v, dict) and k in cfg:
47 | vldm_cfg[k].update(v)
48 | else:
49 | vldm_cfg[k] = v
50 |
51 | with open(cfg.vsr_cfg, 'r') as f:
52 | _cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader)
53 | # _cfg_update = _cfg_update.cfg_dict
54 | for k, v in _cfg_update.items():
55 | if isinstance(v, dict) and k in cfg:
56 | vsr_cfg[k].update(v)
57 | else:
58 | vsr_cfg[k] = v
59 |
60 | return vldm_cfg, vsr_cfg
61 |
62 |
63 | # def get prior and ldm config
64 | def assign_signle_cfg(cfg, _cfg_update, tname):
65 | '''
66 | '''
67 | #
68 | vldm_cfg = deepcopy(cfg)
69 | with open(_cfg_update[tname], 'r') as f:
70 | _cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader)
71 | # _cfg_update = _cfg_update.cfg_dict
72 | for k, v in _cfg_update.items():
73 | if isinstance(v, dict) and k in cfg:
74 | vldm_cfg[k].update(v)
75 | else:
76 | vldm_cfg[k] = v
77 | return vldm_cfg
--------------------------------------------------------------------------------
/utils/camera_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | def create_camera_to_world_matrix(elevation, azimuth, camera_distance=1):
5 | elevation = np.radians(elevation)
6 | azimuth = np.radians(azimuth)
7 | # Convert elevation and azimuth angles to Cartesian coordinates on a unit sphere
8 | x = camera_distance * np.cos(elevation) * np.sin(azimuth)
9 | y = camera_distance * np.sin(elevation)
10 | z = camera_distance * np.cos(elevation) * np.cos(azimuth)
11 |
12 | # Calculate camera position, target, and up vectors
13 | camera_pos = np.array([x, y, z])
14 | target = np.array([0, 0, 0])
15 | up = np.array([0, 1, 0])
16 |
17 | # Construct view matrix
18 | forward = target - camera_pos
19 | forward /= np.linalg.norm(forward)
20 | right = np.cross(forward, up)
21 | right /= np.linalg.norm(right)
22 | new_up = np.cross(right, forward)
23 | new_up /= np.linalg.norm(new_up)
24 | cam2world = np.eye(4)
25 | cam2world[:3, :3] = np.array([right, new_up, -forward]).T
26 | cam2world[:3, 3] = camera_pos
27 | return cam2world
28 |
29 |
30 | def convert_opengl_to_blender(camera_matrix):
31 | if isinstance(camera_matrix, np.ndarray):
32 | # Construct transformation matrix to convert from OpenGL space to Blender space
33 | flip_yz = np.array([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0],
34 | [0, 0, 0, 1]])
35 | camera_matrix_blender = np.dot(flip_yz, camera_matrix)
36 | else:
37 | # Construct transformation matrix to convert from OpenGL space to Blender space
38 | flip_yz = torch.tensor([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0],
39 | [0, 0, 0, 1]])
40 | if camera_matrix.ndim == 3:
41 | flip_yz = flip_yz.unsqueeze(0)
42 | camera_matrix_blender = torch.matmul(
43 | flip_yz.to(camera_matrix), camera_matrix)
44 | return camera_matrix_blender
45 |
46 | def get_camera(num_frames,
47 | elevation=15,
48 | azimuth_start=0,
49 | azimuth_span=360,
50 | blender_coord=True,
51 | camera_distance=1.):
52 | angle_gap = azimuth_span / num_frames
53 | cameras = []
54 | for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start,
55 | angle_gap):
56 | camera_matrix = create_camera_to_world_matrix(elevation, azimuth,
57 | camera_distance)
58 |
59 | if blender_coord:
60 | camera_matrix = convert_opengl_to_blender(camera_matrix)
61 | cameras.append(camera_matrix.flatten())
62 | return torch.tensor(np.stack(cameras, 0)).float()
--------------------------------------------------------------------------------
/utils/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import yaml
3 | import json
4 | import copy
5 | import argparse
6 |
7 | import utils.logging as logging
8 | logger = logging.get_logger(__name__)
9 |
10 | class Config(object):
11 | def __init__(self, load=True, cfg_dict=None, cfg_level=None):
12 | self._level = "cfg" + ("." + cfg_level if cfg_level is not None else "")
13 | if load:
14 | self.args = self._parse_args()
15 | logger.info("Loading config from {}.".format(self.args.cfg_file))
16 | self.need_initialization = True
17 | cfg_base = self._initialize_cfg()
18 | cfg_dict = self._load_yaml(self.args)
19 | cfg_dict = self._merge_cfg_from_base(cfg_base, cfg_dict)
20 | cfg_dict = self._update_from_args(cfg_dict)
21 | self.cfg_dict = cfg_dict
22 | self._update_dict(cfg_dict)
23 |
24 | def _parse_args(self):
25 | parser = argparse.ArgumentParser(
26 | description="Argparser for configuring [code base name to think of] codebase"
27 | )
28 | parser.add_argument(
29 | "--cfg",
30 | dest="cfg_file",
31 | help="Path to the configuration file",
32 | default='configs/i2vgen_xl_infer.yaml'
33 | )
34 | parser.add_argument(
35 | "--init_method",
36 | help="Initialization method, includes TCP or shared file-system",
37 | default="tcp://localhost:9999",
38 | type=str,
39 | )
40 | parser.add_argument(
41 | '--debug',
42 | action='store_true',
43 | default=False,
44 | help='Into debug information'
45 | )
46 | parser.add_argument(
47 | "opts",
48 | help="other configurations",
49 | default=None,
50 | nargs=argparse.REMAINDER)
51 | return parser.parse_args()
52 |
53 | def _path_join(self, path_list):
54 | path = ""
55 | for p in path_list:
56 | path+= p + '/'
57 | return path[:-1]
58 |
59 | def _update_from_args(self, cfg_dict):
60 | args = self.args
61 | for var in vars(args):
62 | cfg_dict[var] = getattr(args, var)
63 | return cfg_dict
64 |
65 | def _initialize_cfg(self):
66 | if self.need_initialization:
67 | self.need_initialization = False
68 | if os.path.exists('./configs/base.yaml'):
69 | with open("./configs/base.yaml", 'r') as f:
70 | cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
71 | else:
72 | with open(os.path.realpath(__file__).split('/')[-3] + "/configs/base.yaml", 'r') as f:
73 | cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
74 | return cfg
75 |
76 | def _load_yaml(self, args, file_name=""):
77 | assert args.cfg_file is not None
78 | if not file_name == "": # reading from base file
79 | with open(file_name, 'r') as f:
80 | cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
81 | else:
82 | if os.getcwd().split("/")[-1] == args.cfg_file.split("/")[0]:
83 | args.cfg_file = args.cfg_file.replace(os.getcwd().split("/")[-1], "./")
84 | with open(args.cfg_file, 'r') as f:
85 | cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
86 | file_name = args.cfg_file
87 |
88 | if "_BASE_RUN" not in cfg.keys() and "_BASE_MODEL" not in cfg.keys() and "_BASE" not in cfg.keys():
89 | # return cfg if the base file is being accessed
90 | cfg = self._merge_cfg_from_command_update(args, cfg)
91 | return cfg
92 |
93 | if "_BASE" in cfg.keys():
94 | if cfg["_BASE"][1] == '.':
95 | prev_count = cfg["_BASE"].count('..')
96 | cfg_base_file = self._path_join(file_name.split('/')[:(-1-cfg["_BASE"].count('..'))] + cfg["_BASE"].split('/')[prev_count:])
97 | else:
98 | cfg_base_file = cfg["_BASE"].replace(
99 | "./",
100 | args.cfg_file.replace(args.cfg_file.split('/')[-1], "")
101 | )
102 | cfg_base = self._load_yaml(args, cfg_base_file)
103 | cfg = self._merge_cfg_from_base(cfg_base, cfg)
104 | else:
105 | if "_BASE_RUN" in cfg.keys():
106 | if cfg["_BASE_RUN"][1] == '.':
107 | prev_count = cfg["_BASE_RUN"].count('..')
108 | cfg_base_file = self._path_join(file_name.split('/')[:(-1-prev_count)] + cfg["_BASE_RUN"].split('/')[prev_count:])
109 | else:
110 | cfg_base_file = cfg["_BASE_RUN"].replace(
111 | "./",
112 | args.cfg_file.replace(args.cfg_file.split('/')[-1], "")
113 | )
114 | cfg_base = self._load_yaml(args, cfg_base_file)
115 | cfg = self._merge_cfg_from_base(cfg_base, cfg, preserve_base=True)
116 | if "_BASE_MODEL" in cfg.keys():
117 | if cfg["_BASE_MODEL"][1] == '.':
118 | prev_count = cfg["_BASE_MODEL"].count('..')
119 | cfg_base_file = self._path_join(file_name.split('/')[:(-1-cfg["_BASE_MODEL"].count('..'))] + cfg["_BASE_MODEL"].split('/')[prev_count:])
120 | else:
121 | cfg_base_file = cfg["_BASE_MODEL"].replace(
122 | "./",
123 | args.cfg_file.replace(args.cfg_file.split('/')[-1], "")
124 | )
125 | cfg_base = self._load_yaml(args, cfg_base_file)
126 | cfg = self._merge_cfg_from_base(cfg_base, cfg)
127 | cfg = self._merge_cfg_from_command(args, cfg)
128 | return cfg
129 |
130 | def _merge_cfg_from_base(self, cfg_base, cfg_new, preserve_base=False):
131 | for k,v in cfg_new.items():
132 | if k in cfg_base.keys():
133 | if isinstance(v, dict):
134 | self._merge_cfg_from_base(cfg_base[k], v)
135 | else:
136 | cfg_base[k] = v
137 | else:
138 | if "BASE" not in k or preserve_base:
139 | cfg_base[k] = v
140 | return cfg_base
141 |
142 | def _merge_cfg_from_command_update(self, args, cfg):
143 | if len(args.opts) == 0:
144 | return cfg
145 |
146 | assert len(args.opts) % 2 == 0, 'Override list {} has odd length: {}.'.format(
147 | args.opts, len(args.opts)
148 | )
149 | keys = args.opts[0::2]
150 | vals = args.opts[1::2]
151 |
152 | for key, val in zip(keys, vals):
153 | cfg[key] = val
154 |
155 | return cfg
156 |
157 | def _merge_cfg_from_command(self, args, cfg):
158 | assert len(args.opts) % 2 == 0, 'Override list {} has odd length: {}.'.format(
159 | args.opts, len(args.opts)
160 | )
161 | keys = args.opts[0::2]
162 | vals = args.opts[1::2]
163 |
164 | # maximum supported depth 3
165 | for idx, key in enumerate(keys):
166 | key_split = key.split('.')
167 | assert len(key_split) <= 4, 'Key depth error. \nMaximum depth: 3\n Get depth: {}'.format(
168 | len(key_split)
169 | )
170 | assert key_split[0] in cfg.keys(), 'Non-existant key: {}.'.format(
171 | key_split[0]
172 | )
173 | if len(key_split) == 2:
174 | assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format(
175 | key
176 | )
177 | elif len(key_split) == 3:
178 | assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format(
179 | key
180 | )
181 | assert key_split[2] in cfg[key_split[0]][key_split[1]].keys(), 'Non-existant key: {}.'.format(
182 | key
183 | )
184 | elif len(key_split) == 4:
185 | assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format(
186 | key
187 | )
188 | assert key_split[2] in cfg[key_split[0]][key_split[1]].keys(), 'Non-existant key: {}.'.format(
189 | key
190 | )
191 | assert key_split[3] in cfg[key_split[0]][key_split[1]][key_split[2]].keys(), 'Non-existant key: {}.'.format(
192 | key
193 | )
194 | if len(key_split) == 1:
195 | cfg[key_split[0]] = vals[idx]
196 | elif len(key_split) == 2:
197 | cfg[key_split[0]][key_split[1]] = vals[idx]
198 | elif len(key_split) == 3:
199 | cfg[key_split[0]][key_split[1]][key_split[2]] = vals[idx]
200 | elif len(key_split) == 4:
201 | cfg[key_split[0]][key_split[1]][key_split[2]][key_split[3]] = vals[idx]
202 | return cfg
203 |
204 | def _update_dict(self, cfg_dict):
205 | def recur(key, elem):
206 | if type(elem) is dict:
207 | return key, Config(load=False, cfg_dict=elem, cfg_level=key)
208 | else:
209 | if type(elem) is str and elem[1:3]=="e-":
210 | elem = float(elem)
211 | return key, elem
212 | dic = dict(recur(k, v) for k, v in cfg_dict.items())
213 | self.__dict__.update(dic)
214 |
215 | def get_args(self):
216 | return self.args
217 |
218 | def __repr__(self):
219 | return "{}\n".format(self.dump())
220 |
221 | def dump(self):
222 | return json.dumps(self.cfg_dict, indent=2)
223 |
224 | def deep_copy(self):
225 | return copy.deepcopy(self)
226 |
227 | if __name__ == '__main__':
228 | # debug
229 | cfg = Config(load=True)
230 | print(cfg.DATA)
--------------------------------------------------------------------------------
/utils/logging.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
3 |
4 | """Logging."""
5 |
6 | import builtins
7 | import decimal
8 | import functools
9 | import logging
10 | import os
11 | import sys
12 | import simplejson
13 | # from fvcore.common.file_io import PathManager
14 |
15 | import utils.distributed as du
16 |
17 |
18 | def _suppress_print():
19 | """
20 | Suppresses printing from the current process.
21 | """
22 |
23 | def print_pass(*objects, sep=" ", end="\n", file=sys.stdout, flush=False):
24 | pass
25 |
26 | builtins.print = print_pass
27 |
28 |
29 | # @functools.lru_cache(maxsize=None)
30 | # def _cached_log_stream(filename):
31 | # return PathManager.open(filename, "a")
32 |
33 |
34 | def setup_logging(cfg, log_file):
35 | """
36 | Sets up the logging for multiple processes. Only enable the logging for the
37 | master process, and suppress logging for the non-master processes.
38 | """
39 | if du.is_master_proc():
40 | # Enable logging for the master process.
41 | logging.root.handlers = []
42 | else:
43 | # Suppress logging for non-master processes.
44 | _suppress_print()
45 |
46 | logger = logging.getLogger()
47 | logger.setLevel(logging.INFO)
48 | logger.propagate = False
49 | plain_formatter = logging.Formatter(
50 | "[%(asctime)s][%(levelname)s] %(name)s: %(lineno)4d: %(message)s",
51 | datefmt="%m/%d %H:%M:%S",
52 | )
53 |
54 | if du.is_master_proc():
55 | ch = logging.StreamHandler(stream=sys.stdout)
56 | ch.setLevel(logging.DEBUG)
57 | ch.setFormatter(plain_formatter)
58 | logger.addHandler(ch)
59 |
60 | if log_file is not None and du.is_master_proc(du.get_world_size()):
61 | filename = os.path.join(cfg.OUTPUT_DIR, log_file)
62 | fh = logging.FileHandler(filename)
63 | fh.setLevel(logging.DEBUG)
64 | fh.setFormatter(plain_formatter)
65 | logger.addHandler(fh)
66 |
67 |
68 | def get_logger(name):
69 | """
70 | Retrieve the logger with the specified name or, if name is None, return a
71 | logger which is the root logger of the hierarchy.
72 | Args:
73 | name (string): name of the logger.
74 | """
75 | return logging.getLogger(name)
76 |
77 |
78 | def log_json_stats(stats):
79 | """
80 | Logs json stats.
81 | Args:
82 | stats (dict): a dictionary of statistical information to log.
83 | """
84 | stats = {
85 | k: decimal.Decimal("{:.6f}".format(v)) if isinstance(v, float) else v
86 | for k, v in stats.items()
87 | }
88 | json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True)
89 | logger = get_logger(__name__)
90 | logger.info("{:s}".format(json_stats))
91 |
--------------------------------------------------------------------------------
/utils/multi_port.py:
--------------------------------------------------------------------------------
1 | import socket
2 | from contextlib import closing
3 |
4 | def find_free_port():
5 | """ https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number """
6 | with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
7 | s.bind(('', 0))
8 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
9 | return str(s.getsockname()[1])
--------------------------------------------------------------------------------
/utils/optim/__init__.py:
--------------------------------------------------------------------------------
1 | from .lr_scheduler import *
2 | from .adafactor import *
3 |
--------------------------------------------------------------------------------
/utils/optim/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/optim/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/optim/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/optim/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/optim/__pycache__/adafactor.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/optim/__pycache__/adafactor.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/optim/__pycache__/adafactor.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/optim/__pycache__/adafactor.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/optim/__pycache__/lr_scheduler.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/optim/__pycache__/lr_scheduler.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/optim/__pycache__/lr_scheduler.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/optim/__pycache__/lr_scheduler.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/optim/adafactor.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.optim import Optimizer
4 | from torch.optim.lr_scheduler import LambdaLR
5 |
6 | __all__ = ['Adafactor']
7 |
8 | class Adafactor(Optimizer):
9 | """
10 | AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code:
11 | https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
12 | Paper: *Adafactor: Adaptive Learning Rates with Sublinear Memory Cost* https://arxiv.org/abs/1804.04235 Note that
13 | this optimizer internally adjusts the learning rate depending on the `scale_parameter`, `relative_step` and
14 | `warmup_init` options. To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
15 | `relative_step=False`.
16 | Arguments:
17 | params (`Iterable[nn.parameter.Parameter]`):
18 | Iterable of parameters to optimize or dictionaries defining parameter groups.
19 | lr (`float`, *optional*):
20 | The external learning rate.
21 | eps (`Tuple[float, float]`, *optional*, defaults to (1e-30, 1e-3)):
22 | Regularization constants for square gradient and parameter scale respectively
23 | clip_threshold (`float`, *optional*, defaults 1.0):
24 | Threshold of root mean square of final gradient update
25 | decay_rate (`float`, *optional*, defaults to -0.8):
26 | Coefficient used to compute running averages of square
27 | beta1 (`float`, *optional*):
28 | Coefficient used for computing running averages of gradient
29 | weight_decay (`float`, *optional*, defaults to 0):
30 | Weight decay (L2 penalty)
31 | scale_parameter (`bool`, *optional*, defaults to `True`):
32 | If True, learning rate is scaled by root mean square
33 | relative_step (`bool`, *optional*, defaults to `True`):
34 | If True, time-dependent learning rate is computed instead of external learning rate
35 | warmup_init (`bool`, *optional*, defaults to `False`):
36 | Time-dependent learning rate computation depends on whether warm-up initialization is being used
37 | This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested.
38 | Recommended T5 finetuning settings (https://discuss.huggingface.co/t/t5-finetuning-tips/684/3):
39 | - Training without LR warmup or clip_threshold is not recommended.
40 | - use scheduled LR warm-up to fixed LR
41 | - use clip_threshold=1.0 (https://arxiv.org/abs/1804.04235)
42 | - Disable relative updates
43 | - Use scale_parameter=False
44 | - Additional optimizer operations like gradient clipping should not be used alongside Adafactor
45 | Example:
46 | ```python
47 | Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3)
48 | ```
49 | Others reported the following combination to work well:
50 | ```python
51 | Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
52 | ```
53 | When using `lr=None` with [`Trainer`] you will most likely need to use [`~optimization.AdafactorSchedule`]
54 | scheduler as following:
55 | ```python
56 | from transformers.optimization import Adafactor, AdafactorSchedule
57 | optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
58 | lr_scheduler = AdafactorSchedule(optimizer)
59 | trainer = Trainer(..., optimizers=(optimizer, lr_scheduler))
60 | ```
61 | Usage:
62 | ```python
63 | # replace AdamW with Adafactor
64 | optimizer = Adafactor(
65 | model.parameters(),
66 | lr=1e-3,
67 | eps=(1e-30, 1e-3),
68 | clip_threshold=1.0,
69 | decay_rate=-0.8,
70 | beta1=None,
71 | weight_decay=0.0,
72 | relative_step=False,
73 | scale_parameter=False,
74 | warmup_init=False,
75 | )
76 | ```"""
77 |
78 | def __init__(
79 | self,
80 | params,
81 | lr=None,
82 | eps=(1e-30, 1e-3),
83 | clip_threshold=1.0,
84 | decay_rate=-0.8,
85 | beta1=None,
86 | weight_decay=0.0,
87 | scale_parameter=True,
88 | relative_step=True,
89 | warmup_init=False,
90 | ):
91 | r"""require_version("torch>=1.5.0") # add_ with alpha
92 | """
93 | if lr is not None and relative_step:
94 | raise ValueError("Cannot combine manual `lr` and `relative_step=True` options")
95 | if warmup_init and not relative_step:
96 | raise ValueError("`warmup_init=True` requires `relative_step=True`")
97 |
98 | defaults = dict(
99 | lr=lr,
100 | eps=eps,
101 | clip_threshold=clip_threshold,
102 | decay_rate=decay_rate,
103 | beta1=beta1,
104 | weight_decay=weight_decay,
105 | scale_parameter=scale_parameter,
106 | relative_step=relative_step,
107 | warmup_init=warmup_init,
108 | )
109 | super().__init__(params, defaults)
110 |
111 | @staticmethod
112 | def _get_lr(param_group, param_state):
113 | rel_step_sz = param_group["lr"]
114 | if param_group["relative_step"]:
115 | min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2
116 | rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"]))
117 | param_scale = 1.0
118 | if param_group["scale_parameter"]:
119 | param_scale = max(param_group["eps"][1], param_state["RMS"])
120 | return param_scale * rel_step_sz
121 |
122 | @staticmethod
123 | def _get_options(param_group, param_shape):
124 | factored = len(param_shape) >= 2
125 | use_first_moment = param_group["beta1"] is not None
126 | return factored, use_first_moment
127 |
128 | @staticmethod
129 | def _rms(tensor):
130 | return tensor.norm(2) / (tensor.numel() ** 0.5)
131 |
132 | @staticmethod
133 | def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
134 | # copy from fairseq's adafactor implementation:
135 | # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505
136 | r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
137 | c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
138 | return torch.mul(r_factor, c_factor)
139 |
140 | def step(self, closure=None):
141 | """
142 | Performs a single optimization step
143 | Arguments:
144 | closure (callable, optional): A closure that reevaluates the model
145 | and returns the loss.
146 | """
147 | loss = None
148 | if closure is not None:
149 | loss = closure()
150 |
151 | for group in self.param_groups:
152 | for p in group["params"]:
153 | if p.grad is None:
154 | continue
155 | grad = p.grad.data
156 | if grad.dtype in {torch.float16, torch.bfloat16}:
157 | grad = grad.float()
158 | if grad.is_sparse:
159 | raise RuntimeError("Adafactor does not support sparse gradients.")
160 |
161 | state = self.state[p]
162 | grad_shape = grad.shape
163 |
164 | factored, use_first_moment = self._get_options(group, grad_shape)
165 | # State Initialization
166 | if len(state) == 0:
167 | state["step"] = 0
168 |
169 | if use_first_moment:
170 | # Exponential moving average of gradient values
171 | state["exp_avg"] = torch.zeros_like(grad)
172 | if factored:
173 | state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
174 | state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
175 | else:
176 | state["exp_avg_sq"] = torch.zeros_like(grad)
177 |
178 | state["RMS"] = 0
179 | else:
180 | if use_first_moment:
181 | state["exp_avg"] = state["exp_avg"].to(grad)
182 | if factored:
183 | state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
184 | state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
185 | else:
186 | state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
187 |
188 | p_data_fp32 = p.data
189 | if p.data.dtype in {torch.float16, torch.bfloat16}:
190 | p_data_fp32 = p_data_fp32.float()
191 |
192 | state["step"] += 1
193 | state["RMS"] = self._rms(p_data_fp32)
194 | lr = self._get_lr(group, state)
195 |
196 | beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
197 | update = (grad**2) + group["eps"][0]
198 | if factored:
199 | exp_avg_sq_row = state["exp_avg_sq_row"]
200 | exp_avg_sq_col = state["exp_avg_sq_col"]
201 |
202 | exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
203 | exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
204 |
205 | # Approximation of exponential moving average of square of gradient
206 | update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
207 | update.mul_(grad)
208 | else:
209 | exp_avg_sq = state["exp_avg_sq"]
210 |
211 | exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
212 | update = exp_avg_sq.rsqrt().mul_(grad)
213 |
214 | update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
215 | update.mul_(lr)
216 |
217 | if use_first_moment:
218 | exp_avg = state["exp_avg"]
219 | exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
220 | update = exp_avg
221 |
222 | if group["weight_decay"] != 0:
223 | p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
224 |
225 | p_data_fp32.add_(-update)
226 |
227 | if p.data.dtype in {torch.float16, torch.bfloat16}:
228 | p.data.copy_(p_data_fp32)
229 |
230 | return loss
231 |
--------------------------------------------------------------------------------
/utils/optim/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import math
2 | from torch.optim.lr_scheduler import _LRScheduler
3 |
4 | __all__ = ['AnnealingLR']
5 |
6 | class AnnealingLR(_LRScheduler):
7 |
8 | def __init__(self, optimizer, base_lr, warmup_steps, total_steps, decay_mode='cosine', min_lr=0.0, last_step=-1):
9 | assert decay_mode in ['linear', 'cosine', 'none']
10 | self.optimizer = optimizer
11 | self.base_lr = base_lr
12 | self.warmup_steps = warmup_steps
13 | self.total_steps = total_steps
14 | self.decay_mode = decay_mode
15 | self.min_lr = min_lr
16 | self.current_step = last_step + 1
17 | self.step(self.current_step)
18 |
19 | def get_lr(self):
20 | if self.warmup_steps > 0 and self.current_step <= self.warmup_steps:
21 | return self.base_lr * self.current_step / self.warmup_steps
22 | else:
23 | ratio = (self.current_step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
24 | ratio = min(1.0, max(0.0, ratio))
25 | if self.decay_mode == 'linear':
26 | return self.base_lr * (1 - ratio)
27 | elif self.decay_mode == 'cosine':
28 | return self.base_lr * (math.cos(math.pi * ratio) + 1.0) / 2.0
29 | else:
30 | return self.base_lr
31 |
32 | def step(self, current_step=None):
33 | if current_step is None:
34 | current_step = self.current_step + 1
35 | self.current_step = current_step
36 | new_lr = max(self.min_lr, self.get_lr())
37 | if isinstance(self.optimizer, list):
38 | for o in self.optimizer:
39 | for group in o.param_groups:
40 | group['lr'] = new_lr
41 | else:
42 | for group in self.optimizer.param_groups:
43 | group['lr'] = new_lr
44 |
45 | def state_dict(self):
46 | return {
47 | 'base_lr': self.base_lr,
48 | 'warmup_steps': self.warmup_steps,
49 | 'total_steps': self.total_steps,
50 | 'decay_mode': self.decay_mode,
51 | 'current_step': self.current_step}
52 |
53 | def load_state_dict(self, state_dict):
54 | self.base_lr = state_dict['base_lr']
55 | self.warmup_steps = state_dict['warmup_steps']
56 | self.total_steps = state_dict['total_steps']
57 | self.decay_mode = state_dict['decay_mode']
58 | self.current_step = state_dict['current_step']
59 |
--------------------------------------------------------------------------------
/utils/recenter_i2v.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import torchvision
3 | import torch
4 | from torch import optim
5 | import numpy as np
6 |
7 | from inspect import isfunction
8 | from PIL import Image, ImageDraw, ImageFont
9 |
10 | import os, sys
11 | import numpy as np
12 | import matplotlib.pyplot as plt
13 | from PIL import Image
14 | import torch
15 | import time
16 | import cv2
17 | import PIL
18 |
19 | def add_margin(pil_img, color=0, size=256):
20 | width, height = pil_img.size
21 | result = Image.new(pil_img.mode, (size, size), color)
22 | result.paste(pil_img, ((size - width) // 2, (size - height) // 2))
23 | return result
24 |
25 | def prepare_inputs(image_path, elevation_input, crop_size=-1, image_size=256):
26 | image_input = Image.open(image_path)
27 |
28 | if crop_size!=-1:
29 | alpha_np = np.asarray(image_input)[:, :, 3]
30 | coords = np.stack(np.nonzero(alpha_np), 1)[:, (1, 0)]
31 | min_x, min_y = np.min(coords, 0)
32 | max_x, max_y = np.max(coords, 0)
33 | ref_img_ = image_input.crop((min_x, min_y, max_x, max_y))
34 | h, w = ref_img_.height, ref_img_.width
35 | scale = crop_size / max(h, w)
36 | h_, w_ = int(scale * h), int(scale * w)
37 | ref_img_ = ref_img_.resize((w_, h_), resample=Image.BICUBIC)
38 | image_input = add_margin(ref_img_, size=image_size)
39 | else:
40 | image_input = add_margin(image_input, size=max(image_input.height, image_input.width))
41 | image_input = image_input.resize((image_size, image_size), resample=Image.BICUBIC)
42 |
43 | image_input = np.asarray(image_input)
44 | image_input = image_input.astype(np.float32) / 255.0
45 | if image_input.shape[-1]==4:
46 | ref_mask = image_input[:, :, 3:]
47 | image_input[:, :, :3] = image_input[:, :, :3] * ref_mask + 1 - ref_mask # white background
48 | return image_input
49 |
50 | root_dir = sys.argv[1]
51 | items = [os.path.join(root_dir, item) for item in os.listdir(root_dir)]
52 | for idx, item in enumerate(items):
53 | res = prepare_inputs(item, 15, 200)
54 | Image.fromarray((res*255.0).astype(np.uint8)).save("./data/images", "{:05d}.png".format(idx))
55 |
--------------------------------------------------------------------------------
/utils/registry.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Alibaba Group Holding Limited. All Rights Reserved.
2 |
3 | # Registry class & build_from_config function partially modified from
4 | # https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/registry.py
5 | # Copyright 2018-2020 Open-MMLab. All rights reserved.
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 |
19 | import copy
20 | import inspect
21 | import warnings
22 |
23 |
24 | def build_from_config(cfg, registry, **kwargs):
25 | """ Default builder function.
26 |
27 | Args:
28 | cfg (dict): A dict which contains parameters passes to target class or function.
29 | Must contains key 'type', indicates the target class or function name.
30 | registry (Registry): An registry to search target class or function.
31 | kwargs (dict, optional): Other params not in config dict.
32 |
33 | Returns:
34 | Target class object or object returned by invoking function.
35 |
36 | Raises:
37 | TypeError:
38 | KeyError:
39 | Exception:
40 | """
41 | if not isinstance(cfg, dict):
42 | raise TypeError(f"config must be type dict, got {type(cfg)}")
43 | if "type" not in cfg:
44 | raise KeyError(f"config must contain key type, got {cfg}")
45 | if not isinstance(registry, Registry):
46 | raise TypeError(f"registry must be type Registry, got {type(registry)}")
47 |
48 | cfg = copy.deepcopy(cfg)
49 |
50 | req_type = cfg.pop("type")
51 | req_type_entry = req_type
52 | if isinstance(req_type, str):
53 | req_type_entry = registry.get(req_type)
54 | if req_type_entry is None:
55 | raise KeyError(f"{req_type} not found in {registry.name} registry")
56 |
57 | if kwargs is not None:
58 | cfg.update(kwargs)
59 |
60 | if inspect.isclass(req_type_entry):
61 | try:
62 | return req_type_entry(**cfg)
63 | except Exception as e:
64 | raise Exception(f"Failed to init class {req_type_entry}, with {e}")
65 | elif inspect.isfunction(req_type_entry):
66 | try:
67 | return req_type_entry(**cfg)
68 | except Exception as e:
69 | raise Exception(f"Failed to invoke function {req_type_entry}, with {e}")
70 | else:
71 | raise TypeError(f"type must be str or class, got {type(req_type_entry)}")
72 |
73 |
74 | class Registry(object):
75 | """ A registry maps key to classes or functions.
76 |
77 | Example:
78 | >>> MODELS = Registry('MODELS')
79 | >>> @MODELS.register_class()
80 | >>> class ResNet(object):
81 | >>> pass
82 | >>> resnet = MODELS.build(dict(type="ResNet"))
83 | >>>
84 | >>> import torchvision
85 | >>> @MODELS.register_function("InceptionV3")
86 | >>> def get_inception_v3(pretrained=False, progress=True):
87 | >>> return torchvision.models.inception_v3(pretrained=pretrained, progress=progress)
88 | >>> inception_v3 = MODELS.build(dict(type='InceptionV3', pretrained=True))
89 |
90 | Args:
91 | name (str): Registry name.
92 | build_func (func, None): Instance construct function. Default is build_from_config.
93 | allow_types (tuple): Indicates how to construct the instance, by constructing class or invoking function.
94 | """
95 |
96 | def __init__(self, name, build_func=None, allow_types=("class", "function")):
97 | self.name = name
98 | self.allow_types = allow_types
99 | self.class_map = {}
100 | self.func_map = {}
101 | self.build_func = build_func or build_from_config
102 |
103 | def get(self, req_type):
104 | return self.class_map.get(req_type) or self.func_map.get(req_type)
105 |
106 | def build(self, *args, **kwargs):
107 | return self.build_func(*args, **kwargs, registry=self)
108 |
109 | def register_class(self, name=None):
110 | def _register(cls):
111 | if not inspect.isclass(cls):
112 | raise TypeError(f"Module must be type class, got {type(cls)}")
113 | if "class" not in self.allow_types:
114 | raise TypeError(f"Register {self.name} only allows type {self.allow_types}, got class")
115 | module_name = name or cls.__name__
116 | if module_name in self.class_map:
117 | warnings.warn(f"Class {module_name} already registered by {self.class_map[module_name]}, "
118 | f"will be replaced by {cls}")
119 | self.class_map[module_name] = cls
120 | return cls
121 |
122 | return _register
123 |
124 | def register_function(self, name=None):
125 | def _register(func):
126 | if not inspect.isfunction(func):
127 | raise TypeError(f"Registry must be type function, got {type(func)}")
128 | if "function" not in self.allow_types:
129 | raise TypeError(f"Registry {self.name} only allows type {self.allow_types}, got function")
130 | func_name = name or func.__name__
131 | if func_name in self.class_map:
132 | warnings.warn(f"Function {func_name} already registered by {self.func_map[func_name]}, "
133 | f"will be replaced by {func}")
134 | self.func_map[func_name] = func
135 | return func
136 |
137 | return _register
138 |
139 | def _list(self):
140 | keys = sorted(list(self.class_map.keys()) + list(self.func_map.keys()))
141 | descriptions = []
142 | for key in keys:
143 | if key in self.class_map:
144 | descriptions.append(f"{key}: {self.class_map[key]}")
145 | else:
146 | descriptions.append(
147 | f"{key}: ")
148 | return "\n".join(descriptions)
149 |
150 | def __repr__(self):
151 | description = self._list()
152 | description = '\n'.join(['\t' + s for s in description.split('\n')])
153 | return f"{self.__class__.__name__} [{self.name}], \n" + description
154 |
155 |
156 |
--------------------------------------------------------------------------------
/utils/registry_class.py:
--------------------------------------------------------------------------------
1 | from .registry import Registry, build_from_config
2 |
3 | def build_func(cfg, registry, **kwargs):
4 | """
5 | Except for config, if passing a list of dataset config, then return the concat type of it
6 | """
7 | return build_from_config(cfg, registry, **kwargs)
8 |
9 | AUTO_ENCODER = Registry("AUTO_ENCODER", build_func=build_func)
10 | DATASETS = Registry("DATASETS", build_func=build_func)
11 | DIFFUSION = Registry("DIFFUSION", build_func=build_func)
12 | DISTRIBUTION = Registry("DISTRIBUTION", build_func=build_func)
13 | EMBEDDER = Registry("EMBEDDER", build_func=build_func)
14 | ENGINE = Registry("ENGINE", build_func=build_func)
15 | INFER_ENGINE = Registry("INFER_ENGINE", build_func=build_func)
16 | MODEL = Registry("MODEL", build_func=build_func)
17 | PRETRAIN = Registry("PRETRAIN", build_func=build_func)
18 | VISUAL = Registry("VISUAL", build_func=build_func)
19 |
--------------------------------------------------------------------------------
/utils/seed.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | import numpy as np
4 |
5 |
6 | def setup_seed(seed):
7 | torch.manual_seed(seed)
8 | torch.cuda.manual_seed_all(seed)
9 | np.random.seed(seed)
10 | random.seed(seed)
11 | torch.backends.cudnn.deterministic = True
--------------------------------------------------------------------------------
/utils/util.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def to_device(batch, device, non_blocking=False):
4 | if isinstance(batch, (list, tuple)):
5 | return type(batch)([
6 | to_device(u, device, non_blocking)
7 | for u in batch])
8 | elif isinstance(batch, dict):
9 | return type(batch)([
10 | (k, to_device(v, device, non_blocking))
11 | for k, v in batch.items()])
12 | elif isinstance(batch, torch.Tensor) and batch.device != device:
13 | batch = batch.to(device, non_blocking=non_blocking)
14 | else:
15 | return batch
16 | return batch
17 |
--------------------------------------------------------------------------------
/utils/video_op.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import sys
4 | import cv2
5 | import glob
6 | import math
7 | import torch
8 | import gzip
9 | import copy
10 | import time
11 | import json
12 | import pickle
13 | import base64
14 | import imageio
15 | import hashlib
16 | import requests
17 | import binascii
18 | import zipfile
19 | # import skvideo.io
20 | import numpy as np
21 | from io import BytesIO
22 | import urllib.request
23 | import torch.nn.functional as F
24 | import torchvision.utils as tvutils
25 | from multiprocessing.pool import ThreadPool as Pool
26 | from einops import rearrange
27 | from PIL import Image, ImageDraw, ImageFont
28 |
29 |
30 | def gen_text_image(captions, text_size):
31 | num_char = int(38 * (text_size / text_size))
32 | font_size = int(text_size / 20)
33 | font = ImageFont.truetype('data/font/DejaVuSans.ttf', size=font_size)
34 | text_image_list = []
35 | for text in captions:
36 | txt_img = Image.new("RGB", (text_size, text_size), color="white")
37 | draw = ImageDraw.Draw(txt_img)
38 | lines = "\n".join(text[start:start + num_char] for start in range(0, len(text), num_char))
39 | draw.text((0, 0), lines, fill="black", font=font)
40 | txt_img = np.array(txt_img)
41 | text_image_list.append(txt_img)
42 | text_images = np.stack(text_image_list, axis=0)
43 | text_images = torch.from_numpy(text_images)
44 | return text_images
45 |
46 | @torch.no_grad()
47 | def save_video_refimg_and_text(
48 | local_path,
49 | ref_frame,
50 | gen_video,
51 | captions,
52 | mean=[0.5, 0.5, 0.5],
53 | std=[0.5, 0.5, 0.5],
54 | text_size=256,
55 | nrow=4,
56 | save_fps=8,
57 | retry=5):
58 | '''
59 | gen_video: BxCxFxHxW
60 | '''
61 | nrow = max(int(gen_video.size(0) / 2), 1)
62 | vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw
63 | vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw
64 |
65 | text_images = gen_text_image(captions, text_size) # Tensor 8x256x256x3
66 | text_images = text_images.unsqueeze(1) # Tensor 8x1x256x256x3
67 | text_images = text_images.repeat_interleave(repeats=gen_video.size(2), dim=1) # 8x16x256x256x3
68 |
69 | ref_frame = ref_frame.unsqueeze(2)
70 | ref_frame = ref_frame.mul_(vid_std).add_(vid_mean)
71 | ref_frame = ref_frame.repeat_interleave(repeats=gen_video.size(2), dim=2) # 8x16x256x256x3
72 | ref_frame.clamp_(0, 1)
73 | ref_frame = ref_frame * 255.0
74 | ref_frame = rearrange(ref_frame, 'b c f h w -> b f h w c')
75 |
76 | gen_video = gen_video.mul_(vid_std).add_(vid_mean) # 8x3x16x256x384
77 | gen_video.clamp_(0, 1)
78 | gen_video = gen_video * 255.0
79 |
80 | images = rearrange(gen_video, 'b c f h w -> b f h w c')
81 | images = torch.cat([ref_frame, images, text_images], dim=3)
82 |
83 | images = rearrange(images, '(r j) f h w c -> f (r h) (j w) c', r=nrow)
84 | images = [(img.numpy()).astype('uint8') for img in images]
85 |
86 | for _ in [None] * retry:
87 | try:
88 | if len(images) == 1:
89 | local_path = local_path + '.png'
90 | cv2.imwrite(local_path, images[0][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100])
91 | else:
92 | local_path = local_path + '.mp4'
93 | frame_dir = os.path.join(os.path.dirname(local_path), '%s_frames' % (os.path.basename(local_path)))
94 | os.system(f'rm -rf {frame_dir}'); os.makedirs(frame_dir, exist_ok=True)
95 | for fid, frame in enumerate(images):
96 | tpth = os.path.join(frame_dir, '%04d.png' % (fid+1))
97 | cv2.imwrite(tpth, frame[:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100])
98 | cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate {save_fps} -i {frame_dir}/%04d.png -vcodec libx264 -crf 17 -pix_fmt yuv420p {local_path}'
99 | os.system(cmd); os.system(f'rm -rf {frame_dir}')
100 | # os.system(f'rm -rf {local_path}')
101 | exception = None
102 | break
103 | except Exception as e:
104 | exception = e
105 | continue
106 |
107 |
108 | @torch.no_grad()
109 | def save_i2vgen_video(
110 | local_path,
111 | image_id,
112 | gen_video,
113 | captions,
114 | mean=[0.5, 0.5, 0.5],
115 | std=[0.5, 0.5, 0.5],
116 | text_size=256,
117 | retry=5,
118 | save_fps = 8
119 | ):
120 | '''
121 | Save both the generated video and the input conditions.
122 | '''
123 | vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw
124 | vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw
125 |
126 | text_images = gen_text_image(captions, text_size) # Tensor 1x256x256x3
127 | text_images = text_images.unsqueeze(1) # Tensor 1x1x256x256x3
128 | text_images = text_images.repeat_interleave(repeats=gen_video.size(2), dim=1) # 1x16x256x256x3
129 |
130 | image_id = image_id.unsqueeze(2) # B, C, F, H, W
131 | image_id = image_id.repeat_interleave(repeats=gen_video.size(2), dim=2) # 1x3x32x256x448
132 | image_id = image_id.mul_(vid_std).add_(vid_mean) # 32x3x256x448
133 | image_id.clamp_(0, 1)
134 | image_id = image_id * 255.0
135 | image_id = rearrange(image_id, 'b c f h w -> b f h w c')
136 |
137 | gen_video = gen_video.mul_(vid_std).add_(vid_mean) # 8x3x16x256x384
138 | gen_video.clamp_(0, 1)
139 | gen_video = gen_video * 255.0
140 |
141 | images = rearrange(gen_video, 'b c f h w -> b f h w c')
142 | images = torch.cat([image_id, images, text_images], dim=3)
143 | images = images[0]
144 | images = [(img.numpy()).astype('uint8') for img in images]
145 |
146 | exception = None
147 | for _ in [None] * retry:
148 | try:
149 | frame_dir = os.path.join(os.path.dirname(local_path), '%s_frames' % (os.path.basename(local_path)))
150 | os.system(f'rm -rf {frame_dir}'); os.makedirs(frame_dir, exist_ok=True)
151 | for fid, frame in enumerate(images):
152 | tpth = os.path.join(frame_dir, '%04d.png' % (fid+1))
153 | cv2.imwrite(tpth, frame[:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100])
154 | cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate {save_fps} -i {frame_dir}/%04d.png -vcodec libx264 -crf 17 -pix_fmt yuv420p {local_path}'
155 | os.system(cmd); os.system(f'rm -rf {frame_dir}')
156 | break
157 | except Exception as e:
158 | exception = e
159 | continue
160 |
161 | if exception is not None:
162 | raise exception
163 |
164 |
165 | @torch.no_grad()
166 | def save_i2vgen_video_safe(
167 | local_path,
168 | gen_video,
169 | captions,
170 | mean=[0.5, 0.5, 0.5],
171 | std=[0.5, 0.5, 0.5],
172 | text_size=256,
173 | retry=5,
174 | save_fps = 8
175 | ):
176 | '''
177 | Save only the generated video, do not save the related reference conditions, and at the same time perform anomaly detection on the last frame.
178 | '''
179 | vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw
180 | vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw
181 |
182 | gen_video = gen_video.mul_(vid_std).add_(vid_mean) # 8x3x16x256x384
183 | gen_video.clamp_(0, 1)
184 | gen_video = gen_video * 255.0
185 |
186 | images = rearrange(gen_video, 'b c f h w -> b f h w c')
187 | images = images[0]
188 | images = [(img.numpy()).astype('uint8') for img in images]
189 | num_image = len(images)
190 | exception = None
191 | for _ in [None] * retry:
192 | try:
193 | if num_image == 1:
194 | local_path = local_path + '.png'
195 | cv2.imwrite(local_path, images[0][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100])
196 | else:
197 | os.makedirs(local_path.replace(".mp4", ""), exist_ok=True)
198 |
199 | writer = imageio.get_writer(local_path, fps=save_fps, codec='libx264', quality=8)
200 | for fid, frame in enumerate(images):
201 | # if fid == num_image-1: # Fix known bugs.
202 | # ratio = (np.sum((frame >= 117) & (frame <= 137)))/(frame.size)
203 | # if ratio > 0.4: continue
204 | writer.append_data(frame)
205 | cv2.imwrite(os.path.join(local_path.replace(".mp4", ""), "{:05d}.png".format(fid)), frame[:,:,::-1])
206 | writer.close()
207 | break
208 | except Exception as e:
209 | exception = e
210 | continue
211 |
212 | if exception is not None:
213 | raise exception
214 |
215 |
--------------------------------------------------------------------------------