├── .gitignore
├── LICENSE
├── README.md
├── assets
├── .DS_Store
├── teaser.pdf
└── teaser.png
├── configs
├── fcb.yaml
└── ultrafusion.yaml
├── dataset
└── test_dataset.py
├── examples
├── 0052
│ ├── oe.jpg
│ └── ue.jpg
└── 0072
│ ├── oe.jpg
│ └── ue.jpg
├── inference.py
├── inference.sh
├── model
├── V4_CA
│ ├── attention.py
│ ├── cldm.py
│ ├── controlnet.py
│ ├── cross_attention.py
│ ├── gaussian_diffusion.py
│ ├── unet.py
│ └── vae.py
├── __init__.py
├── clip.py
├── config.py
├── distributions.py
├── fe
│ └── fe_V1_4.py
├── open_clip
│ ├── __init__.py
│ ├── bpe_simple_vocab_16e6.txt.gz
│ ├── model.py
│ ├── tokenizer.py
│ └── transformer.py
├── raft
│ ├── corr.py
│ ├── datasets.py
│ ├── extractor.py
│ ├── raft.py
│ ├── update.py
│ └── utils
│ │ ├── augmentor.py
│ │ ├── flow_viz.py
│ │ ├── frame_utils.py
│ │ └── utils.py
└── util.py
├── pipeline
└── V4_CA
│ └── pipeline.py
├── requirements.txt
├── utils
├── V4_CA
│ └── sampler.py
├── common.py
├── cond_fn.py
├── flow.py
└── imf.py
└── val_nriqa.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | *.ckpt
3 | *.pth
4 | *.out
5 | /data
6 | /exps
7 | # *.sh
8 | !install_env.sh
9 | /weights
10 | /temp
11 | /results*
12 | .ipynb_checkpoints/
13 | /TODO.txt
14 | /deprecated
15 | /temp_scripts
16 | /.vscode
17 | /runs
18 | /experiment/
19 | /ckpts
20 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | UltraFusion: Ultra High Dynamic Imaging using Exposure Fusion
5 |
6 | (CVPR 2025 Highlight)
7 |
8 |
9 | Zixuan Chen1,3
10 | ·
11 | Yujin Wang1
12 | ·
13 | Xin Cai2
14 | ·
15 | Zhiyuan You2
16 |
17 | Zheming Lu3
18 | ·
19 | Fan Zhang1
20 | ·
21 | Shi Guo1
22 | ·
23 | Tianfan Xue2,1
24 |
25 |
26 | 1Shanghai AI Laboratory, 2The Chinese University of Hong Kong,
27 | 3Zhejiang University
28 |
29 |
34 |
35 |
36 |
37 | 
38 |
39 | ## :mega: News
40 | - **2024.4.23**: Inference codes, benchmark and results are released.
41 | - **2024.4.5**: Our UltraFusion is selected to be presented as a :sparkles:***highlight***:sparkles: in CVPR 2025.
42 | - **2025.2.27**: Accepeted by ***CVPR 2025*** :tada::tada::tada:.
43 | - **2025.1.21**: Feel free to try online demos at Hugging Face and OpenXLab :blush:.
44 |
45 |
46 | ## :memo: ToDo List
47 | - [ ] Release training codes.
48 | - [x] Release inference codes and pre-trained model.
49 | - [x] Release UltraFusion benchmark and visual results.
50 | - [x] Release more visual comparison in our [project page](https://openimaginglab.github.io/UltraFusion/)
51 |
52 | ## :bridge_at_night: Benchmark
53 | We capture 100 challenging real-world HDR scenes for performance evaluation.
54 | Our benchmark and results (include competing methods) are availble at [Google Drive](https://drive.google.com/drive/folders/18icr4A_0qGvwqehPhxH29hqJYO8HS6bi?usp=sharing) and [Baidu Disk]().
55 | Moreover, we also provide results of our method and the comparison methods on [RealHDV](https://github.com/yungsyu99/Real-HDRV) and [MEFB](https://github.com/xingchenzhang/MEFB).
56 |
57 | > *Note: The HDR reconstruction methods perform poorly in some scenes because we follow their setup to retrain 2-exposure version, while the training set they used only provide ground truth for the middle exposure, limiting the dynamic range. We believe that using training data with higher dynamic range can improve performance.*
58 |
59 | ## Quick Start
60 | **Installation**
61 | ```shell
62 | # clone this repo
63 | git clone https://github.com/OpenImagingLab/UltraFusion.git
64 | cd UltraFusion
65 |
66 | # create environment
67 | conda create -n UltraFusion python=3.10
68 | conda activate UltraFusion
69 | pip install -r requirements.txt
70 | ```
71 | **Prepare Data and Pre-trained Model**
72 |
73 | Download [raft-sintel.pth](https://drive.google.com/drive/folders/1sWDsfuZ3Up38EUQt7-JDTT1HcGHuJgvT?usp=sharing), [v2-1_512-ema-pruned.ckpt](https://huggingface.co/stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned.ckpt), [fcb.pt](https://huggingface.co/zxchen00/UltraFusion/blob/main/fcb.pt) and [ultrafusion.pt](https://huggingface.co/zxchen00/UltraFusion/blob/main/ultrafusion.pt), and put them in ```ckpts``` folder. Download three benchmarks ([Google Drive](https://drive.google.com/drive/folders/18icr4A_0qGvwqehPhxH29hqJYO8HS6bi?usp=sharing) or [Baidu Disk]()) and put them in ```data``` folder.
74 |
75 | **Inference**
76 |
77 | Run the following scripts for inference.
78 | ```shell
79 | # UltraFusion Benchmark
80 | python inference.py --dataset UltraFusion --output results --tiled --tile_size 512 --tile_stride 256 --prealign --save_all
81 | # RealHDRV
82 | python inference.py --dataset RealHDRV --output results --tiled --tile_size 512 --tile_stride 256 --prealign --save_all
83 | # MEFB (cancel pre-alignment for static scenes)
84 | python inference.py --dataset MEFB --output results --tiled --tile_size 512 --tile_stride 256 --save_all
85 | ```
86 | You can also use ```val_nriqa.py``` for evaluation.
87 |
88 |
89 |
90 |
91 | ## Acknowledgements
92 | This project is developped on the codebase of [DiffBIR](https://github.com/XPixelGroup/DiffBIR). We appreciate their great work!
93 |
94 | ## :love_you_gesture: Citation
95 | If you find our paper and repo are helpful for your research, please consider citing:
96 | ```BibTeX
97 | @article{chen2025ultrafusion,
98 | title={UltraFusion: Ultra High Dynamic Imaging using Exposure Fusion},
99 | author={Chen, Zixuan and Wang, Yujin and Cai, Xin and You, Zhiyuan and Lu, Zheming and Zhang, Fan and Guo, Shi and Xue, Tianfan},
100 | journal={arXiv preprint arXiv:2501.11515},
101 | year={2025}
102 | }
103 | ```
104 |
--------------------------------------------------------------------------------
/assets/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenImagingLab/UltraFusion/b0787c59217fe4877c233d938810a23c8d2c5c73/assets/.DS_Store
--------------------------------------------------------------------------------
/assets/teaser.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenImagingLab/UltraFusion/b0787c59217fe4877c233d938810a23c8d2c5c73/assets/teaser.pdf
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenImagingLab/UltraFusion/b0787c59217fe4877c233d938810a23c8d2c5c73/assets/teaser.png
--------------------------------------------------------------------------------
/configs/fcb.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | fidelity_encoder:
3 | target: model.fe.fe_V1_4.FidelityEncoderV2
4 | params:
5 | fe_config:
6 | double_z: true
7 | z_channels: 4
8 | resolution: 256
9 | in_channels: 3
10 | out_ch: 3
11 | ch: 128
12 | ch_mult:
13 | - 1
14 | - 2
15 | - 4
16 | - 4
17 | num_res_blocks: 2
18 | attn_resolutions: []
19 | dropout: 0.0
--------------------------------------------------------------------------------
/configs/ultrafusion.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | fidelity_encoder:
3 | target: model.V4_CA.vae.FidelityEncoder
4 | params:
5 | double_z: true
6 | z_channels: 4
7 | resolution: 256
8 | in_channels: 1
9 | out_ch: 3
10 | ch: 128
11 | ch_mult:
12 | - 1
13 | - 2
14 | - 4
15 | - 4
16 | num_res_blocks: 2
17 | attn_resolutions: []
18 | dropout: 0.0
19 | cldm:
20 | target: model.V4_CA.cldm.ControlLDM
21 | params:
22 | latent_scale_factor: 0.18215
23 | unet_cfg:
24 | use_checkpoint: True
25 | image_size: 32 # unused
26 | in_channels: 4
27 | out_channels: 4
28 | model_channels: 320
29 | attention_resolutions: [ 4, 2, 1 ]
30 | num_res_blocks: 2
31 | channel_mult: [ 1, 2, 4, 4 ]
32 | num_head_channels: 64 # need to fix for flash-attn
33 | use_spatial_transformer: True
34 | use_linear_in_transformer: True
35 | transformer_depth: 1
36 | context_dim: 1024
37 | legacy: False
38 | vae_cfg:
39 | embed_dim: 4
40 | ddconfig:
41 | double_z: true
42 | z_channels: 4
43 | resolution: 256
44 | in_channels: 3
45 | out_ch: 3
46 | ch: 128
47 | ch_mult:
48 | - 1
49 | - 2
50 | - 4
51 | - 4
52 | num_res_blocks: 2
53 | attn_resolutions: []
54 | dropout: 0.0
55 | clip_cfg:
56 | embed_dim: 1024
57 | vision_cfg:
58 | image_size: 224
59 | layers: 32
60 | width: 1280
61 | head_width: 80
62 | patch_size: 14
63 | text_cfg:
64 | context_length: 77
65 | vocab_size: 49408
66 | width: 1024
67 | heads: 16
68 | layers: 24
69 | layer: "penultimate"
70 | controlnet_cfg:
71 | use_checkpoint: True
72 | image_size: 32 # unused
73 | in_channels: 4
74 | hint_channels: 4
75 | model_channels: 320
76 | attention_resolutions: [ 4, 2, 1 ]
77 | num_res_blocks: 2
78 | channel_mult: [ 1, 2, 4, 4 ]
79 | num_head_channels: 64 # need to fix for flash-attn
80 | use_spatial_transformer: True
81 | use_linear_in_transformer: True
82 | transformer_depth: 1
83 | context_dim: 1024
84 | legacy: False
85 |
86 | diffusion:
87 | target: model.V4_CA.gaussian_diffusion.Diffusion
88 | params:
89 | linear_start: 0.00085
90 | linear_end: 0.0120
91 | timesteps: 1000
92 |
93 | dataset:
94 | train1:
95 | target: dataset.mef_dataset.MEFDatasetV5
96 | params:
97 | # training file list path
98 | img_dir: /dev/shm/SICE/Dataset_Part1_2expo_train
99 | motion_img_dir: /ailab/group/pjlab-sail/chenzixuan/vimeo_septuplet/sequences
100 | train2:
101 | target: dataset.mef_dataset.MEFDatasetV5
102 | params:
103 | # training file list path
104 | img_dir: /dev/shm/SICE/Dataset_Part2_2expo
105 | motion_img_dir: /ailab/group/pjlab-sail/chenzixuan/vimeo_septuplet/sequences
106 |
107 | train:
108 | # pretrained sd v2.1 path
109 | sd_path: /ailab/user/chenzixuan/Research/pretrained_models/SDv2.1/v2-1_512-ema-pruned.ckpt
110 | # experiment directory path
111 | exp_dir: ./exps/sd_motion_V4_CA_8gpus
112 | learning_rate: 1e-4
113 | # ImageNet 1k (1.3M images)
114 | # batch size = 192, lr = 1e-4, total training steps = 25k
115 | # Our filtered laion2b-en (15M images)
116 | # batch size = 256, lr = 1e-4 (first 30k), 1e-5 (next 50k), total training steps = 80k
117 | batch_size: 4
118 | num_workers: 8
119 | train_steps: 1000000
120 | log_every: 50
121 | ckpt_every: 4000
122 | image_every: 500
123 | resume: ~
--------------------------------------------------------------------------------
/dataset/test_dataset.py:
--------------------------------------------------------------------------------
1 | import os, random, glob, time
2 | import cv2
3 | import torch
4 | import numpy as np
5 | import torch.utils.data as data
6 | from PIL import Image
7 | from torchvision.transforms import ToTensor
8 |
9 |
10 | def get_color_and_struct(isrgb, input_img: torch.Tensor, ksize, sigmaX, c): #input an RGB image
11 |
12 | input_img = input_img.squeeze().cpu().numpy().transpose(1, 2, 0)
13 |
14 | if isrgb==True:
15 | yuv_img = cv2.cvtColor(input_img, cv2.COLOR_RGB2YUV).astype(np.float32)
16 | y = np.expand_dims(yuv_img[:,:,0], axis=-1).astype(np.float64)
17 | u = np.expand_dims(yuv_img[:,:,1], axis=-1).astype(np.float32)
18 | v = np.expand_dims(yuv_img[:,:,2], axis=-1).astype(np.float32)
19 | else:
20 | y = input_img.astype(np.float64)
21 | #mu = gaussian_filter(y, ksize, ksize/6)
22 | mu = cv2.GaussianBlur(y, (ksize,ksize), sigmaX).astype(np.float64)
23 | mu_sq = mu * mu
24 | sigma = np.sqrt(np.absolute(cv2.GaussianBlur(y*y, (ksize,ksize), sigmaX) - mu_sq)).astype(np.float64)
25 | mu = np.expand_dims(mu, axis=-1)
26 | sigma = np.expand_dims(sigma, axis=-1)
27 | dividend = y.astype(np.float64) - mu
28 | divisor = sigma + c
29 | struct = dividend / divisor
30 | struct = struct.astype(np.float32)
31 | struct_norm = (struct - struct.min()) / (struct.max() - struct.min() + 1e-6)
32 | struct_norm = torch.from_numpy(struct_norm).permute(2, 0, 1)
33 | u = torch.from_numpy(u).permute(2, 0, 1)
34 | v = torch.from_numpy(v).permute(2, 0, 1)
35 | img_uv = torch.cat([u, v], dim=0)
36 | return struct_norm, img_uv
37 |
38 |
39 | class TestDataset(data.Dataset):
40 | def __init__(self, dataset):
41 | super(TestDataset, self).__init__()
42 | self.dataset = dataset
43 | self.img_dir_dict = {
44 | 'UltraFusion': './data/UltraFusionBenchmark',
45 | 'MEFB': './data/MEFB',
46 | 'RealHDRV': './data/Real-HDRV-Deghosting-sRGB-Testing',
47 | }
48 | self.ldr_list1 = []
49 | self.ldr_list2 = []
50 | self.file_name_list = []
51 | self.to_tensor = ToTensor()
52 |
53 | self.scene_list = os.listdir(self.img_dir_dict[dataset])
54 | self.scene_list.sort()
55 | for scene in self.scene_list:
56 | if len(os.listdir(os.path.join(self.img_dir_dict[dataset], scene))) > 0:
57 | self.ldr_list1.append(glob.glob(os.path.join(self.img_dir_dict[dataset], scene, '*ue.*'))[0])
58 | self.ldr_list2.append(glob.glob(os.path.join(self.img_dir_dict[dataset], scene, '*oe.*'))[0])
59 | self.file_name_list.append('{}_{}'.format(dataset, scene))
60 |
61 |
62 | def __getitem__(self, index):
63 | ldr1_path = self.ldr_list1[index]
64 | ldr2_path = self.ldr_list2[index]
65 | file_name = self.file_name_list[index]
66 |
67 | ldr1 = Image.open(ldr1_path).convert('RGB')
68 | ldr2 = Image.open(ldr2_path).convert('RGB')
69 |
70 | W, H = ldr1.size
71 |
72 | if W * H >= 6000 * 4000:
73 | ldr1 = ldr1.resize([W // 4, H // 4])
74 | ldr2 = ldr2.resize([W // 4, H // 4])
75 | elif W * H >= 2000 *1500:
76 | ldr1 = ldr1.resize([W * 2 // 5, H * 2 // 5])
77 | ldr2 = ldr2.resize([W * 2 // 5, H * 2 // 5])
78 |
79 | ldr1 = self.to_tensor(ldr1)
80 | ldr2 = self.to_tensor(ldr2)
81 |
82 | return {
83 | 'ue': ldr1,
84 | 'oe': ldr2,
85 | 'file_name': file_name
86 | }
87 |
88 | def __len__(self):
89 | return len(self.ldr_list1)
--------------------------------------------------------------------------------
/examples/0052/oe.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenImagingLab/UltraFusion/b0787c59217fe4877c233d938810a23c8d2c5c73/examples/0052/oe.jpg
--------------------------------------------------------------------------------
/examples/0052/ue.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenImagingLab/UltraFusion/b0787c59217fe4877c233d938810a23c8d2c5c73/examples/0052/ue.jpg
--------------------------------------------------------------------------------
/examples/0072/oe.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenImagingLab/UltraFusion/b0787c59217fe4877c233d938810a23c8d2c5c73/examples/0072/oe.jpg
--------------------------------------------------------------------------------
/examples/0072/ue.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenImagingLab/UltraFusion/b0787c59217fe4877c233d938810a23c8d2c5c73/examples/0072/ue.jpg
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import os, tqdm, math
2 | import torch
3 | import numpy as np
4 | from argparse import ArgumentParser
5 | from collections import OrderedDict
6 | from torch.nn import functional as F
7 | from torch.utils.data import DataLoader
8 | from torchvision.utils import save_image
9 | from torchvision.transforms import ToTensor
10 | from omegaconf import OmegaConf
11 | from accelerate.utils import set_seed
12 |
13 | from dataset.test_dataset import TestDataset, get_color_and_struct
14 | from model.raft.raft import RAFT
15 | from utils.common import instantiate_from_config
16 | from utils.flow import backward_warp, forward_backward_consistency_check, IMF
17 |
18 |
19 | def pad_imgv3(x, crop_size, crop_step):
20 | _, _, h, w = x.size()
21 | n_h = max(math.ceil((h - crop_size) / crop_step), 0)
22 | n_w = max(math.ceil((w - crop_size) / crop_step), 0)
23 | h_target = crop_size + n_h * crop_step
24 | w_target = crop_size + n_w * crop_step
25 | mod_pad_h = h_target - h
26 | mod_pad_w = w_target - w
27 | x_np = x.cpu().numpy()
28 | x_np = np.pad(x_np, pad_width=((0,0),(0,0),(0,mod_pad_h),(0,mod_pad_w)), mode='reflect')
29 | res = torch.from_numpy(x_np).cuda()
30 | return res
31 |
32 |
33 | def pad_img(x, patch_size):
34 | _, _, h, w = x.size()
35 | mod_pad_h = (patch_size - h % patch_size) % patch_size
36 | mod_pad_w = (patch_size - w % patch_size) % patch_size
37 | x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
38 | return x
39 |
40 |
41 | def crop_parallel(img, crop_sz, step):
42 | b, c, h, w = img.shape
43 | h_space = np.arange(0, h - crop_sz + 1, step)
44 | w_space = np.arange(0, w - crop_sz + 1, step)
45 | index = 0
46 | num_h = 0
47 | lr_list=torch.Tensor().to(img.device)
48 | for x in h_space:
49 | num_h += 1
50 | num_w = 0
51 | for y in w_space:
52 | num_w += 1
53 | index += 1
54 | crop_img = img[:, :, x:x + crop_sz, y:y + crop_sz]
55 | lr_list = torch.cat([lr_list, crop_img])
56 | new_h=x + crop_sz # new height after crop
57 | new_w=y + crop_sz # new width after crop
58 | return lr_list, num_h, num_w, new_h, new_w
59 |
60 |
61 | def combine_parallel_wo_artifact(sr_list, num_h, num_w, new_h, new_w, patch_size, step):
62 | p_size = patch_size
63 | pred_lr_list = sr_list
64 |
65 | pred_full_w_list = [] # rectangle
66 | for i in range(num_h):
67 | pred_full_w = torch.zeros([1, 3, p_size, new_w]).cuda()
68 | pred_full_w[:, :, :, 0 : step] = pred_lr_list[i * num_w][:, :, 0 : step]
69 | # pred_full_w[:, :, :, 0 : patch_size] = pred_lr_list[i * num_w][:, :, 0 : patch_size]
70 | pred_full_w[:, :, :, new_w - step :] = pred_lr_list[i * num_w + num_w - 1][:, :, -step:]
71 | for j in range(1, num_w):
72 | repeat_l = j * step
73 | repeat_r = repeat_l + (p_size - step)
74 | ind = i * num_w + j - 1
75 |
76 | pred_full_w[:, :, :, repeat_r : repeat_r + 2 * step - patch_size] = pred_lr_list[ind + 1][:, :, patch_size - step : step]
77 |
78 | for k in range(repeat_l, repeat_r):
79 | alpha = (k - repeat_l) / (repeat_r - repeat_l)
80 | pred_full_w[:, :, :, k] = pred_lr_list[ind][:, :, step + k - repeat_l] * (1 - alpha) + pred_lr_list[ind + 1][:, :, k - repeat_l] * alpha
81 | # pred_full_w[:, :, :, k] = pred_full_w[:, :, :, k] * (1 - alpha) + pred_lr_list[ind + 1][:, :, k - repeat_l] * alpha
82 | pred_full_w_list.append(pred_full_w)
83 |
84 | pred = torch.zeros([1, 3, new_h, new_w], device=sr_list[0].device)
85 | pred[:, :, 0 : step, :] = pred_full_w_list[0][:, :, 0 : step, :]
86 | # pred[:, :, 0 : patch_size, :] = pred_full_w_list[0][:, :, 0 : patch_size, :]
87 | pred[:, :, -step :, :] = pred_full_w_list[-1][:, :, -step :, :]
88 | for i in range(1, num_h):
89 | repeat_u = i * step
90 | repeat_d = repeat_u + (p_size - step)
91 | for k in range(repeat_u, repeat_d):
92 | alpha = (k - repeat_u) / (repeat_d - repeat_u)
93 | pred[:, :, k, :] = pred_full_w_list[i - 1][:, :, step + k - repeat_u, :] * (1 - alpha) + pred_full_w_list[i][:, :, k - repeat_u, :] * alpha
94 | # pred[:, :, k, :] = pred[:, :, k, :] * (1 - alpha) + pred_full_w_list[i][:, :, k - repeat_u, :] * alpha
95 | return pred, pred_full_w_list
96 |
97 |
98 | def mef(img1, img2, img_name, flow_model, pipe, args, consistent_start=None):
99 | _, _, H, W = img2.shape
100 | img1 = pad_img(img1, 16)
101 | img2 = pad_img(img2, 16)
102 | img1_light = IMF(img1, img2)
103 | with torch.no_grad():
104 | _, img12_flow = flow_model(img1_light * 2 - 1, img2 * 2 - 1, iters=20, test_mode=True)
105 | _, img21_flow = flow_model(img2 * 2 - 1, img1_light * 2 - 1, iters=20, test_mode=True)
106 |
107 | img12 = backward_warp(img1, img21_flow)
108 | _, occ_mask = forward_backward_consistency_check(img12_flow, img21_flow)
109 | occ_mask = occ_mask.unsqueeze(dim=1)
110 | img12_mask = img12 * (1. - occ_mask)
111 |
112 | img1 = img1[:, :, :H, :W]
113 | img2 = img2[:, :, :H, :W]
114 | img12 = img12[:, :, :H, :W]
115 | img12_mask = img12_mask[:, :, :H, :W]
116 | occ_mask = occ_mask[:, :, :H, :W]
117 |
118 | if not args.prealign:
119 | img12_mask = img1 # cancel pre-align
120 |
121 |
122 | img2 = pad_imgv3(img2, args.tile_size, args.tile_stride)
123 | img12_mask = pad_imgv3(img12_mask, args.tile_size, args.tile_stride)
124 |
125 | img1_mscn_norm, img1_color = get_color_and_struct(isrgb=True, input_img=img12_mask, ksize=7, sigmaX=0, c=0.0000001)
126 | img1_mscn_norm, img1_color = img1_mscn_norm.unsqueeze(dim=0), img1_color.unsqueeze(dim=0)
127 | img1_mscn_norm, img1_color = img1_mscn_norm.cuda(), img1_color.cuda()
128 | fidelity_input = torch.cat([img2, img1_mscn_norm, img1_color], dim=1)
129 |
130 | img2_patches, num_h, num_w, new_h, new_w = crop_parallel(img2, args.tile_size, args.tile_stride)
131 | img1_struct_patches, num_h, num_w, new_h, new_w = crop_parallel(img1_mscn_norm, args.tile_size, args.tile_stride)
132 | img1_color_patches, num_h, num_w, new_h, new_w = crop_parallel(img1_color, args.tile_size, args.tile_stride)
133 | img2_patches_list = torch.split(img2_patches, 1, dim=0)
134 | img1_struct_patches_list = torch.split(img1_struct_patches, 1, dim=0)
135 | img1_color_patches_list = torch.split(img1_color_patches, 1, dim=0)
136 | fidelity_input_patches, num_h, num_w, new_h, new_w = crop_parallel(fidelity_input, args.tile_size, args.tile_stride)
137 | fidelity_input_patches_list = torch.split(fidelity_input_patches, 1, dim=0)
138 | out_list = []
139 | for ind, (img2_, img1_struct_, img1_color_, fidelity_input_) in enumerate(zip(img2_patches_list, img1_struct_patches_list, img1_color_patches_list, fidelity_input_patches_list)):
140 | set_seed(args.seed)
141 | out = pipe.run(lq2=img2_, lq1_mscn_norm=img1_struct_, lq1_color=img1_color_, tiled=args.tiled, tile_size=args.tile_size, tile_stride=args.tile_stride, cond_fn=cond_fn, fidelity_input=fidelity_input_, consistent_start=consistent_start) # [-1, 1]
142 | out_list.append(out)
143 |
144 | out_list = torch.cat(out_list, dim=0)
145 | out, _ = combine_parallel_wo_artifact(out_list, num_h, num_w, new_h, new_w, args.tile_size, args.tile_stride)
146 |
147 | out = out[:, :, :H, :W]
148 | img1 = img1[:, :, :H, :W]
149 | img1_light = img1_light[:, :, :H, :W]
150 | img12 = img12[:, :, :H, :W]
151 | img2 = img2[:, :, :H, :W]
152 | occ_mask = occ_mask[:, :H, :W]
153 | img12_mask = img12_mask[:, :, :H, :W]
154 | img1_mscn_norm = img1_mscn_norm[:, :, :H, :W]
155 | img1_color = img1_color[:, :, :H, :W]
156 |
157 | u = torch.zeros_like(out)
158 | v = torch.zeros_like(out)
159 | u[:, 1:, :, :] = img1_color
160 | v[:, :2, :, :] = img1_color
161 |
162 | save_image((out + 1) / 2, '{}/{}_out_{}.png'.format(args.output, img_name, 'align' if args.prealign else 'noalign'))
163 | if args.save_all:
164 | save_image(img1, '{}/{}_ue.png'.format(args.output, img_name))
165 | save_image(img1_light, '{}/{}_ue_imf.png'.format(args.output, img_name))
166 | save_image(img12_mask, '{}/{}_ue2oe_mask_{}.png'.format(args.output, img_name, 'align' if args.prealign else 'noalign'))
167 | save_image(img12, '{}/{}_ue2oe_{}.png'.format(args.output, img_name, 'align' if args.prealign else 'noalign'))
168 | save_image(img1_mscn_norm, '{}/{}_ue2oe_mask_mscn_{}.png'.format(args.output, img_name, 'align' if args.prealign else 'noalign'))
169 | save_image(img2, '{}/{}_oe.png'.format(args.output, img_name))
170 | save_image(occ_mask, '{}/{}_occmask_{}.png'.format(args.output, img_name, 'align' if args.prealign else 'noalign'))
171 | save_image(u, '{}/{}_u.png'.format(args.output, img_name))
172 | save_image(v, '{}/{}_v.png'.format(args.output, img_name))
173 |
174 | return out
175 |
176 | parser = ArgumentParser()
177 | parser.add_argument("--dataset", type=str, default='MEFB')
178 | parser.add_argument("--output", default='results', type=str)
179 | parser.add_argument("--tiled", action='store_true', default=False)
180 | parser.add_argument("--tile_size", type=int, default=512)
181 | parser.add_argument("--tile_stride", type=int, default=256)
182 | parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda", "mps"])
183 | parser.add_argument("--seed", type=int, default=231, choices=["cpu", "cuda", "mps"])
184 | parser.add_argument("--prealign", action='store_true', default=False)
185 | parser.add_argument("--save_all", action='store_true', default=False)
186 | args = parser.parse_args()
187 |
188 | from model.V4_CA.cldm import ControlLDM
189 | from model.V4_CA.gaussian_diffusion import Diffusion
190 | from pipeline.V4_CA.pipeline import UltraFusionPipeline
191 | ### load uent, vae, clip
192 | cldm: ControlLDM = instantiate_from_config(OmegaConf.load("configs/ultrafusion.yaml").model.cldm)
193 | sd = torch.load('ckpts/v2-1_512-ema-pruned.ckpt', map_location="cpu")["state_dict"]
194 | unused = cldm.load_pretrained_sd(sd)
195 | print(f"strictly load pretrained sd_v2.1, unused weights: {unused}")
196 | ### load controlnet
197 | control_sd = torch.load('ckpts/ultrafusion.pt', map_location="cpu")
198 | cldm.load_controlnet_from_ckpt(control_sd)
199 | print(f"strictly load controlnet weight")
200 | cldm.eval().to(args.device)
201 | ### load fidelity encoder
202 | fidelity_encoder = instantiate_from_config(OmegaConf.load("configs/fcb.yaml").model.fidelity_encoder)
203 | fidelity_encoder_sd = torch.load('ckpts/fcb.pt')
204 | fidelity_encoder.load_state_dict(fidelity_encoder_sd, strict=True)
205 | fidelity_encoder = fidelity_encoder.cuda()
206 | fidelity_encoder.eval()
207 | ### load diffusion
208 | diffusion: Diffusion = instantiate_from_config(OmegaConf.load("configs/ultrafusion.yaml").model.diffusion)
209 | diffusion.to(args.device)
210 | ### load flow model
211 | flow_state_dict = torch.load('ckpts/raft-sintel.pth', map_location='cpu')
212 | flow_model = RAFT(args).cuda()
213 | new_flow_state_dict = OrderedDict()
214 | for k, v in flow_state_dict.items():
215 | new_flow_state_dict[k.replace("module.", "")] = v
216 | flow_model.load_state_dict(new_flow_state_dict)
217 | flow_model = flow_model.eval()
218 | cond_fn = None
219 |
220 | pipe = UltraFusionPipeline(cldm=cldm, diffusion=diffusion, fidelity_encoder=fidelity_encoder, device=args.device)
221 |
222 | to_tensor = ToTensor()
223 |
224 | dataset = TestDataset(args.dataset)
225 | dataloader = DataLoader(
226 | dataset,
227 | shuffle=False,
228 | batch_size=1,
229 | num_workers=0
230 | )
231 |
232 | if not os.path.exists(args.output):
233 | os.mkdir(args.output)
234 | args.output = os.path.join(args.output, args.dataset)
235 | if not os.path.exists(args.output):
236 | os.mkdir(args.output)
237 |
238 | for batch in dataloader:
239 | ue = batch['ue'].cuda()
240 | oe = batch['oe'].cuda()
241 | img_name = batch['file_name'][0]
242 |
243 | _ = mef(img1=ue, img2=oe, img_name=img_name, flow_model=flow_model, pipe=pipe, args=args, consistent_start=None)
--------------------------------------------------------------------------------
/inference.sh:
--------------------------------------------------------------------------------
1 | # Final results
2 | CUDA_VISIBLE_DEVICES=0 python inference.py --dataset UltraFusion --output results --tiled --tile_size 512 --tile_stride 256 --save_all
--------------------------------------------------------------------------------
/model/V4_CA/attention.py:
--------------------------------------------------------------------------------
1 | from packaging import version
2 | import torch
3 | import torch.nn.functional as F
4 | from torch import nn, einsum
5 | from einops import rearrange, repeat
6 | from typing import Optional, Any
7 |
8 | from model.util import (
9 | checkpoint, zero_module, exists, default
10 | )
11 | from model.config import Config, AttnMode
12 |
13 |
14 | # CrossAttn precision handling
15 | import os
16 | _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
17 |
18 |
19 | # feedforward
20 | class GEGLU(nn.Module):
21 | def __init__(self, dim_in, dim_out):
22 | super().__init__()
23 | self.proj = nn.Linear(dim_in, dim_out * 2)
24 |
25 | def forward(self, x):
26 | x, gate = self.proj(x).chunk(2, dim=-1)
27 | return x * F.gelu(gate)
28 |
29 |
30 | class FeedForward(nn.Module):
31 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
32 | super().__init__()
33 | inner_dim = int(dim * mult)
34 | dim_out = default(dim_out, dim)
35 | project_in = nn.Sequential(
36 | nn.Linear(dim, inner_dim),
37 | nn.GELU()
38 | ) if not glu else GEGLU(dim, inner_dim)
39 |
40 | self.net = nn.Sequential(
41 | project_in,
42 | nn.Dropout(dropout),
43 | nn.Linear(inner_dim, dim_out)
44 | )
45 |
46 | def forward(self, x):
47 | return self.net(x)
48 |
49 |
50 | def Normalize(in_channels):
51 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
52 |
53 |
54 | class CrossAttention(nn.Module):
55 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
56 | super().__init__()
57 | print(f"Setting up {self.__class__.__name__} (vanilla). Query dim is {query_dim}, context_dim is {context_dim} and using "
58 | f"{heads} heads.")
59 | inner_dim = dim_head * heads
60 | context_dim = default(context_dim, query_dim)
61 |
62 | self.scale = dim_head ** -0.5
63 | self.heads = heads
64 |
65 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
66 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
67 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
68 |
69 | self.to_out = nn.Sequential(
70 | nn.Linear(inner_dim, query_dim),
71 | nn.Dropout(dropout)
72 | )
73 |
74 | def forward(self, x, context=None, mask=None):
75 | h = self.heads
76 |
77 | q = self.to_q(x)
78 | context = default(context, x)
79 | k = self.to_k(context)
80 | v = self.to_v(context)
81 |
82 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
83 |
84 | # force cast to fp32 to avoid overflowing
85 | if _ATTN_PRECISION =="fp32":
86 | # with torch.autocast(enabled=False, device_type = 'cuda'):
87 | with torch.autocast(enabled=False, device_type="cuda" if str(x.device).startswith("cuda") else "cpu"):
88 | q, k = q.float(), k.float()
89 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
90 | else:
91 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
92 |
93 | del q, k
94 |
95 | if exists(mask):
96 | mask = rearrange(mask, 'b ... -> b (...)')
97 | max_neg_value = -torch.finfo(sim.dtype).max
98 | mask = repeat(mask, 'b j -> (b h) () j', h=h)
99 | sim.masked_fill_(~mask, max_neg_value)
100 |
101 | # attention, what we cannot get enough of
102 | sim = sim.softmax(dim=-1)
103 |
104 | out = einsum('b i j, b j d -> b i d', sim, v)
105 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
106 | return self.to_out(out)
107 |
108 |
109 | class MemoryEfficientCrossAttention(nn.Module):
110 | # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
111 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
112 | super().__init__()
113 | print(f"Setting up {self.__class__.__name__} (xformers). Query dim is {query_dim}, context_dim is {context_dim} and using "
114 | f"{heads} heads.")
115 | inner_dim = dim_head * heads
116 | context_dim = default(context_dim, query_dim)
117 |
118 | self.heads = heads
119 | self.dim_head = dim_head
120 |
121 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
122 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
123 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
124 |
125 | self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
126 | self.attention_op: Optional[Any] = None
127 |
128 | def forward(self, x, context=None, mask=None):
129 | q = self.to_q(x)
130 | context = default(context, x)
131 | k = self.to_k(context)
132 | v = self.to_v(context)
133 |
134 | b, _, _ = q.shape
135 | q, k, v = map(
136 | lambda t: t.unsqueeze(3)
137 | .reshape(b, t.shape[1], self.heads, self.dim_head)
138 | .permute(0, 2, 1, 3)
139 | .reshape(b * self.heads, t.shape[1], self.dim_head)
140 | .contiguous(),
141 | (q, k, v),
142 | )
143 |
144 | # actually compute the attention, what we cannot get enough of
145 | out = Config.xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
146 |
147 | if exists(mask):
148 | raise NotImplementedError
149 | out = (
150 | out.unsqueeze(0)
151 | .reshape(b, self.heads, out.shape[1], self.dim_head)
152 | .permute(0, 2, 1, 3)
153 | .reshape(b, out.shape[1], self.heads * self.dim_head)
154 | )
155 | return self.to_out(out)
156 |
157 |
158 | class SDPCrossAttention(nn.Module):
159 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
160 | super().__init__()
161 | print(f"Setting up {self.__class__.__name__} (sdp). Query dim is {query_dim}, context_dim is {context_dim} and using "
162 | f"{heads} heads.")
163 | inner_dim = dim_head * heads
164 | context_dim = default(context_dim, query_dim)
165 |
166 | self.heads = heads
167 | self.dim_head = dim_head
168 |
169 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
170 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
171 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
172 |
173 | self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
174 |
175 | def forward(self, x, context=None, mask=None):
176 | q = self.to_q(x)
177 | context = default(context, x)
178 | k = self.to_k(context)
179 | v = self.to_v(context)
180 |
181 | b, _, _ = q.shape
182 | q, k, v = map(
183 | lambda t: t.unsqueeze(3)
184 | .reshape(b, t.shape[1], self.heads, self.dim_head)
185 | .permute(0, 2, 1, 3)
186 | .reshape(b * self.heads, t.shape[1], self.dim_head)
187 | .contiguous(),
188 | (q, k, v),
189 | )
190 |
191 | # actually compute the attention, what we cannot get enough of
192 | out = F.scaled_dot_product_attention(q, k, v)
193 |
194 | if exists(mask):
195 | raise NotImplementedError
196 | out = (
197 | out.unsqueeze(0)
198 | .reshape(b, self.heads, out.shape[1], self.dim_head)
199 | .permute(0, 2, 1, 3)
200 | .reshape(b, out.shape[1], self.heads * self.dim_head)
201 | )
202 | return self.to_out(out)
203 |
204 |
205 | class BasicTransformerBlock(nn.Module):
206 | ATTENTION_MODES = {
207 | AttnMode.VANILLA: CrossAttention, # vanilla attention
208 | AttnMode.XFORMERS: MemoryEfficientCrossAttention,
209 | AttnMode.SDP: SDPCrossAttention
210 | }
211 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
212 | disable_self_attn=False):
213 | super().__init__()
214 | attn_cls = self.ATTENTION_MODES[Config.attn_mode]
215 | self.disable_self_attn = disable_self_attn
216 | self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
217 | context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
218 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
219 | self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
220 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
221 | self.norm1 = nn.LayerNorm(dim)
222 | self.norm2 = nn.LayerNorm(dim)
223 | self.norm3 = nn.LayerNorm(dim)
224 | self.checkpoint = checkpoint
225 |
226 | def forward(self, x, context=None):
227 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
228 |
229 | def _forward(self, x, context=None):
230 | x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
231 | x = self.attn2(self.norm2(x), context=context) + x
232 | x = self.ff(self.norm3(x)) + x
233 | return x
234 |
235 |
236 | class SpatialTransformer(nn.Module):
237 | """
238 | Transformer block for image-like data.
239 | First, project the input (aka embedding)
240 | and reshape to b, t, d.
241 | Then apply standard transformer action.
242 | Finally, reshape to image
243 | NEW: use_linear for more efficiency instead of the 1x1 convs
244 | """
245 | def __init__(self, in_channels, n_heads, d_head,
246 | depth=1, dropout=0., context_dim=None,
247 | disable_self_attn=False, use_linear=False,
248 | use_checkpoint=True):
249 | super().__init__()
250 | if exists(context_dim) and not isinstance(context_dim, list):
251 | context_dim = [context_dim]
252 | self.in_channels = in_channels
253 | inner_dim = n_heads * d_head
254 | self.norm = Normalize(in_channels)
255 | if not use_linear:
256 | self.proj_in = nn.Conv2d(in_channels,
257 | inner_dim,
258 | kernel_size=1,
259 | stride=1,
260 | padding=0)
261 | else:
262 | self.proj_in = nn.Linear(in_channels, inner_dim)
263 |
264 | self.transformer_blocks = nn.ModuleList(
265 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
266 | disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
267 | for d in range(depth)]
268 | )
269 | if not use_linear:
270 | self.proj_out = zero_module(nn.Conv2d(inner_dim,
271 | in_channels,
272 | kernel_size=1,
273 | stride=1,
274 | padding=0))
275 | else:
276 | self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
277 | self.use_linear = use_linear
278 |
279 | def forward(self, x, context=None):
280 | # note: if no context is given, cross-attention defaults to self-attention
281 | if not isinstance(context, list):
282 | context = [context]
283 | b, c, h, w = x.shape
284 | x_in = x
285 | x = self.norm(x)
286 | if not self.use_linear:
287 | x = self.proj_in(x)
288 | x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
289 | if self.use_linear:
290 | x = self.proj_in(x)
291 | for i, block in enumerate(self.transformer_blocks):
292 | x = block(x, context=context[i])
293 | if self.use_linear:
294 | x = self.proj_out(x)
295 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
296 | if not self.use_linear:
297 | x = self.proj_out(x)
298 | return x + x_in
299 |
--------------------------------------------------------------------------------
/model/V4_CA/cldm.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, Set, List, Dict
2 |
3 | import torch
4 | from torch import nn
5 |
6 | from model.V4_CA.controlnet import (
7 | ControlledUnetModel, ControlNet,
8 | )
9 | from model.V4_CA.vae import AutoencoderKL
10 | from model.clip import FrozenOpenCLIPEmbedder
11 |
12 | from utils.common import sliding_windows, count_vram_usage, gaussian_weights
13 | from torchvision.utils import save_image
14 |
15 |
16 | def disabled_train(self: nn.Module) -> nn.Module:
17 | """Overwrite model.train with this function to make sure train/eval mode
18 | does not change anymore."""
19 | return self
20 |
21 |
22 | class ControlLDM(nn.Module):
23 |
24 | def __init__(
25 | self,
26 | unet_cfg,
27 | vae_cfg,
28 | clip_cfg,
29 | controlnet_cfg,
30 | latent_scale_factor
31 | ):
32 | super().__init__()
33 | self.unet = ControlledUnetModel(**unet_cfg)
34 | self.vae = AutoencoderKL(**vae_cfg)
35 | self.clip = FrozenOpenCLIPEmbedder(**clip_cfg)
36 | self.controlnet = ControlNet(**controlnet_cfg)
37 | self.scale_factor = latent_scale_factor
38 | self.control_scales = [1.0] * 13
39 |
40 | @torch.no_grad()
41 | def load_pretrained_sd(self, sd: Dict[str, torch.Tensor]) -> Set[str]:
42 | module_map = {
43 | "unet": "model.diffusion_model",
44 | "vae": "first_stage_model",
45 | "clip": "cond_stage_model",
46 | }
47 | modules = [("unet", self.unet), ("vae", self.vae), ("clip", self.clip)]
48 | used = set()
49 | for name, module in modules:
50 | init_sd = {}
51 | scratch_sd = module.state_dict()
52 | for key in scratch_sd:
53 | target_key = ".".join([module_map[name], key])
54 | init_sd[key] = sd[target_key].clone()
55 | used.add(target_key)
56 | module.load_state_dict(init_sd, strict=True)
57 | unused = set(sd.keys()) - used
58 | # NOTE: this is slightly different from previous version, which haven't switched
59 | # the UNet to eval mode and disabled the requires_grad flag.
60 | for module in [self.vae, self.clip, self.unet]:
61 | module.eval()
62 | module.train = disabled_train
63 | for p in module.parameters():
64 | p.requires_grad = False
65 | return unused
66 |
67 | @torch.no_grad()
68 | def load_controlnet_from_ckpt(self, sd: Dict[str, torch.Tensor]) -> None:
69 | self.controlnet.load_state_dict(sd, strict=True)
70 |
71 | @torch.no_grad()
72 | def load_controlnet_from_unet(self) -> Tuple[Set[str]]:
73 | unet_sd = self.unet.state_dict()
74 | scratch_sd = self.controlnet.state_dict()
75 | init_sd = {}
76 | init_with_new_zero = set()
77 | init_with_scratch = set()
78 | for key in scratch_sd:
79 | if key in unet_sd:
80 | this, target = scratch_sd[key], unet_sd[key]
81 | if this.size() == target.size():
82 | init_sd[key] = target.clone()
83 | elif this.size(1) > target.size(1):
84 | print(this.size(1), target.size(1))
85 | d_ic = this.size(1) - target.size(1)
86 | oc, _, h, w = this.size()
87 | zeros = torch.zeros((oc, d_ic, h, w), dtype=target.dtype)
88 | init_sd[key] = torch.cat((target, zeros), dim=1)
89 | init_with_new_zero.add(key)
90 | else:
91 | print(this.size(1), target.size(1))
92 | d_ic = this.size(1)
93 | print(target.shape)
94 | init_sd[key] = target[:, :d_ic, :, :]
95 | else:
96 | init_sd[key] = scratch_sd[key].clone()
97 | init_with_scratch.add(key)
98 | self.controlnet.load_state_dict(init_sd, strict=True)
99 | return init_with_new_zero, init_with_scratch
100 |
101 | def vae_encode(self, image: torch.Tensor, sample: bool=True) -> torch.Tensor:
102 | if sample:
103 | return self.vae.encode(image).sample() * self.scale_factor
104 | else:
105 | return self.vae.encode(image).mode() * self.scale_factor
106 |
107 | def vae_encode_tiled(self, image: torch.Tensor, tile_size: int, tile_stride: int, sample: bool=True, fidelity_encoder=None, fidelity_input: torch.Tensor=None) -> torch.Tensor:
108 | bs, _, h, w = image.shape
109 | z = torch.zeros((bs, 4, h // 8, w // 8), dtype=torch.float32, device=image.device)
110 | count = torch.zeros_like(z, dtype=torch.float32)
111 | weights = gaussian_weights(tile_size // 8, tile_size // 8)[None, None]
112 | weights = torch.tensor(weights, dtype=torch.float32, device=image.device)
113 | tiles = sliding_windows(h // 8, w // 8, tile_size // 8, tile_stride // 8)
114 | skip_feats = []
115 | for hi, hi_end, wi, wi_end in tiles:
116 | tile_image = image[:, :, hi * 8:hi_end * 8, wi * 8:wi_end * 8]
117 | if fidelity_input is not None:
118 | tile_fidelity_input = fidelity_input[:, :, hi * 8:hi_end * 8, wi * 8:wi_end * 8]
119 | z[:, :, hi:hi_end, wi:wi_end] += self.vae_encode(tile_image, sample=sample) * weights
120 | with torch.no_grad():
121 | # tile_images = torch.cat([tile_image_ue, tile_image], dim=1)
122 | # skip_feat = fidelity_encoder(tile_images)
123 | if fidelity_input is not None:
124 | skip_feat = fidelity_encoder(tile_fidelity_input)
125 | skip_feats.append(skip_feat)
126 | else:
127 | skip_feats.append(None)
128 | count[:, :, hi:hi_end, wi:wi_end] += weights
129 | z.div_(count)
130 | return z, skip_feats
131 |
132 | def vae_decode(self, z: torch.Tensor, skip_feat=None) -> torch.Tensor:
133 | return self.vae.decode(z / self.scale_factor, skip_feat)
134 |
135 |
136 | def calc_mean_std(self, feat, eps=1e-5):
137 | # eps is a small value added to the variance to avoid divide-by-zero.
138 | size = feat.size()
139 | assert (len(size) == 4)
140 | N, C = size[:2]
141 | feat_var = feat.contiguous().view(N, C, -1).var(dim=2) + eps
142 | feat_std = feat_var.sqrt().view(N, C, 1, 1)
143 | feat_mean = feat.contiguous().view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
144 | return feat_mean, feat_std
145 |
146 |
147 | def adaptive_instance_normalization(self, content_feat, style_feat):
148 | assert (content_feat.size()[:2] == style_feat.size()[:2])
149 | size = content_feat.size()
150 | # style_mean = torch.tensor([0.6906, 0.6766, 0.6749]).unsqueeze(dim=0).unsqueeze(dim=2).unsqueeze(dim=3).cuda()
151 | # style_std = torch.tensor([0.1955, 0.2096, 0.2236]).unsqueeze(dim=0).unsqueeze(dim=2).unsqueeze(dim=3).cuda()
152 | style_mean, style_std = self.calc_mean_std(style_feat)
153 | content_mean, content_std = self.calc_mean_std(content_feat)
154 | normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
155 |
156 | k = (style_std) / (content_std)
157 | b = style_mean - content_mean * style_std / content_std
158 | res = normalized_feat * style_std.expand(size) + style_mean.expand(size)
159 | return res
160 |
161 |
162 | @count_vram_usage
163 | def vae_decode_tiled(self, z: torch.Tensor, tile_size: int, tile_stride: int, skip_feats=None, consistent_start=None) -> torch.Tensor:
164 | bs, _, h, w = z.shape
165 | image = torch.zeros((bs, 3, h * 8, w * 8), dtype=torch.float32, device=z.device)
166 | count = torch.zeros_like(image, dtype=torch.float32)
167 | weights = gaussian_weights(tile_size * 8, tile_size * 8)[None, None]
168 | weights = torch.tensor(weights, dtype=torch.float32, device=z.device)
169 | tiles = sliding_windows(h, w, tile_size, tile_stride)
170 | for ind, ((hi, hi_end, wi, wi_end), skip_feat) in enumerate(zip(tiles, skip_feats)):
171 | tile_z = z[:, :, hi:hi_end, wi:wi_end]
172 | tile_z_decoded = self.vae_decode(tile_z, skip_feat)
173 | if consistent_start is not None:
174 | tile_z_decoded = self.adaptive_instance_normalization(tile_z_decoded, consistent_start[:, :, hi * 8:hi_end * 8, wi * 8:wi_end * 8])
175 | image[:, :, hi * 8:hi_end * 8, wi * 8:wi_end * 8] += tile_z_decoded * weights
176 | count[:, :, hi * 8:hi_end * 8, wi * 8:wi_end * 8] += weights
177 | image.div_(count)
178 | return image
179 |
180 |
181 | def prepare_condition(self, lq2: torch.Tensor, lq1_mscn_norm: torch.Tensor, lq1_color: torch.Tensor, txt: List[str]) -> Dict[str, torch.Tensor]:
182 | # Note the lq_vis should be normalized to [-1, 1] and lq_ifr normalized to [0, 1]!!!
183 | return dict(
184 | c_txt=self.clip.encode(txt),
185 | c_lq2=self.vae_encode(lq2, sample=False),
186 | c_lq1_mscn_norm=lq1_mscn_norm,
187 | c_lq1_color=lq1_color
188 | )
189 |
190 |
191 | @count_vram_usage
192 | def prepare_condition_tiled(self, lq2: torch.Tensor, lq1_mscn_norm:torch.Tensor, lq1_color: torch.Tensor, txt: List[str], tile_size: int, tile_stride: int, fidelity_encoder=None, fidelity_input: torch.Tensor=None) -> Dict[str, torch.Tensor]:
193 | c_lq2, skip_feats = self.vae_encode_tiled(lq2, tile_size, tile_stride, sample=False, fidelity_encoder=fidelity_encoder, fidelity_input=fidelity_input) # why smaple = False ?
194 | return dict(
195 | c_txt=self.clip.encode(txt),
196 | c_lq2=c_lq2,
197 | c_lq1_mscn_norm = lq1_mscn_norm,
198 | c_lq1_color = lq1_color
199 | ), skip_feats
200 |
201 | def forward(self, x_noisy, t, cond):
202 | c_txt = cond["c_txt"]
203 |
204 | c_lq2 = cond["c_lq2"]
205 | c_lq1_mscn_norm = cond["c_lq1_mscn_norm"]
206 | c_lq1_color = cond["c_lq1_color"]
207 | c_img = [c_lq2, c_lq1_mscn_norm, c_lq1_color]
208 |
209 | control = self.controlnet(
210 | x=x_noisy, hint=c_img,
211 | timesteps=t, context=c_txt
212 | )
213 | control = [c * scale for c, scale in zip(control, self.control_scales)]
214 | eps = self.unet(
215 | x=x_noisy, timesteps=t,
216 | context=c_txt, control=control, only_mid_control=False
217 | )
218 | return eps
219 |
--------------------------------------------------------------------------------
/model/V4_CA/controlnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch as th
3 | import torch.nn as nn
4 |
5 | from model.util import (
6 | conv_nd,
7 | linear,
8 | zero_module,
9 | timestep_embedding,
10 | exists
11 | )
12 | from model.V4_CA.attention import SpatialTransformer
13 | from model.V4_CA.unet import (
14 | TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock, UNetModel
15 | )
16 | from model.V4_CA.cross_attention import CrossTransformerBlock2D
17 |
18 |
19 | class ControlledUnetModel(UNetModel):
20 |
21 | def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs):
22 | hs = []
23 | t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
24 | emb = self.time_embed(t_emb)
25 | h = x.type(self.dtype)
26 | for module in self.input_blocks:
27 | h = module(h, emb, context)
28 | hs.append(h)
29 | h = self.middle_block(h, emb, context)
30 |
31 | if control is not None:
32 | h += control.pop()
33 |
34 | for i, module in enumerate(self.output_blocks):
35 | if only_mid_control or control is None:
36 | h = torch.cat([h, hs.pop()], dim=1)
37 | else:
38 | h = torch.cat([h, hs.pop() + control.pop()], dim=1)
39 | h = module(h, emb, context)
40 |
41 | h = h.type(x.dtype)
42 | return self.out(h)
43 |
44 |
45 | class ControlNet(nn.Module):
46 |
47 | def __init__(
48 | self,
49 | image_size,
50 | in_channels,
51 | model_channels,
52 | hint_channels,
53 | num_res_blocks,
54 | attention_resolutions,
55 | dropout=0,
56 | channel_mult=(1, 2, 4, 8),
57 | conv_resample=True,
58 | dims=2,
59 | use_checkpoint=False,
60 | use_fp16=False,
61 | num_heads=-1,
62 | num_head_channels=-1,
63 | num_heads_upsample=-1,
64 | use_scale_shift_norm=False,
65 | resblock_updown=False,
66 | use_new_attention_order=False,
67 | use_spatial_transformer=False, # custom transformer support
68 | transformer_depth=1, # custom transformer support
69 | context_dim=None, # custom transformer support
70 | n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
71 | legacy=True,
72 | disable_self_attentions=None,
73 | num_attention_blocks=None,
74 | disable_middle_self_attn=False,
75 | use_linear_in_transformer=False,
76 | ):
77 | super().__init__()
78 | if use_spatial_transformer:
79 | assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
80 |
81 | if context_dim is not None:
82 | assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
83 | from omegaconf.listconfig import ListConfig
84 | if type(context_dim) == ListConfig:
85 | context_dim = list(context_dim)
86 |
87 | if num_heads_upsample == -1:
88 | num_heads_upsample = num_heads
89 |
90 | if num_heads == -1:
91 | assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
92 |
93 | if num_head_channels == -1:
94 | assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
95 |
96 | self.dims = dims
97 | self.image_size = image_size
98 | self.in_channels = in_channels
99 | self.model_channels = model_channels
100 | if isinstance(num_res_blocks, int):
101 | self.num_res_blocks = len(channel_mult) * [num_res_blocks]
102 | else:
103 | if len(num_res_blocks) != len(channel_mult):
104 | raise ValueError("provide num_res_blocks either as an int (globally constant) or "
105 | "as a list/tuple (per-level) with the same length as channel_mult")
106 | self.num_res_blocks = num_res_blocks
107 | if disable_self_attentions is not None:
108 | # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
109 | assert len(disable_self_attentions) == len(channel_mult)
110 | if num_attention_blocks is not None:
111 | assert len(num_attention_blocks) == len(self.num_res_blocks)
112 | assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
113 | print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
114 | f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
115 | f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
116 | f"attention will still not be set.")
117 |
118 | self.attention_resolutions = attention_resolutions
119 | self.dropout = dropout
120 | self.channel_mult = channel_mult
121 | self.conv_resample = conv_resample
122 | self.use_checkpoint = use_checkpoint
123 | self.dtype = th.float16 if use_fp16 else th.float32
124 | self.num_heads = num_heads
125 | self.num_head_channels = num_head_channels
126 | self.num_heads_upsample = num_heads_upsample
127 | self.predict_codebook_ids = n_embed is not None
128 |
129 | time_embed_dim = model_channels * 4
130 | self.time_embed = nn.Sequential(
131 | linear(model_channels, time_embed_dim),
132 | nn.SiLU(),
133 | linear(time_embed_dim, time_embed_dim),
134 | )
135 |
136 | self.input_blocks = nn.ModuleList(
137 | [
138 | TimestepEmbedSequential(
139 | conv_nd(dims, in_channels + hint_channels, model_channels, 3, padding=1)
140 | )
141 | ]
142 | )
143 | self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
144 |
145 | self.cross_expo_block = nn.ModuleList(
146 | [
147 | CrossTransformerBlock2D(dim=model_channels * channel_mult[0], num_heads=5),
148 | CrossTransformerBlock2D(dim=model_channels * channel_mult[0], num_heads=5),
149 | CrossTransformerBlock2D(dim=model_channels * channel_mult[1], num_heads=10),
150 | CrossTransformerBlock2D(dim=model_channels * channel_mult[2], num_heads=20),
151 | CrossTransformerBlock2D(dim=model_channels * channel_mult[3], num_heads=20),
152 | ]
153 | )
154 |
155 | self.struct_hint_block = TimestepEmbedSequential(
156 | conv_nd(dims, 1, 16, 3, padding=1),
157 | nn.SiLU(),
158 | conv_nd(dims, 16, 16, 3, padding=1),
159 | nn.SiLU(),
160 | conv_nd(dims, 16, 32, 3, padding=1, stride=2),
161 | nn.SiLU(),
162 | conv_nd(dims, 32, 32, 3, padding=1),
163 | nn.SiLU(),
164 | conv_nd(dims, 32, 96, 3, padding=1, stride=2),
165 | nn.SiLU(),
166 | conv_nd(dims, 96, 96, 3, padding=1),
167 | nn.SiLU(),
168 | conv_nd(dims, 96, 256, 3, padding=1, stride=2),
169 | nn.SiLU(),
170 | zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
171 | )
172 | self.struct_hint_block2 = nn.ModuleList(
173 | [
174 | nn.Identity(), # [B, 320, 64, 64]
175 | nn.Sequential(conv_nd(dims, model_channels * channel_mult[0], model_channels * channel_mult[0], 3, padding=1, stride=2), nn.SiLU()), # [B, 320, 32, 32]
176 | nn.Sequential(conv_nd(dims, model_channels * channel_mult[0], model_channels * channel_mult[1], 3, padding=1, stride=2), nn.SiLU()), # [B, 640, 16, 16]
177 | nn.Sequential(conv_nd(dims, model_channels * channel_mult[1], model_channels * channel_mult[2], 3, padding=1, stride=2), nn.SiLU()), # [B, 1280, 8, 8]
178 | nn.Sequential(conv_nd(dims, model_channels * channel_mult[3], model_channels * channel_mult[3], 3, padding=1), nn.SiLU()), # [B, 1280, 8, 8]
179 | ]
180 | )
181 |
182 | self.struct_hint_block_zero_convs = nn.ModuleList(
183 | [
184 | nn.Identity(),
185 | zero_module(conv_nd(dims, model_channels * channel_mult[0], model_channels * channel_mult[0], 3, padding=1)), # [B, 320, 32, 32]
186 | zero_module(conv_nd(dims, model_channels * channel_mult[1], model_channels * channel_mult[1], 3, padding=1)), # [B, 640, 16, 16]
187 | zero_module(conv_nd(dims, model_channels * channel_mult[2], model_channels * channel_mult[2], 3, padding=1)), # [B, 1280, 8, 8]
188 | zero_module(conv_nd(dims, model_channels * channel_mult[3], model_channels * channel_mult[3], 3, padding=1)), # [B, 1280, 8, 8]
189 | ]
190 | )
191 |
192 | self.color_hint_block = TimestepEmbedSequential(
193 | conv_nd(dims, 2, 16, 3, padding=1),
194 | nn.SiLU(),
195 | conv_nd(dims, 16, 16, 3, padding=1),
196 | nn.SiLU(),
197 | conv_nd(dims, 16, 32, 3, padding=1, stride=2),
198 | nn.SiLU(),
199 | conv_nd(dims, 32, 32, 3, padding=1),
200 | nn.SiLU(),
201 | conv_nd(dims, 32, 96, 3, padding=1, stride=2),
202 | nn.SiLU(),
203 | conv_nd(dims, 96, 96, 3, padding=1),
204 | nn.SiLU(),
205 | conv_nd(dims, 96, 256, 3, padding=1, stride=2),
206 | nn.SiLU(),
207 | zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
208 | )
209 | self.color_hint_block2 = nn.ModuleList(
210 | [
211 | nn.Identity(), # [B, 320, 64, 64]
212 | nn.Sequential(conv_nd(dims, model_channels * channel_mult[0], model_channels * channel_mult[0], 3, padding=1, stride=2), nn.SiLU()), # [B, 320, 32, 32]
213 | nn.Sequential(conv_nd(dims, model_channels * channel_mult[0], model_channels * channel_mult[1], 3, padding=1, stride=2), nn.SiLU()), # [B, 640, 16, 16]
214 | nn.Sequential(conv_nd(dims, model_channels * channel_mult[1], model_channels * channel_mult[2], 3, padding=1, stride=2), nn.SiLU()), # [B, 1280, 8, 8]
215 | nn.Sequential(conv_nd(dims, model_channels * channel_mult[3], model_channels * channel_mult[3], 3, padding=1), nn.SiLU()), # [B, 1280, 8, 8]
216 | ]
217 | )
218 | self.color_hint_block_zero_convs = nn.ModuleList(
219 | [
220 | nn.Identity(),
221 | zero_module(conv_nd(dims, model_channels * channel_mult[0], model_channels * channel_mult[0], 3, padding=1)), # [B, 320, 32, 32]
222 | zero_module(conv_nd(dims, model_channels * channel_mult[1], model_channels * channel_mult[1], 3, padding=1)), # [B, 640, 16, 16]
223 | zero_module(conv_nd(dims, model_channels * channel_mult[2], model_channels * channel_mult[2], 3, padding=1)), # [B, 1280, 8, 8]
224 | zero_module(conv_nd(dims, model_channels * channel_mult[3], model_channels * channel_mult[3], 3, padding=1)), # [B, 1280, 8, 8]
225 | ]
226 | )
227 |
228 |
229 | self._feature_size = model_channels
230 | input_block_chans = [model_channels]
231 | ch = model_channels
232 | ds = 1
233 | for level, mult in enumerate(channel_mult):
234 | for nr in range(self.num_res_blocks[level]):
235 | layers = [
236 | ResBlock(
237 | ch,
238 | time_embed_dim,
239 | dropout,
240 | out_channels=mult * model_channels,
241 | dims=dims,
242 | use_checkpoint=use_checkpoint,
243 | use_scale_shift_norm=use_scale_shift_norm,
244 | )
245 | ]
246 | ch = mult * model_channels
247 | if ds in attention_resolutions:
248 | if num_head_channels == -1:
249 | dim_head = ch // num_heads
250 | else:
251 | num_heads = ch // num_head_channels
252 | dim_head = num_head_channels
253 | if legacy:
254 | # num_heads = 1
255 | dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
256 | if exists(disable_self_attentions):
257 | disabled_sa = disable_self_attentions[level]
258 | else:
259 | disabled_sa = False
260 |
261 | if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
262 | layers.append(
263 | AttentionBlock(
264 | ch,
265 | use_checkpoint=use_checkpoint,
266 | num_heads=num_heads,
267 | num_head_channels=dim_head,
268 | use_new_attention_order=use_new_attention_order,
269 | ) if not use_spatial_transformer else SpatialTransformer(
270 | ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
271 | disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
272 | use_checkpoint=use_checkpoint
273 | )
274 | )
275 | self.input_blocks.append(TimestepEmbedSequential(*layers))
276 | self.zero_convs.append(self.make_zero_conv(ch))
277 | self._feature_size += ch
278 | input_block_chans.append(ch)
279 | if level != len(channel_mult) - 1:
280 | out_ch = ch
281 | self.input_blocks.append(
282 | TimestepEmbedSequential(
283 | ResBlock(
284 | ch,
285 | time_embed_dim,
286 | dropout,
287 | out_channels=out_ch,
288 | dims=dims,
289 | use_checkpoint=use_checkpoint,
290 | use_scale_shift_norm=use_scale_shift_norm,
291 | down=True,
292 | )
293 | if resblock_updown
294 | else Downsample(
295 | ch, conv_resample, dims=dims, out_channels=out_ch
296 | )
297 | )
298 | )
299 | ch = out_ch
300 | input_block_chans.append(ch)
301 | self.zero_convs.append(self.make_zero_conv(ch))
302 | ds *= 2
303 | self._feature_size += ch
304 |
305 | if num_head_channels == -1:
306 | dim_head = ch // num_heads
307 | else:
308 | num_heads = ch // num_head_channels
309 | dim_head = num_head_channels
310 | if legacy:
311 | # num_heads = 1
312 | dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
313 | self.middle_block = TimestepEmbedSequential(
314 | ResBlock(
315 | ch,
316 | time_embed_dim,
317 | dropout,
318 | dims=dims,
319 | use_checkpoint=use_checkpoint,
320 | use_scale_shift_norm=use_scale_shift_norm,
321 | ),
322 | AttentionBlock(
323 | ch,
324 | use_checkpoint=use_checkpoint,
325 | num_heads=num_heads,
326 | num_head_channels=dim_head,
327 | use_new_attention_order=use_new_attention_order,
328 | ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
329 | ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
330 | disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
331 | use_checkpoint=use_checkpoint
332 | ),
333 | ResBlock(
334 | ch,
335 | time_embed_dim,
336 | dropout,
337 | dims=dims,
338 | use_checkpoint=use_checkpoint,
339 | use_scale_shift_norm=use_scale_shift_norm,
340 | ),
341 | )
342 | self.middle_block_out = self.make_zero_conv(ch)
343 | self._feature_size += ch
344 |
345 | def make_zero_conv(self, channels):
346 | return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
347 |
348 | def forward(self, x, hint, timesteps, context, **kwargs):
349 | t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
350 | emb = self.time_embed(t_emb)
351 |
352 | hint_vis, hint_mscn, hint_color = hint[0], hint[1], hint[2]
353 | x = torch.cat((x, hint_vis), dim=1)
354 | guided_hint_mscn = self.struct_hint_block(hint_mscn, emb, context)
355 | guided_hint_color = self.color_hint_block(hint_color, emb, context)
356 | outs = []
357 |
358 | h = x.type(self.dtype)
359 | for i, (module, zero_conv) in enumerate(zip(self.input_blocks, self.zero_convs)):
360 | h = module(h, emb, context)
361 | if i in [0, 3, 6, 9]:
362 | guided_hint_mscn = self.struct_hint_block2[i // 3](guided_hint_mscn)
363 | guided_hint_color = self.color_hint_block2[i // 3](guided_hint_color)
364 | h = self.cross_expo_block[i // 3](h, self.struct_hint_block_zero_convs[i // 3](guided_hint_mscn), self.color_hint_block_zero_convs[i // 3](guided_hint_color))
365 | outs.append(zero_conv(h, emb, context))
366 |
367 | h = self.middle_block(h, emb, context)
368 | guided_hint_mscn = self.struct_hint_block2[4](guided_hint_mscn)
369 | guided_hint_color = self.color_hint_block2[4](guided_hint_color)
370 | h = self.cross_expo_block[4](h, self.struct_hint_block_zero_convs[4](guided_hint_mscn), self.color_hint_block_zero_convs[4](guided_hint_color))
371 | outs.append(self.middle_block_out(h, emb, context))
372 |
373 | return outs
374 |
--------------------------------------------------------------------------------
/model/V4_CA/cross_attention.py:
--------------------------------------------------------------------------------
1 | # Restormer: Efficient Transformer for High-Resolution Image Restoration
2 | # Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
3 | # https://arxiv.org/abs/2111.09881
4 |
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import numbers
10 | from einops import rearrange
11 |
12 |
13 | class Conv2dNormRelu(nn.Module):
14 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, norm=None, activation='leaky_relu'):
15 | super().__init__()
16 | self.conv_fn = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups)
17 |
18 | if norm == 'batch_norm':
19 | self.norm_fn = nn.BatchNorm2d(out_channels)
20 | elif norm == 'instance_norm':
21 | self.norm_fn = nn.InstanceNorm2d(out_channels)
22 | elif norm is None:
23 | self.norm_fn = nn.Identity()
24 | else:
25 | raise NotImplementedError('Unknown normalization function: %s' % norm)
26 |
27 | if activation == 'relu':
28 | self.relu_fn = nn.ReLU(inplace=True)
29 | elif activation == 'leaky_relu':
30 | self.relu_fn = nn.LeakyReLU(negative_slope=0.1, inplace=True)
31 | elif activation is None:
32 | self.relu_fn = nn.Identity()
33 | else:
34 | raise NotImplementedError('Unknown activation function: %s' % activation)
35 |
36 | def forward(self, x):
37 | x = self.conv_fn(x)
38 | x = self.norm_fn(x)
39 | x = self.relu_fn(x)
40 | return x
41 |
42 |
43 | def to_3d(x):
44 | if len(x.shape) == 3:
45 | return rearrange(x, 'b c n -> b n c')
46 | else:
47 | return rearrange(x, 'b c h w -> b (h w) c')
48 |
49 |
50 | def to_4d(x, h, w=None):
51 | if w is None:
52 | return rearrange(x, 'b n c -> b c n')
53 | else:
54 | return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
55 |
56 |
57 | class BiasFree_LayerNorm(nn.Module):
58 | def __init__(self, normalized_shape):
59 | super(BiasFree_LayerNorm, self).__init__()
60 | if isinstance(normalized_shape, numbers.Integral):
61 | normalized_shape = (normalized_shape,)
62 | normalized_shape = torch.Size(normalized_shape)
63 |
64 | assert len(normalized_shape) == 1
65 |
66 | self.weight = nn.Parameter(torch.ones(normalized_shape))
67 | self.normalized_shape = normalized_shape
68 |
69 | def forward(self, x):
70 | sigma = x.var(-1, keepdim=True, unbiased=False)
71 | return x / torch.sqrt(sigma+1e-5) * self.weight
72 |
73 |
74 | class WithBias_LayerNorm(nn.Module):
75 | def __init__(self, normalized_shape):
76 | super(WithBias_LayerNorm, self).__init__()
77 | if isinstance(normalized_shape, numbers.Integral):
78 | normalized_shape = (normalized_shape,)
79 | normalized_shape = torch.Size(normalized_shape)
80 |
81 | assert len(normalized_shape) == 1
82 |
83 | self.weight = nn.Parameter(torch.ones(normalized_shape))
84 | self.bias = nn.Parameter(torch.zeros(normalized_shape))
85 | self.normalized_shape = normalized_shape
86 |
87 | def forward(self, x):
88 | mu = x.mean(-1, keepdim=True)
89 | sigma = x.var(-1, keepdim=True, unbiased=False)
90 | return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias
91 |
92 |
93 | class LayerNorm(nn.Module):
94 | def __init__(self, dim, LayerNorm_type):
95 | super(LayerNorm, self).__init__()
96 | if LayerNorm_type == 'BiasFree':
97 | self.body = BiasFree_LayerNorm(dim)
98 | else:
99 | self.body = WithBias_LayerNorm(dim)
100 |
101 | def forward(self, x):
102 | if len(x.shape) == 3:
103 | h = x.shape[-1]
104 | w = None
105 | else:
106 | h, w = x.shape[-2:]
107 |
108 | return to_4d(self.body(to_3d(x)), h, w)
109 |
110 |
111 | class Mutual_Attention2D(nn.Module):
112 | def __init__(self, dim, num_heads, bias):
113 | super(Mutual_Attention2D, self).__init__()
114 | self.num_heads = num_heads
115 | self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
116 |
117 | self.qkv_dwconv = nn.Conv2d(
118 | dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
119 | self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
120 |
121 | def forward(self, x, y):
122 | b, c, h, w = x.shape
123 |
124 | qkv = self.qkv_dwconv(torch.cat((x, y, y), dim=1))
125 | q, k, v = qkv.chunk(3, dim=1)
126 |
127 | q = rearrange(q, 'b (head c) h w -> b head c (h w)',
128 | head=self.num_heads)
129 | k = rearrange(k, 'b (head c) h w -> b head c (h w)',
130 | head=self.num_heads)
131 | v = rearrange(v, 'b (head c) h w -> b head c (h w)',
132 | head=self.num_heads)
133 |
134 | q = torch.nn.functional.normalize(q, dim=-1)
135 | k = torch.nn.functional.normalize(k, dim=-1)
136 |
137 | attn = (q @ k.transpose(-2, -1)) * self.temperature
138 | attn = attn.softmax(dim=-1)
139 |
140 | out = (attn @ v)
141 |
142 | out = rearrange(out, 'b head c (h w) -> b (head c) h w',
143 | head=self.num_heads, h=h, w=w)
144 |
145 | out = self.project_out(out)
146 | return out
147 |
148 |
149 | class CrossTransformerBlock2D(nn.Module):
150 | def __init__(self, dim, num_heads, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias'):
151 | super(CrossTransformerBlock2D, self).__init__()
152 |
153 | self.mlps = nn.Sequential(
154 | Conv2dNormRelu(dim * 2, dim),
155 | )
156 | self.norm1x = LayerNorm(dim, LayerNorm_type)
157 | self.norm1y = LayerNorm(dim, LayerNorm_type)
158 | self.attn = Mutual_Attention2D(dim, num_heads, bias)
159 | # self.norm2 = LayerNorm(dim, LayerNorm_type)
160 | # self.ffn = FeedForward2D(dim, ffn_expansion_factor, bias)
161 |
162 | def forward(self, x, y, z):
163 | assert x.shape == y.shape and x.shape == z.shape, 'wrong shape!, {} {} {}'.format(x.shape, y.shape, z.shape)
164 | y2 = self.mlps(torch.cat([y, z], dim=1))
165 | x = x + self.attn(self.norm1x(x), self.norm1y(y2))
166 | # x = x + self.ffn(self.norm2(x))
167 |
168 | return x
--------------------------------------------------------------------------------
/model/V4_CA/gaussian_diffusion.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from typing import Tuple
3 |
4 | import torch
5 | from torch import nn
6 | import numpy as np
7 |
8 |
9 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
10 | if schedule == "linear":
11 | betas = (
12 | np.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=np.float64) ** 2
13 | )
14 |
15 | elif schedule == "cosine":
16 | timesteps = (
17 | np.arange(n_timestep + 1, dtype=np.float64) / n_timestep + cosine_s
18 | )
19 | alphas = timesteps / (1 + cosine_s) * np.pi / 2
20 | alphas = np.cos(alphas).pow(2)
21 | alphas = alphas / alphas[0]
22 | betas = 1 - alphas[1:] / alphas[:-1]
23 | betas = np.clip(betas, a_min=0, a_max=0.999)
24 |
25 | elif schedule == "sqrt_linear":
26 | betas = np.linspace(linear_start, linear_end, n_timestep, dtype=np.float64)
27 | elif schedule == "sqrt":
28 | betas = np.linspace(linear_start, linear_end, n_timestep, dtype=np.float64) ** 0.5
29 | else:
30 | raise ValueError(f"schedule '{schedule}' unknown.")
31 | return betas
32 |
33 |
34 | def extract_into_tensor(a: torch.Tensor, t: torch.Tensor, x_shape: Tuple[int]) -> torch.Tensor:
35 | b, *_ = t.shape
36 | out = a.gather(-1, t)
37 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
38 |
39 |
40 | class Diffusion(nn.Module):
41 |
42 | def __init__(
43 | self,
44 | timesteps=1000,
45 | beta_schedule="linear",
46 | loss_type="l2",
47 | linear_start=1e-4,
48 | linear_end=2e-2,
49 | cosine_s=8e-3,
50 | parameterization="eps"
51 | ):
52 | super().__init__()
53 | self.num_timesteps = timesteps
54 | self.beta_schedule = beta_schedule
55 | self.linear_start = linear_start
56 | self.linear_end = linear_end
57 | self.cosine_s = cosine_s
58 | assert parameterization in ["eps", "x0", "v"], "currently only supporting 'eps' and 'x0' and 'v'"
59 | self.parameterization = parameterization
60 | self.loss_type = loss_type
61 |
62 | betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
63 | cosine_s=cosine_s)
64 | alphas = 1. - betas
65 | alphas_cumprod = np.cumprod(alphas, axis=0)
66 | sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)
67 | sqrt_one_minus_alphas_cumprod = np.sqrt(1. - alphas_cumprod)
68 |
69 | self.betas = betas
70 | self.register("sqrt_alphas_cumprod", sqrt_alphas_cumprod)
71 | self.register("sqrt_one_minus_alphas_cumprod", sqrt_one_minus_alphas_cumprod)
72 |
73 | def register(self, name: str, value: np.ndarray) -> None:
74 | self.register_buffer(name, torch.tensor(value, dtype=torch.float32))
75 |
76 | def q_sample(self, x_start, t, noise):
77 | return (
78 | extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
79 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
80 | )
81 |
82 | def get_v(self, x, noise, t):
83 | return (
84 | extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
85 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
86 | )
87 |
88 | def get_loss(self, pred, target, mean=True):
89 | if self.loss_type == 'l1':
90 | loss = (target - pred).abs()
91 | if mean:
92 | loss = loss.mean()
93 | elif self.loss_type == 'l2':
94 | if mean:
95 | loss = torch.nn.functional.mse_loss(target, pred)
96 | else:
97 | loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
98 | else:
99 | raise NotImplementedError("unknown loss type '{loss_type}'")
100 |
101 | return loss
102 |
103 | def p_losses(self, model, x_start, t, cond):
104 | noise = torch.randn_like(x_start)
105 | x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
106 | model_output = model(x_noisy, t, cond)
107 |
108 | if self.parameterization == "x0":
109 | target = x_start
110 | elif self.parameterization == "eps":
111 | target = noise
112 | elif self.parameterization == "v":
113 | target = self.get_v(x_start, noise, t)
114 | else:
115 | raise NotImplementedError()
116 |
117 | loss_simple = self.get_loss(model_output, target, mean=False).mean()
118 | return loss_simple
119 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | # from . import config
2 |
3 | # from .controlnet import ControlledUnetModel, ControlNet
4 | # from .vae import AutoencoderKL
5 | # from .clip import FrozenOpenCLIPEmbedder
6 |
7 | # from .cldm import ControlLDM
8 | # from .cldm_mc import MultiControlLDM
9 | # from .gaussian_diffusion import Diffusion
10 |
11 | # from .swinir import SwinIR
12 | # from .bsrnet import RRDBNet
13 | # from .scunet import SCUNet
14 |
--------------------------------------------------------------------------------
/model/clip.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | import torch
3 | import torch.nn as nn
4 | from torch.utils.checkpoint import checkpoint
5 | from model.open_clip import CLIP, tokenize
6 |
7 | ### pretrained model path
8 | # _VITH14 = dict(
9 | # laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'),
10 | # )
11 |
12 | class FrozenOpenCLIPEmbedder(nn.Module):
13 | """
14 | Uses the OpenCLIP transformer encoder for text
15 | """
16 | LAYERS = [
17 | #"pooled",
18 | "last",
19 | "penultimate"
20 | ]
21 | def __init__(self, embed_dim, vision_cfg, text_cfg, layer="last"):
22 | super().__init__()
23 | assert layer in self.LAYERS
24 | # model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
25 | model = CLIP(embed_dim, dict(vision_cfg), dict(text_cfg))
26 | del model.visual
27 | self.model = model
28 |
29 | self.layer = layer
30 | if self.layer == "last":
31 | self.layer_idx = 0
32 | elif self.layer == "penultimate":
33 | self.layer_idx = 1
34 | else:
35 | raise NotImplementedError()
36 |
37 | def forward(self, tokens):
38 | z = self.encode_with_transformer(tokens)
39 | return z
40 |
41 | def encode_with_transformer(self, text):
42 | x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
43 | x = x + self.model.positional_embedding
44 | x = x.permute(1, 0, 2) # NLD -> LND
45 | x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
46 | x = x.permute(1, 0, 2) # LND -> NLD
47 | x = self.model.ln_final(x)
48 | return x
49 |
50 | def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
51 | for i, r in enumerate(self.model.transformer.resblocks):
52 | if i == len(self.model.transformer.resblocks) - self.layer_idx:
53 | break
54 | if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
55 | x = checkpoint(r, x, attn_mask)
56 | else:
57 | x = r(x, attn_mask=attn_mask)
58 | return x
59 |
60 | def encode(self, text: List[str]) -> torch.Tensor:
61 | # convert a batch of text to tensor
62 | tokens = tokenize(text)
63 | # move tensor to model device
64 | tokens = tokens.to(next(self.model.parameters()).device)
65 | return self(tokens)
66 |
--------------------------------------------------------------------------------
/model/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Optional, Literal
3 | from types import ModuleType
4 | import enum
5 | from packaging import version
6 |
7 | import torch
8 |
9 | # collect system information
10 | if version.parse(torch.__version__) >= version.parse("2.0.0"):
11 | SDP_IS_AVAILABLE = True
12 | else:
13 | SDP_IS_AVAILABLE = False
14 |
15 | try:
16 | import xformers
17 | import xformers.ops
18 | XFORMERS_IS_AVAILBLE = True
19 | except:
20 | XFORMERS_IS_AVAILBLE = False
21 |
22 |
23 | class AttnMode(enum.Enum):
24 | SDP = 0
25 | XFORMERS = 1
26 | VANILLA = 2
27 |
28 |
29 | class Config:
30 | xformers: Optional[ModuleType] = None
31 | attn_mode: AttnMode = AttnMode.VANILLA
32 |
33 |
34 | # initialize attention mode
35 | if SDP_IS_AVAILABLE:
36 | Config.attn_mode = AttnMode.SDP
37 | print(f"use sdp attention as default")
38 | elif XFORMERS_IS_AVAILBLE:
39 | Config.attn_mode = AttnMode.XFORMERS
40 | print(f"use xformers attention as default")
41 | else:
42 | print(f"both sdp attention and xformers are not available, use vanilla attention (very expensive) as default")
43 |
44 | if XFORMERS_IS_AVAILBLE:
45 | Config.xformers = xformers
46 |
47 |
48 | # user-specified attention mode
49 | ATTN_MODE = os.environ.get("ATTN_MODE", None)
50 | if ATTN_MODE is not None:
51 | assert ATTN_MODE in ["vanilla", "sdp", "xformers"]
52 | if ATTN_MODE == "sdp":
53 | assert SDP_IS_AVAILABLE
54 | Config.attn_mode = AttnMode.SDP
55 | elif ATTN_MODE == "xformers":
56 | assert XFORMERS_IS_AVAILBLE
57 | Config.attn_mode = AttnMode.XFORMERS
58 | else:
59 | Config.attn_mode = AttnMode.VANILLA
60 | print(f"set attention mode to {ATTN_MODE}")
61 | else:
62 | print("keep default attention mode")
63 |
--------------------------------------------------------------------------------
/model/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 |
35 | def sample(self):
36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37 | return x
38 |
39 | def kl(self, other=None):
40 | if self.deterministic:
41 | return torch.Tensor([0.])
42 | else:
43 | if other is None:
44 | return 0.5 * torch.sum(torch.pow(self.mean, 2)
45 | + self.var - 1.0 - self.logvar,
46 | dim=[1, 2, 3])
47 | else:
48 | return 0.5 * torch.sum(
49 | torch.pow(self.mean - other.mean, 2) / other.var
50 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
51 | dim=[1, 2, 3])
52 |
53 | def nll(self, sample, dims=[1,2,3]):
54 | if self.deterministic:
55 | return torch.Tensor([0.])
56 | logtwopi = np.log(2.0 * np.pi)
57 | return 0.5 * torch.sum(
58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59 | dim=dims)
60 |
61 | def mode(self):
62 | return self.mean
63 |
64 |
65 | def normal_kl(mean1, logvar1, mean2, logvar2):
66 | """
67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68 | Compute the KL divergence between two gaussians.
69 | Shapes are automatically broadcasted, so batches can be compared to
70 | scalars, among other use cases.
71 | """
72 | tensor = None
73 | for obj in (mean1, logvar1, mean2, logvar2):
74 | if isinstance(obj, torch.Tensor):
75 | tensor = obj
76 | break
77 | assert tensor is not None, "at least one argument must be a Tensor"
78 |
79 | # Force variances to be Tensors. Broadcasting helps convert scalars to
80 | # Tensors, but it does not work for torch.exp().
81 | logvar1, logvar2 = [
82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83 | for x in (logvar1, logvar2)
84 | ]
85 |
86 | return 0.5 * (
87 | -1.0
88 | + logvar2
89 | - logvar1
90 | + torch.exp(logvar1 - logvar2)
91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92 | )
93 |
--------------------------------------------------------------------------------
/model/open_clip/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import CLIP
2 | from .tokenizer import tokenize
3 |
4 | __all__ = ["CLIP", "tokenize"]
5 |
--------------------------------------------------------------------------------
/model/open_clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenImagingLab/UltraFusion/b0787c59217fe4877c233d938810a23c8d2c5c73/model/open_clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/model/open_clip/model.py:
--------------------------------------------------------------------------------
1 | """ CLIP Model
2 |
3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4 | """
5 | from dataclasses import dataclass
6 | from typing import Optional, Tuple, Union
7 |
8 | import numpy as np
9 | import torch
10 | import torch.nn.functional as F
11 | from torch import nn
12 |
13 | from .transformer import LayerNormFp32, LayerNorm, QuickGELU, VisionTransformer, TextTransformer
14 |
15 |
16 | @dataclass
17 | class CLIPVisionCfg:
18 | layers: Union[Tuple[int, int, int, int], int] = 12
19 | width: int = 768
20 | head_width: int = 64
21 | mlp_ratio: float = 4.0
22 | patch_size: int = 16
23 | image_size: Union[Tuple[int, int], int] = 224
24 |
25 | ls_init_value: Optional[float] = None # layer scale initial value
26 | patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
27 | input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design
28 | global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
29 | attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer
30 | n_queries: int = 256 # n_queries for attentional pooler
31 | attn_pooler_heads: int = 8 # n heads for attentional_pooling
32 | output_tokens: bool = False
33 |
34 | timm_model_name: str = None # a valid model name overrides layers, width, patch_size
35 | timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
36 | timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
37 | timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
38 | timm_proj_bias: bool = False # enable bias final projection
39 | timm_drop: float = 0. # head dropout
40 | timm_drop_path: Optional[float] = None # backbone stochastic depth
41 |
42 |
43 | @dataclass
44 | class CLIPTextCfg:
45 | context_length: int = 77
46 | vocab_size: int = 49408
47 | width: int = 512
48 | heads: int = 8
49 | layers: int = 12
50 | ls_init_value: Optional[float] = None # layer scale initial value
51 | hf_model_name: str = None
52 | hf_tokenizer_name: str = None
53 | hf_model_pretrained: bool = True
54 | proj: str = 'mlp'
55 | pooler_type: str = 'mean_pooler'
56 | embed_cls: bool = False
57 | pad_id: int = 0
58 | output_tokens: bool = False
59 |
60 |
61 | def get_cast_dtype(precision: str):
62 | cast_dtype = None
63 | if precision == 'bf16':
64 | cast_dtype = torch.bfloat16
65 | elif precision == 'fp16':
66 | cast_dtype = torch.float16
67 | return cast_dtype
68 |
69 |
70 | def _build_vision_tower(
71 | embed_dim: int,
72 | vision_cfg: CLIPVisionCfg,
73 | quick_gelu: bool = False,
74 | cast_dtype: Optional[torch.dtype] = None
75 | ):
76 | if isinstance(vision_cfg, dict):
77 | vision_cfg = CLIPVisionCfg(**vision_cfg)
78 |
79 | # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
80 | # memory efficient in recent PyTorch releases (>= 1.10).
81 | # NOTE: timm models always use native GELU regardless of quick_gelu flag.
82 | act_layer = QuickGELU if quick_gelu else nn.GELU
83 |
84 | vision_heads = vision_cfg.width // vision_cfg.head_width
85 | norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
86 | visual = VisionTransformer(
87 | image_size=vision_cfg.image_size,
88 | patch_size=vision_cfg.patch_size,
89 | width=vision_cfg.width,
90 | layers=vision_cfg.layers,
91 | heads=vision_heads,
92 | mlp_ratio=vision_cfg.mlp_ratio,
93 | ls_init_value=vision_cfg.ls_init_value,
94 | patch_dropout=vision_cfg.patch_dropout,
95 | input_patchnorm=vision_cfg.input_patchnorm,
96 | global_average_pool=vision_cfg.global_average_pool,
97 | attentional_pool=vision_cfg.attentional_pool,
98 | n_queries=vision_cfg.n_queries,
99 | attn_pooler_heads=vision_cfg.attn_pooler_heads,
100 | output_tokens=vision_cfg.output_tokens,
101 | output_dim=embed_dim,
102 | act_layer=act_layer,
103 | norm_layer=norm_layer,
104 | )
105 |
106 | return visual
107 |
108 |
109 | def _build_text_tower(
110 | embed_dim: int,
111 | text_cfg: CLIPTextCfg,
112 | quick_gelu: bool = False,
113 | cast_dtype: Optional[torch.dtype] = None,
114 | ):
115 | if isinstance(text_cfg, dict):
116 | text_cfg = CLIPTextCfg(**text_cfg)
117 |
118 | act_layer = QuickGELU if quick_gelu else nn.GELU
119 | norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
120 |
121 | text = TextTransformer(
122 | context_length=text_cfg.context_length,
123 | vocab_size=text_cfg.vocab_size,
124 | width=text_cfg.width,
125 | heads=text_cfg.heads,
126 | layers=text_cfg.layers,
127 | ls_init_value=text_cfg.ls_init_value,
128 | output_dim=embed_dim,
129 | embed_cls=text_cfg.embed_cls,
130 | output_tokens=text_cfg.output_tokens,
131 | pad_id=text_cfg.pad_id,
132 | act_layer=act_layer,
133 | norm_layer=norm_layer,
134 | )
135 | return text
136 |
137 |
138 | class CLIP(nn.Module):
139 | output_dict: torch.jit.Final[bool]
140 |
141 | def __init__(
142 | self,
143 | embed_dim: int,
144 | vision_cfg: CLIPVisionCfg,
145 | text_cfg: CLIPTextCfg,
146 | quick_gelu: bool = False,
147 | cast_dtype: Optional[torch.dtype] = None,
148 | output_dict: bool = False,
149 | ):
150 | super().__init__()
151 | self.output_dict = output_dict
152 | self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
153 |
154 | text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
155 | self.transformer = text.transformer
156 | self.context_length = text.context_length
157 | self.vocab_size = text.vocab_size
158 | self.token_embedding = text.token_embedding
159 | self.positional_embedding = text.positional_embedding
160 | self.ln_final = text.ln_final
161 | self.text_projection = text.text_projection
162 | self.register_buffer('attn_mask', text.attn_mask, persistent=False)
163 |
164 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
165 |
166 | def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
167 | # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
168 | self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
169 |
170 | @torch.jit.ignore
171 | def set_grad_checkpointing(self, enable=True):
172 | self.visual.set_grad_checkpointing(enable)
173 | self.transformer.grad_checkpointing = enable
174 |
175 | def encode_image(self, image, normalize: bool = False):
176 | features = self.visual(image)
177 | return F.normalize(features, dim=-1) if normalize else features
178 |
179 | def encode_text(self, text, normalize: bool = False):
180 | cast_dtype = self.transformer.get_cast_dtype()
181 |
182 | x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
183 |
184 | x = x + self.positional_embedding.to(cast_dtype)
185 | x = x.permute(1, 0, 2) # NLD -> LND
186 | x = self.transformer(x, attn_mask=self.attn_mask)
187 | x = x.permute(1, 0, 2) # LND -> NLD
188 | x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
189 | # take features from the eot embedding (eot_token is the highest number in each sequence)
190 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
191 | return F.normalize(x, dim=-1) if normalize else x
192 |
193 | def forward(
194 | self,
195 | image: Optional[torch.Tensor] = None,
196 | text: Optional[torch.Tensor] = None,
197 | ):
198 | image_features = self.encode_image(image, normalize=True) if image is not None else None
199 | text_features = self.encode_text(text, normalize=True) if text is not None else None
200 | if self.output_dict:
201 | return {
202 | "image_features": image_features,
203 | "text_features": text_features,
204 | "logit_scale": self.logit_scale.exp()
205 | }
206 | return image_features, text_features, self.logit_scale.exp()
207 |
--------------------------------------------------------------------------------
/model/open_clip/tokenizer.py:
--------------------------------------------------------------------------------
1 | """ CLIP tokenizer
2 |
3 | Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4 | """
5 | import gzip
6 | import html
7 | import os
8 | from functools import lru_cache
9 | from typing import Union, List
10 |
11 | import ftfy
12 | import regex as re
13 | import torch
14 |
15 | # https://stackoverflow.com/q/62691279
16 | import os
17 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
18 |
19 |
20 | @lru_cache()
21 | def default_bpe():
22 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
23 |
24 |
25 | @lru_cache()
26 | def bytes_to_unicode():
27 | """
28 | Returns list of utf-8 byte and a corresponding list of unicode strings.
29 | The reversible bpe codes work on unicode strings.
30 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
31 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
32 | This is a significant percentage of your normal, say, 32K bpe vocab.
33 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
34 | And avoids mapping to whitespace/control characters the bpe code barfs on.
35 | """
36 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
37 | cs = bs[:]
38 | n = 0
39 | for b in range(2**8):
40 | if b not in bs:
41 | bs.append(b)
42 | cs.append(2**8+n)
43 | n += 1
44 | cs = [chr(n) for n in cs]
45 | return dict(zip(bs, cs))
46 |
47 |
48 | def get_pairs(word):
49 | """Return set of symbol pairs in a word.
50 | Word is represented as tuple of symbols (symbols being variable-length strings).
51 | """
52 | pairs = set()
53 | prev_char = word[0]
54 | for char in word[1:]:
55 | pairs.add((prev_char, char))
56 | prev_char = char
57 | return pairs
58 |
59 |
60 | def basic_clean(text):
61 | text = ftfy.fix_text(text)
62 | text = html.unescape(html.unescape(text))
63 | return text.strip()
64 |
65 |
66 | def whitespace_clean(text):
67 | text = re.sub(r'\s+', ' ', text)
68 | text = text.strip()
69 | return text
70 |
71 |
72 | class SimpleTokenizer(object):
73 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
74 | self.byte_encoder = bytes_to_unicode()
75 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
76 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
77 | merges = merges[1:49152-256-2+1]
78 | merges = [tuple(merge.split()) for merge in merges]
79 | vocab = list(bytes_to_unicode().values())
80 | vocab = vocab + [v+'' for v in vocab]
81 | for merge in merges:
82 | vocab.append(''.join(merge))
83 | if not special_tokens:
84 | special_tokens = ['', '']
85 | else:
86 | special_tokens = ['', ''] + special_tokens
87 | vocab.extend(special_tokens)
88 | self.encoder = dict(zip(vocab, range(len(vocab))))
89 | self.decoder = {v: k for k, v in self.encoder.items()}
90 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
91 | self.cache = {t:t for t in special_tokens}
92 | special = "|".join(special_tokens)
93 | self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
94 |
95 | self.vocab_size = len(self.encoder)
96 | self.all_special_ids = [self.encoder[t] for t in special_tokens]
97 |
98 | def bpe(self, token):
99 | if token in self.cache:
100 | return self.cache[token]
101 | word = tuple(token[:-1]) + ( token[-1] + '',)
102 | pairs = get_pairs(word)
103 |
104 | if not pairs:
105 | return token+''
106 |
107 | while True:
108 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
109 | if bigram not in self.bpe_ranks:
110 | break
111 | first, second = bigram
112 | new_word = []
113 | i = 0
114 | while i < len(word):
115 | try:
116 | j = word.index(first, i)
117 | new_word.extend(word[i:j])
118 | i = j
119 | except:
120 | new_word.extend(word[i:])
121 | break
122 |
123 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
124 | new_word.append(first+second)
125 | i += 2
126 | else:
127 | new_word.append(word[i])
128 | i += 1
129 | new_word = tuple(new_word)
130 | word = new_word
131 | if len(word) == 1:
132 | break
133 | else:
134 | pairs = get_pairs(word)
135 | word = ' '.join(word)
136 | self.cache[token] = word
137 | return word
138 |
139 | def encode(self, text):
140 | bpe_tokens = []
141 | text = whitespace_clean(basic_clean(text)).lower()
142 | for token in re.findall(self.pat, text):
143 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
144 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
145 | return bpe_tokens
146 |
147 | def decode(self, tokens):
148 | text = ''.join([self.decoder[token] for token in tokens])
149 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
150 | return text
151 |
152 |
153 | _tokenizer = SimpleTokenizer()
154 |
155 | def decode(output_ids: torch.Tensor):
156 | output_ids = output_ids.cpu().numpy()
157 | return _tokenizer.decode(output_ids)
158 |
159 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
160 | """
161 | Returns the tokenized representation of given input string(s)
162 |
163 | Parameters
164 | ----------
165 | texts : Union[str, List[str]]
166 | An input string or a list of input strings to tokenize
167 | context_length : int
168 | The context length to use; all CLIP models use 77 as the context length
169 |
170 | Returns
171 | -------
172 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
173 | """
174 | if isinstance(texts, str):
175 | texts = [texts]
176 |
177 | sot_token = _tokenizer.encoder[""]
178 | eot_token = _tokenizer.encoder[""]
179 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
180 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
181 |
182 | for i, tokens in enumerate(all_tokens):
183 | if len(tokens) > context_length:
184 | tokens = tokens[:context_length] # Truncate
185 | tokens[-1] = eot_token
186 | result[i, :len(tokens)] = torch.tensor(tokens)
187 |
188 | return result
189 |
190 |
191 | class HFTokenizer:
192 | """HuggingFace tokenizer wrapper"""
193 |
194 | def __init__(self, tokenizer_name: str):
195 | from transformers import AutoTokenizer
196 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
197 |
198 | def save_pretrained(self, dest):
199 | self.tokenizer.save_pretrained(dest)
200 |
201 | def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor:
202 | # same cleaning as for default tokenizer, except lowercasing
203 | # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
204 | if isinstance(texts, str):
205 | texts = [texts]
206 | texts = [whitespace_clean(basic_clean(text)) for text in texts]
207 | input_ids = self.tokenizer(
208 | texts,
209 | return_tensors='pt',
210 | max_length=context_length,
211 | padding='max_length',
212 | truncation=True,
213 | ).input_ids
214 | return input_ids
215 |
--------------------------------------------------------------------------------
/model/raft/corr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from .utils.utils import bilinear_sampler, coords_grid
4 |
5 | try:
6 | import alt_cuda_corr
7 | except:
8 | # alt_cuda_corr is not compiled
9 | pass
10 |
11 |
12 | class CorrBlock:
13 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
14 | self.num_levels = num_levels
15 | self.radius = radius
16 | self.corr_pyramid = []
17 |
18 | # all pairs correlation
19 | corr = CorrBlock.corr(fmap1, fmap2)
20 |
21 | batch, h1, w1, dim, h2, w2 = corr.shape
22 | corr = corr.reshape(batch*h1*w1, dim, h2, w2)
23 |
24 | self.corr_pyramid.append(corr)
25 | for i in range(self.num_levels-1):
26 | corr = F.avg_pool2d(corr, 2, stride=2)
27 | self.corr_pyramid.append(corr)
28 |
29 | def __call__(self, coords):
30 | r = self.radius
31 | coords = coords.permute(0, 2, 3, 1)
32 | batch, h1, w1, _ = coords.shape
33 |
34 | out_pyramid = []
35 | for i in range(self.num_levels):
36 | corr = self.corr_pyramid[i]
37 | dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
38 | dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
39 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
40 |
41 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
42 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
43 | coords_lvl = centroid_lvl + delta_lvl
44 |
45 | corr = bilinear_sampler(corr, coords_lvl)
46 | corr = corr.view(batch, h1, w1, -1)
47 | out_pyramid.append(corr)
48 |
49 | out = torch.cat(out_pyramid, dim=-1)
50 | return out.permute(0, 3, 1, 2).contiguous().float()
51 |
52 | @staticmethod
53 | def corr(fmap1, fmap2):
54 | batch, dim, ht, wd = fmap1.shape
55 | fmap1 = fmap1.view(batch, dim, ht*wd)
56 | fmap2 = fmap2.view(batch, dim, ht*wd)
57 |
58 | corr = torch.matmul(fmap1.transpose(1,2), fmap2)
59 | corr = corr.view(batch, ht, wd, 1, ht, wd)
60 | return corr / torch.sqrt(torch.tensor(dim).float())
61 |
62 |
63 | class AlternateCorrBlock:
64 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
65 | self.num_levels = num_levels
66 | self.radius = radius
67 |
68 | self.pyramid = [(fmap1, fmap2)]
69 | for i in range(self.num_levels):
70 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
71 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
72 | self.pyramid.append((fmap1, fmap2))
73 |
74 | def __call__(self, coords):
75 | coords = coords.permute(0, 2, 3, 1)
76 | B, H, W, _ = coords.shape
77 | dim = self.pyramid[0][0].shape[1]
78 |
79 | corr_list = []
80 | for i in range(self.num_levels):
81 | r = self.radius
82 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
83 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
84 |
85 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
86 | corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
87 | corr_list.append(corr.squeeze(1))
88 |
89 | corr = torch.stack(corr_list, dim=1)
90 | corr = corr.reshape(B, -1, H, W)
91 | return corr / torch.sqrt(torch.tensor(dim).float())
92 |
--------------------------------------------------------------------------------
/model/raft/datasets.py:
--------------------------------------------------------------------------------
1 | # Data loading based on https://github.com/NVIDIA/flownet2-pytorch
2 |
3 | import numpy as np
4 | import torch
5 | import torch.utils.data as data
6 | import torch.nn.functional as F
7 |
8 | import os
9 | import math
10 | import random
11 | from glob import glob
12 | import os.path as osp
13 |
14 | from utils import frame_utils
15 | from utils.augmentor import FlowAugmentor, SparseFlowAugmentor
16 |
17 |
18 | class FlowDataset(data.Dataset):
19 | def __init__(self, aug_params=None, sparse=False):
20 | self.augmentor = None
21 | self.sparse = sparse
22 | if aug_params is not None:
23 | if sparse:
24 | self.augmentor = SparseFlowAugmentor(**aug_params)
25 | else:
26 | self.augmentor = FlowAugmentor(**aug_params)
27 |
28 | self.is_test = False
29 | self.init_seed = False
30 | self.flow_list = []
31 | self.image_list = []
32 | self.extra_info = []
33 |
34 | def __getitem__(self, index):
35 |
36 | if self.is_test:
37 | img1 = frame_utils.read_gen(self.image_list[index][0])
38 | img2 = frame_utils.read_gen(self.image_list[index][1])
39 | img1 = np.array(img1).astype(np.uint8)[..., :3]
40 | img2 = np.array(img2).astype(np.uint8)[..., :3]
41 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
42 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
43 | return img1, img2, self.extra_info[index]
44 |
45 | if not self.init_seed:
46 | worker_info = torch.utils.data.get_worker_info()
47 | if worker_info is not None:
48 | torch.manual_seed(worker_info.id)
49 | np.random.seed(worker_info.id)
50 | random.seed(worker_info.id)
51 | self.init_seed = True
52 |
53 | index = index % len(self.image_list)
54 | valid = None
55 | if self.sparse:
56 | flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
57 | else:
58 | flow = frame_utils.read_gen(self.flow_list[index])
59 |
60 | img1 = frame_utils.read_gen(self.image_list[index][0])
61 | img2 = frame_utils.read_gen(self.image_list[index][1])
62 |
63 | flow = np.array(flow).astype(np.float32)
64 | img1 = np.array(img1).astype(np.uint8)
65 | img2 = np.array(img2).astype(np.uint8)
66 |
67 | # grayscale images
68 | if len(img1.shape) == 2:
69 | img1 = np.tile(img1[...,None], (1, 1, 3))
70 | img2 = np.tile(img2[...,None], (1, 1, 3))
71 | else:
72 | img1 = img1[..., :3]
73 | img2 = img2[..., :3]
74 |
75 | if self.augmentor is not None:
76 | if self.sparse:
77 | img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
78 | else:
79 | img1, img2, flow = self.augmentor(img1, img2, flow)
80 |
81 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
82 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
83 | flow = torch.from_numpy(flow).permute(2, 0, 1).float()
84 |
85 | if valid is not None:
86 | valid = torch.from_numpy(valid)
87 | else:
88 | valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)
89 |
90 | return img1, img2, flow, valid.float()
91 |
92 |
93 | def __rmul__(self, v):
94 | self.flow_list = v * self.flow_list
95 | self.image_list = v * self.image_list
96 | return self
97 |
98 | def __len__(self):
99 | return len(self.image_list)
100 |
101 |
102 | class MpiSintel(FlowDataset):
103 | def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'):
104 | super(MpiSintel, self).__init__(aug_params)
105 | flow_root = osp.join(root, split, 'flow')
106 | image_root = osp.join(root, split, dstype)
107 |
108 | if split == 'test':
109 | self.is_test = True
110 |
111 | for scene in os.listdir(image_root):
112 | image_list = sorted(glob(osp.join(image_root, scene, '*.png')))
113 | for i in range(len(image_list)-1):
114 | self.image_list += [ [image_list[i], image_list[i+1]] ]
115 | self.extra_info += [ (scene, i) ] # scene and frame_id
116 |
117 | if split != 'test':
118 | self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo')))
119 |
120 |
121 | class FlyingChairs(FlowDataset):
122 | def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'):
123 | super(FlyingChairs, self).__init__(aug_params)
124 |
125 | images = sorted(glob(osp.join(root, '*.ppm')))
126 | flows = sorted(glob(osp.join(root, '*.flo')))
127 | assert (len(images)//2 == len(flows))
128 |
129 | split_list = np.loadtxt('chairs_split.txt', dtype=np.int32)
130 | for i in range(len(flows)):
131 | xid = split_list[i]
132 | if (split=='training' and xid==1) or (split=='validation' and xid==2):
133 | self.flow_list += [ flows[i] ]
134 | self.image_list += [ [images[2*i], images[2*i+1]] ]
135 |
136 |
137 | class FlyingThings3D(FlowDataset):
138 | def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'):
139 | super(FlyingThings3D, self).__init__(aug_params)
140 |
141 | for cam in ['left']:
142 | for direction in ['into_future', 'into_past']:
143 | image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*')))
144 | image_dirs = sorted([osp.join(f, cam) for f in image_dirs])
145 |
146 | flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*')))
147 | flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs])
148 |
149 | for idir, fdir in zip(image_dirs, flow_dirs):
150 | images = sorted(glob(osp.join(idir, '*.png')) )
151 | flows = sorted(glob(osp.join(fdir, '*.pfm')) )
152 | for i in range(len(flows)-1):
153 | if direction == 'into_future':
154 | self.image_list += [ [images[i], images[i+1]] ]
155 | self.flow_list += [ flows[i] ]
156 | elif direction == 'into_past':
157 | self.image_list += [ [images[i+1], images[i]] ]
158 | self.flow_list += [ flows[i+1] ]
159 |
160 |
161 | class KITTI(FlowDataset):
162 | def __init__(self, aug_params=None, split='training', root='datasets/KITTI'):
163 | super(KITTI, self).__init__(aug_params, sparse=True)
164 | if split == 'testing':
165 | self.is_test = True
166 |
167 | root = osp.join(root, split)
168 | images1 = sorted(glob(osp.join(root, 'image_2/*_10.png')))
169 | images2 = sorted(glob(osp.join(root, 'image_2/*_11.png')))
170 |
171 | for img1, img2 in zip(images1, images2):
172 | frame_id = img1.split('/')[-1]
173 | self.extra_info += [ [frame_id] ]
174 | self.image_list += [ [img1, img2] ]
175 |
176 | if split == 'training':
177 | self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png')))
178 |
179 |
180 | class HD1K(FlowDataset):
181 | def __init__(self, aug_params=None, root='datasets/HD1k'):
182 | super(HD1K, self).__init__(aug_params, sparse=True)
183 |
184 | seq_ix = 0
185 | while 1:
186 | flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix)))
187 | images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix)))
188 |
189 | if len(flows) == 0:
190 | break
191 |
192 | for i in range(len(flows)-1):
193 | self.flow_list += [flows[i]]
194 | self.image_list += [ [images[i], images[i+1]] ]
195 |
196 | seq_ix += 1
197 |
198 |
199 | def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
200 | """ Create the data loader for the corresponding trainign set """
201 |
202 | if args.stage == 'chairs':
203 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}
204 | train_dataset = FlyingChairs(aug_params, split='training')
205 |
206 | elif args.stage == 'things':
207 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True}
208 | clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass')
209 | final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass')
210 | train_dataset = clean_dataset + final_dataset
211 |
212 | elif args.stage == 'sintel':
213 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}
214 | things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
215 | sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
216 | sintel_final = MpiSintel(aug_params, split='training', dstype='final')
217 |
218 | if TRAIN_DS == 'C+T+K+S+H':
219 | kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})
220 | hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})
221 | train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things
222 |
223 | elif TRAIN_DS == 'C+T+K/S':
224 | train_dataset = 100*sintel_clean + 100*sintel_final + things
225 |
226 | elif args.stage == 'kitti':
227 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
228 | train_dataset = KITTI(aug_params, split='training')
229 |
230 | train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
231 | pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
232 |
233 | print('Training with %d image pairs' % len(train_dataset))
234 | return train_loader
235 |
236 |
--------------------------------------------------------------------------------
/model/raft/extractor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class ResidualBlock(nn.Module):
7 | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
8 | super(ResidualBlock, self).__init__()
9 |
10 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
12 | self.relu = nn.ReLU(inplace=True)
13 |
14 | num_groups = planes // 8
15 |
16 | if norm_fn == 'group':
17 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
18 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
19 | if not stride == 1:
20 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
21 |
22 | elif norm_fn == 'batch':
23 | self.norm1 = nn.BatchNorm2d(planes)
24 | self.norm2 = nn.BatchNorm2d(planes)
25 | if not stride == 1:
26 | self.norm3 = nn.BatchNorm2d(planes)
27 |
28 | elif norm_fn == 'instance':
29 | self.norm1 = nn.InstanceNorm2d(planes)
30 | self.norm2 = nn.InstanceNorm2d(planes)
31 | if not stride == 1:
32 | self.norm3 = nn.InstanceNorm2d(planes)
33 |
34 | elif norm_fn == 'none':
35 | self.norm1 = nn.Sequential()
36 | self.norm2 = nn.Sequential()
37 | if not stride == 1:
38 | self.norm3 = nn.Sequential()
39 |
40 | if stride == 1:
41 | self.downsample = None
42 |
43 | else:
44 | self.downsample = nn.Sequential(
45 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
46 |
47 |
48 | def forward(self, x):
49 | y = x
50 | y = self.relu(self.norm1(self.conv1(y)))
51 | y = self.relu(self.norm2(self.conv2(y)))
52 |
53 | if self.downsample is not None:
54 | x = self.downsample(x)
55 |
56 | return self.relu(x+y)
57 |
58 |
59 |
60 | class BottleneckBlock(nn.Module):
61 | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
62 | super(BottleneckBlock, self).__init__()
63 |
64 | self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
65 | self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
66 | self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
67 | self.relu = nn.ReLU(inplace=True)
68 |
69 | num_groups = planes // 8
70 |
71 | if norm_fn == 'group':
72 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
73 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
74 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
75 | if not stride == 1:
76 | self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
77 |
78 | elif norm_fn == 'batch':
79 | self.norm1 = nn.BatchNorm2d(planes//4)
80 | self.norm2 = nn.BatchNorm2d(planes//4)
81 | self.norm3 = nn.BatchNorm2d(planes)
82 | if not stride == 1:
83 | self.norm4 = nn.BatchNorm2d(planes)
84 |
85 | elif norm_fn == 'instance':
86 | self.norm1 = nn.InstanceNorm2d(planes//4)
87 | self.norm2 = nn.InstanceNorm2d(planes//4)
88 | self.norm3 = nn.InstanceNorm2d(planes)
89 | if not stride == 1:
90 | self.norm4 = nn.InstanceNorm2d(planes)
91 |
92 | elif norm_fn == 'none':
93 | self.norm1 = nn.Sequential()
94 | self.norm2 = nn.Sequential()
95 | self.norm3 = nn.Sequential()
96 | if not stride == 1:
97 | self.norm4 = nn.Sequential()
98 |
99 | if stride == 1:
100 | self.downsample = None
101 |
102 | else:
103 | self.downsample = nn.Sequential(
104 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
105 |
106 |
107 | def forward(self, x):
108 | y = x
109 | y = self.relu(self.norm1(self.conv1(y)))
110 | y = self.relu(self.norm2(self.conv2(y)))
111 | y = self.relu(self.norm3(self.conv3(y)))
112 |
113 | if self.downsample is not None:
114 | x = self.downsample(x)
115 |
116 | return self.relu(x+y)
117 |
118 | class BasicEncoder(nn.Module):
119 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
120 | super(BasicEncoder, self).__init__()
121 | self.norm_fn = norm_fn
122 |
123 | if self.norm_fn == 'group':
124 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
125 |
126 | elif self.norm_fn == 'batch':
127 | self.norm1 = nn.BatchNorm2d(64)
128 |
129 | elif self.norm_fn == 'instance':
130 | self.norm1 = nn.InstanceNorm2d(64)
131 |
132 | elif self.norm_fn == 'none':
133 | self.norm1 = nn.Sequential()
134 |
135 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
136 | self.relu1 = nn.ReLU(inplace=True)
137 |
138 | self.in_planes = 64
139 | self.layer1 = self._make_layer(64, stride=1)
140 | self.layer2 = self._make_layer(96, stride=2)
141 | self.layer3 = self._make_layer(128, stride=2)
142 |
143 | # output convolution
144 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
145 |
146 | self.dropout = None
147 | if dropout > 0:
148 | self.dropout = nn.Dropout2d(p=dropout)
149 |
150 | for m in self.modules():
151 | if isinstance(m, nn.Conv2d):
152 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
153 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
154 | if m.weight is not None:
155 | nn.init.constant_(m.weight, 1)
156 | if m.bias is not None:
157 | nn.init.constant_(m.bias, 0)
158 |
159 | def _make_layer(self, dim, stride=1):
160 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
161 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
162 | layers = (layer1, layer2)
163 |
164 | self.in_planes = dim
165 | return nn.Sequential(*layers)
166 |
167 |
168 | def forward(self, x):
169 |
170 | # if input is list, combine batch dimension
171 | is_list = isinstance(x, tuple) or isinstance(x, list)
172 | if is_list:
173 | batch_dim = x[0].shape[0]
174 | x = torch.cat(x, dim=0)
175 |
176 | x = self.conv1(x)
177 | x = self.norm1(x)
178 | x = self.relu1(x)
179 |
180 | x = self.layer1(x)
181 | x = self.layer2(x)
182 | x = self.layer3(x)
183 |
184 | x = self.conv2(x)
185 |
186 | if self.training and self.dropout is not None:
187 | x = self.dropout(x)
188 |
189 | if is_list:
190 | x = torch.split(x, [batch_dim, batch_dim], dim=0)
191 |
192 | return x
193 |
194 |
195 | class SmallEncoder(nn.Module):
196 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
197 | super(SmallEncoder, self).__init__()
198 | self.norm_fn = norm_fn
199 |
200 | if self.norm_fn == 'group':
201 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
202 |
203 | elif self.norm_fn == 'batch':
204 | self.norm1 = nn.BatchNorm2d(32)
205 |
206 | elif self.norm_fn == 'instance':
207 | self.norm1 = nn.InstanceNorm2d(32)
208 |
209 | elif self.norm_fn == 'none':
210 | self.norm1 = nn.Sequential()
211 |
212 | self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
213 | self.relu1 = nn.ReLU(inplace=True)
214 |
215 | self.in_planes = 32
216 | self.layer1 = self._make_layer(32, stride=1)
217 | self.layer2 = self._make_layer(64, stride=2)
218 | self.layer3 = self._make_layer(96, stride=2)
219 |
220 | self.dropout = None
221 | if dropout > 0:
222 | self.dropout = nn.Dropout2d(p=dropout)
223 |
224 | self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
225 |
226 | for m in self.modules():
227 | if isinstance(m, nn.Conv2d):
228 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
229 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
230 | if m.weight is not None:
231 | nn.init.constant_(m.weight, 1)
232 | if m.bias is not None:
233 | nn.init.constant_(m.bias, 0)
234 |
235 | def _make_layer(self, dim, stride=1):
236 | layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
237 | layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
238 | layers = (layer1, layer2)
239 |
240 | self.in_planes = dim
241 | return nn.Sequential(*layers)
242 |
243 |
244 | def forward(self, x):
245 |
246 | # if input is list, combine batch dimension
247 | is_list = isinstance(x, tuple) or isinstance(x, list)
248 | if is_list:
249 | batch_dim = x[0].shape[0]
250 | x = torch.cat(x, dim=0)
251 |
252 | x = self.conv1(x)
253 | x = self.norm1(x)
254 | x = self.relu1(x)
255 |
256 | x = self.layer1(x)
257 | x = self.layer2(x)
258 | x = self.layer3(x)
259 | x = self.conv2(x)
260 |
261 | if self.training and self.dropout is not None:
262 | x = self.dropout(x)
263 |
264 | if is_list:
265 | x = torch.split(x, [batch_dim, batch_dim], dim=0)
266 |
267 | return x
268 |
--------------------------------------------------------------------------------
/model/raft/raft.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from .update import BasicUpdateBlock, SmallUpdateBlock
7 | from .extractor import BasicEncoder, SmallEncoder
8 | from .corr import CorrBlock, AlternateCorrBlock
9 | from .utils.utils import bilinear_sampler, coords_grid, upflow8
10 |
11 |
12 | try:
13 | autocast = torch.cuda.amp.autocast
14 | except:
15 | # dummy autocast for PyTorch < 1.6
16 | class autocast:
17 | def __init__(self, enabled):
18 | pass
19 | def __enter__(self):
20 | pass
21 | def __exit__(self, *args):
22 | pass
23 |
24 |
25 | class RAFT(nn.Module):
26 | def __init__(self, args):
27 | super(RAFT, self).__init__()
28 | self.args = args
29 |
30 | self.hidden_dim = hdim = 128
31 | self.context_dim = cdim = 128
32 | args.corr_levels = 4
33 | args.corr_radius = 4
34 |
35 | if 'dropout' not in self.args:
36 | self.args.dropout = 0
37 |
38 | if 'alternate_corr' not in self.args:
39 | self.args.alternate_corr = False
40 |
41 | # feature network, context network, and update block
42 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
43 | self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
44 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
45 |
46 | def freeze_bn(self):
47 | for m in self.modules():
48 | if isinstance(m, nn.BatchNorm2d):
49 | m.eval()
50 |
51 | def initialize_flow(self, img):
52 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
53 | N, C, H, W = img.shape
54 | coords0 = coords_grid(N, H//8, W//8, device=img.device)
55 | coords1 = coords_grid(N, H//8, W//8, device=img.device)
56 |
57 | # optical flow computed as difference: flow = coords1 - coords0
58 | return coords0, coords1
59 |
60 | def upsample_flow(self, flow, mask):
61 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
62 | N, _, H, W = flow.shape
63 | mask = mask.view(N, 1, 9, 8, 8, H, W)
64 | mask = torch.softmax(mask, dim=2)
65 |
66 | up_flow = F.unfold(8 * flow, [3,3], padding=1)
67 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
68 |
69 | up_flow = torch.sum(mask * up_flow, dim=2)
70 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
71 | return up_flow.reshape(N, 2, 8*H, 8*W)
72 |
73 |
74 | def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
75 | """ Estimate optical flow between pair of frames """
76 |
77 | # image1 = 2 * image1 - 1.0
78 | # image2 = 2 * image2 - 1.0
79 |
80 | image1 = image1.contiguous()
81 | image2 = image2.contiguous()
82 |
83 | hdim = self.hidden_dim
84 | cdim = self.context_dim
85 |
86 | # run the feature network
87 | with autocast(enabled=False):
88 | fmap1, fmap2 = self.fnet([image1, image2])
89 |
90 | fmap1 = fmap1.float()
91 | fmap2 = fmap2.float()
92 | if self.args.alternate_corr:
93 | corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
94 | else:
95 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
96 |
97 | # run the context network
98 | with autocast(enabled=False):
99 | cnet = self.cnet(image1)
100 | net, inp = torch.split(cnet, [hdim, cdim], dim=1)
101 | net = torch.tanh(net)
102 | inp = torch.relu(inp)
103 |
104 | coords0, coords1 = self.initialize_flow(image1)
105 |
106 | if flow_init is not None:
107 | coords1 = coords1 + flow_init
108 |
109 | flow_predictions = []
110 | for itr in range(iters):
111 | coords1 = coords1.detach()
112 | corr = corr_fn(coords1) # index correlation volume
113 |
114 | flow = coords1 - coords0
115 | with autocast(enabled=False):
116 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
117 |
118 | # F(t+1) = F(t) + \Delta(t)
119 | coords1 = coords1 + delta_flow
120 |
121 | # upsample predictions
122 | if up_mask is None:
123 | flow_up = upflow8(coords1 - coords0)
124 | else:
125 | flow_up = self.upsample_flow(coords1 - coords0, up_mask)
126 |
127 | flow_predictions.append(flow_up)
128 |
129 | if test_mode:
130 | return coords1 - coords0, flow_up
131 |
132 | return flow_predictions
133 |
--------------------------------------------------------------------------------
/model/raft/update.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class FlowHead(nn.Module):
7 | def __init__(self, input_dim=128, hidden_dim=256):
8 | super(FlowHead, self).__init__()
9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
11 | self.relu = nn.ReLU(inplace=True)
12 |
13 | def forward(self, x):
14 | return self.conv2(self.relu(self.conv1(x)))
15 |
16 | class ConvGRU(nn.Module):
17 | def __init__(self, hidden_dim=128, input_dim=192+128):
18 | super(ConvGRU, self).__init__()
19 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
20 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
21 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
22 |
23 | def forward(self, h, x):
24 | hx = torch.cat([h, x], dim=1)
25 |
26 | z = torch.sigmoid(self.convz(hx))
27 | r = torch.sigmoid(self.convr(hx))
28 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
29 |
30 | h = (1-z) * h + z * q
31 | return h
32 |
33 | class SepConvGRU(nn.Module):
34 | def __init__(self, hidden_dim=128, input_dim=192+128):
35 | super(SepConvGRU, self).__init__()
36 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
37 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
38 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
39 |
40 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
41 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
42 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
43 |
44 |
45 | def forward(self, h, x):
46 | # horizontal
47 | hx = torch.cat([h, x], dim=1)
48 | z = torch.sigmoid(self.convz1(hx))
49 | r = torch.sigmoid(self.convr1(hx))
50 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
51 | h = (1-z) * h + z * q
52 |
53 | # vertical
54 | hx = torch.cat([h, x], dim=1)
55 | z = torch.sigmoid(self.convz2(hx))
56 | r = torch.sigmoid(self.convr2(hx))
57 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
58 | h = (1-z) * h + z * q
59 |
60 | return h
61 |
62 | class SmallMotionEncoder(nn.Module):
63 | def __init__(self, args):
64 | super(SmallMotionEncoder, self).__init__()
65 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
66 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
67 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
68 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
69 | self.conv = nn.Conv2d(128, 80, 3, padding=1)
70 |
71 | def forward(self, flow, corr):
72 | cor = F.relu(self.convc1(corr))
73 | flo = F.relu(self.convf1(flow))
74 | flo = F.relu(self.convf2(flo))
75 | cor_flo = torch.cat([cor, flo], dim=1)
76 | out = F.relu(self.conv(cor_flo))
77 | return torch.cat([out, flow], dim=1)
78 |
79 | class BasicMotionEncoder(nn.Module):
80 | def __init__(self, args):
81 | super(BasicMotionEncoder, self).__init__()
82 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
83 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
84 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
85 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
86 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
87 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
88 |
89 | def forward(self, flow, corr):
90 | cor = F.relu(self.convc1(corr))
91 | cor = F.relu(self.convc2(cor))
92 | flo = F.relu(self.convf1(flow))
93 | flo = F.relu(self.convf2(flo))
94 |
95 | cor_flo = torch.cat([cor, flo], dim=1)
96 | out = F.relu(self.conv(cor_flo))
97 | return torch.cat([out, flow], dim=1)
98 |
99 | class SmallUpdateBlock(nn.Module):
100 | def __init__(self, args, hidden_dim=96):
101 | super(SmallUpdateBlock, self).__init__()
102 | self.encoder = SmallMotionEncoder(args)
103 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
104 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
105 |
106 | def forward(self, net, inp, corr, flow):
107 | motion_features = self.encoder(flow, corr)
108 | inp = torch.cat([inp, motion_features], dim=1)
109 | net = self.gru(net, inp)
110 | delta_flow = self.flow_head(net)
111 |
112 | return net, None, delta_flow
113 |
114 | class BasicUpdateBlock(nn.Module):
115 | def __init__(self, args, hidden_dim=128, input_dim=128):
116 | super(BasicUpdateBlock, self).__init__()
117 | self.args = args
118 | self.encoder = BasicMotionEncoder(args)
119 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
120 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
121 |
122 | self.mask = nn.Sequential(
123 | nn.Conv2d(128, 256, 3, padding=1),
124 | nn.ReLU(inplace=True),
125 | nn.Conv2d(256, 64*9, 1, padding=0))
126 |
127 | def forward(self, net, inp, corr, flow, upsample=True):
128 | motion_features = self.encoder(flow, corr)
129 | inp = torch.cat([inp, motion_features], dim=1)
130 |
131 | net = self.gru(net, inp)
132 | delta_flow = self.flow_head(net)
133 |
134 | # scale mask to balence gradients
135 | mask = .25 * self.mask(net)
136 | return net, mask, delta_flow
137 |
138 |
139 |
140 |
--------------------------------------------------------------------------------
/model/raft/utils/augmentor.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | import math
4 | from PIL import Image
5 |
6 | import cv2
7 | cv2.setNumThreads(0)
8 | cv2.ocl.setUseOpenCL(False)
9 |
10 | import torch
11 | from torchvision.transforms import ColorJitter
12 | import torch.nn.functional as F
13 |
14 |
15 | class FlowAugmentor:
16 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):
17 |
18 | # spatial augmentation params
19 | self.crop_size = crop_size
20 | self.min_scale = min_scale
21 | self.max_scale = max_scale
22 | self.spatial_aug_prob = 0.8
23 | self.stretch_prob = 0.8
24 | self.max_stretch = 0.2
25 |
26 | # flip augmentation params
27 | self.do_flip = do_flip
28 | self.h_flip_prob = 0.5
29 | self.v_flip_prob = 0.1
30 |
31 | # photometric augmentation params
32 | self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14)
33 | self.asymmetric_color_aug_prob = 0.2
34 | self.eraser_aug_prob = 0.5
35 |
36 | def color_transform(self, img1, img2):
37 | """ Photometric augmentation """
38 |
39 | # asymmetric
40 | if np.random.rand() < self.asymmetric_color_aug_prob:
41 | img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
42 | img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
43 |
44 | # symmetric
45 | else:
46 | image_stack = np.concatenate([img1, img2], axis=0)
47 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
48 | img1, img2 = np.split(image_stack, 2, axis=0)
49 |
50 | return img1, img2
51 |
52 | def eraser_transform(self, img1, img2, bounds=[50, 100]):
53 | """ Occlusion augmentation """
54 |
55 | ht, wd = img1.shape[:2]
56 | if np.random.rand() < self.eraser_aug_prob:
57 | mean_color = np.mean(img2.reshape(-1, 3), axis=0)
58 | for _ in range(np.random.randint(1, 3)):
59 | x0 = np.random.randint(0, wd)
60 | y0 = np.random.randint(0, ht)
61 | dx = np.random.randint(bounds[0], bounds[1])
62 | dy = np.random.randint(bounds[0], bounds[1])
63 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color
64 |
65 | return img1, img2
66 |
67 | def spatial_transform(self, img1, img2, flow):
68 | # randomly sample scale
69 | ht, wd = img1.shape[:2]
70 | min_scale = np.maximum(
71 | (self.crop_size[0] + 8) / float(ht),
72 | (self.crop_size[1] + 8) / float(wd))
73 |
74 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
75 | scale_x = scale
76 | scale_y = scale
77 | if np.random.rand() < self.stretch_prob:
78 | scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
79 | scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
80 |
81 | scale_x = np.clip(scale_x, min_scale, None)
82 | scale_y = np.clip(scale_y, min_scale, None)
83 |
84 | if np.random.rand() < self.spatial_aug_prob:
85 | # rescale the images
86 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
87 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
88 | flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
89 | flow = flow * [scale_x, scale_y]
90 |
91 | if self.do_flip:
92 | if np.random.rand() < self.h_flip_prob: # h-flip
93 | img1 = img1[:, ::-1]
94 | img2 = img2[:, ::-1]
95 | flow = flow[:, ::-1] * [-1.0, 1.0]
96 |
97 | if np.random.rand() < self.v_flip_prob: # v-flip
98 | img1 = img1[::-1, :]
99 | img2 = img2[::-1, :]
100 | flow = flow[::-1, :] * [1.0, -1.0]
101 |
102 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
103 | x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
104 |
105 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
106 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
107 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
108 |
109 | return img1, img2, flow
110 |
111 | def __call__(self, img1, img2, flow):
112 | img1, img2 = self.color_transform(img1, img2)
113 | img1, img2 = self.eraser_transform(img1, img2)
114 | img1, img2, flow = self.spatial_transform(img1, img2, flow)
115 |
116 | img1 = np.ascontiguousarray(img1)
117 | img2 = np.ascontiguousarray(img2)
118 | flow = np.ascontiguousarray(flow)
119 |
120 | return img1, img2, flow
121 |
122 | class SparseFlowAugmentor:
123 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
124 | # spatial augmentation params
125 | self.crop_size = crop_size
126 | self.min_scale = min_scale
127 | self.max_scale = max_scale
128 | self.spatial_aug_prob = 0.8
129 | self.stretch_prob = 0.8
130 | self.max_stretch = 0.2
131 |
132 | # flip augmentation params
133 | self.do_flip = do_flip
134 | self.h_flip_prob = 0.5
135 | self.v_flip_prob = 0.1
136 |
137 | # photometric augmentation params
138 | self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
139 | self.asymmetric_color_aug_prob = 0.2
140 | self.eraser_aug_prob = 0.5
141 |
142 | def color_transform(self, img1, img2):
143 | image_stack = np.concatenate([img1, img2], axis=0)
144 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
145 | img1, img2 = np.split(image_stack, 2, axis=0)
146 | return img1, img2
147 |
148 | def eraser_transform(self, img1, img2):
149 | ht, wd = img1.shape[:2]
150 | if np.random.rand() < self.eraser_aug_prob:
151 | mean_color = np.mean(img2.reshape(-1, 3), axis=0)
152 | for _ in range(np.random.randint(1, 3)):
153 | x0 = np.random.randint(0, wd)
154 | y0 = np.random.randint(0, ht)
155 | dx = np.random.randint(50, 100)
156 | dy = np.random.randint(50, 100)
157 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color
158 |
159 | return img1, img2
160 |
161 | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
162 | ht, wd = flow.shape[:2]
163 | coords = np.meshgrid(np.arange(wd), np.arange(ht))
164 | coords = np.stack(coords, axis=-1)
165 |
166 | coords = coords.reshape(-1, 2).astype(np.float32)
167 | flow = flow.reshape(-1, 2).astype(np.float32)
168 | valid = valid.reshape(-1).astype(np.float32)
169 |
170 | coords0 = coords[valid>=1]
171 | flow0 = flow[valid>=1]
172 |
173 | ht1 = int(round(ht * fy))
174 | wd1 = int(round(wd * fx))
175 |
176 | coords1 = coords0 * [fx, fy]
177 | flow1 = flow0 * [fx, fy]
178 |
179 | xx = np.round(coords1[:,0]).astype(np.int32)
180 | yy = np.round(coords1[:,1]).astype(np.int32)
181 |
182 | v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
183 | xx = xx[v]
184 | yy = yy[v]
185 | flow1 = flow1[v]
186 |
187 | flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
188 | valid_img = np.zeros([ht1, wd1], dtype=np.int32)
189 |
190 | flow_img[yy, xx] = flow1
191 | valid_img[yy, xx] = 1
192 |
193 | return flow_img, valid_img
194 |
195 | def spatial_transform(self, img1, img2, flow, valid):
196 | # randomly sample scale
197 |
198 | ht, wd = img1.shape[:2]
199 | min_scale = np.maximum(
200 | (self.crop_size[0] + 1) / float(ht),
201 | (self.crop_size[1] + 1) / float(wd))
202 |
203 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
204 | scale_x = np.clip(scale, min_scale, None)
205 | scale_y = np.clip(scale, min_scale, None)
206 |
207 | if np.random.rand() < self.spatial_aug_prob:
208 | # rescale the images
209 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
210 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
211 | flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
212 |
213 | if self.do_flip:
214 | if np.random.rand() < 0.5: # h-flip
215 | img1 = img1[:, ::-1]
216 | img2 = img2[:, ::-1]
217 | flow = flow[:, ::-1] * [-1.0, 1.0]
218 | valid = valid[:, ::-1]
219 |
220 | margin_y = 20
221 | margin_x = 50
222 |
223 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)
224 | x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)
225 |
226 | y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
227 | x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
228 |
229 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
230 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
231 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
232 | valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
233 | return img1, img2, flow, valid
234 |
235 |
236 | def __call__(self, img1, img2, flow, valid):
237 | img1, img2 = self.color_transform(img1, img2)
238 | img1, img2 = self.eraser_transform(img1, img2)
239 | img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)
240 |
241 | img1 = np.ascontiguousarray(img1)
242 | img2 = np.ascontiguousarray(img2)
243 | flow = np.ascontiguousarray(flow)
244 | valid = np.ascontiguousarray(valid)
245 |
246 | return img1, img2, flow, valid
247 |
--------------------------------------------------------------------------------
/model/raft/utils/flow_viz.py:
--------------------------------------------------------------------------------
1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
2 |
3 |
4 | # MIT License
5 | #
6 | # Copyright (c) 2018 Tom Runia
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to conditions.
14 | #
15 | # Author: Tom Runia
16 | # Date Created: 2018-08-03
17 |
18 | import numpy as np
19 |
20 | def make_colorwheel():
21 | """
22 | Generates a color wheel for optical flow visualization as presented in:
23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
25 |
26 | Code follows the original C++ source code of Daniel Scharstein.
27 | Code follows the the Matlab source code of Deqing Sun.
28 |
29 | Returns:
30 | np.ndarray: Color wheel
31 | """
32 |
33 | RY = 15
34 | YG = 6
35 | GC = 4
36 | CB = 11
37 | BM = 13
38 | MR = 6
39 |
40 | ncols = RY + YG + GC + CB + BM + MR
41 | colorwheel = np.zeros((ncols, 3))
42 | col = 0
43 |
44 | # RY
45 | colorwheel[0:RY, 0] = 255
46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
47 | col = col+RY
48 | # YG
49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
50 | colorwheel[col:col+YG, 1] = 255
51 | col = col+YG
52 | # GC
53 | colorwheel[col:col+GC, 1] = 255
54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
55 | col = col+GC
56 | # CB
57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
58 | colorwheel[col:col+CB, 2] = 255
59 | col = col+CB
60 | # BM
61 | colorwheel[col:col+BM, 2] = 255
62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
63 | col = col+BM
64 | # MR
65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
66 | colorwheel[col:col+MR, 0] = 255
67 | return colorwheel
68 |
69 |
70 | def flow_uv_to_colors(u, v, convert_to_bgr=False):
71 | """
72 | Applies the flow color wheel to (possibly clipped) flow components u and v.
73 |
74 | According to the C++ source code of Daniel Scharstein
75 | According to the Matlab source code of Deqing Sun
76 |
77 | Args:
78 | u (np.ndarray): Input horizontal flow of shape [H,W]
79 | v (np.ndarray): Input vertical flow of shape [H,W]
80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
81 |
82 | Returns:
83 | np.ndarray: Flow visualization image of shape [H,W,3]
84 | """
85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
86 | colorwheel = make_colorwheel() # shape [55x3]
87 | ncols = colorwheel.shape[0]
88 | rad = np.sqrt(np.square(u) + np.square(v))
89 | a = np.arctan2(-v, -u)/np.pi
90 | fk = (a+1) / 2*(ncols-1)
91 | k0 = np.floor(fk).astype(np.int32)
92 | k1 = k0 + 1
93 | k1[k1 == ncols] = 0
94 | f = fk - k0
95 | for i in range(colorwheel.shape[1]):
96 | tmp = colorwheel[:,i]
97 | col0 = tmp[k0] / 255.0
98 | col1 = tmp[k1] / 255.0
99 | col = (1-f)*col0 + f*col1
100 | idx = (rad <= 1)
101 | col[idx] = 1 - rad[idx] * (1-col[idx])
102 | col[~idx] = col[~idx] * 0.75 # out of range
103 | # Note the 2-i => BGR instead of RGB
104 | ch_idx = 2-i if convert_to_bgr else i
105 | flow_image[:,:,ch_idx] = np.floor(255 * col)
106 | return flow_image
107 |
108 |
109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
110 | """
111 | Expects a two dimensional flow image of shape.
112 |
113 | Args:
114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
117 |
118 | Returns:
119 | np.ndarray: Flow visualization image of shape [H,W,3]
120 | """
121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions'
122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
123 | if clip_flow is not None:
124 | flow_uv = np.clip(flow_uv, 0, clip_flow)
125 | u = flow_uv[:,:,0]
126 | v = flow_uv[:,:,1]
127 | rad = np.sqrt(np.square(u) + np.square(v))
128 | rad_max = np.max(rad)
129 | epsilon = 1e-5
130 | u = u / (rad_max + epsilon)
131 | v = v / (rad_max + epsilon)
132 | return flow_uv_to_colors(u, v, convert_to_bgr)
--------------------------------------------------------------------------------
/model/raft/utils/frame_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | from os.path import *
4 | import re
5 |
6 | import cv2
7 | cv2.setNumThreads(0)
8 | cv2.ocl.setUseOpenCL(False)
9 |
10 | TAG_CHAR = np.array([202021.25], np.float32)
11 |
12 | def readFlow(fn):
13 | """ Read .flo file in Middlebury format"""
14 | # Code adapted from:
15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
16 |
17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only!
18 | # print 'fn = %s'%(fn)
19 | with open(fn, 'rb') as f:
20 | magic = np.fromfile(f, np.float32, count=1)
21 | if 202021.25 != magic:
22 | print('Magic number incorrect. Invalid .flo file')
23 | return None
24 | else:
25 | w = np.fromfile(f, np.int32, count=1)
26 | h = np.fromfile(f, np.int32, count=1)
27 | # print 'Reading %d x %d flo file\n' % (w, h)
28 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
29 | # Reshape data into 3D array (columns, rows, bands)
30 | # The reshape here is for visualization, the original code is (w,h,2)
31 | return np.resize(data, (int(h), int(w), 2))
32 |
33 | def readPFM(file):
34 | file = open(file, 'rb')
35 |
36 | color = None
37 | width = None
38 | height = None
39 | scale = None
40 | endian = None
41 |
42 | header = file.readline().rstrip()
43 | if header == b'PF':
44 | color = True
45 | elif header == b'Pf':
46 | color = False
47 | else:
48 | raise Exception('Not a PFM file.')
49 |
50 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
51 | if dim_match:
52 | width, height = map(int, dim_match.groups())
53 | else:
54 | raise Exception('Malformed PFM header.')
55 |
56 | scale = float(file.readline().rstrip())
57 | if scale < 0: # little-endian
58 | endian = '<'
59 | scale = -scale
60 | else:
61 | endian = '>' # big-endian
62 |
63 | data = np.fromfile(file, endian + 'f')
64 | shape = (height, width, 3) if color else (height, width)
65 |
66 | data = np.reshape(data, shape)
67 | data = np.flipud(data)
68 | return data
69 |
70 | def writeFlow(filename,uv,v=None):
71 | """ Write optical flow to file.
72 |
73 | If v is None, uv is assumed to contain both u and v channels,
74 | stacked in depth.
75 | Original code by Deqing Sun, adapted from Daniel Scharstein.
76 | """
77 | nBands = 2
78 |
79 | if v is None:
80 | assert(uv.ndim == 3)
81 | assert(uv.shape[2] == 2)
82 | u = uv[:,:,0]
83 | v = uv[:,:,1]
84 | else:
85 | u = uv
86 |
87 | assert(u.shape == v.shape)
88 | height,width = u.shape
89 | f = open(filename,'wb')
90 | # write the header
91 | f.write(TAG_CHAR)
92 | np.array(width).astype(np.int32).tofile(f)
93 | np.array(height).astype(np.int32).tofile(f)
94 | # arrange into matrix form
95 | tmp = np.zeros((height, width*nBands))
96 | tmp[:,np.arange(width)*2] = u
97 | tmp[:,np.arange(width)*2 + 1] = v
98 | tmp.astype(np.float32).tofile(f)
99 | f.close()
100 |
101 |
102 | def readFlowKITTI(filename):
103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
104 | flow = flow[:,:,::-1].astype(np.float32)
105 | flow, valid = flow[:, :, :2], flow[:, :, 2]
106 | flow = (flow - 2**15) / 64.0
107 | return flow, valid
108 |
109 | def readDispKITTI(filename):
110 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
111 | valid = disp > 0.0
112 | flow = np.stack([-disp, np.zeros_like(disp)], -1)
113 | return flow, valid
114 |
115 |
116 | def writeFlowKITTI(filename, uv):
117 | uv = 64.0 * uv + 2**15
118 | valid = np.ones([uv.shape[0], uv.shape[1], 1])
119 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
120 | cv2.imwrite(filename, uv[..., ::-1])
121 |
122 |
123 | def read_gen(file_name, pil=False):
124 | ext = splitext(file_name)[-1]
125 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
126 | return Image.open(file_name)
127 | elif ext == '.bin' or ext == '.raw':
128 | return np.load(file_name)
129 | elif ext == '.flo':
130 | return readFlow(file_name).astype(np.float32)
131 | elif ext == '.pfm':
132 | flow = readPFM(file_name).astype(np.float32)
133 | if len(flow.shape) == 2:
134 | return flow
135 | else:
136 | return flow[:, :, :-1]
137 | return []
--------------------------------------------------------------------------------
/model/raft/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | from scipy import interpolate
5 |
6 |
7 | class InputPadder:
8 | """ Pads images such that dimensions are divisible by 8 """
9 | def __init__(self, dims, mode='sintel'):
10 | self.ht, self.wd = dims[-2:]
11 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
12 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
13 | if mode == 'sintel':
14 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
15 | else:
16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
17 |
18 | def pad(self, *inputs):
19 | return [F.pad(x, self._pad, mode='replicate') for x in inputs]
20 |
21 | def unpad(self,x):
22 | ht, wd = x.shape[-2:]
23 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
24 | return x[..., c[0]:c[1], c[2]:c[3]]
25 |
26 | def forward_interpolate(flow):
27 | flow = flow.detach().cpu().numpy()
28 | dx, dy = flow[0], flow[1]
29 |
30 | ht, wd = dx.shape
31 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
32 |
33 | x1 = x0 + dx
34 | y1 = y0 + dy
35 |
36 | x1 = x1.reshape(-1)
37 | y1 = y1.reshape(-1)
38 | dx = dx.reshape(-1)
39 | dy = dy.reshape(-1)
40 |
41 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
42 | x1 = x1[valid]
43 | y1 = y1[valid]
44 | dx = dx[valid]
45 | dy = dy[valid]
46 |
47 | flow_x = interpolate.griddata(
48 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
49 |
50 | flow_y = interpolate.griddata(
51 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
52 |
53 | flow = np.stack([flow_x, flow_y], axis=0)
54 | return torch.from_numpy(flow).float()
55 |
56 |
57 | def bilinear_sampler(img, coords, mode='bilinear', mask=False):
58 | """ Wrapper for grid_sample, uses pixel coordinates """
59 | H, W = img.shape[-2:]
60 | xgrid, ygrid = coords.split([1,1], dim=-1)
61 | xgrid = 2*xgrid/(W-1) - 1
62 | ygrid = 2*ygrid/(H-1) - 1
63 |
64 | grid = torch.cat([xgrid, ygrid], dim=-1)
65 | img = F.grid_sample(img, grid, align_corners=True)
66 |
67 | if mask:
68 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
69 | return img, mask.float()
70 |
71 | return img
72 |
73 |
74 | def coords_grid(batch, ht, wd, device):
75 | coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device))
76 | coords = torch.stack(coords[::-1], dim=0).float()
77 | return coords[None].repeat(batch, 1, 1, 1)
78 |
79 |
80 | def upflow8(flow, mode='bilinear'):
81 | new_size = (8 * flow.shape[2], 8 * flow.shape[3])
82 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
83 |
--------------------------------------------------------------------------------
/model/util.py:
--------------------------------------------------------------------------------
1 | # adopted from
2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3 | # and
4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5 | # and
6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7 | #
8 | # thanks!
9 |
10 |
11 | import os
12 | import math
13 | from inspect import isfunction
14 | import torch
15 | import torch.nn as nn
16 | import numpy as np
17 | from einops import repeat
18 |
19 |
20 | def exists(val):
21 | return val is not None
22 |
23 |
24 | def default(val, d):
25 | if exists(val):
26 | return val
27 | return d() if isfunction(d) else d
28 |
29 |
30 | def checkpoint(func, inputs, params, flag):
31 | """
32 | Evaluate a function without caching intermediate activations, allowing for
33 | reduced memory at the expense of extra compute in the backward pass.
34 | :param func: the function to evaluate.
35 | :param inputs: the argument sequence to pass to `func`.
36 | :param params: a sequence of parameters `func` depends on but does not
37 | explicitly take as arguments.
38 | :param flag: if False, disable gradient checkpointing.
39 | """
40 | if flag:
41 | args = tuple(inputs) + tuple(params)
42 | return CheckpointFunction.apply(func, len(inputs), *args)
43 | else:
44 | return func(*inputs)
45 |
46 |
47 | # class CheckpointFunction(torch.autograd.Function):
48 | # @staticmethod
49 | # def forward(ctx, run_function, length, *args):
50 | # ctx.run_function = run_function
51 | # ctx.input_tensors = list(args[:length])
52 | # ctx.input_params = list(args[length:])
53 | # ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
54 | # "dtype": torch.get_autocast_gpu_dtype(),
55 | # "cache_enabled": torch.is_autocast_cache_enabled()}
56 | # with torch.no_grad():
57 | # output_tensors = ctx.run_function(*ctx.input_tensors)
58 | # return output_tensors
59 |
60 | # @staticmethod
61 | # def backward(ctx, *output_grads):
62 | # ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
63 | # with torch.enable_grad(), \
64 | # torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
65 | # # Fixes a bug where the first op in run_function modifies the
66 | # # Tensor storage in place, which is not allowed for detach()'d
67 | # # Tensors.
68 | # shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
69 | # output_tensors = ctx.run_function(*shallow_copies)
70 | # input_grads = torch.autograd.grad(
71 | # output_tensors,
72 | # ctx.input_tensors + ctx.input_params,
73 | # output_grads,
74 | # allow_unused=True,
75 | # )
76 | # del ctx.input_tensors
77 | # del ctx.input_params
78 | # del output_tensors
79 | # return (None, None) + input_grads
80 |
81 |
82 | # Fixes: When we set unet parameters with requires_grad=False, the original CheckpointFunction
83 | # still tries to compute gradient for unet parameters.
84 | # https://discuss.pytorch.org/t/get-runtimeerror-one-of-the-differentiated-tensors-does-not-require-grad-in-pytorch-lightning/179738/6
85 | class CheckpointFunction(torch.autograd.Function):
86 | @staticmethod
87 | def forward(ctx, run_function, length, *args):
88 | ctx.run_function = run_function
89 | ctx.input_tensors = list(args[:length])
90 | ctx.input_params = list(args[length:])
91 | ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
92 | "dtype": torch.get_autocast_gpu_dtype(),
93 | "cache_enabled": torch.is_autocast_cache_enabled()}
94 | with torch.no_grad():
95 | output_tensors = ctx.run_function(*ctx.input_tensors)
96 | return output_tensors
97 |
98 | @staticmethod
99 | def backward(ctx, *output_grads):
100 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
101 | with torch.enable_grad(), \
102 | torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
103 | # Fixes a bug where the first op in run_function modifies the
104 | # Tensor storage in place, which is not allowed for detach()'d
105 | # Tensors.
106 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
107 | output_tensors = ctx.run_function(*shallow_copies)
108 | grads = torch.autograd.grad(
109 | output_tensors,
110 | ctx.input_tensors + [x for x in ctx.input_params if x.requires_grad],
111 | output_grads,
112 | allow_unused=True,
113 | )
114 | grads = list(grads)
115 | # Assign gradients to the correct positions, matching None for those that do not require gradients
116 | input_grads = []
117 | for tensor in ctx.input_tensors + ctx.input_params:
118 | if tensor.requires_grad:
119 | input_grads.append(grads.pop(0)) # Get the next computed gradient
120 | else:
121 | input_grads.append(None) # No gradient required for this tensor
122 | del ctx.input_tensors
123 | del ctx.input_params
124 | del output_tensors
125 | return (None, None) + tuple(input_grads)
126 |
127 |
128 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
129 | """
130 | Create sinusoidal timestep embeddings.
131 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
132 | These may be fractional.
133 | :param dim: the dimension of the output.
134 | :param max_period: controls the minimum frequency of the embeddings.
135 | :return: an [N x dim] Tensor of positional embeddings.
136 | """
137 | if not repeat_only:
138 | half = dim // 2
139 | freqs = torch.exp(
140 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
141 | ).to(device=timesteps.device)
142 | args = timesteps[:, None].float() * freqs[None]
143 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
144 | if dim % 2:
145 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
146 | else:
147 | embedding = repeat(timesteps, 'b -> b d', d=dim)
148 | return embedding
149 |
150 |
151 | def zero_module(module):
152 | """
153 | Zero out the parameters of a module and return it.
154 | """
155 | for p in module.parameters():
156 | p.detach().zero_()
157 | return module
158 |
159 |
160 | def scale_module(module, scale):
161 | """
162 | Scale the parameters of a module and return it.
163 | """
164 | for p in module.parameters():
165 | p.detach().mul_(scale)
166 | return module
167 |
168 |
169 | def mean_flat(tensor):
170 | """
171 | Take the mean over all non-batch dimensions.
172 | """
173 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
174 |
175 |
176 | def normalization(channels):
177 | """
178 | Make a standard normalization layer.
179 | :param channels: number of input channels.
180 | :return: an nn.Module for normalization.
181 | """
182 | return GroupNorm32(32, channels)
183 |
184 |
185 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
186 | class SiLU(nn.Module):
187 | def forward(self, x):
188 | return x * torch.sigmoid(x)
189 |
190 |
191 | class GroupNorm32(nn.GroupNorm):
192 | def forward(self, x):
193 | return super().forward(x.float()).type(x.dtype)
194 |
195 | def conv_nd(dims, *args, **kwargs):
196 | """
197 | Create a 1D, 2D, or 3D convolution module.
198 | """
199 | if dims == 1:
200 | return nn.Conv1d(*args, **kwargs)
201 | elif dims == 2:
202 | return nn.Conv2d(*args, **kwargs)
203 | elif dims == 3:
204 | return nn.Conv3d(*args, **kwargs)
205 | raise ValueError(f"unsupported dimensions: {dims}")
206 |
207 |
208 | def linear(*args, **kwargs):
209 | """
210 | Create a linear module.
211 | """
212 | return nn.Linear(*args, **kwargs)
213 |
214 |
215 | def avg_pool_nd(dims, *args, **kwargs):
216 | """
217 | Create a 1D, 2D, or 3D average pooling module.
218 | """
219 | if dims == 1:
220 | return nn.AvgPool1d(*args, **kwargs)
221 | elif dims == 2:
222 | return nn.AvgPool2d(*args, **kwargs)
223 | elif dims == 3:
224 | return nn.AvgPool3d(*args, **kwargs)
225 | raise ValueError(f"unsupported dimensions: {dims}")
226 |
--------------------------------------------------------------------------------
/pipeline/V4_CA/pipeline.py:
--------------------------------------------------------------------------------
1 | from typing import overload, Tuple, Optional
2 |
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 | import numpy as np
7 | from PIL import Image
8 | from einops import rearrange
9 |
10 | from model.V4_CA.cldm import ControlLDM
11 | from model.V4_CA.gaussian_diffusion import Diffusion
12 | from utils.V4_CA.sampler import SpacedSampler
13 | from utils.cond_fn import Guidance
14 | from utils.common import wavelet_decomposition, wavelet_reconstruction, count_vram_usage
15 |
16 | from torchvision.utils import save_image
17 |
18 |
19 | def bicubic_resize(img: np.ndarray, scale: float) -> np.ndarray:
20 | pil = Image.fromarray(img)
21 | res = pil.resize(tuple(int(x * scale) for x in pil.size), Image.BICUBIC)
22 | return np.array(res)
23 |
24 |
25 | def resize_short_edge_to(imgs: torch.Tensor, size: int) -> torch.Tensor:
26 | _, _, h, w = imgs.size()
27 | if h == w:
28 | new_h, new_w = size, size
29 | elif h < w:
30 | new_h, new_w = size, int(w * (size / h))
31 | else:
32 | new_h, new_w = int(h * (size / w)), size
33 | return F.interpolate(imgs, size=(new_h, new_w), mode="bicubic", antialias=True)
34 |
35 |
36 | def pad_to_multiples_of(imgs: torch.Tensor, multiple: int) -> torch.Tensor:
37 | _, _, h, w = imgs.size()
38 | if h % multiple == 0 and w % multiple == 0:
39 | return imgs.clone()
40 | # get_pad = lambda x: (x // multiple + 1) * multiple - x
41 | get_pad = lambda x: (x // multiple + int(x % multiple != 0)) * multiple - x
42 | ph, pw = get_pad(h), get_pad(w)
43 | return F.pad(imgs, pad=(0, pw, 0, ph), mode="constant", value=0)
44 |
45 |
46 | class UltraFusionPipeline:
47 |
48 | def __init__(self, cldm: ControlLDM, diffusion: Diffusion, fidelity_encoder, device: str) -> None:
49 | self.cldm = cldm
50 | self.diffusion = diffusion
51 | self.fidelity_encoder = fidelity_encoder
52 | self.device = device
53 | self.final_size: Tuple[int] = None
54 |
55 | def set_final_size(self, lq: torch.Tensor) -> None:
56 | h, w = lq.shape[2:]
57 | self.final_size = (h, w)
58 |
59 | @count_vram_usage
60 | def run(
61 | self,
62 | lq2,
63 | lq1_mscn_norm,
64 | lq1_color,
65 | steps: int = 50,
66 | strength: float = 1.0,
67 | tiled: bool = False,
68 | tile_size: int = 512,
69 | tile_stride: int = 256,
70 | pos_prompt: str = "",
71 | neg_prompt: str = "low quality, blurry, low-resolution, noisy, unsharp, weird textures",
72 | cfg_scale: float = "4.0",
73 | cond_fn: Guidance = None,
74 | fidelity_input: torch.Tensor = None,
75 | consistent_start: torch.Tensor = None
76 | ) -> torch.Tensor:
77 | ### preprocess
78 | lq1_mscn_norm, lq1_color, lq2 = lq1_mscn_norm.cuda(), lq1_color.cuda(), lq2.cuda()
79 | bs, _, H, W = lq2.shape
80 | if not tiled:
81 | assert H == 512 and W == 512, "The image shape must be equal to 512x512"
82 |
83 | # prepare conditon
84 | lq2 = lq2 * 2 - 1 #[-1, 1]
85 | if not tiled:
86 | cond = self.cldm.prepare_condition(lq2, lq1_mscn_norm, lq1_color, pos_prompt)
87 | else:
88 | cond, skip_feats = self.cldm.prepare_condition_tiled(lq2, lq1_mscn_norm, lq1_color, pos_prompt, tile_size=tile_size, tile_stride=tile_stride, fidelity_encoder=self.fidelity_encoder, fidelity_input=fidelity_input)
89 | uncond = None
90 | old_control_scales = self.cldm.control_scales
91 | self.cldm.control_scales = [strength] * 13
92 | x_T = torch.randn((bs, 4, H // 8, W // 8), dtype=torch.float32, device=self.device)
93 | # lq_latent = self.cldm.vae_encode(lq)
94 | # noise = torch.randn(lq_latent.shape, dtype=torch.float32, device=self.device)
95 | # x_T = self.diffusion.q_sample(x_start=lq_latent, t=torch.tensor([999], device=self.device), noise=noise)
96 | ### run sampler
97 | sampler = SpacedSampler(self.diffusion.betas)
98 | z = sampler.sample(
99 | model=self.cldm, device=self.device, steps=steps, batch_size=bs, x_size=(4, H // 8, W // 8),
100 | cond=cond, uncond=uncond, cfg_scale=cfg_scale, x_T=x_T, progress=True,
101 | progress_leave=True, cond_fn=cond_fn, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
102 | )
103 | if not tiled:
104 | sample = self.cldm.vae_decode(z)
105 | else:
106 | sample = self.cldm.vae_decode_tiled(z, tile_size // 8, tile_stride // 8, skip_feats, consistent_start)
107 |
108 | ### postprocess
109 | self.cldm.control_scales = old_control_scales
110 | return sample # [-1 , 1]
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.30.1
2 | einops==0.8.1
3 | ftfy==6.3.1
4 | numpy==1.24.4
5 | omegaconf==2.3.0
6 | opencv_python==4.9.0.80
7 | opencv_python_headless==4.10.0.84
8 | packaging==25.0
9 | Pillow==11.2.1
10 | pyiqa==0.1.13
11 | regex==2024.5.15
12 | scipy==1.15.2
13 | torch==2.3.0
14 | torchvision==0.18.0
15 | tqdm==4.66.4
16 | transformers==4.37.2
17 |
--------------------------------------------------------------------------------
/utils/V4_CA/sampler.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple, Dict
2 |
3 | import torch
4 | from torch import nn
5 | import numpy as np
6 | from tqdm import tqdm
7 |
8 | from model.V4_CA.gaussian_diffusion import extract_into_tensor
9 | from model.V4_CA.cldm import ControlLDM
10 | from utils.cond_fn import Guidance
11 | from utils.common import sliding_windows, gaussian_weights
12 |
13 |
14 | # https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/respace.py
15 | def space_timesteps(num_timesteps, section_counts):
16 | """
17 | Create a list of timesteps to use from an original diffusion process,
18 | given the number of timesteps we want to take from equally-sized portions
19 | of the original process.
20 | For example, if there's 300 timesteps and the section counts are [10,15,20]
21 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
22 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
23 | If the stride is a string starting with "ddim", then the fixed striding
24 | from the DDIM paper is used, and only one section is allowed.
25 | :param num_timesteps: the number of diffusion steps in the original
26 | process to divide up.
27 | :param section_counts: either a list of numbers, or a string containing
28 | comma-separated numbers, indicating the step count
29 | per section. As a special case, use "ddimN" where N
30 | is a number of steps to use the striding from the
31 | DDIM paper.
32 | :return: a set of diffusion steps from the original process to use.
33 | """
34 | if isinstance(section_counts, str):
35 | if section_counts.startswith("ddim"):
36 | desired_count = int(section_counts[len("ddim") :])
37 | for i in range(1, num_timesteps):
38 | if len(range(0, num_timesteps, i)) == desired_count:
39 | return set(range(0, num_timesteps, i))
40 | raise ValueError(
41 | f"cannot create exactly {num_timesteps} steps with an integer stride"
42 | )
43 | section_counts = [int(x) for x in section_counts.split(",")]
44 | size_per = num_timesteps // len(section_counts)
45 | extra = num_timesteps % len(section_counts)
46 | start_idx = 0
47 | all_steps = []
48 | for i, section_count in enumerate(section_counts):
49 | size = size_per + (1 if i < extra else 0)
50 | if size < section_count:
51 | raise ValueError(
52 | f"cannot divide section of {size} steps into {section_count}"
53 | )
54 | if section_count <= 1:
55 | frac_stride = 1
56 | else:
57 | frac_stride = (size - 1) / (section_count - 1)
58 | cur_idx = 0.0
59 | taken_steps = []
60 | for _ in range(section_count):
61 | taken_steps.append(start_idx + round(cur_idx))
62 | cur_idx += frac_stride
63 | all_steps += taken_steps
64 | start_idx += size
65 | return set(all_steps)
66 |
67 |
68 | class SpacedSampler(nn.Module):
69 | """
70 | Implementation for spaced sampling schedule proposed in IDDPM. This class is designed
71 | for sampling ControlLDM.
72 |
73 | https://arxiv.org/pdf/2102.09672.pdf
74 | """
75 |
76 | def __init__(self, betas: np.ndarray) -> "SpacedSampler":
77 | super().__init__()
78 | self.num_timesteps = len(betas)
79 | self.original_betas = betas
80 | self.original_alphas_cumprod = np.cumprod(1.0 - betas, axis=0)
81 | self.context = {}
82 |
83 | def register(self, name: str, value: np.ndarray) -> None:
84 | self.register_buffer(name, torch.tensor(value, dtype=torch.float32))
85 |
86 | def make_schedule(self, num_steps: int) -> None:
87 | # calcualte betas for spaced sampling
88 | # https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/respace.py
89 | used_timesteps = space_timesteps(self.num_timesteps, str(num_steps))
90 | betas = []
91 | last_alpha_cumprod = 1.0
92 | for i, alpha_cumprod in enumerate(self.original_alphas_cumprod):
93 | if i in used_timesteps:
94 | # marginal distribution is the same as q(x_{S_t}|x_0)
95 | betas.append(1 - alpha_cumprod / last_alpha_cumprod)
96 | last_alpha_cumprod = alpha_cumprod
97 | assert len(betas) == num_steps
98 | self.timesteps = np.array(sorted(list(used_timesteps)), dtype=np.int32) # e.g. [0, 10, 20, ...]
99 |
100 | betas = np.array(betas, dtype=np.float64)
101 | alphas = 1.0 - betas
102 | alphas_cumprod = np.cumprod(alphas, axis=0)
103 | # print(f"sampler sqrt_alphas_cumprod: {np.sqrt(alphas_cumprod)[-1]}")
104 | alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
105 | sqrt_recip_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod)
106 | sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod - 1)
107 | # calculations for posterior q(x_{t-1} | x_t, x_0)
108 | posterior_variance = (
109 | betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
110 | )
111 | # log calculation clipped because the posterior variance is 0 at the
112 | # beginning of the diffusion chain.
113 | posterior_log_variance_clipped = np.log(
114 | np.append(posterior_variance[1], posterior_variance[1:])
115 | )
116 | posterior_mean_coef1 = (
117 | betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
118 | )
119 | posterior_mean_coef2 = (
120 | (1.0 - alphas_cumprod_prev)
121 | * np.sqrt(alphas)
122 | / (1.0 - alphas_cumprod)
123 | )
124 |
125 | self.register("sqrt_recip_alphas_cumprod", sqrt_recip_alphas_cumprod)
126 | self.register("sqrt_recipm1_alphas_cumprod", sqrt_recipm1_alphas_cumprod)
127 | self.register("posterior_variance", posterior_variance)
128 | self.register("posterior_log_variance_clipped", posterior_log_variance_clipped)
129 | self.register("posterior_mean_coef1", posterior_mean_coef1)
130 | self.register("posterior_mean_coef2", posterior_mean_coef2)
131 |
132 | def q_posterior_mean_variance(self, x_start: torch.Tensor, x_t: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor]:
133 | """
134 | Implement the posterior distribution q(x_{t-1}|x_t, x_0).
135 |
136 | Args:
137 | x_start (torch.Tensor): The predicted images (NCHW) in timestep `t`.
138 | x_t (torch.Tensor): The sampled intermediate variables (NCHW) of timestep `t`.
139 | t (torch.Tensor): Timestep (N) of `x_t`. `t` serves as an index to get
140 | parameters for each timestep.
141 |
142 | Returns:
143 | posterior_mean (torch.Tensor): Mean of the posterior distribution.
144 | posterior_variance (torch.Tensor): Variance of the posterior distribution.
145 | posterior_log_variance_clipped (torch.Tensor): Log variance of the posterior distribution.
146 | """
147 | posterior_mean = (
148 | extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
149 | + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
150 | )
151 | posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
152 | posterior_log_variance_clipped = extract_into_tensor(
153 | self.posterior_log_variance_clipped, t, x_t.shape
154 | )
155 | return posterior_mean, posterior_variance, posterior_log_variance_clipped
156 |
157 | def _predict_xstart_from_eps(self, x_t: torch.Tensor, t: torch.Tensor, eps: torch.Tensor) -> torch.Tensor:
158 | return (
159 | extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
160 | - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
161 | )
162 |
163 | def apply_cond_fn(
164 | self,
165 | model: ControlLDM,
166 | pred_x0: torch.Tensor,
167 | t: torch.Tensor,
168 | index: torch.Tensor,
169 | cond_fn: Guidance
170 | ) -> torch.Tensor:
171 | t_now = int(t[0].item()) + 1
172 | if not (cond_fn.t_stop < t_now and t_now < cond_fn.t_start):
173 | # stop guidance
174 | self.context["g_apply"] = False
175 | return pred_x0
176 | grad_rescale = 1 / extract_into_tensor(self.posterior_mean_coef1, index, pred_x0.shape)
177 | # apply guidance for multiple times
178 | loss_vals = []
179 | for _ in range(cond_fn.repeat):
180 | # set target and pred for gradient computation
181 | target, pred = None, None
182 | if cond_fn.space == "latent":
183 | target = model.vae_encode(cond_fn.target)
184 | pred = pred_x0
185 | elif cond_fn.space == "rgb":
186 | # We need to backward gradient to x0 in latent space, so it's required
187 | # to trace the computation graph while decoding the latent.
188 | with torch.enable_grad():
189 | target = cond_fn.target
190 | pred_x0_rg = pred_x0.detach().clone().requires_grad_(True)
191 | pred = model.vae_decode(pred_x0_rg)
192 | assert pred.requires_grad
193 | else:
194 | raise NotImplementedError(cond_fn.space)
195 | # compute gradient
196 | delta_pred, loss_val = cond_fn(target, pred, t_now)
197 | loss_vals.append(loss_val)
198 | # update pred_x0 w.r.t gradient
199 | if cond_fn.space == "latent":
200 | delta_pred_x0 = delta_pred
201 | pred_x0 = pred_x0 + delta_pred_x0 * grad_rescale
202 | elif cond_fn.space == "rgb":
203 | pred.backward(delta_pred)
204 | delta_pred_x0 = pred_x0_rg.grad
205 | pred_x0 = pred_x0 + delta_pred_x0 * grad_rescale
206 | else:
207 | raise NotImplementedError(cond_fn.space)
208 | self.context["g_apply"] = True
209 | self.context["g_loss"] = float(np.mean(loss_vals))
210 | return pred_x0
211 |
212 | def predict_noise(
213 | self,
214 | model: ControlLDM,
215 | x: torch.Tensor,
216 | t: torch.Tensor,
217 | cond: Dict[str, torch.Tensor],
218 | uncond: Optional[Dict[str, torch.Tensor]],
219 | cfg_scale: float
220 | ) -> torch.Tensor:
221 | if uncond is None or cfg_scale == 1.:
222 | model_output = model(x, t, cond)
223 | else:
224 | # apply classifier-free guidance
225 | model_cond = model(x, t, cond)
226 | model_uncond = model(x, t, uncond)
227 | model_output = model_uncond + cfg_scale * (model_cond - model_uncond)
228 | return model_output
229 |
230 | @torch.no_grad()
231 | def predict_noise_tiled(
232 | self,
233 | model: ControlLDM,
234 | x: torch.Tensor,
235 | t: torch.Tensor,
236 | cond: Dict[str, torch.Tensor],
237 | uncond: Optional[Dict[str, torch.Tensor]],
238 | cfg_scale: float,
239 | tile_size: int,
240 | tile_stride: int,
241 | fix_init_noise: bool
242 | ):
243 | _, _, h, w = x.shape
244 | tiles_latent = tqdm(sliding_windows(h, w, tile_size // 8, tile_stride // 8), unit="tile", leave=False)
245 | tiles_img = tqdm(sliding_windows(h * 8, w * 8, tile_size, tile_stride), unit="tile", leave=False)
246 | eps = torch.zeros_like(x)
247 | count = torch.zeros_like(x, dtype=torch.float32)
248 | weights = gaussian_weights(tile_size // 8, tile_size // 8)[None, None]
249 | weights = torch.tensor(weights, dtype=torch.float32, device=x.device)
250 | for (hi_latent, hi_end_latent, wi_latent, wi_end_latent), (hi_img, hi_end_img, wi_img, wi_end_img) in zip(tiles_latent, tiles_img):
251 | tiles_latent.set_description(f"Process tile ({hi_latent} {hi_end_latent}), ({wi_latent} {wi_end_latent})")
252 | if fix_init_noise:
253 | tile_x = x[:, :, 0:64, 0:64]
254 | else:
255 | tile_x = x[:, :, hi_latent:hi_end_latent, wi_latent:wi_end_latent]
256 | tile_cond = {
257 | "c_lq2": cond["c_lq2"][:, :, hi_latent:hi_end_latent, wi_latent:wi_end_latent],
258 | "c_lq1_mscn_norm": cond["c_lq1_mscn_norm"][:, :, hi_img:hi_end_img, wi_img:wi_end_img],
259 | "c_lq1_color": cond["c_lq1_color"][:, :, hi_img:hi_end_img, wi_img:wi_end_img],
260 | "c_txt": cond["c_txt"]
261 | }
262 | if uncond:
263 | # not implemented
264 | tile_uncond = {
265 | "c_img": uncond["c_img"],
266 | "c_txt": uncond["c_txt"]
267 | }
268 | else:
269 | tile_uncond = None
270 | tile_eps = self.predict_noise(model, tile_x, t, tile_cond, tile_uncond, cfg_scale)
271 | # accumulate noise
272 | eps[:, :, hi_latent:hi_end_latent, wi_latent:wi_end_latent] += tile_eps * weights
273 | count[:, :, hi_latent:hi_end_latent, wi_latent:wi_end_latent] += weights
274 | # average on noise (score)
275 | eps.div_(count)
276 | return eps
277 |
278 | @torch.no_grad()
279 | def p_sample(
280 | self,
281 | model: ControlLDM,
282 | x: torch.Tensor,
283 | t: torch.Tensor,
284 | index: torch.Tensor,
285 | cond: Dict[str, torch.Tensor],
286 | uncond: Optional[Dict[str, torch.Tensor]],
287 | cfg_scale: float,
288 | cond_fn: Optional[Guidance],
289 | tiled: bool,
290 | tile_size: int,
291 | tile_stride: int,
292 | fix_init_noise: bool=False
293 | ) -> torch.Tensor:
294 | if tiled:
295 | eps = self.predict_noise_tiled(model, x, t, cond, uncond, cfg_scale, tile_size, tile_stride, fix_init_noise)
296 | else:
297 | eps = self.predict_noise(model, x, t, cond, uncond, cfg_scale)
298 | pred_x0 = self._predict_xstart_from_eps(x, index, eps)
299 | if cond_fn:
300 | # assert not tiled, f"tiled sampling currently doesn't support guidance"
301 | assert x.shape[0] == 1, "guidance sampling currently only support batch size = 1"
302 | pred_x0 = self.apply_cond_fn(model, pred_x0, t, index, cond_fn)
303 | model_mean, model_variance, _ = self.q_posterior_mean_variance(pred_x0, x, index)
304 | noise = torch.randn_like(x)
305 | nonzero_mask = (
306 | (index != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
307 | )
308 | x_prev = model_mean + nonzero_mask * torch.sqrt(model_variance) * noise
309 | return x_prev
310 |
311 | @torch.no_grad()
312 | def sample(
313 | self,
314 | model: ControlLDM,
315 | device: str,
316 | steps: int,
317 | batch_size: int,
318 | x_size: Tuple[int],
319 | cond: Dict[str, torch.Tensor],
320 | uncond: Dict[str, torch.Tensor],
321 | cfg_scale: float,
322 | cond_fn: Optional[Guidance]=None,
323 | tiled: bool=False,
324 | tile_size: int=-1,
325 | tile_stride: int=-1,
326 | x_T: Optional[torch.Tensor]=None,
327 | progress: bool=True,
328 | progress_leave: bool=True,
329 | ) -> torch.Tensor:
330 | self.make_schedule(steps)
331 | self.to(device)
332 | if x_T is None:
333 | # TODO: not convert to float32, may trigger an error
334 | img = torch.randn((batch_size, *x_size), device=device)
335 | else:
336 | img = x_T
337 | timesteps = np.flip(self.timesteps) # [1000, 950, 900, ...]
338 | total_steps = len(self.timesteps)
339 | iterator = tqdm(timesteps, total=total_steps, leave=progress_leave, disable=not progress)
340 | for i, step in enumerate(iterator):
341 | ts = torch.full((batch_size,), step, device=device, dtype=torch.long)
342 | index = torch.full_like(ts, fill_value=total_steps - i - 1)
343 | img = self.p_sample(
344 | model, img, ts, index, cond, uncond, cfg_scale, cond_fn,
345 | tiled, tile_size, tile_stride, fix_init_noise=True if i == 0 else False
346 | )
347 | if cond_fn and self.context["g_apply"]:
348 | loss_val = self.context["g_loss"]
349 | desc = f"Spaced Sampler With Guidance, Loss: {loss_val:.6f}"
350 | else:
351 | desc = "Spaced Sampler"
352 | iterator.set_description(desc)
353 | return img
354 |
--------------------------------------------------------------------------------
/utils/common.py:
--------------------------------------------------------------------------------
1 | from typing import Mapping, Any, Tuple, Callable
2 | import importlib
3 | import os
4 | from urllib.parse import urlparse
5 |
6 | import torch
7 | from torch import Tensor
8 | from torch.nn import functional as F
9 | import numpy as np
10 |
11 | from torch.hub import download_url_to_file, get_dir
12 |
13 |
14 | def get_obj_from_str(string: str, reload: bool=False) -> Any:
15 | module, cls = string.rsplit(".", 1)
16 | if reload:
17 | module_imp = importlib.import_module(module)
18 | importlib.reload(module_imp)
19 | return getattr(importlib.import_module(module, package=None), cls)
20 |
21 |
22 | def instantiate_from_config(config: Mapping[str, Any]) -> Any:
23 | if not "target" in config:
24 | raise KeyError("Expected key `target` to instantiate.")
25 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
26 |
27 |
28 | def wavelet_blur(image: Tensor, radius: int):
29 | """
30 | Apply wavelet blur to the input tensor.
31 | """
32 | # input shape: (1, 3, H, W)
33 | # convolution kernel
34 | kernel_vals = [
35 | [0.0625, 0.125, 0.0625],
36 | [0.125, 0.25, 0.125],
37 | [0.0625, 0.125, 0.0625],
38 | ]
39 | kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
40 | # add channel dimensions to the kernel to make it a 4D tensor
41 | kernel = kernel[None, None]
42 | # repeat the kernel across all input channels
43 | kernel = kernel.repeat(3, 1, 1, 1)
44 | image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
45 | # apply convolution
46 | output = F.conv2d(image, kernel, groups=3, dilation=radius)
47 | return output
48 |
49 |
50 | def wavelet_decomposition(image: Tensor, levels=5):
51 | """
52 | Apply wavelet decomposition to the input tensor.
53 | This function only returns the low frequency & the high frequency.
54 | """
55 | high_freq = torch.zeros_like(image)
56 | for i in range(levels):
57 | radius = 2 ** i
58 | low_freq = wavelet_blur(image, radius)
59 | high_freq += (image - low_freq)
60 | image = low_freq
61 |
62 | return high_freq, low_freq
63 |
64 |
65 | def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
66 | """
67 | Apply wavelet decomposition, so that the content will have the same color as the style.
68 | """
69 | # calculate the wavelet decomposition of the content feature
70 | content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
71 | del content_low_freq
72 | # calculate the wavelet decomposition of the style feature
73 | style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
74 | del style_high_freq
75 | # reconstruct the content feature with the style's high frequency
76 | return content_high_freq + style_low_freq
77 |
78 |
79 | # https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/utils/download_util.py/
80 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
81 | """Load file form http url, will download models if necessary.
82 |
83 | Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
84 |
85 | Args:
86 | url (str): URL to be downloaded.
87 | model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
88 | Default: None.
89 | progress (bool): Whether to show the download progress. Default: True.
90 | file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
91 |
92 | Returns:
93 | str: The path to the downloaded file.
94 | """
95 | if model_dir is None: # use the pytorch hub_dir
96 | hub_dir = get_dir()
97 | model_dir = os.path.join(hub_dir, 'checkpoints')
98 |
99 | os.makedirs(model_dir, exist_ok=True)
100 |
101 | parts = urlparse(url)
102 | filename = os.path.basename(parts.path)
103 | if file_name is not None:
104 | filename = file_name
105 | cached_file = os.path.abspath(os.path.join(model_dir, filename))
106 | if not os.path.exists(cached_file):
107 | print(f'Downloading: "{url}" to {cached_file}\n')
108 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
109 | return cached_file
110 |
111 |
112 | def sliding_windows(h: int, w: int, tile_size: int, tile_stride: int) -> Tuple[int, int, int, int]:
113 | hi_list = list(range(0, h - tile_size + 1, tile_stride))
114 | if (h - tile_size) % tile_stride != 0:
115 | hi_list.append(h - tile_size)
116 |
117 | wi_list = list(range(0, w - tile_size + 1, tile_stride))
118 | if (w - tile_size) % tile_stride != 0:
119 | wi_list.append(w - tile_size)
120 |
121 | coords = []
122 | for hi in hi_list:
123 | for wi in wi_list:
124 | coords.append((hi, hi + tile_size, wi, wi + tile_size))
125 | return coords
126 |
127 |
128 | # https://github.com/csslc/CCSR/blob/main/model/q_sampler.py#L503
129 | def gaussian_weights(tile_width: int, tile_height: int) -> np.ndarray:
130 | """Generates a gaussian mask of weights for tile contributions"""
131 | latent_width = tile_width
132 | latent_height = tile_height
133 | var = 0.01
134 | midpoint = (latent_width - 1) / 2 # -1 because index goes from 0 to latent_width - 1
135 | x_probs = [
136 | np.exp(-(x - midpoint) * (x - midpoint) / (latent_width * latent_width) / (2 * var)) / np.sqrt(2 * np.pi * var)
137 | for x in range(latent_width)]
138 | midpoint = latent_height / 2
139 | y_probs = [
140 | np.exp(-(y - midpoint) * (y - midpoint) / (latent_height * latent_height) / (2 * var)) / np.sqrt(2 * np.pi * var)
141 | for y in range(latent_height)]
142 | weights = np.outer(y_probs, x_probs)
143 | return weights
144 |
145 |
146 | COUNT_VRAM = bool(os.environ.get("COUNT_VRAM", False))
147 |
148 | def count_vram_usage(func: Callable) -> Callable:
149 | if not COUNT_VRAM:
150 | return func
151 |
152 | def wrapper(*args, **kwargs):
153 | peak_before = torch.cuda.max_memory_allocated() / (1024 ** 3)
154 | ret = func(*args, **kwargs)
155 | torch.cuda.synchronize()
156 | peak_after = torch.cuda.max_memory_allocated() / (1024 ** 3)
157 | print(f"VRAM peak before {func.__name__}: {peak_before:.5f} GB, after: {peak_after:.5f} GB")
158 | return ret
159 | return wrapper
160 |
--------------------------------------------------------------------------------
/utils/cond_fn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.transforms.functional as tvtf
3 |
4 | from typing import overload, Tuple
5 | from torch.nn import functional as F
6 | from torchvision.utils import save_image
7 |
8 |
9 | class Guidance:
10 |
11 | def __init__(self, scale: float, t_start: int, t_stop: int, space: str, repeat: int) -> "Guidance":
12 | """
13 | Initialize restoration guidance.
14 |
15 | Args:
16 | scale (float): Gradient scale (denoted as `s` in our paper). The larger the gradient scale,
17 | the closer the final result will be to the output of the first stage model.
18 | t_start (int), t_stop (int): The timestep to start or stop guidance. Note that the sampling
19 | process starts from t=1000 to t=0, the `t_start` should be larger than `t_stop`.
20 | space (str): The data space for computing loss function (rgb or latent).
21 |
22 | Our restoration guidance is based on [GDP](https://github.com/Fayeben/GenerativeDiffusionPrior).
23 | Thanks for their work!
24 | """
25 | self.scale = scale * 3000
26 | self.t_start = t_start
27 | self.t_stop = t_stop
28 | self.target = None
29 | self.space = space
30 | self.repeat = repeat
31 |
32 | def load_target(self, target: torch.Tensor) -> None:
33 | self.target = target
34 |
35 | def __call__(self, target_x0, pred_x0: torch.Tensor, t: int) -> Tuple[torch.Tensor, float]:
36 | # avoid propagating gradient out of this scope
37 | pred_x0 = pred_x0.detach().clone()
38 | tmp1, tmp2 = target_x0
39 | tmp1 = tmp1.detach().clone()
40 | tmp2 = tmp2.detach().clone()
41 | return self._forward([tmp1, tmp2], pred_x0, t)
42 | # return self._forward([tmp1, tmp2], pred_x0, t)
43 |
44 | @overload
45 | def _forward(self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int) -> Tuple[torch.Tensor, float]:
46 | ...
47 |
48 |
49 | class MSEGuidance(Guidance):
50 |
51 | def _forward(self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int) -> Tuple[torch.Tensor, float]:
52 | # inputs: [-1, 1], nchw, rgb
53 | target_y = 0.299 * target_x0[:, 0, :, :] + 0.587 * target_x0[:, 1, :, :] + 0.114 * target_x0[:, 2, :, :]
54 | weight_map = (target_y > 0.99).type(torch.float32)
55 | save_image(weight_map, 'weight_map.png')
56 | with torch.enable_grad():
57 | pred_x0.requires_grad_(True)
58 | loss = ((pred_x0 - target_x0).pow(2) * (1. - weight_map)).mean((1, 2, 3)).sum()
59 | scale = self.scale
60 | g = -torch.autograd.grad(loss, pred_x0)[0] * scale
61 | return g, loss.item()
62 |
63 |
64 | class WeightedMSEGuidance(Guidance):
65 |
66 | def _get_weight(self, target: torch.Tensor) -> torch.Tensor:
67 | # convert RGB to G
68 | rgb_to_gray_kernel = torch.tensor([0.2989, 0.5870, 0.1140]).view(1, 3, 1, 1)
69 | target = torch.sum(target * rgb_to_gray_kernel.to(target.device), dim=1, keepdim=True)
70 | # initialize sobel kernel in x and y axis
71 | G_x = [
72 | [1, 0, -1],
73 | [2, 0, -2],
74 | [1, 0, -1]
75 | ]
76 | G_y = [
77 | [1, 2, 1],
78 | [0, 0, 0],
79 | [-1, -2, -1]
80 | ]
81 | G_x = torch.tensor(G_x, dtype=target.dtype, device=target.device)[None]
82 | G_y = torch.tensor(G_y, dtype=target.dtype, device=target.device)[None]
83 | G = torch.stack((G_x, G_y))
84 |
85 | target = F.pad(target, (1, 1, 1, 1), mode='replicate') # padding = 1
86 | grad = F.conv2d(target, G, stride=1)
87 | mag = grad.pow(2).sum(dim=1, keepdim=True).sqrt()
88 |
89 | n, c, h, w = mag.size()
90 | block_size = 2
91 | blocks = mag.view(n, c, h // block_size, block_size, w // block_size, block_size).permute(0, 1, 2, 4, 3, 5).contiguous()
92 | block_mean = blocks.sum(dim=(-2, -1), keepdim=True).tanh().repeat(1, 1, 1, 1, block_size, block_size).permute(0, 1, 2, 4, 3, 5).contiguous()
93 | block_mean = block_mean.view(n, c, h, w)
94 | weight_map = 1 - block_mean
95 |
96 | return weight_map
97 |
98 | def _forward(self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int) -> Tuple[torch.Tensor, float]:
99 | # inputs: [-1, 1], nchw, rgb
100 | with torch.no_grad():
101 | w = self._get_weight((target_x0 + 1) / 2)
102 | with torch.enable_grad():
103 | pred_x0.requires_grad_(True)
104 | loss = ((pred_x0 - target_x0).pow(2) * w).mean((1, 2, 3)).sum()
105 | scale = self.scale
106 | g = -torch.autograd.grad(loss, pred_x0)[0] * scale
107 | return g, loss.item()
108 |
109 |
110 |
111 | class L_spa(torch.nn.Module):
112 |
113 | def __init__(self, patch_size):
114 | super(L_spa, self).__init__()
115 | # print(1)kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)
116 | kernel_left = torch.FloatTensor( [[0,0,0],[-1,1,0],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
117 | kernel_right = torch.FloatTensor( [[0,0,0],[0,1,-1],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
118 | kernel_up = torch.FloatTensor( [[0,-1,0],[0,1, 0 ],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
119 | kernel_down = torch.FloatTensor( [[0,0,0],[0,1, 0],[0,-1,0]]).cuda().unsqueeze(0).unsqueeze(0)
120 |
121 | # kernel_left_up = torch.FloatTensor( [[-1,0,0],[0,1,0],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
122 | # kernel_right_up = torch.FloatTensor( [[0,0,-1],[0,1,0],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
123 | # kernel_left_down = torch.FloatTensor( [[0,0,0],[0,1, 0 ],[-1,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
124 | # kernel_right_down = torch.FloatTensor( [[0,0,0],[0,1, 0],[0,0,-1]]).cuda().unsqueeze(0).unsqueeze(0)
125 |
126 | self.weight_left = torch.nn.Parameter(data=kernel_left, requires_grad=False)
127 | self.weight_right = torch.nn.Parameter(data=kernel_right, requires_grad=False)
128 | self.weight_up = torch.nn.Parameter(data=kernel_up, requires_grad=False)
129 | self.weight_down = torch.nn.Parameter(data=kernel_down, requires_grad=False)
130 | # self.weight_left_up = nn.Parameter(data=kernel_left_up, requires_grad=False)
131 | # self.weight_right_up = nn.Parameter(data=kernel_right_up, requires_grad=False)
132 | # self.weight_left_down = nn.Parameter(data=kernel_left_down, requires_grad=False)
133 | # self.weight_right_down = nn.Parameter(data=kernel_right_down, requires_grad=False)
134 | self.pool = torch.nn.AvgPool2d(patch_size)
135 |
136 | def forward(self, org , enhance, weight_map):
137 | b,c,h,w = org.shape
138 |
139 | org_mean = torch.mean(org,1,keepdim=True)
140 | enhance_mean = torch.mean(enhance,1,keepdim=True)
141 |
142 | org_pool = self.pool(org_mean)
143 | enhance_pool = self.pool(enhance_mean)
144 | weight_map_pool = self.pool(weight_map)
145 |
146 | D_org_letf = F.conv2d(org_pool , self.weight_left, padding=1)
147 | D_org_right = F.conv2d(org_pool , self.weight_right, padding=1)
148 | D_org_up = F.conv2d(org_pool , self.weight_up, padding=1)
149 | D_org_down = F.conv2d(org_pool , self.weight_down, padding=1)
150 |
151 | # D_org_left_up = F.conv2d(org_pool , self.weight_left_up, padding=1)
152 | # D_org_right_up = F.conv2d(org_pool , self.weight_right_up, padding=1)
153 | # D_org_left_down = F.conv2d(org_pool , self.weight_left_down, padding=1)
154 | # D_org_right_down = F.conv2d(org_pool , self.weight_right_down, padding=1)
155 |
156 | D_enhance_letf = F.conv2d(enhance_pool , self.weight_left, padding=1)
157 | D_enhance_right = F.conv2d(enhance_pool , self.weight_right, padding=1)
158 | D_enhance_up = F.conv2d(enhance_pool , self.weight_up, padding=1)
159 | D_enhance_down = F.conv2d(enhance_pool , self.weight_down, padding=1)
160 |
161 | # D_enhance_left_up = F.conv2d(enhance_pool , self.weight_left_up, padding=1)
162 | # D_enhance_right_up = F.conv2d(enhance_pool , self.weight_right_up, padding=1)
163 | # D_enhance_left_down = F.conv2d(enhance_pool , self.weight_left_down, padding=1)
164 | # D_enhance_right_down = F.conv2d(enhance_pool , self.weight_right_down, padding=1)
165 |
166 | D_left = torch.pow(D_org_letf - D_enhance_letf,2)
167 | D_right = torch.pow(D_org_right - D_enhance_right,2)
168 | D_up = torch.pow(D_org_up - D_enhance_up,2)
169 | D_down = torch.pow(D_org_down - D_enhance_down,2)
170 |
171 | # D_left_up = torch.pow(D_org_left_up - D_enhance_left_up, 2)
172 | # D_right_up = torch.pow(D_org_right_up - D_enhance_right_up, 2)
173 | # D_left_down = torch.pow(D_org_left_down - D_enhance_left_down, 2)
174 | # D_right_down = torch.pow(D_org_right_down - D_enhance_right_down, 2)
175 |
176 | E = (D_left + D_right + D_up +D_down)
177 |
178 | return torch.mean(E * weight_map_pool)
179 |
180 |
181 | class L_mscn(torch.nn.Module):
182 |
183 | def __init__(self):
184 | super(L_mscn, self).__init__()
185 | self.l1 = torch.nn.L1Loss()
186 | def forward(self, img1, img2, weight_map):
187 | y1 = 0.299 * img1[:, 0, :, :] + 0.587 * img1[:, 1, :, :] + 0.114 * img1[:, 2, :, :]
188 | y2 = 0.299 * img2[:, 0, :, :] + 0.587 * img2[:, 1, :, :] + 0.114 * img2[:, 2, :, :]
189 | y1 = y1.type(torch.float64)
190 | y2 = y2.type(torch.float64)
191 | mu1 = tvtf.gaussian_blur(y1, (7, 7))
192 | mu2 = tvtf.gaussian_blur(y2, (7, 7))
193 | mu1_sq = mu1 * mu1
194 | mu2_sq = mu2 * mu2
195 | sigma1 = torch.sqrt(torch.abs(tvtf.gaussian_blur(y1 * y1, (7, 7)) - mu1_sq))
196 | sigma2 = torch.sqrt(torch.abs(tvtf.gaussian_blur(y2 * y2, (7, 7)) - mu2_sq))
197 | dividend1 = y1 - mu1
198 | dividend2 = y2 - mu2
199 | divisor1 = sigma1 + 1e-7
200 | divisor2 = sigma2 + 1e-7
201 | struct1 = (dividend1 / divisor1)
202 | struct2 = (dividend2 / divisor2)
203 | struct1_norm = (struct1 - struct1.min()) / (struct1.max() - struct1.min())
204 | struct2_norm = (struct2 - struct2.min()) / (struct2.max() - struct2.min())
205 |
206 | return ((struct1 - struct2).pow(2) * weight_map).mean()
207 | # return self.l1(struct1_norm, struct2_norm)
208 | # return (struct1_norm - struct2_norm).mean()
209 |
210 |
211 | class StructureGuidance(Guidance):
212 |
213 | def __init__(self, scale: float, t_start: int, t_stop: int, space: str, repeat: int) -> Guidance:
214 | super(StructureGuidance, self).__init__(scale, t_start, t_stop, space, repeat)
215 | # self.spa1_loss = L_spa(1)
216 | # self.spa2_loss = L_spa(2)
217 | self.spa4_loss = L_spa(1)
218 | # self.loss = L_mscn()
219 |
220 | def mscn_torch(self, input_img: torch.Tensor, ksize, c): #input an RGB image
221 | y = 0.299 * input_img[:, 0, :, :] + 0.587 * input_img[:, 1, :, :] + 0.114 * input_img[:, 2, :, :]
222 | y = y.type(torch.float64)
223 | mu = tvtf.gaussian_blur(y, (ksize,ksize))
224 | mu_sq = mu * mu
225 | sigma = torch.sqrt(torch.abs(tvtf.gaussian_blur(y*y, (ksize,ksize)) - mu_sq))
226 | dividend = y - mu
227 | divisor = sigma + c
228 | struct = (dividend / divisor).type(torch.float32)
229 | struct_norm = (struct - struct.min()) / (struct.max() - struct.min() + 1e-6)
230 | return struct_norm
231 |
232 | def _forward(self, target_x0, pred_x0: torch.Tensor, t: int) -> Tuple[torch.Tensor, float]:
233 | # pred_x0: [-1, 1], nchw, rgb
234 | target1, target2 = target_x0
235 | pred_x0 = (pred_x0 + 1) / 2 # [0, 1]
236 | target1_struct = self.mscn_torch(input_img=target1, ksize=7, c=1e-7)
237 | target2_struct = self.mscn_torch(input_img=target2, ksize=7, c=1e-7)
238 | target2_y = 0.299 * target2[:, 0, :, :] + 0.587 * target2[:, 1, :, :] + 0.114 * target2[:, 2, :, :]
239 | weight_map = (target2_y > 0.99).type(torch.float32)
240 |
241 | with torch.enable_grad():
242 | pred_x0.requires_grad_(True)
243 | # loss = (pred_x0_struct - target_x0_struct).pow(2).mean((1, 2)).sum()
244 | # loss = self.spa_loss(target2, pred_x0, 1. - weight_map)
245 | loss = self.spa4_loss(target2, pred_x0, 1. - weight_map)
246 | # loss = self.loss(target1, pred_x0, 1. - weight_map)
247 | scale = self.scale
248 | g = -torch.autograd.grad(loss, pred_x0)[0] * scale
249 | return g, loss.item()
--------------------------------------------------------------------------------
/utils/flow.py:
--------------------------------------------------------------------------------
1 | # This code is modified from https://github.com/haofeixu/gmflow and https://github.com/liuziyang123/LDRFlow.
2 |
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from einops import rearrange
7 |
8 |
9 | def backward_warp(x, flo):
10 | """
11 | warp an image/tensor (im2) back to im1, according to the optical flow
12 | x: [B, C, H, W] (im2)
13 | flo: [B, 2, H, W] flow
14 | """
15 | B, C, H, W = x.size()
16 | # mesh grid
17 | xx = torch.arange(0, W).view(1, -1).repeat(H, 1)
18 | yy = torch.arange(0, H).view(-1, 1).repeat(1, W)
19 | xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1)
20 | yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1)
21 | grid = torch.cat((xx, yy), 1).float()
22 |
23 | if x.is_cuda:
24 | grid = grid.cuda()
25 | vgrid = grid + flo
26 | # scale grid to [-1,1]
27 | vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0
28 | vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0
29 |
30 | vgrid = vgrid.permute(0, 2, 3, 1)
31 | output = F.grid_sample(x, vgrid)
32 | # mask = torch.ones(x.size()).to(DEVICE)
33 | # mask = F.grid_sample(mask, vgrid)
34 |
35 | # mask[mask < 0.999] = 0
36 | # mask[mask > 0] = 1
37 |
38 | return output
39 |
40 |
41 | def forward_backward_consistency_check(fwd_flow, bwd_flow,
42 | alpha=0.01,
43 | beta=10
44 | ):
45 | # fwd_flow, bwd_flow: [B, 2, H, W]
46 | # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837)
47 | assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4
48 | assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2
49 | flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W]
50 |
51 | warped_bwd_flow = backward_warp(bwd_flow, fwd_flow) # [B, 2, H, W]
52 | warped_fwd_flow = backward_warp(fwd_flow, bwd_flow) # [B, 2, H, W]
53 |
54 | diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W]
55 | diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1)
56 |
57 | threshold = alpha * flow_mag + beta
58 | # threshold = 0
59 |
60 | fwd_occ = (diff_fwd > threshold).float() # [B, H, W]
61 | bwd_occ = (diff_bwd > threshold).float()
62 |
63 | return fwd_occ, bwd_occ
64 |
65 |
66 | def calculate_imf_map(x, y):
67 | imf_map = torch.zeros(256).cuda()
68 | r = 0
69 | for i in range(256):
70 | if x[i] == 0:
71 | imf_map[i] = -1
72 | else:
73 | p, v, j = x[i], 0, r
74 | while True:
75 | if y[j] < p:
76 | p = p - y[j]
77 | v = v + y[j] * j
78 | j += 1
79 | else:
80 | r = j
81 | y[j] = y[j] - p
82 | v = v + p * j
83 | imf_map[i] = (v / x[i]).round()
84 | break
85 | imf_map = imf_map.unsqueeze(dim=0)
86 | return imf_map
87 |
88 |
89 | def IMF(ue, oe):
90 | B, C, H, W = ue.shape
91 | ue = (ue * 255).round()
92 | oe = (oe * 255).round()
93 |
94 | imf_map = []
95 | ue_rgb = torch.split(ue, 1, dim=1)
96 | oe_rgb = torch.split(oe, 1, dim=1)
97 | imf_map = [
98 | calculate_imf_map(torch.histc(x, bins=256, min=0, max=255), torch.histc(y, bins=256, min=0, max=255)) for x, y in zip(ue_rgb, oe_rgb)
99 | ]
100 | imf_map = torch.concat(imf_map, dim=0)
101 |
102 | zeros = torch.zeros([C, 1], dtype=torch.float32).cuda()
103 | imf_map = torch.concat((imf_map, zeros), 1)
104 |
105 | ue_imf = rearrange(ue.squeeze(), 'c h w -> c (h w)')
106 | ue_imf_floor = ue_imf.floor()
107 | for c in range(C):
108 | ind = ue_imf_floor[c].long()
109 | ue_imf[c, :] = (ue_imf[c, :] - ue_imf_floor[c, :]) * (imf_map[c, :][ind + 1] - imf_map[c, :][ind]) + imf_map[c, :][ind]
110 | ue_imf = rearrange(ue_imf, 'c (h w) -> c h w', h=H, w=W).unsqueeze(dim=0)
111 | ue_imf = (ue_imf / 255.).clamp(0, 1)
112 | return ue_imf
--------------------------------------------------------------------------------
/utils/imf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from einops import rearrange
3 |
4 |
5 | def calculate_imf_map(x, y):
6 | imf_map = torch.zeros(256).cuda()
7 | r = 0
8 | for i in range(256):
9 | if x[i] == 0:
10 | imf_map[i] = -1
11 | else:
12 | p, v, j = x[i], 0, r
13 | while True:
14 | if y[j] < p:
15 | p = p - y[j]
16 | v = v + y[j] * j
17 | j += 1
18 | else:
19 | r = j
20 | y[j] = y[j] - p
21 | v = v + p * j
22 | imf_map[i] = (v / x[i]).round()
23 | break
24 | imf_map = imf_map.unsqueeze(dim=0)
25 | return imf_map
26 |
27 |
28 | def IMF2(ue, oe):
29 | B, C, H, W = ue.shape
30 | ue = (ue * 255).round()
31 | oe = (oe * 255).round()
32 |
33 | imf_map = []
34 | ue_rgb = torch.split(ue, 1, dim=1)
35 | oe_rgb = torch.split(oe, 1, dim=1)
36 | imf_map = [
37 | calculate_imf_map(torch.histc(x, bins=256, min=0, max=255), torch.histc(y, bins=256, min=0, max=255)) for x, y in zip(ue_rgb, oe_rgb)
38 | ]
39 | imf_map = torch.concat(imf_map, dim=0)
40 |
41 | zeros = torch.zeros([C, 1], dtype=torch.float32).cuda()
42 | imf_map = torch.concat((imf_map, zeros), 1)
43 |
44 | ue_imf = rearrange(ue.squeeze(), 'c h w -> c (h w)')
45 | ue_imf_floor = ue_imf.floor()
46 | for c in range(C):
47 | ind = ue_imf_floor[c].long()
48 | ue_imf[c, :] = (ue_imf[c, :] - ue_imf_floor[c, :]) * (imf_map[c, :][ind + 1] - imf_map[c, :][ind]) + imf_map[c, :][ind]
49 | ue_imf = rearrange(ue_imf, 'c (h w) -> c h w', h=H, w=W).unsqueeze(dim=0)
50 | ue_imf = (ue_imf / 255.).clamp(0, 1)
51 | return ue_imf
--------------------------------------------------------------------------------
/val_nriqa.py:
--------------------------------------------------------------------------------
1 | import os,glob
2 | import torch
3 | import pyiqa
4 | from tqdm import tqdm
5 |
6 |
7 | class AverageMeter(object):
8 | def __init__(self):
9 | self.reset()
10 |
11 | def reset(self):
12 | self.val = 0
13 | self.avg = 0
14 | self.sum = 0
15 | self.count = 0
16 | self.max = -1
17 | self.min = 10000
18 |
19 | def update(self, val, n=1):
20 | self.val = val
21 | self.sum += val * n
22 | self.count += n
23 | self.avg = self.sum / self.count
24 | if val > self.max:
25 | self.max = val
26 | if val < self.min:
27 | self.min = val
28 |
29 | def get_max(self):
30 | return self.max
31 |
32 |
33 | metric_list = {
34 | 'musiq': pyiqa.create_metric('musiq', as_loss=False).cuda(),
35 | 'paq2piq': pyiqa.create_metric('paq2piq', as_loss=False).cuda(),
36 | 'hyperiqa': pyiqa.create_metric('hyperiqa', as_loss=False).cuda(),
37 | }
38 | res_list = {}
39 |
40 | for k in metric_list:
41 | print('{} lower better: {}'.format(k, metric_list[k].lower_better))
42 | res_list[k] = AverageMeter()
43 |
44 | img_list = glob.glob('/ailab/user/chenzixuan/Research/Diff-HDR/cvpr2025_release/MEFB/*out*.png')
45 |
46 | for img_path in tqdm(img_list):
47 | for k in metric_list:
48 | tmp = metric_list[k](img_path)
49 | res_list[k].update(tmp)
50 |
51 | for k in res_list:
52 | print('{}: {}'.format(k, res_list[k].avg))
--------------------------------------------------------------------------------