├── MOVIS
├── ldm
│ ├── data
│ │ ├── __init__.py
│ │ ├── inpainting
│ │ │ ├── __init__.py
│ │ │ └── synthetic_mask.py
│ │ ├── __pycache__
│ │ │ ├── base.cpython-39.pyc
│ │ │ └── __init__.cpython-39.pyc
│ │ ├── dummy.py
│ │ ├── base.py
│ │ ├── lsun.py
│ │ ├── nerf_like.py
│ │ ├── coco.py
│ │ └── imagenet.py
│ ├── models
│ │ └── diffusion
│ │ │ ├── __init__.py
│ │ │ ├── sampling_util.py
│ │ │ ├── classifier.py
│ │ │ └── plms.py
│ ├── modules
│ │ ├── encoders
│ │ │ ├── __init__.py
│ │ │ └── __pycache__
│ │ │ │ ├── __init__.cpython-39.pyc
│ │ │ │ └── modules.cpython-39.pyc
│ │ ├── distributions
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ │ ├── __init__.cpython-39.pyc
│ │ │ │ └── distributions.cpython-39.pyc
│ │ │ └── distributions.py
│ │ ├── diffusionmodules
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ │ ├── model.cpython-39.pyc
│ │ │ │ ├── util.cpython-39.pyc
│ │ │ │ ├── __init__.cpython-39.pyc
│ │ │ │ └── openaimodel.cpython-39.pyc
│ │ │ └── util.py
│ │ ├── losses
│ │ │ ├── __init__.py
│ │ │ ├── contperceptual.py
│ │ │ └── vqperceptual.py
│ │ ├── __pycache__
│ │ │ ├── ema.cpython-39.pyc
│ │ │ ├── attention.cpython-39.pyc
│ │ │ └── x_transformer.cpython-39.pyc
│ │ ├── image_degradation
│ │ │ ├── utils
│ │ │ │ └── test.png
│ │ │ └── __init__.py
│ │ ├── ema.py
│ │ ├── evaluate
│ │ │ ├── ssim.py
│ │ │ ├── frechet_video_distance.py
│ │ │ └── torch_frechet_video_distance.py
│ │ └── attention.py
│ ├── thirdp
│ │ └── psp
│ │ │ ├── __pycache__
│ │ │ ├── helpers.cpython-39.pyc
│ │ │ ├── id_loss.cpython-39.pyc
│ │ │ └── model_irse.cpython-39.pyc
│ │ │ ├── id_loss.py
│ │ │ ├── model_irse.py
│ │ │ └── helpers.py
│ ├── extras.py
│ ├── guidance.py
│ ├── lr_scheduler.py
│ └── util.py
├── .gitignore
├── assets
│ ├── SUNRGBD
│ │ ├── example_0
│ │ │ ├── depth.npy
│ │ │ ├── image.png
│ │ │ └── mask.png
│ │ ├── example_1
│ │ │ ├── depth.npy
│ │ │ ├── image.png
│ │ │ └── mask.png
│ │ ├── example_2
│ │ │ ├── depth.npy
│ │ │ ├── image.png
│ │ │ └── mask.png
│ │ ├── example_3
│ │ │ ├── depth.npy
│ │ │ ├── image.png
│ │ │ └── mask.png
│ │ ├── example_4
│ │ │ ├── depth.npy
│ │ │ ├── image.png
│ │ │ └── mask.png
│ │ ├── example_5
│ │ │ ├── depth.npy
│ │ │ ├── image.png
│ │ │ └── mask.png
│ │ ├── example_6
│ │ │ ├── depth.npy
│ │ │ ├── image.png
│ │ │ └── mask.png
│ │ └── example_7
│ │ │ ├── depth.npy
│ │ │ ├── image.png
│ │ │ └── mask.png
│ ├── ScanNet
│ │ ├── example_0
│ │ │ ├── depth.npy
│ │ │ ├── image.png
│ │ │ └── mask.png
│ │ └── example_1
│ │ │ ├── depth.npy
│ │ │ ├── image.png
│ │ │ └── mask.png
│ ├── Youtube
│ │ ├── example_0
│ │ │ ├── depth.npy
│ │ │ ├── image.png
│ │ │ └── mask.png
│ │ └── example_1
│ │ │ ├── depth.npy
│ │ │ ├── image.png
│ │ │ └── mask.png
│ └── RealEstate10k
│ │ ├── example_0
│ │ ├── mask.png
│ │ ├── depth.npy
│ │ └── image.png
│ │ └── example_1
│ │ ├── mask.png
│ │ ├── depth.npy
│ │ └── image.png
├── eval_batch_3d.sh
├── eval_batch_obj.sh
├── train.sh
├── eval_single.sh
├── requirements.txt
├── configs
│ ├── 3d_mix.yaml
│ ├── inference_c3dfs.yaml
│ └── inference_cobj.yaml
└── eval_single.py
├── .gitignore
├── assets
└── teaser.png
└── README.md
/MOVIS/ldm/data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/MOVIS/ldm/data/inpainting/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/MOVIS/ldm/models/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/MOVIS/ldm/modules/encoders/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/MOVIS/ldm/modules/distributions/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/MOVIS/ldm/modules/diffusionmodules/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | C_Obj/
2 | MOVIS-test/
3 | MOVIS-test.zip
--------------------------------------------------------------------------------
/MOVIS/.gitignore:
--------------------------------------------------------------------------------
1 | CLIP/
2 | last.ckpt
3 | taming*
4 | sd-image*
5 | *__pycache__
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/assets/teaser.png
--------------------------------------------------------------------------------
/MOVIS/ldm/modules/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
--------------------------------------------------------------------------------
/MOVIS/assets/SUNRGBD/example_0/depth.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/SUNRGBD/example_0/depth.npy
--------------------------------------------------------------------------------
/MOVIS/assets/SUNRGBD/example_0/image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/SUNRGBD/example_0/image.png
--------------------------------------------------------------------------------
/MOVIS/assets/SUNRGBD/example_0/mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/SUNRGBD/example_0/mask.png
--------------------------------------------------------------------------------
/MOVIS/assets/SUNRGBD/example_1/depth.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/SUNRGBD/example_1/depth.npy
--------------------------------------------------------------------------------
/MOVIS/assets/SUNRGBD/example_1/image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/SUNRGBD/example_1/image.png
--------------------------------------------------------------------------------
/MOVIS/assets/SUNRGBD/example_1/mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/SUNRGBD/example_1/mask.png
--------------------------------------------------------------------------------
/MOVIS/assets/SUNRGBD/example_2/depth.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/SUNRGBD/example_2/depth.npy
--------------------------------------------------------------------------------
/MOVIS/assets/SUNRGBD/example_2/image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/SUNRGBD/example_2/image.png
--------------------------------------------------------------------------------
/MOVIS/assets/SUNRGBD/example_2/mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/SUNRGBD/example_2/mask.png
--------------------------------------------------------------------------------
/MOVIS/assets/SUNRGBD/example_3/depth.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/SUNRGBD/example_3/depth.npy
--------------------------------------------------------------------------------
/MOVIS/assets/SUNRGBD/example_3/image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/SUNRGBD/example_3/image.png
--------------------------------------------------------------------------------
/MOVIS/assets/SUNRGBD/example_3/mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/SUNRGBD/example_3/mask.png
--------------------------------------------------------------------------------
/MOVIS/assets/SUNRGBD/example_4/depth.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/SUNRGBD/example_4/depth.npy
--------------------------------------------------------------------------------
/MOVIS/assets/SUNRGBD/example_4/image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/SUNRGBD/example_4/image.png
--------------------------------------------------------------------------------
/MOVIS/assets/SUNRGBD/example_4/mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/SUNRGBD/example_4/mask.png
--------------------------------------------------------------------------------
/MOVIS/assets/SUNRGBD/example_5/depth.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/SUNRGBD/example_5/depth.npy
--------------------------------------------------------------------------------
/MOVIS/assets/SUNRGBD/example_5/image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/SUNRGBD/example_5/image.png
--------------------------------------------------------------------------------
/MOVIS/assets/SUNRGBD/example_5/mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/SUNRGBD/example_5/mask.png
--------------------------------------------------------------------------------
/MOVIS/assets/SUNRGBD/example_6/depth.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/SUNRGBD/example_6/depth.npy
--------------------------------------------------------------------------------
/MOVIS/assets/SUNRGBD/example_6/image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/SUNRGBD/example_6/image.png
--------------------------------------------------------------------------------
/MOVIS/assets/SUNRGBD/example_6/mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/SUNRGBD/example_6/mask.png
--------------------------------------------------------------------------------
/MOVIS/assets/SUNRGBD/example_7/depth.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/SUNRGBD/example_7/depth.npy
--------------------------------------------------------------------------------
/MOVIS/assets/SUNRGBD/example_7/image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/SUNRGBD/example_7/image.png
--------------------------------------------------------------------------------
/MOVIS/assets/SUNRGBD/example_7/mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/SUNRGBD/example_7/mask.png
--------------------------------------------------------------------------------
/MOVIS/assets/ScanNet/example_0/depth.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/ScanNet/example_0/depth.npy
--------------------------------------------------------------------------------
/MOVIS/assets/ScanNet/example_0/image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/ScanNet/example_0/image.png
--------------------------------------------------------------------------------
/MOVIS/assets/ScanNet/example_0/mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/ScanNet/example_0/mask.png
--------------------------------------------------------------------------------
/MOVIS/assets/ScanNet/example_1/depth.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/ScanNet/example_1/depth.npy
--------------------------------------------------------------------------------
/MOVIS/assets/ScanNet/example_1/image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/ScanNet/example_1/image.png
--------------------------------------------------------------------------------
/MOVIS/assets/ScanNet/example_1/mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/ScanNet/example_1/mask.png
--------------------------------------------------------------------------------
/MOVIS/assets/Youtube/example_0/depth.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/Youtube/example_0/depth.npy
--------------------------------------------------------------------------------
/MOVIS/assets/Youtube/example_0/image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/Youtube/example_0/image.png
--------------------------------------------------------------------------------
/MOVIS/assets/Youtube/example_0/mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/Youtube/example_0/mask.png
--------------------------------------------------------------------------------
/MOVIS/assets/Youtube/example_1/depth.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/Youtube/example_1/depth.npy
--------------------------------------------------------------------------------
/MOVIS/assets/Youtube/example_1/image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/Youtube/example_1/image.png
--------------------------------------------------------------------------------
/MOVIS/assets/Youtube/example_1/mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/Youtube/example_1/mask.png
--------------------------------------------------------------------------------
/MOVIS/assets/RealEstate10k/example_0/mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/RealEstate10k/example_0/mask.png
--------------------------------------------------------------------------------
/MOVIS/assets/RealEstate10k/example_1/mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/RealEstate10k/example_1/mask.png
--------------------------------------------------------------------------------
/MOVIS/assets/RealEstate10k/example_0/depth.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/RealEstate10k/example_0/depth.npy
--------------------------------------------------------------------------------
/MOVIS/assets/RealEstate10k/example_0/image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/RealEstate10k/example_0/image.png
--------------------------------------------------------------------------------
/MOVIS/assets/RealEstate10k/example_1/depth.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/RealEstate10k/example_1/depth.npy
--------------------------------------------------------------------------------
/MOVIS/assets/RealEstate10k/example_1/image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/assets/RealEstate10k/example_1/image.png
--------------------------------------------------------------------------------
/MOVIS/eval_batch_3d.sh:
--------------------------------------------------------------------------------
1 | python eval_batch.py \
2 | --config configs/inference_c3dfs.yaml \
3 | --output_dir output/c3dfs \
4 | --exr_res 256
--------------------------------------------------------------------------------
/MOVIS/eval_batch_obj.sh:
--------------------------------------------------------------------------------
1 | python eval_batch.py \
2 | --config configs/inference_cobj.yaml \
3 | --output_dir output/c_obj \
4 | --exr_res 512
--------------------------------------------------------------------------------
/MOVIS/ldm/data/__pycache__/base.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/ldm/data/__pycache__/base.cpython-39.pyc
--------------------------------------------------------------------------------
/MOVIS/ldm/modules/__pycache__/ema.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/ldm/modules/__pycache__/ema.cpython-39.pyc
--------------------------------------------------------------------------------
/MOVIS/ldm/data/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/ldm/data/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/MOVIS/ldm/modules/image_degradation/utils/test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/ldm/modules/image_degradation/utils/test.png
--------------------------------------------------------------------------------
/MOVIS/ldm/modules/__pycache__/attention.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/ldm/modules/__pycache__/attention.cpython-39.pyc
--------------------------------------------------------------------------------
/MOVIS/ldm/thirdp/psp/__pycache__/helpers.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/ldm/thirdp/psp/__pycache__/helpers.cpython-39.pyc
--------------------------------------------------------------------------------
/MOVIS/ldm/thirdp/psp/__pycache__/id_loss.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/ldm/thirdp/psp/__pycache__/id_loss.cpython-39.pyc
--------------------------------------------------------------------------------
/MOVIS/ldm/modules/__pycache__/x_transformer.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/ldm/modules/__pycache__/x_transformer.cpython-39.pyc
--------------------------------------------------------------------------------
/MOVIS/ldm/thirdp/psp/__pycache__/model_irse.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/ldm/thirdp/psp/__pycache__/model_irse.cpython-39.pyc
--------------------------------------------------------------------------------
/MOVIS/ldm/modules/encoders/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/ldm/modules/encoders/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/MOVIS/ldm/modules/encoders/__pycache__/modules.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/ldm/modules/encoders/__pycache__/modules.cpython-39.pyc
--------------------------------------------------------------------------------
/MOVIS/ldm/modules/diffusionmodules/__pycache__/model.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/ldm/modules/diffusionmodules/__pycache__/model.cpython-39.pyc
--------------------------------------------------------------------------------
/MOVIS/ldm/modules/diffusionmodules/__pycache__/util.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/ldm/modules/diffusionmodules/__pycache__/util.cpython-39.pyc
--------------------------------------------------------------------------------
/MOVIS/ldm/modules/distributions/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/ldm/modules/distributions/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/MOVIS/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/MOVIS/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-39.pyc
--------------------------------------------------------------------------------
/MOVIS/ldm/modules/distributions/__pycache__/distributions.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jason-aplp/MOVIS-code/HEAD/MOVIS/ldm/modules/distributions/__pycache__/distributions.cpython-39.pyc
--------------------------------------------------------------------------------
/MOVIS/ldm/modules/image_degradation/__init__.py:
--------------------------------------------------------------------------------
1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
3 |
--------------------------------------------------------------------------------
/MOVIS/train.sh:
--------------------------------------------------------------------------------
1 | python main.py \
2 | -t \
3 | --base configs/3d_mix.yaml \
4 | --gpus 0, \
5 | --scale_lr False \
6 | --num_nodes 1 \
7 | --seed 42 \
8 | --check_val_every_n_epoch 1 \
9 | --finetune_from sd-image-conditioned-v2.ckpt
--------------------------------------------------------------------------------
/MOVIS/eval_single.sh:
--------------------------------------------------------------------------------
1 | # azimuth angle rotate counterclockwise
2 | python eval_single.py \
3 | --input_image 'assets/SUNRGBD/example_0/image.png' \
4 | --input_depth 'assets/SUNRGBD/example_0/depth.npy' \
5 | --input_mask 'assets/SUNRGBD/example_0/mask.png' \
6 | --azimuth 80 \
7 | --elevation 0 \
8 | --output_path 'test.png'
--------------------------------------------------------------------------------
/MOVIS/requirements.txt:
--------------------------------------------------------------------------------
1 | --extra-index-url https://download.pytorch.org/whl/cu113
2 | torch==1.12.1
3 | torchvision==0.13.1
4 | albumentations==0.4.3
5 | opencv-python==4.11.0.86
6 | pudb==2019.2
7 | imageio==2.9.0
8 | imageio-ffmpeg==0.4.2
9 | pytorch-lightning==1.4.2
10 | omegaconf==2.1.1
11 | test-tube>=0.7.5
12 | streamlit>=0.73.1
13 | einops==0.3.0
14 | torch-fidelity==0.3.0
15 | transformers==4.22.2
16 | kornia==0.6
17 | webdataset==0.2.5
18 | torchmetrics==0.6.0
19 | fire==0.4.0
20 | gradio==3.21.0
21 | diffusers==0.12.1
22 | datasets[vision]==2.4.0
23 | carvekit-colab==4.1.0
24 | rich>=13.3.2
25 | lovely-numpy>=0.2.8
26 | lovely-tensors>=0.1.14
27 | plotly==5.13.1
28 | kiui
29 | lpips
30 | wandb
31 | OpenEXR
32 | imath
--------------------------------------------------------------------------------
/MOVIS/ldm/thirdp/psp/id_loss.py:
--------------------------------------------------------------------------------
1 | # https://github.com/eladrich/pixel2style2pixel
2 | import torch
3 | from torch import nn
4 | from ldm.thirdp.psp.model_irse import Backbone
5 |
6 |
7 | class IDFeatures(nn.Module):
8 | def __init__(self, model_path):
9 | super(IDFeatures, self).__init__()
10 | print('Loading ResNet ArcFace')
11 | self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
12 | self.facenet.load_state_dict(torch.load(model_path, map_location="cpu"))
13 | self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
14 | self.facenet.eval()
15 |
16 | def forward(self, x, crop=False):
17 | # Not sure of the image range here
18 | if crop:
19 | x = torch.nn.functional.interpolate(x, (256, 256), mode="area")
20 | x = x[:, :, 35:223, 32:220]
21 | x = self.face_pool(x)
22 | x_feats = self.facenet(x)
23 | return x_feats
24 |
--------------------------------------------------------------------------------
/MOVIS/ldm/data/dummy.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | import string
4 | from torch.utils.data import Dataset, Subset
5 |
6 | class DummyData(Dataset):
7 | def __init__(self, length, size):
8 | self.length = length
9 | self.size = size
10 |
11 | def __len__(self):
12 | return self.length
13 |
14 | def __getitem__(self, i):
15 | x = np.random.randn(*self.size)
16 | letters = string.ascii_lowercase
17 | y = ''.join(random.choice(string.ascii_lowercase) for i in range(10))
18 | return {"jpg": x, "txt": y}
19 |
20 |
21 | class DummyDataWithEmbeddings(Dataset):
22 | def __init__(self, length, size, emb_size):
23 | self.length = length
24 | self.size = size
25 | self.emb_size = emb_size
26 |
27 | def __len__(self):
28 | return self.length
29 |
30 | def __getitem__(self, i):
31 | x = np.random.randn(*self.size)
32 | y = np.random.randn(*self.emb_size).astype(np.float32)
33 | return {"jpg": x, "txt": y}
34 |
35 |
--------------------------------------------------------------------------------
/MOVIS/ldm/data/base.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from abc import abstractmethod
4 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
5 |
6 |
7 | class Txt2ImgIterableBaseDataset(IterableDataset):
8 | '''
9 | Define an interface to make the IterableDatasets for text2img data chainable
10 | '''
11 | def __init__(self, num_records=0, valid_ids=None, size=256):
12 | super().__init__()
13 | self.num_records = num_records
14 | self.valid_ids = valid_ids
15 | self.sample_ids = valid_ids
16 | self.size = size
17 |
18 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
19 |
20 | def __len__(self):
21 | return self.num_records
22 |
23 | @abstractmethod
24 | def __iter__(self):
25 | pass
26 |
27 |
28 | class PRNGMixin(object):
29 | """
30 | Adds a prng property which is a numpy RandomState which gets
31 | reinitialized whenever the pid changes to avoid synchronized sampling
32 | behavior when used in conjunction with multiprocessing.
33 | """
34 | @property
35 | def prng(self):
36 | currentpid = os.getpid()
37 | if getattr(self, "_initpid", None) != currentpid:
38 | self._initpid = currentpid
39 | self._prng = np.random.RandomState()
40 | return self._prng
41 |
--------------------------------------------------------------------------------
/MOVIS/ldm/models/diffusion/sampling_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | def append_dims(x, target_dims):
6 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.
7 | From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
8 | dims_to_append = target_dims - x.ndim
9 | if dims_to_append < 0:
10 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
11 | return x[(...,) + (None,) * dims_to_append]
12 |
13 |
14 | def renorm_thresholding(x0, value):
15 | # renorm
16 | pred_max = x0.max()
17 | pred_min = x0.min()
18 | pred_x0 = (x0 - pred_min) / (pred_max - pred_min) # 0 ... 1
19 | pred_x0 = 2 * pred_x0 - 1. # -1 ... 1
20 |
21 | s = torch.quantile(
22 | rearrange(pred_x0, 'b ... -> b (...)').abs(),
23 | value,
24 | dim=-1
25 | )
26 | s.clamp_(min=1.0)
27 | s = s.view(-1, *((1,) * (pred_x0.ndim - 1)))
28 |
29 | # clip by threshold
30 | # pred_x0 = pred_x0.clamp(-s, s) / s # needs newer pytorch # TODO bring back to pure-gpu with min/max
31 |
32 | # temporary hack: numpy on cpu
33 | pred_x0 = np.clip(pred_x0.cpu().numpy(), -s.cpu().numpy(), s.cpu().numpy()) / s.cpu().numpy()
34 | pred_x0 = torch.tensor(pred_x0).to(self.model.device)
35 |
36 | # re.renorm
37 | pred_x0 = (pred_x0 + 1.) / 2. # 0 ... 1
38 | pred_x0 = (pred_max - pred_min) * pred_x0 + pred_min # orig range
39 | return pred_x0
40 |
41 |
42 | def norm_thresholding(x0, value):
43 | s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
44 | return x0 * (value / s)
45 |
46 |
47 | def spatial_norm_thresholding(x0, value):
48 | # b c h w
49 | s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
50 | return x0 * (value / s)
--------------------------------------------------------------------------------
/MOVIS/ldm/extras.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from omegaconf import OmegaConf
3 | import torch
4 | from ldm.util import instantiate_from_config
5 | import logging
6 | from contextlib import contextmanager
7 |
8 | from contextlib import contextmanager
9 | import logging
10 |
11 | @contextmanager
12 | def all_logging_disabled(highest_level=logging.CRITICAL):
13 | """
14 | A context manager that will prevent any logging messages
15 | triggered during the body from being processed.
16 |
17 | :param highest_level: the maximum logging level in use.
18 | This would only need to be changed if a custom level greater than CRITICAL
19 | is defined.
20 |
21 | https://gist.github.com/simon-weber/7853144
22 | """
23 | # two kind-of hacks here:
24 | # * can't get the highest logging level in effect => delegate to the user
25 | # * can't get the current module-level override => use an undocumented
26 | # (but non-private!) interface
27 |
28 | previous_level = logging.root.manager.disable
29 |
30 | logging.disable(highest_level)
31 |
32 | try:
33 | yield
34 | finally:
35 | logging.disable(previous_level)
36 |
37 | def load_training_dir(train_dir, device, epoch="last"):
38 | """Load a checkpoint and config from training directory"""
39 | train_dir = Path(train_dir)
40 | ckpt = list(train_dir.rglob(f"*{epoch}.ckpt"))
41 | assert len(ckpt) == 1, f"found {len(ckpt)} matching ckpt files"
42 | config = list(train_dir.rglob(f"*-project.yaml"))
43 | assert len(ckpt) > 0, f"didn't find any config in {train_dir}"
44 | if len(config) > 1:
45 | print(f"found {len(config)} matching config files")
46 | config = sorted(config)[-1]
47 | print(f"selecting {config}")
48 | else:
49 | config = config[0]
50 |
51 |
52 | config = OmegaConf.load(config)
53 | return load_model_from_config(config, ckpt[0], device)
54 |
55 | def load_model_from_config(config, ckpt, device="cpu", verbose=False):
56 | """Loads a model from config and a ckpt
57 | if config is a path will use omegaconf to load
58 | """
59 | if isinstance(config, (str, Path)):
60 | config = OmegaConf.load(config)
61 |
62 | with all_logging_disabled():
63 | print(f"Loading model from {ckpt}")
64 | pl_sd = torch.load(ckpt, map_location="cpu")
65 | global_step = pl_sd["global_step"]
66 | sd = pl_sd["state_dict"]
67 | model = instantiate_from_config(config.model)
68 | m, u = model.load_state_dict(sd, strict=False)
69 | if len(m) > 0 and verbose:
70 | print("missing keys:")
71 | print(m)
72 | if len(u) > 0 and verbose:
73 | print("unexpected keys:")
74 | model.to(device)
75 | model.eval()
76 | model.cond_stage_model.device = device
77 | return model
--------------------------------------------------------------------------------
/MOVIS/ldm/thirdp/psp/model_irse.py:
--------------------------------------------------------------------------------
1 | # https://github.com/eladrich/pixel2style2pixel
2 |
3 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
4 | from ldm.thirdp.psp.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
5 |
6 | """
7 | Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
8 | """
9 |
10 |
11 | class Backbone(Module):
12 | def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
13 | super(Backbone, self).__init__()
14 | assert input_size in [112, 224], "input_size should be 112 or 224"
15 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
16 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
17 | blocks = get_blocks(num_layers)
18 | if mode == 'ir':
19 | unit_module = bottleneck_IR
20 | elif mode == 'ir_se':
21 | unit_module = bottleneck_IR_SE
22 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
23 | BatchNorm2d(64),
24 | PReLU(64))
25 | if input_size == 112:
26 | self.output_layer = Sequential(BatchNorm2d(512),
27 | Dropout(drop_ratio),
28 | Flatten(),
29 | Linear(512 * 7 * 7, 512),
30 | BatchNorm1d(512, affine=affine))
31 | else:
32 | self.output_layer = Sequential(BatchNorm2d(512),
33 | Dropout(drop_ratio),
34 | Flatten(),
35 | Linear(512 * 14 * 14, 512),
36 | BatchNorm1d(512, affine=affine))
37 |
38 | modules = []
39 | for block in blocks:
40 | for bottleneck in block:
41 | modules.append(unit_module(bottleneck.in_channel,
42 | bottleneck.depth,
43 | bottleneck.stride))
44 | self.body = Sequential(*modules)
45 |
46 | def forward(self, x):
47 | x = self.input_layer(x)
48 | x = self.body(x)
49 | x = self.output_layer(x)
50 | return l2_norm(x)
51 |
52 |
53 | def IR_50(input_size):
54 | """Constructs a ir-50 model."""
55 | model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
56 | return model
57 |
58 |
59 | def IR_101(input_size):
60 | """Constructs a ir-101 model."""
61 | model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
62 | return model
63 |
64 |
65 | def IR_152(input_size):
66 | """Constructs a ir-152 model."""
67 | model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
68 | return model
69 |
70 |
71 | def IR_SE_50(input_size):
72 | """Constructs a ir_se-50 model."""
73 | model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
74 | return model
75 |
76 |
77 | def IR_SE_101(input_size):
78 | """Constructs a ir_se-101 model."""
79 | model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
80 | return model
81 |
82 |
83 | def IR_SE_152(input_size):
84 | """Constructs a ir_se-152 model."""
85 | model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
86 | return model
--------------------------------------------------------------------------------
/MOVIS/ldm/modules/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError('Decay must be between 0 and 1')
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
14 | else torch.tensor(-1,dtype=torch.int))
15 |
16 | for name, p in model.named_parameters():
17 | if p.requires_grad:
18 | #remove as '.'-character is not allowed in buffers
19 | s_name = name.replace('.','')
20 | self.m_name2s_name.update({name:s_name})
21 | self.register_buffer(s_name,p.clone().detach().data)
22 |
23 | self.collected_params = []
24 |
25 | def forward(self,model):
26 | decay = self.decay
27 |
28 | if self.num_updates >= 0:
29 | self.num_updates += 1
30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
31 |
32 | one_minus_decay = 1.0 - decay
33 |
34 | with torch.no_grad():
35 | m_param = dict(model.named_parameters())
36 | shadow_params = dict(self.named_buffers())
37 |
38 | for key in m_param:
39 | if m_param[key].requires_grad:
40 | sname = self.m_name2s_name[key]
41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
43 | else:
44 | assert not key in self.m_name2s_name
45 |
46 | def copy_to(self, model):
47 | m_param = dict(model.named_parameters())
48 | shadow_params = dict(self.named_buffers())
49 | for key in m_param:
50 | if m_param[key].requires_grad:
51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
52 | else:
53 | assert not key in self.m_name2s_name
54 |
55 | def store(self, parameters):
56 | """
57 | Save the current parameters for restoring later.
58 | Args:
59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
60 | temporarily stored.
61 | """
62 | self.collected_params = [param.clone() for param in parameters]
63 |
64 | def restore(self, parameters):
65 | """
66 | Restore the parameters stored with the `store` method.
67 | Useful to validate the model with EMA parameters without affecting the
68 | original optimization process. Store the parameters before the
69 | `copy_to` method. After validation (or model saving), use this to
70 | restore the former parameters.
71 | Args:
72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
73 | updated with the stored parameters.
74 | """
75 | for c_param, param in zip(self.collected_params, parameters):
76 | param.data.copy_(c_param.data)
77 |
--------------------------------------------------------------------------------
/MOVIS/ldm/modules/distributions/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 |
35 | def sample(self):
36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37 | return x
38 |
39 | def kl(self, other=None):
40 | if self.deterministic:
41 | return torch.Tensor([0.])
42 | else:
43 | if other is None:
44 | return 0.5 * torch.sum(torch.pow(self.mean, 2)
45 | + self.var - 1.0 - self.logvar,
46 | dim=[1, 2, 3])
47 | else:
48 | return 0.5 * torch.sum(
49 | torch.pow(self.mean - other.mean, 2) / other.var
50 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
51 | dim=[1, 2, 3])
52 |
53 | def nll(self, sample, dims=[1,2,3]):
54 | if self.deterministic:
55 | return torch.Tensor([0.])
56 | logtwopi = np.log(2.0 * np.pi)
57 | return 0.5 * torch.sum(
58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59 | dim=dims)
60 |
61 | def mode(self):
62 | return self.mean
63 |
64 |
65 | def normal_kl(mean1, logvar1, mean2, logvar2):
66 | """
67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68 | Compute the KL divergence between two gaussians.
69 | Shapes are automatically broadcasted, so batches can be compared to
70 | scalars, among other use cases.
71 | """
72 | tensor = None
73 | for obj in (mean1, logvar1, mean2, logvar2):
74 | if isinstance(obj, torch.Tensor):
75 | tensor = obj
76 | break
77 | assert tensor is not None, "at least one argument must be a Tensor"
78 |
79 | # Force variances to be Tensors. Broadcasting helps convert scalars to
80 | # Tensors, but it does not work for torch.exp().
81 | logvar1, logvar2 = [
82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83 | for x in (logvar1, logvar2)
84 | ]
85 |
86 | return 0.5 * (
87 | -1.0
88 | + logvar2
89 | - logvar1
90 | + torch.exp(logvar1 - logvar2)
91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92 | )
93 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MOVIS: Enhancing Multi-Object Novel View Synthesis for Indoor Scenes
2 |
3 | This repository contains the official implementation for [MOVIS: Enhancing Multi-Object Novel View Synthesis for Indoor Scenes](https://arxiv.org/abs/2412.11457)
4 |
5 | ### [Project Page](https://jason-aplp.github.io/MOVIS/) | [Paper](https://arxiv.org/abs/2412.11457) | [Weights](https://huggingface.co/datasets/JasonAplp/MOVIS/blob/main/last.ckpt) | [Dataset](https://huggingface.co/datasets/JasonAplp/MOVIS/tree/main) | [Rendering_Scripts](https://github.com/Jason-aplp/MOVIS-render)
6 |
7 |
8 |
9 |
10 | ## Install
11 |
12 | ```bash
13 | conda create -n movis python=3.9
14 | conda activate movis
15 | cd MOVIS
16 | pip install -r requirements.txt
17 | git clone https://github.com/CompVis/taming-transformers.git
18 | pip install -e taming-transformers/
19 | git clone https://github.com/openai/CLIP.git
20 | pip install -e CLIP/
21 | ```
22 | Download the checkpoint and put it under `MOVIS`.
23 |
24 | ## Single-Image inference
25 |
26 | ```bash
27 | bash eval_single.sh
28 | ```
29 | Revise the parameters within the script accordingly if one wants to change example.
30 | We use [SAM](https://github.com/facebookresearch/segment-anything) and [Depth-FM](https://github.com/CompVis/depth-fm) for getting estimated mask and depth. The background area in the depth map should be cropped out.
31 |
32 | ## Dataset inference
33 | Download [C_Obj](https://huggingface.co/datasets/JasonAplp/MOVIS/tree/main/C_Obj) or [C3DFS_test split](https://huggingface.co/datasets/JasonAplp/MOVIS/tree/main/MOVIS-test) for benchmarking.
34 |
35 | ```bash
36 | bash eval_batch_3d.sh
37 | bash eval_batch_cobj.sh
38 | ```
39 | You should revise the dataset path in the `configs/inference_cobj.yaml` and `configs/inference_c3dfs.yaml` file (data-params-root_dir) before running the training script.
40 |
41 | Note that we provide the models used in C_Obj as well, if you only want to use the renderings for benchmarking, please change the path to the `renderings` folder.
42 |
43 | ## Training
44 | Download image-conditioned stable diffusion checkpoint released by Lambda Labs:
45 | ```bash
46 | wget https://cv.cs.columbia.edu/zero123/assets/sd-image-conditioned-v2.ckpt
47 | ```
48 | Download the dataset from [here](https://huggingface.co/datasets/JasonAplp/MOVIS/tree/main), the dataset structure should be like this:
49 | ```
50 | MOVIS-train/
51 | 000000_004999/
52 | 0/
53 | 1/
54 | ...
55 | 095000_099999/
56 | train_path.json
57 | ```
58 | Run training script:
59 | ```bash
60 | bash train.sh
61 | ```
62 | One should revise the dataset path in the `configs/3d_mix.yaml` file (data-params-root_dir) before running the training script.
63 | Note that this training script is set for an 8-GPU system, each with 80GB of VRAM. If you have smaller GPUs, consider using smaller batch size and gradient accumulation to obtain a similar effective batch size.
64 |
65 |
66 | ## Acknowledgement
67 | This repository is based on [Zero-1-to-3](https://github.com/cvlab-columbia/zero123). We would like to thank the authors of these work for publicly releasing their code.
68 |
69 | ## Citation
70 | ```
71 | @article{lu2024movis,
72 | title={MOVIS: Enhancing Multi-Object Novel View Synthesis for Indoor Scenes},
73 | author={Lu, Ruijie and Chen, Yixin and Ni, Junfeng and Jia, Baoxiong and Liu, Yu and Wan, Diwen and Zeng, Gang and Huang, Siyuan},
74 | journal={arXiv preprint arXiv:2412.11457},
75 | year={2024}
76 | }
77 | ```
--------------------------------------------------------------------------------
/MOVIS/ldm/data/lsun.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import PIL
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 | from torchvision import transforms
7 |
8 |
9 | class LSUNBase(Dataset):
10 | def __init__(self,
11 | txt_file,
12 | data_root,
13 | size=None,
14 | interpolation="bicubic",
15 | flip_p=0.5
16 | ):
17 | self.data_paths = txt_file
18 | self.data_root = data_root
19 | with open(self.data_paths, "r") as f:
20 | self.image_paths = f.read().splitlines()
21 | self._length = len(self.image_paths)
22 | self.labels = {
23 | "relative_file_path_": [l for l in self.image_paths],
24 | "file_path_": [os.path.join(self.data_root, l)
25 | for l in self.image_paths],
26 | }
27 |
28 | self.size = size
29 | self.interpolation = {"linear": PIL.Image.LINEAR,
30 | "bilinear": PIL.Image.BILINEAR,
31 | "bicubic": PIL.Image.BICUBIC,
32 | "lanczos": PIL.Image.LANCZOS,
33 | }[interpolation]
34 | self.flip = transforms.RandomHorizontalFlip(p=flip_p)
35 |
36 | def __len__(self):
37 | return self._length
38 |
39 | def __getitem__(self, i):
40 | example = dict((k, self.labels[k][i]) for k in self.labels)
41 | image = Image.open(example["file_path_"])
42 | if not image.mode == "RGB":
43 | image = image.convert("RGB")
44 |
45 | # default to score-sde preprocessing
46 | img = np.array(image).astype(np.uint8)
47 | crop = min(img.shape[0], img.shape[1])
48 | h, w, = img.shape[0], img.shape[1]
49 | img = img[(h - crop) // 2:(h + crop) // 2,
50 | (w - crop) // 2:(w + crop) // 2]
51 |
52 | image = Image.fromarray(img)
53 | if self.size is not None:
54 | image = image.resize((self.size, self.size), resample=self.interpolation)
55 |
56 | image = self.flip(image)
57 | image = np.array(image).astype(np.uint8)
58 | example["image"] = (image / 127.5 - 1.0).astype(np.float32)
59 | return example
60 |
61 |
62 | class LSUNChurchesTrain(LSUNBase):
63 | def __init__(self, **kwargs):
64 | super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
65 |
66 |
67 | class LSUNChurchesValidation(LSUNBase):
68 | def __init__(self, flip_p=0., **kwargs):
69 | super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
70 | flip_p=flip_p, **kwargs)
71 |
72 |
73 | class LSUNBedroomsTrain(LSUNBase):
74 | def __init__(self, **kwargs):
75 | super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
76 |
77 |
78 | class LSUNBedroomsValidation(LSUNBase):
79 | def __init__(self, flip_p=0.0, **kwargs):
80 | super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
81 | flip_p=flip_p, **kwargs)
82 |
83 |
84 | class LSUNCatsTrain(LSUNBase):
85 | def __init__(self, **kwargs):
86 | super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
87 |
88 |
89 | class LSUNCatsValidation(LSUNBase):
90 | def __init__(self, flip_p=0., **kwargs):
91 | super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
92 | flip_p=flip_p, **kwargs)
93 |
--------------------------------------------------------------------------------
/MOVIS/ldm/guidance.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple
2 | from scipy import interpolate
3 | import numpy as np
4 | import torch
5 | import matplotlib.pyplot as plt
6 | from IPython.display import clear_output
7 | import abc
8 |
9 |
10 | class GuideModel(torch.nn.Module, abc.ABC):
11 | def __init__(self) -> None:
12 | super().__init__()
13 |
14 | @abc.abstractmethod
15 | def preprocess(self, x_img):
16 | pass
17 |
18 | @abc.abstractmethod
19 | def compute_loss(self, inp):
20 | pass
21 |
22 |
23 | class Guider(torch.nn.Module):
24 | def __init__(self, sampler, guide_model, scale=1.0, verbose=False):
25 | """Apply classifier guidance
26 |
27 | Specify a guidance scale as either a scalar
28 | Or a schedule as a list of tuples t = 0->1 and scale, e.g.
29 | [(0, 10), (0.5, 20), (1, 50)]
30 | """
31 | super().__init__()
32 | self.sampler = sampler
33 | self.index = 0
34 | self.show = verbose
35 | self.guide_model = guide_model
36 | self.history = []
37 |
38 | if isinstance(scale, (Tuple, List)):
39 | times = np.array([x[0] for x in scale])
40 | values = np.array([x[1] for x in scale])
41 | self.scale_schedule = {"times": times, "values": values}
42 | else:
43 | self.scale_schedule = float(scale)
44 |
45 | self.ddim_timesteps = sampler.ddim_timesteps
46 | self.ddpm_num_timesteps = sampler.ddpm_num_timesteps
47 |
48 |
49 | def get_scales(self):
50 | if isinstance(self.scale_schedule, float):
51 | return len(self.ddim_timesteps)*[self.scale_schedule]
52 |
53 | interpolater = interpolate.interp1d(self.scale_schedule["times"], self.scale_schedule["values"])
54 | fractional_steps = np.array(self.ddim_timesteps)/self.ddpm_num_timesteps
55 | return interpolater(fractional_steps)
56 |
57 | def modify_score(self, model, e_t, x, t, c):
58 |
59 | # TODO look up index by t
60 | scale = self.get_scales()[self.index]
61 |
62 | if (scale == 0):
63 | return e_t
64 |
65 | sqrt_1ma = self.sampler.ddim_sqrt_one_minus_alphas[self.index].to(x.device)
66 | with torch.enable_grad():
67 | x_in = x.detach().requires_grad_(True)
68 | pred_x0 = model.predict_start_from_noise(x_in, t=t, noise=e_t)
69 | x_img = model.first_stage_model.decode((1/0.18215)*pred_x0)
70 |
71 | inp = self.guide_model.preprocess(x_img)
72 | loss = self.guide_model.compute_loss(inp)
73 | grads = torch.autograd.grad(loss.sum(), x_in)[0]
74 | correction = grads * scale
75 |
76 | if self.show:
77 | clear_output(wait=True)
78 | print(loss.item(), scale, correction.abs().max().item(), e_t.abs().max().item())
79 | self.history.append([loss.item(), scale, correction.min().item(), correction.max().item()])
80 | plt.imshow((inp[0].detach().permute(1,2,0).clamp(-1,1).cpu()+1)/2)
81 | plt.axis('off')
82 | plt.show()
83 | plt.imshow(correction[0][0].detach().cpu())
84 | plt.axis('off')
85 | plt.show()
86 |
87 |
88 | e_t_mod = e_t - sqrt_1ma*correction
89 | if self.show:
90 | fig, axs = plt.subplots(1, 3)
91 | axs[0].imshow(e_t[0][0].detach().cpu(), vmin=-2, vmax=+2)
92 | axs[1].imshow(e_t_mod[0][0].detach().cpu(), vmin=-2, vmax=+2)
93 | axs[2].imshow(correction[0][0].detach().cpu(), vmin=-2, vmax=+2)
94 | plt.show()
95 | self.index += 1
96 | return e_t_mod
--------------------------------------------------------------------------------
/MOVIS/ldm/modules/evaluate/ssim.py:
--------------------------------------------------------------------------------
1 | # MIT Licence
2 |
3 | # Methods to predict the SSIM, taken from
4 | # https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py
5 |
6 | from math import exp
7 |
8 | import torch
9 | import torch.nn.functional as F
10 | from torch.autograd import Variable
11 |
12 | def gaussian(window_size, sigma):
13 | gauss = torch.Tensor(
14 | [
15 | exp(-((x - window_size // 2) ** 2) / float(2 * sigma ** 2))
16 | for x in range(window_size)
17 | ]
18 | )
19 | return gauss / gauss.sum()
20 |
21 |
22 | def create_window(window_size, channel):
23 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
24 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
25 | window = Variable(
26 | _2D_window.expand(channel, 1, window_size, window_size).contiguous()
27 | )
28 | return window
29 |
30 |
31 | def _ssim(
32 | img1, img2, window, window_size, channel, mask=None, size_average=True
33 | ):
34 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
35 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
36 |
37 | mu1_sq = mu1.pow(2)
38 | mu2_sq = mu2.pow(2)
39 | mu1_mu2 = mu1 * mu2
40 |
41 | sigma1_sq = (
42 | F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel)
43 | - mu1_sq
44 | )
45 | sigma2_sq = (
46 | F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel)
47 | - mu2_sq
48 | )
49 | sigma12 = (
50 | F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel)
51 | - mu1_mu2
52 | )
53 |
54 | C1 = (0.01) ** 2
55 | C2 = (0.03) ** 2
56 |
57 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
58 | (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
59 | )
60 |
61 | if not (mask is None):
62 | b = mask.size(0)
63 | ssim_map = ssim_map.mean(dim=1, keepdim=True) * mask
64 | ssim_map = ssim_map.view(b, -1).sum(dim=1) / mask.view(b, -1).sum(
65 | dim=1
66 | ).clamp(min=1)
67 | return ssim_map
68 |
69 | import pdb
70 |
71 | pdb.set_trace
72 |
73 | if size_average:
74 | return ssim_map.mean()
75 | else:
76 | return ssim_map.mean(1).mean(1).mean(1)
77 |
78 |
79 | class SSIM(torch.nn.Module):
80 | def __init__(self, window_size=11, size_average=True):
81 | super(SSIM, self).__init__()
82 | self.window_size = window_size
83 | self.size_average = size_average
84 | self.channel = 1
85 | self.window = create_window(window_size, self.channel)
86 |
87 | def forward(self, img1, img2, mask=None):
88 | (_, channel, _, _) = img1.size()
89 |
90 | if (
91 | channel == self.channel
92 | and self.window.data.type() == img1.data.type()
93 | ):
94 | window = self.window
95 | else:
96 | window = create_window(self.window_size, channel)
97 |
98 | if img1.is_cuda:
99 | window = window.cuda(img1.get_device())
100 | window = window.type_as(img1)
101 |
102 | self.window = window
103 | self.channel = channel
104 |
105 | return _ssim(
106 | img1,
107 | img2,
108 | window,
109 | self.window_size,
110 | channel,
111 | mask,
112 | self.size_average,
113 | )
114 |
115 |
116 | def ssim(img1, img2, window_size=11, mask=None, size_average=True):
117 | (_, channel, _, _) = img1.size()
118 | window = create_window(window_size, channel)
119 |
120 | if img1.is_cuda:
121 | window = window.cuda(img1.get_device())
122 | window = window.type_as(img1)
123 |
124 | return _ssim(img1, img2, window, window_size, channel, mask, size_average)
125 |
--------------------------------------------------------------------------------
/MOVIS/configs/3d_mix.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-04
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: "image_target"
11 | cond_stage_key: "image_cond"
12 | image_size: 32
13 | channels: 4
14 | cond_stage_trainable: false # Note: different from the one we trained before
15 | conditioning_key: hybrid
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 |
19 | timestep_sampling_config:
20 | target: 'Gaussian'
21 | params:
22 | var: 200
23 |
24 | cape: True
25 | depth1: True
26 | mask_super: True
27 | schedule: True
28 |
29 | scheduler_config: # 10000 warmup steps
30 | target: ldm.lr_scheduler.LambdaLinearScheduler
31 | params:
32 | warm_up_steps: [ 100 ]
33 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
34 | f_start: [ 1.e-6 ]
35 | f_max: [ 1. ]
36 | f_min: [ 1. ]
37 |
38 | unet_config:
39 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
40 | params:
41 | image_size: 32 # unused
42 | in_channels: 8
43 | out_channels: 8
44 | model_channels: 320
45 | attention_resolutions: [ 4, 2, 1 ]
46 | num_res_blocks: 2
47 | channel_mult: [ 1, 2, 4, 4 ]
48 | num_heads: 8
49 | use_spatial_transformer: True
50 | transformer_depth: 1
51 | context_dim: 768
52 | use_checkpoint: True
53 | legacy: False
54 |
55 | first_stage_config:
56 | target: ldm.models.autoencoder.AutoencoderKL
57 | params:
58 | embed_dim: 4
59 | monitor: val/rec_loss
60 | ddconfig:
61 | double_z: true
62 | z_channels: 4
63 | resolution: 256
64 | in_channels: 3
65 | out_ch: 3
66 | ch: 128
67 | ch_mult:
68 | - 1
69 | - 2
70 | - 4
71 | - 4
72 | num_res_blocks: 2
73 | attn_resolutions: []
74 | dropout: 0.0
75 | lossconfig:
76 | target: torch.nn.Identity
77 |
78 | cond_stage_config:
79 | target: ldm.modules.encoders.modules.FrozenCLIPImageEmbedder
80 | params:
81 | cape: True
82 |
83 |
84 | data:
85 | target: ldm.data.simple.ObjaverseDataModuleFromConfig
86 | params:
87 | debug: True
88 | root_dir: /path/to/training_data
89 | batch_size: 172
90 | num_workers: 16
91 | total_view: 12
92 | depth1: True
93 | mask_super: True
94 | schedule: True
95 | train:
96 | validation: False
97 | image_transforms:
98 | size: 256
99 |
100 | validation:
101 | validation: True
102 | image_transforms:
103 | size: 256
104 |
105 |
106 | lightning:
107 | find_unused_parameters: false
108 | metrics_over_trainsteps_checkpoint: True
109 | modelcheckpoint:
110 | params:
111 | every_n_train_steps: 500
112 | callbacks:
113 | image_logger:
114 | target: main.ImageLogger
115 | params:
116 | batch_frequency: 500
117 | max_images: 32
118 | increase_log_steps: False
119 | log_first_step: True
120 | log_all_val: True
121 | log_images_kwargs:
122 | use_ema_scope: False
123 | inpaint: False
124 | plot_progressive_rows: False
125 | plot_diffusion_rows: False
126 | N: 32
127 | unconditional_guidance_scale: 3.0
128 | unconditional_guidance_label: [""]
129 |
130 | trainer:
131 | benchmark: True
132 | # val_check_interval: 5000000 # really sorry
133 | # val_check_interval: 100 # really sorry
134 | num_sanity_val_steps: 0
135 | accumulate_grad_batches: 1
136 |
137 | logger:
138 | name: "wandb"
139 | entity: "scene123"
140 | offline: False
141 |
--------------------------------------------------------------------------------
/MOVIS/configs/inference_c3dfs.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-04
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: "image_target"
11 | cond_stage_key: "image_cond"
12 | image_size: 32
13 | channels: 4
14 | cond_stage_trainable: false # Note: different from the one we trained before
15 | conditioning_key: hybrid
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 |
19 | timestep_sampling_config:
20 | target: 'Gaussian'
21 | params:
22 | var: 200
23 |
24 | cape: True
25 | depth1: True
26 | mask_super: True
27 | schedule: True
28 |
29 | scheduler_config: # 10000 warmup steps
30 | target: ldm.lr_scheduler.LambdaLinearScheduler
31 | params:
32 | warm_up_steps: [ 100 ]
33 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
34 | f_start: [ 1.e-6 ]
35 | f_max: [ 1. ]
36 | f_min: [ 1. ]
37 |
38 | unet_config:
39 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
40 | params:
41 | image_size: 32 # unused
42 | in_channels: 8
43 | out_channels: 8
44 | model_channels: 320
45 | attention_resolutions: [ 4, 2, 1 ]
46 | num_res_blocks: 2
47 | channel_mult: [ 1, 2, 4, 4 ]
48 | num_heads: 8
49 | use_spatial_transformer: True
50 | transformer_depth: 1
51 | context_dim: 768
52 | use_checkpoint: True
53 | legacy: False
54 |
55 | first_stage_config:
56 | target: ldm.models.autoencoder.AutoencoderKL
57 | params:
58 | embed_dim: 4
59 | monitor: val/rec_loss
60 | ddconfig:
61 | double_z: true
62 | z_channels: 4
63 | resolution: 256
64 | in_channels: 3
65 | out_ch: 3
66 | ch: 128
67 | ch_mult:
68 | - 1
69 | - 2
70 | - 4
71 | - 4
72 | num_res_blocks: 2
73 | attn_resolutions: []
74 | dropout: 0.0
75 | lossconfig:
76 | target: torch.nn.Identity
77 |
78 | cond_stage_config:
79 | target: ldm.modules.encoders.modules.FrozenCLIPImageEmbedder
80 | params:
81 | cape: True
82 |
83 |
84 | data:
85 | target: ldm.data.simple.ObjaverseDataModuleFromConfig
86 | params:
87 | debug: True
88 | root_dir: /path/to/c3dfs_test
89 | batch_size: 172
90 | num_workers: 16
91 | total_view: 12
92 | depth1: True
93 | mask_super: True
94 | schedule: True
95 | train:
96 | validation: False
97 | image_transforms:
98 | size: 256
99 |
100 | validation:
101 | validation: True
102 | image_transforms:
103 | size: 256
104 |
105 |
106 | lightning:
107 | find_unused_parameters: false
108 | metrics_over_trainsteps_checkpoint: True
109 | modelcheckpoint:
110 | params:
111 | every_n_train_steps: 500
112 | callbacks:
113 | image_logger:
114 | target: main.ImageLogger
115 | params:
116 | batch_frequency: 500
117 | max_images: 32
118 | increase_log_steps: False
119 | log_first_step: True
120 | log_all_val: True
121 | log_images_kwargs:
122 | use_ema_scope: False
123 | inpaint: False
124 | plot_progressive_rows: False
125 | plot_diffusion_rows: False
126 | N: 32
127 | unconditional_guidance_scale: 3.0
128 | unconditional_guidance_label: [""]
129 |
130 | trainer:
131 | benchmark: True
132 | # val_check_interval: 5000000 # really sorry
133 | # val_check_interval: 100 # really sorry
134 | num_sanity_val_steps: 0
135 | accumulate_grad_batches: 1
136 |
137 | logger:
138 | name: "wandb"
139 | entity: "scene123"
140 | offline: False
141 |
--------------------------------------------------------------------------------
/MOVIS/configs/inference_cobj.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-04
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: "image_target"
11 | cond_stage_key: "image_cond"
12 | image_size: 32
13 | channels: 4
14 | cond_stage_trainable: false # Note: different from the one we trained before
15 | conditioning_key: hybrid
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 |
19 | timestep_sampling_config:
20 | target: 'Gaussian'
21 | params:
22 | var: 200
23 |
24 | cape: True
25 | depth1: True
26 | mask_super: True
27 | schedule: True
28 |
29 | scheduler_config: # 10000 warmup steps
30 | target: ldm.lr_scheduler.LambdaLinearScheduler
31 | params:
32 | warm_up_steps: [ 100 ]
33 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
34 | f_start: [ 1.e-6 ]
35 | f_max: [ 1. ]
36 | f_min: [ 1. ]
37 |
38 | unet_config:
39 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
40 | params:
41 | image_size: 32 # unused
42 | in_channels: 8
43 | out_channels: 8
44 | model_channels: 320
45 | attention_resolutions: [ 4, 2, 1 ]
46 | num_res_blocks: 2
47 | channel_mult: [ 1, 2, 4, 4 ]
48 | num_heads: 8
49 | use_spatial_transformer: True
50 | transformer_depth: 1
51 | context_dim: 768
52 | use_checkpoint: True
53 | legacy: False
54 |
55 | first_stage_config:
56 | target: ldm.models.autoencoder.AutoencoderKL
57 | params:
58 | embed_dim: 4
59 | monitor: val/rec_loss
60 | ddconfig:
61 | double_z: true
62 | z_channels: 4
63 | resolution: 256
64 | in_channels: 3
65 | out_ch: 3
66 | ch: 128
67 | ch_mult:
68 | - 1
69 | - 2
70 | - 4
71 | - 4
72 | num_res_blocks: 2
73 | attn_resolutions: []
74 | dropout: 0.0
75 | lossconfig:
76 | target: torch.nn.Identity
77 |
78 | cond_stage_config:
79 | target: ldm.modules.encoders.modules.FrozenCLIPImageEmbedder
80 | params:
81 | cape: True
82 |
83 |
84 | data:
85 | target: ldm.data.simple.ObjaverseDataModuleFromConfig
86 | params:
87 | debug: True
88 | root_dir: /path/to/compositional_objaverse
89 | batch_size: 172
90 | num_workers: 16
91 | total_view: 12
92 | depth1: True
93 | mask_super: True
94 | schedule: True
95 | train:
96 | validation: False
97 | image_transforms:
98 | size: 256
99 |
100 | validation:
101 | validation: True
102 | image_transforms:
103 | size: 256
104 |
105 |
106 | lightning:
107 | find_unused_parameters: false
108 | metrics_over_trainsteps_checkpoint: True
109 | modelcheckpoint:
110 | params:
111 | every_n_train_steps: 500
112 | callbacks:
113 | image_logger:
114 | target: main.ImageLogger
115 | params:
116 | batch_frequency: 500
117 | max_images: 32
118 | increase_log_steps: False
119 | log_first_step: True
120 | log_all_val: True
121 | log_images_kwargs:
122 | use_ema_scope: False
123 | inpaint: False
124 | plot_progressive_rows: False
125 | plot_diffusion_rows: False
126 | N: 32
127 | unconditional_guidance_scale: 3.0
128 | unconditional_guidance_label: [""]
129 |
130 | trainer:
131 | benchmark: True
132 | # val_check_interval: 5000000 # really sorry
133 | # val_check_interval: 100 # really sorry
134 | num_sanity_val_steps: 0
135 | accumulate_grad_batches: 1
136 |
137 | logger:
138 | name: "wandb"
139 | entity: "scene123"
140 | offline: False
141 |
--------------------------------------------------------------------------------
/MOVIS/ldm/thirdp/psp/helpers.py:
--------------------------------------------------------------------------------
1 | # https://github.com/eladrich/pixel2style2pixel
2 |
3 | from collections import namedtuple
4 | import torch
5 | from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
6 |
7 | """
8 | ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
9 | """
10 |
11 |
12 | class Flatten(Module):
13 | def forward(self, input):
14 | return input.view(input.size(0), -1)
15 |
16 |
17 | def l2_norm(input, axis=1):
18 | norm = torch.norm(input, 2, axis, True)
19 | output = torch.div(input, norm)
20 | return output
21 |
22 |
23 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
24 | """ A named tuple describing a ResNet block. """
25 |
26 |
27 | def get_block(in_channel, depth, num_units, stride=2):
28 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
29 |
30 |
31 | def get_blocks(num_layers):
32 | if num_layers == 50:
33 | blocks = [
34 | get_block(in_channel=64, depth=64, num_units=3),
35 | get_block(in_channel=64, depth=128, num_units=4),
36 | get_block(in_channel=128, depth=256, num_units=14),
37 | get_block(in_channel=256, depth=512, num_units=3)
38 | ]
39 | elif num_layers == 100:
40 | blocks = [
41 | get_block(in_channel=64, depth=64, num_units=3),
42 | get_block(in_channel=64, depth=128, num_units=13),
43 | get_block(in_channel=128, depth=256, num_units=30),
44 | get_block(in_channel=256, depth=512, num_units=3)
45 | ]
46 | elif num_layers == 152:
47 | blocks = [
48 | get_block(in_channel=64, depth=64, num_units=3),
49 | get_block(in_channel=64, depth=128, num_units=8),
50 | get_block(in_channel=128, depth=256, num_units=36),
51 | get_block(in_channel=256, depth=512, num_units=3)
52 | ]
53 | else:
54 | raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
55 | return blocks
56 |
57 |
58 | class SEModule(Module):
59 | def __init__(self, channels, reduction):
60 | super(SEModule, self).__init__()
61 | self.avg_pool = AdaptiveAvgPool2d(1)
62 | self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
63 | self.relu = ReLU(inplace=True)
64 | self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
65 | self.sigmoid = Sigmoid()
66 |
67 | def forward(self, x):
68 | module_input = x
69 | x = self.avg_pool(x)
70 | x = self.fc1(x)
71 | x = self.relu(x)
72 | x = self.fc2(x)
73 | x = self.sigmoid(x)
74 | return module_input * x
75 |
76 |
77 | class bottleneck_IR(Module):
78 | def __init__(self, in_channel, depth, stride):
79 | super(bottleneck_IR, self).__init__()
80 | if in_channel == depth:
81 | self.shortcut_layer = MaxPool2d(1, stride)
82 | else:
83 | self.shortcut_layer = Sequential(
84 | Conv2d(in_channel, depth, (1, 1), stride, bias=False),
85 | BatchNorm2d(depth)
86 | )
87 | self.res_layer = Sequential(
88 | BatchNorm2d(in_channel),
89 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
90 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
91 | )
92 |
93 | def forward(self, x):
94 | shortcut = self.shortcut_layer(x)
95 | res = self.res_layer(x)
96 | return res + shortcut
97 |
98 |
99 | class bottleneck_IR_SE(Module):
100 | def __init__(self, in_channel, depth, stride):
101 | super(bottleneck_IR_SE, self).__init__()
102 | if in_channel == depth:
103 | self.shortcut_layer = MaxPool2d(1, stride)
104 | else:
105 | self.shortcut_layer = Sequential(
106 | Conv2d(in_channel, depth, (1, 1), stride, bias=False),
107 | BatchNorm2d(depth)
108 | )
109 | self.res_layer = Sequential(
110 | BatchNorm2d(in_channel),
111 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
112 | PReLU(depth),
113 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
114 | BatchNorm2d(depth),
115 | SEModule(depth, 16)
116 | )
117 |
118 | def forward(self, x):
119 | shortcut = self.shortcut_layer(x)
120 | res = self.res_layer(x)
121 | return res + shortcut
--------------------------------------------------------------------------------
/MOVIS/ldm/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class LambdaWarmUpCosineScheduler:
5 | """
6 | note: use with a base_lr of 1.0
7 | """
8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9 | self.lr_warm_up_steps = warm_up_steps
10 | self.lr_start = lr_start
11 | self.lr_min = lr_min
12 | self.lr_max = lr_max
13 | self.lr_max_decay_steps = max_decay_steps
14 | self.last_lr = 0.
15 | self.verbosity_interval = verbosity_interval
16 |
17 | def schedule(self, n, **kwargs):
18 | if self.verbosity_interval > 0:
19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20 | if n < self.lr_warm_up_steps:
21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22 | self.last_lr = lr
23 | return lr
24 | else:
25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26 | t = min(t, 1.0)
27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28 | 1 + np.cos(t * np.pi))
29 | self.last_lr = lr
30 | return lr
31 |
32 | def __call__(self, n, **kwargs):
33 | return self.schedule(n,**kwargs)
34 |
35 |
36 | class LambdaWarmUpCosineScheduler2:
37 | """
38 | supports repeated iterations, configurable via lists
39 | note: use with a base_lr of 1.0.
40 | """
41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
43 | self.lr_warm_up_steps = warm_up_steps
44 | self.f_start = f_start
45 | self.f_min = f_min
46 | self.f_max = f_max
47 | self.cycle_lengths = cycle_lengths
48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
49 | self.last_f = 0.
50 | self.verbosity_interval = verbosity_interval
51 |
52 | def find_in_interval(self, n):
53 | interval = 0
54 | for cl in self.cum_cycles[1:]:
55 | if n <= cl:
56 | return interval
57 | interval += 1
58 |
59 | def schedule(self, n, **kwargs):
60 | cycle = self.find_in_interval(n)
61 | n = n - self.cum_cycles[cycle]
62 | if self.verbosity_interval > 0:
63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
64 | f"current cycle {cycle}")
65 | if n < self.lr_warm_up_steps[cycle]:
66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
67 | self.last_f = f
68 | return f
69 | else:
70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
71 | t = min(t, 1.0)
72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
73 | 1 + np.cos(t * np.pi))
74 | self.last_f = f
75 | return f
76 |
77 | def __call__(self, n, **kwargs):
78 | return self.schedule(n, **kwargs)
79 |
80 |
81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
82 |
83 | def schedule(self, n, **kwargs):
84 | cycle = self.find_in_interval(n)
85 | n = n - self.cum_cycles[cycle]
86 | if self.verbosity_interval > 0:
87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
88 | f"current cycle {cycle}")
89 |
90 | if n < self.lr_warm_up_steps[cycle]:
91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
92 | self.last_f = f
93 | return f
94 | else:
95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
96 | self.last_f = f
97 | return f
98 |
99 |
--------------------------------------------------------------------------------
/MOVIS/ldm/data/inpainting/synthetic_mask.py:
--------------------------------------------------------------------------------
1 | from PIL import Image, ImageDraw
2 | import numpy as np
3 |
4 | settings = {
5 | "256narrow": {
6 | "p_irr": 1,
7 | "min_n_irr": 4,
8 | "max_n_irr": 50,
9 | "max_l_irr": 40,
10 | "max_w_irr": 10,
11 | "min_n_box": None,
12 | "max_n_box": None,
13 | "min_s_box": None,
14 | "max_s_box": None,
15 | "marg": None,
16 | },
17 | "256train": {
18 | "p_irr": 0.5,
19 | "min_n_irr": 1,
20 | "max_n_irr": 5,
21 | "max_l_irr": 200,
22 | "max_w_irr": 100,
23 | "min_n_box": 1,
24 | "max_n_box": 4,
25 | "min_s_box": 30,
26 | "max_s_box": 150,
27 | "marg": 10,
28 | },
29 | "512train": { # TODO: experimental
30 | "p_irr": 0.5,
31 | "min_n_irr": 1,
32 | "max_n_irr": 5,
33 | "max_l_irr": 450,
34 | "max_w_irr": 250,
35 | "min_n_box": 1,
36 | "max_n_box": 4,
37 | "min_s_box": 30,
38 | "max_s_box": 300,
39 | "marg": 10,
40 | },
41 | "512train-large": { # TODO: experimental
42 | "p_irr": 0.5,
43 | "min_n_irr": 1,
44 | "max_n_irr": 5,
45 | "max_l_irr": 450,
46 | "max_w_irr": 400,
47 | "min_n_box": 1,
48 | "max_n_box": 4,
49 | "min_s_box": 75,
50 | "max_s_box": 450,
51 | "marg": 10,
52 | },
53 | }
54 |
55 |
56 | def gen_segment_mask(mask, start, end, brush_width):
57 | mask = mask > 0
58 | mask = (255 * mask).astype(np.uint8)
59 | mask = Image.fromarray(mask)
60 | draw = ImageDraw.Draw(mask)
61 | draw.line([start, end], fill=255, width=brush_width, joint="curve")
62 | mask = np.array(mask) / 255
63 | return mask
64 |
65 |
66 | def gen_box_mask(mask, masked):
67 | x_0, y_0, w, h = masked
68 | mask[y_0:y_0 + h, x_0:x_0 + w] = 1
69 | return mask
70 |
71 |
72 | def gen_round_mask(mask, masked, radius):
73 | x_0, y_0, w, h = masked
74 | xy = [(x_0, y_0), (x_0 + w, y_0 + w)]
75 |
76 | mask = mask > 0
77 | mask = (255 * mask).astype(np.uint8)
78 | mask = Image.fromarray(mask)
79 | draw = ImageDraw.Draw(mask)
80 | draw.rounded_rectangle(xy, radius=radius, fill=255)
81 | mask = np.array(mask) / 255
82 | return mask
83 |
84 |
85 | def gen_large_mask(prng, img_h, img_w,
86 | marg, p_irr, min_n_irr, max_n_irr, max_l_irr, max_w_irr,
87 | min_n_box, max_n_box, min_s_box, max_s_box):
88 | """
89 | img_h: int, an image height
90 | img_w: int, an image width
91 | marg: int, a margin for a box starting coordinate
92 | p_irr: float, 0 <= p_irr <= 1, a probability of a polygonal chain mask
93 |
94 | min_n_irr: int, min number of segments
95 | max_n_irr: int, max number of segments
96 | max_l_irr: max length of a segment in polygonal chain
97 | max_w_irr: max width of a segment in polygonal chain
98 |
99 | min_n_box: int, min bound for the number of box primitives
100 | max_n_box: int, max bound for the number of box primitives
101 | min_s_box: int, min length of a box side
102 | max_s_box: int, max length of a box side
103 | """
104 |
105 | mask = np.zeros((img_h, img_w))
106 | uniform = prng.randint
107 |
108 | if np.random.uniform(0, 1) < p_irr: # generate polygonal chain
109 | n = uniform(min_n_irr, max_n_irr) # sample number of segments
110 |
111 | for _ in range(n):
112 | y = uniform(0, img_h) # sample a starting point
113 | x = uniform(0, img_w)
114 |
115 | a = uniform(0, 360) # sample angle
116 | l = uniform(10, max_l_irr) # sample segment length
117 | w = uniform(5, max_w_irr) # sample a segment width
118 |
119 | # draw segment starting from (x,y) to (x_,y_) using brush of width w
120 | x_ = x + l * np.sin(a)
121 | y_ = y + l * np.cos(a)
122 |
123 | mask = gen_segment_mask(mask, start=(x, y), end=(x_, y_), brush_width=w)
124 | x, y = x_, y_
125 | else: # generate Box masks
126 | n = uniform(min_n_box, max_n_box) # sample number of rectangles
127 |
128 | for _ in range(n):
129 | h = uniform(min_s_box, max_s_box) # sample box shape
130 | w = uniform(min_s_box, max_s_box)
131 |
132 | x_0 = uniform(marg, img_w - marg - w) # sample upper-left coordinates of box
133 | y_0 = uniform(marg, img_h - marg - h)
134 |
135 | if np.random.uniform(0, 1) < 0.5:
136 | mask = gen_box_mask(mask, masked=(x_0, y_0, w, h))
137 | else:
138 | r = uniform(0, 60) # sample radius
139 | mask = gen_round_mask(mask, masked=(x_0, y_0, w, h), radius=r)
140 | return mask
141 |
142 |
143 | make_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["256train"])
144 | make_narrow_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["256narrow"])
145 | make_512_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["512train"])
146 | make_512_lama_mask_large = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["512train-large"])
147 |
148 |
149 | MASK_MODES = {
150 | "256train": make_lama_mask,
151 | "256narrow": make_narrow_lama_mask,
152 | "512train": make_512_lama_mask,
153 | "512train-large": make_512_lama_mask_large
154 | }
155 |
156 | if __name__ == "__main__":
157 | import sys
158 |
159 | out = sys.argv[1]
160 |
161 | prng = np.random.RandomState(1)
162 | kwargs = settings["256train"]
163 | mask = gen_large_mask(prng, 256, 256, **kwargs)
164 | mask = (255 * mask).astype(np.uint8)
165 | mask = Image.fromarray(mask)
166 | mask.save(out)
167 |
--------------------------------------------------------------------------------
/MOVIS/ldm/modules/evaluate/frechet_video_distance.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Lint as: python2, python3
17 | """Minimal Reference implementation for the Frechet Video Distance (FVD).
18 |
19 | FVD is a metric for the quality of video generation models. It is inspired by
20 | the FID (Frechet Inception Distance) used for images, but uses a different
21 | embedding to be better suitable for videos.
22 | """
23 |
24 | from __future__ import absolute_import
25 | from __future__ import division
26 | from __future__ import print_function
27 |
28 |
29 | import six
30 | import tensorflow.compat.v1 as tf
31 | import tensorflow_gan as tfgan
32 | import tensorflow_hub as hub
33 |
34 |
35 | def preprocess(videos, target_resolution):
36 | """Runs some preprocessing on the videos for I3D model.
37 |
38 | Args:
39 | videos: [batch_size, num_frames, height, width, depth] The videos to be
40 | preprocessed. We don't care about the specific dtype of the videos, it can
41 | be anything that tf.image.resize_bilinear accepts. Values are expected to
42 | be in the range 0-255.
43 | target_resolution: (width, height): target video resolution
44 |
45 | Returns:
46 | videos: [batch_size, num_frames, height, width, depth]
47 | """
48 | videos_shape = list(videos.shape)
49 | all_frames = tf.reshape(videos, [-1] + videos_shape[-3:])
50 | resized_videos = tf.image.resize_bilinear(all_frames, size=target_resolution)
51 | target_shape = [videos_shape[0], -1] + list(target_resolution) + [3]
52 | output_videos = tf.reshape(resized_videos, target_shape)
53 | scaled_videos = 2. * tf.cast(output_videos, tf.float32) / 255. - 1
54 | return scaled_videos
55 |
56 |
57 | def _is_in_graph(tensor_name):
58 | """Checks whether a given tensor does exists in the graph."""
59 | try:
60 | tf.get_default_graph().get_tensor_by_name(tensor_name)
61 | except KeyError:
62 | return False
63 | return True
64 |
65 |
66 | def create_id3_embedding(videos,warmup=False,batch_size=16):
67 | """Embeds the given videos using the Inflated 3D Convolution ne twork.
68 |
69 | Downloads the graph of the I3D from tf.hub and adds it to the graph on the
70 | first call.
71 |
72 | Args:
73 | videos: [batch_size, num_frames, height=224, width=224, depth=3].
74 | Expected range is [-1, 1].
75 |
76 | Returns:
77 | embedding: [batch_size, embedding_size]. embedding_size depends
78 | on the model used.
79 |
80 | Raises:
81 | ValueError: when a provided embedding_layer is not supported.
82 | """
83 |
84 | # batch_size = 16
85 | module_spec = "https://tfhub.dev/deepmind/i3d-kinetics-400/1"
86 |
87 |
88 | # Making sure that we import the graph separately for
89 | # each different input video tensor.
90 | module_name = "fvd_kinetics-400_id3_module_" + six.ensure_str(
91 | videos.name).replace(":", "_")
92 |
93 |
94 |
95 | assert_ops = [
96 | tf.Assert(
97 | tf.reduce_max(videos) <= 1.001,
98 | ["max value in frame is > 1", videos]),
99 | tf.Assert(
100 | tf.reduce_min(videos) >= -1.001,
101 | ["min value in frame is < -1", videos]),
102 | tf.assert_equal(
103 | tf.shape(videos)[0],
104 | batch_size, ["invalid frame batch size: ",
105 | tf.shape(videos)],
106 | summarize=6),
107 | ]
108 | with tf.control_dependencies(assert_ops):
109 | videos = tf.identity(videos)
110 |
111 | module_scope = "%s_apply_default/" % module_name
112 |
113 | # To check whether the module has already been loaded into the graph, we look
114 | # for a given tensor name. If this tensor name exists, we assume the function
115 | # has been called before and the graph was imported. Otherwise we import it.
116 | # Note: in theory, the tensor could exist, but have wrong shapes.
117 | # This will happen if create_id3_embedding is called with a frames_placehoder
118 | # of wrong size/batch size, because even though that will throw a tf.Assert
119 | # on graph-execution time, it will insert the tensor (with wrong shape) into
120 | # the graph. This is why we need the following assert.
121 | if warmup:
122 | video_batch_size = int(videos.shape[0])
123 | assert video_batch_size in [batch_size, -1, None], f"Invalid batch size {video_batch_size}"
124 | tensor_name = module_scope + "RGB/inception_i3d/Mean:0"
125 | if not _is_in_graph(tensor_name):
126 | i3d_model = hub.Module(module_spec, name=module_name)
127 | i3d_model(videos)
128 |
129 | # gets the kinetics-i3d-400-logits layer
130 | tensor_name = module_scope + "RGB/inception_i3d/Mean:0"
131 | tensor = tf.get_default_graph().get_tensor_by_name(tensor_name)
132 | return tensor
133 |
134 |
135 | def calculate_fvd(real_activations,
136 | generated_activations):
137 | """Returns a list of ops that compute metrics as funcs of activations.
138 |
139 | Args:
140 | real_activations: [num_samples, embedding_size]
141 | generated_activations: [num_samples, embedding_size]
142 |
143 | Returns:
144 | A scalar that contains the requested FVD.
145 | """
146 | return tfgan.eval.frechet_classifier_distance_from_activations(
147 | real_activations, generated_activations)
148 |
--------------------------------------------------------------------------------
/MOVIS/ldm/modules/losses/contperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
5 |
6 |
7 | class LPIPSWithDiscriminator(nn.Module):
8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
11 | disc_loss="hinge"):
12 |
13 | super().__init__()
14 | assert disc_loss in ["hinge", "vanilla"]
15 | self.kl_weight = kl_weight
16 | self.pixel_weight = pixelloss_weight
17 | self.perceptual_loss = LPIPS().eval()
18 | self.perceptual_weight = perceptual_weight
19 | # output log variance
20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
21 |
22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
23 | n_layers=disc_num_layers,
24 | use_actnorm=use_actnorm
25 | ).apply(weights_init)
26 | self.discriminator_iter_start = disc_start
27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
28 | self.disc_factor = disc_factor
29 | self.discriminator_weight = disc_weight
30 | self.disc_conditional = disc_conditional
31 |
32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
33 | if last_layer is not None:
34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
36 | else:
37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
39 |
40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
42 | d_weight = d_weight * self.discriminator_weight
43 | return d_weight
44 |
45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
46 | global_step, last_layer=None, cond=None, split="train",
47 | weights=None):
48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
49 | if self.perceptual_weight > 0:
50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
51 | rec_loss = rec_loss + self.perceptual_weight * p_loss
52 |
53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
54 | weighted_nll_loss = nll_loss
55 | if weights is not None:
56 | weighted_nll_loss = weights*nll_loss
57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
59 | kl_loss = posteriors.kl()
60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
61 |
62 | # now the GAN part
63 | if optimizer_idx == 0:
64 | # generator update
65 | if cond is None:
66 | assert not self.disc_conditional
67 | logits_fake = self.discriminator(reconstructions.contiguous())
68 | else:
69 | assert self.disc_conditional
70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
71 | g_loss = -torch.mean(logits_fake)
72 |
73 | if self.disc_factor > 0.0:
74 | try:
75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
76 | except RuntimeError:
77 | assert not self.training
78 | d_weight = torch.tensor(0.0)
79 | else:
80 | d_weight = torch.tensor(0.0)
81 |
82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
84 |
85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
87 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
88 | "{}/d_weight".format(split): d_weight.detach(),
89 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
90 | "{}/g_loss".format(split): g_loss.detach().mean(),
91 | }
92 | return loss, log
93 |
94 | if optimizer_idx == 1:
95 | # second pass for discriminator update
96 | if cond is None:
97 | logits_real = self.discriminator(inputs.contiguous().detach())
98 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
99 | else:
100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
102 |
103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
105 |
106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
107 | "{}/logits_real".format(split): logits_real.detach().mean(),
108 | "{}/logits_fake".format(split): logits_fake.detach().mean()
109 | }
110 | return d_loss, log
111 |
112 |
--------------------------------------------------------------------------------
/MOVIS/ldm/data/nerf_like.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | import os
3 | import json
4 | import numpy as np
5 | import torch
6 | import imageio
7 | import math
8 | import cv2
9 | from torchvision import transforms
10 |
11 | def cartesian_to_spherical(xyz):
12 | ptsnew = np.hstack((xyz, np.zeros(xyz.shape)))
13 | xy = xyz[:,0]**2 + xyz[:,1]**2
14 | z = np.sqrt(xy + xyz[:,2]**2)
15 | theta = np.arctan2(np.sqrt(xy), xyz[:,2]) # for elevation angle defined from Z-axis down
16 | #ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up
17 | azimuth = np.arctan2(xyz[:,1], xyz[:,0])
18 | return np.array([theta, azimuth, z])
19 |
20 |
21 | def get_T(T_target, T_cond):
22 | theta_cond, azimuth_cond, z_cond = cartesian_to_spherical(T_cond[None, :])
23 | theta_target, azimuth_target, z_target = cartesian_to_spherical(T_target[None, :])
24 |
25 | d_theta = theta_target - theta_cond
26 | d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi)
27 | d_z = z_target - z_cond
28 |
29 | d_T = torch.tensor([d_theta.item(), math.sin(d_azimuth.item()), math.cos(d_azimuth.item()), d_z.item()])
30 | return d_T
31 |
32 | def get_spherical(T_target, T_cond):
33 | theta_cond, azimuth_cond, z_cond = cartesian_to_spherical(T_cond[None, :])
34 | theta_target, azimuth_target, z_target = cartesian_to_spherical(T_target[None, :])
35 |
36 | d_theta = theta_target - theta_cond
37 | d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi)
38 | d_z = z_target - z_cond
39 |
40 | d_T = torch.tensor([math.degrees(d_theta.item()), math.degrees(d_azimuth.item()), d_z.item()])
41 | return d_T
42 |
43 | class RTMV(Dataset):
44 | def __init__(self, root_dir='datasets/RTMV/google_scanned',\
45 | first_K=64, resolution=256, load_target=False):
46 | self.root_dir = root_dir
47 | self.scene_list = sorted(next(os.walk(root_dir))[1])
48 | self.resolution = resolution
49 | self.first_K = first_K
50 | self.load_target = load_target
51 |
52 | def __len__(self):
53 | return len(self.scene_list)
54 |
55 | def __getitem__(self, idx):
56 | scene_dir = os.path.join(self.root_dir, self.scene_list[idx])
57 | with open(os.path.join(scene_dir, 'transforms.json'), "r") as f:
58 | meta = json.load(f)
59 | imgs = []
60 | poses = []
61 | for i_img in range(self.first_K):
62 | meta_img = meta['frames'][i_img]
63 |
64 | if i_img == 0 or self.load_target:
65 | img_path = os.path.join(scene_dir, meta_img['file_path'])
66 | img = imageio.imread(img_path)
67 | img = cv2.resize(img, (self.resolution, self.resolution), interpolation = cv2.INTER_LINEAR)
68 | imgs.append(img)
69 |
70 | c2w = meta_img['transform_matrix']
71 | poses.append(c2w)
72 |
73 | imgs = (np.array(imgs) / 255.).astype(np.float32) # (RGBA) imgs
74 | imgs = torch.tensor(self.blend_rgba(imgs)).permute(0, 3, 1, 2)
75 | imgs = imgs * 2 - 1. # convert to stable diffusion range
76 | poses = torch.tensor(np.array(poses).astype(np.float32))
77 | return imgs, poses
78 |
79 | def blend_rgba(self, img):
80 | img = img[..., :3] * img[..., -1:] + (1. - img[..., -1:]) # blend A to RGB
81 | return img
82 |
83 |
84 | class GSO(Dataset):
85 | def __init__(self, root_dir='datasets/GoogleScannedObjects',\
86 | split='val', first_K=5, resolution=256, load_target=False, name='render_mvs'):
87 | self.root_dir = root_dir
88 | with open(os.path.join(root_dir, '%s.json' % split), "r") as f:
89 | self.scene_list = json.load(f)
90 | self.resolution = resolution
91 | self.first_K = first_K
92 | self.load_target = load_target
93 | self.name = name
94 |
95 | def __len__(self):
96 | return len(self.scene_list)
97 |
98 | def __getitem__(self, idx):
99 | scene_dir = os.path.join(self.root_dir, self.scene_list[idx])
100 | with open(os.path.join(scene_dir, 'transforms_%s.json' % self.name), "r") as f:
101 | meta = json.load(f)
102 | imgs = []
103 | poses = []
104 | for i_img in range(self.first_K):
105 | meta_img = meta['frames'][i_img]
106 |
107 | if i_img == 0 or self.load_target:
108 | img_path = os.path.join(scene_dir, meta_img['file_path'])
109 | img = imageio.imread(img_path)
110 | img = cv2.resize(img, (self.resolution, self.resolution), interpolation = cv2.INTER_LINEAR)
111 | imgs.append(img)
112 |
113 | c2w = meta_img['transform_matrix']
114 | poses.append(c2w)
115 |
116 | imgs = (np.array(imgs) / 255.).astype(np.float32) # (RGBA) imgs
117 | mask = imgs[:, :, :, -1]
118 | imgs = torch.tensor(self.blend_rgba(imgs)).permute(0, 3, 1, 2)
119 | imgs = imgs * 2 - 1. # convert to stable diffusion range
120 | poses = torch.tensor(np.array(poses).astype(np.float32))
121 | return imgs, poses
122 |
123 | def blend_rgba(self, img):
124 | img = img[..., :3] * img[..., -1:] + (1. - img[..., -1:]) # blend A to RGB
125 | return img
126 |
127 | class WILD(Dataset):
128 | def __init__(self, root_dir='data/nerf_wild',\
129 | first_K=33, resolution=256, load_target=False):
130 | self.root_dir = root_dir
131 | self.scene_list = sorted(next(os.walk(root_dir))[1])
132 | self.resolution = resolution
133 | self.first_K = first_K
134 | self.load_target = load_target
135 |
136 | def __len__(self):
137 | return len(self.scene_list)
138 |
139 | def __getitem__(self, idx):
140 | scene_dir = os.path.join(self.root_dir, self.scene_list[idx])
141 | with open(os.path.join(scene_dir, 'transforms_train.json'), "r") as f:
142 | meta = json.load(f)
143 | imgs = []
144 | poses = []
145 | for i_img in range(self.first_K):
146 | meta_img = meta['frames'][i_img]
147 |
148 | if i_img == 0 or self.load_target:
149 | img_path = os.path.join(scene_dir, meta_img['file_path'])
150 | img = imageio.imread(img_path + '.png')
151 | img = cv2.resize(img, (self.resolution, self.resolution), interpolation = cv2.INTER_LINEAR)
152 | imgs.append(img)
153 |
154 | c2w = meta_img['transform_matrix']
155 | poses.append(c2w)
156 |
157 | imgs = (np.array(imgs) / 255.).astype(np.float32) # (RGBA) imgs
158 | imgs = torch.tensor(self.blend_rgba(imgs)).permute(0, 3, 1, 2)
159 | imgs = imgs * 2 - 1. # convert to stable diffusion range
160 | poses = torch.tensor(np.array(poses).astype(np.float32))
161 | return imgs, poses
162 |
163 | def blend_rgba(self, img):
164 | img = img[..., :3] * img[..., -1:] + (1. - img[..., -1:]) # blend A to RGB
165 | return img
--------------------------------------------------------------------------------
/MOVIS/ldm/modules/losses/vqperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from einops import repeat
5 |
6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
7 | from taming.modules.losses.lpips import LPIPS
8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
9 |
10 |
11 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
12 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
13 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
14 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
15 | loss_real = (weights * loss_real).sum() / weights.sum()
16 | loss_fake = (weights * loss_fake).sum() / weights.sum()
17 | d_loss = 0.5 * (loss_real + loss_fake)
18 | return d_loss
19 |
20 | def adopt_weight(weight, global_step, threshold=0, value=0.):
21 | if global_step < threshold:
22 | weight = value
23 | return weight
24 |
25 |
26 | def measure_perplexity(predicted_indices, n_embed):
27 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
28 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
29 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
30 | avg_probs = encodings.mean(0)
31 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
32 | cluster_use = torch.sum(avg_probs > 0)
33 | return perplexity, cluster_use
34 |
35 | def l1(x, y):
36 | return torch.abs(x-y)
37 |
38 |
39 | def l2(x, y):
40 | return torch.pow((x-y), 2)
41 |
42 |
43 | class VQLPIPSWithDiscriminator(nn.Module):
44 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
45 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
46 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
47 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
48 | pixel_loss="l1"):
49 | super().__init__()
50 | assert disc_loss in ["hinge", "vanilla"]
51 | assert perceptual_loss in ["lpips", "clips", "dists"]
52 | assert pixel_loss in ["l1", "l2"]
53 | self.codebook_weight = codebook_weight
54 | self.pixel_weight = pixelloss_weight
55 | if perceptual_loss == "lpips":
56 | print(f"{self.__class__.__name__}: Running with LPIPS.")
57 | self.perceptual_loss = LPIPS().eval()
58 | else:
59 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
60 | self.perceptual_weight = perceptual_weight
61 |
62 | if pixel_loss == "l1":
63 | self.pixel_loss = l1
64 | else:
65 | self.pixel_loss = l2
66 |
67 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
68 | n_layers=disc_num_layers,
69 | use_actnorm=use_actnorm,
70 | ndf=disc_ndf
71 | ).apply(weights_init)
72 | self.discriminator_iter_start = disc_start
73 | if disc_loss == "hinge":
74 | self.disc_loss = hinge_d_loss
75 | elif disc_loss == "vanilla":
76 | self.disc_loss = vanilla_d_loss
77 | else:
78 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
79 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
80 | self.disc_factor = disc_factor
81 | self.discriminator_weight = disc_weight
82 | self.disc_conditional = disc_conditional
83 | self.n_classes = n_classes
84 |
85 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
86 | if last_layer is not None:
87 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
88 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
89 | else:
90 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
91 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
92 |
93 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
94 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
95 | d_weight = d_weight * self.discriminator_weight
96 | return d_weight
97 |
98 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
99 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
100 | if not exists(codebook_loss):
101 | codebook_loss = torch.tensor([0.]).to(inputs.device)
102 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
103 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
104 | if self.perceptual_weight > 0:
105 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
106 | rec_loss = rec_loss + self.perceptual_weight * p_loss
107 | else:
108 | p_loss = torch.tensor([0.0])
109 |
110 | nll_loss = rec_loss
111 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
112 | nll_loss = torch.mean(nll_loss)
113 |
114 | # now the GAN part
115 | if optimizer_idx == 0:
116 | # generator update
117 | if cond is None:
118 | assert not self.disc_conditional
119 | logits_fake = self.discriminator(reconstructions.contiguous())
120 | else:
121 | assert self.disc_conditional
122 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
123 | g_loss = -torch.mean(logits_fake)
124 |
125 | try:
126 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
127 | except RuntimeError:
128 | assert not self.training
129 | d_weight = torch.tensor(0.0)
130 |
131 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
132 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
133 |
134 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
135 | "{}/quant_loss".format(split): codebook_loss.detach().mean(),
136 | "{}/nll_loss".format(split): nll_loss.detach().mean(),
137 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
138 | "{}/p_loss".format(split): p_loss.detach().mean(),
139 | "{}/d_weight".format(split): d_weight.detach(),
140 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
141 | "{}/g_loss".format(split): g_loss.detach().mean(),
142 | }
143 | if predicted_indices is not None:
144 | assert self.n_classes is not None
145 | with torch.no_grad():
146 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
147 | log[f"{split}/perplexity"] = perplexity
148 | log[f"{split}/cluster_usage"] = cluster_usage
149 | return loss, log
150 |
151 | if optimizer_idx == 1:
152 | # second pass for discriminator update
153 | if cond is None:
154 | logits_real = self.discriminator(inputs.contiguous().detach())
155 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
156 | else:
157 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
158 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
159 |
160 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
161 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
162 |
163 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
164 | "{}/logits_real".format(split): logits_real.detach().mean(),
165 | "{}/logits_fake".format(split): logits_fake.detach().mean()
166 | }
167 | return d_loss, log
168 |
--------------------------------------------------------------------------------
/MOVIS/ldm/modules/attention.py:
--------------------------------------------------------------------------------
1 | from inspect import isfunction
2 | import math
3 | import torch
4 | import torch.nn.functional as F
5 | from torch import nn, einsum
6 | from einops import rearrange, repeat
7 |
8 | from ldm.modules.diffusionmodules.util import checkpoint
9 |
10 |
11 | def exists(val):
12 | return val is not None
13 |
14 |
15 | def uniq(arr):
16 | return{el: True for el in arr}.keys()
17 |
18 |
19 | def default(val, d):
20 | if exists(val):
21 | return val
22 | return d() if isfunction(d) else d
23 |
24 |
25 | def max_neg_value(t):
26 | return -torch.finfo(t.dtype).max
27 |
28 |
29 | def init_(tensor):
30 | dim = tensor.shape[-1]
31 | std = 1 / math.sqrt(dim)
32 | tensor.uniform_(-std, std)
33 | return tensor
34 |
35 |
36 | # feedforward
37 | class GEGLU(nn.Module):
38 | def __init__(self, dim_in, dim_out):
39 | super().__init__()
40 | self.proj = nn.Linear(dim_in, dim_out * 2)
41 |
42 | def forward(self, x):
43 | x, gate = self.proj(x).chunk(2, dim=-1)
44 | return x * F.gelu(gate)
45 |
46 |
47 | class FeedForward(nn.Module):
48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
49 | super().__init__()
50 | inner_dim = int(dim * mult)
51 | dim_out = default(dim_out, dim)
52 | project_in = nn.Sequential(
53 | nn.Linear(dim, inner_dim),
54 | nn.GELU()
55 | ) if not glu else GEGLU(dim, inner_dim)
56 |
57 | self.net = nn.Sequential(
58 | project_in,
59 | nn.Dropout(dropout),
60 | nn.Linear(inner_dim, dim_out)
61 | )
62 |
63 | def forward(self, x):
64 | return self.net(x)
65 |
66 |
67 | def zero_module(module):
68 | """
69 | Zero out the parameters of a module and return it.
70 | """
71 | for p in module.parameters():
72 | p.detach().zero_()
73 | return module
74 |
75 |
76 | def Normalize(in_channels):
77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
78 |
79 |
80 | class LinearAttention(nn.Module):
81 | def __init__(self, dim, heads=4, dim_head=32):
82 | super().__init__()
83 | self.heads = heads
84 | hidden_dim = dim_head * heads
85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1)
87 |
88 | def forward(self, x):
89 | b, c, h, w = x.shape
90 | qkv = self.to_qkv(x)
91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
92 | k = k.softmax(dim=-1)
93 | context = torch.einsum('bhdn,bhen->bhde', k, v)
94 | out = torch.einsum('bhde,bhdn->bhen', context, q)
95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
96 | return self.to_out(out)
97 |
98 |
99 | class SpatialSelfAttention(nn.Module):
100 | def __init__(self, in_channels):
101 | super().__init__()
102 | self.in_channels = in_channels
103 |
104 | self.norm = Normalize(in_channels)
105 | self.q = torch.nn.Conv2d(in_channels,
106 | in_channels,
107 | kernel_size=1,
108 | stride=1,
109 | padding=0)
110 | self.k = torch.nn.Conv2d(in_channels,
111 | in_channels,
112 | kernel_size=1,
113 | stride=1,
114 | padding=0)
115 | self.v = torch.nn.Conv2d(in_channels,
116 | in_channels,
117 | kernel_size=1,
118 | stride=1,
119 | padding=0)
120 | self.proj_out = torch.nn.Conv2d(in_channels,
121 | in_channels,
122 | kernel_size=1,
123 | stride=1,
124 | padding=0)
125 |
126 | def forward(self, x):
127 | h_ = x
128 | h_ = self.norm(h_)
129 | q = self.q(h_)
130 | k = self.k(h_)
131 | v = self.v(h_)
132 |
133 | # compute attention
134 | b,c,h,w = q.shape
135 | q = rearrange(q, 'b c h w -> b (h w) c')
136 | k = rearrange(k, 'b c h w -> b c (h w)')
137 | w_ = torch.einsum('bij,bjk->bik', q, k)
138 |
139 | w_ = w_ * (int(c)**(-0.5))
140 | w_ = torch.nn.functional.softmax(w_, dim=2)
141 |
142 | # attend to values
143 | v = rearrange(v, 'b c h w -> b c (h w)')
144 | w_ = rearrange(w_, 'b i j -> b j i')
145 | h_ = torch.einsum('bij,bjk->bik', v, w_)
146 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
147 | h_ = self.proj_out(h_)
148 |
149 | return x+h_
150 |
151 |
152 | class CrossAttention(nn.Module):
153 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
154 | super().__init__()
155 | inner_dim = dim_head * heads
156 | context_dim = default(context_dim, query_dim)
157 |
158 | self.scale = dim_head ** -0.5
159 | self.heads = heads
160 |
161 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
162 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
163 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
164 |
165 | self.to_out = nn.Sequential(
166 | nn.Linear(inner_dim, query_dim),
167 | nn.Dropout(dropout)
168 | )
169 |
170 | def forward(self, x, context=None, mask=None):
171 | h = self.heads
172 |
173 | q = self.to_q(x)
174 | context = default(context, x)
175 | k = self.to_k(context)
176 | v = self.to_v(context)
177 |
178 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
179 |
180 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
181 |
182 | if exists(mask):
183 | mask = rearrange(mask, 'b ... -> b (...)')
184 | max_neg_value = -torch.finfo(sim.dtype).max
185 | mask = repeat(mask, 'b j -> (b h) () j', h=h)
186 | sim.masked_fill_(~mask, max_neg_value)
187 |
188 | # attention, what we cannot get enough of
189 | attn = sim.softmax(dim=-1)
190 |
191 | out = einsum('b i j, b j d -> b i d', attn, v)
192 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
193 | return self.to_out(out)
194 |
195 |
196 | class BasicTransformerBlock(nn.Module):
197 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
198 | disable_self_attn=False):
199 | super().__init__()
200 | self.disable_self_attn = disable_self_attn
201 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
202 | context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
203 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
204 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
205 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
206 | self.norm1 = nn.LayerNorm(dim)
207 | self.norm2 = nn.LayerNorm(dim)
208 | self.norm3 = nn.LayerNorm(dim)
209 | self.checkpoint = checkpoint
210 |
211 | def forward(self, x, context=None):
212 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
213 |
214 | def _forward(self, x, context=None):
215 | x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
216 | x = self.attn2(self.norm2(x), context=context) + x
217 | x = self.ff(self.norm3(x)) + x
218 | return x
219 |
220 |
221 | class SpatialTransformer(nn.Module):
222 | """
223 | Transformer block for image-like data.
224 | First, project the input (aka embedding)
225 | and reshape to b, t, d.
226 | Then apply standard transformer action.
227 | Finally, reshape to image
228 | """
229 | def __init__(self, in_channels, n_heads, d_head,
230 | depth=1, dropout=0., context_dim=None,
231 | disable_self_attn=False):
232 | super().__init__()
233 | self.in_channels = in_channels
234 | inner_dim = n_heads * d_head
235 | self.norm = Normalize(in_channels)
236 |
237 | self.proj_in = nn.Conv2d(in_channels,
238 | inner_dim,
239 | kernel_size=1,
240 | stride=1,
241 | padding=0)
242 |
243 | self.transformer_blocks = nn.ModuleList(
244 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim,
245 | disable_self_attn=disable_self_attn)
246 | for d in range(depth)]
247 | )
248 |
249 | self.proj_out = zero_module(nn.Conv2d(inner_dim,
250 | in_channels,
251 | kernel_size=1,
252 | stride=1,
253 | padding=0))
254 |
255 | def forward(self, x, context=None):
256 | # note: if no context is given, cross-attention defaults to self-attention
257 | b, c, h, w = x.shape
258 | x_in = x
259 | x = self.norm(x)
260 | x = self.proj_in(x)
261 | x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
262 | for block in self.transformer_blocks:
263 | x = block(x, context=context)
264 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
265 | x = self.proj_out(x)
266 | return x + x_in
267 |
--------------------------------------------------------------------------------
/MOVIS/ldm/modules/diffusionmodules/util.py:
--------------------------------------------------------------------------------
1 | # adopted from
2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3 | # and
4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5 | # and
6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7 | #
8 | # thanks!
9 |
10 |
11 | import os
12 | import math
13 | import torch
14 | import torch.nn as nn
15 | import numpy as np
16 | from einops import repeat
17 |
18 | from ldm.util import instantiate_from_config
19 |
20 |
21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
22 | if schedule == "linear":
23 | betas = (
24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
25 | )
26 |
27 | elif schedule == "cosine":
28 | timesteps = (
29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
30 | )
31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2
32 | alphas = torch.cos(alphas).pow(2)
33 | alphas = alphas / alphas[0]
34 | betas = 1 - alphas[1:] / alphas[:-1]
35 | betas = np.clip(betas, a_min=0, a_max=0.999)
36 |
37 | elif schedule == "sqrt_linear":
38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
39 | elif schedule == "sqrt":
40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
41 | else:
42 | raise ValueError(f"schedule '{schedule}' unknown.")
43 | return betas.numpy()
44 |
45 |
46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
47 | if ddim_discr_method == 'uniform':
48 | c = num_ddpm_timesteps // num_ddim_timesteps
49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
50 | elif ddim_discr_method == 'quad':
51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
52 | else:
53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
54 |
55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps
56 | # add one to get the final alpha values right (the ones from first scale to data during sampling)
57 | steps_out = ddim_timesteps + 1
58 | if verbose:
59 | print(f'Selected timesteps for ddim sampler: {steps_out}')
60 | return steps_out
61 |
62 |
63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
64 | # select alphas for computing the variance schedule
65 | alphas = alphacums[ddim_timesteps]
66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
67 |
68 | # according the the formula provided in https://arxiv.org/abs/2010.02502
69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
70 | if verbose:
71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
72 | print(f'For the chosen value of eta, which is {eta}, '
73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
74 | return sigmas, alphas, alphas_prev
75 |
76 |
77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
78 | """
79 | Create a beta schedule that discretizes the given alpha_t_bar function,
80 | which defines the cumulative product of (1-beta) over time from t = [0,1].
81 | :param num_diffusion_timesteps: the number of betas to produce.
82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
83 | produces the cumulative product of (1-beta) up to that
84 | part of the diffusion process.
85 | :param max_beta: the maximum beta to use; use values lower than 1 to
86 | prevent singularities.
87 | """
88 | betas = []
89 | for i in range(num_diffusion_timesteps):
90 | t1 = i / num_diffusion_timesteps
91 | t2 = (i + 1) / num_diffusion_timesteps
92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
93 | return np.array(betas)
94 |
95 |
96 | def extract_into_tensor(a, t, x_shape):
97 | b, *_ = t.shape
98 | out = a.gather(-1, t)
99 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
100 |
101 |
102 | def checkpoint(func, inputs, params, flag):
103 | """
104 | Evaluate a function without caching intermediate activations, allowing for
105 | reduced memory at the expense of extra compute in the backward pass.
106 | :param func: the function to evaluate.
107 | :param inputs: the argument sequence to pass to `func`.
108 | :param params: a sequence of parameters `func` depends on but does not
109 | explicitly take as arguments.
110 | :param flag: if False, disable gradient checkpointing.
111 | """
112 | if flag:
113 | args = tuple(inputs) + tuple(params)
114 | return CheckpointFunction.apply(func, len(inputs), *args)
115 | else:
116 | return func(*inputs)
117 |
118 |
119 | class CheckpointFunction(torch.autograd.Function):
120 | @staticmethod
121 | def forward(ctx, run_function, length, *args):
122 | ctx.run_function = run_function
123 | ctx.input_tensors = list(args[:length])
124 | ctx.input_params = list(args[length:])
125 |
126 | with torch.no_grad():
127 | output_tensors = ctx.run_function(*ctx.input_tensors)
128 | return output_tensors
129 |
130 | @staticmethod
131 | def backward(ctx, *output_grads):
132 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
133 | with torch.enable_grad():
134 | # Fixes a bug where the first op in run_function modifies the
135 | # Tensor storage in place, which is not allowed for detach()'d
136 | # Tensors.
137 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
138 | output_tensors = ctx.run_function(*shallow_copies)
139 | input_grads = torch.autograd.grad(
140 | output_tensors,
141 | ctx.input_tensors + ctx.input_params,
142 | output_grads,
143 | allow_unused=True,
144 | )
145 | del ctx.input_tensors
146 | del ctx.input_params
147 | del output_tensors
148 | return (None, None) + input_grads
149 |
150 |
151 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
152 | """
153 | Create sinusoidal timestep embeddings.
154 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
155 | These may be fractional.
156 | :param dim: the dimension of the output.
157 | :param max_period: controls the minimum frequency of the embeddings.
158 | :return: an [N x dim] Tensor of positional embeddings.
159 | """
160 | if not repeat_only:
161 | half = dim // 2
162 | freqs = torch.exp(
163 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
164 | ).to(device=timesteps.device)
165 | args = timesteps[:, None].float() * freqs[None]
166 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
167 | if dim % 2:
168 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
169 | else:
170 | embedding = repeat(timesteps, 'b -> b d', d=dim)
171 | return embedding
172 |
173 |
174 | def zero_module(module):
175 | """
176 | Zero out the parameters of a module and return it.
177 | """
178 | for p in module.parameters():
179 | p.detach().zero_()
180 | return module
181 |
182 |
183 | def scale_module(module, scale):
184 | """
185 | Scale the parameters of a module and return it.
186 | """
187 | for p in module.parameters():
188 | p.detach().mul_(scale)
189 | return module
190 |
191 |
192 | def mean_flat(tensor):
193 | """
194 | Take the mean over all non-batch dimensions.
195 | """
196 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
197 |
198 |
199 | def normalization(channels):
200 | """
201 | Make a standard normalization layer.
202 | :param channels: number of input channels.
203 | :return: an nn.Module for normalization.
204 | """
205 | return GroupNorm32(32, channels)
206 |
207 |
208 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
209 | class SiLU(nn.Module):
210 | def forward(self, x):
211 | return x * torch.sigmoid(x)
212 |
213 |
214 | class GroupNorm32(nn.GroupNorm):
215 | def forward(self, x):
216 | return super().forward(x.float()).type(x.dtype)
217 |
218 | def conv_nd(dims, *args, **kwargs):
219 | """
220 | Create a 1D, 2D, or 3D convolution module.
221 | """
222 | if dims == 1:
223 | return nn.Conv1d(*args, **kwargs)
224 | elif dims == 2:
225 | return nn.Conv2d(*args, **kwargs)
226 | elif dims == 3:
227 | return nn.Conv3d(*args, **kwargs)
228 | raise ValueError(f"unsupported dimensions: {dims}")
229 |
230 |
231 | def linear(*args, **kwargs):
232 | """
233 | Create a linear module.
234 | """
235 | return nn.Linear(*args, **kwargs)
236 |
237 |
238 | def avg_pool_nd(dims, *args, **kwargs):
239 | """
240 | Create a 1D, 2D, or 3D average pooling module.
241 | """
242 | if dims == 1:
243 | return nn.AvgPool1d(*args, **kwargs)
244 | elif dims == 2:
245 | return nn.AvgPool2d(*args, **kwargs)
246 | elif dims == 3:
247 | return nn.AvgPool3d(*args, **kwargs)
248 | raise ValueError(f"unsupported dimensions: {dims}")
249 |
250 |
251 | class HybridConditioner(nn.Module):
252 |
253 | def __init__(self, c_concat_config, c_crossattn_config):
254 | super().__init__()
255 | self.concat_conditioner = instantiate_from_config(c_concat_config)
256 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
257 |
258 | def forward(self, c_concat, c_crossattn):
259 | c_concat = self.concat_conditioner(c_concat)
260 | c_crossattn = self.crossattn_conditioner(c_crossattn)
261 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
262 |
263 |
264 | def noise_like(shape, device, repeat=False):
265 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
266 | noise = lambda: torch.randn(shape, device=device)
267 | return repeat_noise() if repeat else noise()
--------------------------------------------------------------------------------
/MOVIS/ldm/util.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | import torchvision
4 | import torch
5 | from torch import optim
6 | import numpy as np
7 |
8 | from inspect import isfunction
9 | from PIL import Image, ImageDraw, ImageFont
10 |
11 | import os
12 | import numpy as np
13 | import matplotlib.pyplot as plt
14 | from PIL import Image
15 | import torch
16 | import time
17 | import cv2
18 | from carvekit.api.high import HiInterface
19 | import PIL
20 |
21 | def pil_rectangle_crop(im):
22 | width, height = im.size # Get dimensions
23 |
24 | if width <= height:
25 | left = 0
26 | right = width
27 | top = (height - width)/2
28 | bottom = (height + width)/2
29 | else:
30 |
31 | top = 0
32 | bottom = height
33 | left = (width - height) / 2
34 | bottom = (width + height) / 2
35 |
36 | # Crop the center of the image
37 | im = im.crop((left, top, right, bottom))
38 | return im
39 |
40 | def add_margin(pil_img, color, size=256):
41 | width, height = pil_img.size
42 | result = Image.new(pil_img.mode, (size, size), color)
43 | result.paste(pil_img, ((size - width) // 2, (size - height) // 2))
44 | return result
45 |
46 |
47 | def create_carvekit_interface():
48 | # Check doc strings for more information
49 | interface = HiInterface(object_type="object", # Can be "object" or "hairs-like".
50 | batch_size_seg=5,
51 | batch_size_matting=1,
52 | device='cuda' if torch.cuda.is_available() else 'cpu',
53 | seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
54 | matting_mask_size=2048,
55 | trimap_prob_threshold=231,
56 | trimap_dilation=30,
57 | trimap_erosion_iters=5,
58 | fp16=False)
59 |
60 | return interface
61 |
62 |
63 | def load_and_preprocess(interface, input_im):
64 | '''
65 | :param input_im (PIL Image).
66 | :return image (H, W, 3) array in [0, 1].
67 | '''
68 | # See https://github.com/Ir1d/image-background-remove-tool
69 | image = input_im.convert('RGB')
70 |
71 | image_without_background = interface([image])[0]
72 | image_without_background = np.array(image_without_background)
73 | est_seg = image_without_background > 127
74 | image = np.array(image)
75 | foreground = est_seg[:, : , -1].astype(np.bool_)
76 | image[~foreground] = [255., 255., 255.]
77 | x, y, w, h = cv2.boundingRect(foreground.astype(np.uint8))
78 | image = image[y:y+h, x:x+w, :]
79 | image = PIL.Image.fromarray(np.array(image))
80 |
81 | # resize image such that long edge is 512
82 | image.thumbnail([200, 200], Image.Resampling.LANCZOS)
83 | image = add_margin(image, (255, 255, 255), size=256)
84 | image = np.array(image)
85 |
86 | return image
87 |
88 |
89 | def log_txt_as_img(wh, xc, size=10):
90 | # wh a tuple of (width, height)
91 | # xc a list of captions to plot
92 | b = len(xc)
93 | txts = list()
94 | for bi in range(b):
95 | txt = Image.new("RGB", wh, color="white")
96 | draw = ImageDraw.Draw(txt)
97 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
98 | nc = int(40 * (wh[0] / 256))
99 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
100 |
101 | try:
102 | draw.text((0, 0), lines, fill="black", font=font)
103 | except UnicodeEncodeError:
104 | print("Cant encode string for logging. Skipping.")
105 |
106 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
107 | txts.append(txt)
108 | txts = np.stack(txts)
109 | txts = torch.tensor(txts)
110 | return txts
111 |
112 |
113 | def ismap(x):
114 | if not isinstance(x, torch.Tensor):
115 | return False
116 | return (len(x.shape) == 4) and (x.shape[1] > 3)
117 |
118 |
119 | def isimage(x):
120 | if not isinstance(x,torch.Tensor):
121 | return False
122 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
123 |
124 |
125 | def exists(x):
126 | return x is not None
127 |
128 |
129 | def default(val, d):
130 | if exists(val):
131 | return val
132 | return d() if isfunction(d) else d
133 |
134 |
135 | def mean_flat(tensor):
136 | """
137 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
138 | Take the mean over all non-batch dimensions.
139 | """
140 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
141 |
142 |
143 | def count_params(model, verbose=False):
144 | total_params = sum(p.numel() for p in model.parameters())
145 | if verbose:
146 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
147 | return total_params
148 |
149 |
150 | def instantiate_from_config(config):
151 | if not "target" in config:
152 | if config == '__is_first_stage__':
153 | return None
154 | elif config == "__is_unconditional__":
155 | return None
156 | raise KeyError("Expected key `target` to instantiate.")
157 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
158 |
159 |
160 | def get_obj_from_str(string, reload=False):
161 | module, cls = string.rsplit(".", 1)
162 | if reload:
163 | module_imp = importlib.import_module(module)
164 | importlib.reload(module_imp)
165 | return getattr(importlib.import_module(module, package=None), cls)
166 |
167 |
168 | class AdamWwithEMAandWings(optim.Optimizer):
169 | # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
170 | def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
171 | weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
172 | ema_power=1., param_names=()):
173 | """AdamW that saves EMA versions of the parameters."""
174 | if not 0.0 <= lr:
175 | raise ValueError("Invalid learning rate: {}".format(lr))
176 | if not 0.0 <= eps:
177 | raise ValueError("Invalid epsilon value: {}".format(eps))
178 | if not 0.0 <= betas[0] < 1.0:
179 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
180 | if not 0.0 <= betas[1] < 1.0:
181 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
182 | if not 0.0 <= weight_decay:
183 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
184 | if not 0.0 <= ema_decay <= 1.0:
185 | raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
186 | defaults = dict(lr=lr, betas=betas, eps=eps,
187 | weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
188 | ema_power=ema_power, param_names=param_names)
189 | super().__init__(params, defaults)
190 |
191 | def __setstate__(self, state):
192 | super().__setstate__(state)
193 | for group in self.param_groups:
194 | group.setdefault('amsgrad', False)
195 |
196 | @torch.no_grad()
197 | def step(self, closure=None):
198 | """Performs a single optimization step.
199 | Args:
200 | closure (callable, optional): A closure that reevaluates the model
201 | and returns the loss.
202 | """
203 | loss = None
204 | if closure is not None:
205 | with torch.enable_grad():
206 | loss = closure()
207 |
208 | for group in self.param_groups:
209 | params_with_grad = []
210 | grads = []
211 | exp_avgs = []
212 | exp_avg_sqs = []
213 | ema_params_with_grad = []
214 | state_sums = []
215 | max_exp_avg_sqs = []
216 | state_steps = []
217 | amsgrad = group['amsgrad']
218 | beta1, beta2 = group['betas']
219 | ema_decay = group['ema_decay']
220 | ema_power = group['ema_power']
221 |
222 | for p in group['params']:
223 | if p.grad is None:
224 | continue
225 | params_with_grad.append(p)
226 | if p.grad.is_sparse:
227 | raise RuntimeError('AdamW does not support sparse gradients')
228 | grads.append(p.grad)
229 |
230 | state = self.state[p]
231 |
232 | # State initialization
233 | if len(state) == 0:
234 | state['step'] = 0
235 | # Exponential moving average of gradient values
236 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
237 | # Exponential moving average of squared gradient values
238 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
239 | if amsgrad:
240 | # Maintains max of all exp. moving avg. of sq. grad. values
241 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
242 | # Exponential moving average of parameter values
243 | state['param_exp_avg'] = p.detach().float().clone()
244 |
245 | exp_avgs.append(state['exp_avg'])
246 | exp_avg_sqs.append(state['exp_avg_sq'])
247 | ema_params_with_grad.append(state['param_exp_avg'])
248 |
249 | if amsgrad:
250 | max_exp_avg_sqs.append(state['max_exp_avg_sq'])
251 |
252 | # update the steps for each param group update
253 | state['step'] += 1
254 | # record the step after step update
255 | state_steps.append(state['step'])
256 |
257 | optim._functional.adamw(params_with_grad,
258 | grads,
259 | exp_avgs,
260 | exp_avg_sqs,
261 | max_exp_avg_sqs,
262 | state_steps,
263 | amsgrad=amsgrad,
264 | beta1=beta1,
265 | beta2=beta2,
266 | lr=group['lr'],
267 | weight_decay=group['weight_decay'],
268 | eps=group['eps'],
269 | maximize=False)
270 |
271 | cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
272 | for param, ema_param in zip(params_with_grad, ema_params_with_grad):
273 | ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
274 |
275 | return loss
--------------------------------------------------------------------------------
/MOVIS/ldm/models/diffusion/classifier.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import pytorch_lightning as pl
4 | from omegaconf import OmegaConf
5 | from torch.nn import functional as F
6 | from torch.optim import AdamW
7 | from torch.optim.lr_scheduler import LambdaLR
8 | from copy import deepcopy
9 | from einops import rearrange
10 | from glob import glob
11 | from natsort import natsorted
12 |
13 | from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
14 | from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
15 |
16 | __models__ = {
17 | 'class_label': EncoderUNetModel,
18 | 'segmentation': UNetModel
19 | }
20 |
21 |
22 | def disabled_train(self, mode=True):
23 | """Overwrite model.train with this function to make sure train/eval mode
24 | does not change anymore."""
25 | return self
26 |
27 |
28 | class NoisyLatentImageClassifier(pl.LightningModule):
29 |
30 | def __init__(self,
31 | diffusion_path,
32 | num_classes,
33 | ckpt_path=None,
34 | pool='attention',
35 | label_key=None,
36 | diffusion_ckpt_path=None,
37 | scheduler_config=None,
38 | weight_decay=1.e-2,
39 | log_steps=10,
40 | monitor='val/loss',
41 | *args,
42 | **kwargs):
43 | super().__init__(*args, **kwargs)
44 | self.num_classes = num_classes
45 | # get latest config of diffusion model
46 | diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
47 | self.diffusion_config = OmegaConf.load(diffusion_config).model
48 | self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
49 | self.load_diffusion()
50 |
51 | self.monitor = monitor
52 | self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
53 | self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
54 | self.log_steps = log_steps
55 |
56 | self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
57 | else self.diffusion_model.cond_stage_key
58 |
59 | assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
60 |
61 | if self.label_key not in __models__:
62 | raise NotImplementedError()
63 |
64 | self.load_classifier(ckpt_path, pool)
65 |
66 | self.scheduler_config = scheduler_config
67 | self.use_scheduler = self.scheduler_config is not None
68 | self.weight_decay = weight_decay
69 |
70 | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
71 | sd = torch.load(path, map_location="cpu")
72 | if "state_dict" in list(sd.keys()):
73 | sd = sd["state_dict"]
74 | keys = list(sd.keys())
75 | for k in keys:
76 | for ik in ignore_keys:
77 | if k.startswith(ik):
78 | print("Deleting key {} from state_dict.".format(k))
79 | del sd[k]
80 | missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
81 | sd, strict=False)
82 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
83 | if len(missing) > 0:
84 | print(f"Missing Keys: {missing}")
85 | if len(unexpected) > 0:
86 | print(f"Unexpected Keys: {unexpected}")
87 |
88 | def load_diffusion(self):
89 | model = instantiate_from_config(self.diffusion_config)
90 | self.diffusion_model = model.eval()
91 | self.diffusion_model.train = disabled_train
92 | for param in self.diffusion_model.parameters():
93 | param.requires_grad = False
94 |
95 | def load_classifier(self, ckpt_path, pool):
96 | model_config = deepcopy(self.diffusion_config.params.unet_config.params)
97 | model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
98 | model_config.out_channels = self.num_classes
99 | if self.label_key == 'class_label':
100 | model_config.pool = pool
101 |
102 | self.model = __models__[self.label_key](**model_config)
103 | if ckpt_path is not None:
104 | print('#####################################################################')
105 | print(f'load from ckpt "{ckpt_path}"')
106 | print('#####################################################################')
107 | self.init_from_ckpt(ckpt_path)
108 |
109 | @torch.no_grad()
110 | def get_x_noisy(self, x, t, noise=None):
111 | noise = default(noise, lambda: torch.randn_like(x))
112 | continuous_sqrt_alpha_cumprod = None
113 | if self.diffusion_model.use_continuous_noise:
114 | continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
115 | # todo: make sure t+1 is correct here
116 |
117 | return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
118 | continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
119 |
120 | def forward(self, x_noisy, t, *args, **kwargs):
121 | return self.model(x_noisy, t)
122 |
123 | @torch.no_grad()
124 | def get_input(self, batch, k):
125 | x = batch[k]
126 | if len(x.shape) == 3:
127 | x = x[..., None]
128 | x = rearrange(x, 'b h w c -> b c h w')
129 | x = x.to(memory_format=torch.contiguous_format).float()
130 | return x
131 |
132 | @torch.no_grad()
133 | def get_conditioning(self, batch, k=None):
134 | if k is None:
135 | k = self.label_key
136 | assert k is not None, 'Needs to provide label key'
137 |
138 | targets = batch[k].to(self.device)
139 |
140 | if self.label_key == 'segmentation':
141 | targets = rearrange(targets, 'b h w c -> b c h w')
142 | for down in range(self.numd):
143 | h, w = targets.shape[-2:]
144 | targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
145 |
146 | # targets = rearrange(targets,'b c h w -> b h w c')
147 |
148 | return targets
149 |
150 | def compute_top_k(self, logits, labels, k, reduction="mean"):
151 | _, top_ks = torch.topk(logits, k, dim=1)
152 | if reduction == "mean":
153 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
154 | elif reduction == "none":
155 | return (top_ks == labels[:, None]).float().sum(dim=-1)
156 |
157 | def on_train_epoch_start(self):
158 | # save some memory
159 | self.diffusion_model.model.to('cpu')
160 |
161 | @torch.no_grad()
162 | def write_logs(self, loss, logits, targets):
163 | log_prefix = 'train' if self.training else 'val'
164 | log = {}
165 | log[f"{log_prefix}/loss"] = loss.mean()
166 | log[f"{log_prefix}/acc@1"] = self.compute_top_k(
167 | logits, targets, k=1, reduction="mean"
168 | )
169 | log[f"{log_prefix}/acc@5"] = self.compute_top_k(
170 | logits, targets, k=5, reduction="mean"
171 | )
172 |
173 | self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
174 | self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
175 | self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
176 | lr = self.optimizers().param_groups[0]['lr']
177 | self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
178 |
179 | def shared_step(self, batch, t=None):
180 | x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
181 | targets = self.get_conditioning(batch)
182 | if targets.dim() == 4:
183 | targets = targets.argmax(dim=1)
184 | if t is None:
185 | t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
186 | else:
187 | t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
188 | x_noisy = self.get_x_noisy(x, t)
189 | logits = self(x_noisy, t)
190 |
191 | loss = F.cross_entropy(logits, targets, reduction='none')
192 |
193 | self.write_logs(loss.detach(), logits.detach(), targets.detach())
194 |
195 | loss = loss.mean()
196 | return loss, logits, x_noisy, targets
197 |
198 | def training_step(self, batch, batch_idx):
199 | loss, *_ = self.shared_step(batch)
200 | return loss
201 |
202 | def reset_noise_accs(self):
203 | self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
204 | range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
205 |
206 | def on_validation_start(self):
207 | self.reset_noise_accs()
208 |
209 | @torch.no_grad()
210 | def validation_step(self, batch, batch_idx):
211 | loss, *_ = self.shared_step(batch)
212 |
213 | for t in self.noisy_acc:
214 | _, logits, _, targets = self.shared_step(batch, t)
215 | self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
216 | self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
217 |
218 | return loss
219 |
220 | def configure_optimizers(self):
221 | optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
222 |
223 | if self.use_scheduler:
224 | scheduler = instantiate_from_config(self.scheduler_config)
225 |
226 | print("Setting up LambdaLR scheduler...")
227 | scheduler = [
228 | {
229 | 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
230 | 'interval': 'step',
231 | 'frequency': 1
232 | }]
233 | return [optimizer], scheduler
234 |
235 | return optimizer
236 |
237 | @torch.no_grad()
238 | def log_images(self, batch, N=8, *args, **kwargs):
239 | log = dict()
240 | x = self.get_input(batch, self.diffusion_model.first_stage_key)
241 | log['inputs'] = x
242 |
243 | y = self.get_conditioning(batch)
244 |
245 | if self.label_key == 'class_label':
246 | y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
247 | log['labels'] = y
248 |
249 | if ismap(y):
250 | log['labels'] = self.diffusion_model.to_rgb(y)
251 |
252 | for step in range(self.log_steps):
253 | current_time = step * self.log_time_interval
254 |
255 | _, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
256 |
257 | log[f'inputs@t{current_time}'] = x_noisy
258 |
259 | pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
260 | pred = rearrange(pred, 'b h w c -> b c h w')
261 |
262 | log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
263 |
264 | for key in log:
265 | log[key] = log[key][:N]
266 |
267 | return log
268 |
--------------------------------------------------------------------------------
/MOVIS/ldm/modules/evaluate/torch_frechet_video_distance.py:
--------------------------------------------------------------------------------
1 | # based on https://github.com/universome/fvd-comparison/blob/master/compare_models.py; huge thanks!
2 | import os
3 | import numpy as np
4 | import io
5 | import re
6 | import requests
7 | import html
8 | import hashlib
9 | import urllib
10 | import urllib.request
11 | import scipy.linalg
12 | import multiprocessing as mp
13 | import glob
14 |
15 |
16 | from tqdm import tqdm
17 | from typing import Any, List, Tuple, Union, Dict, Callable
18 |
19 | from torchvision.io import read_video
20 | import torch; torch.set_grad_enabled(False)
21 | from einops import rearrange
22 |
23 | from nitro.util import isvideo
24 |
25 | def compute_frechet_distance(mu_sample,sigma_sample,mu_ref,sigma_ref) -> float:
26 | print('Calculate frechet distance...')
27 | m = np.square(mu_sample - mu_ref).sum()
28 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_sample, sigma_ref), disp=False) # pylint: disable=no-member
29 | fid = np.real(m + np.trace(sigma_sample + sigma_ref - s * 2))
30 |
31 | return float(fid)
32 |
33 |
34 | def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
35 | mu = feats.mean(axis=0) # [d]
36 | sigma = np.cov(feats, rowvar=False) # [d, d]
37 |
38 | return mu, sigma
39 |
40 |
41 | def open_url(url: str, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False) -> Any:
42 | """Download the given URL and return a binary-mode file object to access the data."""
43 | assert num_attempts >= 1
44 |
45 | # Doesn't look like an URL scheme so interpret it as a local filename.
46 | if not re.match('^[a-z]+://', url):
47 | return url if return_filename else open(url, "rb")
48 |
49 | # Handle file URLs. This code handles unusual file:// patterns that
50 | # arise on Windows:
51 | #
52 | # file:///c:/foo.txt
53 | #
54 | # which would translate to a local '/c:/foo.txt' filename that's
55 | # invalid. Drop the forward slash for such pathnames.
56 | #
57 | # If you touch this code path, you should test it on both Linux and
58 | # Windows.
59 | #
60 | # Some internet resources suggest using urllib.request.url2pathname() but
61 | # but that converts forward slashes to backslashes and this causes
62 | # its own set of problems.
63 | if url.startswith('file://'):
64 | filename = urllib.parse.urlparse(url).path
65 | if re.match(r'^/[a-zA-Z]:', filename):
66 | filename = filename[1:]
67 | return filename if return_filename else open(filename, "rb")
68 |
69 | url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
70 |
71 | # Download.
72 | url_name = None
73 | url_data = None
74 | with requests.Session() as session:
75 | if verbose:
76 | print("Downloading %s ..." % url, end="", flush=True)
77 | for attempts_left in reversed(range(num_attempts)):
78 | try:
79 | with session.get(url) as res:
80 | res.raise_for_status()
81 | if len(res.content) == 0:
82 | raise IOError("No data received")
83 |
84 | if len(res.content) < 8192:
85 | content_str = res.content.decode("utf-8")
86 | if "download_warning" in res.headers.get("Set-Cookie", ""):
87 | links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
88 | if len(links) == 1:
89 | url = requests.compat.urljoin(url, links[0])
90 | raise IOError("Google Drive virus checker nag")
91 | if "Google Drive - Quota exceeded" in content_str:
92 | raise IOError("Google Drive download quota exceeded -- please try again later")
93 |
94 | match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
95 | url_name = match[1] if match else url
96 | url_data = res.content
97 | if verbose:
98 | print(" done")
99 | break
100 | except KeyboardInterrupt:
101 | raise
102 | except:
103 | if not attempts_left:
104 | if verbose:
105 | print(" failed")
106 | raise
107 | if verbose:
108 | print(".", end="", flush=True)
109 |
110 | # Return data as file object.
111 | assert not return_filename
112 | return io.BytesIO(url_data)
113 |
114 | def load_video(ip):
115 | vid, *_ = read_video(ip)
116 | vid = rearrange(vid, 't h w c -> t c h w').to(torch.uint8)
117 | return vid
118 |
119 | def get_data_from_str(input_str,nprc = None):
120 | assert os.path.isdir(input_str), f'Specified input folder "{input_str}" is not a directory'
121 | vid_filelist = glob.glob(os.path.join(input_str,'*.mp4'))
122 | print(f'Found {len(vid_filelist)} videos in dir {input_str}')
123 |
124 | if nprc is None:
125 | try:
126 | nprc = mp.cpu_count()
127 | except NotImplementedError:
128 | print('WARNING: cpu_count() not avlailable, using only 1 cpu for video loading')
129 | nprc = 1
130 |
131 | pool = mp.Pool(processes=nprc)
132 |
133 | vids = []
134 | for v in tqdm(pool.imap_unordered(load_video,vid_filelist),total=len(vid_filelist),desc='Loading videos...'):
135 | vids.append(v)
136 |
137 |
138 | vids = torch.stack(vids,dim=0).float()
139 |
140 | return vids
141 |
142 | def get_stats(stats):
143 | assert os.path.isfile(stats) and stats.endswith('.npz'), f'no stats found under {stats}'
144 |
145 | print(f'Using precomputed statistics under {stats}')
146 | stats = np.load(stats)
147 | stats = {key: stats[key] for key in stats.files}
148 |
149 | return stats
150 |
151 |
152 |
153 |
154 | @torch.no_grad()
155 | def compute_fvd(ref_input, sample_input, bs=32,
156 | ref_stats=None,
157 | sample_stats=None,
158 | nprc_load=None):
159 |
160 |
161 |
162 | calc_stats = ref_stats is None or sample_stats is None
163 |
164 | if calc_stats:
165 |
166 | only_ref = sample_stats is not None
167 | only_sample = ref_stats is not None
168 |
169 |
170 | if isinstance(ref_input,str) and not only_sample:
171 | ref_input = get_data_from_str(ref_input,nprc_load)
172 |
173 | if isinstance(sample_input, str) and not only_ref:
174 | sample_input = get_data_from_str(sample_input, nprc_load)
175 |
176 | stats = compute_statistics(sample_input,ref_input,
177 | device='cuda' if torch.cuda.is_available() else 'cpu',
178 | bs=bs,
179 | only_ref=only_ref,
180 | only_sample=only_sample)
181 |
182 | if only_ref:
183 | stats.update(get_stats(sample_stats))
184 | elif only_sample:
185 | stats.update(get_stats(ref_stats))
186 |
187 |
188 |
189 | else:
190 | stats = get_stats(sample_stats)
191 | stats.update(get_stats(ref_stats))
192 |
193 | fvd = compute_frechet_distance(**stats)
194 |
195 | return {'FVD' : fvd,}
196 |
197 |
198 | @torch.no_grad()
199 | def compute_statistics(videos_fake, videos_real, device: str='cuda', bs=32, only_ref=False,only_sample=False) -> Dict:
200 | detector_url = 'https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt?dl=1'
201 | detector_kwargs = dict(rescale=True, resize=True, return_features=True) # Return raw features before the softmax layer.
202 |
203 | with open_url(detector_url, verbose=False) as f:
204 | detector = torch.jit.load(f).eval().to(device)
205 |
206 |
207 |
208 | assert not (only_sample and only_ref), 'only_ref and only_sample arguments are mutually exclusive'
209 |
210 | ref_embed, sample_embed = [], []
211 |
212 | info = f'Computing I3D activations for FVD score with batch size {bs}'
213 |
214 | if only_ref:
215 |
216 | if not isvideo(videos_real):
217 | # if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255]
218 | videos_real = torch.from_numpy(videos_real).permute(0, 4, 1, 2, 3).float()
219 | print(videos_real.shape)
220 |
221 | if videos_real.shape[0] % bs == 0:
222 | n_secs = videos_real.shape[0] // bs
223 | else:
224 | n_secs = videos_real.shape[0] // bs + 1
225 |
226 | videos_real = torch.tensor_split(videos_real, n_secs, dim=0)
227 |
228 | for ref_v in tqdm(videos_real, total=len(videos_real),desc=info):
229 |
230 | feats_ref = detector(ref_v.to(device).contiguous(), **detector_kwargs).cpu().numpy()
231 | ref_embed.append(feats_ref)
232 |
233 | elif only_sample:
234 |
235 | if not isvideo(videos_fake):
236 | # if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255]
237 | videos_fake = torch.from_numpy(videos_fake).permute(0, 4, 1, 2, 3).float()
238 | print(videos_fake.shape)
239 |
240 | if videos_fake.shape[0] % bs == 0:
241 | n_secs = videos_fake.shape[0] // bs
242 | else:
243 | n_secs = videos_fake.shape[0] // bs + 1
244 |
245 | videos_real = torch.tensor_split(videos_real, n_secs, dim=0)
246 |
247 | for sample_v in tqdm(videos_fake, total=len(videos_real),desc=info):
248 | feats_sample = detector(sample_v.to(device).contiguous(), **detector_kwargs).cpu().numpy()
249 | sample_embed.append(feats_sample)
250 |
251 |
252 | else:
253 |
254 | if not isvideo(videos_real):
255 | # if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255]
256 | videos_real = torch.from_numpy(videos_real).permute(0, 4, 1, 2, 3).float()
257 |
258 | if not isvideo(videos_fake):
259 | videos_fake = torch.from_numpy(videos_fake).permute(0, 4, 1, 2, 3).float()
260 |
261 | if videos_fake.shape[0] % bs == 0:
262 | n_secs = videos_fake.shape[0] // bs
263 | else:
264 | n_secs = videos_fake.shape[0] // bs + 1
265 |
266 | videos_real = torch.tensor_split(videos_real, n_secs, dim=0)
267 | videos_fake = torch.tensor_split(videos_fake, n_secs, dim=0)
268 |
269 | for ref_v, sample_v in tqdm(zip(videos_real,videos_fake),total=len(videos_fake),desc=info):
270 | # print(ref_v.shape)
271 | # ref_v = torch.nn.functional.interpolate(ref_v, size=(sample_v.shape[2], 256, 256), mode='trilinear', align_corners=False)
272 | # sample_v = torch.nn.functional.interpolate(sample_v, size=(sample_v.shape[2], 256, 256), mode='trilinear', align_corners=False)
273 |
274 |
275 | feats_sample = detector(sample_v.to(device).contiguous(), **detector_kwargs).cpu().numpy()
276 | feats_ref = detector(ref_v.to(device).contiguous(), **detector_kwargs).cpu().numpy()
277 | sample_embed.append(feats_sample)
278 | ref_embed.append(feats_ref)
279 |
280 | out = dict()
281 | if len(sample_embed) > 0:
282 | sample_embed = np.concatenate(sample_embed,axis=0)
283 | mu_sample, sigma_sample = compute_stats(sample_embed)
284 | out.update({'mu_sample': mu_sample,
285 | 'sigma_sample': sigma_sample})
286 |
287 | if len(ref_embed) > 0:
288 | ref_embed = np.concatenate(ref_embed,axis=0)
289 | mu_ref, sigma_ref = compute_stats(ref_embed)
290 | out.update({'mu_ref': mu_ref,
291 | 'sigma_ref': sigma_ref})
292 |
293 |
294 | return out
295 |
--------------------------------------------------------------------------------
/MOVIS/ldm/data/coco.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import albumentations
4 | import numpy as np
5 | from PIL import Image
6 | from tqdm import tqdm
7 | from torch.utils.data import Dataset
8 | from abc import abstractmethod
9 |
10 |
11 | class CocoBase(Dataset):
12 | """needed for (image, caption, segmentation) pairs"""
13 | def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=False, use_stuffthing=False,
14 | crop_size=None, force_no_crop=False, given_files=None, use_segmentation=True,crop_type=None):
15 | self.split = self.get_split()
16 | self.size = size
17 | if crop_size is None:
18 | self.crop_size = size
19 | else:
20 | self.crop_size = crop_size
21 |
22 | assert crop_type in [None, 'random', 'center']
23 | self.crop_type = crop_type
24 | self.use_segmenation = use_segmentation
25 | self.onehot = onehot_segmentation # return segmentation as rgb or one hot
26 | self.stuffthing = use_stuffthing # include thing in segmentation
27 | if self.onehot and not self.stuffthing:
28 | raise NotImplemented("One hot mode is only supported for the "
29 | "stuffthings version because labels are stored "
30 | "a bit different.")
31 |
32 | data_json = datajson
33 | with open(data_json) as json_file:
34 | self.json_data = json.load(json_file)
35 | self.img_id_to_captions = dict()
36 | self.img_id_to_filepath = dict()
37 | self.img_id_to_segmentation_filepath = dict()
38 |
39 | assert data_json.split("/")[-1] in [f"captions_train{self.year()}.json",
40 | f"captions_val{self.year()}.json"]
41 | # TODO currently hardcoded paths, would be better to follow logic in
42 | # cocstuff pixelmaps
43 | if self.use_segmenation:
44 | if self.stuffthing:
45 | self.segmentation_prefix = (
46 | f"data/cocostuffthings/val{self.year()}" if
47 | data_json.endswith(f"captions_val{self.year()}.json") else
48 | f"data/cocostuffthings/train{self.year()}")
49 | else:
50 | self.segmentation_prefix = (
51 | f"data/coco/annotations/stuff_val{self.year()}_pixelmaps" if
52 | data_json.endswith(f"captions_val{self.year()}.json") else
53 | f"data/coco/annotations/stuff_train{self.year()}_pixelmaps")
54 |
55 | imagedirs = self.json_data["images"]
56 | self.labels = {"image_ids": list()}
57 | for imgdir in tqdm(imagedirs, desc="ImgToPath"):
58 | self.img_id_to_filepath[imgdir["id"]] = os.path.join(dataroot, imgdir["file_name"])
59 | self.img_id_to_captions[imgdir["id"]] = list()
60 | pngfilename = imgdir["file_name"].replace("jpg", "png")
61 | if self.use_segmenation:
62 | self.img_id_to_segmentation_filepath[imgdir["id"]] = os.path.join(
63 | self.segmentation_prefix, pngfilename)
64 | if given_files is not None:
65 | if pngfilename in given_files:
66 | self.labels["image_ids"].append(imgdir["id"])
67 | else:
68 | self.labels["image_ids"].append(imgdir["id"])
69 |
70 | capdirs = self.json_data["annotations"]
71 | for capdir in tqdm(capdirs, desc="ImgToCaptions"):
72 | # there are in average 5 captions per image
73 | #self.img_id_to_captions[capdir["image_id"]].append(np.array([capdir["caption"]]))
74 | self.img_id_to_captions[capdir["image_id"]].append(capdir["caption"])
75 |
76 | self.rescaler = albumentations.SmallestMaxSize(max_size=self.size)
77 | if self.split=="validation":
78 | self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
79 | else:
80 | # default option for train is random crop
81 | if self.crop_type in [None, 'random']:
82 | self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size)
83 | else:
84 | self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
85 | self.preprocessor = albumentations.Compose(
86 | [self.rescaler, self.cropper],
87 | additional_targets={"segmentation": "image"})
88 | if force_no_crop:
89 | self.rescaler = albumentations.Resize(height=self.size, width=self.size)
90 | self.preprocessor = albumentations.Compose(
91 | [self.rescaler],
92 | additional_targets={"segmentation": "image"})
93 |
94 | @abstractmethod
95 | def year(self):
96 | raise NotImplementedError()
97 |
98 | def __len__(self):
99 | return len(self.labels["image_ids"])
100 |
101 | def preprocess_image(self, image_path, segmentation_path=None):
102 | image = Image.open(image_path)
103 | if not image.mode == "RGB":
104 | image = image.convert("RGB")
105 | image = np.array(image).astype(np.uint8)
106 | if segmentation_path:
107 | segmentation = Image.open(segmentation_path)
108 | if not self.onehot and not segmentation.mode == "RGB":
109 | segmentation = segmentation.convert("RGB")
110 | segmentation = np.array(segmentation).astype(np.uint8)
111 | if self.onehot:
112 | assert self.stuffthing
113 | # stored in caffe format: unlabeled==255. stuff and thing from
114 | # 0-181. to be compatible with the labels in
115 | # https://github.com/nightrome/cocostuff/blob/master/labels.txt
116 | # we shift stuffthing one to the right and put unlabeled in zero
117 | # as long as segmentation is uint8 shifting to right handles the
118 | # latter too
119 | assert segmentation.dtype == np.uint8
120 | segmentation = segmentation + 1
121 |
122 | processed = self.preprocessor(image=image, segmentation=segmentation)
123 |
124 | image, segmentation = processed["image"], processed["segmentation"]
125 | else:
126 | image = self.preprocessor(image=image,)['image']
127 |
128 | image = (image / 127.5 - 1.0).astype(np.float32)
129 | if segmentation_path:
130 | if self.onehot:
131 | assert segmentation.dtype == np.uint8
132 | # make it one hot
133 | n_labels = 183
134 | flatseg = np.ravel(segmentation)
135 | onehot = np.zeros((flatseg.size, n_labels), dtype=np.bool)
136 | onehot[np.arange(flatseg.size), flatseg] = True
137 | onehot = onehot.reshape(segmentation.shape + (n_labels,)).astype(int)
138 | segmentation = onehot
139 | else:
140 | segmentation = (segmentation / 127.5 - 1.0).astype(np.float32)
141 | return image, segmentation
142 | else:
143 | return image
144 |
145 | def __getitem__(self, i):
146 | img_path = self.img_id_to_filepath[self.labels["image_ids"][i]]
147 | if self.use_segmenation:
148 | seg_path = self.img_id_to_segmentation_filepath[self.labels["image_ids"][i]]
149 | image, segmentation = self.preprocess_image(img_path, seg_path)
150 | else:
151 | image = self.preprocess_image(img_path)
152 | captions = self.img_id_to_captions[self.labels["image_ids"][i]]
153 | # randomly draw one of all available captions per image
154 | caption = captions[np.random.randint(0, len(captions))]
155 | example = {"image": image,
156 | #"caption": [str(caption[0])],
157 | "caption": caption,
158 | "img_path": img_path,
159 | "filename_": img_path.split(os.sep)[-1]
160 | }
161 | if self.use_segmenation:
162 | example.update({"seg_path": seg_path, 'segmentation': segmentation})
163 | return example
164 |
165 |
166 | class CocoImagesAndCaptionsTrain2017(CocoBase):
167 | """returns a pair of (image, caption)"""
168 | def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,):
169 | super().__init__(size=size,
170 | dataroot="data/coco/train2017",
171 | datajson="data/coco/annotations/captions_train2017.json",
172 | onehot_segmentation=onehot_segmentation,
173 | use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop)
174 |
175 | def get_split(self):
176 | return "train"
177 |
178 | def year(self):
179 | return '2017'
180 |
181 |
182 | class CocoImagesAndCaptionsValidation2017(CocoBase):
183 | """returns a pair of (image, caption)"""
184 | def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,
185 | given_files=None):
186 | super().__init__(size=size,
187 | dataroot="data/coco/val2017",
188 | datajson="data/coco/annotations/captions_val2017.json",
189 | onehot_segmentation=onehot_segmentation,
190 | use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop,
191 | given_files=given_files)
192 |
193 | def get_split(self):
194 | return "validation"
195 |
196 | def year(self):
197 | return '2017'
198 |
199 |
200 |
201 | class CocoImagesAndCaptionsTrain2014(CocoBase):
202 | """returns a pair of (image, caption)"""
203 | def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,crop_type='random'):
204 | super().__init__(size=size,
205 | dataroot="data/coco/train2014",
206 | datajson="data/coco/annotations2014/annotations/captions_train2014.json",
207 | onehot_segmentation=onehot_segmentation,
208 | use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop,
209 | use_segmentation=False,
210 | crop_type=crop_type)
211 |
212 | def get_split(self):
213 | return "train"
214 |
215 | def year(self):
216 | return '2014'
217 |
218 | class CocoImagesAndCaptionsValidation2014(CocoBase):
219 | """returns a pair of (image, caption)"""
220 | def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,
221 | given_files=None,crop_type='center',**kwargs):
222 | super().__init__(size=size,
223 | dataroot="data/coco/val2014",
224 | datajson="data/coco/annotations2014/annotations/captions_val2014.json",
225 | onehot_segmentation=onehot_segmentation,
226 | use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop,
227 | given_files=given_files,
228 | use_segmentation=False,
229 | crop_type=crop_type)
230 |
231 | def get_split(self):
232 | return "validation"
233 |
234 | def year(self):
235 | return '2014'
236 |
237 | if __name__ == '__main__':
238 | with open("data/coco/annotations2014/annotations/captions_val2014.json", "r") as json_file:
239 | json_data = json.load(json_file)
240 | capdirs = json_data["annotations"]
241 | import pudb; pudb.set_trace()
242 | #d2 = CocoImagesAndCaptionsTrain2014(size=256)
243 | d2 = CocoImagesAndCaptionsValidation2014(size=256)
244 | print("constructed dataset.")
245 | print(f"length of {d2.__class__.__name__}: {len(d2)}")
246 |
247 | ex2 = d2[0]
248 | # ex3 = d3[0]
249 | # print(ex1["image"].shape)
250 | print(ex2["image"].shape)
251 | # print(ex3["image"].shape)
252 | # print(ex1["segmentation"].shape)
253 | print(ex2["caption"].__class__.__name__)
254 |
--------------------------------------------------------------------------------
/MOVIS/eval_single.py:
--------------------------------------------------------------------------------
1 | from contextlib import nullcontext
2 | from functools import partial
3 | import argparse, os, sys, datetime, glob, importlib, csv
4 |
5 | import math
6 | import fire
7 | import gradio as gr
8 | import numpy as np
9 | import torch
10 | import yaml
11 | from einops import rearrange
12 | from ldm.models.diffusion.ddim import DDIMSampler
13 | from omegaconf import OmegaConf
14 | from PIL import Image
15 | from torch import autocast
16 | from torchvision import transforms
17 | from ldm.util import load_and_preprocess, instantiate_from_config
18 | from torch.utils.data import Dataset, DataLoader
19 | from pathlib import Path
20 | import matplotlib.pyplot as plt
21 | import random
22 | from pytorch_lightning.trainer import Trainer
23 | from pytorch_lightning import seed_everything
24 | import pytorch_lightning as pl
25 | from tqdm import tqdm
26 | import torchvision
27 | import torchvision.transforms as transforms
28 | import torchvision.models as models
29 | from omegaconf import DictConfig, ListConfig
30 | import json
31 |
32 | from skimage.metrics import structural_similarity as ssim
33 | from skimage.metrics import peak_signal_noise_ratio as psnr
34 |
35 | import cv2
36 | import einops
37 |
38 | # Create the parser
39 | parser = argparse.ArgumentParser(description="parse argument.")
40 |
41 | # Add arguments
42 | parser.add_argument("--input_image", type=str, help="path to input image")
43 | parser.add_argument("--input_depth", type=str, help="path to input depth")
44 | parser.add_argument("--input_mask", type=str, help="path to input mask")
45 | parser.add_argument("--azimuth", type=int, help="azimuth change")
46 | parser.add_argument("--elevation", type=int, help="elevation change")
47 | parser.add_argument("--output_path", type=str, help="path to output image")
48 | parser.add_argument("--default_dist", type=int, default=1.6)
49 | parser.add_argument("--default_elevation", type=int, default=15)
50 | parser.add_argument("--ddim_steps", type=int, default=50)
51 | parser.add_argument("--n_samples", type=int, default=1)
52 | parser.add_argument("--cfg_scale", type=float, default=3.0)
53 |
54 | # Parse the arguments
55 | args = parser.parse_args()
56 |
57 | from kiui.op import safe_normalize
58 | def look_at(campos, target, opengl=True):
59 | """construct pose rotation matrix by look-at.
60 |
61 | Args:
62 | campos (np.ndarray): camera position, float [3]
63 | target (np.ndarray): look at target, float [3]
64 | opengl (bool, optional): whether use opengl camera convention (forward direction is target --> camera). Defaults to True.
65 |
66 | Returns:
67 | np.ndarray: the camera pose rotation matrix, float [3, 3], normalized.
68 | """
69 |
70 | if not opengl:
71 | # forward is camera --> target
72 | forward_vector = safe_normalize(target - campos)
73 | up_vector = np.array([0, 1, 0], dtype=np.float32)
74 | right_vector = safe_normalize(np.cross(forward_vector, up_vector))
75 | up_vector = safe_normalize(np.cross(right_vector, forward_vector))
76 | else:
77 | # forward is target --> camera
78 | forward_vector = safe_normalize(campos - target)
79 | up_vector = np.array([0, 0, 1], dtype=np.float32)
80 | right_vector = safe_normalize(np.cross(up_vector, forward_vector))
81 | up_vector = safe_normalize(np.cross(forward_vector, right_vector))
82 | R = np.stack([right_vector, up_vector, forward_vector], axis=1)
83 | return R
84 |
85 |
86 | def load_model_from_config(config, ckpt, device, verbose=False):
87 | print(f'Loading model from {ckpt}')
88 | pl_sd = torch.load(ckpt, map_location='cpu')
89 | if 'global_step' in pl_sd:
90 | print(f'Global Step: {pl_sd["global_step"]}')
91 | sd = pl_sd['state_dict']
92 | model = instantiate_from_config(config.model)
93 | m, u = model.load_state_dict(sd, strict=False)
94 | if len(m) > 0 and verbose:
95 | print('missing keys:')
96 | print(m)
97 | if len(u) > 0 and verbose:
98 | print('unexpected keys:')
99 | print(u)
100 |
101 | model.to(device)
102 | model.eval()
103 | return model
104 |
105 | def get_T2(target_RT, cond_RT):
106 | delta = torch.from_numpy(np.linalg.inv(cond_RT) @ target_RT).flatten()
107 | return delta
108 |
109 |
110 | @torch.no_grad()
111 | def sample_model(input_im, model, sampler, precision, h, w, ddim_steps, n_samples, scale, \
112 | ddim_eta, T, depth1, mask_super):
113 | precision_scope = autocast if precision=="autocast" else nullcontext
114 | with precision_scope("cuda"):
115 | with model.ema_scope():
116 | # c = model.get_learned_conditioning(input_im).tile(1, n_samples,1,1)
117 | # # T = torch.tensor([math.radians(x), math.sin(math.radians(y)), math.cos(math.radians(y)), z])
118 | # T = T[:, None, None, :].repeat(1, n_samples, 1, 1).to(c.device)
119 | c = model.get_learned_conditioning(input_im)
120 | T = T[:, None, :].to(c.device)
121 | T = einops.repeat(T, 'b l n -> b (l k) n', k=c.shape[1])
122 | c = torch.cat([c, T], dim=-1)
123 | c = model.cc_projection(c.float())
124 | cond = {}
125 | cond['c_crossattn'] = [c]
126 | c_concat = model.encode_first_stage((input_im.to(c.device))).mode().detach()
127 | cond['c_concat'] = [model.encode_first_stage((input_im.to(c.device))).mode().detach()\
128 | .repeat(n_samples, 1, 1, 1)]
129 | if depth1 is not None:
130 | cond['depth1'] = [model.encode_first_stage((depth1.float().to(c.device))).mode().detach()]
131 | if mask_super is not None:
132 | cond["mask_cond"] = [model.encode_first_stage((mask_super.float().to(c.device))).mode().detach()]
133 | if scale != 1.0:
134 | uc = {}
135 | uc['c_concat'] = [torch.zeros(cond['c_concat'][0].shape[0], 4, h // 8, w // 8).to(c.device)]
136 | uc['c_crossattn'] = [torch.zeros_like(c).to(c.device)]
137 | if depth1 is not None:
138 | uc['depth1'] = [torch.zeros(cond['c_concat'][0].shape[0], 4, h // 8, w // 8).to(c.device)]
139 | if mask_super is not None:
140 | uc["mask_cond"] = [torch.zeros(cond['c_concat'][0].shape[0], 4, h // 8, w // 8).to(c.device)]
141 | else:
142 | uc = None
143 | # breakpoint()
144 | shape = [4, h // 8, w // 8]
145 | mask_info = None
146 | if (mask_super is not None):
147 |
148 | samples_ddim, _, mask_info = sampler.sample(S=ddim_steps,
149 | conditioning=cond,
150 | batch_size=cond['c_concat'][0].shape[0],
151 | shape=shape,
152 | verbose=False,
153 | unconditional_guidance_scale=scale,
154 | unconditional_conditioning=uc,
155 | eta=ddim_eta,
156 | x_T=None)
157 |
158 | else:
159 | samples_ddim, _ = sampler.sample(S=ddim_steps,
160 | conditioning=cond,
161 | batch_size=cond['c_concat'][0].shape[0],
162 | shape=shape,
163 | verbose=False,
164 | unconditional_guidance_scale=scale,
165 | unconditional_conditioning=uc,
166 | eta=ddim_eta,
167 | x_T=None)
168 | print(samples_ddim.shape)
169 | # samples_ddim = torch.nn.functional.interpolate(samples_ddim, 64, mode='nearest', antialias=False)
170 | x_samples_ddim = model.decode_first_stage(samples_ddim)
171 | return torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0).cpu()
172 |
173 | def get_depth2(pth):
174 | depth = np.load(pth)
175 |
176 | new_depth = np.zeros_like(depth)
177 | new_depth = depth.copy()
178 | new_depth[new_depth > 1000] = new_depth[new_depth < 1000].max() * 2
179 | pixels = new_depth.ravel()
180 | depth_min = np.percentile(pixels, 2)
181 | depth_max = np.percentile(pixels, 98)
182 | if(depth_max - depth_min < 1e-5):
183 | depth_min = new_depth.min()
184 | normalized_depth = ((new_depth - depth_min) / (depth_max - depth_min) - 0.5) * 2.0
185 | normalized_depth = cv2.resize(normalized_depth, (256, 256), interpolation=cv2.INTER_LINEAR)
186 | normalized_depth = np.tile(normalized_depth, (3, 1, 1))
187 | # print(normalized_depth.shape)
188 |
189 | depth_torch = torch.tensor(normalized_depth)
190 | return depth_torch
191 |
192 | def load_mask2(path):
193 | mask = plt.imread(path)
194 | mask = np.uint8(mask * 255.)
195 | mask = cv2.resize(mask, (256, 256), interpolation=cv2.INTER_LINEAR)
196 |
197 | mask_min = mask.min()
198 | mask_max = mask.max()
199 | normalized_mask = ((mask - mask_min) / (mask_max - mask_min) - 0.5) * 2.0
200 | # normalized_mask = cv2.resize(normalized_mask, (256, 256), interpolation=cv2.INTER_LINEAR)
201 | normalized_mask = np.tile(normalized_mask, (3, 1, 1))
202 | return normalized_mask
203 |
204 |
205 |
206 | device = f"cuda:0"
207 | if not torch.cuda.is_available():
208 | device = "cpu"
209 | cfg_file = 'configs/3d_mix.yaml'
210 | config = OmegaConf.load(cfg_file)
211 | if config.model.params.depth1:
212 | config.model.params.unet_config.params.in_channels += 4
213 | if config.model.params.mask_super:
214 | config.model.params.unet_config.params.in_channels += 4
215 | ckpt = 'last.ckpt'
216 | # Instantiate all models beforehand for efficiency.
217 | models = dict()
218 | print('Instantiating LatentDiffusion...')
219 | models['turncam'] = load_model_from_config(config, ckpt, device=device)
220 | sampler = DDIMSampler(models['turncam'])
221 |
222 | # load all the files needed
223 | input_im = np.array(Image.open(args.input_image))
224 | input_im = transforms.ToTensor()(input_im).unsqueeze(0).to(device)
225 | input_im = input_im * 2 - 1
226 | input_im = transforms.functional.resize(input_im, [256, 256])
227 | mask_cond = torch.from_numpy(load_mask2(args.input_mask)).unsqueeze(0)
228 | depth_cond = get_depth2(args.input_depth).unsqueeze(0)
229 |
230 | # input pose
231 | default_elevation = args.default_elevation
232 | default_dis = args.default_dist
233 | default_azimuth = 0
234 | input_RT = np.zeros((4, 4))
235 | tmp_x = default_dis * math.cos(math.radians(default_elevation)) * math.cos(math.radians(default_azimuth))
236 | tmp_y = default_dis * math.cos(math.radians(default_elevation)) * math.sin(math.radians(default_azimuth))
237 | tmp_z = default_dis * math.sin(math.radians(default_elevation))
238 | campos = np.array([tmp_x, tmp_y, tmp_z])
239 | tar = np.array([0, 0, 0])
240 | input_R = look_at(campos, tar)
241 | input_RT[0:3, 0:3] = input_R
242 | input_RT[0:3, 3] = campos
243 | input_RT[3, 3] = 1
244 | ref_input_RT = input_RT
245 |
246 | # output pose
247 | elevation = args.elevation + default_elevation
248 | azimuth = args.azimuth
249 | tmp_x = default_dis * math.cos(math.radians(elevation)) * math.cos(math.radians(azimuth))
250 | tmp_y = default_dis * math.cos(math.radians(elevation)) * math.sin(math.radians(azimuth))
251 | tmp_z = default_dis * math.sin(math.radians(elevation))
252 | output_RT = np.zeros((4, 4))
253 | campos1 = np.array([tmp_x, tmp_y, tmp_z])
254 | tar1 = np.array([0, 0, 0])
255 | output_R = look_at(campos1, tar1)
256 | output_RT[0:3, 0:3] = output_R
257 | output_RT[0:3, 3] = campos1
258 | output_RT[3, 3] = 1
259 | target_view_RT = (output_RT)
260 |
261 | T = get_T2(target_view_RT, ref_input_RT).unsqueeze(0)
262 |
263 | # sorry only this resolution is supported
264 | resolution_x = 256
265 | resolution_y = 256
266 | ddim_steps = args.ddim_steps
267 | n_samples = args.n_samples
268 | cfg_scale = args.cfg_scale
269 |
270 | x_samples_ddim = sample_model(input_im, models['turncam'], sampler, 'fp32', resolution_x, resolution_y, ddim_steps, \
271 | n_samples, cfg_scale, 1.0, T, depth_cond, mask_cond)
272 |
273 | out_img_tmp = x_samples_ddim.permute(0, 2, 3, 1)
274 | out_img_tmp = out_img_tmp.numpy()
275 | out_tmp = (out_img_tmp[0] * 255).astype(np.uint8)
276 | Image.fromarray(out_tmp).save(args.output_path)
--------------------------------------------------------------------------------
/MOVIS/ldm/models/diffusion/plms.py:
--------------------------------------------------------------------------------
1 | """SAMPLING ONLY."""
2 |
3 | import torch
4 | import numpy as np
5 | from tqdm import tqdm
6 | from functools import partial
7 |
8 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9 | from ldm.models.diffusion.sampling_util import norm_thresholding
10 |
11 |
12 | class PLMSSampler(object):
13 | def __init__(self, model, schedule="linear", **kwargs):
14 | super().__init__()
15 | self.model = model
16 | self.ddpm_num_timesteps = model.num_timesteps
17 | self.schedule = schedule
18 |
19 | def register_buffer(self, name, attr):
20 | if type(attr) == torch.Tensor:
21 | if attr.device != torch.device("cuda"):
22 | attr = attr.to(torch.device("cuda"))
23 | setattr(self, name, attr)
24 |
25 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
26 | if ddim_eta != 0:
27 | raise ValueError('ddim_eta must be 0 for PLMS')
28 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
29 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
30 | alphas_cumprod = self.model.alphas_cumprod
31 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
32 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
33 |
34 | self.register_buffer('betas', to_torch(self.model.betas))
35 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
36 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
37 |
38 | # calculations for diffusion q(x_t | x_{t-1}) and others
39 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
40 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
41 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
42 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
43 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
44 |
45 | # ddim sampling parameters
46 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
47 | ddim_timesteps=self.ddim_timesteps,
48 | eta=ddim_eta,verbose=verbose)
49 | self.register_buffer('ddim_sigmas', ddim_sigmas)
50 | self.register_buffer('ddim_alphas', ddim_alphas)
51 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
52 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
53 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
54 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
55 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
56 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
57 |
58 | @torch.no_grad()
59 | def sample(self,
60 | S,
61 | batch_size,
62 | shape,
63 | conditioning=None,
64 | callback=None,
65 | normals_sequence=None,
66 | img_callback=None,
67 | quantize_x0=False,
68 | eta=0.,
69 | mask=None,
70 | x0=None,
71 | temperature=1.,
72 | noise_dropout=0.,
73 | score_corrector=None,
74 | corrector_kwargs=None,
75 | verbose=True,
76 | x_T=None,
77 | log_every_t=100,
78 | unconditional_guidance_scale=1.,
79 | unconditional_conditioning=None,
80 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
81 | dynamic_threshold=None,
82 | **kwargs
83 | ):
84 | if conditioning is not None:
85 | if isinstance(conditioning, dict):
86 | ctmp = conditioning[list(conditioning.keys())[0]]
87 | while isinstance(ctmp, list): ctmp = ctmp[0]
88 | cbs = ctmp.shape[0]
89 | if cbs != batch_size:
90 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
91 | else:
92 | if conditioning.shape[0] != batch_size:
93 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
94 |
95 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
96 | # sampling
97 | C, H, W = shape
98 | size = (batch_size, C, H, W)
99 | print(f'Data shape for PLMS sampling is {size}')
100 |
101 | samples, intermediates = self.plms_sampling(conditioning, size,
102 | callback=callback,
103 | img_callback=img_callback,
104 | quantize_denoised=quantize_x0,
105 | mask=mask, x0=x0,
106 | ddim_use_original_steps=False,
107 | noise_dropout=noise_dropout,
108 | temperature=temperature,
109 | score_corrector=score_corrector,
110 | corrector_kwargs=corrector_kwargs,
111 | x_T=x_T,
112 | log_every_t=log_every_t,
113 | unconditional_guidance_scale=unconditional_guidance_scale,
114 | unconditional_conditioning=unconditional_conditioning,
115 | dynamic_threshold=dynamic_threshold,
116 | )
117 | return samples, intermediates
118 |
119 | @torch.no_grad()
120 | def plms_sampling(self, cond, shape,
121 | x_T=None, ddim_use_original_steps=False,
122 | callback=None, timesteps=None, quantize_denoised=False,
123 | mask=None, x0=None, img_callback=None, log_every_t=100,
124 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
125 | unconditional_guidance_scale=1., unconditional_conditioning=None,
126 | dynamic_threshold=None):
127 | device = self.model.betas.device
128 | b = shape[0]
129 | if x_T is None:
130 | img = torch.randn(shape, device=device)
131 | else:
132 | img = x_T
133 |
134 | if timesteps is None:
135 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
136 | elif timesteps is not None and not ddim_use_original_steps:
137 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
138 | timesteps = self.ddim_timesteps[:subset_end]
139 |
140 | intermediates = {'x_inter': [img], 'pred_x0': [img]}
141 | time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
142 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
143 | print(f"Running PLMS Sampling with {total_steps} timesteps")
144 |
145 | iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
146 | old_eps = []
147 |
148 | for i, step in enumerate(iterator):
149 | index = total_steps - i - 1
150 | ts = torch.full((b,), step, device=device, dtype=torch.long)
151 | ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
152 |
153 | if mask is not None:
154 | assert x0 is not None
155 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
156 | img = img_orig * mask + (1. - mask) * img
157 |
158 | outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
159 | quantize_denoised=quantize_denoised, temperature=temperature,
160 | noise_dropout=noise_dropout, score_corrector=score_corrector,
161 | corrector_kwargs=corrector_kwargs,
162 | unconditional_guidance_scale=unconditional_guidance_scale,
163 | unconditional_conditioning=unconditional_conditioning,
164 | old_eps=old_eps, t_next=ts_next,
165 | dynamic_threshold=dynamic_threshold)
166 | img, pred_x0, e_t = outs
167 | old_eps.append(e_t)
168 | if len(old_eps) >= 4:
169 | old_eps.pop(0)
170 | if callback: callback(i)
171 | if img_callback: img_callback(pred_x0, i)
172 |
173 | if index % log_every_t == 0 or index == total_steps - 1:
174 | intermediates['x_inter'].append(img)
175 | intermediates['pred_x0'].append(pred_x0)
176 |
177 | return img, intermediates
178 |
179 | @torch.no_grad()
180 | def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
181 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
182 | unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
183 | dynamic_threshold=None):
184 | b, *_, device = *x.shape, x.device
185 |
186 | def get_model_output(x, t):
187 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
188 | e_t = self.model.apply_model(x, t, c)
189 | else:
190 | x_in = torch.cat([x] * 2)
191 | t_in = torch.cat([t] * 2)
192 | if isinstance(c, dict):
193 | assert isinstance(unconditional_conditioning, dict)
194 | c_in = dict()
195 | for k in c:
196 | if isinstance(c[k], list):
197 | c_in[k] = [torch.cat([
198 | unconditional_conditioning[k][i],
199 | c[k][i]]) for i in range(len(c[k]))]
200 | else:
201 | c_in[k] = torch.cat([
202 | unconditional_conditioning[k],
203 | c[k]])
204 | else:
205 | c_in = torch.cat([unconditional_conditioning, c])
206 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
207 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
208 |
209 | if score_corrector is not None:
210 | assert self.model.parameterization == "eps"
211 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
212 |
213 | return e_t
214 |
215 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
216 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
217 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
218 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
219 |
220 | def get_x_prev_and_pred_x0(e_t, index):
221 | # select parameters corresponding to the currently considered timestep
222 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
223 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
224 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
225 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
226 |
227 | # current prediction for x_0
228 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
229 | if quantize_denoised:
230 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
231 | if dynamic_threshold is not None:
232 | pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
233 | # direction pointing to x_t
234 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
235 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
236 | if noise_dropout > 0.:
237 | noise = torch.nn.functional.dropout(noise, p=noise_dropout)
238 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
239 | return x_prev, pred_x0
240 |
241 | e_t = get_model_output(x, t)
242 | if len(old_eps) == 0:
243 | # Pseudo Improved Euler (2nd order)
244 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
245 | e_t_next = get_model_output(x_prev, t_next)
246 | e_t_prime = (e_t + e_t_next) / 2
247 | elif len(old_eps) == 1:
248 | # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
249 | e_t_prime = (3 * e_t - old_eps[-1]) / 2
250 | elif len(old_eps) == 2:
251 | # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
252 | e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
253 | elif len(old_eps) >= 3:
254 | # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
255 | e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
256 |
257 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
258 |
259 | return x_prev, pred_x0, e_t
260 |
--------------------------------------------------------------------------------
/MOVIS/ldm/data/imagenet.py:
--------------------------------------------------------------------------------
1 | import os, yaml, pickle, shutil, tarfile, glob
2 | import cv2
3 | import albumentations
4 | import PIL
5 | import numpy as np
6 | import torchvision.transforms.functional as TF
7 | from omegaconf import OmegaConf
8 | from functools import partial
9 | from PIL import Image
10 | from tqdm import tqdm
11 | from torch.utils.data import Dataset, Subset
12 |
13 | import taming.data.utils as tdu
14 | from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
15 | from taming.data.imagenet import ImagePaths
16 |
17 | from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
18 |
19 |
20 | def synset2idx(path_to_yaml="data/index_synset.yaml"):
21 | with open(path_to_yaml) as f:
22 | di2s = yaml.load(f)
23 | return dict((v,k) for k,v in di2s.items())
24 |
25 |
26 | class ImageNetBase(Dataset):
27 | def __init__(self, config=None):
28 | self.config = config or OmegaConf.create()
29 | if not type(self.config)==dict:
30 | self.config = OmegaConf.to_container(self.config)
31 | self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
32 | self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
33 | self._prepare()
34 | self._prepare_synset_to_human()
35 | self._prepare_idx_to_synset()
36 | self._prepare_human_to_integer_label()
37 | self._load()
38 |
39 | def __len__(self):
40 | return len(self.data)
41 |
42 | def __getitem__(self, i):
43 | return self.data[i]
44 |
45 | def _prepare(self):
46 | raise NotImplementedError()
47 |
48 | def _filter_relpaths(self, relpaths):
49 | ignore = set([
50 | "n06596364_9591.JPEG",
51 | ])
52 | relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
53 | if "sub_indices" in self.config:
54 | indices = str_to_indices(self.config["sub_indices"])
55 | synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
56 | self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
57 | files = []
58 | for rpath in relpaths:
59 | syn = rpath.split("/")[0]
60 | if syn in synsets:
61 | files.append(rpath)
62 | return files
63 | else:
64 | return relpaths
65 |
66 | def _prepare_synset_to_human(self):
67 | SIZE = 2655750
68 | URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
69 | self.human_dict = os.path.join(self.root, "synset_human.txt")
70 | if (not os.path.exists(self.human_dict) or
71 | not os.path.getsize(self.human_dict)==SIZE):
72 | download(URL, self.human_dict)
73 |
74 | def _prepare_idx_to_synset(self):
75 | URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
76 | self.idx2syn = os.path.join(self.root, "index_synset.yaml")
77 | if (not os.path.exists(self.idx2syn)):
78 | download(URL, self.idx2syn)
79 |
80 | def _prepare_human_to_integer_label(self):
81 | URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
82 | self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt")
83 | if (not os.path.exists(self.human2integer)):
84 | download(URL, self.human2integer)
85 | with open(self.human2integer, "r") as f:
86 | lines = f.read().splitlines()
87 | assert len(lines) == 1000
88 | self.human2integer_dict = dict()
89 | for line in lines:
90 | value, key = line.split(":")
91 | self.human2integer_dict[key] = int(value)
92 |
93 | def _load(self):
94 | with open(self.txt_filelist, "r") as f:
95 | self.relpaths = f.read().splitlines()
96 | l1 = len(self.relpaths)
97 | self.relpaths = self._filter_relpaths(self.relpaths)
98 | print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
99 |
100 | self.synsets = [p.split("/")[0] for p in self.relpaths]
101 | self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
102 |
103 | unique_synsets = np.unique(self.synsets)
104 | class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
105 | if not self.keep_orig_class_label:
106 | self.class_labels = [class_dict[s] for s in self.synsets]
107 | else:
108 | self.class_labels = [self.synset2idx[s] for s in self.synsets]
109 |
110 | with open(self.human_dict, "r") as f:
111 | human_dict = f.read().splitlines()
112 | human_dict = dict(line.split(maxsplit=1) for line in human_dict)
113 |
114 | self.human_labels = [human_dict[s] for s in self.synsets]
115 |
116 | labels = {
117 | "relpath": np.array(self.relpaths),
118 | "synsets": np.array(self.synsets),
119 | "class_label": np.array(self.class_labels),
120 | "human_label": np.array(self.human_labels),
121 | }
122 |
123 | if self.process_images:
124 | self.size = retrieve(self.config, "size", default=256)
125 | self.data = ImagePaths(self.abspaths,
126 | labels=labels,
127 | size=self.size,
128 | random_crop=self.random_crop,
129 | )
130 | else:
131 | self.data = self.abspaths
132 |
133 |
134 | class ImageNetTrain(ImageNetBase):
135 | NAME = "ILSVRC2012_train"
136 | URL = "http://www.image-net.org/challenges/LSVRC/2012/"
137 | AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
138 | FILES = [
139 | "ILSVRC2012_img_train.tar",
140 | ]
141 | SIZES = [
142 | 147897477120,
143 | ]
144 |
145 | def __init__(self, process_images=True, data_root=None, **kwargs):
146 | self.process_images = process_images
147 | self.data_root = data_root
148 | super().__init__(**kwargs)
149 |
150 | def _prepare(self):
151 | if self.data_root:
152 | self.root = os.path.join(self.data_root, self.NAME)
153 | else:
154 | cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
155 | self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
156 |
157 | self.datadir = os.path.join(self.root, "data")
158 | self.txt_filelist = os.path.join(self.root, "filelist.txt")
159 | self.expected_length = 1281167
160 | self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
161 | default=True)
162 | if not tdu.is_prepared(self.root):
163 | # prep
164 | print("Preparing dataset {} in {}".format(self.NAME, self.root))
165 |
166 | datadir = self.datadir
167 | if not os.path.exists(datadir):
168 | path = os.path.join(self.root, self.FILES[0])
169 | if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
170 | import academictorrents as at
171 | atpath = at.get(self.AT_HASH, datastore=self.root)
172 | assert atpath == path
173 |
174 | print("Extracting {} to {}".format(path, datadir))
175 | os.makedirs(datadir, exist_ok=True)
176 | with tarfile.open(path, "r:") as tar:
177 | tar.extractall(path=datadir)
178 |
179 | print("Extracting sub-tars.")
180 | subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
181 | for subpath in tqdm(subpaths):
182 | subdir = subpath[:-len(".tar")]
183 | os.makedirs(subdir, exist_ok=True)
184 | with tarfile.open(subpath, "r:") as tar:
185 | tar.extractall(path=subdir)
186 |
187 | filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
188 | filelist = [os.path.relpath(p, start=datadir) for p in filelist]
189 | filelist = sorted(filelist)
190 | filelist = "\n".join(filelist)+"\n"
191 | with open(self.txt_filelist, "w") as f:
192 | f.write(filelist)
193 |
194 | tdu.mark_prepared(self.root)
195 |
196 |
197 | class ImageNetValidation(ImageNetBase):
198 | NAME = "ILSVRC2012_validation"
199 | URL = "http://www.image-net.org/challenges/LSVRC/2012/"
200 | AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
201 | VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
202 | FILES = [
203 | "ILSVRC2012_img_val.tar",
204 | "validation_synset.txt",
205 | ]
206 | SIZES = [
207 | 6744924160,
208 | 1950000,
209 | ]
210 |
211 | def __init__(self, process_images=True, data_root=None, **kwargs):
212 | self.data_root = data_root
213 | self.process_images = process_images
214 | super().__init__(**kwargs)
215 |
216 | def _prepare(self):
217 | if self.data_root:
218 | self.root = os.path.join(self.data_root, self.NAME)
219 | else:
220 | cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
221 | self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
222 | self.datadir = os.path.join(self.root, "data")
223 | self.txt_filelist = os.path.join(self.root, "filelist.txt")
224 | self.expected_length = 50000
225 | self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
226 | default=False)
227 | if not tdu.is_prepared(self.root):
228 | # prep
229 | print("Preparing dataset {} in {}".format(self.NAME, self.root))
230 |
231 | datadir = self.datadir
232 | if not os.path.exists(datadir):
233 | path = os.path.join(self.root, self.FILES[0])
234 | if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
235 | import academictorrents as at
236 | atpath = at.get(self.AT_HASH, datastore=self.root)
237 | assert atpath == path
238 |
239 | print("Extracting {} to {}".format(path, datadir))
240 | os.makedirs(datadir, exist_ok=True)
241 | with tarfile.open(path, "r:") as tar:
242 | tar.extractall(path=datadir)
243 |
244 | vspath = os.path.join(self.root, self.FILES[1])
245 | if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
246 | download(self.VS_URL, vspath)
247 |
248 | with open(vspath, "r") as f:
249 | synset_dict = f.read().splitlines()
250 | synset_dict = dict(line.split() for line in synset_dict)
251 |
252 | print("Reorganizing into synset folders")
253 | synsets = np.unique(list(synset_dict.values()))
254 | for s in synsets:
255 | os.makedirs(os.path.join(datadir, s), exist_ok=True)
256 | for k, v in synset_dict.items():
257 | src = os.path.join(datadir, k)
258 | dst = os.path.join(datadir, v)
259 | shutil.move(src, dst)
260 |
261 | filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
262 | filelist = [os.path.relpath(p, start=datadir) for p in filelist]
263 | filelist = sorted(filelist)
264 | filelist = "\n".join(filelist)+"\n"
265 | with open(self.txt_filelist, "w") as f:
266 | f.write(filelist)
267 |
268 | tdu.mark_prepared(self.root)
269 |
270 |
271 |
272 | class ImageNetSR(Dataset):
273 | def __init__(self, size=None,
274 | degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
275 | random_crop=True):
276 | """
277 | Imagenet Superresolution Dataloader
278 | Performs following ops in order:
279 | 1. crops a crop of size s from image either as random or center crop
280 | 2. resizes crop to size with cv2.area_interpolation
281 | 3. degrades resized crop with degradation_fn
282 |
283 | :param size: resizing to size after cropping
284 | :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light
285 | :param downscale_f: Low Resolution Downsample factor
286 | :param min_crop_f: determines crop size s,
287 | where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)
288 | :param max_crop_f: ""
289 | :param data_root:
290 | :param random_crop:
291 | """
292 | self.base = self.get_base()
293 | assert size
294 | assert (size / downscale_f).is_integer()
295 | self.size = size
296 | self.LR_size = int(size / downscale_f)
297 | self.min_crop_f = min_crop_f
298 | self.max_crop_f = max_crop_f
299 | assert(max_crop_f <= 1.)
300 | self.center_crop = not random_crop
301 |
302 | self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
303 |
304 | self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
305 |
306 | if degradation == "bsrgan":
307 | self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
308 |
309 | elif degradation == "bsrgan_light":
310 | self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
311 |
312 | else:
313 | interpolation_fn = {
314 | "cv_nearest": cv2.INTER_NEAREST,
315 | "cv_bilinear": cv2.INTER_LINEAR,
316 | "cv_bicubic": cv2.INTER_CUBIC,
317 | "cv_area": cv2.INTER_AREA,
318 | "cv_lanczos": cv2.INTER_LANCZOS4,
319 | "pil_nearest": PIL.Image.NEAREST,
320 | "pil_bilinear": PIL.Image.BILINEAR,
321 | "pil_bicubic": PIL.Image.BICUBIC,
322 | "pil_box": PIL.Image.BOX,
323 | "pil_hamming": PIL.Image.HAMMING,
324 | "pil_lanczos": PIL.Image.LANCZOS,
325 | }[degradation]
326 |
327 | self.pil_interpolation = degradation.startswith("pil_")
328 |
329 | if self.pil_interpolation:
330 | self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
331 |
332 | else:
333 | self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
334 | interpolation=interpolation_fn)
335 |
336 | def __len__(self):
337 | return len(self.base)
338 |
339 | def __getitem__(self, i):
340 | example = self.base[i]
341 | image = Image.open(example["file_path_"])
342 |
343 | if not image.mode == "RGB":
344 | image = image.convert("RGB")
345 |
346 | image = np.array(image).astype(np.uint8)
347 |
348 | min_side_len = min(image.shape[:2])
349 | crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
350 | crop_side_len = int(crop_side_len)
351 |
352 | if self.center_crop:
353 | self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
354 |
355 | else:
356 | self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
357 |
358 | image = self.cropper(image=image)["image"]
359 | image = self.image_rescaler(image=image)["image"]
360 |
361 | if self.pil_interpolation:
362 | image_pil = PIL.Image.fromarray(image)
363 | LR_image = self.degradation_process(image_pil)
364 | LR_image = np.array(LR_image).astype(np.uint8)
365 |
366 | else:
367 | LR_image = self.degradation_process(image=image)["image"]
368 |
369 | example["image"] = (image/127.5 - 1.0).astype(np.float32)
370 | example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
371 | example["caption"] = example["human_label"] # dummy caption
372 | return example
373 |
374 |
375 | class ImageNetSRTrain(ImageNetSR):
376 | def __init__(self, **kwargs):
377 | super().__init__(**kwargs)
378 |
379 | def get_base(self):
380 | with open("data/imagenet_train_hr_indices.p", "rb") as f:
381 | indices = pickle.load(f)
382 | dset = ImageNetTrain(process_images=False,)
383 | return Subset(dset, indices)
384 |
385 |
386 | class ImageNetSRValidation(ImageNetSR):
387 | def __init__(self, **kwargs):
388 | super().__init__(**kwargs)
389 |
390 | def get_base(self):
391 | with open("data/imagenet_val_hr_indices.p", "rb") as f:
392 | indices = pickle.load(f)
393 | dset = ImageNetValidation(process_images=False,)
394 | return Subset(dset, indices)
395 |
--------------------------------------------------------------------------------