├── LICENSE ├── README.md ├── assets └── f.png ├── configs ├── base.yaml ├── i2vgen_xl_infer.yaml ├── i2vgen_xl_train.yaml ├── t2v_infer.yaml ├── t2v_train.yaml └── t2v_train_laion.yaml ├── core ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── attention.cpython-38.pyc │ ├── gs.cpython-38.pyc │ ├── models.cpython-38.pyc │ ├── options.cpython-38.pyc │ ├── unet.cpython-38.pyc │ └── utils.cpython-38.pyc ├── attention.py ├── gs.py ├── models.py ├── options.py ├── provider_objaverse.py ├── unet.py └── utils.py ├── data ├── images │ ├── demo1.png │ └── demo2.png ├── lvis_thres_28.json ├── stable_diffusion_image_key_temporal_attention_x1.json ├── test_images.txt ├── test_prompts.txt ├── text_captions_cap3d.json └── valid_paths_v4_cap_filter_thres_28_catfilter19w.json ├── inference.py ├── install.sh ├── requirements.txt ├── tools ├── __init__.py ├── __pycache__ │ └── __init__.cpython-38.pyc ├── annotator │ ├── canny │ │ └── __init__.py │ ├── histogram │ │ ├── __init__.py │ │ └── palette.py │ ├── sketch │ │ ├── __init__.py │ │ ├── pidinet.py │ │ └── sketch_simplification.py │ └── util.py ├── basic_funcs │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── pretrain_functions.cpython-310.pyc │ │ └── pretrain_functions.cpython-38.pyc │ └── pretrain_functions.py ├── datasets │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── image_dataset.cpython-310.pyc │ │ ├── image_dataset.cpython-38.pyc │ │ ├── video_dataset.cpython-310.pyc │ │ ├── video_dataset.cpython-38.pyc │ │ ├── video_i2v_dataset.cpython-310.pyc │ │ └── video_i2v_dataset.cpython-38.pyc │ ├── image_dataset.py │ ├── laion_dataset.py │ ├── video_dataset.py │ └── video_i2v_dataset.py ├── hooks │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── visual_train_it2v_video.cpython-310.pyc │ │ └── visual_train_it2v_video.cpython-38.pyc │ ├── visual_train_it2v_video.py │ └── visual_train_t2v.py ├── inferences │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── inference_i2vgen_entrance.cpython-310.pyc │ │ ├── inference_i2vgen_entrance.cpython-38.pyc │ │ ├── inference_text2video_entrance.cpython-310.pyc │ │ └── inference_text2video_entrance.cpython-38.pyc │ ├── inference_i2vgen_entrance.py │ └── inference_text2video_entrance.py ├── modules │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── autoencoder.cpython-310.pyc │ │ ├── autoencoder.cpython-38.pyc │ │ ├── clip_embedder.cpython-310.pyc │ │ ├── clip_embedder.cpython-38.pyc │ │ ├── config.cpython-310.pyc │ │ └── config.cpython-38.pyc │ ├── autoencoder.py │ ├── clip_embedder.py │ ├── config.py │ ├── diffusions │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── diffusion_ddim.cpython-310.pyc │ │ │ ├── diffusion_ddim.cpython-38.pyc │ │ │ ├── losses.cpython-310.pyc │ │ │ ├── losses.cpython-38.pyc │ │ │ ├── schedules.cpython-310.pyc │ │ │ └── schedules.cpython-38.pyc │ │ ├── diffusion_ddim.py │ │ ├── losses.py │ │ └── schedules.py │ └── unet │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── depthwise_attn.cpython-310.pyc │ │ ├── depthwise_attn.cpython-38.pyc │ │ ├── depthwise_net.cpython-310.pyc │ │ ├── depthwise_net.cpython-38.pyc │ │ ├── depthwise_utils.cpython-310.pyc │ │ ├── depthwise_utils.cpython-38.pyc │ │ ├── unet_i2vgen.cpython-310.pyc │ │ ├── unet_i2vgen.cpython-38.pyc │ │ ├── unet_t2v.cpython-310.pyc │ │ ├── unet_t2v.cpython-38.pyc │ │ ├── util.cpython-310.pyc │ │ └── util.cpython-38.pyc │ │ ├── mha_flash.py │ │ ├── unet_i2vgen.py │ │ ├── unet_t2v.py │ │ └── util.py └── train │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-38.pyc │ ├── train_i2v_enterance.cpython-310.pyc │ ├── train_i2v_enterance.cpython-38.pyc │ ├── train_t2v_enterance.cpython-310.pyc │ └── train_t2v_enterance.cpython-38.pyc │ ├── prev_t2v.py │ ├── train_i2v_enterance.py │ └── train_t2v_enterance.py ├── train_net.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-310.pyc ├── __init__.cpython-38.pyc ├── assign_cfg.cpython-310.pyc ├── assign_cfg.cpython-38.pyc ├── camera_utils.cpython-310.pyc ├── camera_utils.cpython-38.pyc ├── config.cpython-310.pyc ├── config.cpython-38.pyc ├── distributed.cpython-310.pyc ├── distributed.cpython-38.pyc ├── logging.cpython-310.pyc ├── logging.cpython-38.pyc ├── multi_port.cpython-310.pyc ├── multi_port.cpython-38.pyc ├── registry.cpython-310.pyc ├── registry.cpython-38.pyc ├── registry_class.cpython-310.pyc ├── registry_class.cpython-38.pyc ├── seed.cpython-310.pyc ├── seed.cpython-38.pyc ├── transforms.cpython-310.pyc ├── transforms.cpython-38.pyc ├── util.cpython-310.pyc ├── util.cpython-38.pyc ├── video_op.cpython-310.pyc └── video_op.cpython-38.pyc ├── assign_cfg.py ├── camera_utils.py ├── config.py ├── distributed.py ├── logging.py ├── multi_port.py ├── optim ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-38.pyc │ ├── adafactor.cpython-310.pyc │ ├── adafactor.cpython-38.pyc │ ├── lr_scheduler.cpython-310.pyc │ └── lr_scheduler.cpython-38.pyc ├── adafactor.py └── lr_scheduler.py ├── recenter_i2v.py ├── registry.py ├── registry_class.py ├── seed.py ├── transforms.py ├── util.py └── video_op.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Alibaba 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## VideoMV: Consistent Multi-View Generation Based on Large Video Generative Model. 2 | 3 | [Qi Zuo\*](https://scholar.google.com/citations?view_op=list_works&hl=en&user=UDnHe2IAAAAJ), 4 | [Xiaodong Gu\*](https://scholar.google.com.hk/citations?user=aJPO514AAAAJ&hl=zh-CN&oi=ao), 5 | [Lingteng Qiu](https://lingtengqiu.github.io/), 6 | [Yuan Dong](dy283090@alibaba-inc.com), 7 | [Zhengyi Zhao](bushe.zzy@alibaba-inc.com), 8 | [Weihao Yuan](https://weihao-yuan.com/), 9 | [Rui Peng](https://prstrive.github.io/), 10 | [Siyu Zhu](https://sites.google.com/site/zhusiyucs/home/), 11 | [Zilong Dong](https://scholar.google.com/citations?user=GHOQKCwAAAAJ&hl=zh-CN&oi=ao), 12 | [Liefeng Bo](https://research.cs.washington.edu/istc/lfb/), 13 | [Qixing Huang](https://www.cs.utexas.edu/~huangqx/) 14 | 15 | https://github.com/alibaba/VideoMV/assets/58206232/3a78e28d-bda4-4d4c-a2ae-994d0320a301 16 | 17 | ## [Project page](https://aigc3d.github.io/VideoMV) | [Paper](https://arxiv.org/abs/2403.12010) | [YouTube](https://www.youtube.com/watch?v=zxjX5p0p0Ks) | [3D Rendering Dataset](https://aigc3d.github.io/gobjaverse) 18 | 19 | ## TODO :triangular_flag_on_post: 20 | - [ ] Release GS、Neus、NeRF reconstruction code. 21 | - [x] News: Release text-to-mv (G-Objaverse + Laion) training code and pretrained model(2024.04.22). Check the Inference&&Training Guidelines. 22 | 23 | Generated Multi-View Images using prompts from DreamFusion420: 24 | 25 | https://github.com/alibaba/VideoMV/assets/58206232/3a4e84e9-a4b2-4ecc-a3e8-7a898e6c3f1a 26 | 27 | 28 | - [x] Release the training code. 29 | - [x] Release multi-view inference code and pretrained weight(G-Objaverse). 30 | 31 | ## Architecture 32 | 33 | ![architecture](assets/f.png) 34 | 35 | ## Install 36 | 37 | - System requirement: Ubuntu20.04 38 | - Tested GPUs: A100 39 | 40 | Install requirements using following scripts. 41 | 42 | ```bash 43 | git clone https://github.com/alibaba/VideoMV.git 44 | conda create -n VideoMV python=3.8 45 | conda activate VideoMV 46 | cd VideoMV && bash install.sh 47 | ``` 48 | 49 | ## Inference 50 | 51 | ```bash 52 | # Download our pretrained models 53 | wget https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/aigc3d/pretrained_models.zip 54 | unzip pretrained_models.zip 55 | # text-to-mv sampling 56 | CUDA_VISIBLE_DEVICES=0 python inference.py --cfg ./configs/t2v_infer.yaml 57 | # text-to-mv sampling using pretrained model trained on laion+Gobjaverse 58 | wget oss://virutalbuy-public/share/aigc3d/videomv_laion/non_ema_00365000.pth 59 | # modify the [test_model] as the location of [non_ema_00365000.pth] 60 | CUDA_VISIBLE_DEVICES=0 python inference.py --cfg ./configs/t2v_infer.yaml 61 | 62 | 63 | # image-to-mv sampling 64 | CUDA_VISIBLE_DEVICES=0 python inference.py --cfg ./configs/i2vgen_xl_infer.yaml 65 | 66 | # To test raw prompts: type the prompts in ./data/test_prompts.txt 67 | 68 | # To test raw images: use Background-Remover(https://www.remove.bg/) to get the foreground of images 69 | # place the images all in /path/to/your_dir 70 | # Then run 71 | python -m utils.recenter_i2v /path/to/your_dir 72 | # The recenter results will be saved in ./data/images 73 | # add test image paths in ./data/test_images.txt 74 | # Then run 75 | CUDA_VISIBLE_DEVICES=0 python inference.py --cfg ./configs/i2vgen_xl_infer.yaml 76 | ``` 77 | 78 | ## Training 79 | 80 | ```bash 81 | # Download our dataset(G-Objaverse) following the instructions at 82 | # https://github.com/modelscope/richdreamer/tree/main/dataset/gobjaverse 83 | # Modify the vid_dataset.data_dir_list as your download data_root 84 | # in ./configs/t2v_train.yaml and ./configs/i2vgen_xl_train.yaml 85 | 86 | # Text-to-mv finetuning 87 | CUDA_VISIBLE_DEVICES=0 python train_net.py --cfg ./configs/t2v_train.yaml 88 | # Text-to-mv fintuning using both Laion and Gobjaverse. 89 | # (Note we use 24 A100 for training both datasets. If your computation resource is not sufficient, do not try it!) 90 | CUDA_VISIBLE_DEVICES=0 python train_net.py --cfg ./configs/t2v_train_laion.yaml 91 | 92 | # Text-to-mv Feed-forward reconstruction finetuning. 93 | # Modify the UNet.use_lgm_refine as 'True' in ./configs/t2v_train.yaml. Then 94 | CUDA_VISIBLE_DEVICES=0 python train_net.py --cfg ./configs/t2v_train.yaml 95 | 96 | 97 | # Image-to-mv finetuning 98 | CUDA_VISIBLE_DEVICES=0 python train_net.py --cfg ./configs/i2vgen_xl_train.yaml 99 | # Image-to-mv Feed-forward reconstruction finetuning. 100 | # Modify the UNet.use_lgm_refine as 'True' in ./configs/i2vgen_xl_train.yaml. Then 101 | CUDA_VISIBLE_DEVICES=0 python train_net.py --cfg ./configs/i2vgen_xl_train.yaml 102 | ``` 103 | 104 | ## Tips 105 | 106 | - You will observe a sudden convergence in Text-to-MV finetuning(~5min). 107 | 108 | - You will not observe a sudden convergence in Image-to-MV finetuning. Usually it takes half a day for a initial convergence. 109 | 110 | - Remove the background of test image use [Background-Remover](https://www.remove.bg/) instead of rembg to get a better result. The artifacts of segmentation mask will influence the quality of multi-view generation results. 111 | 112 | ## Future Works 113 | 114 | - Dense View Large Reconstruction Model. 115 | 116 | - More general and high-quality Text-to-MV using better Video Diffusion Model(like HiGen) and novel finetuning techniques. 117 | 118 | ## Acknowledgement 119 | 120 | This work is built on many amazing research works and open-source projects: 121 | 122 | - [VGen](https://github.com/ali-vilab/VGen) 123 | - [LGM](https://github.com/3DTopia/LGM) 124 | - [SyncDreamer](https://github.com/liuyuan-pal/SyncDreamer) 125 | - [GaussianSplatting](https://github.com/graphdeco-inria/gaussian-splatting) 126 | 127 | Thanks for their excellent work and great contribution to 3D generation area. 128 | 129 | We would like to express our special gratitude to [Jiaxiang Tang](https://github.com/ashawkey), [Yuan Liu](https://github.com/liuyuan-pal) for the valuable discussion in LGM and SyncDreamer. 130 | 131 | 132 | ## Citation 133 | 134 | ``` 135 | @misc{zuo2024videomv, 136 | title={VideoMV: Consistent Multi-View Generation Based on Large Video Generative Model}, 137 | author={Qi Zuo and Xiaodong Gu and Lingteng Qiu and Yuan Dong and Zhengyi Zhao and Weihao Yuan and Rui Peng and Siyu Zhu and Zilong Dong and Liefeng Bo and Qixing Huang}, 138 | year={2024}, 139 | eprint={2403.12010}, 140 | archivePrefix={arXiv}, 141 | primaryClass={cs.CV} 142 | } 143 | ``` 144 | -------------------------------------------------------------------------------- /assets/f.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/assets/f.png -------------------------------------------------------------------------------- /configs/base.yaml: -------------------------------------------------------------------------------- 1 | ENABLE: true 2 | DATASET: webvid10m -------------------------------------------------------------------------------- /configs/i2vgen_xl_infer.yaml: -------------------------------------------------------------------------------- 1 | TASK_TYPE: inference_i2vgen_entrance 2 | use_fp16: True 3 | guide_scale: 6.0 4 | use_fp16: True 5 | chunk_size: 2 6 | decoder_bs: 2 7 | max_frames: 24 8 | target_fps: 8 # FPS Conditions, not the encoding fps 9 | scale: 8 10 | seed: 9999 11 | round: 4 12 | batch_size: 1 13 | use_zero_infer: True 14 | 15 | # For important input 16 | vldm_cfg: configs/i2vgen_xl_train.yaml 17 | test_list_path: data/test_images.txt 18 | test_model: ./pretrained_models/i2v_00882000.pth 19 | log_dir: "workspace/visualization/i2v" 20 | 21 | UNet: { 22 | 'type': 'UNetSD_I2VGen', 23 | 'in_dim': 4, 24 | 'y_dim': 1024, 25 | 'upper_len': 128, 26 | 'context_dim': 1024, 27 | 'concat_dim': 4, 28 | 'out_dim': 4, 29 | 'dim_mult': [1, 2, 4, 4], 30 | 'num_heads': 8, 31 | 'default_fps': 8, 32 | 'head_dim': 64, 33 | 'num_res_blocks': 2, 34 | 'dropout': 0.1, 35 | 'temporal_attention': True, 36 | 'temporal_attn_times': 1, 37 | 'use_checkpoint': True, 38 | 'use_fps_condition': False, 39 | 'use_camera_condition': True, 40 | 'use_lgm_refine': True, # Turn off this if you want to simply fintune a naive i2vgen-xl 41 | 'use_sim_mask': False 42 | } -------------------------------------------------------------------------------- /configs/i2vgen_xl_train.yaml: -------------------------------------------------------------------------------- 1 | TASK_TYPE: train_i2v_entrance 2 | ENABLE: true 3 | use_ema: true 4 | num_workers: 6 5 | frame_lens: [24] 6 | sample_fps: [8] 7 | resolution: [256, 256] 8 | vit_resolution: [224, 224] 9 | 10 | lgm_pretrain: './pretrained_models/model.safetensors' 11 | 12 | vid_dataset: { 13 | 'type': 'Video_I2V_Dataset', 14 | 'data_list': ['./data/valid_paths_v4_cap_filter_thres_28_catfilter19w.json', ], 15 | 'data_dir_list': ['/mnt/objaverse/dataset/raw/0', ], 16 | 'caption_dir': './data/text_captions_cap3d.json', 17 | 'vit_resolution': [224, 224], 18 | 'resolution': [256, 256], 19 | 'get_first_frame': True, 20 | 'max_words': 1000, 21 | 'prepare_lgm': True, 22 | } 23 | 24 | img_dataset: { 25 | 'type': 'ImageDataset', 26 | 'data_list': ['data/img_list.txt', ], 27 | 'data_dir_list': ['data/images', ], 28 | 'vit_resolution': [224, 224], 29 | 'resolution': [256, 256], 30 | 'max_words': 1000 31 | } 32 | 33 | embedder: { 34 | 'type': 'FrozenOpenCLIPTtxtVisualEmbedder', 35 | 'layer': 'penultimate', 36 | 'vit_resolution': [224, 224], 37 | 'pretrained': './pretrained_models/modelscope_i2v/I2VGen-XL/open_clip_pytorch_model.bin' 38 | } 39 | 40 | UNet: { 41 | 'type': 'UNetSD_I2VGen', 42 | 'in_dim': 4, 43 | 'y_dim': 1024, 44 | 'upper_len': 128, 45 | 'context_dim': 1024, 46 | 'concat_dim': 4, 47 | 'out_dim': 4, 48 | 'dim_mult': [1, 2, 4, 4], 49 | 'num_heads': 8, 50 | 'default_fps': 8, 51 | 'head_dim': 64, 52 | 'num_res_blocks': 2, 53 | 'dropout': 0.1, 54 | 'temporal_attention': True, 55 | 'temporal_attn_times': 1, 56 | 'use_checkpoint': True, 57 | 'use_fps_condition': False, 58 | 'use_camera_condition': True, 59 | 'use_lgm_refine': False, # Turn off this if you want to simply fintune a naive i2vgen-xl 60 | 'use_sim_mask': False 61 | } 62 | 63 | Diffusion: { 64 | 'type': 'DiffusionDDIM', 65 | 'schedule': 'cosine', # cosine 66 | 'schedule_param': { 67 | 'num_timesteps': 1000, 68 | 'cosine_s': 0.008, 69 | 'zero_terminal_snr': True, 70 | }, 71 | 'mean_type': 'v', 72 | 'loss_type': 'mse', 73 | 'var_type': 'fixed_small', 74 | 'rescale_timesteps': False, 75 | 'noise_strength': 0.1 76 | } 77 | 78 | batch_sizes: { 79 | "24": 8, 80 | } 81 | 82 | visual_train: { 83 | 'type': 'VisualTrainTextImageToVideo', 84 | 'partial_keys': [ 85 | ['y', 'image', 'local_image', 'fps', 'camera_data', 'gs_data'] 86 | ], 87 | 'use_offset_noise': True, 88 | 'guide_scale': 6.0, 89 | } 90 | 91 | Pretrain: { 92 | 'type': pretrain_specific_strategies, 93 | 'fix_weight': False, 94 | 'grad_scale': 0.5, 95 | 'resume_checkpoint': './pretrained_models/modelscope_i2v/I2VGen-XL/i2vgen_xl_00854500.pth', 96 | 'sd_keys_path': './pretrained_models/modelscope_i2v/I2VGen-XL/stable_diffusion_image_key_temporal_attention_x1.json', 97 | } 98 | 99 | chunk_size: 4 100 | decoder_bs: 4 101 | lr: 0.00003 102 | 103 | noise_strength: 0.1 104 | # classifier-free guidance 105 | p_zero: 0.0 106 | guide_scale: 3.0 107 | num_steps: 1000000 108 | 109 | use_zero_infer: True 110 | viz_interval: 200 # 200 111 | save_ckp_interval: 500 # 500 112 | 113 | # Log 114 | log_dir: "workspace/experiments_i2v" 115 | log_interval: 1 116 | seed: 6666 117 | -------------------------------------------------------------------------------- /configs/t2v_infer.yaml: -------------------------------------------------------------------------------- 1 | TASK_TYPE: inference_text2video_entrance 2 | use_fp16: False 3 | guide_scale: 9.0 4 | chunk_size: 4 5 | decoder_bs: 4 6 | max_frames: 24 7 | target_fps: 8 # FPS Conditions, not encoding fps 8 | scale: 8 9 | batch_size: 1 10 | use_zero_infer: True 11 | 12 | round: 2 13 | seed: 11 14 | 15 | test_list_path: ./data/test_prompts.txt 16 | vldm_cfg: configs/t2v_train.yaml 17 | test_model: ./pretrained_modesl/t2v_00333000.pth 18 | log_dir: ./workspace/visualization/t2v 19 | 20 | UNet: { 21 | 'type': 'UNetSD_T2VBase', 22 | 'in_dim': 4, 23 | 'y_dim': 1024, 24 | 'upper_len': 128, 25 | 'context_dim': 1024, 26 | 'out_dim': 4, 27 | 'dim_mult': [1, 2, 4, 4], 28 | 'num_heads': 8, 29 | 'default_fps': 8, 30 | 'head_dim': 64, 31 | 'num_res_blocks': 2, 32 | 'dropout': 0.1, 33 | 'misc_dropout': 0.4, 34 | 'temporal_attention': True, 35 | 'temporal_attn_times': 1, 36 | 'use_checkpoint': True, 37 | 'use_fps_condition': False, 38 | 'use_camera_condition': True, # Turn off this if you are trained on multi-view images with fixed poses. 39 | 'use_lgm_refine': True, 40 | 'use_sim_mask': False 41 | } -------------------------------------------------------------------------------- /configs/t2v_train.yaml: -------------------------------------------------------------------------------- 1 | TASK_TYPE: train_t2v_entrance 2 | ENABLE: true 3 | use_ema: false 4 | num_workers: 10 5 | frame_lens: [24] 6 | sample_fps: [8] 7 | resolution: [256, 256] 8 | vit_resolution: [224, 224] 9 | lgm_pretrain: './pretrained_models/model.safetensors' 10 | 11 | vid_dataset: { 12 | 'type': 'VideoDataset', 13 | 'data_list': ['./data/lvis_thres_28.json', ], 14 | 'data_dir_list': ['/mnt/objaverse/dataset/raw/0', ], 15 | 'caption_dir': './data/text_captions_cap3d.json', 16 | 'vit_resolution': [224, 224], 17 | 'resolution': [256, 256], 18 | 'get_first_frame': True, 19 | 'max_words': 1000, 20 | 'prepare_lgm': True, 21 | } 22 | 23 | img_dataset: { 24 | 'type': 'ImageDataset', 25 | 'data_list': ['data/img_list.txt', ], 26 | 'data_dir_list': ['data/images', ], 27 | 'vit_resolution': [224, 224], 28 | 'resolution': [256, 256], 29 | 'max_words': 1000 30 | } 31 | embedder: { 32 | 'type': 'FrozenOpenCLIPTtxtVisualEmbedder', 33 | 'layer': 'penultimate', 34 | 'vit_resolution': [224, 224], 35 | 'pretrained': './pretrained_models/modelscope_t2v/open_clip_pytorch_model.bin' 36 | } 37 | 38 | UNet: { 39 | 'type': 'UNetSD_T2VBase', 40 | 'in_dim': 4, 41 | 'y_dim': 1024, 42 | 'upper_len': 128, 43 | 'context_dim': 1024, 44 | 'out_dim': 4, 45 | 'dim_mult': [1, 2, 4, 4], 46 | 'num_heads': 8, 47 | 'default_fps': 8, 48 | 'head_dim': 64, 49 | 'num_res_blocks': 2, 50 | 'dropout': 0.1, 51 | 'misc_dropout': 0.4, 52 | 'temporal_attention': True, 53 | 'temporal_attn_times': 1, 54 | 'use_checkpoint': True, 55 | 'use_fps_condition': False, 56 | 'use_camera_condition': True, # Turn off this if you are trained on multi-view images with fixed poses. 57 | 'use_lgm_refine': False, 58 | 'use_sim_mask': False 59 | } 60 | 61 | Diffusion: { 62 | 'type': 'DiffusionDDIM', 63 | 'schedule': 'linear_sd', # cosine 64 | 'schedule_param': { 65 | 'num_timesteps': 1000, 66 | 'init_beta': 0.00085, 67 | 'last_beta': 0.0120, 68 | 'zero_terminal_snr': False, 69 | }, 70 | 'mean_type': 'eps', # eps for baseline with no lgm reg 71 | 'loss_type': 'mse', 72 | 'var_type': 'fixed_small', 73 | 'rescale_timesteps': False, 74 | 'noise_strength': 0.0 75 | } 76 | 77 | batch_sizes: { 78 | "1": 32, 79 | "24": 8, 80 | } 81 | 82 | visual_train: { 83 | 'type': 'VisualTrainTextImageToVideo', 84 | 'partial_keys': [ 85 | ['y', 'fps', 'camera_data', 'gs_data'], 86 | ], 87 | 'use_offset_noise': False, 88 | 'guide_scale': 9.0, 89 | } 90 | 91 | Pretrain: { 92 | 'type': pretrain_specific_strategies, 93 | 'fix_weight': False, 94 | 'grad_scale': 0.5, 95 | 'resume_checkpoint': './pretrained_models/modelscope_t2v/model_scope_0267000.pth', 96 | 'sd_keys_path': 'data/stable_diffusion_image_key_temporal_attention_x1.json', 97 | } 98 | 99 | chunk_size: 4 100 | decoder_bs: 4 101 | lr: 0.00003 # 0.00003 102 | 103 | noise_strength: 0.0 # no noise 104 | # classifier-free guidance 105 | p_zero: 0.1 106 | guide_scale: 3.0 107 | num_steps: 1000000 108 | 109 | use_zero_infer: True 110 | viz_interval: 50 # 200 111 | save_ckp_interval: 500 # 500 112 | 113 | # Log 114 | log_dir: "workspace/experiment_t2v" 115 | log_interval: 1 116 | seed: 0 -------------------------------------------------------------------------------- /configs/t2v_train_laion.yaml: -------------------------------------------------------------------------------- 1 | TASK_TYPE: train_t2v_entrance 2 | ENABLE: true 3 | use_ema: false 4 | num_workers: 4 5 | frame_lens: [1, 24, 24, 24, 24, 24, 24, 24] 6 | sample_fps: [1, 8, 8, 8, 8, 8, 8, 8] 7 | resolution: [256, 256] 8 | vit_resolution: [224, 224] 9 | lgm_pretrain: './pretrained_models/model.safetensors' 10 | 11 | vid_dataset: { 12 | 'type': 'VideoDataset', 13 | # 'data_list': ['/mnt/cap/muyuan/code/StableVideoDiffusion/StableVideoDiffusion/valid_paths_v4_cap_filter_thres_28.json', ], 14 | 'data_list': ['./data/lvis_thres_28.json', ], 15 | 'data_dir_list': ['/mnt/objaverse/dataset/raw/0', ], 16 | 'caption_dir': './data/text_captions_cap3d.json', 17 | 'vit_resolution': [224, 224], 18 | 'resolution': [256, 256], 19 | 'get_first_frame': True, 20 | 'max_words': 1000, 21 | 'prepare_lgm': False, 22 | } 23 | 24 | img_dataset: { 25 | 'type': 'LAIONImageDataset', 26 | 'data_list': ['{00000..60580}.tar', ], 27 | 'data_dir_list': ['/mnt/laion/dataset/laion2b-en-ath5/improved_aesthetics_5plus/laion-2ben-5_0/', ], 28 | 'vit_resolution': [224, 224], 29 | 'resolution': [256, 256], 30 | 'max_words': 1000, 31 | } 32 | 33 | embedder: { 34 | 'type': 'FrozenOpenCLIPTtxtVisualEmbedder', 35 | 'layer': 'penultimate', 36 | 'vit_resolution': [224, 224], 37 | 'pretrained': './pretrained_models/modelscope_t2v/open_clip_pytorch_model.bin' 38 | } 39 | 40 | UNet: { 41 | 'type': 'UNetSD_T2VBase', 42 | 'in_dim': 4, 43 | 'y_dim': 1024, 44 | 'upper_len': 128, 45 | 'context_dim': 1024, 46 | 'out_dim': 4, 47 | 'dim_mult': [1, 2, 4, 4], 48 | 'num_heads': 8, 49 | 'default_fps': 8, 50 | 'head_dim': 64, 51 | 'num_res_blocks': 2, 52 | 'dropout': 0.1, 53 | 'misc_dropout': 0.4, 54 | 'temporal_attention': True, 55 | 'temporal_attn_times': 1, 56 | 'use_checkpoint': True, 57 | 'use_fps_condition': False, 58 | 'use_camera_condition': True, # Turn off this if you are trained on multi-view images with fixed poses. 59 | 'use_sync_attention': False, # Turn off this if you do not wish to use SyncAttention. 60 | 'use_flexicube_reg': False, # Turn off this if you do not wish to use a 3D reguralization. 61 | 'use_lgm_reg': False, # Turn off this if you do not wish to use a lgm reguralization. 62 | 'use_lgm_refine': False, 63 | 'use_sim_mask': False 64 | } 65 | # Diffusion: { 66 | # 'type': 'DiffusionDDIM', 67 | # 'schedule': 'cosine', # cosine 68 | # 'schedule_param': { 69 | # 'num_timesteps': 1000, 70 | # 'cosine_s': 0.008, 71 | # 'zero_terminal_snr': True, 72 | # }, 73 | # 'mean_type': 'v', 74 | # 'loss_type': 'mse', 75 | # 'var_type': 'fixed_small', 76 | # 'rescale_timesteps': False, 77 | # 'noise_strength': 0.1 78 | # } 79 | 80 | Diffusion: { 81 | 'type': 'DiffusionDDIM', 82 | 'schedule': 'linear_sd', # cosine 83 | 'schedule_param': { 84 | 'num_timesteps': 1000, 85 | 'init_beta': 0.00085, 86 | 'last_beta': 0.0120, 87 | 'zero_terminal_snr': False, 88 | }, 89 | 'mean_type': 'eps', # eps for baseline with no lgm reg 90 | 'loss_type': 'mse', 91 | 'var_type': 'fixed_small', 92 | 'rescale_timesteps': False, 93 | 'noise_strength': 0.0 94 | } 95 | 96 | batch_sizes: { 97 | "1": 196, 98 | "24": 24, 99 | } 100 | 101 | visual_train: { 102 | 'type': 'VisualTrainTextImageToVideo', 103 | 'partial_keys': [ 104 | ['y', 'fps', 'camera_data', 'gs_data'], 105 | ], 106 | 'use_offset_noise': False, 107 | 'guide_scale': 9.0, 108 | } 109 | 110 | Pretrain: { 111 | 'type': pretrain_specific_strategies, 112 | 'fix_weight': False, 113 | 'grad_scale': 0.5, 114 | 'resume_checkpoint': './pretrained_modesl/t2v_00333000.pth', 115 | 'sd_keys_path': 'data/stable_diffusion_image_key_temporal_attention_x1.json', 116 | } 117 | 118 | chunk_size: 4 119 | decoder_bs: 4 120 | lr: 0.00003 # 0.00003 121 | 122 | noise_strength: 0.0 # no noise 123 | # classifier-free guidance 124 | p_zero: 0.1 125 | guide_scale: 3.0 126 | num_steps: 1000000 127 | 128 | use_zero_infer: True 129 | viz_interval: 200 # 200 130 | save_ckp_interval: 500 # 500 131 | 132 | # Log 133 | log_dir: "workspace/experiments_laion" 134 | log_interval: 1 135 | seed: 0 -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/core/__init__.py -------------------------------------------------------------------------------- /core/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/core/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /core/__pycache__/attention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/core/__pycache__/attention.cpython-38.pyc -------------------------------------------------------------------------------- /core/__pycache__/gs.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/core/__pycache__/gs.cpython-38.pyc -------------------------------------------------------------------------------- /core/__pycache__/models.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/core/__pycache__/models.cpython-38.pyc -------------------------------------------------------------------------------- /core/__pycache__/options.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/core/__pycache__/options.cpython-38.pyc -------------------------------------------------------------------------------- /core/__pycache__/unet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/core/__pycache__/unet.cpython-38.pyc -------------------------------------------------------------------------------- /core/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/core/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /core/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py 9 | 10 | import os 11 | import warnings 12 | 13 | from torch import Tensor 14 | from torch import nn 15 | 16 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None 17 | try: 18 | if XFORMERS_ENABLED: 19 | from xformers.ops import memory_efficient_attention, unbind 20 | 21 | XFORMERS_AVAILABLE = True 22 | warnings.warn("xFormers is available (Attention)") 23 | else: 24 | warnings.warn("xFormers is disabled (Attention)") 25 | raise ImportError 26 | except ImportError: 27 | XFORMERS_AVAILABLE = False 28 | warnings.warn("xFormers is not available (Attention)") 29 | 30 | 31 | class Attention(nn.Module): 32 | def __init__( 33 | self, 34 | dim: int, 35 | num_heads: int = 8, 36 | qkv_bias: bool = False, 37 | proj_bias: bool = True, 38 | attn_drop: float = 0.0, 39 | proj_drop: float = 0.0, 40 | ) -> None: 41 | super().__init__() 42 | self.num_heads = num_heads 43 | head_dim = dim // num_heads 44 | self.scale = head_dim**-0.5 45 | 46 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 47 | self.attn_drop = nn.Dropout(attn_drop) 48 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 49 | self.proj_drop = nn.Dropout(proj_drop) 50 | 51 | def forward(self, x: Tensor) -> Tensor: 52 | B, N, C = x.shape 53 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 54 | 55 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] 56 | attn = q @ k.transpose(-2, -1) 57 | 58 | attn = attn.softmax(dim=-1) 59 | attn = self.attn_drop(attn) 60 | 61 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 62 | x = self.proj(x) 63 | x = self.proj_drop(x) 64 | return x 65 | 66 | 67 | class MemEffAttention(Attention): 68 | def forward(self, x: Tensor, attn_bias=None) -> Tensor: 69 | if not XFORMERS_AVAILABLE: 70 | if attn_bias is not None: 71 | raise AssertionError("xFormers is required for using nested tensors") 72 | return super().forward(x) 73 | 74 | B, N, C = x.shape 75 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 76 | 77 | q, k, v = unbind(qkv, 2) 78 | 79 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) 80 | x = x.reshape([B, N, C]) 81 | 82 | x = self.proj(x) 83 | x = self.proj_drop(x) 84 | return x 85 | 86 | 87 | class CrossAttention(nn.Module): 88 | def __init__( 89 | self, 90 | dim: int, 91 | dim_q: int, 92 | dim_k: int, 93 | dim_v: int, 94 | num_heads: int = 8, 95 | qkv_bias: bool = False, 96 | proj_bias: bool = True, 97 | attn_drop: float = 0.0, 98 | proj_drop: float = 0.0, 99 | ) -> None: 100 | super().__init__() 101 | self.dim = dim 102 | self.num_heads = num_heads 103 | head_dim = dim // num_heads 104 | self.scale = head_dim**-0.5 105 | 106 | self.to_q = nn.Linear(dim_q, dim, bias=qkv_bias) 107 | self.to_k = nn.Linear(dim_k, dim, bias=qkv_bias) 108 | self.to_v = nn.Linear(dim_v, dim, bias=qkv_bias) 109 | self.attn_drop = nn.Dropout(attn_drop) 110 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 111 | self.proj_drop = nn.Dropout(proj_drop) 112 | 113 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 114 | # q: [B, N, Cq] 115 | # k: [B, M, Ck] 116 | # v: [B, M, Cv] 117 | # return: [B, N, C] 118 | 119 | B, N, _ = q.shape 120 | M = k.shape[1] 121 | 122 | q = self.scale * self.to_q(q).reshape(B, N, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, N, C/nh] 123 | k = self.to_k(k).reshape(B, M, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, M, C/nh] 124 | v = self.to_v(v).reshape(B, M, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, M, C/nh] 125 | 126 | attn = q @ k.transpose(-2, -1) # [B, nh, N, M] 127 | 128 | attn = attn.softmax(dim=-1) # [B, nh, N, M] 129 | attn = self.attn_drop(attn) 130 | 131 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1) # [B, nh, N, M] @ [B, nh, M, C/nh] --> [B, nh, N, C/nh] --> [B, N, nh, C/nh] --> [B, N, C] 132 | x = self.proj(x) 133 | x = self.proj_drop(x) 134 | return x 135 | 136 | 137 | class MemEffCrossAttention(CrossAttention): 138 | def forward(self, q: Tensor, k: Tensor, v: Tensor, attn_bias=None) -> Tensor: 139 | if not XFORMERS_AVAILABLE: 140 | if attn_bias is not None: 141 | raise AssertionError("xFormers is required for using nested tensors") 142 | return super().forward(x) 143 | 144 | B, N, _ = q.shape 145 | M = k.shape[1] 146 | 147 | q = self.scale * self.to_q(q).reshape(B, N, self.num_heads, self.dim // self.num_heads) # [B, N, nh, C/nh] 148 | k = self.to_k(k).reshape(B, M, self.num_heads, self.dim // self.num_heads) # [B, M, nh, C/nh] 149 | v = self.to_v(v).reshape(B, M, self.num_heads, self.dim // self.num_heads) # [B, M, nh, C/nh] 150 | 151 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) 152 | x = x.reshape(B, N, -1) 153 | 154 | x = self.proj(x) 155 | x = self.proj_drop(x) 156 | return x 157 | -------------------------------------------------------------------------------- /core/gs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from diff_gaussian_rasterization import ( 8 | GaussianRasterizationSettings, 9 | GaussianRasterizer, 10 | ) 11 | 12 | from core.options import Options 13 | 14 | import kiui 15 | 16 | class GaussianRenderer: 17 | def __init__(self, opt: Options): 18 | 19 | self.opt = opt 20 | self.bg_color = torch.tensor([1,1,1], dtype=torch.float32, device="cuda") 21 | 22 | # intrinsics 23 | self.tan_half_fov = np.tan(0.5 * np.deg2rad(self.opt.fovy)) 24 | self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32) 25 | self.proj_matrix[0, 0] = 1 / self.tan_half_fov 26 | self.proj_matrix[1, 1] = 1 / self.tan_half_fov 27 | self.proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear) 28 | self.proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear) 29 | self.proj_matrix[2, 3] = 1 30 | 31 | def render(self, gaussians, cam_view, cam_view_proj, cam_pos, bg_color=None, scale_modifier=1): 32 | # gaussians: [B, N, 14] 33 | # cam_view, cam_view_proj: [B, V, 4, 4] 34 | # cam_pos: [B, V, 3] 35 | 36 | device = gaussians.device 37 | B, V = cam_view.shape[:2] 38 | 39 | # loop of loop... 40 | images = [] 41 | alphas = [] 42 | for b in range(B): 43 | # pos, opacity, scale, rotation, shs 44 | means3D = gaussians[b, :, 0:3].contiguous().float() 45 | opacity = gaussians[b, :, 3:4].contiguous().float() 46 | scales = gaussians[b, :, 4:7].contiguous().float() 47 | rotations = gaussians[b, :, 7:11].contiguous().float() 48 | rgbs = gaussians[b, :, 11:].contiguous().float() # [N, 4] 49 | 50 | for v in range(V): 51 | 52 | # render novel views 53 | view_matrix = cam_view[b, v].float() 54 | view_proj_matrix = cam_view_proj[b, v].float() 55 | campos = cam_pos[b, v].float() 56 | 57 | raster_settings = GaussianRasterizationSettings( 58 | image_height=self.opt.output_size, 59 | image_width=self.opt.output_size, 60 | tanfovx=self.tan_half_fov, 61 | tanfovy=self.tan_half_fov, 62 | bg=self.bg_color if bg_color is None else bg_color, 63 | scale_modifier=scale_modifier, 64 | viewmatrix=view_matrix, 65 | projmatrix=view_proj_matrix, 66 | sh_degree=0, 67 | campos=campos, 68 | prefiltered=False, 69 | debug=False, 70 | ) 71 | 72 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 73 | # Rasterize visible Gaussians to image, obtain their radii (on screen). 74 | rendered_image, radii, rendered_depth, rendered_alpha = rasterizer( 75 | means3D=means3D, 76 | means2D=torch.zeros_like(means3D, dtype=torch.float32, device=device), 77 | shs=None, 78 | colors_precomp=rgbs, 79 | opacities=opacity, 80 | scales=scales, 81 | rotations=rotations, 82 | cov3D_precomp=None, 83 | ) 84 | rendered_image = rendered_image.clamp(0, 1) 85 | images.append(rendered_image) 86 | alphas.append(rendered_alpha) 87 | 88 | images = torch.stack(images, dim=0).view(B, V, 3, self.opt.output_size, self.opt.output_size) # we use 4 for latent 89 | alphas = torch.stack(alphas, dim=0).view(B, V, 1, self.opt.output_size, self.opt.output_size) 90 | 91 | return { 92 | "image": images, # [B, V, 4, H, W] 93 | "alpha": alphas, # [B, V, 1, H, W] 94 | } 95 | 96 | 97 | def save_ply(self, gaussians, path, compatible=True): 98 | # gaussians: [B, N, 14] 99 | # compatible: save pre-activated gaussians as in the original paper 100 | 101 | assert gaussians.shape[0] == 1, 'only support batch size 1' 102 | 103 | from plyfile import PlyData, PlyElement 104 | 105 | means3D = gaussians[0, :, 0:3].contiguous().float() 106 | opacity = gaussians[0, :, 3:4].contiguous().float() 107 | scales = gaussians[0, :, 4:7].contiguous().float() 108 | rotations = gaussians[0, :, 7:11].contiguous().float() 109 | shs = gaussians[0, :, 11:].unsqueeze(1).contiguous().float() # [N, 1, 3] 110 | 111 | # prune by opacity 112 | mask = opacity.squeeze(-1) >= 0.005 113 | means3D = means3D[mask] 114 | opacity = opacity[mask] 115 | scales = scales[mask] 116 | rotations = rotations[mask] 117 | shs = shs[mask] 118 | 119 | # invert activation to make it compatible with the original ply format 120 | if compatible: 121 | opacity = kiui.op.inverse_sigmoid(opacity) 122 | scales = torch.log(scales + 1e-8) 123 | shs = (shs - 0.5) / 0.28209479177387814 124 | 125 | xyzs = means3D.detach().cpu().numpy() 126 | f_dc = shs.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() 127 | opacities = opacity.detach().cpu().numpy() 128 | scales = scales.detach().cpu().numpy() 129 | rotations = rotations.detach().cpu().numpy() 130 | 131 | l = ['x', 'y', 'z'] 132 | # All channels except the 3 DC 133 | for i in range(f_dc.shape[1]): 134 | l.append('f_dc_{}'.format(i)) 135 | l.append('opacity') 136 | for i in range(scales.shape[1]): 137 | l.append('scale_{}'.format(i)) 138 | for i in range(rotations.shape[1]): 139 | l.append('rot_{}'.format(i)) 140 | 141 | dtype_full = [(attribute, 'f4') for attribute in l] 142 | 143 | elements = np.empty(xyzs.shape[0], dtype=dtype_full) 144 | attributes = np.concatenate((xyzs, f_dc, opacities, scales, rotations), axis=1) 145 | elements[:] = list(map(tuple, attributes)) 146 | el = PlyElement.describe(elements, 'vertex') 147 | 148 | PlyData([el]).write(path) 149 | 150 | def load_ply(self, path, compatible=True): 151 | 152 | from plyfile import PlyData, PlyElement 153 | 154 | plydata = PlyData.read(path) 155 | 156 | xyz = np.stack((np.asarray(plydata.elements[0]["x"]), 157 | np.asarray(plydata.elements[0]["y"]), 158 | np.asarray(plydata.elements[0]["z"])), axis=1) 159 | print("Number of points at loading : ", xyz.shape[0]) 160 | 161 | opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] 162 | 163 | shs = np.zeros((xyz.shape[0], 3)) 164 | shs[:, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) 165 | shs[:, 1] = np.asarray(plydata.elements[0]["f_dc_1"]) 166 | shs[:, 2] = np.asarray(plydata.elements[0]["f_dc_2"]) 167 | 168 | scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")] 169 | scales = np.zeros((xyz.shape[0], len(scale_names))) 170 | for idx, attr_name in enumerate(scale_names): 171 | scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) 172 | 173 | rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot_")] 174 | rots = np.zeros((xyz.shape[0], len(rot_names))) 175 | for idx, attr_name in enumerate(rot_names): 176 | rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) 177 | 178 | gaussians = np.concatenate([xyz, opacities, scales, rots, shs], axis=1) 179 | gaussians = torch.from_numpy(gaussians).float() # cpu 180 | 181 | if compatible: 182 | gaussians[..., 3:4] = torch.sigmoid(gaussians[..., 3:4]) 183 | gaussians[..., 4:7] = torch.exp(gaussians[..., 4:7]) 184 | gaussians[..., 11:] = 0.28209479177387814 * gaussians[..., 11:] + 0.5 185 | 186 | return gaussians -------------------------------------------------------------------------------- /core/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | import kiui 7 | from kiui.lpips import LPIPS 8 | 9 | from core.unet import UNet 10 | from core.options import Options 11 | from core.gs import GaussianRenderer 12 | 13 | 14 | class LGM(nn.Module): 15 | def __init__( 16 | self, 17 | opt: Options, 18 | ): 19 | super().__init__() 20 | 21 | self.opt = opt 22 | 23 | # unet 24 | self.unet = UNet( 25 | 9, 14, 26 | down_channels=self.opt.down_channels, 27 | down_attention=self.opt.down_attention, 28 | mid_attention=self.opt.mid_attention, 29 | up_channels=self.opt.up_channels, 30 | up_attention=self.opt.up_attention, 31 | ) 32 | # x = F.interpolate(x, scale_factor=2.0, mode='nearest') 33 | # last conv 34 | self.conv = nn.Conv2d(14, 14, kernel_size=1) # NOTE: maybe remove it if train again 35 | # Gaussian Renderer 36 | self.gs = GaussianRenderer(opt) 37 | 38 | # activations... 39 | self.pos_act = lambda x: x.clamp(-1, 1) 40 | self.scale_act = lambda x: 0.1 * F.softplus(x) 41 | self.opacity_act = lambda x: torch.sigmoid(x) 42 | self.rot_act = F.normalize 43 | self.rgb_act = lambda x: 0.5 * torch.tanh(x) + 0.5 # NOTE: may use sigmoid if train again 44 | 45 | # LPIPS loss 46 | if self.opt.lambda_lpips > 0: 47 | self.lpips_loss = LPIPS(net='vgg') 48 | self.lpips_loss.requires_grad_(False) 49 | 50 | 51 | def state_dict(self, **kwargs): 52 | # remove lpips_loss 53 | state_dict = super().state_dict(**kwargs) 54 | for k in list(state_dict.keys()): 55 | if 'lpips_loss' in k: 56 | del state_dict[k] 57 | return state_dict 58 | 59 | 60 | def prepare_default_rays(self, device, elevation=0): 61 | 62 | from kiui.cam import orbit_camera 63 | from core.utils import get_rays 64 | 65 | cam_poses = np.stack([ 66 | orbit_camera(elevation, 0, radius=self.opt.cam_radius), 67 | orbit_camera(elevation, 90, radius=self.opt.cam_radius), 68 | orbit_camera(elevation, 180, radius=self.opt.cam_radius), 69 | orbit_camera(elevation, 270, radius=self.opt.cam_radius), 70 | ], axis=0) # [4, 4, 4] 71 | cam_poses = torch.from_numpy(cam_poses) 72 | # print("default_rays:", cam_poses) 73 | rays_embeddings = [] 74 | for i in range(cam_poses.shape[0]): 75 | rays_o, rays_d = get_rays(cam_poses[i], self.opt.input_size, self.opt.input_size, self.opt.fovy) # [h, w, 3] 76 | rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6] 77 | rays_embeddings.append(rays_plucker) 78 | 79 | ## visualize rays for plotting figure 80 | # kiui.vis.plot_image(rays_d * 0.5 + 0.5, save=True) 81 | 82 | rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous().to(device) # [V, 6, h, w] 83 | 84 | return rays_embeddings 85 | 86 | 87 | def forward_gaussians(self, images): 88 | # images: [B, 4, 9, H, W] 89 | # return: Gaussians: [B, dim_t] 90 | 91 | B, V, C, H, W = images.shape 92 | images = images.view(B*V, C, H, W) 93 | 94 | x = self.unet(images) # [B*24, 14, h, w] 95 | x = self.conv(x) # [B*24, 14, h, w] 96 | 97 | x = x.reshape(B, self.opt.num_input_views, 14, self.opt.splat_size, self.opt.splat_size) # hard code: 24?? 98 | 99 | ## visualize multi-view gaussian features for plotting figure 100 | # tmp_alpha = self.opacity_act(x[0, :, 3:4]) 101 | # tmp_img_rgb = self.rgb_act(x[0, :, 11:]) * tmp_alpha + (1 - tmp_alpha) 102 | # tmp_img_pos = self.pos_act(x[0, :, 0:3]) * 0.5 + 0.5 103 | # kiui.vis.plot_image(tmp_img_rgb, save=True) 104 | # kiui.vis.plot_image(tmp_img_pos, save=True) 105 | 106 | x = x.permute(0, 1, 3, 4, 2).reshape(B, -1, 14) 107 | 108 | pos = self.pos_act(x[..., 0:3]) # [B, N, 3] 109 | opacity = self.opacity_act(x[..., 3:4]) 110 | scale = self.scale_act(x[..., 4:7]) 111 | rotation = self.rot_act(x[..., 7:11]) 112 | rgbs = self.rgb_act(x[..., 11:]) 113 | 114 | gaussians = torch.cat([pos, opacity, scale, rotation, rgbs], dim=-1) # [B, N, 14] 115 | 116 | return gaussians 117 | 118 | def infer(self, data, step_ratio=1, bg_color_factor=0.5): 119 | results = {} 120 | 121 | images = data['input'] # [B, 4, 9, h, W], input features 122 | 123 | # use the first view to predict gaussians 124 | gaussians = self.forward_gaussians(images) # [B, N, 14] 125 | 126 | bg_color = torch.ones(3, dtype=torch.float32, device=gaussians.device)*bg_color_factor 127 | 128 | # use the other views for rendering and supervision 129 | results = self.gs.render(gaussians, data['cam_view'], data['cam_view_proj'], data['cam_pos'], bg_color=bg_color) 130 | pred_images = results['image'] # [B, V, C, output_size, output_size] 131 | 132 | results['images_pred'] = pred_images 133 | 134 | return results 135 | 136 | def forward(self, data, step_ratio=1): 137 | # data: output of the dataloader 138 | # return: loss 139 | 140 | results = {} 141 | loss = 0 142 | 143 | images = data['input'] # [B, 4, 9, h, W], input features 144 | 145 | # use the first view to predict gaussians 146 | gaussians = self.forward_gaussians(images) # [B, N, 14] 147 | 148 | results['gaussians'] = gaussians 149 | 150 | # random bg for training 151 | if self.training: 152 | bg_color = torch.rand(3, dtype=torch.float32, device=gaussians.device) 153 | else: 154 | bg_color = torch.ones(3, dtype=torch.float32, device=gaussians.device) 155 | 156 | # use the other views for rendering and supervision 157 | results = self.gs.render(gaussians, data['cam_view'], data['cam_view_proj'], data['cam_pos'], bg_color=bg_color) 158 | pred_images = results['image'] # [B, V, C, output_size, output_size] 159 | pred_alphas = results['alpha'] # [B, V, 1, output_size, output_size] 160 | 161 | results['images_pred'] = pred_images 162 | results['alphas_pred'] = pred_alphas 163 | 164 | gt_images = data['images_output'] # [B, V, 3, output_size, output_size], ground-truth novel views 165 | gt_masks = data['masks_output'] # [B, V, 1, output_size, output_size], ground-truth masks 166 | 167 | gt_images = gt_images * gt_masks + bg_color.view(1, 1, 3, 1, 1) * (1 - gt_masks) 168 | 169 | loss_mse = F.mse_loss(pred_images.half(), gt_images.half()) + F.mse_loss(pred_alphas.half(), gt_masks.half()) 170 | loss = loss + loss_mse 171 | 172 | if self.opt.lambda_lpips > 0: 173 | loss_lpips = self.lpips_loss( 174 | # gt_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, 175 | # pred_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, 176 | # downsampled to at most 256 to reduce memory cost 177 | F.interpolate(gt_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False).half(), 178 | F.interpolate(pred_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False).half(), 179 | ).mean() 180 | results['loss_lpips'] = loss_lpips 181 | loss = loss + self.opt.lambda_lpips * loss_lpips 182 | 183 | results['loss'] = loss 184 | 185 | # metric 186 | with torch.no_grad(): 187 | psnr = -10 * torch.log10(torch.mean((pred_images.detach() - gt_images) ** 2)) 188 | results['psnr'] = psnr 189 | 190 | 191 | 192 | return results -------------------------------------------------------------------------------- /core/options.py: -------------------------------------------------------------------------------- 1 | import tyro 2 | from dataclasses import dataclass 3 | from typing import Tuple, Literal, Dict, Optional 4 | 5 | 6 | @dataclass 7 | class Options: 8 | ### model 9 | # Unet image input size 10 | input_size: int = 256 11 | # Unet definition 12 | down_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024, 1024) 13 | down_attention: Tuple[bool, ...] = (False, False, False, True, True, True) 14 | mid_attention: bool = True 15 | up_channels: Tuple[int, ...] = (1024, 1024, 512, 256) 16 | up_attention: Tuple[bool, ...] = (True, True, True, False) 17 | # Unet output size, dependent on the input_size and U-Net structure! 18 | splat_size: int = 64 19 | # gaussian render size 20 | output_size: int = 256 21 | 22 | ### dataset 23 | # data mode (only support s3 now) 24 | data_mode: Literal['s3'] = 's3' 25 | # fovy of the dataset 26 | fovy: float = 39.6 # 39.6 # 49.1 27 | # camera near plane 28 | znear: float = 0.5 # 0.1 # 0.5 29 | # camera far plane 30 | zfar: float = 2.5 # 1000 # 2.5 31 | # number of all views (input + output) 32 | num_views: int = 8 33 | # number of views 34 | num_input_views: int = 4 35 | # camera radius 36 | cam_radius: float = 1.5 # to better use [-1, 1]^3 space 37 | # num workers 38 | num_workers: int = 8 39 | 40 | ### training 41 | # workspace 42 | workspace: str = './workspace' 43 | # resume 44 | resume: Optional[str] = "/mnt/cap/muyuan/code/StableVideoDiffusion/StableVideoDiffusion/i2vgen-xl/LGM/pretrained/model_fp16.safetensors" 45 | # batch size (per-GPU) 46 | batch_size: int = 8 47 | # gradient accumulation 48 | gradient_accumulation_steps: int = 1 49 | # training epochs 50 | num_epochs: int = 30 51 | # lpips loss weight 52 | lambda_lpips: float = 1.0 53 | # gradient clip 54 | gradient_clip: float = 1.0 55 | # mixed precision 56 | mixed_precision: str = 'bf16' 57 | # learning rate 58 | lr: float = 1e-4 59 | # augmentation prob for grid distortion 60 | prob_grid_distortion: float = 0.5 61 | # augmentation prob for camera jitter 62 | prob_cam_jitter: float = 0.5 63 | 64 | ### testing 65 | # test image path 66 | test_path: Optional[str] = None 67 | 68 | ### misc 69 | # nvdiffrast backend setting 70 | force_cuda_rast: bool = False 71 | # render fancy video with gaussian scaling effect 72 | fancy_video: bool = False 73 | 74 | 75 | # all the default settings 76 | config_defaults: Dict[str, Options] = {} 77 | config_doc: Dict[str, str] = {} 78 | 79 | config_doc['lrm'] = 'the default settings for LGM' 80 | config_defaults['lrm'] = Options() 81 | 82 | config_doc['small'] = 'small model with lower resolution Gaussians' 83 | config_defaults['small'] = Options( 84 | input_size=256, 85 | splat_size=64, 86 | output_size=256, 87 | batch_size=4, 88 | gradient_accumulation_steps=1, 89 | mixed_precision='bf16', 90 | ) 91 | 92 | config_doc['big'] = 'big model with higher resolution Gaussians' 93 | config_defaults['big'] = Options( 94 | input_size=256, 95 | up_channels=(1024, 1024, 512, 256, 128), # one more decoder 96 | up_attention=(True, True, True, False, False), 97 | splat_size=128, 98 | output_size=512, # render & supervise Gaussians at a higher resolution. 99 | batch_size=8, 100 | num_views=8, 101 | gradient_accumulation_steps=1, 102 | mixed_precision='bf16', 103 | ) 104 | 105 | config_doc['tiny'] = 'tiny model for ablation' 106 | config_defaults['tiny'] = Options( 107 | input_size=256, 108 | down_channels=(32, 64, 128, 256), 109 | down_attention=(False, False, False, True), 110 | up_channels=(256, 128, 64), 111 | up_attention=(True, False, False), 112 | splat_size=128, 113 | output_size=256, 114 | batch_size=8, 115 | num_views=8, 116 | gradient_accumulation_steps=1, 117 | mixed_precision='bf16', 118 | ) 119 | 120 | AllConfigs = tyro.extras.subcommand_type_from_defaults(config_defaults, config_doc) 121 | -------------------------------------------------------------------------------- /core/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | import roma 8 | from kiui.op import safe_normalize 9 | 10 | def get_rays(pose, h, w, fovy, opengl=True): 11 | 12 | x, y = torch.meshgrid( 13 | torch.arange(w, device=pose.device), 14 | torch.arange(h, device=pose.device), 15 | indexing="xy", 16 | ) 17 | x = x.flatten() 18 | y = y.flatten() 19 | 20 | cx = w * 0.5 21 | cy = h * 0.5 22 | 23 | focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy)) 24 | 25 | camera_dirs = F.pad( 26 | torch.stack( 27 | [ 28 | (x - cx + 0.5) / focal, 29 | (y - cy + 0.5) / focal * (-1.0 if opengl else 1.0), 30 | ], 31 | dim=-1, 32 | ), 33 | (0, 1), 34 | value=(-1.0 if opengl else 1.0), 35 | ) # [hw, 3] 36 | 37 | rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3] 38 | rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3] 39 | 40 | rays_o = rays_o.view(h, w, 3) 41 | rays_d = safe_normalize(rays_d).view(h, w, 3) 42 | 43 | return rays_o, rays_d 44 | 45 | def orbit_camera_jitter(poses, strength=0.1): 46 | # poses: [B, 4, 4], assume orbit camera in opengl format 47 | # random orbital rotate 48 | 49 | B = poses.shape[0] 50 | rotvec_x = poses[:, :3, 1] * strength * np.pi * (torch.rand(B, 1, device=poses.device) * 2 - 1) 51 | rotvec_y = poses[:, :3, 0] * strength * np.pi / 2 * (torch.rand(B, 1, device=poses.device) * 2 - 1) 52 | 53 | rot = roma.rotvec_to_rotmat(rotvec_x) @ roma.rotvec_to_rotmat(rotvec_y) 54 | R = rot @ poses[:, :3, :3] 55 | T = rot @ poses[:, :3, 3:] 56 | 57 | new_poses = poses.clone() 58 | new_poses[:, :3, :3] = R 59 | new_poses[:, :3, 3:] = T 60 | 61 | return new_poses 62 | 63 | def grid_distortion(images, strength=0.5): 64 | # images: [B, C, H, W] 65 | # num_steps: int, grid resolution for distortion 66 | # strength: float in [0, 1], strength of distortion 67 | 68 | B, C, H, W = images.shape 69 | 70 | num_steps = np.random.randint(8, 17) 71 | grid_steps = torch.linspace(-1, 1, num_steps) 72 | 73 | # have to loop batch... 74 | grids = [] 75 | for b in range(B): 76 | # construct displacement 77 | x_steps = torch.linspace(0, 1, num_steps) # [num_steps], inclusive 78 | x_steps = (x_steps + strength * (torch.rand_like(x_steps) - 0.5) / (num_steps - 1)).clamp(0, 1) # perturb 79 | x_steps = (x_steps * W).long() # [num_steps] 80 | x_steps[0] = 0 81 | x_steps[-1] = W 82 | xs = [] 83 | for i in range(num_steps - 1): 84 | xs.append(torch.linspace(grid_steps[i], grid_steps[i + 1], x_steps[i + 1] - x_steps[i])) 85 | xs = torch.cat(xs, dim=0) # [W] 86 | 87 | y_steps = torch.linspace(0, 1, num_steps) # [num_steps], inclusive 88 | y_steps = (y_steps + strength * (torch.rand_like(y_steps) - 0.5) / (num_steps - 1)).clamp(0, 1) # perturb 89 | y_steps = (y_steps * H).long() # [num_steps] 90 | y_steps[0] = 0 91 | y_steps[-1] = H 92 | ys = [] 93 | for i in range(num_steps - 1): 94 | ys.append(torch.linspace(grid_steps[i], grid_steps[i + 1], y_steps[i + 1] - y_steps[i])) 95 | ys = torch.cat(ys, dim=0) # [H] 96 | 97 | # construct grid 98 | grid_x, grid_y = torch.meshgrid(xs, ys, indexing='xy') # [H, W] 99 | grid = torch.stack([grid_x, grid_y], dim=-1) # [H, W, 2] 100 | 101 | grids.append(grid) 102 | 103 | grids = torch.stack(grids, dim=0).to(images.device) # [B, H, W, 2] 104 | 105 | # grid sample 106 | images = F.grid_sample(images, grids, align_corners=False) 107 | 108 | return images 109 | 110 | -------------------------------------------------------------------------------- /data/images/demo1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/data/images/demo1.png -------------------------------------------------------------------------------- /data/images/demo2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/data/images/demo2.png -------------------------------------------------------------------------------- /data/test_images.txt: -------------------------------------------------------------------------------- 1 | ./data/images/demo1.png 2 | ./data/images/demo2.png -------------------------------------------------------------------------------- /data/test_prompts.txt: -------------------------------------------------------------------------------- 1 | Futuristic space helmet 2 | dragon armor 3 | A medieval shield with a cross and wooden handle -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import copy 4 | import json 5 | import math 6 | import random 7 | import logging 8 | import itertools 9 | import numpy as np 10 | 11 | from utils.config import Config 12 | from utils.registry_class import INFER_ENGINE 13 | 14 | from tools import * 15 | 16 | if __name__ == '__main__': 17 | cfg_update = Config(load=True) 18 | INFER_ENGINE.build(dict(type=cfg_update.TASK_TYPE), cfg_update=cfg_update.cfg_dict) 19 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113 2 | pip install -r requirements.txt 3 | pip install ninja 4 | git clone --recursive https://github.com/ashawkey/diff-gaussian-rasterization 5 | pip install ./diff-gaussian-rasterization -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | easydict==1.10 2 | tokenizers==0.12.1 3 | numpy>=1.19.2 4 | ftfy==6.1.1 5 | transformers==4.18.0 6 | imageio==2.15.0 7 | fairscale==0.4.6 8 | ipdb 9 | open-clip-torch==2.0.2 10 | xformers==0.0.13 11 | chardet==5.1.0 12 | torchdiffeq==0.2.3 13 | opencv-python==4.4.0.46 14 | opencv-python-headless==4.7.0.68 15 | torchsde==0.2.6 16 | simplejson==3.18.4 17 | motion-vector-extractor==1.0.6 18 | scikit-learn 19 | scikit-image 20 | rotary-embedding-torch==0.2.1 21 | pynvml==11.5.0 22 | triton==2.0.0.dev20221120 23 | pytorch-lightning==1.4.2 24 | torchmetrics==0.6.0 25 | gradio==3.39.0 26 | imageio-ffmpeg 27 | kornia 28 | tyro 29 | dearpygui 30 | einops 31 | lpips 32 | matplotlib 33 | packaging 34 | Pillow 35 | pygltflib 36 | rembg[gpu,cli] 37 | rich 38 | safetensors 39 | scipy 40 | tqdm 41 | trimesh 42 | kiui >= 0.2.3 43 | roma 44 | plyfile -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | from .annotator import * 2 | from .datasets import * 3 | from .modules import * 4 | from .train import * 5 | from .hooks import * 6 | from .inferences import * 7 | # from .prior import * 8 | from .basic_funcs import * 9 | -------------------------------------------------------------------------------- /tools/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /tools/annotator/canny/__init__.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import numpy as np 4 | from tools.annotator.util import HWC3 5 | # import gradio as gr 6 | 7 | class CannyDetector: 8 | def __call__(self, img, low_threshold = None, high_threshold = None, random_threshold = True): 9 | 10 | ### GPT-4 suggestions 11 | # In the cv2.Canny() function, the low threshold and high threshold are used to determine the edges based on the gradient values in the image. 12 | # There isn't a one-size-fits-all solution for these threshold values, as the optimal values depend on the specific image and the application. 13 | # However, there are some general guidelines and empirical values you can use as a starting point: 14 | # 1. Ratio: A common recommendation is to use a ratio of 1:2 or 1:3 between the low threshold and the high threshold. 15 | # This means if your low threshold is 50, the high threshold should be around 100 or 150. 16 | # 2. Empirical values: As a starting point, you can use low threshold values in the range of 50-100 and high threshold values in the range of 100-200. 17 | # You may need to fine-tune these values based on the specific image and desired edge detection results. 18 | # 3. Automatic threshold calculation: To automatically calculate the threshold values, you can use the median or mean value of the image's pixel intensities as the low threshold, 19 | # and the high threshold can be set as twice or three times the low threshold. 20 | 21 | ### Convert to numpy 22 | if isinstance(img, torch.Tensor): # (h, w, c) 23 | img = img.cpu().numpy() 24 | img_np = cv2.convertScaleAbs((img * 255.)) 25 | elif isinstance(img, np.ndarray): # (h, w, c) 26 | img_np = img # we assume values are in the range from 0 to 255. 27 | else: 28 | assert False 29 | 30 | ### Select the threshold 31 | if (low_threshold is None) and (high_threshold is None): 32 | median_intensity = np.median(img_np) 33 | if random_threshold is False: 34 | low_threshold = int(max(0, (1 - 0.33) * median_intensity)) 35 | high_threshold = int(min(255, (1 + 0.33) * median_intensity)) 36 | else: 37 | random_canny = np.random.uniform(0.1, 0.4) 38 | # Might try other values 39 | low_threshold = int(max(0, (1 - random_canny) * median_intensity)) 40 | high_threshold = 2 * low_threshold 41 | 42 | ### Detect canny edge 43 | canny_edge = cv2.Canny(img_np, low_threshold, high_threshold) 44 | ### Convert to 3 channels 45 | # canny_edge = HWC3(canny_edge) 46 | 47 | canny_condition = torch.from_numpy(canny_edge.copy()).unsqueeze(dim = -1).float().cuda() / 255.0 48 | # canny_condition = torch.stack([canny_condition for _ in range(num_samples)], dim=0) 49 | # canny_condition = einops.rearrange(canny_condition, 'h w c -> b c h w').clone() 50 | # return cv2.Canny(img, low_threshold, high_threshold) 51 | return canny_condition -------------------------------------------------------------------------------- /tools/annotator/histogram/__init__.py: -------------------------------------------------------------------------------- 1 | from .palette import * -------------------------------------------------------------------------------- /tools/annotator/histogram/palette.py: -------------------------------------------------------------------------------- 1 | r"""Modified from ``https://github.com/sergeyk/rayleigh''. 2 | """ 3 | import os 4 | import os.path as osp 5 | import numpy as np 6 | from skimage.color import hsv2rgb, rgb2lab, lab2rgb 7 | from skimage.io import imsave 8 | from sklearn.metrics import euclidean_distances 9 | 10 | __all__ = ['Palette'] 11 | 12 | def rgb2hex(rgb): 13 | return '#%02x%02x%02x' % tuple([int(round(255.0 * u)) for u in rgb]) 14 | 15 | def hex2rgb(hex): 16 | rgb = hex.strip('#') 17 | fn = lambda u: round(int(u, 16) / 255.0, 5) 18 | return fn(rgb[:2]), fn(rgb[2:4]), fn(rgb[4:6]) 19 | 20 | class Palette(object): 21 | r"""Create a color palette (codebook) in the form of a 2D grid of colors. 22 | Further, the rightmost column has num_hues gradations from black to white. 23 | 24 | Parameters: 25 | num_hues: number of colors with full lightness and saturation, in the middle. 26 | num_sat: number of rows above middle row that show the same hues with decreasing saturation. 27 | """ 28 | def __init__(self, num_hues=11, num_sat=5, num_light=4): 29 | n = num_sat + 2 * num_light 30 | 31 | # hues 32 | if num_hues == 8: 33 | hues = np.tile(np.array([0., 0.10, 0.15, 0.28, 0.51, 0.58, 0.77, 0.85]), (n, 1)) 34 | elif num_hues == 9: 35 | hues = np.tile(np.array([0., 0.10, 0.15, 0.28, 0.49, 0.54, 0.60, 0.7, 0.87]), (n, 1)) 36 | elif num_hues == 10: 37 | hues = np.tile(np.array([0., 0.10, 0.15, 0.28, 0.49, 0.54, 0.60, 0.66, 0.76, 0.87]), (n, 1)) 38 | elif num_hues == 11: 39 | hues = np.tile(np.array([0.0, 0.0833, 0.166, 0.25, 0.333, 0.5, 0.56333, 0.666, 0.73, 0.803, 0.916]), (n, 1)) 40 | else: 41 | hues = np.tile(np.linspace(0, 1, num_hues + 1)[:-1], (n, 1)) 42 | 43 | # saturations 44 | sats = np.hstack(( 45 | np.linspace(0, 1, num_sat + 2)[1:-1], 46 | 1, 47 | [1] * num_light, 48 | [0.4] * (num_light - 1))) 49 | sats = np.tile(np.atleast_2d(sats).T, (1, num_hues)) 50 | 51 | # lights 52 | lights = np.hstack(( 53 | [1] * num_sat, 54 | 1, 55 | np.linspace(1, 0.2, num_light + 2)[1:-1], 56 | np.linspace(1, 0.2, num_light + 2)[1:-2])) 57 | lights = np.tile(np.atleast_2d(lights).T, (1, num_hues)) 58 | 59 | # colors 60 | rgb = hsv2rgb(np.dstack([hues, sats, lights])) 61 | gray = np.tile(np.linspace(1, 0, n)[:, np.newaxis, np.newaxis], (1, 1, 3)) 62 | self.thumbnail = np.hstack([rgb, gray]) 63 | 64 | # flatten 65 | rgb = rgb.T.reshape(3, -1).T 66 | gray = gray.T.reshape(3, -1).T 67 | self.rgb = np.vstack((rgb, gray)) 68 | self.lab = rgb2lab(self.rgb[np.newaxis, :, :]).squeeze() 69 | self.hex = [rgb2hex(u) for u in self.rgb] 70 | self.lab_dists = euclidean_distances(self.lab, squared=True) 71 | 72 | def histogram(self, rgb_img, sigma=20): 73 | # compute histogram 74 | lab = rgb2lab(rgb_img).reshape((-1, 3)) 75 | min_ind = np.argmin(euclidean_distances(lab, self.lab, squared=True), axis=1) 76 | hist = 1.0 * np.bincount(min_ind, minlength=self.lab.shape[0]) / lab.shape[0] 77 | 78 | # smooth histogram 79 | if sigma > 0: 80 | weight = np.exp(-self.lab_dists / (2.0 * sigma ** 2)) 81 | weight = weight / weight.sum(1)[:, np.newaxis] 82 | hist = (weight * hist).sum(1) 83 | hist[hist < 1e-5] = 0 84 | return hist 85 | 86 | def get_palette_image(self, hist, percentile=90, width=200, height=50): 87 | # curate histogram 88 | ind = np.argsort(-hist) 89 | ind = ind[hist[ind] > np.percentile(hist, percentile)] 90 | hist = hist[ind] / hist[ind].sum() 91 | 92 | # draw palette 93 | nums = np.array(hist * width, dtype=int) 94 | array = np.vstack([np.tile(np.array(u), (v, 1)) for u, v in zip(self.rgb[ind], nums)]) 95 | array = np.tile(array[np.newaxis, :, :], (height, 1, 1)) 96 | if array.shape[1] < width: 97 | array = np.concatenate([array, np.zeros((height, width - array.shape[1], 3))], axis=1) 98 | return array 99 | 100 | def quantize_image(self, rgb_img): 101 | lab = rgb2lab(rgb_img).reshape((-1, 3)) 102 | min_ind = np.argmin(euclidean_distances(lab, self.lab, squared=True), axis=1) 103 | quantized_lab = self.lab[min_ind] 104 | img = lab2rgb(quantized_lab.reshape(rgb_img.shape)) 105 | return img 106 | 107 | def export(self, dirname): 108 | if not osp.exists(dirname): 109 | os.makedirs(dirname) 110 | 111 | # save thumbnail 112 | imsave(osp.join(dirname, 'palette.png'), self.thumbnail) 113 | 114 | # save html 115 | with open(osp.join(dirname, 'palette.html'), 'w') as f: 116 | html = ''' 117 | 126 | ''' 127 | for row in self.thumbnail: 128 | for col in row: 129 | html += '\n'.format(rgb2hex(col)) 130 | html += '
\n' 131 | f.write(html) 132 | -------------------------------------------------------------------------------- /tools/annotator/sketch/__init__.py: -------------------------------------------------------------------------------- 1 | from .pidinet import * 2 | from .sketch_simplification import * -------------------------------------------------------------------------------- /tools/annotator/sketch/sketch_simplification.py: -------------------------------------------------------------------------------- 1 | r"""PyTorch re-implementation adapted from the Lua code in ``https://github.com/bobbens/sketch_simplification''. 2 | """ 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import math 7 | 8 | # from canvas import DOWNLOAD_TO_CACHE 9 | from artist import DOWNLOAD_TO_CACHE 10 | 11 | __all__ = ['SketchSimplification', 'sketch_simplification_gan', 'sketch_simplification_mse', 12 | 'sketch_to_pencil_v1', 'sketch_to_pencil_v2'] 13 | 14 | class SketchSimplification(nn.Module): 15 | r"""NOTE: 16 | 1. Input image should has only one gray channel. 17 | 2. Input image size should be divisible by 8. 18 | 3. Sketch in the input/output image is in dark color while background in light color. 19 | """ 20 | def __init__(self, mean, std): 21 | assert isinstance(mean, float) and isinstance(std, float) 22 | super(SketchSimplification, self).__init__() 23 | self.mean = mean 24 | self.std = std 25 | 26 | # layers 27 | self.layers = nn.Sequential( 28 | nn.Conv2d(1, 48, 5, 2, 2), 29 | nn.ReLU(inplace=True), 30 | nn.Conv2d(48, 128, 3, 1, 1), 31 | nn.ReLU(inplace=True), 32 | nn.Conv2d(128, 128, 3, 1, 1), 33 | nn.ReLU(inplace=True), 34 | nn.Conv2d(128, 128, 3, 2, 1), 35 | nn.ReLU(inplace=True), 36 | nn.Conv2d(128, 256, 3, 1, 1), 37 | nn.ReLU(inplace=True), 38 | nn.Conv2d(256, 256, 3, 1, 1), 39 | nn.ReLU(inplace=True), 40 | nn.Conv2d(256, 256, 3, 2, 1), 41 | nn.ReLU(inplace=True), 42 | nn.Conv2d(256, 512, 3, 1, 1), 43 | nn.ReLU(inplace=True), 44 | nn.Conv2d(512, 1024, 3, 1, 1), 45 | nn.ReLU(inplace=True), 46 | nn.Conv2d(1024, 1024, 3, 1, 1), 47 | nn.ReLU(inplace=True), 48 | nn.Conv2d(1024, 1024, 3, 1, 1), 49 | nn.ReLU(inplace=True), 50 | nn.Conv2d(1024, 1024, 3, 1, 1), 51 | nn.ReLU(inplace=True), 52 | nn.Conv2d(1024, 512, 3, 1, 1), 53 | nn.ReLU(inplace=True), 54 | nn.Conv2d(512, 256, 3, 1, 1), 55 | nn.ReLU(inplace=True), 56 | nn.ConvTranspose2d(256, 256, 4, 2, 1), 57 | nn.ReLU(inplace=True), 58 | nn.Conv2d(256, 256, 3, 1, 1), 59 | nn.ReLU(inplace=True), 60 | nn.Conv2d(256, 128, 3, 1, 1), 61 | nn.ReLU(inplace=True), 62 | nn.ConvTranspose2d(128, 128, 4, 2, 1), 63 | nn.ReLU(inplace=True), 64 | nn.Conv2d(128, 128, 3, 1, 1), 65 | nn.ReLU(inplace=True), 66 | nn.Conv2d(128, 48, 3, 1, 1), 67 | nn.ReLU(inplace=True), 68 | nn.ConvTranspose2d(48, 48, 4, 2, 1), 69 | nn.ReLU(inplace=True), 70 | nn.Conv2d(48, 24, 3, 1, 1), 71 | nn.ReLU(inplace=True), 72 | nn.Conv2d(24, 1, 3, 1, 1), 73 | nn.Sigmoid()) 74 | 75 | def forward(self, x): 76 | r"""x: [B, 1, H, W] within range [0, 1]. Sketch pixels in dark color. 77 | """ 78 | x = (x - self.mean) / self.std 79 | return self.layers(x) 80 | 81 | def sketch_simplification_gan(pretrained=False): 82 | model = SketchSimplification(mean=0.9664114577640158, std=0.0858381272736797) 83 | if pretrained: 84 | # model.load_state_dict(torch.load( 85 | # DOWNLOAD_TO_CACHE('models/sketch_simplification/sketch_simplification_gan.pth'), 86 | # map_location='cpu')) 87 | model.load_state_dict(torch.load( 88 | DOWNLOAD_TO_CACHE('VideoComposer/Hangjie/models/sketch_simplification/sketch_simplification_gan.pth'), 89 | map_location='cpu')) 90 | return model 91 | 92 | def sketch_simplification_mse(pretrained=False): 93 | model = SketchSimplification(mean=0.9664423107454593, std=0.08583666033640507) 94 | if pretrained: 95 | model.load_state_dict(torch.load( 96 | DOWNLOAD_TO_CACHE('models/sketch_simplification/sketch_simplification_mse.pth'), 97 | map_location='cpu')) 98 | return model 99 | 100 | def sketch_to_pencil_v1(pretrained=False): 101 | model = SketchSimplification(mean=0.9817833515894078, std=0.0925009022585048) 102 | if pretrained: 103 | model.load_state_dict(torch.load( 104 | DOWNLOAD_TO_CACHE('models/sketch_simplification/sketch_to_pencil_v1.pth'), 105 | map_location='cpu')) 106 | return model 107 | 108 | def sketch_to_pencil_v2(pretrained=False): 109 | model = SketchSimplification(mean=0.9851298627337799, std=0.07418377454883571) 110 | if pretrained: 111 | model.load_state_dict(torch.load( 112 | DOWNLOAD_TO_CACHE('models/sketch_simplification/sketch_to_pencil_v2.pth'), 113 | map_location='cpu')) 114 | return model 115 | -------------------------------------------------------------------------------- /tools/annotator/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | 5 | annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts') 6 | 7 | def HWC3(x): 8 | assert x.dtype == np.uint8 9 | if x.ndim == 2: 10 | x = x[:, :, None] 11 | assert x.ndim == 3 12 | H, W, C = x.shape 13 | assert C == 1 or C == 3 or C == 4 14 | if C == 3: 15 | return x 16 | if C == 1: 17 | return np.concatenate([x, x, x], axis=2) 18 | if C == 4: 19 | color = x[:, :, 0:3].astype(np.float32) 20 | alpha = x[:, :, 3:4].astype(np.float32) / 255.0 21 | y = color * alpha + 255.0 * (1.0 - alpha) 22 | y = y.clip(0, 255).astype(np.uint8) 23 | return y 24 | 25 | 26 | def resize_image(input_image, resolution): 27 | H, W, C = input_image.shape 28 | H = float(H) 29 | W = float(W) 30 | k = float(resolution) / min(H, W) 31 | H *= k 32 | W *= k 33 | H = int(np.round(H / 64.0)) * 64 34 | W = int(np.round(W / 64.0)) * 64 35 | img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) 36 | return img -------------------------------------------------------------------------------- /tools/basic_funcs/__init__.py: -------------------------------------------------------------------------------- 1 | from .pretrain_functions import * 2 | -------------------------------------------------------------------------------- /tools/basic_funcs/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/basic_funcs/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /tools/basic_funcs/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/basic_funcs/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /tools/basic_funcs/__pycache__/pretrain_functions.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/basic_funcs/__pycache__/pretrain_functions.cpython-310.pyc -------------------------------------------------------------------------------- /tools/basic_funcs/__pycache__/pretrain_functions.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/basic_funcs/__pycache__/pretrain_functions.cpython-38.pyc -------------------------------------------------------------------------------- /tools/basic_funcs/pretrain_functions.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import logging 5 | import collections 6 | 7 | from utils.registry_class import PRETRAIN 8 | 9 | @PRETRAIN.register_function() 10 | def pretrain_specific_strategies( 11 | model, 12 | resume_checkpoint, 13 | sd_keys_path=None, 14 | grad_scale=1, 15 | fix_weight=False, 16 | **kwargs 17 | ): 18 | 19 | state_dict = torch.load(resume_checkpoint, map_location='cpu') 20 | if 'state_dict' in state_dict: 21 | state_dict = state_dict['state_dict'] 22 | 23 | # [1] load model 24 | try: 25 | ret = model.load_state_dict(state_dict, strict=False) 26 | logging.info(f'load a fixed model with {ret}') 27 | except: 28 | model_dict = model.state_dict() 29 | key_list = list(state_dict.keys()) 30 | for skey, item in state_dict.items(): 31 | if skey not in model_dict: 32 | logging.info(f'Skip {skey}') 33 | continue 34 | if item.shape != model_dict[skey].shape: 35 | logging.info(f'Skip {skey} with different shape {item.shape} {model_dict[skey].shape}') 36 | continue 37 | model_dict[skey].copy_(item) 38 | model.load_state_dict(model_dict) 39 | 40 | # [2] assign strategies 41 | total_size = 0 42 | state_dict = {} if sd_keys_path is None else json.load(open(sd_keys_path)) 43 | for k, p in model.named_parameters(): 44 | if k in state_dict: 45 | total_size += p.numel() 46 | if fix_weight: 47 | p.requires_grad=False 48 | else: 49 | p.register_hook(lambda grad: grad_scale * grad) 50 | 51 | resume_step = int(os.path.basename(resume_checkpoint).split('_')[-1].split('.')[0]) 52 | logging.info(f'Successfully load step {resume_step} model from {resume_checkpoint}') 53 | logging.info(f'load a fixed model with {int(total_size / (1024 ** 2))}M parameters') 54 | return model, resume_step 55 | 56 | 57 | 58 | @PRETRAIN.register_function() 59 | def pretrain_from_sd(): 60 | pass 61 | 62 | 63 | @PRETRAIN.register_function() 64 | def pretrain_ema_model(): 65 | pass 66 | -------------------------------------------------------------------------------- /tools/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .image_dataset import * 2 | from .video_dataset import * 3 | from .video_i2v_dataset import * -------------------------------------------------------------------------------- /tools/datasets/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/datasets/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /tools/datasets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/datasets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /tools/datasets/__pycache__/image_dataset.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/datasets/__pycache__/image_dataset.cpython-310.pyc -------------------------------------------------------------------------------- /tools/datasets/__pycache__/image_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/datasets/__pycache__/image_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /tools/datasets/__pycache__/video_dataset.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/datasets/__pycache__/video_dataset.cpython-310.pyc -------------------------------------------------------------------------------- /tools/datasets/__pycache__/video_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/datasets/__pycache__/video_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /tools/datasets/__pycache__/video_i2v_dataset.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/datasets/__pycache__/video_i2v_dataset.cpython-310.pyc -------------------------------------------------------------------------------- /tools/datasets/__pycache__/video_i2v_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/datasets/__pycache__/video_i2v_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /tools/datasets/image_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import random 5 | import logging 6 | import tempfile 7 | import numpy as np 8 | from copy import copy 9 | from PIL import Image 10 | from io import BytesIO 11 | from torch.utils.data import Dataset 12 | from utils.registry_class import DATASETS 13 | 14 | @DATASETS.register_class() 15 | class ImageDataset(Dataset): 16 | def __init__(self, 17 | data_list, 18 | data_dir_list, 19 | max_words=1000, 20 | vit_resolution=[224, 224], 21 | resolution=(384, 256), 22 | max_frames=1, 23 | transforms=None, 24 | vit_transforms=None, 25 | **kwargs): 26 | 27 | self.max_frames = max_frames 28 | self.resolution = resolution 29 | self.transforms = transforms 30 | self.vit_resolution = vit_resolution 31 | self.vit_transforms = vit_transforms 32 | 33 | image_list = [] 34 | for item_path, data_dir in zip(data_list, data_dir_list): 35 | lines = open(item_path, 'r').readlines() 36 | lines = [[data_dir, item.strip()] for item in lines] 37 | image_list.extend(lines) 38 | self.image_list = image_list 39 | 40 | def __len__(self): 41 | return len(self.image_list) 42 | 43 | def __getitem__(self, index): 44 | data_dir, file_path = self.image_list[index] 45 | img_key = file_path.split('|||')[0] 46 | try: 47 | ref_frame, vit_frame, video_data, caption = self._get_image_data(data_dir, file_path) 48 | except Exception as e: 49 | logging.info('{} get frames failed... with error: {}'.format(img_key, e)) 50 | caption = '' 51 | img_key = '' 52 | ref_frame = torch.zeros(3, self.resolution[1], self.resolution[0]) 53 | vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0]) 54 | video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0]) 55 | return ref_frame, vit_frame, video_data, caption, img_key 56 | 57 | def _get_image_data(self, data_dir, file_path): 58 | frame_list = [] 59 | img_key, caption = file_path.split('|||') 60 | file_path = os.path.join(data_dir, img_key) 61 | for _ in range(5): 62 | try: 63 | image = Image.open(file_path) 64 | if image.mode != 'RGB': 65 | image = image.convert('RGB') 66 | frame_list.append(image) 67 | break 68 | except Exception as e: 69 | logging.info('{} read video frame failed with error: {}'.format(img_key, e)) 70 | continue 71 | 72 | video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0]) 73 | try: 74 | if len(frame_list) > 0: 75 | mid_frame = frame_list[0] 76 | vit_frame = self.vit_transforms(mid_frame) 77 | frame_tensor = self.transforms(frame_list) 78 | video_data[:len(frame_list), ...] = frame_tensor 79 | else: 80 | vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0]) 81 | except: 82 | vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0]) 83 | ref_frame = copy(video_data[0]) 84 | 85 | return ref_frame, vit_frame, video_data, caption 86 | 87 | -------------------------------------------------------------------------------- /tools/datasets/laion_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import cv2 4 | import torch 5 | import random 6 | import logging 7 | import tempfile 8 | import numpy as np 9 | from functools import partial 10 | from copy import copy 11 | from PIL import Image 12 | from io import BytesIO 13 | from torch.utils.data import Dataset 14 | import torchvision.transforms.functional as TF 15 | import albumentations 16 | import PIL 17 | from PIL import Image, ImageFile 18 | ImageFile.LOAD_TRUNCATED_IMAGES = True 19 | import webdataset as wds 20 | 21 | try: 22 | from utils.registry_class import DATASETS 23 | except Exception as ex: 24 | print("#" * 20) 25 | print("import error, try fixed by appending path") 26 | import sys 27 | sys.path.append("./") 28 | from utils.registry_class import DATASETS 29 | 30 | 31 | 32 | def HWC3(x): 33 | assert x.dtype == np.uint8 34 | if x.ndim == 2: 35 | x = x[:, :, None] 36 | assert x.ndim == 3 37 | H, W, C = x.shape 38 | assert C == 1 or C == 3 or C == 4 39 | if C == 3: 40 | return x 41 | if C == 1: 42 | return np.concatenate([x, x, x], axis=2) 43 | if C == 4: 44 | color = x[:, :, 0:3].astype(np.float32) 45 | alpha = x[:, :, 3:4].astype(np.float32) / 255.0 46 | y = color * alpha + 255.0 * (1.0 - alpha) 47 | y = y.clip(0, 255).astype(np.uint8) 48 | return y 49 | 50 | 51 | def resize_image(input_image, resolution): 52 | H, W, C = input_image.shape 53 | H = float(H) 54 | W = float(W) 55 | k = float(resolution) / min(H, W) 56 | H *= k 57 | W *= k 58 | H = int(np.round(H / 64.0)) * 64 59 | W = int(np.round(W / 64.0)) * 64 60 | img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) 61 | return img 62 | 63 | 64 | def my_decoder(key, value): 65 | # solve the issue: https://github.com/webdataset/webdataset/issues/206 66 | 67 | if key.endswith('.jpg'): 68 | # return Image.open(BytesIO(value)) 69 | return np.asarray(Image.open(BytesIO(value)).convert('RGB')) 70 | 71 | return None 72 | 73 | 74 | class filter_fake: 75 | 76 | def __init__(self, punsafety=0.2, aest=4.5): 77 | self.punsafety = punsafety 78 | self.aest = aest 79 | 80 | def __call__(self, src): 81 | for sample in src: 82 | img, prompt, json = sample 83 | # watermark filter 84 | if json['pwatermark'] is not None: 85 | if json['pwatermark'] > 0.3: 86 | continue 87 | 88 | # watermark 89 | if json['punsafe'] is not None: 90 | if json['punsafe'] > self.punsafety: 91 | continue 92 | 93 | # watermark 94 | if json['AESTHETIC_SCORE'] is not None: 95 | if json['AESTHETIC_SCORE'] < self.aest: 96 | continue 97 | 98 | # ratio filter 99 | w, h = json['width'], json['height'] 100 | if max(w / h, h / w) > 3: 101 | continue 102 | 103 | yield img, prompt, json['AESTHETIC_SCORE'], json['key'] 104 | 105 | 106 | class Laion2b_Process(object): 107 | 108 | def __init__(self, 109 | size=None, 110 | degradation=None, 111 | downscale_f=4, 112 | min_crop_f=0.8, 113 | max_crop_f=1., 114 | random_crop=True, 115 | debug: bool = False): 116 | """ 117 | Imagenet Superresolution Dataloader 118 | Performs following ops in order: 119 | 1. crops a crop of size s from image either as random or center crop 120 | 2. resizes crop to size with cv2.area_interpolation 121 | 3. degrades resized crop with degradation_fn 122 | 123 | :param size: resizing to size after cropping 124 | :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light 125 | :param downscale_f: Low Resolution Downsample factor 126 | :param min_crop_f: determines crop size s, 127 | where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f) 128 | :param max_crop_f: "" 129 | :param data_root: 130 | :param random_crop: 131 | """ 132 | # downsacle_f = 0 133 | 134 | assert size 135 | assert (size / downscale_f).is_integer() 136 | self.size = size 137 | self.LR_size = int(size / downscale_f) 138 | self.min_crop_f = min_crop_f 139 | self.max_crop_f = max_crop_f 140 | assert (max_crop_f <= 1.) 141 | self.center_crop = not random_crop 142 | 143 | self.image_rescaler = albumentations.SmallestMaxSize( 144 | max_size=size, interpolation=cv2.INTER_AREA) 145 | 146 | 147 | def __call__(self, samples): 148 | example = {} 149 | image, caption, aesthetics, key = samples 150 | 151 | image = np.array(image).astype(np.uint8) 152 | 153 | min_side_len = min(image.shape[:2]) 154 | crop_side_len = min_side_len * np.random.uniform( 155 | self.min_crop_f, self.max_crop_f, size=None) 156 | crop_side_len = int(crop_side_len) 157 | 158 | if self.center_crop: 159 | self.cropper = albumentations.CenterCrop( 160 | height=crop_side_len, width=crop_side_len) 161 | else: 162 | self.cropper = albumentations.RandomCrop( 163 | height=crop_side_len, width=crop_side_len) 164 | 165 | image = self.cropper(image=image)['image'] 166 | image = self.image_rescaler(image=image)['image'] 167 | 168 | # -1, 1 169 | ref_image = (image / 127.5 - 1.0).astype(np.float32) 170 | ref_image = ref_image.transpose(2, 0, 1) 171 | vit_image = ref_image 172 | video_data = ref_image[np.newaxis, :, :, :] 173 | 174 | 175 | # example['image'] = image 176 | # # depth prior is set to 384 177 | # example['prior'] = resize_image(HWC3(image), 384) 178 | # example['caption'] = caption 179 | # example['aesthetics'] = aesthetics 180 | # example['key'] = key 181 | 182 | return ref_image, vit_image, video_data, caption, key 183 | 184 | 185 | @DATASETS.register_class() 186 | class LAIONImageDataset(): 187 | def __init__(self, 188 | data_list, 189 | data_dir_list, 190 | max_words=1000, 191 | vit_resolution=[224, 224], 192 | resolution=(256, 256), 193 | max_frames=1, 194 | transforms=None, 195 | vit_transforms=None, 196 | **kwargs): 197 | 198 | aest = kwargs.get("aest", 4.0) 199 | punsafety = kwargs.get("punsafety", 0.2) 200 | min_crop_f = kwargs.get("min_crop_f", 1.0) 201 | self.num_samples = kwargs.get("num_samples", 60580*2000) 202 | 203 | assert resolution[0] == resolution[1] 204 | assert len(data_dir_list) == 1 205 | assert len(data_list) == 1 206 | 207 | self.web_dataset = wds.WebDataset(os.path.join(data_dir_list[0], data_list[0]), resampled=True).decode( 208 | my_decoder, 'rgb8').shuffle(1000).to_tuple( 209 | 'jpg', 'txt', 'json').compose( 210 | filter_fake(aest=aest, punsafety=punsafety)).map( 211 | Laion2b_Process( 212 | size=resolution[0], 213 | min_crop_f=min_crop_f) 214 | ) 215 | 216 | def create_dataloader(self, batch_size, world_size, workers): 217 | num_samples = self.num_samples 218 | self.dataset = self.web_dataset.batched(batch_size, partial=False) 219 | round_fn = math.ceil 220 | global_batch_size = batch_size * world_size 221 | num_batches = round_fn(num_samples / global_batch_size) 222 | num_workers = max(1, workers) 223 | num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker 224 | num_batches = num_worker_batches * num_workers 225 | num_samples = num_batches * global_batch_size 226 | dataset = self.dataset.with_epoch(num_worker_batches) # each worker is iterating over this 227 | 228 | self.dataloader = wds.WebLoader( 229 | dataset, 230 | batch_size=None, 231 | shuffle=False, 232 | num_workers=workers, 233 | persistent_workers=workers > 0, 234 | ) 235 | 236 | self.dataloader.num_batches = num_batches 237 | self.dataloader.num_samples = num_samples 238 | 239 | print("#"*50) 240 | print(f"dataloder, num_batches:{num_batches}, num_samples:{num_samples}") 241 | print("#"*50) 242 | return self.dataloader 243 | 244 | 245 | 246 | if __name__ == "__main__": 247 | dataset = LAIONImageDataset( 248 | data_list=['{00000..00001}.tar'], 249 | data_dir_list=['/home/gxd/projects/Normal-Depth-Diffusion-Model/tools/download_dataset/laion-2ben-5_aes/'], 250 | max_words=1000, 251 | resolution=(256, 256), 252 | vit_resolution=(224, 224), 253 | max_frames=24, 254 | sample_fps=1, 255 | transforms=None, 256 | vit_transforms=None, 257 | get_first_frame=True, 258 | num_samples=1000, 259 | debug=True) 260 | 261 | batch_size = 20 262 | world_size = 1 263 | workers = 10 264 | 265 | dataloader = dataset.create_dataloader(batch_size, world_size, workers) 266 | 267 | import tqdm 268 | key_list = [] 269 | for data in tqdm.tqdm(dataloader): 270 | pass 271 | print(data[0].shape, data[1].shape, data[2].shape) 272 | key_list.extend(data[4]) 273 | print(len(key_list), len(set(key_list))) 274 | 275 | 276 | -------------------------------------------------------------------------------- /tools/datasets/video_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import json 4 | import torch 5 | import random 6 | import logging 7 | import tempfile 8 | import numpy as np 9 | from copy import copy 10 | from PIL import Image 11 | import torch.nn.functional as F 12 | from torch.utils.data import Dataset 13 | from utils.registry_class import DATASETS 14 | from core.utils import get_rays, grid_distortion, orbit_camera_jitter 15 | 16 | def read_camera_matrix_single(json_file): 17 | with open(json_file, 'r', encoding='utf8') as reader: 18 | json_content = json.load(reader) 19 | 20 | cond_camera_matrix = np.eye(4) 21 | cond_camera_matrix[:3, 0] = np.array(json_content['x']) 22 | cond_camera_matrix[:3, 1] = -np.array(json_content['y']) 23 | cond_camera_matrix[:3, 2] = -np.array(json_content['z']) 24 | cond_camera_matrix[:3, 3] = np.array(json_content['origin']) 25 | 26 | 27 | camera_matrix = np.eye(4) 28 | camera_matrix[:3, 0] = np.array(json_content['x']) 29 | camera_matrix[:3, 1] = np.array(json_content['y']) 30 | camera_matrix[:3, 2] = np.array(json_content['z']) 31 | camera_matrix[:3, 3] = np.array(json_content['origin']) 32 | 33 | return camera_matrix, cond_camera_matrix 34 | 35 | @DATASETS.register_class() 36 | class VideoDataset(Dataset): 37 | def __init__(self, 38 | data_list, 39 | data_dir_list, 40 | caption_dir, 41 | max_words=1000, 42 | resolution=(384, 256), 43 | vit_resolution=(224, 224), 44 | max_frames=16, 45 | sample_fps=8, 46 | transforms=None, 47 | vit_transforms=None, 48 | get_first_frame=True, 49 | prepare_lgm=False, 50 | **kwargs): 51 | self.prepare_lgm = prepare_lgm 52 | self.max_words = max_words 53 | self.max_frames = max_frames 54 | self.resolution = resolution 55 | self.vit_resolution = vit_resolution 56 | self.sample_fps = sample_fps 57 | self.transforms = transforms 58 | self.vit_transforms = vit_transforms 59 | self.get_first_frame = get_first_frame 60 | 61 | # @NOTE instead we read json 62 | image_list = [] 63 | self.captions = json.load(open(caption_dir)) 64 | for item_path, data_dir in zip(data_list, data_dir_list): 65 | lines = json.load(open(item_path)) 66 | lines = [[data_dir, item] for item in lines] 67 | image_list.extend(lines) 68 | self.image_list = image_list 69 | self.replica = 1000 70 | 71 | if self.prepare_lgm: 72 | from core.options import config_defaults 73 | self.opt = config_defaults['big'] 74 | # default camera intrinsics 75 | self.tan_half_fov = np.tan(0.5 * np.deg2rad(self.opt.fovy)) 76 | self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32) 77 | self.proj_matrix[0, 0] = 1 / self.tan_half_fov 78 | self.proj_matrix[1, 1] = 1 / self.tan_half_fov 79 | self.proj_matrix[2, 2] = (self.opt.zfar + self.opt.znear) / (self.opt.zfar - self.opt.znear) 80 | self.proj_matrix[3, 2] = - (self.opt.zfar * self.opt.znear) / (self.opt.zfar - self.opt.znear) 81 | self.proj_matrix[2, 3] = 1 82 | 83 | def __getitem__(self, index): 84 | index = index % len(self.image_list) 85 | data_dir, file_path = self.image_list[index] 86 | video_key = file_path 87 | caption = self.captions[file_path] + ", 3d asset" 88 | 89 | try: 90 | ref_frame, vit_frame, video_data, fullreso_video_data, camera_data, mask_data, fullreso_mask_data = self._get_video_data(data_dir, file_path) 91 | if self.prepare_lgm: 92 | results = self.prepare_gs(camera_data.clone(), fullreso_mask_data.clone(), fullreso_video_data.clone()) 93 | results['images_output'] = fullreso_video_data # GT renderings of [512, 512] resolution in the range [0,1] 94 | except Exception as e: 95 | print(e) 96 | return self.__getitem__((index+1)%len(self)) # next available data 97 | 98 | if self.prepare_lgm: 99 | return results, ref_frame, vit_frame, video_data, camera_data, mask_data, caption, video_key 100 | else: 101 | return ref_frame, vit_frame, video_data, camera_data, mask_data, caption, video_key 102 | 103 | def prepare_gs(self, camera_data, mask_data, video_data): # mask_data [24,512,512,1] 104 | 105 | results = {} 106 | 107 | mask_data = mask_data.permute(0,3,1,2) 108 | results['masks_output'] = mask_data/255.0 # TODO normalize to [0, 1] 109 | 110 | T = camera_data.shape[0] 111 | camera_data = camera_data.view(T,4,4).contiguous() 112 | 113 | camera_data[:,1] *= -1 114 | camera_data[:,[1, 2]] = camera_data[:,[2, 1]] 115 | cam_dis = np.sqrt(camera_data[0,0,3]**2 + camera_data[0,1,3]**2 + camera_data[0,2,3]**2) 116 | 117 | # normalized camera feats as in paper (transform the first pose to a fixed position) 118 | transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, cam_dis], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(camera_data[0]) 119 | cam_poses = transform.unsqueeze(0) @ camera_data # [V, 4, 4] 120 | 121 | cam_poses_input = cam_poses.clone() 122 | 123 | rays_embeddings = [] 124 | for i in range(T): 125 | rays_o, rays_d = get_rays(cam_poses_input[i], 256, 256, self.opt.fovy) # [h, w, 3] 126 | rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6] 127 | rays_embeddings.append(rays_plucker) 128 | 129 | rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous() # [V=24, 6, h, w] 130 | results['input'] = rays_embeddings 131 | 132 | # opengl to colmap camera for gs renderer 133 | cam_poses_input[:,:3,1:3] *= -1 134 | 135 | # cameras needed by gaussian rasterizer 136 | cam_view = torch.inverse(cam_poses_input).transpose(1, 2) # [V, 4, 4] 137 | cam_view_proj = cam_view @ self.proj_matrix # [V, 4, 4] 138 | cam_pos = - cam_poses_input[:, :3, 3] # [V, 3] 139 | 140 | results['cam_view'] = cam_view 141 | results['cam_view_proj'] = cam_view_proj 142 | results['cam_pos'] = cam_pos 143 | 144 | return results 145 | 146 | def _get_video_data(self, data_dir, file_path): 147 | prefix = os.path.join(data_dir, file_path, 'campos_512_v4') 148 | 149 | frames_path = [os.path.join(prefix, "{:05d}/{:05d}.png".format(frame_idx, frame_idx)) for frame_idx in range(24)] 150 | camera_path = [os.path.join(prefix, "{:05d}/{:05d}.json".format(frame_idx, frame_idx)) for frame_idx in range(24)] 151 | 152 | frame_list = [] 153 | fullreso_frame_list = [] 154 | camera_list = [] 155 | mask_list = [] 156 | fullreso_mask_list = [] 157 | for frame_idx, frame_path in enumerate(frames_path): 158 | img = Image.open(frame_path).convert('RGBA') 159 | mask = torch.from_numpy(np.array(img.resize((self.resolution[1], self.resolution[0])))[:,:,-1]).unsqueeze(-1) 160 | mask_list.append(mask) 161 | fullreso_mask = torch.from_numpy(np.array(img)[:,:,-1]).unsqueeze(-1) 162 | fullreso_mask_list.append(fullreso_mask) 163 | 164 | width = img.width 165 | height = img.height 166 | # grey_scale = random.randint(128, 130) # random gray color 167 | grey_scale = 128 168 | image = Image.new('RGB', size=(width, height), color=(grey_scale,grey_scale,grey_scale)) 169 | image.paste(img,(0,0),mask=img) 170 | 171 | fullreso_frame_list.append(torch.from_numpy(np.array(image)/255.0).float()) # for LGM rendering NOTE notice the data range [0,1] 172 | frame_list.append(image.resize((self.resolution[1], self.resolution[0]))) 173 | 174 | _, camera_embedding = read_camera_matrix_single(camera_path[frame_idx]) 175 | camera_list.append(torch.from_numpy(camera_embedding.flatten().astype(np.float32))) 176 | 177 | camera_data = torch.stack(camera_list, dim=0) # [24,16] 178 | mask_data = torch.stack(mask_list, dim=0) 179 | fullreso_mask_data = torch.stack(fullreso_mask_list, dim=0) 180 | 181 | video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0]) 182 | fullreso_video_data = torch.zeros(self.max_frames, 3, 512, 512) 183 | if self.get_first_frame: 184 | ref_idx = 0 185 | else: 186 | ref_idx = int(len(frame_list)/2) 187 | 188 | mid_frame = copy(frame_list[ref_idx]) 189 | vit_frame = self.vit_transforms(mid_frame) 190 | frames = self.transforms(frame_list) 191 | video_data[:len(frame_list), ...] = frames 192 | 193 | fullreso_video_data[:len(fullreso_frame_list), ...] = torch.stack(fullreso_frame_list, dim=0).permute(0,3,1,2) 194 | 195 | ref_frame = copy(frames[ref_idx]) 196 | 197 | return ref_frame, vit_frame, video_data, fullreso_video_data, camera_data, mask_data, fullreso_mask_data 198 | 199 | def __len__(self): 200 | return len(self.image_list)*self.replica -------------------------------------------------------------------------------- /tools/datasets/video_i2v_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import json 4 | import torch 5 | import random 6 | import logging 7 | import tempfile 8 | import numpy as np 9 | from copy import copy 10 | from PIL import Image 11 | import torch.nn.functional as F 12 | from torch.utils.data import Dataset 13 | from utils.registry_class import DATASETS 14 | from core.utils import get_rays, grid_distortion, orbit_camera_jitter 15 | 16 | def read_camera_matrix_single(json_file): 17 | with open(json_file, 'r', encoding='utf8') as reader: 18 | json_content = json.load(reader) 19 | 20 | cond_camera_matrix = np.eye(4) 21 | cond_camera_matrix[:3, 0] = np.array(json_content['x']) 22 | cond_camera_matrix[:3, 1] = -np.array(json_content['y']) 23 | cond_camera_matrix[:3, 2] = -np.array(json_content['z']) 24 | cond_camera_matrix[:3, 3] = np.array(json_content['origin']) 25 | 26 | 27 | camera_matrix = np.eye(4) 28 | camera_matrix[:3, 0] = np.array(json_content['x']) 29 | camera_matrix[:3, 1] = np.array(json_content['y']) 30 | camera_matrix[:3, 2] = np.array(json_content['z']) 31 | camera_matrix[:3, 3] = np.array(json_content['origin']) 32 | 33 | return camera_matrix, cond_camera_matrix 34 | 35 | @DATASETS.register_class() 36 | class Video_I2V_Dataset(Dataset): 37 | def __init__(self, 38 | data_list, 39 | data_dir_list, 40 | caption_dir, 41 | max_words=1000, 42 | resolution=(384, 256), 43 | vit_resolution=(224, 224), 44 | max_frames=16, 45 | sample_fps=8, 46 | transforms=None, 47 | vit_transforms=None, 48 | get_first_frame=True, 49 | prepare_lgm=False, 50 | **kwargs): 51 | 52 | self.prepare_lgm = prepare_lgm 53 | self.max_words = max_words 54 | self.max_frames = max_frames 55 | self.resolution = resolution 56 | self.vit_resolution = vit_resolution 57 | self.sample_fps = sample_fps 58 | self.transforms = transforms 59 | self.vit_transforms = vit_transforms 60 | self.get_first_frame = get_first_frame 61 | 62 | # @NOTE instead we read json 63 | image_list = [] 64 | # self.captions = json.load(open(caption_dir)) 65 | self.captions = None 66 | for item_path, data_dir in zip(data_list, data_dir_list): 67 | lines = json.load(open(item_path)) 68 | lines = [[data_dir, item] for item in lines] 69 | image_list.extend(lines) 70 | self.image_list = image_list 71 | self.replica = 1000 72 | 73 | if self.prepare_lgm: 74 | from core.options import config_defaults 75 | self.opt = config_defaults['big'] 76 | # default camera intrinsics 77 | self.tan_half_fov = np.tan(0.5 * np.deg2rad(self.opt.fovy)) 78 | self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32) 79 | self.proj_matrix[0, 0] = 1 / self.tan_half_fov 80 | self.proj_matrix[1, 1] = 1 / self.tan_half_fov 81 | self.proj_matrix[2, 2] = (self.opt.zfar + self.opt.znear) / (self.opt.zfar - self.opt.znear) 82 | self.proj_matrix[3, 2] = - (self.opt.zfar * self.opt.znear) / (self.opt.zfar - self.opt.znear) 83 | self.proj_matrix[2, 3] = 1 84 | 85 | def __getitem__(self, index): 86 | index = index % len(self.image_list) 87 | data_dir, file_path = self.image_list[index] 88 | video_key = file_path 89 | caption = "" 90 | 91 | try: 92 | ref_frame, vit_frame, video_data, fullreso_video_data, camera_data, mask_data, fullreso_mask_data = self._get_video_data(data_dir, file_path) 93 | if self.prepare_lgm: 94 | results = self.prepare_gs(camera_data.clone(), fullreso_mask_data.clone(), fullreso_video_data.clone()) 95 | results['images_output'] = fullreso_video_data # GT renderings of [512, 512] resolution in the range [0,1] 96 | except Exception as e: 97 | print(e) 98 | return self.__getitem__((index+1)%len(self)) # next available data 99 | 100 | if self.prepare_lgm: 101 | return results, ref_frame, vit_frame, video_data, camera_data, mask_data, caption, video_key 102 | else: 103 | return ref_frame, vit_frame, video_data, camera_data, mask_data, caption, video_key 104 | 105 | def prepare_gs(self, camera_data, mask_data, video_data): # mask_data [24,512,512,1] 106 | 107 | results = {} 108 | 109 | mask_data = mask_data.permute(0,3,1,2) 110 | results['masks_output'] = mask_data/255.0 # TODO normalize to [0, 1] 111 | 112 | T = camera_data.shape[0] 113 | camera_data = camera_data.view(T,4,4).contiguous() 114 | 115 | camera_data[:,1] *= -1 116 | camera_data[:,[1, 2]] = camera_data[:,[2, 1]] 117 | cam_dis = np.sqrt(camera_data[0,0,3]**2 + camera_data[0,1,3]**2 + camera_data[0,2,3]**2) 118 | 119 | # normalized camera feats as in paper (transform the first pose to a fixed position) 120 | transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, cam_dis], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(camera_data[0]) 121 | cam_poses = transform.unsqueeze(0) @ camera_data # [V, 4, 4] 122 | 123 | cam_poses_input = cam_poses.clone() 124 | 125 | rays_embeddings = [] 126 | for i in range(T): 127 | rays_o, rays_d = get_rays(cam_poses_input[i], 256, 256, self.opt.fovy) # [h, w, 3] 128 | rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6] 129 | rays_embeddings.append(rays_plucker) 130 | 131 | rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous() # [V=24, 6, h, w] 132 | results['input'] = rays_embeddings 133 | 134 | # opengl to colmap camera for gs renderer 135 | cam_poses_input[:,:3,1:3] *= -1 136 | 137 | # cameras needed by gaussian rasterizer 138 | cam_view = torch.inverse(cam_poses_input).transpose(1, 2) # [V, 4, 4] 139 | cam_view_proj = cam_view @ self.proj_matrix # [V, 4, 4] 140 | cam_pos = - cam_poses_input[:, :3, 3] # [V, 3] 141 | 142 | results['cam_view'] = cam_view 143 | results['cam_view_proj'] = cam_view_proj 144 | results['cam_pos'] = cam_pos 145 | 146 | return results 147 | 148 | def _get_video_data(self, data_dir, file_path): 149 | prefix = os.path.join(data_dir, file_path, 'campos_512_v4') 150 | 151 | frames_path = [os.path.join(prefix, "{:05d}/{:05d}.png".format(frame_idx, frame_idx)) for frame_idx in range(24)] 152 | camera_path = [os.path.join(prefix, "{:05d}/{:05d}.json".format(frame_idx, frame_idx)) for frame_idx in range(24)] 153 | 154 | frame_list = [] 155 | fullreso_frame_list = [] 156 | camera_list = [] 157 | mask_list = [] 158 | fullreso_mask_list = [] 159 | for frame_idx, frame_path in enumerate(frames_path): 160 | img = Image.open(frame_path).convert('RGBA') 161 | mask = torch.from_numpy(np.array(img.resize((self.resolution[1], self.resolution[0])))[:,:,-1]).unsqueeze(-1) 162 | mask_list.append(mask) 163 | fullreso_mask = torch.from_numpy(np.array(img)[:,:,-1]).unsqueeze(-1) 164 | fullreso_mask_list.append(fullreso_mask) 165 | 166 | width = img.width 167 | height = img.height 168 | grey_scale = 255 169 | image = Image.new('RGB', size=(width, height), color=(grey_scale,grey_scale,grey_scale)) 170 | image.paste(img,(0,0),mask=img) 171 | 172 | fullreso_frame_list.append(torch.from_numpy(np.array(image)/255.0).float()) # for LGM rendering NOTE notice the data range [0,1] 173 | frame_list.append(image.resize((self.resolution[1], self.resolution[0]))) 174 | 175 | _, camera_embedding = read_camera_matrix_single(camera_path[frame_idx]) 176 | camera_list.append(torch.from_numpy(camera_embedding.flatten().astype(np.float32))) 177 | 178 | camera_data = torch.stack(camera_list, dim=0) # [24,16] 179 | mask_data = torch.stack(mask_list, dim=0) 180 | fullreso_mask_data = torch.stack(fullreso_mask_list, dim=0) 181 | 182 | video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0]) 183 | 184 | fullreso_video_data = torch.zeros(self.max_frames, 3, 512, 512) 185 | 186 | if self.get_first_frame: 187 | ref_idx = 0 188 | else: 189 | ref_idx = int(len(frame_list)/2) 190 | 191 | mid_frame = copy(frame_list[ref_idx]) 192 | vit_frame = self.vit_transforms(mid_frame) 193 | frames = self.transforms(frame_list) 194 | video_data[:len(frame_list), ...] = frames 195 | 196 | if True: # random augmentation 197 | split_idx = np.random.randint(0, len(frame_list)) 198 | video_data = torch.cat([video_data[split_idx:], video_data[:split_idx]], dim=0) 199 | 200 | fullreso_video_data[:len(fullreso_frame_list), ...] = torch.stack(fullreso_frame_list, dim=0).permute(0,3,1,2) 201 | 202 | ref_frame = copy(frames[ref_idx]) 203 | 204 | return ref_frame, vit_frame, video_data, fullreso_video_data, camera_data, mask_data, fullreso_mask_data 205 | 206 | def __len__(self): 207 | return len(self.image_list)*self.replica 208 | 209 | -------------------------------------------------------------------------------- /tools/hooks/__init__.py: -------------------------------------------------------------------------------- 1 | from .visual_train_it2v_video import * 2 | -------------------------------------------------------------------------------- /tools/hooks/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/hooks/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /tools/hooks/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/hooks/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /tools/hooks/__pycache__/visual_train_it2v_video.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/hooks/__pycache__/visual_train_it2v_video.cpython-310.pyc -------------------------------------------------------------------------------- /tools/hooks/__pycache__/visual_train_it2v_video.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/hooks/__pycache__/visual_train_it2v_video.cpython-38.pyc -------------------------------------------------------------------------------- /tools/hooks/visual_train_it2v_video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pynvml 4 | import logging 5 | from einops import rearrange 6 | import torch.cuda.amp as amp 7 | 8 | from utils.video_op import save_video_refimg_and_text 9 | from utils.registry_class import VISUAL 10 | 11 | from PIL import Image 12 | import numpy as np 13 | 14 | 15 | @VISUAL.register_class() 16 | class VisualTrainTextImageToVideo(object): 17 | def __init__(self, cfg_global, autoencoder, diffusion, viz_num, partial_keys=[], guide_scale=9.0, use_offset_noise=None, **kwargs): 18 | super(VisualTrainTextImageToVideo, self).__init__(**kwargs) 19 | self.cfg = cfg_global 20 | self.viz_num = viz_num 21 | self.diffusion = diffusion 22 | self.autoencoder = autoencoder 23 | self.guide_scale = guide_scale 24 | self.partial_keys_list = partial_keys 25 | self.use_offset_noise = use_offset_noise 26 | 27 | def prepare_model_kwargs(self, partial_keys, full_model_kwargs): 28 | """ 29 | """ 30 | model_kwargs = [{}, {}] 31 | for partial_key in partial_keys: 32 | model_kwargs[0][partial_key] = full_model_kwargs[0][partial_key] 33 | model_kwargs[1][partial_key] = full_model_kwargs[1][partial_key] 34 | return model_kwargs 35 | 36 | @torch.no_grad() 37 | def run(self, 38 | model, 39 | video_data, 40 | captions, 41 | step=0, 42 | ref_frame=None, 43 | visual_kwards=[], 44 | **kwargs): 45 | 46 | cfg = self.cfg 47 | viz_num = min(self.viz_num, video_data.size(0)) 48 | 49 | # save latent video_data first shape:[B,C,F,H,W] 50 | save_vid_data = video_data.clone().detach() 51 | for idx in range(save_vid_data.shape[0]): 52 | save_vid = save_vid_data[idx].permute(1,0,2,3) 53 | save_vid = torch.cat(save_vid.chunk(24),dim=-1).squeeze(0) 54 | save_vid = torch.cat(save_vid.chunk(4),dim=-2).squeeze(0) 55 | max_value = save_vid.max() 56 | min_value = save_vid.min() 57 | 58 | file_name = f'rank{cfg.rank:02d}_index{idx:02d}.png' 59 | local_path = os.path.join(cfg.log_dir, f'sample_{step:06d}/{file_name}') 60 | os.makedirs(os.path.join(cfg.log_dir, f'sample_{step:06d}'), exist_ok=True) 61 | save_vid = (save_vid - min_value)/(max_value - min_value) 62 | save_vid = Image.fromarray((save_vid.cpu().numpy()*255).astype(np.uint8)).save(local_path) 63 | 64 | noise = torch.randn_like(video_data[:viz_num]) 65 | if self.use_offset_noise: 66 | noise_strength = getattr(cfg, 'noise_strength', 0) 67 | b, c, f, *_ = video_data[:viz_num].shape 68 | noise = noise + noise_strength * torch.randn(b, c, f, 1, 1, device=video_data.device) 69 | 70 | # import ipdb; ipdb.set_trace() 71 | # print memory 72 | pynvml.nvmlInit() 73 | handle=pynvml.nvmlDeviceGetHandleByIndex(0) 74 | meminfo=pynvml.nvmlDeviceGetMemoryInfo(handle) 75 | logging.info(f'GPU Memory used {meminfo.used / (1024 ** 3):.2f} GB') 76 | 77 | for keys in self.partial_keys_list: 78 | model_kwargs = self.prepare_model_kwargs(keys, visual_kwards) 79 | pre_name = '_'.join(keys) 80 | with amp.autocast(enabled=cfg.use_fp16): 81 | video_data = self.diffusion.ddim_sample_loop( 82 | noise=noise.clone(), 83 | model=model.eval(), 84 | model_kwargs=model_kwargs, 85 | guide_scale=self.guide_scale, 86 | ddim_timesteps=cfg.ddim_timesteps, 87 | eta=0.0) 88 | 89 | # save latent video_data pred shape:[B,C,F,H,W] 90 | save_vid_data_pred = video_data.clone().detach() 91 | for idx in range(save_vid_data_pred.shape[0]): 92 | save_vid = save_vid_data_pred[idx].permute(1,0,2,3) 93 | save_vid = torch.cat(save_vid.chunk(24),dim=-1).squeeze(0) 94 | save_vid = torch.cat(save_vid.chunk(4),dim=-2).squeeze(0) 95 | max_value = save_vid.max() 96 | min_value = save_vid.min() 97 | 98 | file_name = f'rank{cfg.rank:02d}_index{idx:02d}_pred.png' 99 | local_path = os.path.join(cfg.log_dir, f'sample_{step:06d}/{file_name}') 100 | os.makedirs(os.path.join(cfg.log_dir, f'sample_{step:06d}'), exist_ok=True) 101 | save_vid = (save_vid - min_value)/(max_value - min_value) 102 | save_vid = Image.fromarray((save_vid.cpu().numpy()*255).astype(np.uint8)).save(local_path) 103 | 104 | video_data = 1. / cfg.scale_factor * video_data # [64, 4, 32, 48] 105 | video_data = rearrange(video_data, 'b c f h w -> (b f) c h w') 106 | chunk_size = min(cfg.decoder_bs, video_data.shape[0]) 107 | video_data_list = torch.chunk(video_data, video_data.shape[0]//chunk_size,dim=0) 108 | decode_data = [] 109 | for vd_data in video_data_list: 110 | gen_frames = self.autoencoder.decode(vd_data) 111 | decode_data.append(gen_frames) 112 | video_data = torch.cat(decode_data, dim=0) 113 | video_data = rearrange(video_data, '(b f) c h w -> b c f h w', b = viz_num) 114 | 115 | text_size = cfg.resolution[-1] 116 | ref_frame = ref_frame[:viz_num] 117 | file_name = f'rank_{cfg.world_size:02d}_{cfg.rank:02d}_{cfg.sample_fps:02d}_{pre_name}' 118 | local_path = os.path.join(cfg.log_dir, f'sample_{step:06d}/{file_name}') 119 | os.makedirs(os.path.dirname(local_path), exist_ok=True) 120 | try: 121 | save_video_refimg_and_text(local_path, ref_frame.cpu(), video_data.cpu(), captions, cfg.mean, cfg.std, text_size) 122 | except Exception as e: 123 | logging.info(f'Step: {step} save text or video error with {e}') -------------------------------------------------------------------------------- /tools/hooks/visual_train_t2v.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pynvml 4 | import logging 5 | from einops import rearrange 6 | import torch.cuda.amp as amp 7 | 8 | from utils.video_op import save_video_refimg_and_text 9 | from utils.registry_class import VISUAL 10 | 11 | from PIL import Image 12 | import numpy as np 13 | 14 | @VISUAL.register_class() 15 | class VisualTrainTextToVideo(object): 16 | def __init__(self, cfg_global, autoencoder, diffusion, viz_num, partial_keys=[], guide_scale=9.0, use_offset_noise=None, **kwargs): 17 | super(VisualTrainTextToVideo, self).__init__(**kwargs) 18 | self.cfg = cfg_global 19 | self.viz_num = viz_num 20 | self.diffusion = diffusion 21 | self.autoencoder = autoencoder 22 | self.guide_scale = guide_scale 23 | self.partial_keys_list = partial_keys 24 | self.use_offset_noise = use_offset_noise 25 | 26 | def prepare_model_kwargs(self, partial_keys, full_model_kwargs): 27 | """ 28 | """ 29 | model_kwargs = [{}, {}] 30 | for partial_key in partial_keys: 31 | model_kwargs[0][partial_key] = full_model_kwargs[0][partial_key] 32 | model_kwargs[1][partial_key] = full_model_kwargs[1][partial_key] 33 | return model_kwargs 34 | 35 | @torch.no_grad() 36 | def run(self, 37 | model, 38 | video_data, 39 | captions, 40 | step=0, 41 | ref_frame=None, 42 | visual_kwards=[], 43 | **kwargs): 44 | cfg = self.cfg 45 | viz_num = self.viz_num 46 | 47 | 48 | 49 | noise = torch.randn_like(video_data[:viz_num]) # viz_num: 8 50 | if self.use_offset_noise: 51 | noise_strength = getattr(cfg, 'noise_strength', 0) 52 | b, c, f, *_ = video_data[:viz_num].shape 53 | noise = noise + noise_strength * torch.randn(b, c, f, 1, 1, device=video_data.device) 54 | 55 | # print memory 56 | pynvml.nvmlInit() 57 | handle=pynvml.nvmlDeviceGetHandleByIndex(0) 58 | meminfo=pynvml.nvmlDeviceGetMemoryInfo(handle) 59 | logging.info(f'GPU Memory used {meminfo.used / (1024 ** 3):.2f} GB') 60 | 61 | for keys in self.partial_keys_list: 62 | model_kwargs = self.prepare_model_kwargs(keys, visual_kwards) 63 | pre_name = '_'.join(keys) 64 | with amp.autocast(enabled=cfg.use_fp16): 65 | video_data = self.diffusion.ddim_sample_loop( 66 | noise=noise.clone(), 67 | model=model.eval(), 68 | model_kwargs=model_kwargs, 69 | guide_scale=self.guide_scale, 70 | ddim_timesteps=cfg.ddim_timesteps, 71 | eta=0.0) 72 | 73 | video_data = 1. / cfg.scale_factor * video_data # [64, 4, 32, 48] 74 | video_data = rearrange(video_data, 'b c f h w -> (b f) c h w') 75 | chunk_size = min(cfg.decoder_bs, video_data.shape[0]) 76 | video_data_list = torch.chunk(video_data, video_data.shape[0]//chunk_size,dim=0) 77 | decode_data = [] 78 | for vd_data in video_data_list: 79 | gen_frames = self.autoencoder.decode(vd_data) 80 | decode_data.append(gen_frames) 81 | video_data = torch.cat(decode_data, dim=0) 82 | video_data = rearrange(video_data, '(b f) c h w -> b c f h w', b = viz_num) 83 | 84 | text_size = cfg.resolution[-1] 85 | ref_frame = ref_frame[:viz_num] 86 | file_name = f'rank_{cfg.world_size:02d}_{cfg.rank:02d}_{cfg.sample_fps:02d}_{pre_name}' 87 | local_path = os.path.join(cfg.log_dir, f'sample_{step:06d}/{file_name}') 88 | os.makedirs(os.path.dirname(local_path), exist_ok=True) 89 | try: 90 | save_video_refimg_and_text(local_path, ref_frame.cpu(), video_data.cpu(), captions, cfg.mean, cfg.std, text_size) 91 | except Exception as e: 92 | logging.info(f'Step: {step} save text or video error with {e}') 93 | 94 | 95 | -------------------------------------------------------------------------------- /tools/inferences/__init__.py: -------------------------------------------------------------------------------- 1 | from .inference_i2vgen_entrance import * 2 | from .inference_text2video_entrance import * 3 | 4 | -------------------------------------------------------------------------------- /tools/inferences/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/inferences/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /tools/inferences/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/inferences/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /tools/inferences/__pycache__/inference_i2vgen_entrance.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/inferences/__pycache__/inference_i2vgen_entrance.cpython-310.pyc -------------------------------------------------------------------------------- /tools/inferences/__pycache__/inference_i2vgen_entrance.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/inferences/__pycache__/inference_i2vgen_entrance.cpython-38.pyc -------------------------------------------------------------------------------- /tools/inferences/__pycache__/inference_text2video_entrance.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/inferences/__pycache__/inference_text2video_entrance.cpython-310.pyc -------------------------------------------------------------------------------- /tools/inferences/__pycache__/inference_text2video_entrance.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/inferences/__pycache__/inference_text2video_entrance.cpython-38.pyc -------------------------------------------------------------------------------- /tools/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip_embedder import FrozenOpenCLIPEmbedder 2 | from .autoencoder import DiagonalGaussianDistribution, AutoencoderKL 3 | from .clip_embedder import * 4 | from .autoencoder import * 5 | from .unet import * 6 | from .diffusions import * -------------------------------------------------------------------------------- /tools/modules/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /tools/modules/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /tools/modules/__pycache__/autoencoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/__pycache__/autoencoder.cpython-310.pyc -------------------------------------------------------------------------------- /tools/modules/__pycache__/autoencoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/__pycache__/autoencoder.cpython-38.pyc -------------------------------------------------------------------------------- /tools/modules/__pycache__/clip_embedder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/__pycache__/clip_embedder.cpython-310.pyc -------------------------------------------------------------------------------- /tools/modules/__pycache__/clip_embedder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/__pycache__/clip_embedder.cpython-38.pyc -------------------------------------------------------------------------------- /tools/modules/__pycache__/config.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/__pycache__/config.cpython-310.pyc -------------------------------------------------------------------------------- /tools/modules/__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /tools/modules/clip_embedder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import logging 4 | import open_clip 5 | import numpy as np 6 | import torch.nn as nn 7 | import torchvision.transforms as T 8 | 9 | from utils.registry_class import EMBEDDER 10 | 11 | 12 | @EMBEDDER.register_class() 13 | class FrozenOpenCLIPEmbedder(nn.Module): 14 | """ 15 | Uses the OpenCLIP transformer encoder for text 16 | """ 17 | LAYERS = [ 18 | #"pooled", 19 | "last", 20 | "penultimate" 21 | ] 22 | def __init__(self, pretrained, arch="ViT-H-14", device="cuda", max_length=77, 23 | freeze=True, layer="last"): 24 | super().__init__() 25 | assert layer in self.LAYERS 26 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=pretrained) 27 | del model.visual 28 | self.model = model 29 | 30 | self.device = device 31 | self.max_length = max_length 32 | if freeze: 33 | self.freeze() 34 | self.layer = layer 35 | if self.layer == "last": 36 | self.layer_idx = 0 37 | elif self.layer == "penultimate": 38 | self.layer_idx = 1 39 | else: 40 | raise NotImplementedError() 41 | 42 | def freeze(self): 43 | self.model = self.model.eval() 44 | for param in self.parameters(): 45 | param.requires_grad = False 46 | 47 | def forward(self, text): 48 | tokens = open_clip.tokenize(text) 49 | z = self.encode_with_transformer(tokens.to(self.device)) 50 | return z 51 | 52 | def encode_with_transformer(self, text): 53 | x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] 54 | x = x + self.model.positional_embedding 55 | x = x.permute(1, 0, 2) # NLD -> LND 56 | x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) 57 | x = x.permute(1, 0, 2) # LND -> NLD 58 | x = self.model.ln_final(x) 59 | return x 60 | 61 | def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): 62 | for i, r in enumerate(self.model.transformer.resblocks): 63 | if i == len(self.model.transformer.resblocks) - self.layer_idx: 64 | break 65 | if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): 66 | x = checkpoint(r, x, attn_mask) 67 | else: 68 | x = r(x, attn_mask=attn_mask) 69 | return x 70 | 71 | def encode(self, text): 72 | return self(text) 73 | 74 | 75 | @EMBEDDER.register_class() 76 | class FrozenOpenCLIPVisualEmbedder(nn.Module): 77 | """ 78 | Uses the OpenCLIP transformer encoder for text 79 | """ 80 | LAYERS = [ 81 | #"pooled", 82 | "last", 83 | "penultimate" 84 | ] 85 | def __init__(self, pretrained, vit_resolution=(224, 224), arch="ViT-H-14", device="cuda", max_length=77, 86 | freeze=True, layer="last"): 87 | super().__init__() 88 | assert layer in self.LAYERS 89 | model, _, preprocess = open_clip.create_model_and_transforms( 90 | arch, device=torch.device('cpu'), pretrained=pretrained) 91 | # Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711) 92 | del model.transformer 93 | self.model = model 94 | data_white = np.ones((vit_resolution[0], vit_resolution[1], 3), dtype=np.uint8)*255 95 | self.white_image = preprocess(T.ToPILImage()(data_white)).unsqueeze(0) 96 | 97 | self.device = device 98 | self.max_length = max_length # 77 99 | if freeze: 100 | self.freeze() 101 | self.layer = layer # 'penultimate' 102 | if self.layer == "last": 103 | self.layer_idx = 0 104 | elif self.layer == "penultimate": 105 | self.layer_idx = 1 106 | else: 107 | raise NotImplementedError() 108 | 109 | def freeze(self): # model.encode_image(torch.randn(2,3,224,224)) 110 | self.model = self.model.eval() 111 | for param in self.parameters(): 112 | param.requires_grad = False 113 | 114 | def forward(self, image): 115 | # tokens = open_clip.tokenize(text) 116 | z = self.model.encode_image(image.to(self.device)) 117 | return z 118 | 119 | def encode_with_transformer(self, text): 120 | x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] 121 | x = x + self.model.positional_embedding 122 | x = x.permute(1, 0, 2) # NLD -> LND 123 | x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) 124 | x = x.permute(1, 0, 2) # LND -> NLD 125 | x = self.model.ln_final(x) 126 | 127 | return x 128 | 129 | def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): 130 | for i, r in enumerate(self.model.transformer.resblocks): 131 | if i == len(self.model.transformer.resblocks) - self.layer_idx: 132 | break 133 | if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): 134 | x = checkpoint(r, x, attn_mask) 135 | else: 136 | x = r(x, attn_mask=attn_mask) 137 | return x 138 | 139 | def encode(self, text): 140 | return self(text) 141 | 142 | 143 | 144 | @EMBEDDER.register_class() 145 | class FrozenOpenCLIPTtxtVisualEmbedder(nn.Module): 146 | """ 147 | Uses the OpenCLIP transformer encoder for text 148 | """ 149 | LAYERS = [ 150 | #"pooled", 151 | "last", 152 | "penultimate" 153 | ] 154 | def __init__(self, pretrained, arch="ViT-H-14", device="cuda", max_length=77, 155 | freeze=True, layer="last", **kwargs): 156 | super().__init__() 157 | assert layer in self.LAYERS 158 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=pretrained) 159 | self.model = model 160 | 161 | self.device = device 162 | self.max_length = max_length 163 | if freeze: 164 | self.freeze() 165 | self.layer = layer 166 | if self.layer == "last": 167 | self.layer_idx = 0 168 | elif self.layer == "penultimate": 169 | self.layer_idx = 1 170 | else: 171 | raise NotImplementedError() 172 | 173 | def freeze(self): 174 | self.model = self.model.eval() 175 | for param in self.parameters(): 176 | param.requires_grad = False 177 | 178 | # def forward(self, text): 179 | # tokens = open_clip.tokenize(text) 180 | # z = self.encode_with_transformer(tokens.to(self.device)) 181 | # return z 182 | 183 | def forward(self, image=None, text=None): 184 | # xi = self.encode_image(image) if image is not None else None 185 | xi = self.model.encode_image(image.to(self.device)) if image is not None else None 186 | # tokens = open_clip.tokenize(text, truncate=True) 187 | tokens = open_clip.tokenize(text) 188 | xt, x = self.encode_with_transformer(tokens.to(self.device)) 189 | return xi, xt, x 190 | 191 | def encode_with_transformer(self, text): 192 | x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] 193 | x = x + self.model.positional_embedding 194 | x = x.permute(1, 0, 2) # NLD -> LND 195 | x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) 196 | x = x.permute(1, 0, 2) # LND -> NLD 197 | x = self.model.ln_final(x) 198 | xt = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.model.text_projection 199 | return xt, x 200 | 201 | # def encode_with_transformer(self, text): 202 | # x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] 203 | # x = x + self.model.positional_embedding 204 | # x = x.permute(1, 0, 2) # NLD -> LND 205 | # x = self.model.transformer(x) 206 | # x = x.permute(1, 0, 2) # LND -> NLD 207 | # x = self.model.ln_final(x) 208 | # xt = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.model.text_projection 209 | # # text embedding, token embedding 210 | # return xt, x 211 | 212 | def encode_image(self, image): 213 | return self.model.visual(image) 214 | 215 | def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): 216 | for i, r in enumerate(self.model.transformer.resblocks): 217 | if i == len(self.model.transformer.resblocks) - self.layer_idx: 218 | break 219 | if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): 220 | x = checkpoint(r, x, attn_mask) 221 | else: 222 | x = r(x, attn_mask=attn_mask) 223 | return x 224 | 225 | def encode(self, text): 226 | 227 | return self(text) 228 | 229 | -------------------------------------------------------------------------------- /tools/modules/config.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | import os.path as osp 4 | from datetime import datetime 5 | from easydict import EasyDict 6 | import os 7 | 8 | cfg = EasyDict(__name__='Config: VideoLDM Decoder') 9 | 10 | # -------------------------------distributed training-------------------------- 11 | pmi_world_size = int(os.getenv('WORLD_SIZE', 1)) 12 | gpus_per_machine = torch.cuda.device_count() 13 | world_size = pmi_world_size * gpus_per_machine 14 | # ----------------------------------------------------------------------------- 15 | 16 | 17 | # ---------------------------Dataset Parameter--------------------------------- 18 | cfg.mean = [0.5, 0.5, 0.5] 19 | cfg.std = [0.5, 0.5, 0.5] 20 | cfg.max_words = 1000 21 | cfg.num_workers = 8 22 | cfg.prefetch_factor = 2 23 | 24 | # PlaceHolder 25 | cfg.resolution = [448, 256] 26 | cfg.vit_out_dim = 1024 27 | cfg.vit_resolution = 336 28 | cfg.depth_clamp = 10.0 29 | cfg.misc_size = 384 30 | cfg.depth_std = 20.0 31 | 32 | cfg.frame_lens = [32, 32, 32, 1] 33 | cfg.sample_fps = [4, ] 34 | cfg.vid_dataset = { 35 | 'type': 'VideoBaseDataset', 36 | 'data_list': [], 37 | 'max_words': cfg.max_words, 38 | 'resolution': cfg.resolution} 39 | cfg.img_dataset = { 40 | 'type': 'ImageBaseDataset', 41 | 'data_list': ['laion_400m',], 42 | 'max_words': cfg.max_words, 43 | 'resolution': cfg.resolution} 44 | 45 | cfg.batch_sizes = { 46 | str(1):256, 47 | str(4):4, 48 | str(8):4, 49 | str(16):4} 50 | # ----------------------------------------------------------------------------- 51 | 52 | 53 | # ---------------------------Mode Parameters----------------------------------- 54 | # Diffusion 55 | cfg.Diffusion = { 56 | 'type': 'DiffusionDDIM', 57 | 'schedule': 'cosine', # cosine 58 | 'schedule_param': { 59 | 'num_timesteps': 1000, 60 | 'cosine_s': 0.008, 61 | 'zero_terminal_snr': True, 62 | }, 63 | 'mean_type': 'v', # [v, eps] 64 | 'loss_type': 'mse', 65 | 'var_type': 'fixed_small', 66 | 'rescale_timesteps': False, 67 | 'noise_strength': 0.1, 68 | 'ddim_timesteps': 50 69 | } 70 | cfg.ddim_timesteps = 50 # official: 250 71 | cfg.use_div_loss = False 72 | # classifier-free guidance 73 | cfg.p_zero = 0.9 74 | cfg.guide_scale = 3.0 75 | 76 | # clip vision encoder 77 | cfg.vit_mean = [0.48145466, 0.4578275, 0.40821073] 78 | cfg.vit_std = [0.26862954, 0.26130258, 0.27577711] 79 | 80 | # Model 81 | cfg.scale_factor = 0.18215 82 | cfg.use_checkpoint = True 83 | cfg.use_sharded_ddp = False 84 | cfg.use_fsdp = False 85 | cfg.use_fp16 = True 86 | cfg.temporal_attention = True 87 | 88 | cfg.UNet = { 89 | 'type': 'UNetSD', 90 | 'in_dim': 4, 91 | 'dim': 320, 92 | 'y_dim': cfg.vit_out_dim, 93 | 'context_dim': 1024, 94 | 'out_dim': 8, 95 | 'dim_mult': [1, 2, 4, 4], 96 | 'num_heads': 8, 97 | 'head_dim': 64, 98 | 'num_res_blocks': 2, 99 | 'attn_scales': [1 / 1, 1 / 2, 1 / 4], 100 | 'dropout': 0.1, 101 | 'temporal_attention': cfg.temporal_attention, 102 | 'temporal_attn_times': 1, 103 | 'use_checkpoint': cfg.use_checkpoint, 104 | 'use_fps_condition': False, 105 | 'use_sim_mask': False 106 | } 107 | 108 | # auotoencoder from stabel diffusion 109 | cfg.guidances = [] 110 | cfg.auto_encoder = { 111 | 'type': 'AutoencoderKL', 112 | 'ddconfig': { 113 | 'double_z': True, 114 | 'z_channels': 4, 115 | 'resolution': 256, 116 | 'in_channels': 3, 117 | 'out_ch': 3, 118 | 'ch': 128, 119 | 'ch_mult': [1, 2, 4, 4], 120 | 'num_res_blocks': 2, 121 | 'attn_resolutions': [], 122 | 'dropout': 0.0, 123 | 'video_kernel_size': [3, 1, 1] 124 | }, 125 | 'embed_dim': 4, 126 | 'pretrained': './pretrained_models/modelscope_t2v/VQGAN_autoencoder.pth' 127 | } 128 | # clip embedder 129 | cfg.embedder = { 130 | 'type': 'FrozenOpenCLIPEmbedder', 131 | 'layer': 'penultimate', 132 | 'pretrained': 'modelscope_t2v/open_clip_pytorch_model.bin' 133 | } 134 | # ----------------------------------------------------------------------------- 135 | 136 | # ---------------------------Training Settings--------------------------------- 137 | # training and optimizer 138 | cfg.ema_decay = 0.9999 139 | cfg.num_steps = 600000 140 | cfg.lr = 5e-5 141 | cfg.weight_decay = 0.0 142 | cfg.betas = (0.9, 0.999) 143 | cfg.eps = 1.0e-8 144 | cfg.chunk_size = 16 145 | cfg.decoder_bs = 8 146 | cfg.alpha = 0.7 147 | cfg.save_ckp_interval = 1000 148 | 149 | # scheduler 150 | cfg.warmup_steps = 10 151 | cfg.decay_mode = 'cosine' 152 | 153 | # acceleration 154 | cfg.use_ema = True 155 | if world_size<2: 156 | cfg.use_ema = False 157 | cfg.load_from = None 158 | # ----------------------------------------------------------------------------- 159 | 160 | 161 | # ----------------------------Pretrain Settings--------------------------------- 162 | cfg.Pretrain = { 163 | 'type': 'pretrain_specific_strategies', 164 | 'fix_weight': False, 165 | 'grad_scale': 0.2, 166 | 'resume_checkpoint': 'models/jiuniu_0267000.pth', 167 | 'sd_keys_path': 'models/stable_diffusion_image_key_temporal_attention_x1.json', 168 | } 169 | # ----------------------------------------------------------------------------- 170 | 171 | 172 | # -----------------------------Visual------------------------------------------- 173 | # Visual videos 174 | cfg.viz_interval = 1000 175 | cfg.visual_train = { 176 | 'type': 'VisualTrainTextImageToVideo', 177 | } 178 | cfg.visual_inference = { 179 | 'type': 'VisualGeneratedVideos', 180 | } 181 | cfg.inference_list_path = '' 182 | 183 | # logging 184 | cfg.log_interval = 100 185 | 186 | ### Default log_dir 187 | cfg.log_dir = 'workspace/temp_dir' 188 | # ----------------------------------------------------------------------------- 189 | 190 | 191 | # ---------------------------Others-------------------------------------------- 192 | # seed 193 | cfg.seed = 8888 194 | # motionless static 195 | cfg.negative_prompt = 'Distorted, discontinuous, Ugly, blurry, low resolution, disfigured, disconnected limbs, Ugly faces, incomplete arms' 196 | # ----------------------------------------------------------------------------- 197 | 198 | -------------------------------------------------------------------------------- /tools/modules/diffusions/__init__.py: -------------------------------------------------------------------------------- 1 | from .diffusion_ddim import * 2 | -------------------------------------------------------------------------------- /tools/modules/diffusions/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/diffusions/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /tools/modules/diffusions/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/diffusions/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /tools/modules/diffusions/__pycache__/diffusion_ddim.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/diffusions/__pycache__/diffusion_ddim.cpython-310.pyc -------------------------------------------------------------------------------- /tools/modules/diffusions/__pycache__/diffusion_ddim.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/diffusions/__pycache__/diffusion_ddim.cpython-38.pyc -------------------------------------------------------------------------------- /tools/modules/diffusions/__pycache__/losses.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/diffusions/__pycache__/losses.cpython-310.pyc -------------------------------------------------------------------------------- /tools/modules/diffusions/__pycache__/losses.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/diffusions/__pycache__/losses.cpython-38.pyc -------------------------------------------------------------------------------- /tools/modules/diffusions/__pycache__/schedules.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/diffusions/__pycache__/schedules.cpython-310.pyc -------------------------------------------------------------------------------- /tools/modules/diffusions/__pycache__/schedules.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/diffusions/__pycache__/schedules.cpython-38.pyc -------------------------------------------------------------------------------- /tools/modules/diffusions/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | __all__ = ['kl_divergence', 'discretized_gaussian_log_likelihood'] 5 | 6 | def kl_divergence(mu1, logvar1, mu2, logvar2): 7 | return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mu1 - mu2) ** 2) * torch.exp(-logvar2)) 8 | 9 | def standard_normal_cdf(x): 10 | r"""A fast approximation of the cumulative distribution function of the standard normal. 11 | """ 12 | return 0.5 * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 13 | 14 | def discretized_gaussian_log_likelihood(x0, mean, log_scale): 15 | assert x0.shape == mean.shape == log_scale.shape 16 | cx = x0 - mean 17 | inv_stdv = torch.exp(-log_scale) 18 | cdf_plus = standard_normal_cdf(inv_stdv * (cx + 1.0 / 255.0)) 19 | cdf_min = standard_normal_cdf(inv_stdv * (cx - 1.0 / 255.0)) 20 | log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) 21 | log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) 22 | cdf_delta = cdf_plus - cdf_min 23 | log_probs = torch.where( 24 | x0 < -0.999, 25 | log_cdf_plus, 26 | torch.where(x0 > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12)))) 27 | assert log_probs.shape == x0.shape 28 | return log_probs 29 | -------------------------------------------------------------------------------- /tools/modules/diffusions/schedules.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | 5 | def beta_schedule(schedule='cosine', 6 | num_timesteps=1000, 7 | zero_terminal_snr=False, 8 | **kwargs): 9 | # compute betas 10 | betas = { 11 | 'logsnr_cosine_interp': logsnr_cosine_interp_schedule, 12 | 'linear': linear_schedule, 13 | 'linear_sd': linear_sd_schedule, 14 | 'quadratic': quadratic_schedule, 15 | 'cosine': cosine_schedule 16 | }[schedule](num_timesteps, **kwargs) 17 | 18 | if zero_terminal_snr and betas.max() != 1.0: 19 | betas = rescale_zero_terminal_snr(betas) 20 | 21 | return betas 22 | 23 | 24 | def linear_schedule(num_timesteps, init_beta, last_beta, **kwargs): 25 | scale = 1000.0 / num_timesteps 26 | init_beta = init_beta or scale * 0.0001 27 | ast_beta = last_beta or scale * 0.02 28 | return torch.linspace(init_beta, last_beta, num_timesteps, dtype=torch.float64) 29 | 30 | def logsnr_cosine_interp_schedule( 31 | num_timesteps, 32 | scale_min=2, 33 | scale_max=4, 34 | logsnr_min=-15, 35 | logsnr_max=15, 36 | **kwargs): 37 | return logsnrs_to_sigmas( 38 | _logsnr_cosine_interp(num_timesteps, logsnr_min, logsnr_max, scale_min, scale_max)) 39 | 40 | def linear_sd_schedule(num_timesteps, init_beta, last_beta, **kwargs): 41 | return torch.linspace(init_beta ** 0.5, last_beta ** 0.5, num_timesteps, dtype=torch.float64) ** 2 42 | 43 | 44 | def quadratic_schedule(num_timesteps, init_beta, last_beta, **kwargs): 45 | init_beta = init_beta or 0.0015 46 | last_beta = last_beta or 0.0195 47 | return torch.linspace(init_beta ** 0.5, last_beta ** 0.5, num_timesteps, dtype=torch.float64) ** 2 48 | 49 | 50 | def cosine_schedule(num_timesteps, cosine_s=0.008, **kwargs): 51 | betas = [] 52 | for step in range(num_timesteps): 53 | t1 = step / num_timesteps 54 | t2 = (step + 1) / num_timesteps 55 | fn = lambda u: math.cos((u + cosine_s) / (1 + cosine_s) * math.pi / 2) ** 2 56 | betas.append(min(1.0 - fn(t2) / fn(t1), 0.999)) 57 | return torch.tensor(betas, dtype=torch.float64) 58 | 59 | 60 | # def cosine_schedule(n, cosine_s=0.008, **kwargs): 61 | # ramp = torch.linspace(0, 1, n + 1) 62 | # square_alphas = torch.cos((ramp + cosine_s) / (1 + cosine_s) * torch.pi / 2) ** 2 63 | # betas = (1 - square_alphas[1:] / square_alphas[:-1]).clamp(max=0.999) 64 | # return betas_to_sigmas(betas) 65 | 66 | 67 | def betas_to_sigmas(betas): 68 | return torch.sqrt(1 - torch.cumprod(1 - betas, dim=0)) 69 | 70 | 71 | def sigmas_to_betas(sigmas): 72 | square_alphas = 1 - sigmas**2 73 | betas = 1 - torch.cat( 74 | [square_alphas[:1], square_alphas[1:] / square_alphas[:-1]]) 75 | return betas 76 | 77 | 78 | 79 | def sigmas_to_logsnrs(sigmas): 80 | square_sigmas = sigmas**2 81 | return torch.log(square_sigmas / (1 - square_sigmas)) 82 | 83 | 84 | def _logsnr_cosine(n, logsnr_min=-15, logsnr_max=15): 85 | t_min = math.atan(math.exp(-0.5 * logsnr_min)) 86 | t_max = math.atan(math.exp(-0.5 * logsnr_max)) 87 | t = torch.linspace(1, 0, n) 88 | logsnrs = -2 * torch.log(torch.tan(t_min + t * (t_max - t_min))) 89 | return logsnrs 90 | 91 | 92 | def _logsnr_cosine_shifted(n, logsnr_min=-15, logsnr_max=15, scale=2): 93 | logsnrs = _logsnr_cosine(n, logsnr_min, logsnr_max) 94 | logsnrs += 2 * math.log(1 / scale) 95 | return logsnrs 96 | 97 | def karras_schedule(n, sigma_min=0.002, sigma_max=80.0, rho=7.0): 98 | ramp = torch.linspace(1, 0, n) 99 | min_inv_rho = sigma_min**(1 / rho) 100 | max_inv_rho = sigma_max**(1 / rho) 101 | sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho))**rho 102 | sigmas = torch.sqrt(sigmas**2 / (1 + sigmas**2)) 103 | return sigmas 104 | 105 | def _logsnr_cosine_interp(n, 106 | logsnr_min=-15, 107 | logsnr_max=15, 108 | scale_min=2, 109 | scale_max=4): 110 | t = torch.linspace(1, 0, n) 111 | logsnrs_min = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_min) 112 | logsnrs_max = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_max) 113 | logsnrs = t * logsnrs_min + (1 - t) * logsnrs_max 114 | return logsnrs 115 | 116 | 117 | def logsnrs_to_sigmas(logsnrs): 118 | return torch.sqrt(torch.sigmoid(-logsnrs)) 119 | 120 | 121 | def rescale_zero_terminal_snr(betas): 122 | """ 123 | Rescale Schedule to Zero Terminal SNR 124 | """ 125 | # Convert betas to alphas_bar_sqrt 126 | alphas = 1 - betas 127 | alphas_bar = alphas.cumprod(0) 128 | alphas_bar_sqrt = alphas_bar.sqrt() 129 | 130 | # Store old values. 8 alphas_bar_sqrt_0 = a 131 | alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() 132 | alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() 133 | # Shift so last timestep is zero. 134 | alphas_bar_sqrt -= alphas_bar_sqrt_T 135 | # Scale so first timestep is back to old value. 136 | alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) 137 | 138 | # Convert alphas_bar_sqrt to betas 139 | alphas_bar = alphas_bar_sqrt ** 2 140 | alphas = alphas_bar[1:] / alphas_bar[:-1] 141 | alphas = torch.cat([alphas_bar[0:1], alphas]) 142 | betas = 1 - alphas 143 | return betas 144 | 145 | -------------------------------------------------------------------------------- /tools/modules/unet/__init__.py: -------------------------------------------------------------------------------- 1 | from .unet_i2vgen import * 2 | from .unet_t2v import * -------------------------------------------------------------------------------- /tools/modules/unet/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/unet/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /tools/modules/unet/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/unet/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /tools/modules/unet/__pycache__/depthwise_attn.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/unet/__pycache__/depthwise_attn.cpython-310.pyc -------------------------------------------------------------------------------- /tools/modules/unet/__pycache__/depthwise_attn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/unet/__pycache__/depthwise_attn.cpython-38.pyc -------------------------------------------------------------------------------- /tools/modules/unet/__pycache__/depthwise_net.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/unet/__pycache__/depthwise_net.cpython-310.pyc -------------------------------------------------------------------------------- /tools/modules/unet/__pycache__/depthwise_net.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/unet/__pycache__/depthwise_net.cpython-38.pyc -------------------------------------------------------------------------------- /tools/modules/unet/__pycache__/depthwise_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/unet/__pycache__/depthwise_utils.cpython-310.pyc -------------------------------------------------------------------------------- /tools/modules/unet/__pycache__/depthwise_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/unet/__pycache__/depthwise_utils.cpython-38.pyc -------------------------------------------------------------------------------- /tools/modules/unet/__pycache__/unet_i2vgen.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/unet/__pycache__/unet_i2vgen.cpython-310.pyc -------------------------------------------------------------------------------- /tools/modules/unet/__pycache__/unet_i2vgen.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/unet/__pycache__/unet_i2vgen.cpython-38.pyc -------------------------------------------------------------------------------- /tools/modules/unet/__pycache__/unet_t2v.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/unet/__pycache__/unet_t2v.cpython-310.pyc -------------------------------------------------------------------------------- /tools/modules/unet/__pycache__/unet_t2v.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/unet/__pycache__/unet_t2v.cpython-38.pyc -------------------------------------------------------------------------------- /tools/modules/unet/__pycache__/util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/unet/__pycache__/util.cpython-310.pyc -------------------------------------------------------------------------------- /tools/modules/unet/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/modules/unet/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /tools/modules/unet/mha_flash.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.cuda.amp as amp 4 | import torch.nn.functional as F 5 | import math 6 | import os 7 | import time 8 | import numpy as np 9 | import random 10 | 11 | # from flash_attn.flash_attention import FlashAttention 12 | 13 | class FlashAttentionBlock(nn.Module): 14 | 15 | def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None, batch_size=4): 16 | # consider head_dim first, then num_heads 17 | num_heads = dim // head_dim if head_dim else num_heads 18 | head_dim = dim // num_heads 19 | assert num_heads * head_dim == dim 20 | super(FlashAttentionBlock, self).__init__() 21 | self.dim = dim 22 | self.context_dim = context_dim 23 | self.num_heads = num_heads 24 | self.head_dim = head_dim 25 | self.scale = math.pow(head_dim, -0.25) 26 | 27 | # layers 28 | self.norm = nn.GroupNorm(32, dim) 29 | self.to_qkv = nn.Conv2d(dim, dim * 3, 1) 30 | if context_dim is not None: 31 | self.context_kv = nn.Linear(context_dim, dim * 2) 32 | self.proj = nn.Conv2d(dim, dim, 1) 33 | 34 | if self.head_dim <= 128 and (self.head_dim % 8) == 0: 35 | new_scale = math.pow(head_dim, -0.5) 36 | self.flash_attn = FlashAttention(softmax_scale=None, attention_dropout=0.0) 37 | 38 | # zero out the last layer params 39 | nn.init.zeros_(self.proj.weight) 40 | # self.apply(self._init_weight) 41 | 42 | 43 | def _init_weight(self, module): 44 | if isinstance(module, nn.Linear): 45 | module.weight.data.normal_(mean=0.0, std=0.15) 46 | if module.bias is not None: 47 | module.bias.data.zero_() 48 | elif isinstance(module, nn.Conv2d): 49 | module.weight.data.normal_(mean=0.0, std=0.15) 50 | if module.bias is not None: 51 | module.bias.data.zero_() 52 | 53 | def forward(self, x, context=None): 54 | r"""x: [B, C, H, W]. 55 | context: [B, L, C] or None. 56 | """ 57 | identity = x 58 | b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim 59 | 60 | # compute query, key, value 61 | x = self.norm(x) 62 | q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1) 63 | if context is not None: 64 | ck, cv = self.context_kv(context).reshape(b, -1, n * 2, d).permute(0, 2, 3, 1).chunk(2, dim=1) 65 | k = torch.cat([ck, k], dim=-1) 66 | v = torch.cat([cv, v], dim=-1) 67 | cq = torch.zeros([b, n, d, 4], dtype=q.dtype, device=q.device) 68 | q = torch.cat([q, cq], dim=-1) 69 | 70 | qkv = torch.cat([q,k,v], dim=1) 71 | origin_dtype = qkv.dtype 72 | qkv = qkv.permute(0, 3, 1, 2).reshape(b, -1, 3, n, d).half().contiguous() 73 | out, _ = self.flash_attn(qkv) 74 | out.to(origin_dtype) 75 | 76 | if context is not None: 77 | out = out[:, :-4, :, :] 78 | out = out.permute(0, 2, 3, 1).reshape(b, c, h, w) 79 | 80 | # output 81 | x = self.proj(out) 82 | return x + identity 83 | 84 | if __name__ == '__main__': 85 | batch_size = 8 86 | flash_net = FlashAttentionBlock(dim=1280, context_dim=512, num_heads=None, head_dim=64, batch_size=batch_size).cuda() 87 | 88 | x = torch.randn([batch_size, 1280, 32, 32], dtype=torch.float32).cuda() 89 | context = torch.randn([batch_size, 4, 512], dtype=torch.float32).cuda() 90 | # context = None 91 | flash_net.eval() 92 | 93 | with amp.autocast(enabled=True): 94 | # warm up 95 | for i in range(5): 96 | y = flash_net(x, context) 97 | torch.cuda.synchronize() 98 | s1 = time.time() 99 | for i in range(10): 100 | y = flash_net(x, context) 101 | torch.cuda.synchronize() 102 | s2 = time.time() 103 | 104 | print(f'Average cost time {(s2-s1)*1000/10} ms') -------------------------------------------------------------------------------- /tools/train/__init__.py: -------------------------------------------------------------------------------- 1 | from .train_t2v_enterance import * 2 | from .train_i2v_enterance import * -------------------------------------------------------------------------------- /tools/train/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/train/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /tools/train/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/train/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /tools/train/__pycache__/train_i2v_enterance.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/train/__pycache__/train_i2v_enterance.cpython-310.pyc -------------------------------------------------------------------------------- /tools/train/__pycache__/train_i2v_enterance.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/train/__pycache__/train_i2v_enterance.cpython-38.pyc -------------------------------------------------------------------------------- /tools/train/__pycache__/train_t2v_enterance.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/train/__pycache__/train_t2v_enterance.cpython-310.pyc -------------------------------------------------------------------------------- /tools/train/__pycache__/train_t2v_enterance.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/tools/train/__pycache__/train_t2v_enterance.cpython-38.pyc -------------------------------------------------------------------------------- /train_net.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import copy 4 | import json 5 | import math 6 | import random 7 | import logging 8 | import itertools 9 | import numpy as np 10 | 11 | from utils.config import Config 12 | from utils.registry_class import ENGINE 13 | 14 | from tools import * 15 | 16 | if __name__ == '__main__': 17 | cfg_update = Config(load=True) 18 | ENGINE.build(dict(type=cfg_update.TASK_TYPE), cfg_update=cfg_update.cfg_dict) 19 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/assign_cfg.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/assign_cfg.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/assign_cfg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/assign_cfg.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/camera_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/camera_utils.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/camera_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/camera_utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/config.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/config.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/distributed.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/distributed.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/distributed.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/distributed.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logging.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/logging.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logging.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/logging.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/multi_port.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/multi_port.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/multi_port.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/multi_port.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/registry.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/registry.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/registry.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/registry.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/registry_class.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/registry_class.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/registry_class.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/registry_class.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/seed.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/seed.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/seed.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/seed.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/transforms.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/transforms.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/transforms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/transforms.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/util.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/video_op.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/video_op.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/video_op.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/__pycache__/video_op.cpython-38.pyc -------------------------------------------------------------------------------- /utils/assign_cfg.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from copy import deepcopy, copy 3 | 4 | 5 | # def get prior and ldm config 6 | def assign_prior_mudule_cfg(cfg): 7 | ''' 8 | ''' 9 | # 10 | prior_cfg = deepcopy(cfg) 11 | vldm_cfg = deepcopy(cfg) 12 | 13 | with open(cfg.prior_cfg, 'r') as f: 14 | _cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader) 15 | # _cfg_update = _cfg_update.cfg_dict 16 | for k, v in _cfg_update.items(): 17 | if isinstance(v, dict) and k in cfg: 18 | prior_cfg[k].update(v) 19 | else: 20 | prior_cfg[k] = v 21 | 22 | with open(cfg.vldm_cfg, 'r') as f: 23 | _cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader) 24 | # _cfg_update = _cfg_update.cfg_dict 25 | for k, v in _cfg_update.items(): 26 | if isinstance(v, dict) and k in cfg: 27 | vldm_cfg[k].update(v) 28 | else: 29 | vldm_cfg[k] = v 30 | 31 | return prior_cfg, vldm_cfg 32 | 33 | 34 | # def get prior and ldm config 35 | def assign_vldm_vsr_mudule_cfg(cfg): 36 | ''' 37 | ''' 38 | # 39 | vldm_cfg = deepcopy(cfg) 40 | vsr_cfg = deepcopy(cfg) 41 | 42 | with open(cfg.vldm_cfg, 'r') as f: 43 | _cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader) 44 | # _cfg_update = _cfg_update.cfg_dict 45 | for k, v in _cfg_update.items(): 46 | if isinstance(v, dict) and k in cfg: 47 | vldm_cfg[k].update(v) 48 | else: 49 | vldm_cfg[k] = v 50 | 51 | with open(cfg.vsr_cfg, 'r') as f: 52 | _cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader) 53 | # _cfg_update = _cfg_update.cfg_dict 54 | for k, v in _cfg_update.items(): 55 | if isinstance(v, dict) and k in cfg: 56 | vsr_cfg[k].update(v) 57 | else: 58 | vsr_cfg[k] = v 59 | 60 | return vldm_cfg, vsr_cfg 61 | 62 | 63 | # def get prior and ldm config 64 | def assign_signle_cfg(cfg, _cfg_update, tname): 65 | ''' 66 | ''' 67 | # 68 | vldm_cfg = deepcopy(cfg) 69 | with open(_cfg_update[tname], 'r') as f: 70 | _cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader) 71 | # _cfg_update = _cfg_update.cfg_dict 72 | for k, v in _cfg_update.items(): 73 | if isinstance(v, dict) and k in cfg: 74 | vldm_cfg[k].update(v) 75 | else: 76 | vldm_cfg[k] = v 77 | return vldm_cfg -------------------------------------------------------------------------------- /utils/camera_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def create_camera_to_world_matrix(elevation, azimuth, camera_distance=1): 5 | elevation = np.radians(elevation) 6 | azimuth = np.radians(azimuth) 7 | # Convert elevation and azimuth angles to Cartesian coordinates on a unit sphere 8 | x = camera_distance * np.cos(elevation) * np.sin(azimuth) 9 | y = camera_distance * np.sin(elevation) 10 | z = camera_distance * np.cos(elevation) * np.cos(azimuth) 11 | 12 | # Calculate camera position, target, and up vectors 13 | camera_pos = np.array([x, y, z]) 14 | target = np.array([0, 0, 0]) 15 | up = np.array([0, 1, 0]) 16 | 17 | # Construct view matrix 18 | forward = target - camera_pos 19 | forward /= np.linalg.norm(forward) 20 | right = np.cross(forward, up) 21 | right /= np.linalg.norm(right) 22 | new_up = np.cross(right, forward) 23 | new_up /= np.linalg.norm(new_up) 24 | cam2world = np.eye(4) 25 | cam2world[:3, :3] = np.array([right, new_up, -forward]).T 26 | cam2world[:3, 3] = camera_pos 27 | return cam2world 28 | 29 | 30 | def convert_opengl_to_blender(camera_matrix): 31 | if isinstance(camera_matrix, np.ndarray): 32 | # Construct transformation matrix to convert from OpenGL space to Blender space 33 | flip_yz = np.array([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], 34 | [0, 0, 0, 1]]) 35 | camera_matrix_blender = np.dot(flip_yz, camera_matrix) 36 | else: 37 | # Construct transformation matrix to convert from OpenGL space to Blender space 38 | flip_yz = torch.tensor([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], 39 | [0, 0, 0, 1]]) 40 | if camera_matrix.ndim == 3: 41 | flip_yz = flip_yz.unsqueeze(0) 42 | camera_matrix_blender = torch.matmul( 43 | flip_yz.to(camera_matrix), camera_matrix) 44 | return camera_matrix_blender 45 | 46 | def get_camera(num_frames, 47 | elevation=15, 48 | azimuth_start=0, 49 | azimuth_span=360, 50 | blender_coord=True, 51 | camera_distance=1.): 52 | angle_gap = azimuth_span / num_frames 53 | cameras = [] 54 | for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, 55 | angle_gap): 56 | camera_matrix = create_camera_to_world_matrix(elevation, azimuth, 57 | camera_distance) 58 | 59 | if blender_coord: 60 | camera_matrix = convert_opengl_to_blender(camera_matrix) 61 | cameras.append(camera_matrix.flatten()) 62 | return torch.tensor(np.stack(cameras, 0)).float() -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import json 4 | import copy 5 | import argparse 6 | 7 | import utils.logging as logging 8 | logger = logging.get_logger(__name__) 9 | 10 | class Config(object): 11 | def __init__(self, load=True, cfg_dict=None, cfg_level=None): 12 | self._level = "cfg" + ("." + cfg_level if cfg_level is not None else "") 13 | if load: 14 | self.args = self._parse_args() 15 | logger.info("Loading config from {}.".format(self.args.cfg_file)) 16 | self.need_initialization = True 17 | cfg_base = self._initialize_cfg() 18 | cfg_dict = self._load_yaml(self.args) 19 | cfg_dict = self._merge_cfg_from_base(cfg_base, cfg_dict) 20 | cfg_dict = self._update_from_args(cfg_dict) 21 | self.cfg_dict = cfg_dict 22 | self._update_dict(cfg_dict) 23 | 24 | def _parse_args(self): 25 | parser = argparse.ArgumentParser( 26 | description="Argparser for configuring [code base name to think of] codebase" 27 | ) 28 | parser.add_argument( 29 | "--cfg", 30 | dest="cfg_file", 31 | help="Path to the configuration file", 32 | default='configs/i2vgen_xl_infer.yaml' 33 | ) 34 | parser.add_argument( 35 | "--init_method", 36 | help="Initialization method, includes TCP or shared file-system", 37 | default="tcp://localhost:9999", 38 | type=str, 39 | ) 40 | parser.add_argument( 41 | '--debug', 42 | action='store_true', 43 | default=False, 44 | help='Into debug information' 45 | ) 46 | parser.add_argument( 47 | "opts", 48 | help="other configurations", 49 | default=None, 50 | nargs=argparse.REMAINDER) 51 | return parser.parse_args() 52 | 53 | def _path_join(self, path_list): 54 | path = "" 55 | for p in path_list: 56 | path+= p + '/' 57 | return path[:-1] 58 | 59 | def _update_from_args(self, cfg_dict): 60 | args = self.args 61 | for var in vars(args): 62 | cfg_dict[var] = getattr(args, var) 63 | return cfg_dict 64 | 65 | def _initialize_cfg(self): 66 | if self.need_initialization: 67 | self.need_initialization = False 68 | if os.path.exists('./configs/base.yaml'): 69 | with open("./configs/base.yaml", 'r') as f: 70 | cfg = yaml.load(f.read(), Loader=yaml.SafeLoader) 71 | else: 72 | with open(os.path.realpath(__file__).split('/')[-3] + "/configs/base.yaml", 'r') as f: 73 | cfg = yaml.load(f.read(), Loader=yaml.SafeLoader) 74 | return cfg 75 | 76 | def _load_yaml(self, args, file_name=""): 77 | assert args.cfg_file is not None 78 | if not file_name == "": # reading from base file 79 | with open(file_name, 'r') as f: 80 | cfg = yaml.load(f.read(), Loader=yaml.SafeLoader) 81 | else: 82 | if os.getcwd().split("/")[-1] == args.cfg_file.split("/")[0]: 83 | args.cfg_file = args.cfg_file.replace(os.getcwd().split("/")[-1], "./") 84 | with open(args.cfg_file, 'r') as f: 85 | cfg = yaml.load(f.read(), Loader=yaml.SafeLoader) 86 | file_name = args.cfg_file 87 | 88 | if "_BASE_RUN" not in cfg.keys() and "_BASE_MODEL" not in cfg.keys() and "_BASE" not in cfg.keys(): 89 | # return cfg if the base file is being accessed 90 | cfg = self._merge_cfg_from_command_update(args, cfg) 91 | return cfg 92 | 93 | if "_BASE" in cfg.keys(): 94 | if cfg["_BASE"][1] == '.': 95 | prev_count = cfg["_BASE"].count('..') 96 | cfg_base_file = self._path_join(file_name.split('/')[:(-1-cfg["_BASE"].count('..'))] + cfg["_BASE"].split('/')[prev_count:]) 97 | else: 98 | cfg_base_file = cfg["_BASE"].replace( 99 | "./", 100 | args.cfg_file.replace(args.cfg_file.split('/')[-1], "") 101 | ) 102 | cfg_base = self._load_yaml(args, cfg_base_file) 103 | cfg = self._merge_cfg_from_base(cfg_base, cfg) 104 | else: 105 | if "_BASE_RUN" in cfg.keys(): 106 | if cfg["_BASE_RUN"][1] == '.': 107 | prev_count = cfg["_BASE_RUN"].count('..') 108 | cfg_base_file = self._path_join(file_name.split('/')[:(-1-prev_count)] + cfg["_BASE_RUN"].split('/')[prev_count:]) 109 | else: 110 | cfg_base_file = cfg["_BASE_RUN"].replace( 111 | "./", 112 | args.cfg_file.replace(args.cfg_file.split('/')[-1], "") 113 | ) 114 | cfg_base = self._load_yaml(args, cfg_base_file) 115 | cfg = self._merge_cfg_from_base(cfg_base, cfg, preserve_base=True) 116 | if "_BASE_MODEL" in cfg.keys(): 117 | if cfg["_BASE_MODEL"][1] == '.': 118 | prev_count = cfg["_BASE_MODEL"].count('..') 119 | cfg_base_file = self._path_join(file_name.split('/')[:(-1-cfg["_BASE_MODEL"].count('..'))] + cfg["_BASE_MODEL"].split('/')[prev_count:]) 120 | else: 121 | cfg_base_file = cfg["_BASE_MODEL"].replace( 122 | "./", 123 | args.cfg_file.replace(args.cfg_file.split('/')[-1], "") 124 | ) 125 | cfg_base = self._load_yaml(args, cfg_base_file) 126 | cfg = self._merge_cfg_from_base(cfg_base, cfg) 127 | cfg = self._merge_cfg_from_command(args, cfg) 128 | return cfg 129 | 130 | def _merge_cfg_from_base(self, cfg_base, cfg_new, preserve_base=False): 131 | for k,v in cfg_new.items(): 132 | if k in cfg_base.keys(): 133 | if isinstance(v, dict): 134 | self._merge_cfg_from_base(cfg_base[k], v) 135 | else: 136 | cfg_base[k] = v 137 | else: 138 | if "BASE" not in k or preserve_base: 139 | cfg_base[k] = v 140 | return cfg_base 141 | 142 | def _merge_cfg_from_command_update(self, args, cfg): 143 | if len(args.opts) == 0: 144 | return cfg 145 | 146 | assert len(args.opts) % 2 == 0, 'Override list {} has odd length: {}.'.format( 147 | args.opts, len(args.opts) 148 | ) 149 | keys = args.opts[0::2] 150 | vals = args.opts[1::2] 151 | 152 | for key, val in zip(keys, vals): 153 | cfg[key] = val 154 | 155 | return cfg 156 | 157 | def _merge_cfg_from_command(self, args, cfg): 158 | assert len(args.opts) % 2 == 0, 'Override list {} has odd length: {}.'.format( 159 | args.opts, len(args.opts) 160 | ) 161 | keys = args.opts[0::2] 162 | vals = args.opts[1::2] 163 | 164 | # maximum supported depth 3 165 | for idx, key in enumerate(keys): 166 | key_split = key.split('.') 167 | assert len(key_split) <= 4, 'Key depth error. \nMaximum depth: 3\n Get depth: {}'.format( 168 | len(key_split) 169 | ) 170 | assert key_split[0] in cfg.keys(), 'Non-existant key: {}.'.format( 171 | key_split[0] 172 | ) 173 | if len(key_split) == 2: 174 | assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format( 175 | key 176 | ) 177 | elif len(key_split) == 3: 178 | assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format( 179 | key 180 | ) 181 | assert key_split[2] in cfg[key_split[0]][key_split[1]].keys(), 'Non-existant key: {}.'.format( 182 | key 183 | ) 184 | elif len(key_split) == 4: 185 | assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format( 186 | key 187 | ) 188 | assert key_split[2] in cfg[key_split[0]][key_split[1]].keys(), 'Non-existant key: {}.'.format( 189 | key 190 | ) 191 | assert key_split[3] in cfg[key_split[0]][key_split[1]][key_split[2]].keys(), 'Non-existant key: {}.'.format( 192 | key 193 | ) 194 | if len(key_split) == 1: 195 | cfg[key_split[0]] = vals[idx] 196 | elif len(key_split) == 2: 197 | cfg[key_split[0]][key_split[1]] = vals[idx] 198 | elif len(key_split) == 3: 199 | cfg[key_split[0]][key_split[1]][key_split[2]] = vals[idx] 200 | elif len(key_split) == 4: 201 | cfg[key_split[0]][key_split[1]][key_split[2]][key_split[3]] = vals[idx] 202 | return cfg 203 | 204 | def _update_dict(self, cfg_dict): 205 | def recur(key, elem): 206 | if type(elem) is dict: 207 | return key, Config(load=False, cfg_dict=elem, cfg_level=key) 208 | else: 209 | if type(elem) is str and elem[1:3]=="e-": 210 | elem = float(elem) 211 | return key, elem 212 | dic = dict(recur(k, v) for k, v in cfg_dict.items()) 213 | self.__dict__.update(dic) 214 | 215 | def get_args(self): 216 | return self.args 217 | 218 | def __repr__(self): 219 | return "{}\n".format(self.dump()) 220 | 221 | def dump(self): 222 | return json.dumps(self.cfg_dict, indent=2) 223 | 224 | def deep_copy(self): 225 | return copy.deepcopy(self) 226 | 227 | if __name__ == '__main__': 228 | # debug 229 | cfg = Config(load=True) 230 | print(cfg.DATA) -------------------------------------------------------------------------------- /utils/logging.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Logging.""" 5 | 6 | import builtins 7 | import decimal 8 | import functools 9 | import logging 10 | import os 11 | import sys 12 | import simplejson 13 | # from fvcore.common.file_io import PathManager 14 | 15 | import utils.distributed as du 16 | 17 | 18 | def _suppress_print(): 19 | """ 20 | Suppresses printing from the current process. 21 | """ 22 | 23 | def print_pass(*objects, sep=" ", end="\n", file=sys.stdout, flush=False): 24 | pass 25 | 26 | builtins.print = print_pass 27 | 28 | 29 | # @functools.lru_cache(maxsize=None) 30 | # def _cached_log_stream(filename): 31 | # return PathManager.open(filename, "a") 32 | 33 | 34 | def setup_logging(cfg, log_file): 35 | """ 36 | Sets up the logging for multiple processes. Only enable the logging for the 37 | master process, and suppress logging for the non-master processes. 38 | """ 39 | if du.is_master_proc(): 40 | # Enable logging for the master process. 41 | logging.root.handlers = [] 42 | else: 43 | # Suppress logging for non-master processes. 44 | _suppress_print() 45 | 46 | logger = logging.getLogger() 47 | logger.setLevel(logging.INFO) 48 | logger.propagate = False 49 | plain_formatter = logging.Formatter( 50 | "[%(asctime)s][%(levelname)s] %(name)s: %(lineno)4d: %(message)s", 51 | datefmt="%m/%d %H:%M:%S", 52 | ) 53 | 54 | if du.is_master_proc(): 55 | ch = logging.StreamHandler(stream=sys.stdout) 56 | ch.setLevel(logging.DEBUG) 57 | ch.setFormatter(plain_formatter) 58 | logger.addHandler(ch) 59 | 60 | if log_file is not None and du.is_master_proc(du.get_world_size()): 61 | filename = os.path.join(cfg.OUTPUT_DIR, log_file) 62 | fh = logging.FileHandler(filename) 63 | fh.setLevel(logging.DEBUG) 64 | fh.setFormatter(plain_formatter) 65 | logger.addHandler(fh) 66 | 67 | 68 | def get_logger(name): 69 | """ 70 | Retrieve the logger with the specified name or, if name is None, return a 71 | logger which is the root logger of the hierarchy. 72 | Args: 73 | name (string): name of the logger. 74 | """ 75 | return logging.getLogger(name) 76 | 77 | 78 | def log_json_stats(stats): 79 | """ 80 | Logs json stats. 81 | Args: 82 | stats (dict): a dictionary of statistical information to log. 83 | """ 84 | stats = { 85 | k: decimal.Decimal("{:.6f}".format(v)) if isinstance(v, float) else v 86 | for k, v in stats.items() 87 | } 88 | json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True) 89 | logger = get_logger(__name__) 90 | logger.info("{:s}".format(json_stats)) 91 | -------------------------------------------------------------------------------- /utils/multi_port.py: -------------------------------------------------------------------------------- 1 | import socket 2 | from contextlib import closing 3 | 4 | def find_free_port(): 5 | """ https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number """ 6 | with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: 7 | s.bind(('', 0)) 8 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 9 | return str(s.getsockname()[1]) -------------------------------------------------------------------------------- /utils/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .lr_scheduler import * 2 | from .adafactor import * 3 | -------------------------------------------------------------------------------- /utils/optim/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/optim/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /utils/optim/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/optim/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/optim/__pycache__/adafactor.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/optim/__pycache__/adafactor.cpython-310.pyc -------------------------------------------------------------------------------- /utils/optim/__pycache__/adafactor.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/optim/__pycache__/adafactor.cpython-38.pyc -------------------------------------------------------------------------------- /utils/optim/__pycache__/lr_scheduler.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/optim/__pycache__/lr_scheduler.cpython-310.pyc -------------------------------------------------------------------------------- /utils/optim/__pycache__/lr_scheduler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/VideoMV/c434922de00a26e8010eb6be6b3466885c38e97b/utils/optim/__pycache__/lr_scheduler.cpython-38.pyc -------------------------------------------------------------------------------- /utils/optim/adafactor.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim import Optimizer 4 | from torch.optim.lr_scheduler import LambdaLR 5 | 6 | __all__ = ['Adafactor'] 7 | 8 | class Adafactor(Optimizer): 9 | """ 10 | AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code: 11 | https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py 12 | Paper: *Adafactor: Adaptive Learning Rates with Sublinear Memory Cost* https://arxiv.org/abs/1804.04235 Note that 13 | this optimizer internally adjusts the learning rate depending on the `scale_parameter`, `relative_step` and 14 | `warmup_init` options. To use a manual (external) learning rate schedule you should set `scale_parameter=False` and 15 | `relative_step=False`. 16 | Arguments: 17 | params (`Iterable[nn.parameter.Parameter]`): 18 | Iterable of parameters to optimize or dictionaries defining parameter groups. 19 | lr (`float`, *optional*): 20 | The external learning rate. 21 | eps (`Tuple[float, float]`, *optional*, defaults to (1e-30, 1e-3)): 22 | Regularization constants for square gradient and parameter scale respectively 23 | clip_threshold (`float`, *optional*, defaults 1.0): 24 | Threshold of root mean square of final gradient update 25 | decay_rate (`float`, *optional*, defaults to -0.8): 26 | Coefficient used to compute running averages of square 27 | beta1 (`float`, *optional*): 28 | Coefficient used for computing running averages of gradient 29 | weight_decay (`float`, *optional*, defaults to 0): 30 | Weight decay (L2 penalty) 31 | scale_parameter (`bool`, *optional*, defaults to `True`): 32 | If True, learning rate is scaled by root mean square 33 | relative_step (`bool`, *optional*, defaults to `True`): 34 | If True, time-dependent learning rate is computed instead of external learning rate 35 | warmup_init (`bool`, *optional*, defaults to `False`): 36 | Time-dependent learning rate computation depends on whether warm-up initialization is being used 37 | This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested. 38 | Recommended T5 finetuning settings (https://discuss.huggingface.co/t/t5-finetuning-tips/684/3): 39 | - Training without LR warmup or clip_threshold is not recommended. 40 | - use scheduled LR warm-up to fixed LR 41 | - use clip_threshold=1.0 (https://arxiv.org/abs/1804.04235) 42 | - Disable relative updates 43 | - Use scale_parameter=False 44 | - Additional optimizer operations like gradient clipping should not be used alongside Adafactor 45 | Example: 46 | ```python 47 | Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3) 48 | ``` 49 | Others reported the following combination to work well: 50 | ```python 51 | Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None) 52 | ``` 53 | When using `lr=None` with [`Trainer`] you will most likely need to use [`~optimization.AdafactorSchedule`] 54 | scheduler as following: 55 | ```python 56 | from transformers.optimization import Adafactor, AdafactorSchedule 57 | optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None) 58 | lr_scheduler = AdafactorSchedule(optimizer) 59 | trainer = Trainer(..., optimizers=(optimizer, lr_scheduler)) 60 | ``` 61 | Usage: 62 | ```python 63 | # replace AdamW with Adafactor 64 | optimizer = Adafactor( 65 | model.parameters(), 66 | lr=1e-3, 67 | eps=(1e-30, 1e-3), 68 | clip_threshold=1.0, 69 | decay_rate=-0.8, 70 | beta1=None, 71 | weight_decay=0.0, 72 | relative_step=False, 73 | scale_parameter=False, 74 | warmup_init=False, 75 | ) 76 | ```""" 77 | 78 | def __init__( 79 | self, 80 | params, 81 | lr=None, 82 | eps=(1e-30, 1e-3), 83 | clip_threshold=1.0, 84 | decay_rate=-0.8, 85 | beta1=None, 86 | weight_decay=0.0, 87 | scale_parameter=True, 88 | relative_step=True, 89 | warmup_init=False, 90 | ): 91 | r"""require_version("torch>=1.5.0") # add_ with alpha 92 | """ 93 | if lr is not None and relative_step: 94 | raise ValueError("Cannot combine manual `lr` and `relative_step=True` options") 95 | if warmup_init and not relative_step: 96 | raise ValueError("`warmup_init=True` requires `relative_step=True`") 97 | 98 | defaults = dict( 99 | lr=lr, 100 | eps=eps, 101 | clip_threshold=clip_threshold, 102 | decay_rate=decay_rate, 103 | beta1=beta1, 104 | weight_decay=weight_decay, 105 | scale_parameter=scale_parameter, 106 | relative_step=relative_step, 107 | warmup_init=warmup_init, 108 | ) 109 | super().__init__(params, defaults) 110 | 111 | @staticmethod 112 | def _get_lr(param_group, param_state): 113 | rel_step_sz = param_group["lr"] 114 | if param_group["relative_step"]: 115 | min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2 116 | rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) 117 | param_scale = 1.0 118 | if param_group["scale_parameter"]: 119 | param_scale = max(param_group["eps"][1], param_state["RMS"]) 120 | return param_scale * rel_step_sz 121 | 122 | @staticmethod 123 | def _get_options(param_group, param_shape): 124 | factored = len(param_shape) >= 2 125 | use_first_moment = param_group["beta1"] is not None 126 | return factored, use_first_moment 127 | 128 | @staticmethod 129 | def _rms(tensor): 130 | return tensor.norm(2) / (tensor.numel() ** 0.5) 131 | 132 | @staticmethod 133 | def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): 134 | # copy from fairseq's adafactor implementation: 135 | # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505 136 | r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) 137 | c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() 138 | return torch.mul(r_factor, c_factor) 139 | 140 | def step(self, closure=None): 141 | """ 142 | Performs a single optimization step 143 | Arguments: 144 | closure (callable, optional): A closure that reevaluates the model 145 | and returns the loss. 146 | """ 147 | loss = None 148 | if closure is not None: 149 | loss = closure() 150 | 151 | for group in self.param_groups: 152 | for p in group["params"]: 153 | if p.grad is None: 154 | continue 155 | grad = p.grad.data 156 | if grad.dtype in {torch.float16, torch.bfloat16}: 157 | grad = grad.float() 158 | if grad.is_sparse: 159 | raise RuntimeError("Adafactor does not support sparse gradients.") 160 | 161 | state = self.state[p] 162 | grad_shape = grad.shape 163 | 164 | factored, use_first_moment = self._get_options(group, grad_shape) 165 | # State Initialization 166 | if len(state) == 0: 167 | state["step"] = 0 168 | 169 | if use_first_moment: 170 | # Exponential moving average of gradient values 171 | state["exp_avg"] = torch.zeros_like(grad) 172 | if factored: 173 | state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad) 174 | state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad) 175 | else: 176 | state["exp_avg_sq"] = torch.zeros_like(grad) 177 | 178 | state["RMS"] = 0 179 | else: 180 | if use_first_moment: 181 | state["exp_avg"] = state["exp_avg"].to(grad) 182 | if factored: 183 | state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) 184 | state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) 185 | else: 186 | state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) 187 | 188 | p_data_fp32 = p.data 189 | if p.data.dtype in {torch.float16, torch.bfloat16}: 190 | p_data_fp32 = p_data_fp32.float() 191 | 192 | state["step"] += 1 193 | state["RMS"] = self._rms(p_data_fp32) 194 | lr = self._get_lr(group, state) 195 | 196 | beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) 197 | update = (grad**2) + group["eps"][0] 198 | if factored: 199 | exp_avg_sq_row = state["exp_avg_sq_row"] 200 | exp_avg_sq_col = state["exp_avg_sq_col"] 201 | 202 | exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t)) 203 | exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t)) 204 | 205 | # Approximation of exponential moving average of square of gradient 206 | update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) 207 | update.mul_(grad) 208 | else: 209 | exp_avg_sq = state["exp_avg_sq"] 210 | 211 | exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) 212 | update = exp_avg_sq.rsqrt().mul_(grad) 213 | 214 | update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) 215 | update.mul_(lr) 216 | 217 | if use_first_moment: 218 | exp_avg = state["exp_avg"] 219 | exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"])) 220 | update = exp_avg 221 | 222 | if group["weight_decay"] != 0: 223 | p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr)) 224 | 225 | p_data_fp32.add_(-update) 226 | 227 | if p.data.dtype in {torch.float16, torch.bfloat16}: 228 | p.data.copy_(p_data_fp32) 229 | 230 | return loss 231 | -------------------------------------------------------------------------------- /utils/optim/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch.optim.lr_scheduler import _LRScheduler 3 | 4 | __all__ = ['AnnealingLR'] 5 | 6 | class AnnealingLR(_LRScheduler): 7 | 8 | def __init__(self, optimizer, base_lr, warmup_steps, total_steps, decay_mode='cosine', min_lr=0.0, last_step=-1): 9 | assert decay_mode in ['linear', 'cosine', 'none'] 10 | self.optimizer = optimizer 11 | self.base_lr = base_lr 12 | self.warmup_steps = warmup_steps 13 | self.total_steps = total_steps 14 | self.decay_mode = decay_mode 15 | self.min_lr = min_lr 16 | self.current_step = last_step + 1 17 | self.step(self.current_step) 18 | 19 | def get_lr(self): 20 | if self.warmup_steps > 0 and self.current_step <= self.warmup_steps: 21 | return self.base_lr * self.current_step / self.warmup_steps 22 | else: 23 | ratio = (self.current_step - self.warmup_steps) / (self.total_steps - self.warmup_steps) 24 | ratio = min(1.0, max(0.0, ratio)) 25 | if self.decay_mode == 'linear': 26 | return self.base_lr * (1 - ratio) 27 | elif self.decay_mode == 'cosine': 28 | return self.base_lr * (math.cos(math.pi * ratio) + 1.0) / 2.0 29 | else: 30 | return self.base_lr 31 | 32 | def step(self, current_step=None): 33 | if current_step is None: 34 | current_step = self.current_step + 1 35 | self.current_step = current_step 36 | new_lr = max(self.min_lr, self.get_lr()) 37 | if isinstance(self.optimizer, list): 38 | for o in self.optimizer: 39 | for group in o.param_groups: 40 | group['lr'] = new_lr 41 | else: 42 | for group in self.optimizer.param_groups: 43 | group['lr'] = new_lr 44 | 45 | def state_dict(self): 46 | return { 47 | 'base_lr': self.base_lr, 48 | 'warmup_steps': self.warmup_steps, 49 | 'total_steps': self.total_steps, 50 | 'decay_mode': self.decay_mode, 51 | 'current_step': self.current_step} 52 | 53 | def load_state_dict(self, state_dict): 54 | self.base_lr = state_dict['base_lr'] 55 | self.warmup_steps = state_dict['warmup_steps'] 56 | self.total_steps = state_dict['total_steps'] 57 | self.decay_mode = state_dict['decay_mode'] 58 | self.current_step = state_dict['current_step'] 59 | -------------------------------------------------------------------------------- /utils/recenter_i2v.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import torchvision 3 | import torch 4 | from torch import optim 5 | import numpy as np 6 | 7 | from inspect import isfunction 8 | from PIL import Image, ImageDraw, ImageFont 9 | 10 | import os, sys 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | from PIL import Image 14 | import torch 15 | import time 16 | import cv2 17 | import PIL 18 | 19 | def add_margin(pil_img, color=0, size=256): 20 | width, height = pil_img.size 21 | result = Image.new(pil_img.mode, (size, size), color) 22 | result.paste(pil_img, ((size - width) // 2, (size - height) // 2)) 23 | return result 24 | 25 | def prepare_inputs(image_path, elevation_input, crop_size=-1, image_size=256): 26 | image_input = Image.open(image_path) 27 | 28 | if crop_size!=-1: 29 | alpha_np = np.asarray(image_input)[:, :, 3] 30 | coords = np.stack(np.nonzero(alpha_np), 1)[:, (1, 0)] 31 | min_x, min_y = np.min(coords, 0) 32 | max_x, max_y = np.max(coords, 0) 33 | ref_img_ = image_input.crop((min_x, min_y, max_x, max_y)) 34 | h, w = ref_img_.height, ref_img_.width 35 | scale = crop_size / max(h, w) 36 | h_, w_ = int(scale * h), int(scale * w) 37 | ref_img_ = ref_img_.resize((w_, h_), resample=Image.BICUBIC) 38 | image_input = add_margin(ref_img_, size=image_size) 39 | else: 40 | image_input = add_margin(image_input, size=max(image_input.height, image_input.width)) 41 | image_input = image_input.resize((image_size, image_size), resample=Image.BICUBIC) 42 | 43 | image_input = np.asarray(image_input) 44 | image_input = image_input.astype(np.float32) / 255.0 45 | if image_input.shape[-1]==4: 46 | ref_mask = image_input[:, :, 3:] 47 | image_input[:, :, :3] = image_input[:, :, :3] * ref_mask + 1 - ref_mask # white background 48 | return image_input 49 | 50 | root_dir = sys.argv[1] 51 | items = [os.path.join(root_dir, item) for item in os.listdir(root_dir)] 52 | for idx, item in enumerate(items): 53 | res = prepare_inputs(item, 15, 200) 54 | Image.fromarray((res*255.0).astype(np.uint8)).save("./data/images", "{:05d}.png".format(idx)) 55 | -------------------------------------------------------------------------------- /utils/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Alibaba Group Holding Limited. All Rights Reserved. 2 | 3 | # Registry class & build_from_config function partially modified from 4 | # https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/registry.py 5 | # Copyright 2018-2020 Open-MMLab. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | import copy 20 | import inspect 21 | import warnings 22 | 23 | 24 | def build_from_config(cfg, registry, **kwargs): 25 | """ Default builder function. 26 | 27 | Args: 28 | cfg (dict): A dict which contains parameters passes to target class or function. 29 | Must contains key 'type', indicates the target class or function name. 30 | registry (Registry): An registry to search target class or function. 31 | kwargs (dict, optional): Other params not in config dict. 32 | 33 | Returns: 34 | Target class object or object returned by invoking function. 35 | 36 | Raises: 37 | TypeError: 38 | KeyError: 39 | Exception: 40 | """ 41 | if not isinstance(cfg, dict): 42 | raise TypeError(f"config must be type dict, got {type(cfg)}") 43 | if "type" not in cfg: 44 | raise KeyError(f"config must contain key type, got {cfg}") 45 | if not isinstance(registry, Registry): 46 | raise TypeError(f"registry must be type Registry, got {type(registry)}") 47 | 48 | cfg = copy.deepcopy(cfg) 49 | 50 | req_type = cfg.pop("type") 51 | req_type_entry = req_type 52 | if isinstance(req_type, str): 53 | req_type_entry = registry.get(req_type) 54 | if req_type_entry is None: 55 | raise KeyError(f"{req_type} not found in {registry.name} registry") 56 | 57 | if kwargs is not None: 58 | cfg.update(kwargs) 59 | 60 | if inspect.isclass(req_type_entry): 61 | try: 62 | return req_type_entry(**cfg) 63 | except Exception as e: 64 | raise Exception(f"Failed to init class {req_type_entry}, with {e}") 65 | elif inspect.isfunction(req_type_entry): 66 | try: 67 | return req_type_entry(**cfg) 68 | except Exception as e: 69 | raise Exception(f"Failed to invoke function {req_type_entry}, with {e}") 70 | else: 71 | raise TypeError(f"type must be str or class, got {type(req_type_entry)}") 72 | 73 | 74 | class Registry(object): 75 | """ A registry maps key to classes or functions. 76 | 77 | Example: 78 | >>> MODELS = Registry('MODELS') 79 | >>> @MODELS.register_class() 80 | >>> class ResNet(object): 81 | >>> pass 82 | >>> resnet = MODELS.build(dict(type="ResNet")) 83 | >>> 84 | >>> import torchvision 85 | >>> @MODELS.register_function("InceptionV3") 86 | >>> def get_inception_v3(pretrained=False, progress=True): 87 | >>> return torchvision.models.inception_v3(pretrained=pretrained, progress=progress) 88 | >>> inception_v3 = MODELS.build(dict(type='InceptionV3', pretrained=True)) 89 | 90 | Args: 91 | name (str): Registry name. 92 | build_func (func, None): Instance construct function. Default is build_from_config. 93 | allow_types (tuple): Indicates how to construct the instance, by constructing class or invoking function. 94 | """ 95 | 96 | def __init__(self, name, build_func=None, allow_types=("class", "function")): 97 | self.name = name 98 | self.allow_types = allow_types 99 | self.class_map = {} 100 | self.func_map = {} 101 | self.build_func = build_func or build_from_config 102 | 103 | def get(self, req_type): 104 | return self.class_map.get(req_type) or self.func_map.get(req_type) 105 | 106 | def build(self, *args, **kwargs): 107 | return self.build_func(*args, **kwargs, registry=self) 108 | 109 | def register_class(self, name=None): 110 | def _register(cls): 111 | if not inspect.isclass(cls): 112 | raise TypeError(f"Module must be type class, got {type(cls)}") 113 | if "class" not in self.allow_types: 114 | raise TypeError(f"Register {self.name} only allows type {self.allow_types}, got class") 115 | module_name = name or cls.__name__ 116 | if module_name in self.class_map: 117 | warnings.warn(f"Class {module_name} already registered by {self.class_map[module_name]}, " 118 | f"will be replaced by {cls}") 119 | self.class_map[module_name] = cls 120 | return cls 121 | 122 | return _register 123 | 124 | def register_function(self, name=None): 125 | def _register(func): 126 | if not inspect.isfunction(func): 127 | raise TypeError(f"Registry must be type function, got {type(func)}") 128 | if "function" not in self.allow_types: 129 | raise TypeError(f"Registry {self.name} only allows type {self.allow_types}, got function") 130 | func_name = name or func.__name__ 131 | if func_name in self.class_map: 132 | warnings.warn(f"Function {func_name} already registered by {self.func_map[func_name]}, " 133 | f"will be replaced by {func}") 134 | self.func_map[func_name] = func 135 | return func 136 | 137 | return _register 138 | 139 | def _list(self): 140 | keys = sorted(list(self.class_map.keys()) + list(self.func_map.keys())) 141 | descriptions = [] 142 | for key in keys: 143 | if key in self.class_map: 144 | descriptions.append(f"{key}: {self.class_map[key]}") 145 | else: 146 | descriptions.append( 147 | f"{key}: ") 148 | return "\n".join(descriptions) 149 | 150 | def __repr__(self): 151 | description = self._list() 152 | description = '\n'.join(['\t' + s for s in description.split('\n')]) 153 | return f"{self.__class__.__name__} [{self.name}], \n" + description 154 | 155 | 156 | -------------------------------------------------------------------------------- /utils/registry_class.py: -------------------------------------------------------------------------------- 1 | from .registry import Registry, build_from_config 2 | 3 | def build_func(cfg, registry, **kwargs): 4 | """ 5 | Except for config, if passing a list of dataset config, then return the concat type of it 6 | """ 7 | return build_from_config(cfg, registry, **kwargs) 8 | 9 | AUTO_ENCODER = Registry("AUTO_ENCODER", build_func=build_func) 10 | DATASETS = Registry("DATASETS", build_func=build_func) 11 | DIFFUSION = Registry("DIFFUSION", build_func=build_func) 12 | DISTRIBUTION = Registry("DISTRIBUTION", build_func=build_func) 13 | EMBEDDER = Registry("EMBEDDER", build_func=build_func) 14 | ENGINE = Registry("ENGINE", build_func=build_func) 15 | INFER_ENGINE = Registry("INFER_ENGINE", build_func=build_func) 16 | MODEL = Registry("MODEL", build_func=build_func) 17 | PRETRAIN = Registry("PRETRAIN", build_func=build_func) 18 | VISUAL = Registry("VISUAL", build_func=build_func) 19 | -------------------------------------------------------------------------------- /utils/seed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | 5 | 6 | def setup_seed(seed): 7 | torch.manual_seed(seed) 8 | torch.cuda.manual_seed_all(seed) 9 | np.random.seed(seed) 10 | random.seed(seed) 11 | torch.backends.cudnn.deterministic = True -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def to_device(batch, device, non_blocking=False): 4 | if isinstance(batch, (list, tuple)): 5 | return type(batch)([ 6 | to_device(u, device, non_blocking) 7 | for u in batch]) 8 | elif isinstance(batch, dict): 9 | return type(batch)([ 10 | (k, to_device(v, device, non_blocking)) 11 | for k, v in batch.items()]) 12 | elif isinstance(batch, torch.Tensor) and batch.device != device: 13 | batch = batch.to(device, non_blocking=non_blocking) 14 | else: 15 | return batch 16 | return batch 17 | -------------------------------------------------------------------------------- /utils/video_op.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import sys 4 | import cv2 5 | import glob 6 | import math 7 | import torch 8 | import gzip 9 | import copy 10 | import time 11 | import json 12 | import pickle 13 | import base64 14 | import imageio 15 | import hashlib 16 | import requests 17 | import binascii 18 | import zipfile 19 | # import skvideo.io 20 | import numpy as np 21 | from io import BytesIO 22 | import urllib.request 23 | import torch.nn.functional as F 24 | import torchvision.utils as tvutils 25 | from multiprocessing.pool import ThreadPool as Pool 26 | from einops import rearrange 27 | from PIL import Image, ImageDraw, ImageFont 28 | 29 | 30 | def gen_text_image(captions, text_size): 31 | num_char = int(38 * (text_size / text_size)) 32 | font_size = int(text_size / 20) 33 | font = ImageFont.truetype('data/font/DejaVuSans.ttf', size=font_size) 34 | text_image_list = [] 35 | for text in captions: 36 | txt_img = Image.new("RGB", (text_size, text_size), color="white") 37 | draw = ImageDraw.Draw(txt_img) 38 | lines = "\n".join(text[start:start + num_char] for start in range(0, len(text), num_char)) 39 | draw.text((0, 0), lines, fill="black", font=font) 40 | txt_img = np.array(txt_img) 41 | text_image_list.append(txt_img) 42 | text_images = np.stack(text_image_list, axis=0) 43 | text_images = torch.from_numpy(text_images) 44 | return text_images 45 | 46 | @torch.no_grad() 47 | def save_video_refimg_and_text( 48 | local_path, 49 | ref_frame, 50 | gen_video, 51 | captions, 52 | mean=[0.5, 0.5, 0.5], 53 | std=[0.5, 0.5, 0.5], 54 | text_size=256, 55 | nrow=4, 56 | save_fps=8, 57 | retry=5): 58 | ''' 59 | gen_video: BxCxFxHxW 60 | ''' 61 | nrow = max(int(gen_video.size(0) / 2), 1) 62 | vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw 63 | vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw 64 | 65 | text_images = gen_text_image(captions, text_size) # Tensor 8x256x256x3 66 | text_images = text_images.unsqueeze(1) # Tensor 8x1x256x256x3 67 | text_images = text_images.repeat_interleave(repeats=gen_video.size(2), dim=1) # 8x16x256x256x3 68 | 69 | ref_frame = ref_frame.unsqueeze(2) 70 | ref_frame = ref_frame.mul_(vid_std).add_(vid_mean) 71 | ref_frame = ref_frame.repeat_interleave(repeats=gen_video.size(2), dim=2) # 8x16x256x256x3 72 | ref_frame.clamp_(0, 1) 73 | ref_frame = ref_frame * 255.0 74 | ref_frame = rearrange(ref_frame, 'b c f h w -> b f h w c') 75 | 76 | gen_video = gen_video.mul_(vid_std).add_(vid_mean) # 8x3x16x256x384 77 | gen_video.clamp_(0, 1) 78 | gen_video = gen_video * 255.0 79 | 80 | images = rearrange(gen_video, 'b c f h w -> b f h w c') 81 | images = torch.cat([ref_frame, images, text_images], dim=3) 82 | 83 | images = rearrange(images, '(r j) f h w c -> f (r h) (j w) c', r=nrow) 84 | images = [(img.numpy()).astype('uint8') for img in images] 85 | 86 | for _ in [None] * retry: 87 | try: 88 | if len(images) == 1: 89 | local_path = local_path + '.png' 90 | cv2.imwrite(local_path, images[0][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) 91 | else: 92 | local_path = local_path + '.mp4' 93 | frame_dir = os.path.join(os.path.dirname(local_path), '%s_frames' % (os.path.basename(local_path))) 94 | os.system(f'rm -rf {frame_dir}'); os.makedirs(frame_dir, exist_ok=True) 95 | for fid, frame in enumerate(images): 96 | tpth = os.path.join(frame_dir, '%04d.png' % (fid+1)) 97 | cv2.imwrite(tpth, frame[:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) 98 | cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate {save_fps} -i {frame_dir}/%04d.png -vcodec libx264 -crf 17 -pix_fmt yuv420p {local_path}' 99 | os.system(cmd); os.system(f'rm -rf {frame_dir}') 100 | # os.system(f'rm -rf {local_path}') 101 | exception = None 102 | break 103 | except Exception as e: 104 | exception = e 105 | continue 106 | 107 | 108 | @torch.no_grad() 109 | def save_i2vgen_video( 110 | local_path, 111 | image_id, 112 | gen_video, 113 | captions, 114 | mean=[0.5, 0.5, 0.5], 115 | std=[0.5, 0.5, 0.5], 116 | text_size=256, 117 | retry=5, 118 | save_fps = 8 119 | ): 120 | ''' 121 | Save both the generated video and the input conditions. 122 | ''' 123 | vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw 124 | vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw 125 | 126 | text_images = gen_text_image(captions, text_size) # Tensor 1x256x256x3 127 | text_images = text_images.unsqueeze(1) # Tensor 1x1x256x256x3 128 | text_images = text_images.repeat_interleave(repeats=gen_video.size(2), dim=1) # 1x16x256x256x3 129 | 130 | image_id = image_id.unsqueeze(2) # B, C, F, H, W 131 | image_id = image_id.repeat_interleave(repeats=gen_video.size(2), dim=2) # 1x3x32x256x448 132 | image_id = image_id.mul_(vid_std).add_(vid_mean) # 32x3x256x448 133 | image_id.clamp_(0, 1) 134 | image_id = image_id * 255.0 135 | image_id = rearrange(image_id, 'b c f h w -> b f h w c') 136 | 137 | gen_video = gen_video.mul_(vid_std).add_(vid_mean) # 8x3x16x256x384 138 | gen_video.clamp_(0, 1) 139 | gen_video = gen_video * 255.0 140 | 141 | images = rearrange(gen_video, 'b c f h w -> b f h w c') 142 | images = torch.cat([image_id, images, text_images], dim=3) 143 | images = images[0] 144 | images = [(img.numpy()).astype('uint8') for img in images] 145 | 146 | exception = None 147 | for _ in [None] * retry: 148 | try: 149 | frame_dir = os.path.join(os.path.dirname(local_path), '%s_frames' % (os.path.basename(local_path))) 150 | os.system(f'rm -rf {frame_dir}'); os.makedirs(frame_dir, exist_ok=True) 151 | for fid, frame in enumerate(images): 152 | tpth = os.path.join(frame_dir, '%04d.png' % (fid+1)) 153 | cv2.imwrite(tpth, frame[:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) 154 | cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate {save_fps} -i {frame_dir}/%04d.png -vcodec libx264 -crf 17 -pix_fmt yuv420p {local_path}' 155 | os.system(cmd); os.system(f'rm -rf {frame_dir}') 156 | break 157 | except Exception as e: 158 | exception = e 159 | continue 160 | 161 | if exception is not None: 162 | raise exception 163 | 164 | 165 | @torch.no_grad() 166 | def save_i2vgen_video_safe( 167 | local_path, 168 | gen_video, 169 | captions, 170 | mean=[0.5, 0.5, 0.5], 171 | std=[0.5, 0.5, 0.5], 172 | text_size=256, 173 | retry=5, 174 | save_fps = 8 175 | ): 176 | ''' 177 | Save only the generated video, do not save the related reference conditions, and at the same time perform anomaly detection on the last frame. 178 | ''' 179 | vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw 180 | vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw 181 | 182 | gen_video = gen_video.mul_(vid_std).add_(vid_mean) # 8x3x16x256x384 183 | gen_video.clamp_(0, 1) 184 | gen_video = gen_video * 255.0 185 | 186 | images = rearrange(gen_video, 'b c f h w -> b f h w c') 187 | images = images[0] 188 | images = [(img.numpy()).astype('uint8') for img in images] 189 | num_image = len(images) 190 | exception = None 191 | for _ in [None] * retry: 192 | try: 193 | if num_image == 1: 194 | local_path = local_path + '.png' 195 | cv2.imwrite(local_path, images[0][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) 196 | else: 197 | os.makedirs(local_path.replace(".mp4", ""), exist_ok=True) 198 | 199 | writer = imageio.get_writer(local_path, fps=save_fps, codec='libx264', quality=8) 200 | for fid, frame in enumerate(images): 201 | # if fid == num_image-1: # Fix known bugs. 202 | # ratio = (np.sum((frame >= 117) & (frame <= 137)))/(frame.size) 203 | # if ratio > 0.4: continue 204 | writer.append_data(frame) 205 | cv2.imwrite(os.path.join(local_path.replace(".mp4", ""), "{:05d}.png".format(fid)), frame[:,:,::-1]) 206 | writer.close() 207 | break 208 | except Exception as e: 209 | exception = e 210 | continue 211 | 212 | if exception is not None: 213 | raise exception 214 | 215 | --------------------------------------------------------------------------------