├── .gitignore ├── Figs ├── banner.png └── github_results.png ├── LICENSE ├── README.md ├── cldm ├── cldm.py ├── ddim_hacked.py ├── hack.py ├── logger.py └── model.py ├── config.py ├── control_depth_inpaint.yaml ├── docs ├── installation.md ├── manual.md └── meshgraphormer.md ├── handrefiner.py ├── ldm ├── data │ ├── __init__.py │ ├── control_synthcompositedata.py │ └── util.py ├── models │ ├── autoencoder.py │ └── diffusion │ │ ├── __init__.py │ │ ├── ddim.py │ │ ├── ddpm.py │ │ ├── dpm_solver │ │ ├── __init__.py │ │ ├── dpm_solver.py │ │ └── sampler.py │ │ ├── plms.py │ │ └── sampling_util.py ├── modules │ ├── attention.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── model.py │ │ ├── openaimodel.py │ │ ├── upscaling.py │ │ └── util.py │ ├── distributions │ │ ├── __init__.py │ │ └── distributions.py │ ├── ema.py │ ├── encoders │ │ ├── __init__.py │ │ └── modules.py │ ├── image_degradation │ │ ├── __init__.py │ │ ├── bsrgan.py │ │ ├── bsrgan_light.py │ │ ├── utils │ │ │ └── test.png │ │ └── utils_image.py │ └── midas │ │ ├── __init__.py │ │ ├── api.py │ │ ├── midas │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── blocks.py │ │ ├── dpt_depth.py │ │ ├── midas_net.py │ │ ├── midas_net_custom.py │ │ ├── transforms.py │ │ └── vit.py │ │ └── utils.py └── util.py ├── preprocessor ├── depth_preprocessor.py └── meshgraphormer.py ├── requirements.txt ├── scripts ├── _gcnn.py ├── _mano.py ├── config.py └── download_models.sh ├── test ├── 1.jpg ├── 2.jpg ├── 3.jpg ├── 4.jpg ├── 5.jpg └── test.json └── training ├── README.md ├── control_synthcompositedata.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | **/.DS_Store -------------------------------------------------------------------------------- /Figs/banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wenquanlu/HandRefiner/f07e196298bafe871064b2952587548e28ccd467/Figs/banner.png -------------------------------------------------------------------------------- /Figs/github_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wenquanlu/HandRefiner/f07e196298bafe871064b2952587548e28ccd467/Figs/github_results.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Wenquan Lu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

HandRefiner: Refining Malformed Hands in Generated Images by Diffusion-based Conditional Inpainting

2 |

3 | 4 | 5 | # News 6 | 7 | **2023.12.1** 8 | The paper is post on arxiv! 9 | 10 | **2023.12.29** 11 | First code commit released. 12 | 13 | **2024.1.7** 14 | The preprocessor and the finetuned model have been ported to [ComfyUI controlnet](https://github.com/Fannovel16/comfyui_controlnet_aux). The preprocessor has been ported to [sd webui controlnet](https://github.com/Mikubill/sd-webui-controlnet). Thanks for all your great work! 15 | 16 | **2024.1.15** 17 | ⚠️ When using finetuned ControlNet from this repository or [control_sd15_inpaint_depth_hand](https://huggingface.co/hr16/ControlNet-HandRefiner-pruned), I noticed many still use control strength/control weight of 1 which can result in loss of texture. As stated in the paper, we recommend using a smaller control strength (e.g. 0.4 - 0.8). 18 | 19 | # Introduction 20 | 21 | This is the official repository of the paper HandRefiner: Refining Malformed Hands in Generated Images by Diffusion-based Conditional Inpainting 22 | 23 |

24 | 25 |
Figure 1: Stable Diffusion (first two rows) and SDXL (last row) generate malformed hands (left in each pair), e.g., incorrect 26 | number of fingers or irregular shapes, which can be effectively rectified by our HandRefiner (right in each pair). 27 |
28 |
29 | 30 |

31 | 32 |

33 | In this study, we introduce a lightweight post-processing solution called HandRefiner to correct malformed hands in generated images. HandRefiner employs a conditional inpainting 34 | approach to rectify malformed hands while leaving other 35 | parts of the image untouched. We leverage the hand mesh 36 | reconstruction model that consistently adheres to the correct number of fingers and hand shape, while also being 37 | capable of fitting the desired hand pose in the generated 38 | image. Given a generated failed image due to malformed 39 | hands, we utilize ControlNet modules to re-inject such correct hand information. Additionally, we uncover a phase 40 | transition phenomenon within ControlNet as we vary the 41 | control strength. It enables us to take advantage of more 42 | readily available synthetic data without suffering from the 43 | domain gap between realistic and synthetic hands. 44 | 45 | # Visual Results 46 |

47 | 48 |
49 | 50 | # Installation 51 | Check [installation.md](docs/installation.md) for installation instructions. 52 | 53 | # Manual 54 | Check [manual.md](docs/manual.md) for an explanation of commands to execute the HandRefiner. 55 | 56 | # Get Started 57 | For single image rectification: 58 | ```bash 59 | python handrefiner.py --input_img test/1.jpg --out_dir output --strength 0.55 --weights models/inpaint_depth_control.ckpt --prompt "a man facing the camera, making a hand gesture, indoor" --seed 1 60 | ``` 61 | For multiple image rectifications: 62 | ```bash 63 | python handrefiner.py --input_dir test --out_dir output --strength 0.55 --weights models/inpaint_depth_control.ckpt --prompt_file test/test.json --seed 1 64 | ``` 65 | 66 | 67 | 68 | # Important Q&A 69 | 105 | 106 | ## Comments 107 | - Our codebase builds heavily on [stable-diffusion](https://github.com/CompVis/stable-diffusion), [ControlNet](https://github.com/lllyasviel/ControlNet) and [MeshGraphormer](https://github.com/microsoft/MeshGraphormer). 108 | 109 | ## Citation 110 | 111 | If you find HandRefiner helpful, please consider giving this repo a star :star: and citing: 112 | 113 | ``` 114 | @inproceedings{10.1145/3664647.3680693, 115 | author = {Lu, Wenquan and Xu, Yufei and Zhang, Jing and Wang, Chaoyue and Tao, Dacheng}, 116 | title = {HandRefiner: Refining Malformed Hands in Generated Images by Diffusion-based Conditional Inpainting}, 117 | year = {2024}, 118 | isbn = {9798400706868}, 119 | publisher = {Association for Computing Machinery}, 120 | address = {New York, NY, USA}, 121 | url = {https://doi.org/10.1145/3664647.3680693}, 122 | doi = {10.1145/3664647.3680693}, 123 | abstract = {Diffusion models have achieved remarkable success in generating realistic images but suffer from generating accurate human hands, such as incorrect finger counts or irregular shapes. This difficulty arises from the complex task of learning the physical structure and pose of hands from training images, which involves extensive deformations and occlusions. For correct hand generation, our paper introduces a lightweight post-processing solution called HandRefiner. HandRefiner employs a conditional inpainting approach to rectify malformed hands while leaving other parts of the image untouched. We leverage the hand mesh reconstruction model that consistently adheres to the correct number of fingers and hand shape, while also being capable of fitting the desired hand pose in the generated image. Given a generated failed image due to malformed hands, we utilize ControlNet modules to re-inject such correct hand information. Additionally, we uncover a phase transition phenomenon within ControlNet as we vary the control strength. It enables us to take advantage of more readily available synthetic data without suffering from the domain gap between realistic and synthetic hands. Experiments demonstrate that HandRefiner can significantly improve the generation quality quantitatively and qualitatively. The code is available at https://github.com/wenquanlu/HandRefiner.}, 124 | booktitle = {Proceedings of the 32nd ACM International Conference on Multimedia}, 125 | pages = {7085–7093}, 126 | numpages = {9}, 127 | keywords = {deep learning, diffusion models, image inpainting}, 128 | location = {Melbourne VIC, Australia}, 129 | series = {MM '24} 130 | } 131 | ``` 132 | -------------------------------------------------------------------------------- /cldm/hack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import einops 3 | 4 | import ldm.modules.encoders.modules 5 | import ldm.modules.attention 6 | 7 | from transformers import logging 8 | from ldm.modules.attention import default 9 | 10 | 11 | def disable_verbosity(): 12 | logging.set_verbosity_error() 13 | print('logging improved.') 14 | return 15 | 16 | 17 | def enable_sliced_attention(): 18 | ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward 19 | print('Enabled sliced_attention.') 20 | return 21 | 22 | 23 | def hack_everything(clip_skip=0): 24 | disable_verbosity() 25 | ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward 26 | ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip 27 | print('Enabled clip hacks.') 28 | return 29 | 30 | 31 | # Written by Lvmin 32 | def _hacked_clip_forward(self, text): 33 | PAD = self.tokenizer.pad_token_id 34 | EOS = self.tokenizer.eos_token_id 35 | BOS = self.tokenizer.bos_token_id 36 | 37 | def tokenize(t): 38 | return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"] 39 | 40 | def transformer_encode(t): 41 | if self.clip_skip > 1: 42 | rt = self.transformer(input_ids=t, output_hidden_states=True) 43 | return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip]) 44 | else: 45 | return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state 46 | 47 | def split(x): 48 | return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3] 49 | 50 | def pad(x, p, i): 51 | return x[:i] if len(x) >= i else x + [p] * (i - len(x)) 52 | 53 | raw_tokens_list = tokenize(text) 54 | tokens_list = [] 55 | 56 | for raw_tokens in raw_tokens_list: 57 | raw_tokens_123 = split(raw_tokens) 58 | raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123] 59 | raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123] 60 | tokens_list.append(raw_tokens_123) 61 | 62 | tokens_list = torch.IntTensor(tokens_list).to(self.device) 63 | 64 | feed = einops.rearrange(tokens_list, 'b f i -> (b f) i') 65 | y = transformer_encode(feed) 66 | z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3) 67 | 68 | return z 69 | 70 | 71 | # Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py 72 | def _hacked_sliced_attentin_forward(self, x, context=None, mask=None): 73 | h = self.heads 74 | 75 | q = self.to_q(x) 76 | context = default(context, x) 77 | k = self.to_k(context) 78 | v = self.to_v(context) 79 | del context, x 80 | 81 | q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 82 | 83 | limit = k.shape[0] 84 | att_step = 1 85 | q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0)) 86 | k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0)) 87 | v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0)) 88 | 89 | q_chunks.reverse() 90 | k_chunks.reverse() 91 | v_chunks.reverse() 92 | sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) 93 | del k, q, v 94 | for i in range(0, limit, att_step): 95 | q_buffer = q_chunks.pop() 96 | k_buffer = k_chunks.pop() 97 | v_buffer = v_chunks.pop() 98 | sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale 99 | 100 | del k_buffer, q_buffer 101 | # attention, what we cannot get enough of, by chunks 102 | 103 | sim_buffer = sim_buffer.softmax(dim=-1) 104 | 105 | sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer) 106 | del v_buffer 107 | sim[i:i + att_step, :, :] = sim_buffer 108 | 109 | del sim_buffer 110 | sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h) 111 | return self.to_out(sim) 112 | -------------------------------------------------------------------------------- /cldm/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import torchvision 6 | from PIL import Image 7 | from pytorch_lightning.callbacks import Callback 8 | from pytorch_lightning.utilities.distributed import rank_zero_only 9 | 10 | 11 | class ImageLogger(Callback): 12 | def __init__(self, batch_frequency=2000, max_images=4, clamp=True, increase_log_steps=True, 13 | rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False, 14 | log_images_kwargs=None): 15 | super().__init__() 16 | self.rescale = rescale 17 | self.batch_freq = batch_frequency 18 | self.max_images = max_images 19 | if not increase_log_steps: 20 | self.log_steps = [self.batch_freq] 21 | self.clamp = clamp 22 | self.disabled = disabled 23 | self.log_on_batch_idx = log_on_batch_idx 24 | self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} 25 | self.log_first_step = log_first_step 26 | 27 | @rank_zero_only 28 | def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx): 29 | root = os.path.join(save_dir, "image_log", split) 30 | for k in images: 31 | grid = torchvision.utils.make_grid(images[k], nrow=4) 32 | if self.rescale: 33 | grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w 34 | grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) 35 | grid = grid.numpy() 36 | grid = (grid * 255).astype(np.uint8) 37 | filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx) 38 | path = os.path.join(root, filename) 39 | os.makedirs(os.path.split(path)[0], exist_ok=True) 40 | Image.fromarray(grid).save(path) 41 | 42 | def log_img(self, pl_module, batch, batch_idx, split="train"): 43 | check_idx = batch_idx # if self.log_on_batch_idx else pl_module.global_step 44 | if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0 45 | hasattr(pl_module, "log_images") and 46 | callable(pl_module.log_images) and 47 | self.max_images > 0): 48 | logger = type(pl_module.logger) 49 | 50 | is_train = pl_module.training 51 | if is_train: 52 | pl_module.eval() 53 | 54 | with torch.no_grad(): 55 | images = pl_module.log_images(batch, split=split, **self.log_images_kwargs) 56 | 57 | for k in images: 58 | N = min(images[k].shape[0], self.max_images) 59 | images[k] = images[k][:N] 60 | if isinstance(images[k], torch.Tensor): 61 | images[k] = images[k].detach().cpu() 62 | if self.clamp: 63 | images[k] = torch.clamp(images[k], -1., 1.) 64 | 65 | self.log_local(pl_module.logger.save_dir, split, images, 66 | pl_module.global_step, pl_module.current_epoch, batch_idx) 67 | 68 | if is_train: 69 | pl_module.train() 70 | 71 | def check_frequency(self, check_idx): 72 | return check_idx % self.batch_freq == 0 73 | 74 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 75 | if not self.disabled: 76 | self.log_img(pl_module, batch, batch_idx, split="train") 77 | -------------------------------------------------------------------------------- /cldm/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from omegaconf import OmegaConf 5 | from ldm.util import instantiate_from_config 6 | 7 | 8 | def get_state_dict(d): 9 | return d.get('state_dict', d) 10 | 11 | 12 | def load_state_dict(ckpt_path, location='cpu'): 13 | _, extension = os.path.splitext(ckpt_path) 14 | if extension.lower() == ".safetensors": 15 | import safetensors.torch 16 | state_dict = safetensors.torch.load_file(ckpt_path, device=location) 17 | else: 18 | state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location))) 19 | state_dict = get_state_dict(state_dict) 20 | print(f'Loaded state_dict from [{ckpt_path}]') 21 | return state_dict 22 | 23 | 24 | def create_model(config_path): 25 | config = OmegaConf.load(config_path) 26 | model = instantiate_from_config(config.model).cpu() 27 | print(f'Loaded model config from [{config_path}]') 28 | return model 29 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | save_memory = False 3 | handrefiner_root=str(Path(__file__).parent) -------------------------------------------------------------------------------- /control_depth_inpaint.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: cldm.cldm.ControlLDM 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.0120 6 | num_timesteps_cond: 1 7 | log_every_t: 200 8 | timesteps: 1000 9 | first_stage_key: "jpg" 10 | cond_stage_key: "txt" 11 | control_key: "hint" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false 15 | conditioning_key: hybrid 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | only_mid_control: False 20 | 21 | control_stage_config: 22 | target: cldm.cldm.ControlNet 23 | params: 24 | image_size: 32 # unused 25 | in_channels: 4 26 | hint_channels: 3 27 | model_channels: 320 28 | attention_resolutions: [ 4, 2, 1 ] 29 | num_res_blocks: 2 30 | channel_mult: [ 1, 2, 4, 4 ] 31 | num_heads: 8 32 | use_spatial_transformer: True 33 | transformer_depth: 1 34 | context_dim: 768 35 | use_checkpoint: True 36 | legacy: False 37 | 38 | unet_config: 39 | target: cldm.cldm.ControlledUnetModel 40 | params: 41 | image_size: 32 # unused 42 | in_channels: 9 43 | out_channels: 4 44 | model_channels: 320 45 | attention_resolutions: [ 4, 2, 1 ] 46 | num_res_blocks: 2 47 | channel_mult: [ 1, 2, 4, 4 ] 48 | num_heads: 8 49 | use_spatial_transformer: True 50 | transformer_depth: 1 51 | context_dim: 768 52 | use_checkpoint: True 53 | legacy: False 54 | 55 | first_stage_config: 56 | target: ldm.models.autoencoder.AutoencoderKL 57 | params: 58 | embed_dim: 4 59 | monitor: val/rec_loss 60 | ddconfig: 61 | double_z: true 62 | z_channels: 4 63 | resolution: 256 64 | in_channels: 3 65 | out_ch: 3 66 | ch: 128 67 | ch_mult: 68 | - 1 69 | - 2 70 | - 4 71 | - 4 72 | num_res_blocks: 2 73 | attn_resolutions: [] 74 | dropout: 0.0 75 | lossconfig: 76 | target: torch.nn.Identity 77 | 78 | cond_stage_config: 79 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 80 | -------------------------------------------------------------------------------- /docs/installation.md: -------------------------------------------------------------------------------- 1 | ## Installation Instructions 2 | 3 | 1. Clone HandRefiner to your local repository 4 | 2. Install MeshGraphormer to HandRefiner/MeshGraphormer following instructions in [meshgraphormer.md](meshgraphormer.md). (If encountrer any error, you can also refer to original documentations in the Meshgraphormer). 5 | Please also comply to Mesh Graphormer's license when using it in this project. 6 | 3. Make sure you are on the 'HandRefiner/' directory for the following steps, refer to [requirements.txt](../requirements.txt) for packages required for the project. 7 | 4. Install Mediapipe: 8 | ```bash 9 | pip install -q mediapipe==0.10.0 10 | cd preprocessor 11 | wget https://storage.googleapis.com/mediapipe-models/hand_landmarker/hand_landmarker/float16/1/hand_landmarker.task 12 | ``` 13 | ``` 14 | 5. Download weights, there are two sets of weights can be used: 15 | - Inpaint Stable Diffusion weights [sd-v1-5-inpainting.ckpt](https://huggingface.co/runwayml/stable-diffusion-inpainting/blob/main/sd-v1-5-inpainting.ckpt) and Depth controlnet weights [control_v11f1p_sd15_depth.pth](https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11f1p_sd15_depth.pth). Put sd-v1-5-inpainting.ckpt and control_v11f1p_sd15_depth.pth in HandRefiner/models/ folder. To use these weights, set --finetuned flag to False when executing the HandRefiner. 16 | - Finetuned weights [inpaint_depth_control.ckpt](https://drive.google.com/file/d/1eD2Lnfk0KZols68mVahcVfNx3GnYdHxo/view?usp=sharing) as introduced in the paper. Put inpaint_depth_control.ckpt in the HandRefiner/models/ folder. A control strength of 0.4 - 0.8 is recommended for the finetuned weights, we use 0.55 in the evaluation of paper. Alternatively, adaptive control strength can be used by setting --adaptive_control flag to True, though the inference time is much longer. 17 | 18 | Finetuned weights are more adaptable to complex gestures, and their inpainting is more harmonious. You can also attempt to use original weights while the failure rate could be higher. 19 | 20 | 6. Test if installation succeeds: 21 | 22 | For single image rectification: 23 | ```bash 24 | python handrefiner.py --input_img test/1.jpg --out_dir output --strength 0.55 --weights models/inpaint_depth_control.ckpt --prompt "a man facing the camera, making a hand gesture, indoor" --seed 1 25 | ``` 26 | For multiple image rectifications: 27 | ```bash 28 | python handrefiner.py --input_dir test --out_dir output --strength 0.55 --weights models/inpaint_depth_control.ckpt --prompt_file test/test.json --seed 1 29 | ``` 30 | 31 | -------------------------------------------------------------------------------- /docs/manual.md: -------------------------------------------------------------------------------- 1 | ## Manual 2 | Arguments for executing HandRefiner.py: 3 | 71 | -------------------------------------------------------------------------------- /docs/meshgraphormer.md: -------------------------------------------------------------------------------- 1 | # MeshGraphormer Instructions for HandRefiner 2 | 3 | ## Installation 4 | 5 | ### Requirements 6 | 7 | 8 | 9 | Install the MeshGraphormer to HandRefiner/MeshGraphormer 10 | 11 | ```bash 12 | git clone --recursive https://github.com/microsoft/MeshGraphormer.git 13 | cd MeshGraphormer 14 | pip install ./manopth/. 15 | ``` 16 | 17 | 18 | ## Download 19 | Make sure you are on 'HandRefiner/MeshGraphormer' directory for the following steps 20 | 1. Create folder that store pretrained models. 21 | ```bash 22 | mkdir -p models # pre-trained models 23 | ``` 24 | 25 | 2. Download pretrained models, and some code modifications. 26 | 27 | ```bash 28 | cp ../scripts/download_models.sh scripts/download_models.sh 29 | cp ../scripts/_gcnn.py src/modeling/_gcnn.py 30 | cp ../scripts/_mano.py src/modeling/_mano.py 31 | cp ../scripts/config.py src/modeling/data/config.py 32 | bash scripts/download_models.sh 33 | ``` 34 | 35 | The resulting data structure should follow the hierarchy as below. 36 | ``` 37 | MeshGraphormer 38 | |-- models 39 | | |-- graphormer_release 40 | | | |-- graphormer_hand_state_dict.bin 41 | | |-- hrnet 42 | | | |-- hrnetv2_w64_imagenet_pretrained.pth 43 | | | |-- cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml 44 | |-- src 45 | |-- datasets 46 | |-- predictions 47 | |-- README.md 48 | |-- ... 49 | |-- ... 50 | ``` 51 | 52 | 3. Download MANO model from their official websites 53 | 54 | - Download `MANO_RIGHT.pkl` from [MANO](https://mano.is.tue.mpg.de/), and place it at `MeshGraphormer/src/modeling/data`. 55 | 56 | Please put the downloaded files under the `MeshGraphormer/src/modeling/data` directory. The data structure should follow the hierarchy below. 57 | ``` 58 | MeshGraphormer 59 | |-- src 60 | | |-- modeling 61 | | | |-- data 62 | | | | |-- MANO_RIGHT.pkl 63 | |-- models 64 | |-- datasets 65 | |-- predictions 66 | |-- README.md 67 | |-- ... 68 | |-- ... 69 | ``` 70 | 4. exit the MeshGraphormer directory when finished 71 | ```bash 72 | cd .. 73 | ``` -------------------------------------------------------------------------------- /handrefiner.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | # STEP 1: Import the necessary modules. 5 | from __future__ import absolute_import, division, print_function 6 | import sys 7 | from config import handrefiner_root 8 | import os 9 | 10 | def load(): 11 | paths = [handrefiner_root, os.path.join(handrefiner_root, 'MeshGraphormer'), os.path.join(handrefiner_root, 'preprocessor')] 12 | for p in paths: 13 | sys.path.insert(0, p) 14 | 15 | load() 16 | 17 | import argparse 18 | import json 19 | import torch 20 | import numpy as np 21 | import cv2 22 | 23 | from PIL import Image 24 | from torchvision import transforms 25 | import numpy as np 26 | import cv2 27 | 28 | from pytorch_lightning import seed_everything 29 | from cldm.model import create_model, load_state_dict 30 | from cldm.ddim_hacked import DDIMSampler 31 | import config 32 | 33 | import cv2 34 | import einops 35 | import numpy as np 36 | import torch 37 | import random 38 | from pathlib import Path 39 | from preprocessor.meshgraphormer import MeshGraphormerMediapipe 40 | import ast 41 | 42 | transform = transforms.Compose([ 43 | transforms.ToTensor(), 44 | transforms.Normalize( 45 | mean=[0.485, 0.456, 0.406], 46 | std=[0.229, 0.224, 0.225])]) 47 | 48 | def parse_args(): 49 | parser = argparse.ArgumentParser() 50 | 51 | # input directory containing images to be rectified 52 | parser.add_argument('--input_dir', type=str, default="") 53 | 54 | # input image 55 | parser.add_argument('--input_img', type=str, default="") 56 | 57 | # output directory where the rectified images will be saved to 58 | parser.add_argument('--out_dir', type=str, default="") 59 | 60 | # file where the mpjpe values will be logged to 61 | parser.add_argument('--log_json', type=str, default="") 62 | 63 | # control strength for ControlNet 64 | parser.add_argument('--strength', type=float, default=1.0) 65 | 66 | # directory where the depth maps will be saved to. Leaving it empty will disable this function 67 | parser.add_argument('--depth_dir', type=str, default="") 68 | 69 | # directory where the masks will be saved to. Leaving it empty will disable this function 70 | parser.add_argument('--mask_dir', type=str, default="") 71 | 72 | # whether evaluate the mpjpe error in fixed control strength mode 73 | parser.add_argument('--eval', type=ast.literal_eval, default=False) 74 | 75 | # whether use finetuned ControlNet trained on synthetic images as introduced in the paper 76 | parser.add_argument('--finetuned', type=ast.literal_eval, default=True) 77 | 78 | # path to the SD + ControlNet weights 79 | parser.add_argument('--weights', type=str, default="") 80 | 81 | # batch size 82 | parser.add_argument('--num_samples', type=int, default=1) 83 | 84 | # prompt file for multi-image rectification 85 | # see manual.md for file format 86 | parser.add_argument('--prompt_file', type=str, default="") 87 | 88 | # prompt for single image rectification 89 | parser.add_argument('--prompt', type=str, default="") 90 | 91 | # number of generation iteration for each image to be rectified 92 | # in general, for each input image, n_iter x num_samples number of rectified images will be produced 93 | parser.add_argument('--n_iter', type=int, default=1) 94 | 95 | # adaptive control strength as introduced in paper (we tend to use fixed control strength as default) 96 | parser.add_argument('--adaptive_control', type=ast.literal_eval, default=False) 97 | 98 | # padding controls the size of masks around the hand 99 | parser.add_argument('--padding_bbox', type=int, default=30) 100 | 101 | # set seed 102 | parser.add_argument('--seed', type=int, default=-1) 103 | args = parser.parse_args() 104 | return args 105 | 106 | args = parse_args() 107 | 108 | if (args.prompt_file != "" and args.prompt != "") or (args.prompt_file == "" and args.prompt == ""): 109 | raise Exception("Please specify one and only one of the --prompt and --prompt_file") 110 | if (args.input_dir != "" and args.input_img != "") or (args.input_dir == "" and args.input_img == ""): 111 | raise Exception("Please specify one and only one of the --input_dir and --input_img") 112 | 113 | model = create_model("control_depth_inpaint.yaml").cpu() 114 | if args.finetuned: 115 | model.load_state_dict(load_state_dict(args.weights, location='cuda'), strict=False) 116 | else: 117 | model.load_state_dict( 118 | load_state_dict("models/sd-v1-5-inpainting.ckpt", location="cuda"), strict=False 119 | ) 120 | model.load_state_dict( 121 | load_state_dict("models/control_v11f1p_sd15_depth.pth", location="cuda"), 122 | strict=False, 123 | ) 124 | 125 | model = model.to("cuda") 126 | 127 | meshgraphormer = MeshGraphormerMediapipe() 128 | 129 | if args.log_json != "": 130 | f_mpjpe = open(args.log_json, 'w') 131 | 132 | 133 | # prompt needs to be same for all pictures in the same batch 134 | if args.input_img != "": 135 | assert args.prompt_file == "", "prompt file should not be used for single image rectification" 136 | inputs = [args.input_img] 137 | else: 138 | if args.prompt_file != "": 139 | f_prompt = open(args.prompt_file) 140 | inputs = f_prompt.readlines() 141 | else: 142 | inputs = os.listdir(args.input_dir) 143 | 144 | for file_info in inputs: 145 | if args.prompt_file != "": 146 | file_info = json.loads(file_info) 147 | file_name = file_info["img"] 148 | prompt = file_info["txt"] 149 | else: 150 | file_name = file_info 151 | prompt = args.prompt 152 | 153 | image_file = os.path.join(args.input_dir, file_name) 154 | 155 | file_name_raw = Path(file_name).stem 156 | 157 | # STEP 3: Load the input image. 158 | image = np.array(Image.open(image_file)) 159 | 160 | raw_image = image 161 | H, W, C = raw_image.shape 162 | gen_count = 0 163 | for iteration in range(args.n_iter): 164 | 165 | depthmap, mask, info = meshgraphormer.get_depth(args.input_dir, file_name, args.padding_bbox) 166 | 167 | if args.depth_dir != "": 168 | cv2.imwrite(os.path.join(args.depth_dir, file_name_raw + "_depth.jpg"), depthmap) 169 | if args.mask_dir != "": 170 | cv2.imwrite(os.path.join(args.mask_dir, file_name_raw + "_mask.jpg"), mask) 171 | 172 | control = depthmap 173 | 174 | ddim_sampler = DDIMSampler(model) 175 | num_samples = args.num_samples 176 | ddim_steps = 50 177 | guess_mode = False 178 | strength = args.strength 179 | scale = 9.0 180 | seed = args.seed 181 | 182 | label = file_name[:2] 183 | a_prompt = "realistic, best quality, extremely detailed" 184 | n_prompt = "fake 3D rendered image, longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, blue" 185 | 186 | source = raw_image 187 | 188 | source = (source.astype(np.float32) / 127.5) - 1.0 189 | source = source.transpose([2, 0, 1]) # source is c h w 190 | 191 | mask = mask.astype(np.float32) / 255.0 192 | mask = mask[None] 193 | mask[mask < 0.5] = 0 194 | mask[mask >= 0.5] = 1 195 | 196 | hint = control.astype(np.float32) / 255.0 197 | 198 | masked_image = source * (mask < 0.5) # masked image is c h w 199 | 200 | mask = torch.stack([torch.tensor(mask) for _ in range(num_samples)], dim=0).to("cuda") 201 | mask = torch.nn.functional.interpolate(mask, size=(64, 64)) 202 | 203 | if seed == -1: 204 | seed = random.randint(0, 65535) 205 | seed_everything(seed) 206 | 207 | if config.save_memory: 208 | model.low_vram_shift(is_diffusing=False) 209 | 210 | masked_image = torch.stack( 211 | [torch.tensor(masked_image) for _ in range(num_samples)], dim=0 212 | ).to("cuda") 213 | 214 | # this should be b,c,h,w 215 | masked_image = model.get_first_stage_encoding(model.encode_first_stage(masked_image)) 216 | 217 | x = torch.stack([torch.tensor(source) for _ in range(num_samples)], dim=0).to("cuda") 218 | z = model.get_first_stage_encoding(model.encode_first_stage(x)) 219 | 220 | cats = torch.cat([mask, masked_image], dim=1) 221 | 222 | hint = hint[ 223 | None, 224 | ].repeat(3, axis=0) 225 | 226 | hint = torch.stack([torch.tensor(hint) for _ in range(num_samples)], dim=0).to("cuda") 227 | 228 | cond = { 229 | "c_concat": [cats], 230 | "c_control": [hint], 231 | "c_crossattn": [model.get_learned_conditioning([prompt] * num_samples)], 232 | } 233 | un_cond = { 234 | "c_concat": [cats], 235 | "c_control": [hint], 236 | "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)], 237 | } 238 | 239 | 240 | shape = (4, H // 8, W // 8) 241 | 242 | if config.save_memory: 243 | model.low_vram_shift(is_diffusing=True) 244 | 245 | if not args.adaptive_control: 246 | seed_everything(seed) 247 | model.control_scales = ( 248 | [strength * (0.825 ** float(12 - i)) for i in range(13)] 249 | if guess_mode 250 | else ([strength] * 13) 251 | ) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01 252 | samples, intermediates = ddim_sampler.sample( 253 | ddim_steps, 254 | num_samples, 255 | shape, 256 | cond, 257 | verbose=False, 258 | unconditional_guidance_scale=scale, 259 | unconditional_conditioning=un_cond, 260 | x0=z, 261 | mask=mask 262 | ) 263 | if config.save_memory: 264 | model.low_vram_shift(is_diffusing=False) 265 | 266 | x_samples = model.decode_first_stage(samples) 267 | # print(x_samples.shape) 268 | x_samples = ( 269 | (einops.rearrange(x_samples, "b c h w -> b h w c") * 127.5 + 127.5) 270 | .cpu() 271 | .numpy() 272 | .clip(0, 255) 273 | .astype(np.uint8) 274 | ) 275 | 276 | if args.eval: # currently only works for batch size of 1 277 | assert args.num_samples == 1, "MPJPE evaluation currently only works for batch size of 1" 278 | mpjpe = meshgraphormer.eval_mpjpe(x_samples[0], info) 279 | print(mpjpe) 280 | if args.log_json != "": 281 | mpjpe_info = {"img": image_file, "strength": strength, "mpjpje": mpjpe} 282 | f_mpjpe.write(json.dumps(mpjpe_info)) 283 | f_mpjpe.write("\n") 284 | for i in range(args.num_samples): 285 | cv2.imwrite( 286 | os.path.join(args.out_dir, "{}_{}.jpg".format(file_name_raw, gen_count)), cv2.cvtColor(x_samples[i], cv2.COLOR_RGB2BGR) 287 | ) 288 | gen_count += 1 289 | else: 290 | assert args.num_samples == 1, "Adaptive thresholding currently only works for batch size of 1" 291 | strengths = [1.0, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] 292 | ref_mpjpe = None 293 | chosen_strength = None 294 | final_mpjpe = None 295 | chosen_sample = None 296 | count = 0 297 | for strength in strengths: 298 | seed_everything(seed) 299 | model.control_scales = ( 300 | [strength * (0.825 ** float(12 - i)) for i in range(13)] 301 | if guess_mode 302 | else ([strength] * 13) 303 | ) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01 304 | samples, intermediates = ddim_sampler.sample( 305 | ddim_steps, 306 | num_samples, 307 | shape, 308 | cond, 309 | verbose=False, 310 | unconditional_guidance_scale=scale, 311 | unconditional_conditioning=un_cond, 312 | x0=z, 313 | mask=mask 314 | ) 315 | if config.save_memory: 316 | model.low_vram_shift(is_diffusing=False) 317 | 318 | x_samples = model.decode_first_stage(samples) 319 | 320 | x_samples = ( 321 | (einops.rearrange(x_samples, "b c h w -> b h w c") * 127.5 + 127.5) 322 | .cpu() 323 | .numpy() 324 | .clip(0, 255) 325 | .astype(np.uint8) 326 | ) 327 | mpjpe = meshgraphormer.eval_mpjpe(x_samples[0], info) 328 | if count == 0: 329 | ref_mpjpe = mpjpe 330 | chosen_sample = x_samples[0] 331 | elif mpjpe < ref_mpjpe * 1.15: 332 | chosen_strength = strength 333 | final_mpjpe = mpjpe 334 | chosen_sample = x_samples[0] 335 | break 336 | elif strength == 0.9: 337 | final_mpjpe = ref_mpjpe 338 | chosen_strength = 1.0 339 | count += 1 340 | 341 | if args.log_json != "": 342 | mpjpe_info = {"img": image_file, "strength": chosen_strength, "mpjpje": final_mpjpe} 343 | f_mpjpe.write(json.dumps(mpjpe_info)) 344 | f_mpjpe.write("\n") 345 | 346 | cv2.imwrite( 347 | os.path.join(args.out_dir, "{}_{}.jpg".format(file_name_raw, gen_count)), cv2.cvtColor(x_samples[0], cv2.COLOR_RGB2BGR) 348 | ) 349 | gen_count += 1 350 | 351 | 352 | -------------------------------------------------------------------------------- /ldm/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wenquanlu/HandRefiner/f07e196298bafe871064b2952587548e28ccd467/ldm/data/__init__.py -------------------------------------------------------------------------------- /ldm/data/control_synthcompositedata.py: -------------------------------------------------------------------------------- 1 | import json 2 | import cv2 3 | import numpy as np 4 | from PIL import Image 5 | import random 6 | 7 | from torch.utils.data import Dataset 8 | 9 | DATA_PATH_1 = "/raid/wenquanlu/RHD/RHD_published_v2/" 10 | DATA_PATH_2 = "/raid/wenquanlu/synthesisai/" 11 | 12 | abbrev_dict = {"RHD": DATA_PATH_1, 13 | "synthesisai": DATA_PATH_2} 14 | 15 | class Control_composite_Hand_synth_data(Dataset): 16 | def __init__(self): 17 | self.data = [] 18 | with open('../RHD/RHD_published_v2/rgb_caption.json', 'rt') as f1: 19 | for line in f1: 20 | item = json.loads(line) 21 | item['dataset'] = 'RHD' 22 | self.data.append(item) 23 | with open('../synthesisai/rgb_caption.json', 'rt') as f2: 24 | for line in f2: 25 | item = json.loads(line) 26 | item['dataset'] = 'synthesisai' 27 | self.data.append(item) 28 | def __len__(self): 29 | return len(self.data) 30 | 31 | def __getitem__(self, idx): 32 | item = self.data[idx] 33 | source_filename = item['jpg'] 34 | prompt = item['txt'] 35 | dataset = item['dataset'] 36 | datapath = abbrev_dict[dataset] 37 | if random.random() < 0.5: 38 | prompt = "" 39 | source = cv2.imread(datapath + "image/" + source_filename) 40 | source = (source.astype(np.float32) / 127.5) - 1.0 41 | source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB) 42 | 43 | mask = np.array(Image.open(datapath + "mask/" + source_filename).convert("L")) 44 | mask = mask.astype(np.float32)/255.0 45 | mask = mask[None] 46 | mask[mask < 0.5] = 0 47 | mask[mask >= 0.5] = 1 48 | mask = np.transpose(mask, [1, 2, 0]) 49 | 50 | hint = cv2.imread(datapath + "pose/" + source_filename) 51 | hint = cv2.cvtColor(hint, cv2.COLOR_BGR2RGB) 52 | 53 | hint = hint.astype(np.float32) / 255.0 54 | 55 | masked_image = source * (mask < 0.5) 56 | return dict(jpg=source, txt=prompt, hint=hint, mask=mask, masked_image=masked_image) -------------------------------------------------------------------------------- /ldm/data/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ldm.modules.midas.api import load_midas_transform 4 | 5 | 6 | class AddMiDaS(object): 7 | def __init__(self, model_type): 8 | super().__init__() 9 | self.transform = load_midas_transform(model_type) 10 | 11 | def pt2np(self, x): 12 | x = ((x + 1.0) * .5).detach().cpu().numpy() 13 | return x 14 | 15 | def np2pt(self, x): 16 | x = torch.from_numpy(x) * 2 - 1. 17 | return x 18 | 19 | def __call__(self, sample): 20 | # sample['jpg'] is tensor hwc in [-1, 1] at this point 21 | x = self.pt2np(sample['jpg']) 22 | x = self.transform({"image": x})["image"] 23 | sample['midas_in'] = x 24 | return sample -------------------------------------------------------------------------------- /ldm/models/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | import torch.nn.functional as F 4 | from contextlib import contextmanager 5 | 6 | from ldm.modules.diffusionmodules.model import Encoder, Decoder 7 | from ldm.modules.distributions.distributions import DiagonalGaussianDistribution 8 | 9 | from ldm.util import instantiate_from_config 10 | from ldm.modules.ema import LitEma 11 | 12 | 13 | class AutoencoderKL(pl.LightningModule): 14 | def __init__(self, 15 | ddconfig, 16 | lossconfig, 17 | embed_dim, 18 | ckpt_path=None, 19 | ignore_keys=[], 20 | image_key="image", 21 | colorize_nlabels=None, 22 | monitor=None, 23 | ema_decay=None, 24 | learn_logvar=False 25 | ): 26 | super().__init__() 27 | self.learn_logvar = learn_logvar 28 | self.image_key = image_key 29 | self.encoder = Encoder(**ddconfig) 30 | self.decoder = Decoder(**ddconfig) 31 | self.loss = instantiate_from_config(lossconfig) 32 | assert ddconfig["double_z"] 33 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) 34 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 35 | self.embed_dim = embed_dim 36 | if colorize_nlabels is not None: 37 | assert type(colorize_nlabels)==int 38 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 39 | if monitor is not None: 40 | self.monitor = monitor 41 | 42 | self.use_ema = ema_decay is not None 43 | if self.use_ema: 44 | self.ema_decay = ema_decay 45 | assert 0. < ema_decay < 1. 46 | self.model_ema = LitEma(self, decay=ema_decay) 47 | print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") 48 | 49 | if ckpt_path is not None: 50 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 51 | 52 | def init_from_ckpt(self, path, ignore_keys=list()): 53 | sd = torch.load(path, map_location="cpu")["state_dict"] 54 | keys = list(sd.keys()) 55 | for k in keys: 56 | for ik in ignore_keys: 57 | if k.startswith(ik): 58 | print("Deleting key {} from state_dict.".format(k)) 59 | del sd[k] 60 | self.load_state_dict(sd, strict=False) 61 | print(f"Restored from {path}") 62 | 63 | @contextmanager 64 | def ema_scope(self, context=None): 65 | if self.use_ema: 66 | self.model_ema.store(self.parameters()) 67 | self.model_ema.copy_to(self) 68 | if context is not None: 69 | print(f"{context}: Switched to EMA weights") 70 | try: 71 | yield None 72 | finally: 73 | if self.use_ema: 74 | self.model_ema.restore(self.parameters()) 75 | if context is not None: 76 | print(f"{context}: Restored training weights") 77 | 78 | def on_train_batch_end(self, *args, **kwargs): 79 | if self.use_ema: 80 | self.model_ema(self) 81 | 82 | def encode(self, x): 83 | h = self.encoder(x) 84 | moments = self.quant_conv(h) 85 | posterior = DiagonalGaussianDistribution(moments) 86 | return posterior 87 | 88 | def decode(self, z): 89 | z = self.post_quant_conv(z) 90 | dec = self.decoder(z) 91 | return dec 92 | 93 | def forward(self, input, sample_posterior=True): 94 | posterior = self.encode(input) 95 | if sample_posterior: 96 | z = posterior.sample() 97 | else: 98 | z = posterior.mode() 99 | dec = self.decode(z) 100 | return dec, posterior 101 | 102 | def get_input(self, batch, k): 103 | x = batch[k] 104 | if len(x.shape) == 3: 105 | x = x[..., None] 106 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() 107 | return x 108 | 109 | def training_step(self, batch, batch_idx, optimizer_idx): 110 | inputs = self.get_input(batch, self.image_key) 111 | reconstructions, posterior = self(inputs) 112 | 113 | if optimizer_idx == 0: 114 | # train encoder+decoder+logvar 115 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 116 | last_layer=self.get_last_layer(), split="train") 117 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 118 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) 119 | return aeloss 120 | 121 | if optimizer_idx == 1: 122 | # train the discriminator 123 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 124 | last_layer=self.get_last_layer(), split="train") 125 | 126 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 127 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) 128 | return discloss 129 | 130 | def validation_step(self, batch, batch_idx): 131 | log_dict = self._validation_step(batch, batch_idx) 132 | with self.ema_scope(): 133 | log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") 134 | return log_dict 135 | 136 | def _validation_step(self, batch, batch_idx, postfix=""): 137 | inputs = self.get_input(batch, self.image_key) 138 | reconstructions, posterior = self(inputs) 139 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, 140 | last_layer=self.get_last_layer(), split="val"+postfix) 141 | 142 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, 143 | last_layer=self.get_last_layer(), split="val"+postfix) 144 | 145 | self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) 146 | self.log_dict(log_dict_ae) 147 | self.log_dict(log_dict_disc) 148 | return self.log_dict 149 | 150 | def configure_optimizers(self): 151 | lr = self.learning_rate 152 | ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list( 153 | self.quant_conv.parameters()) + list(self.post_quant_conv.parameters()) 154 | if self.learn_logvar: 155 | print(f"{self.__class__.__name__}: Learning logvar") 156 | ae_params_list.append(self.loss.logvar) 157 | opt_ae = torch.optim.Adam(ae_params_list, 158 | lr=lr, betas=(0.5, 0.9)) 159 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), 160 | lr=lr, betas=(0.5, 0.9)) 161 | return [opt_ae, opt_disc], [] 162 | 163 | def get_last_layer(self): 164 | return self.decoder.conv_out.weight 165 | 166 | @torch.no_grad() 167 | def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): 168 | log = dict() 169 | x = self.get_input(batch, self.image_key) 170 | x = x.to(self.device) 171 | if not only_inputs: 172 | xrec, posterior = self(x) 173 | if x.shape[1] > 3: 174 | # colorize with random projection 175 | assert xrec.shape[1] > 3 176 | x = self.to_rgb(x) 177 | xrec = self.to_rgb(xrec) 178 | log["samples"] = self.decode(torch.randn_like(posterior.sample())) 179 | log["reconstructions"] = xrec 180 | if log_ema or self.use_ema: 181 | with self.ema_scope(): 182 | xrec_ema, posterior_ema = self(x) 183 | if x.shape[1] > 3: 184 | # colorize with random projection 185 | assert xrec_ema.shape[1] > 3 186 | xrec_ema = self.to_rgb(xrec_ema) 187 | log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample())) 188 | log["reconstructions_ema"] = xrec_ema 189 | log["inputs"] = x 190 | return log 191 | 192 | def to_rgb(self, x): 193 | assert self.image_key == "segmentation" 194 | if not hasattr(self, "colorize"): 195 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 196 | x = F.conv2d(x, weight=self.colorize) 197 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1. 198 | return x 199 | 200 | 201 | class IdentityFirstStage(torch.nn.Module): 202 | def __init__(self, *args, vq_interface=False, **kwargs): 203 | self.vq_interface = vq_interface 204 | super().__init__() 205 | 206 | def encode(self, x, *args, **kwargs): 207 | return x 208 | 209 | def decode(self, x, *args, **kwargs): 210 | return x 211 | 212 | def quantize(self, x, *args, **kwargs): 213 | if self.vq_interface: 214 | return x, None, [None, None, None] 215 | return x 216 | 217 | def forward(self, x, *args, **kwargs): 218 | return x 219 | 220 | -------------------------------------------------------------------------------- /ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wenquanlu/HandRefiner/f07e196298bafe871064b2952587548e28ccd467/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler import DPMSolverSampler -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/sampler.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | import torch 3 | 4 | from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver 5 | 6 | 7 | MODEL_TYPES = { 8 | "eps": "noise", 9 | "v": "v" 10 | } 11 | 12 | 13 | class DPMSolverSampler(object): 14 | def __init__(self, model, **kwargs): 15 | super().__init__() 16 | self.model = model 17 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) 18 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) 19 | 20 | def register_buffer(self, name, attr): 21 | if type(attr) == torch.Tensor: 22 | if attr.device != torch.device("cuda"): 23 | attr = attr.to(torch.device("cuda")) 24 | setattr(self, name, attr) 25 | 26 | @torch.no_grad() 27 | def sample(self, 28 | S, 29 | batch_size, 30 | shape, 31 | conditioning=None, 32 | callback=None, 33 | normals_sequence=None, 34 | img_callback=None, 35 | quantize_x0=False, 36 | eta=0., 37 | mask=None, 38 | x0=None, 39 | temperature=1., 40 | noise_dropout=0., 41 | score_corrector=None, 42 | corrector_kwargs=None, 43 | verbose=True, 44 | x_T=None, 45 | log_every_t=100, 46 | unconditional_guidance_scale=1., 47 | unconditional_conditioning=None, 48 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 49 | **kwargs 50 | ): 51 | if conditioning is not None: 52 | if isinstance(conditioning, dict): 53 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 54 | if cbs != batch_size: 55 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 56 | else: 57 | if conditioning.shape[0] != batch_size: 58 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 59 | 60 | # sampling 61 | C, H, W = shape 62 | size = (batch_size, C, H, W) 63 | 64 | print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') 65 | 66 | device = self.model.betas.device 67 | if x_T is None: 68 | img = torch.randn(size, device=device) 69 | else: 70 | img = x_T 71 | 72 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) 73 | 74 | model_fn = model_wrapper( 75 | lambda x, t, c: self.model.apply_model(x, t, c), 76 | ns, 77 | model_type=MODEL_TYPES[self.model.parameterization], 78 | guidance_type="classifier-free", 79 | condition=conditioning, 80 | unconditional_condition=unconditional_conditioning, 81 | guidance_scale=unconditional_guidance_scale, 82 | ) 83 | 84 | dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) 85 | x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) 86 | 87 | return x.to(device), None -------------------------------------------------------------------------------- /ldm/models/diffusion/plms.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | from functools import partial 7 | 8 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like 9 | from ldm.models.diffusion.sampling_util import norm_thresholding 10 | 11 | 12 | class PLMSSampler(object): 13 | def __init__(self, model, schedule="linear", **kwargs): 14 | super().__init__() 15 | self.model = model 16 | self.ddpm_num_timesteps = model.num_timesteps 17 | self.schedule = schedule 18 | 19 | def register_buffer(self, name, attr): 20 | if type(attr) == torch.Tensor: 21 | if attr.device != torch.device("cuda"): 22 | attr = attr.to(torch.device("cuda")) 23 | setattr(self, name, attr) 24 | 25 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 26 | if ddim_eta != 0: 27 | raise ValueError('ddim_eta must be 0 for PLMS') 28 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 29 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 30 | alphas_cumprod = self.model.alphas_cumprod 31 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 32 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 33 | 34 | self.register_buffer('betas', to_torch(self.model.betas)) 35 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 36 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 37 | 38 | # calculations for diffusion q(x_t | x_{t-1}) and others 39 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 40 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 41 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 42 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 43 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 44 | 45 | # ddim sampling parameters 46 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 47 | ddim_timesteps=self.ddim_timesteps, 48 | eta=ddim_eta,verbose=verbose) 49 | self.register_buffer('ddim_sigmas', ddim_sigmas) 50 | self.register_buffer('ddim_alphas', ddim_alphas) 51 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 52 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 53 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 54 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 55 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 56 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 57 | 58 | @torch.no_grad() 59 | def sample(self, 60 | S, 61 | batch_size, 62 | shape, 63 | conditioning=None, 64 | callback=None, 65 | normals_sequence=None, 66 | img_callback=None, 67 | quantize_x0=False, 68 | eta=0., 69 | mask=None, 70 | x0=None, 71 | temperature=1., 72 | noise_dropout=0., 73 | score_corrector=None, 74 | corrector_kwargs=None, 75 | verbose=True, 76 | x_T=None, 77 | log_every_t=100, 78 | unconditional_guidance_scale=1., 79 | unconditional_conditioning=None, 80 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 81 | dynamic_threshold=None, 82 | **kwargs 83 | ): 84 | if conditioning is not None: 85 | if isinstance(conditioning, dict): 86 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 87 | if cbs != batch_size: 88 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 89 | else: 90 | if conditioning.shape[0] != batch_size: 91 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 92 | 93 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) 94 | # sampling 95 | C, H, W = shape 96 | size = (batch_size, C, H, W) 97 | print(f'Data shape for PLMS sampling is {size}') 98 | 99 | samples, intermediates = self.plms_sampling(conditioning, size, 100 | callback=callback, 101 | img_callback=img_callback, 102 | quantize_denoised=quantize_x0, 103 | mask=mask, x0=x0, 104 | ddim_use_original_steps=False, 105 | noise_dropout=noise_dropout, 106 | temperature=temperature, 107 | score_corrector=score_corrector, 108 | corrector_kwargs=corrector_kwargs, 109 | x_T=x_T, 110 | log_every_t=log_every_t, 111 | unconditional_guidance_scale=unconditional_guidance_scale, 112 | unconditional_conditioning=unconditional_conditioning, 113 | dynamic_threshold=dynamic_threshold, 114 | ) 115 | return samples, intermediates 116 | 117 | @torch.no_grad() 118 | def plms_sampling(self, cond, shape, 119 | x_T=None, ddim_use_original_steps=False, 120 | callback=None, timesteps=None, quantize_denoised=False, 121 | mask=None, x0=None, img_callback=None, log_every_t=100, 122 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 123 | unconditional_guidance_scale=1., unconditional_conditioning=None, 124 | dynamic_threshold=None): 125 | device = self.model.betas.device 126 | b = shape[0] 127 | if x_T is None: 128 | img = torch.randn(shape, device=device) 129 | else: 130 | img = x_T 131 | 132 | if timesteps is None: 133 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 134 | elif timesteps is not None and not ddim_use_original_steps: 135 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 136 | timesteps = self.ddim_timesteps[:subset_end] 137 | 138 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 139 | time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) 140 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 141 | print(f"Running PLMS Sampling with {total_steps} timesteps") 142 | 143 | iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) 144 | old_eps = [] 145 | 146 | for i, step in enumerate(iterator): 147 | index = total_steps - i - 1 148 | ts = torch.full((b,), step, device=device, dtype=torch.long) 149 | ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) 150 | 151 | if mask is not None: 152 | assert x0 is not None 153 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 154 | img = img_orig * mask + (1. - mask) * img 155 | 156 | outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 157 | quantize_denoised=quantize_denoised, temperature=temperature, 158 | noise_dropout=noise_dropout, score_corrector=score_corrector, 159 | corrector_kwargs=corrector_kwargs, 160 | unconditional_guidance_scale=unconditional_guidance_scale, 161 | unconditional_conditioning=unconditional_conditioning, 162 | old_eps=old_eps, t_next=ts_next, 163 | dynamic_threshold=dynamic_threshold) 164 | img, pred_x0, e_t = outs 165 | old_eps.append(e_t) 166 | if len(old_eps) >= 4: 167 | old_eps.pop(0) 168 | if callback: callback(i) 169 | if img_callback: img_callback(pred_x0, i) 170 | 171 | if index % log_every_t == 0 or index == total_steps - 1: 172 | intermediates['x_inter'].append(img) 173 | intermediates['pred_x0'].append(pred_x0) 174 | 175 | return img, intermediates 176 | 177 | @torch.no_grad() 178 | def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 179 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 180 | unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, 181 | dynamic_threshold=None): 182 | b, *_, device = *x.shape, x.device 183 | 184 | def get_model_output(x, t): 185 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 186 | e_t = self.model.apply_model(x, t, c) 187 | else: 188 | x_in = torch.cat([x] * 2) 189 | t_in = torch.cat([t] * 2) 190 | c_in = torch.cat([unconditional_conditioning, c]) 191 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) 192 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) 193 | 194 | if score_corrector is not None: 195 | assert self.model.parameterization == "eps" 196 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 197 | 198 | return e_t 199 | 200 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 201 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 202 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 203 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 204 | 205 | def get_x_prev_and_pred_x0(e_t, index): 206 | # select parameters corresponding to the currently considered timestep 207 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) 208 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) 209 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) 210 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) 211 | 212 | # current prediction for x_0 213 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 214 | if quantize_denoised: 215 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 216 | if dynamic_threshold is not None: 217 | pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) 218 | # direction pointing to x_t 219 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 220 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 221 | if noise_dropout > 0.: 222 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 223 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 224 | return x_prev, pred_x0 225 | 226 | e_t = get_model_output(x, t) 227 | if len(old_eps) == 0: 228 | # Pseudo Improved Euler (2nd order) 229 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) 230 | e_t_next = get_model_output(x_prev, t_next) 231 | e_t_prime = (e_t + e_t_next) / 2 232 | elif len(old_eps) == 1: 233 | # 2nd order Pseudo Linear Multistep (Adams-Bashforth) 234 | e_t_prime = (3 * e_t - old_eps[-1]) / 2 235 | elif len(old_eps) == 2: 236 | # 3nd order Pseudo Linear Multistep (Adams-Bashforth) 237 | e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 238 | elif len(old_eps) >= 3: 239 | # 4nd order Pseudo Linear Multistep (Adams-Bashforth) 240 | e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 241 | 242 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) 243 | 244 | return x_prev, pred_x0, e_t 245 | -------------------------------------------------------------------------------- /ldm/models/diffusion/sampling_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def append_dims(x, target_dims): 6 | """Appends dimensions to the end of a tensor until it has target_dims dimensions. 7 | From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" 8 | dims_to_append = target_dims - x.ndim 9 | if dims_to_append < 0: 10 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') 11 | return x[(...,) + (None,) * dims_to_append] 12 | 13 | 14 | def norm_thresholding(x0, value): 15 | s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) 16 | return x0 * (value / s) 17 | 18 | 19 | def spatial_norm_thresholding(x0, value): 20 | # b c h w 21 | s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) 22 | return x0 * (value / s) -------------------------------------------------------------------------------- /ldm/modules/attention.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, einsum 6 | from einops import rearrange, repeat 7 | from typing import Optional, Any 8 | 9 | from ldm.modules.diffusionmodules.util import checkpoint 10 | 11 | 12 | try: 13 | import xformers 14 | import xformers.ops 15 | XFORMERS_IS_AVAILBLE = True 16 | except: 17 | XFORMERS_IS_AVAILBLE = False 18 | 19 | # CrossAttn precision handling 20 | import os 21 | _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") 22 | 23 | def exists(val): 24 | return val is not None 25 | 26 | 27 | def uniq(arr): 28 | return{el: True for el in arr}.keys() 29 | 30 | 31 | def default(val, d): 32 | if exists(val): 33 | return val 34 | return d() if isfunction(d) else d 35 | 36 | 37 | def max_neg_value(t): 38 | return -torch.finfo(t.dtype).max 39 | 40 | 41 | def init_(tensor): 42 | dim = tensor.shape[-1] 43 | std = 1 / math.sqrt(dim) 44 | tensor.uniform_(-std, std) 45 | return tensor 46 | 47 | 48 | # feedforward 49 | class GEGLU(nn.Module): 50 | def __init__(self, dim_in, dim_out): 51 | super().__init__() 52 | self.proj = nn.Linear(dim_in, dim_out * 2) 53 | 54 | def forward(self, x): 55 | x, gate = self.proj(x).chunk(2, dim=-1) 56 | return x * F.gelu(gate) 57 | 58 | 59 | class FeedForward(nn.Module): 60 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 61 | super().__init__() 62 | inner_dim = int(dim * mult) 63 | dim_out = default(dim_out, dim) 64 | project_in = nn.Sequential( 65 | nn.Linear(dim, inner_dim), 66 | nn.GELU() 67 | ) if not glu else GEGLU(dim, inner_dim) 68 | 69 | self.net = nn.Sequential( 70 | project_in, 71 | nn.Dropout(dropout), 72 | nn.Linear(inner_dim, dim_out) 73 | ) 74 | 75 | def forward(self, x): 76 | return self.net(x) 77 | 78 | 79 | def zero_module(module): 80 | """ 81 | Zero out the parameters of a module and return it. 82 | """ 83 | for p in module.parameters(): 84 | p.detach().zero_() 85 | return module 86 | 87 | 88 | def Normalize(in_channels): 89 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 90 | 91 | 92 | class SpatialSelfAttention(nn.Module): 93 | def __init__(self, in_channels): 94 | super().__init__() 95 | self.in_channels = in_channels 96 | 97 | self.norm = Normalize(in_channels) 98 | self.q = torch.nn.Conv2d(in_channels, 99 | in_channels, 100 | kernel_size=1, 101 | stride=1, 102 | padding=0) 103 | self.k = torch.nn.Conv2d(in_channels, 104 | in_channels, 105 | kernel_size=1, 106 | stride=1, 107 | padding=0) 108 | self.v = torch.nn.Conv2d(in_channels, 109 | in_channels, 110 | kernel_size=1, 111 | stride=1, 112 | padding=0) 113 | self.proj_out = torch.nn.Conv2d(in_channels, 114 | in_channels, 115 | kernel_size=1, 116 | stride=1, 117 | padding=0) 118 | 119 | def forward(self, x): 120 | h_ = x 121 | h_ = self.norm(h_) 122 | q = self.q(h_) 123 | k = self.k(h_) 124 | v = self.v(h_) 125 | 126 | # compute attention 127 | b,c,h,w = q.shape 128 | q = rearrange(q, 'b c h w -> b (h w) c') 129 | k = rearrange(k, 'b c h w -> b c (h w)') 130 | w_ = torch.einsum('bij,bjk->bik', q, k) 131 | 132 | w_ = w_ * (int(c)**(-0.5)) 133 | w_ = torch.nn.functional.softmax(w_, dim=2) 134 | 135 | # attend to values 136 | v = rearrange(v, 'b c h w -> b c (h w)') 137 | w_ = rearrange(w_, 'b i j -> b j i') 138 | h_ = torch.einsum('bij,bjk->bik', v, w_) 139 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) 140 | h_ = self.proj_out(h_) 141 | 142 | return x+h_ 143 | 144 | 145 | class CrossAttention(nn.Module): 146 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): 147 | super().__init__() 148 | inner_dim = dim_head * heads 149 | context_dim = default(context_dim, query_dim) 150 | 151 | self.scale = dim_head ** -0.5 152 | self.heads = heads 153 | 154 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 155 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 156 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 157 | 158 | self.to_out = nn.Sequential( 159 | nn.Linear(inner_dim, query_dim), 160 | nn.Dropout(dropout) 161 | ) 162 | 163 | def forward(self, x, context=None, mask=None): 164 | h = self.heads 165 | 166 | q = self.to_q(x) 167 | context = default(context, x) 168 | k = self.to_k(context) 169 | v = self.to_v(context) 170 | 171 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 172 | 173 | # force cast to fp32 to avoid overflowing 174 | if _ATTN_PRECISION =="fp32": 175 | with torch.autocast(enabled=False, device_type = 'cuda'): 176 | q, k = q.float(), k.float() 177 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 178 | else: 179 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 180 | 181 | del q, k 182 | 183 | if exists(mask): 184 | mask = rearrange(mask, 'b ... -> b (...)') 185 | max_neg_value = -torch.finfo(sim.dtype).max 186 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 187 | sim.masked_fill_(~mask, max_neg_value) 188 | 189 | # attention, what we cannot get enough of 190 | sim = sim.softmax(dim=-1) 191 | 192 | out = einsum('b i j, b j d -> b i d', sim, v) 193 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 194 | return self.to_out(out) 195 | 196 | 197 | class MemoryEfficientCrossAttention(nn.Module): 198 | # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 199 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): 200 | super().__init__() 201 | print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " 202 | f"{heads} heads.") 203 | inner_dim = dim_head * heads 204 | context_dim = default(context_dim, query_dim) 205 | 206 | self.heads = heads 207 | self.dim_head = dim_head 208 | 209 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 210 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 211 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 212 | 213 | self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) 214 | self.attention_op: Optional[Any] = None 215 | 216 | def forward(self, x, context=None, mask=None): 217 | q = self.to_q(x) 218 | context = default(context, x) 219 | k = self.to_k(context) 220 | v = self.to_v(context) 221 | 222 | b, _, _ = q.shape 223 | q, k, v = map( 224 | lambda t: t.unsqueeze(3) 225 | .reshape(b, t.shape[1], self.heads, self.dim_head) 226 | .permute(0, 2, 1, 3) 227 | .reshape(b * self.heads, t.shape[1], self.dim_head) 228 | .contiguous(), 229 | (q, k, v), 230 | ) 231 | 232 | # actually compute the attention, what we cannot get enough of 233 | out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) 234 | 235 | if exists(mask): 236 | raise NotImplementedError 237 | out = ( 238 | out.unsqueeze(0) 239 | .reshape(b, self.heads, out.shape[1], self.dim_head) 240 | .permute(0, 2, 1, 3) 241 | .reshape(b, out.shape[1], self.heads * self.dim_head) 242 | ) 243 | return self.to_out(out) 244 | 245 | 246 | class BasicTransformerBlock(nn.Module): 247 | ATTENTION_MODES = { 248 | "softmax": CrossAttention, # vanilla attention 249 | "softmax-xformers": MemoryEfficientCrossAttention 250 | } 251 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, 252 | disable_self_attn=False): 253 | super().__init__() 254 | attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" 255 | assert attn_mode in self.ATTENTION_MODES 256 | attn_cls = self.ATTENTION_MODES[attn_mode] 257 | self.disable_self_attn = disable_self_attn 258 | self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, 259 | context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn 260 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 261 | self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, 262 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none 263 | self.norm1 = nn.LayerNorm(dim) 264 | self.norm2 = nn.LayerNorm(dim) 265 | self.norm3 = nn.LayerNorm(dim) 266 | self.checkpoint = checkpoint 267 | 268 | def forward(self, x, context=None): 269 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) 270 | 271 | def _forward(self, x, context=None): 272 | x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x 273 | x = self.attn2(self.norm2(x), context=context) + x 274 | x = self.ff(self.norm3(x)) + x 275 | return x 276 | 277 | 278 | class SpatialTransformer(nn.Module): 279 | """ 280 | Transformer block for image-like data. 281 | First, project the input (aka embedding) 282 | and reshape to b, t, d. 283 | Then apply standard transformer action. 284 | Finally, reshape to image 285 | NEW: use_linear for more efficiency instead of the 1x1 convs 286 | """ 287 | def __init__(self, in_channels, n_heads, d_head, 288 | depth=1, dropout=0., context_dim=None, 289 | disable_self_attn=False, use_linear=False, 290 | use_checkpoint=True): 291 | super().__init__() 292 | if exists(context_dim) and not isinstance(context_dim, list): 293 | context_dim = [context_dim] 294 | self.in_channels = in_channels 295 | inner_dim = n_heads * d_head 296 | self.norm = Normalize(in_channels) 297 | if not use_linear: 298 | self.proj_in = nn.Conv2d(in_channels, 299 | inner_dim, 300 | kernel_size=1, 301 | stride=1, 302 | padding=0) 303 | else: 304 | self.proj_in = nn.Linear(in_channels, inner_dim) 305 | 306 | self.transformer_blocks = nn.ModuleList( 307 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], 308 | disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) 309 | for d in range(depth)] 310 | ) 311 | if not use_linear: 312 | self.proj_out = zero_module(nn.Conv2d(inner_dim, 313 | in_channels, 314 | kernel_size=1, 315 | stride=1, 316 | padding=0)) 317 | else: 318 | self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) 319 | self.use_linear = use_linear 320 | 321 | def forward(self, x, context=None): 322 | # note: if no context is given, cross-attention defaults to self-attention 323 | if not isinstance(context, list): 324 | context = [context] 325 | b, c, h, w = x.shape 326 | x_in = x 327 | x = self.norm(x) 328 | if not self.use_linear: 329 | x = self.proj_in(x) 330 | x = rearrange(x, 'b c h w -> b (h w) c').contiguous() 331 | if self.use_linear: 332 | x = self.proj_in(x) 333 | for i, block in enumerate(self.transformer_blocks): 334 | x = block(x, context=context[i]) 335 | if self.use_linear: 336 | x = self.proj_out(x) 337 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() 338 | if not self.use_linear: 339 | x = self.proj_out(x) 340 | return x + x_in 341 | 342 | -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wenquanlu/HandRefiner/f07e196298bafe871064b2952587548e28ccd467/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/upscaling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from functools import partial 5 | 6 | from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule 7 | from ldm.util import default 8 | 9 | 10 | class AbstractLowScaleModel(nn.Module): 11 | # for concatenating a downsampled image to the latent representation 12 | def __init__(self, noise_schedule_config=None): 13 | super(AbstractLowScaleModel, self).__init__() 14 | if noise_schedule_config is not None: 15 | self.register_schedule(**noise_schedule_config) 16 | 17 | def register_schedule(self, beta_schedule="linear", timesteps=1000, 18 | linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 19 | betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, 20 | cosine_s=cosine_s) 21 | alphas = 1. - betas 22 | alphas_cumprod = np.cumprod(alphas, axis=0) 23 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) 24 | 25 | timesteps, = betas.shape 26 | self.num_timesteps = int(timesteps) 27 | self.linear_start = linear_start 28 | self.linear_end = linear_end 29 | assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' 30 | 31 | to_torch = partial(torch.tensor, dtype=torch.float32) 32 | 33 | self.register_buffer('betas', to_torch(betas)) 34 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 35 | self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) 36 | 37 | # calculations for diffusion q(x_t | x_{t-1}) and others 38 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) 39 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) 40 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) 41 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) 42 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) 43 | 44 | def q_sample(self, x_start, t, noise=None): 45 | noise = default(noise, lambda: torch.randn_like(x_start)) 46 | return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 47 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) 48 | 49 | def forward(self, x): 50 | return x, None 51 | 52 | def decode(self, x): 53 | return x 54 | 55 | 56 | class SimpleImageConcat(AbstractLowScaleModel): 57 | # no noise level conditioning 58 | def __init__(self): 59 | super(SimpleImageConcat, self).__init__(noise_schedule_config=None) 60 | self.max_noise_level = 0 61 | 62 | def forward(self, x): 63 | # fix to constant noise level 64 | return x, torch.zeros(x.shape[0], device=x.device).long() 65 | 66 | 67 | class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel): 68 | def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False): 69 | super().__init__(noise_schedule_config=noise_schedule_config) 70 | self.max_noise_level = max_noise_level 71 | 72 | def forward(self, x, noise_level=None): 73 | if noise_level is None: 74 | noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() 75 | else: 76 | assert isinstance(noise_level, torch.Tensor) 77 | z = self.q_sample(x, noise_level) 78 | return z, noise_level 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/util.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import os 12 | import math 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | from einops import repeat 17 | 18 | from ldm.util import instantiate_from_config 19 | 20 | 21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 22 | if schedule == "linear": 23 | betas = ( 24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 25 | ) 26 | 27 | elif schedule == "cosine": 28 | timesteps = ( 29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 30 | ) 31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 32 | alphas = torch.cos(alphas).pow(2) 33 | alphas = alphas / alphas[0] 34 | betas = 1 - alphas[1:] / alphas[:-1] 35 | betas = np.clip(betas, a_min=0, a_max=0.999) 36 | 37 | elif schedule == "sqrt_linear": 38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 39 | elif schedule == "sqrt": 40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 41 | else: 42 | raise ValueError(f"schedule '{schedule}' unknown.") 43 | return betas.numpy() 44 | 45 | 46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 47 | if ddim_discr_method == 'uniform': 48 | c = num_ddpm_timesteps // num_ddim_timesteps 49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 50 | elif ddim_discr_method == 'quad': 51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 52 | else: 53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 54 | 55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 56 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 57 | steps_out = ddim_timesteps + 1 58 | if verbose: 59 | print(f'Selected timesteps for ddim sampler: {steps_out}') 60 | return steps_out 61 | 62 | 63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 64 | # select alphas for computing the variance schedule 65 | alphas = alphacums[ddim_timesteps] 66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 67 | 68 | # according the the formula provided in https://arxiv.org/abs/2010.02502 69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 70 | if verbose: 71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 72 | print(f'For the chosen value of eta, which is {eta}, ' 73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 74 | return sigmas, alphas, alphas_prev 75 | 76 | 77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 78 | """ 79 | Create a beta schedule that discretizes the given alpha_t_bar function, 80 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 81 | :param num_diffusion_timesteps: the number of betas to produce. 82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 83 | produces the cumulative product of (1-beta) up to that 84 | part of the diffusion process. 85 | :param max_beta: the maximum beta to use; use values lower than 1 to 86 | prevent singularities. 87 | """ 88 | betas = [] 89 | for i in range(num_diffusion_timesteps): 90 | t1 = i / num_diffusion_timesteps 91 | t2 = (i + 1) / num_diffusion_timesteps 92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 93 | return np.array(betas) 94 | 95 | 96 | def extract_into_tensor(a, t, x_shape): 97 | b, *_ = t.shape 98 | out = a.gather(-1, t) 99 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 100 | 101 | 102 | def checkpoint(func, inputs, params, flag): 103 | """ 104 | Evaluate a function without caching intermediate activations, allowing for 105 | reduced memory at the expense of extra compute in the backward pass. 106 | :param func: the function to evaluate. 107 | :param inputs: the argument sequence to pass to `func`. 108 | :param params: a sequence of parameters `func` depends on but does not 109 | explicitly take as arguments. 110 | :param flag: if False, disable gradient checkpointing. 111 | """ 112 | if flag: 113 | args = tuple(inputs) + tuple(params) 114 | return CheckpointFunction.apply(func, len(inputs), *args) 115 | else: 116 | return func(*inputs) 117 | 118 | 119 | class CheckpointFunction(torch.autograd.Function): 120 | @staticmethod 121 | def forward(ctx, run_function, length, *args): 122 | ctx.run_function = run_function 123 | ctx.input_tensors = list(args[:length]) 124 | ctx.input_params = list(args[length:]) 125 | ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), 126 | "dtype": torch.get_autocast_gpu_dtype(), 127 | "cache_enabled": torch.is_autocast_cache_enabled()} 128 | with torch.no_grad(): 129 | output_tensors = ctx.run_function(*ctx.input_tensors) 130 | return output_tensors 131 | 132 | @staticmethod 133 | def backward(ctx, *output_grads): 134 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 135 | with torch.enable_grad(), \ 136 | torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): 137 | # Fixes a bug where the first op in run_function modifies the 138 | # Tensor storage in place, which is not allowed for detach()'d 139 | # Tensors. 140 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 141 | output_tensors = ctx.run_function(*shallow_copies) 142 | input_grads = torch.autograd.grad( 143 | output_tensors, 144 | ctx.input_tensors + ctx.input_params, 145 | output_grads, 146 | allow_unused=True, 147 | ) 148 | del ctx.input_tensors 149 | del ctx.input_params 150 | del output_tensors 151 | return (None, None) + input_grads 152 | 153 | 154 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 155 | """ 156 | Create sinusoidal timestep embeddings. 157 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 158 | These may be fractional. 159 | :param dim: the dimension of the output. 160 | :param max_period: controls the minimum frequency of the embeddings. 161 | :return: an [N x dim] Tensor of positional embeddings. 162 | """ 163 | if not repeat_only: 164 | half = dim // 2 165 | freqs = torch.exp( 166 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 167 | ).to(device=timesteps.device) 168 | args = timesteps[:, None].float() * freqs[None] 169 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 170 | if dim % 2: 171 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 172 | else: 173 | embedding = repeat(timesteps, 'b -> b d', d=dim) 174 | return embedding 175 | 176 | 177 | def zero_module(module): 178 | """ 179 | Zero out the parameters of a module and return it. 180 | """ 181 | for p in module.parameters(): 182 | p.detach().zero_() 183 | return module 184 | 185 | 186 | def scale_module(module, scale): 187 | """ 188 | Scale the parameters of a module and return it. 189 | """ 190 | for p in module.parameters(): 191 | p.detach().mul_(scale) 192 | return module 193 | 194 | 195 | def mean_flat(tensor): 196 | """ 197 | Take the mean over all non-batch dimensions. 198 | """ 199 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 200 | 201 | 202 | def normalization(channels): 203 | """ 204 | Make a standard normalization layer. 205 | :param channels: number of input channels. 206 | :return: an nn.Module for normalization. 207 | """ 208 | return GroupNorm32(32, channels) 209 | 210 | 211 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 212 | class SiLU(nn.Module): 213 | def forward(self, x): 214 | return x * torch.sigmoid(x) 215 | 216 | 217 | class GroupNorm32(nn.GroupNorm): 218 | def forward(self, x): 219 | return super().forward(x.float()).type(x.dtype) 220 | 221 | def conv_nd(dims, *args, **kwargs): 222 | """ 223 | Create a 1D, 2D, or 3D convolution module. 224 | """ 225 | if dims == 1: 226 | return nn.Conv1d(*args, **kwargs) 227 | elif dims == 2: 228 | return nn.Conv2d(*args, **kwargs) 229 | elif dims == 3: 230 | return nn.Conv3d(*args, **kwargs) 231 | raise ValueError(f"unsupported dimensions: {dims}") 232 | 233 | 234 | def linear(*args, **kwargs): 235 | """ 236 | Create a linear module. 237 | """ 238 | return nn.Linear(*args, **kwargs) 239 | 240 | 241 | def avg_pool_nd(dims, *args, **kwargs): 242 | """ 243 | Create a 1D, 2D, or 3D average pooling module. 244 | """ 245 | if dims == 1: 246 | return nn.AvgPool1d(*args, **kwargs) 247 | elif dims == 2: 248 | return nn.AvgPool2d(*args, **kwargs) 249 | elif dims == 3: 250 | return nn.AvgPool3d(*args, **kwargs) 251 | raise ValueError(f"unsupported dimensions: {dims}") 252 | 253 | 254 | class HybridConditioner(nn.Module): 255 | 256 | def __init__(self, c_concat_config, c_crossattn_config): 257 | super().__init__() 258 | self.concat_conditioner = instantiate_from_config(c_concat_config) 259 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 260 | 261 | def forward(self, c_concat, c_crossattn): 262 | c_concat = self.concat_conditioner(c_concat) 263 | c_crossattn = self.crossattn_conditioner(c_crossattn) 264 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} 265 | 266 | 267 | def noise_like(shape, device, repeat=False): 268 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 269 | noise = lambda: torch.randn(shape, device=device) 270 | return repeat_noise() if repeat else noise() -------------------------------------------------------------------------------- /ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wenquanlu/HandRefiner/f07e196298bafe871064b2952587548e28ccd467/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1, dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | # remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.', '') 20 | self.m_name2s_name.update({name: s_name}) 21 | self.register_buffer(s_name, p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def reset_num_updates(self): 26 | del self.num_updates 27 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int)) 28 | 29 | def forward(self, model): 30 | decay = self.decay 31 | 32 | if self.num_updates >= 0: 33 | self.num_updates += 1 34 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 35 | 36 | one_minus_decay = 1.0 - decay 37 | 38 | with torch.no_grad(): 39 | m_param = dict(model.named_parameters()) 40 | shadow_params = dict(self.named_buffers()) 41 | 42 | for key in m_param: 43 | if m_param[key].requires_grad: 44 | sname = self.m_name2s_name[key] 45 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 46 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 47 | else: 48 | assert not key in self.m_name2s_name 49 | 50 | def copy_to(self, model): 51 | m_param = dict(model.named_parameters()) 52 | shadow_params = dict(self.named_buffers()) 53 | for key in m_param: 54 | if m_param[key].requires_grad: 55 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 56 | else: 57 | assert not key in self.m_name2s_name 58 | 59 | def store(self, parameters): 60 | """ 61 | Save the current parameters for restoring later. 62 | Args: 63 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 64 | temporarily stored. 65 | """ 66 | self.collected_params = [param.clone() for param in parameters] 67 | 68 | def restore(self, parameters): 69 | """ 70 | Restore the parameters stored with the `store` method. 71 | Useful to validate the model with EMA parameters without affecting the 72 | original optimization process. Store the parameters before the 73 | `copy_to` method. After validation (or model saving), use this to 74 | restore the former parameters. 75 | Args: 76 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 77 | updated with the stored parameters. 78 | """ 79 | for c_param, param in zip(self.collected_params, parameters): 80 | param.data.copy_(c_param.data) 81 | -------------------------------------------------------------------------------- /ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wenquanlu/HandRefiner/f07e196298bafe871064b2952587548e28ccd467/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /ldm/modules/encoders/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.checkpoint import checkpoint 4 | 5 | from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel 6 | 7 | 8 | from ldm.util import default, count_params 9 | 10 | 11 | class AbstractEncoder(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | def encode(self, *args, **kwargs): 16 | raise NotImplementedError 17 | 18 | 19 | class IdentityEncoder(AbstractEncoder): 20 | 21 | def encode(self, x): 22 | return x 23 | 24 | 25 | class ClassEmbedder(nn.Module): 26 | def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1): 27 | super().__init__() 28 | self.key = key 29 | self.embedding = nn.Embedding(n_classes, embed_dim) 30 | self.n_classes = n_classes 31 | self.ucg_rate = ucg_rate 32 | 33 | def forward(self, batch, key=None, disable_dropout=False): 34 | if key is None: 35 | key = self.key 36 | # this is for use in crossattn 37 | c = batch[key][:, None] 38 | if self.ucg_rate > 0. and not disable_dropout: 39 | mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) 40 | c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1) 41 | c = c.long() 42 | c = self.embedding(c) 43 | return c 44 | 45 | def get_unconditional_conditioning(self, bs, device="cuda"): 46 | uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) 47 | uc = torch.ones((bs,), device=device) * uc_class 48 | uc = {self.key: uc} 49 | return uc 50 | 51 | 52 | def disabled_train(self, mode=True): 53 | """Overwrite model.train with this function to make sure train/eval mode 54 | does not change anymore.""" 55 | return self 56 | 57 | 58 | class FrozenT5Embedder(AbstractEncoder): 59 | """Uses the T5 transformer encoder for text""" 60 | def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl 61 | super().__init__() 62 | self.tokenizer = T5Tokenizer.from_pretrained(version) 63 | self.transformer = T5EncoderModel.from_pretrained(version) 64 | self.device = device 65 | self.max_length = max_length # TODO: typical value? 66 | if freeze: 67 | self.freeze() 68 | 69 | def freeze(self): 70 | self.transformer = self.transformer.eval() 71 | #self.train = disabled_train 72 | for param in self.parameters(): 73 | param.requires_grad = False 74 | 75 | def forward(self, text): 76 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 77 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 78 | tokens = batch_encoding["input_ids"].to(self.device) 79 | outputs = self.transformer(input_ids=tokens) 80 | 81 | z = outputs.last_hidden_state 82 | return z 83 | 84 | def encode(self, text): 85 | return self(text) 86 | 87 | 88 | class FrozenCLIPEmbedder(AbstractEncoder): 89 | """Uses the CLIP transformer encoder for text (from huggingface)""" 90 | LAYERS = [ 91 | "last", 92 | "pooled", 93 | "hidden" 94 | ] 95 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, 96 | freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 97 | super().__init__() 98 | assert layer in self.LAYERS 99 | self.tokenizer = CLIPTokenizer.from_pretrained(version) 100 | self.transformer = CLIPTextModel.from_pretrained(version) 101 | self.device = device 102 | self.max_length = max_length 103 | if freeze: 104 | self.freeze() 105 | self.layer = layer 106 | self.layer_idx = layer_idx 107 | if layer == "hidden": 108 | assert layer_idx is not None 109 | assert 0 <= abs(layer_idx) <= 12 110 | 111 | def freeze(self): 112 | self.transformer = self.transformer.eval() 113 | #self.train = disabled_train 114 | for param in self.parameters(): 115 | param.requires_grad = False 116 | 117 | def forward(self, text): 118 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 119 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 120 | tokens = batch_encoding["input_ids"].to("cuda") 121 | outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") 122 | if self.layer == "last": 123 | z = outputs.last_hidden_state 124 | elif self.layer == "pooled": 125 | z = outputs.pooler_output[:, None, :] 126 | else: 127 | z = outputs.hidden_states[self.layer_idx] 128 | return z 129 | 130 | def encode(self, text): 131 | return self(text) 132 | 133 | 134 | 135 | 136 | class FrozenCLIPT5Encoder(AbstractEncoder): 137 | def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", 138 | clip_max_length=77, t5_max_length=77): 139 | super().__init__() 140 | self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) 141 | self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) 142 | print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, " 143 | f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.") 144 | 145 | def encode(self, text): 146 | return self(text) 147 | 148 | def forward(self, text): 149 | clip_z = self.clip_encoder.encode(text) 150 | t5_z = self.t5_encoder.encode(text) 151 | return [clip_z, t5_z] 152 | 153 | 154 | -------------------------------------------------------------------------------- /ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wenquanlu/HandRefiner/f07e196298bafe871064b2952587548e28ccd467/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /ldm/modules/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wenquanlu/HandRefiner/f07e196298bafe871064b2952587548e28ccd467/ldm/modules/midas/__init__.py -------------------------------------------------------------------------------- /ldm/modules/midas/api.py: -------------------------------------------------------------------------------- 1 | # based on https://github.com/isl-org/MiDaS 2 | 3 | import cv2 4 | import torch 5 | import torch.nn as nn 6 | from torchvision.transforms import Compose 7 | 8 | from ldm.modules.midas.midas.dpt_depth import DPTDepthModel 9 | from ldm.modules.midas.midas.midas_net import MidasNet 10 | from ldm.modules.midas.midas.midas_net_custom import MidasNet_small 11 | from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet 12 | 13 | 14 | ISL_PATHS = { 15 | "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt", 16 | "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt", 17 | "midas_v21": "", 18 | "midas_v21_small": "", 19 | } 20 | 21 | 22 | def disabled_train(self, mode=True): 23 | """Overwrite model.train with this function to make sure train/eval mode 24 | does not change anymore.""" 25 | return self 26 | 27 | 28 | def load_midas_transform(model_type): 29 | # https://github.com/isl-org/MiDaS/blob/master/run.py 30 | # load transform only 31 | if model_type == "dpt_large": # DPT-Large 32 | net_w, net_h = 384, 384 33 | resize_mode = "minimal" 34 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 35 | 36 | elif model_type == "dpt_hybrid": # DPT-Hybrid 37 | net_w, net_h = 384, 384 38 | resize_mode = "minimal" 39 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 40 | 41 | elif model_type == "midas_v21": 42 | net_w, net_h = 384, 384 43 | resize_mode = "upper_bound" 44 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 45 | 46 | elif model_type == "midas_v21_small": 47 | net_w, net_h = 256, 256 48 | resize_mode = "upper_bound" 49 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 50 | 51 | else: 52 | assert False, f"model_type '{model_type}' not implemented, use: --model_type large" 53 | 54 | transform = Compose( 55 | [ 56 | Resize( 57 | net_w, 58 | net_h, 59 | resize_target=None, 60 | keep_aspect_ratio=True, 61 | ensure_multiple_of=32, 62 | resize_method=resize_mode, 63 | image_interpolation_method=cv2.INTER_CUBIC, 64 | ), 65 | normalization, 66 | PrepareForNet(), 67 | ] 68 | ) 69 | 70 | return transform 71 | 72 | 73 | def load_model(model_type): 74 | # https://github.com/isl-org/MiDaS/blob/master/run.py 75 | # load network 76 | model_path = ISL_PATHS[model_type] 77 | if model_type == "dpt_large": # DPT-Large 78 | model = DPTDepthModel( 79 | path=model_path, 80 | backbone="vitl16_384", 81 | non_negative=True, 82 | ) 83 | net_w, net_h = 384, 384 84 | resize_mode = "minimal" 85 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 86 | 87 | elif model_type == "dpt_hybrid": # DPT-Hybrid 88 | model = DPTDepthModel( 89 | path=model_path, 90 | backbone="vitb_rn50_384", 91 | non_negative=True, 92 | ) 93 | net_w, net_h = 384, 384 94 | resize_mode = "minimal" 95 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 96 | 97 | elif model_type == "midas_v21": 98 | model = MidasNet(model_path, non_negative=True) 99 | net_w, net_h = 384, 384 100 | resize_mode = "upper_bound" 101 | normalization = NormalizeImage( 102 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 103 | ) 104 | 105 | elif model_type == "midas_v21_small": 106 | model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, 107 | non_negative=True, blocks={'expand': True}) 108 | net_w, net_h = 256, 256 109 | resize_mode = "upper_bound" 110 | normalization = NormalizeImage( 111 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 112 | ) 113 | 114 | else: 115 | print(f"model_type '{model_type}' not implemented, use: --model_type large") 116 | assert False 117 | 118 | transform = Compose( 119 | [ 120 | Resize( 121 | net_w, 122 | net_h, 123 | resize_target=None, 124 | keep_aspect_ratio=True, 125 | ensure_multiple_of=32, 126 | resize_method=resize_mode, 127 | image_interpolation_method=cv2.INTER_CUBIC, 128 | ), 129 | normalization, 130 | PrepareForNet(), 131 | ] 132 | ) 133 | 134 | return model.eval(), transform 135 | 136 | 137 | class MiDaSInference(nn.Module): 138 | MODEL_TYPES_TORCH_HUB = [ 139 | "DPT_Large", 140 | "DPT_Hybrid", 141 | "MiDaS_small" 142 | ] 143 | MODEL_TYPES_ISL = [ 144 | "dpt_large", 145 | "dpt_hybrid", 146 | "midas_v21", 147 | "midas_v21_small", 148 | ] 149 | 150 | def __init__(self, model_type): 151 | super().__init__() 152 | assert (model_type in self.MODEL_TYPES_ISL) 153 | model, _ = load_model(model_type) 154 | self.model = model 155 | self.model.train = disabled_train 156 | 157 | def forward(self, x): 158 | # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array 159 | # NOTE: we expect that the correct transform has been called during dataloading. 160 | with torch.no_grad(): 161 | prediction = self.model(x) 162 | prediction = torch.nn.functional.interpolate( 163 | prediction.unsqueeze(1), 164 | size=x.shape[2:], 165 | mode="bicubic", 166 | align_corners=False, 167 | ) 168 | assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3]) 169 | return prediction 170 | 171 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wenquanlu/HandRefiner/f07e196298bafe871064b2952587548e28ccd467/ldm/modules/midas/midas/__init__.py -------------------------------------------------------------------------------- /ldm/modules/midas/midas/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseModel(torch.nn.Module): 5 | def load(self, path): 6 | """Load model from file. 7 | 8 | Args: 9 | path (str): file path 10 | """ 11 | parameters = torch.load(path, map_location=torch.device('cpu')) 12 | 13 | if "optimizer" in parameters: 14 | parameters = parameters["model"] 15 | 16 | self.load_state_dict(parameters) 17 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .vit import ( 5 | _make_pretrained_vitb_rn50_384, 6 | _make_pretrained_vitl16_384, 7 | _make_pretrained_vitb16_384, 8 | forward_vit, 9 | ) 10 | 11 | def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): 12 | if backbone == "vitl16_384": 13 | pretrained = _make_pretrained_vitl16_384( 14 | use_pretrained, hooks=hooks, use_readout=use_readout 15 | ) 16 | scratch = _make_scratch( 17 | [256, 512, 1024, 1024], features, groups=groups, expand=expand 18 | ) # ViT-L/16 - 85.0% Top1 (backbone) 19 | elif backbone == "vitb_rn50_384": 20 | pretrained = _make_pretrained_vitb_rn50_384( 21 | use_pretrained, 22 | hooks=hooks, 23 | use_vit_only=use_vit_only, 24 | use_readout=use_readout, 25 | ) 26 | scratch = _make_scratch( 27 | [256, 512, 768, 768], features, groups=groups, expand=expand 28 | ) # ViT-H/16 - 85.0% Top1 (backbone) 29 | elif backbone == "vitb16_384": 30 | pretrained = _make_pretrained_vitb16_384( 31 | use_pretrained, hooks=hooks, use_readout=use_readout 32 | ) 33 | scratch = _make_scratch( 34 | [96, 192, 384, 768], features, groups=groups, expand=expand 35 | ) # ViT-B/16 - 84.6% Top1 (backbone) 36 | elif backbone == "resnext101_wsl": 37 | pretrained = _make_pretrained_resnext101_wsl(use_pretrained) 38 | scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 39 | elif backbone == "efficientnet_lite3": 40 | pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) 41 | scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 42 | else: 43 | print(f"Backbone '{backbone}' not implemented") 44 | assert False 45 | 46 | return pretrained, scratch 47 | 48 | 49 | def _make_scratch(in_shape, out_shape, groups=1, expand=False): 50 | scratch = nn.Module() 51 | 52 | out_shape1 = out_shape 53 | out_shape2 = out_shape 54 | out_shape3 = out_shape 55 | out_shape4 = out_shape 56 | if expand==True: 57 | out_shape1 = out_shape 58 | out_shape2 = out_shape*2 59 | out_shape3 = out_shape*4 60 | out_shape4 = out_shape*8 61 | 62 | scratch.layer1_rn = nn.Conv2d( 63 | in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 64 | ) 65 | scratch.layer2_rn = nn.Conv2d( 66 | in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 67 | ) 68 | scratch.layer3_rn = nn.Conv2d( 69 | in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 70 | ) 71 | scratch.layer4_rn = nn.Conv2d( 72 | in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 73 | ) 74 | 75 | return scratch 76 | 77 | 78 | def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): 79 | efficientnet = torch.hub.load( 80 | "rwightman/gen-efficientnet-pytorch", 81 | "tf_efficientnet_lite3", 82 | pretrained=use_pretrained, 83 | exportable=exportable 84 | ) 85 | return _make_efficientnet_backbone(efficientnet) 86 | 87 | 88 | def _make_efficientnet_backbone(effnet): 89 | pretrained = nn.Module() 90 | 91 | pretrained.layer1 = nn.Sequential( 92 | effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] 93 | ) 94 | pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) 95 | pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) 96 | pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) 97 | 98 | return pretrained 99 | 100 | 101 | def _make_resnet_backbone(resnet): 102 | pretrained = nn.Module() 103 | pretrained.layer1 = nn.Sequential( 104 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 105 | ) 106 | 107 | pretrained.layer2 = resnet.layer2 108 | pretrained.layer3 = resnet.layer3 109 | pretrained.layer4 = resnet.layer4 110 | 111 | return pretrained 112 | 113 | 114 | def _make_pretrained_resnext101_wsl(use_pretrained): 115 | resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") 116 | return _make_resnet_backbone(resnet) 117 | 118 | 119 | 120 | class Interpolate(nn.Module): 121 | """Interpolation module. 122 | """ 123 | 124 | def __init__(self, scale_factor, mode, align_corners=False): 125 | """Init. 126 | 127 | Args: 128 | scale_factor (float): scaling 129 | mode (str): interpolation mode 130 | """ 131 | super(Interpolate, self).__init__() 132 | 133 | self.interp = nn.functional.interpolate 134 | self.scale_factor = scale_factor 135 | self.mode = mode 136 | self.align_corners = align_corners 137 | 138 | def forward(self, x): 139 | """Forward pass. 140 | 141 | Args: 142 | x (tensor): input 143 | 144 | Returns: 145 | tensor: interpolated data 146 | """ 147 | 148 | x = self.interp( 149 | x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners 150 | ) 151 | 152 | return x 153 | 154 | 155 | class ResidualConvUnit(nn.Module): 156 | """Residual convolution module. 157 | """ 158 | 159 | def __init__(self, features): 160 | """Init. 161 | 162 | Args: 163 | features (int): number of features 164 | """ 165 | super().__init__() 166 | 167 | self.conv1 = nn.Conv2d( 168 | features, features, kernel_size=3, stride=1, padding=1, bias=True 169 | ) 170 | 171 | self.conv2 = nn.Conv2d( 172 | features, features, kernel_size=3, stride=1, padding=1, bias=True 173 | ) 174 | 175 | self.relu = nn.ReLU(inplace=True) 176 | 177 | def forward(self, x): 178 | """Forward pass. 179 | 180 | Args: 181 | x (tensor): input 182 | 183 | Returns: 184 | tensor: output 185 | """ 186 | out = self.relu(x) 187 | out = self.conv1(out) 188 | out = self.relu(out) 189 | out = self.conv2(out) 190 | 191 | return out + x 192 | 193 | 194 | class FeatureFusionBlock(nn.Module): 195 | """Feature fusion block. 196 | """ 197 | 198 | def __init__(self, features): 199 | """Init. 200 | 201 | Args: 202 | features (int): number of features 203 | """ 204 | super(FeatureFusionBlock, self).__init__() 205 | 206 | self.resConfUnit1 = ResidualConvUnit(features) 207 | self.resConfUnit2 = ResidualConvUnit(features) 208 | 209 | def forward(self, *xs): 210 | """Forward pass. 211 | 212 | Returns: 213 | tensor: output 214 | """ 215 | output = xs[0] 216 | 217 | if len(xs) == 2: 218 | output += self.resConfUnit1(xs[1]) 219 | 220 | output = self.resConfUnit2(output) 221 | 222 | output = nn.functional.interpolate( 223 | output, scale_factor=2, mode="bilinear", align_corners=True 224 | ) 225 | 226 | return output 227 | 228 | 229 | 230 | 231 | class ResidualConvUnit_custom(nn.Module): 232 | """Residual convolution module. 233 | """ 234 | 235 | def __init__(self, features, activation, bn): 236 | """Init. 237 | 238 | Args: 239 | features (int): number of features 240 | """ 241 | super().__init__() 242 | 243 | self.bn = bn 244 | 245 | self.groups=1 246 | 247 | self.conv1 = nn.Conv2d( 248 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups 249 | ) 250 | 251 | self.conv2 = nn.Conv2d( 252 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups 253 | ) 254 | 255 | if self.bn==True: 256 | self.bn1 = nn.BatchNorm2d(features) 257 | self.bn2 = nn.BatchNorm2d(features) 258 | 259 | self.activation = activation 260 | 261 | self.skip_add = nn.quantized.FloatFunctional() 262 | 263 | def forward(self, x): 264 | """Forward pass. 265 | 266 | Args: 267 | x (tensor): input 268 | 269 | Returns: 270 | tensor: output 271 | """ 272 | 273 | out = self.activation(x) 274 | out = self.conv1(out) 275 | if self.bn==True: 276 | out = self.bn1(out) 277 | 278 | out = self.activation(out) 279 | out = self.conv2(out) 280 | if self.bn==True: 281 | out = self.bn2(out) 282 | 283 | if self.groups > 1: 284 | out = self.conv_merge(out) 285 | 286 | return self.skip_add.add(out, x) 287 | 288 | # return out + x 289 | 290 | 291 | class FeatureFusionBlock_custom(nn.Module): 292 | """Feature fusion block. 293 | """ 294 | 295 | def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): 296 | """Init. 297 | 298 | Args: 299 | features (int): number of features 300 | """ 301 | super(FeatureFusionBlock_custom, self).__init__() 302 | 303 | self.deconv = deconv 304 | self.align_corners = align_corners 305 | 306 | self.groups=1 307 | 308 | self.expand = expand 309 | out_features = features 310 | if self.expand==True: 311 | out_features = features//2 312 | 313 | self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) 314 | 315 | self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) 316 | self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) 317 | 318 | self.skip_add = nn.quantized.FloatFunctional() 319 | 320 | def forward(self, *xs): 321 | """Forward pass. 322 | 323 | Returns: 324 | tensor: output 325 | """ 326 | output = xs[0] 327 | 328 | if len(xs) == 2: 329 | res = self.resConfUnit1(xs[1]) 330 | output = self.skip_add.add(output, res) 331 | # output += res 332 | 333 | output = self.resConfUnit2(output) 334 | 335 | output = nn.functional.interpolate( 336 | output, scale_factor=2, mode="bilinear", align_corners=self.align_corners 337 | ) 338 | 339 | output = self.out_conv(output) 340 | 341 | return output 342 | 343 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/dpt_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base_model import BaseModel 6 | from .blocks import ( 7 | FeatureFusionBlock, 8 | FeatureFusionBlock_custom, 9 | Interpolate, 10 | _make_encoder, 11 | forward_vit, 12 | ) 13 | 14 | 15 | def _make_fusion_block(features, use_bn): 16 | return FeatureFusionBlock_custom( 17 | features, 18 | nn.ReLU(False), 19 | deconv=False, 20 | bn=use_bn, 21 | expand=False, 22 | align_corners=True, 23 | ) 24 | 25 | 26 | class DPT(BaseModel): 27 | def __init__( 28 | self, 29 | head, 30 | features=256, 31 | backbone="vitb_rn50_384", 32 | readout="project", 33 | channels_last=False, 34 | use_bn=False, 35 | ): 36 | 37 | super(DPT, self).__init__() 38 | 39 | self.channels_last = channels_last 40 | 41 | hooks = { 42 | "vitb_rn50_384": [0, 1, 8, 11], 43 | "vitb16_384": [2, 5, 8, 11], 44 | "vitl16_384": [5, 11, 17, 23], 45 | } 46 | 47 | # Instantiate backbone and reassemble blocks 48 | self.pretrained, self.scratch = _make_encoder( 49 | backbone, 50 | features, 51 | False, # Set to true of you want to train from scratch, uses ImageNet weights 52 | groups=1, 53 | expand=False, 54 | exportable=False, 55 | hooks=hooks[backbone], 56 | use_readout=readout, 57 | ) 58 | 59 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 60 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 61 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn) 62 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 63 | 64 | self.scratch.output_conv = head 65 | 66 | 67 | def forward(self, x): 68 | if self.channels_last == True: 69 | x.contiguous(memory_format=torch.channels_last) 70 | 71 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) 72 | 73 | layer_1_rn = self.scratch.layer1_rn(layer_1) 74 | layer_2_rn = self.scratch.layer2_rn(layer_2) 75 | layer_3_rn = self.scratch.layer3_rn(layer_3) 76 | layer_4_rn = self.scratch.layer4_rn(layer_4) 77 | 78 | path_4 = self.scratch.refinenet4(layer_4_rn) 79 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 80 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 81 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 82 | 83 | out = self.scratch.output_conv(path_1) 84 | 85 | return out 86 | 87 | 88 | class DPTDepthModel(DPT): 89 | def __init__(self, path=None, non_negative=True, **kwargs): 90 | features = kwargs["features"] if "features" in kwargs else 256 91 | 92 | head = nn.Sequential( 93 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), 94 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 95 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 96 | nn.ReLU(True), 97 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 98 | nn.ReLU(True) if non_negative else nn.Identity(), 99 | nn.Identity(), 100 | ) 101 | 102 | super().__init__(head, **kwargs) 103 | 104 | if path is not None: 105 | self.load(path) 106 | 107 | def forward(self, x): 108 | return super().forward(x).squeeze(dim=1) 109 | 110 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/midas_net.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=256, non_negative=True): 17 | """Init. 18 | 19 | Args: 20 | path (str, optional): Path to saved model. Defaults to None. 21 | features (int, optional): Number of features. Defaults to 256. 22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 23 | """ 24 | print("Loading weights: ", path) 25 | 26 | super(MidasNet, self).__init__() 27 | 28 | use_pretrained = False if path is None else True 29 | 30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) 31 | 32 | self.scratch.refinenet4 = FeatureFusionBlock(features) 33 | self.scratch.refinenet3 = FeatureFusionBlock(features) 34 | self.scratch.refinenet2 = FeatureFusionBlock(features) 35 | self.scratch.refinenet1 = FeatureFusionBlock(features) 36 | 37 | self.scratch.output_conv = nn.Sequential( 38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 39 | Interpolate(scale_factor=2, mode="bilinear"), 40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 41 | nn.ReLU(True), 42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 43 | nn.ReLU(True) if non_negative else nn.Identity(), 44 | ) 45 | 46 | if path: 47 | self.load(path) 48 | 49 | def forward(self, x): 50 | """Forward pass. 51 | 52 | Args: 53 | x (tensor): input data (image) 54 | 55 | Returns: 56 | tensor: depth 57 | """ 58 | 59 | layer_1 = self.pretrained.layer1(x) 60 | layer_2 = self.pretrained.layer2(layer_1) 61 | layer_3 = self.pretrained.layer3(layer_2) 62 | layer_4 = self.pretrained.layer4(layer_3) 63 | 64 | layer_1_rn = self.scratch.layer1_rn(layer_1) 65 | layer_2_rn = self.scratch.layer2_rn(layer_2) 66 | layer_3_rn = self.scratch.layer3_rn(layer_3) 67 | layer_4_rn = self.scratch.layer4_rn(layer_4) 68 | 69 | path_4 = self.scratch.refinenet4(layer_4_rn) 70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 73 | 74 | out = self.scratch.output_conv(path_1) 75 | 76 | return torch.squeeze(out, dim=1) 77 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/midas_net_custom.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet_small(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, 17 | blocks={'expand': True}): 18 | """Init. 19 | 20 | Args: 21 | path (str, optional): Path to saved model. Defaults to None. 22 | features (int, optional): Number of features. Defaults to 256. 23 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 24 | """ 25 | print("Loading weights: ", path) 26 | 27 | super(MidasNet_small, self).__init__() 28 | 29 | use_pretrained = False if path else True 30 | 31 | self.channels_last = channels_last 32 | self.blocks = blocks 33 | self.backbone = backbone 34 | 35 | self.groups = 1 36 | 37 | features1=features 38 | features2=features 39 | features3=features 40 | features4=features 41 | self.expand = False 42 | if "expand" in self.blocks and self.blocks['expand'] == True: 43 | self.expand = True 44 | features1=features 45 | features2=features*2 46 | features3=features*4 47 | features4=features*8 48 | 49 | self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) 50 | 51 | self.scratch.activation = nn.ReLU(False) 52 | 53 | self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 54 | self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 55 | self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 56 | self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) 57 | 58 | 59 | self.scratch.output_conv = nn.Sequential( 60 | nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), 61 | Interpolate(scale_factor=2, mode="bilinear"), 62 | nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), 63 | self.scratch.activation, 64 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 65 | nn.ReLU(True) if non_negative else nn.Identity(), 66 | nn.Identity(), 67 | ) 68 | 69 | if path: 70 | self.load(path) 71 | 72 | 73 | def forward(self, x): 74 | """Forward pass. 75 | 76 | Args: 77 | x (tensor): input data (image) 78 | 79 | Returns: 80 | tensor: depth 81 | """ 82 | if self.channels_last==True: 83 | print("self.channels_last = ", self.channels_last) 84 | x.contiguous(memory_format=torch.channels_last) 85 | 86 | 87 | layer_1 = self.pretrained.layer1(x) 88 | layer_2 = self.pretrained.layer2(layer_1) 89 | layer_3 = self.pretrained.layer3(layer_2) 90 | layer_4 = self.pretrained.layer4(layer_3) 91 | 92 | layer_1_rn = self.scratch.layer1_rn(layer_1) 93 | layer_2_rn = self.scratch.layer2_rn(layer_2) 94 | layer_3_rn = self.scratch.layer3_rn(layer_3) 95 | layer_4_rn = self.scratch.layer4_rn(layer_4) 96 | 97 | 98 | path_4 = self.scratch.refinenet4(layer_4_rn) 99 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 100 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 101 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 102 | 103 | out = self.scratch.output_conv(path_1) 104 | 105 | return torch.squeeze(out, dim=1) 106 | 107 | 108 | 109 | def fuse_model(m): 110 | prev_previous_type = nn.Identity() 111 | prev_previous_name = '' 112 | previous_type = nn.Identity() 113 | previous_name = '' 114 | for name, module in m.named_modules(): 115 | if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: 116 | # print("FUSED ", prev_previous_name, previous_name, name) 117 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) 118 | elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: 119 | # print("FUSED ", prev_previous_name, previous_name) 120 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) 121 | # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: 122 | # print("FUSED ", previous_name, name) 123 | # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) 124 | 125 | prev_previous_type = previous_type 126 | prev_previous_name = previous_name 127 | previous_type = type(module) 128 | previous_name = name -------------------------------------------------------------------------------- /ldm/modules/midas/midas/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import math 4 | 5 | 6 | def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): 7 | """Rezise the sample to ensure the given size. Keeps aspect ratio. 8 | 9 | Args: 10 | sample (dict): sample 11 | size (tuple): image size 12 | 13 | Returns: 14 | tuple: new size 15 | """ 16 | shape = list(sample["disparity"].shape) 17 | 18 | if shape[0] >= size[0] and shape[1] >= size[1]: 19 | return sample 20 | 21 | scale = [0, 0] 22 | scale[0] = size[0] / shape[0] 23 | scale[1] = size[1] / shape[1] 24 | 25 | scale = max(scale) 26 | 27 | shape[0] = math.ceil(scale * shape[0]) 28 | shape[1] = math.ceil(scale * shape[1]) 29 | 30 | # resize 31 | sample["image"] = cv2.resize( 32 | sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method 33 | ) 34 | 35 | sample["disparity"] = cv2.resize( 36 | sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST 37 | ) 38 | sample["mask"] = cv2.resize( 39 | sample["mask"].astype(np.float32), 40 | tuple(shape[::-1]), 41 | interpolation=cv2.INTER_NEAREST, 42 | ) 43 | sample["mask"] = sample["mask"].astype(bool) 44 | 45 | return tuple(shape) 46 | 47 | 48 | class Resize(object): 49 | """Resize sample to given size (width, height). 50 | """ 51 | 52 | def __init__( 53 | self, 54 | width, 55 | height, 56 | resize_target=True, 57 | keep_aspect_ratio=False, 58 | ensure_multiple_of=1, 59 | resize_method="lower_bound", 60 | image_interpolation_method=cv2.INTER_AREA, 61 | ): 62 | """Init. 63 | 64 | Args: 65 | width (int): desired output width 66 | height (int): desired output height 67 | resize_target (bool, optional): 68 | True: Resize the full sample (image, mask, target). 69 | False: Resize image only. 70 | Defaults to True. 71 | keep_aspect_ratio (bool, optional): 72 | True: Keep the aspect ratio of the input sample. 73 | Output sample might not have the given width and height, and 74 | resize behaviour depends on the parameter 'resize_method'. 75 | Defaults to False. 76 | ensure_multiple_of (int, optional): 77 | Output width and height is constrained to be multiple of this parameter. 78 | Defaults to 1. 79 | resize_method (str, optional): 80 | "lower_bound": Output will be at least as large as the given size. 81 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) 82 | "minimal": Scale as least as possible. (Output size might be smaller than given size.) 83 | Defaults to "lower_bound". 84 | """ 85 | self.__width = width 86 | self.__height = height 87 | 88 | self.__resize_target = resize_target 89 | self.__keep_aspect_ratio = keep_aspect_ratio 90 | self.__multiple_of = ensure_multiple_of 91 | self.__resize_method = resize_method 92 | self.__image_interpolation_method = image_interpolation_method 93 | 94 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None): 95 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) 96 | 97 | if max_val is not None and y > max_val: 98 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) 99 | 100 | if y < min_val: 101 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) 102 | 103 | return y 104 | 105 | def get_size(self, width, height): 106 | # determine new height and width 107 | scale_height = self.__height / height 108 | scale_width = self.__width / width 109 | 110 | if self.__keep_aspect_ratio: 111 | if self.__resize_method == "lower_bound": 112 | # scale such that output size is lower bound 113 | if scale_width > scale_height: 114 | # fit width 115 | scale_height = scale_width 116 | else: 117 | # fit height 118 | scale_width = scale_height 119 | elif self.__resize_method == "upper_bound": 120 | # scale such that output size is upper bound 121 | if scale_width < scale_height: 122 | # fit width 123 | scale_height = scale_width 124 | else: 125 | # fit height 126 | scale_width = scale_height 127 | elif self.__resize_method == "minimal": 128 | # scale as least as possbile 129 | if abs(1 - scale_width) < abs(1 - scale_height): 130 | # fit width 131 | scale_height = scale_width 132 | else: 133 | # fit height 134 | scale_width = scale_height 135 | else: 136 | raise ValueError( 137 | f"resize_method {self.__resize_method} not implemented" 138 | ) 139 | 140 | if self.__resize_method == "lower_bound": 141 | new_height = self.constrain_to_multiple_of( 142 | scale_height * height, min_val=self.__height 143 | ) 144 | new_width = self.constrain_to_multiple_of( 145 | scale_width * width, min_val=self.__width 146 | ) 147 | elif self.__resize_method == "upper_bound": 148 | new_height = self.constrain_to_multiple_of( 149 | scale_height * height, max_val=self.__height 150 | ) 151 | new_width = self.constrain_to_multiple_of( 152 | scale_width * width, max_val=self.__width 153 | ) 154 | elif self.__resize_method == "minimal": 155 | new_height = self.constrain_to_multiple_of(scale_height * height) 156 | new_width = self.constrain_to_multiple_of(scale_width * width) 157 | else: 158 | raise ValueError(f"resize_method {self.__resize_method} not implemented") 159 | 160 | return (new_width, new_height) 161 | 162 | def __call__(self, sample): 163 | width, height = self.get_size( 164 | sample["image"].shape[1], sample["image"].shape[0] 165 | ) 166 | 167 | # resize sample 168 | sample["image"] = cv2.resize( 169 | sample["image"], 170 | (width, height), 171 | interpolation=self.__image_interpolation_method, 172 | ) 173 | 174 | if self.__resize_target: 175 | if "disparity" in sample: 176 | sample["disparity"] = cv2.resize( 177 | sample["disparity"], 178 | (width, height), 179 | interpolation=cv2.INTER_NEAREST, 180 | ) 181 | 182 | if "depth" in sample: 183 | sample["depth"] = cv2.resize( 184 | sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST 185 | ) 186 | 187 | sample["mask"] = cv2.resize( 188 | sample["mask"].astype(np.float32), 189 | (width, height), 190 | interpolation=cv2.INTER_NEAREST, 191 | ) 192 | sample["mask"] = sample["mask"].astype(bool) 193 | 194 | return sample 195 | 196 | 197 | class NormalizeImage(object): 198 | """Normlize image by given mean and std. 199 | """ 200 | 201 | def __init__(self, mean, std): 202 | self.__mean = mean 203 | self.__std = std 204 | 205 | def __call__(self, sample): 206 | sample["image"] = (sample["image"] - self.__mean) / self.__std 207 | 208 | return sample 209 | 210 | 211 | class PrepareForNet(object): 212 | """Prepare sample for usage as network input. 213 | """ 214 | 215 | def __init__(self): 216 | pass 217 | 218 | def __call__(self, sample): 219 | image = np.transpose(sample["image"], (2, 0, 1)) 220 | sample["image"] = np.ascontiguousarray(image).astype(np.float32) 221 | 222 | if "mask" in sample: 223 | sample["mask"] = sample["mask"].astype(np.float32) 224 | sample["mask"] = np.ascontiguousarray(sample["mask"]) 225 | 226 | if "disparity" in sample: 227 | disparity = sample["disparity"].astype(np.float32) 228 | sample["disparity"] = np.ascontiguousarray(disparity) 229 | 230 | if "depth" in sample: 231 | depth = sample["depth"].astype(np.float32) 232 | sample["depth"] = np.ascontiguousarray(depth) 233 | 234 | return sample 235 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import timm 4 | import types 5 | import math 6 | import torch.nn.functional as F 7 | 8 | 9 | class Slice(nn.Module): 10 | def __init__(self, start_index=1): 11 | super(Slice, self).__init__() 12 | self.start_index = start_index 13 | 14 | def forward(self, x): 15 | return x[:, self.start_index :] 16 | 17 | 18 | class AddReadout(nn.Module): 19 | def __init__(self, start_index=1): 20 | super(AddReadout, self).__init__() 21 | self.start_index = start_index 22 | 23 | def forward(self, x): 24 | if self.start_index == 2: 25 | readout = (x[:, 0] + x[:, 1]) / 2 26 | else: 27 | readout = x[:, 0] 28 | return x[:, self.start_index :] + readout.unsqueeze(1) 29 | 30 | 31 | class ProjectReadout(nn.Module): 32 | def __init__(self, in_features, start_index=1): 33 | super(ProjectReadout, self).__init__() 34 | self.start_index = start_index 35 | 36 | self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) 37 | 38 | def forward(self, x): 39 | readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) 40 | features = torch.cat((x[:, self.start_index :], readout), -1) 41 | 42 | return self.project(features) 43 | 44 | 45 | class Transpose(nn.Module): 46 | def __init__(self, dim0, dim1): 47 | super(Transpose, self).__init__() 48 | self.dim0 = dim0 49 | self.dim1 = dim1 50 | 51 | def forward(self, x): 52 | x = x.transpose(self.dim0, self.dim1) 53 | return x 54 | 55 | 56 | def forward_vit(pretrained, x): 57 | b, c, h, w = x.shape 58 | 59 | glob = pretrained.model.forward_flex(x) 60 | 61 | layer_1 = pretrained.activations["1"] 62 | layer_2 = pretrained.activations["2"] 63 | layer_3 = pretrained.activations["3"] 64 | layer_4 = pretrained.activations["4"] 65 | 66 | layer_1 = pretrained.act_postprocess1[0:2](layer_1) 67 | layer_2 = pretrained.act_postprocess2[0:2](layer_2) 68 | layer_3 = pretrained.act_postprocess3[0:2](layer_3) 69 | layer_4 = pretrained.act_postprocess4[0:2](layer_4) 70 | 71 | unflatten = nn.Sequential( 72 | nn.Unflatten( 73 | 2, 74 | torch.Size( 75 | [ 76 | h // pretrained.model.patch_size[1], 77 | w // pretrained.model.patch_size[0], 78 | ] 79 | ), 80 | ) 81 | ) 82 | 83 | if layer_1.ndim == 3: 84 | layer_1 = unflatten(layer_1) 85 | if layer_2.ndim == 3: 86 | layer_2 = unflatten(layer_2) 87 | if layer_3.ndim == 3: 88 | layer_3 = unflatten(layer_3) 89 | if layer_4.ndim == 3: 90 | layer_4 = unflatten(layer_4) 91 | 92 | layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) 93 | layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) 94 | layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) 95 | layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) 96 | 97 | return layer_1, layer_2, layer_3, layer_4 98 | 99 | 100 | def _resize_pos_embed(self, posemb, gs_h, gs_w): 101 | posemb_tok, posemb_grid = ( 102 | posemb[:, : self.start_index], 103 | posemb[0, self.start_index :], 104 | ) 105 | 106 | gs_old = int(math.sqrt(len(posemb_grid))) 107 | 108 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 109 | posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") 110 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) 111 | 112 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 113 | 114 | return posemb 115 | 116 | 117 | def forward_flex(self, x): 118 | b, c, h, w = x.shape 119 | 120 | pos_embed = self._resize_pos_embed( 121 | self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] 122 | ) 123 | 124 | B = x.shape[0] 125 | 126 | if hasattr(self.patch_embed, "backbone"): 127 | x = self.patch_embed.backbone(x) 128 | if isinstance(x, (list, tuple)): 129 | x = x[-1] # last feature if backbone outputs list/tuple of features 130 | 131 | x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) 132 | 133 | if getattr(self, "dist_token", None) is not None: 134 | cls_tokens = self.cls_token.expand( 135 | B, -1, -1 136 | ) # stole cls_tokens impl from Phil Wang, thanks 137 | dist_token = self.dist_token.expand(B, -1, -1) 138 | x = torch.cat((cls_tokens, dist_token, x), dim=1) 139 | else: 140 | cls_tokens = self.cls_token.expand( 141 | B, -1, -1 142 | ) # stole cls_tokens impl from Phil Wang, thanks 143 | x = torch.cat((cls_tokens, x), dim=1) 144 | 145 | x = x + pos_embed 146 | x = self.pos_drop(x) 147 | 148 | for blk in self.blocks: 149 | x = blk(x) 150 | 151 | x = self.norm(x) 152 | 153 | return x 154 | 155 | 156 | activations = {} 157 | 158 | 159 | def get_activation(name): 160 | def hook(model, input, output): 161 | activations[name] = output 162 | 163 | return hook 164 | 165 | 166 | def get_readout_oper(vit_features, features, use_readout, start_index=1): 167 | if use_readout == "ignore": 168 | readout_oper = [Slice(start_index)] * len(features) 169 | elif use_readout == "add": 170 | readout_oper = [AddReadout(start_index)] * len(features) 171 | elif use_readout == "project": 172 | readout_oper = [ 173 | ProjectReadout(vit_features, start_index) for out_feat in features 174 | ] 175 | else: 176 | assert ( 177 | False 178 | ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" 179 | 180 | return readout_oper 181 | 182 | 183 | def _make_vit_b16_backbone( 184 | model, 185 | features=[96, 192, 384, 768], 186 | size=[384, 384], 187 | hooks=[2, 5, 8, 11], 188 | vit_features=768, 189 | use_readout="ignore", 190 | start_index=1, 191 | ): 192 | pretrained = nn.Module() 193 | 194 | pretrained.model = model 195 | pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) 196 | pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) 197 | pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) 198 | pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) 199 | 200 | pretrained.activations = activations 201 | 202 | readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) 203 | 204 | # 32, 48, 136, 384 205 | pretrained.act_postprocess1 = nn.Sequential( 206 | readout_oper[0], 207 | Transpose(1, 2), 208 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 209 | nn.Conv2d( 210 | in_channels=vit_features, 211 | out_channels=features[0], 212 | kernel_size=1, 213 | stride=1, 214 | padding=0, 215 | ), 216 | nn.ConvTranspose2d( 217 | in_channels=features[0], 218 | out_channels=features[0], 219 | kernel_size=4, 220 | stride=4, 221 | padding=0, 222 | bias=True, 223 | dilation=1, 224 | groups=1, 225 | ), 226 | ) 227 | 228 | pretrained.act_postprocess2 = nn.Sequential( 229 | readout_oper[1], 230 | Transpose(1, 2), 231 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 232 | nn.Conv2d( 233 | in_channels=vit_features, 234 | out_channels=features[1], 235 | kernel_size=1, 236 | stride=1, 237 | padding=0, 238 | ), 239 | nn.ConvTranspose2d( 240 | in_channels=features[1], 241 | out_channels=features[1], 242 | kernel_size=2, 243 | stride=2, 244 | padding=0, 245 | bias=True, 246 | dilation=1, 247 | groups=1, 248 | ), 249 | ) 250 | 251 | pretrained.act_postprocess3 = nn.Sequential( 252 | readout_oper[2], 253 | Transpose(1, 2), 254 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 255 | nn.Conv2d( 256 | in_channels=vit_features, 257 | out_channels=features[2], 258 | kernel_size=1, 259 | stride=1, 260 | padding=0, 261 | ), 262 | ) 263 | 264 | pretrained.act_postprocess4 = nn.Sequential( 265 | readout_oper[3], 266 | Transpose(1, 2), 267 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 268 | nn.Conv2d( 269 | in_channels=vit_features, 270 | out_channels=features[3], 271 | kernel_size=1, 272 | stride=1, 273 | padding=0, 274 | ), 275 | nn.Conv2d( 276 | in_channels=features[3], 277 | out_channels=features[3], 278 | kernel_size=3, 279 | stride=2, 280 | padding=1, 281 | ), 282 | ) 283 | 284 | pretrained.model.start_index = start_index 285 | pretrained.model.patch_size = [16, 16] 286 | 287 | # We inject this function into the VisionTransformer instances so that 288 | # we can use it with interpolated position embeddings without modifying the library source. 289 | pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) 290 | pretrained.model._resize_pos_embed = types.MethodType( 291 | _resize_pos_embed, pretrained.model 292 | ) 293 | 294 | return pretrained 295 | 296 | 297 | def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): 298 | model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) 299 | 300 | hooks = [5, 11, 17, 23] if hooks == None else hooks 301 | return _make_vit_b16_backbone( 302 | model, 303 | features=[256, 512, 1024, 1024], 304 | hooks=hooks, 305 | vit_features=1024, 306 | use_readout=use_readout, 307 | ) 308 | 309 | 310 | def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): 311 | model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) 312 | 313 | hooks = [2, 5, 8, 11] if hooks == None else hooks 314 | return _make_vit_b16_backbone( 315 | model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout 316 | ) 317 | 318 | 319 | def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None): 320 | model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) 321 | 322 | hooks = [2, 5, 8, 11] if hooks == None else hooks 323 | return _make_vit_b16_backbone( 324 | model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout 325 | ) 326 | 327 | 328 | def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None): 329 | model = timm.create_model( 330 | "vit_deit_base_distilled_patch16_384", pretrained=pretrained 331 | ) 332 | 333 | hooks = [2, 5, 8, 11] if hooks == None else hooks 334 | return _make_vit_b16_backbone( 335 | model, 336 | features=[96, 192, 384, 768], 337 | hooks=hooks, 338 | use_readout=use_readout, 339 | start_index=2, 340 | ) 341 | 342 | 343 | def _make_vit_b_rn50_backbone( 344 | model, 345 | features=[256, 512, 768, 768], 346 | size=[384, 384], 347 | hooks=[0, 1, 8, 11], 348 | vit_features=768, 349 | use_vit_only=False, 350 | use_readout="ignore", 351 | start_index=1, 352 | ): 353 | pretrained = nn.Module() 354 | 355 | pretrained.model = model 356 | 357 | if use_vit_only == True: 358 | pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) 359 | pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) 360 | else: 361 | pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( 362 | get_activation("1") 363 | ) 364 | pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( 365 | get_activation("2") 366 | ) 367 | 368 | pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) 369 | pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) 370 | 371 | pretrained.activations = activations 372 | 373 | readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) 374 | 375 | if use_vit_only == True: 376 | pretrained.act_postprocess1 = nn.Sequential( 377 | readout_oper[0], 378 | Transpose(1, 2), 379 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 380 | nn.Conv2d( 381 | in_channels=vit_features, 382 | out_channels=features[0], 383 | kernel_size=1, 384 | stride=1, 385 | padding=0, 386 | ), 387 | nn.ConvTranspose2d( 388 | in_channels=features[0], 389 | out_channels=features[0], 390 | kernel_size=4, 391 | stride=4, 392 | padding=0, 393 | bias=True, 394 | dilation=1, 395 | groups=1, 396 | ), 397 | ) 398 | 399 | pretrained.act_postprocess2 = nn.Sequential( 400 | readout_oper[1], 401 | Transpose(1, 2), 402 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 403 | nn.Conv2d( 404 | in_channels=vit_features, 405 | out_channels=features[1], 406 | kernel_size=1, 407 | stride=1, 408 | padding=0, 409 | ), 410 | nn.ConvTranspose2d( 411 | in_channels=features[1], 412 | out_channels=features[1], 413 | kernel_size=2, 414 | stride=2, 415 | padding=0, 416 | bias=True, 417 | dilation=1, 418 | groups=1, 419 | ), 420 | ) 421 | else: 422 | pretrained.act_postprocess1 = nn.Sequential( 423 | nn.Identity(), nn.Identity(), nn.Identity() 424 | ) 425 | pretrained.act_postprocess2 = nn.Sequential( 426 | nn.Identity(), nn.Identity(), nn.Identity() 427 | ) 428 | 429 | pretrained.act_postprocess3 = nn.Sequential( 430 | readout_oper[2], 431 | Transpose(1, 2), 432 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 433 | nn.Conv2d( 434 | in_channels=vit_features, 435 | out_channels=features[2], 436 | kernel_size=1, 437 | stride=1, 438 | padding=0, 439 | ), 440 | ) 441 | 442 | pretrained.act_postprocess4 = nn.Sequential( 443 | readout_oper[3], 444 | Transpose(1, 2), 445 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 446 | nn.Conv2d( 447 | in_channels=vit_features, 448 | out_channels=features[3], 449 | kernel_size=1, 450 | stride=1, 451 | padding=0, 452 | ), 453 | nn.Conv2d( 454 | in_channels=features[3], 455 | out_channels=features[3], 456 | kernel_size=3, 457 | stride=2, 458 | padding=1, 459 | ), 460 | ) 461 | 462 | pretrained.model.start_index = start_index 463 | pretrained.model.patch_size = [16, 16] 464 | 465 | # We inject this function into the VisionTransformer instances so that 466 | # we can use it with interpolated position embeddings without modifying the library source. 467 | pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) 468 | 469 | # We inject this function into the VisionTransformer instances so that 470 | # we can use it with interpolated position embeddings without modifying the library source. 471 | pretrained.model._resize_pos_embed = types.MethodType( 472 | _resize_pos_embed, pretrained.model 473 | ) 474 | 475 | return pretrained 476 | 477 | 478 | def _make_pretrained_vitb_rn50_384( 479 | pretrained, use_readout="ignore", hooks=None, use_vit_only=False 480 | ): 481 | model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) 482 | 483 | hooks = [0, 1, 8, 11] if hooks == None else hooks 484 | return _make_vit_b_rn50_backbone( 485 | model, 486 | features=[256, 512, 768, 768], 487 | size=[384, 384], 488 | hooks=hooks, 489 | use_vit_only=use_vit_only, 490 | use_readout=use_readout, 491 | ) 492 | -------------------------------------------------------------------------------- /ldm/modules/midas/utils.py: -------------------------------------------------------------------------------- 1 | """Utils for monoDepth.""" 2 | import sys 3 | import re 4 | import numpy as np 5 | import cv2 6 | import torch 7 | 8 | 9 | def read_pfm(path): 10 | """Read pfm file. 11 | 12 | Args: 13 | path (str): path to file 14 | 15 | Returns: 16 | tuple: (data, scale) 17 | """ 18 | with open(path, "rb") as file: 19 | 20 | color = None 21 | width = None 22 | height = None 23 | scale = None 24 | endian = None 25 | 26 | header = file.readline().rstrip() 27 | if header.decode("ascii") == "PF": 28 | color = True 29 | elif header.decode("ascii") == "Pf": 30 | color = False 31 | else: 32 | raise Exception("Not a PFM file: " + path) 33 | 34 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) 35 | if dim_match: 36 | width, height = list(map(int, dim_match.groups())) 37 | else: 38 | raise Exception("Malformed PFM header.") 39 | 40 | scale = float(file.readline().decode("ascii").rstrip()) 41 | if scale < 0: 42 | # little-endian 43 | endian = "<" 44 | scale = -scale 45 | else: 46 | # big-endian 47 | endian = ">" 48 | 49 | data = np.fromfile(file, endian + "f") 50 | shape = (height, width, 3) if color else (height, width) 51 | 52 | data = np.reshape(data, shape) 53 | data = np.flipud(data) 54 | 55 | return data, scale 56 | 57 | 58 | def write_pfm(path, image, scale=1): 59 | """Write pfm file. 60 | 61 | Args: 62 | path (str): pathto file 63 | image (array): data 64 | scale (int, optional): Scale. Defaults to 1. 65 | """ 66 | 67 | with open(path, "wb") as file: 68 | color = None 69 | 70 | if image.dtype.name != "float32": 71 | raise Exception("Image dtype must be float32.") 72 | 73 | image = np.flipud(image) 74 | 75 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 76 | color = True 77 | elif ( 78 | len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 79 | ): # greyscale 80 | color = False 81 | else: 82 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") 83 | 84 | file.write("PF\n" if color else "Pf\n".encode()) 85 | file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) 86 | 87 | endian = image.dtype.byteorder 88 | 89 | if endian == "<" or endian == "=" and sys.byteorder == "little": 90 | scale = -scale 91 | 92 | file.write("%f\n".encode() % scale) 93 | 94 | image.tofile(file) 95 | 96 | 97 | def read_image(path): 98 | """Read image and output RGB image (0-1). 99 | 100 | Args: 101 | path (str): path to file 102 | 103 | Returns: 104 | array: RGB image (0-1) 105 | """ 106 | img = cv2.imread(path) 107 | 108 | if img.ndim == 2: 109 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 110 | 111 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 112 | 113 | return img 114 | 115 | 116 | def resize_image(img): 117 | """Resize image and make it fit for network. 118 | 119 | Args: 120 | img (array): image 121 | 122 | Returns: 123 | tensor: data ready for network 124 | """ 125 | height_orig = img.shape[0] 126 | width_orig = img.shape[1] 127 | 128 | if width_orig > height_orig: 129 | scale = width_orig / 384 130 | else: 131 | scale = height_orig / 384 132 | 133 | height = (np.ceil(height_orig / scale / 32) * 32).astype(int) 134 | width = (np.ceil(width_orig / scale / 32) * 32).astype(int) 135 | 136 | img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) 137 | 138 | img_resized = ( 139 | torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() 140 | ) 141 | img_resized = img_resized.unsqueeze(0) 142 | 143 | return img_resized 144 | 145 | 146 | def resize_depth(depth, width, height): 147 | """Resize depth map and bring to CPU (numpy). 148 | 149 | Args: 150 | depth (tensor): depth 151 | width (int): image width 152 | height (int): image height 153 | 154 | Returns: 155 | array: processed depth 156 | """ 157 | depth = torch.squeeze(depth[0, :, :, :]).to("cpu") 158 | 159 | depth_resized = cv2.resize( 160 | depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC 161 | ) 162 | 163 | return depth_resized 164 | 165 | def write_depth(path, depth, bits=1): 166 | """Write depth map to pfm and png file. 167 | 168 | Args: 169 | path (str): filepath without extension 170 | depth (array): depth 171 | """ 172 | write_pfm(path + ".pfm", depth.astype(np.float32)) 173 | 174 | depth_min = depth.min() 175 | depth_max = depth.max() 176 | 177 | max_val = (2**(8*bits))-1 178 | 179 | if depth_max - depth_min > np.finfo("float").eps: 180 | out = max_val * (depth - depth_min) / (depth_max - depth_min) 181 | else: 182 | out = np.zeros(depth.shape, dtype=depth.type) 183 | 184 | if bits == 1: 185 | cv2.imwrite(path + ".png", out.astype("uint8")) 186 | elif bits == 2: 187 | cv2.imwrite(path + ".png", out.astype("uint16")) 188 | 189 | return 190 | -------------------------------------------------------------------------------- /ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | from torch import optim 5 | import numpy as np 6 | 7 | from inspect import isfunction 8 | from PIL import Image, ImageDraw, ImageFont 9 | 10 | 11 | def log_txt_as_img(wh, xc, size=10): 12 | # wh a tuple of (width, height) 13 | # xc a list of captions to plot 14 | b = len(xc) 15 | txts = list() 16 | for bi in range(b): 17 | txt = Image.new("RGB", wh, color="white") 18 | draw = ImageDraw.Draw(txt) 19 | font = ImageFont.truetype('font/DejaVuSans.ttf', size=size) 20 | nc = int(40 * (wh[0] / 256)) 21 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 22 | 23 | try: 24 | draw.text((0, 0), lines, fill="black", font=font) 25 | except UnicodeEncodeError: 26 | print("Cant encode string for logging. Skipping.") 27 | 28 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 29 | txts.append(txt) 30 | txts = np.stack(txts) 31 | txts = torch.tensor(txts) 32 | return txts 33 | 34 | 35 | def ismap(x): 36 | if not isinstance(x, torch.Tensor): 37 | return False 38 | return (len(x.shape) == 4) and (x.shape[1] > 3) 39 | 40 | 41 | def isimage(x): 42 | if not isinstance(x,torch.Tensor): 43 | return False 44 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 45 | 46 | 47 | def exists(x): 48 | return x is not None 49 | 50 | 51 | def default(val, d): 52 | if exists(val): 53 | return val 54 | return d() if isfunction(d) else d 55 | 56 | 57 | def mean_flat(tensor): 58 | """ 59 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 60 | Take the mean over all non-batch dimensions. 61 | """ 62 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 63 | 64 | 65 | def count_params(model, verbose=False): 66 | total_params = sum(p.numel() for p in model.parameters()) 67 | if verbose: 68 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") 69 | return total_params 70 | 71 | 72 | def instantiate_from_config(config): 73 | if not "target" in config: 74 | if config == '__is_first_stage__': 75 | return None 76 | elif config == "__is_unconditional__": 77 | return None 78 | raise KeyError("Expected key `target` to instantiate.") 79 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 80 | 81 | 82 | def get_obj_from_str(string, reload=False): 83 | module, cls = string.rsplit(".", 1) 84 | if reload: 85 | module_imp = importlib.import_module(module) 86 | importlib.reload(module_imp) 87 | return getattr(importlib.import_module(module, package=None), cls) 88 | 89 | 90 | class AdamWwithEMAandWings(optim.Optimizer): 91 | # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 92 | def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using 93 | weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code 94 | ema_power=1., param_names=()): 95 | """AdamW that saves EMA versions of the parameters.""" 96 | if not 0.0 <= lr: 97 | raise ValueError("Invalid learning rate: {}".format(lr)) 98 | if not 0.0 <= eps: 99 | raise ValueError("Invalid epsilon value: {}".format(eps)) 100 | if not 0.0 <= betas[0] < 1.0: 101 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 102 | if not 0.0 <= betas[1] < 1.0: 103 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 104 | if not 0.0 <= weight_decay: 105 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 106 | if not 0.0 <= ema_decay <= 1.0: 107 | raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) 108 | defaults = dict(lr=lr, betas=betas, eps=eps, 109 | weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, 110 | ema_power=ema_power, param_names=param_names) 111 | super().__init__(params, defaults) 112 | 113 | def __setstate__(self, state): 114 | super().__setstate__(state) 115 | for group in self.param_groups: 116 | group.setdefault('amsgrad', False) 117 | 118 | @torch.no_grad() 119 | def step(self, closure=None): 120 | """Performs a single optimization step. 121 | Args: 122 | closure (callable, optional): A closure that reevaluates the model 123 | and returns the loss. 124 | """ 125 | loss = None 126 | if closure is not None: 127 | with torch.enable_grad(): 128 | loss = closure() 129 | 130 | for group in self.param_groups: 131 | params_with_grad = [] 132 | grads = [] 133 | exp_avgs = [] 134 | exp_avg_sqs = [] 135 | ema_params_with_grad = [] 136 | state_sums = [] 137 | max_exp_avg_sqs = [] 138 | state_steps = [] 139 | amsgrad = group['amsgrad'] 140 | beta1, beta2 = group['betas'] 141 | ema_decay = group['ema_decay'] 142 | ema_power = group['ema_power'] 143 | 144 | for p in group['params']: 145 | if p.grad is None: 146 | continue 147 | params_with_grad.append(p) 148 | if p.grad.is_sparse: 149 | raise RuntimeError('AdamW does not support sparse gradients') 150 | grads.append(p.grad) 151 | 152 | state = self.state[p] 153 | 154 | # State initialization 155 | if len(state) == 0: 156 | state['step'] = 0 157 | # Exponential moving average of gradient values 158 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 159 | # Exponential moving average of squared gradient values 160 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 161 | if amsgrad: 162 | # Maintains max of all exp. moving avg. of sq. grad. values 163 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 164 | # Exponential moving average of parameter values 165 | state['param_exp_avg'] = p.detach().float().clone() 166 | 167 | exp_avgs.append(state['exp_avg']) 168 | exp_avg_sqs.append(state['exp_avg_sq']) 169 | ema_params_with_grad.append(state['param_exp_avg']) 170 | 171 | if amsgrad: 172 | max_exp_avg_sqs.append(state['max_exp_avg_sq']) 173 | 174 | # update the steps for each param group update 175 | state['step'] += 1 176 | # record the step after step update 177 | state_steps.append(state['step']) 178 | 179 | optim._functional.adamw(params_with_grad, 180 | grads, 181 | exp_avgs, 182 | exp_avg_sqs, 183 | max_exp_avg_sqs, 184 | state_steps, 185 | amsgrad=amsgrad, 186 | beta1=beta1, 187 | beta2=beta2, 188 | lr=group['lr'], 189 | weight_decay=group['weight_decay'], 190 | eps=group['eps'], 191 | maximize=False) 192 | 193 | cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) 194 | for param, ema_param in zip(params_with_grad, ema_params_with_grad): 195 | ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) 196 | 197 | return loss -------------------------------------------------------------------------------- /preprocessor/depth_preprocessor.py: -------------------------------------------------------------------------------- 1 | class Preprocessor: 2 | def __init__(self) -> None: 3 | pass 4 | 5 | def get_depth(self, input_dir, file_name): 6 | return -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | albumentations==1.3.1 2 | azureml==0.2.7 3 | chumpy==0.70 4 | einops==0.7.0 5 | matplotlib==3.7.1 6 | mediapipe 7 | numpy==1.23.5 8 | omegaconf==2.1.1 9 | opencv_contrib_python==4.7.0.72 10 | opencv_python==4.7.0.72 11 | opencv_python_headless==4.7.0.72 12 | Pillow==9.4.0 13 | pytorch_lightning==1.4.2 14 | pytorch_pretrained_bert==0.6.2 15 | safetensors==0.3.3 16 | scipy==1.9.0 17 | timm==0.6.13 18 | torch==2.0.0 19 | torchvision==0.15.1 20 | tqdm==4.65.0 21 | transformers==4.27.4 22 | trimesh[easy]==3.23.5 23 | yacs==0.1.8 24 | 25 | # If encounter any error, please see ControlNet and MeshGraphormer for more complete package requirements. 26 | -------------------------------------------------------------------------------- /scripts/_gcnn.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import scipy.sparse 6 | import math 7 | 8 | class SparseMM(torch.autograd.Function): 9 | """Redefine sparse @ dense matrix multiplication to enable backpropagation. 10 | The builtin matrix multiplication operation does not support backpropagation in some cases. 11 | """ 12 | @staticmethod 13 | def forward(ctx, sparse, dense): 14 | ctx.req_grad = dense.requires_grad 15 | ctx.save_for_backward(sparse) 16 | return torch.matmul(sparse, dense) 17 | 18 | @staticmethod 19 | def backward(ctx, grad_output): 20 | grad_input = None 21 | sparse, = ctx.saved_tensors 22 | if ctx.req_grad: 23 | grad_input = torch.matmul(sparse.t(), grad_output) 24 | return None, grad_input 25 | 26 | def spmm(sparse, dense): 27 | return SparseMM.apply(sparse, dense) 28 | 29 | 30 | def gelu(x): 31 | """Implementation of the gelu activation function. 32 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 33 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 34 | Also see https://arxiv.org/abs/1606.08415 35 | """ 36 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 37 | 38 | class BertLayerNorm(torch.nn.Module): 39 | def __init__(self, hidden_size, eps=1e-12): 40 | """Construct a layernorm module in the TF style (epsilon inside the square root). 41 | """ 42 | super(BertLayerNorm, self).__init__() 43 | self.weight = torch.nn.Parameter(torch.ones(hidden_size)) 44 | self.bias = torch.nn.Parameter(torch.zeros(hidden_size)) 45 | self.variance_epsilon = eps 46 | 47 | def forward(self, x): 48 | u = x.mean(-1, keepdim=True) 49 | s = (x - u).pow(2).mean(-1, keepdim=True) 50 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 51 | return self.weight * x + self.bias 52 | 53 | 54 | class GraphResBlock(torch.nn.Module): 55 | """ 56 | Graph Residual Block similar to the Bottleneck Residual Block in ResNet 57 | """ 58 | def __init__(self, in_channels, out_channels, mesh_type='body'): 59 | super(GraphResBlock, self).__init__() 60 | self.in_channels = in_channels 61 | self.out_channels = out_channels 62 | self.lin1 = GraphLinear(in_channels, out_channels // 2) 63 | self.conv = GraphConvolution(out_channels // 2, out_channels // 2, mesh_type) 64 | self.lin2 = GraphLinear(out_channels // 2, out_channels) 65 | self.skip_conv = GraphLinear(in_channels, out_channels) 66 | # print('Use BertLayerNorm in GraphResBlock') 67 | self.pre_norm = BertLayerNorm(in_channels) 68 | self.norm1 = BertLayerNorm(out_channels // 2) 69 | self.norm2 = BertLayerNorm(out_channels // 2) 70 | 71 | def forward(self, x): 72 | trans_y = F.relu(self.pre_norm(x)).transpose(1,2) 73 | y = self.lin1(trans_y).transpose(1,2) 74 | 75 | y = F.relu(self.norm1(y)) 76 | y = self.conv(y) 77 | 78 | trans_y = F.relu(self.norm2(y)).transpose(1,2) 79 | y = self.lin2(trans_y).transpose(1,2) 80 | 81 | z = x+y 82 | 83 | return z 84 | 85 | # class GraphResBlock(torch.nn.Module): 86 | # """ 87 | # Graph Residual Block similar to the Bottleneck Residual Block in ResNet 88 | # """ 89 | # def __init__(self, in_channels, out_channels, mesh_type='body'): 90 | # super(GraphResBlock, self).__init__() 91 | # self.in_channels = in_channels 92 | # self.out_channels = out_channels 93 | # self.conv = GraphConvolution(self.in_channels, self.out_channels, mesh_type) 94 | # print('Use BertLayerNorm and GeLU in GraphResBlock') 95 | # self.norm = BertLayerNorm(self.out_channels) 96 | # def forward(self, x): 97 | # y = self.conv(x) 98 | # y = self.norm(y) 99 | # y = gelu(y) 100 | # z = x+y 101 | # return z 102 | 103 | class GraphLinear(torch.nn.Module): 104 | """ 105 | Generalization of 1x1 convolutions on Graphs 106 | """ 107 | def __init__(self, in_channels, out_channels): 108 | super(GraphLinear, self).__init__() 109 | self.in_channels = in_channels 110 | self.out_channels = out_channels 111 | self.W = torch.nn.Parameter(torch.FloatTensor(out_channels, in_channels)) 112 | self.b = torch.nn.Parameter(torch.FloatTensor(out_channels)) 113 | self.reset_parameters() 114 | 115 | def reset_parameters(self): 116 | w_stdv = 1 / (self.in_channels * self.out_channels) 117 | self.W.data.uniform_(-w_stdv, w_stdv) 118 | self.b.data.uniform_(-w_stdv, w_stdv) 119 | 120 | def forward(self, x): 121 | return torch.matmul(self.W[None, :], x) + self.b[None, :, None] 122 | 123 | class GraphConvolution(torch.nn.Module): 124 | """Simple GCN layer, similar to https://arxiv.org/abs/1609.02907.""" 125 | def __init__(self, in_features, out_features, mesh='body', bias=True): 126 | super(GraphConvolution, self).__init__() 127 | device=torch.device('cuda') 128 | self.in_features = in_features 129 | self.out_features = out_features 130 | 131 | if mesh=='body': 132 | adj_indices = torch.load('MeshGraphormer/src/modeling/data/smpl_431_adjmat_indices.pt') 133 | adj_mat_value = torch.load('MeshGraphormer/src/modeling/data/smpl_431_adjmat_values.pt') 134 | adj_mat_size = torch.load('MeshGraphormer/src/modeling/data/smpl_431_adjmat_size.pt') 135 | elif mesh=='hand': 136 | adj_indices = torch.load('MeshGraphormer/src/modeling/data/mano_195_adjmat_indices.pt') 137 | adj_mat_value = torch.load('MeshGraphormer/src/modeling/data/mano_195_adjmat_values.pt') 138 | adj_mat_size = torch.load('MeshGraphormer/src/modeling/data/mano_195_adjmat_size.pt') 139 | 140 | self.adjmat = torch.sparse_coo_tensor(adj_indices, adj_mat_value, size=adj_mat_size).to(device) 141 | 142 | self.weight = torch.nn.Parameter(torch.FloatTensor(in_features, out_features)) 143 | if bias: 144 | self.bias = torch.nn.Parameter(torch.FloatTensor(out_features)) 145 | else: 146 | self.register_parameter('bias', None) 147 | self.reset_parameters() 148 | 149 | def reset_parameters(self): 150 | # stdv = 1. / math.sqrt(self.weight.size(1)) 151 | stdv = 6. / math.sqrt(self.weight.size(0) + self.weight.size(1)) 152 | self.weight.data.uniform_(-stdv, stdv) 153 | if self.bias is not None: 154 | self.bias.data.uniform_(-stdv, stdv) 155 | 156 | def forward(self, x): 157 | if x.ndimension() == 2: 158 | support = torch.matmul(x, self.weight) 159 | output = torch.matmul(self.adjmat, support) 160 | if self.bias is not None: 161 | output = output + self.bias 162 | return output 163 | else: 164 | output = [] 165 | for i in range(x.shape[0]): 166 | support = torch.matmul(x[i], self.weight) 167 | # output.append(torch.matmul(self.adjmat, support)) 168 | output.append(spmm(self.adjmat, support)) 169 | output = torch.stack(output, dim=0) 170 | if self.bias is not None: 171 | output = output + self.bias 172 | return output 173 | 174 | def __repr__(self): 175 | return self.__class__.__name__ + ' (' \ 176 | + str(self.in_features) + ' -> ' \ 177 | + str(self.out_features) + ')' -------------------------------------------------------------------------------- /scripts/_mano.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains the MANO defination and mesh sampling operations for MANO mesh 3 | 4 | Adapted from opensource projects 5 | MANOPTH (https://github.com/hassony2/manopth) 6 | Pose2Mesh (https://github.com/hongsukchoi/Pose2Mesh_RELEASE) 7 | GraphCMR (https://github.com/nkolot/GraphCMR/) 8 | """ 9 | 10 | from __future__ import division 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | import os.path as osp 15 | import json 16 | import code 17 | from manopth.manolayer import ManoLayer 18 | import scipy.sparse 19 | import src.modeling.data.config as cfg 20 | 21 | class MANO(nn.Module): 22 | def __init__(self): 23 | super(MANO, self).__init__() 24 | 25 | self.mano_dir = 'MeshGraphormer/src/modeling/data' 26 | self.layer = self.get_layer() 27 | self.vertex_num = 778 28 | self.face = self.layer.th_faces.numpy() 29 | self.joint_regressor = self.layer.th_J_regressor.numpy() 30 | 31 | self.joint_num = 21 32 | self.joints_name = ('Wrist', 'Thumb_1', 'Thumb_2', 'Thumb_3', 'Thumb_4', 'Index_1', 'Index_2', 'Index_3', 'Index_4', 'Middle_1', 'Middle_2', 'Middle_3', 'Middle_4', 'Ring_1', 'Ring_2', 'Ring_3', 'Ring_4', 'Pinky_1', 'Pinky_2', 'Pinky_3', 'Pinky_4') 33 | self.skeleton = ( (0,1), (0,5), (0,9), (0,13), (0,17), (1,2), (2,3), (3,4), (5,6), (6,7), (7,8), (9,10), (10,11), (11,12), (13,14), (14,15), (15,16), (17,18), (18,19), (19,20) ) 34 | self.root_joint_idx = self.joints_name.index('Wrist') 35 | 36 | # add fingertips to joint_regressor 37 | self.fingertip_vertex_idx = [745, 317, 444, 556, 673] # mesh vertex idx (right hand) 38 | thumbtip_onehot = np.array([1 if i == 745 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) 39 | indextip_onehot = np.array([1 if i == 317 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) 40 | middletip_onehot = np.array([1 if i == 445 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) 41 | ringtip_onehot = np.array([1 if i == 556 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) 42 | pinkytip_onehot = np.array([1 if i == 673 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) 43 | self.joint_regressor = np.concatenate((self.joint_regressor, thumbtip_onehot, indextip_onehot, middletip_onehot, ringtip_onehot, pinkytip_onehot)) 44 | self.joint_regressor = self.joint_regressor[[0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20],:] 45 | joint_regressor_torch = torch.from_numpy(self.joint_regressor).float() 46 | self.register_buffer('joint_regressor_torch', joint_regressor_torch) 47 | 48 | def get_layer(self): 49 | return ManoLayer(mano_root=osp.join(self.mano_dir), flat_hand_mean=False, use_pca=False) # load right hand MANO model 50 | 51 | def get_3d_joints(self, vertices): 52 | """ 53 | This method is used to get the joint locations from the SMPL mesh 54 | Input: 55 | vertices: size = (B, 778, 3) 56 | Output: 57 | 3D joints: size = (B, 21, 3) 58 | """ 59 | joints = torch.einsum('bik,ji->bjk', [vertices, self.joint_regressor_torch]) 60 | return joints 61 | 62 | 63 | class SparseMM(torch.autograd.Function): 64 | """Redefine sparse @ dense matrix multiplication to enable backpropagation. 65 | The builtin matrix multiplication operation does not support backpropagation in some cases. 66 | """ 67 | @staticmethod 68 | def forward(ctx, sparse, dense): 69 | ctx.req_grad = dense.requires_grad 70 | ctx.save_for_backward(sparse) 71 | return torch.matmul(sparse, dense) 72 | 73 | @staticmethod 74 | def backward(ctx, grad_output): 75 | grad_input = None 76 | sparse, = ctx.saved_tensors 77 | if ctx.req_grad: 78 | grad_input = torch.matmul(sparse.t(), grad_output) 79 | return None, grad_input 80 | 81 | def spmm(sparse, dense): 82 | return SparseMM.apply(sparse, dense) 83 | 84 | 85 | def scipy_to_pytorch(A, U, D): 86 | """Convert scipy sparse matrices to pytorch sparse matrix.""" 87 | ptU = [] 88 | ptD = [] 89 | 90 | for i in range(len(U)): 91 | u = scipy.sparse.coo_matrix(U[i]) 92 | i = torch.LongTensor(np.array([u.row, u.col])) 93 | v = torch.FloatTensor(u.data) 94 | ptU.append(torch.sparse.FloatTensor(i, v, u.shape)) 95 | 96 | for i in range(len(D)): 97 | d = scipy.sparse.coo_matrix(D[i]) 98 | i = torch.LongTensor(np.array([d.row, d.col])) 99 | v = torch.FloatTensor(d.data) 100 | ptD.append(torch.sparse.FloatTensor(i, v, d.shape)) 101 | 102 | return ptU, ptD 103 | 104 | 105 | def adjmat_sparse(adjmat, nsize=1): 106 | """Create row-normalized sparse graph adjacency matrix.""" 107 | adjmat = scipy.sparse.csr_matrix(adjmat) 108 | if nsize > 1: 109 | orig_adjmat = adjmat.copy() 110 | for _ in range(1, nsize): 111 | adjmat = adjmat * orig_adjmat 112 | adjmat.data = np.ones_like(adjmat.data) 113 | for i in range(adjmat.shape[0]): 114 | adjmat[i,i] = 1 115 | num_neighbors = np.array(1 / adjmat.sum(axis=-1)) 116 | adjmat = adjmat.multiply(num_neighbors) 117 | adjmat = scipy.sparse.coo_matrix(adjmat) 118 | row = adjmat.row 119 | col = adjmat.col 120 | data = adjmat.data 121 | i = torch.LongTensor(np.array([row, col])) 122 | v = torch.from_numpy(data).float() 123 | adjmat = torch.sparse.FloatTensor(i, v, adjmat.shape) 124 | return adjmat 125 | 126 | def get_graph_params(filename, nsize=1): 127 | """Load and process graph adjacency matrix and upsampling/downsampling matrices.""" 128 | data = np.load(filename, encoding='latin1', allow_pickle=True) 129 | A = data['A'] 130 | U = data['U'] 131 | D = data['D'] 132 | U, D = scipy_to_pytorch(A, U, D) 133 | A = [adjmat_sparse(a, nsize=nsize) for a in A] 134 | return A, U, D 135 | 136 | 137 | class Mesh(object): 138 | """Mesh object that is used for handling certain graph operations.""" 139 | def __init__(self, filename=cfg.MANO_sampling_matrix, 140 | num_downsampling=1, nsize=1, device=torch.device('cuda')): 141 | self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize) 142 | # self._A = [a.to(device) for a in self._A] 143 | self._U = [u.to(device) for u in self._U] 144 | self._D = [d.to(device) for d in self._D] 145 | self.num_downsampling = num_downsampling 146 | 147 | def downsample(self, x, n1=0, n2=None): 148 | """Downsample mesh.""" 149 | if n2 is None: 150 | n2 = self.num_downsampling 151 | if x.ndimension() < 3: 152 | for i in range(n1, n2): 153 | x = spmm(self._D[i], x) 154 | elif x.ndimension() == 3: 155 | out = [] 156 | for i in range(x.shape[0]): 157 | y = x[i] 158 | for j in range(n1, n2): 159 | y = spmm(self._D[j], y) 160 | out.append(y) 161 | x = torch.stack(out, dim=0) 162 | return x 163 | 164 | def upsample(self, x, n1=1, n2=0): 165 | """Upsample mesh.""" 166 | if x.ndimension() < 3: 167 | for i in reversed(range(n2, n1)): 168 | x = spmm(self._U[i], x) 169 | elif x.ndimension() == 3: 170 | out = [] 171 | for i in range(x.shape[0]): 172 | y = x[i] 173 | for j in reversed(range(n2, n1)): 174 | y = spmm(self._U[j], y) 175 | out.append(y) 176 | x = torch.stack(out, dim=0) 177 | return x 178 | -------------------------------------------------------------------------------- /scripts/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains definitions of useful data stuctures and the paths 3 | for the datasets and data files necessary to run the code. 4 | 5 | Adapted from opensource project GraphCMR (https://github.com/nkolot/GraphCMR/) and Pose2Mesh (https://github.com/hongsukchoi/Pose2Mesh_RELEASE) 6 | 7 | """ 8 | 9 | from os.path import join 10 | folder_path = 'MeshGraphormer/src/modeling/' 11 | JOINT_REGRESSOR_TRAIN_EXTRA = folder_path + 'data/J_regressor_extra.npy' 12 | JOINT_REGRESSOR_H36M_correct = folder_path + 'data/J_regressor_h36m_correct.npy' 13 | SMPL_FILE = folder_path + 'data/basicModel_neutral_lbs_10_207_0_v1.0.0.pkl' 14 | SMPL_Male = folder_path + 'data/basicModel_m_lbs_10_207_0_v1.0.0.pkl' 15 | SMPL_Female = folder_path + 'data/basicModel_f_lbs_10_207_0_v1.0.0.pkl' 16 | SMPL_sampling_matrix = folder_path + 'data/mesh_downsampling.npz' 17 | MANO_FILE = folder_path + 'data/MANO_RIGHT.pkl' 18 | MANO_sampling_matrix = folder_path + 'data/mano_downsampling.npz' 19 | 20 | JOINTS_IDX = [8, 5, 29, 30, 4, 7, 21, 19, 17, 16, 18, 20, 31, 32, 33, 34, 35, 36, 37, 24, 26, 25, 28, 27] 21 | 22 | 23 | """ 24 | We follow the body joint definition, loss functions, and evaluation metrics from 25 | open source project GraphCMR (https://github.com/nkolot/GraphCMR/) 26 | 27 | Each dataset uses different sets of joints. 28 | We use a superset of 24 joints such that we include all joints from every dataset. 29 | If a dataset doesn't provide annotations for a specific joint, we simply ignore it. 30 | The joints used here are: 31 | """ 32 | J24_NAME = ('R_Ankle', 'R_Knee', 'R_Hip', 'L_Hip', 'L_Knee', 'L_Ankle', 'R_Wrist', 'R_Elbow', 'R_Shoulder', 'L_Shoulder', 33 | 'L_Elbow','L_Wrist','Neck','Top_of_Head','Pelvis','Thorax','Spine','Jaw','Head','Nose','L_Eye','R_Eye','L_Ear','R_Ear') 34 | H36M_J17_NAME = ( 'Pelvis', 'R_Hip', 'R_Knee', 'R_Ankle', 'L_Hip', 'L_Knee', 'L_Ankle', 'Torso', 'Neck', 'Nose', 'Head', 35 | 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'R_Shoulder', 'R_Elbow', 'R_Wrist') 36 | J24_TO_J14 = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18] 37 | H36M_J17_TO_J14 = [3, 2, 1, 4, 5, 6, 16, 15, 14, 11, 12, 13, 8, 10] 38 | 39 | """ 40 | We follow the hand joint definition and mesh topology from 41 | open source project Manopth (https://github.com/hassony2/manopth) 42 | 43 | The hand joints used here are: 44 | """ 45 | J_NAME = ('Wrist', 'Thumb_1', 'Thumb_2', 'Thumb_3', 'Thumb_4', 'Index_1', 'Index_2', 'Index_3', 'Index_4', 'Middle_1', 46 | 'Middle_2', 'Middle_3', 'Middle_4', 'Ring_1', 'Ring_2', 'Ring_3', 'Ring_4', 'Pinky_1', 'Pinky_2', 'Pinky_3', 'Pinky_4') 47 | ROOT_INDEX = 0 -------------------------------------------------------------------------------- /scripts/download_models.sh: -------------------------------------------------------------------------------- 1 | # -------------------------------- 2 | # Setup 3 | # -------------------------------- 4 | export REPO_DIR=$PWD 5 | if [ ! -d $REPO_DIR/models ] ; then 6 | mkdir -p $REPO_DIR/models 7 | fi 8 | BLOB='https://datarelease.blob.core.windows.net/metro' 9 | 10 | 11 | # -------------------------------- 12 | # Download our pre-trained models 13 | # -------------------------------- 14 | if [ ! -d $REPO_DIR/models/graphormer_release ] ; then 15 | mkdir -p $REPO_DIR/models/graphormer_release 16 | fi 17 | 18 | # (3) Mesh Graphormer for hand mesh reconstruction (trained on FreiHAND) 19 | wget -nc $BLOB/models/graphormer_hand_state_dict.bin -O $REPO_DIR/models/graphormer_release/graphormer_hand_state_dict.bin 20 | 21 | 22 | # -------------------------------- 23 | # Download the ImageNet pre-trained HRNet models 24 | # The weights are provided by https://github.com/HRNet/HRNet-Image-Classification 25 | # -------------------------------- 26 | if [ ! -d $REPO_DIR/models/hrnet ] ; then 27 | mkdir -p $REPO_DIR/models/hrnet 28 | fi 29 | wget -nc $BLOB/models/hrnetv2_w64_imagenet_pretrained.pth -O $REPO_DIR/models/hrnet/hrnetv2_w64_imagenet_pretrained.pth 30 | wget -nc $BLOB/models/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml -O $REPO_DIR/models/hrnet/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /test/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wenquanlu/HandRefiner/f07e196298bafe871064b2952587548e28ccd467/test/1.jpg -------------------------------------------------------------------------------- /test/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wenquanlu/HandRefiner/f07e196298bafe871064b2952587548e28ccd467/test/2.jpg -------------------------------------------------------------------------------- /test/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wenquanlu/HandRefiner/f07e196298bafe871064b2952587548e28ccd467/test/3.jpg -------------------------------------------------------------------------------- /test/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wenquanlu/HandRefiner/f07e196298bafe871064b2952587548e28ccd467/test/4.jpg -------------------------------------------------------------------------------- /test/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wenquanlu/HandRefiner/f07e196298bafe871064b2952587548e28ccd467/test/5.jpg -------------------------------------------------------------------------------- /test/test.json: -------------------------------------------------------------------------------- 1 | {"img": "1.jpg", "txt": "a man facing the camera, making a hand gesture, indoor"} 2 | {"img": "2.jpg", "txt": "a woman facing the camera, making a hand gesture, indoor"} 3 | {"img": "3.jpg", "txt": "a man facing the camera, making a hand gesture, indoor"} 4 | {"img": "4.jpg", "txt": "a man facing the camera, making a hand gesture, indoor"} 5 | {"img": "5.jpg", "txt": "a woman facing the camera, making a hand gesture, indoor"} -------------------------------------------------------------------------------- /training/README.md: -------------------------------------------------------------------------------- 1 | ## Training Script - train.py 2 | The training script should be placed at the same level of the cldm folder. 3 | Some paths needed to be manually set: 4 | 5 |
  • L40: path to SD1.5 6 |
  • L43: path to depth controlnet weight 7 | 8 | ## Data Loader - control_synthcompositedata.py 9 | The loader should be placed in ldm/data/ 10 | 11 | Some paths needed to be mannally set: 12 | 13 | dataset needs to be structured as: 14 | ```bash 15 | |- dataset1 16 | | |- image 17 | | |- mask 18 | | |- pose 19 | | |- prompt.json 20 | ``` 21 | Some paths needed to be manually set: 22 |
  • L9: path to dataset 1 23 |
  • L10: path to dataset 2 24 |
  • L18: path to dataset 1 prompt json file 25 |
  • L23: path to dataset 2 prompt json file 26 | 27 | Each prompt json file are structured as: 28 | ```json 29 | {"jpg": "image name", "txt": "text prompt", "dataset": "dataset identifier (RHD|synthesisai)"} 30 | ``` 31 | -------------------------------------------------------------------------------- /training/control_synthcompositedata.py: -------------------------------------------------------------------------------- 1 | import json 2 | import cv2 3 | import numpy as np 4 | from PIL import Image 5 | import random 6 | 7 | from torch.utils.data import Dataset 8 | 9 | DATA_PATH_1 = "../RHD/RHD_published_v2/" 10 | DATA_PATH_2 = "../synthesisai/" 11 | 12 | abbrev_dict = {"RHD": DATA_PATH_1, 13 | "synthesisai": DATA_PATH_2} 14 | 15 | class Control_composite_Hand_synth_data(Dataset): 16 | def __init__(self): 17 | self.data = [] 18 | with open('../RHD/RHD_published_v2/embedded_rgb_caption.json', 'rt') as f1: 19 | for line in f1: 20 | item = json.loads(line) 21 | item['dataset'] = 'RHD' 22 | self.data.append(item) 23 | with open('../synthesisai/embedded_rgb_caption.json', 'rt') as f2: 24 | for line in f2: 25 | item = json.loads(line) 26 | item['dataset'] = 'synthesisai' 27 | self.data.append(item) 28 | def __len__(self): 29 | return len(self.data) 30 | 31 | def __getitem__(self, idx): 32 | item = self.data[idx] 33 | source_filename = item['jpg'] 34 | prompt = item['txt'] 35 | dataset = item['dataset'] 36 | datapath = abbrev_dict[dataset] 37 | if random.random() < 0.5: 38 | prompt = "" 39 | source = cv2.imread(datapath + "image/" + source_filename) 40 | source = (source.astype(np.float32) / 127.5) - 1.0 41 | source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB) 42 | 43 | mask = np.array(Image.open(datapath + "mask/" + source_filename).convert("L")) 44 | mask = mask.astype(np.float32)/255.0 45 | mask = mask[None] 46 | mask[mask < 0.5] = 0 47 | mask[mask >= 0.5] = 1 48 | mask = np.transpose(mask, [1, 2, 0]) 49 | 50 | hint = cv2.imread(datapath + "pose/" + source_filename) 51 | hint = cv2.cvtColor(hint, cv2.COLOR_BGR2RGB) 52 | 53 | hint = hint.astype(np.float32) / 255.0 54 | 55 | masked_image = source * (mask < 0.5) 56 | return dict(jpg=source, txt=prompt, hint=hint, mask=mask, masked_image=masked_image) -------------------------------------------------------------------------------- /training/train.py: -------------------------------------------------------------------------------- 1 | from ldm.data.control_synthcompositedata import Control_composite_Hand_synth_data 2 | import torch 3 | import pytorch_lightning as pl 4 | from torch.utils.data import DataLoader 5 | from cldm.model import create_model, load_state_dict 6 | from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor 7 | from einops import rearrange 8 | from PIL import Image 9 | import numpy as np 10 | import os 11 | from cldm.logger import ImageLogger 12 | import argparse 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--devices", default="0", type=str, help="comma delimited list of devices") 16 | parser.add_argument("--batchsize", default=4, type=int) 17 | parser.add_argument("--acc_grad", default=4, type=int) 18 | parser.add_argument("--max_epochs", default=3, type=int) 19 | args = parser.parse_args() 20 | args.devices = [int(n) for n in args.devices.split(",")] 21 | 22 | def get_state_dict(d): 23 | return d.get('state_dict', d) 24 | def load_state_dict(ckpt_path, location='cpu'): 25 | _, extension = os.path.splitext(ckpt_path) 26 | if extension.lower() == ".safetensors": 27 | import safetensors.torch 28 | state_dict = safetensors.torch.load_file(ckpt_path, device=location) 29 | else: 30 | state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location))) 31 | state_dict = get_state_dict(state_dict) 32 | print(f'Loaded state_dict from [{ckpt_path}]') 33 | return state_dict 34 | 35 | learning_rate = 1e-5 36 | 37 | model = create_model("control_depth_inpaint.yaml") 38 | 39 | #### load the SD inpainting weights 40 | states = load_state_dict("./sd-v1-5-inpainting.ckpt", location='cpu') 41 | model.load_state_dict(states, strict=False) 42 | 43 | control_states = load_state_dict("./models/control_v11f1p_sd15_depth.pth") 44 | model.load_state_dict(control_states, strict=False) 45 | 46 | 47 | model.learning_rate = learning_rate 48 | model.sd_locked = True 49 | model.only_mid_control = False 50 | 51 | dataset = Control_composite_Hand_synth_data() 52 | 53 | checkpoint_callback = ModelCheckpoint(save_top_k=-1, monitor="epoch") 54 | 55 | #### start of the training expectation: the model should behave the same to standalone depth controlnet + inpainting SD 56 | dataloader = DataLoader(dataset, num_workers=8, batch_size=args.batchsize, shuffle=True) 57 | trainer = pl.Trainer(precision=32, max_epochs=args.max_epochs, accelerator="gpu", devices=args.devices, accumulate_grad_batches=args.acc_grad, callbacks=[ImageLogger(), checkpoint_callback], strategy='ddp') 58 | trainer.fit(model, dataloader) --------------------------------------------------------------------------------