├── src ├── models │ ├── stage1 │ │ ├── __init__.py │ │ ├── modules │ │ │ ├── utils.py │ │ │ └── quantizer.py │ │ ├── nerf │ │ │ ├── nerf_utils.py │ │ │ └── helper.py │ │ ├── discriminator.py │ │ └── generator.py │ ├── stage2 │ │ ├── __init__.py │ │ ├── diffae │ │ │ ├── __init__.py │ │ │ └── nn.py │ │ ├── diffusion │ │ │ ├── losses.py │ │ │ ├── respace.py │ │ │ ├── nn.py │ │ │ └── resample.py │ │ └── ema.py │ └── __init__.py ├── optimizers │ ├── __init__.py │ └── scheduler.py ├── utils │ ├── structural_losses │ │ ├── tf_nndistance_so.so │ │ ├── tf_approxmatch_g.cu.o │ │ ├── tf_approxmatch_so.so │ │ ├── tf_nndistance_g.cu.o │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ └── tf_nndistance.cpython-38.pyc │ │ ├── __init__.py │ │ ├── tf_nndistance_compile.sh │ │ ├── tf_approxmatch_compile.sh │ │ ├── makefile │ │ ├── tf_nndistance.py │ │ ├── tf_approxmatch.py │ │ ├── tf_nndistance_g.cu │ │ ├── approxmatch.cu │ │ ├── approxmatch.cpp │ │ └── tf_approxmatch_g.cu │ ├── config2.py │ ├── prdc.py │ ├── config1.py │ ├── utils.py │ ├── logger.py │ ├── metric_voxel.py │ └── fid_utils.py └── datasets │ └── nerf_dataset.py ├── scripts ├── test_stage2_large_CelebAHQ.sh ├── test_stage2_small_CelebAHQ.sh ├── test_stage2_small_SRNCars.sh ├── test_stage2_large_ShapeNet.sh ├── test_stage2_small_ShapeNet.sh ├── train_stage2_large_CelebAHQ.sh ├── train_stage2_small_CelebAHQ.sh ├── train_stage1_large_CelebAHQ.sh ├── train_stage1_small_CelebAHQ.sh ├── train_stage2_large_ShapeNet.sh ├── train_stage2_small_ShapeNet.sh ├── train_stage2_small_SRNCars.sh ├── train_stage1_large_ShapeNet.sh ├── train_stage1_small_ShapeNet.sh ├── test_stage1_small_SRNCars.sh ├── test_stage1_large_CelebAHQ.sh ├── test_stage1_small_CelebAHQ.sh ├── test_stage1_large_ShapeNet.sh ├── test_stage1_small_ShapeNet.sh └── train_stage1_small_SRNCars.sh ├── requirements.txt ├── configs ├── CelebAHQ64px-ResidualMLP_L8_D4096_scale1_epoch1000-mNIF_K256L4W64H1024_W0_30_epoch800.yaml ├── CelebAHQ64px-ResidualMLP_L8_D4096_scale1_epoch1000-mNIF_K384L5W128H512_W0_30_epoch800.yaml ├── ShapeNet64px-ResidualMLP-L8-D4096-scale1-epoch1000-MetaLatentMixtureINR-K256L4W64H512-W0-50-epochs800.yaml ├── ShapeNet64px-ResidualMLP-L8-D4096-scale1-epoch1000-MetaLatentMixtureINR-K512L5W128H1024-W0-30-epochs400.yaml └── SRNCars128px-ResidualMLP-L8-D4096-scale1-epoch1000-NV32-NP512-LatentMixtureINR-K256L4W64H128-nouse+elu-epoch1000+1000.yaml ├── README.md └── main_stage2.py /src/models/stage1/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/stage2/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/stage2/diffae/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .scheduler import build_scheduler 2 | -------------------------------------------------------------------------------- /src/utils/structural_losses/tf_nndistance_so.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tackgeun/mNIF/HEAD/src/utils/structural_losses/tf_nndistance_so.so -------------------------------------------------------------------------------- /src/utils/structural_losses/tf_approxmatch_g.cu.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tackgeun/mNIF/HEAD/src/utils/structural_losses/tf_approxmatch_g.cu.o -------------------------------------------------------------------------------- /src/utils/structural_losses/tf_approxmatch_so.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tackgeun/mNIF/HEAD/src/utils/structural_losses/tf_approxmatch_so.so -------------------------------------------------------------------------------- /src/utils/structural_losses/tf_nndistance_g.cu.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tackgeun/mNIF/HEAD/src/utils/structural_losses/tf_nndistance_g.cu.o -------------------------------------------------------------------------------- /src/utils/structural_losses/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tackgeun/mNIF/HEAD/src/utils/structural_losses/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/utils/structural_losses/__pycache__/tf_nndistance.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tackgeun/mNIF/HEAD/src/utils/structural_losses/__pycache__/tf_nndistance.cpython-38.pyc -------------------------------------------------------------------------------- /src/utils/structural_losses/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | from .tf_nndistance import nn_distance 3 | #from tf_approxmatch import approx_match, match_cost 4 | except: 5 | print('External Losses (Chamfer-EMD) were not loaded.') 6 | -------------------------------------------------------------------------------- /scripts/test_stage2_large_CelebAHQ.sh: -------------------------------------------------------------------------------- 1 | python eval_stage2.py -r=experiments/CelebAHQ-large/ResidualMLP-L8-D4096-scale1-epoch1000 --stage2_epoch=999 --stage1_path=experiments/CelebAHQ-large/LatentMixtureINR-K384L5W128H512-lr1e-4+1.0-lrschedule-batch32-epoch800 --stage1_epoch=799 2 | -------------------------------------------------------------------------------- /scripts/test_stage2_small_CelebAHQ.sh: -------------------------------------------------------------------------------- 1 | python eval_stage2.py -r=experiments/CelebAHQ-small/ResidualMLP-L8-D4096-scale1-epoch1000/ --stage2_epoch=999 --stage1_path=experiments/CelebAHQ-small/LatentMixtureINR-K256L4W64H1024-lr1e-4+1.0-lrschedule-batch32-epoch800/ --stage1_epoch=799 2 | -------------------------------------------------------------------------------- /scripts/test_stage2_small_SRNCars.sh: -------------------------------------------------------------------------------- 1 | python eval_stage2.py -r=experiments/SRNCars-small/ResidualMLP-L8-D4096-scale1-epoch1000 --stage2_epoch=999 --stage1_path=experiments/SRNCars-small/NV32-NP512-LatentMixtureINR-K256L4W64H128-lr1e-4+1e-4-wd1.0-batch8-epoch1000+1000 --stage1_epoch=999 2 | -------------------------------------------------------------------------------- /scripts/test_stage2_large_ShapeNet.sh: -------------------------------------------------------------------------------- 1 | python eval_stage2.py -r=experiments/ShapeNet-large/ResidualMLP-L8-D4096-scale1-epoch1000 --stage2_epoch=999 --stage1_path=experiments/ShapeNet-large/LatentMixtureINR-K512L5W128H1024-subsampling4096-lr1e-4+1.0-lrschedule-batch32-epoch400 --stage1_epoch=399 2 | -------------------------------------------------------------------------------- /scripts/test_stage2_small_ShapeNet.sh: -------------------------------------------------------------------------------- 1 | python eval_stage2.py -r=experiments/ShapeNet-small/ResidualMLP-L8-D4096-scale1-epochs1000 --stage2_epoch=999 --stage1_path=experiments/ShapeNet-small/LatentMixtureINR-K256L4W64H512-W0-50-subsampling4096-lr1e-4+1.0-lrschedule-batch32-epoch800/ --stage1_epoch=799 2 | -------------------------------------------------------------------------------- /scripts/train_stage2_large_CelebAHQ.sh: -------------------------------------------------------------------------------- 1 | python main_stage2.py -r=results/CelebAHQ-large -c=configs/CelebAHQ64px-ResidualMLP_L8_D4096_scale1_epoch1000-mNIF_K384L5W128H512_W0_30_epoch800.yaml --checkpoint_path=experiments/CelebAHQ-large/LatentMixtureINR-K384L5W128H512-lr1e-4+1.0-lrschedule-batch32-epoch800 --stage1_epoch=799 2 | -------------------------------------------------------------------------------- /scripts/train_stage2_small_CelebAHQ.sh: -------------------------------------------------------------------------------- 1 | ython main_stage2.py -r=results/CelebAHQ-large -c=configs/CelebAHQ64px-ResidualMLP_L8_D4096_scale1_epoch1000-mNIF_K256L4W64H1024_W0_30_epoch800.yaml --checkpoint_path=experiments/CelebAHQ-small/LatentMixtureINR-K256L4W64H1024-lr1e-4+1.0-lrschedule-batch32-epoch800/ --stage1_epoch=799 2 | -------------------------------------------------------------------------------- /scripts/train_stage1_large_CelebAHQ.sh: -------------------------------------------------------------------------------- 1 | python main_stage1_cavia.py -r=experiments/mNIF-stage1-M384L5W128H512 --model_type=latent0.0001-mixtureinr-layerwise --hidden_features=512 --k_mixtures=384 --use_meta_sgd --width=128 --depth=5 --num_inner=3 --lr_outer=1e-4 --lr_inner=1.0 --batch_size=32 --save_freq=100 --num_epochs=800 --use_lr_scheduler 2 | 3 | -------------------------------------------------------------------------------- /scripts/train_stage1_small_CelebAHQ.sh: -------------------------------------------------------------------------------- 1 | python main_stage1_cavia.py -r=experiments/mNIF-stage1-M256L4W128H1024 --model_type=latent0.0001-mixtureinr-layerwise --hidden_features=1024 --k_mixtures=256 --use_meta_sgd --width=64 --depth=4 --num_inner=3 --lr_outer=1e-4 --lr_inner=1.0 --batch_size=32 --save_freq=100 --num_epochs=800 --use_lr_scheduler 2 | 3 | -------------------------------------------------------------------------------- /scripts/train_stage2_large_ShapeNet.sh: -------------------------------------------------------------------------------- 1 | python main_stage2.py -r=results/ShapeNet-large -c=configs/ShapeNet64px-ResidualMLP-L8-D4096-scale1-epoch1000-MetaLatentMixtureINR-K512L5W128H1024-W0-30-epochs400.yaml --checkpoint_path=experiments/ShapeNet-large/LatentMixtureINR-K512L5W128H1024-subsampling4096-lr1e-4+1.0-lrschedule-batch32-epoch400 --stage1_epoch=399 2 | -------------------------------------------------------------------------------- /scripts/train_stage2_small_ShapeNet.sh: -------------------------------------------------------------------------------- 1 | python main_stage2.py -r=results/ShapeNet-small -c=configs/ShapeNet64px-ResidualMLP-L8-D4096-scale1-epoch1000-MetaLatentMixtureINR-K256L4W64H512-W0-50-epochs800.yaml --checkpoint_path=experiments/ShapeNet-small/LatentMixtureINR-K256L4W64H512-W0-50-subsampling4096-lr1e-4+1.0-lrschedule-batch32-epoch800/ --stage1_epoch=799 2 | -------------------------------------------------------------------------------- /scripts/train_stage2_small_SRNCars.sh: -------------------------------------------------------------------------------- 1 | python main_stage2.py -r=results/SRNCars-small/ -c=configs/SRNCars128px-ResidualMLP-L8-D4096-scale1-epoch1000-NV32-NP512-LatentMixtureINR-K256L4W64H128-nouse+elu-epoch1000+1000.yaml --checkpoint_path=experiments/SRNCars-small/NV32-NP512-LatentMixtureINR-K256L4W64H128-lr1e-4+1e-4-wd1.0-batch8-epoch1000+1000 --stage1_epoch=999 2 | -------------------------------------------------------------------------------- /scripts/train_stage1_large_ShapeNet.sh: -------------------------------------------------------------------------------- 1 | python main_stage1_cavia.py -r=results/ShapeNet-large/mNIF-stage1-M512L5W128H1024 --model_type=latent0.0001-mixtureinr-layerwise --hidden_features=1024 --k_mixtures=512 --use_meta_sgd --width=128 --depth=5 --num_inner=3 --lr_outer=1e-4 --lr_inner=1.0 --batch_size=32 --dataset=shapenet --save_freq=100 --num_epochs=400 --use_lr_schedule --clip_grad 2 | 3 | 4 | -------------------------------------------------------------------------------- /scripts/train_stage1_small_ShapeNet.sh: -------------------------------------------------------------------------------- 1 | python main_stage1_cavia.py -r=results/ShapeNet-small/mNIF-stage1-M256L4W64H512 --model_type=latent0.0001-mixtureinr-layerwise --hidden_features=512 --k_mixtures=256 --use_meta_sgd --width=64 --depth=4 --w0=50 --num_inner=3 --lr_outer=1e-4 --lr_inner=1.0 --batch_size=32 --dataset=shapenet --save_freq=100 --num_epochs=800 --use_lr_schedule --clip_grad 2 | 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.5.1 2 | numpy==1.21.2 3 | omegaconf==2.1.1 4 | pickle5==0.0.11 5 | Pillow==9.2.0 6 | pytorch_lightning==1.5.7 7 | regex>2022.1.18 8 | scipy==1.7.3 9 | torch==1.10.0 10 | git+https://github.com/tackgeun/pytorch-meta.git 11 | torchmetrics 12 | torchvision 13 | torch-fidelity==0.3.0 14 | tqdm==4.61.2 15 | configargparse 16 | setuptools==59.5.0 17 | six 18 | easydict 19 | imageio 20 | pymcubes 21 | -------------------------------------------------------------------------------- /scripts/test_stage1_small_SRNCars.sh: -------------------------------------------------------------------------------- 1 | python main_stage1_autodecoding.py -m=experiments/SRNCars-small/NV32-NP512-LatentMixtureINR-K256L4W64H128-lr1e-4+1e-4-wd1.0-batch8-epoch1000+1000/metainits/epoch999.pth --model_type=latent-layerwise-mixtureinr --w0=30 --width=64 --depth=4 --k_mixtures=256 --hidden_features=128 --lr_outer=1e-4 --lr_inner=1e-4 --weight_decay_inner=1.0 --dataset=srn_cars --resolution=128 --eval --batch_size=1 --subsampled_views=32 --subsampled_pixels=2048 2 | -------------------------------------------------------------------------------- /src/utils/structural_losses/tf_nndistance_compile.sh: -------------------------------------------------------------------------------- 1 | /usr/local/bin/nvcc -std=c++17 -c -o tf_nndistance_g.cu.o tf_nndistance_g.cu -I /orions4-zfs/projects/optas/Virt_Env/tf_1.3/lib/python2.7/site-packages/tensorflow/include -DGOOGLE_CUDA=1 -x cu -Xcompiler -fPIC -O2 && g++ -std=c++11 tf_nndistance.cpp tf_nndistance_g.cu.o -o tf_nndistance_so.so -shared -fPIC -I /orions4-zfs/projects/optas/Virt_Env/tf_1.3/lib/python2.7/site-packages/tensorflow/include -L /usr/local/cuda/lib64 -O2 2 | -------------------------------------------------------------------------------- /src/utils/structural_losses/tf_approxmatch_compile.sh: -------------------------------------------------------------------------------- 1 | set -e 2 | if [ 'tf_approxmatch_g.cu.o' -ot 'tf_approxmatch_g.cu' ] ; then 3 | echo 'nvcc' 4 | /usr/local/cuda-8.0/bin/nvcc tf_approxmatch_g.cu -o tf_approxmatch_g.cu.o -c -O2 -DGOOGLE_CUDA=1 -x cu -Xcompiler -fPIC 5 | fi 6 | if [ 'tf_approxmatch_so.so' -ot 'tf_approxmatch.cpp' ] || [ 'tf_approxmatch_so.so' -ot 'tf_approxmatch_g.cu.o' ] ; then 7 | echo 'g++' 8 | g++ -std=c++11 tf_approxmatch.cpp tf_approxmatch_g.cu.o -o tf_approxmatch_so.so -shared -fPIC -I /orions4-zfs/projects/optas/Virt_Env/tf_1.3/lib/python2.7/site-packages/tensorflow/include -I /usr/local/cuda-8.0/include -L /usr/local/cuda-8.0/lib64/ -O2 9 | fi 10 | 11 | -------------------------------------------------------------------------------- /scripts/test_stage1_large_CelebAHQ.sh: -------------------------------------------------------------------------------- 1 | python main_stage1_cavia.py -m=experiments/CelebAHQ-large/LatentMixtureINR-K384L5W128H512-lr1e-4+1.0-lrschedule-batch32-epoch800/metainits/epoch799.pth --model_type=latent0.0001-mixtureinr-layerwise --hidden_features=512 --k_mixtures=384 --use_meta_sgd --width=128 --depth=5 --num_inner=3 --lr_outer=1e-4 --lr_inner=1.0 --batch_size=32 --eval 2 | python main_stage1_cavia.py -m=experiments/CelebAHQ-large/LatentMixtureINR-K384L5W128H512-lr1e-4+1.0-lrschedule-batch32-epoch800/metainits/epoch799.pth --model_type=latent0.0001-mixtureinr-layerwise --hidden_features=512 --k_mixtures=384 --use_meta_sgd --width=128 --depth=5 --num_inner=3 --lr_outer=1e-4 --lr_inner=1.0 --batch_size=32 --eval --split=test 3 | 4 | -------------------------------------------------------------------------------- /scripts/test_stage1_small_CelebAHQ.sh: -------------------------------------------------------------------------------- 1 | python main_stage1_cavia.py -m=experiments/CelebAHQ-small/LatentMixtureINR-K256L4W64H1024-lr1e-4+1.0-lrschedule-batch32-epoch800/metainits/epoch799.pth --model_type=latent0.0001-mixtureinr-layerwise --hidden_features=1024 --k_mixtures=256 --use_meta_sgd --width=64 --depth=4 --num_inner=3 --lr_outer=1e-4 --lr_inner=1.0 --batch_size=32 --eval 2 | python main_stage1_cavia.py -m=experiments/CelebAHQ-small/LatentMixtureINR-K256L4W64H1024-lr1e-4+1.0-lrschedule-batch32-epoch800/metainits/epoch799.pth --model_type=latent0.0001-mixtureinr-layerwise --hidden_features=1024 --k_mixtures=256 --use_meta_sgd --width=64 --depth=4 --num_inner=3 --lr_outer=1e-4 --lr_inner=1.0 --batch_size=32 --eval --split=test 3 | 4 | -------------------------------------------------------------------------------- /scripts/test_stage1_large_ShapeNet.sh: -------------------------------------------------------------------------------- 1 | python main_stage1_cavia.py -m=experiments/ShapeNet-large/LatentMixtureINR-K512L5W128H1024-subsampling4096-lr1e-4+1.0-lrschedule-batch32-epoch400/metainits/epoch399.pth --model_type=latent0.0001-mixtureinr-layerwise --hidden_features=1024 --k_mixtures=512 --use_meta_sgd --width=128 --depth=5 --num_inner=3 --lr_outer=1e-4 --lr_inner=1.0 --batch_size=32 --dataset=shapenet --eval 2 | python main_stage1_cavia.py -m=experiments/ShapeNet-large/LatentMixtureINR-K512L5W128H1024-subsampling4096-lr1e-4+1.0-lrschedule-batch32-epoch400/metainits/epoch399.pth --model_type=latent0.0001-mixtureinr-layerwise --hidden_features=1024 --k_mixtures=512 --use_meta_sgd --width=128 --depth=5 --num_inner=3 --lr_outer=1e-4 --lr_inner=1.0 --batch_size=32 --dataset=shapenet --eval --split=test 3 | 4 | -------------------------------------------------------------------------------- /scripts/test_stage1_small_ShapeNet.sh: -------------------------------------------------------------------------------- 1 | python main_stage1_cavia.py -m=experiments/ShapeNet-small/LatentMixtureINR-K256L4W64H512-W0-50-subsampling4096-lr1e-4+1.0-lrschedule-batch32-epoch800/metainits/epoch799.pth --model_type=latent0.0001-mixtureinr-layerwise --hidden_features=512 --k_mixtures=256 --use_meta_sgd --width=64 --depth=4 --w0=50 --num_inner=3 --lr_outer=1e-4 --lr_inner=1.0 --batch_size=32 --dataset=shapenet --eval 2 | python main_stage1_cavia.py -m=experiments/ShapeNet-small/LatentMixtureINR-K256L4W64H512-W0-50-subsampling4096-lr1e-4+1.0-lrschedule-batch32-epoch800/metainits/epoch799.pth --model_type=latent0.0001-mixtureinr-layerwise --hidden_features=512 --k_mixtures=256 --use_meta_sgd --width=64 --depth=4 --w0=50 --num_inner=3 --lr_outer=1e-4 --lr_inner=1.0 --batch_size=32 --dataset=shapenet --eval --split=test 3 | 4 | -------------------------------------------------------------------------------- /scripts/train_stage1_small_SRNCars.sh: -------------------------------------------------------------------------------- 1 | python main_stage1_autoencoding.py -r=experiments/NV32-NP512-128px-LatentMixtureINR-nouse+elu-K256L4W64H128-lr1e-4+1e-4-wd1.0-batch8-epoch1000/ --model_type=latent-layerwise-mixtureinr --w0=30 --width=64 --depth=4 --k_mixtures=256 --hidden_features=128 --lr_outer=1e-4 --lr_inner=1e-4 --weight_decay_inner=1.0 --dataset=srn_cars --resolution=128 --num_epochs=1000 --save_freq=500 --use_lr_schedule --batch_size=8 --subsampled_views=32 --subsampled_pixels=512 2 | python main_stage1_autoencoding.py -r=experiments/NV32-NP512-128px-LatentMixtureINR-nouse+elu-K256L4W64H128-lr1e-4+1e-4-wd1.0-batch8-epoch1000+1000/ -p=experiments/NV32-NP512-128px-LatentMixtureINR-nouse+elu-K256L4W64H128-W0-30-lr1e-4+1e-4-wd1.0-batch8-epoch1000/metainits/epoch999.pth --model_type=latent-layerwise-mixtureinr --w0=30 --width=64 --depth=4 --k_mixtures=256 --hidden_features=128 --lr_outer=1e-4 --lr_inner=1e-4 --weight_decay_inner=1.0 --dataset=srn_cars --resolution=128 --num_epochs=1000 --save_freq=100 --use_lr_schedule --batch_size=8 --subsampled_views=32 --subsampled_pixels=512 3 | -------------------------------------------------------------------------------- /src/utils/structural_losses/makefile: -------------------------------------------------------------------------------- 1 | nvcc = /usr/local/cuda/bin/nvcc 2 | cudalib = /usr/local/cuda/lib64 3 | tensorflow = /home/tackgeun/anaconda3/envs/asym-diff/lib/python3.8/site-packages/tensorflow/include/ 4 | 5 | all: tf_approxmatch_so.so tf_approxmatch_g.cu.o tf_nndistance_so.so tf_nndistance_g.cu.o 6 | 7 | 8 | tf_approxmatch_so.so: tf_approxmatch_g.cu.o tf_approxmatch.cpp 9 | g++ -std=c++17 tf_approxmatch.cpp tf_approxmatch_g.cu.o -o tf_approxmatch_so.so -shared -fPIC -I $(tensorflow) -lcudart -L $(cudalib) -O2 -D_GLIBCXX_USE_CXX17_ABI=0 10 | 11 | 12 | tf_approxmatch_g.cu.o: tf_approxmatch_g.cu 13 | $(nvcc) -D_GLIBCXX_USE_CXX17_ABI=0 -std=c++17 -c -o tf_approxmatch_g.cu.o tf_approxmatch_g.cu -I $(tensorflow) -DGOOGLE_CUDA=1 -x cu -Xcompiler -fPIC -O2 14 | 15 | 16 | tf_nndistance_so.so: tf_nndistance_g.cu.o tf_nndistance.cpp 17 | g++ -std=c++17 tf_nndistance.cpp tf_nndistance_g.cu.o -o tf_nndistance_so.so -shared -fPIC -I $(tensorflow) -lcudart -L $(cudalib) -O2 -D_GLIBCXX_USE_CXX17_ABI=0 18 | 19 | 20 | tf_nndistance_g.cu.o: tf_nndistance_g.cu 21 | $(nvcc) -D_GLIBCXX_USE_CXX17_ABI=0 -std=c++17 -c -o tf_nndistance_g.cu.o tf_nndistance_g.cu -I $(tensorflow) -DGOOGLE_CUDA=1 -x cu -Xcompiler -fPIC -O2 22 | 23 | 24 | clean: 25 | rm tf_approxmatch_so.so 26 | rm tf_nndistance_so.so 27 | rm *.cu.o 28 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # HQ-Transformer 3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved. 4 | # Licensed under the MIT License [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | 7 | import os 8 | import torch 9 | import torch.nn as nn 10 | import pytorch_lightning as pl 11 | 12 | from typing import Optional, Tuple 13 | from omegaconf import OmegaConf 14 | from torch.nn import functional as F 15 | from torch.cuda.amp import autocast 16 | 17 | from src.optimizers.scheduler import build_scheduler 18 | 19 | # stage 1 model builder 20 | def build_model_stage1(stage1_type, config): 21 | if stage1_type == 'asym-diff': 22 | from .stage1.asymautoenc import AsymmetricAutoEncoder 23 | return AsymmetricAutoEncoder(hparams=config.stage1, 24 | hparams_opt=config.optimizer) 25 | elif stage1_type == 'parammix': 26 | from .stage1.metainr import MetaINR 27 | return MetaINR(hparams=config) 28 | 29 | def build_model_stage2(cfg_stage2, cfg_opt, affine): 30 | from .stage2.latentddim import LatentDDIM 31 | return LatentDDIM(hparams=cfg_stage2, hparams_opt=cfg_opt, affine=affine) 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /configs/CelebAHQ64px-ResidualMLP_L8_D4096_scale1_epoch1000-mNIF_K256L4W64H1024_W0_30_epoch800.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | dataset: celebahq 3 | image_resolution: 64 4 | data_dimension: 1024 5 | 6 | stage2: 7 | hparams_diffusion: 8 | steps: 1000 9 | learn_sigma: False 10 | noise_schedule: squaredcos_cap_v2 11 | rescale_learned_sigmas: False 12 | predict_xstart: True 13 | 14 | hparams_model: 15 | net_type: functa 16 | num_layers: 8 17 | num_hid_groups: 16 18 | num_hid_channels: 4096 19 | num_channels: 1024 20 | activation: 'silu' 21 | 22 | num_time_layers: 1 23 | num_time_emb_channels: 64 24 | 25 | feat_type: context 26 | feat_std_scale: 1.0 27 | 28 | hparams_metainr: 29 | width: 64 30 | depth: 4 31 | k_mixtures: 256 32 | latent_channels: 1024 33 | in_channels: 2 34 | out_channels: 3 35 | use_meta_sgd: True 36 | pred_type: image 37 | 38 | optimizer: 39 | use_amp: False 40 | use_ema: True 41 | opt_type: adamw 42 | base_lr: 1.0e-4 43 | grad_clip_norm: 0.0 # means don't clip 44 | betas: [0.9, 0.999] 45 | warmup: 46 | multiplier: 1 47 | warmup_epoch: 0.0 48 | buffer_epoch: 0 49 | min_lr: 0.0 50 | mode: fix 51 | start_from_zero: True 52 | 53 | experiment: 54 | epochs: 1000 55 | save_ckpt_freq: 1000 56 | -------------------------------------------------------------------------------- /configs/CelebAHQ64px-ResidualMLP_L8_D4096_scale1_epoch1000-mNIF_K384L5W128H512_W0_30_epoch800.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | dataset: celebahq 3 | image_resolution: 64 4 | data_dimension: 512 5 | 6 | stage2: 7 | hparams_diffusion: 8 | steps: 1000 9 | learn_sigma: False 10 | noise_schedule: squaredcos_cap_v2 11 | rescale_learned_sigmas: False 12 | predict_xstart: True 13 | 14 | hparams_model: 15 | net_type: functa 16 | num_layers: 8 17 | num_hid_groups: 16 18 | num_hid_channels: 4096 19 | num_channels: 512 20 | activation: 'silu' 21 | 22 | num_time_layers: 1 23 | num_time_emb_channels: 64 24 | 25 | feat_type: context 26 | feat_std_scale: 1.0 27 | 28 | hparams_metainr: 29 | width: 128 30 | depth: 5 31 | k_mixtures: 384 32 | latent_channels: 512 33 | in_channels: 2 34 | out_channels: 3 35 | use_meta_sgd: True 36 | pred_type: image 37 | 38 | optimizer: 39 | use_amp: False 40 | use_ema: True 41 | opt_type: adamw 42 | base_lr: 1.0e-4 43 | grad_clip_norm: 0.0 # means don't clip 44 | betas: [0.9, 0.999] 45 | warmup: 46 | multiplier: 1 47 | warmup_epoch: 0.0 48 | buffer_epoch: 0 49 | min_lr: 0.0 50 | mode: fix 51 | start_from_zero: True 52 | 53 | experiment: 54 | epochs: 1000 55 | save_ckpt_freq: 1000 56 | -------------------------------------------------------------------------------- /configs/ShapeNet64px-ResidualMLP-L8-D4096-scale1-epoch1000-MetaLatentMixtureINR-K256L4W64H512-W0-50-epochs800.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | dataset: shapenet 3 | image_resolution: 64 4 | data_dimension: 256 5 | 6 | stage2: 7 | hparams_diffusion: 8 | steps: 1000 9 | learn_sigma: False 10 | noise_schedule: squaredcos_cap_v2 11 | rescale_learned_sigmas: False 12 | predict_xstart: True 13 | 14 | hparams_model: 15 | net_type: functa 16 | num_layers: 8 17 | num_hid_groups: 16 18 | num_hid_channels: 4096 19 | num_channels: 512 20 | activation: 'silu' 21 | 22 | num_time_layers: 1 23 | num_time_emb_channels: 64 24 | 25 | feat_type: context 26 | feat_std_scale: 1.0 27 | 28 | hparams_metainr: 29 | width: 64 30 | depth: 4 31 | w0: 50 32 | k_mixtures: 256 33 | latent_channels: 512 34 | image_resolution: 64 35 | in_channels: 3 36 | out_channels: 1 37 | use_meta_sgd: True 38 | pred_type: voxel 39 | 40 | optimizer: 41 | use_amp: False 42 | use_ema: True 43 | opt_type: adamw 44 | base_lr: 1.0e-4 45 | grad_clip_norm: 0.0 # means don't clip 46 | betas: [0.9, 0.999] 47 | warmup: 48 | multiplier: 1 49 | warmup_epoch: 0.0 50 | buffer_epoch: 0 51 | min_lr: 0.0 52 | mode: fix 53 | start_from_zero: True 54 | 55 | experiment: 56 | epochs: 1000 57 | save_ckpt_freq: 1000 58 | -------------------------------------------------------------------------------- /configs/ShapeNet64px-ResidualMLP-L8-D4096-scale1-epoch1000-MetaLatentMixtureINR-K512L5W128H1024-W0-30-epochs400.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | dataset: shapenet 3 | image_resolution: 64 4 | data_dimension: 1024 5 | 6 | stage2: 7 | hparams_diffusion: 8 | steps: 1000 9 | learn_sigma: False 10 | noise_schedule: squaredcos_cap_v2 11 | rescale_learned_sigmas: False 12 | predict_xstart: True 13 | 14 | hparams_model: 15 | net_type: functa 16 | num_layers: 8 17 | num_hid_groups: 16 18 | num_hid_channels: 4096 19 | num_channels: 1024 20 | activation: 'silu' 21 | 22 | num_time_layers: 1 23 | num_time_emb_channels: 64 24 | 25 | feat_type: context 26 | feat_std_scale: 1.0 27 | 28 | hparams_metainr: 29 | width: 128 30 | depth: 5 31 | w0: 30 32 | k_mixtures: 512 33 | latent_channels: 1024 34 | image_resolution: 64 35 | in_channels: 3 36 | out_channels: 1 37 | use_meta_sgd: True 38 | pred_type: voxel 39 | 40 | optimizer: 41 | use_amp: False 42 | use_ema: True 43 | opt_type: adamw 44 | base_lr: 1.0e-4 45 | grad_clip_norm: 0.0 # means don't clip 46 | betas: [0.9, 0.999] 47 | warmup: 48 | multiplier: 1 49 | warmup_epoch: 0.0 50 | buffer_epoch: 0 51 | min_lr: 0.0 52 | mode: fix 53 | start_from_zero: True 54 | 55 | experiment: 56 | epochs: 1000 57 | save_ckpt_freq: 1000 58 | -------------------------------------------------------------------------------- /configs/SRNCars128px-ResidualMLP-L8-D4096-scale1-epoch1000-NV32-NP512-LatentMixtureINR-K256L4W64H128-nouse+elu-epoch1000+1000.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | dataset: srncars 3 | image_resolution: 128 4 | data_dimension: 128 5 | hparams_nerf: 6 | resolution: 128 7 | epoch_for_full_rendering: 5 8 | subsampled_views: 32 9 | subsampled_pixels: 512 10 | num_samples_per_ray: 31 11 | near: 0.8 12 | far: 1.8 13 | randomized: True 14 | prob_mask_sampling: 0.0 15 | rgb_activation: no_use 16 | density_activation: elu 17 | H: 128 18 | W: 128 19 | 20 | stage2: 21 | hparams_diffusion: 22 | steps: 1000 23 | learn_sigma: False 24 | noise_schedule: squaredcos_cap_v2 25 | rescale_learned_sigmas: False 26 | predict_xstart: True 27 | 28 | hparams_model: 29 | net_type: functa 30 | num_layers: 8 31 | num_hid_groups: 16 32 | num_hid_channels: 4096 33 | num_channels: 128 34 | activation: silu 35 | 36 | num_time_layers: 1 37 | num_time_emb_channels: 64 38 | 39 | feat_type: context 40 | feat_std_scale: 1.0 41 | 42 | hparams_metainr: 43 | width: 64 44 | depth: 4 45 | w0: 30 46 | k_mixtures: 256 47 | latent_channels: 128 48 | image_resolution: 128 49 | in_channels: 3 50 | out_channels: 4 51 | use_meta_sgd: False 52 | pred_type: scene 53 | 54 | optimizer: 55 | use_amp: False 56 | use_ema: True 57 | opt_type: adamw 58 | base_lr: 1.0e-4 59 | grad_clip_norm: 0.0 # means don't clip 60 | betas: [0.9, 0.999] 61 | warmup: 62 | multiplier: 1 63 | warmup_epoch: 0.0 64 | buffer_epoch: 0 65 | min_lr: 0.0 66 | mode: fix 67 | start_from_zero: True 68 | 69 | experiment: 70 | epochs: 1000 71 | save_ckpt_freq: 1000 72 | -------------------------------------------------------------------------------- /src/models/stage2/diffusion/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for various likelihood-based losses. These are ported from the original 3 | Ho et al. diffusion models codebase: 4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py 5 | """ 6 | 7 | import numpy as np 8 | 9 | import torch as th 10 | 11 | 12 | def normal_kl(mean1, logvar1, mean2, logvar2): 13 | """ 14 | Compute the KL divergence between two gaussians. 15 | 16 | Shapes are automatically broadcasted, so batches can be compared to 17 | scalars, among other use cases. 18 | """ 19 | tensor = None 20 | for obj in (mean1, logvar1, mean2, logvar2): 21 | if isinstance(obj, th.Tensor): 22 | tensor = obj 23 | break 24 | assert tensor is not None, "at least one argument must be a Tensor" 25 | 26 | # Force variances to be Tensors. Broadcasting helps convert scalars to 27 | # Tensors, but it does not work for th.exp(). 28 | logvar1, logvar2 = [ 29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 30 | for x in (logvar1, logvar2) 31 | ] 32 | 33 | return 0.5 * ( 34 | -1.0 35 | + logvar2 36 | - logvar1 37 | + th.exp(logvar1 - logvar2) 38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 39 | ) 40 | 41 | 42 | def approx_standard_normal_cdf(x): 43 | """ 44 | A fast approximation of the cumulative distribution function of the 45 | standard normal. 46 | """ 47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 48 | 49 | 50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 51 | """ 52 | Compute the log-likelihood of a Gaussian distribution discretizing to a 53 | given image. 54 | 55 | :param x: the target images. It is assumed that this was uint8 values, 56 | rescaled to the range [-1, 1]. 57 | :param means: the Gaussian mean Tensor. 58 | :param log_scales: the Gaussian log stddev Tensor. 59 | :return: a tensor like x of log probabilities (in nats). 60 | """ 61 | assert x.shape == means.shape == log_scales.shape 62 | centered_x = x - means 63 | inv_stdv = th.exp(-log_scales) 64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 65 | cdf_plus = approx_standard_normal_cdf(plus_in) 66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 67 | cdf_min = approx_standard_normal_cdf(min_in) 68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 70 | cdf_delta = cdf_plus - cdf_min 71 | log_probs = th.where( 72 | x < -0.999, 73 | log_cdf_plus, 74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 75 | ) 76 | assert log_probs.shape == x.shape 77 | return log_probs 78 | -------------------------------------------------------------------------------- /src/models/stage2/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /src/utils/structural_losses/tf_nndistance.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.framework import ops 3 | import os.path as osp 4 | 5 | base_dir = osp.dirname(osp.abspath(__file__)) 6 | 7 | nn_distance_module = tf.load_op_library(osp.join(base_dir, 'tf_nndistance_so.so')) 8 | 9 | 10 | def nn_distance(xyz1, xyz2): 11 | ''' 12 | Computes the distance of nearest neighbors for a pair of point clouds 13 | input: xyz1: (batch_size,#points_1,3) the first point cloud 14 | input: xyz2: (batch_size,#points_2,3) the second point cloud 15 | output: dist1: (batch_size,#point_1) distance from first to second 16 | output: idx1: (batch_size,#point_1) nearest neighbor from first to second 17 | output: dist2: (batch_size,#point_2) distance from second to first 18 | output: idx2: (batch_size,#point_2) nearest neighbor from second to first 19 | ''' 20 | 21 | return nn_distance_module.nn_distance(xyz1,xyz2) 22 | 23 | #@tf.RegisterShape('NnDistance') 24 | #@ops.RegisterShape('NnDistance') 25 | def _nn_distance_shape(op): 26 | shape1=op.inputs[0].get_shape().with_rank(3) 27 | shape2=op.inputs[1].get_shape().with_rank(3) 28 | return [tf.TensorShape([shape1.dims[0],shape1.dims[1]]),tf.TensorShape([shape1.dims[0],shape1.dims[1]]), 29 | tf.TensorShape([shape2.dims[0],shape2.dims[1]]),tf.TensorShape([shape2.dims[0],shape2.dims[1]])] 30 | @ops.RegisterGradient('NnDistance') 31 | def _nn_distance_grad(op,grad_dist1,grad_idx1,grad_dist2,grad_idx2): 32 | xyz1=op.inputs[0] 33 | xyz2=op.inputs[1] 34 | idx1=op.outputs[1] 35 | idx2=op.outputs[3] 36 | return nn_distance_module.nn_distance_grad(xyz1,xyz2,grad_dist1,idx1,grad_dist2,idx2) 37 | 38 | 39 | if __name__=='__main__': 40 | import numpy as np 41 | import random 42 | import time 43 | #from tensorflow.python.kernel_tests.gradient_checker import compute_gradient 44 | random.seed(100) 45 | np.random.seed(100) 46 | with tf.Session('') as sess: 47 | xyz1=np.random.randn(32,16384,3).astype('float32') 48 | xyz2=np.random.randn(32,1024,3).astype('float32') 49 | with tf.device('/gpu:0'): 50 | inp1=tf.Variable(xyz1) 51 | inp2=tf.constant(xyz2) 52 | reta,retb,retc,retd=nn_distance(inp1,inp2) 53 | loss=tf.reduce_sum(reta)+tf.reduce_sum(retc) 54 | train=tf.train.GradientDescentOptimizer(learning_rate=0.05).minimize(loss) 55 | sess.run(tf.initialize_all_variables()) 56 | t0=time.time() 57 | t1=t0 58 | best=1e100 59 | for i in range(100): 60 | trainloss,_=sess.run([loss,train]) 61 | newt=time.time() 62 | best=min(best,newt-t1) 63 | print(i,trainloss,(newt-t0)/(i+1),best) 64 | t1=newt 65 | #print sess.run([inp1,retb,inp2,retd]) 66 | #grads=compute_gradient([inp1,inp2],[(16,32,3),(16,32,3)],loss,(1,),[xyz1,xyz2]) 67 | #for i,j in grads: 68 | #print i.shape,j.shape,np.mean(np.abs(i-j)),np.mean(np.abs(i)),np.mean(np.abs(j)) 69 | #for i in xrange(10): 70 | #t0=time.time() 71 | #a,b,c,d=sess.run([reta,retb,retc,retd],feed_dict={inp1:xyz1,inp2:xyz2}) 72 | #print 'time',time.time()-t0 73 | #print a.shape,b.shape,c.shape,d.shape 74 | #print a.dtype,b.dtype,c.dtype,d.dtype 75 | #samples=np.array(random.sample(range(xyz2.shape[1]),100),dtype='int32') 76 | #dist1=((xyz1[:,samples,None,:]-xyz2[:,None,:,:])**2).sum(axis=-1).min(axis=-1) 77 | #idx1=((xyz1[:,samples,None,:]-xyz2[:,None,:,:])**2).sum(axis=-1).argmin(axis=-1) 78 | #print np.abs(dist1-a[:,samples]).max() 79 | #print np.abs(idx1-b[:,samples]).max() 80 | #dist2=((xyz2[:,samples,None,:]-xyz1[:,None,:,:])**2).sum(axis=-1).min(axis=-1) 81 | #idx2=((xyz2[:,samples,None,:]-xyz1[:,None,:,:])**2).sum(axis=-1).argmin(axis=-1) 82 | #print np.abs(dist2-c[:,samples]).max() 83 | #print np.abs(idx2-d[:,samples]).max() 84 | 85 | -------------------------------------------------------------------------------- /src/models/stage1/modules/utils.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # HQ-Transformer 3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved. 4 | # Licensed under the MIT License [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class ActNorm(nn.Module): 12 | def __init__(self, num_features, logdet=False, affine=True, 13 | allow_reverse_init=False): 14 | assert affine 15 | super().__init__() 16 | self.logdet = logdet 17 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 18 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 19 | self.allow_reverse_init = allow_reverse_init 20 | 21 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) 22 | 23 | def initialize(self, input): 24 | with torch.no_grad(): 25 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 26 | mean = ( 27 | flatten.mean(1) 28 | .unsqueeze(1) 29 | .unsqueeze(2) 30 | .unsqueeze(3) 31 | .permute(1, 0, 2, 3) 32 | ) 33 | std = ( 34 | flatten.std(1) 35 | .unsqueeze(1) 36 | .unsqueeze(2) 37 | .unsqueeze(3) 38 | .permute(1, 0, 2, 3) 39 | ) 40 | 41 | self.loc.data.copy_(-mean) 42 | self.scale.data.copy_(1 / (std + 1e-6)) 43 | 44 | def forward(self, input, reverse=False): 45 | if reverse: 46 | return self.reverse(input) 47 | if len(input.shape) == 2: 48 | input = input[:, :, None, None] 49 | squeeze = True 50 | else: 51 | squeeze = False 52 | 53 | _, _, height, width = input.shape 54 | 55 | if self.training and self.initialized.item() == 0: 56 | self.initialize(input) 57 | self.initialized.fill_(1) 58 | 59 | h = self.scale * (input + self.loc) 60 | 61 | if squeeze: 62 | h = h.squeeze(-1).squeeze(-1) 63 | 64 | if self.logdet: 65 | log_abs = torch.log(torch.abs(self.scale)) 66 | logdet = height*width*torch.sum(log_abs) 67 | logdet = logdet * torch.ones(input.shape[0]).to(input) 68 | return h, logdet 69 | 70 | return h 71 | 72 | def reverse(self, output): 73 | if self.training and self.initialized.item() == 0: 74 | if not self.allow_reverse_init: 75 | raise RuntimeError( 76 | "Initializing ActNorm in reverse direction is " 77 | "disabled by default. Use allow_reverse_init=True to enable." 78 | ) 79 | else: 80 | self.initialize(output) 81 | self.initialized.fill_(1) 82 | 83 | if len(output.shape) == 2: 84 | output = output[:, :, None, None] 85 | squeeze = True 86 | else: 87 | squeeze = False 88 | 89 | h = output / self.scale - self.loc 90 | 91 | if squeeze: 92 | h = h.squeeze(-1).squeeze(-1) 93 | return h 94 | 95 | 96 | def weights_init(m): 97 | classname = m.__class__.__name__ 98 | if classname.find('Conv') != -1: 99 | nn.init.normal_(m.weight.data, 0.0, 0.02) 100 | elif classname.find('BatchNorm') != -1: 101 | nn.init.normal_(m.weight.data, 1.0, 0.02) 102 | nn.init.constant_(m.bias.data, 0) 103 | -------------------------------------------------------------------------------- /src/models/stage2/diffae/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | from enum import Enum 6 | import math 7 | from typing import Optional 8 | 9 | import torch as th 10 | import torch.nn as nn 11 | import torch.utils.checkpoint 12 | 13 | import torch.nn.functional as F 14 | 15 | 16 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 17 | class SiLU(nn.Module): 18 | # @th.jit.script 19 | def forward(self, x): 20 | return x * th.sigmoid(x) 21 | 22 | 23 | class GroupNorm32(nn.GroupNorm): 24 | def forward(self, x): 25 | return super().forward(x.float()).type(x.dtype) 26 | 27 | 28 | def conv_nd(dims, *args, **kwargs): 29 | """ 30 | Create a 1D, 2D, or 3D convolution module. 31 | """ 32 | if dims == 1: 33 | return nn.Conv1d(*args, **kwargs) 34 | elif dims == 2: 35 | return nn.Conv2d(*args, **kwargs) 36 | elif dims == 3: 37 | return nn.Conv3d(*args, **kwargs) 38 | raise ValueError(f"unsupported dimensions: {dims}") 39 | 40 | 41 | def linear(*args, **kwargs): 42 | """ 43 | Create a linear module. 44 | """ 45 | return nn.Linear(*args, **kwargs) 46 | 47 | 48 | def avg_pool_nd(dims, *args, **kwargs): 49 | """ 50 | Create a 1D, 2D, or 3D average pooling module. 51 | """ 52 | if dims == 1: 53 | return nn.AvgPool1d(*args, **kwargs) 54 | elif dims == 2: 55 | return nn.AvgPool2d(*args, **kwargs) 56 | elif dims == 3: 57 | return nn.AvgPool3d(*args, **kwargs) 58 | raise ValueError(f"unsupported dimensions: {dims}") 59 | 60 | 61 | def update_ema(target_params, source_params, rate=0.99): 62 | """ 63 | Update target parameters to be closer to those of source parameters using 64 | an exponential moving average. 65 | 66 | :param target_params: the target parameter sequence. 67 | :param source_params: the source parameter sequence. 68 | :param rate: the EMA rate (closer to 1 means slower). 69 | """ 70 | for targ, src in zip(target_params, source_params): 71 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 72 | 73 | 74 | def zero_module(module): 75 | """ 76 | Zero out the parameters of a module and return it. 77 | """ 78 | for p in module.parameters(): 79 | p.detach().zero_() 80 | return module 81 | 82 | 83 | def scale_module(module, scale): 84 | """ 85 | Scale the parameters of a module and return it. 86 | """ 87 | for p in module.parameters(): 88 | p.detach().mul_(scale) 89 | return module 90 | 91 | 92 | def mean_flat(tensor): 93 | """ 94 | Take the mean over all non-batch dimensions. 95 | """ 96 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 97 | 98 | 99 | def normalization(channels): 100 | """ 101 | Make a standard normalization layer. 102 | 103 | :param channels: number of input channels. 104 | :return: an nn.Module for normalization. 105 | """ 106 | return GroupNorm32(min(32, channels), channels) 107 | 108 | 109 | def timestep_embedding(timesteps, dim, max_period=10000): 110 | """ 111 | Create sinusoidal timestep embeddings. 112 | 113 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 114 | These may be fractional. 115 | :param dim: the dimension of the output. 116 | :param max_period: controls the minimum frequency of the embeddings. 117 | :return: an [N x dim] Tensor of positional embeddings. 118 | """ 119 | half = dim // 2 120 | freqs = th.exp(-math.log(max_period) * 121 | th.arange(start=0, end=half, dtype=th.float32) / 122 | half).to(device=timesteps.device) 123 | args = timesteps[:, None].float() * freqs[None] 124 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 125 | if dim % 2: 126 | embedding = th.cat( 127 | [embedding, th.zeros_like(embedding[:, :1])], dim=-1) 128 | return embedding 129 | 130 | 131 | def torch_checkpoint(func, args, flag, preserve_rng_state=False): 132 | # torch's gradient checkpoint works with automatic mixed precision, given torch >= 1.8 133 | if flag: 134 | return torch.utils.checkpoint.checkpoint( 135 | func, *args, preserve_rng_state=preserve_rng_state) 136 | else: 137 | return func(*args) 138 | -------------------------------------------------------------------------------- /src/utils/structural_losses/tf_approxmatch.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.framework import ops 3 | import os.path as osp 4 | 5 | base_dir = osp.dirname(osp.abspath(__file__)) 6 | 7 | approxmatch_module = tf.load_op_library(osp.join(base_dir, 'tf_approxmatch_so.so')) 8 | 9 | 10 | def approx_match(xyz1,xyz2): 11 | ''' 12 | input: 13 | xyz1 : batch_size * #dataset_points * 3 14 | xyz2 : batch_size * #query_points * 3 15 | returns: 16 | match : batch_size * #query_points * #dataset_points 17 | ''' 18 | return approxmatch_module.approx_match(xyz1,xyz2) 19 | ops.NoGradient('ApproxMatch') 20 | #@tf.RegisterShape('ApproxMatch') 21 | @ops.RegisterShape('ApproxMatch') 22 | def _approx_match_shape(op): 23 | shape1=op.inputs[0].get_shape().with_rank(3) 24 | shape2=op.inputs[1].get_shape().with_rank(3) 25 | return [tf.TensorShape([shape1.dims[0],shape2.dims[1],shape1.dims[1]])] 26 | 27 | def match_cost(xyz1,xyz2,match): 28 | ''' 29 | input: 30 | xyz1 : batch_size * #dataset_points * 3 31 | xyz2 : batch_size * #query_points * 3 32 | match : batch_size * #query_points * #dataset_points 33 | returns: 34 | cost : batch_size 35 | ''' 36 | return approxmatch_module.match_cost(xyz1,xyz2,match) 37 | #@tf.RegisterShape('MatchCost') 38 | @ops.RegisterShape('MatchCost') 39 | def _match_cost_shape(op): 40 | shape1=op.inputs[0].get_shape().with_rank(3) 41 | shape2=op.inputs[1].get_shape().with_rank(3) 42 | shape3=op.inputs[2].get_shape().with_rank(3) 43 | return [tf.TensorShape([shape1.dims[0]])] 44 | @tf.RegisterGradient('MatchCost') 45 | def _match_cost_grad(op,grad_cost): 46 | xyz1=op.inputs[0] 47 | xyz2=op.inputs[1] 48 | match=op.inputs[2] 49 | grad_1,grad_2=approxmatch_module.match_cost_grad(xyz1,xyz2,match) 50 | return [grad_1*tf.expand_dims(tf.expand_dims(grad_cost,1),2),grad_2*tf.expand_dims(tf.expand_dims(grad_cost,1),2),None] 51 | 52 | if __name__=='__main__': 53 | alpha=0.5 54 | beta=2.0 55 | import bestmatch 56 | import numpy as np 57 | import math 58 | import random 59 | import cv2 60 | 61 | import tf_nndistance 62 | 63 | npoint=100 64 | 65 | with tf.device('/gpu:2'): 66 | pt_in=tf.placeholder(tf.float32,shape=(1,npoint*4,3)) 67 | mypoints=tf.Variable(np.random.randn(1,npoint,3).astype('float32')) 68 | match=approx_match(pt_in,mypoints) 69 | loss=tf.reduce_sum(match_cost(pt_in,mypoints,match)) 70 | #match=approx_match(mypoints,pt_in) 71 | #loss=tf.reduce_sum(match_cost(mypoints,pt_in,match)) 72 | #distf,_,distb,_=tf_nndistance.nn_distance(pt_in,mypoints) 73 | #loss=tf.reduce_sum((distf+1e-9)**0.5)*0.5+tf.reduce_sum((distb+1e-9)**0.5)*0.5 74 | #loss=tf.reduce_max((distf+1e-9)**0.5)*0.5*npoint+tf.reduce_max((distb+1e-9)**0.5)*0.5*npoint 75 | 76 | optimizer=tf.train.GradientDescentOptimizer(1e-4).minimize(loss) 77 | with tf.Session('') as sess: 78 | sess.run(tf.initialize_all_variables()) 79 | while True: 80 | meanloss=0 81 | meantrueloss=0 82 | for i in xrange(1001): 83 | #phi=np.random.rand(4*npoint)*math.pi*2 84 | #tpoints=(np.hstack([np.cos(phi)[:,None],np.sin(phi)[:,None],(phi*0)[:,None]])*random.random())[None,:,:] 85 | #tpoints=((np.random.rand(400)-0.5)[:,None]*[0,2,0]+[(random.random()-0.5)*2,0,0]).astype('float32')[None,:,:] 86 | tpoints=np.hstack([np.linspace(-1,1,400)[:,None],(random.random()*2*np.linspace(1,0,400)**2)[:,None],np.zeros((400,1))])[None,:,:] 87 | trainloss,_=sess.run([loss,optimizer],feed_dict={pt_in:tpoints.astype('float32')}) 88 | trainloss,trainmatch=sess.run([loss,match],feed_dict={pt_in:tpoints.astype('float32')}) 89 | #trainmatch=trainmatch.transpose((0,2,1)) 90 | show=np.zeros((400,400,3),dtype='uint8')^255 91 | trainmypoints=sess.run(mypoints) 92 | for i in xrange(len(tpoints[0])): 93 | u=np.random.choice(range(len(trainmypoints[0])),p=trainmatch[0].T[i]) 94 | cv2.line(show, 95 | (int(tpoints[0][i,1]*100+200),int(tpoints[0][i,0]*100+200)), 96 | (int(trainmypoints[0][u,1]*100+200),int(trainmypoints[0][u,0]*100+200)), 97 | cv2.cv.CV_RGB(0,255,0)) 98 | for x,y,z in tpoints[0]: 99 | cv2.circle(show,(int(y*100+200),int(x*100+200)),2,cv2.cv.CV_RGB(255,0,0)) 100 | for x,y,z in trainmypoints[0]: 101 | cv2.circle(show,(int(y*100+200),int(x*100+200)),3,cv2.cv.CV_RGB(0,0,255)) 102 | cost=((tpoints[0][:,None,:]-np.repeat(trainmypoints[0][None,:,:],4,axis=1))**2).sum(axis=2)**0.5 103 | #trueloss=bestmatch.bestmatch(cost)[0] 104 | print trainloss#,trueloss 105 | cv2.imshow('show',show) 106 | cmd=cv2.waitKey(10)%256 107 | if cmd==ord('q'): 108 | break 109 | -------------------------------------------------------------------------------- /src/optimizers/scheduler.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Modified from RQ-Transformer 3 | # https://github.com/kakaobrain/rq-vae-transformer/blob/main/rqvae/optimizer/scheduler.py 4 | # ------------------------------------------------------------------------------------ 5 | 6 | import math 7 | import torch 8 | from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR 9 | import pdb 10 | 11 | def build_scheduler(optimizer, 12 | base_lr, 13 | steps_per_epoch, 14 | final_steps, 15 | warmup_config, 16 | sche_type='cosine', 17 | world_size=None): 18 | 19 | multiplier = warmup_config.multiplier 20 | warmup_steps = warmup_config.warmup_epoch * steps_per_epoch 21 | buffer_steps = warmup_config.buffer_epoch * steps_per_epoch 22 | min_lr = warmup_config.min_lr 23 | mode = warmup_config.mode 24 | start_from_zero = warmup_config.start_from_zero 25 | 26 | 27 | if sche_type == 'cosine' or sche_type is None: 28 | scheduler = CosineAnnealingLR( 29 | optimizer, T_max=final_steps - warmup_steps - buffer_steps, eta_min=min_lr 30 | ) 31 | elif sche_type == 'const': 32 | scheduler = StepLR( 33 | optimizer, factor=0.1, total_iters=(final_steps - warmup_steps - buffer_steps) // 2 34 | ) 35 | else: 36 | raise NotImplementedError(f'{sche_type} is not supported..') 37 | 38 | if warmup_steps > 0.0: 39 | if mode == 'linear': 40 | multiplier = max(1.0, multiplier * world_size) 41 | elif mode == 'sqrt': 42 | multiplier = max(1.0, multiplier * math.sqrt(world_size)) 43 | elif mode == 'fix': 44 | multiplier = max(1.0, multiplier) 45 | elif mode == 'none': 46 | pass 47 | else: 48 | raise NotImplementedError(f'{mode} is not a valid warmup policy') 49 | warmup = GradualWarmup( 50 | optimizer, 51 | steps=warmup_steps, 52 | buffer_steps=buffer_steps, 53 | multiplier=multiplier, 54 | start_from_zero=start_from_zero 55 | ) 56 | else: 57 | warmup = None 58 | 59 | scheduler = Scheduler(optimizer, warmup_scheduler=warmup, after_scheduler=scheduler) 60 | return scheduler 61 | 62 | 63 | class GradualWarmup(torch.optim.lr_scheduler._LRScheduler): 64 | def __init__(self, optimizer, steps, buffer_steps, multiplier, start_from_zero=True, last_epoch=-1): 65 | self.steps = steps 66 | self.t_steps = steps + buffer_steps 67 | self.multiplier = multiplier 68 | self.start_from_zero = start_from_zero 69 | 70 | super().__init__(optimizer, last_epoch) 71 | 72 | def get_lr(self): 73 | if self.last_epoch > self.steps: 74 | return [group['lr'] for group in self.optimizer.param_groups] 75 | 76 | if self.start_from_zero: 77 | multiplier = self.multiplier * min(1.0, (self.last_epoch / self.steps)) 78 | else: 79 | multiplier = 1 + ((self.multiplier - 1) * min(1.0, (self.last_epoch / self.steps))) 80 | return [lr * multiplier for lr in self.base_lrs] 81 | 82 | 83 | class Scheduler(torch.optim.lr_scheduler._LRScheduler): 84 | def __init__(self, optimizer, warmup_scheduler, after_scheduler, last_epoch=-1): 85 | self.warmup_scheduler = warmup_scheduler 86 | self.after_scheduler = after_scheduler 87 | 88 | super().__init__(optimizer, last_epoch) 89 | 90 | def step(self, epoch=None): 91 | if self.warmup_scheduler is not None: 92 | self.warmup_scheduler.step(epoch=epoch) 93 | 94 | if self.warmup_scheduler is None or \ 95 | self.warmup_scheduler.last_epoch > self.warmup_scheduler.t_steps: 96 | self.after_scheduler.step(epoch=epoch) 97 | 98 | def get_last_lr(self): 99 | if self.warmup_scheduler is not None and \ 100 | self.warmup_scheduler.last_epoch <= self.warmup_scheduler.t_steps: 101 | return self.warmup_scheduler.get_last_lr() 102 | else: 103 | return self.after_scheduler.get_last_lr() 104 | 105 | def state_dict(self): 106 | return { 107 | 'warmup': None if self.warmup_scheduler is None else self.warmup_scheduler.state_dict(), 108 | 'after': self.after_scheduler.state_dict() 109 | } 110 | 111 | def load_state_dict(self, state_dict): 112 | if self.warmup_scheduler is not None: 113 | self.warmup_scheduler.load_state_dict(state_dict['warmup']) 114 | self.after_scheduler.load_state_dict(state_dict['after']) 115 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Generative Neural Fields by Mixtures of Neural Implicit Functions 2 | The official implementation of Generative Neural Fields by Mixtures of Neural Implicit Functions 3 | - [Tackgeun You](https://tackgeun.github.io/), [Mijeong Kim](https://mjmjeong.github.io/), [Jungtaek Kim](https://jungtaekkim.github.io/) and [Bohyung Han](https://cv.snu.ac.kr/index.php/~bhhan/), (**NeurIPS 2023**) 4 | 5 | ## Requirements 6 | We have tested our codes on the environment below 7 | - `Python 3.8` / `Pytorch 1.10` / `torchvision 0.11.0` / `CUDA 11.3` / `Ubuntu 18.04` . 8 | 9 | Please run the following command to install the necessary dependencies 10 | ```bash 11 | pip install -r requirements.txt 12 | ``` 13 | 14 | or you can use provided docker image. 15 | ``` 16 | docker pull tackgeun/mNIF:init 17 | ``` 18 | 19 | ## Dataset Preparation 20 | Here are benchmarks from three modalities adopted in our work. 21 | Extract those zip files in ${ROOT}/datasets. 22 | - [CelebAHQ 128px](https://www.dropbox.com/scl/fi/l4p15ecmnm9k5qnq8kkx3/CelebAHQ.zip?rlkey=xn2tllj539fkizp9rn2xgsq4x&dl=0). 23 | - [ShapeNet 64x64x64](https://www.dropbox.com/scl/fi/lj9uwsw1234jfw2bpupfk/shapenet.zip?rlkey=2bkv8xmc7en6ok6mc5oczmn5b&dl=0) 24 | - [SRN Cars 128px](https://www.dropbox.com/scl/fi/4maypw7idr7yis8cwxw4d/srn_cars_lmdb.zip?rlkey=zdd67iy24t20xwjbn3vyn6o1q&dl=0) 25 | 26 | 27 | ## Pre-trained models 28 | Here are pre-trained models from three modalities in our work. 29 | Extract those zip files in ${ROOT}/experiments. 30 | - [mNIF (S) CelebAHQ 64px](https://www.dropbox.com/scl/fi/iurr0s79glpehvkdezhuw/CelebAHQ-small.zip?rlkey=6spmw045i2on46glke5l6up4k&dl=0) 31 | - [mNIF (S) ShapeNet 64x64x64](https://www.dropbox.com/scl/fi/8ievjsf0jlmbrof0awlb1/ShapeNet-small.zip?rlkey=m7riec4rmek232p6gswo2mubm&dl=0) 32 | - [mNIF (S) SRN Cars 128px](https://www.dropbox.com/scl/fi/t9qk83eewy0uosn9syt5n/SRNCars-small.zip?rlkey=06jw7gvczv2d61agdeo1mt9yf&dl=0) 33 | 34 | 35 | ## Training and Evaluation Commands 36 | Refer to the shell scripts in scripts. 37 | ### Training Mixtures of Neural Implicit Functions with Meta-Learning (Image, Voxel) 38 | Training stage 1 mNIF with fast context adaptative via meta learning (CAVIA) 39 | ``` 40 | sh scripts/train_stage1_small_CelebAHQ.sh 41 | sh scripts/train_stage1_small_ShapeNet.sh 42 | ``` 43 | 44 | ### Evaluation and Test-time Adaptation of Mixtures of Neural Implicit Functions with Meta-Learning (Image, Voxel) 45 | CAVIA simultaneously conducts adaptation and evaluation of given samples. 46 | ``` 47 | sh scripts/test_stage1_small_CelebAHQ.sh 48 | sh scripts/test_stage1_small_ShapeNet.sh 49 | ``` 50 | For evaluation, remove result path in -r and add a specific model -m=${MODEL_PATH}/metainits/epoch${EPOCH}.pth and add --eval flag. 51 | - It also computes context vectors in latent space, which is saved on ${MODEL_PATH}/contexts/context-epoch${EPOCH}.pth 52 | 53 | If out-of-memory occurs during evaluation, reduce the batch size and lr_inner because lr_inner is dependent on batch size currently. 54 | - If the model is trained with batch_size=32 and lr_inner=10.0, batch_size=16 requires lr_inner=5.0 55 | 56 | ### Training Mixtures of Neural Implicit Functions with Auto-Decoding (NeRF) 57 | Training stage 1 mNIF with auto decoding 58 | ``` 59 | sh scripts/train_stage1_small_SRNCars.sh 60 | ``` 61 | 62 | ### Evaluation of Mixtures of Neural Implicit Functions with Auto-Decoding (NeRF) 63 | Evaluation stage 1 mNIF with auto decoding. 64 | Contrary to CAVIA, auto-decoding procedure already computes context vectors during stage 1 training. 65 | ``` 66 | sh scripts/test_stage1_small_SRNCars.sh 67 | ``` 68 | 69 | 70 | ### Training Denoising Diffusion Process 71 | Training latent diffusion model using features acquired from the context adaptation. 72 | Testingi and test-time adaptation of stage 1 model is required for stage 1 model trained with CAVIA. 73 | ``` 74 | sh scripts/train_stage2_small_CelebAHQ.sh 75 | sh scripts/train_stage2_small_ShapeNet.sh 76 | sh scripts/train_stage2_small_SRNCars.sh 77 | ``` 78 | 79 | ### Evaluating diffusion model 80 | ``` 81 | sh scripts/test_stage2_small_CelebAHQ.sh 82 | sh scripts/test_stage2_small_ShapeNet.sh 83 | sh scripts/test_stage2_small_SRNCars.sh 84 | ``` 85 | 86 | 87 | ## Acknowledgement 88 | Our implementation is based on below repositories. 89 | - Datasets 90 | - CelebAHQ is duplicated dataset from Functa 91 | - Pre-processing [ShapeNet 64x64x64 IMNet](https://drive.google.com/open?id=158so7dnkQQNFSQTj741S3SUbuIXXRrLn) from [IMNet](https://github.com/czq142857/IM-NET) 92 | - Pre-processing SRNCars dataset from [PixelNeRF](https://github.com/sxyu/pixel-nerf) 93 | - Mixtures of neural implicit functions 94 | - [SIREN](https://github.com/vsitzmann/siren) 95 | - [Functa](https://github.com/deepmind/functa) 96 | - [PixelNeRF](https://github.com/sxyu/pixel-nerf) 97 | - ShapeNet evaluation 98 | - [GEM](https://github.com/yilundu/gem) 99 | - Latent diffusion model 100 | - [ADM](https://github.com/openai/guided-diffusion) 101 | - [Karlo](https://github.com/kakaobrain/karlo) 102 | - [HQ-Transformer](https://github.com/kakaobrain/hqtransformer) 103 | -------------------------------------------------------------------------------- /src/models/stage1/nerf/nerf_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import numpy as np 4 | import pdb 5 | ''' 6 | NeRF-pytorch & NeRF-Factory 7 | https://github.com/yenchenlin/nerf-pytorch/blob/master/run_nerf_helpers.py 8 | https://github.com/kakaobrain/NeRF-Factory/blob/main/src/model/nerf/helper.py 9 | ''' 10 | # data preparing 11 | def get_rays(H, W, focal, c2w, padding=None, compute_radii=False): 12 | # pytorch's meshgrid has indexing='ij' 13 | if padding is not None: 14 | i, j = torch.meshgrid(torch.linspace(-padding, W-1+padding, W+2*padding), torch.linspace(-padding, H-1+padding, H+2*padding)) 15 | else: 16 | i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H)) 17 | i = i.t().to(c2w.device) 18 | j = j.t().to(c2w.device) 19 | extra_shift = 0.5 20 | dirs = torch.stack([(i-W*.5+extra_shift)/focal, -(j-H*.5+extra_shift)/focal, -torch.ones_like(i)], -1) 21 | # Rotate ray directions from camera frame to the world frame 22 | rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] 23 | # Translate camera frame's origin to the world frame. It is the origin of all rays. 24 | rays_o = c2w[:3,-1].expand(rays_d.shape) 25 | 26 | if compute_radii: 27 | dx = torch.sqrt(torch.sum((rays_d[:-1, :, :] - rays_d[1:, :, :]) ** 2, -1)) 28 | dx = torch.cat([dx, dx[-2:-1, :]], 0) 29 | 30 | # Cut the distance in half, and then round it out so that it's 31 | # halfway between inscribed by / circumscribed about the pixel. 32 | radii = dx[..., None] * 2 / math.sqrt(12) 33 | return torch.stack((rays_o, rays_d, radii.repeat(1,1,3)), 0) 34 | else: 35 | return torch.stack((rays_o, rays_d), 0) 36 | 37 | def cast_rays(t_vals, origins, directions): 38 | return origins[..., None, :] + t_vals[..., None] * directions[..., None, :] 39 | 40 | def sample_along_rays( 41 | cam_rays, 42 | configs, 43 | ): 44 | # get configs 45 | num_samples = configs.num_samples_per_ray 46 | near, far = configs.near, configs.far 47 | lindisp = configs.lindisp 48 | randomized = configs.randomized # noise 49 | 50 | rays_o, rays_d = cam_rays[0], cam_rays[1] 51 | bsz = rays_o.shape[0] 52 | 53 | t_vals = torch.linspace(0.0, 1.0, num_samples + 1, device=rays_o.device) 54 | if lindisp: 55 | t_vals = 1.0 / (1.0 / near * (1.0 - t_vals) + 1.0 / far * t_vals) 56 | else: 57 | t_vals = near * (1.0 - t_vals) + far * t_vals 58 | 59 | if randomized: 60 | mids = 0.5 * (t_vals[..., 1:] + t_vals[..., :-1]) 61 | upper = torch.cat([mids, t_vals[..., -1:]], -1) 62 | lower = torch.cat([t_vals[..., :1], mids], -1) 63 | t_rand = torch.rand((bsz, num_samples + 1), device=rays_o.device) 64 | t_vals = lower + (upper - lower) * t_rand 65 | else: 66 | t_vals = torch.broadcast_to(t_vals, (bsz, num_samples + 1)) 67 | 68 | coords = cast_rays(t_vals, rays_o, rays_d) 69 | return t_vals, coords 70 | 71 | 72 | def volumetric_rendering(rgb, density, t_vals, dirs, white_bkgd): 73 | eps = 1e-10 74 | 75 | dists = torch.cat( 76 | [ 77 | t_vals[..., 1:] - t_vals[..., :-1], 78 | torch.ones(t_vals[..., :1].shape, device=t_vals.device) * 1e10, 79 | ], 80 | dim=-1, 81 | ) 82 | dists = dists * torch.norm(dirs[..., None, :], dim=-1) 83 | alpha = 1.0 - torch.exp(-density[..., 0] * dists) 84 | accum_prod = torch.cat( 85 | [ 86 | torch.ones_like(alpha[..., :1]), 87 | torch.cumprod(1.0 - alpha[..., :-1] + eps, dim=-1), 88 | ], 89 | dim=-1, 90 | ) 91 | 92 | weights = alpha * accum_prod 93 | 94 | comp_rgb = (weights[..., None] * rgb).sum(dim=-2) 95 | depth = (weights * t_vals).sum(dim=-1) 96 | acc = weights.sum(dim=-1) 97 | inv_eps = 1 / eps 98 | 99 | if white_bkgd: 100 | comp_rgb = comp_rgb + (1.0 - acc[..., None]) 101 | 102 | return comp_rgb, acc, depth, weights 103 | 104 | def volumetric_rendering_functa(rgb, density, t_vals, dirs, white_bkgd): 105 | eps = 1e-10 106 | distance_between_points = t_vals[..., 1:] - t_vals[..., :-1] 107 | dists = torch.cat( 108 | [ 109 | distance_between_points, 110 | torch.ones(distance_between_points[..., :1].shape, device=t_vals.device) * 1e-3, 111 | ], 112 | dim=-1, 113 | ) 114 | dists = dists * torch.norm(dirs[..., None, :], dim=-1) 115 | alpha = 1.0 - torch.exp(-density[..., 0] * dists) 116 | 117 | trans = torch.minimum(torch.ones_like(alpha), 1.0 - alpha + eps) 118 | trans = torch.cat( 119 | [ 120 | torch.ones_like(trans[..., :1]), 121 | trans[..., :-1], 122 | ], 123 | dim=-1, 124 | ) 125 | 126 | cum_trans = torch.cumprod(trans, dim=-1) 127 | 128 | weights = alpha * cum_trans 129 | 130 | comp_rgb = (weights[..., None] * rgb).sum(dim=-2) 131 | depth = (weights * t_vals).sum(dim=-1) 132 | acc = weights.sum(dim=-1) 133 | 134 | if white_bkgd: 135 | comp_rgb = comp_rgb + (1.0 - acc[..., None]) 136 | 137 | return comp_rgb, acc, depth, weights -------------------------------------------------------------------------------- /src/models/stage2/diffusion/respace.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion) 3 | # ------------------------------------------------------------------------------------ 4 | 5 | 6 | import torch as th 7 | 8 | from .gaussian_diffusion import GaussianDiffusion 9 | 10 | 11 | def space_timesteps(num_timesteps, section_counts): 12 | """ 13 | Create a list of timesteps to use from an original diffusion process, 14 | given the number of timesteps we want to take from equally-sized portions 15 | of the original process. 16 | 17 | For example, if there's 300 timesteps and the section counts are [10,15,20] 18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 19 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 20 | 21 | :param num_timesteps: the number of diffusion steps in the original 22 | process to divide up. 23 | :param section_counts: either a list of numbers, or a string containing 24 | comma-separated numbers, indicating the step count 25 | per section. As a special case, use "ddimN" where N 26 | is a number of steps to use the striding from the 27 | DDIM paper. 28 | :return: a set of diffusion steps from the original process to use. 29 | """ 30 | if isinstance(section_counts, str): 31 | if section_counts.startswith("ddim"): 32 | desired_count = int(section_counts[len("ddim") :]) 33 | for i in range(1, num_timesteps): 34 | if len(range(0, num_timesteps, i)) == desired_count: 35 | return set(range(0, num_timesteps, i)) 36 | raise ValueError( 37 | f"cannot create exactly {num_timesteps} steps with an integer stride" 38 | ) 39 | elif section_counts == "fast27": 40 | steps = space_timesteps(num_timesteps, "10,10,3,2,2") 41 | # Help reduce DDIM artifacts from noisiest timesteps. 42 | steps.remove(num_timesteps - 1) 43 | steps.add(num_timesteps - 3) 44 | return steps 45 | section_counts = [int(x) for x in section_counts.split(",")] 46 | size_per = num_timesteps // len(section_counts) 47 | extra = num_timesteps % len(section_counts) 48 | start_idx = 0 49 | all_steps = [] 50 | for i, section_count in enumerate(section_counts): 51 | size = size_per + (1 if i < extra else 0) 52 | if size < section_count: 53 | raise ValueError( 54 | f"cannot divide section of {size} steps into {section_count}" 55 | ) 56 | if section_count <= 1: 57 | frac_stride = 1 58 | else: 59 | frac_stride = (size - 1) / (section_count - 1) 60 | cur_idx = 0.0 61 | taken_steps = [] 62 | for _ in range(section_count): 63 | taken_steps.append(start_idx + round(cur_idx)) 64 | cur_idx += frac_stride 65 | all_steps += taken_steps 66 | start_idx += size 67 | return set(all_steps) 68 | 69 | 70 | class SpacedDiffusion(GaussianDiffusion): 71 | """ 72 | A diffusion process which can skip steps in a base diffusion process. 73 | 74 | :param use_timesteps: a collection (sequence or set) of timesteps from the 75 | original diffusion process to retain. 76 | :param kwargs: the kwargs to create the base diffusion process. 77 | """ 78 | 79 | def __init__(self, use_timesteps, **kwargs): 80 | self.use_timesteps = set(use_timesteps) 81 | self.original_num_steps = len(kwargs["betas"]) 82 | 83 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 84 | last_alpha_cumprod = 1.0 85 | new_betas = [] 86 | timestep_map = [] 87 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 88 | if i in self.use_timesteps: 89 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 90 | last_alpha_cumprod = alpha_cumprod 91 | timestep_map.append(i) 92 | kwargs["betas"] = th.tensor(new_betas).numpy() 93 | super().__init__(**kwargs) 94 | self.register_buffer("timestep_map", th.tensor(timestep_map), persistent=False) 95 | 96 | def p_mean_variance(self, model, *args, **kwargs): 97 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 98 | 99 | def condition_mean(self, cond_fn, *args, **kwargs): 100 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 101 | 102 | def condition_score(self, cond_fn, *args, **kwargs): 103 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 104 | 105 | def _wrap_model(self, model): 106 | def wrapped(x, ts, **kwargs): 107 | ts_cpu = ts.detach().to("cpu") 108 | return model( 109 | x, self.timestep_map[ts_cpu].to(device=ts.device, dtype=ts.dtype), **kwargs 110 | ) 111 | 112 | return wrapped 113 | -------------------------------------------------------------------------------- /src/utils/structural_losses/tf_nndistance_g.cu: -------------------------------------------------------------------------------- 1 | #if GOOGLE_CUDA 2 | #define EIGEN_USE_GPU 3 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 4 | 5 | __global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){ 6 | const int batch=512; 7 | __shared__ float buf[batch*3]; 8 | for (int i=blockIdx.x;ibest){ 120 | result[(i*n+j)]=best; 121 | result_i[(i*n+j)]=best_i; 122 | } 123 | } 124 | __syncthreads(); 125 | } 126 | } 127 | } 128 | void NmDistanceKernelLauncher(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i){ 129 | NmDistanceKernel<<>>(b,n,xyz,m,xyz2,result,result_i); 130 | NmDistanceKernel<<>>(b,m,xyz2,n,xyz,result2,result2_i); 131 | } 132 | __global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){ 133 | for (int i=blockIdx.x;i>>(b,n,xyz1,m,xyz2,grad_dist1,idx1,grad_xyz1,grad_xyz2); 156 | NmDistanceGradKernel<<>>(b,m,xyz2,n,xyz1,grad_dist2,idx2,grad_xyz2,grad_xyz1); 157 | } 158 | 159 | #endif 160 | -------------------------------------------------------------------------------- /src/utils/config2.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Minimal DALL-E 3 | # Copyright (c) 2021 KakaoBrain. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | 7 | import os 8 | from datetime import datetime 9 | from typing import Optional, List, Tuple 10 | from dataclasses import dataclass, field 11 | from omegaconf import OmegaConf 12 | from .config1 import LatentMixtureINRConfig, SimpleNeRFConfig 13 | 14 | @dataclass 15 | class DataConfig: 16 | dataset: Optional[str] = None 17 | image_resolution: int = 64 18 | data_dimension: int = 2307 19 | hparams_nerf: Optional[SimpleNeRFConfig] = SimpleNeRFConfig() 20 | 21 | @dataclass 22 | class INRParams: 23 | model_type: str = 'sine' 24 | sparse_type: str = 'none' 25 | hidden_features: int = 32 26 | num_hidden_layers: int = 2 27 | image_resolution: int = 64 28 | 29 | @dataclass 30 | class DiffusionParams: 31 | steps: int = 1000 32 | learn_sigma: bool = False 33 | sigma_small: bool = False 34 | noise_schedule: str = 'linear' 35 | use_kl: bool = False 36 | predict_xstart: bool = False 37 | rescale_learned_sigmas: bool = False 38 | timestep_respacing: str = '1000' 39 | 40 | 41 | @dataclass 42 | class MLPSkipNetParams: 43 | net_type: str = 'skip' 44 | num_layers: int = 20 45 | skip_layers: List[int] = field(default_factory=lambda: [i for i in range(1, 20)]) 46 | activation: str = 'silu' 47 | num_hid_channels: int = 2048 48 | num_hid_groups: int = 0 49 | use_first_hid_layer: bool = False 50 | use_norm: bool = True # use LayerNorm 51 | norm_type: str = 'layernorm' 52 | condition_bias: int = 1 53 | num_channels: int = 2307 54 | num_emb_channels: int = 0 55 | num_out_channels: int = 0 # determined by learn_sigma in diffusion params. 56 | num_time_emb_channels: int = 64 57 | dropout: float = 0.0 58 | last_act: Optional[str] = None 59 | num_time_layers: int = 2 60 | time_last_act: bool = False 61 | 62 | use_cond_index: bool = False # use index condition 63 | num_indices: int = 27000 # for CelebA 64 | type_cond_index: str = 'concat' 65 | 66 | # use_parallel_input: bool = False 67 | # parallel_input: Optional[List[int]] = field(default_factory=lambda: [96, 2211]) 68 | # parallel_embed: Optional[List[int]] = field(default_factory=lambda: [2048, 2048]) 69 | 70 | # @dataclass 71 | # class LatentMixtureINRConfig: 72 | # width: int = 64 73 | # depth: int = 4 74 | # out_channels: int = 3 75 | # image_resolution: int = 64 76 | # w0: float = 30. 77 | # k_mixtures: int = 64 78 | # mixture_type: str = 'layerwise' 79 | # embedding_type: str = 'none' 80 | # use_latent_embedding: bool = True 81 | # normalize_mixture: bool = False 82 | # latent_channels: int = 256 83 | # latent_init_scale: Tuple[float, float] = (0.95, 1.05) 84 | # use_meta_sgd: bool = True 85 | # meta_sgd_init_range: Tuple[float, float] = (0.005, 0.1) 86 | # meta_sgd_clip_range: Tuple[float, float] = (0., 1.) 87 | # init_path: str = '' 88 | 89 | @dataclass 90 | class Stage2Config: 91 | hparams_diffusion: DiffusionParams = DiffusionParams() 92 | hparams_model: MLPSkipNetParams = MLPSkipNetParams() 93 | hparams_inr: INRParams = INRParams() 94 | hparams_metainr: LatentMixtureINRConfig = LatentMixtureINRConfig() 95 | feat_std_scale: float = 4.0 96 | feat_type: str = 'weight' 97 | image_std_scale: float = 1.0 98 | crop_feature: bool = False 99 | crop_dim: int = 0 100 | 101 | @dataclass 102 | class WarmupConfig: 103 | multiplier: float = 1.0 104 | warmup_epoch: float = 0.0 105 | buffer_epoch: float = 0.0 106 | min_lr: float = 0.0 107 | mode: str = 'fix' 108 | start_from_zero: bool = True 109 | 110 | @dataclass 111 | class OptConfig: 112 | opt_type: str = 'adamw' 113 | betas: Optional[Tuple[float]] = field(default_factory=lambda: [0.9, 0.999]) 114 | base_lr: float = 1e-4 115 | weight_decay: float = 0.01 116 | use_amp: bool = False 117 | use_ema: bool = False 118 | grad_clip_norm: Optional[float] = None 119 | max_steps: Optional[int] = None 120 | steps_per_epoch: Optional[int] = None 121 | sched_type: str = 'cosine' 122 | warmup: WarmupConfig = WarmupConfig() 123 | 124 | 125 | @dataclass 126 | class ExpConfig: 127 | epochs: int = 100 128 | save_ckpt_freq: int = 2 129 | test_freq: int = 1 130 | img_logging_freq: int = 5000 131 | fp16_grad_comp: bool = False 132 | 133 | 134 | @dataclass 135 | class DefaultConfig2: 136 | dataset: DataConfig = DataConfig() 137 | stage2: Stage2Config = Stage2Config() 138 | optimizer: OptConfig = OptConfig() 139 | experiment: ExpConfig = ExpConfig() 140 | 141 | 142 | def update_config(cfg_base, cfg_new): 143 | cfg_update = OmegaConf.merge(cfg_base, cfg_new) 144 | return cfg_update 145 | 146 | def build_config(args): 147 | cfg_base = OmegaConf.structured(DefaultConfig2) 148 | if args.eval: 149 | cfg_new = OmegaConf.load(os.path.join(args.result_path, 'config.yaml')) 150 | cfg_update = update_config(cfg_base, cfg_new) 151 | result_path = args.result_path 152 | else: 153 | cfg_new = OmegaConf.load(args.config_path) 154 | cfg_update = update_config(cfg_base, cfg_new) 155 | #now = datetime.now().strftime('%d%m%Y_%H%M%S') 156 | result_path = os.path.join(args.result_path, 157 | os.path.basename(args.config_path).split('.')[0]) 158 | return cfg_update, result_path 159 | -------------------------------------------------------------------------------- /src/utils/prdc.py: -------------------------------------------------------------------------------- 1 | """ 2 | The final version of code is adopted from the below address 3 | https://github.kakaocorp.com/large-scale/vqvae/blob/447bd328ca3a77cdce5bf57355f3710cdfb0b417/src/utils/prdc.py 4 | 5 | PRDC computation with torch + numpy. 6 | 7 | Modified from: https://github.com/clovaai/generative-evaluation-prdc 8 | prdc 9 | Copyright (c) 2020-present NAVER Corp. 10 | MIT license 11 | """ 12 | import numpy as np 13 | import torch 14 | 15 | __all__ = ['compute_prdc'] 16 | 17 | 18 | def compute_pairwise_distance_sklearn(data_x, data_y=None): 19 | """Legacy method to compute pairwise distance (as in original prdc package) 20 | Args: 21 | data_x: numpy.ndarray([N, feature_dim], dtype=np.float32) 22 | data_y: numpy.ndarray([N, feature_dim], dtype=np.float32) 23 | Returns: 24 | numpy.ndarray([N, N], dtype=np.float32) of pairwise distances. 25 | """ 26 | import sklearn.metrics 27 | 28 | if data_y is None: 29 | data_y = data_x 30 | dists = sklearn.metrics.pairwise_distances( 31 | data_x, data_y, metric='euclidean', n_jobs=8) 32 | return dists 33 | 34 | 35 | def batch_pairwise_distances(U, V): 36 | """Compute pairwise distances between two batches of feature vectors.""" 37 | 38 | # Squared norms of each row in U and V. 39 | # norm_u as a column and norm_v as a row vectors. 40 | norm_u = U.pow(2.0).sum(1, keepdim=True) # shape: (len(U), 1) 41 | norm_v = V.pow(2.0).sum(1, keepdim=True).transpose(0, 1) # shape: (1, len(V)) 42 | 43 | # Pairwise squared Euclidean distances. 44 | D = norm_u + norm_v - 2. * (U @ V.t()) # shape: (len(U), len(V)) 45 | 46 | return D 47 | 48 | 49 | def compute_pairwise_distance(data_x, 50 | data_y=None, 51 | row_batch_size=10000, 52 | col_batch_size=10000, 53 | ): 54 | """ 55 | Args: 56 | data_x: numpy.ndarray([N, feature_dim], dtype=np.float32) 57 | data_y: numpy.ndarray([N, feature_dim], dtype=np.float32) 58 | Returns: 59 | numpy.ndarray([N, N], dtype=np.float32) of pairwise distances. 60 | """ 61 | 62 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 63 | if data_y is None: 64 | data_y = data_x 65 | 66 | n_x = len(data_x) 67 | n_y = len(data_y) 68 | 69 | dists = np.zeros([n_x, n_y], dtype=np.float32) 70 | 71 | for begin1 in range(0, n_x, row_batch_size): 72 | end1 = min(begin1 + row_batch_size, n_x) 73 | row_batch = data_x[begin1:end1] 74 | row_batch = torch.from_numpy(row_batch).to(device) 75 | 76 | for begin2 in range(0, n_y, col_batch_size): 77 | end2 = min(begin2 + col_batch_size, n_y) 78 | col_batch = data_y[begin2:end2] 79 | col_batch = torch.from_numpy(col_batch).to(device) 80 | 81 | # Compute distances between batches. 82 | batch_dist = batch_pairwise_distances(row_batch, col_batch) 83 | dists[begin1:end1, begin2:end2] = batch_dist.cpu().numpy() 84 | 85 | return dists 86 | 87 | 88 | def get_kth_value(unsorted, k, axis=-1): 89 | """ 90 | Args: 91 | unsorted: numpy.ndarray of any dimensionality. 92 | k: int 93 | Returns: 94 | kth values along the designated axis. 95 | """ 96 | indices = np.argpartition(unsorted, k, axis=axis)[..., :k] 97 | k_smallests = np.take_along_axis(unsorted, indices, axis=axis) 98 | kth_values = k_smallests.max(axis=axis) 99 | return kth_values 100 | 101 | 102 | def compute_nearest_neighbour_distances(input_features, nearest_k): 103 | """ 104 | Args: 105 | input_features: numpy.ndarray([N, feature_dim], dtype=np.float32) 106 | nearest_k: int 107 | Returns: 108 | Distances to kth nearest neighbours. 109 | """ 110 | distances = compute_pairwise_distance(input_features) 111 | radii = get_kth_value(distances, k=nearest_k + 1, axis=-1) 112 | return radii 113 | 114 | 115 | def compute_prdc(real_features, fake_features, nearest_k): 116 | """ 117 | Computes precision, recall, density, and coverage given two manifolds. 118 | 119 | Args: 120 | real_features: numpy.ndarray([N, feature_dim], dtype=np.float32) 121 | fake_features: numpy.ndarray([N, feature_dim], dtype=np.float32) 122 | nearest_k: int. 123 | Returns: 124 | dict of precision, recall, density, and coverage. 125 | """ 126 | 127 | print('Num real: {} Num fake: {}' 128 | .format(real_features.shape[0], fake_features.shape[0])) 129 | 130 | real_nearest_neighbour_distances = compute_nearest_neighbour_distances( 131 | real_features, nearest_k) 132 | fake_nearest_neighbour_distances = compute_nearest_neighbour_distances( 133 | fake_features, nearest_k) 134 | distance_real_fake = compute_pairwise_distance( 135 | real_features, fake_features) 136 | 137 | precision = ( 138 | distance_real_fake < 139 | np.expand_dims(real_nearest_neighbour_distances, axis=1) 140 | ).any(axis=0).mean() 141 | 142 | recall = ( 143 | distance_real_fake < 144 | np.expand_dims(fake_nearest_neighbour_distances, axis=0) 145 | ).any(axis=1).mean() 146 | 147 | density = (1. / float(nearest_k)) * ( 148 | distance_real_fake < 149 | np.expand_dims(real_nearest_neighbour_distances, axis=1) 150 | ).sum(axis=0).mean() 151 | 152 | coverage = ( 153 | distance_real_fake.min(axis=1) < 154 | real_nearest_neighbour_distances 155 | ).mean() 156 | 157 | return dict(precision=precision, recall=recall, 158 | density=density, coverage=coverage) 159 | -------------------------------------------------------------------------------- /src/models/stage2/diffusion/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | 11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 12 | class SiLU(nn.Module): 13 | def forward(self, x): 14 | return x * th.sigmoid(x) 15 | 16 | 17 | class GroupNorm32(nn.GroupNorm): 18 | def forward(self, x): 19 | return super().forward(x.float()).type(x.dtype) 20 | 21 | 22 | def conv_nd(dims, *args, **kwargs): 23 | """ 24 | Create a 1D, 2D, or 3D convolution module. 25 | """ 26 | if dims == 1: 27 | return nn.Conv1d(*args, **kwargs) 28 | elif dims == 2: 29 | return nn.Conv2d(*args, **kwargs) 30 | elif dims == 3: 31 | return nn.Conv3d(*args, **kwargs) 32 | raise ValueError(f"unsupported dimensions: {dims}") 33 | 34 | 35 | def linear(*args, **kwargs): 36 | """ 37 | Create a linear module. 38 | """ 39 | return nn.Linear(*args, **kwargs) 40 | 41 | 42 | def avg_pool_nd(dims, *args, **kwargs): 43 | """ 44 | Create a 1D, 2D, or 3D average pooling module. 45 | """ 46 | if dims == 1: 47 | return nn.AvgPool1d(*args, **kwargs) 48 | elif dims == 2: 49 | return nn.AvgPool2d(*args, **kwargs) 50 | elif dims == 3: 51 | return nn.AvgPool3d(*args, **kwargs) 52 | raise ValueError(f"unsupported dimensions: {dims}") 53 | 54 | 55 | def update_ema(target_params, source_params, rate=0.99): 56 | """ 57 | Update target parameters to be closer to those of source parameters using 58 | an exponential moving average. 59 | 60 | :param target_params: the target parameter sequence. 61 | :param source_params: the source parameter sequence. 62 | :param rate: the EMA rate (closer to 1 means slower). 63 | """ 64 | for targ, src in zip(target_params, source_params): 65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 66 | 67 | 68 | def zero_module(module): 69 | """ 70 | Zero out the parameters of a module and return it. 71 | """ 72 | for p in module.parameters(): 73 | p.detach().zero_() 74 | return module 75 | 76 | 77 | def scale_module(module, scale): 78 | """ 79 | Scale the parameters of a module and return it. 80 | """ 81 | for p in module.parameters(): 82 | p.detach().mul_(scale) 83 | return module 84 | 85 | 86 | def mean_flat(tensor): 87 | """ 88 | Take the mean over all non-batch dimensions. 89 | """ 90 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 91 | 92 | 93 | def normalization(channels): 94 | """ 95 | Make a standard normalization layer. 96 | 97 | :param channels: number of input channels. 98 | :return: an nn.Module for normalization. 99 | """ 100 | return GroupNorm32(32, channels) 101 | 102 | 103 | def timestep_embedding(timesteps, dim, max_period=10000): 104 | """ 105 | Create sinusoidal timestep embeddings. 106 | 107 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 108 | These may be fractional. 109 | :param dim: the dimension of the output. 110 | :param max_period: controls the minimum frequency of the embeddings. 111 | :return: an [N x dim] Tensor of positional embeddings. 112 | """ 113 | half = dim // 2 114 | freqs = th.exp( 115 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 116 | ).to(device=timesteps.device) 117 | args = timesteps[:, None].float() * freqs[None] 118 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 119 | if dim % 2: 120 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 121 | return embedding 122 | 123 | 124 | def checkpoint(func, inputs, params, flag): 125 | """ 126 | Evaluate a function without caching intermediate activations, allowing for 127 | reduced memory at the expense of extra compute in the backward pass. 128 | 129 | :param func: the function to evaluate. 130 | :param inputs: the argument sequence to pass to `func`. 131 | :param params: a sequence of parameters `func` depends on but does not 132 | explicitly take as arguments. 133 | :param flag: if False, disable gradient checkpointing. 134 | """ 135 | if flag: 136 | args = tuple(inputs) + tuple(params) 137 | return CheckpointFunction.apply(func, len(inputs), *args) 138 | else: 139 | return func(*inputs) 140 | 141 | 142 | class CheckpointFunction(th.autograd.Function): 143 | @staticmethod 144 | def forward(ctx, run_function, length, *args): 145 | ctx.run_function = run_function 146 | ctx.input_tensors = list(args[:length]) 147 | ctx.input_params = list(args[length:]) 148 | with th.no_grad(): 149 | output_tensors = ctx.run_function(*ctx.input_tensors) 150 | return output_tensors 151 | 152 | @staticmethod 153 | def backward(ctx, *output_grads): 154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 155 | with th.enable_grad(): 156 | # Fixes a bug where the first op in run_function modifies the 157 | # Tensor storage in place, which is not allowed for detach()'d 158 | # Tensors. 159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 160 | output_tensors = ctx.run_function(*shallow_copies) 161 | input_grads = th.autograd.grad( 162 | output_tensors, 163 | ctx.input_tensors + ctx.input_params, 164 | output_grads, 165 | allow_unused=True, 166 | ) 167 | del ctx.input_tensors 168 | del ctx.input_params 169 | del output_tensors 170 | return (None, None) + input_grads 171 | -------------------------------------------------------------------------------- /src/models/stage2/diffusion/resample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | import torch.distributed as dist 6 | 7 | 8 | def create_named_schedule_sampler(name, diffusion): 9 | """ 10 | Create a ScheduleSampler from a library of pre-defined samplers. 11 | 12 | :param name: the name of the sampler. 13 | :param diffusion: the diffusion object to sample for. 14 | """ 15 | if name == "uniform": 16 | return UniformSampler(diffusion) 17 | elif name == "loss-second-moment": 18 | return LossSecondMomentResampler(diffusion) 19 | else: 20 | raise NotImplementedError(f"unknown schedule sampler: {name}") 21 | 22 | 23 | class ScheduleSampler(ABC): 24 | """ 25 | A distribution over timesteps in the diffusion process, intended to reduce 26 | variance of the objective. 27 | 28 | By default, samplers perform unbiased importance sampling, in which the 29 | objective's mean is unchanged. 30 | However, subclasses may override sample() to change how the resampled 31 | terms are reweighted, allowing for actual changes in the objective. 32 | """ 33 | 34 | @abstractmethod 35 | def weights(self): 36 | """ 37 | Get a numpy array of weights, one per diffusion step. 38 | 39 | The weights needn't be normalized, but must be positive. 40 | """ 41 | 42 | def sample(self, batch_size, device): 43 | """ 44 | Importance-sample timesteps for a batch. 45 | 46 | :param batch_size: the number of timesteps. 47 | :param device: the torch device to save to. 48 | :return: a tuple (timesteps, weights): 49 | - timesteps: a tensor of timestep indices. 50 | - weights: a tensor of weights to scale the resulting losses. 51 | """ 52 | w = self.weights() 53 | p = w / np.sum(w) 54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 55 | indices = th.from_numpy(indices_np).long().to(device) 56 | weights_np = 1 / (len(p) * p[indices_np]) 57 | weights = th.from_numpy(weights_np).float().to(device) 58 | return indices, weights 59 | 60 | 61 | class UniformSampler(ScheduleSampler): 62 | def __init__(self, diffusion): 63 | self.diffusion = diffusion 64 | self._weights = np.ones([diffusion.num_timesteps]) 65 | 66 | def weights(self): 67 | return self._weights 68 | 69 | 70 | class LossAwareSampler(ScheduleSampler): 71 | def update_with_local_losses(self, local_ts, local_losses): 72 | """ 73 | Update the reweighting using losses from a model. 74 | 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | 80 | :param local_ts: an integer Tensor of timesteps. 81 | :param local_losses: a 1D Tensor of losses. 82 | """ 83 | batch_sizes = [ 84 | th.tensor([0], dtype=th.int32, device=local_ts.device) 85 | for _ in range(dist.get_world_size()) 86 | ] 87 | dist.all_gather( 88 | batch_sizes, 89 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 90 | ) 91 | 92 | # Pad all_gather batches to be the maximum batch size. 93 | batch_sizes = [x.item() for x in batch_sizes] 94 | max_bs = max(batch_sizes) 95 | 96 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 97 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 98 | dist.all_gather(timestep_batches, local_ts) 99 | dist.all_gather(loss_batches, local_losses) 100 | timesteps = [ 101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 102 | ] 103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 104 | self.update_with_all_losses(timesteps, losses) 105 | 106 | @abstractmethod 107 | def update_with_all_losses(self, ts, losses): 108 | """ 109 | Update the reweighting using losses from a model. 110 | 111 | Sub-classes should override this method to update the reweighting 112 | using losses from the model. 113 | 114 | This method directly updates the reweighting without synchronizing 115 | between workers. It is called by update_with_local_losses from all 116 | ranks with identical arguments. Thus, it should have deterministic 117 | behavior to maintain state across workers. 118 | 119 | :param ts: a list of int timesteps. 120 | :param losses: a list of float losses, one per timestep. 121 | """ 122 | 123 | 124 | class LossSecondMomentResampler(LossAwareSampler): 125 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 126 | self.diffusion = diffusion 127 | self.history_per_term = history_per_term 128 | self.uniform_prob = uniform_prob 129 | self._loss_history = np.zeros( 130 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 131 | ) 132 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 133 | 134 | def weights(self): 135 | if not self._warmed_up(): 136 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 138 | weights /= np.sum(weights) 139 | weights *= 1 - self.uniform_prob 140 | weights += self.uniform_prob / len(weights) 141 | return weights 142 | 143 | def update_with_all_losses(self, ts, losses): 144 | for t, loss in zip(ts, losses): 145 | if self._loss_counts[t] == self.history_per_term: 146 | # Shift out the oldest loss term. 147 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 148 | self._loss_history[t, -1] = loss 149 | else: 150 | self._loss_history[t, self._loss_counts[t]] = loss 151 | self._loss_counts[t] += 1 152 | 153 | def _warmed_up(self): 154 | return (self._loss_counts == self.history_per_term).all() 155 | -------------------------------------------------------------------------------- /src/utils/structural_losses/approxmatch.cu: -------------------------------------------------------------------------------- 1 | //n<=4096, m<=1024 2 | __global__ void approxmatch(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,float * __restrict__ match){ 3 | const int MaxN=4096,MaxM=1024; 4 | __shared__ float remainL[MaxN],remainR[MaxM],ratioR[MaxM],ratioL[MaxN]; 5 | __shared__ int listR[MaxM],lc; 6 | float multiL,multiR; 7 | if (n>=m){ 8 | multiL=1; 9 | multiR=n/m; 10 | }else{ 11 | multiL=m/n; 12 | multiR=1; 13 | } 14 | for (int i=blockIdx.x;i=-2;j--){ 23 | float level=-powf(4.0f,j); 24 | if (j==-2){ 25 | level=0; 26 | } 27 | if (threadIdx.x==0){ 28 | lc=0; 29 | for (int k=0;k0) 31 | listR[lc++]=k; 32 | } 33 | __syncthreads(); 34 | int _lc=lc; 35 | for (int k=threadIdx.x;k>>(b,n,m,xyz1,xyz2,match); 94 | } 95 | __global__ void matchcost(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,const float * __restrict__ match,float * __restrict__ out){ 96 | __shared__ float allsum[512]; 97 | const int Block=256; 98 | __shared__ float buf[Block*3]; 99 | for (int i=blockIdx.x;i>>(b,n,m,xyz1,xyz2,match,out); 138 | } 139 | __global__ void matchcostgrad(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,const float * __restrict__ match,float * grad2){ 140 | __shared__ float sum_grad[256*3]; 141 | for (int i=blockIdx.x;i>>(b,n,m,xyz1,xyz2,match,grad2); 182 | } 183 | 184 | -------------------------------------------------------------------------------- /src/models/stage1/discriminator.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # HQ-Transformer 3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved. 4 | # Licensed under the MIT License [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import lpips 11 | 12 | from .CIPS.Discriminators import Discriminator 13 | from .modules.utils import weights_init 14 | 15 | import pdb 16 | 17 | def adopt_weight(weight, global_step, threshold=0, value=0.): 18 | if global_step < threshold: 19 | weight = value 20 | return weight 21 | 22 | 23 | def hinge_d_loss(logits_real, logits_fake): 24 | loss_real = torch.mean(F.relu(1. - logits_real)) 25 | loss_fake = torch.mean(F.relu(1. + logits_fake)) 26 | d_loss = 0.5 * (loss_real + loss_fake) 27 | return d_loss 28 | 29 | 30 | def vanilla_d_loss(logits_real, logits_fake): 31 | d_loss = 0.5 * ( 32 | torch.mean(torch.nn.functional.softplus(-logits_real)) + 33 | torch.mean(torch.nn.functional.softplus(logits_fake))) 34 | return d_loss 35 | 36 | 37 | # def d_r1_loss(real_pred, real_img): 38 | # grad_real, = autograd.grad( 39 | # outputs=real_pred.sum(), inputs=real_img, create_graph=True 40 | # ) 41 | # grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean() 42 | 43 | # return grad_penalty 44 | 45 | 46 | class VQLPIPSWithDiscriminator(nn.Module): 47 | def __init__(self, 48 | codebook_weight=0.0, 49 | pixelloss_weight=1.0, 50 | perceptual_weight=0.0, 51 | disc_weight=1.0, 52 | disc_start=0, disc_in_channels=3, disc_factor=1.0, 53 | disc_loss="vanilla"): 54 | super().__init__() 55 | assert disc_loss in ["hinge", "vanilla"] 56 | 57 | self.pixel_weight = pixelloss_weight 58 | 59 | self.perceptual_weight = perceptual_weight 60 | if self.perceptual_weight > 0.0: 61 | self.perceptual_loss = lpips.LPIPS(net='vgg', spatial=True) 62 | 63 | self.codebook_weight = codebook_weight 64 | 65 | self.discriminator = Discriminator(size=256, input_size=disc_in_channels) 66 | self.discriminator_iter_start = disc_start 67 | if disc_loss == "hinge": 68 | self.disc_loss = hinge_d_loss 69 | elif disc_loss == "vanilla": 70 | self.disc_loss = vanilla_d_loss 71 | else: 72 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 73 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 74 | self.disc_factor = disc_factor 75 | self.discriminator_weight = disc_weight 76 | 77 | 78 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 79 | global_step, last_layer=None, split="train"): 80 | _inputs, _recons = inputs.contiguous(), reconstructions.contiguous() 81 | 82 | log = dict() 83 | 84 | if self.discriminator_weight == 0.0 or (self.discriminator_weight > 0.0 and optimizer_idx == 0): 85 | rec_loss = F.mse_loss(_inputs, _recons) 86 | log["{}/rec_loss".format(split)] = rec_loss.detach().mean() 87 | if self.pixel_weight > 0.0: 88 | loss = rec_loss 89 | else: 90 | loss = 0 91 | 92 | if self.perceptual_weight > 0.0: 93 | p_loss = self.perceptual_loss(_inputs, _recons).mean() 94 | log["{}/p_loss".format(split)] = p_loss.detach() 95 | loss += self.perceptual_weight * p_loss 96 | 97 | if codebook_loss is not None: 98 | log["{}/quant_loss".format(split)] = codebook_loss.detach().mean() 99 | 100 | # now the GAN part 101 | if self.discriminator_weight > 0.0: 102 | if optimizer_idx == 0: 103 | g_loss, logits_fake = self.forward_logits_fake(reconstructions) 104 | d_weight = self.discriminator_weight 105 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 106 | 107 | loss += d_weight * disc_factor * g_loss 108 | if codebook_loss is not None: 109 | loss += self.codebook_weight * codebook_loss.mean() 110 | 111 | log["{}/total_loss".format(split)] = loss.clone().detach().mean() 112 | log["{}/d_weight".format(split)] = d_weight 113 | log["{}/disc_factor".format(split)] = torch.tensor(disc_factor) 114 | log["{}/g_loss".format(split)] = g_loss.detach().mean() 115 | 116 | return loss, log 117 | 118 | if optimizer_idx == 1: 119 | d_loss, logits_real, logits_fake = self.forward_logits_real_fake(inputs, 120 | reconstructions, 121 | global_step) 122 | 123 | log["{}/disc_loss".format(split)] = d_loss.clone().detach().mean() 124 | log["{}/logits_real".format(split)] = logits_real.detach().mean() 125 | log["{}/logits_fake".format(split)] = logits_fake.detach().mean() 126 | 127 | return d_loss, log 128 | else: 129 | log["{}/total_loss".format(split)] = loss.clone().detach().mean() 130 | return loss, log 131 | 132 | def forward_logits_fake(self, reconstructions): 133 | # generator update 134 | logits_fake = self.discriminator(reconstructions.contiguous()) 135 | 136 | g_loss = -torch.mean(logits_fake) 137 | return g_loss, logits_fake 138 | 139 | def forward_logits_real_fake(self, inputs, reconstructions, global_step): 140 | # second pass for discriminator update 141 | logits_real = self.discriminator(inputs.contiguous().detach()) 142 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 143 | 144 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 145 | 146 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 147 | 148 | return d_loss, logits_real, logits_fake 149 | -------------------------------------------------------------------------------- /src/utils/config1.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Minimal DALL-E 3 | # Copyright (c) 2021 KakaoBrain. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | 7 | import os 8 | from datetime import datetime 9 | from typing import Optional, List, Tuple 10 | from dataclasses import dataclass, field 11 | from omegaconf import OmegaConf 12 | 13 | @dataclass 14 | class SimpleNeRFConfig: 15 | epoch_for_full_rendering: int = 0 16 | subsampled_views: int = 16 17 | subsampled_pixels: int = 512 18 | num_samples_per_ray: int = 31 19 | near: float = 0.8 20 | far: float = 1.8 21 | randomized: bool = False 22 | prob_mask_sampling: float = 0.0 23 | rgb_activation: str = 'sigmoid' 24 | density_activation: str = 'relu' 25 | lindisp: bool = False 26 | white_bkgd: bool = True 27 | chuncking_unit: int = -1 28 | functa_rendering: bool = False 29 | rendering_type: str = 'baseline' 30 | debug: bool = False 31 | resolution: int = 0 32 | H: int = 0 33 | W: int = 0 34 | 35 | 36 | @dataclass 37 | class DataConfig: 38 | dataset: Optional[str] = None 39 | resolution: int = 256 40 | num_subsampling: int = 0 41 | pred_type: str = 'image' 42 | repeat_sampling: bool = False 43 | hparams_nerf: Optional[SimpleNeRFConfig] = SimpleNeRFConfig() 44 | 45 | 46 | @dataclass 47 | class Stage1EncoderHparams: 48 | #type: str = "ViT-B/32" 49 | type: str = "resnet50-conv" 50 | num_embed_conv: int = 1 51 | dim_embed_conv: int = 0 52 | num_groups: int = 16 53 | type_init_norm: str = '' 54 | num_flatten_linear: int = 1 55 | dim_flatten_linear: int = 0 56 | num_residual_embed: int = 0 57 | dim_residual_embed: int = 0 58 | 59 | 60 | @dataclass 61 | class Stage1GeneratorHparams: 62 | type: str = 'CIPSskip' 63 | size: int = 256 # is this resolution 64 | hidden_size: int = 512 65 | n_mlp: int = 8 66 | style_dim: int = 512 67 | lr_mlp: float = 0.01 68 | activation: Optional[str] = None 69 | channel_multiplier: int = 2 70 | 71 | 72 | @dataclass 73 | class Stage1HparamsDisc: 74 | disc_in_channels: int = 3 75 | disc_start: int = 0 76 | disc_weight: float = 0.0 77 | codebook_weight: float = 0.0 78 | pixelloss_weight: float = 1.0 79 | perceptual_weight: float = 0.0 80 | 81 | 82 | @dataclass 83 | class Stage1Config: 84 | hparams_enc: Stage1EncoderHparams = Stage1EncoderHparams() 85 | hparams_dec: Stage1GeneratorHparams = Stage1GeneratorHparams() 86 | hparams_disc: Optional[Stage1HparamsDisc] = Stage1HparamsDisc() 87 | 88 | 89 | @dataclass 90 | class WarmupConfig: 91 | multiplier: float = 1.0 92 | warmup_epoch: float = 0.0 93 | buffer_epoch: float = 0.0 94 | min_lr: float = 0.0 95 | mode: str = 'fix' 96 | start_from_zero: bool = True 97 | 98 | 99 | @dataclass 100 | class OptConfig: 101 | opt_type: str = 'adam' 102 | betas: Optional[Tuple[float]] = None 103 | base_lr: float = 1e-4 104 | g_ratio: float = 1.0 105 | d_ratio: float = 1.0 106 | use_amp: bool = True 107 | grad_clip_norm: Optional[float] = 1.0 108 | max_steps: Optional[int] = None 109 | steps_per_epoch: Optional[int] = None 110 | warmup_config: WarmupConfig = WarmupConfig() 111 | 112 | 113 | @dataclass 114 | class ExpConfig: 115 | epochs: int = 100 116 | save_ckpt_freq: int = 2 117 | test_freq: int = 1 118 | img_logging_freq: int = 5000 119 | fp16_grad_comp: bool = False 120 | 121 | 122 | @dataclass 123 | class DefaultConfig: 124 | dataset: DataConfig = DataConfig() 125 | stage1: Stage1Config = Stage1Config() 126 | optimizer: OptConfig = OptConfig() 127 | experiment: ExpConfig = ExpConfig() 128 | 129 | 130 | @dataclass 131 | class LatentModulatedSIRENConfig: 132 | width: int = 256 133 | depth: int = 5 134 | out_channels: int = 3 135 | latent_dim: int = 128 136 | latent_vector_type: str = 'instance' 137 | layer_sizes: Tuple[int, ...] = () 138 | w0: float = 30. 139 | modulate_scale: bool = False 140 | modulate_shift: bool = True 141 | latent_init_scale: float = 0.01 142 | use_meta_sgd: bool = True 143 | meta_sgd_init_range: Tuple[float, float] = (0.005, 0.1) 144 | meta_sgd_clip_range: Tuple[float, float] = (0., 1.) 145 | 146 | 147 | @dataclass 148 | class LatentMixtureINRConfig: 149 | width: int = 64 150 | depth: int = 4 151 | in_channels: int = 2 152 | out_channels: int = 3 153 | image_resolution: int = 64 154 | w0: float = 30. 155 | k_mixtures: int = 64 156 | mixture_type: str = 'layerwise' 157 | embedding_type: str = 'none' 158 | use_latent_embedding: bool = True 159 | std_latent: float = 0.0 160 | #normalize_mixture: bool = False 161 | latent_channels: int = 256 162 | latent_init_scale: Tuple[float, float] = (0.95, 1.05) 163 | use_meta_sgd: bool = True 164 | meta_sgd_init_range: Tuple[float, float] = (0.005, 0.1) 165 | meta_sgd_clip_range: Tuple[float, float] = (0., 1.) 166 | #use_residual_param: bool = False 167 | #type_lipschitz: str = 'none' 168 | #lipschitz_const: float = 1.0 169 | init_path: str = '' 170 | pred_type: str = 'none' 171 | outermost_linear: Optional[bool] = None 172 | 173 | 174 | @dataclass 175 | class MetaOptConfig: 176 | use_amp: bool = False 177 | optim_outer: str = 'adamw' 178 | betas: Optional[Tuple[float]] = None 179 | lr_outer: float = 3e-6 180 | lr_inner: float = 1e-2 181 | num_steps: int = 3 182 | weight_decay_outer: float = 0.0 183 | weight_decay_inner: float = 0.0 184 | sparsity_inner: float = 0.0 185 | sparsity_outer: float = 0.0 186 | 187 | clip_grad: bool = False 188 | grad_clip_norm: float = 4.0 189 | 190 | # num_epochs: int = 50 191 | # save_freq: int = 1 192 | double_precision: bool = False 193 | max_steps: Optional[int] = None 194 | use_lr_scheduler: bool = True 195 | min_lr_outer: float = 0.0 196 | first_order: bool = False 197 | 198 | steps_per_epoch: Optional[int] = None 199 | 200 | 201 | @dataclass 202 | class DefaultMetaINR: 203 | dataset: DataConfig = DataConfig() 204 | model_type: str = 'mixtureinr' 205 | hparams_inr: LatentMixtureINRConfig = LatentMixtureINRConfig() 206 | optimizer: MetaOptConfig = MetaOptConfig() 207 | experiment: ExpConfig = ExpConfig() 208 | 209 | 210 | def update_config(cfg_base, cfg_new): 211 | cfg_update = OmegaConf.merge(cfg_base, cfg_new) 212 | return cfg_update 213 | 214 | 215 | def build_config(args): 216 | if args.stage1_type == 'asym-diff': 217 | cfg_base = OmegaConf.structured(DefaultConfig) 218 | elif args.stage1_type == 'parammix': 219 | cfg_base = OmegaConf.structured(DefaultMetaINR) 220 | 221 | if args.eval: 222 | cfg_new = OmegaConf.load(os.path.join(args.result_path, 'config.yaml')) 223 | cfg_update = update_config(cfg_base, cfg_new) 224 | result_path = args.result_path 225 | else: 226 | cfg_new = OmegaConf.load(args.config_path) 227 | cfg_update = update_config(cfg_base, cfg_new) 228 | now = datetime.now().strftime('%d%m%Y_%H%M%S') 229 | result_path = os.path.join(args.result_path, 230 | os.path.basename(args.config_path).split('.')[0], 231 | now) 232 | return cfg_update, result_path -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import math 4 | import logging 5 | 6 | from scipy import linalg 7 | import numpy as np 8 | 9 | import torch 10 | 11 | def set_seed(seed): 12 | random.seed(seed) 13 | np.random.seed(seed) 14 | torch.manual_seed(seed) 15 | torch.cuda.manual_seed_all(seed) 16 | torch.backends.cudnn.deterministic = True 17 | torch.backends.cudnn.benchmark = False 18 | 19 | 20 | def cond_mkdir(path): 21 | if not os.path.exists(path): 22 | os.makedirs(path) 23 | 24 | 25 | def logging_model_size(model, logger): 26 | if logger is None: 27 | return 28 | logger.info( 29 | "[OPTION: ALL] #params: %.4fM", sum(p.numel() for p in model.parameters()) / 1e6 30 | ) 31 | logger.info( 32 | "[OPTION: Trainable] #params: %.4fM", sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6 33 | ) 34 | 35 | 36 | def image_mse(mask, model_output, gt): 37 | if mask is None: 38 | return {'img_loss': ((model_output['model_out'] - gt['img']) ** 2).mean()} 39 | else: 40 | return {'img_loss': (mask * (model_output['model_out'] - gt['img']) ** 2).mean()} 41 | 42 | def compute_psnr(signal, gt): 43 | mse = max(float(torch.mean((signal-gt)**2)), 1e-8) 44 | psnr = float(-10 * math.log10(mse)) 45 | return psnr 46 | 47 | # https://github.com/yilundu/gem/blob/main/experiment_scripts/gen_imnet_autodecoder.py#L32 48 | def sample_points_triangle(vertices, triangles, num_of_points): 49 | epsilon = 1e-6 50 | triangle_area_list = np.zeros([len(triangles)],np.float32) 51 | triangle_normal_list = np.zeros([len(triangles),3],np.float32) 52 | for i in range(len(triangles)): 53 | #area = |u x v|/2 = |u||v|sin(uv)/2 54 | a,b,c = vertices[triangles[i,1]]-vertices[triangles[i,0]] 55 | x,y,z = vertices[triangles[i,2]]-vertices[triangles[i,0]] 56 | ti = b*z-c*y 57 | tj = c*x-a*z 58 | tk = a*y-b*x 59 | area2 = math.sqrt(ti*ti+tj*tj+tk*tk) 60 | if area2100: 84 | print("infinite loop here!") 85 | return point_normal_list 86 | for i in range(len(triangle_index_list)): 87 | if count>=num_of_points: break 88 | dxb = triangle_index_list[i] 89 | prob = sample_prob_list[dxb] 90 | prob_i = int(prob) 91 | prob_f = prob-prob_i 92 | if np.random.random()=1: 103 | u_x = 1-u_x 104 | v_y = 1-v_y 105 | ppp = u*u_x+v*v_y+base 106 | 107 | point_normal_list[count,:3] = ppp 108 | point_normal_list[count,3:] = normal_direction 109 | count += 1 110 | if count>=num_of_points: break 111 | 112 | return point_normal_list 113 | 114 | 115 | # https://github.com/yilundu/gem/blob/main/experiment_scripts/gen_imnet_autodecoder.py#L99 116 | def write_ply_triangle(name, vertices, triangles): 117 | fout = open(name, 'w') 118 | fout.write("ply\n") 119 | fout.write("format ascii 1.0\n") 120 | fout.write("element vertex "+str(len(vertices))+"\n") 121 | fout.write("property float x\n") 122 | fout.write("property float y\n") 123 | fout.write("property float z\n") 124 | fout.write("element face "+str(len(triangles))+"\n") 125 | fout.write("property list uchar int vertex_index\n") 126 | fout.write("end_header\n") 127 | for ii in range(len(vertices)): 128 | fout.write(str(vertices[ii,0])+" "+str(vertices[ii,1])+" "+str(vertices[ii,2])+"\n") 129 | for ii in range(len(triangles)): 130 | fout.write("3 "+str(triangles[ii,0])+" "+str(triangles[ii,1])+" "+str(triangles[ii,2])+"\n") 131 | fout.close() 132 | 133 | def frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 134 | """Numpy implementation of the Frechet Distance. 135 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 136 | and X_2 ~ N(mu_2, C_2) is 137 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 138 | 139 | Stable version by Dougal J. Sutherland. 140 | 141 | Params: 142 | -- mu1 : Numpy array containing the activations of a layer of the 143 | inception net (like returned by the function 'get_predictions') 144 | for generated samples. 145 | -- mu2 : The sample mean over activations, precalculated on an 146 | representative data set. 147 | -- sigma1: The covariance matrix over activations for generated samples. 148 | -- sigma2: The covariance matrix over activations, precalculated on an 149 | representative data set. 150 | 151 | Returns: 152 | -- : The Frechet Distance. 153 | """ 154 | 155 | mu1 = np.atleast_1d(mu1) 156 | mu2 = np.atleast_1d(mu2) 157 | 158 | sigma1 = np.atleast_2d(sigma1) 159 | sigma2 = np.atleast_2d(sigma2) 160 | 161 | assert mu1.shape == mu2.shape, \ 162 | 'Training and test mean vectors have different lengths' 163 | assert sigma1.shape == sigma2.shape, \ 164 | 'Training and test covariances have different dimensions' 165 | 166 | diff = mu1 - mu2 167 | 168 | # Product might be almost singular 169 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 170 | if not np.isfinite(covmean).all(): 171 | msg = ('fid calculation produces singular product; ' 172 | 'adding %s to diagonal of cov estimates') % eps 173 | logging.warning(msg) 174 | offset = np.eye(sigma1.shape[0]) * eps 175 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 176 | 177 | # Numerical error might give slight imaginary component 178 | if np.iscomplexobj(covmean): 179 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 180 | m = np.max(np.abs(covmean.imag)) 181 | raise ValueError('Imaginary component {}'.format(m)) 182 | covmean = covmean.real 183 | 184 | tr_covmean = np.trace(covmean) 185 | 186 | return (diff.dot(diff) + np.trace(sigma1) + 187 | np.trace(sigma2) - 2 * tr_covmean) 188 | 189 | def mean_covar_numpy(xs): 190 | if isinstance(xs, torch.Tensor): 191 | xs = xs.cpu().numpy() 192 | return np.mean(xs, axis=0), np.cov(xs, rowvar=False) -------------------------------------------------------------------------------- /main_stage2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | from torch.utils.data.dataloader import DataLoader 5 | import torchvision 6 | 7 | import pytorch_lightning as pl 8 | from torch.distributed.algorithms.ddp_comm_hooks import default_hooks 9 | from pytorch_lightning.plugins import DDPPlugin 10 | from pytorch_lightning.callbacks import ModelCheckpoint 11 | from pytorch_lightning.loggers import TensorBoardLogger 12 | 13 | from src.datasets import CelebAHQ, ShapeNet, SRNDatasets, INRWeightWrapper 14 | from src.utils.logger import LatentDDPMLogger 15 | from src.models import build_model_stage2 16 | from src.utils.config2 import build_config 17 | from src.utils.utils import logging_model_size 18 | 19 | 20 | parser = argparse.ArgumentParser() 21 | 22 | parser.add_argument('-c', '--config-path', type=str, default=None, required=True) 23 | parser.add_argument('-r', '--result-path', type=str, default=None, required=True) 24 | parser.add_argument('--checkpoint_path', type=str, default='') 25 | parser.add_argument('--stage1_epoch', type=int, default=-1) 26 | parser.add_argument('--eval', default=False) 27 | 28 | parser.add_argument('--world_size', default=-1, type=int, help='number of nodes for distributed training') 29 | parser.add_argument('--local_rank', default=-1, type=int, help='local rank for distributed training') 30 | parser.add_argument('--node_rank', default=-1, type=int, help='node rank for distributed training') 31 | parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend') 32 | parser.add_argument('--n-nodes', type=int, default=1) 33 | parser.add_argument('--n-gpus', type=int, default=1) 34 | parser.add_argument('--local_batch_size', type=int, default=64) 35 | parser.add_argument('--valid_batch_size', type=int, default=64) 36 | parser.add_argument('--total_batch_size', type=int, default=64) 37 | parser.add_argument('--seed', type=int, default=0) 38 | parser.add_argument('--dataset_root', type=str, default='datasets') 39 | parser.add_argument('--reduce_sample', type=int, default=0) 40 | parser.add_argument('--context_tag', type=str, default='set') 41 | 42 | args = parser.parse_args() 43 | 44 | def setup_callbacks(config, result_path): 45 | # Setup callbacks 46 | ckpt_path = os.path.join(result_path, 'ckpt') 47 | log_path = os.path.join(result_path, 'log') 48 | 49 | checkpoint_callback = ModelCheckpoint( 50 | dirpath=ckpt_path, 51 | filename=config.dataset.dataset+"-lddim{epoch:02d}", 52 | every_n_epochs=config.experiment.save_ckpt_freq, 53 | save_top_k=-1, 54 | save_weights_only=True, 55 | save_last=False # do not save the last 56 | ) 57 | logger_tb = TensorBoardLogger(log_path, name="latent-ddpm") 58 | logger_cu = LatentDDPMLogger(config, result_path) 59 | return checkpoint_callback, logger_tb, logger_cu 60 | 61 | 62 | if __name__ == '__main__': 63 | pl.seed_everything(args.seed) 64 | 65 | # Setup 66 | config, result_path = build_config(args) 67 | ckpt_callback, logger_tb, logger_cu = setup_callbacks(config, result_path) 68 | 69 | if len(args.dataset_root) > 0: 70 | root_path = args.dataset_root 71 | else: 72 | root_path = None 73 | 74 | # Build data modules 75 | dname = config.dataset.dataset.lower() 76 | 77 | if 'celeba' in dname: 78 | data_res = config.dataset.image_resolution 79 | downsampled = False 80 | tf_dataset = 'tf' in dname.lower() 81 | train_dataset = CelebAHQ(split='train', downsampled=downsampled, resolution=data_res, dataset_root=root_path, tf_dataset=tf_dataset) 82 | valid_dataset = CelebAHQ(split='test', downsampled=downsampled, resolution=data_res, dataset_root=root_path, tf_dataset=tf_dataset) 83 | resampling = 'bicubic' 84 | 85 | elif 'cifar10' in dname: 86 | data_res = config.dataset.image_resolution 87 | train_dataset = torchvision.datasets.CIFAR10(root=root_path, train=True, download=True) 88 | valid_dataset = torchvision.datasets.CIFAR10(root=root_path, train=False, download=True) 89 | elif 'shapenet' in dname: 90 | train_dataset = ShapeNet(split='train', sampling=4096, dataset_root=root_path) 91 | valid_dataset = ShapeNet(split='test', sampling=4096, dataset_root=root_path) 92 | elif 'srncars' in dname: 93 | train_dataset = None 94 | valid_dataset = None 95 | 96 | else: 97 | raise ValueError() 98 | 99 | checkpoint_path = args.checkpoint_path 100 | print(checkpoint_path) 101 | input_res = config.stage2.hparams_inr.image_resolution 102 | train_dataset = INRWeightWrapper(train_dataset, 103 | sidelength=input_res, 104 | checkpoint_path=checkpoint_path, 105 | checkpoint_step=args.stage1_epoch, 106 | reduce_sample=args.reduce_sample, 107 | feed_type=config.stage2.feat_type, 108 | context_tag=args.context_tag, 109 | istuple='cifar10' in dname) 110 | train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.local_batch_size, pin_memory=True, num_workers=8) 111 | 112 | # # Ignore validation dataset because fixed training epoch. 113 | # valid_checkpoint_path = checkpoint_path + '-test' 114 | # valid_dataset = INRWeightWrapper(valid_dataset, 115 | # sidelength=input_res, 116 | # checkpoint_path=valid_checkpoint_path, 117 | # checkpoint_step=args.ckpt_step, 118 | # reduce_sample=args.reduce_sample, 119 | # feed_type=config.stage2.feat_type) 120 | # valid_dataloader = DataLoader(valid_dataset, shuffle=False, batch_size=args.valid_batch_size, pin_memory=True, num_workers=8) 121 | # valid_dataloader = DataLoader(train_dataset, shuffle=False, batch_size=args.local_batch_size, pin_memory=True, num_workers=4) 122 | valid_dataloader = None 123 | 124 | # Calculate how many batches are accumulated 125 | total_gpus = args.n_gpus * args.n_nodes 126 | assert args.total_batch_size % total_gpus == 0 127 | grad_accm_steps = args.total_batch_size // (args.local_batch_size * total_gpus) 128 | config.optimizer.max_steps = len(train_dataset) // args.total_batch_size * config.experiment.epochs 129 | config.optimizer.steps_per_epoch = len(train_dataset) // args.total_batch_size 130 | config.stage2.hparams_metainr.init_path = os.path.join(args.checkpoint_path, 'metainits', f'epoch{args.stage1_epoch}.pth') 131 | 132 | # Build a model 133 | model = build_model_stage2(cfg_stage2=config.stage2, cfg_opt=config.optimizer, affine=train_dataset.affine) 134 | logging_model_size(model, logger_cu._logger) 135 | 136 | # Build a trainer 137 | trainer = pl.Trainer(max_epochs=config.experiment.epochs, 138 | accumulate_grad_batches=grad_accm_steps, 139 | gradient_clip_val=config.optimizer.grad_clip_norm, 140 | precision=16 if config.optimizer.use_amp else 32, 141 | callbacks=[ckpt_callback, logger_cu], 142 | accelerator="gpu", 143 | num_nodes=args.n_nodes, 144 | devices=args.n_gpus, 145 | strategy=DDPPlugin(ddp_comm_hook=default_hooks.fp16_compress_hook) if 146 | config.experiment.fp16_grad_comp else "ddp", 147 | logger=logger_tb, 148 | log_every_n_steps=10) 149 | 150 | trainer.fit(model, train_dataloader, valid_dataloader) 151 | -------------------------------------------------------------------------------- /src/utils/structural_losses/approxmatch.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | using namespace std; 9 | float randomf(){ 10 | return (rand()+0.5)/(RAND_MAX+1.0); 11 | } 12 | static double get_time(){ 13 | timespec tp; 14 | clock_gettime(CLOCK_MONOTONIC,&tp); 15 | return tp.tv_sec+tp.tv_nsec*1e-9; 16 | } 17 | void approxmatch_cpu(int b,int n,int m,float * xyz1,float * xyz2,float * match){ 18 | for (int i=0;i saturatedl(n,double(factorl)),saturatedr(m,double(factorr)); 22 | vector weight(n*m); 23 | for (int j=0;j=-2;j--){ 26 | //printf("i=%d j=%d\n",i,j); 27 | double level=-powf(4.0,j); 28 | if (j==-2) 29 | level=0; 30 | for (int k=0;k ss(m,1e-9); 42 | for (int k=0;k ss2(m,0); 59 | for (int k=0;k1){ 154 | printf("bad i=%d j=%d k=%d u=%f\n",i,j,k,u); 155 | } 156 | s+=u; 157 | } 158 | if (s<0.999 || s>1.001){ 159 | printf("bad i=%d j=%d s=%f\n",i,j,s); 160 | } 161 | } 162 | for (int j=0;j4.001){ 168 | printf("bad i=%d j=%d s=%f\n",i,j,s); 169 | } 170 | } 171 | }*/ 172 | /*for (int j=0;j1e-3) 222 | if (fabs(double(match[i*n*m+k*n+j]-match_cpu[i*n*m+j*m+k]))>1e-2){ 223 | printf("i %d j %d k %d m %f %f\n",i,j,k,match[i*n*m+k*n+j],match_cpu[i*n*m+j*m+k]); 224 | flag=false; 225 | break; 226 | } 227 | //emax=max(emax,fabs(double(match[i*n*m+k*n+j]-match_cpu[i*n*m+j*m+k]))); 228 | emax+=fabs(double(match[i*n*m+k*n+j]-match_cpu[i*n*m+j*m+k])); 229 | } 230 | } 231 | printf("emax_match=%f\n",emax/2/n/m); 232 | emax=0; 233 | for (int i=0;i<2;i++) 234 | emax+=fabs(double(cost[i]-cost_cpu[i])); 235 | printf("emax_cost=%f\n",emax/2); 236 | emax=0; 237 | for (int i=0;i<2*m*3;i++) 238 | emax+=fabs(double(grad[i]-grad_cpu[i])); 239 | //for (int i=0;i<3*m;i++){ 240 | //if (grad[i]!=0) 241 | //printf("i %d %f %f\n",i,grad[i],grad_cpu[i]); 242 | //} 243 | printf("emax_grad=%f\n",emax/(2*m*3)); 244 | 245 | cudaFree(xyz1_g); 246 | cudaFree(xyz2_g); 247 | cudaFree(match_g); 248 | cudaFree(cost_g); 249 | cudaFree(grad_g); 250 | 251 | return 0; 252 | } 253 | 254 | -------------------------------------------------------------------------------- /src/utils/logger.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # HQ-Transformer 3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved. 4 | # Licensed under the MIT License [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | 7 | import os 8 | import logging 9 | import math 10 | 11 | import torch 12 | import torchvision 13 | import torch.distributed as dist 14 | 15 | from pytorch_lightning.callbacks import Callback 16 | from pytorch_lightning.utilities.distributed import rank_zero_only 17 | from omegaconf import OmegaConf 18 | 19 | import pdb 20 | 21 | class DefaultLogger(Callback): 22 | def __init__(self, config, result_path, is_eval=False): 23 | super().__init__() 24 | 25 | self._config = config 26 | self._result_path = result_path 27 | self._logger = self._init_logger(is_eval=is_eval) 28 | 29 | @rank_zero_only 30 | def _init_logger(self, is_eval=False): 31 | self.save_config() 32 | logger = logging.getLogger(__name__) 33 | logger.setLevel(logging.INFO) 34 | # create console handler and set level to info 35 | ch = logging.FileHandler(os.path.join(self._result_path, 'eval.log' if is_eval else 'train.log')) 36 | ch.setLevel(logging.INFO) 37 | ch.setFormatter(logging.Formatter( 38 | fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 39 | datefmt="%m/%d/%Y %H:%M:%S") 40 | ) 41 | # add ch to logger 42 | logger.addHandler(ch) 43 | logger.info(f"Logs will be recorded in {self._result_path}...") 44 | return logger 45 | 46 | @rank_zero_only 47 | def save_config(self): 48 | if not os.path.exists(self._result_path): 49 | os.makedirs(self._result_path) 50 | with open(os.path.join(self._result_path, 'config.yaml'), 'w') as fp: 51 | OmegaConf.save(config=self._config, f=fp) 52 | 53 | @rank_zero_only 54 | def log_metrics(self, trainer, split='valid'): 55 | metrics = [] 56 | for k, v in trainer.callback_metrics.items(): 57 | if split == 'valid': 58 | if k.startswith('valid'): 59 | k = k.split('/')[-1].strip() 60 | metrics.append((k, v)) 61 | elif split == 'test': 62 | if k.startswith('test'): 63 | k = k.split('/')[-1].strip() 64 | metrics.append((k, v)) 65 | else: 66 | if k.startswith('train') and k.endswith('epoch'): 67 | k = k.split('/')[-1].strip()[:-6] 68 | metrics.append((k, v)) 69 | metrics = sorted(metrics, key=lambda x: x[0]) 70 | line = ','.join([f" {metric[0]}:{metric[1].item():.4f}" for metric in metrics]) 71 | line = f'EPOCH:{trainer.current_epoch}, {split.upper()}\t' + line 72 | self._logger.info(line) 73 | 74 | def on_train_epoch_end(self, trainer, pl_module): 75 | if dist.get_rank() == 0: 76 | self.log_metrics(trainer, split='train') 77 | 78 | def on_validation_epoch_end(self, trainer, pl_module): 79 | if dist.get_rank() == 0: 80 | self.log_metrics(trainer, split='valid') 81 | 82 | class AsymmetricAutoEncoderLogger(DefaultLogger): 83 | def __init__(self, config, result_path, is_eval=False): 84 | super().__init__(config, result_path, is_eval) 85 | 86 | @rank_zero_only 87 | def log_img(self, pl_module, batch, global_step, split="train"): 88 | with torch.no_grad(): 89 | images, _ = batch 90 | images = images.cpu() 91 | 92 | recons = recons.cpu() 93 | 94 | grid_org = (torchvision.utils.make_grid(images, nrow=4) + 1.0) / 2.0 95 | grid_rec = (torchvision.utils.make_grid(recons, nrow=4) + 1.0) / 2.0 96 | grid_rec = torch. clip(grid_rec, min=0, max=1) 97 | 98 | pl_module.logger.experiment.add_image(f"images_org/{split}", grid_org, global_step=global_step) 99 | 100 | def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): 101 | if hasattr(pl_module.discriminator, 'perceptual_loss'): 102 | pl_module.discriminator.perceptual_loss.eval() 103 | pl_module.generator.encoder.eval() 104 | 105 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 106 | if pl_module._num_opt_steps % self._config.experiment.img_logging_freq == 0: 107 | pl_module.eval() 108 | self.log_img(pl_module, batch, global_step=pl_module._num_opt_steps, split="train") 109 | pl_module.train() 110 | 111 | def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 112 | if batch_idx == 0: 113 | pl_module.eval() 114 | self.log_img(pl_module, batch, global_step=trainer.current_epoch, split="valid") 115 | 116 | 117 | class MetaINRLogger(DefaultLogger): 118 | def __init__(self, config, result_path, is_eval=False): 119 | super().__init__(config, result_path, is_eval) 120 | 121 | class LatentDDPMLogger(DefaultLogger): 122 | def __init__(self, config, result_path, is_eval=False): 123 | super().__init__(config, result_path, is_eval) 124 | 125 | @rank_zero_only 126 | def log_img(self, pl_module, batch, current_epoch, split="train"): 127 | with torch.no_grad(): 128 | inputs, gts = batch 129 | images = gts['img'] 130 | L = int(math.sqrt(images.size(1))) 131 | images = images.view(-1, L, L, 3).permute(0,3,1,2).cpu() 132 | pl_module.logger.experiment.add_image(f"images_org/{split}", grid_org, global_step=current_epoch) 133 | 134 | @rank_zero_only 135 | def log_sample(self, pl_module, current_epoch, split="train"): 136 | with torch.no_grad(): 137 | if pl_module.feat_type == 'weight': 138 | samples = pl_module.sample(16, pl_module._diffusion_kwargs.timestep_respacing, resolution=32) 139 | samples = torch.clamp((torchvision.utils.make_grid(samples, nrow=4) + 1.0) / 2.0, 0, 1) 140 | pl_module.logger.experiment.add_image(f"samples/32px/{split}", samples, global_step=current_epoch) 141 | 142 | samples = pl_module.sample(16, pl_module._diffusion_kwargs.timestep_respacing, resolution=64) 143 | samples = torch.clamp((torchvision.utils.make_grid(samples, nrow=4) + 1.0) / 2.0, 0, 1) 144 | pl_module.logger.experiment.add_image(f"samples/64px/{split}", samples, global_step=current_epoch) 145 | 146 | samples = pl_module.sample(4, pl_module._diffusion_kwargs.timestep_respacing, resolution=128) 147 | samples = torch.clamp((torchvision.utils.make_grid(samples, nrow=2) + 1.0) / 2.0, 0, 1) 148 | pl_module.logger.experiment.add_image(f"samples/128px/{split}", samples, global_step=current_epoch) 149 | else: 150 | samples = pl_module.sample(16, pl_module._diffusion_kwargs.timestep_respacing) 151 | samples = torch.clamp((torchvision.utils.make_grid(samples, nrow=4) + 1.0) / 2.0, 0, 1) 152 | pl_module.logger.experiment.add_image(f"samples/{split}", samples, global_step=current_epoch) 153 | 154 | # def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 155 | # if batch_idx == 0: 156 | # self.log_img(pl_module, batch, current_epoch=trainer.current_epoch, split="train") 157 | 158 | def on_train_epoch_end(self, trainer, pl_module): 159 | self.log_metrics(trainer, split='train') 160 | self.log_sample(pl_module, current_epoch=trainer.current_epoch, split="train") 161 | 162 | def on_validation_epoch_end(self, trainer, pl_module): 163 | self.log_metrics(trainer, split='valid') 164 | -------------------------------------------------------------------------------- /src/models/stage1/modules/quantizer.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Modified from VQGAN (https://github.com/CompVis/taming-transformers) 3 | # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved. 4 | # ------------------------------------------------------------------------------------ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.distributed as dist_fn 9 | from torch.nn import functional as F 10 | from typing import List, Tuple, Optional 11 | from einops import rearrange 12 | import math 13 | 14 | 15 | class VectorQuantizer(nn.Module): 16 | """ 17 | Simplified VectorQuantizer in the original VQGAN repository 18 | """ 19 | def __init__(self, dim: int, n_embed: int, beta: float) -> None: 20 | super().__init__() 21 | self.n_embed = n_embed 22 | self.dim = dim 23 | self.beta = beta 24 | 25 | self.embedding = nn.Embedding(self.n_embed, self.dim) 26 | self.embedding.weight.data.uniform_(-1.0 / self.n_embed, 1.0 / self.n_embed) 27 | 28 | def forward(self, 29 | z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.LongTensor]: 30 | z = rearrange(z, 'b c h w -> b h w c').contiguous() # [B,C,H,W] -> [B,H,W,C] 31 | z_flattened = z.view(-1, self.dim) 32 | 33 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ 34 | torch.sum(self.embedding.weight**2, dim=1) - 2 * \ 35 | torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) 36 | 37 | min_encoding_indices = torch.argmin(d, dim=1) 38 | z_q = self.embedding(min_encoding_indices).view(z.shape) 39 | loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2) 40 | z_q = z + (z_q - z).detach() 41 | z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() 42 | return z_q, loss, min_encoding_indices 43 | 44 | def get_codebook_entry(self, 45 | indices: torch.LongTensor, 46 | shape: Optional[List[int]] = None) -> torch.FloatTensor: 47 | z_q = self.embedding(indices) 48 | if shape is not None: 49 | z_q = z_q.view(shape) 50 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 51 | return z_q 52 | 53 | 54 | class EMAVectorQuantizer(nn.Module): 55 | """ 56 | EMAVectorQuantizer 57 | """ 58 | def __init__(self, 59 | dim: int, 60 | n_embed: int, 61 | beta: float, 62 | decay: float = 0.99, 63 | eps: float = 1e-5, 64 | use_l2_norm: bool = False, 65 | restart_unused_codes: bool = False) -> None: 66 | super().__init__() 67 | self.n_embed = n_embed 68 | self.dim = dim 69 | self.beta = beta 70 | self.decay = decay 71 | self.eps = eps 72 | self.use_l2_norm = use_l2_norm 73 | self.restart_unused_codes = restart_unused_codes 74 | self.threshold = 1.0 75 | 76 | embedding = torch.randn(n_embed, dim) 77 | if (self.use_l2_norm): 78 | embedding = F.normalize(embedding, p=2.0, dim=1, eps=1e-6) 79 | self.register_buffer("embedding", embedding) 80 | self.register_buffer("cluster_size", torch.zeros(self.n_embed)) 81 | self.register_buffer("embedding_avg", embedding.clone()) 82 | 83 | @torch.no_grad() 84 | def _tile_with_noise(self, x, target_n): 85 | B, embed_dim = x.shape 86 | n_repeats = (target_n + B - 1) // B 87 | std = x.new_ones(embed_dim) * 0.01 / math.sqrt(embed_dim) 88 | x = x.repeat(n_repeats, 1) 89 | x = x + torch.rand_like(x) * std 90 | return x 91 | 92 | def forward(self, 93 | z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.LongTensor]: 94 | z = rearrange(z, 'b c h w -> b h w c').contiguous() # [B,C,H,W] -> [B,H,W,C] 95 | z_flattened = z.view(-1, self.dim) 96 | if (self.use_l2_norm): 97 | z_flattened = F.normalize(z_flattened, p=2.0, dim=1, eps=1e-6) 98 | 99 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ 100 | torch.sum(self.embedding**2, dim=1) - 2 * \ 101 | torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding, 'n d -> d n')) 102 | 103 | min_encoding_indices = torch.argmin(d, dim=1) 104 | z_q = F.embedding(min_encoding_indices, self.embedding).view(z.shape) 105 | embed_onehot = F.one_hot(min_encoding_indices, self.n_embed).type(z_flattened.dtype) 106 | 107 | if self.training: 108 | embed_onehot_sum = embed_onehot.sum(0) 109 | embed_sum = embed_onehot.transpose(0, 1) @ z_flattened 110 | 111 | dist_fn.all_reduce(embed_onehot_sum, op=dist_fn.ReduceOp.SUM) 112 | dist_fn.all_reduce(embed_sum, op=dist_fn.ReduceOp.SUM) 113 | 114 | self.cluster_size.data.mul_(self.decay).add_(embed_onehot_sum, alpha=1 - self.decay) 115 | self.embedding_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay) 116 | 117 | if self.restart_unused_codes: 118 | n_vectors = z_flattened.size(0) 119 | if n_vectors < self.n_embed: 120 | vectors = self._tile_with_noise(z_flattened, self.n_embed) 121 | else: 122 | vectors = z_flattened 123 | n_vectors = vectors.shape[0] 124 | _vectors_random = vectors[torch.randperm(n_vectors, device=vectors.device)][:self.n_embed] 125 | 126 | if dist_fn.is_initialized(): 127 | dist_fn.broadcast(_vectors_random, 0) 128 | 129 | usage = (self.cluster_size.view(-1, 1) >= 1).float() 130 | self.embedding_avg.mul_(usage).add_(_vectors_random * (1-usage)) 131 | self.cluster_size.mul_(usage.view(-1)) 132 | self.cluster_size.add_(torch.ones_like(self.cluster_size) * (1-usage).view(-1)) 133 | 134 | n = self.cluster_size.sum() 135 | cluster_size = (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n 136 | embed_normalized = self.embedding_avg / cluster_size.unsqueeze(1) 137 | 138 | if (self.use_l2_norm): 139 | embed_normalized = F.normalize(embed_normalized, p=2.0, dim=1, eps=1e-6) 140 | 141 | self.embedding.data.copy_(embed_normalized) 142 | 143 | diff = self.beta * torch.mean((z_q.detach() - z) ** 2) 144 | z_q = z + (z_q - z).detach() 145 | z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() 146 | return z_q, diff, min_encoding_indices 147 | 148 | def get_soft_codes(self, 149 | z: torch.FloatTensor, 150 | temp=1.0, 151 | stochastic=False): 152 | 153 | z = rearrange(z, 'b c h w -> b h w c').contiguous() # [B,C,H,W] -> [B,H,W,C] 154 | z_flattened = z.view(-1, self.dim) 155 | if (self.use_l2_norm): 156 | z_flattened = F.normalize(z_flattened, p=2.0, dim=1, eps=1e-6) 157 | 158 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ 159 | torch.sum(self.embedding**2, dim=1) - 2 * \ 160 | torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding, 'n d -> d n')) 161 | 162 | soft_code = F.softmax(-d / temp, dim=1) 163 | 164 | if stochastic: 165 | soft_code_flat = soft_code.reshape(-1, soft_code.shape[-1]) 166 | code = torch.multinomial(soft_code_flat, 1) 167 | code = code.reshape(*soft_code.shape[:-1]) 168 | else: 169 | # min_encoding_indices 170 | code = torch.argmin(d, dim=1) 171 | 172 | z_q = F.embedding(code, self.embedding).view(z.shape) 173 | 174 | diff = self.beta * torch.mean((z_q.detach() - z) ** 2) 175 | z_q = z + (z_q - z).detach() 176 | z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() 177 | return z_q, diff, code, soft_code 178 | 179 | def get_codebook_entry(self, 180 | indices: torch.LongTensor, 181 | shape: Optional[List[int]] = None) -> torch.FloatTensor: 182 | z_q = F.embedding(indices, self.embedding) 183 | if shape is not None: 184 | z_q = z_q.view(shape) 185 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 186 | return z_q 187 | -------------------------------------------------------------------------------- /src/models/stage1/generator.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # HQ-Transformer 3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved. 4 | # Licensed under the MIT License [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | 7 | import torch 8 | import torch.nn as nn 9 | from typing import Tuple, List, Optional 10 | from omegaconf import OmegaConf 11 | from einops.layers.torch import Rearrange 12 | import itertools 13 | import torchvision 14 | 15 | import clip 16 | from .CIPS.GeneratorsCIPS import CIPSskip 17 | 18 | import pdb 19 | 20 | class ResMLP(nn.Module): 21 | def __init__(self,nin,nhidden,nblock): 22 | super(ResMLP, self).__init__() 23 | 24 | M = [] 25 | for _ in range(nblock): 26 | M.append(nn.Sequential(nn.Linear(nin, nhidden), nn.LeakyReLU(), nn.Linear(nhidden, nin))) 27 | 28 | self.M = nn.ModuleList(M) 29 | 30 | def forward(self,x): 31 | for m in self.M: 32 | x = x + m(x) 33 | 34 | return x 35 | 36 | class AutoEncoderGenerator(torch.nn.Module): 37 | def __init__(self, 38 | hparams_enc: OmegaConf, 39 | hparams_dec: OmegaConf, 40 | device='cuda'): 41 | super().__init__() 42 | 43 | self.resmlp = None 44 | if hparams_enc.type =="ViT-B/32": 45 | self.encoder, _ = clip.load(hparams_enc.type, device=device) 46 | self.preprocess = lambda img: torch.nn.functional.interpolate(img, size=224) 47 | dim_input = 512 48 | self.enc_type = 'clip-flatten' 49 | elif 'densenet121' in hparams_enc.type: 50 | densenet = torchvision.models.densenet121(pretrained=True) 51 | net = list(densenet.children())[:-1] # last batchnorm w.o. classifier 52 | dim_input = 1024 53 | dim_spatial = 8*8 54 | 55 | self.enc_type = 'densenet-conv' 56 | self.encoder = torch.nn.Sequential(*net) 57 | self.preprocess = None 58 | 59 | elif 'resnet' in hparams_enc.type: 60 | if 'resnet18' in hparams_enc.type: 61 | resnet = torchvision.models.resnet18(pretrained=True) 62 | net = list(resnet.children())[:-2] # conv 63 | dim_input = 512 64 | dim_spatial = 8*8 65 | 66 | elif 'resnet50' in hparams_enc.type: 67 | resnet = torchvision.models.resnet50(pretrained=True) 68 | if 'layer2' in hparams_enc.type: 69 | net = list(resnet.children())[:-4] 70 | dim_input = 512 71 | dim_spatial = 32*32 72 | elif 'layer3' in hparams_enc.type: 73 | net = list(resnet.children())[:-3] 74 | dim_input = 1024 75 | dim_spatial = 16*16 76 | else: 77 | net = list(resnet.children())[:-2] 78 | dim_input = 2048 79 | dim_spatial = 8*8 80 | 81 | if 'conv' in hparams_enc.type: 82 | self.enc_type = 'resnet-conv' 83 | self.encoder = torch.nn.Sequential(*net) 84 | self.preprocess = None 85 | else: # flatten 86 | dim_input = dim_input * dim_spatial 87 | dim_spatial = 1*1 88 | self.enc_type = 'resnet-flatten' 89 | resnet.append(Rearrange('b c h w -> b (c h w)')) 90 | self.encoder = torch.nn.Sequential(*net) 91 | self.preprocess = None 92 | 93 | else: 94 | assert(False) 95 | 96 | if hparams_dec.type == 'CIPSskip': 97 | self.decoder = CIPSskip(**hparams_dec) 98 | 99 | dim_output = hparams_dec.style_dim 100 | 101 | if 'conv' in self.enc_type: 102 | embed = [] 103 | if 'groupnorm' in hparams_enc.type_init_norm: 104 | n_group = int(hparams_enc.type_init_norm.split('groupnorm')[-1]) 105 | embed.append(torch.nn.GroupNorm(n_group, dim_input)) 106 | 107 | # conv layers 108 | for ci in range(0, hparams_enc.num_embed_conv): 109 | if ci == 0: 110 | cin = dim_input 111 | else: 112 | if hparams_enc.dim_embed_conv > 0: 113 | cin = hparams_enc.dim_embed_conv 114 | else: 115 | cin = dim_input 116 | 117 | if hparams_enc.dim_embed_conv > 0: 118 | cout = hparams_enc.dim_embed_conv 119 | else: 120 | cout = dim_input 121 | 122 | embed.append(torch.nn.Conv2d(cin, cout, 1)) 123 | embed.append(torch.nn.GroupNorm(hparams_enc.num_groups, cout)) 124 | embed.append(torch.nn.LeakyReLU()) 125 | 126 | embed.pop() # remove last activation 127 | 128 | # flatten layer 129 | embed.append(Rearrange('b c h w -> b (c h w)')) 130 | 131 | num_flatten_linear = hparams_enc.num_flatten_linear 132 | dim_flatten_linear = hparams_enc.dim_flatten_linear 133 | for li in range(0, num_flatten_linear): 134 | if li == 0: 135 | lin = cout*dim_spatial 136 | else: 137 | if dim_flatten_linear > 0: 138 | lin = dim_flatten_linear 139 | else: 140 | lin = dim_output 141 | 142 | if li+1 == num_flatten_linear: 143 | lout = dim_output 144 | else: 145 | if dim_flatten_linear > 0: 146 | lout = dim_flatten_linear 147 | else: 148 | lout = dim_output 149 | 150 | embed.append(torch.nn.Linear(lin, lout, 1)) 151 | embed.append(torch.nn.LeakyReLU()) 152 | 153 | embed.pop() # remove last activation 154 | 155 | num_res = hparams_enc.num_residual_embed 156 | dim_res = hparams_enc.dim_residual_embed 157 | if num_res > 0: 158 | if dim_res == 0: 159 | dim_res = lout 160 | self.resmlp = ResMLP(lout, dim_res, num_res) 161 | 162 | self.embed = nn.Sequential(*embed) 163 | 164 | elif 'flatten' in self.enc_type: 165 | self.embed = nn.Linear(dim_input, dim_output, 1) 166 | 167 | self.res = hparams_dec.size 168 | self.coords = None 169 | 170 | def forward(self, x, coords=None): 171 | latent = self.encode(x) 172 | recon = self.decode(latent, coords) 173 | return recon, latent 174 | 175 | def encode(self, x: torch.FloatTensor) -> torch.FloatTensor: 176 | if self.preprocess is not None: 177 | x = self.preprocess(x) 178 | if 'clip' in self.enc_type: 179 | h = self.encoder.encode_image(x) 180 | else: 181 | h = self.encoder(x) 182 | 183 | h = self.embed(h.float()) 184 | 185 | if self.resmlp is not None: 186 | h = self.resmlp(h) 187 | return h 188 | 189 | def decode(self, latent: torch.FloatTensor, coords: Optional[torch.FloatTensor]=None) -> torch.FloatTensor: 190 | # if coords is not given, use the saved one. 191 | if coords is None: 192 | if self.coords is None: 193 | update_coords = True 194 | elif latent.size(0) != self.coords.size(0): 195 | update_coords = True 196 | else: 197 | update_coords = False 198 | 199 | if update_coords: 200 | b = latent.size(0) 201 | self.coords = self.convert_to_coord_format(b, self.res, self.res, integer_values=False) 202 | 203 | coords = self.coords 204 | 205 | recon, _ = self.decoder(coords, [latent]) 206 | return recon 207 | 208 | # adopted from "tensor_transforms.py" CIPS 209 | def convert_to_coord_format(self, b, h, w, device='cuda', integer_values=False): 210 | if integer_values: 211 | x_channel = torch.arange(w, dtype=torch.float, device=device).view(1, 1, 1, -1).repeat(b, 1, w, 1) 212 | y_channel = torch.arange(h, dtype=torch.float, device=device).view(1, 1, -1, 1).repeat(b, 1, 1, h) 213 | else: 214 | x_channel = torch.linspace(-1, 1, w, device=device).view(1, 1, 1, -1).repeat(b, 1, w, 1) 215 | y_channel = torch.linspace(-1, 1, h, device=device).view(1, 1, -1, 1).repeat(b, 1, 1, h) 216 | return torch.cat((x_channel, y_channel), dim=1) 217 | 218 | def get_parameters(self): 219 | # fusing two generators by chain 220 | return itertools.chain(self.decoder.parameters(), self.embed.parameters()) -------------------------------------------------------------------------------- /src/utils/metric_voxel.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow.compat.v1 as tf 3 | from src.utils.structural_losses.tf_nndistance import nn_distance 4 | 5 | import pdb 6 | 7 | def iterate_in_chunks(l, n): 8 | '''Yield successive 'n'-sized chunks from iterable 'l'. 9 | Note: last chunk will be smaller than l if n doesn't divide l perfectly. 10 | ''' 11 | for i in range(0, len(l), n): 12 | yield l[i:i + n] 13 | 14 | # https://github.com/optas/latent_3d_points/blob/master/src/evaluation_metrics.py#L33 15 | def minimum_mathing_distance_tf_graph(n_pc_points, batch_size=None, normalize=True, sess=None, verbose=False, use_sqrt=False): 16 | ''' Produces the graph operations necessary to compute the MMD and consequently also the Coverage due to their 'symmetric' nature. 17 | Assuming a "reference" and a "sample" set of point-clouds that will be matched, this function creates the operation that matches 18 | a _single_ "reference" point-cloud to all the "sample" point-clouds given in a batch. Thus, is the building block of the function 19 | ```minimum_mathing_distance`` and ```coverage``` that iterate over the "sample" batches and each "reference" point-cloud. 20 | Args: 21 | n_pc_points (int): how many points each point-cloud of those to be compared has. 22 | batch_size (optional, int): if the iterator code that uses this function will 23 | use a constant batch size for iterating the sample point-clouds you can 24 | specify it hear to speed up the compute. Alternatively, the code is adapted 25 | to read the batch size dynamically. 26 | normalize (boolean): if True, the matched distances are normalized by diving them with 27 | the number of points of the compared point-clouds (n_pc_points). 28 | use_sqrt (boolean): When the matching is based on Chamfer (default behavior), if True, 29 | the Chamfer is computed based on the (not-squared) euclidean distances of the 30 | matched point-wise euclidean distances. 31 | use_EMD (boolean): If true, the matchings are based on the EMD. 32 | ''' 33 | tf.disable_v2_behavior() 34 | if normalize: 35 | reducer = tf.reduce_mean 36 | else: 37 | reducer = tf.reduce_sum 38 | 39 | if sess is None: 40 | config = tf.ConfigProto() 41 | config.gpu_options.allow_growth = True 42 | sess = tf.Session(config=config) 43 | 44 | # Placeholders for the point-clouds: 1 for the reference (usually Ground-truth) and one of variable size for the collection 45 | # which is going to be matched with the reference. 46 | ref_pl = tf.placeholder(tf.float32, shape=(1, n_pc_points, 3)) 47 | sample_pl = tf.placeholder(tf.float32, shape=(batch_size, n_pc_points, 3)) 48 | 49 | if batch_size is None: 50 | batch_size = tf.shape(sample_pl)[0] 51 | 52 | ref_repeat = tf.tile(ref_pl, [batch_size, 1, 1]) 53 | ref_repeat = tf.reshape(ref_repeat, [batch_size, n_pc_points, 3]) 54 | 55 | ref_to_s, _, s_to_ref, _ = nn_distance(ref_repeat, sample_pl) 56 | if use_sqrt: 57 | ref_to_s = tf.sqrt(ref_to_s) 58 | s_to_ref = tf.sqrt(s_to_ref) 59 | all_dist_in_batch = reducer(ref_to_s, 1) + reducer(s_to_ref, 1) 60 | 61 | best_in_batch = tf.reduce_min(all_dist_in_batch) # Best distance, of those that were matched to single ref pc. 62 | location_of_best = tf.argmin(all_dist_in_batch, axis=0) 63 | return ref_pl, sample_pl, best_in_batch, location_of_best, sess 64 | 65 | # https://github.com/optas/latent_3d_points/blob/master/src/evaluation_metrics.py#L90 66 | def minimum_mathing_distance(sample_pcs, ref_pcs, batch_size, normalize=True, sess=None, verbose=False, use_sqrt=False): 67 | '''Computes the MMD between two sets of point-clouds. 68 | Args: 69 | sample_pcs (numpy array SxKx3): the S point-clouds, each of K points that will be matched and 70 | compared to a set of "reference" point-clouds. 71 | ref_pcs (numpy array RxKx3): the R point-clouds, each of K points that constitute the set of 72 | "reference" point-clouds. 73 | batch_size (int): specifies how large will the batches be that the compute will use to make 74 | the comparisons of the sample-vs-ref point-clouds. 75 | normalize (boolean): if True, the distances are normalized by diving them with 76 | the number of the points of the point-clouds (n_pc_points). 77 | use_sqrt: (boolean): When the matching is based on Chamfer (default behavior), if True, the 78 | Chamfer is computed based on the (not-squared) euclidean distances of the matched point-wise 79 | euclidean distances. 80 | sess (tf.Session, default None): if None, it will make a new Session for this. 81 | use_EMD (boolean: If true, the matchings are based on the EMD. 82 | Returns: 83 | A tuple containing the MMD and all the matched distances of which the MMD is their mean. 84 | ''' 85 | 86 | n_ref, n_pc_points, pc_dim = ref_pcs.shape 87 | _, n_pc_points_s, pc_dim_s = sample_pcs.shape 88 | 89 | if n_pc_points != n_pc_points_s or pc_dim != pc_dim_s: 90 | raise ValueError('Incompatible size of point-clouds.') 91 | 92 | ref_pl, sample_pl, best_in_batch, _, sess = minimum_mathing_distance_tf_graph(n_pc_points, normalize=normalize, 93 | sess=sess, use_sqrt=use_sqrt) 94 | matched_dists = [] 95 | for i in range(n_ref): 96 | best_in_all_batches = [] 97 | if verbose and i % 50 == 0: 98 | print(i) 99 | for sample_chunk in iterate_in_chunks(sample_pcs, batch_size): 100 | feed_dict = {ref_pl: np.expand_dims(ref_pcs[i], 0), sample_pl: sample_chunk} 101 | b = sess.run(best_in_batch, feed_dict=feed_dict) 102 | best_in_all_batches.append(b) 103 | matched_dists.append(np.min(best_in_all_batches)) 104 | mmd = np.mean(matched_dists) 105 | return mmd, matched_dists 106 | 107 | 108 | # https://github.com/optas/latent_3d_points/blob/master/src/evaluation_metrics.py#L135 109 | def coverage(sample_pcs, ref_pcs, batch_size, normalize=True, sess=None, verbose=False, use_sqrt=False, ret_dist=False): 110 | '''Computes the Coverage between two sets of point-clouds. 111 | Args: 112 | sample_pcs (numpy array SxKx3): the S point-clouds, each of K points that will be matched 113 | and compared to a set of "reference" point-clouds. 114 | ref_pcs (numpy array RxKx3): the R point-clouds, each of K points that constitute the 115 | set of "reference" point-clouds. 116 | batch_size (int): specifies how large will the batches be that the compute will use to 117 | make the comparisons of the sample-vs-ref point-clouds. 118 | normalize (boolean): if True, the distances are normalized by diving them with 119 | the number of the points of the point-clouds (n_pc_points). 120 | use_sqrt (boolean): When the matching is based on Chamfer (default behavior), if True, 121 | the Chamfer is computed based on the (not-squared) euclidean distances of the matched 122 | point-wise euclidean distances. 123 | sess (tf.Session): If None, it will make a new Session for this. 124 | use_EMD (boolean): If true, the matchings are based on the EMD. 125 | ret_dist (boolean): If true, it will also return the distances between each sample_pcs and 126 | it's matched ground-truth. 127 | Returns: the coverage score (int), 128 | the indices of the ref_pcs that are matched with each sample_pc 129 | and optionally the matched distances of the samples_pcs. 130 | ''' 131 | n_ref, n_pc_points, pc_dim = ref_pcs.shape 132 | n_sam, n_pc_points_s, pc_dim_s = sample_pcs.shape 133 | 134 | if n_pc_points != n_pc_points_s or pc_dim != pc_dim_s: 135 | raise ValueError('Incompatible Point-Clouds.') 136 | 137 | ref_pl, sample_pl, best_in_batch, loc_of_best, sess = minimum_mathing_distance_tf_graph(n_pc_points, normalize=normalize, 138 | sess=sess, use_sqrt=use_sqrt) 139 | matched_gt = [] 140 | matched_dist = [] 141 | for i in range(n_sam): 142 | best_in_all_batches = [] 143 | loc_in_all_batches = [] 144 | 145 | if verbose and i % 50 == 0: 146 | print(i) 147 | 148 | for ref_chunk in iterate_in_chunks(ref_pcs, batch_size): 149 | feed_dict = {ref_pl: np.expand_dims(sample_pcs[i], 0), sample_pl: ref_chunk} 150 | b, loc = sess.run([best_in_batch, loc_of_best], feed_dict=feed_dict) 151 | best_in_all_batches.append(b) 152 | loc_in_all_batches.append(loc) 153 | 154 | best_in_all_batches = np.array(best_in_all_batches) 155 | b_hit = np.argmin(best_in_all_batches) # In which batch the minimum occurred. 156 | matched_dist.append(np.min(best_in_all_batches)) 157 | hit = np.array(loc_in_all_batches)[b_hit] 158 | matched_gt.append(batch_size * b_hit + hit) 159 | 160 | cov = len(np.unique(matched_gt)) / float(n_ref) 161 | 162 | if ret_dist: 163 | return cov, matched_gt, matched_dist 164 | else: 165 | return cov, matched_gt -------------------------------------------------------------------------------- /src/utils/structural_losses/tf_approxmatch_g.cu: -------------------------------------------------------------------------------- 1 | __global__ void approxmatch(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,float * __restrict__ match,float * temp){ 2 | float * remainL=temp+blockIdx.x*(n+m)*2, * remainR=temp+blockIdx.x*(n+m)*2+n,*ratioL=temp+blockIdx.x*(n+m)*2+n+m,*ratioR=temp+blockIdx.x*(n+m)*2+n+m+n; 3 | float multiL,multiR; 4 | if (n>=m){ 5 | multiL=1; 6 | multiR=n/m; 7 | }else{ 8 | multiL=m/n; 9 | multiR=1; 10 | } 11 | const int Block=1024; 12 | __shared__ float buf[Block*4]; 13 | for (int i=blockIdx.x;i=-2;j--){ 22 | float level=-powf(4.0f,j); 23 | if (j==-2){ 24 | level=0; 25 | } 26 | for (int k0=0;k0>>(b,n,m,xyz1,xyz2,match,temp); 182 | } 183 | __global__ void matchcost(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,const float * __restrict__ match,float * __restrict__ out){ 184 | __shared__ float allsum[512]; 185 | const int Block=1024; 186 | __shared__ float buf[Block*3]; 187 | for (int i=blockIdx.x;i>>(b,n,m,xyz1,xyz2,match,out); 228 | } 229 | __global__ void matchcostgrad2(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,const float * __restrict__ match,float * __restrict__ grad2){ 230 | __shared__ float sum_grad[256*3]; 231 | for (int i=blockIdx.x;i>>(b,n,m,xyz1,xyz2,match,grad1); 294 | matchcostgrad2<<>>(b,n,m,xyz1,xyz2,match,grad2); 295 | } 296 | 297 | -------------------------------------------------------------------------------- /src/datasets/nerf_dataset.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # HQ-Transformer 3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved. 4 | # Licensed under the MIT License [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | 7 | import os 8 | import glob 9 | import imageio 10 | import lmdb 11 | import pickle 12 | from io import BytesIO 13 | from PIL import Image 14 | 15 | import torch 16 | import torch.nn.functional as F 17 | from torch.utils.data import DataLoader 18 | import torchvision 19 | import torchvision.transforms as transforms 20 | from torchvision.transforms import Compose, ToTensor, Normalize 21 | 22 | import numpy as np 23 | 24 | import pdb 25 | 26 | class SRNDatasets(DataLoader): 27 | """ 28 | Dataset from SRN (V. Sitzmann et al. 2020) 29 | PixelNeRF: https://github.com/sxyu/pixel-nerf/blob/master/src/data/SRNDataset.py 30 | """ 31 | 32 | def __init__( 33 | self, category, opt, split="train", world_scale=1.0, dataset_root='datasets' 34 | ): 35 | """ 36 | :param split train | val | test 37 | :param image_size result image size (resizes if different) 38 | :param world_scale amount to scale entire world by 39 | """ 40 | self.debug = opt.debug 41 | self.zero_to_one = True if opt.rgb_activation == 'sigmoid' else False 42 | self.subsampled_views = opt.subsampled_views 43 | image_size=(opt.resolution, opt.resolution) 44 | 45 | if category == 'cars': 46 | path = dataset_root + '/srn_cars/cars' 47 | else: 48 | raise NotImplementedError("please category name of SRN Dataset") 49 | 50 | self.base_path = path + "_" + split 51 | self.dataset_name = os.path.basename(path) 52 | 53 | 54 | print("Loading SRN dataset", self.base_path, "name:", self.dataset_name) 55 | self.split = split 56 | assert os.path.exists(self.base_path) 57 | 58 | # if category == 'chairs' and split == "train": 59 | # # Ugly thing from SRN's public dataset 60 | # tmp = os.path.join(self.base_path, "chairs_2.0_train") 61 | # if os.path.exists(tmp): 62 | # self.base_path = tmp 63 | 64 | self.intrins = sorted( 65 | glob.glob(os.path.join(self.base_path, "*", "intrinsics.txt")) 66 | ) 67 | 68 | self.image_to_tensor = Compose([ 69 | ToTensor(), 70 | Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 71 | ]) 72 | # self.mask_to_tensor = Compose([ 73 | # ToTensor(), 74 | # Normalize((0.0,), (1.0,)) 75 | # ]) 76 | 77 | 78 | self.image_size = image_size 79 | self.world_scale = world_scale 80 | self._coord_trans = torch.diag( 81 | torch.tensor([1, -1, -1, 1], dtype=torch.float32) 82 | ) 83 | 84 | self.z_near = opt.near 85 | self.z_far = opt.far 86 | 87 | # if category == 'cars': 88 | # self.z_near = 0.8 89 | # self.z_far = 1.8 90 | # elif category == 'chairs': 91 | # self.z_near = 1.25 92 | # self.z_far = 2.75 93 | self.lindisp = False 94 | 95 | def __len__(self): 96 | if self.debug: 97 | return 1 # for debug (NeRF check) 98 | return len(self.intrins) 99 | 100 | def __getitem__(self, index): 101 | intrin_path = self.intrins[index] 102 | dir_path = os.path.dirname(intrin_path) 103 | rgb_paths = sorted(glob.glob(os.path.join(dir_path, "rgb", "*"))) 104 | pose_paths = sorted(glob.glob(os.path.join(dir_path, "pose", "*"))) 105 | 106 | assert len(rgb_paths) == len(pose_paths) 107 | 108 | with open(intrin_path, "r") as intrinfile: 109 | lines = intrinfile.readlines() 110 | focal, cx, cy, _ = map(float, lines[0].split()) 111 | height, width = map(int, lines[-1].split()) 112 | 113 | all_imgs = [] 114 | all_poses = [] 115 | # all_masks = [] 116 | # all_bboxes = [] 117 | for rgb_path, pose_path in zip(rgb_paths, pose_paths): 118 | img = imageio.imread(rgb_path)[..., :3] 119 | img_tensor = self.image_to_tensor(img) 120 | # mask = (img != 255).all(axis=-1)[..., None].astype(np.uint8) * 255 121 | # mask_tensor = self.mask_to_tensor(mask) 122 | 123 | pose = torch.from_numpy( 124 | np.loadtxt(pose_path, dtype=np.float32).reshape(4, 4) 125 | ) 126 | pose = pose @ self._coord_trans 127 | 128 | # rows = np.any(mask, axis=1) 129 | # cols = np.any(mask, axis=0) 130 | # rnz = np.where(rows)[0] 131 | # cnz = np.where(cols)[0] 132 | # if len(rnz) == 0: 133 | # raise RuntimeError( 134 | # "ERROR: Bad image at", rgb_path, "please investigate!" 135 | # ) 136 | # rmin, rmax = rnz[[0, -1]] 137 | # cmin, cmax = cnz[[0, -1]] 138 | # bbox = torch.tensor([cmin, rmin, cmax, rmax], dtype=torch.float32) 139 | 140 | all_imgs.append(img_tensor) 141 | # all_masks.append(mask_tensor) 142 | all_poses.append(pose) 143 | # all_bboxes.append(bbox) 144 | 145 | all_imgs = torch.stack(all_imgs) 146 | all_poses = torch.stack(all_poses) 147 | # all_masks = torch.stack(all_masks) 148 | # all_bboxes = torch.stack(all_bboxes) 149 | 150 | if all_imgs.shape[-2:] != self.image_size: 151 | scale = self.image_size[0] / all_imgs.shape[-2] 152 | focal *= scale 153 | cx *= scale 154 | cy *= scale 155 | # all_bboxes *= scale 156 | 157 | all_imgs = F.interpolate(all_imgs, size=self.image_size, mode="area") 158 | # all_masks = F.interpolate(all_masks, size=self.image_size, mode="area") 159 | 160 | if self.world_scale != 1.0: 161 | focal *= self.world_scale 162 | all_poses[:, :3, 3] *= self.world_scale 163 | focal = torch.tensor(focal, dtype=torch.float32) 164 | 165 | # result = { 166 | # "path": dir_path, 167 | # "img_id": index, 168 | # "focal": focal, 169 | # "c": torch.tensor([cx, cy], dtype=torch.float32), 170 | # "images": all_imgs, 171 | # "masks": all_masks, 172 | # "bbox": all_bboxes, 173 | # "poses": all_poses, 174 | # } 175 | if self.zero_to_one: 176 | all_imgs = all_imgs * 0.5 + 0.5 177 | in_dict = {'idx': index, 178 | 'focal': focal, 179 | 'c2w': all_poses} 180 | out_dict = {'img': all_imgs} 181 | 182 | return in_dict, out_dict 183 | 184 | 185 | def format_for_lmdb(*args): 186 | key_parts = [] 187 | for arg in args: 188 | if isinstance(arg, int): 189 | arg = str(arg).zfill(7) 190 | key_parts.append(arg) 191 | return '-'.join(key_parts).encode('utf-8') 192 | 193 | 194 | class SRNDatasetsLMDB(DataLoader): 195 | """ 196 | Dataset from SRN (V. Sitzmann et al. 2020) 197 | PixelNeRF: https://github.com/sxyu/pixel-nerf/blob/master/src/data/SRNDataset.py 198 | """ 199 | 200 | def __init__( 201 | self, category, opt, split="train", world_scale=1.0, dataset_root='datasets', zero_to_one=False 202 | ): 203 | """ 204 | :param split train | val | test 205 | :param image_size result image size (resizes if different) 206 | :param world_scale amount to scale entire world by 207 | """ 208 | if category != 'cars': 209 | raise NotImplementedError("please category name of SRN Dataset") 210 | 211 | self.base_path = os.path.join(dataset_root, 'srn_cars_lmdb', "cars_" + split) 212 | assert os.path.exists(self.base_path) 213 | 214 | self.env = lmdb.open( 215 | self.base_path, 216 | max_readers=32, 217 | readonly=True, 218 | lock=False, 219 | readahead=False, 220 | meminit=False, 221 | ) 222 | 223 | if not self.env: 224 | raise IOError('Cannot open lmdb dataset', self.base_path) 225 | 226 | print("Loading SRN dataset", self.base_path, "name:") 227 | with self.env.begin(write=False) as txn: 228 | self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8')) 229 | 230 | self.split = split 231 | 232 | self.image_to_tensor = Compose([ 233 | ToTensor(), 234 | Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 235 | ]) 236 | 237 | self.debug = opt.debug 238 | self.zero_to_one = True if opt.rgb_activation == 'sigmoid' or zero_to_one else False 239 | self.subsampled_views = opt.subsampled_views 240 | self.image_size = (opt.resolution, opt.resolution) 241 | self.world_scale = world_scale 242 | self._coord_trans = torch.diag( 243 | torch.tensor([1, -1, -1, 1], dtype=torch.float32) 244 | ) 245 | 246 | self.z_near = opt.near 247 | self.z_far = opt.far 248 | self.lindisp = False 249 | 250 | def __len__(self): 251 | if self.debug: 252 | return 1 # for debug (NeRF check) 253 | return self.length 254 | 255 | def __getitem__(self, index): 256 | all_imgs = [] 257 | all_poses = [] 258 | 259 | with self.env.begin(write=False) as txn: 260 | intrins_byte = txn.get(format_for_lmdb(index, 'intrinsic')) 261 | intrins = pickle.loads(intrins_byte) 262 | 263 | focal, cx, cy = intrins[0][0], intrins[0][1], intrins[0][2] 264 | 265 | l_key = format_for_lmdb(index, 'length') 266 | length_txt = txn.get(l_key) 267 | for frame_idx in range(0, int(length_txt)): 268 | #i_key = f'{str(index).zfill(7)}img{str(frame_idx).zfill(7)}'.encode('utf-8') 269 | #p_key = f'{str(index).zfill(7)}pose{str(frame_idx).zfill(7)}'.encode('utf-8') 270 | i_key = format_for_lmdb(index, 'img', frame_idx) 271 | p_key = format_for_lmdb(index, 'pose', frame_idx) 272 | 273 | pose = pickle.loads(txn.get(p_key)) 274 | 275 | buffer = BytesIO(txn.get(i_key)) 276 | img = Image.open(buffer) 277 | 278 | img_tensor = self.image_to_tensor(img) 279 | 280 | 281 | all_imgs.append(img_tensor) 282 | 283 | all_poses.append(pose) 284 | 285 | all_imgs = torch.stack(all_imgs) 286 | all_poses = torch.stack(all_poses) 287 | 288 | if all_imgs.shape[-2:] != self.image_size: 289 | scale = self.image_size[0] / all_imgs.shape[-2] 290 | focal *= scale 291 | cx *= scale 292 | cy *= scale 293 | 294 | all_imgs = F.interpolate(all_imgs, size=self.image_size, mode="area") 295 | 296 | if self.world_scale != 1.0: 297 | focal *= self.world_scale 298 | all_poses[:, :3, 3] *= self.world_scale 299 | focal = torch.tensor(focal, dtype=torch.float32) 300 | 301 | if self.zero_to_one: 302 | all_imgs = all_imgs * 0.5 + 0.5 303 | in_dict = {'idx': index, 304 | 'focal': focal, 305 | 'c2w': all_poses} 306 | out_dict = {'img': all_imgs} 307 | 308 | return in_dict, out_dict 309 | -------------------------------------------------------------------------------- /src/models/stage1/nerf/helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torchvision 4 | import numpy as np 5 | 6 | # baseline and functa 7 | from .nerf_utils import * 8 | # mip-nerf 9 | from .ray_utils import sample_along_rays_mip 10 | from .ray_utils import volumetric_rendering as volumetric_rendering_mip 11 | 12 | ''' 13 | nerf_utils.py: general nerf functions 14 | helper.py : helping functions for our work 15 | ''' 16 | 17 | MAX_DENSITY = 10. 18 | 19 | def get_rays_batch(H, W, focal, c2w, compute_radii=False): 20 | # TODO: faster 21 | bsz = c2w.shape[0] 22 | all_rays = [] 23 | for i in range(bsz): 24 | # per image 25 | cam_rays = get_rays(H, W, focal[i], c2w[i], compute_radii=compute_radii) 26 | all_rays.append(cam_rays) 27 | results = torch.stack(all_rays, 1) 28 | #results = torch.stack(list(map(lambda f, c: get_rays(H,W,f,c), focal, c2w)), 1) 29 | return results 30 | 31 | def get_samples_for_nerf(model_input, gt, opt, view_sampling=True, pixel_sampling=True, view_num=None): 32 | all_scene_rays = [] 33 | all_scene_rgb = [] 34 | all_scene_idx = [] 35 | ALL_VIEW = gt['img'].shape[1] 36 | bsz = gt['img'].shape[0] 37 | 38 | for i_batch in range(bsz): 39 | focal = model_input['focal'][i_batch] #(ALL_VIEW) 40 | c2w = model_input['c2w'][i_batch] #(ALL_VIEW,4,4) 41 | idx_ = model_input['idx'][i_batch] #1 42 | rgb = gt['img'][i_batch] #(ALL_VIEW,3,H,W) 43 | 44 | # sampling view 45 | if view_sampling and opt.subsampled_views > 0: 46 | NV = opt.subsampled_views if view_num is None else view_num 47 | view_inds = np.random.choice(ALL_VIEW, NV) 48 | focal = focal.repeat(NV) 49 | c2w = c2w[view_inds, :,:] 50 | rgb = rgb[view_inds, :, :, :] 51 | else: 52 | focal = focal.repeat(ALL_VIEW) 53 | NV = ALL_VIEW 54 | 55 | # get origin & direction of all pixels 56 | compute_radii = opt.rendering_type == 'mip-nerf' 57 | cam_rays = get_rays_batch(opt.H, opt.W, focal, c2w, compute_radii=compute_radii) #(2 or 3,NV,H,W,3) 58 | 59 | # sampling [H,W] indices 60 | NM = 3 if compute_radii else 2 61 | 62 | assert cam_rays.size(0) == NM and cam_rays.size(1) == NV and cam_rays.size(4) == 3 63 | 64 | assert rgb.size(0) == NV and rgb.size(1) == 3 65 | cam_rays = cam_rays.permute(0, 1, 4, 2, 3) #(2,NV,3,H,W) 66 | cam_rays = cam_rays.reshape(NM, NV, 3, -1) #(2,NV,3,H*W) 67 | rgb = rgb.reshape(NV,3,-1) #(NV,3,H*W) 68 | 69 | if pixel_sampling and opt.subsampled_pixels > 0: 70 | if 'bbox' in model_input.keys(): 71 | pass 72 | else: 73 | pix_inds = np.random.choice(opt.H * opt.W, opt.subsampled_pixels) 74 | 75 | cam_rays = cam_rays[:,:,:,pix_inds].permute(0, 1, 3, 2) #(2,NV,NP,3) 76 | rgb = rgb[:,:,pix_inds].permute(0, 2, 1) #(NV,NP,3) 77 | else: 78 | cam_rays = cam_rays.permute(0, 1, 3, 2) #(2,NV,NP,3) 79 | rgb = rgb.permute(0, 2, 1) #(NV,NP,3) 80 | 81 | all_scene_rays.append(cam_rays) 82 | all_scene_rgb.append(rgb) 83 | 84 | all_scene_rgb = torch.stack(all_scene_rgb, 0).reshape(bsz, -1, 3) #(B*NV*NP,3) 85 | all_scene_rays = torch.stack(all_scene_rays, 1).reshape(NM, -1, 3) #(3,B*NV*NP,3) 86 | if opt.rendering_type == 'mip-nerf': 87 | t_vals, (coords, coords_covs) = sample_along_rays_mip(all_scene_rays, opt) 88 | else: 89 | t_vals, coords = sample_along_rays(all_scene_rays, opt) 90 | 91 | coords = coords.reshape(bsz, -1, 3) 92 | 93 | # Model input 94 | model_input['coords'] = coords 95 | model_input['rays_d'] = all_scene_rays[1] 96 | model_input['t_vals'] = t_vals 97 | del model_input['focal'] 98 | del model_input['c2w'] 99 | 100 | # GT 101 | gt['img'] = all_scene_rgb 102 | return model_input, gt 103 | 104 | 105 | def get_test_samples_for_nerf(model_input, view_inds, opt, focal_ratio=1.0): 106 | all_scene_rays = [] 107 | NV = 1 108 | if focal_ratio != 1.0: 109 | focal = model_input['focal'] * focal_ratio #(ALL_VIEW) 110 | else: 111 | focal = model_input['focal'] #(ALL_VIEW) 112 | #focal = focal.repeat(ALL_VIEW) 113 | c2w = model_input['c2w'][0] #(ALL_VIEW,4,4) 114 | 115 | #focal = focal.repeat(NV) 116 | _c2w = c2w[view_inds:view_inds+1, :,:] 117 | _focal = focal.unsqueeze(0) 118 | 119 | # get origin & direction of all pixels 120 | # cam_rays = get_rays_batch(opt.H, opt.W, _focal, _c2w) #(2,NV,H,W,3) 121 | compute_radii = opt.rendering_type == 'mip-nerf' 122 | cam_rays = get_rays_batch(opt.H, opt.W, _focal, _c2w, compute_radii=compute_radii) #(2 or 3,NV,H,W,3) 123 | 124 | # sampling [H,W] indices 125 | #compute_radii = opt.rendering_type == 'mip-nerf': 126 | 127 | NM = 3 if compute_radii else 2 128 | # if compute_radii: 129 | # NM = 3 130 | # else: 131 | # NM = 2 132 | assert cam_rays.size(0) == NM and cam_rays.size(1) == NV and cam_rays.size(4) == 3 133 | #assert cam_rays.size(0) == 2 and cam_rays.size(1) == NV and cam_rays.size(4) == 3 134 | cam_rays = cam_rays.permute(0, 1, 4, 2, 3) #(2,NV,3,H,W) 135 | cam_rays = cam_rays.reshape(NM, NV, 3, -1) #(2,NV,3,H*W) 136 | cam_rays = cam_rays.permute(0, 1, 3, 2) #(2,NV,NP,3) 137 | 138 | all_scene_rays.append(cam_rays) 139 | 140 | all_scene_rays = torch.stack(all_scene_rays, 1).reshape(NM, -1, 3) #(2,B*NV*NP,3) 141 | if opt.rendering_type == 'mip-nerf': 142 | t_vals, (coords, coords_cov) = sample_along_rays_mip(all_scene_rays, opt) 143 | else: 144 | t_vals, coords = sample_along_rays(all_scene_rays, opt) 145 | 146 | coords = coords.reshape(1, -1, 3) 147 | 148 | # Model input 149 | model_input['coords'] = coords 150 | model_input['rays_d'] = all_scene_rays[1] 151 | model_input['t_vals'] = t_vals 152 | del model_input['focal'] 153 | del model_input['c2w'] 154 | 155 | return model_input 156 | 157 | def nerf_volume_rendering(prediction, opt, out_type='rgb'): 158 | pred_rgb, pred_density = prediction['model_out'][..., :3], prediction['model_out'][..., -1:] 159 | 160 | bsz = pred_rgb.shape[0] 161 | # rgb activation 162 | pred_rgb = pred_rgb.reshape(-1, opt.num_samples_per_ray+1, 3) 163 | if opt.rgb_activation == 'sigmoid': 164 | pred_rgb = torch.sigmoid(pred_rgb) 165 | elif opt.rgb_activation == 'relu': 166 | pred_rgb = F.relu(pred_rgb) 167 | elif 'sine' in opt.rgb_activation: 168 | w0 = float(opt.rgb_activation.split('sine')[-1]) 169 | pred_rgb = torch.sin(w0*pred_rgb) 170 | elif opt.rgb_activation == 'no_use': 171 | pass 172 | else: 173 | raise Exception("check rgb activation") 174 | 175 | # density activation 176 | pred_density = pred_density.reshape(-1, opt.num_samples_per_ray+1, 1) 177 | if opt.density_activation == 'elu': 178 | pred_density = F.elu(pred_density, alpha=0.1) + 0.1 179 | pred_density = torch.clip(pred_density, 0, MAX_DENSITY) 180 | elif opt.density_activation == 'relu': 181 | pred_density = F.relu(pred_density) 182 | elif opt.density_activation == 'leakyrelu': 183 | pred_density = F.leaky_relu(pred_density) + 0.1 184 | elif opt.density_activation == 'shift1': 185 | pred_density = torch.clip(pred_density + 1.0, 0, MAX_DENSITY) 186 | elif opt.density_activation == 'shift': 187 | pred_density = pred_density + 0.5 188 | pred_density = torch.clip(pred_density, 0, MAX_DENSITY) 189 | elif opt.density_activation == 'sine5+shift0.9': 190 | pred_density = torch.sin(5.0 * pred_density) 191 | pred_density = torch.clip(pred_density + 0.9, 0, MAX_DENSITY) 192 | elif opt.density_activation == 'sine5+shift1': 193 | pred_density = torch.sin(5.0 * pred_density) 194 | pred_density = torch.clip(pred_density + 1, 0, MAX_DENSITY) 195 | elif opt.density_activation == 'sine5+shift1.1': 196 | pred_density = torch.sin(5.0 * pred_density) 197 | pred_density = torch.clip(pred_density + 1.1, 0, MAX_DENSITY) 198 | elif opt.density_activation == 'sine5+shift0.5+elu': 199 | pred_density = torch.sin(5.0 * pred_density) + 0.5 200 | pred_density = F.elu(pred_density, alpha=0.1) + 0.1 201 | pred_density = torch.clip(pred_density, 0, MAX_DENSITY) 202 | elif opt.density_activation == 'sine5+shift1+elu': 203 | pred_density = torch.sin(5.0 * pred_density) + 1.0 204 | pred_density = F.elu(pred_density, alpha=0.1) + 0.1 205 | pred_density = torch.clip(pred_density, 0, MAX_DENSITY) 206 | elif 'elu+scale' in opt.density_activation: 207 | scale = float(opt.density_activation.split('elu+scale')[-1]) 208 | pred_density = pred_density * scale 209 | pred_density = F.elu(pred_density, alpha=0.1) + 0.1 210 | pred_density = torch.clip(pred_density, 0, MAX_DENSITY) 211 | elif opt.density_activation == 'no_use': 212 | pass 213 | else: 214 | raise Exception("check density activation") 215 | 216 | t_vals, rays_d = prediction['model_in']['t_vals'], prediction['model_in']['rays_d'] 217 | if opt.rendering_type == 'functa': 218 | color, acc, depth, weight = volumetric_rendering_functa(pred_rgb, pred_density, t_vals, rays_d, opt.white_bkgd) 219 | elif opt.rendering_type == 'mip-nerf': 220 | color, acc, depth, weight = volumetric_rendering_mip(pred_rgb, pred_density, t_vals, rays_d, opt.white_bkgd) 221 | else: 222 | color, acc, depth, weight = volumetric_rendering(pred_rgb, pred_density, t_vals, rays_d, opt.white_bkgd) 223 | 224 | 225 | # reshape 226 | color = color.reshape(bsz, -1, 3) 227 | depth = depth.reshape(bsz, -1) 228 | acc = acc.reshape(bsz, -1) 229 | if out_type == 'all': 230 | prediction['model_out'] = { 231 | 'rgb': color, 232 | 'depth': depth, 233 | 'acc': acc, 234 | } 235 | else: 236 | prediction['model_out'] = color 237 | return prediction 238 | 239 | def save_rendering_output(model_output, gt, opt, image_path, max_num=-1): 240 | save_gt = gt['img'].reshape(-1, opt.H, opt.W, 3).permute(0,3,1,2).detach().cpu() 241 | pred_rgb = model_output['model_out']['rgb'].reshape(-1, opt.H, opt.W, 3).permute(0,3,1,2).detach().cpu() 242 | pred_depth = model_output['model_out']['depth'].reshape(-1, opt.H, opt.W, 1).permute(0,3,1,2).detach().cpu() 243 | pred_acc = model_output['model_out']['acc'].reshape(-1, opt.H, opt.W, 1).permute(0,3,1,2).detach().cpu() 244 | pred_depth = ((pred_depth-pred_depth.min())/(pred_depth.max()-pred_depth.min())*2-1).repeat(1,3,1,1) 245 | pred_acc = pred_acc.repeat(1,3,1,1) 246 | combined_image = torch.cat((save_gt, pred_rgb, pred_depth, pred_acc), -1) 247 | if max_num > 0: 248 | save_num = min(combined_image.size(0), max_num) 249 | combined_image = combined_image[:save_num] 250 | 251 | combined_image = torchvision.utils.make_grid(combined_image, nrow=1) 252 | torchvision.utils.save_image(combined_image, image_path) 253 | 254 | def split_dict(dict_, start, end, num_unit): 255 | output_dict = {} 256 | for k, v in dict_.items(): 257 | if k not in ['coords']: 258 | output_dict[k] = v 259 | else: 260 | if end >= v.size(1): 261 | output_dict[k] = v[:,start:,:] 262 | else: 263 | output_dict[k] = v[:,start:end,:] 264 | return output_dict, start+num_unit, end+num_unit 265 | -------------------------------------------------------------------------------- /src/utils/fid_utils.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Adopted from 3 | # https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py 4 | # ------------------------------------------------------------------------------------ 5 | 6 | import glob 7 | import logging 8 | import os 9 | import sys 10 | import torch 11 | import numpy as np 12 | import torch.nn.functional as F 13 | import torchvision.transforms as transforms 14 | 15 | from scipy import linalg 16 | from torch.utils.data import DataLoader 17 | from datasets import ImageNet 18 | from tqdm import tqdm 19 | 20 | from .inception import InceptionV3 21 | 22 | if int(sys.version.split('.')[1]) < 8: 23 | import pickle5 as pickle 24 | else: 25 | import pickle 26 | 27 | 28 | class InceptionWrapper(InceptionV3): 29 | 30 | def forward(self, inp): 31 | pred = super().forward(inp)[0] 32 | # If model output is not scalar, apply global spatial average pooling. 33 | # This happens if you choose a dimensionality not equal 2048. 34 | if pred.size(2) != 1 or pred.size(3) != 1: 35 | pred = F.adaptive_avg_pool2d(pred, output_size=(1, 1)) 36 | pred = pred.reshape(pred.shape[0], -1) 37 | 38 | return pred 39 | 40 | def get_logits(self, inp): 41 | _, logits = super().forward(inp, return_logits=True) 42 | 43 | return logits 44 | 45 | 46 | def get_inception_model(dims=2048): 47 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 48 | model = InceptionWrapper([block_idx]) 49 | return model 50 | 51 | 52 | def mean_covar_torch(xs): 53 | mu = torch.mean(xs, dim=0, keepdim=True) 54 | ys = xs - mu 55 | unnormalized_sigma = (ys.T @ ys) 56 | sigma = unnormalized_sigma / (xs.shape[0] - 1) 57 | return mu, sigma 58 | 59 | 60 | def mean_covar_numpy(xs): 61 | if isinstance(xs, torch.Tensor): 62 | xs = xs.cpu().numpy() 63 | return np.mean(xs, axis=0), np.cov(xs, rowvar=False) 64 | 65 | 66 | def frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 67 | """Numpy implementation of the Frechet Distance. 68 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 69 | and X_2 ~ N(mu_2, C_2) is 70 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 71 | 72 | Stable version by Dougal J. Sutherland. 73 | 74 | Params: 75 | -- mu1 : Numpy array containing the activations of a layer of the 76 | inception net (like returned by the function 'get_predictions') 77 | for generated samples. 78 | -- mu2 : The sample mean over activations, precalculated on an 79 | representative data set. 80 | -- sigma1: The covariance matrix over activations for generated samples. 81 | -- sigma2: The covariance matrix over activations, precalculated on an 82 | representative data set. 83 | 84 | Returns: 85 | -- : The Frechet Distance. 86 | """ 87 | 88 | mu1 = np.atleast_1d(mu1) 89 | mu2 = np.atleast_1d(mu2) 90 | 91 | sigma1 = np.atleast_2d(sigma1) 92 | sigma2 = np.atleast_2d(sigma2) 93 | 94 | assert mu1.shape == mu2.shape, \ 95 | 'Training and test mean vectors have different lengths' 96 | assert sigma1.shape == sigma2.shape, \ 97 | 'Training and test covariances have different dimensions' 98 | 99 | diff = mu1 - mu2 100 | 101 | # Product might be almost singular 102 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 103 | if not np.isfinite(covmean).all(): 104 | msg = ('fid calculation produces singular product; ' 105 | 'adding %s to diagonal of cov estimates') % eps 106 | logging.warning(msg) 107 | offset = np.eye(sigma1.shape[0]) * eps 108 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 109 | 110 | # Numerical error might give slight imaginary component 111 | if np.iscomplexobj(covmean): 112 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 113 | m = np.max(np.abs(covmean.imag)) 114 | raise ValueError('Imaginary component {}'.format(m)) 115 | covmean = covmean.real 116 | 117 | tr_covmean = np.trace(covmean) 118 | 119 | return (diff.dot(diff) + np.trace(sigma1) + 120 | np.trace(sigma2) - 2 * tr_covmean) 121 | 122 | 123 | @torch.no_grad() 124 | def compute_statistics_imagenet_val(resolution=128, 125 | batch_size=500, 126 | inception_model=None, 127 | stage1_model=None, 128 | device=torch.device('cuda'), 129 | skip_original=False, 130 | ): 131 | transforms_ = [ 132 | transforms.Resize(resolution), 133 | transforms.CenterCrop(resolution), 134 | transforms.Resize((resolution, resolution)), 135 | transforms.ToTensor(), 136 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 137 | ] 138 | transforms_ = transforms.Compose(transforms_) 139 | 140 | dataset = ImageNet(split='val', transform=transforms_) 141 | 142 | mu_acts, sigma_acts, mu_acts_recon, sigma_acts_recon = \ 143 | compute_statistics_dataset(dataset, 144 | batch_size=batch_size, 145 | inception_model=inception_model, 146 | stage1_model=stage1_model, 147 | device=device, 148 | skip_original=skip_original, 149 | ) 150 | 151 | return mu_acts, sigma_acts, mu_acts_recon, sigma_acts_recon 152 | 153 | 154 | @torch.no_grad() 155 | def compute_statistics_dataset(dataset, 156 | batch_size=500, 157 | inception_model=None, 158 | stage1_model=None, 159 | device=torch.device('cuda'), 160 | skip_original=False, 161 | ): 162 | 163 | if skip_original and stage1_model is None: 164 | return None, None, None, None 165 | 166 | if inception_model is None: 167 | inception_model = get_inception_model().to(device) 168 | 169 | loader = DataLoader(dataset, shuffle=False, pin_memory=True, batch_size=batch_size, num_workers=16) 170 | 171 | inception_model.eval() 172 | if stage1_model: 173 | stage1_model.eval() 174 | 175 | acts = [] 176 | acts_recon = [] 177 | 178 | sample_size_sum = 0.0 179 | sample_sum = torch.tensor(0.0, device=device) 180 | sample_sq_sum = torch.tensor(0.0, device=device) 181 | sample_max = torch.tensor(float('-inf'), device=device) 182 | sample_min = torch.tensor(float('inf'), device=device) 183 | 184 | for xs, _ in tqdm(loader, desc="compute acts"): 185 | xs = xs.to(device, non_blocking=True) 186 | 187 | # we are assuming that dataset returns value in -1 ~ 1 -> remap to 0 ~ 1 188 | xs = torch.clamp(xs*0.5 + 0.5, 0, 1) 189 | 190 | sample_sum += xs.sum() 191 | sample_sq_sum += xs.pow(2.0).sum() 192 | sample_size_sum += xs.numel() 193 | sample_max = max(xs.max(), sample_max) 194 | sample_min = min(xs.min(), sample_min) 195 | 196 | act = inception_model(xs).cpu() if not skip_original else None 197 | acts.append(act) 198 | 199 | if stage1_model: 200 | # here we assume that stage1 model input & output values are in -1 ~ 1 range 201 | # this may not cover DiscreteVAE 202 | imgs = 2. * xs - 1. 203 | xs_recon = torch.cat([ 204 | stage1_model(imgs[i:i+1])[0] for i in range(imgs.shape[0]) 205 | ], dim=0) 206 | xs_recon = torch.clamp(xs_recon * 0.5 + 0.5, 0, 1) 207 | act_recon = inception_model(xs_recon).cpu() 208 | acts_recon.append(act_recon) 209 | 210 | sample_mean = sample_sum.item() / sample_size_sum 211 | sample_std = ((sample_sq_sum.item() / sample_size_sum) - (sample_mean ** 2.0)) ** 0.5 212 | logging.info(f'val imgs. stats :: ' 213 | f'max: {sample_max:.4f}, min: {sample_min:.4f}, mean: {sample_mean:.4f}, std: {sample_std:.4f}') 214 | 215 | acts = torch.cat(acts, dim=0) if not skip_original else None 216 | 217 | if skip_original: 218 | mu_acts, sigma_acts = None, None 219 | else: 220 | mu_acts, sigma_acts = mean_covar_numpy(acts) 221 | 222 | if stage1_model: 223 | acts_recon = torch.cat(acts_recon, dim=0) 224 | mu_acts_recon, sigma_acts_recon = mean_covar_numpy(acts_recon) 225 | else: 226 | mu_acts_recon, sigma_acts_recon = None, None 227 | 228 | return mu_acts, sigma_acts, mu_acts_recon, sigma_acts_recon 229 | 230 | 231 | def create_dataset_from_files(path, verbose=False): 232 | samples = [] 233 | pkl_lists = glob.glob(os.path.join(path, 'samples*.pkl')) 234 | first_file_name = os.path.basename(pkl_lists[0]) 235 | last_file_name = os.path.basename(pkl_lists[-1]) 236 | logging.info(f'loading generated images from {path}: [{first_file_name}, ..., {last_file_name}]') 237 | 238 | for pkl in tqdm(pkl_lists, desc='loading pickles'): 239 | with open(pkl, 'rb') as f: 240 | # samples.append(pickle.load(f).cpu().numpy()) 241 | s = pickle.load(f) 242 | if isinstance(s, np.ndarray): 243 | s = torch.from_numpy(s) 244 | samples.append(s) 245 | 246 | datasets = [torch.utils.data.TensorDataset(sample) for sample in samples] 247 | dataset = torch.utils.data.ConcatDataset(datasets) 248 | 249 | if verbose: 250 | total_size = sum([sample.size for sample in samples]) 251 | sample_mean = sum([sample.sum() for sample in samples]) / total_size 252 | sample_std = (sum([((sample - sample_mean)**2).sum() for sample in samples]) / total_size) ** 0.5 253 | sample_max = max([sample.max() for sample in samples]) 254 | sample_min = min([sample.min() for sample in samples]) 255 | logging.info(f'gen. imgs. stats :: ' 256 | f'max: {sample_max:.4f}, min: {sample_min:.4f}, mean: {sample_mean:.4f}, std: {sample_std:.4f}') 257 | 258 | return dataset 259 | 260 | 261 | @torch.no_grad() 262 | def compute_activations_from_dataset(dataset, 263 | batch_size=500, 264 | inception_model=None, 265 | device=torch.device('cuda'), 266 | normalized=False, 267 | ): 268 | if inception_model is None: 269 | inception_model = get_inception_model().to(device) 270 | 271 | loader = DataLoader(dataset, shuffle=False, pin_memory=True, batch_size=batch_size, num_workers=16) 272 | 273 | acts = [] 274 | inception_model.eval() 275 | 276 | for xs in tqdm(loader, desc="compute acts (gen. imgs)"): 277 | xs = xs[0].to(device, non_blocking=True) 278 | if normalized: 279 | xs = 0.5 * xs + 0.5 280 | act = inception_model(xs) 281 | acts.append(act.cpu()) 282 | 283 | acts = torch.cat(acts, dim=0) 284 | return acts 285 | 286 | 287 | def compute_activations_from_files(path, 288 | batch_size=500, 289 | inception_model=None, 290 | device=torch.device('cuda'), 291 | ): 292 | dataset = create_dataset_from_files(path) 293 | return compute_activations_from_dataset(dataset, 294 | batch_size=batch_size, 295 | inception_model=inception_model, 296 | device=device) 297 | 298 | 299 | def compute_statistics_from_files(path, 300 | batch_size=500, 301 | inception_model=None, 302 | device=torch.device('cuda'), 303 | return_acts=False, 304 | ): 305 | acts = compute_activations_from_files(path, 306 | batch_size=batch_size, 307 | inception_model=inception_model, 308 | device=device, 309 | ) 310 | mu_acts, sigma_acts = mean_covar_numpy(acts) 311 | if return_acts: 312 | return mu_acts, sigma_acts, acts 313 | else: 314 | return mu_acts, sigma_acts 315 | --------------------------------------------------------------------------------