├── .gitignore ├── LICENSE ├── README.md ├── assets ├── .DS_Store ├── teaser.pdf └── teaser.png ├── configs ├── fcb.yaml └── ultrafusion.yaml ├── dataset └── test_dataset.py ├── examples ├── 0052 │ ├── oe.jpg │ └── ue.jpg └── 0072 │ ├── oe.jpg │ └── ue.jpg ├── inference.py ├── inference.sh ├── model ├── V4_CA │ ├── attention.py │ ├── cldm.py │ ├── controlnet.py │ ├── cross_attention.py │ ├── gaussian_diffusion.py │ ├── unet.py │ └── vae.py ├── __init__.py ├── clip.py ├── config.py ├── distributions.py ├── fe │ └── fe_V1_4.py ├── open_clip │ ├── __init__.py │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── model.py │ ├── tokenizer.py │ └── transformer.py ├── raft │ ├── corr.py │ ├── datasets.py │ ├── extractor.py │ ├── raft.py │ ├── update.py │ └── utils │ │ ├── augmentor.py │ │ ├── flow_viz.py │ │ ├── frame_utils.py │ │ └── utils.py └── util.py ├── pipeline └── V4_CA │ └── pipeline.py ├── requirements.txt ├── utils ├── V4_CA │ └── sampler.py ├── common.py ├── cond_fn.py ├── flow.py └── imf.py └── val_nriqa.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.ckpt 3 | *.pth 4 | *.out 5 | /data 6 | /exps 7 | # *.sh 8 | !install_env.sh 9 | /weights 10 | /temp 11 | /results* 12 | .ipynb_checkpoints/ 13 | /TODO.txt 14 | /deprecated 15 | /temp_scripts 16 | /.vscode 17 | /runs 18 | /experiment/ 19 | /ckpts 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 | UltraFusion: Ultra High Dynamic Imaging using Exposure Fusion 5 | 6 | (CVPR 2025 Highlight) 7 |

8 |

9 | Zixuan Chen1,3 10 | · 11 | Yujin Wang1 12 | · 13 | Xin Cai2 14 | · 15 | Zhiyuan You2 16 |
17 | Zheming Lu3 18 | · 19 | Fan Zhang1 20 | · 21 | Shi Guo1 22 | · 23 | Tianfan Xue2,1 24 | 25 |
26 | 1Shanghai AI Laboratory, 2The Chinese University of Hong Kong, 27 | 3Zhejiang University 28 |
29 |

30 | Paper PDF 31 | Project Page 32 | 33 |
34 |

35 |

36 | 37 | ![teaser_img](assets/teaser.png) 38 | 39 | ## :mega: News 40 | - **2024.4.23**: Inference codes, benchmark and results are released. 41 | - **2024.4.5**: Our UltraFusion is selected to be presented as a :sparkles:***highlight***:sparkles: in CVPR 2025. 42 | - **2025.2.27**: Accepeted by ***CVPR 2025*** :tada::tada::tada:. 43 | - **2025.1.21**: Feel free to try online demos at Hugging Face and OpenXLab :blush:. 44 | 45 | 46 | ## :memo: ToDo List 47 | - [ ] Release training codes. 48 | - [x] Release inference codes and pre-trained model. 49 | - [x] Release UltraFusion benchmark and visual results. 50 | - [x] Release more visual comparison in our [project page](https://openimaginglab.github.io/UltraFusion/) 51 | 52 | ## :bridge_at_night: Benchmark 53 | We capture 100 challenging real-world HDR scenes for performance evaluation. 54 | Our benchmark and results (include competing methods) are availble at [Google Drive](https://drive.google.com/drive/folders/18icr4A_0qGvwqehPhxH29hqJYO8HS6bi?usp=sharing) and [Baidu Disk](). 55 | Moreover, we also provide results of our method and the comparison methods on [RealHDV](https://github.com/yungsyu99/Real-HDRV) and [MEFB](https://github.com/xingchenzhang/MEFB). 56 | 57 | > *Note: The HDR reconstruction methods perform poorly in some scenes because we follow their setup to retrain 2-exposure version, while the training set they used only provide ground truth for the middle exposure, limiting the dynamic range. We believe that using training data with higher dynamic range can improve performance.* 58 | 59 | ## Quick Start 60 | **Installation** 61 | ```shell 62 | # clone this repo 63 | git clone https://github.com/OpenImagingLab/UltraFusion.git 64 | cd UltraFusion 65 | 66 | # create environment 67 | conda create -n UltraFusion python=3.10 68 | conda activate UltraFusion 69 | pip install -r requirements.txt 70 | ``` 71 | **Prepare Data and Pre-trained Model** 72 | 73 | Download [raft-sintel.pth](https://drive.google.com/drive/folders/1sWDsfuZ3Up38EUQt7-JDTT1HcGHuJgvT?usp=sharing), [v2-1_512-ema-pruned.ckpt](https://huggingface.co/stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned.ckpt), [fcb.pt](https://huggingface.co/zxchen00/UltraFusion/blob/main/fcb.pt) and [ultrafusion.pt](https://huggingface.co/zxchen00/UltraFusion/blob/main/ultrafusion.pt), and put them in ```ckpts``` folder. Download three benchmarks ([Google Drive](https://drive.google.com/drive/folders/18icr4A_0qGvwqehPhxH29hqJYO8HS6bi?usp=sharing) or [Baidu Disk]()) and put them in ```data``` folder. 74 | 75 | **Inference** 76 | 77 | Run the following scripts for inference. 78 | ```shell 79 | # UltraFusion Benchmark 80 | python inference.py --dataset UltraFusion --output results --tiled --tile_size 512 --tile_stride 256 --prealign --save_all 81 | # RealHDRV 82 | python inference.py --dataset RealHDRV --output results --tiled --tile_size 512 --tile_stride 256 --prealign --save_all 83 | # MEFB (cancel pre-alignment for static scenes) 84 | python inference.py --dataset MEFB --output results --tiled --tile_size 512 --tile_stride 256 --save_all 85 | ``` 86 | You can also use ```val_nriqa.py``` for evaluation. 87 | 88 | 89 | 90 | 91 | ## Acknowledgements 92 | This project is developped on the codebase of [DiffBIR](https://github.com/XPixelGroup/DiffBIR). We appreciate their great work! 93 | 94 | ## :love_you_gesture: Citation 95 | If you find our paper and repo are helpful for your research, please consider citing: 96 | ```BibTeX 97 | @article{chen2025ultrafusion, 98 | title={UltraFusion: Ultra High Dynamic Imaging using Exposure Fusion}, 99 | author={Chen, Zixuan and Wang, Yujin and Cai, Xin and You, Zhiyuan and Lu, Zheming and Zhang, Fan and Guo, Shi and Xue, Tianfan}, 100 | journal={arXiv preprint arXiv:2501.11515}, 101 | year={2025} 102 | } 103 | ``` 104 | -------------------------------------------------------------------------------- /assets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenImagingLab/UltraFusion/b0787c59217fe4877c233d938810a23c8d2c5c73/assets/.DS_Store -------------------------------------------------------------------------------- /assets/teaser.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenImagingLab/UltraFusion/b0787c59217fe4877c233d938810a23c8d2c5c73/assets/teaser.pdf -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenImagingLab/UltraFusion/b0787c59217fe4877c233d938810a23c8d2c5c73/assets/teaser.png -------------------------------------------------------------------------------- /configs/fcb.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | fidelity_encoder: 3 | target: model.fe.fe_V1_4.FidelityEncoderV2 4 | params: 5 | fe_config: 6 | double_z: true 7 | z_channels: 4 8 | resolution: 256 9 | in_channels: 3 10 | out_ch: 3 11 | ch: 128 12 | ch_mult: 13 | - 1 14 | - 2 15 | - 4 16 | - 4 17 | num_res_blocks: 2 18 | attn_resolutions: [] 19 | dropout: 0.0 -------------------------------------------------------------------------------- /configs/ultrafusion.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | fidelity_encoder: 3 | target: model.V4_CA.vae.FidelityEncoder 4 | params: 5 | double_z: true 6 | z_channels: 4 7 | resolution: 256 8 | in_channels: 1 9 | out_ch: 3 10 | ch: 128 11 | ch_mult: 12 | - 1 13 | - 2 14 | - 4 15 | - 4 16 | num_res_blocks: 2 17 | attn_resolutions: [] 18 | dropout: 0.0 19 | cldm: 20 | target: model.V4_CA.cldm.ControlLDM 21 | params: 22 | latent_scale_factor: 0.18215 23 | unet_cfg: 24 | use_checkpoint: True 25 | image_size: 32 # unused 26 | in_channels: 4 27 | out_channels: 4 28 | model_channels: 320 29 | attention_resolutions: [ 4, 2, 1 ] 30 | num_res_blocks: 2 31 | channel_mult: [ 1, 2, 4, 4 ] 32 | num_head_channels: 64 # need to fix for flash-attn 33 | use_spatial_transformer: True 34 | use_linear_in_transformer: True 35 | transformer_depth: 1 36 | context_dim: 1024 37 | legacy: False 38 | vae_cfg: 39 | embed_dim: 4 40 | ddconfig: 41 | double_z: true 42 | z_channels: 4 43 | resolution: 256 44 | in_channels: 3 45 | out_ch: 3 46 | ch: 128 47 | ch_mult: 48 | - 1 49 | - 2 50 | - 4 51 | - 4 52 | num_res_blocks: 2 53 | attn_resolutions: [] 54 | dropout: 0.0 55 | clip_cfg: 56 | embed_dim: 1024 57 | vision_cfg: 58 | image_size: 224 59 | layers: 32 60 | width: 1280 61 | head_width: 80 62 | patch_size: 14 63 | text_cfg: 64 | context_length: 77 65 | vocab_size: 49408 66 | width: 1024 67 | heads: 16 68 | layers: 24 69 | layer: "penultimate" 70 | controlnet_cfg: 71 | use_checkpoint: True 72 | image_size: 32 # unused 73 | in_channels: 4 74 | hint_channels: 4 75 | model_channels: 320 76 | attention_resolutions: [ 4, 2, 1 ] 77 | num_res_blocks: 2 78 | channel_mult: [ 1, 2, 4, 4 ] 79 | num_head_channels: 64 # need to fix for flash-attn 80 | use_spatial_transformer: True 81 | use_linear_in_transformer: True 82 | transformer_depth: 1 83 | context_dim: 1024 84 | legacy: False 85 | 86 | diffusion: 87 | target: model.V4_CA.gaussian_diffusion.Diffusion 88 | params: 89 | linear_start: 0.00085 90 | linear_end: 0.0120 91 | timesteps: 1000 92 | 93 | dataset: 94 | train1: 95 | target: dataset.mef_dataset.MEFDatasetV5 96 | params: 97 | # training file list path 98 | img_dir: /dev/shm/SICE/Dataset_Part1_2expo_train 99 | motion_img_dir: /ailab/group/pjlab-sail/chenzixuan/vimeo_septuplet/sequences 100 | train2: 101 | target: dataset.mef_dataset.MEFDatasetV5 102 | params: 103 | # training file list path 104 | img_dir: /dev/shm/SICE/Dataset_Part2_2expo 105 | motion_img_dir: /ailab/group/pjlab-sail/chenzixuan/vimeo_septuplet/sequences 106 | 107 | train: 108 | # pretrained sd v2.1 path 109 | sd_path: /ailab/user/chenzixuan/Research/pretrained_models/SDv2.1/v2-1_512-ema-pruned.ckpt 110 | # experiment directory path 111 | exp_dir: ./exps/sd_motion_V4_CA_8gpus 112 | learning_rate: 1e-4 113 | # ImageNet 1k (1.3M images) 114 | # batch size = 192, lr = 1e-4, total training steps = 25k 115 | # Our filtered laion2b-en (15M images) 116 | # batch size = 256, lr = 1e-4 (first 30k), 1e-5 (next 50k), total training steps = 80k 117 | batch_size: 4 118 | num_workers: 8 119 | train_steps: 1000000 120 | log_every: 50 121 | ckpt_every: 4000 122 | image_every: 500 123 | resume: ~ -------------------------------------------------------------------------------- /dataset/test_dataset.py: -------------------------------------------------------------------------------- 1 | import os, random, glob, time 2 | import cv2 3 | import torch 4 | import numpy as np 5 | import torch.utils.data as data 6 | from PIL import Image 7 | from torchvision.transforms import ToTensor 8 | 9 | 10 | def get_color_and_struct(isrgb, input_img: torch.Tensor, ksize, sigmaX, c): #input an RGB image 11 | 12 | input_img = input_img.squeeze().cpu().numpy().transpose(1, 2, 0) 13 | 14 | if isrgb==True: 15 | yuv_img = cv2.cvtColor(input_img, cv2.COLOR_RGB2YUV).astype(np.float32) 16 | y = np.expand_dims(yuv_img[:,:,0], axis=-1).astype(np.float64) 17 | u = np.expand_dims(yuv_img[:,:,1], axis=-1).astype(np.float32) 18 | v = np.expand_dims(yuv_img[:,:,2], axis=-1).astype(np.float32) 19 | else: 20 | y = input_img.astype(np.float64) 21 | #mu = gaussian_filter(y, ksize, ksize/6) 22 | mu = cv2.GaussianBlur(y, (ksize,ksize), sigmaX).astype(np.float64) 23 | mu_sq = mu * mu 24 | sigma = np.sqrt(np.absolute(cv2.GaussianBlur(y*y, (ksize,ksize), sigmaX) - mu_sq)).astype(np.float64) 25 | mu = np.expand_dims(mu, axis=-1) 26 | sigma = np.expand_dims(sigma, axis=-1) 27 | dividend = y.astype(np.float64) - mu 28 | divisor = sigma + c 29 | struct = dividend / divisor 30 | struct = struct.astype(np.float32) 31 | struct_norm = (struct - struct.min()) / (struct.max() - struct.min() + 1e-6) 32 | struct_norm = torch.from_numpy(struct_norm).permute(2, 0, 1) 33 | u = torch.from_numpy(u).permute(2, 0, 1) 34 | v = torch.from_numpy(v).permute(2, 0, 1) 35 | img_uv = torch.cat([u, v], dim=0) 36 | return struct_norm, img_uv 37 | 38 | 39 | class TestDataset(data.Dataset): 40 | def __init__(self, dataset): 41 | super(TestDataset, self).__init__() 42 | self.dataset = dataset 43 | self.img_dir_dict = { 44 | 'UltraFusion': './data/UltraFusionBenchmark', 45 | 'MEFB': './data/MEFB', 46 | 'RealHDRV': './data/Real-HDRV-Deghosting-sRGB-Testing', 47 | } 48 | self.ldr_list1 = [] 49 | self.ldr_list2 = [] 50 | self.file_name_list = [] 51 | self.to_tensor = ToTensor() 52 | 53 | self.scene_list = os.listdir(self.img_dir_dict[dataset]) 54 | self.scene_list.sort() 55 | for scene in self.scene_list: 56 | if len(os.listdir(os.path.join(self.img_dir_dict[dataset], scene))) > 0: 57 | self.ldr_list1.append(glob.glob(os.path.join(self.img_dir_dict[dataset], scene, '*ue.*'))[0]) 58 | self.ldr_list2.append(glob.glob(os.path.join(self.img_dir_dict[dataset], scene, '*oe.*'))[0]) 59 | self.file_name_list.append('{}_{}'.format(dataset, scene)) 60 | 61 | 62 | def __getitem__(self, index): 63 | ldr1_path = self.ldr_list1[index] 64 | ldr2_path = self.ldr_list2[index] 65 | file_name = self.file_name_list[index] 66 | 67 | ldr1 = Image.open(ldr1_path).convert('RGB') 68 | ldr2 = Image.open(ldr2_path).convert('RGB') 69 | 70 | W, H = ldr1.size 71 | 72 | if W * H >= 6000 * 4000: 73 | ldr1 = ldr1.resize([W // 4, H // 4]) 74 | ldr2 = ldr2.resize([W // 4, H // 4]) 75 | elif W * H >= 2000 *1500: 76 | ldr1 = ldr1.resize([W * 2 // 5, H * 2 // 5]) 77 | ldr2 = ldr2.resize([W * 2 // 5, H * 2 // 5]) 78 | 79 | ldr1 = self.to_tensor(ldr1) 80 | ldr2 = self.to_tensor(ldr2) 81 | 82 | return { 83 | 'ue': ldr1, 84 | 'oe': ldr2, 85 | 'file_name': file_name 86 | } 87 | 88 | def __len__(self): 89 | return len(self.ldr_list1) -------------------------------------------------------------------------------- /examples/0052/oe.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenImagingLab/UltraFusion/b0787c59217fe4877c233d938810a23c8d2c5c73/examples/0052/oe.jpg -------------------------------------------------------------------------------- /examples/0052/ue.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenImagingLab/UltraFusion/b0787c59217fe4877c233d938810a23c8d2c5c73/examples/0052/ue.jpg -------------------------------------------------------------------------------- /examples/0072/oe.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenImagingLab/UltraFusion/b0787c59217fe4877c233d938810a23c8d2c5c73/examples/0072/oe.jpg -------------------------------------------------------------------------------- /examples/0072/ue.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenImagingLab/UltraFusion/b0787c59217fe4877c233d938810a23c8d2c5c73/examples/0072/ue.jpg -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os, tqdm, math 2 | import torch 3 | import numpy as np 4 | from argparse import ArgumentParser 5 | from collections import OrderedDict 6 | from torch.nn import functional as F 7 | from torch.utils.data import DataLoader 8 | from torchvision.utils import save_image 9 | from torchvision.transforms import ToTensor 10 | from omegaconf import OmegaConf 11 | from accelerate.utils import set_seed 12 | 13 | from dataset.test_dataset import TestDataset, get_color_and_struct 14 | from model.raft.raft import RAFT 15 | from utils.common import instantiate_from_config 16 | from utils.flow import backward_warp, forward_backward_consistency_check, IMF 17 | 18 | 19 | def pad_imgv3(x, crop_size, crop_step): 20 | _, _, h, w = x.size() 21 | n_h = max(math.ceil((h - crop_size) / crop_step), 0) 22 | n_w = max(math.ceil((w - crop_size) / crop_step), 0) 23 | h_target = crop_size + n_h * crop_step 24 | w_target = crop_size + n_w * crop_step 25 | mod_pad_h = h_target - h 26 | mod_pad_w = w_target - w 27 | x_np = x.cpu().numpy() 28 | x_np = np.pad(x_np, pad_width=((0,0),(0,0),(0,mod_pad_h),(0,mod_pad_w)), mode='reflect') 29 | res = torch.from_numpy(x_np).cuda() 30 | return res 31 | 32 | 33 | def pad_img(x, patch_size): 34 | _, _, h, w = x.size() 35 | mod_pad_h = (patch_size - h % patch_size) % patch_size 36 | mod_pad_w = (patch_size - w % patch_size) % patch_size 37 | x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') 38 | return x 39 | 40 | 41 | def crop_parallel(img, crop_sz, step): 42 | b, c, h, w = img.shape 43 | h_space = np.arange(0, h - crop_sz + 1, step) 44 | w_space = np.arange(0, w - crop_sz + 1, step) 45 | index = 0 46 | num_h = 0 47 | lr_list=torch.Tensor().to(img.device) 48 | for x in h_space: 49 | num_h += 1 50 | num_w = 0 51 | for y in w_space: 52 | num_w += 1 53 | index += 1 54 | crop_img = img[:, :, x:x + crop_sz, y:y + crop_sz] 55 | lr_list = torch.cat([lr_list, crop_img]) 56 | new_h=x + crop_sz # new height after crop 57 | new_w=y + crop_sz # new width after crop 58 | return lr_list, num_h, num_w, new_h, new_w 59 | 60 | 61 | def combine_parallel_wo_artifact(sr_list, num_h, num_w, new_h, new_w, patch_size, step): 62 | p_size = patch_size 63 | pred_lr_list = sr_list 64 | 65 | pred_full_w_list = [] # rectangle 66 | for i in range(num_h): 67 | pred_full_w = torch.zeros([1, 3, p_size, new_w]).cuda() 68 | pred_full_w[:, :, :, 0 : step] = pred_lr_list[i * num_w][:, :, 0 : step] 69 | # pred_full_w[:, :, :, 0 : patch_size] = pred_lr_list[i * num_w][:, :, 0 : patch_size] 70 | pred_full_w[:, :, :, new_w - step :] = pred_lr_list[i * num_w + num_w - 1][:, :, -step:] 71 | for j in range(1, num_w): 72 | repeat_l = j * step 73 | repeat_r = repeat_l + (p_size - step) 74 | ind = i * num_w + j - 1 75 | 76 | pred_full_w[:, :, :, repeat_r : repeat_r + 2 * step - patch_size] = pred_lr_list[ind + 1][:, :, patch_size - step : step] 77 | 78 | for k in range(repeat_l, repeat_r): 79 | alpha = (k - repeat_l) / (repeat_r - repeat_l) 80 | pred_full_w[:, :, :, k] = pred_lr_list[ind][:, :, step + k - repeat_l] * (1 - alpha) + pred_lr_list[ind + 1][:, :, k - repeat_l] * alpha 81 | # pred_full_w[:, :, :, k] = pred_full_w[:, :, :, k] * (1 - alpha) + pred_lr_list[ind + 1][:, :, k - repeat_l] * alpha 82 | pred_full_w_list.append(pred_full_w) 83 | 84 | pred = torch.zeros([1, 3, new_h, new_w], device=sr_list[0].device) 85 | pred[:, :, 0 : step, :] = pred_full_w_list[0][:, :, 0 : step, :] 86 | # pred[:, :, 0 : patch_size, :] = pred_full_w_list[0][:, :, 0 : patch_size, :] 87 | pred[:, :, -step :, :] = pred_full_w_list[-1][:, :, -step :, :] 88 | for i in range(1, num_h): 89 | repeat_u = i * step 90 | repeat_d = repeat_u + (p_size - step) 91 | for k in range(repeat_u, repeat_d): 92 | alpha = (k - repeat_u) / (repeat_d - repeat_u) 93 | pred[:, :, k, :] = pred_full_w_list[i - 1][:, :, step + k - repeat_u, :] * (1 - alpha) + pred_full_w_list[i][:, :, k - repeat_u, :] * alpha 94 | # pred[:, :, k, :] = pred[:, :, k, :] * (1 - alpha) + pred_full_w_list[i][:, :, k - repeat_u, :] * alpha 95 | return pred, pred_full_w_list 96 | 97 | 98 | def mef(img1, img2, img_name, flow_model, pipe, args, consistent_start=None): 99 | _, _, H, W = img2.shape 100 | img1 = pad_img(img1, 16) 101 | img2 = pad_img(img2, 16) 102 | img1_light = IMF(img1, img2) 103 | with torch.no_grad(): 104 | _, img12_flow = flow_model(img1_light * 2 - 1, img2 * 2 - 1, iters=20, test_mode=True) 105 | _, img21_flow = flow_model(img2 * 2 - 1, img1_light * 2 - 1, iters=20, test_mode=True) 106 | 107 | img12 = backward_warp(img1, img21_flow) 108 | _, occ_mask = forward_backward_consistency_check(img12_flow, img21_flow) 109 | occ_mask = occ_mask.unsqueeze(dim=1) 110 | img12_mask = img12 * (1. - occ_mask) 111 | 112 | img1 = img1[:, :, :H, :W] 113 | img2 = img2[:, :, :H, :W] 114 | img12 = img12[:, :, :H, :W] 115 | img12_mask = img12_mask[:, :, :H, :W] 116 | occ_mask = occ_mask[:, :, :H, :W] 117 | 118 | if not args.prealign: 119 | img12_mask = img1 # cancel pre-align 120 | 121 | 122 | img2 = pad_imgv3(img2, args.tile_size, args.tile_stride) 123 | img12_mask = pad_imgv3(img12_mask, args.tile_size, args.tile_stride) 124 | 125 | img1_mscn_norm, img1_color = get_color_and_struct(isrgb=True, input_img=img12_mask, ksize=7, sigmaX=0, c=0.0000001) 126 | img1_mscn_norm, img1_color = img1_mscn_norm.unsqueeze(dim=0), img1_color.unsqueeze(dim=0) 127 | img1_mscn_norm, img1_color = img1_mscn_norm.cuda(), img1_color.cuda() 128 | fidelity_input = torch.cat([img2, img1_mscn_norm, img1_color], dim=1) 129 | 130 | img2_patches, num_h, num_w, new_h, new_w = crop_parallel(img2, args.tile_size, args.tile_stride) 131 | img1_struct_patches, num_h, num_w, new_h, new_w = crop_parallel(img1_mscn_norm, args.tile_size, args.tile_stride) 132 | img1_color_patches, num_h, num_w, new_h, new_w = crop_parallel(img1_color, args.tile_size, args.tile_stride) 133 | img2_patches_list = torch.split(img2_patches, 1, dim=0) 134 | img1_struct_patches_list = torch.split(img1_struct_patches, 1, dim=0) 135 | img1_color_patches_list = torch.split(img1_color_patches, 1, dim=0) 136 | fidelity_input_patches, num_h, num_w, new_h, new_w = crop_parallel(fidelity_input, args.tile_size, args.tile_stride) 137 | fidelity_input_patches_list = torch.split(fidelity_input_patches, 1, dim=0) 138 | out_list = [] 139 | for ind, (img2_, img1_struct_, img1_color_, fidelity_input_) in enumerate(zip(img2_patches_list, img1_struct_patches_list, img1_color_patches_list, fidelity_input_patches_list)): 140 | set_seed(args.seed) 141 | out = pipe.run(lq2=img2_, lq1_mscn_norm=img1_struct_, lq1_color=img1_color_, tiled=args.tiled, tile_size=args.tile_size, tile_stride=args.tile_stride, cond_fn=cond_fn, fidelity_input=fidelity_input_, consistent_start=consistent_start) # [-1, 1] 142 | out_list.append(out) 143 | 144 | out_list = torch.cat(out_list, dim=0) 145 | out, _ = combine_parallel_wo_artifact(out_list, num_h, num_w, new_h, new_w, args.tile_size, args.tile_stride) 146 | 147 | out = out[:, :, :H, :W] 148 | img1 = img1[:, :, :H, :W] 149 | img1_light = img1_light[:, :, :H, :W] 150 | img12 = img12[:, :, :H, :W] 151 | img2 = img2[:, :, :H, :W] 152 | occ_mask = occ_mask[:, :H, :W] 153 | img12_mask = img12_mask[:, :, :H, :W] 154 | img1_mscn_norm = img1_mscn_norm[:, :, :H, :W] 155 | img1_color = img1_color[:, :, :H, :W] 156 | 157 | u = torch.zeros_like(out) 158 | v = torch.zeros_like(out) 159 | u[:, 1:, :, :] = img1_color 160 | v[:, :2, :, :] = img1_color 161 | 162 | save_image((out + 1) / 2, '{}/{}_out_{}.png'.format(args.output, img_name, 'align' if args.prealign else 'noalign')) 163 | if args.save_all: 164 | save_image(img1, '{}/{}_ue.png'.format(args.output, img_name)) 165 | save_image(img1_light, '{}/{}_ue_imf.png'.format(args.output, img_name)) 166 | save_image(img12_mask, '{}/{}_ue2oe_mask_{}.png'.format(args.output, img_name, 'align' if args.prealign else 'noalign')) 167 | save_image(img12, '{}/{}_ue2oe_{}.png'.format(args.output, img_name, 'align' if args.prealign else 'noalign')) 168 | save_image(img1_mscn_norm, '{}/{}_ue2oe_mask_mscn_{}.png'.format(args.output, img_name, 'align' if args.prealign else 'noalign')) 169 | save_image(img2, '{}/{}_oe.png'.format(args.output, img_name)) 170 | save_image(occ_mask, '{}/{}_occmask_{}.png'.format(args.output, img_name, 'align' if args.prealign else 'noalign')) 171 | save_image(u, '{}/{}_u.png'.format(args.output, img_name)) 172 | save_image(v, '{}/{}_v.png'.format(args.output, img_name)) 173 | 174 | return out 175 | 176 | parser = ArgumentParser() 177 | parser.add_argument("--dataset", type=str, default='MEFB') 178 | parser.add_argument("--output", default='results', type=str) 179 | parser.add_argument("--tiled", action='store_true', default=False) 180 | parser.add_argument("--tile_size", type=int, default=512) 181 | parser.add_argument("--tile_stride", type=int, default=256) 182 | parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda", "mps"]) 183 | parser.add_argument("--seed", type=int, default=231, choices=["cpu", "cuda", "mps"]) 184 | parser.add_argument("--prealign", action='store_true', default=False) 185 | parser.add_argument("--save_all", action='store_true', default=False) 186 | args = parser.parse_args() 187 | 188 | from model.V4_CA.cldm import ControlLDM 189 | from model.V4_CA.gaussian_diffusion import Diffusion 190 | from pipeline.V4_CA.pipeline import UltraFusionPipeline 191 | ### load uent, vae, clip 192 | cldm: ControlLDM = instantiate_from_config(OmegaConf.load("configs/ultrafusion.yaml").model.cldm) 193 | sd = torch.load('ckpts/v2-1_512-ema-pruned.ckpt', map_location="cpu")["state_dict"] 194 | unused = cldm.load_pretrained_sd(sd) 195 | print(f"strictly load pretrained sd_v2.1, unused weights: {unused}") 196 | ### load controlnet 197 | control_sd = torch.load('ckpts/ultrafusion.pt', map_location="cpu") 198 | cldm.load_controlnet_from_ckpt(control_sd) 199 | print(f"strictly load controlnet weight") 200 | cldm.eval().to(args.device) 201 | ### load fidelity encoder 202 | fidelity_encoder = instantiate_from_config(OmegaConf.load("configs/fcb.yaml").model.fidelity_encoder) 203 | fidelity_encoder_sd = torch.load('ckpts/fcb.pt') 204 | fidelity_encoder.load_state_dict(fidelity_encoder_sd, strict=True) 205 | fidelity_encoder = fidelity_encoder.cuda() 206 | fidelity_encoder.eval() 207 | ### load diffusion 208 | diffusion: Diffusion = instantiate_from_config(OmegaConf.load("configs/ultrafusion.yaml").model.diffusion) 209 | diffusion.to(args.device) 210 | ### load flow model 211 | flow_state_dict = torch.load('ckpts/raft-sintel.pth', map_location='cpu') 212 | flow_model = RAFT(args).cuda() 213 | new_flow_state_dict = OrderedDict() 214 | for k, v in flow_state_dict.items(): 215 | new_flow_state_dict[k.replace("module.", "")] = v 216 | flow_model.load_state_dict(new_flow_state_dict) 217 | flow_model = flow_model.eval() 218 | cond_fn = None 219 | 220 | pipe = UltraFusionPipeline(cldm=cldm, diffusion=diffusion, fidelity_encoder=fidelity_encoder, device=args.device) 221 | 222 | to_tensor = ToTensor() 223 | 224 | dataset = TestDataset(args.dataset) 225 | dataloader = DataLoader( 226 | dataset, 227 | shuffle=False, 228 | batch_size=1, 229 | num_workers=0 230 | ) 231 | 232 | if not os.path.exists(args.output): 233 | os.mkdir(args.output) 234 | args.output = os.path.join(args.output, args.dataset) 235 | if not os.path.exists(args.output): 236 | os.mkdir(args.output) 237 | 238 | for batch in dataloader: 239 | ue = batch['ue'].cuda() 240 | oe = batch['oe'].cuda() 241 | img_name = batch['file_name'][0] 242 | 243 | _ = mef(img1=ue, img2=oe, img_name=img_name, flow_model=flow_model, pipe=pipe, args=args, consistent_start=None) -------------------------------------------------------------------------------- /inference.sh: -------------------------------------------------------------------------------- 1 | # Final results 2 | CUDA_VISIBLE_DEVICES=0 python inference.py --dataset UltraFusion --output results --tiled --tile_size 512 --tile_stride 256 --save_all -------------------------------------------------------------------------------- /model/V4_CA/attention.py: -------------------------------------------------------------------------------- 1 | from packaging import version 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn, einsum 5 | from einops import rearrange, repeat 6 | from typing import Optional, Any 7 | 8 | from model.util import ( 9 | checkpoint, zero_module, exists, default 10 | ) 11 | from model.config import Config, AttnMode 12 | 13 | 14 | # CrossAttn precision handling 15 | import os 16 | _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") 17 | 18 | 19 | # feedforward 20 | class GEGLU(nn.Module): 21 | def __init__(self, dim_in, dim_out): 22 | super().__init__() 23 | self.proj = nn.Linear(dim_in, dim_out * 2) 24 | 25 | def forward(self, x): 26 | x, gate = self.proj(x).chunk(2, dim=-1) 27 | return x * F.gelu(gate) 28 | 29 | 30 | class FeedForward(nn.Module): 31 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 32 | super().__init__() 33 | inner_dim = int(dim * mult) 34 | dim_out = default(dim_out, dim) 35 | project_in = nn.Sequential( 36 | nn.Linear(dim, inner_dim), 37 | nn.GELU() 38 | ) if not glu else GEGLU(dim, inner_dim) 39 | 40 | self.net = nn.Sequential( 41 | project_in, 42 | nn.Dropout(dropout), 43 | nn.Linear(inner_dim, dim_out) 44 | ) 45 | 46 | def forward(self, x): 47 | return self.net(x) 48 | 49 | 50 | def Normalize(in_channels): 51 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 52 | 53 | 54 | class CrossAttention(nn.Module): 55 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): 56 | super().__init__() 57 | print(f"Setting up {self.__class__.__name__} (vanilla). Query dim is {query_dim}, context_dim is {context_dim} and using " 58 | f"{heads} heads.") 59 | inner_dim = dim_head * heads 60 | context_dim = default(context_dim, query_dim) 61 | 62 | self.scale = dim_head ** -0.5 63 | self.heads = heads 64 | 65 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 66 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 67 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 68 | 69 | self.to_out = nn.Sequential( 70 | nn.Linear(inner_dim, query_dim), 71 | nn.Dropout(dropout) 72 | ) 73 | 74 | def forward(self, x, context=None, mask=None): 75 | h = self.heads 76 | 77 | q = self.to_q(x) 78 | context = default(context, x) 79 | k = self.to_k(context) 80 | v = self.to_v(context) 81 | 82 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 83 | 84 | # force cast to fp32 to avoid overflowing 85 | if _ATTN_PRECISION =="fp32": 86 | # with torch.autocast(enabled=False, device_type = 'cuda'): 87 | with torch.autocast(enabled=False, device_type="cuda" if str(x.device).startswith("cuda") else "cpu"): 88 | q, k = q.float(), k.float() 89 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 90 | else: 91 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 92 | 93 | del q, k 94 | 95 | if exists(mask): 96 | mask = rearrange(mask, 'b ... -> b (...)') 97 | max_neg_value = -torch.finfo(sim.dtype).max 98 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 99 | sim.masked_fill_(~mask, max_neg_value) 100 | 101 | # attention, what we cannot get enough of 102 | sim = sim.softmax(dim=-1) 103 | 104 | out = einsum('b i j, b j d -> b i d', sim, v) 105 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 106 | return self.to_out(out) 107 | 108 | 109 | class MemoryEfficientCrossAttention(nn.Module): 110 | # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 111 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): 112 | super().__init__() 113 | print(f"Setting up {self.__class__.__name__} (xformers). Query dim is {query_dim}, context_dim is {context_dim} and using " 114 | f"{heads} heads.") 115 | inner_dim = dim_head * heads 116 | context_dim = default(context_dim, query_dim) 117 | 118 | self.heads = heads 119 | self.dim_head = dim_head 120 | 121 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 122 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 123 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 124 | 125 | self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) 126 | self.attention_op: Optional[Any] = None 127 | 128 | def forward(self, x, context=None, mask=None): 129 | q = self.to_q(x) 130 | context = default(context, x) 131 | k = self.to_k(context) 132 | v = self.to_v(context) 133 | 134 | b, _, _ = q.shape 135 | q, k, v = map( 136 | lambda t: t.unsqueeze(3) 137 | .reshape(b, t.shape[1], self.heads, self.dim_head) 138 | .permute(0, 2, 1, 3) 139 | .reshape(b * self.heads, t.shape[1], self.dim_head) 140 | .contiguous(), 141 | (q, k, v), 142 | ) 143 | 144 | # actually compute the attention, what we cannot get enough of 145 | out = Config.xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) 146 | 147 | if exists(mask): 148 | raise NotImplementedError 149 | out = ( 150 | out.unsqueeze(0) 151 | .reshape(b, self.heads, out.shape[1], self.dim_head) 152 | .permute(0, 2, 1, 3) 153 | .reshape(b, out.shape[1], self.heads * self.dim_head) 154 | ) 155 | return self.to_out(out) 156 | 157 | 158 | class SDPCrossAttention(nn.Module): 159 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): 160 | super().__init__() 161 | print(f"Setting up {self.__class__.__name__} (sdp). Query dim is {query_dim}, context_dim is {context_dim} and using " 162 | f"{heads} heads.") 163 | inner_dim = dim_head * heads 164 | context_dim = default(context_dim, query_dim) 165 | 166 | self.heads = heads 167 | self.dim_head = dim_head 168 | 169 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 170 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 171 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 172 | 173 | self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) 174 | 175 | def forward(self, x, context=None, mask=None): 176 | q = self.to_q(x) 177 | context = default(context, x) 178 | k = self.to_k(context) 179 | v = self.to_v(context) 180 | 181 | b, _, _ = q.shape 182 | q, k, v = map( 183 | lambda t: t.unsqueeze(3) 184 | .reshape(b, t.shape[1], self.heads, self.dim_head) 185 | .permute(0, 2, 1, 3) 186 | .reshape(b * self.heads, t.shape[1], self.dim_head) 187 | .contiguous(), 188 | (q, k, v), 189 | ) 190 | 191 | # actually compute the attention, what we cannot get enough of 192 | out = F.scaled_dot_product_attention(q, k, v) 193 | 194 | if exists(mask): 195 | raise NotImplementedError 196 | out = ( 197 | out.unsqueeze(0) 198 | .reshape(b, self.heads, out.shape[1], self.dim_head) 199 | .permute(0, 2, 1, 3) 200 | .reshape(b, out.shape[1], self.heads * self.dim_head) 201 | ) 202 | return self.to_out(out) 203 | 204 | 205 | class BasicTransformerBlock(nn.Module): 206 | ATTENTION_MODES = { 207 | AttnMode.VANILLA: CrossAttention, # vanilla attention 208 | AttnMode.XFORMERS: MemoryEfficientCrossAttention, 209 | AttnMode.SDP: SDPCrossAttention 210 | } 211 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, 212 | disable_self_attn=False): 213 | super().__init__() 214 | attn_cls = self.ATTENTION_MODES[Config.attn_mode] 215 | self.disable_self_attn = disable_self_attn 216 | self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, 217 | context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn 218 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 219 | self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, 220 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none 221 | self.norm1 = nn.LayerNorm(dim) 222 | self.norm2 = nn.LayerNorm(dim) 223 | self.norm3 = nn.LayerNorm(dim) 224 | self.checkpoint = checkpoint 225 | 226 | def forward(self, x, context=None): 227 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) 228 | 229 | def _forward(self, x, context=None): 230 | x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x 231 | x = self.attn2(self.norm2(x), context=context) + x 232 | x = self.ff(self.norm3(x)) + x 233 | return x 234 | 235 | 236 | class SpatialTransformer(nn.Module): 237 | """ 238 | Transformer block for image-like data. 239 | First, project the input (aka embedding) 240 | and reshape to b, t, d. 241 | Then apply standard transformer action. 242 | Finally, reshape to image 243 | NEW: use_linear for more efficiency instead of the 1x1 convs 244 | """ 245 | def __init__(self, in_channels, n_heads, d_head, 246 | depth=1, dropout=0., context_dim=None, 247 | disable_self_attn=False, use_linear=False, 248 | use_checkpoint=True): 249 | super().__init__() 250 | if exists(context_dim) and not isinstance(context_dim, list): 251 | context_dim = [context_dim] 252 | self.in_channels = in_channels 253 | inner_dim = n_heads * d_head 254 | self.norm = Normalize(in_channels) 255 | if not use_linear: 256 | self.proj_in = nn.Conv2d(in_channels, 257 | inner_dim, 258 | kernel_size=1, 259 | stride=1, 260 | padding=0) 261 | else: 262 | self.proj_in = nn.Linear(in_channels, inner_dim) 263 | 264 | self.transformer_blocks = nn.ModuleList( 265 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], 266 | disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) 267 | for d in range(depth)] 268 | ) 269 | if not use_linear: 270 | self.proj_out = zero_module(nn.Conv2d(inner_dim, 271 | in_channels, 272 | kernel_size=1, 273 | stride=1, 274 | padding=0)) 275 | else: 276 | self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) 277 | self.use_linear = use_linear 278 | 279 | def forward(self, x, context=None): 280 | # note: if no context is given, cross-attention defaults to self-attention 281 | if not isinstance(context, list): 282 | context = [context] 283 | b, c, h, w = x.shape 284 | x_in = x 285 | x = self.norm(x) 286 | if not self.use_linear: 287 | x = self.proj_in(x) 288 | x = rearrange(x, 'b c h w -> b (h w) c').contiguous() 289 | if self.use_linear: 290 | x = self.proj_in(x) 291 | for i, block in enumerate(self.transformer_blocks): 292 | x = block(x, context=context[i]) 293 | if self.use_linear: 294 | x = self.proj_out(x) 295 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() 296 | if not self.use_linear: 297 | x = self.proj_out(x) 298 | return x + x_in 299 | -------------------------------------------------------------------------------- /model/V4_CA/cldm.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Set, List, Dict 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from model.V4_CA.controlnet import ( 7 | ControlledUnetModel, ControlNet, 8 | ) 9 | from model.V4_CA.vae import AutoencoderKL 10 | from model.clip import FrozenOpenCLIPEmbedder 11 | 12 | from utils.common import sliding_windows, count_vram_usage, gaussian_weights 13 | from torchvision.utils import save_image 14 | 15 | 16 | def disabled_train(self: nn.Module) -> nn.Module: 17 | """Overwrite model.train with this function to make sure train/eval mode 18 | does not change anymore.""" 19 | return self 20 | 21 | 22 | class ControlLDM(nn.Module): 23 | 24 | def __init__( 25 | self, 26 | unet_cfg, 27 | vae_cfg, 28 | clip_cfg, 29 | controlnet_cfg, 30 | latent_scale_factor 31 | ): 32 | super().__init__() 33 | self.unet = ControlledUnetModel(**unet_cfg) 34 | self.vae = AutoencoderKL(**vae_cfg) 35 | self.clip = FrozenOpenCLIPEmbedder(**clip_cfg) 36 | self.controlnet = ControlNet(**controlnet_cfg) 37 | self.scale_factor = latent_scale_factor 38 | self.control_scales = [1.0] * 13 39 | 40 | @torch.no_grad() 41 | def load_pretrained_sd(self, sd: Dict[str, torch.Tensor]) -> Set[str]: 42 | module_map = { 43 | "unet": "model.diffusion_model", 44 | "vae": "first_stage_model", 45 | "clip": "cond_stage_model", 46 | } 47 | modules = [("unet", self.unet), ("vae", self.vae), ("clip", self.clip)] 48 | used = set() 49 | for name, module in modules: 50 | init_sd = {} 51 | scratch_sd = module.state_dict() 52 | for key in scratch_sd: 53 | target_key = ".".join([module_map[name], key]) 54 | init_sd[key] = sd[target_key].clone() 55 | used.add(target_key) 56 | module.load_state_dict(init_sd, strict=True) 57 | unused = set(sd.keys()) - used 58 | # NOTE: this is slightly different from previous version, which haven't switched 59 | # the UNet to eval mode and disabled the requires_grad flag. 60 | for module in [self.vae, self.clip, self.unet]: 61 | module.eval() 62 | module.train = disabled_train 63 | for p in module.parameters(): 64 | p.requires_grad = False 65 | return unused 66 | 67 | @torch.no_grad() 68 | def load_controlnet_from_ckpt(self, sd: Dict[str, torch.Tensor]) -> None: 69 | self.controlnet.load_state_dict(sd, strict=True) 70 | 71 | @torch.no_grad() 72 | def load_controlnet_from_unet(self) -> Tuple[Set[str]]: 73 | unet_sd = self.unet.state_dict() 74 | scratch_sd = self.controlnet.state_dict() 75 | init_sd = {} 76 | init_with_new_zero = set() 77 | init_with_scratch = set() 78 | for key in scratch_sd: 79 | if key in unet_sd: 80 | this, target = scratch_sd[key], unet_sd[key] 81 | if this.size() == target.size(): 82 | init_sd[key] = target.clone() 83 | elif this.size(1) > target.size(1): 84 | print(this.size(1), target.size(1)) 85 | d_ic = this.size(1) - target.size(1) 86 | oc, _, h, w = this.size() 87 | zeros = torch.zeros((oc, d_ic, h, w), dtype=target.dtype) 88 | init_sd[key] = torch.cat((target, zeros), dim=1) 89 | init_with_new_zero.add(key) 90 | else: 91 | print(this.size(1), target.size(1)) 92 | d_ic = this.size(1) 93 | print(target.shape) 94 | init_sd[key] = target[:, :d_ic, :, :] 95 | else: 96 | init_sd[key] = scratch_sd[key].clone() 97 | init_with_scratch.add(key) 98 | self.controlnet.load_state_dict(init_sd, strict=True) 99 | return init_with_new_zero, init_with_scratch 100 | 101 | def vae_encode(self, image: torch.Tensor, sample: bool=True) -> torch.Tensor: 102 | if sample: 103 | return self.vae.encode(image).sample() * self.scale_factor 104 | else: 105 | return self.vae.encode(image).mode() * self.scale_factor 106 | 107 | def vae_encode_tiled(self, image: torch.Tensor, tile_size: int, tile_stride: int, sample: bool=True, fidelity_encoder=None, fidelity_input: torch.Tensor=None) -> torch.Tensor: 108 | bs, _, h, w = image.shape 109 | z = torch.zeros((bs, 4, h // 8, w // 8), dtype=torch.float32, device=image.device) 110 | count = torch.zeros_like(z, dtype=torch.float32) 111 | weights = gaussian_weights(tile_size // 8, tile_size // 8)[None, None] 112 | weights = torch.tensor(weights, dtype=torch.float32, device=image.device) 113 | tiles = sliding_windows(h // 8, w // 8, tile_size // 8, tile_stride // 8) 114 | skip_feats = [] 115 | for hi, hi_end, wi, wi_end in tiles: 116 | tile_image = image[:, :, hi * 8:hi_end * 8, wi * 8:wi_end * 8] 117 | if fidelity_input is not None: 118 | tile_fidelity_input = fidelity_input[:, :, hi * 8:hi_end * 8, wi * 8:wi_end * 8] 119 | z[:, :, hi:hi_end, wi:wi_end] += self.vae_encode(tile_image, sample=sample) * weights 120 | with torch.no_grad(): 121 | # tile_images = torch.cat([tile_image_ue, tile_image], dim=1) 122 | # skip_feat = fidelity_encoder(tile_images) 123 | if fidelity_input is not None: 124 | skip_feat = fidelity_encoder(tile_fidelity_input) 125 | skip_feats.append(skip_feat) 126 | else: 127 | skip_feats.append(None) 128 | count[:, :, hi:hi_end, wi:wi_end] += weights 129 | z.div_(count) 130 | return z, skip_feats 131 | 132 | def vae_decode(self, z: torch.Tensor, skip_feat=None) -> torch.Tensor: 133 | return self.vae.decode(z / self.scale_factor, skip_feat) 134 | 135 | 136 | def calc_mean_std(self, feat, eps=1e-5): 137 | # eps is a small value added to the variance to avoid divide-by-zero. 138 | size = feat.size() 139 | assert (len(size) == 4) 140 | N, C = size[:2] 141 | feat_var = feat.contiguous().view(N, C, -1).var(dim=2) + eps 142 | feat_std = feat_var.sqrt().view(N, C, 1, 1) 143 | feat_mean = feat.contiguous().view(N, C, -1).mean(dim=2).view(N, C, 1, 1) 144 | return feat_mean, feat_std 145 | 146 | 147 | def adaptive_instance_normalization(self, content_feat, style_feat): 148 | assert (content_feat.size()[:2] == style_feat.size()[:2]) 149 | size = content_feat.size() 150 | # style_mean = torch.tensor([0.6906, 0.6766, 0.6749]).unsqueeze(dim=0).unsqueeze(dim=2).unsqueeze(dim=3).cuda() 151 | # style_std = torch.tensor([0.1955, 0.2096, 0.2236]).unsqueeze(dim=0).unsqueeze(dim=2).unsqueeze(dim=3).cuda() 152 | style_mean, style_std = self.calc_mean_std(style_feat) 153 | content_mean, content_std = self.calc_mean_std(content_feat) 154 | normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) 155 | 156 | k = (style_std) / (content_std) 157 | b = style_mean - content_mean * style_std / content_std 158 | res = normalized_feat * style_std.expand(size) + style_mean.expand(size) 159 | return res 160 | 161 | 162 | @count_vram_usage 163 | def vae_decode_tiled(self, z: torch.Tensor, tile_size: int, tile_stride: int, skip_feats=None, consistent_start=None) -> torch.Tensor: 164 | bs, _, h, w = z.shape 165 | image = torch.zeros((bs, 3, h * 8, w * 8), dtype=torch.float32, device=z.device) 166 | count = torch.zeros_like(image, dtype=torch.float32) 167 | weights = gaussian_weights(tile_size * 8, tile_size * 8)[None, None] 168 | weights = torch.tensor(weights, dtype=torch.float32, device=z.device) 169 | tiles = sliding_windows(h, w, tile_size, tile_stride) 170 | for ind, ((hi, hi_end, wi, wi_end), skip_feat) in enumerate(zip(tiles, skip_feats)): 171 | tile_z = z[:, :, hi:hi_end, wi:wi_end] 172 | tile_z_decoded = self.vae_decode(tile_z, skip_feat) 173 | if consistent_start is not None: 174 | tile_z_decoded = self.adaptive_instance_normalization(tile_z_decoded, consistent_start[:, :, hi * 8:hi_end * 8, wi * 8:wi_end * 8]) 175 | image[:, :, hi * 8:hi_end * 8, wi * 8:wi_end * 8] += tile_z_decoded * weights 176 | count[:, :, hi * 8:hi_end * 8, wi * 8:wi_end * 8] += weights 177 | image.div_(count) 178 | return image 179 | 180 | 181 | def prepare_condition(self, lq2: torch.Tensor, lq1_mscn_norm: torch.Tensor, lq1_color: torch.Tensor, txt: List[str]) -> Dict[str, torch.Tensor]: 182 | # Note the lq_vis should be normalized to [-1, 1] and lq_ifr normalized to [0, 1]!!! 183 | return dict( 184 | c_txt=self.clip.encode(txt), 185 | c_lq2=self.vae_encode(lq2, sample=False), 186 | c_lq1_mscn_norm=lq1_mscn_norm, 187 | c_lq1_color=lq1_color 188 | ) 189 | 190 | 191 | @count_vram_usage 192 | def prepare_condition_tiled(self, lq2: torch.Tensor, lq1_mscn_norm:torch.Tensor, lq1_color: torch.Tensor, txt: List[str], tile_size: int, tile_stride: int, fidelity_encoder=None, fidelity_input: torch.Tensor=None) -> Dict[str, torch.Tensor]: 193 | c_lq2, skip_feats = self.vae_encode_tiled(lq2, tile_size, tile_stride, sample=False, fidelity_encoder=fidelity_encoder, fidelity_input=fidelity_input) # why smaple = False ? 194 | return dict( 195 | c_txt=self.clip.encode(txt), 196 | c_lq2=c_lq2, 197 | c_lq1_mscn_norm = lq1_mscn_norm, 198 | c_lq1_color = lq1_color 199 | ), skip_feats 200 | 201 | def forward(self, x_noisy, t, cond): 202 | c_txt = cond["c_txt"] 203 | 204 | c_lq2 = cond["c_lq2"] 205 | c_lq1_mscn_norm = cond["c_lq1_mscn_norm"] 206 | c_lq1_color = cond["c_lq1_color"] 207 | c_img = [c_lq2, c_lq1_mscn_norm, c_lq1_color] 208 | 209 | control = self.controlnet( 210 | x=x_noisy, hint=c_img, 211 | timesteps=t, context=c_txt 212 | ) 213 | control = [c * scale for c, scale in zip(control, self.control_scales)] 214 | eps = self.unet( 215 | x=x_noisy, timesteps=t, 216 | context=c_txt, control=control, only_mid_control=False 217 | ) 218 | return eps 219 | -------------------------------------------------------------------------------- /model/V4_CA/controlnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch as th 3 | import torch.nn as nn 4 | 5 | from model.util import ( 6 | conv_nd, 7 | linear, 8 | zero_module, 9 | timestep_embedding, 10 | exists 11 | ) 12 | from model.V4_CA.attention import SpatialTransformer 13 | from model.V4_CA.unet import ( 14 | TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock, UNetModel 15 | ) 16 | from model.V4_CA.cross_attention import CrossTransformerBlock2D 17 | 18 | 19 | class ControlledUnetModel(UNetModel): 20 | 21 | def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs): 22 | hs = [] 23 | t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) 24 | emb = self.time_embed(t_emb) 25 | h = x.type(self.dtype) 26 | for module in self.input_blocks: 27 | h = module(h, emb, context) 28 | hs.append(h) 29 | h = self.middle_block(h, emb, context) 30 | 31 | if control is not None: 32 | h += control.pop() 33 | 34 | for i, module in enumerate(self.output_blocks): 35 | if only_mid_control or control is None: 36 | h = torch.cat([h, hs.pop()], dim=1) 37 | else: 38 | h = torch.cat([h, hs.pop() + control.pop()], dim=1) 39 | h = module(h, emb, context) 40 | 41 | h = h.type(x.dtype) 42 | return self.out(h) 43 | 44 | 45 | class ControlNet(nn.Module): 46 | 47 | def __init__( 48 | self, 49 | image_size, 50 | in_channels, 51 | model_channels, 52 | hint_channels, 53 | num_res_blocks, 54 | attention_resolutions, 55 | dropout=0, 56 | channel_mult=(1, 2, 4, 8), 57 | conv_resample=True, 58 | dims=2, 59 | use_checkpoint=False, 60 | use_fp16=False, 61 | num_heads=-1, 62 | num_head_channels=-1, 63 | num_heads_upsample=-1, 64 | use_scale_shift_norm=False, 65 | resblock_updown=False, 66 | use_new_attention_order=False, 67 | use_spatial_transformer=False, # custom transformer support 68 | transformer_depth=1, # custom transformer support 69 | context_dim=None, # custom transformer support 70 | n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model 71 | legacy=True, 72 | disable_self_attentions=None, 73 | num_attention_blocks=None, 74 | disable_middle_self_attn=False, 75 | use_linear_in_transformer=False, 76 | ): 77 | super().__init__() 78 | if use_spatial_transformer: 79 | assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' 80 | 81 | if context_dim is not None: 82 | assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' 83 | from omegaconf.listconfig import ListConfig 84 | if type(context_dim) == ListConfig: 85 | context_dim = list(context_dim) 86 | 87 | if num_heads_upsample == -1: 88 | num_heads_upsample = num_heads 89 | 90 | if num_heads == -1: 91 | assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' 92 | 93 | if num_head_channels == -1: 94 | assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' 95 | 96 | self.dims = dims 97 | self.image_size = image_size 98 | self.in_channels = in_channels 99 | self.model_channels = model_channels 100 | if isinstance(num_res_blocks, int): 101 | self.num_res_blocks = len(channel_mult) * [num_res_blocks] 102 | else: 103 | if len(num_res_blocks) != len(channel_mult): 104 | raise ValueError("provide num_res_blocks either as an int (globally constant) or " 105 | "as a list/tuple (per-level) with the same length as channel_mult") 106 | self.num_res_blocks = num_res_blocks 107 | if disable_self_attentions is not None: 108 | # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not 109 | assert len(disable_self_attentions) == len(channel_mult) 110 | if num_attention_blocks is not None: 111 | assert len(num_attention_blocks) == len(self.num_res_blocks) 112 | assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) 113 | print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " 114 | f"This option has LESS priority than attention_resolutions {attention_resolutions}, " 115 | f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " 116 | f"attention will still not be set.") 117 | 118 | self.attention_resolutions = attention_resolutions 119 | self.dropout = dropout 120 | self.channel_mult = channel_mult 121 | self.conv_resample = conv_resample 122 | self.use_checkpoint = use_checkpoint 123 | self.dtype = th.float16 if use_fp16 else th.float32 124 | self.num_heads = num_heads 125 | self.num_head_channels = num_head_channels 126 | self.num_heads_upsample = num_heads_upsample 127 | self.predict_codebook_ids = n_embed is not None 128 | 129 | time_embed_dim = model_channels * 4 130 | self.time_embed = nn.Sequential( 131 | linear(model_channels, time_embed_dim), 132 | nn.SiLU(), 133 | linear(time_embed_dim, time_embed_dim), 134 | ) 135 | 136 | self.input_blocks = nn.ModuleList( 137 | [ 138 | TimestepEmbedSequential( 139 | conv_nd(dims, in_channels + hint_channels, model_channels, 3, padding=1) 140 | ) 141 | ] 142 | ) 143 | self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)]) 144 | 145 | self.cross_expo_block = nn.ModuleList( 146 | [ 147 | CrossTransformerBlock2D(dim=model_channels * channel_mult[0], num_heads=5), 148 | CrossTransformerBlock2D(dim=model_channels * channel_mult[0], num_heads=5), 149 | CrossTransformerBlock2D(dim=model_channels * channel_mult[1], num_heads=10), 150 | CrossTransformerBlock2D(dim=model_channels * channel_mult[2], num_heads=20), 151 | CrossTransformerBlock2D(dim=model_channels * channel_mult[3], num_heads=20), 152 | ] 153 | ) 154 | 155 | self.struct_hint_block = TimestepEmbedSequential( 156 | conv_nd(dims, 1, 16, 3, padding=1), 157 | nn.SiLU(), 158 | conv_nd(dims, 16, 16, 3, padding=1), 159 | nn.SiLU(), 160 | conv_nd(dims, 16, 32, 3, padding=1, stride=2), 161 | nn.SiLU(), 162 | conv_nd(dims, 32, 32, 3, padding=1), 163 | nn.SiLU(), 164 | conv_nd(dims, 32, 96, 3, padding=1, stride=2), 165 | nn.SiLU(), 166 | conv_nd(dims, 96, 96, 3, padding=1), 167 | nn.SiLU(), 168 | conv_nd(dims, 96, 256, 3, padding=1, stride=2), 169 | nn.SiLU(), 170 | zero_module(conv_nd(dims, 256, model_channels, 3, padding=1)) 171 | ) 172 | self.struct_hint_block2 = nn.ModuleList( 173 | [ 174 | nn.Identity(), # [B, 320, 64, 64] 175 | nn.Sequential(conv_nd(dims, model_channels * channel_mult[0], model_channels * channel_mult[0], 3, padding=1, stride=2), nn.SiLU()), # [B, 320, 32, 32] 176 | nn.Sequential(conv_nd(dims, model_channels * channel_mult[0], model_channels * channel_mult[1], 3, padding=1, stride=2), nn.SiLU()), # [B, 640, 16, 16] 177 | nn.Sequential(conv_nd(dims, model_channels * channel_mult[1], model_channels * channel_mult[2], 3, padding=1, stride=2), nn.SiLU()), # [B, 1280, 8, 8] 178 | nn.Sequential(conv_nd(dims, model_channels * channel_mult[3], model_channels * channel_mult[3], 3, padding=1), nn.SiLU()), # [B, 1280, 8, 8] 179 | ] 180 | ) 181 | 182 | self.struct_hint_block_zero_convs = nn.ModuleList( 183 | [ 184 | nn.Identity(), 185 | zero_module(conv_nd(dims, model_channels * channel_mult[0], model_channels * channel_mult[0], 3, padding=1)), # [B, 320, 32, 32] 186 | zero_module(conv_nd(dims, model_channels * channel_mult[1], model_channels * channel_mult[1], 3, padding=1)), # [B, 640, 16, 16] 187 | zero_module(conv_nd(dims, model_channels * channel_mult[2], model_channels * channel_mult[2], 3, padding=1)), # [B, 1280, 8, 8] 188 | zero_module(conv_nd(dims, model_channels * channel_mult[3], model_channels * channel_mult[3], 3, padding=1)), # [B, 1280, 8, 8] 189 | ] 190 | ) 191 | 192 | self.color_hint_block = TimestepEmbedSequential( 193 | conv_nd(dims, 2, 16, 3, padding=1), 194 | nn.SiLU(), 195 | conv_nd(dims, 16, 16, 3, padding=1), 196 | nn.SiLU(), 197 | conv_nd(dims, 16, 32, 3, padding=1, stride=2), 198 | nn.SiLU(), 199 | conv_nd(dims, 32, 32, 3, padding=1), 200 | nn.SiLU(), 201 | conv_nd(dims, 32, 96, 3, padding=1, stride=2), 202 | nn.SiLU(), 203 | conv_nd(dims, 96, 96, 3, padding=1), 204 | nn.SiLU(), 205 | conv_nd(dims, 96, 256, 3, padding=1, stride=2), 206 | nn.SiLU(), 207 | zero_module(conv_nd(dims, 256, model_channels, 3, padding=1)) 208 | ) 209 | self.color_hint_block2 = nn.ModuleList( 210 | [ 211 | nn.Identity(), # [B, 320, 64, 64] 212 | nn.Sequential(conv_nd(dims, model_channels * channel_mult[0], model_channels * channel_mult[0], 3, padding=1, stride=2), nn.SiLU()), # [B, 320, 32, 32] 213 | nn.Sequential(conv_nd(dims, model_channels * channel_mult[0], model_channels * channel_mult[1], 3, padding=1, stride=2), nn.SiLU()), # [B, 640, 16, 16] 214 | nn.Sequential(conv_nd(dims, model_channels * channel_mult[1], model_channels * channel_mult[2], 3, padding=1, stride=2), nn.SiLU()), # [B, 1280, 8, 8] 215 | nn.Sequential(conv_nd(dims, model_channels * channel_mult[3], model_channels * channel_mult[3], 3, padding=1), nn.SiLU()), # [B, 1280, 8, 8] 216 | ] 217 | ) 218 | self.color_hint_block_zero_convs = nn.ModuleList( 219 | [ 220 | nn.Identity(), 221 | zero_module(conv_nd(dims, model_channels * channel_mult[0], model_channels * channel_mult[0], 3, padding=1)), # [B, 320, 32, 32] 222 | zero_module(conv_nd(dims, model_channels * channel_mult[1], model_channels * channel_mult[1], 3, padding=1)), # [B, 640, 16, 16] 223 | zero_module(conv_nd(dims, model_channels * channel_mult[2], model_channels * channel_mult[2], 3, padding=1)), # [B, 1280, 8, 8] 224 | zero_module(conv_nd(dims, model_channels * channel_mult[3], model_channels * channel_mult[3], 3, padding=1)), # [B, 1280, 8, 8] 225 | ] 226 | ) 227 | 228 | 229 | self._feature_size = model_channels 230 | input_block_chans = [model_channels] 231 | ch = model_channels 232 | ds = 1 233 | for level, mult in enumerate(channel_mult): 234 | for nr in range(self.num_res_blocks[level]): 235 | layers = [ 236 | ResBlock( 237 | ch, 238 | time_embed_dim, 239 | dropout, 240 | out_channels=mult * model_channels, 241 | dims=dims, 242 | use_checkpoint=use_checkpoint, 243 | use_scale_shift_norm=use_scale_shift_norm, 244 | ) 245 | ] 246 | ch = mult * model_channels 247 | if ds in attention_resolutions: 248 | if num_head_channels == -1: 249 | dim_head = ch // num_heads 250 | else: 251 | num_heads = ch // num_head_channels 252 | dim_head = num_head_channels 253 | if legacy: 254 | # num_heads = 1 255 | dim_head = ch // num_heads if use_spatial_transformer else num_head_channels 256 | if exists(disable_self_attentions): 257 | disabled_sa = disable_self_attentions[level] 258 | else: 259 | disabled_sa = False 260 | 261 | if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: 262 | layers.append( 263 | AttentionBlock( 264 | ch, 265 | use_checkpoint=use_checkpoint, 266 | num_heads=num_heads, 267 | num_head_channels=dim_head, 268 | use_new_attention_order=use_new_attention_order, 269 | ) if not use_spatial_transformer else SpatialTransformer( 270 | ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, 271 | disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, 272 | use_checkpoint=use_checkpoint 273 | ) 274 | ) 275 | self.input_blocks.append(TimestepEmbedSequential(*layers)) 276 | self.zero_convs.append(self.make_zero_conv(ch)) 277 | self._feature_size += ch 278 | input_block_chans.append(ch) 279 | if level != len(channel_mult) - 1: 280 | out_ch = ch 281 | self.input_blocks.append( 282 | TimestepEmbedSequential( 283 | ResBlock( 284 | ch, 285 | time_embed_dim, 286 | dropout, 287 | out_channels=out_ch, 288 | dims=dims, 289 | use_checkpoint=use_checkpoint, 290 | use_scale_shift_norm=use_scale_shift_norm, 291 | down=True, 292 | ) 293 | if resblock_updown 294 | else Downsample( 295 | ch, conv_resample, dims=dims, out_channels=out_ch 296 | ) 297 | ) 298 | ) 299 | ch = out_ch 300 | input_block_chans.append(ch) 301 | self.zero_convs.append(self.make_zero_conv(ch)) 302 | ds *= 2 303 | self._feature_size += ch 304 | 305 | if num_head_channels == -1: 306 | dim_head = ch // num_heads 307 | else: 308 | num_heads = ch // num_head_channels 309 | dim_head = num_head_channels 310 | if legacy: 311 | # num_heads = 1 312 | dim_head = ch // num_heads if use_spatial_transformer else num_head_channels 313 | self.middle_block = TimestepEmbedSequential( 314 | ResBlock( 315 | ch, 316 | time_embed_dim, 317 | dropout, 318 | dims=dims, 319 | use_checkpoint=use_checkpoint, 320 | use_scale_shift_norm=use_scale_shift_norm, 321 | ), 322 | AttentionBlock( 323 | ch, 324 | use_checkpoint=use_checkpoint, 325 | num_heads=num_heads, 326 | num_head_channels=dim_head, 327 | use_new_attention_order=use_new_attention_order, 328 | ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn 329 | ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, 330 | disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, 331 | use_checkpoint=use_checkpoint 332 | ), 333 | ResBlock( 334 | ch, 335 | time_embed_dim, 336 | dropout, 337 | dims=dims, 338 | use_checkpoint=use_checkpoint, 339 | use_scale_shift_norm=use_scale_shift_norm, 340 | ), 341 | ) 342 | self.middle_block_out = self.make_zero_conv(ch) 343 | self._feature_size += ch 344 | 345 | def make_zero_conv(self, channels): 346 | return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0))) 347 | 348 | def forward(self, x, hint, timesteps, context, **kwargs): 349 | t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) 350 | emb = self.time_embed(t_emb) 351 | 352 | hint_vis, hint_mscn, hint_color = hint[0], hint[1], hint[2] 353 | x = torch.cat((x, hint_vis), dim=1) 354 | guided_hint_mscn = self.struct_hint_block(hint_mscn, emb, context) 355 | guided_hint_color = self.color_hint_block(hint_color, emb, context) 356 | outs = [] 357 | 358 | h = x.type(self.dtype) 359 | for i, (module, zero_conv) in enumerate(zip(self.input_blocks, self.zero_convs)): 360 | h = module(h, emb, context) 361 | if i in [0, 3, 6, 9]: 362 | guided_hint_mscn = self.struct_hint_block2[i // 3](guided_hint_mscn) 363 | guided_hint_color = self.color_hint_block2[i // 3](guided_hint_color) 364 | h = self.cross_expo_block[i // 3](h, self.struct_hint_block_zero_convs[i // 3](guided_hint_mscn), self.color_hint_block_zero_convs[i // 3](guided_hint_color)) 365 | outs.append(zero_conv(h, emb, context)) 366 | 367 | h = self.middle_block(h, emb, context) 368 | guided_hint_mscn = self.struct_hint_block2[4](guided_hint_mscn) 369 | guided_hint_color = self.color_hint_block2[4](guided_hint_color) 370 | h = self.cross_expo_block[4](h, self.struct_hint_block_zero_convs[4](guided_hint_mscn), self.color_hint_block_zero_convs[4](guided_hint_color)) 371 | outs.append(self.middle_block_out(h, emb, context)) 372 | 373 | return outs 374 | -------------------------------------------------------------------------------- /model/V4_CA/cross_attention.py: -------------------------------------------------------------------------------- 1 | # Restormer: Efficient Transformer for High-Resolution Image Restoration 2 | # Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang 3 | # https://arxiv.org/abs/2111.09881 4 | 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import numbers 10 | from einops import rearrange 11 | 12 | 13 | class Conv2dNormRelu(nn.Module): 14 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, norm=None, activation='leaky_relu'): 15 | super().__init__() 16 | self.conv_fn = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups) 17 | 18 | if norm == 'batch_norm': 19 | self.norm_fn = nn.BatchNorm2d(out_channels) 20 | elif norm == 'instance_norm': 21 | self.norm_fn = nn.InstanceNorm2d(out_channels) 22 | elif norm is None: 23 | self.norm_fn = nn.Identity() 24 | else: 25 | raise NotImplementedError('Unknown normalization function: %s' % norm) 26 | 27 | if activation == 'relu': 28 | self.relu_fn = nn.ReLU(inplace=True) 29 | elif activation == 'leaky_relu': 30 | self.relu_fn = nn.LeakyReLU(negative_slope=0.1, inplace=True) 31 | elif activation is None: 32 | self.relu_fn = nn.Identity() 33 | else: 34 | raise NotImplementedError('Unknown activation function: %s' % activation) 35 | 36 | def forward(self, x): 37 | x = self.conv_fn(x) 38 | x = self.norm_fn(x) 39 | x = self.relu_fn(x) 40 | return x 41 | 42 | 43 | def to_3d(x): 44 | if len(x.shape) == 3: 45 | return rearrange(x, 'b c n -> b n c') 46 | else: 47 | return rearrange(x, 'b c h w -> b (h w) c') 48 | 49 | 50 | def to_4d(x, h, w=None): 51 | if w is None: 52 | return rearrange(x, 'b n c -> b c n') 53 | else: 54 | return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 55 | 56 | 57 | class BiasFree_LayerNorm(nn.Module): 58 | def __init__(self, normalized_shape): 59 | super(BiasFree_LayerNorm, self).__init__() 60 | if isinstance(normalized_shape, numbers.Integral): 61 | normalized_shape = (normalized_shape,) 62 | normalized_shape = torch.Size(normalized_shape) 63 | 64 | assert len(normalized_shape) == 1 65 | 66 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 67 | self.normalized_shape = normalized_shape 68 | 69 | def forward(self, x): 70 | sigma = x.var(-1, keepdim=True, unbiased=False) 71 | return x / torch.sqrt(sigma+1e-5) * self.weight 72 | 73 | 74 | class WithBias_LayerNorm(nn.Module): 75 | def __init__(self, normalized_shape): 76 | super(WithBias_LayerNorm, self).__init__() 77 | if isinstance(normalized_shape, numbers.Integral): 78 | normalized_shape = (normalized_shape,) 79 | normalized_shape = torch.Size(normalized_shape) 80 | 81 | assert len(normalized_shape) == 1 82 | 83 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 84 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 85 | self.normalized_shape = normalized_shape 86 | 87 | def forward(self, x): 88 | mu = x.mean(-1, keepdim=True) 89 | sigma = x.var(-1, keepdim=True, unbiased=False) 90 | return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias 91 | 92 | 93 | class LayerNorm(nn.Module): 94 | def __init__(self, dim, LayerNorm_type): 95 | super(LayerNorm, self).__init__() 96 | if LayerNorm_type == 'BiasFree': 97 | self.body = BiasFree_LayerNorm(dim) 98 | else: 99 | self.body = WithBias_LayerNorm(dim) 100 | 101 | def forward(self, x): 102 | if len(x.shape) == 3: 103 | h = x.shape[-1] 104 | w = None 105 | else: 106 | h, w = x.shape[-2:] 107 | 108 | return to_4d(self.body(to_3d(x)), h, w) 109 | 110 | 111 | class Mutual_Attention2D(nn.Module): 112 | def __init__(self, dim, num_heads, bias): 113 | super(Mutual_Attention2D, self).__init__() 114 | self.num_heads = num_heads 115 | self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) 116 | 117 | self.qkv_dwconv = nn.Conv2d( 118 | dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias) 119 | self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) 120 | 121 | def forward(self, x, y): 122 | b, c, h, w = x.shape 123 | 124 | qkv = self.qkv_dwconv(torch.cat((x, y, y), dim=1)) 125 | q, k, v = qkv.chunk(3, dim=1) 126 | 127 | q = rearrange(q, 'b (head c) h w -> b head c (h w)', 128 | head=self.num_heads) 129 | k = rearrange(k, 'b (head c) h w -> b head c (h w)', 130 | head=self.num_heads) 131 | v = rearrange(v, 'b (head c) h w -> b head c (h w)', 132 | head=self.num_heads) 133 | 134 | q = torch.nn.functional.normalize(q, dim=-1) 135 | k = torch.nn.functional.normalize(k, dim=-1) 136 | 137 | attn = (q @ k.transpose(-2, -1)) * self.temperature 138 | attn = attn.softmax(dim=-1) 139 | 140 | out = (attn @ v) 141 | 142 | out = rearrange(out, 'b head c (h w) -> b (head c) h w', 143 | head=self.num_heads, h=h, w=w) 144 | 145 | out = self.project_out(out) 146 | return out 147 | 148 | 149 | class CrossTransformerBlock2D(nn.Module): 150 | def __init__(self, dim, num_heads, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias'): 151 | super(CrossTransformerBlock2D, self).__init__() 152 | 153 | self.mlps = nn.Sequential( 154 | Conv2dNormRelu(dim * 2, dim), 155 | ) 156 | self.norm1x = LayerNorm(dim, LayerNorm_type) 157 | self.norm1y = LayerNorm(dim, LayerNorm_type) 158 | self.attn = Mutual_Attention2D(dim, num_heads, bias) 159 | # self.norm2 = LayerNorm(dim, LayerNorm_type) 160 | # self.ffn = FeedForward2D(dim, ffn_expansion_factor, bias) 161 | 162 | def forward(self, x, y, z): 163 | assert x.shape == y.shape and x.shape == z.shape, 'wrong shape!, {} {} {}'.format(x.shape, y.shape, z.shape) 164 | y2 = self.mlps(torch.cat([y, z], dim=1)) 165 | x = x + self.attn(self.norm1x(x), self.norm1y(y2)) 166 | # x = x + self.ffn(self.norm2(x)) 167 | 168 | return x -------------------------------------------------------------------------------- /model/V4_CA/gaussian_diffusion.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Tuple 3 | 4 | import torch 5 | from torch import nn 6 | import numpy as np 7 | 8 | 9 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 10 | if schedule == "linear": 11 | betas = ( 12 | np.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=np.float64) ** 2 13 | ) 14 | 15 | elif schedule == "cosine": 16 | timesteps = ( 17 | np.arange(n_timestep + 1, dtype=np.float64) / n_timestep + cosine_s 18 | ) 19 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 20 | alphas = np.cos(alphas).pow(2) 21 | alphas = alphas / alphas[0] 22 | betas = 1 - alphas[1:] / alphas[:-1] 23 | betas = np.clip(betas, a_min=0, a_max=0.999) 24 | 25 | elif schedule == "sqrt_linear": 26 | betas = np.linspace(linear_start, linear_end, n_timestep, dtype=np.float64) 27 | elif schedule == "sqrt": 28 | betas = np.linspace(linear_start, linear_end, n_timestep, dtype=np.float64) ** 0.5 29 | else: 30 | raise ValueError(f"schedule '{schedule}' unknown.") 31 | return betas 32 | 33 | 34 | def extract_into_tensor(a: torch.Tensor, t: torch.Tensor, x_shape: Tuple[int]) -> torch.Tensor: 35 | b, *_ = t.shape 36 | out = a.gather(-1, t) 37 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 38 | 39 | 40 | class Diffusion(nn.Module): 41 | 42 | def __init__( 43 | self, 44 | timesteps=1000, 45 | beta_schedule="linear", 46 | loss_type="l2", 47 | linear_start=1e-4, 48 | linear_end=2e-2, 49 | cosine_s=8e-3, 50 | parameterization="eps" 51 | ): 52 | super().__init__() 53 | self.num_timesteps = timesteps 54 | self.beta_schedule = beta_schedule 55 | self.linear_start = linear_start 56 | self.linear_end = linear_end 57 | self.cosine_s = cosine_s 58 | assert parameterization in ["eps", "x0", "v"], "currently only supporting 'eps' and 'x0' and 'v'" 59 | self.parameterization = parameterization 60 | self.loss_type = loss_type 61 | 62 | betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, 63 | cosine_s=cosine_s) 64 | alphas = 1. - betas 65 | alphas_cumprod = np.cumprod(alphas, axis=0) 66 | sqrt_alphas_cumprod = np.sqrt(alphas_cumprod) 67 | sqrt_one_minus_alphas_cumprod = np.sqrt(1. - alphas_cumprod) 68 | 69 | self.betas = betas 70 | self.register("sqrt_alphas_cumprod", sqrt_alphas_cumprod) 71 | self.register("sqrt_one_minus_alphas_cumprod", sqrt_one_minus_alphas_cumprod) 72 | 73 | def register(self, name: str, value: np.ndarray) -> None: 74 | self.register_buffer(name, torch.tensor(value, dtype=torch.float32)) 75 | 76 | def q_sample(self, x_start, t, noise): 77 | return ( 78 | extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 79 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise 80 | ) 81 | 82 | def get_v(self, x, noise, t): 83 | return ( 84 | extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise - 85 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x 86 | ) 87 | 88 | def get_loss(self, pred, target, mean=True): 89 | if self.loss_type == 'l1': 90 | loss = (target - pred).abs() 91 | if mean: 92 | loss = loss.mean() 93 | elif self.loss_type == 'l2': 94 | if mean: 95 | loss = torch.nn.functional.mse_loss(target, pred) 96 | else: 97 | loss = torch.nn.functional.mse_loss(target, pred, reduction='none') 98 | else: 99 | raise NotImplementedError("unknown loss type '{loss_type}'") 100 | 101 | return loss 102 | 103 | def p_losses(self, model, x_start, t, cond): 104 | noise = torch.randn_like(x_start) 105 | x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) 106 | model_output = model(x_noisy, t, cond) 107 | 108 | if self.parameterization == "x0": 109 | target = x_start 110 | elif self.parameterization == "eps": 111 | target = noise 112 | elif self.parameterization == "v": 113 | target = self.get_v(x_start, noise, t) 114 | else: 115 | raise NotImplementedError() 116 | 117 | loss_simple = self.get_loss(model_output, target, mean=False).mean() 118 | return loss_simple 119 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | # from . import config 2 | 3 | # from .controlnet import ControlledUnetModel, ControlNet 4 | # from .vae import AutoencoderKL 5 | # from .clip import FrozenOpenCLIPEmbedder 6 | 7 | # from .cldm import ControlLDM 8 | # from .cldm_mc import MultiControlLDM 9 | # from .gaussian_diffusion import Diffusion 10 | 11 | # from .swinir import SwinIR 12 | # from .bsrnet import RRDBNet 13 | # from .scunet import SCUNet 14 | -------------------------------------------------------------------------------- /model/clip.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import torch 3 | import torch.nn as nn 4 | from torch.utils.checkpoint import checkpoint 5 | from model.open_clip import CLIP, tokenize 6 | 7 | ### pretrained model path 8 | # _VITH14 = dict( 9 | # laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'), 10 | # ) 11 | 12 | class FrozenOpenCLIPEmbedder(nn.Module): 13 | """ 14 | Uses the OpenCLIP transformer encoder for text 15 | """ 16 | LAYERS = [ 17 | #"pooled", 18 | "last", 19 | "penultimate" 20 | ] 21 | def __init__(self, embed_dim, vision_cfg, text_cfg, layer="last"): 22 | super().__init__() 23 | assert layer in self.LAYERS 24 | # model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) 25 | model = CLIP(embed_dim, dict(vision_cfg), dict(text_cfg)) 26 | del model.visual 27 | self.model = model 28 | 29 | self.layer = layer 30 | if self.layer == "last": 31 | self.layer_idx = 0 32 | elif self.layer == "penultimate": 33 | self.layer_idx = 1 34 | else: 35 | raise NotImplementedError() 36 | 37 | def forward(self, tokens): 38 | z = self.encode_with_transformer(tokens) 39 | return z 40 | 41 | def encode_with_transformer(self, text): 42 | x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] 43 | x = x + self.model.positional_embedding 44 | x = x.permute(1, 0, 2) # NLD -> LND 45 | x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) 46 | x = x.permute(1, 0, 2) # LND -> NLD 47 | x = self.model.ln_final(x) 48 | return x 49 | 50 | def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): 51 | for i, r in enumerate(self.model.transformer.resblocks): 52 | if i == len(self.model.transformer.resblocks) - self.layer_idx: 53 | break 54 | if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): 55 | x = checkpoint(r, x, attn_mask) 56 | else: 57 | x = r(x, attn_mask=attn_mask) 58 | return x 59 | 60 | def encode(self, text: List[str]) -> torch.Tensor: 61 | # convert a batch of text to tensor 62 | tokens = tokenize(text) 63 | # move tensor to model device 64 | tokens = tokens.to(next(self.model.parameters()).device) 65 | return self(tokens) 66 | -------------------------------------------------------------------------------- /model/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional, Literal 3 | from types import ModuleType 4 | import enum 5 | from packaging import version 6 | 7 | import torch 8 | 9 | # collect system information 10 | if version.parse(torch.__version__) >= version.parse("2.0.0"): 11 | SDP_IS_AVAILABLE = True 12 | else: 13 | SDP_IS_AVAILABLE = False 14 | 15 | try: 16 | import xformers 17 | import xformers.ops 18 | XFORMERS_IS_AVAILBLE = True 19 | except: 20 | XFORMERS_IS_AVAILBLE = False 21 | 22 | 23 | class AttnMode(enum.Enum): 24 | SDP = 0 25 | XFORMERS = 1 26 | VANILLA = 2 27 | 28 | 29 | class Config: 30 | xformers: Optional[ModuleType] = None 31 | attn_mode: AttnMode = AttnMode.VANILLA 32 | 33 | 34 | # initialize attention mode 35 | if SDP_IS_AVAILABLE: 36 | Config.attn_mode = AttnMode.SDP 37 | print(f"use sdp attention as default") 38 | elif XFORMERS_IS_AVAILBLE: 39 | Config.attn_mode = AttnMode.XFORMERS 40 | print(f"use xformers attention as default") 41 | else: 42 | print(f"both sdp attention and xformers are not available, use vanilla attention (very expensive) as default") 43 | 44 | if XFORMERS_IS_AVAILBLE: 45 | Config.xformers = xformers 46 | 47 | 48 | # user-specified attention mode 49 | ATTN_MODE = os.environ.get("ATTN_MODE", None) 50 | if ATTN_MODE is not None: 51 | assert ATTN_MODE in ["vanilla", "sdp", "xformers"] 52 | if ATTN_MODE == "sdp": 53 | assert SDP_IS_AVAILABLE 54 | Config.attn_mode = AttnMode.SDP 55 | elif ATTN_MODE == "xformers": 56 | assert XFORMERS_IS_AVAILBLE 57 | Config.attn_mode = AttnMode.XFORMERS 58 | else: 59 | Config.attn_mode = AttnMode.VANILLA 60 | print(f"set attention mode to {ATTN_MODE}") 61 | else: 62 | print("keep default attention mode") 63 | -------------------------------------------------------------------------------- /model/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /model/open_clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import CLIP 2 | from .tokenizer import tokenize 3 | 4 | __all__ = ["CLIP", "tokenize"] 5 | -------------------------------------------------------------------------------- /model/open_clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenImagingLab/UltraFusion/b0787c59217fe4877c233d938810a23c8d2c5c73/model/open_clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /model/open_clip/model.py: -------------------------------------------------------------------------------- 1 | """ CLIP Model 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | from dataclasses import dataclass 6 | from typing import Optional, Tuple, Union 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | from torch import nn 12 | 13 | from .transformer import LayerNormFp32, LayerNorm, QuickGELU, VisionTransformer, TextTransformer 14 | 15 | 16 | @dataclass 17 | class CLIPVisionCfg: 18 | layers: Union[Tuple[int, int, int, int], int] = 12 19 | width: int = 768 20 | head_width: int = 64 21 | mlp_ratio: float = 4.0 22 | patch_size: int = 16 23 | image_size: Union[Tuple[int, int], int] = 224 24 | 25 | ls_init_value: Optional[float] = None # layer scale initial value 26 | patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results 27 | input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design 28 | global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580) 29 | attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer 30 | n_queries: int = 256 # n_queries for attentional pooler 31 | attn_pooler_heads: int = 8 # n heads for attentional_pooling 32 | output_tokens: bool = False 33 | 34 | timm_model_name: str = None # a valid model name overrides layers, width, patch_size 35 | timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model 36 | timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') 37 | timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') 38 | timm_proj_bias: bool = False # enable bias final projection 39 | timm_drop: float = 0. # head dropout 40 | timm_drop_path: Optional[float] = None # backbone stochastic depth 41 | 42 | 43 | @dataclass 44 | class CLIPTextCfg: 45 | context_length: int = 77 46 | vocab_size: int = 49408 47 | width: int = 512 48 | heads: int = 8 49 | layers: int = 12 50 | ls_init_value: Optional[float] = None # layer scale initial value 51 | hf_model_name: str = None 52 | hf_tokenizer_name: str = None 53 | hf_model_pretrained: bool = True 54 | proj: str = 'mlp' 55 | pooler_type: str = 'mean_pooler' 56 | embed_cls: bool = False 57 | pad_id: int = 0 58 | output_tokens: bool = False 59 | 60 | 61 | def get_cast_dtype(precision: str): 62 | cast_dtype = None 63 | if precision == 'bf16': 64 | cast_dtype = torch.bfloat16 65 | elif precision == 'fp16': 66 | cast_dtype = torch.float16 67 | return cast_dtype 68 | 69 | 70 | def _build_vision_tower( 71 | embed_dim: int, 72 | vision_cfg: CLIPVisionCfg, 73 | quick_gelu: bool = False, 74 | cast_dtype: Optional[torch.dtype] = None 75 | ): 76 | if isinstance(vision_cfg, dict): 77 | vision_cfg = CLIPVisionCfg(**vision_cfg) 78 | 79 | # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more 80 | # memory efficient in recent PyTorch releases (>= 1.10). 81 | # NOTE: timm models always use native GELU regardless of quick_gelu flag. 82 | act_layer = QuickGELU if quick_gelu else nn.GELU 83 | 84 | vision_heads = vision_cfg.width // vision_cfg.head_width 85 | norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm 86 | visual = VisionTransformer( 87 | image_size=vision_cfg.image_size, 88 | patch_size=vision_cfg.patch_size, 89 | width=vision_cfg.width, 90 | layers=vision_cfg.layers, 91 | heads=vision_heads, 92 | mlp_ratio=vision_cfg.mlp_ratio, 93 | ls_init_value=vision_cfg.ls_init_value, 94 | patch_dropout=vision_cfg.patch_dropout, 95 | input_patchnorm=vision_cfg.input_patchnorm, 96 | global_average_pool=vision_cfg.global_average_pool, 97 | attentional_pool=vision_cfg.attentional_pool, 98 | n_queries=vision_cfg.n_queries, 99 | attn_pooler_heads=vision_cfg.attn_pooler_heads, 100 | output_tokens=vision_cfg.output_tokens, 101 | output_dim=embed_dim, 102 | act_layer=act_layer, 103 | norm_layer=norm_layer, 104 | ) 105 | 106 | return visual 107 | 108 | 109 | def _build_text_tower( 110 | embed_dim: int, 111 | text_cfg: CLIPTextCfg, 112 | quick_gelu: bool = False, 113 | cast_dtype: Optional[torch.dtype] = None, 114 | ): 115 | if isinstance(text_cfg, dict): 116 | text_cfg = CLIPTextCfg(**text_cfg) 117 | 118 | act_layer = QuickGELU if quick_gelu else nn.GELU 119 | norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm 120 | 121 | text = TextTransformer( 122 | context_length=text_cfg.context_length, 123 | vocab_size=text_cfg.vocab_size, 124 | width=text_cfg.width, 125 | heads=text_cfg.heads, 126 | layers=text_cfg.layers, 127 | ls_init_value=text_cfg.ls_init_value, 128 | output_dim=embed_dim, 129 | embed_cls=text_cfg.embed_cls, 130 | output_tokens=text_cfg.output_tokens, 131 | pad_id=text_cfg.pad_id, 132 | act_layer=act_layer, 133 | norm_layer=norm_layer, 134 | ) 135 | return text 136 | 137 | 138 | class CLIP(nn.Module): 139 | output_dict: torch.jit.Final[bool] 140 | 141 | def __init__( 142 | self, 143 | embed_dim: int, 144 | vision_cfg: CLIPVisionCfg, 145 | text_cfg: CLIPTextCfg, 146 | quick_gelu: bool = False, 147 | cast_dtype: Optional[torch.dtype] = None, 148 | output_dict: bool = False, 149 | ): 150 | super().__init__() 151 | self.output_dict = output_dict 152 | self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) 153 | 154 | text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) 155 | self.transformer = text.transformer 156 | self.context_length = text.context_length 157 | self.vocab_size = text.vocab_size 158 | self.token_embedding = text.token_embedding 159 | self.positional_embedding = text.positional_embedding 160 | self.ln_final = text.ln_final 161 | self.text_projection = text.text_projection 162 | self.register_buffer('attn_mask', text.attn_mask, persistent=False) 163 | 164 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 165 | 166 | def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): 167 | # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 168 | self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) 169 | 170 | @torch.jit.ignore 171 | def set_grad_checkpointing(self, enable=True): 172 | self.visual.set_grad_checkpointing(enable) 173 | self.transformer.grad_checkpointing = enable 174 | 175 | def encode_image(self, image, normalize: bool = False): 176 | features = self.visual(image) 177 | return F.normalize(features, dim=-1) if normalize else features 178 | 179 | def encode_text(self, text, normalize: bool = False): 180 | cast_dtype = self.transformer.get_cast_dtype() 181 | 182 | x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] 183 | 184 | x = x + self.positional_embedding.to(cast_dtype) 185 | x = x.permute(1, 0, 2) # NLD -> LND 186 | x = self.transformer(x, attn_mask=self.attn_mask) 187 | x = x.permute(1, 0, 2) # LND -> NLD 188 | x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] 189 | # take features from the eot embedding (eot_token is the highest number in each sequence) 190 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 191 | return F.normalize(x, dim=-1) if normalize else x 192 | 193 | def forward( 194 | self, 195 | image: Optional[torch.Tensor] = None, 196 | text: Optional[torch.Tensor] = None, 197 | ): 198 | image_features = self.encode_image(image, normalize=True) if image is not None else None 199 | text_features = self.encode_text(text, normalize=True) if text is not None else None 200 | if self.output_dict: 201 | return { 202 | "image_features": image_features, 203 | "text_features": text_features, 204 | "logit_scale": self.logit_scale.exp() 205 | } 206 | return image_features, text_features, self.logit_scale.exp() 207 | -------------------------------------------------------------------------------- /model/open_clip/tokenizer.py: -------------------------------------------------------------------------------- 1 | """ CLIP tokenizer 2 | 3 | Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | import gzip 6 | import html 7 | import os 8 | from functools import lru_cache 9 | from typing import Union, List 10 | 11 | import ftfy 12 | import regex as re 13 | import torch 14 | 15 | # https://stackoverflow.com/q/62691279 16 | import os 17 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 18 | 19 | 20 | @lru_cache() 21 | def default_bpe(): 22 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 23 | 24 | 25 | @lru_cache() 26 | def bytes_to_unicode(): 27 | """ 28 | Returns list of utf-8 byte and a corresponding list of unicode strings. 29 | The reversible bpe codes work on unicode strings. 30 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 31 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 32 | This is a significant percentage of your normal, say, 32K bpe vocab. 33 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 34 | And avoids mapping to whitespace/control characters the bpe code barfs on. 35 | """ 36 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 37 | cs = bs[:] 38 | n = 0 39 | for b in range(2**8): 40 | if b not in bs: 41 | bs.append(b) 42 | cs.append(2**8+n) 43 | n += 1 44 | cs = [chr(n) for n in cs] 45 | return dict(zip(bs, cs)) 46 | 47 | 48 | def get_pairs(word): 49 | """Return set of symbol pairs in a word. 50 | Word is represented as tuple of symbols (symbols being variable-length strings). 51 | """ 52 | pairs = set() 53 | prev_char = word[0] 54 | for char in word[1:]: 55 | pairs.add((prev_char, char)) 56 | prev_char = char 57 | return pairs 58 | 59 | 60 | def basic_clean(text): 61 | text = ftfy.fix_text(text) 62 | text = html.unescape(html.unescape(text)) 63 | return text.strip() 64 | 65 | 66 | def whitespace_clean(text): 67 | text = re.sub(r'\s+', ' ', text) 68 | text = text.strip() 69 | return text 70 | 71 | 72 | class SimpleTokenizer(object): 73 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): 74 | self.byte_encoder = bytes_to_unicode() 75 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 76 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 77 | merges = merges[1:49152-256-2+1] 78 | merges = [tuple(merge.split()) for merge in merges] 79 | vocab = list(bytes_to_unicode().values()) 80 | vocab = vocab + [v+'' for v in vocab] 81 | for merge in merges: 82 | vocab.append(''.join(merge)) 83 | if not special_tokens: 84 | special_tokens = ['', ''] 85 | else: 86 | special_tokens = ['', ''] + special_tokens 87 | vocab.extend(special_tokens) 88 | self.encoder = dict(zip(vocab, range(len(vocab)))) 89 | self.decoder = {v: k for k, v in self.encoder.items()} 90 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 91 | self.cache = {t:t for t in special_tokens} 92 | special = "|".join(special_tokens) 93 | self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 94 | 95 | self.vocab_size = len(self.encoder) 96 | self.all_special_ids = [self.encoder[t] for t in special_tokens] 97 | 98 | def bpe(self, token): 99 | if token in self.cache: 100 | return self.cache[token] 101 | word = tuple(token[:-1]) + ( token[-1] + '',) 102 | pairs = get_pairs(word) 103 | 104 | if not pairs: 105 | return token+'' 106 | 107 | while True: 108 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 109 | if bigram not in self.bpe_ranks: 110 | break 111 | first, second = bigram 112 | new_word = [] 113 | i = 0 114 | while i < len(word): 115 | try: 116 | j = word.index(first, i) 117 | new_word.extend(word[i:j]) 118 | i = j 119 | except: 120 | new_word.extend(word[i:]) 121 | break 122 | 123 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 124 | new_word.append(first+second) 125 | i += 2 126 | else: 127 | new_word.append(word[i]) 128 | i += 1 129 | new_word = tuple(new_word) 130 | word = new_word 131 | if len(word) == 1: 132 | break 133 | else: 134 | pairs = get_pairs(word) 135 | word = ' '.join(word) 136 | self.cache[token] = word 137 | return word 138 | 139 | def encode(self, text): 140 | bpe_tokens = [] 141 | text = whitespace_clean(basic_clean(text)).lower() 142 | for token in re.findall(self.pat, text): 143 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 144 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 145 | return bpe_tokens 146 | 147 | def decode(self, tokens): 148 | text = ''.join([self.decoder[token] for token in tokens]) 149 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 150 | return text 151 | 152 | 153 | _tokenizer = SimpleTokenizer() 154 | 155 | def decode(output_ids: torch.Tensor): 156 | output_ids = output_ids.cpu().numpy() 157 | return _tokenizer.decode(output_ids) 158 | 159 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: 160 | """ 161 | Returns the tokenized representation of given input string(s) 162 | 163 | Parameters 164 | ---------- 165 | texts : Union[str, List[str]] 166 | An input string or a list of input strings to tokenize 167 | context_length : int 168 | The context length to use; all CLIP models use 77 as the context length 169 | 170 | Returns 171 | ------- 172 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 173 | """ 174 | if isinstance(texts, str): 175 | texts = [texts] 176 | 177 | sot_token = _tokenizer.encoder[""] 178 | eot_token = _tokenizer.encoder[""] 179 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 180 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 181 | 182 | for i, tokens in enumerate(all_tokens): 183 | if len(tokens) > context_length: 184 | tokens = tokens[:context_length] # Truncate 185 | tokens[-1] = eot_token 186 | result[i, :len(tokens)] = torch.tensor(tokens) 187 | 188 | return result 189 | 190 | 191 | class HFTokenizer: 192 | """HuggingFace tokenizer wrapper""" 193 | 194 | def __init__(self, tokenizer_name: str): 195 | from transformers import AutoTokenizer 196 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 197 | 198 | def save_pretrained(self, dest): 199 | self.tokenizer.save_pretrained(dest) 200 | 201 | def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor: 202 | # same cleaning as for default tokenizer, except lowercasing 203 | # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance 204 | if isinstance(texts, str): 205 | texts = [texts] 206 | texts = [whitespace_clean(basic_clean(text)) for text in texts] 207 | input_ids = self.tokenizer( 208 | texts, 209 | return_tensors='pt', 210 | max_length=context_length, 211 | padding='max_length', 212 | truncation=True, 213 | ).input_ids 214 | return input_ids 215 | -------------------------------------------------------------------------------- /model/raft/corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from .utils.utils import bilinear_sampler, coords_grid 4 | 5 | try: 6 | import alt_cuda_corr 7 | except: 8 | # alt_cuda_corr is not compiled 9 | pass 10 | 11 | 12 | class CorrBlock: 13 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 14 | self.num_levels = num_levels 15 | self.radius = radius 16 | self.corr_pyramid = [] 17 | 18 | # all pairs correlation 19 | corr = CorrBlock.corr(fmap1, fmap2) 20 | 21 | batch, h1, w1, dim, h2, w2 = corr.shape 22 | corr = corr.reshape(batch*h1*w1, dim, h2, w2) 23 | 24 | self.corr_pyramid.append(corr) 25 | for i in range(self.num_levels-1): 26 | corr = F.avg_pool2d(corr, 2, stride=2) 27 | self.corr_pyramid.append(corr) 28 | 29 | def __call__(self, coords): 30 | r = self.radius 31 | coords = coords.permute(0, 2, 3, 1) 32 | batch, h1, w1, _ = coords.shape 33 | 34 | out_pyramid = [] 35 | for i in range(self.num_levels): 36 | corr = self.corr_pyramid[i] 37 | dx = torch.linspace(-r, r, 2*r+1, device=coords.device) 38 | dy = torch.linspace(-r, r, 2*r+1, device=coords.device) 39 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1) 40 | 41 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i 42 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) 43 | coords_lvl = centroid_lvl + delta_lvl 44 | 45 | corr = bilinear_sampler(corr, coords_lvl) 46 | corr = corr.view(batch, h1, w1, -1) 47 | out_pyramid.append(corr) 48 | 49 | out = torch.cat(out_pyramid, dim=-1) 50 | return out.permute(0, 3, 1, 2).contiguous().float() 51 | 52 | @staticmethod 53 | def corr(fmap1, fmap2): 54 | batch, dim, ht, wd = fmap1.shape 55 | fmap1 = fmap1.view(batch, dim, ht*wd) 56 | fmap2 = fmap2.view(batch, dim, ht*wd) 57 | 58 | corr = torch.matmul(fmap1.transpose(1,2), fmap2) 59 | corr = corr.view(batch, ht, wd, 1, ht, wd) 60 | return corr / torch.sqrt(torch.tensor(dim).float()) 61 | 62 | 63 | class AlternateCorrBlock: 64 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 65 | self.num_levels = num_levels 66 | self.radius = radius 67 | 68 | self.pyramid = [(fmap1, fmap2)] 69 | for i in range(self.num_levels): 70 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2) 71 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2) 72 | self.pyramid.append((fmap1, fmap2)) 73 | 74 | def __call__(self, coords): 75 | coords = coords.permute(0, 2, 3, 1) 76 | B, H, W, _ = coords.shape 77 | dim = self.pyramid[0][0].shape[1] 78 | 79 | corr_list = [] 80 | for i in range(self.num_levels): 81 | r = self.radius 82 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() 83 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() 84 | 85 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() 86 | corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) 87 | corr_list.append(corr.squeeze(1)) 88 | 89 | corr = torch.stack(corr_list, dim=1) 90 | corr = corr.reshape(B, -1, H, W) 91 | return corr / torch.sqrt(torch.tensor(dim).float()) 92 | -------------------------------------------------------------------------------- /model/raft/datasets.py: -------------------------------------------------------------------------------- 1 | # Data loading based on https://github.com/NVIDIA/flownet2-pytorch 2 | 3 | import numpy as np 4 | import torch 5 | import torch.utils.data as data 6 | import torch.nn.functional as F 7 | 8 | import os 9 | import math 10 | import random 11 | from glob import glob 12 | import os.path as osp 13 | 14 | from utils import frame_utils 15 | from utils.augmentor import FlowAugmentor, SparseFlowAugmentor 16 | 17 | 18 | class FlowDataset(data.Dataset): 19 | def __init__(self, aug_params=None, sparse=False): 20 | self.augmentor = None 21 | self.sparse = sparse 22 | if aug_params is not None: 23 | if sparse: 24 | self.augmentor = SparseFlowAugmentor(**aug_params) 25 | else: 26 | self.augmentor = FlowAugmentor(**aug_params) 27 | 28 | self.is_test = False 29 | self.init_seed = False 30 | self.flow_list = [] 31 | self.image_list = [] 32 | self.extra_info = [] 33 | 34 | def __getitem__(self, index): 35 | 36 | if self.is_test: 37 | img1 = frame_utils.read_gen(self.image_list[index][0]) 38 | img2 = frame_utils.read_gen(self.image_list[index][1]) 39 | img1 = np.array(img1).astype(np.uint8)[..., :3] 40 | img2 = np.array(img2).astype(np.uint8)[..., :3] 41 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float() 42 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float() 43 | return img1, img2, self.extra_info[index] 44 | 45 | if not self.init_seed: 46 | worker_info = torch.utils.data.get_worker_info() 47 | if worker_info is not None: 48 | torch.manual_seed(worker_info.id) 49 | np.random.seed(worker_info.id) 50 | random.seed(worker_info.id) 51 | self.init_seed = True 52 | 53 | index = index % len(self.image_list) 54 | valid = None 55 | if self.sparse: 56 | flow, valid = frame_utils.readFlowKITTI(self.flow_list[index]) 57 | else: 58 | flow = frame_utils.read_gen(self.flow_list[index]) 59 | 60 | img1 = frame_utils.read_gen(self.image_list[index][0]) 61 | img2 = frame_utils.read_gen(self.image_list[index][1]) 62 | 63 | flow = np.array(flow).astype(np.float32) 64 | img1 = np.array(img1).astype(np.uint8) 65 | img2 = np.array(img2).astype(np.uint8) 66 | 67 | # grayscale images 68 | if len(img1.shape) == 2: 69 | img1 = np.tile(img1[...,None], (1, 1, 3)) 70 | img2 = np.tile(img2[...,None], (1, 1, 3)) 71 | else: 72 | img1 = img1[..., :3] 73 | img2 = img2[..., :3] 74 | 75 | if self.augmentor is not None: 76 | if self.sparse: 77 | img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid) 78 | else: 79 | img1, img2, flow = self.augmentor(img1, img2, flow) 80 | 81 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float() 82 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float() 83 | flow = torch.from_numpy(flow).permute(2, 0, 1).float() 84 | 85 | if valid is not None: 86 | valid = torch.from_numpy(valid) 87 | else: 88 | valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000) 89 | 90 | return img1, img2, flow, valid.float() 91 | 92 | 93 | def __rmul__(self, v): 94 | self.flow_list = v * self.flow_list 95 | self.image_list = v * self.image_list 96 | return self 97 | 98 | def __len__(self): 99 | return len(self.image_list) 100 | 101 | 102 | class MpiSintel(FlowDataset): 103 | def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'): 104 | super(MpiSintel, self).__init__(aug_params) 105 | flow_root = osp.join(root, split, 'flow') 106 | image_root = osp.join(root, split, dstype) 107 | 108 | if split == 'test': 109 | self.is_test = True 110 | 111 | for scene in os.listdir(image_root): 112 | image_list = sorted(glob(osp.join(image_root, scene, '*.png'))) 113 | for i in range(len(image_list)-1): 114 | self.image_list += [ [image_list[i], image_list[i+1]] ] 115 | self.extra_info += [ (scene, i) ] # scene and frame_id 116 | 117 | if split != 'test': 118 | self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo'))) 119 | 120 | 121 | class FlyingChairs(FlowDataset): 122 | def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'): 123 | super(FlyingChairs, self).__init__(aug_params) 124 | 125 | images = sorted(glob(osp.join(root, '*.ppm'))) 126 | flows = sorted(glob(osp.join(root, '*.flo'))) 127 | assert (len(images)//2 == len(flows)) 128 | 129 | split_list = np.loadtxt('chairs_split.txt', dtype=np.int32) 130 | for i in range(len(flows)): 131 | xid = split_list[i] 132 | if (split=='training' and xid==1) or (split=='validation' and xid==2): 133 | self.flow_list += [ flows[i] ] 134 | self.image_list += [ [images[2*i], images[2*i+1]] ] 135 | 136 | 137 | class FlyingThings3D(FlowDataset): 138 | def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'): 139 | super(FlyingThings3D, self).__init__(aug_params) 140 | 141 | for cam in ['left']: 142 | for direction in ['into_future', 'into_past']: 143 | image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*'))) 144 | image_dirs = sorted([osp.join(f, cam) for f in image_dirs]) 145 | 146 | flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*'))) 147 | flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs]) 148 | 149 | for idir, fdir in zip(image_dirs, flow_dirs): 150 | images = sorted(glob(osp.join(idir, '*.png')) ) 151 | flows = sorted(glob(osp.join(fdir, '*.pfm')) ) 152 | for i in range(len(flows)-1): 153 | if direction == 'into_future': 154 | self.image_list += [ [images[i], images[i+1]] ] 155 | self.flow_list += [ flows[i] ] 156 | elif direction == 'into_past': 157 | self.image_list += [ [images[i+1], images[i]] ] 158 | self.flow_list += [ flows[i+1] ] 159 | 160 | 161 | class KITTI(FlowDataset): 162 | def __init__(self, aug_params=None, split='training', root='datasets/KITTI'): 163 | super(KITTI, self).__init__(aug_params, sparse=True) 164 | if split == 'testing': 165 | self.is_test = True 166 | 167 | root = osp.join(root, split) 168 | images1 = sorted(glob(osp.join(root, 'image_2/*_10.png'))) 169 | images2 = sorted(glob(osp.join(root, 'image_2/*_11.png'))) 170 | 171 | for img1, img2 in zip(images1, images2): 172 | frame_id = img1.split('/')[-1] 173 | self.extra_info += [ [frame_id] ] 174 | self.image_list += [ [img1, img2] ] 175 | 176 | if split == 'training': 177 | self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png'))) 178 | 179 | 180 | class HD1K(FlowDataset): 181 | def __init__(self, aug_params=None, root='datasets/HD1k'): 182 | super(HD1K, self).__init__(aug_params, sparse=True) 183 | 184 | seq_ix = 0 185 | while 1: 186 | flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix))) 187 | images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix))) 188 | 189 | if len(flows) == 0: 190 | break 191 | 192 | for i in range(len(flows)-1): 193 | self.flow_list += [flows[i]] 194 | self.image_list += [ [images[i], images[i+1]] ] 195 | 196 | seq_ix += 1 197 | 198 | 199 | def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'): 200 | """ Create the data loader for the corresponding trainign set """ 201 | 202 | if args.stage == 'chairs': 203 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True} 204 | train_dataset = FlyingChairs(aug_params, split='training') 205 | 206 | elif args.stage == 'things': 207 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True} 208 | clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass') 209 | final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass') 210 | train_dataset = clean_dataset + final_dataset 211 | 212 | elif args.stage == 'sintel': 213 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True} 214 | things = FlyingThings3D(aug_params, dstype='frames_cleanpass') 215 | sintel_clean = MpiSintel(aug_params, split='training', dstype='clean') 216 | sintel_final = MpiSintel(aug_params, split='training', dstype='final') 217 | 218 | if TRAIN_DS == 'C+T+K+S+H': 219 | kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True}) 220 | hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True}) 221 | train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things 222 | 223 | elif TRAIN_DS == 'C+T+K/S': 224 | train_dataset = 100*sintel_clean + 100*sintel_final + things 225 | 226 | elif args.stage == 'kitti': 227 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False} 228 | train_dataset = KITTI(aug_params, split='training') 229 | 230 | train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, 231 | pin_memory=False, shuffle=True, num_workers=4, drop_last=True) 232 | 233 | print('Training with %d image pairs' % len(train_dataset)) 234 | return train_loader 235 | 236 | -------------------------------------------------------------------------------- /model/raft/extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ResidualBlock(nn.Module): 7 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 8 | super(ResidualBlock, self).__init__() 9 | 10 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) 11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 12 | self.relu = nn.ReLU(inplace=True) 13 | 14 | num_groups = planes // 8 15 | 16 | if norm_fn == 'group': 17 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 18 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 19 | if not stride == 1: 20 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 21 | 22 | elif norm_fn == 'batch': 23 | self.norm1 = nn.BatchNorm2d(planes) 24 | self.norm2 = nn.BatchNorm2d(planes) 25 | if not stride == 1: 26 | self.norm3 = nn.BatchNorm2d(planes) 27 | 28 | elif norm_fn == 'instance': 29 | self.norm1 = nn.InstanceNorm2d(planes) 30 | self.norm2 = nn.InstanceNorm2d(planes) 31 | if not stride == 1: 32 | self.norm3 = nn.InstanceNorm2d(planes) 33 | 34 | elif norm_fn == 'none': 35 | self.norm1 = nn.Sequential() 36 | self.norm2 = nn.Sequential() 37 | if not stride == 1: 38 | self.norm3 = nn.Sequential() 39 | 40 | if stride == 1: 41 | self.downsample = None 42 | 43 | else: 44 | self.downsample = nn.Sequential( 45 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 46 | 47 | 48 | def forward(self, x): 49 | y = x 50 | y = self.relu(self.norm1(self.conv1(y))) 51 | y = self.relu(self.norm2(self.conv2(y))) 52 | 53 | if self.downsample is not None: 54 | x = self.downsample(x) 55 | 56 | return self.relu(x+y) 57 | 58 | 59 | 60 | class BottleneckBlock(nn.Module): 61 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 62 | super(BottleneckBlock, self).__init__() 63 | 64 | self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) 65 | self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) 66 | self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) 67 | self.relu = nn.ReLU(inplace=True) 68 | 69 | num_groups = planes // 8 70 | 71 | if norm_fn == 'group': 72 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) 73 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) 74 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 75 | if not stride == 1: 76 | self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 77 | 78 | elif norm_fn == 'batch': 79 | self.norm1 = nn.BatchNorm2d(planes//4) 80 | self.norm2 = nn.BatchNorm2d(planes//4) 81 | self.norm3 = nn.BatchNorm2d(planes) 82 | if not stride == 1: 83 | self.norm4 = nn.BatchNorm2d(planes) 84 | 85 | elif norm_fn == 'instance': 86 | self.norm1 = nn.InstanceNorm2d(planes//4) 87 | self.norm2 = nn.InstanceNorm2d(planes//4) 88 | self.norm3 = nn.InstanceNorm2d(planes) 89 | if not stride == 1: 90 | self.norm4 = nn.InstanceNorm2d(planes) 91 | 92 | elif norm_fn == 'none': 93 | self.norm1 = nn.Sequential() 94 | self.norm2 = nn.Sequential() 95 | self.norm3 = nn.Sequential() 96 | if not stride == 1: 97 | self.norm4 = nn.Sequential() 98 | 99 | if stride == 1: 100 | self.downsample = None 101 | 102 | else: 103 | self.downsample = nn.Sequential( 104 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) 105 | 106 | 107 | def forward(self, x): 108 | y = x 109 | y = self.relu(self.norm1(self.conv1(y))) 110 | y = self.relu(self.norm2(self.conv2(y))) 111 | y = self.relu(self.norm3(self.conv3(y))) 112 | 113 | if self.downsample is not None: 114 | x = self.downsample(x) 115 | 116 | return self.relu(x+y) 117 | 118 | class BasicEncoder(nn.Module): 119 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): 120 | super(BasicEncoder, self).__init__() 121 | self.norm_fn = norm_fn 122 | 123 | if self.norm_fn == 'group': 124 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) 125 | 126 | elif self.norm_fn == 'batch': 127 | self.norm1 = nn.BatchNorm2d(64) 128 | 129 | elif self.norm_fn == 'instance': 130 | self.norm1 = nn.InstanceNorm2d(64) 131 | 132 | elif self.norm_fn == 'none': 133 | self.norm1 = nn.Sequential() 134 | 135 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) 136 | self.relu1 = nn.ReLU(inplace=True) 137 | 138 | self.in_planes = 64 139 | self.layer1 = self._make_layer(64, stride=1) 140 | self.layer2 = self._make_layer(96, stride=2) 141 | self.layer3 = self._make_layer(128, stride=2) 142 | 143 | # output convolution 144 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) 145 | 146 | self.dropout = None 147 | if dropout > 0: 148 | self.dropout = nn.Dropout2d(p=dropout) 149 | 150 | for m in self.modules(): 151 | if isinstance(m, nn.Conv2d): 152 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 153 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 154 | if m.weight is not None: 155 | nn.init.constant_(m.weight, 1) 156 | if m.bias is not None: 157 | nn.init.constant_(m.bias, 0) 158 | 159 | def _make_layer(self, dim, stride=1): 160 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) 161 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) 162 | layers = (layer1, layer2) 163 | 164 | self.in_planes = dim 165 | return nn.Sequential(*layers) 166 | 167 | 168 | def forward(self, x): 169 | 170 | # if input is list, combine batch dimension 171 | is_list = isinstance(x, tuple) or isinstance(x, list) 172 | if is_list: 173 | batch_dim = x[0].shape[0] 174 | x = torch.cat(x, dim=0) 175 | 176 | x = self.conv1(x) 177 | x = self.norm1(x) 178 | x = self.relu1(x) 179 | 180 | x = self.layer1(x) 181 | x = self.layer2(x) 182 | x = self.layer3(x) 183 | 184 | x = self.conv2(x) 185 | 186 | if self.training and self.dropout is not None: 187 | x = self.dropout(x) 188 | 189 | if is_list: 190 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 191 | 192 | return x 193 | 194 | 195 | class SmallEncoder(nn.Module): 196 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): 197 | super(SmallEncoder, self).__init__() 198 | self.norm_fn = norm_fn 199 | 200 | if self.norm_fn == 'group': 201 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) 202 | 203 | elif self.norm_fn == 'batch': 204 | self.norm1 = nn.BatchNorm2d(32) 205 | 206 | elif self.norm_fn == 'instance': 207 | self.norm1 = nn.InstanceNorm2d(32) 208 | 209 | elif self.norm_fn == 'none': 210 | self.norm1 = nn.Sequential() 211 | 212 | self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) 213 | self.relu1 = nn.ReLU(inplace=True) 214 | 215 | self.in_planes = 32 216 | self.layer1 = self._make_layer(32, stride=1) 217 | self.layer2 = self._make_layer(64, stride=2) 218 | self.layer3 = self._make_layer(96, stride=2) 219 | 220 | self.dropout = None 221 | if dropout > 0: 222 | self.dropout = nn.Dropout2d(p=dropout) 223 | 224 | self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) 225 | 226 | for m in self.modules(): 227 | if isinstance(m, nn.Conv2d): 228 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 229 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 230 | if m.weight is not None: 231 | nn.init.constant_(m.weight, 1) 232 | if m.bias is not None: 233 | nn.init.constant_(m.bias, 0) 234 | 235 | def _make_layer(self, dim, stride=1): 236 | layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) 237 | layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) 238 | layers = (layer1, layer2) 239 | 240 | self.in_planes = dim 241 | return nn.Sequential(*layers) 242 | 243 | 244 | def forward(self, x): 245 | 246 | # if input is list, combine batch dimension 247 | is_list = isinstance(x, tuple) or isinstance(x, list) 248 | if is_list: 249 | batch_dim = x[0].shape[0] 250 | x = torch.cat(x, dim=0) 251 | 252 | x = self.conv1(x) 253 | x = self.norm1(x) 254 | x = self.relu1(x) 255 | 256 | x = self.layer1(x) 257 | x = self.layer2(x) 258 | x = self.layer3(x) 259 | x = self.conv2(x) 260 | 261 | if self.training and self.dropout is not None: 262 | x = self.dropout(x) 263 | 264 | if is_list: 265 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 266 | 267 | return x 268 | -------------------------------------------------------------------------------- /model/raft/raft.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .update import BasicUpdateBlock, SmallUpdateBlock 7 | from .extractor import BasicEncoder, SmallEncoder 8 | from .corr import CorrBlock, AlternateCorrBlock 9 | from .utils.utils import bilinear_sampler, coords_grid, upflow8 10 | 11 | 12 | try: 13 | autocast = torch.cuda.amp.autocast 14 | except: 15 | # dummy autocast for PyTorch < 1.6 16 | class autocast: 17 | def __init__(self, enabled): 18 | pass 19 | def __enter__(self): 20 | pass 21 | def __exit__(self, *args): 22 | pass 23 | 24 | 25 | class RAFT(nn.Module): 26 | def __init__(self, args): 27 | super(RAFT, self).__init__() 28 | self.args = args 29 | 30 | self.hidden_dim = hdim = 128 31 | self.context_dim = cdim = 128 32 | args.corr_levels = 4 33 | args.corr_radius = 4 34 | 35 | if 'dropout' not in self.args: 36 | self.args.dropout = 0 37 | 38 | if 'alternate_corr' not in self.args: 39 | self.args.alternate_corr = False 40 | 41 | # feature network, context network, and update block 42 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout) 43 | self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout) 44 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) 45 | 46 | def freeze_bn(self): 47 | for m in self.modules(): 48 | if isinstance(m, nn.BatchNorm2d): 49 | m.eval() 50 | 51 | def initialize_flow(self, img): 52 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 53 | N, C, H, W = img.shape 54 | coords0 = coords_grid(N, H//8, W//8, device=img.device) 55 | coords1 = coords_grid(N, H//8, W//8, device=img.device) 56 | 57 | # optical flow computed as difference: flow = coords1 - coords0 58 | return coords0, coords1 59 | 60 | def upsample_flow(self, flow, mask): 61 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 62 | N, _, H, W = flow.shape 63 | mask = mask.view(N, 1, 9, 8, 8, H, W) 64 | mask = torch.softmax(mask, dim=2) 65 | 66 | up_flow = F.unfold(8 * flow, [3,3], padding=1) 67 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 68 | 69 | up_flow = torch.sum(mask * up_flow, dim=2) 70 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 71 | return up_flow.reshape(N, 2, 8*H, 8*W) 72 | 73 | 74 | def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False): 75 | """ Estimate optical flow between pair of frames """ 76 | 77 | # image1 = 2 * image1 - 1.0 78 | # image2 = 2 * image2 - 1.0 79 | 80 | image1 = image1.contiguous() 81 | image2 = image2.contiguous() 82 | 83 | hdim = self.hidden_dim 84 | cdim = self.context_dim 85 | 86 | # run the feature network 87 | with autocast(enabled=False): 88 | fmap1, fmap2 = self.fnet([image1, image2]) 89 | 90 | fmap1 = fmap1.float() 91 | fmap2 = fmap2.float() 92 | if self.args.alternate_corr: 93 | corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 94 | else: 95 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 96 | 97 | # run the context network 98 | with autocast(enabled=False): 99 | cnet = self.cnet(image1) 100 | net, inp = torch.split(cnet, [hdim, cdim], dim=1) 101 | net = torch.tanh(net) 102 | inp = torch.relu(inp) 103 | 104 | coords0, coords1 = self.initialize_flow(image1) 105 | 106 | if flow_init is not None: 107 | coords1 = coords1 + flow_init 108 | 109 | flow_predictions = [] 110 | for itr in range(iters): 111 | coords1 = coords1.detach() 112 | corr = corr_fn(coords1) # index correlation volume 113 | 114 | flow = coords1 - coords0 115 | with autocast(enabled=False): 116 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) 117 | 118 | # F(t+1) = F(t) + \Delta(t) 119 | coords1 = coords1 + delta_flow 120 | 121 | # upsample predictions 122 | if up_mask is None: 123 | flow_up = upflow8(coords1 - coords0) 124 | else: 125 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 126 | 127 | flow_predictions.append(flow_up) 128 | 129 | if test_mode: 130 | return coords1 - coords0, flow_up 131 | 132 | return flow_predictions 133 | -------------------------------------------------------------------------------- /model/raft/update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FlowHead(nn.Module): 7 | def __init__(self, input_dim=128, hidden_dim=256): 8 | super(FlowHead, self).__init__() 9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 11 | self.relu = nn.ReLU(inplace=True) 12 | 13 | def forward(self, x): 14 | return self.conv2(self.relu(self.conv1(x))) 15 | 16 | class ConvGRU(nn.Module): 17 | def __init__(self, hidden_dim=128, input_dim=192+128): 18 | super(ConvGRU, self).__init__() 19 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 20 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 21 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 22 | 23 | def forward(self, h, x): 24 | hx = torch.cat([h, x], dim=1) 25 | 26 | z = torch.sigmoid(self.convz(hx)) 27 | r = torch.sigmoid(self.convr(hx)) 28 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) 29 | 30 | h = (1-z) * h + z * q 31 | return h 32 | 33 | class SepConvGRU(nn.Module): 34 | def __init__(self, hidden_dim=128, input_dim=192+128): 35 | super(SepConvGRU, self).__init__() 36 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 37 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 38 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 39 | 40 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 41 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 42 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 43 | 44 | 45 | def forward(self, h, x): 46 | # horizontal 47 | hx = torch.cat([h, x], dim=1) 48 | z = torch.sigmoid(self.convz1(hx)) 49 | r = torch.sigmoid(self.convr1(hx)) 50 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) 51 | h = (1-z) * h + z * q 52 | 53 | # vertical 54 | hx = torch.cat([h, x], dim=1) 55 | z = torch.sigmoid(self.convz2(hx)) 56 | r = torch.sigmoid(self.convr2(hx)) 57 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) 58 | h = (1-z) * h + z * q 59 | 60 | return h 61 | 62 | class SmallMotionEncoder(nn.Module): 63 | def __init__(self, args): 64 | super(SmallMotionEncoder, self).__init__() 65 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 66 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) 67 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3) 68 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1) 69 | self.conv = nn.Conv2d(128, 80, 3, padding=1) 70 | 71 | def forward(self, flow, corr): 72 | cor = F.relu(self.convc1(corr)) 73 | flo = F.relu(self.convf1(flow)) 74 | flo = F.relu(self.convf2(flo)) 75 | cor_flo = torch.cat([cor, flo], dim=1) 76 | out = F.relu(self.conv(cor_flo)) 77 | return torch.cat([out, flow], dim=1) 78 | 79 | class BasicMotionEncoder(nn.Module): 80 | def __init__(self, args): 81 | super(BasicMotionEncoder, self).__init__() 82 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 83 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) 84 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 85 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3) 86 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 87 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) 88 | 89 | def forward(self, flow, corr): 90 | cor = F.relu(self.convc1(corr)) 91 | cor = F.relu(self.convc2(cor)) 92 | flo = F.relu(self.convf1(flow)) 93 | flo = F.relu(self.convf2(flo)) 94 | 95 | cor_flo = torch.cat([cor, flo], dim=1) 96 | out = F.relu(self.conv(cor_flo)) 97 | return torch.cat([out, flow], dim=1) 98 | 99 | class SmallUpdateBlock(nn.Module): 100 | def __init__(self, args, hidden_dim=96): 101 | super(SmallUpdateBlock, self).__init__() 102 | self.encoder = SmallMotionEncoder(args) 103 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) 104 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128) 105 | 106 | def forward(self, net, inp, corr, flow): 107 | motion_features = self.encoder(flow, corr) 108 | inp = torch.cat([inp, motion_features], dim=1) 109 | net = self.gru(net, inp) 110 | delta_flow = self.flow_head(net) 111 | 112 | return net, None, delta_flow 113 | 114 | class BasicUpdateBlock(nn.Module): 115 | def __init__(self, args, hidden_dim=128, input_dim=128): 116 | super(BasicUpdateBlock, self).__init__() 117 | self.args = args 118 | self.encoder = BasicMotionEncoder(args) 119 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) 120 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 121 | 122 | self.mask = nn.Sequential( 123 | nn.Conv2d(128, 256, 3, padding=1), 124 | nn.ReLU(inplace=True), 125 | nn.Conv2d(256, 64*9, 1, padding=0)) 126 | 127 | def forward(self, net, inp, corr, flow, upsample=True): 128 | motion_features = self.encoder(flow, corr) 129 | inp = torch.cat([inp, motion_features], dim=1) 130 | 131 | net = self.gru(net, inp) 132 | delta_flow = self.flow_head(net) 133 | 134 | # scale mask to balence gradients 135 | mask = .25 * self.mask(net) 136 | return net, mask, delta_flow 137 | 138 | 139 | 140 | -------------------------------------------------------------------------------- /model/raft/utils/augmentor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import math 4 | from PIL import Image 5 | 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | import torch 11 | from torchvision.transforms import ColorJitter 12 | import torch.nn.functional as F 13 | 14 | 15 | class FlowAugmentor: 16 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True): 17 | 18 | # spatial augmentation params 19 | self.crop_size = crop_size 20 | self.min_scale = min_scale 21 | self.max_scale = max_scale 22 | self.spatial_aug_prob = 0.8 23 | self.stretch_prob = 0.8 24 | self.max_stretch = 0.2 25 | 26 | # flip augmentation params 27 | self.do_flip = do_flip 28 | self.h_flip_prob = 0.5 29 | self.v_flip_prob = 0.1 30 | 31 | # photometric augmentation params 32 | self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14) 33 | self.asymmetric_color_aug_prob = 0.2 34 | self.eraser_aug_prob = 0.5 35 | 36 | def color_transform(self, img1, img2): 37 | """ Photometric augmentation """ 38 | 39 | # asymmetric 40 | if np.random.rand() < self.asymmetric_color_aug_prob: 41 | img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8) 42 | img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8) 43 | 44 | # symmetric 45 | else: 46 | image_stack = np.concatenate([img1, img2], axis=0) 47 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 48 | img1, img2 = np.split(image_stack, 2, axis=0) 49 | 50 | return img1, img2 51 | 52 | def eraser_transform(self, img1, img2, bounds=[50, 100]): 53 | """ Occlusion augmentation """ 54 | 55 | ht, wd = img1.shape[:2] 56 | if np.random.rand() < self.eraser_aug_prob: 57 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 58 | for _ in range(np.random.randint(1, 3)): 59 | x0 = np.random.randint(0, wd) 60 | y0 = np.random.randint(0, ht) 61 | dx = np.random.randint(bounds[0], bounds[1]) 62 | dy = np.random.randint(bounds[0], bounds[1]) 63 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color 64 | 65 | return img1, img2 66 | 67 | def spatial_transform(self, img1, img2, flow): 68 | # randomly sample scale 69 | ht, wd = img1.shape[:2] 70 | min_scale = np.maximum( 71 | (self.crop_size[0] + 8) / float(ht), 72 | (self.crop_size[1] + 8) / float(wd)) 73 | 74 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 75 | scale_x = scale 76 | scale_y = scale 77 | if np.random.rand() < self.stretch_prob: 78 | scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 79 | scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 80 | 81 | scale_x = np.clip(scale_x, min_scale, None) 82 | scale_y = np.clip(scale_y, min_scale, None) 83 | 84 | if np.random.rand() < self.spatial_aug_prob: 85 | # rescale the images 86 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 87 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 88 | flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 89 | flow = flow * [scale_x, scale_y] 90 | 91 | if self.do_flip: 92 | if np.random.rand() < self.h_flip_prob: # h-flip 93 | img1 = img1[:, ::-1] 94 | img2 = img2[:, ::-1] 95 | flow = flow[:, ::-1] * [-1.0, 1.0] 96 | 97 | if np.random.rand() < self.v_flip_prob: # v-flip 98 | img1 = img1[::-1, :] 99 | img2 = img2[::-1, :] 100 | flow = flow[::-1, :] * [1.0, -1.0] 101 | 102 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) 103 | x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) 104 | 105 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 106 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 107 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 108 | 109 | return img1, img2, flow 110 | 111 | def __call__(self, img1, img2, flow): 112 | img1, img2 = self.color_transform(img1, img2) 113 | img1, img2 = self.eraser_transform(img1, img2) 114 | img1, img2, flow = self.spatial_transform(img1, img2, flow) 115 | 116 | img1 = np.ascontiguousarray(img1) 117 | img2 = np.ascontiguousarray(img2) 118 | flow = np.ascontiguousarray(flow) 119 | 120 | return img1, img2, flow 121 | 122 | class SparseFlowAugmentor: 123 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False): 124 | # spatial augmentation params 125 | self.crop_size = crop_size 126 | self.min_scale = min_scale 127 | self.max_scale = max_scale 128 | self.spatial_aug_prob = 0.8 129 | self.stretch_prob = 0.8 130 | self.max_stretch = 0.2 131 | 132 | # flip augmentation params 133 | self.do_flip = do_flip 134 | self.h_flip_prob = 0.5 135 | self.v_flip_prob = 0.1 136 | 137 | # photometric augmentation params 138 | self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14) 139 | self.asymmetric_color_aug_prob = 0.2 140 | self.eraser_aug_prob = 0.5 141 | 142 | def color_transform(self, img1, img2): 143 | image_stack = np.concatenate([img1, img2], axis=0) 144 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 145 | img1, img2 = np.split(image_stack, 2, axis=0) 146 | return img1, img2 147 | 148 | def eraser_transform(self, img1, img2): 149 | ht, wd = img1.shape[:2] 150 | if np.random.rand() < self.eraser_aug_prob: 151 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 152 | for _ in range(np.random.randint(1, 3)): 153 | x0 = np.random.randint(0, wd) 154 | y0 = np.random.randint(0, ht) 155 | dx = np.random.randint(50, 100) 156 | dy = np.random.randint(50, 100) 157 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color 158 | 159 | return img1, img2 160 | 161 | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): 162 | ht, wd = flow.shape[:2] 163 | coords = np.meshgrid(np.arange(wd), np.arange(ht)) 164 | coords = np.stack(coords, axis=-1) 165 | 166 | coords = coords.reshape(-1, 2).astype(np.float32) 167 | flow = flow.reshape(-1, 2).astype(np.float32) 168 | valid = valid.reshape(-1).astype(np.float32) 169 | 170 | coords0 = coords[valid>=1] 171 | flow0 = flow[valid>=1] 172 | 173 | ht1 = int(round(ht * fy)) 174 | wd1 = int(round(wd * fx)) 175 | 176 | coords1 = coords0 * [fx, fy] 177 | flow1 = flow0 * [fx, fy] 178 | 179 | xx = np.round(coords1[:,0]).astype(np.int32) 180 | yy = np.round(coords1[:,1]).astype(np.int32) 181 | 182 | v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) 183 | xx = xx[v] 184 | yy = yy[v] 185 | flow1 = flow1[v] 186 | 187 | flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) 188 | valid_img = np.zeros([ht1, wd1], dtype=np.int32) 189 | 190 | flow_img[yy, xx] = flow1 191 | valid_img[yy, xx] = 1 192 | 193 | return flow_img, valid_img 194 | 195 | def spatial_transform(self, img1, img2, flow, valid): 196 | # randomly sample scale 197 | 198 | ht, wd = img1.shape[:2] 199 | min_scale = np.maximum( 200 | (self.crop_size[0] + 1) / float(ht), 201 | (self.crop_size[1] + 1) / float(wd)) 202 | 203 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 204 | scale_x = np.clip(scale, min_scale, None) 205 | scale_y = np.clip(scale, min_scale, None) 206 | 207 | if np.random.rand() < self.spatial_aug_prob: 208 | # rescale the images 209 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 210 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 211 | flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y) 212 | 213 | if self.do_flip: 214 | if np.random.rand() < 0.5: # h-flip 215 | img1 = img1[:, ::-1] 216 | img2 = img2[:, ::-1] 217 | flow = flow[:, ::-1] * [-1.0, 1.0] 218 | valid = valid[:, ::-1] 219 | 220 | margin_y = 20 221 | margin_x = 50 222 | 223 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y) 224 | x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x) 225 | 226 | y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) 227 | x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) 228 | 229 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 230 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 231 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 232 | valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 233 | return img1, img2, flow, valid 234 | 235 | 236 | def __call__(self, img1, img2, flow, valid): 237 | img1, img2 = self.color_transform(img1, img2) 238 | img1, img2 = self.eraser_transform(img1, img2) 239 | img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid) 240 | 241 | img1 = np.ascontiguousarray(img1) 242 | img2 = np.ascontiguousarray(img2) 243 | flow = np.ascontiguousarray(flow) 244 | valid = np.ascontiguousarray(valid) 245 | 246 | return img1, img2, flow, valid 247 | -------------------------------------------------------------------------------- /model/raft/utils/flow_viz.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization 2 | 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Tom Runia 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to conditions. 14 | # 15 | # Author: Tom Runia 16 | # Date Created: 2018-08-03 17 | 18 | import numpy as np 19 | 20 | def make_colorwheel(): 21 | """ 22 | Generates a color wheel for optical flow visualization as presented in: 23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 25 | 26 | Code follows the original C++ source code of Daniel Scharstein. 27 | Code follows the the Matlab source code of Deqing Sun. 28 | 29 | Returns: 30 | np.ndarray: Color wheel 31 | """ 32 | 33 | RY = 15 34 | YG = 6 35 | GC = 4 36 | CB = 11 37 | BM = 13 38 | MR = 6 39 | 40 | ncols = RY + YG + GC + CB + BM + MR 41 | colorwheel = np.zeros((ncols, 3)) 42 | col = 0 43 | 44 | # RY 45 | colorwheel[0:RY, 0] = 255 46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) 47 | col = col+RY 48 | # YG 49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) 50 | colorwheel[col:col+YG, 1] = 255 51 | col = col+YG 52 | # GC 53 | colorwheel[col:col+GC, 1] = 255 54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) 55 | col = col+GC 56 | # CB 57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) 58 | colorwheel[col:col+CB, 2] = 255 59 | col = col+CB 60 | # BM 61 | colorwheel[col:col+BM, 2] = 255 62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) 63 | col = col+BM 64 | # MR 65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) 66 | colorwheel[col:col+MR, 0] = 255 67 | return colorwheel 68 | 69 | 70 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 71 | """ 72 | Applies the flow color wheel to (possibly clipped) flow components u and v. 73 | 74 | According to the C++ source code of Daniel Scharstein 75 | According to the Matlab source code of Deqing Sun 76 | 77 | Args: 78 | u (np.ndarray): Input horizontal flow of shape [H,W] 79 | v (np.ndarray): Input vertical flow of shape [H,W] 80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 81 | 82 | Returns: 83 | np.ndarray: Flow visualization image of shape [H,W,3] 84 | """ 85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 86 | colorwheel = make_colorwheel() # shape [55x3] 87 | ncols = colorwheel.shape[0] 88 | rad = np.sqrt(np.square(u) + np.square(v)) 89 | a = np.arctan2(-v, -u)/np.pi 90 | fk = (a+1) / 2*(ncols-1) 91 | k0 = np.floor(fk).astype(np.int32) 92 | k1 = k0 + 1 93 | k1[k1 == ncols] = 0 94 | f = fk - k0 95 | for i in range(colorwheel.shape[1]): 96 | tmp = colorwheel[:,i] 97 | col0 = tmp[k0] / 255.0 98 | col1 = tmp[k1] / 255.0 99 | col = (1-f)*col0 + f*col1 100 | idx = (rad <= 1) 101 | col[idx] = 1 - rad[idx] * (1-col[idx]) 102 | col[~idx] = col[~idx] * 0.75 # out of range 103 | # Note the 2-i => BGR instead of RGB 104 | ch_idx = 2-i if convert_to_bgr else i 105 | flow_image[:,:,ch_idx] = np.floor(255 * col) 106 | return flow_image 107 | 108 | 109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): 110 | """ 111 | Expects a two dimensional flow image of shape. 112 | 113 | Args: 114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 117 | 118 | Returns: 119 | np.ndarray: Flow visualization image of shape [H,W,3] 120 | """ 121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 123 | if clip_flow is not None: 124 | flow_uv = np.clip(flow_uv, 0, clip_flow) 125 | u = flow_uv[:,:,0] 126 | v = flow_uv[:,:,1] 127 | rad = np.sqrt(np.square(u) + np.square(v)) 128 | rad_max = np.max(rad) 129 | epsilon = 1e-5 130 | u = u / (rad_max + epsilon) 131 | v = v / (rad_max + epsilon) 132 | return flow_uv_to_colors(u, v, convert_to_bgr) -------------------------------------------------------------------------------- /model/raft/utils/frame_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from os.path import * 4 | import re 5 | 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | TAG_CHAR = np.array([202021.25], np.float32) 11 | 12 | def readFlow(fn): 13 | """ Read .flo file in Middlebury format""" 14 | # Code adapted from: 15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 16 | 17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 18 | # print 'fn = %s'%(fn) 19 | with open(fn, 'rb') as f: 20 | magic = np.fromfile(f, np.float32, count=1) 21 | if 202021.25 != magic: 22 | print('Magic number incorrect. Invalid .flo file') 23 | return None 24 | else: 25 | w = np.fromfile(f, np.int32, count=1) 26 | h = np.fromfile(f, np.int32, count=1) 27 | # print 'Reading %d x %d flo file\n' % (w, h) 28 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) 29 | # Reshape data into 3D array (columns, rows, bands) 30 | # The reshape here is for visualization, the original code is (w,h,2) 31 | return np.resize(data, (int(h), int(w), 2)) 32 | 33 | def readPFM(file): 34 | file = open(file, 'rb') 35 | 36 | color = None 37 | width = None 38 | height = None 39 | scale = None 40 | endian = None 41 | 42 | header = file.readline().rstrip() 43 | if header == b'PF': 44 | color = True 45 | elif header == b'Pf': 46 | color = False 47 | else: 48 | raise Exception('Not a PFM file.') 49 | 50 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) 51 | if dim_match: 52 | width, height = map(int, dim_match.groups()) 53 | else: 54 | raise Exception('Malformed PFM header.') 55 | 56 | scale = float(file.readline().rstrip()) 57 | if scale < 0: # little-endian 58 | endian = '<' 59 | scale = -scale 60 | else: 61 | endian = '>' # big-endian 62 | 63 | data = np.fromfile(file, endian + 'f') 64 | shape = (height, width, 3) if color else (height, width) 65 | 66 | data = np.reshape(data, shape) 67 | data = np.flipud(data) 68 | return data 69 | 70 | def writeFlow(filename,uv,v=None): 71 | """ Write optical flow to file. 72 | 73 | If v is None, uv is assumed to contain both u and v channels, 74 | stacked in depth. 75 | Original code by Deqing Sun, adapted from Daniel Scharstein. 76 | """ 77 | nBands = 2 78 | 79 | if v is None: 80 | assert(uv.ndim == 3) 81 | assert(uv.shape[2] == 2) 82 | u = uv[:,:,0] 83 | v = uv[:,:,1] 84 | else: 85 | u = uv 86 | 87 | assert(u.shape == v.shape) 88 | height,width = u.shape 89 | f = open(filename,'wb') 90 | # write the header 91 | f.write(TAG_CHAR) 92 | np.array(width).astype(np.int32).tofile(f) 93 | np.array(height).astype(np.int32).tofile(f) 94 | # arrange into matrix form 95 | tmp = np.zeros((height, width*nBands)) 96 | tmp[:,np.arange(width)*2] = u 97 | tmp[:,np.arange(width)*2 + 1] = v 98 | tmp.astype(np.float32).tofile(f) 99 | f.close() 100 | 101 | 102 | def readFlowKITTI(filename): 103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) 104 | flow = flow[:,:,::-1].astype(np.float32) 105 | flow, valid = flow[:, :, :2], flow[:, :, 2] 106 | flow = (flow - 2**15) / 64.0 107 | return flow, valid 108 | 109 | def readDispKITTI(filename): 110 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 111 | valid = disp > 0.0 112 | flow = np.stack([-disp, np.zeros_like(disp)], -1) 113 | return flow, valid 114 | 115 | 116 | def writeFlowKITTI(filename, uv): 117 | uv = 64.0 * uv + 2**15 118 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 119 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 120 | cv2.imwrite(filename, uv[..., ::-1]) 121 | 122 | 123 | def read_gen(file_name, pil=False): 124 | ext = splitext(file_name)[-1] 125 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': 126 | return Image.open(file_name) 127 | elif ext == '.bin' or ext == '.raw': 128 | return np.load(file_name) 129 | elif ext == '.flo': 130 | return readFlow(file_name).astype(np.float32) 131 | elif ext == '.pfm': 132 | flow = readPFM(file_name).astype(np.float32) 133 | if len(flow.shape) == 2: 134 | return flow 135 | else: 136 | return flow[:, :, :-1] 137 | return [] -------------------------------------------------------------------------------- /model/raft/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy import interpolate 5 | 6 | 7 | class InputPadder: 8 | """ Pads images such that dimensions are divisible by 8 """ 9 | def __init__(self, dims, mode='sintel'): 10 | self.ht, self.wd = dims[-2:] 11 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 12 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 13 | if mode == 'sintel': 14 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 15 | else: 16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] 17 | 18 | def pad(self, *inputs): 19 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 20 | 21 | def unpad(self,x): 22 | ht, wd = x.shape[-2:] 23 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 24 | return x[..., c[0]:c[1], c[2]:c[3]] 25 | 26 | def forward_interpolate(flow): 27 | flow = flow.detach().cpu().numpy() 28 | dx, dy = flow[0], flow[1] 29 | 30 | ht, wd = dx.shape 31 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 32 | 33 | x1 = x0 + dx 34 | y1 = y0 + dy 35 | 36 | x1 = x1.reshape(-1) 37 | y1 = y1.reshape(-1) 38 | dx = dx.reshape(-1) 39 | dy = dy.reshape(-1) 40 | 41 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 42 | x1 = x1[valid] 43 | y1 = y1[valid] 44 | dx = dx[valid] 45 | dy = dy[valid] 46 | 47 | flow_x = interpolate.griddata( 48 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) 49 | 50 | flow_y = interpolate.griddata( 51 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) 52 | 53 | flow = np.stack([flow_x, flow_y], axis=0) 54 | return torch.from_numpy(flow).float() 55 | 56 | 57 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 58 | """ Wrapper for grid_sample, uses pixel coordinates """ 59 | H, W = img.shape[-2:] 60 | xgrid, ygrid = coords.split([1,1], dim=-1) 61 | xgrid = 2*xgrid/(W-1) - 1 62 | ygrid = 2*ygrid/(H-1) - 1 63 | 64 | grid = torch.cat([xgrid, ygrid], dim=-1) 65 | img = F.grid_sample(img, grid, align_corners=True) 66 | 67 | if mask: 68 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 69 | return img, mask.float() 70 | 71 | return img 72 | 73 | 74 | def coords_grid(batch, ht, wd, device): 75 | coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device)) 76 | coords = torch.stack(coords[::-1], dim=0).float() 77 | return coords[None].repeat(batch, 1, 1, 1) 78 | 79 | 80 | def upflow8(flow, mode='bilinear'): 81 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 82 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 83 | -------------------------------------------------------------------------------- /model/util.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import os 12 | import math 13 | from inspect import isfunction 14 | import torch 15 | import torch.nn as nn 16 | import numpy as np 17 | from einops import repeat 18 | 19 | 20 | def exists(val): 21 | return val is not None 22 | 23 | 24 | def default(val, d): 25 | if exists(val): 26 | return val 27 | return d() if isfunction(d) else d 28 | 29 | 30 | def checkpoint(func, inputs, params, flag): 31 | """ 32 | Evaluate a function without caching intermediate activations, allowing for 33 | reduced memory at the expense of extra compute in the backward pass. 34 | :param func: the function to evaluate. 35 | :param inputs: the argument sequence to pass to `func`. 36 | :param params: a sequence of parameters `func` depends on but does not 37 | explicitly take as arguments. 38 | :param flag: if False, disable gradient checkpointing. 39 | """ 40 | if flag: 41 | args = tuple(inputs) + tuple(params) 42 | return CheckpointFunction.apply(func, len(inputs), *args) 43 | else: 44 | return func(*inputs) 45 | 46 | 47 | # class CheckpointFunction(torch.autograd.Function): 48 | # @staticmethod 49 | # def forward(ctx, run_function, length, *args): 50 | # ctx.run_function = run_function 51 | # ctx.input_tensors = list(args[:length]) 52 | # ctx.input_params = list(args[length:]) 53 | # ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), 54 | # "dtype": torch.get_autocast_gpu_dtype(), 55 | # "cache_enabled": torch.is_autocast_cache_enabled()} 56 | # with torch.no_grad(): 57 | # output_tensors = ctx.run_function(*ctx.input_tensors) 58 | # return output_tensors 59 | 60 | # @staticmethod 61 | # def backward(ctx, *output_grads): 62 | # ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 63 | # with torch.enable_grad(), \ 64 | # torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): 65 | # # Fixes a bug where the first op in run_function modifies the 66 | # # Tensor storage in place, which is not allowed for detach()'d 67 | # # Tensors. 68 | # shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 69 | # output_tensors = ctx.run_function(*shallow_copies) 70 | # input_grads = torch.autograd.grad( 71 | # output_tensors, 72 | # ctx.input_tensors + ctx.input_params, 73 | # output_grads, 74 | # allow_unused=True, 75 | # ) 76 | # del ctx.input_tensors 77 | # del ctx.input_params 78 | # del output_tensors 79 | # return (None, None) + input_grads 80 | 81 | 82 | # Fixes: When we set unet parameters with requires_grad=False, the original CheckpointFunction 83 | # still tries to compute gradient for unet parameters. 84 | # https://discuss.pytorch.org/t/get-runtimeerror-one-of-the-differentiated-tensors-does-not-require-grad-in-pytorch-lightning/179738/6 85 | class CheckpointFunction(torch.autograd.Function): 86 | @staticmethod 87 | def forward(ctx, run_function, length, *args): 88 | ctx.run_function = run_function 89 | ctx.input_tensors = list(args[:length]) 90 | ctx.input_params = list(args[length:]) 91 | ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), 92 | "dtype": torch.get_autocast_gpu_dtype(), 93 | "cache_enabled": torch.is_autocast_cache_enabled()} 94 | with torch.no_grad(): 95 | output_tensors = ctx.run_function(*ctx.input_tensors) 96 | return output_tensors 97 | 98 | @staticmethod 99 | def backward(ctx, *output_grads): 100 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 101 | with torch.enable_grad(), \ 102 | torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): 103 | # Fixes a bug where the first op in run_function modifies the 104 | # Tensor storage in place, which is not allowed for detach()'d 105 | # Tensors. 106 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 107 | output_tensors = ctx.run_function(*shallow_copies) 108 | grads = torch.autograd.grad( 109 | output_tensors, 110 | ctx.input_tensors + [x for x in ctx.input_params if x.requires_grad], 111 | output_grads, 112 | allow_unused=True, 113 | ) 114 | grads = list(grads) 115 | # Assign gradients to the correct positions, matching None for those that do not require gradients 116 | input_grads = [] 117 | for tensor in ctx.input_tensors + ctx.input_params: 118 | if tensor.requires_grad: 119 | input_grads.append(grads.pop(0)) # Get the next computed gradient 120 | else: 121 | input_grads.append(None) # No gradient required for this tensor 122 | del ctx.input_tensors 123 | del ctx.input_params 124 | del output_tensors 125 | return (None, None) + tuple(input_grads) 126 | 127 | 128 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 129 | """ 130 | Create sinusoidal timestep embeddings. 131 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 132 | These may be fractional. 133 | :param dim: the dimension of the output. 134 | :param max_period: controls the minimum frequency of the embeddings. 135 | :return: an [N x dim] Tensor of positional embeddings. 136 | """ 137 | if not repeat_only: 138 | half = dim // 2 139 | freqs = torch.exp( 140 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 141 | ).to(device=timesteps.device) 142 | args = timesteps[:, None].float() * freqs[None] 143 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 144 | if dim % 2: 145 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 146 | else: 147 | embedding = repeat(timesteps, 'b -> b d', d=dim) 148 | return embedding 149 | 150 | 151 | def zero_module(module): 152 | """ 153 | Zero out the parameters of a module and return it. 154 | """ 155 | for p in module.parameters(): 156 | p.detach().zero_() 157 | return module 158 | 159 | 160 | def scale_module(module, scale): 161 | """ 162 | Scale the parameters of a module and return it. 163 | """ 164 | for p in module.parameters(): 165 | p.detach().mul_(scale) 166 | return module 167 | 168 | 169 | def mean_flat(tensor): 170 | """ 171 | Take the mean over all non-batch dimensions. 172 | """ 173 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 174 | 175 | 176 | def normalization(channels): 177 | """ 178 | Make a standard normalization layer. 179 | :param channels: number of input channels. 180 | :return: an nn.Module for normalization. 181 | """ 182 | return GroupNorm32(32, channels) 183 | 184 | 185 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 186 | class SiLU(nn.Module): 187 | def forward(self, x): 188 | return x * torch.sigmoid(x) 189 | 190 | 191 | class GroupNorm32(nn.GroupNorm): 192 | def forward(self, x): 193 | return super().forward(x.float()).type(x.dtype) 194 | 195 | def conv_nd(dims, *args, **kwargs): 196 | """ 197 | Create a 1D, 2D, or 3D convolution module. 198 | """ 199 | if dims == 1: 200 | return nn.Conv1d(*args, **kwargs) 201 | elif dims == 2: 202 | return nn.Conv2d(*args, **kwargs) 203 | elif dims == 3: 204 | return nn.Conv3d(*args, **kwargs) 205 | raise ValueError(f"unsupported dimensions: {dims}") 206 | 207 | 208 | def linear(*args, **kwargs): 209 | """ 210 | Create a linear module. 211 | """ 212 | return nn.Linear(*args, **kwargs) 213 | 214 | 215 | def avg_pool_nd(dims, *args, **kwargs): 216 | """ 217 | Create a 1D, 2D, or 3D average pooling module. 218 | """ 219 | if dims == 1: 220 | return nn.AvgPool1d(*args, **kwargs) 221 | elif dims == 2: 222 | return nn.AvgPool2d(*args, **kwargs) 223 | elif dims == 3: 224 | return nn.AvgPool3d(*args, **kwargs) 225 | raise ValueError(f"unsupported dimensions: {dims}") 226 | -------------------------------------------------------------------------------- /pipeline/V4_CA/pipeline.py: -------------------------------------------------------------------------------- 1 | from typing import overload, Tuple, Optional 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | import numpy as np 7 | from PIL import Image 8 | from einops import rearrange 9 | 10 | from model.V4_CA.cldm import ControlLDM 11 | from model.V4_CA.gaussian_diffusion import Diffusion 12 | from utils.V4_CA.sampler import SpacedSampler 13 | from utils.cond_fn import Guidance 14 | from utils.common import wavelet_decomposition, wavelet_reconstruction, count_vram_usage 15 | 16 | from torchvision.utils import save_image 17 | 18 | 19 | def bicubic_resize(img: np.ndarray, scale: float) -> np.ndarray: 20 | pil = Image.fromarray(img) 21 | res = pil.resize(tuple(int(x * scale) for x in pil.size), Image.BICUBIC) 22 | return np.array(res) 23 | 24 | 25 | def resize_short_edge_to(imgs: torch.Tensor, size: int) -> torch.Tensor: 26 | _, _, h, w = imgs.size() 27 | if h == w: 28 | new_h, new_w = size, size 29 | elif h < w: 30 | new_h, new_w = size, int(w * (size / h)) 31 | else: 32 | new_h, new_w = int(h * (size / w)), size 33 | return F.interpolate(imgs, size=(new_h, new_w), mode="bicubic", antialias=True) 34 | 35 | 36 | def pad_to_multiples_of(imgs: torch.Tensor, multiple: int) -> torch.Tensor: 37 | _, _, h, w = imgs.size() 38 | if h % multiple == 0 and w % multiple == 0: 39 | return imgs.clone() 40 | # get_pad = lambda x: (x // multiple + 1) * multiple - x 41 | get_pad = lambda x: (x // multiple + int(x % multiple != 0)) * multiple - x 42 | ph, pw = get_pad(h), get_pad(w) 43 | return F.pad(imgs, pad=(0, pw, 0, ph), mode="constant", value=0) 44 | 45 | 46 | class UltraFusionPipeline: 47 | 48 | def __init__(self, cldm: ControlLDM, diffusion: Diffusion, fidelity_encoder, device: str) -> None: 49 | self.cldm = cldm 50 | self.diffusion = diffusion 51 | self.fidelity_encoder = fidelity_encoder 52 | self.device = device 53 | self.final_size: Tuple[int] = None 54 | 55 | def set_final_size(self, lq: torch.Tensor) -> None: 56 | h, w = lq.shape[2:] 57 | self.final_size = (h, w) 58 | 59 | @count_vram_usage 60 | def run( 61 | self, 62 | lq2, 63 | lq1_mscn_norm, 64 | lq1_color, 65 | steps: int = 50, 66 | strength: float = 1.0, 67 | tiled: bool = False, 68 | tile_size: int = 512, 69 | tile_stride: int = 256, 70 | pos_prompt: str = "", 71 | neg_prompt: str = "low quality, blurry, low-resolution, noisy, unsharp, weird textures", 72 | cfg_scale: float = "4.0", 73 | cond_fn: Guidance = None, 74 | fidelity_input: torch.Tensor = None, 75 | consistent_start: torch.Tensor = None 76 | ) -> torch.Tensor: 77 | ### preprocess 78 | lq1_mscn_norm, lq1_color, lq2 = lq1_mscn_norm.cuda(), lq1_color.cuda(), lq2.cuda() 79 | bs, _, H, W = lq2.shape 80 | if not tiled: 81 | assert H == 512 and W == 512, "The image shape must be equal to 512x512" 82 | 83 | # prepare conditon 84 | lq2 = lq2 * 2 - 1 #[-1, 1] 85 | if not tiled: 86 | cond = self.cldm.prepare_condition(lq2, lq1_mscn_norm, lq1_color, pos_prompt) 87 | else: 88 | cond, skip_feats = self.cldm.prepare_condition_tiled(lq2, lq1_mscn_norm, lq1_color, pos_prompt, tile_size=tile_size, tile_stride=tile_stride, fidelity_encoder=self.fidelity_encoder, fidelity_input=fidelity_input) 89 | uncond = None 90 | old_control_scales = self.cldm.control_scales 91 | self.cldm.control_scales = [strength] * 13 92 | x_T = torch.randn((bs, 4, H // 8, W // 8), dtype=torch.float32, device=self.device) 93 | # lq_latent = self.cldm.vae_encode(lq) 94 | # noise = torch.randn(lq_latent.shape, dtype=torch.float32, device=self.device) 95 | # x_T = self.diffusion.q_sample(x_start=lq_latent, t=torch.tensor([999], device=self.device), noise=noise) 96 | ### run sampler 97 | sampler = SpacedSampler(self.diffusion.betas) 98 | z = sampler.sample( 99 | model=self.cldm, device=self.device, steps=steps, batch_size=bs, x_size=(4, H // 8, W // 8), 100 | cond=cond, uncond=uncond, cfg_scale=cfg_scale, x_T=x_T, progress=True, 101 | progress_leave=True, cond_fn=cond_fn, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride 102 | ) 103 | if not tiled: 104 | sample = self.cldm.vae_decode(z) 105 | else: 106 | sample = self.cldm.vae_decode_tiled(z, tile_size // 8, tile_stride // 8, skip_feats, consistent_start) 107 | 108 | ### postprocess 109 | self.cldm.control_scales = old_control_scales 110 | return sample # [-1 , 1] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.30.1 2 | einops==0.8.1 3 | ftfy==6.3.1 4 | numpy==1.24.4 5 | omegaconf==2.3.0 6 | opencv_python==4.9.0.80 7 | opencv_python_headless==4.10.0.84 8 | packaging==25.0 9 | Pillow==11.2.1 10 | pyiqa==0.1.13 11 | regex==2024.5.15 12 | scipy==1.15.2 13 | torch==2.3.0 14 | torchvision==0.18.0 15 | tqdm==4.66.4 16 | transformers==4.37.2 17 | -------------------------------------------------------------------------------- /utils/V4_CA/sampler.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Dict 2 | 3 | import torch 4 | from torch import nn 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | from model.V4_CA.gaussian_diffusion import extract_into_tensor 9 | from model.V4_CA.cldm import ControlLDM 10 | from utils.cond_fn import Guidance 11 | from utils.common import sliding_windows, gaussian_weights 12 | 13 | 14 | # https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/respace.py 15 | def space_timesteps(num_timesteps, section_counts): 16 | """ 17 | Create a list of timesteps to use from an original diffusion process, 18 | given the number of timesteps we want to take from equally-sized portions 19 | of the original process. 20 | For example, if there's 300 timesteps and the section counts are [10,15,20] 21 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 22 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 23 | If the stride is a string starting with "ddim", then the fixed striding 24 | from the DDIM paper is used, and only one section is allowed. 25 | :param num_timesteps: the number of diffusion steps in the original 26 | process to divide up. 27 | :param section_counts: either a list of numbers, or a string containing 28 | comma-separated numbers, indicating the step count 29 | per section. As a special case, use "ddimN" where N 30 | is a number of steps to use the striding from the 31 | DDIM paper. 32 | :return: a set of diffusion steps from the original process to use. 33 | """ 34 | if isinstance(section_counts, str): 35 | if section_counts.startswith("ddim"): 36 | desired_count = int(section_counts[len("ddim") :]) 37 | for i in range(1, num_timesteps): 38 | if len(range(0, num_timesteps, i)) == desired_count: 39 | return set(range(0, num_timesteps, i)) 40 | raise ValueError( 41 | f"cannot create exactly {num_timesteps} steps with an integer stride" 42 | ) 43 | section_counts = [int(x) for x in section_counts.split(",")] 44 | size_per = num_timesteps // len(section_counts) 45 | extra = num_timesteps % len(section_counts) 46 | start_idx = 0 47 | all_steps = [] 48 | for i, section_count in enumerate(section_counts): 49 | size = size_per + (1 if i < extra else 0) 50 | if size < section_count: 51 | raise ValueError( 52 | f"cannot divide section of {size} steps into {section_count}" 53 | ) 54 | if section_count <= 1: 55 | frac_stride = 1 56 | else: 57 | frac_stride = (size - 1) / (section_count - 1) 58 | cur_idx = 0.0 59 | taken_steps = [] 60 | for _ in range(section_count): 61 | taken_steps.append(start_idx + round(cur_idx)) 62 | cur_idx += frac_stride 63 | all_steps += taken_steps 64 | start_idx += size 65 | return set(all_steps) 66 | 67 | 68 | class SpacedSampler(nn.Module): 69 | """ 70 | Implementation for spaced sampling schedule proposed in IDDPM. This class is designed 71 | for sampling ControlLDM. 72 | 73 | https://arxiv.org/pdf/2102.09672.pdf 74 | """ 75 | 76 | def __init__(self, betas: np.ndarray) -> "SpacedSampler": 77 | super().__init__() 78 | self.num_timesteps = len(betas) 79 | self.original_betas = betas 80 | self.original_alphas_cumprod = np.cumprod(1.0 - betas, axis=0) 81 | self.context = {} 82 | 83 | def register(self, name: str, value: np.ndarray) -> None: 84 | self.register_buffer(name, torch.tensor(value, dtype=torch.float32)) 85 | 86 | def make_schedule(self, num_steps: int) -> None: 87 | # calcualte betas for spaced sampling 88 | # https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/respace.py 89 | used_timesteps = space_timesteps(self.num_timesteps, str(num_steps)) 90 | betas = [] 91 | last_alpha_cumprod = 1.0 92 | for i, alpha_cumprod in enumerate(self.original_alphas_cumprod): 93 | if i in used_timesteps: 94 | # marginal distribution is the same as q(x_{S_t}|x_0) 95 | betas.append(1 - alpha_cumprod / last_alpha_cumprod) 96 | last_alpha_cumprod = alpha_cumprod 97 | assert len(betas) == num_steps 98 | self.timesteps = np.array(sorted(list(used_timesteps)), dtype=np.int32) # e.g. [0, 10, 20, ...] 99 | 100 | betas = np.array(betas, dtype=np.float64) 101 | alphas = 1.0 - betas 102 | alphas_cumprod = np.cumprod(alphas, axis=0) 103 | # print(f"sampler sqrt_alphas_cumprod: {np.sqrt(alphas_cumprod)[-1]}") 104 | alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) 105 | sqrt_recip_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod) 106 | sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod - 1) 107 | # calculations for posterior q(x_{t-1} | x_t, x_0) 108 | posterior_variance = ( 109 | betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) 110 | ) 111 | # log calculation clipped because the posterior variance is 0 at the 112 | # beginning of the diffusion chain. 113 | posterior_log_variance_clipped = np.log( 114 | np.append(posterior_variance[1], posterior_variance[1:]) 115 | ) 116 | posterior_mean_coef1 = ( 117 | betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod) 118 | ) 119 | posterior_mean_coef2 = ( 120 | (1.0 - alphas_cumprod_prev) 121 | * np.sqrt(alphas) 122 | / (1.0 - alphas_cumprod) 123 | ) 124 | 125 | self.register("sqrt_recip_alphas_cumprod", sqrt_recip_alphas_cumprod) 126 | self.register("sqrt_recipm1_alphas_cumprod", sqrt_recipm1_alphas_cumprod) 127 | self.register("posterior_variance", posterior_variance) 128 | self.register("posterior_log_variance_clipped", posterior_log_variance_clipped) 129 | self.register("posterior_mean_coef1", posterior_mean_coef1) 130 | self.register("posterior_mean_coef2", posterior_mean_coef2) 131 | 132 | def q_posterior_mean_variance(self, x_start: torch.Tensor, x_t: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor]: 133 | """ 134 | Implement the posterior distribution q(x_{t-1}|x_t, x_0). 135 | 136 | Args: 137 | x_start (torch.Tensor): The predicted images (NCHW) in timestep `t`. 138 | x_t (torch.Tensor): The sampled intermediate variables (NCHW) of timestep `t`. 139 | t (torch.Tensor): Timestep (N) of `x_t`. `t` serves as an index to get 140 | parameters for each timestep. 141 | 142 | Returns: 143 | posterior_mean (torch.Tensor): Mean of the posterior distribution. 144 | posterior_variance (torch.Tensor): Variance of the posterior distribution. 145 | posterior_log_variance_clipped (torch.Tensor): Log variance of the posterior distribution. 146 | """ 147 | posterior_mean = ( 148 | extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start 149 | + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t 150 | ) 151 | posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) 152 | posterior_log_variance_clipped = extract_into_tensor( 153 | self.posterior_log_variance_clipped, t, x_t.shape 154 | ) 155 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 156 | 157 | def _predict_xstart_from_eps(self, x_t: torch.Tensor, t: torch.Tensor, eps: torch.Tensor) -> torch.Tensor: 158 | return ( 159 | extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 160 | - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps 161 | ) 162 | 163 | def apply_cond_fn( 164 | self, 165 | model: ControlLDM, 166 | pred_x0: torch.Tensor, 167 | t: torch.Tensor, 168 | index: torch.Tensor, 169 | cond_fn: Guidance 170 | ) -> torch.Tensor: 171 | t_now = int(t[0].item()) + 1 172 | if not (cond_fn.t_stop < t_now and t_now < cond_fn.t_start): 173 | # stop guidance 174 | self.context["g_apply"] = False 175 | return pred_x0 176 | grad_rescale = 1 / extract_into_tensor(self.posterior_mean_coef1, index, pred_x0.shape) 177 | # apply guidance for multiple times 178 | loss_vals = [] 179 | for _ in range(cond_fn.repeat): 180 | # set target and pred for gradient computation 181 | target, pred = None, None 182 | if cond_fn.space == "latent": 183 | target = model.vae_encode(cond_fn.target) 184 | pred = pred_x0 185 | elif cond_fn.space == "rgb": 186 | # We need to backward gradient to x0 in latent space, so it's required 187 | # to trace the computation graph while decoding the latent. 188 | with torch.enable_grad(): 189 | target = cond_fn.target 190 | pred_x0_rg = pred_x0.detach().clone().requires_grad_(True) 191 | pred = model.vae_decode(pred_x0_rg) 192 | assert pred.requires_grad 193 | else: 194 | raise NotImplementedError(cond_fn.space) 195 | # compute gradient 196 | delta_pred, loss_val = cond_fn(target, pred, t_now) 197 | loss_vals.append(loss_val) 198 | # update pred_x0 w.r.t gradient 199 | if cond_fn.space == "latent": 200 | delta_pred_x0 = delta_pred 201 | pred_x0 = pred_x0 + delta_pred_x0 * grad_rescale 202 | elif cond_fn.space == "rgb": 203 | pred.backward(delta_pred) 204 | delta_pred_x0 = pred_x0_rg.grad 205 | pred_x0 = pred_x0 + delta_pred_x0 * grad_rescale 206 | else: 207 | raise NotImplementedError(cond_fn.space) 208 | self.context["g_apply"] = True 209 | self.context["g_loss"] = float(np.mean(loss_vals)) 210 | return pred_x0 211 | 212 | def predict_noise( 213 | self, 214 | model: ControlLDM, 215 | x: torch.Tensor, 216 | t: torch.Tensor, 217 | cond: Dict[str, torch.Tensor], 218 | uncond: Optional[Dict[str, torch.Tensor]], 219 | cfg_scale: float 220 | ) -> torch.Tensor: 221 | if uncond is None or cfg_scale == 1.: 222 | model_output = model(x, t, cond) 223 | else: 224 | # apply classifier-free guidance 225 | model_cond = model(x, t, cond) 226 | model_uncond = model(x, t, uncond) 227 | model_output = model_uncond + cfg_scale * (model_cond - model_uncond) 228 | return model_output 229 | 230 | @torch.no_grad() 231 | def predict_noise_tiled( 232 | self, 233 | model: ControlLDM, 234 | x: torch.Tensor, 235 | t: torch.Tensor, 236 | cond: Dict[str, torch.Tensor], 237 | uncond: Optional[Dict[str, torch.Tensor]], 238 | cfg_scale: float, 239 | tile_size: int, 240 | tile_stride: int, 241 | fix_init_noise: bool 242 | ): 243 | _, _, h, w = x.shape 244 | tiles_latent = tqdm(sliding_windows(h, w, tile_size // 8, tile_stride // 8), unit="tile", leave=False) 245 | tiles_img = tqdm(sliding_windows(h * 8, w * 8, tile_size, tile_stride), unit="tile", leave=False) 246 | eps = torch.zeros_like(x) 247 | count = torch.zeros_like(x, dtype=torch.float32) 248 | weights = gaussian_weights(tile_size // 8, tile_size // 8)[None, None] 249 | weights = torch.tensor(weights, dtype=torch.float32, device=x.device) 250 | for (hi_latent, hi_end_latent, wi_latent, wi_end_latent), (hi_img, hi_end_img, wi_img, wi_end_img) in zip(tiles_latent, tiles_img): 251 | tiles_latent.set_description(f"Process tile ({hi_latent} {hi_end_latent}), ({wi_latent} {wi_end_latent})") 252 | if fix_init_noise: 253 | tile_x = x[:, :, 0:64, 0:64] 254 | else: 255 | tile_x = x[:, :, hi_latent:hi_end_latent, wi_latent:wi_end_latent] 256 | tile_cond = { 257 | "c_lq2": cond["c_lq2"][:, :, hi_latent:hi_end_latent, wi_latent:wi_end_latent], 258 | "c_lq1_mscn_norm": cond["c_lq1_mscn_norm"][:, :, hi_img:hi_end_img, wi_img:wi_end_img], 259 | "c_lq1_color": cond["c_lq1_color"][:, :, hi_img:hi_end_img, wi_img:wi_end_img], 260 | "c_txt": cond["c_txt"] 261 | } 262 | if uncond: 263 | # not implemented 264 | tile_uncond = { 265 | "c_img": uncond["c_img"], 266 | "c_txt": uncond["c_txt"] 267 | } 268 | else: 269 | tile_uncond = None 270 | tile_eps = self.predict_noise(model, tile_x, t, tile_cond, tile_uncond, cfg_scale) 271 | # accumulate noise 272 | eps[:, :, hi_latent:hi_end_latent, wi_latent:wi_end_latent] += tile_eps * weights 273 | count[:, :, hi_latent:hi_end_latent, wi_latent:wi_end_latent] += weights 274 | # average on noise (score) 275 | eps.div_(count) 276 | return eps 277 | 278 | @torch.no_grad() 279 | def p_sample( 280 | self, 281 | model: ControlLDM, 282 | x: torch.Tensor, 283 | t: torch.Tensor, 284 | index: torch.Tensor, 285 | cond: Dict[str, torch.Tensor], 286 | uncond: Optional[Dict[str, torch.Tensor]], 287 | cfg_scale: float, 288 | cond_fn: Optional[Guidance], 289 | tiled: bool, 290 | tile_size: int, 291 | tile_stride: int, 292 | fix_init_noise: bool=False 293 | ) -> torch.Tensor: 294 | if tiled: 295 | eps = self.predict_noise_tiled(model, x, t, cond, uncond, cfg_scale, tile_size, tile_stride, fix_init_noise) 296 | else: 297 | eps = self.predict_noise(model, x, t, cond, uncond, cfg_scale) 298 | pred_x0 = self._predict_xstart_from_eps(x, index, eps) 299 | if cond_fn: 300 | # assert not tiled, f"tiled sampling currently doesn't support guidance" 301 | assert x.shape[0] == 1, "guidance sampling currently only support batch size = 1" 302 | pred_x0 = self.apply_cond_fn(model, pred_x0, t, index, cond_fn) 303 | model_mean, model_variance, _ = self.q_posterior_mean_variance(pred_x0, x, index) 304 | noise = torch.randn_like(x) 305 | nonzero_mask = ( 306 | (index != 0).float().view(-1, *([1] * (len(x.shape) - 1))) 307 | ) 308 | x_prev = model_mean + nonzero_mask * torch.sqrt(model_variance) * noise 309 | return x_prev 310 | 311 | @torch.no_grad() 312 | def sample( 313 | self, 314 | model: ControlLDM, 315 | device: str, 316 | steps: int, 317 | batch_size: int, 318 | x_size: Tuple[int], 319 | cond: Dict[str, torch.Tensor], 320 | uncond: Dict[str, torch.Tensor], 321 | cfg_scale: float, 322 | cond_fn: Optional[Guidance]=None, 323 | tiled: bool=False, 324 | tile_size: int=-1, 325 | tile_stride: int=-1, 326 | x_T: Optional[torch.Tensor]=None, 327 | progress: bool=True, 328 | progress_leave: bool=True, 329 | ) -> torch.Tensor: 330 | self.make_schedule(steps) 331 | self.to(device) 332 | if x_T is None: 333 | # TODO: not convert to float32, may trigger an error 334 | img = torch.randn((batch_size, *x_size), device=device) 335 | else: 336 | img = x_T 337 | timesteps = np.flip(self.timesteps) # [1000, 950, 900, ...] 338 | total_steps = len(self.timesteps) 339 | iterator = tqdm(timesteps, total=total_steps, leave=progress_leave, disable=not progress) 340 | for i, step in enumerate(iterator): 341 | ts = torch.full((batch_size,), step, device=device, dtype=torch.long) 342 | index = torch.full_like(ts, fill_value=total_steps - i - 1) 343 | img = self.p_sample( 344 | model, img, ts, index, cond, uncond, cfg_scale, cond_fn, 345 | tiled, tile_size, tile_stride, fix_init_noise=True if i == 0 else False 346 | ) 347 | if cond_fn and self.context["g_apply"]: 348 | loss_val = self.context["g_loss"] 349 | desc = f"Spaced Sampler With Guidance, Loss: {loss_val:.6f}" 350 | else: 351 | desc = "Spaced Sampler" 352 | iterator.set_description(desc) 353 | return img 354 | -------------------------------------------------------------------------------- /utils/common.py: -------------------------------------------------------------------------------- 1 | from typing import Mapping, Any, Tuple, Callable 2 | import importlib 3 | import os 4 | from urllib.parse import urlparse 5 | 6 | import torch 7 | from torch import Tensor 8 | from torch.nn import functional as F 9 | import numpy as np 10 | 11 | from torch.hub import download_url_to_file, get_dir 12 | 13 | 14 | def get_obj_from_str(string: str, reload: bool=False) -> Any: 15 | module, cls = string.rsplit(".", 1) 16 | if reload: 17 | module_imp = importlib.import_module(module) 18 | importlib.reload(module_imp) 19 | return getattr(importlib.import_module(module, package=None), cls) 20 | 21 | 22 | def instantiate_from_config(config: Mapping[str, Any]) -> Any: 23 | if not "target" in config: 24 | raise KeyError("Expected key `target` to instantiate.") 25 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 26 | 27 | 28 | def wavelet_blur(image: Tensor, radius: int): 29 | """ 30 | Apply wavelet blur to the input tensor. 31 | """ 32 | # input shape: (1, 3, H, W) 33 | # convolution kernel 34 | kernel_vals = [ 35 | [0.0625, 0.125, 0.0625], 36 | [0.125, 0.25, 0.125], 37 | [0.0625, 0.125, 0.0625], 38 | ] 39 | kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) 40 | # add channel dimensions to the kernel to make it a 4D tensor 41 | kernel = kernel[None, None] 42 | # repeat the kernel across all input channels 43 | kernel = kernel.repeat(3, 1, 1, 1) 44 | image = F.pad(image, (radius, radius, radius, radius), mode='replicate') 45 | # apply convolution 46 | output = F.conv2d(image, kernel, groups=3, dilation=radius) 47 | return output 48 | 49 | 50 | def wavelet_decomposition(image: Tensor, levels=5): 51 | """ 52 | Apply wavelet decomposition to the input tensor. 53 | This function only returns the low frequency & the high frequency. 54 | """ 55 | high_freq = torch.zeros_like(image) 56 | for i in range(levels): 57 | radius = 2 ** i 58 | low_freq = wavelet_blur(image, radius) 59 | high_freq += (image - low_freq) 60 | image = low_freq 61 | 62 | return high_freq, low_freq 63 | 64 | 65 | def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor): 66 | """ 67 | Apply wavelet decomposition, so that the content will have the same color as the style. 68 | """ 69 | # calculate the wavelet decomposition of the content feature 70 | content_high_freq, content_low_freq = wavelet_decomposition(content_feat) 71 | del content_low_freq 72 | # calculate the wavelet decomposition of the style feature 73 | style_high_freq, style_low_freq = wavelet_decomposition(style_feat) 74 | del style_high_freq 75 | # reconstruct the content feature with the style's high frequency 76 | return content_high_freq + style_low_freq 77 | 78 | 79 | # https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/utils/download_util.py/ 80 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None): 81 | """Load file form http url, will download models if necessary. 82 | 83 | Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py 84 | 85 | Args: 86 | url (str): URL to be downloaded. 87 | model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. 88 | Default: None. 89 | progress (bool): Whether to show the download progress. Default: True. 90 | file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. 91 | 92 | Returns: 93 | str: The path to the downloaded file. 94 | """ 95 | if model_dir is None: # use the pytorch hub_dir 96 | hub_dir = get_dir() 97 | model_dir = os.path.join(hub_dir, 'checkpoints') 98 | 99 | os.makedirs(model_dir, exist_ok=True) 100 | 101 | parts = urlparse(url) 102 | filename = os.path.basename(parts.path) 103 | if file_name is not None: 104 | filename = file_name 105 | cached_file = os.path.abspath(os.path.join(model_dir, filename)) 106 | if not os.path.exists(cached_file): 107 | print(f'Downloading: "{url}" to {cached_file}\n') 108 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) 109 | return cached_file 110 | 111 | 112 | def sliding_windows(h: int, w: int, tile_size: int, tile_stride: int) -> Tuple[int, int, int, int]: 113 | hi_list = list(range(0, h - tile_size + 1, tile_stride)) 114 | if (h - tile_size) % tile_stride != 0: 115 | hi_list.append(h - tile_size) 116 | 117 | wi_list = list(range(0, w - tile_size + 1, tile_stride)) 118 | if (w - tile_size) % tile_stride != 0: 119 | wi_list.append(w - tile_size) 120 | 121 | coords = [] 122 | for hi in hi_list: 123 | for wi in wi_list: 124 | coords.append((hi, hi + tile_size, wi, wi + tile_size)) 125 | return coords 126 | 127 | 128 | # https://github.com/csslc/CCSR/blob/main/model/q_sampler.py#L503 129 | def gaussian_weights(tile_width: int, tile_height: int) -> np.ndarray: 130 | """Generates a gaussian mask of weights for tile contributions""" 131 | latent_width = tile_width 132 | latent_height = tile_height 133 | var = 0.01 134 | midpoint = (latent_width - 1) / 2 # -1 because index goes from 0 to latent_width - 1 135 | x_probs = [ 136 | np.exp(-(x - midpoint) * (x - midpoint) / (latent_width * latent_width) / (2 * var)) / np.sqrt(2 * np.pi * var) 137 | for x in range(latent_width)] 138 | midpoint = latent_height / 2 139 | y_probs = [ 140 | np.exp(-(y - midpoint) * (y - midpoint) / (latent_height * latent_height) / (2 * var)) / np.sqrt(2 * np.pi * var) 141 | for y in range(latent_height)] 142 | weights = np.outer(y_probs, x_probs) 143 | return weights 144 | 145 | 146 | COUNT_VRAM = bool(os.environ.get("COUNT_VRAM", False)) 147 | 148 | def count_vram_usage(func: Callable) -> Callable: 149 | if not COUNT_VRAM: 150 | return func 151 | 152 | def wrapper(*args, **kwargs): 153 | peak_before = torch.cuda.max_memory_allocated() / (1024 ** 3) 154 | ret = func(*args, **kwargs) 155 | torch.cuda.synchronize() 156 | peak_after = torch.cuda.max_memory_allocated() / (1024 ** 3) 157 | print(f"VRAM peak before {func.__name__}: {peak_before:.5f} GB, after: {peak_after:.5f} GB") 158 | return ret 159 | return wrapper 160 | -------------------------------------------------------------------------------- /utils/cond_fn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms.functional as tvtf 3 | 4 | from typing import overload, Tuple 5 | from torch.nn import functional as F 6 | from torchvision.utils import save_image 7 | 8 | 9 | class Guidance: 10 | 11 | def __init__(self, scale: float, t_start: int, t_stop: int, space: str, repeat: int) -> "Guidance": 12 | """ 13 | Initialize restoration guidance. 14 | 15 | Args: 16 | scale (float): Gradient scale (denoted as `s` in our paper). The larger the gradient scale, 17 | the closer the final result will be to the output of the first stage model. 18 | t_start (int), t_stop (int): The timestep to start or stop guidance. Note that the sampling 19 | process starts from t=1000 to t=0, the `t_start` should be larger than `t_stop`. 20 | space (str): The data space for computing loss function (rgb or latent). 21 | 22 | Our restoration guidance is based on [GDP](https://github.com/Fayeben/GenerativeDiffusionPrior). 23 | Thanks for their work! 24 | """ 25 | self.scale = scale * 3000 26 | self.t_start = t_start 27 | self.t_stop = t_stop 28 | self.target = None 29 | self.space = space 30 | self.repeat = repeat 31 | 32 | def load_target(self, target: torch.Tensor) -> None: 33 | self.target = target 34 | 35 | def __call__(self, target_x0, pred_x0: torch.Tensor, t: int) -> Tuple[torch.Tensor, float]: 36 | # avoid propagating gradient out of this scope 37 | pred_x0 = pred_x0.detach().clone() 38 | tmp1, tmp2 = target_x0 39 | tmp1 = tmp1.detach().clone() 40 | tmp2 = tmp2.detach().clone() 41 | return self._forward([tmp1, tmp2], pred_x0, t) 42 | # return self._forward([tmp1, tmp2], pred_x0, t) 43 | 44 | @overload 45 | def _forward(self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int) -> Tuple[torch.Tensor, float]: 46 | ... 47 | 48 | 49 | class MSEGuidance(Guidance): 50 | 51 | def _forward(self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int) -> Tuple[torch.Tensor, float]: 52 | # inputs: [-1, 1], nchw, rgb 53 | target_y = 0.299 * target_x0[:, 0, :, :] + 0.587 * target_x0[:, 1, :, :] + 0.114 * target_x0[:, 2, :, :] 54 | weight_map = (target_y > 0.99).type(torch.float32) 55 | save_image(weight_map, 'weight_map.png') 56 | with torch.enable_grad(): 57 | pred_x0.requires_grad_(True) 58 | loss = ((pred_x0 - target_x0).pow(2) * (1. - weight_map)).mean((1, 2, 3)).sum() 59 | scale = self.scale 60 | g = -torch.autograd.grad(loss, pred_x0)[0] * scale 61 | return g, loss.item() 62 | 63 | 64 | class WeightedMSEGuidance(Guidance): 65 | 66 | def _get_weight(self, target: torch.Tensor) -> torch.Tensor: 67 | # convert RGB to G 68 | rgb_to_gray_kernel = torch.tensor([0.2989, 0.5870, 0.1140]).view(1, 3, 1, 1) 69 | target = torch.sum(target * rgb_to_gray_kernel.to(target.device), dim=1, keepdim=True) 70 | # initialize sobel kernel in x and y axis 71 | G_x = [ 72 | [1, 0, -1], 73 | [2, 0, -2], 74 | [1, 0, -1] 75 | ] 76 | G_y = [ 77 | [1, 2, 1], 78 | [0, 0, 0], 79 | [-1, -2, -1] 80 | ] 81 | G_x = torch.tensor(G_x, dtype=target.dtype, device=target.device)[None] 82 | G_y = torch.tensor(G_y, dtype=target.dtype, device=target.device)[None] 83 | G = torch.stack((G_x, G_y)) 84 | 85 | target = F.pad(target, (1, 1, 1, 1), mode='replicate') # padding = 1 86 | grad = F.conv2d(target, G, stride=1) 87 | mag = grad.pow(2).sum(dim=1, keepdim=True).sqrt() 88 | 89 | n, c, h, w = mag.size() 90 | block_size = 2 91 | blocks = mag.view(n, c, h // block_size, block_size, w // block_size, block_size).permute(0, 1, 2, 4, 3, 5).contiguous() 92 | block_mean = blocks.sum(dim=(-2, -1), keepdim=True).tanh().repeat(1, 1, 1, 1, block_size, block_size).permute(0, 1, 2, 4, 3, 5).contiguous() 93 | block_mean = block_mean.view(n, c, h, w) 94 | weight_map = 1 - block_mean 95 | 96 | return weight_map 97 | 98 | def _forward(self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int) -> Tuple[torch.Tensor, float]: 99 | # inputs: [-1, 1], nchw, rgb 100 | with torch.no_grad(): 101 | w = self._get_weight((target_x0 + 1) / 2) 102 | with torch.enable_grad(): 103 | pred_x0.requires_grad_(True) 104 | loss = ((pred_x0 - target_x0).pow(2) * w).mean((1, 2, 3)).sum() 105 | scale = self.scale 106 | g = -torch.autograd.grad(loss, pred_x0)[0] * scale 107 | return g, loss.item() 108 | 109 | 110 | 111 | class L_spa(torch.nn.Module): 112 | 113 | def __init__(self, patch_size): 114 | super(L_spa, self).__init__() 115 | # print(1)kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0) 116 | kernel_left = torch.FloatTensor( [[0,0,0],[-1,1,0],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0) 117 | kernel_right = torch.FloatTensor( [[0,0,0],[0,1,-1],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0) 118 | kernel_up = torch.FloatTensor( [[0,-1,0],[0,1, 0 ],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0) 119 | kernel_down = torch.FloatTensor( [[0,0,0],[0,1, 0],[0,-1,0]]).cuda().unsqueeze(0).unsqueeze(0) 120 | 121 | # kernel_left_up = torch.FloatTensor( [[-1,0,0],[0,1,0],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0) 122 | # kernel_right_up = torch.FloatTensor( [[0,0,-1],[0,1,0],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0) 123 | # kernel_left_down = torch.FloatTensor( [[0,0,0],[0,1, 0 ],[-1,0,0]]).cuda().unsqueeze(0).unsqueeze(0) 124 | # kernel_right_down = torch.FloatTensor( [[0,0,0],[0,1, 0],[0,0,-1]]).cuda().unsqueeze(0).unsqueeze(0) 125 | 126 | self.weight_left = torch.nn.Parameter(data=kernel_left, requires_grad=False) 127 | self.weight_right = torch.nn.Parameter(data=kernel_right, requires_grad=False) 128 | self.weight_up = torch.nn.Parameter(data=kernel_up, requires_grad=False) 129 | self.weight_down = torch.nn.Parameter(data=kernel_down, requires_grad=False) 130 | # self.weight_left_up = nn.Parameter(data=kernel_left_up, requires_grad=False) 131 | # self.weight_right_up = nn.Parameter(data=kernel_right_up, requires_grad=False) 132 | # self.weight_left_down = nn.Parameter(data=kernel_left_down, requires_grad=False) 133 | # self.weight_right_down = nn.Parameter(data=kernel_right_down, requires_grad=False) 134 | self.pool = torch.nn.AvgPool2d(patch_size) 135 | 136 | def forward(self, org , enhance, weight_map): 137 | b,c,h,w = org.shape 138 | 139 | org_mean = torch.mean(org,1,keepdim=True) 140 | enhance_mean = torch.mean(enhance,1,keepdim=True) 141 | 142 | org_pool = self.pool(org_mean) 143 | enhance_pool = self.pool(enhance_mean) 144 | weight_map_pool = self.pool(weight_map) 145 | 146 | D_org_letf = F.conv2d(org_pool , self.weight_left, padding=1) 147 | D_org_right = F.conv2d(org_pool , self.weight_right, padding=1) 148 | D_org_up = F.conv2d(org_pool , self.weight_up, padding=1) 149 | D_org_down = F.conv2d(org_pool , self.weight_down, padding=1) 150 | 151 | # D_org_left_up = F.conv2d(org_pool , self.weight_left_up, padding=1) 152 | # D_org_right_up = F.conv2d(org_pool , self.weight_right_up, padding=1) 153 | # D_org_left_down = F.conv2d(org_pool , self.weight_left_down, padding=1) 154 | # D_org_right_down = F.conv2d(org_pool , self.weight_right_down, padding=1) 155 | 156 | D_enhance_letf = F.conv2d(enhance_pool , self.weight_left, padding=1) 157 | D_enhance_right = F.conv2d(enhance_pool , self.weight_right, padding=1) 158 | D_enhance_up = F.conv2d(enhance_pool , self.weight_up, padding=1) 159 | D_enhance_down = F.conv2d(enhance_pool , self.weight_down, padding=1) 160 | 161 | # D_enhance_left_up = F.conv2d(enhance_pool , self.weight_left_up, padding=1) 162 | # D_enhance_right_up = F.conv2d(enhance_pool , self.weight_right_up, padding=1) 163 | # D_enhance_left_down = F.conv2d(enhance_pool , self.weight_left_down, padding=1) 164 | # D_enhance_right_down = F.conv2d(enhance_pool , self.weight_right_down, padding=1) 165 | 166 | D_left = torch.pow(D_org_letf - D_enhance_letf,2) 167 | D_right = torch.pow(D_org_right - D_enhance_right,2) 168 | D_up = torch.pow(D_org_up - D_enhance_up,2) 169 | D_down = torch.pow(D_org_down - D_enhance_down,2) 170 | 171 | # D_left_up = torch.pow(D_org_left_up - D_enhance_left_up, 2) 172 | # D_right_up = torch.pow(D_org_right_up - D_enhance_right_up, 2) 173 | # D_left_down = torch.pow(D_org_left_down - D_enhance_left_down, 2) 174 | # D_right_down = torch.pow(D_org_right_down - D_enhance_right_down, 2) 175 | 176 | E = (D_left + D_right + D_up +D_down) 177 | 178 | return torch.mean(E * weight_map_pool) 179 | 180 | 181 | class L_mscn(torch.nn.Module): 182 | 183 | def __init__(self): 184 | super(L_mscn, self).__init__() 185 | self.l1 = torch.nn.L1Loss() 186 | def forward(self, img1, img2, weight_map): 187 | y1 = 0.299 * img1[:, 0, :, :] + 0.587 * img1[:, 1, :, :] + 0.114 * img1[:, 2, :, :] 188 | y2 = 0.299 * img2[:, 0, :, :] + 0.587 * img2[:, 1, :, :] + 0.114 * img2[:, 2, :, :] 189 | y1 = y1.type(torch.float64) 190 | y2 = y2.type(torch.float64) 191 | mu1 = tvtf.gaussian_blur(y1, (7, 7)) 192 | mu2 = tvtf.gaussian_blur(y2, (7, 7)) 193 | mu1_sq = mu1 * mu1 194 | mu2_sq = mu2 * mu2 195 | sigma1 = torch.sqrt(torch.abs(tvtf.gaussian_blur(y1 * y1, (7, 7)) - mu1_sq)) 196 | sigma2 = torch.sqrt(torch.abs(tvtf.gaussian_blur(y2 * y2, (7, 7)) - mu2_sq)) 197 | dividend1 = y1 - mu1 198 | dividend2 = y2 - mu2 199 | divisor1 = sigma1 + 1e-7 200 | divisor2 = sigma2 + 1e-7 201 | struct1 = (dividend1 / divisor1) 202 | struct2 = (dividend2 / divisor2) 203 | struct1_norm = (struct1 - struct1.min()) / (struct1.max() - struct1.min()) 204 | struct2_norm = (struct2 - struct2.min()) / (struct2.max() - struct2.min()) 205 | 206 | return ((struct1 - struct2).pow(2) * weight_map).mean() 207 | # return self.l1(struct1_norm, struct2_norm) 208 | # return (struct1_norm - struct2_norm).mean() 209 | 210 | 211 | class StructureGuidance(Guidance): 212 | 213 | def __init__(self, scale: float, t_start: int, t_stop: int, space: str, repeat: int) -> Guidance: 214 | super(StructureGuidance, self).__init__(scale, t_start, t_stop, space, repeat) 215 | # self.spa1_loss = L_spa(1) 216 | # self.spa2_loss = L_spa(2) 217 | self.spa4_loss = L_spa(1) 218 | # self.loss = L_mscn() 219 | 220 | def mscn_torch(self, input_img: torch.Tensor, ksize, c): #input an RGB image 221 | y = 0.299 * input_img[:, 0, :, :] + 0.587 * input_img[:, 1, :, :] + 0.114 * input_img[:, 2, :, :] 222 | y = y.type(torch.float64) 223 | mu = tvtf.gaussian_blur(y, (ksize,ksize)) 224 | mu_sq = mu * mu 225 | sigma = torch.sqrt(torch.abs(tvtf.gaussian_blur(y*y, (ksize,ksize)) - mu_sq)) 226 | dividend = y - mu 227 | divisor = sigma + c 228 | struct = (dividend / divisor).type(torch.float32) 229 | struct_norm = (struct - struct.min()) / (struct.max() - struct.min() + 1e-6) 230 | return struct_norm 231 | 232 | def _forward(self, target_x0, pred_x0: torch.Tensor, t: int) -> Tuple[torch.Tensor, float]: 233 | # pred_x0: [-1, 1], nchw, rgb 234 | target1, target2 = target_x0 235 | pred_x0 = (pred_x0 + 1) / 2 # [0, 1] 236 | target1_struct = self.mscn_torch(input_img=target1, ksize=7, c=1e-7) 237 | target2_struct = self.mscn_torch(input_img=target2, ksize=7, c=1e-7) 238 | target2_y = 0.299 * target2[:, 0, :, :] + 0.587 * target2[:, 1, :, :] + 0.114 * target2[:, 2, :, :] 239 | weight_map = (target2_y > 0.99).type(torch.float32) 240 | 241 | with torch.enable_grad(): 242 | pred_x0.requires_grad_(True) 243 | # loss = (pred_x0_struct - target_x0_struct).pow(2).mean((1, 2)).sum() 244 | # loss = self.spa_loss(target2, pred_x0, 1. - weight_map) 245 | loss = self.spa4_loss(target2, pred_x0, 1. - weight_map) 246 | # loss = self.loss(target1, pred_x0, 1. - weight_map) 247 | scale = self.scale 248 | g = -torch.autograd.grad(loss, pred_x0)[0] * scale 249 | return g, loss.item() -------------------------------------------------------------------------------- /utils/flow.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/haofeixu/gmflow and https://github.com/liuziyang123/LDRFlow. 2 | 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from einops import rearrange 7 | 8 | 9 | def backward_warp(x, flo): 10 | """ 11 | warp an image/tensor (im2) back to im1, according to the optical flow 12 | x: [B, C, H, W] (im2) 13 | flo: [B, 2, H, W] flow 14 | """ 15 | B, C, H, W = x.size() 16 | # mesh grid 17 | xx = torch.arange(0, W).view(1, -1).repeat(H, 1) 18 | yy = torch.arange(0, H).view(-1, 1).repeat(1, W) 19 | xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1) 20 | yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1) 21 | grid = torch.cat((xx, yy), 1).float() 22 | 23 | if x.is_cuda: 24 | grid = grid.cuda() 25 | vgrid = grid + flo 26 | # scale grid to [-1,1] 27 | vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0 28 | vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0 29 | 30 | vgrid = vgrid.permute(0, 2, 3, 1) 31 | output = F.grid_sample(x, vgrid) 32 | # mask = torch.ones(x.size()).to(DEVICE) 33 | # mask = F.grid_sample(mask, vgrid) 34 | 35 | # mask[mask < 0.999] = 0 36 | # mask[mask > 0] = 1 37 | 38 | return output 39 | 40 | 41 | def forward_backward_consistency_check(fwd_flow, bwd_flow, 42 | alpha=0.01, 43 | beta=10 44 | ): 45 | # fwd_flow, bwd_flow: [B, 2, H, W] 46 | # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837) 47 | assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4 48 | assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2 49 | flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W] 50 | 51 | warped_bwd_flow = backward_warp(bwd_flow, fwd_flow) # [B, 2, H, W] 52 | warped_fwd_flow = backward_warp(fwd_flow, bwd_flow) # [B, 2, H, W] 53 | 54 | diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W] 55 | diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1) 56 | 57 | threshold = alpha * flow_mag + beta 58 | # threshold = 0 59 | 60 | fwd_occ = (diff_fwd > threshold).float() # [B, H, W] 61 | bwd_occ = (diff_bwd > threshold).float() 62 | 63 | return fwd_occ, bwd_occ 64 | 65 | 66 | def calculate_imf_map(x, y): 67 | imf_map = torch.zeros(256).cuda() 68 | r = 0 69 | for i in range(256): 70 | if x[i] == 0: 71 | imf_map[i] = -1 72 | else: 73 | p, v, j = x[i], 0, r 74 | while True: 75 | if y[j] < p: 76 | p = p - y[j] 77 | v = v + y[j] * j 78 | j += 1 79 | else: 80 | r = j 81 | y[j] = y[j] - p 82 | v = v + p * j 83 | imf_map[i] = (v / x[i]).round() 84 | break 85 | imf_map = imf_map.unsqueeze(dim=0) 86 | return imf_map 87 | 88 | 89 | def IMF(ue, oe): 90 | B, C, H, W = ue.shape 91 | ue = (ue * 255).round() 92 | oe = (oe * 255).round() 93 | 94 | imf_map = [] 95 | ue_rgb = torch.split(ue, 1, dim=1) 96 | oe_rgb = torch.split(oe, 1, dim=1) 97 | imf_map = [ 98 | calculate_imf_map(torch.histc(x, bins=256, min=0, max=255), torch.histc(y, bins=256, min=0, max=255)) for x, y in zip(ue_rgb, oe_rgb) 99 | ] 100 | imf_map = torch.concat(imf_map, dim=0) 101 | 102 | zeros = torch.zeros([C, 1], dtype=torch.float32).cuda() 103 | imf_map = torch.concat((imf_map, zeros), 1) 104 | 105 | ue_imf = rearrange(ue.squeeze(), 'c h w -> c (h w)') 106 | ue_imf_floor = ue_imf.floor() 107 | for c in range(C): 108 | ind = ue_imf_floor[c].long() 109 | ue_imf[c, :] = (ue_imf[c, :] - ue_imf_floor[c, :]) * (imf_map[c, :][ind + 1] - imf_map[c, :][ind]) + imf_map[c, :][ind] 110 | ue_imf = rearrange(ue_imf, 'c (h w) -> c h w', h=H, w=W).unsqueeze(dim=0) 111 | ue_imf = (ue_imf / 255.).clamp(0, 1) 112 | return ue_imf -------------------------------------------------------------------------------- /utils/imf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | 4 | 5 | def calculate_imf_map(x, y): 6 | imf_map = torch.zeros(256).cuda() 7 | r = 0 8 | for i in range(256): 9 | if x[i] == 0: 10 | imf_map[i] = -1 11 | else: 12 | p, v, j = x[i], 0, r 13 | while True: 14 | if y[j] < p: 15 | p = p - y[j] 16 | v = v + y[j] * j 17 | j += 1 18 | else: 19 | r = j 20 | y[j] = y[j] - p 21 | v = v + p * j 22 | imf_map[i] = (v / x[i]).round() 23 | break 24 | imf_map = imf_map.unsqueeze(dim=0) 25 | return imf_map 26 | 27 | 28 | def IMF2(ue, oe): 29 | B, C, H, W = ue.shape 30 | ue = (ue * 255).round() 31 | oe = (oe * 255).round() 32 | 33 | imf_map = [] 34 | ue_rgb = torch.split(ue, 1, dim=1) 35 | oe_rgb = torch.split(oe, 1, dim=1) 36 | imf_map = [ 37 | calculate_imf_map(torch.histc(x, bins=256, min=0, max=255), torch.histc(y, bins=256, min=0, max=255)) for x, y in zip(ue_rgb, oe_rgb) 38 | ] 39 | imf_map = torch.concat(imf_map, dim=0) 40 | 41 | zeros = torch.zeros([C, 1], dtype=torch.float32).cuda() 42 | imf_map = torch.concat((imf_map, zeros), 1) 43 | 44 | ue_imf = rearrange(ue.squeeze(), 'c h w -> c (h w)') 45 | ue_imf_floor = ue_imf.floor() 46 | for c in range(C): 47 | ind = ue_imf_floor[c].long() 48 | ue_imf[c, :] = (ue_imf[c, :] - ue_imf_floor[c, :]) * (imf_map[c, :][ind + 1] - imf_map[c, :][ind]) + imf_map[c, :][ind] 49 | ue_imf = rearrange(ue_imf, 'c (h w) -> c h w', h=H, w=W).unsqueeze(dim=0) 50 | ue_imf = (ue_imf / 255.).clamp(0, 1) 51 | return ue_imf -------------------------------------------------------------------------------- /val_nriqa.py: -------------------------------------------------------------------------------- 1 | import os,glob 2 | import torch 3 | import pyiqa 4 | from tqdm import tqdm 5 | 6 | 7 | class AverageMeter(object): 8 | def __init__(self): 9 | self.reset() 10 | 11 | def reset(self): 12 | self.val = 0 13 | self.avg = 0 14 | self.sum = 0 15 | self.count = 0 16 | self.max = -1 17 | self.min = 10000 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | if val > self.max: 25 | self.max = val 26 | if val < self.min: 27 | self.min = val 28 | 29 | def get_max(self): 30 | return self.max 31 | 32 | 33 | metric_list = { 34 | 'musiq': pyiqa.create_metric('musiq', as_loss=False).cuda(), 35 | 'paq2piq': pyiqa.create_metric('paq2piq', as_loss=False).cuda(), 36 | 'hyperiqa': pyiqa.create_metric('hyperiqa', as_loss=False).cuda(), 37 | } 38 | res_list = {} 39 | 40 | for k in metric_list: 41 | print('{} lower better: {}'.format(k, metric_list[k].lower_better)) 42 | res_list[k] = AverageMeter() 43 | 44 | img_list = glob.glob('/ailab/user/chenzixuan/Research/Diff-HDR/cvpr2025_release/MEFB/*out*.png') 45 | 46 | for img_path in tqdm(img_list): 47 | for k in metric_list: 48 | tmp = metric_list[k](img_path) 49 | res_list[k].update(tmp) 50 | 51 | for k in res_list: 52 | print('{}: {}'.format(k, res_list[k].avg)) --------------------------------------------------------------------------------