├── .gitignore ├── LICENSE ├── MCVD_demo_SMMNIST.ipynb ├── README.md ├── configs ├── bair.yml ├── bair_big.yml ├── bair_big_spade.yml ├── bedroom.yml ├── celeba.yml ├── church.yml ├── cifar10.yml ├── cityscapes.yml ├── cityscapes_big.yml ├── cityscapes_big_spade.yml ├── ffhq.yml ├── kth64_big.yml ├── kth64_big_spade.yml ├── smmnist_DDPM_big5.yml ├── smmnist_DDPM_big5_spade.yml ├── smmnist_DDPM_small5.yml ├── smmnist_DDPM_small5_3d_32Gb.yml ├── tower.yml └── ucf101.yml ├── datasets ├── __init__.py ├── bair.py ├── bair_convert.py ├── bair_download.sh ├── celeba.py ├── cityscapes.py ├── cityscapes_convert.py ├── cityscapes_download.sh ├── ffhq.py ├── ffhq_tfrecords.py ├── h5.py ├── imagenet.py ├── kinetics600_convert.py ├── kth.py ├── kth_convert.py ├── kth_download.sh ├── kth_sequences.txt ├── moving_mnist.py ├── stochastic_moving_mnist.py ├── ucf101.py ├── ucf101_convert.py ├── ucf101_download.sh ├── utils.py └── vision.py ├── evaluation ├── fid_PR.py ├── fid_score_OLD.py ├── inception.py ├── nearest_neighbor.py └── pr.py ├── example_scripts ├── final │ ├── base_1f.sh │ ├── base_1f_2.sh │ ├── base_1f_4.sh │ ├── base_1f_vidgen.sh │ ├── base_1f_vidgen_short.sh │ ├── sampling_scripts.sh │ ├── simple_sample.py │ └── training_scripts.sh └── video_gen_metrics.sh ├── load_model_from_ckpt.py ├── losses ├── __init__.py └── dsm.py ├── main.py ├── models ├── __init__.py ├── base_model.py ├── better │ ├── __init__.py │ ├── layers.py │ ├── layers3d.py │ ├── layerspp.py │ ├── ncsnpp_more.py │ ├── normalization.py │ ├── op │ │ ├── __init__.py │ │ ├── fused_act.py │ │ ├── fused_bias_act.cpp │ │ ├── fused_bias_act_kernel.cu │ │ ├── upfirdn2d.cpp │ │ ├── upfirdn2d.py │ │ └── upfirdn2d_kernel.cu │ ├── up_or_down_sampling.py │ └── utils.py ├── dist_model.py ├── ema.py ├── eval_models.py ├── fvd │ ├── __init__.py │ ├── convert_tf_pretrained.py │ ├── fvd.py │ └── pytorch_i3d.py ├── networks_basic.py ├── pndm.py ├── pretrained_networks.py ├── unet.py └── weights │ ├── v0.0 │ ├── alex.pth │ ├── squeeze.pth │ └── vgg.pth │ └── v0.1 │ ├── alex.pth │ ├── squeeze.pth │ └── vgg.pth ├── quick_sample.py ├── requirements.txt └── runners ├── __init__.py └── ncsn_runner.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | *.DS_Store 132 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Vikram Voleti (vikram.voleti@gmail.com) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/bair.yml: -------------------------------------------------------------------------------- 1 | training: 2 | L1: false 3 | batch_size: 32 4 | n_epochs: 1000000 5 | n_iters: 3000001 6 | snapshot_freq: 50000 7 | snapshot_sampling: true 8 | sample_freq: 50000 9 | val_freq: 1000 10 | log_freq: 100 11 | log_all_sigmas: false 12 | 13 | sampling: 14 | batch_size: 100 15 | data_init: false 16 | ckpt_id: 0 17 | final_only: true 18 | fid: false 19 | ssim: true 20 | fvd: true 21 | denoise: true 22 | subsample: 100 23 | num_samples4fid: 10000 24 | num_samples4fvd: 10000 25 | inpainting: false 26 | interpolation: false 27 | n_interpolations: 15 28 | consistent: true 29 | step_lr: 0.0 30 | n_steps_each: 0 31 | train: false 32 | num_frames_pred: 28 33 | clip_before: true 34 | max_data_iter: 100000 35 | init_prev_t: -1.0 # if >0, we start next_frame at prev_frame starting with noise t=init_prev_t 36 | one_frame_at_a_time: false 37 | preds_per_test: 1 38 | 39 | fast_fid: 40 | batch_size: 1000 41 | num_samples: 1000 42 | begin_ckpt: 5000 43 | freq: 5000 44 | end_ckpt: 300000 45 | pr_nn_k: 3 46 | verbose: false 47 | ensemble: false 48 | step_lr: 0.0 49 | n_steps_each: 0 50 | 51 | test: 52 | begin_ckpt: 5000 53 | end_ckpt: 300000 54 | batch_size: 100 55 | 56 | data: 57 | dataset: "BAIR" 58 | image_size: 64 59 | channels: 3 60 | logit_transform: false 61 | uniform_dequantization: false 62 | gaussian_dequantization: false 63 | random_flip: true 64 | rescaled: true 65 | color_jitter: 0.2 66 | num_workers: 0 67 | num_frames: 10 68 | num_frames_cond: 2 69 | num_frames_future: 0 70 | prob_mask_cond: 0.0 71 | prob_mask_future: 0.0 72 | prob_mask_sync: false 73 | 74 | model: 75 | depth: deeper 76 | version: DDPM 77 | gamma: false 78 | arch: unetmore 79 | type: v1 80 | time_conditional: true 81 | dropout: 0.1 82 | sigma_dist: linear 83 | sigma_begin: 0.02 84 | sigma_end: 0.0001 85 | num_classes: 1000 86 | ema: true 87 | ema_rate: 0.999 88 | spec_norm: false 89 | normalization: InstanceNorm++ 90 | nonlinearity: swish 91 | ngf: 32 92 | ch_mult: 93 | - 1 94 | - 2 95 | - 2 96 | - 2 97 | num_res_blocks: 3 # 8 for traditional 98 | attn_resolutions: 99 | - 8 100 | - 16 101 | - 32 # can use only 16 for traditional 102 | n_head_channels: 64 # -1 for traditional 103 | conditional: true 104 | noise_in_cond: false 105 | output_all_frames: false # could be useful especially for 3d models 106 | cond_emb: false 107 | spade: false 108 | spade_dim: 128 109 | 110 | optim: 111 | weight_decay: 0.000 112 | optimizer: "Adam" 113 | lr: 0.0001 114 | warmup: 5000 115 | beta1: 0.9 116 | amsgrad: false 117 | eps: 0.00000001 118 | grad_clip: 1.0 119 | -------------------------------------------------------------------------------- /configs/bair_big.yml: -------------------------------------------------------------------------------- 1 | training: 2 | L1: false 3 | batch_size: 64 4 | n_epochs: 1000000 5 | n_iters: 3000001 6 | snapshot_freq: 50000 7 | snapshot_sampling: true 8 | sample_freq: 50000 9 | val_freq: 1000 10 | log_freq: 100 11 | log_all_sigmas: false 12 | 13 | sampling: 14 | batch_size: 100 15 | data_init: false 16 | ckpt_id: 0 17 | final_only: true 18 | fid: false 19 | ssim: true 20 | fvd: true 21 | denoise: true 22 | subsample: 100 23 | num_samples4fid: 10000 24 | num_samples4fvd: 10000 25 | inpainting: false 26 | interpolation: false 27 | n_interpolations: 15 28 | consistent: true 29 | step_lr: 0.0 30 | n_steps_each: 0 31 | train: false 32 | num_frames_pred: 28 33 | clip_before: true 34 | max_data_iter: 100000 35 | init_prev_t: -1.0 # if >0, we start next_frame at prev_frame starting with noise t=init_prev_t 36 | one_frame_at_a_time: false 37 | preds_per_test: 1 38 | 39 | fast_fid: 40 | batch_size: 1000 41 | num_samples: 1000 42 | begin_ckpt: 5000 43 | freq: 5000 44 | end_ckpt: 300000 45 | pr_nn_k: 3 46 | verbose: false 47 | ensemble: false 48 | step_lr: 0.0 49 | n_steps_each: 0 50 | 51 | test: 52 | begin_ckpt: 5000 53 | end_ckpt: 300000 54 | batch_size: 100 55 | 56 | data: 57 | dataset: "BAIR" 58 | image_size: 64 59 | channels: 3 60 | logit_transform: false 61 | uniform_dequantization: false 62 | gaussian_dequantization: false 63 | random_flip: true 64 | rescaled: true 65 | color_jitter: 0.0 66 | num_workers: 0 67 | test_subset: -1 68 | num_frames: 5 69 | num_frames_cond: 2 70 | num_frames_future: 0 71 | prob_mask_cond: 0.0 72 | prob_mask_future: 0.0 73 | prob_mask_sync: false 74 | 75 | model: 76 | depth: deeper 77 | version: DDPM 78 | gamma: false 79 | arch: unetmore 80 | type: v1 81 | time_conditional: true 82 | dropout: 0.1 83 | sigma_dist: linear 84 | sigma_begin: 0.02 85 | sigma_end: 0.0001 86 | num_classes: 1000 87 | ema: true 88 | ema_rate: 0.999 89 | spec_norm: false 90 | normalization: InstanceNorm++ 91 | nonlinearity: swish 92 | ngf: 96 93 | ch_mult: 94 | - 1 95 | - 2 96 | - 3 97 | - 4 98 | num_res_blocks: 2 # 8 for traditional 99 | attn_resolutions: 100 | - 8 101 | - 16 102 | - 32 # can use only 16 for traditional 103 | n_head_channels: 96 # -1 for traditional 104 | conditional: true 105 | noise_in_cond: false 106 | output_all_frames: false # could be useful especially for 3d models 107 | cond_emb: false 108 | spade: false 109 | spade_dim: 128 110 | 111 | optim: 112 | weight_decay: 0.000 113 | optimizer: "Adam" 114 | lr: 0.0001 115 | warmup: 5000 116 | beta1: 0.9 117 | amsgrad: false 118 | eps: 0.00000001 119 | grad_clip: 1.0 120 | -------------------------------------------------------------------------------- /configs/bair_big_spade.yml: -------------------------------------------------------------------------------- 1 | training: 2 | L1: false 3 | batch_size: 64 4 | n_epochs: 1000000 5 | n_iters: 3000001 6 | snapshot_freq: 50000 7 | snapshot_sampling: true 8 | sample_freq: 50000 9 | val_freq: 1000 10 | log_freq: 100 11 | log_all_sigmas: false 12 | 13 | sampling: 14 | batch_size: 100 15 | data_init: false 16 | ckpt_id: 0 17 | final_only: true 18 | fid: false 19 | ssim: true 20 | fvd: true 21 | denoise: true 22 | subsample: 100 23 | num_samples4fid: 10000 24 | num_samples4fvd: 10000 25 | inpainting: false 26 | interpolation: false 27 | n_interpolations: 15 28 | consistent: true 29 | step_lr: 0.0 30 | n_steps_each: 0 31 | train: false 32 | num_frames_pred: 28 33 | clip_before: true 34 | max_data_iter: 100000 35 | init_prev_t: -1.0 # if >0, we start next_frame at prev_frame starting with noise t=init_prev_t 36 | one_frame_at_a_time: false 37 | preds_per_test: 1 38 | 39 | fast_fid: 40 | batch_size: 1000 41 | num_samples: 1000 42 | begin_ckpt: 5000 43 | freq: 5000 44 | end_ckpt: 300000 45 | pr_nn_k: 3 46 | verbose: false 47 | ensemble: false 48 | step_lr: 0.0 49 | n_steps_each: 0 50 | 51 | test: 52 | begin_ckpt: 5000 53 | end_ckpt: 300000 54 | batch_size: 100 55 | 56 | data: 57 | dataset: "BAIR" 58 | image_size: 64 59 | channels: 3 60 | logit_transform: false 61 | uniform_dequantization: false 62 | gaussian_dequantization: false 63 | random_flip: true 64 | rescaled: true 65 | color_jitter: 0.0 66 | test_subset: -1 67 | num_workers: 0 68 | num_frames: 5 69 | num_frames_cond: 2 70 | num_frames_future: 0 71 | prob_mask_cond: 0.0 72 | prob_mask_future: 0.0 73 | prob_mask_sync: false 74 | 75 | model: 76 | depth: deeper 77 | version: DDPM 78 | gamma: false 79 | arch: unetmore 80 | type: v1 81 | time_conditional: true 82 | dropout: 0.1 83 | sigma_dist: linear 84 | sigma_begin: 0.02 85 | sigma_end: 0.0001 86 | num_classes: 1000 87 | ema: true 88 | ema_rate: 0.999 89 | spec_norm: false 90 | normalization: InstanceNorm++ 91 | nonlinearity: swish 92 | ngf: 96 93 | ch_mult: 94 | - 1 95 | - 2 96 | - 3 97 | - 4 98 | num_res_blocks: 2 # 8 for traditional 99 | attn_resolutions: 100 | - 8 101 | - 16 102 | - 32 # can use only 16 for traditional 103 | n_head_channels: 96 # -1 for traditional 104 | conditional: true 105 | noise_in_cond: false 106 | output_all_frames: false # could be useful especially for 3d models 107 | cond_emb: false 108 | spade: true 109 | spade_dim: 128 110 | 111 | optim: 112 | weight_decay: 0.000 113 | optimizer: "Adam" 114 | lr: 0.0001 115 | warmup: 5000 116 | beta1: 0.9 117 | amsgrad: false 118 | eps: 0.00000001 119 | grad_clip: 1.0 120 | -------------------------------------------------------------------------------- /configs/bedroom.yml: -------------------------------------------------------------------------------- 1 | training: 2 | L1: false 3 | batch_size: 128 4 | n_epochs: 500000 5 | n_iters: 150001 6 | snapshot_freq: 50000 7 | snapshot_sampling: true 8 | sample_freq: 50000 9 | val_freq: 100 10 | log_all_sigmas: false 11 | 12 | sampling: 13 | batch_size: 36 14 | data_init: false 15 | step_lr: 0.0000018 16 | n_steps_each: 3 17 | ckpt_id: 0 18 | final_only: true 19 | fid: false 20 | ssim: true 21 | fvd: true 22 | denoise: true 23 | num_samples4fid: 10000 24 | inpainting: false 25 | interpolation: false 26 | n_interpolations: 10 27 | clip_before: true 28 | max_data_iter: 100000 29 | init_prev_t: -1.0 # if >0, we start next_frame at prev_frame starting with noise t=init_prev_t 30 | one_frame_at_a_time: false 31 | preds_per_test: 1 32 | 33 | fast_fid: 34 | batch_size: 1000 35 | num_samples: 1000 36 | step_lr: 0.0000018 37 | n_steps_each: 3 38 | begin_ckpt: 100000 39 | end_ckpt: 150000 40 | verbose: false 41 | ensemble: false 42 | 43 | test: 44 | begin_ckpt: 5000 45 | end_ckpt: 150000 46 | batch_size: 100 47 | 48 | data: 49 | dataset: "LSUN" 50 | category: "bedroom" 51 | image_size: 128 52 | channels: 3 53 | logit_transform: false 54 | uniform_dequantization: false 55 | gaussian_dequantization: false 56 | random_flip: true 57 | rescaled: false 58 | num_workers: 32 59 | num_frames: 1 60 | num_frames_cond: 0 61 | num_frames_future: 0 62 | prob_mask_cond: 0.0 63 | prob_mask_future: 0.0 64 | prob_mask_sync: false 65 | 66 | model: 67 | depth: deeper 68 | sigma_begin: 190 69 | num_classes: 1086 70 | ema: true 71 | ema_rate: 0.999 72 | spec_norm: false 73 | sigma_dist: geometric 74 | sigma_end: 0.01 75 | normalization: InstanceNorm++ 76 | nonlinearity: elu 77 | ngf: 128 78 | ch_mult: 79 | - 1 80 | - 2 81 | - 2 82 | - 2 83 | num_res_blocks: 1 # 8 for traditional 84 | attn_resolutions: 85 | - 8 86 | - 16 87 | - 32 # can use only 16 for traditional 88 | n_head_channels: 64 # -1 for traditional 89 | conditional: false 90 | noise_in_cond: false 91 | output_all_frames: false # could be useful especially for 3d models 92 | cond_emb: false 93 | spade: false 94 | spade_dim: 128 95 | 96 | optim: 97 | weight_decay: 0.000 98 | optimizer: "Adam" 99 | lr: 0.0001 100 | beta1: 0.9 101 | amsgrad: false 102 | eps: 0.00000001 103 | -------------------------------------------------------------------------------- /configs/celeba.yml: -------------------------------------------------------------------------------- 1 | training: 2 | L1: false 3 | batch_size: 128 4 | n_epochs: 500000 5 | n_iters: 210001 6 | snapshot_freq: 50000 7 | snapshot_sampling: true 8 | sample_freq: 50000 9 | val_freq: 100 10 | anneal_power: 2 11 | log_all_sigmas: false 12 | 13 | sampling: 14 | batch_size: 64 15 | data_init: false 16 | step_lr: 0.0000033 17 | n_steps_each: 5 18 | ckpt_id: 0 19 | final_only: true 20 | fid: false 21 | ssim: true 22 | fvd: true 23 | denoise: true 24 | num_samples4fid: 10000 25 | inpainting: false 26 | interpolation: false 27 | n_interpolations: 15 28 | clip_before: true 29 | max_data_iter: 100000 30 | init_prev_t: -1.0 # if >0, we start next_frame at prev_frame starting with noise t=init_prev_t 31 | one_frame_at_a_time: false 32 | preds_per_test: 1 33 | 34 | fast_fid: 35 | batch_size: 1000 36 | num_samples: 1000 37 | step_lr: 0.0000033 38 | n_steps_each: 5 39 | begin_ckpt: 5000 40 | end_ckpt: 210000 41 | verbose: false 42 | ensemble: false 43 | 44 | test: 45 | begin_ckpt: 5000 46 | end_ckpt: 210000 47 | batch_size: 100 48 | 49 | data: 50 | dataset: "CELEBA" 51 | image_size: 64 52 | channels: 3 53 | logit_transform: false 54 | uniform_dequantization: false 55 | gaussian_dequantization: false 56 | random_flip: true 57 | rescaled: false 58 | num_workers: 32 59 | num_frames: 1 60 | num_frames_cond: 0 61 | num_frames_future: 0 62 | prob_mask_cond: 0.0 63 | prob_mask_future: 0.0 64 | prob_mask_sync: false 65 | 66 | model: 67 | depth: deep 68 | sigma_begin: 90 69 | num_classes: 500 70 | ema: true 71 | ema_rate: 0.999 72 | spec_norm: false 73 | sigma_dist: geometric 74 | sigma_end: 0.01 75 | normalization: InstanceNorm++ 76 | nonlinearity: elu 77 | ngf: 128 78 | ch_mult: 79 | - 1 80 | - 2 81 | - 2 82 | - 2 83 | num_res_blocks: 1 # 8 for traditional 84 | attn_resolutions: 85 | - 8 86 | - 16 87 | - 32 # can use only 16 for traditional 88 | n_head_channels: 64 # -1 for traditional 89 | conditional: false 90 | noise_in_cond: false 91 | output_all_frames: false # could be useful especially for 3d models 92 | cond_emb: false 93 | spade: false 94 | spade_dim: 128 95 | 96 | optim: 97 | weight_decay: 0.000 98 | optimizer: "Adam" 99 | lr: 0.0001 100 | beta1: 0.9 101 | amsgrad: false 102 | eps: 0.00000001 103 | -------------------------------------------------------------------------------- /configs/church.yml: -------------------------------------------------------------------------------- 1 | training: 2 | L1: false 3 | batch_size: 128 4 | n_epochs: 500000 5 | n_iters: 200001 6 | snapshot_freq: 50000 7 | snapshot_sampling: true 8 | sample_freq: 50000 9 | val_freq: 100 10 | anneal_power: 2 11 | log_all_sigmas: false 12 | 13 | sampling: 14 | batch_size: 96 15 | data_init: false 16 | step_lr: 0.0000049 17 | n_steps_each: 4 18 | ckpt_id: 0 19 | final_only: true 20 | fid: false 21 | ssim: true 22 | fvd: true 23 | denoise: true 24 | num_samples4fid: 10000 25 | inpainting: false 26 | interpolation: false 27 | n_interpolations: 12 28 | clip_before: true 29 | max_data_iter: 100000 30 | init_prev_t: -1.0 # if >0, we start next_frame at prev_frame starting with noise t=init_prev_t 31 | one_frame_at_a_time: false 32 | preds_per_test: 1 33 | 34 | fast_fid: 35 | batch_size: 1000 36 | num_samples: 1000 37 | step_lr: 0.0000049 38 | n_steps_each: 4 39 | begin_ckpt: 100000 40 | end_ckpt: 200000 41 | verbose: false 42 | ensemble: false 43 | 44 | test: 45 | begin_ckpt: 5000 46 | end_ckpt: 200000 47 | batch_size: 100 48 | 49 | data: 50 | dataset: "LSUN" 51 | category: "church_outdoor" 52 | image_size: 64 53 | channels: 3 54 | logit_transform: false 55 | uniform_dequantization: false 56 | gaussian_dequantization: false 57 | random_flip: true 58 | rescaled: false 59 | num_workers: 32 60 | num_frames: 1 61 | num_frames_cond: 0 62 | num_frames_future: 0 63 | prob_mask_cond: 0.0 64 | prob_mask_future: 0.0 65 | prob_mask_sync: false 66 | 67 | model: 68 | depth: deeper 69 | sigma_begin: 140 70 | num_classes: 788 71 | ema: true 72 | ema_rate: 0.999 73 | spec_norm: false 74 | sigma_dist: geometric 75 | sigma_end: 0.01 76 | normalization: InstanceNorm++ 77 | nonlinearity: elu 78 | ngf: 128 79 | ch_mult: 80 | - 1 81 | - 2 82 | - 2 83 | - 2 84 | num_res_blocks: 1 # 8 for traditional 85 | attn_resolutions: 86 | - 8 87 | - 16 88 | - 32 # can use only 16 for traditional 89 | n_head_channels: 64 # -1 for traditional 90 | conditional: false 91 | noise_in_cond: false 92 | output_all_frames: false # could be useful especially for 3d models 93 | cond_emb: false 94 | spade: false 95 | spade_dim: 128 96 | 97 | optim: 98 | weight_decay: 0.000 99 | optimizer: "Adam" 100 | lr: 0.0001 101 | beta1: 0.9 102 | amsgrad: false 103 | eps: 0.00000001 104 | -------------------------------------------------------------------------------- /configs/cifar10.yml: -------------------------------------------------------------------------------- 1 | training: 2 | L1: false 3 | batch_size: 128 4 | n_epochs: 500000 5 | n_iters: 300001 6 | snapshot_freq: 50000 7 | snapshot_sampling: true 8 | sample_freq: 50000 9 | val_freq: 100 10 | log_all_sigmas: false 11 | 12 | sampling: 13 | batch_size: 100 14 | data_init: false 15 | step_lr: 0.0000062 16 | n_steps_each: 5 17 | ckpt_id: 0 18 | final_only: true 19 | fid: false 20 | ssim: true 21 | fvd: true 22 | denoise: true 23 | num_samples4fid: 10000 24 | inpainting: false 25 | interpolation: false 26 | n_interpolations: 15 27 | clip_before: true 28 | max_data_iter: 100000 29 | init_prev_t: -1.0 # if >0, we start next_frame at prev_frame starting with noise t=init_prev_t 30 | one_frame_at_a_time: false 31 | preds_per_test: 1 32 | 33 | fast_fid: 34 | batch_size: 1000 35 | num_samples: 1000 36 | step_lr: 0.0000062 37 | n_steps_each: 5 38 | begin_ckpt: 5000 39 | end_ckpt: 300000 40 | verbose: false 41 | ensemble: false 42 | 43 | test: 44 | begin_ckpt: 5000 45 | end_ckpt: 300000 46 | batch_size: 100 47 | 48 | data: 49 | dataset: "CIFAR10" 50 | image_size: 32 51 | channels: 3 52 | logit_transform: false 53 | uniform_dequantization: false 54 | gaussian_dequantization: false 55 | random_flip: true 56 | rescaled: false 57 | num_workers: 0 58 | num_frames: 1 59 | num_frames_cond: 0 60 | num_frames_future: 0 61 | prob_mask_cond: 0.0 62 | prob_mask_future: 0.0 63 | prob_mask_sync: false 64 | 65 | model: 66 | depth: deep 67 | version: SMLD 68 | arch: ncsn 69 | sigma_dist: geometric 70 | sigma_begin: 50 71 | sigma_end: 0.01 72 | num_classes: 232 73 | ema: true 74 | ema_rate: 0.999 75 | spec_norm: false 76 | normalization: InstanceNorm++ 77 | nonlinearity: elu 78 | ngf: 128 79 | ch_mult: 80 | - 1 81 | - 2 82 | - 2 83 | - 2 84 | num_res_blocks: 1 # 8 for traditional 85 | attn_resolutions: 86 | - 8 87 | - 16 88 | - 32 # can use only 16 for traditional 89 | n_head_channels: 64 # -1 for traditional 90 | conditional: false 91 | noise_in_cond: false 92 | output_all_frames: false # could be useful especially for 3d models 93 | cond_emb: false 94 | spade: false 95 | spade_dim: 128 96 | 97 | optim: 98 | weight_decay: 0.000 99 | optimizer: "Adam" 100 | lr: 0.0001 101 | warmup: 0 102 | beta1: 0.9 103 | amsgrad: false 104 | eps: 0.00000001 105 | -------------------------------------------------------------------------------- /configs/cityscapes.yml: -------------------------------------------------------------------------------- 1 | training: 2 | L1: false 3 | batch_size: 32 4 | n_epochs: 1000000 5 | n_iters: 3000001 6 | snapshot_freq: 50000 7 | snapshot_sampling: true 8 | sample_freq: 50000 9 | val_freq: 1000 10 | log_freq: 100 11 | log_all_sigmas: false 12 | 13 | sampling: 14 | batch_size: 100 15 | data_init: false 16 | ckpt_id: 0 17 | final_only: true 18 | fid: false 19 | ssim: true 20 | fvd: true 21 | denoise: true 22 | subsample: 100 23 | num_samples4fid: 10000 24 | num_samples4fvd: 10000 25 | inpainting: false 26 | interpolation: false 27 | n_interpolations: 15 28 | consistent: true 29 | step_lr: 0.0 30 | n_steps_each: 0 31 | train: false 32 | num_frames_pred: 28 33 | clip_before: true 34 | max_data_iter: 100000 35 | init_prev_t: -1.0 # if >0, we start next_frame at prev_frame starting with noise t=init_prev_t 36 | one_frame_at_a_time: false 37 | preds_per_test: 1 38 | 39 | fast_fid: 40 | batch_size: 1000 41 | num_samples: 1000 42 | begin_ckpt: 5000 43 | freq: 5000 44 | end_ckpt: 300000 45 | pr_nn_k: 3 46 | verbose: false 47 | ensemble: false 48 | step_lr: 0.0 49 | n_steps_each: 0 50 | 51 | test: 52 | begin_ckpt: 5000 53 | end_ckpt: 300000 54 | batch_size: 100 55 | 56 | data: 57 | dataset: "Cityscapes" 58 | image_size: 128 59 | channels: 3 60 | logit_transform: false 61 | uniform_dequantization: false 62 | gaussian_dequantization: false 63 | random_flip: true 64 | rescaled: true 65 | color_jitter: 0.0 66 | num_workers: 0 67 | num_frames: 2 68 | num_frames_cond: 2 69 | num_frames_future: 0 70 | prob_mask_cond: 0.0 71 | prob_mask_future: 0.0 72 | prob_mask_sync: false 73 | 74 | model: 75 | depth: deeper 76 | version: DDPM 77 | gamma: false 78 | arch: unetmore 79 | type: v1 80 | time_conditional: true 81 | dropout: 0.0 82 | sigma_dist: linear 83 | sigma_begin: 0.02 84 | sigma_end: 0.0001 85 | num_classes: 1000 86 | ema: true 87 | ema_rate: 0.999 88 | spec_norm: false 89 | normalization: InstanceNorm++ 90 | nonlinearity: swish 91 | ngf: 32 92 | ch_mult: 93 | - 1 94 | - 2 95 | - 2 96 | - 2 97 | num_res_blocks: 3 # 8 for traditional 98 | attn_resolutions: 99 | - 8 100 | - 16 101 | - 32 # can use only 16 for traditional 102 | n_head_channels: 64 # -1 for traditional 103 | conditional: true 104 | noise_in_cond: false 105 | output_all_frames: false # could be useful especially for 3d models 106 | cond_emb: false 107 | spade: false 108 | spade_dim: 128 109 | 110 | optim: 111 | weight_decay: 0.000 112 | optimizer: "Adam" 113 | lr: 0.0001 114 | warmup: 5000 115 | beta1: 0.9 116 | amsgrad: false 117 | eps: 0.00000001 118 | grad_clip: 1.0 119 | -------------------------------------------------------------------------------- /configs/cityscapes_big.yml: -------------------------------------------------------------------------------- 1 | training: 2 | L1: false 3 | batch_size: 64 4 | n_epochs: 1000000 5 | n_iters: 3000001 6 | snapshot_freq: 50000 7 | snapshot_sampling: true 8 | sample_freq: 50000 9 | val_freq: 1000 10 | log_freq: 100 11 | log_all_sigmas: false 12 | 13 | sampling: 14 | batch_size: 100 15 | data_init: false 16 | ckpt_id: 0 17 | final_only: true 18 | fid: false 19 | ssim: true 20 | fvd: true 21 | denoise: true 22 | subsample: 100 23 | num_samples4fid: 10000 24 | num_samples4fvd: 10000 25 | inpainting: false 26 | interpolation: false 27 | n_interpolations: 15 28 | consistent: true 29 | step_lr: 0.0 30 | n_steps_each: 0 31 | train: false 32 | num_frames_pred: 28 33 | clip_before: true 34 | max_data_iter: 100000 35 | init_prev_t: -1.0 # if >0, we start next_frame at prev_frame starting with noise t=init_prev_t 36 | one_frame_at_a_time: false 37 | preds_per_test: 1 38 | 39 | fast_fid: 40 | batch_size: 1000 41 | num_samples: 1000 42 | begin_ckpt: 5000 43 | freq: 5000 44 | end_ckpt: 300000 45 | pr_nn_k: 3 46 | verbose: false 47 | ensemble: false 48 | step_lr: 0.0 49 | n_steps_each: 0 50 | 51 | test: 52 | begin_ckpt: 5000 53 | end_ckpt: 300000 54 | batch_size: 100 55 | 56 | data: 57 | dataset: "Cityscapes" 58 | image_size: 128 59 | channels: 3 60 | logit_transform: false 61 | uniform_dequantization: false 62 | gaussian_dequantization: false 63 | random_flip: true 64 | rescaled: true 65 | color_jitter: 0.0 66 | num_workers: 0 67 | num_frames: 5 68 | num_frames_cond: 2 69 | num_frames_future: 0 70 | prob_mask_cond: 0.0 71 | prob_mask_future: 0.0 72 | prob_mask_sync: false 73 | 74 | model: 75 | depth: deeper 76 | version: DDPM 77 | gamma: false 78 | arch: unetmore 79 | type: v1 80 | time_conditional: true 81 | dropout: 0.0 82 | sigma_dist: linear 83 | sigma_begin: 0.02 84 | sigma_end: 0.0001 85 | num_classes: 1000 86 | ema: true 87 | ema_rate: 0.999 88 | spec_norm: false 89 | normalization: InstanceNorm++ 90 | nonlinearity: swish 91 | ngf: 128 92 | ch_mult: 93 | - 1 94 | - 1 95 | - 2 96 | - 3 97 | - 4 98 | num_res_blocks: 2 # 8 for traditional 99 | attn_resolutions: 100 | - 8 101 | - 16 102 | - 32 # can use only 16 for traditional 103 | n_head_channels: 128 # -1 for traditional 104 | conditional: true 105 | noise_in_cond: false 106 | output_all_frames: false # could be useful especially for 3d models 107 | cond_emb: false 108 | spade: false 109 | spade_dim: 128 110 | 111 | optim: 112 | weight_decay: 0.000 113 | optimizer: "Adam" 114 | lr: 0.0001 115 | warmup: 5000 116 | beta1: 0.9 117 | amsgrad: false 118 | eps: 0.00000001 119 | grad_clip: 1.0 120 | -------------------------------------------------------------------------------- /configs/cityscapes_big_spade.yml: -------------------------------------------------------------------------------- 1 | training: 2 | L1: false 3 | batch_size: 32 4 | n_epochs: 1000000 5 | n_iters: 3000001 6 | snapshot_freq: 50000 7 | snapshot_sampling: true 8 | sample_freq: 50000 9 | val_freq: 1000 10 | log_freq: 100 11 | log_all_sigmas: false 12 | 13 | sampling: 14 | batch_size: 100 15 | data_init: false 16 | ckpt_id: 0 17 | final_only: true 18 | fid: false 19 | ssim: true 20 | fvd: true 21 | denoise: true 22 | subsample: 100 23 | num_samples4fid: 10000 24 | num_samples4fvd: 10000 25 | inpainting: false 26 | interpolation: false 27 | n_interpolations: 15 28 | consistent: true 29 | step_lr: 0.0 30 | n_steps_each: 0 31 | train: false 32 | num_frames_pred: 28 33 | clip_before: true 34 | max_data_iter: 100000 35 | init_prev_t: -1.0 # if >0, we start next_frame at prev_frame starting with noise t=init_prev_t 36 | one_frame_at_a_time: false 37 | preds_per_test: 1 38 | 39 | fast_fid: 40 | batch_size: 1000 41 | num_samples: 1000 42 | begin_ckpt: 5000 43 | freq: 5000 44 | end_ckpt: 300000 45 | pr_nn_k: 3 46 | verbose: false 47 | ensemble: false 48 | step_lr: 0.0 49 | n_steps_each: 0 50 | 51 | test: 52 | begin_ckpt: 5000 53 | end_ckpt: 300000 54 | batch_size: 100 55 | 56 | data: 57 | dataset: "Cityscapes" 58 | image_size: 128 59 | channels: 3 60 | logit_transform: false 61 | uniform_dequantization: false 62 | gaussian_dequantization: false 63 | random_flip: true 64 | rescaled: true 65 | color_jitter: 0.0 66 | num_workers: 0 67 | num_frames: 5 68 | num_frames_cond: 2 69 | num_frames_future: 0 70 | prob_mask_cond: 0.0 71 | prob_mask_future: 0.0 72 | prob_mask_sync: false 73 | 74 | model: 75 | depth: deeper 76 | version: DDPM 77 | gamma: false 78 | arch: unetmore 79 | type: v1 80 | time_conditional: true 81 | dropout: 0.0 82 | sigma_dist: linear 83 | sigma_begin: 0.02 84 | sigma_end: 0.0001 85 | num_classes: 1000 86 | ema: true 87 | ema_rate: 0.999 88 | spec_norm: false 89 | normalization: InstanceNorm++ 90 | nonlinearity: swish 91 | ngf: 192 92 | ch_mult: 93 | - 1 94 | - 1 95 | - 2 96 | - 3 97 | - 4 98 | num_res_blocks: 2 # 8 for traditional 99 | attn_resolutions: 100 | - 8 101 | - 16 102 | - 32 # can use only 16 for traditional 103 | n_head_channels: 192 # -1 for traditional 104 | conditional: true 105 | noise_in_cond: false 106 | output_all_frames: false # could be useful especially for 3d models 107 | cond_emb: false 108 | spade: true 109 | spade_dim: 256 110 | 111 | optim: 112 | weight_decay: 0.000 113 | optimizer: "Adam" 114 | lr: 0.0001 115 | warmup: 5000 116 | beta1: 0.9 117 | amsgrad: false 118 | eps: 0.00000001 119 | grad_clip: 1.0 120 | -------------------------------------------------------------------------------- /configs/ffhq.yml: -------------------------------------------------------------------------------- 1 | training: 2 | L1: false 3 | batch_size: 32 4 | n_epochs: 500000 5 | n_iters: 80001 6 | snapshot_freq: 50000 7 | snapshot_sampling: true 8 | sample_freq: 50000 9 | val_freq: 100 10 | anneal_power: 2 11 | log_all_sigmas: false 12 | 13 | sampling: 14 | batch_size: 36 15 | data_init: false 16 | step_lr: 0.0000009 17 | n_steps_each: 3 18 | ckpt_id: 0 19 | final_only: true 20 | fid: false 21 | ssim: true 22 | fvd: true 23 | denoise: true 24 | num_samples4fid: 10000 25 | inpainting: false 26 | interpolation: false 27 | n_interpolations: 8 28 | clip_before: true 29 | max_data_iter: 100000 30 | init_prev_t: -1.0 # if >0, we start next_frame at prev_frame starting with noise t=init_prev_t 31 | one_frame_at_a_time: false 32 | preds_per_test: 1 33 | 34 | fast_fid: 35 | batch_size: 1000 36 | num_samples: 1000 37 | step_lr: 0.0000009 38 | n_steps_each: 3 39 | begin_ckpt: 100000 40 | end_ckpt: 80000 41 | verbose: false 42 | ensemble: false 43 | 44 | test: 45 | begin_ckpt: 5000 46 | end_ckpt: 80000 47 | batch_size: 100 48 | 49 | data: 50 | dataset: "FFHQ" 51 | image_size: 256 52 | channels: 3 53 | logit_transform: false 54 | uniform_dequantization: false 55 | gaussian_dequantization: false 56 | random_flip: true 57 | rescaled: false 58 | num_workers: 8 59 | num_frames: 1 60 | num_frames_cond: 0 61 | num_frames_future: 0 62 | prob_mask_cond: 0.0 63 | prob_mask_future: 0.0 64 | prob_mask_sync: false 65 | 66 | model: 67 | depth: deepest 68 | sigma_begin: 348 69 | num_classes: 2311 70 | ema: true 71 | ema_rate: 0.999 72 | spec_norm: false 73 | sigma_dist: geometric 74 | sigma_end: 0.01 75 | normalization: InstanceNorm++ 76 | nonlinearity: elu 77 | ngf: 128 78 | ch_mult: 79 | - 1 80 | - 2 81 | - 2 82 | - 2 83 | num_res_blocks: 1 # 8 for traditional 84 | attn_resolutions: 85 | - 8 86 | - 16 87 | - 32 # can use only 16 for traditional 88 | n_head_channels: 64 # -1 for traditional 89 | conditional: false 90 | noise_in_cond: false 91 | output_all_frames: false # could be useful especially for 3d models 92 | cond_emb: false 93 | spade: false 94 | spade_dim: 128 95 | 96 | optim: 97 | weight_decay: 0.000 98 | optimizer: "Adam" 99 | lr: 0.0001 100 | beta1: 0.9 101 | amsgrad: false 102 | eps: 0.001 103 | -------------------------------------------------------------------------------- /configs/kth64_big.yml: -------------------------------------------------------------------------------- 1 | training: 2 | L1: false 3 | batch_size: 64 4 | n_epochs: 1000000 5 | n_iters: 3000001 6 | snapshot_freq: 50000 7 | snapshot_sampling: true 8 | sample_freq: 50000 9 | val_freq: 1000 10 | log_freq: 100 11 | log_all_sigmas: false 12 | 13 | sampling: 14 | batch_size: 100 15 | data_init: false 16 | ckpt_id: 0 17 | final_only: true 18 | fid: false 19 | ssim: true 20 | fvd: true 21 | denoise: true 22 | subsample: 100 23 | num_samples4fid: 10000 24 | num_samples4fvd: 10000 25 | inpainting: false 26 | interpolation: false 27 | n_interpolations: 15 28 | consistent: true 29 | step_lr: 0.0 30 | n_steps_each: 0 31 | train: false 32 | num_frames_pred: 20 33 | clip_before: true 34 | max_data_iter: 100000 35 | one_frame_at_a_time: false 36 | preds_per_test: 1 37 | 38 | fast_fid: 39 | batch_size: 1000 40 | num_samples: 1000 41 | begin_ckpt: 5000 42 | freq: 5000 43 | end_ckpt: 300000 44 | pr_nn_k: 3 45 | verbose: false 46 | ensemble: false 47 | step_lr: 0.0 48 | n_steps_each: 0 49 | 50 | test: 51 | begin_ckpt: 5000 52 | end_ckpt: 300000 53 | batch_size: 100 54 | 55 | data: 56 | dataset: "KTH" 57 | image_size: 64 58 | channels: 1 59 | logit_transform: false 60 | uniform_dequantization: false 61 | gaussian_dequantization: false 62 | random_flip: true 63 | rescaled: true 64 | num_workers: 0 65 | num_frames: 5 66 | num_frames_cond: 10 67 | num_frames_future: 0 68 | prob_mask_cond: 0.0 69 | prob_mask_future: 0.0 70 | prob_mask_sync: false 71 | 72 | model: 73 | depth: deeper 74 | version: DDPM 75 | gamma: false 76 | arch: unetmore 77 | type: v1 78 | time_conditional: true 79 | dropout: 0.1 80 | sigma_dist: linear 81 | sigma_begin: 0.02 82 | sigma_end: 0.0001 83 | num_classes: 1000 84 | ema: true 85 | ema_rate: 0.999 86 | spec_norm: false 87 | normalization: InstanceNorm++ 88 | nonlinearity: swish 89 | ngf: 96 90 | ch_mult: 91 | - 1 92 | - 2 93 | - 3 94 | - 4 95 | num_res_blocks: 2 # 8 for traditional 96 | attn_resolutions: 97 | - 8 98 | - 16 99 | - 32 # can use only 16 for traditional 100 | n_head_channels: 96 # -1 for traditional 101 | conditional: true 102 | noise_in_cond: false 103 | output_all_frames: false # could be useful especially for 3d models 104 | cond_emb: false 105 | spade: false 106 | spade_dim: 128 107 | 108 | optim: 109 | weight_decay: 0.000 110 | optimizer: "Adam" 111 | lr: 0.0002 112 | warmup: 5000 113 | beta1: 0.9 114 | amsgrad: false 115 | eps: 0.00000001 116 | grad_clip: 1.0 117 | -------------------------------------------------------------------------------- /configs/kth64_big_spade.yml: -------------------------------------------------------------------------------- 1 | training: 2 | L1: false 3 | batch_size: 64 4 | n_epochs: 1000000 5 | n_iters: 3000001 6 | snapshot_freq: 50000 7 | snapshot_sampling: true 8 | sample_freq: 50000 9 | val_freq: 1000 10 | log_freq: 100 11 | log_all_sigmas: false 12 | 13 | sampling: 14 | batch_size: 100 15 | data_init: false 16 | ckpt_id: 0 17 | final_only: true 18 | fid: false 19 | ssim: true 20 | fvd: true 21 | denoise: true 22 | subsample: 100 23 | num_samples4fid: 10000 24 | num_samples4fvd: 10000 25 | inpainting: false 26 | interpolation: false 27 | n_interpolations: 15 28 | consistent: true 29 | step_lr: 0.0 30 | n_steps_each: 0 31 | train: false 32 | num_frames_pred: 20 33 | clip_before: true 34 | max_data_iter: 100000 35 | one_frame_at_a_time: false 36 | preds_per_test: 1 37 | 38 | fast_fid: 39 | batch_size: 1000 40 | num_samples: 1000 41 | begin_ckpt: 5000 42 | freq: 5000 43 | end_ckpt: 300000 44 | pr_nn_k: 3 45 | verbose: false 46 | ensemble: false 47 | step_lr: 0.0 48 | n_steps_each: 0 49 | 50 | test: 51 | begin_ckpt: 5000 52 | end_ckpt: 300000 53 | batch_size: 100 54 | 55 | data: 56 | dataset: "KTH" 57 | image_size: 64 58 | channels: 1 59 | logit_transform: false 60 | uniform_dequantization: false 61 | gaussian_dequantization: false 62 | random_flip: true 63 | rescaled: true 64 | num_workers: 0 65 | num_frames: 5 66 | num_frames_cond: 10 67 | num_frames_future: 0 68 | prob_mask_cond: 0.0 69 | prob_mask_future: 0.0 70 | prob_mask_sync: false 71 | 72 | model: 73 | depth: deeper 74 | version: DDPM 75 | gamma: false 76 | arch: unetmore 77 | type: v1 78 | time_conditional: true 79 | dropout: 0.1 80 | sigma_dist: linear 81 | sigma_begin: 0.02 82 | sigma_end: 0.0001 83 | num_classes: 1000 84 | ema: true 85 | ema_rate: 0.999 86 | spec_norm: false 87 | normalization: InstanceNorm++ 88 | nonlinearity: swish 89 | ngf: 96 90 | ch_mult: 91 | - 1 92 | - 2 93 | - 3 94 | - 4 95 | num_res_blocks: 2 # 8 for traditional 96 | attn_resolutions: 97 | - 8 98 | - 16 99 | - 32 # can use only 16 for traditional 100 | n_head_channels: 96 # -1 for traditional 101 | conditional: true 102 | noise_in_cond: false 103 | output_all_frames: false # could be useful especially for 3d models 104 | cond_emb: false 105 | spade: true 106 | spade_dim: 128 107 | 108 | optim: 109 | weight_decay: 0.000 110 | optimizer: "Adam" 111 | lr: 0.0002 112 | warmup: 5000 113 | beta1: 0.9 114 | amsgrad: false 115 | eps: 0.00000001 116 | grad_clip: 1.0 117 | -------------------------------------------------------------------------------- /configs/smmnist_DDPM_big5.yml: -------------------------------------------------------------------------------- 1 | training: 2 | L1: false 3 | batch_size: 64 4 | n_epochs: 1000000 5 | n_iters: 3000001 6 | snapshot_freq: 50000 7 | snapshot_sampling: true 8 | sample_freq: 50000 9 | val_freq: 100 10 | log_freq: 50 11 | log_all_sigmas: false 12 | 13 | sampling: 14 | batch_size: 100 15 | data_init: false 16 | ckpt_id: 0 17 | final_only: true 18 | fid: false 19 | ssim: true 20 | fvd: true 21 | denoise: true 22 | subsample: 1000 23 | num_samples4fid: 10000 24 | num_samples4fvd: 10000 25 | inpainting: false 26 | interpolation: false 27 | n_interpolations: 15 28 | consistent: true 29 | step_lr: 0.0 30 | n_steps_each: 0 31 | train: false 32 | num_frames_pred: 20 33 | clip_before: true 34 | max_data_iter: 100000 35 | init_prev_t: -1.0 # if >0, we start next_frame at prev_frame starting with noise t=init_prev_t 36 | one_frame_at_a_time: false 37 | preds_per_test: 1 38 | 39 | fast_fid: 40 | batch_size: 1000 41 | num_samples: 1000 42 | begin_ckpt: 5000 43 | freq: 5000 44 | end_ckpt: 300000 45 | pr_nn_k: 3 46 | verbose: false 47 | ensemble: false 48 | step_lr: 0.0 49 | n_steps_each: 0 50 | 51 | test: 52 | begin_ckpt: 5000 53 | end_ckpt: 300000 54 | batch_size: 100 55 | 56 | data: 57 | dataset: "StochasticMovingMNIST" 58 | image_size: 64 59 | channels: 1 60 | logit_transform: false 61 | uniform_dequantization: false 62 | gaussian_dequantization: false 63 | random_flip: true 64 | rescaled: true 65 | num_workers: 0 66 | num_digits: 2 67 | step_length: 0.1 68 | num_frames: 5 69 | num_frames_cond: 5 70 | num_frames_future: 0 71 | prob_mask_cond: 0.0 72 | prob_mask_future: 0.0 73 | prob_mask_sync: false 74 | 75 | model: 76 | depth: deep 77 | version: DDPM 78 | gamma: false 79 | arch: unetmore 80 | type: v1 81 | time_conditional: true 82 | dropout: 0.1 83 | sigma_dist: linear 84 | sigma_begin: 0.02 85 | sigma_end: 0.0001 86 | num_classes: 1000 87 | ema: true 88 | ema_rate: 0.999 89 | spec_norm: false 90 | normalization: InstanceNorm++ 91 | nonlinearity: swish 92 | ngf: 64 93 | ch_mult: 94 | - 1 95 | - 2 96 | - 3 97 | - 4 98 | num_res_blocks: 2 # 8 for traditional 99 | attn_resolutions: 100 | - 8 101 | - 16 102 | - 32 # can use only 16 for traditional 103 | n_head_channels: 64 # -1 for traditional 104 | conditional: true 105 | noise_in_cond: false 106 | output_all_frames: false # could be useful especially for 3d models 107 | cond_emb: false 108 | spade: false 109 | spade_dim: 128 110 | 111 | optim: 112 | weight_decay: 0.000 113 | optimizer: "Adam" 114 | lr: 0.0002 115 | warmup: 1000 116 | beta1: 0.9 117 | amsgrad: false 118 | eps: 0.00000001 119 | grad_clip: 1.0 120 | -------------------------------------------------------------------------------- /configs/smmnist_DDPM_big5_spade.yml: -------------------------------------------------------------------------------- 1 | training: 2 | L1: false 3 | batch_size: 64 4 | n_epochs: 1000000 5 | n_iters: 3000001 6 | snapshot_freq: 50000 7 | snapshot_sampling: true 8 | sample_freq: 50000 9 | val_freq: 100 10 | log_freq: 50 11 | log_all_sigmas: false 12 | 13 | sampling: 14 | batch_size: 100 15 | data_init: false 16 | ckpt_id: 0 17 | final_only: true 18 | fid: false 19 | ssim: true 20 | fvd: true 21 | denoise: true 22 | subsample: 1000 23 | num_samples4fid: 10000 24 | num_samples4fvd: 10000 25 | inpainting: false 26 | interpolation: false 27 | n_interpolations: 15 28 | consistent: true 29 | step_lr: 0.0 30 | n_steps_each: 0 31 | train: false 32 | num_frames_pred: 20 33 | clip_before: true 34 | max_data_iter: 100000 35 | init_prev_t: -1.0 # if >0, we start next_frame at prev_frame starting with noise t=init_prev_t 36 | one_frame_at_a_time: false 37 | preds_per_test: 1 38 | 39 | fast_fid: 40 | batch_size: 1000 41 | num_samples: 1000 42 | begin_ckpt: 5000 43 | freq: 5000 44 | end_ckpt: 300000 45 | pr_nn_k: 3 46 | verbose: false 47 | ensemble: false 48 | step_lr: 0.0 49 | n_steps_each: 0 50 | 51 | test: 52 | begin_ckpt: 5000 53 | end_ckpt: 300000 54 | batch_size: 100 55 | 56 | data: 57 | dataset: "StochasticMovingMNIST" 58 | image_size: 64 59 | channels: 1 60 | logit_transform: false 61 | uniform_dequantization: false 62 | gaussian_dequantization: false 63 | random_flip: true 64 | rescaled: true 65 | num_workers: 0 66 | num_digits: 2 67 | step_length: 0.1 68 | num_frames: 5 69 | num_frames_cond: 5 70 | num_frames_future: 0 71 | prob_mask_cond: 0.0 72 | prob_mask_future: 0.0 73 | prob_mask_sync: false 74 | 75 | model: 76 | depth: deep 77 | version: DDPM 78 | gamma: false 79 | arch: unetmore 80 | type: v1 81 | time_conditional: true 82 | dropout: 0.1 83 | sigma_dist: linear 84 | sigma_begin: 0.02 85 | sigma_end: 0.0001 86 | num_classes: 1000 87 | ema: true 88 | ema_rate: 0.999 89 | spec_norm: false 90 | normalization: InstanceNorm++ 91 | nonlinearity: swish 92 | ngf: 64 93 | ch_mult: 94 | - 1 95 | - 2 96 | - 3 97 | - 4 98 | num_res_blocks: 2 # 8 for traditional 99 | attn_resolutions: 100 | - 8 101 | - 16 102 | - 32 # can use only 16 for traditional 103 | n_head_channels: 64 # -1 for traditional 104 | conditional: true 105 | noise_in_cond: false 106 | output_all_frames: false # could be useful especially for 3d models 107 | cond_emb: false 108 | spade: true 109 | spade_dim: 128 110 | 111 | optim: 112 | weight_decay: 0.000 113 | optimizer: "Adam" 114 | lr: 0.0002 115 | warmup: 1000 116 | beta1: 0.9 117 | amsgrad: false 118 | eps: 0.00000001 119 | grad_clip: 1.0 120 | -------------------------------------------------------------------------------- /configs/smmnist_DDPM_small5.yml: -------------------------------------------------------------------------------- 1 | training: 2 | L1: false 3 | batch_size: 64 4 | n_epochs: 1000000 5 | n_iters: 3000001 6 | snapshot_freq: 50000 7 | snapshot_sampling: true 8 | sample_freq: 50000 9 | val_freq: 100 10 | log_freq: 50 11 | log_all_sigmas: false 12 | 13 | sampling: 14 | batch_size: 100 15 | data_init: false 16 | ckpt_id: 0 17 | final_only: true 18 | fid: false 19 | ssim: true 20 | fvd: true 21 | denoise: true 22 | subsample: 1000 23 | num_samples4fid: 10000 24 | num_samples4fvd: 10000 25 | inpainting: false 26 | interpolation: false 27 | n_interpolations: 15 28 | consistent: true 29 | step_lr: 0.0 30 | n_steps_each: 0 31 | train: false 32 | num_frames_pred: 20 33 | clip_before: true 34 | max_data_iter: 100000 35 | init_prev_t: -1.0 # if >0, we start next_frame at prev_frame starting with noise t=init_prev_t 36 | one_frame_at_a_time: false 37 | preds_per_test: 1 38 | 39 | fast_fid: 40 | batch_size: 1000 41 | num_samples: 1000 42 | begin_ckpt: 5000 43 | freq: 5000 44 | end_ckpt: 300000 45 | pr_nn_k: 3 46 | verbose: false 47 | ensemble: false 48 | step_lr: 0.0 49 | n_steps_each: 0 50 | 51 | test: 52 | begin_ckpt: 5000 53 | end_ckpt: 300000 54 | batch_size: 100 55 | 56 | data: 57 | dataset: "StochasticMovingMNIST" 58 | image_size: 64 59 | channels: 1 60 | logit_transform: false 61 | uniform_dequantization: false 62 | gaussian_dequantization: false 63 | random_flip: true 64 | rescaled: true 65 | num_workers: 0 66 | num_digits: 2 67 | step_length: 0.1 68 | num_frames: 2 69 | num_frames_cond: 5 70 | num_frames_future: 0 71 | prob_mask_cond: 0.0 72 | prob_mask_future: 0.0 73 | prob_mask_sync: false 74 | 75 | model: 76 | depth: deep 77 | version: DDPM 78 | gamma: false 79 | arch: unet 80 | type: v1 81 | time_conditional: true 82 | dropout: 0.1 83 | sigma_dist: linear 84 | sigma_begin: 0.02 85 | sigma_end: 0.0001 86 | num_classes: 1000 87 | ema: true 88 | ema_rate: 0.999 89 | spec_norm: false 90 | normalization: InstanceNorm++ 91 | nonlinearity: swish 92 | ngf: 32 93 | ch_mult: 94 | - 1 95 | - 2 96 | - 2 97 | - 2 98 | num_res_blocks: 1 # 8 for traditional 99 | attn_resolutions: 100 | - 8 101 | - 16 102 | - 32 # can use only 16 for traditional 103 | n_head_channels: 64 # -1 for traditional 104 | conditional: true 105 | noise_in_cond: false 106 | output_all_frames: false # could be useful especially for 3d models 107 | cond_emb: false 108 | spade: false 109 | spade_dim: 128 110 | 111 | optim: 112 | weight_decay: 0.000 113 | optimizer: "Adam" 114 | lr: 0.0002 115 | warmup: 1000 116 | beta1: 0.9 117 | amsgrad: false 118 | eps: 0.00000001 119 | grad_clip: 1.0 120 | -------------------------------------------------------------------------------- /configs/smmnist_DDPM_small5_3d_32Gb.yml: -------------------------------------------------------------------------------- 1 | training: 2 | L1: false 3 | batch_size: 64 4 | n_epochs: 1000000 5 | n_iters: 3000001 6 | snapshot_freq: 50000 7 | snapshot_sampling: true 8 | sample_freq: 50000 9 | val_freq: 100 10 | log_freq: 50 11 | log_all_sigmas: false 12 | 13 | sampling: 14 | batch_size: 100 15 | data_init: false 16 | ckpt_id: 0 17 | final_only: true 18 | fid: false 19 | ssim: true 20 | fvd: true 21 | denoise: true 22 | subsample: 1000 23 | num_samples4fid: 10000 24 | num_samples4fvd: 10000 25 | inpainting: false 26 | interpolation: false 27 | n_interpolations: 15 28 | consistent: true 29 | step_lr: 0.0 30 | n_steps_each: 0 31 | train: false 32 | num_frames_pred: 20 33 | clip_before: true 34 | max_data_iter: 100000 35 | init_prev_t: -1.0 # if >0, we start next_frame at prev_frame starting with noise t=init_prev_t 36 | one_frame_at_a_time: false 37 | preds_per_test: 1 38 | 39 | fast_fid: 40 | batch_size: 1000 41 | num_samples: 1000 42 | begin_ckpt: 5000 43 | freq: 5000 44 | end_ckpt: 300000 45 | pr_nn_k: 3 46 | verbose: false 47 | ensemble: false 48 | step_lr: 0.0 49 | n_steps_each: 0 50 | 51 | test: 52 | begin_ckpt: 5000 53 | end_ckpt: 300000 54 | batch_size: 100 55 | 56 | data: 57 | dataset: "StochasticMovingMNIST" 58 | image_size: 64 59 | channels: 1 60 | logit_transform: false 61 | uniform_dequantization: false 62 | gaussian_dequantization: false 63 | random_flip: true 64 | rescaled: true 65 | num_workers: 0 66 | num_digits: 2 67 | step_length: 0.1 68 | num_frames: 2 69 | num_frames_cond: 5 70 | num_frames_future: 0 71 | prob_mask_cond: 0.0 72 | prob_mask_future: 0.0 73 | prob_mask_sync: false 74 | 75 | model: 76 | depth: deep 77 | version: DDPM 78 | gamma: false 79 | arch: unetmore3d 80 | type: v1 81 | time_conditional: true 82 | dropout: 0.1 83 | sigma_dist: linear 84 | sigma_begin: 0.02 85 | sigma_end: 0.0001 86 | num_classes: 1000 87 | ema: true 88 | ema_rate: 0.999 89 | spec_norm: false 90 | normalization: InstanceNorm++ 91 | nonlinearity: swish 92 | ngf: 12 93 | ch_mult: 94 | - 1 95 | - 1 96 | - 2 97 | - 2 98 | num_res_blocks: 1 99 | attn_resolutions: 100 | - 16 # only 16 for traditional 101 | n_head_channels: -1 # -1 for traditional 102 | conditional: true 103 | noise_in_cond: false 104 | output_all_frames: false # could be useful especially for 3d models 105 | cond_emb: false 106 | spade: false 107 | spade_dim: 128 108 | 109 | optim: 110 | weight_decay: 0.000 111 | optimizer: "Adam" 112 | lr: 0.0002 113 | warmup: 1000 114 | beta1: 0.9 115 | amsgrad: false 116 | eps: 0.00000001 117 | grad_clip: 1.0 118 | -------------------------------------------------------------------------------- /configs/tower.yml: -------------------------------------------------------------------------------- 1 | training: 2 | L1: false 3 | batch_size: 128 4 | n_epochs: 500000 5 | n_iters: 150001 6 | snapshot_freq: 50000 7 | snapshot_sampling: true 8 | sample_freq: 50000 9 | val_freq: 100 10 | anneal_power: 2 11 | log_all_sigmas: false 12 | 13 | sampling: 14 | batch_size: 36 15 | data_init: false 16 | step_lr: 0.0000018 17 | n_steps_each: 3 18 | ckpt_id: 0 19 | final_only: true 20 | fid: false 21 | ssim: true 22 | fvd: true 23 | denoise: true 24 | num_samples4fid: 10000 25 | inpainting: false 26 | interpolation: false 27 | n_interpolations: 10 28 | clip_before: true 29 | max_data_iter: 100000 30 | init_prev_t: -1.0 # if >0, we start next_frame at prev_frame starting with noise t=init_prev_t 31 | one_frame_at_a_time: false 32 | preds_per_test: 1 33 | 34 | fast_fid: 35 | batch_size: 1000 36 | num_samples: 1000 37 | step_lr: 0.0000018 38 | n_steps_each: 3 39 | begin_ckpt: 100000 40 | end_ckpt: 150000 41 | verbose: false 42 | ensemble: false 43 | 44 | test: 45 | begin_ckpt: 5000 46 | end_ckpt: 150000 47 | batch_size: 100 48 | 49 | data: 50 | dataset: "LSUN" 51 | category: "tower" 52 | image_size: 128 53 | channels: 3 54 | logit_transform: false 55 | uniform_dequantization: false 56 | gaussian_dequantization: false 57 | random_flip: true 58 | rescaled: false 59 | num_workers: 32 60 | num_frames: 1 61 | num_frames_cond: 5 62 | num_frames_future: 0 63 | prob_mask_cond: 0.0 64 | prob_mask_future: 0.0 65 | prob_mask_sync: false 66 | 67 | model: 68 | depth: deeper 69 | sigma_begin: 190 70 | num_classes: 1086 71 | ema: true 72 | ema_rate: 0.999 73 | spec_norm: false 74 | sigma_dist: geometric 75 | sigma_end: 0.01 76 | normalization: InstanceNorm++ 77 | nonlinearity: elu 78 | ngf: 128 79 | ch_mult: 80 | - 1 81 | - 2 82 | - 2 83 | - 2 84 | num_res_blocks: 1 # 8 for traditional 85 | attn_resolutions: 86 | - 8 87 | - 16 88 | - 32 # can use only 16 for traditional 89 | n_head_channels: 64 # -1 for traditional 90 | conditional: false 91 | noise_in_cond: false 92 | output_all_frames: false # could be useful especially for 3d models 93 | cond_emb: false 94 | spade: false 95 | spade_dim: 128 96 | 97 | optim: 98 | weight_decay: 0.000 99 | optimizer: "Adam" 100 | lr: 0.0001 101 | beta1: 0.9 102 | amsgrad: false 103 | eps: 0.00000001 104 | -------------------------------------------------------------------------------- /configs/ucf101.yml: -------------------------------------------------------------------------------- 1 | training: 2 | L1: false 3 | batch_size: 64 4 | n_epochs: 1000000 5 | n_iters: 3000001 6 | snapshot_freq: 50000 7 | snapshot_sampling: true 8 | sample_freq: 50000 9 | val_freq: 1000 10 | log_freq: 100 11 | log_all_sigmas: false 12 | 13 | sampling: 14 | batch_size: 100 15 | data_init: false 16 | ckpt_id: 0 17 | final_only: true 18 | fid: false 19 | ssim: true 20 | fvd: true 21 | denoise: true 22 | subsample: 100 23 | num_samples4fid: 10000 24 | num_samples4fvd: 10000 25 | inpainting: false 26 | interpolation: false 27 | n_interpolations: 15 28 | consistent: true 29 | step_lr: 0.0 30 | n_steps_each: 0 31 | train: false 32 | num_frames_pred: 28 33 | clip_before: true 34 | max_data_iter: 100000 35 | init_prev_t: -1.0 # if >0, we start next_frame at prev_frame starting with noise t=init_prev_t 36 | one_frame_at_a_time: false 37 | preds_per_test: 1 38 | 39 | fast_fid: 40 | batch_size: 1000 41 | num_samples: 1000 42 | begin_ckpt: 5000 43 | freq: 5000 44 | end_ckpt: 300000 45 | pr_nn_k: 3 46 | verbose: false 47 | ensemble: false 48 | step_lr: 0.0 49 | n_steps_each: 0 50 | 51 | test: 52 | begin_ckpt: 5000 53 | end_ckpt: 300000 54 | batch_size: 100 55 | 56 | data: 57 | dataset: "UCF101" 58 | image_size: 64 59 | channels: 3 60 | logit_transform: false 61 | uniform_dequantization: false 62 | gaussian_dequantization: false 63 | random_flip: true 64 | rescaled: true 65 | color_jitter: 0.0 66 | num_workers: 4 67 | num_frames: 4 68 | num_frames_cond: 4 69 | num_frames_future: 0 70 | prob_mask_cond: 0.0 71 | prob_mask_future: 0.0 72 | prob_mask_sync: false 73 | 74 | model: 75 | depth: deeper 76 | version: DDPM 77 | gamma: false 78 | arch: unetmore 79 | type: v1 80 | time_conditional: true 81 | dropout: 0.1 82 | sigma_dist: linear 83 | sigma_begin: 0.02 84 | sigma_end: 0.0001 85 | num_classes: 1000 86 | ema: true 87 | ema_rate: 0.999 88 | spec_norm: false 89 | normalization: InstanceNorm++ 90 | nonlinearity: swish 91 | ngf: 192 92 | ch_mult: 93 | - 1 94 | - 2 95 | - 3 96 | - 4 97 | num_res_blocks: 2 # 8 for traditional 98 | attn_resolutions: 99 | - 8 100 | - 16 101 | - 32 # can use only 16 for traditional 102 | n_head_channels: 96 # -1 for traditional 103 | conditional: true 104 | noise_in_cond: false 105 | output_all_frames: false # could be useful especially for 3d models 106 | cond_emb: false 107 | spade: false 108 | spade_dim: 128 109 | 110 | optim: 111 | weight_decay: 0.000 112 | optimizer: "Adam" 113 | lr: 0.0001 114 | warmup: 5000 115 | beta1: 0.9 116 | amsgrad: false 117 | eps: 0.00000001 118 | grad_clip: 1.0 119 | -------------------------------------------------------------------------------- /datasets/bair.py: -------------------------------------------------------------------------------- 1 | # https://github.com/edenton/svg/blob/master/data/bair.py 2 | import numpy as np 3 | import torch 4 | 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | from torchvision import transforms 8 | 9 | from .h5 import HDF5Dataset 10 | 11 | 12 | class BAIRDataset(Dataset): 13 | 14 | def __init__(self, data_path, frames_per_sample=5, random_time=True, random_horizontal_flip=True, color_jitter=0, 15 | total_videos=-1, with_target=True): 16 | 17 | self.data_path = data_path # '/path/to/Datasets/BAIR_h5/train' (with shard_0001.hdf5 in it), or /path/to/BAIR_h5/train/shard_0001.hdf5 18 | self.frames_per_sample = frames_per_sample 19 | self.random_time = random_time 20 | self.random_horizontal_flip = random_horizontal_flip 21 | self.color_jitter = color_jitter 22 | self.total_videos = total_videos # If we wish to restrict total number of videos (e.g. for val) 23 | self.with_target = with_target 24 | 25 | self.jitter = transforms.ColorJitter(hue=color_jitter) 26 | 27 | # Read h5 files as dataset 28 | self.videos_ds = HDF5Dataset(self.data_path) 29 | 30 | print(f"Dataset length: {self.__len__()}") 31 | 32 | def window_stack(self, a, width=3, step=1): 33 | return torch.stack([a[i:1+i-width or None:step] for i in range(width)]).transpose(0, 1) 34 | 35 | def len_of_vid(self, index): 36 | video_index = index % self.__len__() 37 | shard_idx, idx_in_shard = self.videos_ds.get_indices(video_index) 38 | with self.videos_ds.opener(self.videos_ds.shard_paths[shard_idx]) as f: 39 | video_len = f['len'][str(idx_in_shard)][()] 40 | return video_len 41 | 42 | def __len__(self): 43 | return self.total_videos if self.total_videos > 0 else len(self.videos_ds) 44 | 45 | def max_index(self): 46 | return len(self.videos_ds) 47 | 48 | def __getitem__(self, index, time_idx=0): 49 | 50 | # Use `index` to select the video, and then 51 | # randomly choose a `frames_per_sample` window of frames in the video 52 | video_index = round(index / (self.__len__() - 1) * (self.max_index() - 1)) 53 | shard_idx, idx_in_shard = self.videos_ds.get_indices(video_index) 54 | 55 | prefinals = [] 56 | flip_p = np.random.randint(2) == 0 if self.random_horizontal_flip else 0 57 | with self.videos_ds.opener(self.videos_ds.shard_paths[shard_idx]) as f: 58 | video_len = f['len'][str(idx_in_shard)][()] 59 | if self.random_time and video_len > self.frames_per_sample: 60 | time_idx = np.random.choice(video_len - self.frames_per_sample) 61 | for i in range(time_idx, min(time_idx + self.frames_per_sample, video_len)): 62 | # byte_str = f[str(idx_in_shard)][str(i)][()] 63 | # img = Image.frombytes('RGB', (64, 64), byte_str) 64 | # arr = np.expand_dims(np.array(img.getdata()).reshape(img.size[1], img.size[0], 3), 0) 65 | img = f[str(idx_in_shard)][str(i)][()] 66 | arr = transforms.RandomHorizontalFlip(flip_p)(transforms.ToTensor()(img)) 67 | prefinals.append(arr) 68 | 69 | data = torch.stack(prefinals) 70 | data = self.jitter(data) 71 | 72 | if self.with_target: 73 | return data, torch.tensor(1) 74 | else: 75 | return data 76 | -------------------------------------------------------------------------------- /datasets/bair_convert.py: -------------------------------------------------------------------------------- 1 | # https://github.com/edenton/svg/blob/master/data/convert_bair.py 2 | import argparse 3 | import glob 4 | import imageio 5 | import io 6 | import numpy as np 7 | import os 8 | import sys 9 | import tensorflow as tf 10 | 11 | from PIL import Image 12 | from tensorflow.python.platform import gfile 13 | from tqdm import tqdm 14 | 15 | from h5 import HDF5Maker 16 | 17 | 18 | def get_seq(data_dir, dname): 19 | data_dir = '%s/softmotion30_44k/%s' % (data_dir, dname) 20 | 21 | filenames = gfile.Glob(os.path.join(data_dir, '*')) 22 | if not filenames: 23 | raise RuntimeError('No data files found.') 24 | 25 | for f in filenames: 26 | k = 0 27 | # tf.enable_eager_execution() 28 | for serialized_example in tf.python_io.tf_record_iterator(f): 29 | example = tf.train.Example() 30 | example.ParseFromString(serialized_example) 31 | image_seq = [] 32 | for i in range(30): 33 | image_name = str(i) + '/image_aux1/encoded' 34 | byte_str = example.features.feature[image_name].bytes_list.value[0] 35 | # image_seq.append(byte_str) 36 | img = Image.frombytes('RGB', (64, 64), byte_str) 37 | arr = np.array(img.getdata()).reshape(img.size[1], img.size[0], 3) 38 | image_seq.append(arr) 39 | # image_seq = np.concatenate(image_seq, axis=0) 40 | k = k + 1 41 | yield f, k, image_seq 42 | 43 | 44 | def make_h5_from_bair(bair_dir, split='train', out_dir='./h5_ds', vids_per_shard=100000, force_h5=False): 45 | 46 | # H5 maker 47 | h5_maker = HDF5Maker(out_dir, num_per_shard=vids_per_shard, force=force_h5, video=True) 48 | 49 | seq_generator = get_seq(bair_dir, split) 50 | 51 | filenames = gfile.Glob(os.path.join('%s/softmotion30_44k/%s' % (bair_dir, split), '*')) 52 | for file in tqdm(filenames): 53 | 54 | # num = sum(1 for _ in tf.python_io.tf_record_iterator(file)) 55 | num = 256 56 | for i in tqdm(range(num)): 57 | 58 | try: 59 | f, k, seq = next(seq_generator) 60 | # h5_maker.add_data(seq, dtype=None) 61 | h5_maker.add_data(seq, dtype='uint8') 62 | 63 | except StopIteration: 64 | break 65 | 66 | except (KeyboardInterrupt, SystemExit): 67 | print("Ctrl+C!!") 68 | break 69 | 70 | except: 71 | e = sys.exc_info()[0] 72 | print("ERROR:", e) 73 | 74 | h5_maker.close() 75 | 76 | 77 | if __name__ == "__main__": 78 | 79 | parser = argparse.ArgumentParser() 80 | parser.add_argument('--out_dir', type=str, help="Directory to save .hdf5 files") 81 | parser.add_argument('--bair_dir', type=str, help="Directory with videos") 82 | parser.add_argument('--vids_per_shard', type=int, default=100000) 83 | parser.add_argument('--force_h5', type=eval, default=False) 84 | 85 | args = parser.parse_args() 86 | 87 | make_h5_from_bair(out_dir=os.path.join(args.out_dir, 'train'), bair_dir=args.bair_dir, split='train', vids_per_shard=args.vids_per_shard, force_h5=args.force_h5) 88 | make_h5_from_bair(out_dir=os.path.join(args.out_dir, 'test'), bair_dir=args.bair_dir, split='test', vids_per_shard=args.vids_per_shard, force_h5=args.force_h5) 89 | -------------------------------------------------------------------------------- /datasets/bair_download.sh: -------------------------------------------------------------------------------- 1 | TARGET_DIR=$1 2 | if [ -z $TARGET_DIR ] 3 | then 4 | echo "Must specify target directory" 5 | else 6 | mkdir $TARGET_DIR/ 7 | URL=http://rail.eecs.berkeley.edu/datasets/bair_robot_pushing_dataset_v0.tar 8 | wget $URL -P $TARGET_DIR 9 | tar -xvf $TARGET_DIR/bair_robot_pushing_dataset_v0.tar -C $TARGET_DIR 10 | fi 11 | -------------------------------------------------------------------------------- /datasets/celeba.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import PIL 4 | from .vision import VisionDataset 5 | from .utils import download_file_from_google_drive, check_integrity 6 | 7 | 8 | class CelebA(VisionDataset): 9 | """`Large-scale CelebFaces Attributes (CelebA) Dataset `_ Dataset. 10 | 11 | Args: 12 | root (string): Root directory where images are downloaded to. 13 | split (string): One of {'train', 'valid', 'test'}. 14 | Accordingly dataset is selected. 15 | target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``, 16 | or ``landmarks``. Can also be a list to output a tuple with all specified target types. 17 | The targets represent: 18 | ``attr`` (np.array shape=(40,) dtype=int): binary (0, 1) labels for attributes 19 | ``identity`` (int): label for each person (data points with the same identity are the same person) 20 | ``bbox`` (np.array shape=(4,) dtype=int): bounding box (x, y, width, height) 21 | ``landmarks`` (np.array shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x, 22 | righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y) 23 | Defaults to ``attr``. 24 | transform (callable, optional): A function/transform that takes in an PIL image 25 | and returns a transformed version. E.g, ``transforms.ToTensor`` 26 | target_transform (callable, optional): A function/transform that takes in the 27 | target and transforms it. 28 | download (bool, optional): If true, downloads the dataset from the internet and 29 | puts it in root directory. If dataset is already downloaded, it is not 30 | downloaded again. 31 | """ 32 | 33 | base_folder = "celeba" 34 | # There currently does not appear to be a easy way to extract 7z in python (without introducing additional 35 | # dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available 36 | # right now. 37 | file_list = [ 38 | # File ID MD5 Hash Filename 39 | ("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"), 40 | # ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"), 41 | # ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"), 42 | ("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"), 43 | ("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"), 44 | ("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"), 45 | ("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"), 46 | # ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"), 47 | ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"), 48 | ] 49 | 50 | def __init__(self, root, 51 | split="train", 52 | target_type="attr", 53 | transform=None, target_transform=None, 54 | download=False): 55 | import pandas 56 | super(CelebA, self).__init__(root) 57 | self.split = split 58 | if isinstance(target_type, list): 59 | self.target_type = target_type 60 | else: 61 | self.target_type = [target_type] 62 | self.transform = transform 63 | self.target_transform = target_transform 64 | 65 | if download: 66 | self.download() 67 | 68 | if not self._check_integrity(): 69 | raise RuntimeError('Dataset not found or corrupted.' + 70 | ' You can use download=True to download it') 71 | 72 | self.transform = transform 73 | self.target_transform = target_transform 74 | 75 | if split.lower() == "train": 76 | split = 0 77 | elif split.lower() == "valid": 78 | split = 1 79 | elif split.lower() == "test": 80 | split = 2 81 | else: 82 | raise ValueError('Wrong split entered! Please use split="train" ' 83 | 'or split="valid" or split="test"') 84 | 85 | with open(os.path.join(self.root, self.base_folder, "list_eval_partition.txt"), "r") as f: 86 | splits = pandas.read_csv(f, delim_whitespace=True, header=None, index_col=0) 87 | 88 | with open(os.path.join(self.root, self.base_folder, "identity_CelebA.txt"), "r") as f: 89 | self.identity = pandas.read_csv(f, delim_whitespace=True, header=None, index_col=0) 90 | 91 | with open(os.path.join(self.root, self.base_folder, "list_bbox_celeba.txt"), "r") as f: 92 | self.bbox = pandas.read_csv(f, delim_whitespace=True, header=1, index_col=0) 93 | 94 | with open(os.path.join(self.root, self.base_folder, "list_landmarks_align_celeba.txt"), "r") as f: 95 | self.landmarks_align = pandas.read_csv(f, delim_whitespace=True, header=1) 96 | 97 | with open(os.path.join(self.root, self.base_folder, "list_attr_celeba.txt"), "r") as f: 98 | self.attr = pandas.read_csv(f, delim_whitespace=True, header=1) 99 | 100 | mask = (splits[1] == split) 101 | self.filename = splits[mask].index.values 102 | self.identity = torch.as_tensor(self.identity[mask].values) 103 | self.bbox = torch.as_tensor(self.bbox[mask].values) 104 | self.landmarks_align = torch.as_tensor(self.landmarks_align[mask].values) 105 | self.attr = torch.as_tensor(self.attr[mask].values) 106 | self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1} 107 | 108 | def _check_integrity(self): 109 | for (_, md5, filename) in self.file_list: 110 | fpath = os.path.join(self.root, self.base_folder, filename) 111 | _, ext = os.path.splitext(filename) 112 | # Allow original archive to be deleted (zip and 7z) 113 | # Only need the extracted images 114 | if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5): 115 | return False 116 | 117 | # Should check a hash of the images 118 | return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba")) 119 | 120 | def download(self): 121 | import zipfile 122 | 123 | if self._check_integrity(): 124 | print('Files already downloaded and verified') 125 | return 126 | 127 | for (file_id, md5, filename) in self.file_list: 128 | download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5) 129 | 130 | with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f: 131 | f.extractall(os.path.join(self.root, self.base_folder)) 132 | 133 | def __getitem__(self, index): 134 | X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index])) 135 | 136 | target = [] 137 | for t in self.target_type: 138 | if t == "attr": 139 | target.append(self.attr[index, :]) 140 | elif t == "identity": 141 | target.append(self.identity[index, 0]) 142 | elif t == "bbox": 143 | target.append(self.bbox[index, :]) 144 | elif t == "landmarks": 145 | target.append(self.landmarks_align[index, :]) 146 | else: 147 | raise ValueError("Target type \"{}\" is not recognized.".format(t)) 148 | target = tuple(target) if len(target) > 1 else target[0] 149 | 150 | if self.transform is not None: 151 | X = self.transform(X) 152 | 153 | if self.target_transform is not None: 154 | target = self.target_transform(target) 155 | 156 | return X, target 157 | 158 | def __len__(self): 159 | return len(self.attr) 160 | 161 | def extra_repr(self): 162 | lines = ["Target type: {target_type}", "Split: {split}"] 163 | return '\n'.join(lines).format(**self.__dict__) 164 | -------------------------------------------------------------------------------- /datasets/cityscapes.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | 8 | from .h5 import HDF5Dataset 9 | 10 | 11 | class CityscapesDataset(Dataset): 12 | 13 | def __init__(self, data_path, frames_per_sample=5, random_time=True, random_horizontal_flip=True, color_jitter=0, 14 | total_videos=-1, with_target=True): 15 | 16 | self.data_path = data_path # '/path/to/Datasets/Cityscapes128_h5/train' (with shard_0001.hdf5 in it) 17 | self.frames_per_sample = frames_per_sample 18 | self.random_time = random_time 19 | self.random_horizontal_flip = random_horizontal_flip 20 | self.color_jitter = color_jitter 21 | self.total_videos = total_videos # If we wish to restrict total number of videos (e.g. for val) 22 | self.with_target = with_target 23 | 24 | self.jitter = transforms.ColorJitter(hue=color_jitter) 25 | 26 | # Read h5 files as dataset 27 | self.videos_ds = HDF5Dataset(self.data_path) 28 | 29 | print(f"Dataset length: {self.__len__()}") 30 | 31 | def window_stack(self, a, width=3, step=1): 32 | return torch.stack([a[i:1+i-width or None:step] for i in range(width)]).transpose(0, 1) 33 | 34 | def len_of_vid(self, index): 35 | video_index = index % self.__len__() 36 | shard_idx, idx_in_shard = self.videos_ds.get_indices(video_index) 37 | with self.videos_ds.opener(self.videos_ds.shard_paths[shard_idx]) as f: 38 | video_len = f['len'][str(idx_in_shard)][()] 39 | return video_len 40 | 41 | def __len__(self): 42 | return self.total_videos if self.total_videos > 0 else len(self.videos_ds) 43 | 44 | def max_index(self): 45 | return len(self.videos_ds) 46 | 47 | def __getitem__(self, index, time_idx=0): 48 | 49 | # Use `index` to select the video, and then 50 | # randomly choose a `frames_per_sample` window of frames in the video 51 | video_index = round(index / (self.__len__() - 1) * (self.max_index() - 1)) 52 | shard_idx, idx_in_shard = self.videos_ds.get_indices(video_index) 53 | 54 | prefinals = [] 55 | flip_p = np.random.randint(2) == 0 if self.random_horizontal_flip else 0 56 | with self.videos_ds.opener(self.videos_ds.shard_paths[shard_idx]) as f: 57 | video_len = f['len'][str(idx_in_shard)][()] 58 | if self.random_time and video_len > self.frames_per_sample: 59 | time_idx = np.random.choice(video_len - self.frames_per_sample) 60 | for i in range(time_idx, min(time_idx + self.frames_per_sample, video_len)): 61 | img = f[str(idx_in_shard)][str(i)][()] 62 | arr = transforms.RandomHorizontalFlip(flip_p)(transforms.ToTensor()(img)) 63 | prefinals.append(arr) 64 | 65 | data = torch.stack(prefinals) 66 | data = self.jitter(data) 67 | 68 | if self.with_target: 69 | return data, torch.tensor(1) 70 | else: 71 | return data 72 | -------------------------------------------------------------------------------- /datasets/cityscapes_convert.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import glob 4 | import numpy as np 5 | import os 6 | import sys 7 | 8 | from functools import partial 9 | from multiprocessing import Pool 10 | from PIL import Image 11 | from tqdm import tqdm 12 | 13 | from h5 import HDF5Maker 14 | 15 | 16 | def center_crop(image): 17 | h, w, c = image.shape 18 | new_h, new_w = h if h < w else w, w if w < h else h 19 | r_min, r_max = h//2 - new_h//2, h//2 + new_h//2 20 | c_min, c_max = w//2 - new_w//2, w//2 + new_w//2 21 | return image[r_min:r_max, c_min:c_max, :] 22 | 23 | 24 | def read_video(video_files, image_size): 25 | frames = [] 26 | for file in video_files: 27 | frame = cv2.imread(file) 28 | img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 29 | img_cc = center_crop(img) 30 | pil_im = Image.fromarray(img_cc) 31 | pil_im_rsz = pil_im.resize((image_size, image_size), Image.LANCZOS) 32 | frames.append(np.array(pil_im_rsz)) 33 | return frames 34 | 35 | 36 | def filename_to_num(filename): 37 | return 1000.*sum([ord(x) for x in os.path.basename(filename).split('_')[0]]) + 100.*int(os.path.basename(filename).split('_')[1]) + int(os.path.basename(filename).split('_')[2]) 38 | 39 | 40 | def process_video(video_files, image_size): 41 | frames = [] 42 | try: 43 | frames = read_video(video_files, image_size) 44 | except StopIteration: 45 | pass 46 | # break 47 | except (KeyboardInterrupt, SystemExit): 48 | print("Ctrl+C!!") 49 | # break 50 | except: 51 | e = sys.exc_info()[0] 52 | print("ERROR:", e) 53 | return frames 54 | 55 | 56 | def make_h5_from_cityscapes_multi(cityscapes_dir, image_size, out_dir='./h5_ds', vids_per_shard=100000, force_h5=False): 57 | 58 | # H5 maker 59 | h5_maker = HDF5Maker(out_dir, num_per_shard=vids_per_shard, force=force_h5, video=True) 60 | 61 | filenames_all = sorted(glob.glob(os.path.join(cityscapes_dir, '*', '*.png'))) 62 | videos = np.array(filenames_all).reshape(-1, 30) 63 | 64 | p_video = partial(process_video, image_size=image_size) 65 | 66 | # Process videos 100 at a time 67 | pbar = tqdm(total=len(videos)) 68 | for i in range(int(np.ceil(len(videos)/100))): 69 | 70 | # pool 71 | with Pool() as pool: 72 | # tic = time.time() 73 | frames_all = pool.imap(p_video, videos[i*100:(i+1)*100]) 74 | # add frames to h5 75 | for frames in frames_all: 76 | if len(frames) > 0: 77 | h5_maker.add_data(frames, dtype='uint8') 78 | # toc = time.time() 79 | 80 | pbar.update(len(videos[i*100:(i+1)*100])) 81 | 82 | pbar.close() 83 | h5_maker.close() 84 | 85 | 86 | def make_h5_from_cityscapes(cityscapes_dir, image_size, out_dir='./h5_ds', vids_per_shard=100000, force_h5=False): 87 | 88 | # H5 maker 89 | h5_maker = HDF5Maker(out_dir, num_per_shard=vids_per_shard, force=force_h5, video=True) 90 | 91 | filenames_all = sorted(glob.glob(os.path.join(cityscapes_dir, '*', '*.png'))) 92 | 93 | videos = np.array(filenames_all).reshape(-1, 30) 94 | 95 | for video_files in tqdm(videos): 96 | 97 | try: 98 | frames = read_video(video_files, image_size) 99 | h5_maker.add_data(frames, dtype='uint8') 100 | 101 | except StopIteration: 102 | break 103 | 104 | except (KeyboardInterrupt, SystemExit): 105 | print("Ctrl+C!!") 106 | break 107 | 108 | except: 109 | e = sys.exc_info()[0] 110 | print("ERROR:", e) 111 | 112 | h5_maker.close() 113 | 114 | 115 | if __name__ == "__main__": 116 | 117 | parser = argparse.ArgumentParser() 118 | parser.add_argument('--out_dir', type=str, help="Directory to save .hdf5 files") 119 | parser.add_argument('--leftImg8bit_sequence_dir', type=str, help="Path to 'leftImg8bit_sequence' ") 120 | parser.add_argument('--image_size', type=int, default=128) 121 | parser.add_argument('--vids_per_shard', type=int, default=100000) 122 | parser.add_argument('--force_h5', type=eval, default=False) 123 | 124 | args = parser.parse_args() 125 | 126 | make_h5_from_cityscapes_multi(out_dir=os.path.join(args.out_dir, "train"), cityscapes_dir=os.path.join(args.leftImg8bit_sequence_dir, "train"), image_size=args.image_size, vids_per_shard=args.vids_per_shard, force_h5=args.force_h5) 127 | make_h5_from_cityscapes_multi(out_dir=os.path.join(args.out_dir, "val"), cityscapes_dir=os.path.join(args.leftImg8bit_sequence_dir, "val"), image_size=args.image_size, vids_per_shard=args.vids_per_shard, force_h5=args.force_h5) 128 | make_h5_from_cityscapes_multi(out_dir=os.path.join(args.out_dir, "test"), cityscapes_dir=os.path.join(args.leftImg8bit_sequence_dir, "test"), image_size=args.image_size, vids_per_shard=args.vids_per_shard, force_h5=args.force_h5) 129 | -------------------------------------------------------------------------------- /datasets/cityscapes_download.sh: -------------------------------------------------------------------------------- 1 | # https://github.com/cemsaz/city-scapes-script 2 | TARGET_DIR=$1 3 | USERNAME=$2 4 | PASSWORD=$3 5 | if [ -z $TARGET_DIR ] 6 | then 7 | echo "Must specify target directory" 8 | else 9 | mkdir $TARGET_DIR/ 10 | # Login 11 | wget --keep-session-cookies --save-cookies=cookies.txt --post-data "username=$1&password=$2&submit=Login" https://www.cityscapes-dataset.com/login/ -P $TARGET_DIR 12 | # Download leftImg8bit_sequence_trainvaltest.zip (324GB) 13 | wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=14 -P $TARGET_DIR 14 | # Unzip 15 | unzip $TARGET_DIR/leftImg8bit_sequence_trainvaltest.zip -d $TARGET_DIR 16 | fi 17 | -------------------------------------------------------------------------------- /datasets/ffhq.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | 3 | import lmdb 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class FFHQ(Dataset): 9 | def __init__(self, path, transform, resolution=8): 10 | self.env = lmdb.open( 11 | path, 12 | max_readers=32, 13 | readonly=True, 14 | lock=False, 15 | readahead=False, 16 | meminit=False, 17 | ) 18 | 19 | if not self.env: 20 | raise IOError('Cannot open lmdb dataset', path) 21 | 22 | with self.env.begin(write=False) as txn: 23 | self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8')) 24 | 25 | self.resolution = resolution 26 | self.transform = transform 27 | 28 | def __len__(self): 29 | return self.length 30 | 31 | def __getitem__(self, index): 32 | with self.env.begin(write=False) as txn: 33 | key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8') 34 | img_bytes = txn.get(key) 35 | 36 | buffer = BytesIO(img_bytes) 37 | img = Image.open(buffer) 38 | img = self.transform(img) 39 | target = 0 40 | 41 | return img, target -------------------------------------------------------------------------------- /datasets/ffhq_tfrecords.py: -------------------------------------------------------------------------------- 1 | # https://github.com/podgorskiy/StyleGan/blob/master/dataloader.py 2 | 3 | # Copyright 2019 Stanislav Pidhorskyi 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | 18 | # import dareblopy as db 19 | import numpy as np 20 | import torch 21 | 22 | 23 | class TFRecordsDataLoader: 24 | def __init__(self, tfrecords_paths, batch_size, 25 | ch=3, img_size=None, length=None, seed=0, buffer_size_mb=200): 26 | self.iterator = None 27 | self.filenames = tfrecords_paths 28 | self.batch_size = batch_size 29 | self.ch = ch 30 | self.img_size = img_size 31 | self.length = length 32 | self.seed = seed 33 | self.buffer_size_mb = buffer_size_mb 34 | 35 | if self.img_size is None or self.ch is None: 36 | raw_dataset = tf.data.TFRecordDataset(self.filenames[0]) 37 | for raw_record in raw_dataset.take(1): pass 38 | example = tf.train.Example() 39 | example.ParseFromString(raw_record.numpy()) 40 | # print(example) 41 | result = {} 42 | # example.features.feature is the dictionary 43 | for key, feature in example.features.feature.items(): 44 | # The values are the Feature objects which contain a `kind` which contains: 45 | # one of three fields: bytes_list, float_list, int64_list 46 | kind = feature.WhichOneof('kind') 47 | result[key] = np.array(getattr(feature, kind).value) 48 | # ch, img_size 49 | self.ch = result['shape'][0] 50 | self.img_size = result['shape'][-1] 51 | 52 | if self.length is None: 53 | import tensorflow as tf 54 | tf.compat.v1.enable_eager_execution() 55 | self.length = 0 56 | for file in self.filenames: 57 | self.length += sum(1 for _ in tf.data.TFRecordDataset(file)) 58 | 59 | self.features = { 60 | # 'shape': db.FixedLenFeature([3], db.int64), 61 | 'data': db.FixedLenFeature([ch, img_size, img_size], db.uint8) 62 | } 63 | 64 | self.buffer_size = 1024 ** 2 * self.buffer_size_mb // (3 * img_size * img_size) 65 | 66 | self.iterator = db.ParsedTFRecordsDatasetIterator(self.filenames, self.features, self.batch_size, self.buffer_size, seed=self.seed) 67 | 68 | def transform(self, x): 69 | return torch.from_numpy(x[0]), torch.zeros(len(x[0])) 70 | 71 | def __iter__(self): 72 | return map(self.transform, self.iterator) 73 | 74 | def __len__(self): 75 | return self.length // self.batch_size 76 | 77 | 78 | class FFHQ_TFRecordsDataLoader(TFRecordsDataLoader): 79 | def __init__(self, tfrecords_paths, batch_size, img_size, 80 | seed=0, length=70000, buffer_size_mb=200): 81 | super().__init__(tfrecords_paths, batch_size, img_size=img_size, seed=seed, 82 | length=length, buffer_size_mb=buffer_size_mb) 83 | -------------------------------------------------------------------------------- /datasets/kth.py: -------------------------------------------------------------------------------- 1 | # https://github.com/edenton/svg/blob/master/data/kth.py 2 | import numpy as np 3 | import os 4 | import pickle 5 | import torch 6 | 7 | from PIL import Image 8 | from torch.utils.data import Dataset 9 | from torchvision import transforms 10 | 11 | from .h5 import HDF5Dataset 12 | 13 | 14 | class KTHDataset(Dataset): 15 | 16 | def __init__(self, data_dir, frames_per_sample=5, train=True, random_time=True, random_horizontal_flip=True, 17 | total_videos=-1, with_target=True, start_at=0): 18 | 19 | self.data_dir = data_dir # '/path/to/Datasets/KTH64_h5' (with shard_0001.hdf5 and persons.pkl in it) 20 | self.train = train 21 | self.frames_per_sample = frames_per_sample 22 | self.random_time = random_time 23 | self.random_horizontal_flip = random_horizontal_flip 24 | self.total_videos = total_videos # If we wish to restrict total number of videos (e.g. for val) 25 | self.with_target = with_target 26 | self.start_at = start_at 27 | 28 | # Read h5 files as dataset 29 | self.videos_ds = HDF5Dataset(self.data_dir) 30 | 31 | # Persons 32 | with open(os.path.join(data_dir, 'persons.pkl'), 'rb') as f: 33 | self.persons = pickle.load(f) 34 | 35 | # Train 36 | self.train_persons = list(range(1, 21)) 37 | self.train_idx = sum([self.persons[p] for p in self.train_persons], []) 38 | # Test 39 | self.test_persons = list(range(21, 26)) 40 | self.test_idx = sum([self.persons[p] for p in self.test_persons], []) 41 | 42 | print(f"Dataset length: {self.__len__()}") 43 | 44 | def len_of_vid(self, index): 45 | video_index = index % self.__len__() 46 | shard_idx, idx_in_shard = self.videos_ds.get_indices(video_index) 47 | with self.videos_ds.opener(self.videos_ds.shard_paths[shard_idx]) as f: 48 | video_len = f['len'][str(idx_in_shard)][()] 49 | return video_len 50 | 51 | def __len__(self): 52 | return self.total_videos if self.total_videos > 0 else len(self.train_idx) if self.train else len(self.test_idx) 53 | 54 | def max_index(self): 55 | return len(self.train_idx) if self.train else len(self.test_idx) 56 | 57 | def __getitem__(self, index, time_idx=0): 58 | 59 | # Use `index` to select the video, and then 60 | # randomly choose a `frames_per_sample` window of frames in the video 61 | video_index = round(index / (self.__len__() - 1) * (self.max_index() - 1)) 62 | shard_idx, idx_in_shard = self.videos_ds.get_indices(video_index) 63 | idx = self.train_idx[int(idx_in_shard)] if self.train else self.test_idx[int(idx_in_shard)] 64 | 65 | prefinals = [] 66 | flip_p = np.random.randint(2) == 0 if self.random_horizontal_flip else 0 67 | with self.videos_ds.opener(self.videos_ds.shard_paths[shard_idx]) as f: 68 | video_len = f['len'][str(idx)][()] - self.start_at 69 | if self.random_time and video_len > self.frames_per_sample: 70 | time_idx = np.random.choice(video_len - self.frames_per_sample) 71 | time_idx += self.start_at 72 | for i in range(time_idx, min(time_idx + self.frames_per_sample, video_len)): 73 | img = f[str(idx)][str(i)][()] 74 | arr = transforms.RandomHorizontalFlip(flip_p)(transforms.ToTensor()(img)) 75 | prefinals.append(arr) 76 | target = int(f['target'][str(idx)][()]) 77 | 78 | if self.with_target: 79 | return torch.stack(prefinals), torch.tensor(target) 80 | else: 81 | return torch.stack(prefinals) 82 | -------------------------------------------------------------------------------- /datasets/kth_convert.py: -------------------------------------------------------------------------------- 1 | # https://github.com/edenton/svg/blob/master/data/convert_bair.py 2 | import argparse 3 | import cv2 4 | import glob 5 | import numpy as np 6 | import os 7 | import pickle 8 | import sys 9 | 10 | from tqdm import tqdm 11 | 12 | from h5 import HDF5Maker 13 | 14 | 15 | class KTH_HDF5Maker(HDF5Maker): 16 | 17 | def add_video_info(self): 18 | pass 19 | 20 | def create_video_groups(self): 21 | self.writer.create_group('len') 22 | self.writer.create_group('person') 23 | self.writer.create_group('target') 24 | 25 | def add_video_data(self, data, dtype=None): 26 | data, person, target = data 27 | self.writer['len'].create_dataset(str(self.count), data=len(data)) 28 | self.writer['person'].create_dataset(str(self.count), data=person, dtype='uint8') 29 | self.writer['target'].create_dataset(str(self.count), data=target, dtype='uint8') 30 | self.writer.create_group(str(self.count)) 31 | for i, frame in enumerate(data): 32 | self.writer[str(self.count)].create_dataset(str(i), data=frame, dtype=dtype, compression="lzf") 33 | 34 | 35 | def read_video(video, image_size): 36 | 37 | cap = cv2.VideoCapture(video) 38 | frames = [] 39 | 40 | while True: 41 | 42 | # Capture frame-by-frame 43 | ret, frame = cap.read() 44 | 45 | # if frame is read correctly ret is True 46 | if not ret: 47 | # print("Can't receive frame (stream end?). Exiting ...") 48 | break 49 | 50 | # Our operations on the frame come here 51 | gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) 52 | image = cv2.resize(gray, (image_size, image_size)) 53 | frames.append(image) 54 | 55 | cap.release() 56 | return frames 57 | 58 | 59 | def show_video(frames): 60 | import matplotlib.pyplot as plt 61 | from matplotlib.animation import FuncAnimation 62 | im1 = plt.imshow(frames[0]) 63 | def update(frame): 64 | im1.set_data(frame) 65 | ani = FuncAnimation(plt.gcf(), update, frames=frames, interval=10, repeat=False) 66 | plt.show() 67 | 68 | 69 | def append_to_dict_list(d, key, value): 70 | if key not in d: 71 | d[key] = [] 72 | d[key].append(value) 73 | 74 | 75 | def make_h5_from_kth(kth_dir, image_size=64, out_dir='./h5_ds', vids_per_shard=1000000, force_h5=False): 76 | 77 | # H5 maker 78 | h5_maker = KTH_HDF5Maker(out_dir, num_per_shard=vids_per_shard, force=force_h5, video=True) 79 | 80 | # data_root = '/path/to/Datasets/KTH' 81 | # image_size = 64 82 | classes = ['boxing', 'handclapping', 'handwaving', 'jogging', 'running', 'walking'] 83 | # frame_rate = 25 84 | 85 | count = 0 86 | persons = {} 87 | targets = {} 88 | for video in tqdm(sorted(glob.glob(os.path.join(kth_dir, 'raw', '*', '*')))): 89 | 90 | try: 91 | frames = read_video(video, image_size) 92 | person = int(os.path.basename(video).split('_')[0].split('person')[-1]) 93 | target = classes.index(video.split('/')[-2]) 94 | h5_maker.add_data((frames, person, target), dtype='uint8') 95 | append_to_dict_list(persons, person, count) 96 | append_to_dict_list(targets, target, count) 97 | count += 1 98 | 99 | except StopIteration: 100 | break 101 | 102 | except (KeyboardInterrupt, SystemExit): 103 | print("Ctrl+C!!") 104 | break 105 | 106 | except: 107 | e = sys.exc_info()[0] 108 | print("ERROR:", e) 109 | 110 | h5_maker.close() 111 | 112 | # Save persons 113 | print("Writing", os.path.join(out_dir, 'persons.pkl')) 114 | with open(os.path.join(out_dir, 'persons.pkl'), 'wb') as f: 115 | pickle.dump(persons, f) 116 | 117 | # Save targets 118 | print("Writing", os.path.join(out_dir, 'targets.pkl')) 119 | with open(os.path.join(out_dir, 'targets.pkl'), 'wb') as f: 120 | pickle.dump(targets, f) 121 | 122 | 123 | if __name__ == "__main__": 124 | 125 | parser = argparse.ArgumentParser() 126 | parser.add_argument('--out_dir', type=str, help="Directory to save .hdf5 files") 127 | parser.add_argument('--kth_dir', type=str, help="Directory with KTH") 128 | parser.add_argument('--image_size', type=int, default=64) 129 | parser.add_argument('--vids_per_shard', type=int, default=1000000) 130 | parser.add_argument('--force_h5', type=eval, default=False) 131 | 132 | args = parser.parse_args() 133 | 134 | make_h5_from_kth(out_dir=args.out_dir, kth_dir=args.kth_dir, image_size=args.image_size, vids_per_shard=args.vids_per_shard, force_h5=args.force_h5) 135 | -------------------------------------------------------------------------------- /datasets/kth_download.sh: -------------------------------------------------------------------------------- 1 | TARGET_DIR=$1 2 | if [ -z $TARGET_DIR ] 3 | then 4 | echo "Must specify target directory" 5 | else 6 | mkdir $TARGET_DIR/processed 7 | URL=http://www.cs.nyu.edu/~denton/datasets/kth.tar.gz 8 | wget $URL -P $TARGET_DIR/processed 9 | tar -zxvf $TARGET_DIR/processed/kth.tar.gz -C $TARGET_DIR/processed/ 10 | rm $TARGET_DIR/processed/kth.tar.gz 11 | 12 | mkdir $TARGET_DIR/raw 13 | for c in walking jogging running handwaving handclapping boxing 14 | do 15 | URL=http://www.nada.kth.se/cvap/actions/"$c".zip 16 | wget $URL -P $TARGET_DIR/raw 17 | mkdir $TARGET_DIR/raw/$c 18 | unzip $TARGET_DIR/raw/"$c".zip -d $TARGET_DIR/raw/$c 19 | rm $TARGET_DIR/raw/"$c".zip 20 | done 21 | 22 | fi -------------------------------------------------------------------------------- /datasets/moving_mnist.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import math 3 | import numpy as np 4 | import os 5 | import random 6 | import sys 7 | import torch 8 | import torch.nn as nn 9 | import torch.utils.data as data 10 | import torch.utils.model_zoo as model_zoo 11 | import torchvision.transforms as transforms 12 | 13 | if sys.version_info[0] == 2: 14 | from urllib import urlretrieve 15 | else: 16 | from urllib.request import urlretrieve 17 | 18 | import progressbar 19 | 20 | from collections import OrderedDict 21 | 22 | pbar = None 23 | 24 | 25 | def mmnist_data_loader(collate_fn=None, n_frames=10, num_digits=1, with_target=False, 26 | batch_size=100, n_workers=8, is_train=True, drop_last=True, 27 | dset_path=os.path.dirname(os.path.realpath(__file__))): 28 | dset = MovingMNIST(dset_path, is_train, n_frames, num_digits, with_target=with_target) 29 | dloader = data.DataLoader(dset, batch_size=batch_size, shuffle=is_train, collate_fn=collate_fn, 30 | num_workers=n_workers, drop_last=drop_last, pin_memory=True) 31 | # Returns images of size [1, 64, 64] in [-1, 1] 32 | return dloader 33 | 34 | 35 | class ToTensor(object): 36 | """Converts a numpy.ndarray (... x H x W x C) in the range 37 | [0, 255] to a torch.FloatTensor of shape (... x C x H x W) in the range [0.0, 1.0]. 38 | """ 39 | def __init__(self, scale=True): 40 | self.scale = scale 41 | def __call__(self, arr): 42 | if isinstance(arr, np.ndarray): 43 | video = torch.from_numpy(np.rollaxis(arr, axis=-1, start=-3)) 44 | if self.scale: 45 | return video.float().div(255) 46 | else: 47 | return video.float() 48 | else: 49 | raise NotImplementedError 50 | 51 | 52 | # def load_mnist(root): 53 | # # Load MNIST dataset for generating training data. 54 | # path = os.path.join(root, 'train-images-idx3-ubyte.gz') 55 | # with gzip.open(path, 'rb') as f: 56 | # mnist = np.frombuffer(f.read(), np.uint8, offset=16) 57 | # mnist = mnist.reshape(-1, 28, 28) 58 | # return mnist 59 | 60 | 61 | # def load_fixed_set(root, is_train): 62 | # # Load the fixed dataset 63 | # filename = 'mnist_test_seq.npy' 64 | # path = os.path.join(root, filename) 65 | # dataset = np.load(path) 66 | # dataset = dataset[..., np.newaxis] 67 | # return dataset 68 | 69 | 70 | # loads mnist from web on demand 71 | def load_mnist(root, is_train=True): 72 | 73 | def load_mnist_images(filename): 74 | if not os.path.exists(os.path.join(root, filename)): 75 | download(root, filename) 76 | with gzip.open(os.path.join(root, filename), 'rb') as f: 77 | data = np.frombuffer(f.read(), np.uint8, offset=16) 78 | data = data.reshape(-1, 28, 28) 79 | return data 80 | 81 | if is_train: 82 | return load_mnist_images('http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz') 83 | return load_mnist_images('http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy') 84 | 85 | 86 | def download(root, filename): 87 | def show_progress(block_num, block_size, total_size): 88 | global pbar 89 | if pbar is None: 90 | pbar = progressbar.ProgressBar(maxval=total_size) 91 | pbar.start() 92 | 93 | downloaded = block_num * block_size 94 | if downloaded < total_size: 95 | pbar.update(downloaded) 96 | else: 97 | pbar.finish() 98 | pbar = None 99 | print("Downloading %s" % os.path.basename(filename)) 100 | os.makedirs(root, exist_ok=True) 101 | urlretrieve(filename, os.path.join(root, os.path.basename(filename)), show_progress) 102 | 103 | 104 | def load_mnist(root): 105 | # Load MNIST dataset for generating training data. 106 | path = os.path.join(root, 'train-images-idx3-ubyte.gz') 107 | if not os.path.exists(path): 108 | download(root, 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz') 109 | with gzip.open(path, 'rb') as f: 110 | mnist = np.frombuffer(f.read(), np.uint8, offset=16) 111 | mnist = mnist.reshape(-1, 28, 28) 112 | return mnist 113 | 114 | 115 | def load_fixed_set(root): 116 | # Load the fixed dataset 117 | path = os.path.join(root, 'mnist_test_seq.npy') 118 | if not os.path.exists(path): 119 | download(root, 'http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy') 120 | dataset = np.load(path) 121 | dataset = dataset[..., np.newaxis] 122 | return dataset 123 | 124 | 125 | class MovingMNIST(data.Dataset): 126 | def __init__(self, root, is_train, n_frames, num_digits, transform=transforms.Compose([ToTensor()]), step_length=0.1, with_target=False): 127 | super(MovingMNIST, self).__init__() 128 | 129 | self.dataset = None 130 | if is_train: 131 | self.mnist = load_mnist(root) 132 | else: 133 | if num_digits != 2: 134 | self.mnist = load_mnist(root) 135 | else: 136 | self.dataset = load_fixed_set(root) 137 | self.length = int(1e4) if self.dataset is None else self.dataset.shape[1] 138 | 139 | self.is_train = is_train 140 | self.num_digits = num_digits 141 | self.n_frames = n_frames 142 | self.transform = transform 143 | self.with_target = with_target 144 | 145 | # For generating data 146 | self.image_size_ = 64 147 | self.digit_size_ = 28 148 | self.step_length_ = step_length 149 | 150 | def get_random_trajectory(self, seq_length): 151 | ''' Generate a random sequence of a MNIST digit ''' 152 | canvas_size = self.image_size_ - self.digit_size_ 153 | x = random.random() 154 | y = random.random() 155 | theta = random.random() * 2 * np.pi 156 | v_y = np.sin(theta) 157 | v_x = np.cos(theta) 158 | 159 | start_y = np.zeros(seq_length) 160 | start_x = np.zeros(seq_length) 161 | for i in range(seq_length): 162 | # Take a step along velocity. 163 | y += v_y * self.step_length_ 164 | x += v_x * self.step_length_ 165 | 166 | # Bounce off edges. 167 | if x <= 0: 168 | x = 0 169 | v_x = -v_x 170 | if x >= 1.0: 171 | x = 1.0 172 | v_x = -v_x 173 | if y <= 0: 174 | y = 0 175 | v_y = -v_y 176 | if y >= 1.0: 177 | y = 1.0 178 | v_y = -v_y 179 | start_y[i] = y 180 | start_x[i] = x 181 | 182 | # Scale to the size of the canvas. 183 | start_y = (canvas_size * start_y).astype(np.int32) 184 | start_x = (canvas_size * start_x).astype(np.int32) 185 | return start_y, start_x 186 | 187 | def generate_moving_mnist(self, num_digits=2): 188 | ''' 189 | Get random trajectories for the digits and generate a video. 190 | ''' 191 | data = np.zeros((self.n_frames, self.image_size_, self.image_size_), dtype=np.float32) 192 | for n in range(num_digits): 193 | # Trajectory 194 | start_y, start_x = self.get_random_trajectory(self.n_frames) 195 | ind = random.randint(0, self.mnist.shape[0] - 1) 196 | digit_image = self.mnist[ind] 197 | for i in range(self.n_frames): 198 | top = start_y[i] 199 | left = start_x[i] 200 | bottom = top + self.digit_size_ 201 | right = left + self.digit_size_ 202 | # Draw digit 203 | data[i, top:bottom, left:right] = np.maximum(data[i, top:bottom, left:right], digit_image) 204 | 205 | data = data[..., np.newaxis] 206 | 207 | return data 208 | 209 | def __getitem__(self, idx): 210 | if self.is_train or self.num_digits != 2: 211 | # Generate data on the fly 212 | images = self.generate_moving_mnist(self.num_digits) 213 | else: 214 | images = self.dataset[:, idx, ...] 215 | 216 | if self.with_target: 217 | targets = np.array(images > 127, dtype=float) * 255.0 218 | 219 | if self.transform is not None: 220 | images = self.transform(images) 221 | if self.with_target: 222 | targets = self.transform(targets) 223 | 224 | if self.with_target: 225 | return images, targets 226 | else: 227 | return images 228 | 229 | def __len__(self): 230 | return self.length 231 | -------------------------------------------------------------------------------- /datasets/stochastic_moving_mnist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from torchvision import datasets, transforms 5 | 6 | 7 | class ToTensor(object): 8 | """Converts a numpy.ndarray (... x H x W x C) to a torch.FloatTensor of shape (... x C x H x W) in the range [0.0, 1.0]. 9 | """ 10 | def __init__(self, scale=True): 11 | self.scale = scale 12 | def __call__(self, arr): 13 | if isinstance(arr, np.ndarray): 14 | video = torch.from_numpy(np.rollaxis(arr, axis=-1, start=-3)) 15 | if self.scale: 16 | return video.float() 17 | else: 18 | return video.float() 19 | else: 20 | raise NotImplementedError 21 | 22 | 23 | # https://github.com/edenton/svg/blob/master/data/moving_mnist.py 24 | class StochasticMovingMNIST(object): 25 | 26 | """Data Handler that creates Bouncing MNIST dataset on the fly.""" 27 | 28 | def __init__(self, data_root, train=True, seq_len=20, num_digits=2, image_size=64, deterministic=False, 29 | step_length=0.1, total_videos=-1, with_target=False, transform=transforms.Compose([ToTensor()])): 30 | path = data_root 31 | self.seq_len = seq_len 32 | self.num_digits = num_digits 33 | self.image_size = image_size 34 | self.step_length = step_length 35 | self.with_target = with_target 36 | self.transform = transform 37 | self.deterministic = deterministic 38 | 39 | self.seed_is_set = False # multi threaded loading 40 | self.digit_size = 32 41 | self.channels = 1 42 | 43 | self.data = datasets.MNIST( 44 | path, 45 | train=train, 46 | download=True, 47 | transform=transforms.Compose( 48 | [transforms.Resize(self.digit_size), 49 | transforms.ToTensor()])) 50 | 51 | self.N = len(self.data) if total_videos == -1 else total_videos 52 | 53 | print(f"Dataset length: {self.__len__()}") 54 | 55 | def set_seed(self, seed): 56 | if not self.seed_is_set: 57 | self.seed_is_set = True 58 | np.random.seed(seed) 59 | 60 | def __len__(self): 61 | return self.N 62 | 63 | def __getitem__(self, index): 64 | self.set_seed(index) 65 | image_size = self.image_size 66 | digit_size = self.digit_size 67 | x = np.zeros((self.seq_len, 68 | image_size, 69 | image_size, 70 | self.channels), 71 | dtype=np.float32) 72 | for n in range(self.num_digits): 73 | idx = np.random.randint(self.N) 74 | digit, _ = self.data[idx] 75 | 76 | sx = np.random.randint(image_size-digit_size) 77 | sy = np.random.randint(image_size-digit_size) 78 | dx = np.random.randint(-4, 5) 79 | dy = np.random.randint(-4, 5) 80 | for t in range(self.seq_len): 81 | if sy < 0: 82 | sy = 0 83 | if self.deterministic: 84 | dy = -dy 85 | else: 86 | dy = np.random.randint(1, 5) 87 | dx = np.random.randint(-4, 5) 88 | elif sy >= image_size-32: 89 | sy = image_size-32-1 90 | if self.deterministic: 91 | dy = -dy 92 | else: 93 | dy = np.random.randint(-4, 0) 94 | dx = np.random.randint(-4, 5) 95 | 96 | if sx < 0: 97 | sx = 0 98 | if self.deterministic: 99 | dx = -dx 100 | else: 101 | dx = np.random.randint(1, 5) 102 | dy = np.random.randint(-4, 5) 103 | elif sx >= image_size-32: 104 | sx = image_size-32-1 105 | if self.deterministic: 106 | dx = -dx 107 | else: 108 | dx = np.random.randint(-4, 0) 109 | dy = np.random.randint(-4, 5) 110 | 111 | x[t, sy:sy+32, sx:sx+32, 0] += digit.numpy().squeeze() 112 | sy += dy 113 | sx += dx 114 | 115 | x[x>1] = 1. 116 | 117 | if self.with_target: 118 | targets = np.array(x >= 0.5, dtype=float) 119 | 120 | if self.transform is not None: 121 | x = self.transform(x) 122 | if self.with_target: 123 | targets = self.transform(targets) 124 | 125 | if self.with_target: 126 | return x, targets 127 | else: 128 | return x 129 | -------------------------------------------------------------------------------- /datasets/ucf101.py: -------------------------------------------------------------------------------- 1 | # https://github.com/edenton/svg/blob/master/data/kth.py 2 | import numpy as np 3 | import os 4 | import pickle 5 | import torch 6 | 7 | from PIL import Image 8 | from torch.utils.data import Dataset 9 | from torchvision import transforms 10 | 11 | from .h5 import HDF5Dataset 12 | 13 | 14 | class UCF101Dataset(Dataset): 15 | 16 | def __init__(self, data_path, frames_per_sample=5, image_size=64, train=True, random_time=True, random_horizontal_flip=True, 17 | total_videos=-1, skip_videos=0, with_target=True): 18 | 19 | self.data_path = data_path # '/path/to/Datasets/UCF101_64_h5' (with .hdf5 file in it), or to the hdf5 file itself 20 | self.train = train 21 | self.frames_per_sample = frames_per_sample 22 | self.image_size = image_size 23 | self.random_time = random_time 24 | self.random_horizontal_flip = random_horizontal_flip 25 | self.total_videos = total_videos # If we wish to restrict total number of videos (e.g. for val) 26 | self.with_target = with_target 27 | 28 | # Read h5 files as dataset 29 | self.videos_ds = HDF5Dataset(self.data_path) 30 | 31 | # Train 32 | # self.num_train_vids = 9624 33 | # self.num_test_vids = 3696 # -> 369 : https://arxiv.org/pdf/1511.05440.pdf takes every 10th test video 34 | with self.videos_ds.opener(self.videos_ds.shard_paths[0]) as f: 35 | self.num_train_vids = f['num_train'][()] 36 | self.num_test_vids = f['num_test'][()]//10 # https://arxiv.org/pdf/1511.05440.pdf takes every 10th test video 37 | 38 | print(f"Dataset length: {self.__len__()}") 39 | 40 | def len_of_vid(self, index): 41 | video_index = index % self.__len__() 42 | shard_idx, idx_in_shard = self.videos_ds.get_indices(video_index) 43 | with self.videos_ds.opener(self.videos_ds.shard_paths[shard_idx]) as f: 44 | video_len = f['len'][str(idx_in_shard)][()] 45 | return video_len 46 | 47 | def __len__(self): 48 | return self.total_videos if self.total_videos > 0 else self.num_train_vids if self.train else self.num_test_vids 49 | 50 | def max_index(self): 51 | return self.num_train_vids if self.train else self.num_test_vids 52 | 53 | def __getitem__(self, index, time_idx=0): 54 | 55 | # Use `index` to select the video, and then 56 | # randomly choose a `frames_per_sample` window of frames in the video 57 | video_index = round(index / (self.__len__() - 1) * (self.max_index() - 1)) 58 | if not self.train: 59 | video_index = video_index * 10 + self.num_train_vids # https://arxiv.org/pdf/1511.05440.pdf takes every 10th test video 60 | shard_idx, idx_in_shard = self.videos_ds.get_indices(video_index) 61 | 62 | # random crop 63 | crop_c = np.random.randint(int(self.image_size/240*320) - self.image_size) if self.train else int((self.image_size/240*320 - self.image_size)/2) 64 | 65 | # random horizontal flip 66 | flip_p = np.random.randint(2) == 0 if self.random_horizontal_flip else 0 67 | 68 | # read data 69 | prefinals = [] 70 | with self.videos_ds.opener(self.videos_ds.shard_paths[shard_idx]) as f: 71 | target = int(f['target'][str(idx_in_shard)][()]) 72 | # slice data 73 | video_len = f['len'][str(idx_in_shard)][()] 74 | if self.random_time and video_len > self.frames_per_sample: 75 | time_idx = np.random.choice(video_len - self.frames_per_sample) 76 | for i in range(time_idx, min(time_idx + self.frames_per_sample, video_len)): 77 | img = f[str(idx_in_shard)][str(i)][()] 78 | arr = transforms.RandomHorizontalFlip(flip_p)(transforms.ToTensor()(img[:, crop_c:crop_c + self.image_size])) 79 | prefinals.append(arr) 80 | 81 | video = torch.stack(prefinals) 82 | 83 | if self.with_target: 84 | return video, torch.tensor(target) 85 | else: 86 | return video 87 | -------------------------------------------------------------------------------- /datasets/ucf101_convert.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import glob 4 | import imageio 5 | import numpy as np 6 | import os 7 | import sys 8 | 9 | from functools import partial 10 | from multiprocessing import Pool 11 | from PIL import Image 12 | from tqdm import tqdm 13 | 14 | from h5 import HDF5Maker 15 | 16 | 17 | class UCF101_HDF5Maker(HDF5Maker): 18 | 19 | def create_video_groups(self): 20 | self.writer.create_group('len') 21 | self.writer.create_group('data') 22 | self.writer.create_group('target') 23 | 24 | def add_video_data(self, data, dtype=None): 25 | data, target = data 26 | self.writer['len'].create_dataset(str(self.count), data=len(data)) 27 | self.writer['target'].create_dataset(str(self.count), data=target, dtype='uint8') 28 | self.writer.create_group(str(self.count)) 29 | for i, frame in enumerate(data): 30 | self.writer[str(self.count)].create_dataset(str(i), data=frame, dtype=dtype, compression="lzf") 31 | 32 | 33 | def center_crop(image): 34 | h, w, c = image.shape 35 | new_h, new_w = h if h < w else w, w if w < h else h 36 | r_min, r_max = h//2 - new_h//2, h//2 + new_h//2 37 | c_min, c_max = w//2 - new_w//2, w//2 + new_w//2 38 | return image[r_min:r_max, c_min:c_max, :] 39 | 40 | 41 | def read_video(video_file, image_size): 42 | frames = [] 43 | reader = imageio.get_reader(video_file) 44 | h, w = 240, 320 45 | new_h = image_size 46 | new_w = int(new_h / h * w) 47 | for img in reader: 48 | # img_cc = center_crop(img) 49 | pil_im = Image.fromarray(img) 50 | pil_im_rsz = pil_im.resize((new_w, new_h), Image.LANCZOS) 51 | frames.append(np.array(pil_im_rsz)) 52 | # frames.append(np.array(img)) 53 | return np.stack(frames) 54 | 55 | 56 | def process_video(video_file, image_size): 57 | frames = [] 58 | try: 59 | frames = read_video(video_file, image_size) 60 | except StopIteration: 61 | pass 62 | # break 63 | except (KeyboardInterrupt, SystemExit): 64 | print("Ctrl+C!!") 65 | return "break" 66 | except: 67 | e = sys.exc_info()[0] 68 | print("ERROR:", e) 69 | return frames 70 | 71 | 72 | def read_splits(splits_dir, split_idx, ucf_dir): 73 | # train 74 | txt_train = os.path.join(splits_dir, f"trainlist0{split_idx}.txt") 75 | vids_train = open(txt_train, 'r').read().splitlines() 76 | vids_train = [os.path.join(ucf_dir, line.split('.avi')[0] + '.avi') for line in vids_train] 77 | # test 78 | txt_test = os.path.join(splits_dir, f"testlist0{split_idx}.txt") 79 | vids_test = open(txt_test, 'r').read().splitlines() 80 | vids_test = [os.path.join(ucf_dir, line) for line in vids_test] 81 | # classes 82 | classes = {line.split(' ')[-1]: int(line.split(' ')[0])-1 for line in open(os.path.join(splits_dir, 'classInd.txt'), 'r').read().splitlines()} 83 | classes_train = [classes[os.path.basename(os.path.dirname(f))] for f in vids_train] 84 | classes_test = [classes[os.path.basename(os.path.dirname(f))] for f in vids_test] 85 | return vids_train, vids_test, classes_train, classes_test 86 | 87 | 88 | # def make_h5_from_ucf_multi(ucf_dir, splits_dir, split_idx, out_dir='./h5_ds', vids_per_shard=100000, force_h5=False): 89 | 90 | # # H5 maker 91 | # h5_maker = UCF101_HDF5Maker(out_dir, num_per_shard=vids_per_shard, force=force_h5, video=True) 92 | 93 | # vids_train, vids_test, classes_train, classes_test = read_splits(splits_dir, split_idx, ucf_dir) 94 | # print("Train:", len(vids_train), "\nTest", len(vids_test)) 95 | 96 | # h5_maker.writer.create_dataset('num_train', data=len(vids_train)) 97 | # h5_maker.writer.create_dataset('num_test', data=len(vids_test)) 98 | # videos = vids_train + vids_test 99 | # classes = classes_train + classes_test 100 | 101 | # # Process videos 100 at a time 102 | # pbar = tqdm(total=len(videos)) 103 | # for i in range(int(np.ceil(len(videos)/100))): 104 | 105 | # # pool 106 | # with Pool() as pool: 107 | # # tic = time.time() 108 | # results = pool.imap(process_video, [(v, c) for v, c in zip(videos[i*100:(i+1)*100], classes[i*100:(i+1)*100])]) 109 | # # add frames to h5 110 | # for result in results: 111 | # frames, t = result 112 | # if len(frames) > 0: 113 | # h5_maker.add_data(result, dtype='uint8') 114 | # # toc = time.time() 115 | 116 | # pbar.update(len(videos[i*100:(i+1)*100])) 117 | 118 | # pbar.close() 119 | # h5_maker.close() 120 | 121 | 122 | def make_h5_from_ucf(ucf_dir, splits_dir, split_idx, image_size, out_dir='./h5_ds', vids_per_shard=100000, force_h5=False): 123 | 124 | # H5 maker 125 | h5_maker = UCF101_HDF5Maker(out_dir, num_per_shard=vids_per_shard, force=force_h5, video=True) 126 | 127 | vids_train, vids_test, classes_train, classes_test = read_splits(splits_dir, split_idx, ucf_dir) 128 | print("Train:", len(vids_train), "\nTest", len(vids_test)) 129 | 130 | h5_maker.writer.create_dataset('num_train', data=len(vids_train)) 131 | h5_maker.writer.create_dataset('num_test', data=len(vids_test)) 132 | videos = vids_train + vids_test 133 | classes = classes_train + classes_test 134 | 135 | for i in tqdm(range(len(videos))): 136 | frames = process_video(videos[i], image_size) 137 | if isinstance(frames, str) and frames == "break": 138 | break 139 | h5_maker.add_data((frames, classes[i]), dtype='uint8') 140 | 141 | h5_maker.close() 142 | 143 | 144 | if __name__ == "__main__": 145 | 146 | parser = argparse.ArgumentParser() 147 | parser.add_argument('--out_dir', type=str, help="Directory to save .hdf5 files") 148 | parser.add_argument('--ucf_dir', type=str, help="Path to UCF-101 videos") 149 | parser.add_argument('--splits_dir', type=str, help="Path to ucfTrainTestlist") 150 | parser.add_argument('--split_idx', type=int, choices=[1, 2, 3], default=3, help="Which split to use") 151 | parser.add_argument('--image_size', type=int, default=64) 152 | parser.add_argument('--vids_per_shard', type=int, default=100000) 153 | parser.add_argument('--force_h5', type=eval, default=False) 154 | 155 | args = parser.parse_args() 156 | 157 | make_h5_from_ucf(out_dir=args.out_dir, ucf_dir=args.ucf_dir, splits_dir=args.splits_dir, split_idx=args.split_idx, 158 | image_size=args.image_size, vids_per_shard=args.vids_per_shard, force_h5=args.force_h5) 159 | -------------------------------------------------------------------------------- /datasets/ucf101_download.sh: -------------------------------------------------------------------------------- 1 | TARGET_DIR=$1 2 | if [ -z $TARGET_DIR ] 3 | then 4 | echo "Must specify target directory" 5 | else 6 | mkdir $TARGET_DIR/ 7 | # Download UCF101.rar (6.5GB) 8 | wget -P $TARGET_DIR https://www.crcv.ucf.edu/data/UCF101/UCF101.rar 9 | # Unrar 10 | unrar x $TARGET_DIR/UCF101.rar $TARGET_DIR 11 | # Download splits 12 | wget -P $TARGET_DIR https://www.crcv.ucf.edu/data/UCF101/UCF101TrainTestSplits-RecognitionTask.zip 13 | # Unzip 14 | unzip $TARGET_DIR/UCF101TrainTestSplits-RecognitionTask.zip -d $TARGET_DIR 15 | fi 16 | -------------------------------------------------------------------------------- /datasets/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import hashlib 4 | import errno 5 | from torch.utils.model_zoo import tqdm 6 | 7 | 8 | def gen_bar_updater(): 9 | pbar = tqdm(total=None) 10 | 11 | def bar_update(count, block_size, total_size): 12 | if pbar.total is None and total_size: 13 | pbar.total = total_size 14 | progress_bytes = count * block_size 15 | pbar.update(progress_bytes - pbar.n) 16 | 17 | return bar_update 18 | 19 | 20 | def check_integrity(fpath, md5=None): 21 | if md5 is None: 22 | return True 23 | if not os.path.isfile(fpath): 24 | return False 25 | md5o = hashlib.md5() 26 | with open(fpath, 'rb') as f: 27 | # read in 1MB chunks 28 | for chunk in iter(lambda: f.read(1024 * 1024), b''): 29 | md5o.update(chunk) 30 | md5c = md5o.hexdigest() 31 | if md5c != md5: 32 | return False 33 | return True 34 | 35 | 36 | def makedir_exist_ok(dirpath): 37 | """ 38 | Python2 support for os.makedirs(.., exist_ok=True) 39 | """ 40 | try: 41 | os.makedirs(dirpath) 42 | except OSError as e: 43 | if e.errno == errno.EEXIST: 44 | pass 45 | else: 46 | raise 47 | 48 | 49 | def download_url(url, root, filename=None, md5=None): 50 | """Download a file from a url and place it in root. 51 | 52 | Args: 53 | url (str): URL to download file from 54 | root (str): Directory to place downloaded file in 55 | filename (str, optional): Name to save the file under. If None, use the basename of the URL 56 | md5 (str, optional): MD5 checksum of the download. If None, do not check 57 | """ 58 | from six.moves import urllib 59 | 60 | root = os.path.expanduser(root) 61 | if not filename: 62 | filename = os.path.basename(url) 63 | fpath = os.path.join(root, filename) 64 | 65 | makedir_exist_ok(root) 66 | 67 | # downloads file 68 | if os.path.isfile(fpath) and check_integrity(fpath, md5): 69 | print('Using downloaded and verified file: ' + fpath) 70 | else: 71 | try: 72 | print('Downloading ' + url + ' to ' + fpath) 73 | urllib.request.urlretrieve( 74 | url, fpath, 75 | reporthook=gen_bar_updater() 76 | ) 77 | except OSError: 78 | if url[:5] == 'https': 79 | url = url.replace('https:', 'http:') 80 | print('Failed download. Trying https -> http instead.' 81 | ' Downloading ' + url + ' to ' + fpath) 82 | urllib.request.urlretrieve( 83 | url, fpath, 84 | reporthook=gen_bar_updater() 85 | ) 86 | 87 | 88 | def list_dir(root, prefix=False): 89 | """List all directories at a given root 90 | 91 | Args: 92 | root (str): Path to directory whose folders need to be listed 93 | prefix (bool, optional): If true, prepends the path to each result, otherwise 94 | only returns the name of the directories found 95 | """ 96 | root = os.path.expanduser(root) 97 | directories = list( 98 | filter( 99 | lambda p: os.path.isdir(os.path.join(root, p)), 100 | os.listdir(root) 101 | ) 102 | ) 103 | 104 | if prefix is True: 105 | directories = [os.path.join(root, d) for d in directories] 106 | 107 | return directories 108 | 109 | 110 | def list_files(root, suffix, prefix=False): 111 | """List all files ending with a suffix at a given root 112 | 113 | Args: 114 | root (str): Path to directory whose folders need to be listed 115 | suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). 116 | It uses the Python "str.endswith" method and is passed directly 117 | prefix (bool, optional): If true, prepends the path to each result, otherwise 118 | only returns the name of the files found 119 | """ 120 | root = os.path.expanduser(root) 121 | files = list( 122 | filter( 123 | lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix), 124 | os.listdir(root) 125 | ) 126 | ) 127 | 128 | if prefix is True: 129 | files = [os.path.join(root, d) for d in files] 130 | 131 | return files 132 | 133 | 134 | def download_file_from_google_drive(file_id, root, filename=None, md5=None): 135 | """Download a Google Drive file from and place it in root. 136 | 137 | Args: 138 | file_id (str): id of file to be downloaded 139 | root (str): Directory to place downloaded file in 140 | filename (str, optional): Name to save the file under. If None, use the id of the file. 141 | md5 (str, optional): MD5 checksum of the download. If None, do not check 142 | """ 143 | # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url 144 | import requests 145 | url = "https://docs.google.com/uc?export=download" 146 | 147 | root = os.path.expanduser(root) 148 | if not filename: 149 | filename = file_id 150 | fpath = os.path.join(root, filename) 151 | 152 | makedir_exist_ok(root) 153 | 154 | if os.path.isfile(fpath) and check_integrity(fpath, md5): 155 | print('Using downloaded and verified file: ' + fpath) 156 | else: 157 | session = requests.Session() 158 | 159 | response = session.get(url, params={'id': file_id}, stream=True) 160 | token = _get_confirm_token(response) 161 | 162 | if token: 163 | params = {'id': file_id, 'confirm': token} 164 | response = session.get(url, params=params, stream=True) 165 | 166 | _save_response_content(response, fpath) 167 | 168 | 169 | def _get_confirm_token(response): 170 | for key, value in response.cookies.items(): 171 | if key.startswith('download_warning'): 172 | return value 173 | 174 | return None 175 | 176 | 177 | def _save_response_content(response, destination, chunk_size=32768): 178 | with open(destination, "wb") as f: 179 | pbar = tqdm(total=None) 180 | progress = 0 181 | for chunk in response.iter_content(chunk_size): 182 | if chunk: # filter out keep-alive new chunks 183 | f.write(chunk) 184 | progress += len(chunk) 185 | pbar.update(progress - pbar.n) 186 | pbar.close() 187 | -------------------------------------------------------------------------------- /datasets/vision.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.data as data 4 | 5 | 6 | class VisionDataset(data.Dataset): 7 | _repr_indent = 4 8 | 9 | def __init__(self, root): 10 | if isinstance(root, torch._six.string_classes): 11 | root = os.path.expanduser(root) 12 | self.root = root 13 | 14 | def __getitem__(self, index): 15 | raise NotImplementedError 16 | 17 | def __len__(self): 18 | raise NotImplementedError 19 | 20 | def __repr__(self): 21 | head = "Dataset " + self.__class__.__name__ 22 | body = ["Number of datapoints: {}".format(self.__len__())] 23 | if self.root is not None: 24 | body.append("Root location: {}".format(self.root)) 25 | body += self.extra_repr().splitlines() 26 | if hasattr(self, 'transform') and self.transform is not None: 27 | body += self._format_transform_repr(self.transform, 28 | "Transforms: ") 29 | if hasattr(self, 'target_transform') and self.target_transform is not None: 30 | body += self._format_transform_repr(self.target_transform, 31 | "Target transforms: ") 32 | lines = [head] + [" " * self._repr_indent + line for line in body] 33 | return '\n'.join(lines) 34 | 35 | def _format_transform_repr(self, transform, head): 36 | lines = transform.__repr__().splitlines() 37 | return (["{}{}".format(head, lines[0])] + 38 | ["{}{}".format(" " * len(head), line) for line in lines[1:]]) 39 | 40 | def extra_repr(self): 41 | return "" 42 | -------------------------------------------------------------------------------- /evaluation/nearest_neighbor.py: -------------------------------------------------------------------------------- 1 | """ 2 | prdc 3 | Copyright (c) 2020-present NAVER Corp. 4 | Modified by Yang Song (yangsong@cs.stanford.edu) 5 | MIT license 6 | """ 7 | import sklearn.metrics 8 | import pathlib 9 | 10 | import numpy as np 11 | import torch 12 | from torchvision.datasets import LSUN, CelebA, CIFAR10 13 | from datasets.ffhq import FFHQ 14 | from torch.utils.data import DataLoader 15 | from torchvision.transforms import Compose, Resize, CenterCrop, RandomHorizontalFlip, ToPILImage, ToTensor 16 | from torchvision.utils import save_image 17 | from scipy import linalg 18 | from torch.nn.functional import adaptive_avg_pool2d 19 | import os 20 | import argparse 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--path', type=str, required=True) 23 | parser.add_argument('--k', type=int, default=9) 24 | parser.add_argument('--n_samples', type=int, default=10) 25 | parser.add_argument('--dataset', type=str, required=True) 26 | parser.add_argument('-i', type=str, required=True) 27 | 28 | from PIL import Image 29 | 30 | try: 31 | from tqdm import tqdm 32 | except ImportError: 33 | # If not tqdm is not available, provide a mock version of it 34 | def tqdm(x): return x 35 | 36 | from evaluation.inception import InceptionV3 37 | 38 | def imread(filename): 39 | """ 40 | Loads an image file into a (height, width, 3) uint8 ndarray. 41 | """ 42 | return np.asarray(Image.open(filename), dtype=np.uint8)[..., :3] 43 | 44 | 45 | def get_activations(model, images, dims=2048): 46 | # Reshape to (n_images, 3, height, width) 47 | with torch.no_grad(): 48 | pred = model(images)[0] 49 | 50 | # If model output is not scalar, apply global spatial average pooling. 51 | # This happens if you choose a dimensionality not equal 2048. 52 | if pred.size(2) != 1 or pred.size(3) != 1: 53 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 54 | 55 | return pred.reshape(pred.size(0), -1) 56 | 57 | 58 | def _compute_features_of_path(path, model, batch_size, dims, cuda): 59 | if path.endswith('.npz'): 60 | f = np.load(path) 61 | act = f['features'][:] 62 | f.close() 63 | else: 64 | path = pathlib.Path(path) 65 | files = list(path.glob('*.jpg')) + list(path.glob('*.png')) 66 | act = get_activations(files, model, batch_size, dims, cuda, verbose=False) 67 | return act 68 | 69 | 70 | def get_nearest_neighbors(dataset, path, name, n_samples, k=10, cuda=True): 71 | if not os.path.exists(path): 72 | raise RuntimeError('Invalid path: %s' % path) 73 | 74 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048] 75 | 76 | model = InceptionV3([block_idx]) 77 | if cuda: 78 | model.cuda() 79 | model.eval() 80 | 81 | flipper = RandomHorizontalFlip(p=1.) 82 | to_pil = ToPILImage() 83 | to_tensor = ToTensor() 84 | data_features = [] 85 | data = [] 86 | for x, _ in tqdm(dataset, desc="sweeping the whole dataset"): 87 | if cuda: x = x.cuda() 88 | data_features.append(get_activations(model, x).cpu()) 89 | data.append(x.cpu()) 90 | 91 | data_features = torch.cat(data_features, dim=0) 92 | data = torch.cat(data, dim=0) 93 | 94 | samples = torch.load(path)[:n_samples] 95 | flipped_samples = torch.stack([to_tensor(flipper(to_pil(img))) for img in samples], dim=0) 96 | if cuda: 97 | samples = samples.cuda() 98 | flipped_samples = flipped_samples.cuda() 99 | 100 | sample_features = get_activations(model, samples).cpu() 101 | flip_sample_feature = get_activations(model, flipped_samples).cpu() 102 | sample_cdist = torch.cdist(sample_features, data_features) 103 | flip_sample_cdist = torch.cdist(flip_sample_feature, data_features) 104 | 105 | plot_data = [] 106 | for i in tqdm(range(len(samples)), desc='find nns and save images'): 107 | plot_data.append(samples[i].cpu()) 108 | all_dists = torch.min(sample_cdist[i], flip_sample_cdist[i]) 109 | indices = torch.topk(-all_dists, k=k)[1] 110 | for ind in indices: 111 | plot_data.append(data[ind]) 112 | 113 | plot_data = torch.stack(plot_data, dim=0) 114 | save_image(plot_data, '{}.png'.format(name), nrow=k+1) 115 | 116 | 117 | if __name__ == '__main__': 118 | args = parser.parse_args() 119 | if args.dataset == 'church': 120 | transforms = Compose([ 121 | Resize(96), 122 | CenterCrop(96), 123 | ToTensor() 124 | ]) 125 | dataset = LSUN('exp/datasets/lsun', ['church_outdoor_train'], transform=transforms) 126 | 127 | elif args.dataset == 'tower' or args.dataset == 'bedroom': 128 | transforms = Compose([ 129 | Resize(128), 130 | CenterCrop(128), 131 | ToTensor() 132 | ]) 133 | dataset = LSUN('exp/datasets/lsun', ['{}_train'.format(args.dataset)], transform=transforms) 134 | 135 | elif args.dataset == 'celeba': 136 | transforms = Compose([ 137 | CenterCrop(140), 138 | Resize(64), 139 | ToTensor(), 140 | ]) 141 | dataset = CelebA('exp/datasets/celeba', split='train', transform=transforms) 142 | 143 | elif args.dataset == 'cifar10': 144 | dataset = CIFAR10('exp/datasets/cifar10', train=True, transform=ToTensor()) 145 | elif args.dataset == 'ffhq': 146 | dataset = FFHQ(path='exp/datasets/FFHQ', transform=ToTensor(), resolution=256) 147 | 148 | dataloader = DataLoader(dataset, batch_size=128, drop_last=False) 149 | get_nearest_neighbors(dataloader, args.path, args.i, args.n_samples, args.k, torch.cuda.is_available()) 150 | -------------------------------------------------------------------------------- /evaluation/pr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def calc_cdist(feat1, feat2, batch_size=10000): 4 | dists = [] 5 | for feat2_batch in feat2.split(batch_size): 6 | dists.append(torch.cdist(feat1, feat2_batch).cpu()) 7 | return torch.cat(dists, dim=1) 8 | 9 | 10 | def calculate_precision_recall_part(feat_r, feat_g, k=3, batch_size=10000): 11 | # Precision 12 | NNk_r = [] 13 | for feat_r_batch in feat_r.split(batch_size): 14 | NNk_r.append(calc_cdist(feat_r_batch, feat_r, batch_size).kthvalue(k+1).values) 15 | NNk_r = torch.cat(NNk_r) 16 | precision = [] 17 | for feat_g_batch in feat_g.split(batch_size): 18 | dist_g_r_batch = calc_cdist(feat_g_batch, feat_r, batch_size) 19 | precision.append((dist_g_r_batch <= NNk_r).any(dim=1).float()) 20 | precision = torch.cat(precision).mean().item() 21 | # Recall 22 | NNk_g = [] 23 | for feat_g_batch in feat_g.split(batch_size): 24 | NNk_g.append(calc_cdist(feat_g_batch, feat_g, batch_size).kthvalue(k+1).values) 25 | NNk_g = torch.cat(NNk_g) 26 | recall = [] 27 | for feat_r_batch in feat_r.split(batch_size): 28 | dist_r_g_batch = calc_cdist(feat_r_batch, feat_g, batch_size) 29 | recall.append((dist_r_g_batch <= NNk_g).any(dim=1).float()) 30 | recall = torch.cat(recall).mean().item() 31 | return precision, recall 32 | 33 | 34 | def calc_cdist_full(feat1, feat2, batch_size=10000): 35 | dists = [] 36 | for feat1_batch in feat1.split(batch_size): 37 | dists_batch = [] 38 | for feat2_batch in feat2.split(batch_size): 39 | dists_batch.append(torch.cdist(feat1_batch, feat2_batch).cpu()) 40 | dists.append(torch.cat(dists_batch, dim=1)) 41 | return torch.cat(dists, dim=0) 42 | 43 | 44 | def calculate_precision_recall_full(feat_r, feat_g, k=3, batch_size=10000): 45 | NNk_r = calc_cdist_full(feat_r, feat_r, batch_size).kthvalue(k+1).values 46 | NNk_g = calc_cdist_full(feat_g, feat_g, batch_size).kthvalue(k+1).values 47 | dist_g_r = calc_cdist_full(feat_g, feat_r, batch_size) 48 | dist_r_g = dist_g_r.T 49 | # Precision 50 | precision = (dist_g_r <= NNk_r).any(dim=1).float().mean().item() 51 | # Recall 52 | recall = (dist_r_g <= NNk_g).any(dim=1).float().mean().item() 53 | return precision, recall 54 | 55 | 56 | def calculate_precision_recall(feat_r, feat_g, device=torch.device('cuda'), k=3, 57 | batch_size=10000, save_cpu_ram=False, **kwargs): 58 | feat_r = feat_r.to(device) 59 | feat_g = feat_g.to(device) 60 | if save_cpu_ram: 61 | return calculate_precision_recall_part(feat_r, feat_g, k, batch_size) 62 | else: 63 | return calculate_precision_recall_full(feat_r, feat_g, k, batch_size) 64 | 65 | -------------------------------------------------------------------------------- /example_scripts/final/base_1f.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #module load python/3.8.10 StdEnv/2020 cuda/11.0 cudnn/8.0.3 4 | #source $HOME/vidgen2/bin/activate 5 | #rsync -avz --no-g --no-p /path/to/mask-cond-video-diffusion $SLURM_TMPDIR 6 | #cd $SLURM_TMPDIR/mask-cond-video-diffusion 7 | 8 | ## Example 9 | #config="kth64" 10 | #data="/path/to/datasets/KTH64_h5" 11 | #devices="0,1" 12 | #exp=/scratch/${user}/checkpoints/my_exp 13 | #config_mod="sampling.num_frames_pred=20 data.num_frames=5 data.num_frames_cond=10 training.batch_size=64 sampling.subsample=100 sampling.clip_before=True sampling.batch_size=100 sampling.max_data_iter=1 model.version=DDPM model.arch=unetmore" 14 | 15 | # Test ddim 16 | CUDA_VISIBLE_DEVICES=0 python main.py --config configs/${config}.yml --data_path ${data} --exp ${exp} --ni --config_mod ${config_mod} 17 | -------------------------------------------------------------------------------- /example_scripts/final/base_1f_2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #module load python/3.8.10 StdEnv/2020 cuda/11.0 cudnn/8.0.3 4 | #source $HOME/vidgen2/bin/activate 5 | #rsync -avz --no-g --no-p /path/to/mask-cond-video-diffusion $SLURM_TMPDIR 6 | #cd $SLURM_TMPDIR/mask-cond-video-diffusion 7 | 8 | ## Example 9 | #config="kth64" 10 | #data="/path/to/datasets/KTH64_h5" 11 | #devices="0,1" 12 | #exp=/scratch/${user}/checkpoints/my_exp 13 | #config_mod="sampling.num_frames_pred=20 data.num_frames=5 data.num_frames_cond=10 training.batch_size=64 sampling.subsample=100 sampling.clip_before=True sampling.batch_size=100 sampling.max_data_iter=1 model.version=DDPM model.arch=unetmore" 14 | 15 | # Test ddim 16 | CUDA_VISIBLE_DEVICES=0,1 python main.py --config configs/${config}.yml --data_path ${data} --exp ${exp} --ni --config_mod ${config_mod} 17 | -------------------------------------------------------------------------------- /example_scripts/final/base_1f_4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #module load python/3.8.10 StdEnv/2020 cuda/11.0 cudnn/8.0.3 4 | #source $HOME/vidgen2/bin/activate 5 | #rsync -avz --no-g --no-p /path/to/mask-cond-video-diffusion $SLURM_TMPDIR 6 | #cd $SLURM_TMPDIR/mask-cond-video-diffusion 7 | 8 | ## Example 9 | #config="kth64" 10 | #data="/path/to/datasets/KTH64_h5" 11 | #devices="0,1,2,3" 12 | #exp=/scratch/${user}/checkpoints/my_exp 13 | #config_mod="sampling.num_frames_pred=20 data.num_frames=5 data.num_frames_cond=10 training.batch_size=64 sampling.subsample=100 sampling.clip_before=True sampling.batch_size=100 sampling.max_data_iter=1 model.version=DDPM model.arch=unetmore" 14 | 15 | # Test ddim 16 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --config configs/${config}.yml --data_path ${data} --exp ${exp} --ni --config_mod ${config_mod} 17 | -------------------------------------------------------------------------------- /example_scripts/final/base_1f_vidgen.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #module load python/3.8.10 StdEnv/2020 cuda/11.0 cudnn/8.0.3 4 | #source $HOME/vidgen2/bin/activate 5 | #rsync -avz --no-g --no-p /path/to/mask-cond-video-diffusion $SLURM_TMPDIR 6 | #cd $SLURM_TMPDIR/mask-cond-video-diffusion 7 | 8 | ## Example 9 | #config="kth64" 10 | #data="/home/${user}/scratch/datasets/KTH64_h5" 11 | #devices="0,1" 12 | #exp=/scratch/${user}/checkpoints/my_exp 13 | #config_mod="sampling.num_frames_pred=20 data.num_frames=5 data.num_frames_cond=10 training.batch_size=64 sampling.subsample=100 sampling.clip_before=True sampling.batch_size=100 sampling.max_data_iter=1 model.version=DDPM model.arch=unetmore" 14 | #ckpt=60000 15 | 16 | # Test ddpm 100 17 | CUDA_VISIBLE_DEVICES=0 python main.py --config configs/${config}.yml --data_path ${data} --exp ${exp} --ni --config_mod ${config_mod} sampling.num_frames_pred=${nfp} sampling.preds_per_test=10 sampling.subsample=100 model.version=DDPM --ckpt ${ckpt} --video_gen -v videos_${ckpt}_DDPM_100_nfp_${nfp} 18 | 19 | # Test ddim 100 20 | CUDA_VISIBLE_DEVICES=0 python main.py --config configs/${config}.yml --data_path ${data} --exp ${exp} --ni --config_mod ${config_mod} sampling.num_frames_pred=${nfp} sampling.preds_per_test=10 sampling.subsample=100 model.version=DDIM --ckpt ${ckpt} --video_gen -v videos_${ckpt}_DDIM_100_nfp_${nfp} 21 | 22 | # Test ddpm 1000 23 | CUDA_VISIBLE_DEVICES=0 python main.py --config configs/${config}.yml --data_path ${data} --exp ${exp} --ni --config_mod ${config_mod} sampling.num_frames_pred=${nfp} sampling.preds_per_test=10 sampling.subsample=1000 model.version=DDPM --ckpt ${ckpt} --video_gen -v videos_${ckpt}_DDPM_1000_nfp_${nfp} 24 | -------------------------------------------------------------------------------- /example_scripts/final/base_1f_vidgen_short.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #module load python/3.8.10 StdEnv/2020 cuda/11.0 cudnn/8.0.3 4 | #source $HOME/vidgen2/bin/activate 5 | #rsync -avz --no-g --no-p /path/to/mask-cond-video-diffusion $SLURM_TMPDIR 6 | #cd $SLURM_TMPDIR/mask-cond-video-diffusion 7 | 8 | ## Example 9 | #config="kth64" 10 | #data="/home/${user}/scratch/datasets/KTH64_h5" 11 | #devices="0,1" 12 | #exp=/scratch/${user}/checkpoints/my_exp 13 | #config_mod="sampling.num_frames_pred=20 data.num_frames=5 data.num_frames_cond=10 training.batch_size=64 sampling.subsample=100 sampling.clip_before=True sampling.batch_size=100 sampling.max_data_iter=1 model.version=DDPM model.arch=unetmore" 14 | #ckpt=60000 15 | 16 | # Test ddpm 100 17 | CUDA_VISIBLE_DEVICES=0 python main.py --config configs/${config}.yml --data_path ${data} --exp ${exp} --ni --config_mod ${config_mod} sampling.num_frames_pred=${nfp} sampling.preds_per_test=10 sampling.subsample=100 model.version=DDPM --ckpt ${ckpt} --video_gen -v videos_${ckpt}_DDPM_100_nfp_${nfp} 18 | 19 | # Test ddim 100 20 | CUDA_VISIBLE_DEVICES=0 python main.py --config configs/${config}.yml --data_path ${data} --exp ${exp} --ni --config_mod ${config_mod} sampling.num_frames_pred=${nfp} sampling.preds_per_test=10 sampling.subsample=100 model.version=DDIM --ckpt ${ckpt} --video_gen -v videos_${ckpt}_DDIM_100_nfp_${nfp} 21 | 22 | # Test ddpm 1000 23 | #CUDA_VISIBLE_DEVICES=0 python main.py --config configs/${config}.yml --data_path ${data} --exp ${exp} --ni --config_mod ${config_mod} sampling.num_frames_pred=${nfp} sampling.preds_per_test=10 sampling.subsample=1000 model.version=DDPM --ckpt ${ckpt} --video_gen -v videos_${ckpt}_DDPM_1000_nfp_${nfp} 24 | -------------------------------------------------------------------------------- /example_scripts/final/simple_sample.py: -------------------------------------------------------------------------------- 1 | 2 | # CUDA_VISIBLE_DEVICES=3 python -i load_model_from_ckpt.py --ckpt_path /path/to/logs/checkpoint.pt 3 | 4 | # Load CIFAR10 5 | import torch 6 | from torchvision.datasets import CIFAR10 7 | ds = CIFAR10('/path/to/data/cifar10', train=True) 8 | data = torch.from_numpy(ds.data) 9 | 10 | # Transform data 11 | from datasets import get_dataset, data_transform, inverse_data_transform 12 | 13 | # Sampler 14 | from models import ddpm_sampler 15 | 16 | 17 | all_samples = ddpm_sampler(init_samples, scorenet, 18 | n_steps_each=config.sampling.n_steps_each, 19 | step_lr=config.sampling.step_lr, verbose=True, 20 | final_only=config.sampling.final_only, 21 | denoise=config.sampling.denoise, 22 | subsample_steps=getattr(config.sampling, 'subsample', None)) 23 | -------------------------------------------------------------------------------- /example_scripts/video_gen_metrics.sh: -------------------------------------------------------------------------------- 1 | EXP=$1 2 | CKPT=$2 3 | NUMFRAMESPRED=$3 4 | PREDSPERTEST=$4 5 | DATAPATH=$5 6 | NAME=$6 7 | GPU1=$7 8 | # GPU2=$8 9 | 10 | CUDA_VISIBLE_DEVICES=$GPU1 python main.py --config $EXP/logs/config.yml --data_path $DATAPATH --exp $EXP --ckpt $CKPT --seed 0 --video_gen -v videos_${CKPT}_${NAME}_DDPM_100_traj${PREDSPERTEST} --config_mod sampling.fvd=True model.version="DDPM" sampling.subsample=100 sampling.num_frames_pred=$NUMFRAMESPRED sampling.preds_per_test=$PREDSPERTEST sampling.max_data_iter=100000000 11 | # CUDA_VISIBLE_DEVICES=$GPU2 python main.py --config $EXP/logs/config.yml --data_path $DATAPATH --exp $EXP --ckpt $CKPT --seed 0 --video_gen -v videos_${CKPT}_${NAME}_DDIM_100_traj${PREDSPERTEST} --config_mod sampling.fvd=True model.version="DDIM" sampling.subsample=100 sampling.num_frames_pred=$NUMFRAMESPRED sampling.preds_per_test=$PREDSPERTEST sampling.max_data_iter=100000000 & 12 | -------------------------------------------------------------------------------- /load_model_from_ckpt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import os 4 | import torch 5 | import yaml 6 | 7 | from collections import OrderedDict 8 | from functools import partial 9 | from imageio import mimwrite 10 | from torch.utils.data import DataLoader 11 | from torchvision.utils import make_grid, save_image 12 | 13 | try: 14 | from torchvision.transforms.functional import resize, InterpolationMode 15 | interp = InterpolationMode.NEAREST 16 | except: 17 | from torchvision.transforms.functional import resize 18 | interp = 0 19 | 20 | from datasets import get_dataset, data_transform, inverse_data_transform 21 | from main import dict2namespace 22 | from models import get_sigmas, anneal_Langevin_dynamics, anneal_Langevin_dynamics_consistent, ddpm_sampler, ddim_sampler, FPNDM_sampler 23 | from models.ema import EMAHelper 24 | from runners.ncsn_runner import get_model 25 | 26 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 27 | # device = torch.device('cpu') 28 | 29 | 30 | def parse_args(): 31 | parser = argparse.ArgumentParser(description=globals()['__doc__']) 32 | parser.add_argument('--ckpt_path', type=str, required=True, help='Path to checkpoint.pt') 33 | parser.add_argument('--data_path', type=str, default='/mnt/data/scratch/data/CIFAR10', help='Path to the dataset') 34 | args = parser.parse_args() 35 | return args.ckpt_path, args.data_path 36 | 37 | 38 | # Make and load model 39 | def load_model(ckpt_path, device=device): 40 | # Parse config file 41 | with open(os.path.join(os.path.dirname(ckpt_path), 'config.yml'), 'r') as f: 42 | config = yaml.load(f, Loader=yaml.FullLoader) 43 | # Load config file 44 | config = dict2namespace(config) 45 | config.device = device 46 | # Load model 47 | scorenet = get_model(config) 48 | if config.device != torch.device('cpu'): 49 | scorenet = torch.nn.DataParallel(scorenet) 50 | states = torch.load(ckpt_path, map_location=config.device) 51 | else: 52 | states = torch.load(ckpt_path, map_location='cpu') 53 | states[0] = OrderedDict([(k.replace('module.', ''), v) for k, v in states[0].items()]) 54 | scorenet.load_state_dict(states[0], strict=False) 55 | if config.model.ema: 56 | ema_helper = EMAHelper(mu=config.model.ema_rate) 57 | ema_helper.register(scorenet) 58 | ema_helper.load_state_dict(states[-1]) 59 | ema_helper.ema(scorenet) 60 | scorenet.eval() 61 | return scorenet, config 62 | 63 | 64 | def get_sampler_from_config(config): 65 | version = getattr(config.model, 'version', "DDPM") 66 | # Sampler 67 | if version == "SMLD": 68 | consistent = getattr(config.sampling, 'consistent', False) 69 | sampler = anneal_Langevin_dynamics_consistent if consistent else anneal_Langevin_dynamics 70 | elif version == "DDPM": 71 | sampler = partial(ddpm_sampler, config=config) 72 | elif version == "DDIM": 73 | sampler = partial(ddim_sampler, config=config) 74 | elif version == "FPNDM": 75 | sampler = partial(FPNDM_sampler, config=config) 76 | return sampler 77 | 78 | 79 | def get_sampler(config): 80 | sampler = get_sampler_from_config(config) 81 | sampler_partial = partial(sampler, n_steps_each=config.sampling.n_steps_each, 82 | step_lr=config.sampling.step_lr, just_beta=False, 83 | final_only=True, denoise=config.sampling.denoise, 84 | subsample_steps=getattr(config.sampling, 'subsample', None), 85 | clip_before=getattr(config.sampling, 'clip_before', True), 86 | verbose=False, log=False, gamma=getattr(config.model, 'gamma', False)) 87 | def sampler_fn(init, scorenet, cond, cond_mask, subsample=getattr(config.sampling, 'subsample', None), verbose=False): 88 | init = init.to(config.device) 89 | cond = cond.to(config.device) 90 | if cond_mask is not None: 91 | cond_mask = cond_mask.to(config.device) 92 | return inverse_data_transform(config, sampler_partial(init, scorenet, cond=cond, cond_mask=cond_mask, 93 | subsample_steps=subsample, verbose=verbose)[-1].to('cpu')) 94 | return sampler_fn 95 | 96 | 97 | def init_samples(n_init_samples, config): 98 | # Initial samples 99 | # n_init_samples = min(36, config.training.batch_size) 100 | version = getattr(config.model, 'version', "DDPM") 101 | init_samples_shape = (n_init_samples, config.data.channels*config.data.num_frames, config.data.image_size, config.data.image_size) 102 | if version == "SMLD": 103 | init_samples = torch.rand(init_samples_shape) 104 | init_samples = data_transform(self.config, init_samples) 105 | elif version == "DDPM" or self.version == "DDIM" or self.version == "FPNDM": 106 | if getattr(config.model, 'gamma', False): 107 | used_k, used_theta = net.k_cum[0], net.theta_t[0] 108 | z = Gamma(torch.full(init_samples_shape, used_k), torch.full(init_samples_shape, 1 / used_theta)).sample().to(config.device) 109 | init_samples = z - used_k*used_theta # we don't scale here 110 | else: 111 | init_samples = torch.randn(init_samples_shape) 112 | return init_samples 113 | 114 | 115 | if __name__ == '__main__': 116 | # data_path = '/path/to/data/CIFAR10' 117 | ckpt_path, data_path = parse_args() 118 | 119 | scorenet, config = load_model(ckpt_path, device) 120 | 121 | # Initial samples 122 | dataset, test_dataset = get_dataset(data_path, config) 123 | dataloader = DataLoader(dataset, batch_size=config.training.batch_size, shuffle=True, 124 | num_workers=config.data.num_workers) 125 | train_iter = iter(dataloader) 126 | x, y = next(train_iter) 127 | test_loader = DataLoader(test_dataset, batch_size=config.training.batch_size, shuffle=False, 128 | num_workers=config.data.num_workers, drop_last=True) 129 | test_iter = iter(test_loader) 130 | test_x, test_y = next(test_iter) 131 | 132 | net = scorenet.module if hasattr(scorenet, 'module') else scorenet 133 | version = getattr(net, 'version', 'SMLD').upper() 134 | net_type = getattr(net, 'type') if isinstance(getattr(net, 'type'), str) else 'v1' 135 | 136 | if version == "SMLD": 137 | sigmas = net.sigmas 138 | labels = torch.randint(0, len(sigmas), (x.shape[0],), device=x.device) 139 | used_sigmas = sigmas[labels].reshape(x.shape[0], *([1] * len(x.shape[1:]))) 140 | device = sigmas.device 141 | 142 | elif version == "DDPM" or version == "DDIM": 143 | alphas = net.alphas 144 | labels = torch.randint(0, len(alphas), (x.shape[0],), device=x.device) 145 | used_alphas = alphas[labels].reshape(x.shape[0], *([1] * len(x.shape[1:]))) 146 | device = alphas.device 147 | 148 | 149 | # CUDA_VISIBLE_DEVICES=3 python -i load_model_from_ckpt.py --ckpt_path /path/to/ncsnv2/cifar10/BASELINE_DDPM_800k/logs/checkpoint.pt 150 | -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | 3 | 4 | def get_optimizer(config, parameters): 5 | if config.optim.optimizer == 'Adam': 6 | return optim.Adam(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay, 7 | betas=(config.optim.beta1, 0.999), amsgrad=config.optim.amsgrad, 8 | eps=config.optim.eps) 9 | elif config.optim.optimizer == 'RMSProp': 10 | return optim.RMSprop(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay) 11 | elif config.optim.optimizer == 'SGD': 12 | return optim.SGD(parameters, lr=config.optim.lr, momentum=0.9) 13 | else: 14 | raise NotImplementedError('Optimizer {} not understood.'.format(config.optim.optimizer)) 15 | 16 | 17 | def warmup_lr(optimizer, step, warmup, max_lr): 18 | if step > warmup: 19 | return max_lr 20 | lr = max_lr * min(float(step) / max(warmup, 1), 1.0) 21 | for param_group in optimizer.param_groups: 22 | param_group["lr"] = lr 23 | return lr 24 | 25 | -------------------------------------------------------------------------------- /losses/dsm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from functools import partial 4 | from torch.distributions.gamma import Gamma 5 | 6 | 7 | def anneal_dsm_score_estimation(scorenet, x, labels=None, loss_type='a', hook=None, cond=None, cond_mask=None, gamma=False, L1=False, all_frames=False): 8 | 9 | net = scorenet.module if hasattr(scorenet, 'module') else scorenet 10 | version = getattr(net, 'version', 'SMLD').upper() 11 | net_type = getattr(net, 'type') if isinstance(getattr(net, 'type'), str) else 'v1' 12 | 13 | if all_frames: 14 | x = torch.cat([x, cond], dim=1) 15 | cond = None 16 | 17 | # z, perturbed_x 18 | if version == "SMLD": 19 | sigmas = net.sigmas 20 | if labels is None: 21 | labels = torch.randint(0, len(sigmas), (x.shape[0],), device=x.device) 22 | used_sigmas = sigmas[labels].reshape(x.shape[0], *([1] * len(x.shape[1:]))) 23 | z = torch.randn_like(x) 24 | perturbed_x = x + used_sigmas * z 25 | elif version == "DDPM" or version == "DDIM" or version == "FPNDM": 26 | alphas = net.alphas 27 | if labels is None: 28 | labels = torch.randint(0, len(alphas), (x.shape[0],), device=x.device) 29 | used_alphas = alphas[labels].reshape(x.shape[0], *([1] * len(x.shape[1:]))) 30 | if gamma: 31 | used_k = net.k_cum[labels].reshape(x.shape[0], *([1] * len(x.shape[1:]))).repeat(1, *x.shape[1:]) 32 | used_theta = net.theta_t[labels].reshape(x.shape[0], *([1] * len(x.shape[1:]))).repeat(1, *x.shape[1:]) 33 | z = Gamma(used_k, 1 / used_theta).sample() 34 | z = (z - used_k*used_theta) / (1 - used_alphas).sqrt() 35 | else: 36 | z = torch.randn_like(x) 37 | perturbed_x = used_alphas.sqrt() * x + (1 - used_alphas).sqrt() * z 38 | scorenet = partial(scorenet, cond=cond) 39 | 40 | # Loss 41 | if L1: 42 | def pow_(x): 43 | return x.abs() 44 | else: 45 | def pow_(x): 46 | return 1 / 2. * x.square() 47 | loss = pow_((z - scorenet(perturbed_x, labels, cond_mask=cond_mask)).reshape(len(x), -1)).sum(dim=-1) 48 | 49 | if hook is not None: 50 | hook(loss, labels) 51 | 52 | return loss.mean(dim=0) 53 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | 5 | class BaseModel(): 6 | def __init__(self): 7 | pass; 8 | 9 | def name(self): 10 | return 'BaseModel' 11 | 12 | def initialize(self, use_gpu=True, gpu_ids=[0]): 13 | self.use_gpu = use_gpu 14 | self.gpu_ids = gpu_ids 15 | 16 | def forward(self): 17 | pass 18 | 19 | def get_image_paths(self): 20 | pass 21 | 22 | def optimize_parameters(self): 23 | pass 24 | 25 | def get_current_visuals(self): 26 | return self.input 27 | 28 | def get_current_errors(self): 29 | return {} 30 | 31 | def save(self, label): 32 | pass 33 | 34 | # helper saving function that can be used by subclasses 35 | def save_network(self, network, path, network_label, epoch_label): 36 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 37 | save_path = os.path.join(path, save_filename) 38 | torch.save(network.state_dict(), save_path) 39 | 40 | # helper loading function that can be used by subclasses 41 | def load_network(self, network, network_label, epoch_label): 42 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 43 | save_path = os.path.join(self.save_dir, save_filename) 44 | print('Loading network from %s'%save_path) 45 | network.load_state_dict(torch.load(save_path)) 46 | 47 | def update_learning_rate(): 48 | pass 49 | 50 | def get_image_paths(self): 51 | return self.image_paths 52 | 53 | def save_done(self, flag=False): 54 | np.save(os.path.join(self.save_dir, 'done_flag'),flag) 55 | np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i') -------------------------------------------------------------------------------- /models/better/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /models/better/normalization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Normalization layers.""" 17 | import torch.nn as nn 18 | import torch 19 | import functools 20 | 21 | 22 | def get_normalization(config, conditional=False): 23 | """Obtain normalization modules from the config file.""" 24 | norm = config.model.normalization 25 | if conditional: 26 | if norm == 'InstanceNorm++': 27 | return functools.partial(ConditionalInstanceNorm2dPlus, num_classes=config.model.num_classes) 28 | else: 29 | raise NotImplementedError(f'{norm} not implemented yet.') 30 | else: 31 | if norm == 'InstanceNorm': 32 | return nn.InstanceNorm2d 33 | elif norm == 'InstanceNorm++': 34 | return InstanceNorm2dPlus 35 | elif norm == 'VarianceNorm': 36 | return VarianceNorm2d 37 | elif norm == 'GroupNorm': 38 | return nn.GroupNorm 39 | else: 40 | raise ValueError('Unknown normalization: %s' % norm) 41 | 42 | 43 | class ConditionalBatchNorm2d(nn.Module): 44 | def __init__(self, num_features, num_classes, bias=True): 45 | super().__init__() 46 | self.num_features = num_features 47 | self.bias = bias 48 | self.bn = nn.BatchNorm2d(num_features, affine=False) 49 | if self.bias: 50 | self.embed = nn.Embedding(num_classes, num_features * 2) 51 | self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) 52 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 53 | else: 54 | self.embed = nn.Embedding(num_classes, num_features) 55 | self.embed.weight.data.uniform_() 56 | 57 | def forward(self, x, y): 58 | out = self.bn(x) 59 | if self.bias: 60 | gamma, beta = self.embed(y).chunk(2, dim=1) 61 | out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1) 62 | else: 63 | gamma = self.embed(y) 64 | out = gamma.view(-1, self.num_features, 1, 1) * out 65 | return out 66 | 67 | 68 | class ConditionalInstanceNorm2d(nn.Module): 69 | def __init__(self, num_features, num_classes, bias=True): 70 | super().__init__() 71 | self.num_features = num_features 72 | self.bias = bias 73 | self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) 74 | if bias: 75 | self.embed = nn.Embedding(num_classes, num_features * 2) 76 | self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) 77 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 78 | else: 79 | self.embed = nn.Embedding(num_classes, num_features) 80 | self.embed.weight.data.uniform_() 81 | 82 | def forward(self, x, y): 83 | h = self.instance_norm(x) 84 | if self.bias: 85 | gamma, beta = self.embed(y).chunk(2, dim=-1) 86 | out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1) 87 | else: 88 | gamma = self.embed(y) 89 | out = gamma.view(-1, self.num_features, 1, 1) * h 90 | return out 91 | 92 | 93 | class ConditionalVarianceNorm2d(nn.Module): 94 | def __init__(self, num_features, num_classes, bias=False): 95 | super().__init__() 96 | self.num_features = num_features 97 | self.bias = bias 98 | self.embed = nn.Embedding(num_classes, num_features) 99 | self.embed.weight.data.normal_(1, 0.02) 100 | 101 | def forward(self, x, y): 102 | vars = torch.var(x, dim=(2, 3), keepdim=True) 103 | h = x / torch.sqrt(vars + 1e-5) 104 | 105 | gamma = self.embed(y) 106 | out = gamma.view(-1, self.num_features, 1, 1) * h 107 | return out 108 | 109 | 110 | class VarianceNorm2d(nn.Module): 111 | def __init__(self, num_features, bias=False): 112 | super().__init__() 113 | self.num_features = num_features 114 | self.bias = bias 115 | self.alpha = nn.Parameter(torch.zeros(num_features)) 116 | self.alpha.data.normal_(1, 0.02) 117 | 118 | def forward(self, x): 119 | vars = torch.var(x, dim=(2, 3), keepdim=True) 120 | h = x / torch.sqrt(vars + 1e-5) 121 | 122 | out = self.alpha.view(-1, self.num_features, 1, 1) * h 123 | return out 124 | 125 | 126 | class ConditionalNoneNorm2d(nn.Module): 127 | def __init__(self, num_features, num_classes, bias=True): 128 | super().__init__() 129 | self.num_features = num_features 130 | self.bias = bias 131 | if bias: 132 | self.embed = nn.Embedding(num_classes, num_features * 2) 133 | self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) 134 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 135 | else: 136 | self.embed = nn.Embedding(num_classes, num_features) 137 | self.embed.weight.data.uniform_() 138 | 139 | def forward(self, x, y): 140 | if self.bias: 141 | gamma, beta = self.embed(y).chunk(2, dim=-1) 142 | out = gamma.view(-1, self.num_features, 1, 1) * x + beta.view(-1, self.num_features, 1, 1) 143 | else: 144 | gamma = self.embed(y) 145 | out = gamma.view(-1, self.num_features, 1, 1) * x 146 | return out 147 | 148 | 149 | class NoneNorm2d(nn.Module): 150 | def __init__(self, num_features, bias=True): 151 | super().__init__() 152 | 153 | def forward(self, x): 154 | return x 155 | 156 | 157 | class InstanceNorm2dPlus(nn.Module): 158 | def __init__(self, num_features, bias=True): 159 | super().__init__() 160 | self.num_features = num_features 161 | self.bias = bias 162 | self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) 163 | self.alpha = nn.Parameter(torch.zeros(num_features)) 164 | self.gamma = nn.Parameter(torch.zeros(num_features)) 165 | self.alpha.data.normal_(1, 0.02) 166 | self.gamma.data.normal_(1, 0.02) 167 | if bias: 168 | self.beta = nn.Parameter(torch.zeros(num_features)) 169 | 170 | def forward(self, x): 171 | means = torch.mean(x, dim=(2, 3)) 172 | m = torch.mean(means, dim=-1, keepdim=True) 173 | v = torch.var(means, dim=-1, keepdim=True) 174 | means = (means - m) / (torch.sqrt(v + 1e-5)) 175 | h = self.instance_norm(x) 176 | 177 | if self.bias: 178 | h = h + means[..., None, None] * self.alpha[..., None, None] 179 | out = self.gamma.view(-1, self.num_features, 1, 1) * h + self.beta.view(-1, self.num_features, 1, 1) 180 | else: 181 | h = h + means[..., None, None] * self.alpha[..., None, None] 182 | out = self.gamma.view(-1, self.num_features, 1, 1) * h 183 | return out 184 | 185 | 186 | class ConditionalInstanceNorm2dPlus(nn.Module): 187 | def __init__(self, num_features, num_classes, bias=True): 188 | super().__init__() 189 | self.num_features = num_features 190 | self.bias = bias 191 | self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) 192 | if bias: 193 | self.embed = nn.Embedding(num_classes, num_features * 3) 194 | self.embed.weight.data[:, :2 * num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02) 195 | self.embed.weight.data[:, 2 * num_features:].zero_() # Initialise bias at 0 196 | else: 197 | self.embed = nn.Embedding(num_classes, 2 * num_features) 198 | self.embed.weight.data.normal_(1, 0.02) 199 | 200 | def forward(self, x, y): 201 | means = torch.mean(x, dim=(2, 3)) 202 | m = torch.mean(means, dim=-1, keepdim=True) 203 | v = torch.var(means, dim=-1, keepdim=True) 204 | means = (means - m) / (torch.sqrt(v + 1e-5)) 205 | h = self.instance_norm(x) 206 | 207 | if self.bias: 208 | gamma, alpha, beta = self.embed(y).chunk(3, dim=-1) 209 | h = h + means[..., None, None] * alpha[..., None, None] 210 | out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1) 211 | else: 212 | gamma, alpha = self.embed(y).chunk(2, dim=-1) 213 | h = h + means[..., None, None] * alpha[..., None, None] 214 | out = gamma.view(-1, self.num_features, 1, 1) * h 215 | return out 216 | -------------------------------------------------------------------------------- /models/better/op/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/voletiv/mcvd-pytorch/451da2eb635bad50da6a7c03b443a34c6eb08b3a/models/better/op/__init__.py -------------------------------------------------------------------------------- /models/better/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.system("unset TORCH_CUDA_ARCH_LIST") 3 | 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | from torch.autograd import Function 8 | from torch.utils.cpp_extension import load 9 | 10 | 11 | module_path = os.path.dirname(__file__) 12 | 13 | 14 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 15 | if input.device.type == "cpu": 16 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 17 | return ( 18 | F.leaky_relu( 19 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 20 | ) 21 | * scale 22 | ) 23 | 24 | else: 25 | 26 | fused = load( 27 | "fused", 28 | sources=[ 29 | os.path.join(module_path, "fused_bias_act.cpp"), 30 | os.path.join(module_path, "fused_bias_act_kernel.cu"), 31 | ], 32 | ) 33 | 34 | class FusedLeakyReLUFunctionBackward(Function): 35 | 36 | @staticmethod 37 | def forward(ctx, grad_output, out, negative_slope, scale): 38 | ctx.save_for_backward(out) 39 | ctx.negative_slope = negative_slope 40 | ctx.scale = scale 41 | 42 | empty = grad_output.new_empty(0) 43 | 44 | grad_input = fused.fused_bias_act( 45 | grad_output, empty, out, 3, 1, negative_slope, scale 46 | ) 47 | 48 | dim = [0] 49 | 50 | if grad_input.ndim > 2: 51 | dim += list(range(2, grad_input.ndim)) 52 | 53 | grad_bias = grad_input.sum(dim).detach() 54 | 55 | return grad_input, grad_bias 56 | 57 | @staticmethod 58 | def backward(ctx, gradgrad_input, gradgrad_bias): 59 | out, = ctx.saved_tensors 60 | gradgrad_out = fused.fused_bias_act( 61 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 62 | ) 63 | 64 | return gradgrad_out, None, None, None 65 | 66 | 67 | class FusedLeakyReLUFunction(Function): 68 | 69 | @staticmethod 70 | def forward(ctx, input, bias, negative_slope, scale): 71 | empty = input.new_empty(0) 72 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 73 | ctx.save_for_backward(out) 74 | ctx.negative_slope = negative_slope 75 | ctx.scale = scale 76 | 77 | return out 78 | 79 | @staticmethod 80 | def backward(ctx, grad_output): 81 | out, = ctx.saved_tensors 82 | 83 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 84 | grad_output, out, ctx.negative_slope, ctx.scale 85 | ) 86 | 87 | return grad_input, grad_bias, None, None 88 | 89 | 90 | class FusedLeakyReLU(nn.Module): 91 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 92 | super().__init__() 93 | 94 | self.bias = nn.Parameter(torch.zeros(channel)) 95 | self.negative_slope = negative_slope 96 | self.scale = scale 97 | 98 | def forward(self, input): 99 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 100 | 101 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 102 | -------------------------------------------------------------------------------- /models/better/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /models/better/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /models/better/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /models/better/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.system("unset TORCH_CUDA_ARCH_LIST") 3 | 4 | import torch 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | 12 | 13 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 14 | if input.device.type == "cpu": 15 | out = upfirdn2d_native( 16 | input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] 17 | ) 18 | 19 | else: 20 | 21 | upfirdn2d_op = load( 22 | "upfirdn2d", 23 | sources=[ 24 | os.path.join(module_path, "upfirdn2d.cpp"), 25 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 26 | ], 27 | ) 28 | 29 | class UpFirDn2dBackward(Function): 30 | 31 | @staticmethod 32 | def forward( 33 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 34 | ): 35 | 36 | up_x, up_y = up 37 | down_x, down_y = down 38 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 39 | 40 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 41 | 42 | grad_input = upfirdn2d_op.upfirdn2d( 43 | grad_output, 44 | grad_kernel, 45 | down_x, 46 | down_y, 47 | up_x, 48 | up_y, 49 | g_pad_x0, 50 | g_pad_x1, 51 | g_pad_y0, 52 | g_pad_y1, 53 | ) 54 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 55 | 56 | ctx.save_for_backward(kernel) 57 | 58 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 59 | 60 | ctx.up_x = up_x 61 | ctx.up_y = up_y 62 | ctx.down_x = down_x 63 | ctx.down_y = down_y 64 | ctx.pad_x0 = pad_x0 65 | ctx.pad_x1 = pad_x1 66 | ctx.pad_y0 = pad_y0 67 | ctx.pad_y1 = pad_y1 68 | ctx.in_size = in_size 69 | ctx.out_size = out_size 70 | 71 | return grad_input 72 | 73 | @staticmethod 74 | def backward(ctx, gradgrad_input): 75 | kernel, = ctx.saved_tensors 76 | 77 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 78 | 79 | gradgrad_out = upfirdn2d_op.upfirdn2d( 80 | gradgrad_input, 81 | kernel, 82 | ctx.up_x, 83 | ctx.up_y, 84 | ctx.down_x, 85 | ctx.down_y, 86 | ctx.pad_x0, 87 | ctx.pad_x1, 88 | ctx.pad_y0, 89 | ctx.pad_y1, 90 | ) 91 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 92 | gradgrad_out = gradgrad_out.view( 93 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 94 | ) 95 | 96 | return gradgrad_out, None, None, None, None, None, None, None, None 97 | 98 | 99 | class UpFirDn2d(Function): 100 | 101 | @staticmethod 102 | def forward(ctx, input, kernel, up, down, pad): 103 | up_x, up_y = up 104 | down_x, down_y = down 105 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 106 | 107 | kernel_h, kernel_w = kernel.shape 108 | batch, channel, in_h, in_w = input.shape 109 | ctx.in_size = input.shape 110 | 111 | input = input.reshape(-1, in_h, in_w, 1) 112 | 113 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 114 | 115 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 116 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 117 | ctx.out_size = (out_h, out_w) 118 | 119 | ctx.up = (up_x, up_y) 120 | ctx.down = (down_x, down_y) 121 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 122 | 123 | g_pad_x0 = kernel_w - pad_x0 - 1 124 | g_pad_y0 = kernel_h - pad_y0 - 1 125 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 126 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 127 | 128 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 129 | 130 | out = upfirdn2d_op.upfirdn2d( 131 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 132 | ) 133 | # out = out.view(major, out_h, out_w, minor) 134 | out = out.view(-1, channel, out_h, out_w) 135 | 136 | return out 137 | 138 | @staticmethod 139 | def backward(ctx, grad_output): 140 | kernel, grad_kernel = ctx.saved_tensors 141 | 142 | grad_input = UpFirDn2dBackward.apply( 143 | grad_output, 144 | kernel, 145 | grad_kernel, 146 | ctx.up, 147 | ctx.down, 148 | ctx.pad, 149 | ctx.g_pad, 150 | ctx.in_size, 151 | ctx.out_size, 152 | ) 153 | 154 | return grad_input, None, None, None, None 155 | 156 | out = UpFirDn2d.apply( 157 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 158 | ) 159 | 160 | return out 161 | 162 | 163 | def upfirdn2d_native( 164 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 165 | ): 166 | _, channel, in_h, in_w = input.shape 167 | input = input.reshape(-1, in_h, in_w, 1) 168 | 169 | _, in_h, in_w, minor = input.shape 170 | kernel_h, kernel_w = kernel.shape 171 | 172 | out = input.view(-1, in_h, 1, in_w, 1, minor) 173 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 174 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 175 | 176 | out = F.pad( 177 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 178 | ) 179 | out = out[ 180 | :, 181 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 182 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 183 | :, 184 | ] 185 | 186 | out = out.permute(0, 3, 1, 2) 187 | out = out.reshape( 188 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 189 | ) 190 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 191 | out = F.conv2d(out, w) 192 | out = out.reshape( 193 | -1, 194 | minor, 195 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 196 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 197 | ) 198 | out = out.permute(0, 2, 3, 1) 199 | out = out[:, ::down_y, ::down_x, :] 200 | 201 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 202 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 203 | 204 | return out.view(-1, channel, out_h, out_w) 205 | -------------------------------------------------------------------------------- /models/better/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """All functions and modules related to model definition. 17 | """ 18 | 19 | import torch 20 | import sde_lib 21 | import numpy as np 22 | 23 | 24 | _MODELS = {} 25 | 26 | 27 | def register_model(cls=None, *, name=None): 28 | """A decorator for registering model classes.""" 29 | 30 | def _register(cls): 31 | if name is None: 32 | local_name = cls.__name__ 33 | else: 34 | local_name = name 35 | if local_name in _MODELS: 36 | raise ValueError(f'Already registered model with name: {local_name}') 37 | _MODELS[local_name] = cls 38 | return cls 39 | 40 | if cls is None: 41 | return _register 42 | else: 43 | return _register(cls) 44 | 45 | 46 | def get_model(name): 47 | return _MODELS[name] 48 | 49 | 50 | def get_sigmas(config): 51 | """Get sigmas --- the set of noise levels for SMLD from config files. 52 | Args: 53 | config: A ConfigDict object parsed from the config file 54 | Returns: 55 | sigmas: a jax numpy arrary of noise levels 56 | """ 57 | sigmas = np.exp( 58 | np.linspace(np.log(config.model.sigma_max), np.log(config.model.sigma_min), config.model.num_scales)) 59 | 60 | return sigmas 61 | 62 | 63 | def get_ddpm_params(config): 64 | """Get betas and alphas --- parameters used in the original DDPM paper.""" 65 | num_diffusion_timesteps = 1000 66 | # parameters need to be adapted if number of time steps differs from 1000 67 | beta_start = config.model.beta_min / config.model.num_scales 68 | beta_end = config.model.beta_max / config.model.num_scales 69 | betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) 70 | 71 | alphas = 1. - betas 72 | alphas_cumprod = np.cumprod(alphas, axis=0) 73 | sqrt_alphas_cumprod = np.sqrt(alphas_cumprod) 74 | sqrt_1m_alphas_cumprod = np.sqrt(1. - alphas_cumprod) 75 | 76 | return { 77 | 'betas': betas, 78 | 'alphas': alphas, 79 | 'alphas_cumprod': alphas_cumprod, 80 | 'sqrt_alphas_cumprod': sqrt_alphas_cumprod, 81 | 'sqrt_1m_alphas_cumprod': sqrt_1m_alphas_cumprod, 82 | 'beta_min': beta_start * (num_diffusion_timesteps - 1), 83 | 'beta_max': beta_end * (num_diffusion_timesteps - 1), 84 | 'num_diffusion_timesteps': num_diffusion_timesteps 85 | } 86 | 87 | 88 | def create_model(config): 89 | """Create the score model.""" 90 | model_name = config.model.name 91 | score_model = get_model(model_name)(config) 92 | score_model = score_model.to(config.device) 93 | score_model = torch.nn.DataParallel(score_model) 94 | return score_model 95 | 96 | 97 | def get_model_fn(model, train=False): 98 | """Create a function to give the output of the score-based model. 99 | 100 | Args: 101 | model: The score model. 102 | train: `True` for training and `False` for evaluation. 103 | 104 | Returns: 105 | A model function. 106 | """ 107 | 108 | def model_fn(x, labels): 109 | """Compute the output of the score-based model. 110 | 111 | Args: 112 | x: A mini-batch of input data. 113 | labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently 114 | for different models. 115 | 116 | Returns: 117 | A tuple of (model output, new mutable states) 118 | """ 119 | if not train: 120 | model.eval() 121 | return model(x, labels) 122 | else: 123 | model.train() 124 | return model(x, labels) 125 | 126 | return model_fn 127 | 128 | 129 | def get_score_fn(sde, model, train=False, continuous=False): 130 | """Wraps `score_fn` so that the model output corresponds to a real time-dependent score function. 131 | 132 | Args: 133 | sde: An `sde_lib.SDE` object that represents the forward SDE. 134 | model: A score model. 135 | train: `True` for training and `False` for evaluation. 136 | continuous: If `True`, the score-based model is expected to directly take continuous time steps. 137 | 138 | Returns: 139 | A score function. 140 | """ 141 | model_fn = get_model_fn(model, train=train) 142 | 143 | if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE): 144 | def score_fn(x, t): 145 | # Scale neural network output by standard deviation and flip sign 146 | if continuous or isinstance(sde, sde_lib.subVPSDE): 147 | # For VP-trained models, t=0 corresponds to the lowest noise level 148 | # The maximum value of time embedding is assumed to 999 for 149 | # continuously-trained models. 150 | labels = t * 999 151 | score = model_fn(x, labels) 152 | std = sde.marginal_prob(torch.zeros_like(x), t)[1] 153 | else: 154 | # For VP-trained models, t=0 corresponds to the lowest noise level 155 | labels = t * (sde.N - 1) 156 | score = model_fn(x, labels) 157 | std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()] 158 | 159 | score = -score / std[:, None, None, None] 160 | return score 161 | 162 | elif isinstance(sde, sde_lib.VESDE): 163 | def score_fn(x, t): 164 | if continuous: 165 | labels = sde.marginal_prob(torch.zeros_like(x), t)[1] 166 | else: 167 | # For VE-trained models, t=0 corresponds to the highest noise level 168 | labels = sde.T - t 169 | labels *= sde.N - 1 170 | labels = torch.round(labels).long() 171 | 172 | score = model_fn(x, labels) 173 | return score 174 | 175 | else: 176 | raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.") 177 | 178 | return score_fn 179 | 180 | 181 | def to_flattened_numpy(x): 182 | """Flatten a torch tensor `x` and convert it to numpy.""" 183 | return x.detach().cpu().numpy().reshape((-1,)) 184 | 185 | 186 | def from_flattened_numpy(x, shape): 187 | """Form a torch tensor with the given `shape` from a flattened numpy array `x`.""" 188 | return torch.from_numpy(x.reshape(shape)) -------------------------------------------------------------------------------- /models/ema.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch.nn as nn 3 | 4 | class EMAHelper(object): 5 | def __init__(self, mu=0.999): 6 | self.mu = mu 7 | self.shadow = {} 8 | 9 | def register(self, module): 10 | if isinstance(module, nn.DataParallel): 11 | module = module.module 12 | for name, param in module.named_parameters(): 13 | if param.requires_grad: 14 | self.shadow[name] = param.data.clone() 15 | 16 | def update(self, module): 17 | if isinstance(module, nn.DataParallel): 18 | module = module.module 19 | for name, param in module.named_parameters(): 20 | if param.requires_grad: 21 | self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data 22 | 23 | def ema(self, module): 24 | if isinstance(module, nn.DataParallel): 25 | module = module.module 26 | for name, param in module.named_parameters(): 27 | if param.requires_grad: 28 | param.data.copy_(self.shadow[name].data) 29 | 30 | def ema_copy(self, module): 31 | if isinstance(module, nn.DataParallel): 32 | inner_module = module.module 33 | module_copy = type(inner_module)(inner_module.config).to(inner_module.config.device) 34 | module_copy.load_state_dict(inner_module.state_dict()) 35 | module_copy = nn.DataParallel(module_copy) 36 | else: 37 | module_copy = type(module)(module.config).to(module.config.device) 38 | module_copy.load_state_dict(module.state_dict()) 39 | # module_copy = copy.deepcopy(module) 40 | self.ema(module_copy) 41 | return module_copy 42 | 43 | def state_dict(self): 44 | return self.shadow 45 | 46 | def load_state_dict(self, state_dict): 47 | self.shadow = state_dict 48 | 49 | 50 | # import glob, torch, tqdm 51 | # ckpt_files = sorted(glob.glob("*.pt")) 52 | # for file in tqdm.tqdm(ckpt_files): 53 | # a = torch.load(file) 54 | # a[0]['module.unet.all_modules.52.Norm_0.weight'] = a[0].pop('module.unet.all_modules.52.weight') 55 | # a[0]['module.unet.all_modules.52.Norm_0.bias'] = a[0].pop('module.unet.all_modules.52.bias') 56 | # a[-1]['unet.all_modules.52.Norm_0.weight'] = a[-1].pop('unet.all_modules.52.weight') 57 | # a[-1]['unet.all_modules.52.Norm_0.bias'] = a[-1].pop('unet.all_modules.52.bias') 58 | # torch.save(a, file) 59 | -------------------------------------------------------------------------------- /models/eval_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import models.dist_model as dist_model 4 | import numpy as np 5 | import models.dist_model as dist_model 6 | 7 | # Taken from https://github.com/psh01087/Vid-ODE/blob/main/eval_models/__init__.py 8 | class PerceptualLoss(torch.nn.Module): 9 | def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, device='cpu'): # VGG using our perceptually-learned weights (LPIPS metric) 10 | # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss 11 | super(PerceptualLoss, self).__init__() 12 | print('Setting up Perceptual loss...') 13 | self.device = device 14 | self.spatial = spatial 15 | self.model = dist_model.DistModel() 16 | self.model.initialize(model=model, net=net, colorspace=colorspace, spatial=self.spatial, device=device) 17 | print('...[%s] initialized'%self.model.name()) 18 | print('...Done') 19 | 20 | def forward(self, pred, target, normalize=False): 21 | """ 22 | Pred and target are Variables. 23 | If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] 24 | If normalize is False, assumes the images are already between [-1,+1] 25 | Inputs pred and target are Nx3xHxW 26 | Output pytorch Variable N long 27 | """ 28 | 29 | if normalize: 30 | target = 2 * target - 1 31 | pred = 2 * pred - 1 32 | 33 | return self.model.forward(target, pred) 34 | 35 | def normalize_tensor(in_feat,eps=1e-10): 36 | norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True)) 37 | return in_feat/(norm_factor+eps) 38 | 39 | def l2(p0, p1, range=255.): 40 | return .5*np.mean((p0 / range - p1 / range)**2) 41 | 42 | def psnr(p0, p1, peak=255.): 43 | return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2)) 44 | 45 | #def dssim(p0, p1, range=255.): 46 | # return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2. 47 | 48 | def rgb2lab(in_img,mean_cent=False): 49 | from skimage import color 50 | img_lab = color.rgb2lab(in_img) 51 | if(mean_cent): 52 | img_lab[:,:,0] = img_lab[:,:,0]-50 53 | return img_lab 54 | 55 | def tensor2np(tensor_obj): 56 | # change dimension of a tensor object into a numpy array 57 | return tensor_obj[0].cpu().float().numpy().transpose((1,2,0)) 58 | 59 | def np2tensor(np_obj): 60 | # change dimenion of np array into tensor array 61 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 62 | 63 | def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False): 64 | # image tensor to lab tensor 65 | from skimage import color 66 | 67 | img = tensor2im(image_tensor) 68 | img_lab = color.rgb2lab(img) 69 | if(mc_only): 70 | img_lab[:,:,0] = img_lab[:,:,0]-50 71 | if(to_norm and not mc_only): 72 | img_lab[:,:,0] = img_lab[:,:,0]-50 73 | img_lab = img_lab/100. 74 | 75 | return np2tensor(img_lab) 76 | 77 | def tensorlab2tensor(lab_tensor,return_inbnd=False): 78 | from skimage import color 79 | import warnings 80 | warnings.filterwarnings("ignore") 81 | 82 | lab = tensor2np(lab_tensor)*100. 83 | lab[:,:,0] = lab[:,:,0]+50 84 | 85 | rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1) 86 | if(return_inbnd): 87 | # convert back to lab, see if we match 88 | lab_back = color.rgb2lab(rgb_back.astype('uint8')) 89 | mask = 1.*np.isclose(lab_back,lab,atol=2.) 90 | mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis]) 91 | return (im2tensor(rgb_back),mask) 92 | else: 93 | return im2tensor(rgb_back) 94 | 95 | def rgb2lab(input): 96 | from skimage import color 97 | return color.rgb2lab(input / 255.) 98 | 99 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 100 | image_numpy = image_tensor[0].cpu().float().numpy() 101 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 102 | return image_numpy.astype(imtype) 103 | 104 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 105 | return torch.Tensor((image / factor - cent) 106 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 107 | 108 | def tensor2vec(vector_tensor): 109 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0] 110 | 111 | def voc_ap(rec, prec, use_07_metric=False): 112 | """ ap = voc_ap(rec, prec, [use_07_metric]) 113 | Compute VOC AP given precision and recall. 114 | If use_07_metric is true, uses the 115 | VOC 07 11 point method (default:False). 116 | """ 117 | if use_07_metric: 118 | # 11 point metric 119 | ap = 0. 120 | for t in np.arange(0., 1.1, 0.1): 121 | if np.sum(rec >= t) == 0: 122 | p = 0 123 | else: 124 | p = np.max(prec[rec >= t]) 125 | ap = ap + p / 11. 126 | else: 127 | # correct AP calculation 128 | # first append sentinel values at the end 129 | mrec = np.concatenate(([0.], rec, [1.])) 130 | mpre = np.concatenate(([0.], prec, [0.])) 131 | 132 | # compute the precision envelope 133 | for i in range(mpre.size - 1, 0, -1): 134 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 135 | 136 | # to calculate area under PR curve, look for points 137 | # where X axis (recall) changes value 138 | i = np.where(mrec[1:] != mrec[:-1])[0] 139 | 140 | # and sum (\Delta recall) * prec 141 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 142 | return ap 143 | 144 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 145 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): 146 | image_numpy = image_tensor[0].cpu().float().numpy() 147 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 148 | return image_numpy.astype(imtype) 149 | 150 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 151 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): 152 | return torch.Tensor((image / factor - cent) 153 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) -------------------------------------------------------------------------------- /models/fvd/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/voletiv/mcvd-pytorch/451da2eb635bad50da6a7c03b443a34c6eb08b3a/models/fvd/__init__.py -------------------------------------------------------------------------------- /models/fvd/convert_tf_pretrained.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from collections import OrderedDict 3 | import tensorflow_hub as hub 4 | import torch 5 | 6 | from src_pytorch.fvd.pytorch_i3d import InceptionI3d 7 | 8 | 9 | def convert_name(name): 10 | mapping = { 11 | 'conv_3d': 'conv3d', 12 | 'batch_norm': 'bn', 13 | 'w:0': 'weight', 14 | 'b:0': 'bias', 15 | 'moving_mean:0': 'running_mean', 16 | 'moving_variance:0': 'running_var', 17 | 'beta:0': 'bias' 18 | } 19 | 20 | segs = name.split('/') 21 | new_segs = [] 22 | i = 0 23 | while i < len(segs): 24 | seg = segs[i] 25 | if 'Mixed' in seg: 26 | new_segs.append(seg) 27 | elif 'Conv' in seg and 'Mixed' not in name: 28 | new_segs.append(seg) 29 | elif 'Branch' in seg: 30 | branch_i = int(seg.split('_')[-1]) 31 | i += 1 32 | seg = segs[i] 33 | 34 | # special case due to typo in original code 35 | if 'Mixed_5b' in name and branch_i == 2: 36 | if '1x1' in seg: 37 | new_segs.append(f'b{branch_i}a') 38 | elif '3x3' in seg: 39 | new_segs.append(f'b{branch_i}b') 40 | else: 41 | raise Exception() 42 | # Either Conv3d_{i}a_... or Conv3d_{i}b_... 43 | elif 'a' in seg: 44 | if branch_i == 0: 45 | new_segs.append('b0') 46 | else: 47 | new_segs.append(f'b{branch_i}a') 48 | elif 'b' in seg: 49 | new_segs.append(f'b{branch_i}b') 50 | else: 51 | raise Exception 52 | elif seg == 'Logits': 53 | new_segs.append('logits') 54 | i += 1 55 | elif seg in mapping: 56 | new_segs.append(mapping[seg]) 57 | else: 58 | raise Exception(f"No match found for seg {seg} in name {name}") 59 | 60 | i += 1 61 | return '.'.join(new_segs) 62 | 63 | def convert_tensor(tensor): 64 | tensor_dim = len(tensor.shape) 65 | if tensor_dim == 5: # conv or bn 66 | if all([t == 1 for t in tensor.shape[:-1]]): 67 | tensor = tensor.squeeze() 68 | else: 69 | tensor = tensor.permute(4, 3, 0, 1, 2).contiguous() 70 | elif tensor_dim == 1: # conv bias 71 | pass 72 | else: 73 | raise Exception(f"Invalid shape {tensor.shape}") 74 | return tensor 75 | 76 | n_class = int(sys.argv[1]) # 600 or 400 77 | assert n_class in [400, 600] 78 | 79 | # Converts model from https://github.com/google-research/google-research/tree/master/frechet_video_distance 80 | # to pytorch version for loading 81 | model_url = f"https://tfhub.dev/deepmind/i3d-kinetics-{n_class}/1" 82 | i3d = hub.load(model_url) 83 | name_prefix = 'RGB/inception_i3d/' 84 | 85 | print('Creating state_dict...') 86 | all_names = [] 87 | state_dict = OrderedDict() 88 | for var in i3d.variables: 89 | name = var.name[len(name_prefix):] 90 | new_name = convert_name(name) 91 | all_names.append(new_name) 92 | 93 | tensor = torch.FloatTensor(var.value().numpy()) 94 | new_tensor = convert_tensor(tensor) 95 | 96 | state_dict[new_name] = new_tensor 97 | 98 | if 'bn.bias' in new_name: 99 | new_name = new_name[:-4] + 'weight' # bn.weight 100 | new_tensor = torch.ones_like(new_tensor).float() 101 | state_dict[new_name] = new_tensor 102 | 103 | print(f'Complete state_dict with {len(state_dict)} entries') 104 | 105 | s = dict() 106 | for i, n in enumerate(all_names): 107 | s[n] = s.get(n, []) + [i] 108 | 109 | for k, v in s.items(): 110 | if len(v) > 1: 111 | print('dup', k) 112 | for i in v: 113 | print('\t', i3d.variables[i].name) 114 | 115 | print('Testing load_state_dict...') 116 | print('Creating model...') 117 | 118 | i3d = InceptionI3d(n_class, in_channels=3) 119 | 120 | print('Loading state_dict...') 121 | i3d.load_state_dict(state_dict) 122 | 123 | print(f'Saving state_dict as fvd/i3d_pretrained_{n_class}.pt') 124 | torch.save(state_dict, f'fvd/i3d_pretrained_{n_class}.pt') 125 | 126 | print('Done') 127 | 128 | -------------------------------------------------------------------------------- /models/networks_basic.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | import sys 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.init as init 8 | from torch.autograd import Variable 9 | import numpy as np 10 | from skimage import color 11 | from . import pretrained_networks as pn 12 | 13 | from . import eval_models as util 14 | 15 | def spatial_average(in_tens, keepdim=True): 16 | return in_tens.mean([2,3],keepdim=keepdim) 17 | 18 | def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W 19 | in_H = in_tens.shape[2] 20 | scale_factor = 1.*out_H/in_H 21 | 22 | return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens) 23 | 24 | # Learned perceptual metric 25 | class PNetLin(nn.Module): 26 | def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True): 27 | super(PNetLin, self).__init__() 28 | 29 | self.pnet_type = pnet_type 30 | self.pnet_tune = pnet_tune 31 | self.pnet_rand = pnet_rand 32 | self.spatial = spatial 33 | self.lpips = lpips 34 | self.version = version 35 | self.scaling_layer = ScalingLayer() 36 | 37 | if(self.pnet_type in ['vgg','vgg16']): 38 | net_type = pn.vgg16 39 | self.chns = [64,128,256,512,512] 40 | elif(self.pnet_type=='alex'): 41 | net_type = pn.alexnet 42 | self.chns = [64,192,384,256,256] 43 | elif(self.pnet_type=='squeeze'): 44 | net_type = pn.squeezenet 45 | self.chns = [64,128,256,384,384,512,512] 46 | self.L = len(self.chns) 47 | 48 | self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) 49 | 50 | if(lpips): 51 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 52 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 53 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 54 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 55 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 56 | self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4] 57 | if(self.pnet_type=='squeeze'): # 7 layers for squeezenet 58 | self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) 59 | self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) 60 | self.lins+=[self.lin5,self.lin6] 61 | 62 | def forward(self, in0, in1, retPerLayer=False): 63 | # v0.0 - original release had a bug, where input was not scaled 64 | in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1) 65 | outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) 66 | feats0, feats1, diffs = {}, {}, {} 67 | 68 | for kk in range(self.L): 69 | feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk]) 70 | diffs[kk] = (feats0[kk]-feats1[kk])**2 71 | if(self.lpips): 72 | if(self.spatial): 73 | res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)] 74 | else: 75 | res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)] 76 | else: 77 | if(self.spatial): 78 | res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)] 79 | else: 80 | res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)] 81 | val = res[0] 82 | for l in range(1,self.L): 83 | val += res[l] 84 | 85 | if(retPerLayer): 86 | return (val, res) 87 | else: 88 | return val 89 | 90 | class ScalingLayer(nn.Module): 91 | def __init__(self): 92 | super(ScalingLayer, self).__init__() 93 | self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None]) 94 | self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None]) 95 | 96 | def forward(self, inp): 97 | return (inp - self.shift) / self.scale 98 | 99 | 100 | class NetLinLayer(nn.Module): 101 | ''' A single linear layer which does a 1x1 conv ''' 102 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 103 | super(NetLinLayer, self).__init__() 104 | 105 | layers = [nn.Dropout(),] if(use_dropout) else [] 106 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),] 107 | self.model = nn.Sequential(*layers) 108 | 109 | 110 | class Dist2LogitLayer(nn.Module): 111 | ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) ''' 112 | def __init__(self, chn_mid=32, use_sigmoid=True): 113 | super(Dist2LogitLayer, self).__init__() 114 | 115 | layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),] 116 | layers += [nn.LeakyReLU(0.2,True),] 117 | layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),] 118 | layers += [nn.LeakyReLU(0.2,True),] 119 | layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),] 120 | if(use_sigmoid): 121 | layers += [nn.Sigmoid(),] 122 | self.model = nn.Sequential(*layers) 123 | 124 | def forward(self,d0,d1,eps=0.1): 125 | return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1)) 126 | 127 | class BCERankingLoss(nn.Module): 128 | def __init__(self, chn_mid=32): 129 | super(BCERankingLoss, self).__init__() 130 | self.net = Dist2LogitLayer(chn_mid=chn_mid) 131 | # self.parameters = list(self.net.parameters()) 132 | self.loss = torch.nn.BCELoss() 133 | 134 | def forward(self, d0, d1, judge): 135 | per = (judge+1.)/2. 136 | self.logit = self.net.forward(d0,d1) 137 | return self.loss(self.logit, per) 138 | 139 | # L2, DSSIM metrics 140 | class FakeNet(nn.Module): 141 | def __init__(self, device='cpu', colorspace='Lab'): 142 | super(FakeNet, self).__init__() 143 | self.device = device 144 | self.colorspace=colorspace 145 | 146 | class L2(FakeNet): 147 | 148 | def forward(self, in0, in1, retPerLayer=None): 149 | assert(in0.size()[0]==1) # currently only supports batchSize 1 150 | 151 | if(self.colorspace=='RGB'): 152 | (N,C,X,Y) = in0.size() 153 | value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N) 154 | return value 155 | elif(self.colorspace=='Lab'): 156 | value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)), 157 | util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') 158 | ret_var = Variable( torch.Tensor((value,) ) ).to(self.device) 159 | return ret_var 160 | 161 | class DSSIM(FakeNet): 162 | 163 | def forward(self, in0, in1, retPerLayer=None): 164 | assert(in0.size()[0]==1) # currently only supports batchSize 1 165 | 166 | if(self.colorspace=='RGB'): 167 | value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float') 168 | elif(self.colorspace=='Lab'): 169 | value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)), 170 | util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') 171 | ret_var = Variable( torch.Tensor((value,) ) ).to(self.device) 172 | return ret_var 173 | 174 | def print_network(net): 175 | num_params = 0 176 | for param in net.parameters(): 177 | num_params += param.numel() 178 | print('Network',net) 179 | print('Total number of parameters: %d' % num_params) 180 | -------------------------------------------------------------------------------- /models/pndm.py: -------------------------------------------------------------------------------- 1 | ## Modified from https://github.com/luping-liu/PNDM/blob/f285e8e6da36049ea29e97b741fb71e531505ec8/runner/method.py#L20 2 | 3 | def runge_kutta(x, t_list, model, alphas_cump, ets, clip_before=False): 4 | e_1 = model(x, t_list[0]) 5 | ets.append(e_1) 6 | x_2 = transfer(x, t_list[0], t_list[1], e_1, alphas_cump, clip_before) 7 | 8 | e_2 = model(x_2, t_list[1]) 9 | x_3 = transfer(x, t_list[0], t_list[1], e_2, alphas_cump, clip_before) 10 | 11 | e_3 = model(x_3, t_list[1]) 12 | x_4 = transfer(x, t_list[0], t_list[2], e_3, alphas_cump, clip_before) 13 | 14 | e_4 = model(x_4, t_list[2]) 15 | et = (1 / 6) * (e_1 + 2 * e_2 + 2 * e_3 + e_4) 16 | 17 | return et, ets 18 | 19 | def transfer(x, t, t_next, et, alphas_cump, clip_before=False): 20 | at = alphas_cump[t.long() + 1].view(-1, 1, 1, 1) 21 | at_next = alphas_cump[t_next.long() + 1].view(-1, 1, 1, 1) 22 | 23 | # x0 = (1 / c_alpha.sqrt()) * (x_mod - (1 - c_alpha).sqrt() * grad) 24 | # x_mod = c_alpha_prev.sqrt() * x0 + (1 - c_alpha_prev).sqrt() * grad 25 | 26 | x_delta = (at_next - at) * ((1 / (at.sqrt() * (at.sqrt() + at_next.sqrt()))) * x - \ 27 | 1 / (at.sqrt() * (((1 - at_next) * at).sqrt() + ((1 - at) * at_next).sqrt())) * et) 28 | 29 | x_next = x + x_delta 30 | if clip_before: 31 | x_next = x_next.clip_(-1, 1) 32 | 33 | return x_next 34 | 35 | def gen_order_1(img, t, t_next, model, alphas_cump, ets, clip_before=False): ## DDIM 36 | noise = model(img, t) 37 | ets.append(noise) 38 | img_next = transfer(img, t, t_next, noise, alphas_cump, clip_before) 39 | return img_next, ets 40 | 41 | def gen_order_4(img, t, t_next, model, alphas_cump, ets, clip_before=False): ## F-PNDM 42 | t_list = [t, (t+t_next)/2, t_next] 43 | #print(t_list) 44 | if len(ets) > 2: 45 | noise_ = model(img, t) 46 | ets.append(noise_) 47 | noise = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4]) 48 | else: 49 | noise, ets = runge_kutta(img, t_list, model, alphas_cump, ets, clip_before) 50 | 51 | img_next = transfer(img, t, t_next, noise, alphas_cump, clip_before) 52 | return img_next, ets 53 | -------------------------------------------------------------------------------- /models/pretrained_networks.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torchvision import models as tv 4 | 5 | class squeezenet(torch.nn.Module): 6 | def __init__(self, requires_grad=False, pretrained=True): 7 | super(squeezenet, self).__init__() 8 | pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features 9 | self.slice1 = torch.nn.Sequential() 10 | self.slice2 = torch.nn.Sequential() 11 | self.slice3 = torch.nn.Sequential() 12 | self.slice4 = torch.nn.Sequential() 13 | self.slice5 = torch.nn.Sequential() 14 | self.slice6 = torch.nn.Sequential() 15 | self.slice7 = torch.nn.Sequential() 16 | self.N_slices = 7 17 | for x in range(2): 18 | self.slice1.add_module(str(x), pretrained_features[x]) 19 | for x in range(2,5): 20 | self.slice2.add_module(str(x), pretrained_features[x]) 21 | for x in range(5, 8): 22 | self.slice3.add_module(str(x), pretrained_features[x]) 23 | for x in range(8, 10): 24 | self.slice4.add_module(str(x), pretrained_features[x]) 25 | for x in range(10, 11): 26 | self.slice5.add_module(str(x), pretrained_features[x]) 27 | for x in range(11, 12): 28 | self.slice6.add_module(str(x), pretrained_features[x]) 29 | for x in range(12, 13): 30 | self.slice7.add_module(str(x), pretrained_features[x]) 31 | if not requires_grad: 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | def forward(self, X): 36 | h = self.slice1(X) 37 | h_relu1 = h 38 | h = self.slice2(h) 39 | h_relu2 = h 40 | h = self.slice3(h) 41 | h_relu3 = h 42 | h = self.slice4(h) 43 | h_relu4 = h 44 | h = self.slice5(h) 45 | h_relu5 = h 46 | h = self.slice6(h) 47 | h_relu6 = h 48 | h = self.slice7(h) 49 | h_relu7 = h 50 | vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7']) 51 | out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7) 52 | 53 | return out 54 | 55 | 56 | class alexnet(torch.nn.Module): 57 | def __init__(self, requires_grad=False, pretrained=True): 58 | super(alexnet, self).__init__() 59 | alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features 60 | self.slice1 = torch.nn.Sequential() 61 | self.slice2 = torch.nn.Sequential() 62 | self.slice3 = torch.nn.Sequential() 63 | self.slice4 = torch.nn.Sequential() 64 | self.slice5 = torch.nn.Sequential() 65 | self.N_slices = 5 66 | for x in range(2): 67 | self.slice1.add_module(str(x), alexnet_pretrained_features[x]) 68 | for x in range(2, 5): 69 | self.slice2.add_module(str(x), alexnet_pretrained_features[x]) 70 | for x in range(5, 8): 71 | self.slice3.add_module(str(x), alexnet_pretrained_features[x]) 72 | for x in range(8, 10): 73 | self.slice4.add_module(str(x), alexnet_pretrained_features[x]) 74 | for x in range(10, 12): 75 | self.slice5.add_module(str(x), alexnet_pretrained_features[x]) 76 | if not requires_grad: 77 | for param in self.parameters(): 78 | param.requires_grad = False 79 | 80 | def forward(self, X): 81 | h = self.slice1(X) 82 | h_relu1 = h 83 | h = self.slice2(h) 84 | h_relu2 = h 85 | h = self.slice3(h) 86 | h_relu3 = h 87 | h = self.slice4(h) 88 | h_relu4 = h 89 | h = self.slice5(h) 90 | h_relu5 = h 91 | alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5']) 92 | out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) 93 | 94 | return out 95 | 96 | class vgg16(torch.nn.Module): 97 | def __init__(self, requires_grad=False, pretrained=True): 98 | super(vgg16, self).__init__() 99 | vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features 100 | self.slice1 = torch.nn.Sequential() 101 | self.slice2 = torch.nn.Sequential() 102 | self.slice3 = torch.nn.Sequential() 103 | self.slice4 = torch.nn.Sequential() 104 | self.slice5 = torch.nn.Sequential() 105 | self.N_slices = 5 106 | for x in range(4): 107 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 108 | for x in range(4, 9): 109 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 110 | for x in range(9, 16): 111 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 112 | for x in range(16, 23): 113 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 114 | for x in range(23, 30): 115 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 116 | if not requires_grad: 117 | for param in self.parameters(): 118 | param.requires_grad = False 119 | 120 | def forward(self, X): 121 | h = self.slice1(X) 122 | h_relu1_2 = h 123 | h = self.slice2(h) 124 | h_relu2_2 = h 125 | h = self.slice3(h) 126 | h_relu3_3 = h 127 | h = self.slice4(h) 128 | h_relu4_3 = h 129 | h = self.slice5(h) 130 | h_relu5_3 = h 131 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 132 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 133 | 134 | return out 135 | 136 | 137 | 138 | class resnet(torch.nn.Module): 139 | def __init__(self, requires_grad=False, pretrained=True, num=18): 140 | super(resnet, self).__init__() 141 | if(num==18): 142 | self.net = tv.resnet18(pretrained=pretrained) 143 | elif(num==34): 144 | self.net = tv.resnet34(pretrained=pretrained) 145 | elif(num==50): 146 | self.net = tv.resnet50(pretrained=pretrained) 147 | elif(num==101): 148 | self.net = tv.resnet101(pretrained=pretrained) 149 | elif(num==152): 150 | self.net = tv.resnet152(pretrained=pretrained) 151 | self.N_slices = 5 152 | 153 | self.conv1 = self.net.conv1 154 | self.bn1 = self.net.bn1 155 | self.relu = self.net.relu 156 | self.maxpool = self.net.maxpool 157 | self.layer1 = self.net.layer1 158 | self.layer2 = self.net.layer2 159 | self.layer3 = self.net.layer3 160 | self.layer4 = self.net.layer4 161 | 162 | def forward(self, X): 163 | h = self.conv1(X) 164 | h = self.bn1(h) 165 | h = self.relu(h) 166 | h_relu1 = h 167 | h = self.maxpool(h) 168 | h = self.layer1(h) 169 | h_conv2 = h 170 | h = self.layer2(h) 171 | h_conv3 = h 172 | h = self.layer3(h) 173 | h_conv4 = h 174 | h = self.layer4(h) 175 | h_conv5 = h 176 | 177 | outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5']) 178 | out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) 179 | 180 | return out 181 | -------------------------------------------------------------------------------- /models/weights/v0.0/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/voletiv/mcvd-pytorch/451da2eb635bad50da6a7c03b443a34c6eb08b3a/models/weights/v0.0/alex.pth -------------------------------------------------------------------------------- /models/weights/v0.0/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/voletiv/mcvd-pytorch/451da2eb635bad50da6a7c03b443a34c6eb08b3a/models/weights/v0.0/squeeze.pth -------------------------------------------------------------------------------- /models/weights/v0.0/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/voletiv/mcvd-pytorch/451da2eb635bad50da6a7c03b443a34c6eb08b3a/models/weights/v0.0/vgg.pth -------------------------------------------------------------------------------- /models/weights/v0.1/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/voletiv/mcvd-pytorch/451da2eb635bad50da6a7c03b443a34c6eb08b3a/models/weights/v0.1/alex.pth -------------------------------------------------------------------------------- /models/weights/v0.1/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/voletiv/mcvd-pytorch/451da2eb635bad50da6a7c03b443a34c6eb08b3a/models/weights/v0.1/squeeze.pth -------------------------------------------------------------------------------- /models/weights/v0.1/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/voletiv/mcvd-pytorch/451da2eb635bad50da6a7c03b443a34c6eb08b3a/models/weights/v0.1/vgg.pth -------------------------------------------------------------------------------- /quick_sample.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import os 4 | import torch 5 | import yaml 6 | 7 | from collections import OrderedDict 8 | from imageio import mimwrite 9 | from torch.utils.data import DataLoader 10 | from torchvision.utils import make_grid, save_image 11 | 12 | try: 13 | from torchvision.transforms.functional import resize, InterpolationMode 14 | interp = InterpolationMode.NEAREST 15 | except: 16 | from torchvision.transforms.functional import resize 17 | interp = 0 18 | 19 | from datasets import get_dataset, data_transform, inverse_data_transform 20 | from main import dict2namespace 21 | from models import get_sigmas, anneal_Langevin_dynamics 22 | from models.ema import EMAHelper 23 | from runners.ncsn_runner import get_model, conditioning_fn 24 | 25 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 26 | # device = torch.device('cpu') 27 | 28 | from models import ddpm_sampler 29 | 30 | 31 | def parse_args(): 32 | parser = argparse.ArgumentParser(description=globals()['__doc__']) 33 | parser.add_argument('--ckpt_path', type=str, required=True, help='Path to checkpoint.pt') 34 | parser.add_argument('--data_path', type=str, help='Path to the dataset') 35 | parser.add_argument('--save_path', type=str, help='Path to the dataset') 36 | args = parser.parse_args() 37 | return args.ckpt_path, args.data_path, args.save_path 38 | 39 | 40 | # Make and load model 41 | def load_model(ckpt_path, device): 42 | # Parse config file 43 | with open(os.path.join(os.path.dirname(ckpt_path), 'config.yml'), 'r') as f: 44 | config = yaml.load(f, Loader=yaml.FullLoader) 45 | # Load config file 46 | config = dict2namespace(config) 47 | config.device = device 48 | # Load model 49 | scorenet = get_model(config) 50 | if config.device != torch.device('cpu'): 51 | scorenet = torch.nn.DataParallel(scorenet) 52 | states = torch.load(ckpt_path, map_location=config.device) 53 | else: 54 | states = torch.load(ckpt_path, map_location='cpu') 55 | states[0] = OrderedDict([(k.replace('module.', ''), v) for k, v in states[0].items()]) 56 | scorenet.load_state_dict(states[0], strict=False) 57 | if config.model.ema: 58 | ema_helper = EMAHelper(mu=config.model.ema_rate) 59 | ema_helper.register(scorenet) 60 | ema_helper.load_state_dict(states[-1]) 61 | ema_helper.ema(scorenet) 62 | scorenet.eval() 63 | return scorenet, config 64 | 65 | 66 | if __name__ == '__main__': 67 | # data_path = '/path/to/data/CIFAR10' 68 | ckpt_path, data_path, save_path = parse_args() 69 | 70 | scorenet, config = load_model(ckpt_path, device) 71 | 72 | # Initial samples 73 | dataset, test_dataset = get_dataset(data_path, config) 74 | dataloader = DataLoader(dataset, batch_size=config.training.batch_size, shuffle=True, 75 | num_workers=config.data.num_workers) 76 | train_iter = iter(dataloader) 77 | x, y = next(train_iter) 78 | test_loader = DataLoader(test_dataset, batch_size=config.training.batch_size, shuffle=False, 79 | num_workers=config.data.num_workers, drop_last=True) 80 | test_iter = iter(test_loader) 81 | test_x, test_y = next(test_iter) 82 | 83 | net = scorenet.module if hasattr(scorenet, 'module') else scorenet 84 | version = getattr(net, 'version', 'SMLD').upper() 85 | net_type = getattr(net, 'type') if isinstance(getattr(net, 'type'), str) else 'v1' 86 | 87 | if version == "SMLD": 88 | sigmas = net.sigmas 89 | labels = torch.randint(0, len(sigmas), (x.shape[0],), device=x.device) 90 | used_sigmas = sigmas[labels].reshape(x.shape[0], *([1] * len(x.shape[1:]))) 91 | device = sigmas.device 92 | 93 | elif version == "DDPM" or version == "DDIM": 94 | alphas = net.alphas 95 | labels = torch.randint(0, len(alphas), (x.shape[0],), device=x.device) 96 | used_alphas = alphas[labels].reshape(x.shape[0], *([1] * len(x.shape[1:]))) 97 | device = alphas.device 98 | 99 | for batch, (X, y) in enumerate(dataloader): 100 | break 101 | 102 | X = X.to(config.device) 103 | X = data_transform(config, X) 104 | 105 | conditional = config.data.num_frames_cond > 0 106 | cond = None 107 | if conditional: 108 | X, cond = conditioning_fn(config, X) 109 | 110 | init_samples = torch.randn(len(X), config.data.channels*config.data.num_frames, 111 | config.data.image_size, config.data.image_size, 112 | device=config.device) 113 | 114 | all_samples = ddpm_sampler(init_samples, scorenet, cond=cond[:len(init_samples)], 115 | n_steps_each=config.sampling.n_steps_each, 116 | step_lr=config.sampling.step_lr, just_beta=False, 117 | final_only=True, denoise=config.sampling.denoise, 118 | subsample_steps=getattr(config.sampling, 'subsample', None), 119 | verbose=True) 120 | 121 | sample = all_samples[-1].reshape(all_samples[-1].shape[0], config.data.channels, 122 | config.data.image_size, config.data.image_size) 123 | 124 | sample = inverse_data_transform(config, sample) 125 | 126 | image_grid = make_grid(sample, np.sqrt(config.training.batch_size)) 127 | step = 0 128 | save_image(image_grid, 129 | os.path.join(save_path, 'image_grid_{}.png'.format(step))) 130 | torch.save(sample, os.path.join(save_path, 'samples_{}.pt'.format(step))) 131 | 132 | # CUDA_VISIBLE_DEVICES=3 python -i load_model_from_ckpt.py --ckpt_path /path/to/ncsnv2/cifar10/BASELINE_DDPM_800k/logs/checkpoint.pt 133 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch 2 | numpy 3 | PyYAML 4 | imageio 5 | imageio-ffmpeg 6 | matplotlib 7 | opencv-python 8 | scikit-image 9 | tqdm 10 | h5py 11 | progressbar 12 | psutil 13 | ninja 14 | -------------------------------------------------------------------------------- /runners/__init__.py: -------------------------------------------------------------------------------- 1 | from runners.ncsn_runner import * 2 | --------------------------------------------------------------------------------