├── Readme.md ├── configs └── latent-diffusion │ ├── talking-inference.yaml │ └── talking.yaml ├── data ├── data_test.txt ├── data_train.txt └── train_name.txt ├── inference.sh ├── ldm ├── __pycache__ │ └── util.cpython-37.pyc ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── base.cpython-37.pyc │ │ └── talk_data_ref_smooth.cpython-37.pyc │ ├── base.py │ ├── talk_data_ref_smooth.py │ └── talk_data_ref_smooth_inference.py ├── lr_scheduler.py ├── models │ ├── __pycache__ │ │ └── autoencoder.cpython-37.pyc │ ├── autoencoder.py │ └── diffusion │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── ddim_ldm_ref_inpaint.cpython-37.pyc │ │ └── ddpm_talking.cpython-37.pyc │ │ ├── ddim_ldm_ref_inpaint.py │ │ └── ddpm_talking.py ├── modules │ ├── __pycache__ │ │ ├── attention.cpython-37.pyc │ │ ├── ema.cpython-37.pyc │ │ └── x_transformer.cpython-37.pyc │ ├── attention.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── model.cpython-37.pyc │ │ │ ├── openaimodel.cpython-37.pyc │ │ │ └── util.cpython-37.pyc │ │ ├── model.py │ │ ├── openaimodel.py │ │ └── util.py │ ├── distributions │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ └── distributions.cpython-37.pyc │ │ └── distributions.py │ ├── ema.py │ ├── encoders │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ └── modules.cpython-37.pyc │ │ └── modules.py │ ├── image_degradation │ │ ├── __init__.py │ │ ├── bsrgan.py │ │ ├── bsrgan_light.py │ │ └── utils_image.py │ ├── losses │ │ ├── __init__.py │ │ ├── contperceptual.py │ │ └── vqperceptual.py │ └── x_transformer.py └── util.py ├── main.py ├── models └── Readme.md ├── requirements.txt ├── run.sh └── scripts └── inference.py /Readme.md: -------------------------------------------------------------------------------- 1 | # DiffTalk # 2 | The pytorch implementation for our CVPR2023 paper "DiffTalk: Crafting Diffusion Models for Generalized Audio-Driven Portraits Animation". 3 | 4 | [[Project]](https://sstzal.github.io/DiffTalk/) [[Paper]](https://openaccess.thecvf.com/content/CVPR2023/papers/Shen_DiffTalk_Crafting_Diffusion_Models_for_Generalized_Audio-Driven_Portraits_Animation_CVPR_2023_paper.pdf) [[Video Demo]](https://youtu.be/tup5kbsOJXc) 5 | 6 | ## Requirements 7 | - python 3.7.0 8 | - pytorch 1.10.0 9 | - pytorch-lightning 1.2.5 10 | - torchvision 0.11.0 11 | - pytorch-lightning==1.2.5 12 | 13 | For more details, please refer to the `requirements.txt`. We conduct the experiments with 8 NVIDIA 3090Ti GPUs. 14 | 15 | Put the first stage [model](https://cloud.tsinghua.edu.cn/f/7eb11fc208144ed0ad20/?dl=1) to `./models`. 16 | 17 | ## Dataset 18 | Please download the HDTF dataset for training and test, and process the dataset as following. 19 | 20 | **Data Preprocessing:** 21 | 22 | 23 | 1. Set all videos to 25 fps. 24 | 2. Extract the audio signals and facial landmarks. 25 | 3. Put the processed data in `./data/HDTF`, and construct the data directory as following. 26 | 4. Constract the `data_train.txt` and `data_test.txt` as following. 27 | 28 | ./data/HDTF: 29 | 30 | |——data/HDTF 31 | |——images 32 | |——0_0.jpg 33 | |——0_1.jpg 34 | |——... 35 | |——N_M.bin 36 | |——landmarks 37 | |——0_0.lmd 38 | |——0_1.lmd 39 | |——... 40 | |——N_M.lms 41 | |——audio_smooth 42 | |——0_0.npy 43 | |——0_1.npy 44 | |——... 45 | |——N_M.npy 46 | 47 | ./data/data_train(test).txt: 48 | 49 | 0_0 50 | 0_1 51 | 0_2 52 | ... 53 | N_M 54 | 55 | 56 | N is the total number of classes, and M is the class size. 57 | 58 | 59 | ## Training 60 | ``` 61 | sh run.sh 62 | ``` 63 | 64 | ## Test 65 | ``` 66 | sh inference.sh 67 | ``` 68 | ## Weakness 69 | 1. The DiffTalk models talking head generation as an iterative denoising process, which needs more time to synthesize a frame compared with most GAN-based approaches. This is also a common problem of LDM-based works. 70 | 2. The model is trained on the HDTF dataset, and it sometimes fails on some identities from other datasets. 71 | 3. When driving a portrait with more challenging cross-identity audio, the audio-lip synchronization of the synthesized video is slightly inferior to the ones under self-driven setting. 72 | 4. During inference, the network is also sensitive to the mask shape in z_T , where the mask needs to cover the mouth region completely and its shape cannot leak any 73 | lip shape information. 74 | 75 | ## Acknowledgement 76 | This code is built upon the publicly available code [latent-diffusion](https://github.com/CompVis/latent-diffusion). Thanks the authors of latent-diffusion for making their excellent work and codes publicly available. 77 | 78 | ## Citation ## 79 | Please cite the following paper if you use this repository in your research. 80 | 81 | ``` 82 | @inproceedings{shen2023difftalk, 83 | author={Shen, Shuai and Zhao, Wenliang and Meng, Zibin and Li, Wanhua and Zhu, Zheng and Zhou, Jie and Lu, Jiwen}, 84 | title={DiffTalk: Crafting Diffusion Models for Generalized Audio-Driven Portraits Animation}, 85 | booktitle={CVPR}, 86 | year={2023} 87 | } 88 | ``` 89 | -------------------------------------------------------------------------------- /configs/latent-diffusion/talking-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm_talking.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: audio 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 64 21 | in_channels: 9 22 | out_channels: 3 23 | model_channels: 256 24 | attention_resolutions: 25 | - 8 26 | - 4 27 | - 2 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 3 33 | - 4 34 | num_head_channels: 32 35 | use_spatial_transformer: true 36 | transformer_depth: 1 37 | context_dim: 128 38 | first_stage_config: 39 | target: ldm.models.autoencoder.VQModelInterface 40 | params: 41 | embed_dim: 3 42 | n_embed: 8192 43 | ckpt_path: ./models/model.ckpt 44 | ddconfig: 45 | double_z: false 46 | z_channels: 3 47 | resolution: 256 48 | in_channels: 3 49 | out_ch: 3 50 | ch: 128 51 | ch_mult: 52 | - 1 53 | - 2 54 | - 4 55 | num_res_blocks: 2 56 | attn_resolutions: [] 57 | dropout: 0.0 58 | lossconfig: 59 | target: torch.nn.Identity 60 | cond_stage_config_audio: 61 | target: ldm.modules.encoders.modules.AudioNet 62 | params: 63 | dim_aud: 64 64 | win_size: 16 65 | cond_stage_config_audio_smooth: 66 | target: ldm.modules.encoders.modules.AudioAttNet 67 | params: 68 | dim_aud: 32 69 | cond_stage_config_ldm: 70 | target: ldm.modules.encoders.modules.LdmNet 71 | data: 72 | target: main.DataModuleFromConfig 73 | params: 74 | batch_size: 8 75 | num_workers: 12 76 | wrap: false 77 | train: 78 | target: ldm.data.talk_data_ref_smooth_inference.TalkTrain 79 | params: 80 | size: 256 81 | validation: 82 | target: ldm.data.talk_data_ref_smooth_inference.TalkValidation 83 | params: 84 | size: 256 85 | 86 | 87 | lightning: 88 | callbacks: 89 | image_logger: 90 | target: main.ImageLogger 91 | params: 92 | batch_frequency: 5000 93 | max_images: 8 94 | increase_log_steps: False 95 | 96 | trainer: 97 | benchmark: True 98 | gradient_clip_val: 0.1 -------------------------------------------------------------------------------- /configs/latent-diffusion/talking.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm_talking.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: audio 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 64 21 | in_channels: 9 22 | out_channels: 3 23 | model_channels: 256 24 | attention_resolutions: 25 | - 8 26 | - 4 27 | - 2 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 3 33 | - 4 34 | num_head_channels: 32 35 | use_spatial_transformer: true 36 | transformer_depth: 1 37 | context_dim: 128 38 | first_stage_config: 39 | target: ldm.models.autoencoder.VQModelInterface 40 | params: 41 | embed_dim: 3 42 | n_embed: 8192 43 | ckpt_path: ./models/model.ckpt 44 | ddconfig: 45 | double_z: false 46 | z_channels: 3 47 | resolution: 256 48 | in_channels: 3 49 | out_ch: 3 50 | ch: 128 51 | ch_mult: 52 | - 1 53 | - 2 54 | - 4 55 | num_res_blocks: 2 56 | attn_resolutions: [] 57 | dropout: 0.0 58 | lossconfig: 59 | target: torch.nn.Identity 60 | cond_stage_config_audio: 61 | target: ldm.modules.encoders.modules.AudioNet 62 | params: 63 | dim_aud: 64 64 | win_size: 16 65 | cond_stage_config_audio_smooth: 66 | target: ldm.modules.encoders.modules.AudioAttNet 67 | params: 68 | dim_aud: 32 69 | cond_stage_config_ldm: 70 | target: ldm.modules.encoders.modules.LdmNet 71 | data: 72 | target: main.DataModuleFromConfig 73 | params: 74 | batch_size: 8 75 | num_workers: 12 76 | wrap: false 77 | train: 78 | target: ldm.data.talk_data_ref_smooth.TalkTrain 79 | params: 80 | size: 256 81 | validation: 82 | target: ldm.data.talk_data_ref_smooth.TalkValidation 83 | params: 84 | size: 256 85 | 86 | 87 | lightning: 88 | callbacks: 89 | image_logger: 90 | target: main.ImageLogger 91 | params: 92 | batch_frequency: 5000 93 | max_images: 8 94 | increase_log_steps: False 95 | 96 | trainer: 97 | benchmark: True 98 | gradient_clip_val: 0.1 -------------------------------------------------------------------------------- /data/train_name.txt: -------------------------------------------------------------------------------- 1 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio46_000.mp4 2 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio43_000.mp4 3 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio33_000.mp4 4 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio34_003.mp4 5 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio40_000.mp4 6 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio35_000.mp4 7 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio8_000.mp4 8 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio47_000.mp4 9 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio45_000.mp4 10 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio7_000.mp4 11 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio53_000.mp4 12 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio34_000.mp4 13 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio12_000.mp4 14 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio34_004.mp4 15 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio50_000.mp4 16 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio20_000.mp4 17 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio34_002.mp4 18 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio41_000.mp4 19 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio13_000.mp4 20 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio26_000.mp4 21 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio30_000.mp4 22 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio42_000.mp4 23 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio34_001.mp4 24 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio16_000.mp4 25 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio3_000.mp4 26 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio17_000.mp4 27 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio29_000.mp4 28 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio38_000.mp4 29 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio1_000.mp4 30 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio39_000.mp4 31 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio5_000.mp4 32 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio44_000.mp4 33 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio27_000.mp4 34 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio28_000.mp4 35 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio4_000.mp4 36 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio54_000.mp4 37 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio23_000.mp4 38 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio34_006.mp4 39 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio37_000.mp4 40 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio21_000.mp4 41 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio36_000.mp4 42 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio22_000.mp4 43 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio34_009.mp4 44 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio49_000.mp4 45 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio11_000.mp4 46 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio34_007.mp4 47 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio19_000.mp4 48 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio34_005.mp4 49 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio10_000.mp4 50 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio25_000.mp4 51 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio32_000.mp4 52 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio52_000.mp4 53 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio14_000.mp4 54 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio51_000.mp4 55 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio34_008.mp4 56 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio9_000.mp4 57 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio18_000.mp4 58 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio31_000.mp4 59 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio2_000.mp4 60 | /mnt/cfs/algorithm/public_data/HDTF/video_fps/RD_Radio11_001.mp4 61 | /mnt/cfs/algorithm/public_data/HDTF/video_fps_/WRA_JebHensarling2_003.mp4 62 | /mnt/cfs/algorithm/public_data/HDTF/video_fps_/WRA_JonKyl_000.mp4 63 | /mnt/cfs/algorithm/public_data/HDTF/video_fps_/WDA_GerryConnolly_000.mp4 64 | /mnt/cfs/algorithm/public_data/HDTF/video_fps_/WDA_FrankPallone1_000.mp4 65 | /mnt/cfs/algorithm/public_data/HDTF/video_fps_/WDA_HillaryClinton_000.mp4 66 | /mnt/cfs/algorithm/public_data/HDTF/video_fps_/WDA_BarbaraLee1_000.mp4 67 | /mnt/cfs/algorithm/public_data/HDTF/video_fps_/WRA_SteveDaines0_000.mp4 68 | /mnt/cfs/algorithm/public_data/HDTF/video_fps_/WDA_AmyKlobuchar1_002.mp4 69 | /mnt/cfs/algorithm/public_data/HDTF/video_fps_/WRA_MarkwayneMullin_000.mp4 70 | /mnt/cfs/algorithm/public_data/HDTF/video_fps_/WDA_StenyHoyer_000.mp4 71 | /mnt/cfs/algorithm/public_data/HDTF/video_fps_/WDA_BarackObama_001.mp4 72 | /mnt/cfs/algorithm/public_data/HDTF/video_fps_/WDA_SheldonWhitehouse0_000.mp4 73 | /mnt/cfs/algorithm/public_data/HDTF/video_fps_/WRA_JohnKasich1_001.mp4 74 | /mnt/cfs/algorithm/public_data/HDTF/video_fps_/WDA_JoeCrowley0_000.mp4 75 | /mnt/cfs/algorithm/public_data/HDTF/video_fps_/WRA_SaxbyChambliss_000.mp4 76 | /mnt/cfs/algorithm/public_data/HDTF/video_fps_/WDA_RichardBlumenthal_000.mp4 77 | /mnt/cfs/algorithm/public_data/HDTF/video_fps_/WRA_GregWalden1_000.mp4 78 | /mnt/cfs/algorithm/public_data/HDTF/video_fps_/WRA_JoniErnst1_000.mp4 79 | /mnt/cfs/algorithm/public_data/HDTF/video_fps_/WRA_RandPaul1_000.mp4 80 | /mnt/cfs/algorithm/public_data/HDTF/video_fps_/WRA_ErikPaulsen_003.mp4 81 | /mnt/cfs/algorithm/public_data/HDTF/video_fps_/WRA_JohnKasich3_000.mp4 82 | /mnt/cfs/algorithm/public_data/HDTF/video_fps_/WDA_HakeemJeffries_000.mp4 83 | /mnt/cfs/algorithm/public_data/HDTF/video_fps_/WDA_JackReed0_000.mp4 84 | /mnt/cfs/algorithm/public_data/HDTF/video_fps_/WRA_DianeBlack0_000.mp4 85 | /mnt/cfs/algorithm/public_data/HDTF/video_fps_/WDA_NancyPelosi3_000.mp4 86 | /mnt/cfs/algorithm/public_data/HDTF/video_fps_/WRA_MikeJohanns_000.mp4 87 | /mnt/cfs/algorithm/public_data/HDTF/video_fps_/WDA_JoeManchin_000.mp4 88 | /mnt/cfs/algorithm/public_data/HDTF/video_fps_/WDA_ChrisCoons1_000.mp4 89 | /mnt/cfs/algorithm/public_data/HDTF/video_fps_/WRA_DavidVitter_000.mp4 90 | /mnt/cfs/algorithm/public_data/HDTF/video_fps_/WDA_JackyRosen_000.mp4 91 | /mnt/cfs/algorithm/public_data/HDTF/video_fps_/WDA_PatrickLeahy0_000.mp4 92 | /mnt/cfs/algorithm/public_data/HDTF/video_fps_/WRA_ErikPaulsen_002.mp4 93 | /mnt/cfs/algorithm/public_data/HDTF/video_fps_/WRA_JoePitts_000.mp4 94 | /mnt/cfs/algorithm/public_data/HDTF/video_fps_/WDA_JoeCrowley1_001.mp4 95 | /mnt/cfs/algorithm/public_data/HDTF/video_fps_/WDA_TerriSewell0_000.mp4 96 | /mnt/cfs/algorithm/public_data/HDTF/video_fps_/WDA_MartinHeinrich_000.mp4 97 | /mnt/cfs/algorithm/public_data/HDTF/video_fps_/WDA_JackieSpeier_000.mp4 98 | /mnt/cfs/algorithm/public_data/HDTF/video_fps_/WRA_LisaMurkowski0_000.mp4 99 | -------------------------------------------------------------------------------- /inference.sh: -------------------------------------------------------------------------------- 1 | python scripts/inference.py -------------------------------------------------------------------------------- /ldm/__pycache__/util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sstzal/DiffTalk/ce38ae67a7688ab9987dab5dfee176e748add65a/ldm/__pycache__/util.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sstzal/DiffTalk/ce38ae67a7688ab9987dab5dfee176e748add65a/ldm/data/__init__.py -------------------------------------------------------------------------------- /ldm/data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sstzal/DiffTalk/ce38ae67a7688ab9987dab5dfee176e748add65a/ldm/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/data/__pycache__/base.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sstzal/DiffTalk/ce38ae67a7688ab9987dab5dfee176e748add65a/ldm/data/__pycache__/base.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/data/__pycache__/talk_data_ref_smooth.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sstzal/DiffTalk/ce38ae67a7688ab9987dab5dfee176e748add65a/ldm/data/__pycache__/talk_data_ref_smooth.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/data/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset 3 | 4 | 5 | class Txt2ImgIterableBaseDataset(IterableDataset): 6 | ''' 7 | Define an interface to make the IterableDatasets for text2img data chainable 8 | ''' 9 | def __init__(self, num_records=0, valid_ids=None, size=256): 10 | super().__init__() 11 | self.num_records = num_records 12 | self.valid_ids = valid_ids 13 | self.sample_ids = valid_ids 14 | self.size = size 15 | 16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') 17 | 18 | def __len__(self): 19 | return self.num_records 20 | 21 | @abstractmethod 22 | def __iter__(self): 23 | pass -------------------------------------------------------------------------------- /ldm/data/talk_data_ref_smooth.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | import random 8 | import pdb 9 | 10 | class TALKBase(Dataset): 11 | def __init__(self, 12 | txt_file, 13 | data_root, 14 | size=None, 15 | interpolation="bicubic", 16 | flip_p=0.5 17 | ): 18 | self.data_paths = txt_file 19 | self.data_root = data_root 20 | with open(self.data_paths, "r") as f: 21 | self.image_paths = f.read().splitlines() 22 | self._length = len(self.image_paths) 23 | image_list_path = os.path.join(data_root, 'data.txt') 24 | with open(image_list_path, "r") as f: 25 | self.image_num = f.read().splitlines() 26 | 27 | self.labels = { 28 | "frame_id": [int(l.split('_')[0]) for l in self.image_paths], 29 | "image_path_": [os.path.join(self.data_root, 'images', l+'.jpg') for l in self.image_paths], 30 | "audio_smooth_path_": [os.path.join(self.data_root, 'audio_smooth', l + '.npy') for l in self.image_paths], 31 | "landmark_path_": [os.path.join(self.data_root, 'landmarks', l+'.lms') for l in self.image_paths], 32 | "reference_path": [l.split('_')[0] + '_' + str(random.choice(list(set(range(1, int(self.image_num[int(l.split('_')[0])-1].split()[1])))-set(range(int(l.split('_')[1])-60, int(l.split('_')[1])+60))))) 33 | for l in self.image_paths], 34 | } 35 | 36 | self.size = size 37 | self.interpolation = {"linear": PIL.Image.LINEAR, 38 | "bilinear": PIL.Image.BILINEAR, 39 | "bicubic": PIL.Image.BICUBIC, 40 | "lanczos": PIL.Image.LANCZOS, 41 | }[interpolation] 42 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) 43 | 44 | def __len__(self): 45 | return self._length 46 | 47 | def __getitem__(self, i): 48 | example = dict((k, self.labels[k][i]) for k in self.labels) 49 | 50 | image = Image.open(example["image_path_"]) 51 | if not image.mode == "RGB": 52 | image = image.convert("RGB") 53 | 54 | # default to score-sde preprocessing 55 | img = np.array(image).astype(np.uint8) 56 | image = Image.fromarray(img) 57 | h, w = image.size 58 | if self.size is not None: 59 | image = image.resize((self.size, self.size), resample=self.interpolation) 60 | 61 | image = np.array(image).astype(np.uint8) 62 | example["image"] = (image / 127.5 - 1.0).astype(np.float32) 63 | 64 | landmarks = np.loadtxt(example["landmark_path_"], dtype=np.float32) 65 | landmarks_img = landmarks[13:48] 66 | landmarks_img2 = landmarks[0:4] 67 | landmarks_img = np.concatenate((landmarks_img2, landmarks_img)) 68 | scaler = h / self.size 69 | example["landmarks"] = (landmarks_img / scaler) 70 | 71 | #mask 72 | mask = np.ones((self.size, self.size)) 73 | mask[(landmarks[30][1] / scaler).astype(int):, :] = 0. 74 | mask = mask[..., None] 75 | image_mask = (image * mask).astype(np.uint8) 76 | example["image_mask"] = (image_mask / 127.5 - 1.0).astype(np.float32) 77 | 78 | example["audio_smooth"] = np.load(example["audio_smooth_path_"]).astype(np.float32) 79 | 80 | #add for reference 81 | image_r = Image.open(os.path.join(self.data_root, 'images', example["reference_path"] +'.jpg')) 82 | if not image_r.mode == "RGB": 83 | image_r = image_r.convert("RGB") 84 | 85 | img_r = np.array(image_r).astype(np.uint8) 86 | image_r = Image.fromarray(img_r) 87 | image_r = image_r.resize((self.size, self.size), resample=self.interpolation) 88 | image_r = np.array(image_r).astype(np.uint8) 89 | example["reference_img"] = (image_r / 127.5 - 1.0).astype(np.float32) 90 | 91 | return example 92 | 93 | 94 | class TalkTrain(TALKBase): 95 | def __init__(self, **kwargs): 96 | super().__init__(txt_file="./data/data_train.txt", data_root="./data/HDTF", **kwargs) 97 | 98 | class TalkValidation(TALKBase): 99 | def __init__(self, flip_p=0., **kwargs): 100 | super().__init__(txt_file="./data/data_test.txt", data_root="./data/HDTF", flip_p=flip_p, **kwargs) 101 | -------------------------------------------------------------------------------- /ldm/data/talk_data_ref_smooth_inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | import random 8 | import cv2 9 | 10 | 11 | class TALKBase(Dataset): 12 | def __init__(self, 13 | txt_file, 14 | data_root, 15 | size=None, 16 | interpolation="bicubic", 17 | flip_p=0.5 18 | ): 19 | self.data_paths = txt_file 20 | self.data_root = data_root 21 | with open(self.data_paths, "r") as f: 22 | self.image_paths = f.read().splitlines() 23 | self._length = len(self.image_paths) 24 | image_list_path = os.path.join(data_root, 'data.txt') 25 | with open(image_list_path, "r") as f: 26 | self.image_num = f.read().splitlines() 27 | 28 | self.labels = { 29 | "frame_id": [int(l.split('_')[0]) for l in self.image_paths], 30 | "image_path_": [os.path.join(self.data_root, 'images', l+'.jpg') for l in self.image_paths], 31 | "audio_smooth_path_": [os.path.join(self.data_root, 'audio_smooth', '105_' + l.split('_')[1] + '.npy') for l in self.image_paths], 32 | "landmark_path_": [os.path.join(self.data_root, 'landmarks', l+'.lms') for l in self.image_paths], 33 | "reference_path": [l.split('_')[0] + '_' + str(random.choice(list(set(range(1, int(self.image_num[int(l.split('_')[0])-1].split()[1])))-set(range(int(l.split('_')[1])-60, int(l.split('_')[1])+60))))) 34 | for l in self.image_paths], 35 | } 36 | 37 | self.size = size 38 | self.interpolation = {"linear": PIL.Image.LINEAR, 39 | "bilinear": PIL.Image.BILINEAR, 40 | "bicubic": PIL.Image.BICUBIC, 41 | "lanczos": PIL.Image.LANCZOS, 42 | }[interpolation] 43 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) 44 | 45 | def __len__(self): 46 | return self._length 47 | 48 | def __getitem__(self, i): 49 | example = dict((k, self.labels[k][i]) for k in self.labels) 50 | 51 | image = Image.open(example["image_path_"]) 52 | if not image.mode == "RGB": 53 | image = image.convert("RGB") 54 | 55 | img = np.array(image).astype(np.uint8) 56 | image = Image.fromarray(img) 57 | h, w = image.size 58 | if self.size is not None: 59 | image = image.resize((self.size, self.size), resample=self.interpolation) 60 | image2 = image.resize((64, 64), resample=PIL.Image.BICUBIC) 61 | 62 | image = np.array(image).astype(np.uint8) 63 | example["image"] = (image / 127.5 - 1.0).astype(np.float32) 64 | 65 | landmarks = np.loadtxt(example["landmark_path_"], dtype=np.float32) 66 | landmarks_img = landmarks[13:48] 67 | landmarks_img2 = landmarks[0:4] 68 | landmarks_img = np.concatenate((landmarks_img2, landmarks_img)) 69 | scaler = h / self.size 70 | example["landmarks"] = (landmarks_img / scaler) 71 | example["landmarks_all"] = (landmarks / scaler) 72 | example["scaler"] = scaler 73 | 74 | #inference mask 75 | inference_mask = np.ones((h, w)) 76 | points = landmarks[2:15] 77 | points = np.concatenate((points, landmarks[33:34])).astype('int32') 78 | inference_mask = cv2.fillPoly(inference_mask, pts=[points], color=(0, 0, 0)) 79 | inference_mask = (inference_mask > 0).astype(int) 80 | inference_mask = Image.fromarray(inference_mask.astype(np.uint8)) 81 | inference_mask = inference_mask.resize((64, 64), resample=self.interpolation) 82 | inference_mask = np.array(inference_mask) 83 | example["inference_mask"] = inference_mask 84 | 85 | #mask 86 | mask = np.ones((self.size, self.size)) 87 | # zeros will be filled in 88 | mask[(landmarks[33][1] / scaler).astype(int):, :] = 0. 89 | mask = mask[..., None] 90 | image_mask = (image * mask).astype(np.uint8) 91 | example["image_mask"] = (image_mask / 127.5 - 1.0).astype(np.float32) 92 | 93 | example["audio_smooth"] = np.load(example["audio_smooth_path_"]) .astype(np.float32) 94 | 95 | reference_path = example["reference_path"].split('_')[0] 96 | image_r = Image.open(os.path.join(self.data_root, 'images', reference_path + '_1.jpg')) 97 | if not image_r.mode == "RGB": 98 | image_r = image_r.convert("RGB") 99 | 100 | img_r = np.array(image_r).astype(np.uint8) 101 | image_r = Image.fromarray(img_r) 102 | image_r = image_r.resize((self.size, self.size), resample=self.interpolation) 103 | image_r = np.array(image_r).astype(np.uint8) 104 | example["reference_img"] = (image_r / 127.5 - 1.0).astype(np.float32) 105 | 106 | return example 107 | 108 | 109 | class TalkTrain(TALKBase): 110 | def __init__(self, **kwargs): 111 | super().__init__(txt_file="./data/data_train.txt", data_root="./data/HDTF", **kwargs) 112 | 113 | 114 | class TalkValidation(TALKBase): 115 | def __init__(self, flip_p=0., **kwargs): 116 | super().__init__(txt_file="./data/data_test.txt", data_root="./data/HDTF", 117 | flip_p=flip_p, **kwargs) -------------------------------------------------------------------------------- /ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /ldm/models/__pycache__/autoencoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sstzal/DiffTalk/ce38ae67a7688ab9987dab5dfee176e748add65a/ldm/models/__pycache__/autoencoder.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/models/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | import torch.nn.functional as F 4 | from contextlib import contextmanager 5 | 6 | from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer 7 | 8 | from ldm.modules.diffusionmodules.model import Encoder, Decoder 9 | from ldm.modules.distributions.distributions import DiagonalGaussianDistribution 10 | 11 | from ldm.util import instantiate_from_config 12 | 13 | 14 | class VQModel(pl.LightningModule): 15 | def __init__(self, 16 | ddconfig, 17 | lossconfig, 18 | n_embed, 19 | embed_dim, 20 | ckpt_path=None, 21 | ignore_keys=[], 22 | image_key="image", 23 | colorize_nlabels=None, 24 | monitor=None, 25 | batch_resize_range=None, 26 | scheduler_config=None, 27 | lr_g_factor=1.0, 28 | remap=None, 29 | sane_index_shape=False, # tell vector quantizer to return indices as bhw 30 | use_ema=False 31 | ): 32 | super().__init__() 33 | self.embed_dim = embed_dim 34 | self.n_embed = n_embed 35 | self.image_key = image_key 36 | self.encoder = Encoder(**ddconfig) 37 | self.decoder = Decoder(**ddconfig) 38 | self.loss = instantiate_from_config(lossconfig) 39 | self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, 40 | remap=remap, 41 | sane_index_shape=sane_index_shape) 42 | self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) 43 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 44 | if colorize_nlabels is not None: 45 | assert type(colorize_nlabels)==int 46 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 47 | if monitor is not None: 48 | self.monitor = monitor 49 | self.batch_resize_range = batch_resize_range 50 | if self.batch_resize_range is not None: 51 | print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.") 52 | 53 | self.use_ema = use_ema 54 | if self.use_ema: 55 | self.model_ema = LitEma(self) 56 | print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") 57 | 58 | if ckpt_path is not None: 59 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 60 | self.scheduler_config = scheduler_config 61 | self.lr_g_factor = lr_g_factor 62 | 63 | @contextmanager 64 | def ema_scope(self, context=None): 65 | if self.use_ema: 66 | self.model_ema.store(self.parameters()) 67 | self.model_ema.copy_to(self) 68 | if context is not None: 69 | print(f"{context}: Switched to EMA weights") 70 | try: 71 | yield None 72 | finally: 73 | if self.use_ema: 74 | self.model_ema.restore(self.parameters()) 75 | if context is not None: 76 | print(f"{context}: Restored training weights") 77 | 78 | def init_from_ckpt(self, path, ignore_keys=list()): 79 | sd = torch.load(path, map_location="cpu")["state_dict"] 80 | keys = list(sd.keys()) 81 | for k in keys: 82 | for ik in ignore_keys: 83 | if k.startswith(ik): 84 | print("Deleting key {} from state_dict.".format(k)) 85 | del sd[k] 86 | missing, unexpected = self.load_state_dict(sd, strict=False) 87 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") 88 | if len(missing) > 0: 89 | print(f"Missing Keys: {missing}") 90 | print(f"Unexpected Keys: {unexpected}") 91 | 92 | def on_train_batch_end(self, *args, **kwargs): 93 | if self.use_ema: 94 | self.model_ema(self) 95 | 96 | def encode(self, x): 97 | h = self.encoder(x) 98 | h = self.quant_conv(h) 99 | quant, emb_loss, info = self.quantize(h) 100 | return quant, emb_loss, info 101 | 102 | def encode_to_prequant(self, x): 103 | h = self.encoder(x) 104 | h = self.quant_conv(h) 105 | return h 106 | 107 | def decode(self, quant): 108 | quant = self.post_quant_conv(quant) 109 | dec = self.decoder(quant) 110 | return dec 111 | 112 | def decode_code(self, code_b): 113 | quant_b = self.quantize.embed_code(code_b) 114 | dec = self.decode(quant_b) 115 | return dec 116 | 117 | def forward(self, input, return_pred_indices=False): 118 | quant, diff, (_,_,ind) = self.encode(input) 119 | dec = self.decode(quant) 120 | if return_pred_indices: 121 | return dec, diff, ind 122 | return dec, diff 123 | 124 | def get_input(self, batch, k): 125 | x = batch[k] 126 | if len(x.shape) == 3: 127 | x = x[..., None] 128 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() 129 | if self.batch_resize_range is not None: 130 | lower_size = self.batch_resize_range[0] 131 | upper_size = self.batch_resize_range[1] 132 | if self.global_step <= 4: 133 | # do the first few batches with max size to avoid later oom 134 | new_resize = upper_size 135 | else: 136 | new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16)) 137 | if new_resize != x.shape[2]: 138 | x = F.interpolate(x, size=new_resize, mode="bicubic") 139 | x = x.detach() 140 | return x 141 | 142 | def training_step(self, batch, batch_idx, optimizer_idx): 143 | x = self.get_input(batch, self.image_key) 144 | xrec, qloss, ind = self(x, return_pred_indices=True) 145 | 146 | if optimizer_idx == 0: 147 | # autoencode 148 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, 149 | last_layer=self.get_last_layer(), split="train", 150 | predicted_indices=ind) 151 | 152 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) 153 | return aeloss 154 | 155 | if optimizer_idx == 1: 156 | # discriminator 157 | discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, 158 | last_layer=self.get_last_layer(), split="train") 159 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) 160 | return discloss 161 | 162 | def validation_step(self, batch, batch_idx): 163 | log_dict = self._validation_step(batch, batch_idx) 164 | with self.ema_scope(): 165 | log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") 166 | return log_dict 167 | 168 | def _validation_step(self, batch, batch_idx, suffix=""): 169 | x = self.get_input(batch, self.image_key) 170 | xrec, qloss, ind = self(x, return_pred_indices=True) 171 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, 172 | self.global_step, 173 | last_layer=self.get_last_layer(), 174 | split="val"+suffix, 175 | predicted_indices=ind 176 | ) 177 | 178 | discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, 179 | self.global_step, 180 | last_layer=self.get_last_layer(), 181 | split="val"+suffix, 182 | predicted_indices=ind 183 | ) 184 | rec_loss = log_dict_ae[f"val{suffix}/rec_loss"] 185 | self.log(f"val{suffix}/rec_loss", rec_loss, 186 | prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) 187 | self.log(f"val{suffix}/aeloss", aeloss, 188 | prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) 189 | if version.parse(pl.__version__) >= version.parse('1.4.0'): 190 | del log_dict_ae[f"val{suffix}/rec_loss"] 191 | self.log_dict(log_dict_ae) 192 | self.log_dict(log_dict_disc) 193 | return self.log_dict 194 | 195 | def configure_optimizers(self): 196 | lr_d = self.learning_rate 197 | lr_g = self.lr_g_factor*self.learning_rate 198 | print("lr_d", lr_d) 199 | print("lr_g", lr_g) 200 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ 201 | list(self.decoder.parameters())+ 202 | list(self.quantize.parameters())+ 203 | list(self.quant_conv.parameters())+ 204 | list(self.post_quant_conv.parameters()), 205 | lr=lr_g, betas=(0.5, 0.9)) 206 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), 207 | lr=lr_d, betas=(0.5, 0.9)) 208 | 209 | if self.scheduler_config is not None: 210 | scheduler = instantiate_from_config(self.scheduler_config) 211 | 212 | print("Setting up LambdaLR scheduler...") 213 | scheduler = [ 214 | { 215 | 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule), 216 | 'interval': 'step', 217 | 'frequency': 1 218 | }, 219 | { 220 | 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule), 221 | 'interval': 'step', 222 | 'frequency': 1 223 | }, 224 | ] 225 | return [opt_ae, opt_disc], scheduler 226 | return [opt_ae, opt_disc], [] 227 | 228 | def get_last_layer(self): 229 | return self.decoder.conv_out.weight 230 | 231 | def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): 232 | log = dict() 233 | x = self.get_input(batch, self.image_key) 234 | x = x.to(self.device) 235 | if only_inputs: 236 | log["inputs"] = x 237 | return log 238 | xrec, _ = self(x) 239 | if x.shape[1] > 3: 240 | # colorize with random projection 241 | assert xrec.shape[1] > 3 242 | x = self.to_rgb(x) 243 | xrec = self.to_rgb(xrec) 244 | log["inputs"] = x 245 | log["reconstructions"] = xrec 246 | if plot_ema: 247 | with self.ema_scope(): 248 | xrec_ema, _ = self(x) 249 | if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) 250 | log["reconstructions_ema"] = xrec_ema 251 | return log 252 | 253 | def to_rgb(self, x): 254 | assert self.image_key == "segmentation" 255 | if not hasattr(self, "colorize"): 256 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 257 | x = F.conv2d(x, weight=self.colorize) 258 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1. 259 | return x 260 | 261 | 262 | class VQModelInterface(VQModel): 263 | def __init__(self, embed_dim, *args, **kwargs): 264 | super().__init__(embed_dim=embed_dim, *args, **kwargs) 265 | self.embed_dim = embed_dim 266 | 267 | def encode(self, x): 268 | h = self.encoder(x) 269 | h = self.quant_conv(h) 270 | return h 271 | 272 | def decode(self, h, force_not_quantize=False): 273 | # also go through quantization layer 274 | if not force_not_quantize: 275 | quant, emb_loss, info = self.quantize(h) 276 | else: 277 | quant = h 278 | quant = self.post_quant_conv(quant) 279 | dec = self.decoder(quant) 280 | return dec 281 | 282 | 283 | class AutoencoderKL(pl.LightningModule): 284 | def __init__(self, 285 | ddconfig, 286 | lossconfig, 287 | embed_dim, 288 | ckpt_path=None, 289 | ignore_keys=[], 290 | image_key="image", 291 | colorize_nlabels=None, 292 | monitor=None, 293 | ): 294 | super().__init__() 295 | self.image_key = image_key 296 | self.encoder = Encoder(**ddconfig) 297 | self.decoder = Decoder(**ddconfig) 298 | self.loss = instantiate_from_config(lossconfig) 299 | assert ddconfig["double_z"] 300 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) 301 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 302 | self.embed_dim = embed_dim 303 | if colorize_nlabels is not None: 304 | assert type(colorize_nlabels)==int 305 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 306 | if monitor is not None: 307 | self.monitor = monitor 308 | if ckpt_path is not None: 309 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 310 | 311 | def init_from_ckpt(self, path, ignore_keys=list()): 312 | sd = torch.load(path, map_location="cpu")["state_dict"] 313 | keys = list(sd.keys()) 314 | for k in keys: 315 | for ik in ignore_keys: 316 | if k.startswith(ik): 317 | print("Deleting key {} from state_dict.".format(k)) 318 | del sd[k] 319 | self.load_state_dict(sd, strict=False) 320 | print(f"Restored from {path}") 321 | 322 | def encode(self, x): 323 | h = self.encoder(x) 324 | moments = self.quant_conv(h) 325 | posterior = DiagonalGaussianDistribution(moments) 326 | return posterior 327 | 328 | def decode(self, z): 329 | z = self.post_quant_conv(z) 330 | dec = self.decoder(z) 331 | return dec 332 | 333 | def forward(self, input, sample_posterior=True): 334 | posterior = self.encode(input) 335 | if sample_posterior: 336 | z = posterior.sample() 337 | else: 338 | z = posterior.mode() 339 | dec = self.decode(z) 340 | return dec, posterior 341 | 342 | def get_input(self, batch, k): 343 | x = batch[k] 344 | if len(x.shape) == 3: 345 | x = x[..., None] 346 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() 347 | return x 348 | 349 | def training_step(self, batch, batch_idx, optimizer_idx): 350 | inputs = self.get_input(batch, self.image_key) 351 | reconstructions, posterior = self(inputs) 352 | 353 | if optimizer_idx == 0: 354 | # train encoder+decoder+logvar 355 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 356 | last_layer=self.get_last_layer(), split="train") 357 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 358 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) 359 | return aeloss 360 | 361 | if optimizer_idx == 1: 362 | # train the discriminator 363 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 364 | last_layer=self.get_last_layer(), split="train") 365 | 366 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 367 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) 368 | return discloss 369 | 370 | def validation_step(self, batch, batch_idx): 371 | inputs = self.get_input(batch, self.image_key) 372 | reconstructions, posterior = self(inputs) 373 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, 374 | last_layer=self.get_last_layer(), split="val") 375 | 376 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, 377 | last_layer=self.get_last_layer(), split="val") 378 | 379 | self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) 380 | self.log_dict(log_dict_ae) 381 | self.log_dict(log_dict_disc) 382 | return self.log_dict 383 | 384 | def configure_optimizers(self): 385 | lr = self.learning_rate 386 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ 387 | list(self.decoder.parameters())+ 388 | list(self.quant_conv.parameters())+ 389 | list(self.post_quant_conv.parameters()), 390 | lr=lr, betas=(0.5, 0.9)) 391 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), 392 | lr=lr, betas=(0.5, 0.9)) 393 | return [opt_ae, opt_disc], [] 394 | 395 | def get_last_layer(self): 396 | return self.decoder.conv_out.weight 397 | 398 | @torch.no_grad() 399 | def log_images(self, batch, only_inputs=False, **kwargs): 400 | log = dict() 401 | x = self.get_input(batch, self.image_key) 402 | x = x.to(self.device) 403 | if not only_inputs: 404 | xrec, posterior = self(x) 405 | if x.shape[1] > 3: 406 | # colorize with random projection 407 | assert xrec.shape[1] > 3 408 | x = self.to_rgb(x) 409 | xrec = self.to_rgb(xrec) 410 | log["samples"] = self.decode(torch.randn_like(posterior.sample())) 411 | log["reconstructions"] = xrec 412 | log["inputs"] = x 413 | return log 414 | 415 | def to_rgb(self, x): 416 | assert self.image_key == "segmentation" 417 | if not hasattr(self, "colorize"): 418 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 419 | x = F.conv2d(x, weight=self.colorize) 420 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1. 421 | return x 422 | 423 | 424 | class IdentityFirstStage(torch.nn.Module): 425 | def __init__(self, *args, vq_interface=False, **kwargs): 426 | self.vq_interface = vq_interface 427 | super().__init__() 428 | 429 | def encode(self, x, *args, **kwargs): 430 | return x 431 | 432 | def decode(self, x, *args, **kwargs): 433 | return x 434 | 435 | def quantize(self, x, *args, **kwargs): 436 | if self.vq_interface: 437 | return x, None, [None, None, None] 438 | return x 439 | 440 | def forward(self, x, *args, **kwargs): 441 | return x 442 | -------------------------------------------------------------------------------- /ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sstzal/DiffTalk/ce38ae67a7688ab9987dab5dfee176e748add65a/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sstzal/DiffTalk/ce38ae67a7688ab9987dab5dfee176e748add65a/ldm/models/diffusion/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/ddim_ldm_ref_inpaint.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sstzal/DiffTalk/ce38ae67a7688ab9987dab5dfee176e748add65a/ldm/models/diffusion/__pycache__/ddim_ldm_ref_inpaint.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/ddpm_talking.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sstzal/DiffTalk/ce38ae67a7688ab9987dab5dfee176e748add65a/ldm/models/diffusion/__pycache__/ddpm_talking.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/ddim_ldm_ref_inpaint.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | from functools import partial 7 | import pdb 8 | 9 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like 10 | 11 | 12 | class DDIMSampler(object): 13 | def __init__(self, model, schedule="linear", **kwargs): 14 | super().__init__() 15 | self.model = model 16 | self.ddpm_num_timesteps = model.num_timesteps 17 | self.schedule = schedule 18 | 19 | def register_buffer(self, name, attr): 20 | if type(attr) == torch.Tensor: 21 | if attr.device != torch.device("cuda"): 22 | attr = attr.to(torch.device("cuda")) 23 | setattr(self, name, attr) 24 | 25 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 26 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 27 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 28 | alphas_cumprod = self.model.alphas_cumprod 29 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 30 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 31 | 32 | self.register_buffer('betas', to_torch(self.model.betas)) 33 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 34 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 35 | 36 | # calculations for diffusion q(x_t | x_{t-1}) and others 37 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 38 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 39 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 40 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 41 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 42 | 43 | # ddim sampling parameters 44 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 45 | ddim_timesteps=self.ddim_timesteps, 46 | eta=ddim_eta,verbose=verbose) 47 | self.register_buffer('ddim_sigmas', ddim_sigmas) 48 | self.register_buffer('ddim_alphas', ddim_alphas) 49 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 50 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 51 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 52 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 53 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 54 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 55 | 56 | @torch.no_grad() 57 | def sample(self, 58 | S, 59 | batch_size, 60 | shape, 61 | conditioning=None, 62 | callback=None, 63 | normals_sequence=None, 64 | img_callback=None, 65 | quantize_x0=False, 66 | eta=0., 67 | mask=None, 68 | x0=None, 69 | temperature=1., 70 | noise_dropout=0., 71 | score_corrector=None, 72 | corrector_kwargs=None, 73 | verbose=True, 74 | x_T=None, 75 | log_every_t=100, 76 | unconditional_guidance_scale=1., 77 | unconditional_conditioning=None, 78 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 79 | **kwargs 80 | ): 81 | if conditioning is not None: 82 | if isinstance(conditioning, dict): 83 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 84 | if cbs != batch_size: 85 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 86 | else: 87 | if conditioning.shape[0] != batch_size: 88 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 89 | 90 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) 91 | # sampling 92 | C, H, W = shape 93 | size = (batch_size, C, H, W) 94 | print(f'Data shape for DDIM sampling is {size}, eta {eta}') 95 | 96 | samples, intermediates = self.ddim_sampling(conditioning, size, 97 | callback=callback, 98 | img_callback=img_callback, 99 | quantize_denoised=quantize_x0, 100 | mask=mask, x0=x0, 101 | ddim_use_original_steps=False, 102 | noise_dropout=noise_dropout, 103 | temperature=temperature, 104 | score_corrector=score_corrector, 105 | corrector_kwargs=corrector_kwargs, 106 | x_T=x_T, 107 | log_every_t=log_every_t, 108 | unconditional_guidance_scale=unconditional_guidance_scale, 109 | unconditional_conditioning=unconditional_conditioning, 110 | ) 111 | return samples, intermediates 112 | 113 | @torch.no_grad() 114 | def ddim_sampling(self, cond, shape, 115 | x_T=None, ddim_use_original_steps=False, 116 | callback=None, timesteps=None, quantize_denoised=False, 117 | mask=None, x0=None, img_callback=None, log_every_t=100, 118 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 119 | unconditional_guidance_scale=1., unconditional_conditioning=None,): 120 | device = self.model.betas.device 121 | b = shape[0] 122 | if x_T is None: 123 | img = torch.randn(shape, device=device) 124 | else: 125 | img = x_T 126 | 127 | if timesteps is None: 128 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 129 | elif timesteps is not None and not ddim_use_original_steps: 130 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 131 | timesteps = self.ddim_timesteps[:subset_end] 132 | 133 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 134 | time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) 135 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 136 | print(f"Running DDIM Sampling with {total_steps} timesteps") 137 | 138 | iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) 139 | 140 | for i, step in enumerate(iterator): 141 | index = total_steps - i - 1 142 | ts = torch.full((b,), step, device=device, dtype=torch.long) 143 | 144 | if mask is not None: 145 | assert x0 is not None 146 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 147 | img = img_orig * mask + (1. - mask) * img 148 | 149 | outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 150 | quantize_denoised=quantize_denoised, temperature=temperature, 151 | noise_dropout=noise_dropout, score_corrector=score_corrector, 152 | corrector_kwargs=corrector_kwargs, 153 | unconditional_guidance_scale=unconditional_guidance_scale, 154 | unconditional_conditioning=unconditional_conditioning) 155 | img, pred_x0 = outs 156 | if callback: callback(i) 157 | if img_callback: img_callback(pred_x0, i) 158 | 159 | if index % log_every_t == 0 or index == total_steps - 1: 160 | intermediates['x_inter'].append(img) 161 | intermediates['pred_x0'].append(pred_x0) 162 | 163 | return img, intermediates 164 | 165 | @torch.no_grad() 166 | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 167 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 168 | unconditional_guidance_scale=1., unconditional_conditioning=None): 169 | b, *_, device = *x.shape, x.device 170 | 171 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 172 | cond_audio = c['audio'] 173 | cond_lip = c['lip'] 174 | cond_ldm = c['ldm'] 175 | cond_mask = c['mask_image'] 176 | e_t = self.model.apply_model(x, t, [cond_audio, cond_lip, cond_ldm, cond_mask]) 177 | else: 178 | x_in = torch.cat([x] * 2) 179 | t_in = torch.cat([t] * 2) 180 | c_in = torch.cat([unconditional_conditioning, c]) 181 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) 182 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) 183 | 184 | if score_corrector is not None: 185 | assert self.model.parameterization == "eps" 186 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 187 | 188 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 189 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 190 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 191 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 192 | # select parameters corresponding to the currently considered timestep 193 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) 194 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) 195 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) 196 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) 197 | 198 | # current prediction for x_0 199 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 200 | if quantize_denoised: 201 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 202 | # direction pointing to x_t 203 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 204 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 205 | if noise_dropout > 0.: 206 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 207 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 208 | return x_prev, pred_x0 209 | -------------------------------------------------------------------------------- /ldm/modules/__pycache__/attention.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sstzal/DiffTalk/ce38ae67a7688ab9987dab5dfee176e748add65a/ldm/modules/__pycache__/attention.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/modules/__pycache__/ema.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sstzal/DiffTalk/ce38ae67a7688ab9987dab5dfee176e748add65a/ldm/modules/__pycache__/ema.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/modules/__pycache__/x_transformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sstzal/DiffTalk/ce38ae67a7688ab9987dab5dfee176e748add65a/ldm/modules/__pycache__/x_transformer.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/modules/attention.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, einsum 6 | from einops import rearrange, repeat 7 | 8 | from ldm.modules.diffusionmodules.util import checkpoint 9 | 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def uniq(arr): 16 | return{el: True for el in arr}.keys() 17 | 18 | 19 | def default(val, d): 20 | if exists(val): 21 | return val 22 | return d() if isfunction(d) else d 23 | 24 | 25 | def max_neg_value(t): 26 | return -torch.finfo(t.dtype).max 27 | 28 | 29 | def init_(tensor): 30 | dim = tensor.shape[-1] 31 | std = 1 / math.sqrt(dim) 32 | tensor.uniform_(-std, std) 33 | return tensor 34 | 35 | 36 | # feedforward 37 | class GEGLU(nn.Module): 38 | def __init__(self, dim_in, dim_out): 39 | super().__init__() 40 | self.proj = nn.Linear(dim_in, dim_out * 2) 41 | 42 | def forward(self, x): 43 | x, gate = self.proj(x).chunk(2, dim=-1) 44 | return x * F.gelu(gate) 45 | 46 | 47 | class FeedForward(nn.Module): 48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 49 | super().__init__() 50 | inner_dim = int(dim * mult) 51 | dim_out = default(dim_out, dim) 52 | project_in = nn.Sequential( 53 | nn.Linear(dim, inner_dim), 54 | nn.GELU() 55 | ) if not glu else GEGLU(dim, inner_dim) 56 | 57 | self.net = nn.Sequential( 58 | project_in, 59 | nn.Dropout(dropout), 60 | nn.Linear(inner_dim, dim_out) 61 | ) 62 | 63 | def forward(self, x): 64 | return self.net(x) 65 | 66 | 67 | def zero_module(module): 68 | """ 69 | Zero out the parameters of a module and return it. 70 | """ 71 | for p in module.parameters(): 72 | p.detach().zero_() 73 | return module 74 | 75 | 76 | def Normalize(in_channels): 77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 78 | 79 | 80 | class LinearAttention(nn.Module): 81 | def __init__(self, dim, heads=4, dim_head=32): 82 | super().__init__() 83 | self.heads = heads 84 | hidden_dim = dim_head * heads 85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 87 | 88 | def forward(self, x): 89 | b, c, h, w = x.shape 90 | qkv = self.to_qkv(x) 91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) 92 | k = k.softmax(dim=-1) 93 | context = torch.einsum('bhdn,bhen->bhde', k, v) 94 | out = torch.einsum('bhde,bhdn->bhen', context, q) 95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) 96 | return self.to_out(out) 97 | 98 | 99 | class SpatialSelfAttention(nn.Module): 100 | def __init__(self, in_channels): 101 | super().__init__() 102 | self.in_channels = in_channels 103 | 104 | self.norm = Normalize(in_channels) 105 | self.q = torch.nn.Conv2d(in_channels, 106 | in_channels, 107 | kernel_size=1, 108 | stride=1, 109 | padding=0) 110 | self.k = torch.nn.Conv2d(in_channels, 111 | in_channels, 112 | kernel_size=1, 113 | stride=1, 114 | padding=0) 115 | self.v = torch.nn.Conv2d(in_channels, 116 | in_channels, 117 | kernel_size=1, 118 | stride=1, 119 | padding=0) 120 | self.proj_out = torch.nn.Conv2d(in_channels, 121 | in_channels, 122 | kernel_size=1, 123 | stride=1, 124 | padding=0) 125 | 126 | def forward(self, x): 127 | h_ = x 128 | h_ = self.norm(h_) 129 | q = self.q(h_) 130 | k = self.k(h_) 131 | v = self.v(h_) 132 | 133 | # compute attention 134 | b,c,h,w = q.shape 135 | q = rearrange(q, 'b c h w -> b (h w) c') 136 | k = rearrange(k, 'b c h w -> b c (h w)') 137 | w_ = torch.einsum('bij,bjk->bik', q, k) 138 | 139 | w_ = w_ * (int(c)**(-0.5)) 140 | w_ = torch.nn.functional.softmax(w_, dim=2) 141 | 142 | # attend to values 143 | v = rearrange(v, 'b c h w -> b c (h w)') 144 | w_ = rearrange(w_, 'b i j -> b j i') 145 | h_ = torch.einsum('bij,bjk->bik', v, w_) 146 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) 147 | h_ = self.proj_out(h_) 148 | 149 | return x+h_ 150 | 151 | 152 | class CrossAttention(nn.Module): 153 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): 154 | super().__init__() 155 | inner_dim = dim_head * heads 156 | context_dim = default(context_dim, query_dim) 157 | 158 | self.scale = dim_head ** -0.5 159 | self.heads = heads 160 | 161 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 162 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 163 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 164 | 165 | self.to_out = nn.Sequential( 166 | nn.Linear(inner_dim, query_dim), 167 | nn.Dropout(dropout) 168 | ) 169 | 170 | def forward(self, x, context=None, mask=None): 171 | h = self.heads#14 172 | 173 | q = self.to_q(x)#x:[10, 1024, 448], q:[10, 1024, 448] 174 | context = default(context, x) 175 | k = self.to_k(context) 176 | v = self.to_v(context) 177 | 178 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 179 | 180 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 181 | 182 | if exists(mask): 183 | mask = rearrange(mask, 'b ... -> b (...)') 184 | max_neg_value = -torch.finfo(sim.dtype).max 185 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 186 | sim.masked_fill_(~mask, max_neg_value) 187 | 188 | # attention, what we cannot get enough of 189 | attn = sim.softmax(dim=-1) 190 | 191 | out = einsum('b i j, b j d -> b i d', attn, v) 192 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 193 | return self.to_out(out) 194 | 195 | 196 | class BasicTransformerBlock(nn.Module): 197 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): 198 | super().__init__() 199 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention 200 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 201 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, 202 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none 203 | self.norm1 = nn.LayerNorm(dim) 204 | self.norm2 = nn.LayerNorm(dim) 205 | self.norm3 = nn.LayerNorm(dim) 206 | self.checkpoint = checkpoint 207 | 208 | def forward(self, x, context=None): 209 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) 210 | 211 | def _forward(self, x, context=None): 212 | x = self.attn1(self.norm1(x)) + x 213 | x = self.attn2(self.norm2(x), context=context) + x 214 | x = self.ff(self.norm3(x)) + x 215 | return x 216 | 217 | 218 | class SpatialTransformer(nn.Module): 219 | """ 220 | Transformer block for image-like data. 221 | First, project the input (aka embedding) 222 | and reshape to b, t, d. 223 | Then apply standard transformer action. 224 | Finally, reshape to image 225 | """ 226 | def __init__(self, in_channels, n_heads, d_head, 227 | depth=1, dropout=0., context_dim=None): 228 | super().__init__() 229 | self.in_channels = in_channels 230 | inner_dim = n_heads * d_head 231 | self.norm = Normalize(in_channels) 232 | 233 | self.proj_in = nn.Conv2d(in_channels, 234 | inner_dim, 235 | kernel_size=1, 236 | stride=1, 237 | padding=0) 238 | 239 | self.transformer_blocks = nn.ModuleList( 240 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) 241 | for d in range(depth)] 242 | ) 243 | 244 | self.proj_out = zero_module(nn.Conv2d(inner_dim, 245 | in_channels, 246 | kernel_size=1, 247 | stride=1, 248 | padding=0)) 249 | 250 | def forward(self, x, context=None): 251 | # note: if no context is given, cross-attention defaults to self-attention 252 | b, c, h, w = x.shape 253 | x_in = x 254 | x = self.norm(x) 255 | x = self.proj_in(x) 256 | x = rearrange(x, 'b c h w -> b (h w) c') 257 | for block in self.transformer_blocks: 258 | x = block(x, context=context) 259 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 260 | x = self.proj_out(x) 261 | return x + x_in -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sstzal/DiffTalk/ce38ae67a7688ab9987dab5dfee176e748add65a/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sstzal/DiffTalk/ce38ae67a7688ab9987dab5dfee176e748add65a/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sstzal/DiffTalk/ce38ae67a7688ab9987dab5dfee176e748add65a/ldm/modules/diffusionmodules/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sstzal/DiffTalk/ce38ae67a7688ab9987dab5dfee176e748add65a/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sstzal/DiffTalk/ce38ae67a7688ab9987dab5dfee176e748add65a/ldm/modules/diffusionmodules/__pycache__/util.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | from einops import repeat 7 | 8 | from ldm.util import instantiate_from_config 9 | 10 | 11 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 12 | if schedule == "linear": 13 | betas = ( 14 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 15 | ) 16 | 17 | elif schedule == "cosine": 18 | timesteps = ( 19 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 20 | ) 21 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 22 | alphas = torch.cos(alphas).pow(2) 23 | alphas = alphas / alphas[0] 24 | betas = 1 - alphas[1:] / alphas[:-1] 25 | betas = np.clip(betas, a_min=0, a_max=0.999) 26 | 27 | elif schedule == "sqrt_linear": 28 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 29 | elif schedule == "sqrt": 30 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 31 | else: 32 | raise ValueError(f"schedule '{schedule}' unknown.") 33 | return betas.numpy() 34 | 35 | 36 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 37 | if ddim_discr_method == 'uniform': 38 | c = num_ddpm_timesteps // num_ddim_timesteps 39 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 40 | elif ddim_discr_method == 'quad': 41 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 42 | else: 43 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 44 | 45 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 46 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 47 | steps_out = ddim_timesteps + 1 48 | if verbose: 49 | print(f'Selected timesteps for ddim sampler: {steps_out}') 50 | return steps_out 51 | 52 | 53 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 54 | # select alphas for computing the variance schedule 55 | alphas = alphacums[ddim_timesteps] 56 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 57 | 58 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 59 | if verbose: 60 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 61 | print(f'For the chosen value of eta, which is {eta}, ' 62 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 63 | return sigmas, alphas, alphas_prev 64 | 65 | 66 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 67 | """ 68 | Create a beta schedule that discretizes the given alpha_t_bar function, 69 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 70 | :param num_diffusion_timesteps: the number of betas to produce. 71 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 72 | produces the cumulative product of (1-beta) up to that 73 | part of the diffusion process. 74 | :param max_beta: the maximum beta to use; use values lower than 1 to 75 | prevent singularities. 76 | """ 77 | betas = [] 78 | for i in range(num_diffusion_timesteps): 79 | t1 = i / num_diffusion_timesteps 80 | t2 = (i + 1) / num_diffusion_timesteps 81 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 82 | return np.array(betas) 83 | 84 | 85 | def extract_into_tensor(a, t, x_shape): 86 | b, *_ = t.shape 87 | out = a.gather(-1, t) 88 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 89 | 90 | 91 | def checkpoint(func, inputs, params, flag): 92 | """ 93 | Evaluate a function without caching intermediate activations, allowing for 94 | reduced memory at the expense of extra compute in the backward pass. 95 | :param func: the function to evaluate. 96 | :param inputs: the argument sequence to pass to `func`. 97 | :param params: a sequence of parameters `func` depends on but does not 98 | explicitly take as arguments. 99 | :param flag: if False, disable gradient checkpointing. 100 | """ 101 | if flag: 102 | args = tuple(inputs) + tuple(params) 103 | return CheckpointFunction.apply(func, len(inputs), *args) 104 | else: 105 | return func(*inputs) 106 | 107 | 108 | class CheckpointFunction(torch.autograd.Function): 109 | @staticmethod 110 | def forward(ctx, run_function, length, *args): 111 | ctx.run_function = run_function 112 | ctx.input_tensors = list(args[:length]) 113 | ctx.input_params = list(args[length:]) 114 | 115 | with torch.no_grad(): 116 | output_tensors = ctx.run_function(*ctx.input_tensors) 117 | return output_tensors 118 | 119 | @staticmethod 120 | def backward(ctx, *output_grads): 121 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 122 | with torch.enable_grad(): 123 | # Fixes a bug where the first op in run_function modifies the 124 | # Tensor storage in place, which is not allowed for detach()'d 125 | # Tensors. 126 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 127 | output_tensors = ctx.run_function(*shallow_copies) 128 | input_grads = torch.autograd.grad( 129 | output_tensors, 130 | ctx.input_tensors + ctx.input_params, 131 | output_grads, 132 | allow_unused=True, 133 | ) 134 | del ctx.input_tensors 135 | del ctx.input_params 136 | del output_tensors 137 | return (None, None) + input_grads 138 | 139 | 140 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 141 | """ 142 | Create sinusoidal timestep embeddings. 143 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 144 | These may be fractional. 145 | :param dim: the dimension of the output. 146 | :param max_period: controls the minimum frequency of the embeddings. 147 | :return: an [N x dim] Tensor of positional embeddings. 148 | """ 149 | if not repeat_only: 150 | half = dim // 2 151 | freqs = torch.exp( 152 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 153 | ).to(device=timesteps.device) 154 | args = timesteps[:, None].float() * freqs[None] 155 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 156 | if dim % 2: 157 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 158 | else: 159 | embedding = repeat(timesteps, 'b -> b d', d=dim) 160 | return embedding 161 | 162 | 163 | def zero_module(module): 164 | """ 165 | Zero out the parameters of a module and return it. 166 | """ 167 | for p in module.parameters(): 168 | p.detach().zero_() 169 | return module 170 | 171 | 172 | def scale_module(module, scale): 173 | """ 174 | Scale the parameters of a module and return it. 175 | """ 176 | for p in module.parameters(): 177 | p.detach().mul_(scale) 178 | return module 179 | 180 | 181 | def mean_flat(tensor): 182 | """ 183 | Take the mean over all non-batch dimensions. 184 | """ 185 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 186 | 187 | 188 | def normalization(channels): 189 | """ 190 | Make a standard normalization layer. 191 | :param channels: number of input channels. 192 | :return: an nn.Module for normalization. 193 | """ 194 | return GroupNorm32(32, channels) 195 | 196 | 197 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 198 | class SiLU(nn.Module): 199 | def forward(self, x): 200 | return x * torch.sigmoid(x) 201 | 202 | 203 | class GroupNorm32(nn.GroupNorm): 204 | def forward(self, x): 205 | return super().forward(x.float()).type(x.dtype) 206 | 207 | def conv_nd(dims, *args, **kwargs): 208 | """ 209 | Create a 1D, 2D, or 3D convolution module. 210 | """ 211 | if dims == 1: 212 | return nn.Conv1d(*args, **kwargs) 213 | elif dims == 2: 214 | return nn.Conv2d(*args, **kwargs) 215 | elif dims == 3: 216 | return nn.Conv3d(*args, **kwargs) 217 | raise ValueError(f"unsupported dimensions: {dims}") 218 | 219 | 220 | def linear(*args, **kwargs): 221 | """ 222 | Create a linear module. 223 | """ 224 | return nn.Linear(*args, **kwargs) 225 | 226 | 227 | def avg_pool_nd(dims, *args, **kwargs): 228 | """ 229 | Create a 1D, 2D, or 3D average pooling module. 230 | """ 231 | if dims == 1: 232 | return nn.AvgPool1d(*args, **kwargs) 233 | elif dims == 2: 234 | return nn.AvgPool2d(*args, **kwargs) 235 | elif dims == 3: 236 | return nn.AvgPool3d(*args, **kwargs) 237 | raise ValueError(f"unsupported dimensions: {dims}") 238 | 239 | 240 | class HybridConditioner(nn.Module): 241 | 242 | def __init__(self, c_concat_config, c_crossattn_config): 243 | super().__init__() 244 | self.concat_conditioner = instantiate_from_config(c_concat_config) 245 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 246 | 247 | def forward(self, c_concat, c_crossattn): 248 | c_concat = self.concat_conditioner(c_concat) 249 | c_crossattn = self.crossattn_conditioner(c_crossattn) 250 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} 251 | 252 | 253 | def noise_like(shape, device, repeat=False): 254 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 255 | noise = lambda: torch.randn(shape, device=device) 256 | return repeat_noise() if repeat else noise() -------------------------------------------------------------------------------- /ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sstzal/DiffTalk/ce38ae67a7688ab9987dab5dfee176e748add65a/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /ldm/modules/distributions/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sstzal/DiffTalk/ce38ae67a7688ab9987dab5dfee176e748add65a/ldm/modules/distributions/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/modules/distributions/__pycache__/distributions.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sstzal/DiffTalk/ce38ae67a7688ab9987dab5dfee176e748add65a/ldm/modules/distributions/__pycache__/distributions.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | 67 | tensor = None 68 | for obj in (mean1, logvar1, mean2, logvar2): 69 | if isinstance(obj, torch.Tensor): 70 | tensor = obj 71 | break 72 | assert tensor is not None, "at least one argument must be a Tensor" 73 | 74 | # Force variances to be Tensors. Broadcasting helps convert scalars to 75 | # Tensors, but it does not work for torch.exp(). 76 | logvar1, logvar2 = [ 77 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 78 | for x in (logvar1, logvar2) 79 | ] 80 | 81 | return 0.5 * ( 82 | -1.0 83 | + logvar2 84 | - logvar1 85 | + torch.exp(logvar1 - logvar2) 86 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 87 | ) 88 | -------------------------------------------------------------------------------- /ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sstzal/DiffTalk/ce38ae67a7688ab9987dab5dfee176e748add65a/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /ldm/modules/encoders/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sstzal/DiffTalk/ce38ae67a7688ab9987dab5dfee176e748add65a/ldm/modules/encoders/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/modules/encoders/__pycache__/modules.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sstzal/DiffTalk/ce38ae67a7688ab9987dab5dfee176e748add65a/ldm/modules/encoders/__pycache__/modules.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/modules/encoders/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from functools import partial 4 | from einops import rearrange, repeat 5 | import pdb 6 | 7 | from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test 8 | 9 | 10 | class AbstractEncoder(nn.Module): 11 | def __init__(self): 12 | super().__init__() 13 | 14 | def encode(self, *args, **kwargs): 15 | raise NotImplementedError 16 | 17 | 18 | 19 | class ClassEmbedder(nn.Module): 20 | def __init__(self, embed_dim, n_classes=1000, key='class'): 21 | super().__init__() 22 | self.key = key 23 | self.embedding = nn.Embedding(n_classes, embed_dim) 24 | 25 | def forward(self, batch, key=None): 26 | if key is None: 27 | key = self.key 28 | # this is for use in crossattn 29 | c = batch[key][:, None] 30 | c = self.embedding(c) 31 | return c 32 | 33 | 34 | class TransformerEmbedder(AbstractEncoder): 35 | """Some transformer encoder layers""" 36 | def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): 37 | super().__init__() 38 | self.device = device 39 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 40 | attn_layers=Encoder(dim=n_embed, depth=n_layer)) 41 | 42 | def forward(self, tokens): 43 | tokens = tokens.to(self.device) # meh 44 | z = self.transformer(tokens, return_embeddings=True) 45 | return z 46 | 47 | def encode(self, x): 48 | return self(x) 49 | 50 | 51 | class BERTTokenizer(AbstractEncoder): 52 | """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" 53 | def __init__(self, device="cuda", vq_interface=True, max_length=77): 54 | super().__init__() 55 | from transformers import BertTokenizerFast # TODO: add to reuquirements 56 | self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") 57 | self.device = device 58 | self.vq_interface = vq_interface 59 | self.max_length = max_length 60 | 61 | def forward(self, text): 62 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 63 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 64 | tokens = batch_encoding["input_ids"].to(self.device) 65 | return tokens 66 | 67 | @torch.no_grad() 68 | def encode(self, text): 69 | tokens = self(text) 70 | if not self.vq_interface: 71 | return tokens 72 | return None, None, [None, None, tokens] 73 | 74 | def decode(self, text): 75 | return text 76 | 77 | 78 | class BERTEmbedder(AbstractEncoder): 79 | """Uses the BERT tokenizr model and add some transformer encoder layers""" 80 | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, 81 | device="cuda",use_tokenizer=True, embedding_dropout=0.0): 82 | super().__init__() 83 | self.use_tknz_fn = use_tokenizer 84 | if self.use_tknz_fn: 85 | self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) 86 | self.device = device 87 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 88 | attn_layers=Encoder(dim=n_embed, depth=n_layer), 89 | emb_dropout=embedding_dropout) 90 | 91 | def forward(self, text): 92 | if self.use_tknz_fn: 93 | tokens = self.tknz_fn(text)#.to(self.device) 94 | else: 95 | tokens = text 96 | z = self.transformer(tokens, return_embeddings=True) 97 | return z 98 | 99 | def encode(self, text): 100 | # output of length 77 101 | return self(text) 102 | 103 | 104 | class SpatialRescaler(nn.Module): 105 | def __init__(self, 106 | n_stages=1, 107 | method='bilinear', 108 | multiplier=0.5, 109 | in_channels=3, 110 | out_channels=None, 111 | bias=False): 112 | super().__init__() 113 | self.n_stages = n_stages 114 | assert self.n_stages >= 0 115 | assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] 116 | self.multiplier = multiplier 117 | self.interpolator = partial(torch.nn.functional.interpolate, mode=method) 118 | self.remap_output = out_channels is not None 119 | if self.remap_output: 120 | print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') 121 | self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) 122 | 123 | def forward(self,x): 124 | for stage in range(self.n_stages): 125 | x = self.interpolator(x, scale_factor=self.multiplier) 126 | 127 | 128 | if self.remap_output: 129 | x = self.channel_mapper(x) 130 | return x 131 | 132 | def encode(self, x): 133 | return self(x) 134 | 135 | 136 | class FrozenCLIPTextEmbedder(nn.Module): 137 | """ 138 | Uses the CLIP transformer encoder for text. 139 | """ 140 | def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True): 141 | super().__init__() 142 | self.model, _ = clip.load(version, jit=False, device="cpu") 143 | self.device = device 144 | self.max_length = max_length 145 | self.n_repeat = n_repeat 146 | self.normalize = normalize 147 | 148 | def freeze(self): 149 | self.model = self.model.eval() 150 | for param in self.parameters(): 151 | param.requires_grad = False 152 | 153 | def forward(self, text): 154 | tokens = clip.tokenize(text).to(self.device) 155 | z = self.model.encode_text(tokens) 156 | if self.normalize: 157 | z = z / torch.linalg.norm(z, dim=1, keepdim=True) 158 | return z 159 | 160 | def encode(self, text): 161 | z = self(text) 162 | if z.ndim==2: 163 | z = z[:, None, :] 164 | z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) 165 | return z 166 | 167 | 168 | class FrozenClipImageEmbedder(nn.Module): 169 | """ 170 | Uses the CLIP image encoder. 171 | """ 172 | def __init__( 173 | self, 174 | model, 175 | jit=False, 176 | device='cuda' if torch.cuda.is_available() else 'cpu', 177 | antialias=False, 178 | ): 179 | super().__init__() 180 | self.model, _ = clip.load(name=model, device=device, jit=jit) 181 | 182 | self.antialias = antialias 183 | 184 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) 185 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 186 | 187 | def preprocess(self, x): 188 | # normalize to [0,1] 189 | x = kornia.geometry.resize(x, (224, 224), 190 | interpolation='bicubic',align_corners=True, 191 | antialias=self.antialias) 192 | x = (x + 1.) / 2. 193 | # renormalize according to clip 194 | x = kornia.enhance.normalize(x, self.mean, self.std) 195 | return x 196 | 197 | def forward(self, x): 198 | # x is assumed to be in range [-1,1] 199 | return self.model.encode_image(self.preprocess(x)) 200 | 201 | 202 | #audio encoder 203 | # Audio feature extractor 204 | class AudioAttNet(nn.Module): 205 | def __init__(self, dim_aud=76, seq_len=8): 206 | super(AudioAttNet, self).__init__() 207 | self.seq_len = seq_len 208 | self.dim_aud = dim_aud 209 | self.attentionConvNet = nn.Sequential( # b x subspace_dim x seq_len 210 | nn.Conv1d(self.dim_aud, 16, kernel_size=3, 211 | stride=1, padding=1, bias=True), 212 | nn.LeakyReLU(0.02, True), 213 | nn.Conv1d(16, 8, kernel_size=3, stride=1, padding=1, bias=True), 214 | nn.LeakyReLU(0.02, True), 215 | nn.Conv1d(8, 4, kernel_size=3, stride=1, padding=1, bias=True), 216 | nn.LeakyReLU(0.02, True), 217 | nn.Conv1d(4, 2, kernel_size=3, stride=1, padding=1, bias=True), 218 | nn.LeakyReLU(0.02, True), 219 | nn.Conv1d(2, 1, kernel_size=3, stride=1, padding=1, bias=True), 220 | nn.LeakyReLU(0.02, True) 221 | ) 222 | self.attentionNet = nn.Sequential( 223 | nn.Linear(in_features=self.seq_len, 224 | out_features=self.seq_len, bias=True), 225 | nn.Softmax(dim=2) 226 | ) 227 | 228 | def forward(self, x): 229 | y = x[..., :self.dim_aud].permute(0, 2, 1) 230 | y = self.attentionConvNet(y) 231 | y = self.attentionNet(y) 232 | return torch.matmul(y,x).squeeze(1) 233 | 234 | 235 | # Audio feature extractor 236 | class AudioNet(nn.Module): 237 | def __init__(self, dim_aud=76, win_size=16): 238 | super(AudioNet, self).__init__() 239 | self.win_size = win_size 240 | self.dim_aud = dim_aud 241 | self.encoder_conv = nn.Sequential( 242 | nn.Conv1d(29, 32, kernel_size=3, stride=2, 243 | padding=1, bias=True), 244 | nn.LeakyReLU(0.02, True), 245 | nn.Conv1d(32, 32, kernel_size=3, stride=2, 246 | padding=1, bias=True), 247 | nn.LeakyReLU(0.02, True), 248 | nn.Conv1d(32, 64, kernel_size=3, stride=2, 249 | padding=1, bias=True), 250 | nn.LeakyReLU(0.02, True), 251 | nn.Conv1d(64, 64, kernel_size=3, stride=2, 252 | padding=1, bias=True), 253 | nn.LeakyReLU(0.02, True), 254 | ) 255 | self.encoder_fc1 = nn.Sequential( 256 | nn.Linear(64, 64), 257 | nn.LeakyReLU(0.02, True), 258 | nn.Linear(64, dim_aud), 259 | ) 260 | 261 | def forward(self, x): 262 | half_w = int(self.win_size/2) 263 | x = x[:, 8-half_w:8+half_w, :].permute(0, 2, 1) 264 | x = self.encoder_conv(x).squeeze(-1) 265 | x = self.encoder_fc1(x).squeeze() 266 | return x 267 | 268 | 269 | #lipID encoder 270 | # Audio feature extractor 271 | class LipNet(nn.Module): 272 | def __init__(self, dim_out=64): 273 | super(LipNet, self).__init__() 274 | self.dim_out = dim_out 275 | self.encoder_conv = nn.Sequential( 276 | nn.Conv2d(3, 32, kernel_size=3, stride=2, 277 | padding=1, bias=True), 278 | nn.BatchNorm2d(32), 279 | nn.ReLU(), 280 | nn.Conv2d(32, 64, kernel_size=3, stride=2, 281 | padding=1, bias=True), 282 | nn.BatchNorm2d(64), 283 | nn.ReLU(), 284 | nn.Conv2d(64, 64, kernel_size=3, stride=1, 285 | padding=1, bias=True), 286 | nn.BatchNorm2d(64), 287 | nn.ReLU(), 288 | nn.Conv2d(64, 64, kernel_size=3, stride=2, 289 | padding=1, bias=True), 290 | nn.BatchNorm2d(64), 291 | nn.ReLU(), 292 | nn.AvgPool2d(3, stride=2), 293 | ) 294 | self.encoder_fc1 = nn.Sequential( 295 | nn.Linear(64*2*4, 128), 296 | nn.LeakyReLU(0.02, True), 297 | nn.Linear(128, dim_out), 298 | ) 299 | 300 | def forward(self, x): 301 | x = self.encoder_conv(x) 302 | x = x.reshape(x.shape[0], -1) 303 | x = self.encoder_fc1(x).squeeze() 304 | return x 305 | 306 | 307 | 308 | #landmark encoder 309 | #landmarks feature extractor 310 | class LdmNet(nn.Module): 311 | def __init__(self, dim_out=64): 312 | super(LdmNet, self).__init__() 313 | self.dim_out = dim_out 314 | self.encoder_fc1 = nn.Sequential( 315 | nn.Linear(39*2, 128), 316 | nn.LeakyReLU(0.02, True), 317 | nn.Linear(128, dim_out), 318 | nn.LeakyReLU(0.02, True), 319 | nn.Linear(dim_out, dim_out), 320 | ) 321 | 322 | def forward(self, x): 323 | x = x.reshape(x.shape[0], -1) 324 | x = self.encoder_fc1(x) 325 | return x 326 | 327 | 328 | class LatentCode(nn.Module): 329 | def __init__(self, class_length=176): 330 | super(LatentCode, self).__init__() 331 | init_l = torch.zeros(class_length, 32) 332 | self.latent_code = nn.Parameter(init_l) 333 | 334 | def forward(self, class_id): 335 | code = self.latent_code[class_id] 336 | return code -------------------------------------------------------------------------------- /ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /ldm/modules/image_degradation/bsrgan.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | import cv2 5 | import torch 6 | 7 | from functools import partial 8 | import random 9 | from scipy import ndimage 10 | import scipy 11 | import scipy.stats as ss 12 | from scipy.interpolate import interp2d 13 | from scipy.linalg import orth 14 | import albumentations 15 | 16 | import ldm.modules.image_degradation.utils_image as util 17 | 18 | 19 | def modcrop_np(img, sf): 20 | ''' 21 | Args: 22 | img: numpy image, WxH or WxHxC 23 | sf: scale factor 24 | Return: 25 | cropped image 26 | ''' 27 | w, h = img.shape[:2] 28 | im = np.copy(img) 29 | return im[:w - w % sf, :h - h % sf, ...] 30 | 31 | 32 | 33 | 34 | def analytic_kernel(k): 35 | """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" 36 | k_size = k.shape[0] 37 | # Calculate the big kernels size 38 | big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) 39 | # Loop over the small kernel to fill the big one 40 | for r in range(k_size): 41 | for c in range(k_size): 42 | big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k 43 | # Crop the edges of the big kernel to ignore very small values and increase run time of SR 44 | crop = k_size // 2 45 | cropped_big_k = big_k[crop:-crop, crop:-crop] 46 | # Normalize to 1 47 | return cropped_big_k / cropped_big_k.sum() 48 | 49 | 50 | def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): 51 | """ generate an anisotropic Gaussian kernel 52 | Args: 53 | ksize : e.g., 15, kernel size 54 | theta : [0, pi], rotation angle range 55 | l1 : [0.1,50], scaling of eigenvalues 56 | l2 : [0.1,l1], scaling of eigenvalues 57 | If l1 = l2, will get an isotropic Gaussian kernel. 58 | Returns: 59 | k : kernel 60 | """ 61 | 62 | v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) 63 | V = np.array([[v[0], v[1]], [v[1], -v[0]]]) 64 | D = np.array([[l1, 0], [0, l2]]) 65 | Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) 66 | k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) 67 | 68 | return k 69 | 70 | 71 | def gm_blur_kernel(mean, cov, size=15): 72 | center = size / 2.0 + 0.5 73 | k = np.zeros([size, size]) 74 | for y in range(size): 75 | for x in range(size): 76 | cy = y - center + 1 77 | cx = x - center + 1 78 | k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) 79 | 80 | k = k / np.sum(k) 81 | return k 82 | 83 | 84 | def shift_pixel(x, sf, upper_left=True): 85 | """shift pixel for super-resolution with different scale factors 86 | Args: 87 | x: WxHxC or WxH 88 | sf: scale factor 89 | upper_left: shift direction 90 | """ 91 | h, w = x.shape[:2] 92 | shift = (sf - 1) * 0.5 93 | xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) 94 | if upper_left: 95 | x1 = xv + shift 96 | y1 = yv + shift 97 | else: 98 | x1 = xv - shift 99 | y1 = yv - shift 100 | 101 | x1 = np.clip(x1, 0, w - 1) 102 | y1 = np.clip(y1, 0, h - 1) 103 | 104 | if x.ndim == 2: 105 | x = interp2d(xv, yv, x)(x1, y1) 106 | if x.ndim == 3: 107 | for i in range(x.shape[-1]): 108 | x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) 109 | 110 | return x 111 | 112 | 113 | def blur(x, k): 114 | ''' 115 | x: image, NxcxHxW 116 | k: kernel, Nx1xhxw 117 | ''' 118 | n, c = x.shape[:2] 119 | p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 120 | x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') 121 | k = k.repeat(1, c, 1, 1) 122 | k = k.view(-1, 1, k.shape[2], k.shape[3]) 123 | x = x.view(1, -1, x.shape[2], x.shape[3]) 124 | x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) 125 | x = x.view(n, c, x.shape[2], x.shape[3]) 126 | 127 | return x 128 | 129 | 130 | def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): 131 | 132 | # Set random eigen-vals (lambdas) and angle (theta) for COV matrix 133 | lambda_1 = min_var + np.random.rand() * (max_var - min_var) 134 | lambda_2 = min_var + np.random.rand() * (max_var - min_var) 135 | theta = np.random.rand() * np.pi # random theta 136 | noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 137 | 138 | # Set COV matrix using Lambdas and Theta 139 | LAMBDA = np.diag([lambda_1, lambda_2]) 140 | Q = np.array([[np.cos(theta), -np.sin(theta)], 141 | [np.sin(theta), np.cos(theta)]]) 142 | SIGMA = Q @ LAMBDA @ Q.T 143 | INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] 144 | 145 | # Set expectation position (shifting kernel for aligned image) 146 | MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) 147 | MU = MU[None, None, :, None] 148 | 149 | # Create meshgrid for Gaussian 150 | [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) 151 | Z = np.stack([X, Y], 2)[:, :, :, None] 152 | 153 | # Calcualte Gaussian for every pixel of the kernel 154 | ZZ = Z - MU 155 | ZZ_t = ZZ.transpose(0, 1, 3, 2) 156 | raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) 157 | 158 | # shift the kernel so it will be centered 159 | # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) 160 | 161 | # Normalize the kernel and return 162 | # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) 163 | kernel = raw_kernel / np.sum(raw_kernel) 164 | return kernel 165 | 166 | 167 | def fspecial_gaussian(hsize, sigma): 168 | hsize = [hsize, hsize] 169 | siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] 170 | std = sigma 171 | [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) 172 | arg = -(x * x + y * y) / (2 * std * std) 173 | h = np.exp(arg) 174 | h[h < scipy.finfo(float).eps * h.max()] = 0 175 | sumh = h.sum() 176 | if sumh != 0: 177 | h = h / sumh 178 | return h 179 | 180 | 181 | def fspecial_laplacian(alpha): 182 | alpha = max([0, min([alpha, 1])]) 183 | h1 = alpha / (alpha + 1) 184 | h2 = (1 - alpha) / (alpha + 1) 185 | h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] 186 | h = np.array(h) 187 | return h 188 | 189 | 190 | def fspecial(filter_type, *args, **kwargs): 191 | 192 | if filter_type == 'gaussian': 193 | return fspecial_gaussian(*args, **kwargs) 194 | if filter_type == 'laplacian': 195 | return fspecial_laplacian(*args, **kwargs) 196 | 197 | 198 | 199 | def bicubic_degradation(x, sf=3): 200 | ''' 201 | Args: 202 | x: HxWxC image, [0, 1] 203 | sf: down-scale factor 204 | Return: 205 | bicubicly downsampled LR image 206 | ''' 207 | x = util.imresize_np(x, scale=1 / sf) 208 | return x 209 | 210 | 211 | def srmd_degradation(x, k, sf=3): 212 | ''' blur + bicubic downsampling 213 | Args: 214 | x: HxWxC image, [0, 1] 215 | k: hxw, double 216 | sf: down-scale factor 217 | Return: 218 | downsampled LR image 219 | Reference: 220 | @inproceedings{zhang2018learning, 221 | title={Learning a single convolutional super-resolution network for multiple degradations}, 222 | author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, 223 | booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, 224 | pages={3262--3271}, 225 | year={2018} 226 | } 227 | ''' 228 | x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' 229 | x = bicubic_degradation(x, sf=sf) 230 | return x 231 | 232 | 233 | def dpsr_degradation(x, k, sf=3): 234 | ''' bicubic downsampling + blur 235 | Args: 236 | x: HxWxC image, [0, 1] 237 | k: hxw, double 238 | sf: down-scale factor 239 | Return: 240 | downsampled LR image 241 | Reference: 242 | @inproceedings{zhang2019deep, 243 | title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, 244 | author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, 245 | booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, 246 | pages={1671--1681}, 247 | year={2019} 248 | } 249 | ''' 250 | x = bicubic_degradation(x, sf=sf) 251 | x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') 252 | return x 253 | 254 | 255 | def classical_degradation(x, k, sf=3): 256 | ''' blur + downsampling 257 | Args: 258 | x: HxWxC image, [0, 1]/[0, 255] 259 | k: hxw, double 260 | sf: down-scale factor 261 | Return: 262 | downsampled LR image 263 | ''' 264 | x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') 265 | # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) 266 | st = 0 267 | return x[st::sf, st::sf, ...] 268 | 269 | 270 | def add_sharpening(img, weight=0.5, radius=50, threshold=10): 271 | """USM sharpening. borrowed from real-ESRGAN 272 | Input image: I; Blurry image: B. 273 | 1. K = I + weight * (I - B) 274 | 2. Mask = 1 if abs(I - B) > threshold, else: 0 275 | 3. Blur mask: 276 | 4. Out = Mask * K + (1 - Mask) * I 277 | Args: 278 | img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. 279 | weight (float): Sharp weight. Default: 1. 280 | radius (float): Kernel size of Gaussian blur. Default: 50. 281 | threshold (int): 282 | """ 283 | if radius % 2 == 0: 284 | radius += 1 285 | blur = cv2.GaussianBlur(img, (radius, radius), 0) 286 | residual = img - blur 287 | mask = np.abs(residual) * 255 > threshold 288 | mask = mask.astype('float32') 289 | soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) 290 | 291 | K = img + weight * residual 292 | K = np.clip(K, 0, 1) 293 | return soft_mask * K + (1 - soft_mask) * img 294 | 295 | 296 | def add_blur(img, sf=4): 297 | wd2 = 4.0 + sf 298 | wd = 2.0 + 0.2 * sf 299 | if random.random() < 0.5: 300 | l1 = wd2 * random.random() 301 | l2 = wd2 * random.random() 302 | k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) 303 | else: 304 | k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random()) 305 | img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') 306 | 307 | return img 308 | 309 | 310 | def add_resize(img, sf=4): 311 | rnum = np.random.rand() 312 | if rnum > 0.8: # up 313 | sf1 = random.uniform(1, 2) 314 | elif rnum < 0.7: # down 315 | sf1 = random.uniform(0.5 / sf, 1) 316 | else: 317 | sf1 = 1.0 318 | img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) 319 | img = np.clip(img, 0.0, 1.0) 320 | 321 | return img 322 | 323 | 324 | def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): 325 | noise_level = random.randint(noise_level1, noise_level2) 326 | rnum = np.random.rand() 327 | if rnum > 0.6: # add color Gaussian noise 328 | img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) 329 | elif rnum < 0.4: # add grayscale Gaussian noise 330 | img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) 331 | else: # add noise 332 | L = noise_level2 / 255. 333 | D = np.diag(np.random.rand(3)) 334 | U = orth(np.random.rand(3, 3)) 335 | conv = np.dot(np.dot(np.transpose(U), D), U) 336 | img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) 337 | img = np.clip(img, 0.0, 1.0) 338 | return img 339 | 340 | 341 | def add_speckle_noise(img, noise_level1=2, noise_level2=25): 342 | noise_level = random.randint(noise_level1, noise_level2) 343 | img = np.clip(img, 0.0, 1.0) 344 | rnum = random.random() 345 | if rnum > 0.6: 346 | img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) 347 | elif rnum < 0.4: 348 | img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) 349 | else: 350 | L = noise_level2 / 255. 351 | D = np.diag(np.random.rand(3)) 352 | U = orth(np.random.rand(3, 3)) 353 | conv = np.dot(np.dot(np.transpose(U), D), U) 354 | img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) 355 | img = np.clip(img, 0.0, 1.0) 356 | return img 357 | 358 | 359 | def add_Poisson_noise(img): 360 | img = np.clip((img * 255.0).round(), 0, 255) / 255. 361 | vals = 10 ** (2 * random.random() + 2.0) # [2, 4] 362 | if random.random() < 0.5: 363 | img = np.random.poisson(img * vals).astype(np.float32) / vals 364 | else: 365 | img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) 366 | img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. 367 | noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray 368 | img += noise_gray[:, :, np.newaxis] 369 | img = np.clip(img, 0.0, 1.0) 370 | return img 371 | 372 | 373 | def add_JPEG_noise(img): 374 | quality_factor = random.randint(30, 95) 375 | img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) 376 | result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) 377 | img = cv2.imdecode(encimg, 1) 378 | img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) 379 | return img 380 | 381 | 382 | def random_crop(lq, hq, sf=4, lq_patchsize=64): 383 | h, w = lq.shape[:2] 384 | rnd_h = random.randint(0, h - lq_patchsize) 385 | rnd_w = random.randint(0, w - lq_patchsize) 386 | lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] 387 | 388 | rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) 389 | hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] 390 | return lq, hq 391 | 392 | 393 | def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): 394 | """ 395 | img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) 396 | sf: scale factor 397 | isp_model: camera ISP model 398 | Returns 399 | ------- 400 | img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] 401 | hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] 402 | """ 403 | isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 404 | sf_ori = sf 405 | 406 | h1, w1 = img.shape[:2] 407 | img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop 408 | h, w = img.shape[:2] 409 | 410 | if h < lq_patchsize * sf or w < lq_patchsize * sf: 411 | raise ValueError(f'img size ({h1}X{w1}) is too small!') 412 | 413 | hq = img.copy() 414 | 415 | if sf == 4 and random.random() < scale2_prob: # downsample1 416 | if np.random.rand() < 0.5: 417 | img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), 418 | interpolation=random.choice([1, 2, 3])) 419 | else: 420 | img = util.imresize_np(img, 1 / 2, True) 421 | img = np.clip(img, 0.0, 1.0) 422 | sf = 2 423 | 424 | shuffle_order = random.sample(range(7), 7) 425 | idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) 426 | if idx1 > idx2: # keep downsample3 last 427 | shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] 428 | 429 | for i in shuffle_order: 430 | 431 | if i == 0: 432 | img = add_blur(img, sf=sf) 433 | 434 | elif i == 1: 435 | img = add_blur(img, sf=sf) 436 | 437 | elif i == 2: 438 | a, b = img.shape[1], img.shape[0] 439 | # downsample2 440 | if random.random() < 0.75: 441 | sf1 = random.uniform(1, 2 * sf) 442 | img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), 443 | interpolation=random.choice([1, 2, 3])) 444 | else: 445 | k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) 446 | k_shifted = shift_pixel(k, sf) 447 | k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel 448 | img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') 449 | img = img[0::sf, 0::sf, ...] # nearest downsampling 450 | img = np.clip(img, 0.0, 1.0) 451 | 452 | elif i == 3: 453 | # downsample3 454 | img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) 455 | img = np.clip(img, 0.0, 1.0) 456 | 457 | elif i == 4: 458 | # add Gaussian noise 459 | img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) 460 | 461 | elif i == 5: 462 | # add JPEG noise 463 | if random.random() < jpeg_prob: 464 | img = add_JPEG_noise(img) 465 | 466 | elif i == 6: 467 | # add processed camera sensor noise 468 | if random.random() < isp_prob and isp_model is not None: 469 | with torch.no_grad(): 470 | img, hq = isp_model.forward(img.copy(), hq) 471 | 472 | # add final JPEG compression noise 473 | img = add_JPEG_noise(img) 474 | 475 | # random crop 476 | img, hq = random_crop(img, hq, sf_ori, lq_patchsize) 477 | 478 | return img, hq 479 | 480 | 481 | # todo no isp_model? 482 | def degradation_bsrgan_variant(image, sf=4, isp_model=None): 483 | """ 484 | sf: scale factor 485 | isp_model: camera ISP model 486 | Returns 487 | ------- 488 | img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] 489 | hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] 490 | """ 491 | image = util.uint2single(image) 492 | isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 493 | sf_ori = sf 494 | 495 | h1, w1 = image.shape[:2] 496 | image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop 497 | h, w = image.shape[:2] 498 | 499 | hq = image.copy() 500 | 501 | if sf == 4 and random.random() < scale2_prob: # downsample1 502 | if np.random.rand() < 0.5: 503 | image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), 504 | interpolation=random.choice([1, 2, 3])) 505 | else: 506 | image = util.imresize_np(image, 1 / 2, True) 507 | image = np.clip(image, 0.0, 1.0) 508 | sf = 2 509 | 510 | shuffle_order = random.sample(range(7), 7) 511 | idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) 512 | if idx1 > idx2: # keep downsample3 last 513 | shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] 514 | 515 | for i in shuffle_order: 516 | 517 | if i == 0: 518 | image = add_blur(image, sf=sf) 519 | 520 | elif i == 1: 521 | image = add_blur(image, sf=sf) 522 | 523 | elif i == 2: 524 | a, b = image.shape[1], image.shape[0] 525 | # downsample2 526 | if random.random() < 0.75: 527 | sf1 = random.uniform(1, 2 * sf) 528 | image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), 529 | interpolation=random.choice([1, 2, 3])) 530 | else: 531 | k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) 532 | k_shifted = shift_pixel(k, sf) 533 | k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel 534 | image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') 535 | image = image[0::sf, 0::sf, ...] # nearest downsampling 536 | image = np.clip(image, 0.0, 1.0) 537 | 538 | elif i == 3: 539 | # downsample3 540 | image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) 541 | image = np.clip(image, 0.0, 1.0) 542 | 543 | elif i == 4: 544 | # add Gaussian noise 545 | image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25) 546 | 547 | elif i == 5: 548 | # add JPEG noise 549 | if random.random() < jpeg_prob: 550 | image = add_JPEG_noise(image) 551 | 552 | # elif i == 6: 553 | # # add processed camera sensor noise 554 | # if random.random() < isp_prob and isp_model is not None: 555 | # with torch.no_grad(): 556 | # img, hq = isp_model.forward(img.copy(), hq) 557 | 558 | # add final JPEG compression noise 559 | image = add_JPEG_noise(image) 560 | image = util.single2uint(image) 561 | example = {"image":image} 562 | return example 563 | 564 | 565 | # TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc... 566 | def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None): 567 | """ 568 | img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) 569 | sf: scale factor 570 | use_shuffle: the degradation shuffle 571 | use_sharp: sharpening the img 572 | Returns 573 | ------- 574 | img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] 575 | hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] 576 | """ 577 | 578 | h1, w1 = img.shape[:2] 579 | img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop 580 | h, w = img.shape[:2] 581 | 582 | if h < lq_patchsize * sf or w < lq_patchsize * sf: 583 | raise ValueError(f'img size ({h1}X{w1}) is too small!') 584 | 585 | if use_sharp: 586 | img = add_sharpening(img) 587 | hq = img.copy() 588 | 589 | if random.random() < shuffle_prob: 590 | shuffle_order = random.sample(range(13), 13) 591 | else: 592 | shuffle_order = list(range(13)) 593 | # local shuffle for noise, JPEG is always the last one 594 | shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6))) 595 | shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13))) 596 | 597 | poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1 598 | 599 | for i in shuffle_order: 600 | if i == 0: 601 | img = add_blur(img, sf=sf) 602 | elif i == 1: 603 | img = add_resize(img, sf=sf) 604 | elif i == 2: 605 | img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) 606 | elif i == 3: 607 | if random.random() < poisson_prob: 608 | img = add_Poisson_noise(img) 609 | elif i == 4: 610 | if random.random() < speckle_prob: 611 | img = add_speckle_noise(img) 612 | elif i == 5: 613 | if random.random() < isp_prob and isp_model is not None: 614 | with torch.no_grad(): 615 | img, hq = isp_model.forward(img.copy(), hq) 616 | elif i == 6: 617 | img = add_JPEG_noise(img) 618 | elif i == 7: 619 | img = add_blur(img, sf=sf) 620 | elif i == 8: 621 | img = add_resize(img, sf=sf) 622 | elif i == 9: 623 | img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) 624 | elif i == 10: 625 | if random.random() < poisson_prob: 626 | img = add_Poisson_noise(img) 627 | elif i == 11: 628 | if random.random() < speckle_prob: 629 | img = add_speckle_noise(img) 630 | elif i == 12: 631 | if random.random() < isp_prob and isp_model is not None: 632 | with torch.no_grad(): 633 | img, hq = isp_model.forward(img.copy(), hq) 634 | else: 635 | print('check the shuffle!') 636 | 637 | # resize to desired size 638 | img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])), 639 | interpolation=random.choice([1, 2, 3])) 640 | 641 | # add final JPEG compression noise 642 | img = add_JPEG_noise(img) 643 | 644 | # random crop 645 | img, hq = random_crop(img, hq, sf, lq_patchsize) 646 | 647 | return img, hq 648 | 649 | 650 | if __name__ == '__main__': 651 | print("hey") 652 | img = util.imread_uint('utils/test.png', 3) 653 | print(img) 654 | img = util.uint2single(img) 655 | print(img) 656 | img = img[:448, :448] 657 | h = img.shape[0] // 4 658 | print("resizing to", h) 659 | sf = 4 660 | deg_fn = partial(degradation_bsrgan_variant, sf=sf) 661 | for i in range(20): 662 | print(i) 663 | img_lq = deg_fn(img) 664 | print(img_lq) 665 | img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"] 666 | print(img_lq.shape) 667 | print("bicubic", img_lq_bicubic.shape) 668 | print(img_hq.shape) 669 | lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), 670 | interpolation=0) 671 | lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), 672 | interpolation=0) 673 | img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) 674 | util.imsave(img_concat, str(i) + '.png') 675 | 676 | 677 | -------------------------------------------------------------------------------- /ldm/modules/image_degradation/bsrgan_light.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | import cv2 4 | import torch 5 | 6 | from functools import partial 7 | import random 8 | from scipy import ndimage 9 | import scipy 10 | import scipy.stats as ss 11 | from scipy.interpolate import interp2d 12 | from scipy.linalg import orth 13 | import albumentations 14 | 15 | import ldm.modules.image_degradation.utils_image as util 16 | 17 | 18 | def modcrop_np(img, sf): 19 | ''' 20 | Args: 21 | img: numpy image, WxH or WxHxC 22 | sf: scale factor 23 | Return: 24 | cropped image 25 | ''' 26 | w, h = img.shape[:2] 27 | im = np.copy(img) 28 | return im[:w - w % sf, :h - h % sf, ...] 29 | 30 | 31 | 32 | 33 | def analytic_kernel(k): 34 | """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" 35 | k_size = k.shape[0] 36 | # Calculate the big kernels size 37 | big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) 38 | # Loop over the small kernel to fill the big one 39 | for r in range(k_size): 40 | for c in range(k_size): 41 | big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k 42 | # Crop the edges of the big kernel to ignore very small values and increase run time of SR 43 | crop = k_size // 2 44 | cropped_big_k = big_k[crop:-crop, crop:-crop] 45 | # Normalize to 1 46 | return cropped_big_k / cropped_big_k.sum() 47 | 48 | 49 | def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): 50 | """ generate an anisotropic Gaussian kernel 51 | Args: 52 | ksize : e.g., 15, kernel size 53 | theta : [0, pi], rotation angle range 54 | l1 : [0.1,50], scaling of eigenvalues 55 | l2 : [0.1,l1], scaling of eigenvalues 56 | If l1 = l2, will get an isotropic Gaussian kernel. 57 | Returns: 58 | k : kernel 59 | """ 60 | 61 | v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) 62 | V = np.array([[v[0], v[1]], [v[1], -v[0]]]) 63 | D = np.array([[l1, 0], [0, l2]]) 64 | Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) 65 | k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) 66 | 67 | return k 68 | 69 | 70 | def gm_blur_kernel(mean, cov, size=15): 71 | center = size / 2.0 + 0.5 72 | k = np.zeros([size, size]) 73 | for y in range(size): 74 | for x in range(size): 75 | cy = y - center + 1 76 | cx = x - center + 1 77 | k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) 78 | 79 | k = k / np.sum(k) 80 | return k 81 | 82 | 83 | def shift_pixel(x, sf, upper_left=True): 84 | """shift pixel for super-resolution with different scale factors 85 | Args: 86 | x: WxHxC or WxH 87 | sf: scale factor 88 | upper_left: shift direction 89 | """ 90 | h, w = x.shape[:2] 91 | shift = (sf - 1) * 0.5 92 | xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) 93 | if upper_left: 94 | x1 = xv + shift 95 | y1 = yv + shift 96 | else: 97 | x1 = xv - shift 98 | y1 = yv - shift 99 | 100 | x1 = np.clip(x1, 0, w - 1) 101 | y1 = np.clip(y1, 0, h - 1) 102 | 103 | if x.ndim == 2: 104 | x = interp2d(xv, yv, x)(x1, y1) 105 | if x.ndim == 3: 106 | for i in range(x.shape[-1]): 107 | x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) 108 | 109 | return x 110 | 111 | 112 | def blur(x, k): 113 | ''' 114 | x: image, NxcxHxW 115 | k: kernel, Nx1xhxw 116 | ''' 117 | n, c = x.shape[:2] 118 | p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 119 | x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') 120 | k = k.repeat(1, c, 1, 1) 121 | k = k.view(-1, 1, k.shape[2], k.shape[3]) 122 | x = x.view(1, -1, x.shape[2], x.shape[3]) 123 | x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) 124 | x = x.view(n, c, x.shape[2], x.shape[3]) 125 | 126 | return x 127 | 128 | 129 | def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): 130 | 131 | # Set random eigen-vals (lambdas) and angle (theta) for COV matrix 132 | lambda_1 = min_var + np.random.rand() * (max_var - min_var) 133 | lambda_2 = min_var + np.random.rand() * (max_var - min_var) 134 | theta = np.random.rand() * np.pi # random theta 135 | noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 136 | 137 | # Set COV matrix using Lambdas and Theta 138 | LAMBDA = np.diag([lambda_1, lambda_2]) 139 | Q = np.array([[np.cos(theta), -np.sin(theta)], 140 | [np.sin(theta), np.cos(theta)]]) 141 | SIGMA = Q @ LAMBDA @ Q.T 142 | INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] 143 | 144 | # Set expectation position (shifting kernel for aligned image) 145 | MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) 146 | MU = MU[None, None, :, None] 147 | 148 | # Create meshgrid for Gaussian 149 | [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) 150 | Z = np.stack([X, Y], 2)[:, :, :, None] 151 | 152 | # Calcualte Gaussian for every pixel of the kernel 153 | ZZ = Z - MU 154 | ZZ_t = ZZ.transpose(0, 1, 3, 2) 155 | raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) 156 | 157 | # shift the kernel so it will be centered 158 | # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) 159 | 160 | # Normalize the kernel and return 161 | # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) 162 | kernel = raw_kernel / np.sum(raw_kernel) 163 | return kernel 164 | 165 | 166 | def fspecial_gaussian(hsize, sigma): 167 | hsize = [hsize, hsize] 168 | siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] 169 | std = sigma 170 | [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) 171 | arg = -(x * x + y * y) / (2 * std * std) 172 | h = np.exp(arg) 173 | h[h < scipy.finfo(float).eps * h.max()] = 0 174 | sumh = h.sum() 175 | if sumh != 0: 176 | h = h / sumh 177 | return h 178 | 179 | 180 | def fspecial_laplacian(alpha): 181 | alpha = max([0, min([alpha, 1])]) 182 | h1 = alpha / (alpha + 1) 183 | h2 = (1 - alpha) / (alpha + 1) 184 | h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] 185 | h = np.array(h) 186 | return h 187 | 188 | 189 | def fspecial(filter_type, *args, **kwargs): 190 | 191 | if filter_type == 'gaussian': 192 | return fspecial_gaussian(*args, **kwargs) 193 | if filter_type == 'laplacian': 194 | return fspecial_laplacian(*args, **kwargs) 195 | 196 | 197 | """ 198 | # -------------------------------------------- 199 | # degradation models 200 | # -------------------------------------------- 201 | """ 202 | 203 | 204 | def bicubic_degradation(x, sf=3): 205 | ''' 206 | Args: 207 | x: HxWxC image, [0, 1] 208 | sf: down-scale factor 209 | Return: 210 | bicubicly downsampled LR image 211 | ''' 212 | x = util.imresize_np(x, scale=1 / sf) 213 | return x 214 | 215 | 216 | def srmd_degradation(x, k, sf=3): 217 | ''' blur + bicubic downsampling 218 | Args: 219 | x: HxWxC image, [0, 1] 220 | k: hxw, double 221 | sf: down-scale factor 222 | Return: 223 | downsampled LR image 224 | Reference: 225 | @inproceedings{zhang2018learning, 226 | title={Learning a single convolutional super-resolution network for multiple degradations}, 227 | author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, 228 | booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, 229 | pages={3262--3271}, 230 | year={2018} 231 | } 232 | ''' 233 | x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' 234 | x = bicubic_degradation(x, sf=sf) 235 | return x 236 | 237 | 238 | def dpsr_degradation(x, k, sf=3): 239 | ''' bicubic downsampling + blur 240 | Args: 241 | x: HxWxC image, [0, 1] 242 | k: hxw, double 243 | sf: down-scale factor 244 | Return: 245 | downsampled LR image 246 | Reference: 247 | @inproceedings{zhang2019deep, 248 | title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, 249 | author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, 250 | booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, 251 | pages={1671--1681}, 252 | year={2019} 253 | } 254 | ''' 255 | x = bicubic_degradation(x, sf=sf) 256 | x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') 257 | return x 258 | 259 | 260 | def classical_degradation(x, k, sf=3): 261 | ''' blur + downsampling 262 | Args: 263 | x: HxWxC image, [0, 1]/[0, 255] 264 | k: hxw, double 265 | sf: down-scale factor 266 | Return: 267 | downsampled LR image 268 | ''' 269 | x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') 270 | # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) 271 | st = 0 272 | return x[st::sf, st::sf, ...] 273 | 274 | 275 | def add_sharpening(img, weight=0.5, radius=50, threshold=10): 276 | """USM sharpening. borrowed from real-ESRGAN 277 | Input image: I; Blurry image: B. 278 | 1. K = I + weight * (I - B) 279 | 2. Mask = 1 if abs(I - B) > threshold, else: 0 280 | 3. Blur mask: 281 | 4. Out = Mask * K + (1 - Mask) * I 282 | Args: 283 | img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. 284 | weight (float): Sharp weight. Default: 1. 285 | radius (float): Kernel size of Gaussian blur. Default: 50. 286 | threshold (int): 287 | """ 288 | if radius % 2 == 0: 289 | radius += 1 290 | blur = cv2.GaussianBlur(img, (radius, radius), 0) 291 | residual = img - blur 292 | mask = np.abs(residual) * 255 > threshold 293 | mask = mask.astype('float32') 294 | soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) 295 | 296 | K = img + weight * residual 297 | K = np.clip(K, 0, 1) 298 | return soft_mask * K + (1 - soft_mask) * img 299 | 300 | 301 | def add_blur(img, sf=4): 302 | wd2 = 4.0 + sf 303 | wd = 2.0 + 0.2 * sf 304 | 305 | wd2 = wd2/4 306 | wd = wd/4 307 | 308 | if random.random() < 0.5: 309 | l1 = wd2 * random.random() 310 | l2 = wd2 * random.random() 311 | k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) 312 | else: 313 | k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random()) 314 | img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') 315 | 316 | return img 317 | 318 | 319 | def add_resize(img, sf=4): 320 | rnum = np.random.rand() 321 | if rnum > 0.8: # up 322 | sf1 = random.uniform(1, 2) 323 | elif rnum < 0.7: # down 324 | sf1 = random.uniform(0.5 / sf, 1) 325 | else: 326 | sf1 = 1.0 327 | img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) 328 | img = np.clip(img, 0.0, 1.0) 329 | 330 | return img 331 | 332 | 333 | 334 | def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): 335 | noise_level = random.randint(noise_level1, noise_level2) 336 | rnum = np.random.rand() 337 | if rnum > 0.6: # add color Gaussian noise 338 | img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) 339 | elif rnum < 0.4: # add grayscale Gaussian noise 340 | img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) 341 | else: # add noise 342 | L = noise_level2 / 255. 343 | D = np.diag(np.random.rand(3)) 344 | U = orth(np.random.rand(3, 3)) 345 | conv = np.dot(np.dot(np.transpose(U), D), U) 346 | img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) 347 | img = np.clip(img, 0.0, 1.0) 348 | return img 349 | 350 | 351 | def add_speckle_noise(img, noise_level1=2, noise_level2=25): 352 | noise_level = random.randint(noise_level1, noise_level2) 353 | img = np.clip(img, 0.0, 1.0) 354 | rnum = random.random() 355 | if rnum > 0.6: 356 | img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) 357 | elif rnum < 0.4: 358 | img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) 359 | else: 360 | L = noise_level2 / 255. 361 | D = np.diag(np.random.rand(3)) 362 | U = orth(np.random.rand(3, 3)) 363 | conv = np.dot(np.dot(np.transpose(U), D), U) 364 | img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) 365 | img = np.clip(img, 0.0, 1.0) 366 | return img 367 | 368 | 369 | def add_Poisson_noise(img): 370 | img = np.clip((img * 255.0).round(), 0, 255) / 255. 371 | vals = 10 ** (2 * random.random() + 2.0) # [2, 4] 372 | if random.random() < 0.5: 373 | img = np.random.poisson(img * vals).astype(np.float32) / vals 374 | else: 375 | img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) 376 | img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. 377 | noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray 378 | img += noise_gray[:, :, np.newaxis] 379 | img = np.clip(img, 0.0, 1.0) 380 | return img 381 | 382 | 383 | def add_JPEG_noise(img): 384 | quality_factor = random.randint(80, 95) 385 | img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) 386 | result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) 387 | img = cv2.imdecode(encimg, 1) 388 | img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) 389 | return img 390 | 391 | 392 | def random_crop(lq, hq, sf=4, lq_patchsize=64): 393 | h, w = lq.shape[:2] 394 | rnd_h = random.randint(0, h - lq_patchsize) 395 | rnd_w = random.randint(0, w - lq_patchsize) 396 | lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] 397 | 398 | rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) 399 | hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] 400 | return lq, hq 401 | 402 | 403 | def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): 404 | """ 405 | img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) 406 | sf: scale factor 407 | isp_model: camera ISP model 408 | Returns 409 | ------- 410 | img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] 411 | hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] 412 | """ 413 | isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 414 | sf_ori = sf 415 | 416 | h1, w1 = img.shape[:2] 417 | img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop 418 | h, w = img.shape[:2] 419 | 420 | if h < lq_patchsize * sf or w < lq_patchsize * sf: 421 | raise ValueError(f'img size ({h1}X{w1}) is too small!') 422 | 423 | hq = img.copy() 424 | 425 | if sf == 4 and random.random() < scale2_prob: # downsample1 426 | if np.random.rand() < 0.5: 427 | img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), 428 | interpolation=random.choice([1, 2, 3])) 429 | else: 430 | img = util.imresize_np(img, 1 / 2, True) 431 | img = np.clip(img, 0.0, 1.0) 432 | sf = 2 433 | 434 | shuffle_order = random.sample(range(7), 7) 435 | idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) 436 | if idx1 > idx2: # keep downsample3 last 437 | shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] 438 | 439 | for i in shuffle_order: 440 | 441 | if i == 0: 442 | img = add_blur(img, sf=sf) 443 | 444 | elif i == 1: 445 | img = add_blur(img, sf=sf) 446 | 447 | elif i == 2: 448 | a, b = img.shape[1], img.shape[0] 449 | # downsample2 450 | if random.random() < 0.75: 451 | sf1 = random.uniform(1, 2 * sf) 452 | img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), 453 | interpolation=random.choice([1, 2, 3])) 454 | else: 455 | k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) 456 | k_shifted = shift_pixel(k, sf) 457 | k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel 458 | img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') 459 | img = img[0::sf, 0::sf, ...] # nearest downsampling 460 | img = np.clip(img, 0.0, 1.0) 461 | 462 | elif i == 3: 463 | # downsample3 464 | img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) 465 | img = np.clip(img, 0.0, 1.0) 466 | 467 | elif i == 4: 468 | # add Gaussian noise 469 | img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8) 470 | 471 | elif i == 5: 472 | # add JPEG noise 473 | if random.random() < jpeg_prob: 474 | img = add_JPEG_noise(img) 475 | 476 | elif i == 6: 477 | # add processed camera sensor noise 478 | if random.random() < isp_prob and isp_model is not None: 479 | with torch.no_grad(): 480 | img, hq = isp_model.forward(img.copy(), hq) 481 | 482 | # add final JPEG compression noise 483 | img = add_JPEG_noise(img) 484 | 485 | # random crop 486 | img, hq = random_crop(img, hq, sf_ori, lq_patchsize) 487 | 488 | return img, hq 489 | 490 | 491 | # todo no isp_model? 492 | def degradation_bsrgan_variant(image, sf=4, isp_model=None): 493 | """ 494 | sf: scale factor 495 | isp_model: camera ISP model 496 | Returns 497 | ------- 498 | img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] 499 | hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] 500 | """ 501 | image = util.uint2single(image) 502 | isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 503 | sf_ori = sf 504 | 505 | h1, w1 = image.shape[:2] 506 | image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop 507 | h, w = image.shape[:2] 508 | 509 | hq = image.copy() 510 | 511 | if sf == 4 and random.random() < scale2_prob: # downsample1 512 | if np.random.rand() < 0.5: 513 | image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), 514 | interpolation=random.choice([1, 2, 3])) 515 | else: 516 | image = util.imresize_np(image, 1 / 2, True) 517 | image = np.clip(image, 0.0, 1.0) 518 | sf = 2 519 | 520 | shuffle_order = random.sample(range(7), 7) 521 | idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) 522 | if idx1 > idx2: # keep downsample3 last 523 | shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] 524 | 525 | for i in shuffle_order: 526 | 527 | if i == 0: 528 | image = add_blur(image, sf=sf) 529 | 530 | # elif i == 1: 531 | # image = add_blur(image, sf=sf) 532 | 533 | if i == 0: 534 | pass 535 | 536 | elif i == 2: 537 | a, b = image.shape[1], image.shape[0] 538 | # downsample2 539 | if random.random() < 0.8: 540 | sf1 = random.uniform(1, 2 * sf) 541 | image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), 542 | interpolation=random.choice([1, 2, 3])) 543 | else: 544 | k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) 545 | k_shifted = shift_pixel(k, sf) 546 | k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel 547 | image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') 548 | image = image[0::sf, 0::sf, ...] # nearest downsampling 549 | 550 | image = np.clip(image, 0.0, 1.0) 551 | 552 | elif i == 3: 553 | # downsample3 554 | image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) 555 | image = np.clip(image, 0.0, 1.0) 556 | 557 | elif i == 4: 558 | # add Gaussian noise 559 | image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2) 560 | 561 | elif i == 5: 562 | # add JPEG noise 563 | if random.random() < jpeg_prob: 564 | image = add_JPEG_noise(image) 565 | # 566 | # elif i == 6: 567 | # # add processed camera sensor noise 568 | # if random.random() < isp_prob and isp_model is not None: 569 | # with torch.no_grad(): 570 | # img, hq = isp_model.forward(img.copy(), hq) 571 | 572 | # add final JPEG compression noise 573 | image = add_JPEG_noise(image) 574 | image = util.single2uint(image) 575 | example = {"image": image} 576 | return example 577 | 578 | 579 | 580 | 581 | if __name__ == '__main__': 582 | print("hey") 583 | img = util.imread_uint('utils/test.png', 3) 584 | img = img[:448, :448] 585 | h = img.shape[0] // 4 586 | print("resizing to", h) 587 | sf = 4 588 | deg_fn = partial(degradation_bsrgan_variant, sf=sf) 589 | for i in range(20): 590 | print(i) 591 | img_hq = img 592 | img_lq = deg_fn(img)["image"] 593 | img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq) 594 | print(img_lq) 595 | img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"] 596 | print(img_lq.shape) 597 | print("bicubic", img_lq_bicubic.shape) 598 | print(img_hq.shape) 599 | lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), 600 | interpolation=0) 601 | lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), 602 | (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), 603 | interpolation=0) 604 | img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) 605 | util.imsave(img_concat, str(i) + '.png') 606 | -------------------------------------------------------------------------------- /ldm/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /ldm/modules/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 5 | 6 | 7 | class LPIPSWithDiscriminator(nn.Module): 8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 11 | disc_loss="hinge"): 12 | 13 | super().__init__() 14 | assert disc_loss in ["hinge", "vanilla"] 15 | self.kl_weight = kl_weight 16 | self.pixel_weight = pixelloss_weight 17 | self.perceptual_loss = LPIPS().eval() 18 | self.perceptual_weight = perceptual_weight 19 | # output log variance 20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 21 | 22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 23 | n_layers=disc_num_layers, 24 | use_actnorm=use_actnorm 25 | ).apply(weights_init) 26 | self.discriminator_iter_start = disc_start 27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 28 | self.disc_factor = disc_factor 29 | self.discriminator_weight = disc_weight 30 | self.disc_conditional = disc_conditional 31 | 32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 33 | if last_layer is not None: 34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 36 | else: 37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 39 | 40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 42 | d_weight = d_weight * self.discriminator_weight 43 | return d_weight 44 | 45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 46 | global_step, last_layer=None, cond=None, split="train", 47 | weights=None): 48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 49 | if self.perceptual_weight > 0: 50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 51 | rec_loss = rec_loss + self.perceptual_weight * p_loss 52 | 53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 54 | weighted_nll_loss = nll_loss 55 | if weights is not None: 56 | weighted_nll_loss = weights*nll_loss 57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 59 | kl_loss = posteriors.kl() 60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 61 | 62 | # now the GAN part 63 | if optimizer_idx == 0: 64 | # generator update 65 | if cond is None: 66 | assert not self.disc_conditional 67 | logits_fake = self.discriminator(reconstructions.contiguous()) 68 | else: 69 | assert self.disc_conditional 70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 71 | g_loss = -torch.mean(logits_fake) 72 | 73 | if self.disc_factor > 0.0: 74 | try: 75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 76 | except RuntimeError: 77 | assert not self.training 78 | d_weight = torch.tensor(0.0) 79 | else: 80 | d_weight = torch.tensor(0.0) 81 | 82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 84 | 85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), 87 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 88 | "{}/d_weight".format(split): d_weight.detach(), 89 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 90 | "{}/g_loss".format(split): g_loss.detach().mean(), 91 | } 92 | return loss, log 93 | 94 | if optimizer_idx == 1: 95 | # second pass for discriminator update 96 | if cond is None: 97 | logits_real = self.discriminator(inputs.contiguous().detach()) 98 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 99 | else: 100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 102 | 103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 105 | 106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 107 | "{}/logits_real".format(split): logits_real.detach().mean(), 108 | "{}/logits_fake".format(split): logits_fake.detach().mean() 109 | } 110 | return d_loss, log 111 | 112 | -------------------------------------------------------------------------------- /ldm/modules/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from einops import repeat 5 | 6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init 7 | from taming.modules.losses.lpips import LPIPS 8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss 9 | 10 | 11 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): 12 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] 13 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) 14 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) 15 | loss_real = (weights * loss_real).sum() / weights.sum() 16 | loss_fake = (weights * loss_fake).sum() / weights.sum() 17 | d_loss = 0.5 * (loss_real + loss_fake) 18 | return d_loss 19 | 20 | def adopt_weight(weight, global_step, threshold=0, value=0.): 21 | if global_step < threshold: 22 | weight = value 23 | return weight 24 | 25 | 26 | def measure_perplexity(predicted_indices, n_embed): 27 | 28 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) 29 | avg_probs = encodings.mean(0) 30 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 31 | cluster_use = torch.sum(avg_probs > 0) 32 | return perplexity, cluster_use 33 | 34 | def l1(x, y): 35 | return torch.abs(x-y) 36 | 37 | 38 | def l2(x, y): 39 | return torch.pow((x-y), 2) 40 | 41 | 42 | class VQLPIPSWithDiscriminator(nn.Module): 43 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 44 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 45 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 46 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", 47 | pixel_loss="l1"): 48 | super().__init__() 49 | assert disc_loss in ["hinge", "vanilla"] 50 | assert perceptual_loss in ["lpips", "clips", "dists"] 51 | assert pixel_loss in ["l1", "l2"] 52 | self.codebook_weight = codebook_weight 53 | self.pixel_weight = pixelloss_weight 54 | if perceptual_loss == "lpips": 55 | print(f"{self.__class__.__name__}: Running with LPIPS.") 56 | self.perceptual_loss = LPIPS().eval() 57 | else: 58 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") 59 | self.perceptual_weight = perceptual_weight 60 | 61 | if pixel_loss == "l1": 62 | self.pixel_loss = l1 63 | else: 64 | self.pixel_loss = l2 65 | 66 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 67 | n_layers=disc_num_layers, 68 | use_actnorm=use_actnorm, 69 | ndf=disc_ndf 70 | ).apply(weights_init) 71 | self.discriminator_iter_start = disc_start 72 | if disc_loss == "hinge": 73 | self.disc_loss = hinge_d_loss 74 | elif disc_loss == "vanilla": 75 | self.disc_loss = vanilla_d_loss 76 | else: 77 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 78 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 79 | self.disc_factor = disc_factor 80 | self.discriminator_weight = disc_weight 81 | self.disc_conditional = disc_conditional 82 | self.n_classes = n_classes 83 | 84 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 85 | if last_layer is not None: 86 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 87 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 88 | else: 89 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 90 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 91 | 92 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 93 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 94 | d_weight = d_weight * self.discriminator_weight 95 | return d_weight 96 | 97 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 98 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None): 99 | if not exists(codebook_loss): 100 | codebook_loss = torch.tensor([0.]).to(inputs.device) 101 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 102 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) 103 | if self.perceptual_weight > 0: 104 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 105 | rec_loss = rec_loss + self.perceptual_weight * p_loss 106 | else: 107 | p_loss = torch.tensor([0.0]) 108 | 109 | nll_loss = rec_loss 110 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 111 | nll_loss = torch.mean(nll_loss) 112 | 113 | # now the GAN part 114 | if optimizer_idx == 0: 115 | # generator update 116 | if cond is None: 117 | assert not self.disc_conditional 118 | logits_fake = self.discriminator(reconstructions.contiguous()) 119 | else: 120 | assert self.disc_conditional 121 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 122 | g_loss = -torch.mean(logits_fake) 123 | 124 | try: 125 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 126 | except RuntimeError: 127 | assert not self.training 128 | d_weight = torch.tensor(0.0) 129 | 130 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 131 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 132 | 133 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 134 | "{}/quant_loss".format(split): codebook_loss.detach().mean(), 135 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 136 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 137 | "{}/p_loss".format(split): p_loss.detach().mean(), 138 | "{}/d_weight".format(split): d_weight.detach(), 139 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 140 | "{}/g_loss".format(split): g_loss.detach().mean(), 141 | } 142 | if predicted_indices is not None: 143 | assert self.n_classes is not None 144 | with torch.no_grad(): 145 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) 146 | log[f"{split}/perplexity"] = perplexity 147 | log[f"{split}/cluster_usage"] = cluster_usage 148 | return loss, log 149 | 150 | if optimizer_idx == 1: 151 | # second pass for discriminator update 152 | if cond is None: 153 | logits_real = self.discriminator(inputs.contiguous().detach()) 154 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 155 | else: 156 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 157 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 158 | 159 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 160 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 161 | 162 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 163 | "{}/logits_real".format(split): logits_real.detach().mean(), 164 | "{}/logits_fake".format(split): logits_fake.detach().mean() 165 | } 166 | return d_loss, log 167 | -------------------------------------------------------------------------------- /ldm/modules/x_transformer.py: -------------------------------------------------------------------------------- 1 | """shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers""" 2 | import torch 3 | from torch import nn, einsum 4 | import torch.nn.functional as F 5 | from functools import partial 6 | from inspect import isfunction 7 | from collections import namedtuple 8 | from einops import rearrange, repeat, reduce 9 | 10 | # constants 11 | 12 | DEFAULT_DIM_HEAD = 64 13 | 14 | Intermediates = namedtuple('Intermediates', [ 15 | 'pre_softmax_attn', 16 | 'post_softmax_attn' 17 | ]) 18 | 19 | LayerIntermediates = namedtuple('Intermediates', [ 20 | 'hiddens', 21 | 'attn_intermediates' 22 | ]) 23 | 24 | 25 | class AbsolutePositionalEmbedding(nn.Module): 26 | def __init__(self, dim, max_seq_len): 27 | super().__init__() 28 | self.emb = nn.Embedding(max_seq_len, dim) 29 | self.init_() 30 | 31 | def init_(self): 32 | nn.init.normal_(self.emb.weight, std=0.02) 33 | 34 | def forward(self, x): 35 | n = torch.arange(x.shape[1], device=x.device) 36 | return self.emb(n)[None, :, :] 37 | 38 | 39 | class FixedPositionalEmbedding(nn.Module): 40 | def __init__(self, dim): 41 | super().__init__() 42 | inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) 43 | self.register_buffer('inv_freq', inv_freq) 44 | 45 | def forward(self, x, seq_dim=1, offset=0): 46 | t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset 47 | sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) 48 | emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) 49 | return emb[None, :, :] 50 | 51 | 52 | # helpers 53 | 54 | def exists(val): 55 | return val is not None 56 | 57 | 58 | def default(val, d): 59 | if exists(val): 60 | return val 61 | return d() if isfunction(d) else d 62 | 63 | 64 | def always(val): 65 | def inner(*args, **kwargs): 66 | return val 67 | return inner 68 | 69 | 70 | def not_equals(val): 71 | def inner(x): 72 | return x != val 73 | return inner 74 | 75 | 76 | def equals(val): 77 | def inner(x): 78 | return x == val 79 | return inner 80 | 81 | 82 | def max_neg_value(tensor): 83 | return -torch.finfo(tensor.dtype).max 84 | 85 | 86 | # keyword argument helpers 87 | 88 | def pick_and_pop(keys, d): 89 | values = list(map(lambda key: d.pop(key), keys)) 90 | return dict(zip(keys, values)) 91 | 92 | 93 | def group_dict_by_key(cond, d): 94 | return_val = [dict(), dict()] 95 | for key in d.keys(): 96 | match = bool(cond(key)) 97 | ind = int(not match) 98 | return_val[ind][key] = d[key] 99 | return (*return_val,) 100 | 101 | 102 | def string_begins_with(prefix, str): 103 | return str.startswith(prefix) 104 | 105 | 106 | def group_by_key_prefix(prefix, d): 107 | return group_dict_by_key(partial(string_begins_with, prefix), d) 108 | 109 | 110 | def groupby_prefix_and_trim(prefix, d): 111 | kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) 112 | kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) 113 | return kwargs_without_prefix, kwargs 114 | 115 | 116 | # classes 117 | class Scale(nn.Module): 118 | def __init__(self, value, fn): 119 | super().__init__() 120 | self.value = value 121 | self.fn = fn 122 | 123 | def forward(self, x, **kwargs): 124 | x, *rest = self.fn(x, **kwargs) 125 | return (x * self.value, *rest) 126 | 127 | 128 | class Rezero(nn.Module): 129 | def __init__(self, fn): 130 | super().__init__() 131 | self.fn = fn 132 | self.g = nn.Parameter(torch.zeros(1)) 133 | 134 | def forward(self, x, **kwargs): 135 | x, *rest = self.fn(x, **kwargs) 136 | return (x * self.g, *rest) 137 | 138 | 139 | class ScaleNorm(nn.Module): 140 | def __init__(self, dim, eps=1e-5): 141 | super().__init__() 142 | self.scale = dim ** -0.5 143 | self.eps = eps 144 | self.g = nn.Parameter(torch.ones(1)) 145 | 146 | def forward(self, x): 147 | norm = torch.norm(x, dim=-1, keepdim=True) * self.scale 148 | return x / norm.clamp(min=self.eps) * self.g 149 | 150 | 151 | class RMSNorm(nn.Module): 152 | def __init__(self, dim, eps=1e-8): 153 | super().__init__() 154 | self.scale = dim ** -0.5 155 | self.eps = eps 156 | self.g = nn.Parameter(torch.ones(dim)) 157 | 158 | def forward(self, x): 159 | norm = torch.norm(x, dim=-1, keepdim=True) * self.scale 160 | return x / norm.clamp(min=self.eps) * self.g 161 | 162 | 163 | class Residual(nn.Module): 164 | def forward(self, x, residual): 165 | return x + residual 166 | 167 | 168 | class GRUGating(nn.Module): 169 | def __init__(self, dim): 170 | super().__init__() 171 | self.gru = nn.GRUCell(dim, dim) 172 | 173 | def forward(self, x, residual): 174 | gated_output = self.gru( 175 | rearrange(x, 'b n d -> (b n) d'), 176 | rearrange(residual, 'b n d -> (b n) d') 177 | ) 178 | 179 | return gated_output.reshape_as(x) 180 | 181 | 182 | # feedforward 183 | 184 | class GEGLU(nn.Module): 185 | def __init__(self, dim_in, dim_out): 186 | super().__init__() 187 | self.proj = nn.Linear(dim_in, dim_out * 2) 188 | 189 | def forward(self, x): 190 | x, gate = self.proj(x).chunk(2, dim=-1) 191 | return x * F.gelu(gate) 192 | 193 | 194 | class FeedForward(nn.Module): 195 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 196 | super().__init__() 197 | inner_dim = int(dim * mult) 198 | dim_out = default(dim_out, dim) 199 | project_in = nn.Sequential( 200 | nn.Linear(dim, inner_dim), 201 | nn.GELU() 202 | ) if not glu else GEGLU(dim, inner_dim) 203 | 204 | self.net = nn.Sequential( 205 | project_in, 206 | nn.Dropout(dropout), 207 | nn.Linear(inner_dim, dim_out) 208 | ) 209 | 210 | def forward(self, x): 211 | return self.net(x) 212 | 213 | 214 | # attention. 215 | class Attention(nn.Module): 216 | def __init__( 217 | self, 218 | dim, 219 | dim_head=DEFAULT_DIM_HEAD, 220 | heads=8, 221 | causal=False, 222 | mask=None, 223 | talking_heads=False, 224 | sparse_topk=None, 225 | use_entmax15=False, 226 | num_mem_kv=0, 227 | dropout=0., 228 | on_attn=False 229 | ): 230 | super().__init__() 231 | if use_entmax15: 232 | raise NotImplementedError("Check out entmax activation instead of softmax activation!") 233 | self.scale = dim_head ** -0.5 234 | self.heads = heads 235 | self.causal = causal 236 | self.mask = mask 237 | 238 | inner_dim = dim_head * heads 239 | 240 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 241 | self.to_k = nn.Linear(dim, inner_dim, bias=False) 242 | self.to_v = nn.Linear(dim, inner_dim, bias=False) 243 | self.dropout = nn.Dropout(dropout) 244 | 245 | # talking heads 246 | self.talking_heads = talking_heads 247 | if talking_heads: 248 | self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads)) 249 | self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads)) 250 | 251 | # explicit topk sparse attention 252 | self.sparse_topk = sparse_topk 253 | 254 | # entmax 255 | #self.attn_fn = entmax15 if use_entmax15 else F.softmax 256 | self.attn_fn = F.softmax 257 | 258 | # add memory key / values 259 | self.num_mem_kv = num_mem_kv 260 | if num_mem_kv > 0: 261 | self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) 262 | self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) 263 | 264 | # attention on attention 265 | self.attn_on_attn = on_attn 266 | self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim) 267 | 268 | def forward( 269 | self, 270 | x, 271 | context=None, 272 | mask=None, 273 | context_mask=None, 274 | rel_pos=None, 275 | sinusoidal_emb=None, 276 | prev_attn=None, 277 | mem=None 278 | ): 279 | b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device 280 | kv_input = default(context, x) 281 | 282 | q_input = x 283 | k_input = kv_input 284 | v_input = kv_input 285 | 286 | if exists(mem): 287 | k_input = torch.cat((mem, k_input), dim=-2) 288 | v_input = torch.cat((mem, v_input), dim=-2) 289 | 290 | if exists(sinusoidal_emb): 291 | # in shortformer, the query would start at a position offset depending on the past cached memory 292 | offset = k_input.shape[-2] - q_input.shape[-2] 293 | q_input = q_input + sinusoidal_emb(q_input, offset=offset) 294 | k_input = k_input + sinusoidal_emb(k_input) 295 | 296 | q = self.to_q(q_input) 297 | k = self.to_k(k_input) 298 | v = self.to_v(v_input) 299 | 300 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) 301 | 302 | input_mask = None 303 | if any(map(exists, (mask, context_mask))): 304 | q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) 305 | k_mask = q_mask if not exists(context) else context_mask 306 | k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) 307 | q_mask = rearrange(q_mask, 'b i -> b () i ()') 308 | k_mask = rearrange(k_mask, 'b j -> b () () j') 309 | input_mask = q_mask * k_mask 310 | 311 | if self.num_mem_kv > 0: 312 | mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) 313 | k = torch.cat((mem_k, k), dim=-2) 314 | v = torch.cat((mem_v, v), dim=-2) 315 | if exists(input_mask): 316 | input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) 317 | 318 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 319 | mask_value = max_neg_value(dots) 320 | 321 | if exists(prev_attn): 322 | dots = dots + prev_attn 323 | 324 | pre_softmax_attn = dots 325 | 326 | if talking_heads: 327 | dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous() 328 | 329 | if exists(rel_pos): 330 | dots = rel_pos(dots) 331 | 332 | if exists(input_mask): 333 | dots.masked_fill_(~input_mask, mask_value) 334 | del input_mask 335 | 336 | if self.causal: 337 | i, j = dots.shape[-2:] 338 | r = torch.arange(i, device=device) 339 | mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j') 340 | mask = F.pad(mask, (j - i, 0), value=False) 341 | dots.masked_fill_(mask, mask_value) 342 | del mask 343 | 344 | if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: 345 | top, _ = dots.topk(self.sparse_topk, dim=-1) 346 | vk = top[..., -1].unsqueeze(-1).expand_as(dots) 347 | mask = dots < vk 348 | dots.masked_fill_(mask, mask_value) 349 | del mask 350 | 351 | attn = self.attn_fn(dots, dim=-1) 352 | post_softmax_attn = attn 353 | 354 | attn = self.dropout(attn) 355 | 356 | if talking_heads: 357 | attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous() 358 | 359 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 360 | out = rearrange(out, 'b h n d -> b n (h d)') 361 | 362 | intermediates = Intermediates( 363 | pre_softmax_attn=pre_softmax_attn, 364 | post_softmax_attn=post_softmax_attn 365 | ) 366 | 367 | return self.to_out(out), intermediates 368 | 369 | 370 | class AttentionLayers(nn.Module): 371 | def __init__( 372 | self, 373 | dim, 374 | depth, 375 | heads=8, 376 | causal=False, 377 | cross_attend=False, 378 | only_cross=False, 379 | use_scalenorm=False, 380 | use_rmsnorm=False, 381 | use_rezero=False, 382 | rel_pos_num_buckets=32, 383 | rel_pos_max_distance=128, 384 | position_infused_attn=False, 385 | custom_layers=None, 386 | sandwich_coef=None, 387 | par_ratio=None, 388 | residual_attn=False, 389 | cross_residual_attn=False, 390 | macaron=False, 391 | pre_norm=True, 392 | gate_residual=False, 393 | **kwargs 394 | ): 395 | super().__init__() 396 | ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) 397 | attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs) 398 | 399 | dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) 400 | 401 | self.dim = dim 402 | self.depth = depth 403 | self.layers = nn.ModuleList([]) 404 | 405 | self.has_pos_emb = position_infused_attn 406 | self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None 407 | self.rotary_pos_emb = always(None) 408 | 409 | assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' 410 | self.rel_pos = None 411 | 412 | self.pre_norm = pre_norm 413 | 414 | self.residual_attn = residual_attn 415 | self.cross_residual_attn = cross_residual_attn 416 | 417 | norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm 418 | norm_class = RMSNorm if use_rmsnorm else norm_class 419 | norm_fn = partial(norm_class, dim) 420 | 421 | norm_fn = nn.Identity if use_rezero else norm_fn 422 | branch_fn = Rezero if use_rezero else None 423 | 424 | if cross_attend and not only_cross: 425 | default_block = ('a', 'c', 'f') 426 | elif cross_attend and only_cross: 427 | default_block = ('c', 'f') 428 | else: 429 | default_block = ('a', 'f') 430 | 431 | if macaron: 432 | default_block = ('f',) + default_block 433 | 434 | if exists(custom_layers): 435 | layer_types = custom_layers 436 | elif exists(par_ratio): 437 | par_depth = depth * len(default_block) 438 | assert 1 < par_ratio <= par_depth, 'par ratio out of range' 439 | default_block = tuple(filter(not_equals('f'), default_block)) 440 | par_attn = par_depth // par_ratio 441 | depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper 442 | par_width = (depth_cut + depth_cut // par_attn) // par_attn 443 | assert len(default_block) <= par_width, 'default block is too large for par_ratio' 444 | par_block = default_block + ('f',) * (par_width - len(default_block)) 445 | par_head = par_block * par_attn 446 | layer_types = par_head + ('f',) * (par_depth - len(par_head)) 447 | elif exists(sandwich_coef): 448 | assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' 449 | layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef 450 | else: 451 | layer_types = default_block * depth 452 | 453 | self.layer_types = layer_types 454 | self.num_attn_layers = len(list(filter(equals('a'), layer_types))) 455 | 456 | for layer_type in self.layer_types: 457 | if layer_type == 'a': 458 | layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) 459 | elif layer_type == 'c': 460 | layer = Attention(dim, heads=heads, **attn_kwargs) 461 | elif layer_type == 'f': 462 | layer = FeedForward(dim, **ff_kwargs) 463 | layer = layer if not macaron else Scale(0.5, layer) 464 | else: 465 | raise Exception(f'invalid layer type {layer_type}') 466 | 467 | if isinstance(layer, Attention) and exists(branch_fn): 468 | layer = branch_fn(layer) 469 | 470 | if gate_residual: 471 | residual_fn = GRUGating(dim) 472 | else: 473 | residual_fn = Residual() 474 | 475 | self.layers.append(nn.ModuleList([ 476 | norm_fn(), 477 | layer, 478 | residual_fn 479 | ])) 480 | 481 | def forward( 482 | self, 483 | x, 484 | context=None, 485 | mask=None, 486 | context_mask=None, 487 | mems=None, 488 | return_hiddens=False 489 | ): 490 | hiddens = [] 491 | intermediates = [] 492 | prev_attn = None 493 | prev_cross_attn = None 494 | 495 | mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers 496 | 497 | for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): 498 | is_last = ind == (len(self.layers) - 1) 499 | 500 | if layer_type == 'a': 501 | hiddens.append(x) 502 | layer_mem = mems.pop(0) 503 | 504 | residual = x 505 | 506 | if self.pre_norm: 507 | x = norm(x) 508 | 509 | if layer_type == 'a': 510 | out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos, 511 | prev_attn=prev_attn, mem=layer_mem) 512 | elif layer_type == 'c': 513 | out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn) 514 | elif layer_type == 'f': 515 | out = block(x) 516 | 517 | x = residual_fn(out, residual) 518 | 519 | if layer_type in ('a', 'c'): 520 | intermediates.append(inter) 521 | 522 | if layer_type == 'a' and self.residual_attn: 523 | prev_attn = inter.pre_softmax_attn 524 | elif layer_type == 'c' and self.cross_residual_attn: 525 | prev_cross_attn = inter.pre_softmax_attn 526 | 527 | if not self.pre_norm and not is_last: 528 | x = norm(x) 529 | 530 | if return_hiddens: 531 | intermediates = LayerIntermediates( 532 | hiddens=hiddens, 533 | attn_intermediates=intermediates 534 | ) 535 | 536 | return x, intermediates 537 | 538 | return x 539 | 540 | 541 | class Encoder(AttentionLayers): 542 | def __init__(self, **kwargs): 543 | assert 'causal' not in kwargs, 'cannot set causality on encoder' 544 | super().__init__(causal=False, **kwargs) 545 | 546 | 547 | 548 | class TransformerWrapper(nn.Module): 549 | def __init__( 550 | self, 551 | *, 552 | num_tokens, 553 | max_seq_len, 554 | attn_layers, 555 | emb_dim=None, 556 | max_mem_len=0., 557 | emb_dropout=0., 558 | num_memory_tokens=None, 559 | tie_embedding=False, 560 | use_pos_emb=True 561 | ): 562 | super().__init__() 563 | assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' 564 | 565 | dim = attn_layers.dim 566 | emb_dim = default(emb_dim, dim) 567 | 568 | self.max_seq_len = max_seq_len 569 | self.max_mem_len = max_mem_len 570 | self.num_tokens = num_tokens 571 | 572 | self.token_emb = nn.Embedding(num_tokens, emb_dim) 573 | self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( 574 | use_pos_emb and not attn_layers.has_pos_emb) else always(0) 575 | self.emb_dropout = nn.Dropout(emb_dropout) 576 | 577 | self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() 578 | self.attn_layers = attn_layers 579 | self.norm = nn.LayerNorm(dim) 580 | 581 | self.init_() 582 | 583 | self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() 584 | 585 | # memory tokens (like [cls]) from Memory Transformers paper 586 | num_memory_tokens = default(num_memory_tokens, 0) 587 | self.num_memory_tokens = num_memory_tokens 588 | if num_memory_tokens > 0: 589 | self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) 590 | 591 | # let funnel encoder know number of memory tokens, if specified 592 | if hasattr(attn_layers, 'num_memory_tokens'): 593 | attn_layers.num_memory_tokens = num_memory_tokens 594 | 595 | def init_(self): 596 | nn.init.normal_(self.token_emb.weight, std=0.02) 597 | 598 | def forward( 599 | self, 600 | x, 601 | return_embeddings=False, 602 | mask=None, 603 | return_mems=False, 604 | return_attn=False, 605 | mems=None, 606 | **kwargs 607 | ): 608 | b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens 609 | x = self.token_emb(x) 610 | x += self.pos_emb(x) 611 | x = self.emb_dropout(x) 612 | 613 | x = self.project_emb(x) 614 | 615 | if num_mem > 0: 616 | mem = repeat(self.memory_tokens, 'n d -> b n d', b=b) 617 | x = torch.cat((mem, x), dim=1) 618 | 619 | # auto-handle masking after appending memory tokens 620 | if exists(mask): 621 | mask = F.pad(mask, (num_mem, 0), value=True) 622 | 623 | x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) 624 | x = self.norm(x) 625 | 626 | mem, x = x[:, :num_mem], x[:, num_mem:] 627 | 628 | out = self.to_logits(x) if not return_embeddings else x 629 | 630 | if return_mems: 631 | hiddens = intermediates.hiddens 632 | new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens 633 | new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems)) 634 | return out, new_mems 635 | 636 | if return_attn: 637 | attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) 638 | return out, attn_maps 639 | 640 | return out 641 | 642 | -------------------------------------------------------------------------------- /ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | import numpy as np 5 | from collections import abc 6 | from einops import rearrange 7 | from functools import partial 8 | 9 | import multiprocessing as mp 10 | from threading import Thread 11 | from queue import Queue 12 | 13 | from inspect import isfunction 14 | from PIL import Image, ImageDraw, ImageFont 15 | 16 | 17 | def log_txt_as_img(wh, xc, size=10): 18 | # wh a tuple of (width, height) 19 | # xc a list of captions to plot 20 | b = len(xc) 21 | txts = list() 22 | for bi in range(b): 23 | txt = Image.new("RGB", wh, color="white") 24 | draw = ImageDraw.Draw(txt) 25 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 26 | nc = int(40 * (wh[0] / 256)) 27 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 28 | 29 | try: 30 | draw.text((0, 0), lines, fill="black", font=font) 31 | except UnicodeEncodeError: 32 | print("Cant encode string for logging. Skipping.") 33 | 34 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 35 | txts.append(txt) 36 | txts = np.stack(txts) 37 | txts = torch.tensor(txts) 38 | return txts 39 | 40 | 41 | def ismap(x): 42 | if not isinstance(x, torch.Tensor): 43 | return False 44 | return (len(x.shape) == 4) and (x.shape[1] > 3) 45 | 46 | 47 | def isimage(x): 48 | if not isinstance(x, torch.Tensor): 49 | return False 50 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 51 | 52 | 53 | def exists(x): 54 | return x is not None 55 | 56 | 57 | def default(val, d): 58 | if exists(val): 59 | return val 60 | return d() if isfunction(d) else d 61 | 62 | 63 | def mean_flat(tensor): 64 | 65 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 66 | 67 | 68 | def count_params(model, verbose=False): 69 | total_params = sum(p.numel() for p in model.parameters()) 70 | if verbose: 71 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 72 | return total_params 73 | 74 | 75 | def instantiate_from_config(config): 76 | if not "target" in config: 77 | if config == '__is_first_stage__': 78 | return None 79 | elif config == "__is_unconditional__": 80 | return None 81 | raise KeyError("Expected key `target` to instantiate.") 82 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 83 | 84 | 85 | def get_obj_from_str(string, reload=False): 86 | module, cls = string.rsplit(".", 1) 87 | if reload: 88 | module_imp = importlib.import_module(module) 89 | importlib.reload(module_imp) 90 | return getattr(importlib.import_module(module, package=None), cls) 91 | 92 | 93 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): 94 | # create dummy dataset instance 95 | 96 | # run prefetching 97 | if idx_to_fn: 98 | res = func(data, worker_id=idx) 99 | else: 100 | res = func(data) 101 | Q.put([idx, res]) 102 | Q.put("Done") 103 | 104 | 105 | def parallel_data_prefetch( 106 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False 107 | ): 108 | # if target_data_type not in ["ndarray", "list"]: 109 | # raise ValueError( 110 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." 111 | # ) 112 | if isinstance(data, np.ndarray) and target_data_type == "list": 113 | raise ValueError("list expected but function got ndarray.") 114 | elif isinstance(data, abc.Iterable): 115 | if isinstance(data, dict): 116 | print( 117 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' 118 | ) 119 | data = list(data.values()) 120 | if target_data_type == "ndarray": 121 | data = np.asarray(data) 122 | else: 123 | data = list(data) 124 | else: 125 | raise TypeError( 126 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." 127 | ) 128 | 129 | if cpu_intensive: 130 | Q = mp.Queue(1000) 131 | proc = mp.Process 132 | else: 133 | Q = Queue(1000) 134 | proc = Thread 135 | # spawn processes 136 | if target_data_type == "ndarray": 137 | arguments = [ 138 | [func, Q, part, i, use_worker_id] 139 | for i, part in enumerate(np.array_split(data, n_proc)) 140 | ] 141 | else: 142 | step = ( 143 | int(len(data) / n_proc + 1) 144 | if len(data) % n_proc != 0 145 | else int(len(data) / n_proc) 146 | ) 147 | arguments = [ 148 | [func, Q, part, i, use_worker_id] 149 | for i, part in enumerate( 150 | [data[i: i + step] for i in range(0, len(data), step)] 151 | ) 152 | ] 153 | processes = [] 154 | for i in range(n_proc): 155 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) 156 | processes += [p] 157 | 158 | # start processes 159 | print(f"Start prefetching...") 160 | import time 161 | 162 | start = time.time() 163 | gather_res = [[] for _ in range(n_proc)] 164 | try: 165 | for p in processes: 166 | p.start() 167 | 168 | k = 0 169 | while k < n_proc: 170 | # get result 171 | res = Q.get() 172 | if res == "Done": 173 | k += 1 174 | else: 175 | gather_res[res[0]] = res[1] 176 | 177 | except Exception as e: 178 | print("Exception: ", e) 179 | for p in processes: 180 | p.terminate() 181 | 182 | raise e 183 | finally: 184 | for p in processes: 185 | p.join() 186 | print(f"Prefetching complete. [{time.time() - start} sec.]") 187 | 188 | if target_data_type == 'ndarray': 189 | if not isinstance(gather_res[0], np.ndarray): 190 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0) 191 | 192 | # order outputs 193 | return np.concatenate(gather_res, axis=0) 194 | elif target_data_type == 'list': 195 | out = [] 196 | for r in gather_res: 197 | out.extend(r) 198 | return out 199 | else: 200 | return gather_res 201 | -------------------------------------------------------------------------------- /models/Readme.md: -------------------------------------------------------------------------------- 1 | # Pretrained Model for the Fisrt Stage# 2 | 3 | Please put first stage models here as 'model.ckpt'. -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.3.0 2 | aiohttp==3.8.3 3 | aiosignal==1.2.0 4 | antlr4-python3-runtime==4.8 5 | appdirs==1.4.4 6 | async-timeout==4.0.2 7 | asynctest==0.13.0 8 | attrs==22.1.0 9 | audioread==3.0.0 10 | cachetools==5.2.0 11 | certifi @ file:///croot/certifi_1665076670883/work/certifi 12 | cffi==1.15.1 13 | charset-normalizer==2.1.1 14 | click==8.1.3 15 | -e git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1#egg=clip 16 | ConfigArgParse==1.5.3 17 | cycler==0.11.0 18 | Cython==0.29.32 19 | decorator==5.1.1 20 | einops==0.3.0 21 | face-alignment==1.3.5 22 | ffmpeg==1.4 23 | fonttools==4.38.0 24 | frozenlist==1.3.1 25 | fsspec==2022.10.0 26 | ftfy==6.1.1 27 | future==0.18.2 28 | google-auth==2.13.0 29 | google-auth-oauthlib==0.4.6 30 | grpcio==1.50.0 31 | idna==3.4 32 | imageio==2.22.2 33 | importlib-metadata==5.0.0 34 | joblib==1.2.0 35 | kiwisolver==1.4.4 36 | librosa==0.9.2 37 | llvmlite==0.39.1 38 | lpips==0.1.4 39 | Markdown==3.4.1 40 | MarkupSafe==2.1.1 41 | matplotlib==3.5.3 42 | mkl-fft==1.3.1 43 | mkl-random @ file:///tmp/build/80754af9/mkl_random_1626179032232/work 44 | mkl-service==2.4.0 45 | multidict==6.0.2 46 | mutagen==1.46.0 47 | networkx==2.6.3 48 | numba==0.56.3 49 | numpy @ file:///opt/conda/conda-bld/numpy_and_numpy_base_1653915516269/work 50 | oauthlib==3.2.2 51 | omegaconf==2.1.1 52 | opencv-python==4.6.0.66 53 | packaging==21.3 54 | pandas==1.1.5 55 | Pillow==9.2.0 56 | pooch==1.6.0 57 | protobuf==3.19.6 58 | pyasn1==0.4.8 59 | pyasn1-modules==0.2.8 60 | pycparser==2.21 61 | pydub==0.25.1 62 | pynormalize==0.1.4 63 | pyparsing==3.0.9 64 | pysptk==0.1.21 65 | python-dateutil==2.8.2 66 | python-speech-features==0.6 67 | pytorch-lightning==1.2.5 68 | pytz==2022.5 69 | PyWavelets==1.3.0 70 | pyworld==0.3.1 71 | PyYAML==6.0 72 | regex==2022.9.13 73 | requests==2.28.1 74 | requests-oauthlib==1.3.1 75 | resampy==0.4.2 76 | Resemblyzer==0.1.1.dev0 77 | rsa==4.9 78 | scenedetect==0.6.0.3 79 | scikit-image==0.19.3 80 | scikit-learn==1.0.2 81 | scikit-video==1.1.11 82 | scipy==1.7.3 83 | seaborn==0.12.1 84 | six @ file:///tmp/build/80754af9/six_1644875935023/work 85 | sklearn==0.0 86 | soundfile==0.11.0 87 | -e git+https://github.com/CompVis/taming-transformers.git@24268930bf1dce879235a7fddd0b2355b84d7ea6#egg=taming_transformers 88 | tensorboard==2.10.1 89 | tensorboard-data-server==0.6.1 90 | tensorboard-plugin-wit==1.8.1 91 | tensorboardX==2.5.1 92 | test-tube==0.7.5 93 | threadpoolctl==3.1.0 94 | tifffile==2021.11.2 95 | torch==1.10.0 96 | torchaudio==0.10.0 97 | torchmetrics==0.10.0 98 | torchvision==0.11.0 99 | tqdm==4.64.1 100 | typing==3.7.4.3 101 | typing_extensions @ file:///tmp/abs_ben9emwtky/croots/recipe/typing_extensions_1659638822008/work 102 | urllib3==1.26.12 103 | wcwidth==0.2.5 104 | webrtcvad==2.0.10 105 | Werkzeug==2.2.2 106 | yarl==1.8.1 107 | zipp==3.9.0 108 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py --base configs/latent-diffusion/talking.yaml -t --gpus 0,1,2,3,4,5,6,7, -------------------------------------------------------------------------------- /scripts/inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from omegaconf import OmegaConf 3 | import sys 4 | import os 5 | from ldm.util import instantiate_from_config 6 | from ldm.models.diffusion.ddim_ldm_ref_inpaint import DDIMSampler 7 | import numpy as np 8 | from PIL import Image 9 | from einops import rearrange 10 | from torchvision.utils import make_grid 11 | from torch.utils.data import random_split, DataLoader, Dataset, Subset 12 | from omegaconf import OmegaConf 13 | import configargparse 14 | import pdb 15 | 16 | def config_parser(): 17 | parser = configargparse.ArgumentParser() 18 | parser.add_argument('--batchsize', type=int, default=20,) 19 | parser.add_argument('--numworkers', type=int, default=12,) 20 | parser.add_argument('--save_dir', type=str, default='./logs/inference', ) 21 | 22 | return parser 23 | 24 | def load_model_from_config(config, ckpt): 25 | print(f"Loading model from {ckpt}") 26 | pl_sd = torch.load(ckpt) 27 | sd = pl_sd["state_dict"] 28 | model = instantiate_from_config(config.model) 29 | m, u = model.load_state_dict(sd, strict=False) 30 | model.cuda() 31 | model.eval() 32 | return model 33 | 34 | def get_model(): 35 | config = OmegaConf.load("configs/latent-diffusion/talking-inference.yaml") 36 | model = load_model_from_config(config, "logs/xxx.ckpt") 37 | return model 38 | 39 | 40 | parser = config_parser() 41 | args = parser.parse_args() 42 | 43 | model = get_model() 44 | sampler = DDIMSampler(model) 45 | 46 | 47 | ddim_steps = 200 48 | ddim_eta = 0.0 49 | use_ddim = True 50 | log = dict() 51 | samples = [] 52 | samples_inpainting= [] 53 | xrec_img = [] 54 | 55 | 56 | # init and save configs 57 | config_file= 'configs/latent-diffusion/talking-inference.yaml' 58 | config = OmegaConf.load(config_file) 59 | data = instantiate_from_config(config.data) 60 | dataset_configs = config.data['params']['validation'] 61 | datasets = dict([('validation', instantiate_from_config(dataset_configs))]) 62 | 63 | print("#### Data #####") 64 | for k in datasets: 65 | print(f"{k}, {datasets[k].__class__.__name__}, {len(datasets[k])}") 66 | 67 | val_dataloader = DataLoader(datasets["validation"], batch_size=args.batchsize, num_workers=args.numworkers, shuffle=False) 68 | 69 | with torch.no_grad(): 70 | for i, batch in enumerate(val_dataloader): 71 | samples = [] 72 | samples_inpainting = [] 73 | xrec_img = [] 74 | z, c_audio, c_lip, c_ldm, c_mask, x, xrec, xc_audio, xc_lip = model.get_input(batch, 'image', 75 | return_first_stage_outputs=True, 76 | force_c_encode=True, 77 | return_original_cond=True, 78 | bs=args.batchsize) 79 | shape = (z.shape[1], z.shape[2], z.shape[3]) 80 | N = min(x.shape[0], args.batchsize) 81 | c = {'audio': c_audio, 'lip': c_lip, 'ldm': c_ldm, 'mask_image': c_mask} 82 | 83 | b, h, w = z.shape[0], z.shape[2], z.shape[3] 84 | landmarks = batch["landmarks_all"] 85 | landmarks = landmarks / 4 86 | mask = batch["inference_mask"].to(model.device) 87 | mask = mask[:, None, ...] 88 | with model.ema_scope(): 89 | samples_ddim, _ = sampler.sample(ddim_steps, N, shape, c, x0=z[:N], verbose=False, eta=ddim_eta, mask=mask) 90 | 91 | x_samples_ddim = model.decode_first_stage(samples_ddim) 92 | x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, 93 | min=0.0, max=1.0) 94 | samples_inpainting.append(x_samples_ddim) 95 | 96 | #save images 97 | samples_inpainting = torch.stack(samples_inpainting, 0) 98 | samples_inpainting = rearrange(samples_inpainting, 'n b c h w -> (n b) c h w') 99 | save_path = os.path.join(args.save_dir, '105_a105_mask_face') 100 | if not os.path.exists(save_path): 101 | os.mkdir(save_path) 102 | for j in range(samples_inpainting.shape[0]): 103 | samples_inpainting_img = 255. * rearrange(samples_inpainting[j], 'c h w -> h w c').cpu().numpy() 104 | img = Image.fromarray(samples_inpainting_img.astype(np.uint8)) 105 | img.save(os.path.join(save_path, '{:04d}_{:04d}.jpg'.format(i, j))) 106 | 107 | 108 | --------------------------------------------------------------------------------