├── .gitignore
├── Figs
├── banner.png
└── github_results.png
├── LICENSE
├── README.md
├── cldm
├── cldm.py
├── ddim_hacked.py
├── hack.py
├── logger.py
└── model.py
├── config.py
├── control_depth_inpaint.yaml
├── docs
├── installation.md
├── manual.md
└── meshgraphormer.md
├── handrefiner.py
├── ldm
├── data
│ ├── __init__.py
│ ├── control_synthcompositedata.py
│ └── util.py
├── models
│ ├── autoencoder.py
│ └── diffusion
│ │ ├── __init__.py
│ │ ├── ddim.py
│ │ ├── ddpm.py
│ │ ├── dpm_solver
│ │ ├── __init__.py
│ │ ├── dpm_solver.py
│ │ └── sampler.py
│ │ ├── plms.py
│ │ └── sampling_util.py
├── modules
│ ├── attention.py
│ ├── diffusionmodules
│ │ ├── __init__.py
│ │ ├── model.py
│ │ ├── openaimodel.py
│ │ ├── upscaling.py
│ │ └── util.py
│ ├── distributions
│ │ ├── __init__.py
│ │ └── distributions.py
│ ├── ema.py
│ ├── encoders
│ │ ├── __init__.py
│ │ └── modules.py
│ ├── image_degradation
│ │ ├── __init__.py
│ │ ├── bsrgan.py
│ │ ├── bsrgan_light.py
│ │ ├── utils
│ │ │ └── test.png
│ │ └── utils_image.py
│ └── midas
│ │ ├── __init__.py
│ │ ├── api.py
│ │ ├── midas
│ │ ├── __init__.py
│ │ ├── base_model.py
│ │ ├── blocks.py
│ │ ├── dpt_depth.py
│ │ ├── midas_net.py
│ │ ├── midas_net_custom.py
│ │ ├── transforms.py
│ │ └── vit.py
│ │ └── utils.py
└── util.py
├── preprocessor
├── depth_preprocessor.py
└── meshgraphormer.py
├── requirements.txt
├── scripts
├── _gcnn.py
├── _mano.py
├── config.py
└── download_models.sh
├── test
├── 1.jpg
├── 2.jpg
├── 3.jpg
├── 4.jpg
├── 5.jpg
└── test.json
└── training
├── README.md
├── control_synthcompositedata.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | **/.DS_Store
--------------------------------------------------------------------------------
/Figs/banner.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wenquanlu/HandRefiner/f07e196298bafe871064b2952587548e28ccd467/Figs/banner.png
--------------------------------------------------------------------------------
/Figs/github_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wenquanlu/HandRefiner/f07e196298bafe871064b2952587548e28ccd467/Figs/github_results.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Wenquan Lu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
HandRefiner: Refining Malformed Hands in Generated Images by Diffusion-based Conditional Inpainting
2 |
3 |
4 |
5 | # News
6 |
7 | **2023.12.1**
8 | The paper is post on arxiv!
9 |
10 | **2023.12.29**
11 | First code commit released.
12 |
13 | **2024.1.7**
14 | The preprocessor and the finetuned model have been ported to [ComfyUI controlnet](https://github.com/Fannovel16/comfyui_controlnet_aux). The preprocessor has been ported to [sd webui controlnet](https://github.com/Mikubill/sd-webui-controlnet). Thanks for all your great work!
15 |
16 | **2024.1.15**
17 | ⚠️ When using finetuned ControlNet from this repository or [control_sd15_inpaint_depth_hand](https://huggingface.co/hr16/ControlNet-HandRefiner-pruned), I noticed many still use control strength/control weight of 1 which can result in loss of texture. As stated in the paper, we recommend using a smaller control strength (e.g. 0.4 - 0.8).
18 |
19 | # Introduction
20 |
21 | This is the official repository of the paper HandRefiner: Refining Malformed Hands in Generated Images by Diffusion-based Conditional Inpainting
22 |
23 |
24 |
25 | Figure 1: Stable Diffusion (first two rows) and SDXL (last row) generate malformed hands (left in each pair), e.g., incorrect
26 | number of fingers or irregular shapes, which can be effectively rectified by our HandRefiner (right in each pair).
27 |
28 |
29 |
30 |
31 |
32 |
33 | In this study, we introduce a lightweight post-processing solution called HandRefiner to correct malformed hands in generated images. HandRefiner employs a conditional inpainting
34 | approach to rectify malformed hands while leaving other
35 | parts of the image untouched. We leverage the hand mesh
36 | reconstruction model that consistently adheres to the correct number of fingers and hand shape, while also being
37 | capable of fitting the desired hand pose in the generated
38 | image. Given a generated failed image due to malformed
39 | hands, we utilize ControlNet modules to re-inject such correct hand information. Additionally, we uncover a phase
40 | transition phenomenon within ControlNet as we vary the
41 | control strength. It enables us to take advantage of more
42 | readily available synthetic data without suffering from the
43 | domain gap between realistic and synthetic hands.
44 |
45 | # Visual Results
46 |
47 |
48 |
49 |
50 | # Installation
51 | Check [installation.md](docs/installation.md) for installation instructions.
52 |
53 | # Manual
54 | Check [manual.md](docs/manual.md) for an explanation of commands to execute the HandRefiner.
55 |
56 | # Get Started
57 | For single image rectification:
58 | ```bash
59 | python handrefiner.py --input_img test/1.jpg --out_dir output --strength 0.55 --weights models/inpaint_depth_control.ckpt --prompt "a man facing the camera, making a hand gesture, indoor" --seed 1
60 | ```
61 | For multiple image rectifications:
62 | ```bash
63 | python handrefiner.py --input_dir test --out_dir output --strength 0.55 --weights models/inpaint_depth_control.ckpt --prompt_file test/test.json --seed 1
64 | ```
65 |
66 |
67 |
68 | # Important Q&A
69 |
70 | What kind of images can be rectified?
71 |
72 | Like any method, this method also has its limits. If the original hands are so bad that are inrecognisable from human eyes, then it is pretty much impossible for neural networks to fit a reasonable mesh. Also, due to the fitting nature of the method, we do not rectify the hand size. So if you have a giant malformed hand in the original image, you will still get a giant hand back in the rectified image. Thus malformed hands with hand-like shape and appropriate size can be rectified.
73 |
74 | Can we use it on SDXL images?
75 |
76 | In the paper, the SDXL images are resized to 512x512 before the rectification, because the base model used in this project is sd1.5.
77 | Solution for SDXL:
78 | However, it is certainly not difficult to implement it in SDXL, and I believe many implementations already have the functionality of using inpainting SDXL combined with depth controlnet to inpaint the image.
79 | So what you can do is get the depth map and masks from the pipeline of this repository, then pipe them to the whatever implementation for SDXL you use for inpainting the image.
80 | A caveat is that I have not tested this before, and as mentioned in the paper, since depth controlnet is not fine-tuned on the hand mesh data, it may have a high rate of failed inpainting. In that case, you can use the technique mentioned in the paper, using available synthetic data to fine-tune the depth sdxl controlnet, for example, using these two datasets here [[1]](https://synthesis.ai/static-gestures-dataset/)[[2]](https://synthesis.ai/animated-gestures-dataset/), then you can adjust control strength to get the desired texture and appearance.
81 |
82 | What if the generation failed?
83 |
84 | The first thing is to check the depth map, if the depth map is bad, you can consider using a different mesh reconstruction model to reconstruct the mesh.
85 |
86 | Second things is to check if the masks of hands fully cover the malformed hands, some malformed hand can have very long fingers so it may not be covered by the detected masks, to fix this
87 | 1. Consider using a greater padding by adjusting the pad parameter in the argument
88 | 2. Provide a hand-drawn mask
89 |
90 | If all of the previous steps are ok, then you may need to regenerate several times or try different control strengths. <- changing the seed can be very helpful.
91 |
92 | Since the small hands is a limitation mentioned in the paper, what is the appropriate hand size for the SD v1.5 weight?
93 |
94 | Generally, hands with size at least 60px × 60px is recommended for the current weights. To make it applicable for small hands, consider scale up the image using some super-resolution methods.
95 |
96 | How to contribute to this project?
97 |
98 | In the last decade, the CV community has produced dozens of highly accurate mesh reconstruction models, in this project we use the recent SOTA model Mesh Graphormer on the FreiHAND benchmark. However, it is very welcome to contribute to this project by porting other models here, I have written a template parent class for models under preprocessor folder.
99 |
100 | Can I use it for Anime hands or other styles?
101 |
102 | As long as the hand detection model and the mesh reconstruction model are able to detect the hands and reconstruct meshes, it should work for other styles. However, from my understanding, these models are not trained on cartoon or anime images, so there is a great chance that the mesh reconstruction stage may fail.
103 |
104 |
105 |
106 | ## Comments
107 | - Our codebase builds heavily on [stable-diffusion](https://github.com/CompVis/stable-diffusion), [ControlNet](https://github.com/lllyasviel/ControlNet) and [MeshGraphormer](https://github.com/microsoft/MeshGraphormer).
108 |
109 | ## Citation
110 |
111 | If you find HandRefiner helpful, please consider giving this repo a star :star: and citing:
112 |
113 | ```
114 | @inproceedings{10.1145/3664647.3680693,
115 | author = {Lu, Wenquan and Xu, Yufei and Zhang, Jing and Wang, Chaoyue and Tao, Dacheng},
116 | title = {HandRefiner: Refining Malformed Hands in Generated Images by Diffusion-based Conditional Inpainting},
117 | year = {2024},
118 | isbn = {9798400706868},
119 | publisher = {Association for Computing Machinery},
120 | address = {New York, NY, USA},
121 | url = {https://doi.org/10.1145/3664647.3680693},
122 | doi = {10.1145/3664647.3680693},
123 | abstract = {Diffusion models have achieved remarkable success in generating realistic images but suffer from generating accurate human hands, such as incorrect finger counts or irregular shapes. This difficulty arises from the complex task of learning the physical structure and pose of hands from training images, which involves extensive deformations and occlusions. For correct hand generation, our paper introduces a lightweight post-processing solution called HandRefiner. HandRefiner employs a conditional inpainting approach to rectify malformed hands while leaving other parts of the image untouched. We leverage the hand mesh reconstruction model that consistently adheres to the correct number of fingers and hand shape, while also being capable of fitting the desired hand pose in the generated image. Given a generated failed image due to malformed hands, we utilize ControlNet modules to re-inject such correct hand information. Additionally, we uncover a phase transition phenomenon within ControlNet as we vary the control strength. It enables us to take advantage of more readily available synthetic data without suffering from the domain gap between realistic and synthetic hands. Experiments demonstrate that HandRefiner can significantly improve the generation quality quantitatively and qualitatively. The code is available at https://github.com/wenquanlu/HandRefiner.},
124 | booktitle = {Proceedings of the 32nd ACM International Conference on Multimedia},
125 | pages = {7085–7093},
126 | numpages = {9},
127 | keywords = {deep learning, diffusion models, image inpainting},
128 | location = {Melbourne VIC, Australia},
129 | series = {MM '24}
130 | }
131 | ```
132 |
--------------------------------------------------------------------------------
/cldm/hack.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import einops
3 |
4 | import ldm.modules.encoders.modules
5 | import ldm.modules.attention
6 |
7 | from transformers import logging
8 | from ldm.modules.attention import default
9 |
10 |
11 | def disable_verbosity():
12 | logging.set_verbosity_error()
13 | print('logging improved.')
14 | return
15 |
16 |
17 | def enable_sliced_attention():
18 | ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward
19 | print('Enabled sliced_attention.')
20 | return
21 |
22 |
23 | def hack_everything(clip_skip=0):
24 | disable_verbosity()
25 | ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward
26 | ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip
27 | print('Enabled clip hacks.')
28 | return
29 |
30 |
31 | # Written by Lvmin
32 | def _hacked_clip_forward(self, text):
33 | PAD = self.tokenizer.pad_token_id
34 | EOS = self.tokenizer.eos_token_id
35 | BOS = self.tokenizer.bos_token_id
36 |
37 | def tokenize(t):
38 | return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"]
39 |
40 | def transformer_encode(t):
41 | if self.clip_skip > 1:
42 | rt = self.transformer(input_ids=t, output_hidden_states=True)
43 | return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip])
44 | else:
45 | return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state
46 |
47 | def split(x):
48 | return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3]
49 |
50 | def pad(x, p, i):
51 | return x[:i] if len(x) >= i else x + [p] * (i - len(x))
52 |
53 | raw_tokens_list = tokenize(text)
54 | tokens_list = []
55 |
56 | for raw_tokens in raw_tokens_list:
57 | raw_tokens_123 = split(raw_tokens)
58 | raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123]
59 | raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123]
60 | tokens_list.append(raw_tokens_123)
61 |
62 | tokens_list = torch.IntTensor(tokens_list).to(self.device)
63 |
64 | feed = einops.rearrange(tokens_list, 'b f i -> (b f) i')
65 | y = transformer_encode(feed)
66 | z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3)
67 |
68 | return z
69 |
70 |
71 | # Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py
72 | def _hacked_sliced_attentin_forward(self, x, context=None, mask=None):
73 | h = self.heads
74 |
75 | q = self.to_q(x)
76 | context = default(context, x)
77 | k = self.to_k(context)
78 | v = self.to_v(context)
79 | del context, x
80 |
81 | q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
82 |
83 | limit = k.shape[0]
84 | att_step = 1
85 | q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0))
86 | k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0))
87 | v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0))
88 |
89 | q_chunks.reverse()
90 | k_chunks.reverse()
91 | v_chunks.reverse()
92 | sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
93 | del k, q, v
94 | for i in range(0, limit, att_step):
95 | q_buffer = q_chunks.pop()
96 | k_buffer = k_chunks.pop()
97 | v_buffer = v_chunks.pop()
98 | sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
99 |
100 | del k_buffer, q_buffer
101 | # attention, what we cannot get enough of, by chunks
102 |
103 | sim_buffer = sim_buffer.softmax(dim=-1)
104 |
105 | sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
106 | del v_buffer
107 | sim[i:i + att_step, :, :] = sim_buffer
108 |
109 | del sim_buffer
110 | sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h)
111 | return self.to_out(sim)
112 |
--------------------------------------------------------------------------------
/cldm/logger.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import torch
5 | import torchvision
6 | from PIL import Image
7 | from pytorch_lightning.callbacks import Callback
8 | from pytorch_lightning.utilities.distributed import rank_zero_only
9 |
10 |
11 | class ImageLogger(Callback):
12 | def __init__(self, batch_frequency=2000, max_images=4, clamp=True, increase_log_steps=True,
13 | rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
14 | log_images_kwargs=None):
15 | super().__init__()
16 | self.rescale = rescale
17 | self.batch_freq = batch_frequency
18 | self.max_images = max_images
19 | if not increase_log_steps:
20 | self.log_steps = [self.batch_freq]
21 | self.clamp = clamp
22 | self.disabled = disabled
23 | self.log_on_batch_idx = log_on_batch_idx
24 | self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
25 | self.log_first_step = log_first_step
26 |
27 | @rank_zero_only
28 | def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
29 | root = os.path.join(save_dir, "image_log", split)
30 | for k in images:
31 | grid = torchvision.utils.make_grid(images[k], nrow=4)
32 | if self.rescale:
33 | grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
34 | grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
35 | grid = grid.numpy()
36 | grid = (grid * 255).astype(np.uint8)
37 | filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
38 | path = os.path.join(root, filename)
39 | os.makedirs(os.path.split(path)[0], exist_ok=True)
40 | Image.fromarray(grid).save(path)
41 |
42 | def log_img(self, pl_module, batch, batch_idx, split="train"):
43 | check_idx = batch_idx # if self.log_on_batch_idx else pl_module.global_step
44 | if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
45 | hasattr(pl_module, "log_images") and
46 | callable(pl_module.log_images) and
47 | self.max_images > 0):
48 | logger = type(pl_module.logger)
49 |
50 | is_train = pl_module.training
51 | if is_train:
52 | pl_module.eval()
53 |
54 | with torch.no_grad():
55 | images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
56 |
57 | for k in images:
58 | N = min(images[k].shape[0], self.max_images)
59 | images[k] = images[k][:N]
60 | if isinstance(images[k], torch.Tensor):
61 | images[k] = images[k].detach().cpu()
62 | if self.clamp:
63 | images[k] = torch.clamp(images[k], -1., 1.)
64 |
65 | self.log_local(pl_module.logger.save_dir, split, images,
66 | pl_module.global_step, pl_module.current_epoch, batch_idx)
67 |
68 | if is_train:
69 | pl_module.train()
70 |
71 | def check_frequency(self, check_idx):
72 | return check_idx % self.batch_freq == 0
73 |
74 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
75 | if not self.disabled:
76 | self.log_img(pl_module, batch, batch_idx, split="train")
77 |
--------------------------------------------------------------------------------
/cldm/model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | from omegaconf import OmegaConf
5 | from ldm.util import instantiate_from_config
6 |
7 |
8 | def get_state_dict(d):
9 | return d.get('state_dict', d)
10 |
11 |
12 | def load_state_dict(ckpt_path, location='cpu'):
13 | _, extension = os.path.splitext(ckpt_path)
14 | if extension.lower() == ".safetensors":
15 | import safetensors.torch
16 | state_dict = safetensors.torch.load_file(ckpt_path, device=location)
17 | else:
18 | state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
19 | state_dict = get_state_dict(state_dict)
20 | print(f'Loaded state_dict from [{ckpt_path}]')
21 | return state_dict
22 |
23 |
24 | def create_model(config_path):
25 | config = OmegaConf.load(config_path)
26 | model = instantiate_from_config(config.model).cpu()
27 | print(f'Loaded model config from [{config_path}]')
28 | return model
29 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | save_memory = False
3 | handrefiner_root=str(Path(__file__).parent)
--------------------------------------------------------------------------------
/control_depth_inpaint.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: cldm.cldm.ControlLDM
3 | params:
4 | linear_start: 0.00085
5 | linear_end: 0.0120
6 | num_timesteps_cond: 1
7 | log_every_t: 200
8 | timesteps: 1000
9 | first_stage_key: "jpg"
10 | cond_stage_key: "txt"
11 | control_key: "hint"
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: false
15 | conditioning_key: hybrid
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 | only_mid_control: False
20 |
21 | control_stage_config:
22 | target: cldm.cldm.ControlNet
23 | params:
24 | image_size: 32 # unused
25 | in_channels: 4
26 | hint_channels: 3
27 | model_channels: 320
28 | attention_resolutions: [ 4, 2, 1 ]
29 | num_res_blocks: 2
30 | channel_mult: [ 1, 2, 4, 4 ]
31 | num_heads: 8
32 | use_spatial_transformer: True
33 | transformer_depth: 1
34 | context_dim: 768
35 | use_checkpoint: True
36 | legacy: False
37 |
38 | unet_config:
39 | target: cldm.cldm.ControlledUnetModel
40 | params:
41 | image_size: 32 # unused
42 | in_channels: 9
43 | out_channels: 4
44 | model_channels: 320
45 | attention_resolutions: [ 4, 2, 1 ]
46 | num_res_blocks: 2
47 | channel_mult: [ 1, 2, 4, 4 ]
48 | num_heads: 8
49 | use_spatial_transformer: True
50 | transformer_depth: 1
51 | context_dim: 768
52 | use_checkpoint: True
53 | legacy: False
54 |
55 | first_stage_config:
56 | target: ldm.models.autoencoder.AutoencoderKL
57 | params:
58 | embed_dim: 4
59 | monitor: val/rec_loss
60 | ddconfig:
61 | double_z: true
62 | z_channels: 4
63 | resolution: 256
64 | in_channels: 3
65 | out_ch: 3
66 | ch: 128
67 | ch_mult:
68 | - 1
69 | - 2
70 | - 4
71 | - 4
72 | num_res_blocks: 2
73 | attn_resolutions: []
74 | dropout: 0.0
75 | lossconfig:
76 | target: torch.nn.Identity
77 |
78 | cond_stage_config:
79 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
80 |
--------------------------------------------------------------------------------
/docs/installation.md:
--------------------------------------------------------------------------------
1 | ## Installation Instructions
2 |
3 | 1. Clone HandRefiner to your local repository
4 | 2. Install MeshGraphormer to HandRefiner/MeshGraphormer following instructions in [meshgraphormer.md](meshgraphormer.md). (If encountrer any error, you can also refer to original documentations in the Meshgraphormer).
5 | Please also comply to Mesh Graphormer's license when using it in this project.
6 | 3. Make sure you are on the 'HandRefiner/' directory for the following steps, refer to [requirements.txt](../requirements.txt) for packages required for the project.
7 | 4. Install Mediapipe:
8 | ```bash
9 | pip install -q mediapipe==0.10.0
10 | cd preprocessor
11 | wget https://storage.googleapis.com/mediapipe-models/hand_landmarker/hand_landmarker/float16/1/hand_landmarker.task
12 | ```
13 | ```
14 | 5. Download weights, there are two sets of weights can be used:
15 | - Inpaint Stable Diffusion weights [sd-v1-5-inpainting.ckpt](https://huggingface.co/runwayml/stable-diffusion-inpainting/blob/main/sd-v1-5-inpainting.ckpt) and Depth controlnet weights [control_v11f1p_sd15_depth.pth](https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11f1p_sd15_depth.pth). Put sd-v1-5-inpainting.ckpt and control_v11f1p_sd15_depth.pth in HandRefiner/models/ folder. To use these weights, set --finetuned flag to False when executing the HandRefiner.
16 | - Finetuned weights [inpaint_depth_control.ckpt](https://drive.google.com/file/d/1eD2Lnfk0KZols68mVahcVfNx3GnYdHxo/view?usp=sharing) as introduced in the paper. Put inpaint_depth_control.ckpt in the HandRefiner/models/ folder. A control strength of 0.4 - 0.8 is recommended for the finetuned weights, we use 0.55 in the evaluation of paper. Alternatively, adaptive control strength can be used by setting --adaptive_control flag to True, though the inference time is much longer.
17 |
18 | Finetuned weights are more adaptable to complex gestures, and their inpainting is more harmonious. You can also attempt to use original weights while the failure rate could be higher.
19 |
20 | 6. Test if installation succeeds:
21 |
22 | For single image rectification:
23 | ```bash
24 | python handrefiner.py --input_img test/1.jpg --out_dir output --strength 0.55 --weights models/inpaint_depth_control.ckpt --prompt "a man facing the camera, making a hand gesture, indoor" --seed 1
25 | ```
26 | For multiple image rectifications:
27 | ```bash
28 | python handrefiner.py --input_dir test --out_dir output --strength 0.55 --weights models/inpaint_depth_control.ckpt --prompt_file test/test.json --seed 1
29 | ```
30 |
31 |
--------------------------------------------------------------------------------
/docs/manual.md:
--------------------------------------------------------------------------------
1 | ## Manual
2 | Arguments for executing HandRefiner.py:
3 |
4 | --input_dir
5 |
6 | input directory containing images to be rectified
7 |
8 | --input_img
9 |
10 | input image to be rectified
11 | --out_dir
12 |
13 | output directory where the rectified images will be saved to
14 |
15 | --log_json
16 |
17 | file where the mpjpe values will be logged to
18 | --strength
19 |
20 | control strength for ControlNet
21 |
22 | --depth_dir
23 |
24 | directory where the depth maps will be saved to. Leaving it empty will disable this function
25 | --mask_dir
26 |
27 | directory where the masks will be saved to. Leaving it empty will disable this function
28 | --eval (True/False)
29 |
30 | whether evaluate the mpjpe error in fixed control strength mode, currently only works for batch size of 1.
31 | --finetuned (True/False)
32 |
33 | whether use finetuned ControlNet trained on synthetic images as introduced in the paper
34 | --weights
35 |
36 | path to the SD + ControlNet weights
37 | --num_samples
38 |
39 | batch size
40 | --prompt_file
41 |
42 | prompt file for multi-image rectification
43 | Format for prompt file:
44 | ```
45 | {"img": filename, "txt": prompt}
46 | ```
47 | Example:
48 | ```json
49 | {"img": "img1.jpg", "txt": "a woman making a hand gesture"}
50 | {"img": "img2.jpg", "txt": "a man making a hand gesture"}
51 | {"img": "img3.jpg", "txt": "a man making a thumbs up gesture"}
52 | ```
53 |
54 | --prompt
55 |
56 | prompt for single image rectification
57 | --n_iter
58 |
59 | number of generation iteration for each image to be rectified. In general, for each input image, n_iter x num_samples number of rectified images will be produced
60 | --adaptive_control (True/False)
61 |
62 | adaptive control strength as introduced in paper, currently only works for batch size of 1. We tend to use fixed control strength as default.
63 | --padding_bbox
64 |
65 | padding controls the size of masks around the hand
66 |
67 | --seed
68 |
69 | set seed to maintain reproducibility
70 |
71 |
--------------------------------------------------------------------------------
/docs/meshgraphormer.md:
--------------------------------------------------------------------------------
1 | # MeshGraphormer Instructions for HandRefiner
2 |
3 | ## Installation
4 |
5 | ### Requirements
6 |
7 |
8 |
9 | Install the MeshGraphormer to HandRefiner/MeshGraphormer
10 |
11 | ```bash
12 | git clone --recursive https://github.com/microsoft/MeshGraphormer.git
13 | cd MeshGraphormer
14 | pip install ./manopth/.
15 | ```
16 |
17 |
18 | ## Download
19 | Make sure you are on 'HandRefiner/MeshGraphormer' directory for the following steps
20 | 1. Create folder that store pretrained models.
21 | ```bash
22 | mkdir -p models # pre-trained models
23 | ```
24 |
25 | 2. Download pretrained models, and some code modifications.
26 |
27 | ```bash
28 | cp ../scripts/download_models.sh scripts/download_models.sh
29 | cp ../scripts/_gcnn.py src/modeling/_gcnn.py
30 | cp ../scripts/_mano.py src/modeling/_mano.py
31 | cp ../scripts/config.py src/modeling/data/config.py
32 | bash scripts/download_models.sh
33 | ```
34 |
35 | The resulting data structure should follow the hierarchy as below.
36 | ```
37 | MeshGraphormer
38 | |-- models
39 | | |-- graphormer_release
40 | | | |-- graphormer_hand_state_dict.bin
41 | | |-- hrnet
42 | | | |-- hrnetv2_w64_imagenet_pretrained.pth
43 | | | |-- cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml
44 | |-- src
45 | |-- datasets
46 | |-- predictions
47 | |-- README.md
48 | |-- ...
49 | |-- ...
50 | ```
51 |
52 | 3. Download MANO model from their official websites
53 |
54 | - Download `MANO_RIGHT.pkl` from [MANO](https://mano.is.tue.mpg.de/), and place it at `MeshGraphormer/src/modeling/data`.
55 |
56 | Please put the downloaded files under the `MeshGraphormer/src/modeling/data` directory. The data structure should follow the hierarchy below.
57 | ```
58 | MeshGraphormer
59 | |-- src
60 | | |-- modeling
61 | | | |-- data
62 | | | | |-- MANO_RIGHT.pkl
63 | |-- models
64 | |-- datasets
65 | |-- predictions
66 | |-- README.md
67 | |-- ...
68 | |-- ...
69 | ```
70 | 4. exit the MeshGraphormer directory when finished
71 | ```bash
72 | cd ..
73 | ```
--------------------------------------------------------------------------------
/handrefiner.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | # STEP 1: Import the necessary modules.
5 | from __future__ import absolute_import, division, print_function
6 | import sys
7 | from config import handrefiner_root
8 | import os
9 |
10 | def load():
11 | paths = [handrefiner_root, os.path.join(handrefiner_root, 'MeshGraphormer'), os.path.join(handrefiner_root, 'preprocessor')]
12 | for p in paths:
13 | sys.path.insert(0, p)
14 |
15 | load()
16 |
17 | import argparse
18 | import json
19 | import torch
20 | import numpy as np
21 | import cv2
22 |
23 | from PIL import Image
24 | from torchvision import transforms
25 | import numpy as np
26 | import cv2
27 |
28 | from pytorch_lightning import seed_everything
29 | from cldm.model import create_model, load_state_dict
30 | from cldm.ddim_hacked import DDIMSampler
31 | import config
32 |
33 | import cv2
34 | import einops
35 | import numpy as np
36 | import torch
37 | import random
38 | from pathlib import Path
39 | from preprocessor.meshgraphormer import MeshGraphormerMediapipe
40 | import ast
41 |
42 | transform = transforms.Compose([
43 | transforms.ToTensor(),
44 | transforms.Normalize(
45 | mean=[0.485, 0.456, 0.406],
46 | std=[0.229, 0.224, 0.225])])
47 |
48 | def parse_args():
49 | parser = argparse.ArgumentParser()
50 |
51 | # input directory containing images to be rectified
52 | parser.add_argument('--input_dir', type=str, default="")
53 |
54 | # input image
55 | parser.add_argument('--input_img', type=str, default="")
56 |
57 | # output directory where the rectified images will be saved to
58 | parser.add_argument('--out_dir', type=str, default="")
59 |
60 | # file where the mpjpe values will be logged to
61 | parser.add_argument('--log_json', type=str, default="")
62 |
63 | # control strength for ControlNet
64 | parser.add_argument('--strength', type=float, default=1.0)
65 |
66 | # directory where the depth maps will be saved to. Leaving it empty will disable this function
67 | parser.add_argument('--depth_dir', type=str, default="")
68 |
69 | # directory where the masks will be saved to. Leaving it empty will disable this function
70 | parser.add_argument('--mask_dir', type=str, default="")
71 |
72 | # whether evaluate the mpjpe error in fixed control strength mode
73 | parser.add_argument('--eval', type=ast.literal_eval, default=False)
74 |
75 | # whether use finetuned ControlNet trained on synthetic images as introduced in the paper
76 | parser.add_argument('--finetuned', type=ast.literal_eval, default=True)
77 |
78 | # path to the SD + ControlNet weights
79 | parser.add_argument('--weights', type=str, default="")
80 |
81 | # batch size
82 | parser.add_argument('--num_samples', type=int, default=1)
83 |
84 | # prompt file for multi-image rectification
85 | # see manual.md for file format
86 | parser.add_argument('--prompt_file', type=str, default="")
87 |
88 | # prompt for single image rectification
89 | parser.add_argument('--prompt', type=str, default="")
90 |
91 | # number of generation iteration for each image to be rectified
92 | # in general, for each input image, n_iter x num_samples number of rectified images will be produced
93 | parser.add_argument('--n_iter', type=int, default=1)
94 |
95 | # adaptive control strength as introduced in paper (we tend to use fixed control strength as default)
96 | parser.add_argument('--adaptive_control', type=ast.literal_eval, default=False)
97 |
98 | # padding controls the size of masks around the hand
99 | parser.add_argument('--padding_bbox', type=int, default=30)
100 |
101 | # set seed
102 | parser.add_argument('--seed', type=int, default=-1)
103 | args = parser.parse_args()
104 | return args
105 |
106 | args = parse_args()
107 |
108 | if (args.prompt_file != "" and args.prompt != "") or (args.prompt_file == "" and args.prompt == ""):
109 | raise Exception("Please specify one and only one of the --prompt and --prompt_file")
110 | if (args.input_dir != "" and args.input_img != "") or (args.input_dir == "" and args.input_img == ""):
111 | raise Exception("Please specify one and only one of the --input_dir and --input_img")
112 |
113 | model = create_model("control_depth_inpaint.yaml").cpu()
114 | if args.finetuned:
115 | model.load_state_dict(load_state_dict(args.weights, location='cuda'), strict=False)
116 | else:
117 | model.load_state_dict(
118 | load_state_dict("models/sd-v1-5-inpainting.ckpt", location="cuda"), strict=False
119 | )
120 | model.load_state_dict(
121 | load_state_dict("models/control_v11f1p_sd15_depth.pth", location="cuda"),
122 | strict=False,
123 | )
124 |
125 | model = model.to("cuda")
126 |
127 | meshgraphormer = MeshGraphormerMediapipe()
128 |
129 | if args.log_json != "":
130 | f_mpjpe = open(args.log_json, 'w')
131 |
132 |
133 | # prompt needs to be same for all pictures in the same batch
134 | if args.input_img != "":
135 | assert args.prompt_file == "", "prompt file should not be used for single image rectification"
136 | inputs = [args.input_img]
137 | else:
138 | if args.prompt_file != "":
139 | f_prompt = open(args.prompt_file)
140 | inputs = f_prompt.readlines()
141 | else:
142 | inputs = os.listdir(args.input_dir)
143 |
144 | for file_info in inputs:
145 | if args.prompt_file != "":
146 | file_info = json.loads(file_info)
147 | file_name = file_info["img"]
148 | prompt = file_info["txt"]
149 | else:
150 | file_name = file_info
151 | prompt = args.prompt
152 |
153 | image_file = os.path.join(args.input_dir, file_name)
154 |
155 | file_name_raw = Path(file_name).stem
156 |
157 | # STEP 3: Load the input image.
158 | image = np.array(Image.open(image_file))
159 |
160 | raw_image = image
161 | H, W, C = raw_image.shape
162 | gen_count = 0
163 | for iteration in range(args.n_iter):
164 |
165 | depthmap, mask, info = meshgraphormer.get_depth(args.input_dir, file_name, args.padding_bbox)
166 |
167 | if args.depth_dir != "":
168 | cv2.imwrite(os.path.join(args.depth_dir, file_name_raw + "_depth.jpg"), depthmap)
169 | if args.mask_dir != "":
170 | cv2.imwrite(os.path.join(args.mask_dir, file_name_raw + "_mask.jpg"), mask)
171 |
172 | control = depthmap
173 |
174 | ddim_sampler = DDIMSampler(model)
175 | num_samples = args.num_samples
176 | ddim_steps = 50
177 | guess_mode = False
178 | strength = args.strength
179 | scale = 9.0
180 | seed = args.seed
181 |
182 | label = file_name[:2]
183 | a_prompt = "realistic, best quality, extremely detailed"
184 | n_prompt = "fake 3D rendered image, longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, blue"
185 |
186 | source = raw_image
187 |
188 | source = (source.astype(np.float32) / 127.5) - 1.0
189 | source = source.transpose([2, 0, 1]) # source is c h w
190 |
191 | mask = mask.astype(np.float32) / 255.0
192 | mask = mask[None]
193 | mask[mask < 0.5] = 0
194 | mask[mask >= 0.5] = 1
195 |
196 | hint = control.astype(np.float32) / 255.0
197 |
198 | masked_image = source * (mask < 0.5) # masked image is c h w
199 |
200 | mask = torch.stack([torch.tensor(mask) for _ in range(num_samples)], dim=0).to("cuda")
201 | mask = torch.nn.functional.interpolate(mask, size=(64, 64))
202 |
203 | if seed == -1:
204 | seed = random.randint(0, 65535)
205 | seed_everything(seed)
206 |
207 | if config.save_memory:
208 | model.low_vram_shift(is_diffusing=False)
209 |
210 | masked_image = torch.stack(
211 | [torch.tensor(masked_image) for _ in range(num_samples)], dim=0
212 | ).to("cuda")
213 |
214 | # this should be b,c,h,w
215 | masked_image = model.get_first_stage_encoding(model.encode_first_stage(masked_image))
216 |
217 | x = torch.stack([torch.tensor(source) for _ in range(num_samples)], dim=0).to("cuda")
218 | z = model.get_first_stage_encoding(model.encode_first_stage(x))
219 |
220 | cats = torch.cat([mask, masked_image], dim=1)
221 |
222 | hint = hint[
223 | None,
224 | ].repeat(3, axis=0)
225 |
226 | hint = torch.stack([torch.tensor(hint) for _ in range(num_samples)], dim=0).to("cuda")
227 |
228 | cond = {
229 | "c_concat": [cats],
230 | "c_control": [hint],
231 | "c_crossattn": [model.get_learned_conditioning([prompt] * num_samples)],
232 | }
233 | un_cond = {
234 | "c_concat": [cats],
235 | "c_control": [hint],
236 | "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)],
237 | }
238 |
239 |
240 | shape = (4, H // 8, W // 8)
241 |
242 | if config.save_memory:
243 | model.low_vram_shift(is_diffusing=True)
244 |
245 | if not args.adaptive_control:
246 | seed_everything(seed)
247 | model.control_scales = (
248 | [strength * (0.825 ** float(12 - i)) for i in range(13)]
249 | if guess_mode
250 | else ([strength] * 13)
251 | ) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
252 | samples, intermediates = ddim_sampler.sample(
253 | ddim_steps,
254 | num_samples,
255 | shape,
256 | cond,
257 | verbose=False,
258 | unconditional_guidance_scale=scale,
259 | unconditional_conditioning=un_cond,
260 | x0=z,
261 | mask=mask
262 | )
263 | if config.save_memory:
264 | model.low_vram_shift(is_diffusing=False)
265 |
266 | x_samples = model.decode_first_stage(samples)
267 | # print(x_samples.shape)
268 | x_samples = (
269 | (einops.rearrange(x_samples, "b c h w -> b h w c") * 127.5 + 127.5)
270 | .cpu()
271 | .numpy()
272 | .clip(0, 255)
273 | .astype(np.uint8)
274 | )
275 |
276 | if args.eval: # currently only works for batch size of 1
277 | assert args.num_samples == 1, "MPJPE evaluation currently only works for batch size of 1"
278 | mpjpe = meshgraphormer.eval_mpjpe(x_samples[0], info)
279 | print(mpjpe)
280 | if args.log_json != "":
281 | mpjpe_info = {"img": image_file, "strength": strength, "mpjpje": mpjpe}
282 | f_mpjpe.write(json.dumps(mpjpe_info))
283 | f_mpjpe.write("\n")
284 | for i in range(args.num_samples):
285 | cv2.imwrite(
286 | os.path.join(args.out_dir, "{}_{}.jpg".format(file_name_raw, gen_count)), cv2.cvtColor(x_samples[i], cv2.COLOR_RGB2BGR)
287 | )
288 | gen_count += 1
289 | else:
290 | assert args.num_samples == 1, "Adaptive thresholding currently only works for batch size of 1"
291 | strengths = [1.0, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
292 | ref_mpjpe = None
293 | chosen_strength = None
294 | final_mpjpe = None
295 | chosen_sample = None
296 | count = 0
297 | for strength in strengths:
298 | seed_everything(seed)
299 | model.control_scales = (
300 | [strength * (0.825 ** float(12 - i)) for i in range(13)]
301 | if guess_mode
302 | else ([strength] * 13)
303 | ) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
304 | samples, intermediates = ddim_sampler.sample(
305 | ddim_steps,
306 | num_samples,
307 | shape,
308 | cond,
309 | verbose=False,
310 | unconditional_guidance_scale=scale,
311 | unconditional_conditioning=un_cond,
312 | x0=z,
313 | mask=mask
314 | )
315 | if config.save_memory:
316 | model.low_vram_shift(is_diffusing=False)
317 |
318 | x_samples = model.decode_first_stage(samples)
319 |
320 | x_samples = (
321 | (einops.rearrange(x_samples, "b c h w -> b h w c") * 127.5 + 127.5)
322 | .cpu()
323 | .numpy()
324 | .clip(0, 255)
325 | .astype(np.uint8)
326 | )
327 | mpjpe = meshgraphormer.eval_mpjpe(x_samples[0], info)
328 | if count == 0:
329 | ref_mpjpe = mpjpe
330 | chosen_sample = x_samples[0]
331 | elif mpjpe < ref_mpjpe * 1.15:
332 | chosen_strength = strength
333 | final_mpjpe = mpjpe
334 | chosen_sample = x_samples[0]
335 | break
336 | elif strength == 0.9:
337 | final_mpjpe = ref_mpjpe
338 | chosen_strength = 1.0
339 | count += 1
340 |
341 | if args.log_json != "":
342 | mpjpe_info = {"img": image_file, "strength": chosen_strength, "mpjpje": final_mpjpe}
343 | f_mpjpe.write(json.dumps(mpjpe_info))
344 | f_mpjpe.write("\n")
345 |
346 | cv2.imwrite(
347 | os.path.join(args.out_dir, "{}_{}.jpg".format(file_name_raw, gen_count)), cv2.cvtColor(x_samples[0], cv2.COLOR_RGB2BGR)
348 | )
349 | gen_count += 1
350 |
351 |
352 |
--------------------------------------------------------------------------------
/ldm/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wenquanlu/HandRefiner/f07e196298bafe871064b2952587548e28ccd467/ldm/data/__init__.py
--------------------------------------------------------------------------------
/ldm/data/control_synthcompositedata.py:
--------------------------------------------------------------------------------
1 | import json
2 | import cv2
3 | import numpy as np
4 | from PIL import Image
5 | import random
6 |
7 | from torch.utils.data import Dataset
8 |
9 | DATA_PATH_1 = "/raid/wenquanlu/RHD/RHD_published_v2/"
10 | DATA_PATH_2 = "/raid/wenquanlu/synthesisai/"
11 |
12 | abbrev_dict = {"RHD": DATA_PATH_1,
13 | "synthesisai": DATA_PATH_2}
14 |
15 | class Control_composite_Hand_synth_data(Dataset):
16 | def __init__(self):
17 | self.data = []
18 | with open('../RHD/RHD_published_v2/rgb_caption.json', 'rt') as f1:
19 | for line in f1:
20 | item = json.loads(line)
21 | item['dataset'] = 'RHD'
22 | self.data.append(item)
23 | with open('../synthesisai/rgb_caption.json', 'rt') as f2:
24 | for line in f2:
25 | item = json.loads(line)
26 | item['dataset'] = 'synthesisai'
27 | self.data.append(item)
28 | def __len__(self):
29 | return len(self.data)
30 |
31 | def __getitem__(self, idx):
32 | item = self.data[idx]
33 | source_filename = item['jpg']
34 | prompt = item['txt']
35 | dataset = item['dataset']
36 | datapath = abbrev_dict[dataset]
37 | if random.random() < 0.5:
38 | prompt = ""
39 | source = cv2.imread(datapath + "image/" + source_filename)
40 | source = (source.astype(np.float32) / 127.5) - 1.0
41 | source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB)
42 |
43 | mask = np.array(Image.open(datapath + "mask/" + source_filename).convert("L"))
44 | mask = mask.astype(np.float32)/255.0
45 | mask = mask[None]
46 | mask[mask < 0.5] = 0
47 | mask[mask >= 0.5] = 1
48 | mask = np.transpose(mask, [1, 2, 0])
49 |
50 | hint = cv2.imread(datapath + "pose/" + source_filename)
51 | hint = cv2.cvtColor(hint, cv2.COLOR_BGR2RGB)
52 |
53 | hint = hint.astype(np.float32) / 255.0
54 |
55 | masked_image = source * (mask < 0.5)
56 | return dict(jpg=source, txt=prompt, hint=hint, mask=mask, masked_image=masked_image)
--------------------------------------------------------------------------------
/ldm/data/util.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from ldm.modules.midas.api import load_midas_transform
4 |
5 |
6 | class AddMiDaS(object):
7 | def __init__(self, model_type):
8 | super().__init__()
9 | self.transform = load_midas_transform(model_type)
10 |
11 | def pt2np(self, x):
12 | x = ((x + 1.0) * .5).detach().cpu().numpy()
13 | return x
14 |
15 | def np2pt(self, x):
16 | x = torch.from_numpy(x) * 2 - 1.
17 | return x
18 |
19 | def __call__(self, sample):
20 | # sample['jpg'] is tensor hwc in [-1, 1] at this point
21 | x = self.pt2np(sample['jpg'])
22 | x = self.transform({"image": x})["image"]
23 | sample['midas_in'] = x
24 | return sample
--------------------------------------------------------------------------------
/ldm/models/autoencoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pytorch_lightning as pl
3 | import torch.nn.functional as F
4 | from contextlib import contextmanager
5 |
6 | from ldm.modules.diffusionmodules.model import Encoder, Decoder
7 | from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
8 |
9 | from ldm.util import instantiate_from_config
10 | from ldm.modules.ema import LitEma
11 |
12 |
13 | class AutoencoderKL(pl.LightningModule):
14 | def __init__(self,
15 | ddconfig,
16 | lossconfig,
17 | embed_dim,
18 | ckpt_path=None,
19 | ignore_keys=[],
20 | image_key="image",
21 | colorize_nlabels=None,
22 | monitor=None,
23 | ema_decay=None,
24 | learn_logvar=False
25 | ):
26 | super().__init__()
27 | self.learn_logvar = learn_logvar
28 | self.image_key = image_key
29 | self.encoder = Encoder(**ddconfig)
30 | self.decoder = Decoder(**ddconfig)
31 | self.loss = instantiate_from_config(lossconfig)
32 | assert ddconfig["double_z"]
33 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
34 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
35 | self.embed_dim = embed_dim
36 | if colorize_nlabels is not None:
37 | assert type(colorize_nlabels)==int
38 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
39 | if monitor is not None:
40 | self.monitor = monitor
41 |
42 | self.use_ema = ema_decay is not None
43 | if self.use_ema:
44 | self.ema_decay = ema_decay
45 | assert 0. < ema_decay < 1.
46 | self.model_ema = LitEma(self, decay=ema_decay)
47 | print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
48 |
49 | if ckpt_path is not None:
50 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
51 |
52 | def init_from_ckpt(self, path, ignore_keys=list()):
53 | sd = torch.load(path, map_location="cpu")["state_dict"]
54 | keys = list(sd.keys())
55 | for k in keys:
56 | for ik in ignore_keys:
57 | if k.startswith(ik):
58 | print("Deleting key {} from state_dict.".format(k))
59 | del sd[k]
60 | self.load_state_dict(sd, strict=False)
61 | print(f"Restored from {path}")
62 |
63 | @contextmanager
64 | def ema_scope(self, context=None):
65 | if self.use_ema:
66 | self.model_ema.store(self.parameters())
67 | self.model_ema.copy_to(self)
68 | if context is not None:
69 | print(f"{context}: Switched to EMA weights")
70 | try:
71 | yield None
72 | finally:
73 | if self.use_ema:
74 | self.model_ema.restore(self.parameters())
75 | if context is not None:
76 | print(f"{context}: Restored training weights")
77 |
78 | def on_train_batch_end(self, *args, **kwargs):
79 | if self.use_ema:
80 | self.model_ema(self)
81 |
82 | def encode(self, x):
83 | h = self.encoder(x)
84 | moments = self.quant_conv(h)
85 | posterior = DiagonalGaussianDistribution(moments)
86 | return posterior
87 |
88 | def decode(self, z):
89 | z = self.post_quant_conv(z)
90 | dec = self.decoder(z)
91 | return dec
92 |
93 | def forward(self, input, sample_posterior=True):
94 | posterior = self.encode(input)
95 | if sample_posterior:
96 | z = posterior.sample()
97 | else:
98 | z = posterior.mode()
99 | dec = self.decode(z)
100 | return dec, posterior
101 |
102 | def get_input(self, batch, k):
103 | x = batch[k]
104 | if len(x.shape) == 3:
105 | x = x[..., None]
106 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
107 | return x
108 |
109 | def training_step(self, batch, batch_idx, optimizer_idx):
110 | inputs = self.get_input(batch, self.image_key)
111 | reconstructions, posterior = self(inputs)
112 |
113 | if optimizer_idx == 0:
114 | # train encoder+decoder+logvar
115 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
116 | last_layer=self.get_last_layer(), split="train")
117 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
118 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
119 | return aeloss
120 |
121 | if optimizer_idx == 1:
122 | # train the discriminator
123 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
124 | last_layer=self.get_last_layer(), split="train")
125 |
126 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
127 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
128 | return discloss
129 |
130 | def validation_step(self, batch, batch_idx):
131 | log_dict = self._validation_step(batch, batch_idx)
132 | with self.ema_scope():
133 | log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
134 | return log_dict
135 |
136 | def _validation_step(self, batch, batch_idx, postfix=""):
137 | inputs = self.get_input(batch, self.image_key)
138 | reconstructions, posterior = self(inputs)
139 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
140 | last_layer=self.get_last_layer(), split="val"+postfix)
141 |
142 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
143 | last_layer=self.get_last_layer(), split="val"+postfix)
144 |
145 | self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
146 | self.log_dict(log_dict_ae)
147 | self.log_dict(log_dict_disc)
148 | return self.log_dict
149 |
150 | def configure_optimizers(self):
151 | lr = self.learning_rate
152 | ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
153 | self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
154 | if self.learn_logvar:
155 | print(f"{self.__class__.__name__}: Learning logvar")
156 | ae_params_list.append(self.loss.logvar)
157 | opt_ae = torch.optim.Adam(ae_params_list,
158 | lr=lr, betas=(0.5, 0.9))
159 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
160 | lr=lr, betas=(0.5, 0.9))
161 | return [opt_ae, opt_disc], []
162 |
163 | def get_last_layer(self):
164 | return self.decoder.conv_out.weight
165 |
166 | @torch.no_grad()
167 | def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
168 | log = dict()
169 | x = self.get_input(batch, self.image_key)
170 | x = x.to(self.device)
171 | if not only_inputs:
172 | xrec, posterior = self(x)
173 | if x.shape[1] > 3:
174 | # colorize with random projection
175 | assert xrec.shape[1] > 3
176 | x = self.to_rgb(x)
177 | xrec = self.to_rgb(xrec)
178 | log["samples"] = self.decode(torch.randn_like(posterior.sample()))
179 | log["reconstructions"] = xrec
180 | if log_ema or self.use_ema:
181 | with self.ema_scope():
182 | xrec_ema, posterior_ema = self(x)
183 | if x.shape[1] > 3:
184 | # colorize with random projection
185 | assert xrec_ema.shape[1] > 3
186 | xrec_ema = self.to_rgb(xrec_ema)
187 | log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
188 | log["reconstructions_ema"] = xrec_ema
189 | log["inputs"] = x
190 | return log
191 |
192 | def to_rgb(self, x):
193 | assert self.image_key == "segmentation"
194 | if not hasattr(self, "colorize"):
195 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
196 | x = F.conv2d(x, weight=self.colorize)
197 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
198 | return x
199 |
200 |
201 | class IdentityFirstStage(torch.nn.Module):
202 | def __init__(self, *args, vq_interface=False, **kwargs):
203 | self.vq_interface = vq_interface
204 | super().__init__()
205 |
206 | def encode(self, x, *args, **kwargs):
207 | return x
208 |
209 | def decode(self, x, *args, **kwargs):
210 | return x
211 |
212 | def quantize(self, x, *args, **kwargs):
213 | if self.vq_interface:
214 | return x, None, [None, None, None]
215 | return x
216 |
217 | def forward(self, x, *args, **kwargs):
218 | return x
219 |
220 |
--------------------------------------------------------------------------------
/ldm/models/diffusion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wenquanlu/HandRefiner/f07e196298bafe871064b2952587548e28ccd467/ldm/models/diffusion/__init__.py
--------------------------------------------------------------------------------
/ldm/models/diffusion/dpm_solver/__init__.py:
--------------------------------------------------------------------------------
1 | from .sampler import DPMSolverSampler
--------------------------------------------------------------------------------
/ldm/models/diffusion/dpm_solver/sampler.py:
--------------------------------------------------------------------------------
1 | """SAMPLING ONLY."""
2 | import torch
3 |
4 | from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
5 |
6 |
7 | MODEL_TYPES = {
8 | "eps": "noise",
9 | "v": "v"
10 | }
11 |
12 |
13 | class DPMSolverSampler(object):
14 | def __init__(self, model, **kwargs):
15 | super().__init__()
16 | self.model = model
17 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
18 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
19 |
20 | def register_buffer(self, name, attr):
21 | if type(attr) == torch.Tensor:
22 | if attr.device != torch.device("cuda"):
23 | attr = attr.to(torch.device("cuda"))
24 | setattr(self, name, attr)
25 |
26 | @torch.no_grad()
27 | def sample(self,
28 | S,
29 | batch_size,
30 | shape,
31 | conditioning=None,
32 | callback=None,
33 | normals_sequence=None,
34 | img_callback=None,
35 | quantize_x0=False,
36 | eta=0.,
37 | mask=None,
38 | x0=None,
39 | temperature=1.,
40 | noise_dropout=0.,
41 | score_corrector=None,
42 | corrector_kwargs=None,
43 | verbose=True,
44 | x_T=None,
45 | log_every_t=100,
46 | unconditional_guidance_scale=1.,
47 | unconditional_conditioning=None,
48 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
49 | **kwargs
50 | ):
51 | if conditioning is not None:
52 | if isinstance(conditioning, dict):
53 | cbs = conditioning[list(conditioning.keys())[0]].shape[0]
54 | if cbs != batch_size:
55 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
56 | else:
57 | if conditioning.shape[0] != batch_size:
58 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
59 |
60 | # sampling
61 | C, H, W = shape
62 | size = (batch_size, C, H, W)
63 |
64 | print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
65 |
66 | device = self.model.betas.device
67 | if x_T is None:
68 | img = torch.randn(size, device=device)
69 | else:
70 | img = x_T
71 |
72 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
73 |
74 | model_fn = model_wrapper(
75 | lambda x, t, c: self.model.apply_model(x, t, c),
76 | ns,
77 | model_type=MODEL_TYPES[self.model.parameterization],
78 | guidance_type="classifier-free",
79 | condition=conditioning,
80 | unconditional_condition=unconditional_conditioning,
81 | guidance_scale=unconditional_guidance_scale,
82 | )
83 |
84 | dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
85 | x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
86 |
87 | return x.to(device), None
--------------------------------------------------------------------------------
/ldm/models/diffusion/plms.py:
--------------------------------------------------------------------------------
1 | """SAMPLING ONLY."""
2 |
3 | import torch
4 | import numpy as np
5 | from tqdm import tqdm
6 | from functools import partial
7 |
8 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9 | from ldm.models.diffusion.sampling_util import norm_thresholding
10 |
11 |
12 | class PLMSSampler(object):
13 | def __init__(self, model, schedule="linear", **kwargs):
14 | super().__init__()
15 | self.model = model
16 | self.ddpm_num_timesteps = model.num_timesteps
17 | self.schedule = schedule
18 |
19 | def register_buffer(self, name, attr):
20 | if type(attr) == torch.Tensor:
21 | if attr.device != torch.device("cuda"):
22 | attr = attr.to(torch.device("cuda"))
23 | setattr(self, name, attr)
24 |
25 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
26 | if ddim_eta != 0:
27 | raise ValueError('ddim_eta must be 0 for PLMS')
28 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
29 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
30 | alphas_cumprod = self.model.alphas_cumprod
31 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
32 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
33 |
34 | self.register_buffer('betas', to_torch(self.model.betas))
35 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
36 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
37 |
38 | # calculations for diffusion q(x_t | x_{t-1}) and others
39 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
40 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
41 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
42 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
43 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
44 |
45 | # ddim sampling parameters
46 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
47 | ddim_timesteps=self.ddim_timesteps,
48 | eta=ddim_eta,verbose=verbose)
49 | self.register_buffer('ddim_sigmas', ddim_sigmas)
50 | self.register_buffer('ddim_alphas', ddim_alphas)
51 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
52 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
53 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
54 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
55 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
56 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
57 |
58 | @torch.no_grad()
59 | def sample(self,
60 | S,
61 | batch_size,
62 | shape,
63 | conditioning=None,
64 | callback=None,
65 | normals_sequence=None,
66 | img_callback=None,
67 | quantize_x0=False,
68 | eta=0.,
69 | mask=None,
70 | x0=None,
71 | temperature=1.,
72 | noise_dropout=0.,
73 | score_corrector=None,
74 | corrector_kwargs=None,
75 | verbose=True,
76 | x_T=None,
77 | log_every_t=100,
78 | unconditional_guidance_scale=1.,
79 | unconditional_conditioning=None,
80 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
81 | dynamic_threshold=None,
82 | **kwargs
83 | ):
84 | if conditioning is not None:
85 | if isinstance(conditioning, dict):
86 | cbs = conditioning[list(conditioning.keys())[0]].shape[0]
87 | if cbs != batch_size:
88 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
89 | else:
90 | if conditioning.shape[0] != batch_size:
91 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
92 |
93 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
94 | # sampling
95 | C, H, W = shape
96 | size = (batch_size, C, H, W)
97 | print(f'Data shape for PLMS sampling is {size}')
98 |
99 | samples, intermediates = self.plms_sampling(conditioning, size,
100 | callback=callback,
101 | img_callback=img_callback,
102 | quantize_denoised=quantize_x0,
103 | mask=mask, x0=x0,
104 | ddim_use_original_steps=False,
105 | noise_dropout=noise_dropout,
106 | temperature=temperature,
107 | score_corrector=score_corrector,
108 | corrector_kwargs=corrector_kwargs,
109 | x_T=x_T,
110 | log_every_t=log_every_t,
111 | unconditional_guidance_scale=unconditional_guidance_scale,
112 | unconditional_conditioning=unconditional_conditioning,
113 | dynamic_threshold=dynamic_threshold,
114 | )
115 | return samples, intermediates
116 |
117 | @torch.no_grad()
118 | def plms_sampling(self, cond, shape,
119 | x_T=None, ddim_use_original_steps=False,
120 | callback=None, timesteps=None, quantize_denoised=False,
121 | mask=None, x0=None, img_callback=None, log_every_t=100,
122 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
123 | unconditional_guidance_scale=1., unconditional_conditioning=None,
124 | dynamic_threshold=None):
125 | device = self.model.betas.device
126 | b = shape[0]
127 | if x_T is None:
128 | img = torch.randn(shape, device=device)
129 | else:
130 | img = x_T
131 |
132 | if timesteps is None:
133 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
134 | elif timesteps is not None and not ddim_use_original_steps:
135 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
136 | timesteps = self.ddim_timesteps[:subset_end]
137 |
138 | intermediates = {'x_inter': [img], 'pred_x0': [img]}
139 | time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
140 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
141 | print(f"Running PLMS Sampling with {total_steps} timesteps")
142 |
143 | iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
144 | old_eps = []
145 |
146 | for i, step in enumerate(iterator):
147 | index = total_steps - i - 1
148 | ts = torch.full((b,), step, device=device, dtype=torch.long)
149 | ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
150 |
151 | if mask is not None:
152 | assert x0 is not None
153 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
154 | img = img_orig * mask + (1. - mask) * img
155 |
156 | outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
157 | quantize_denoised=quantize_denoised, temperature=temperature,
158 | noise_dropout=noise_dropout, score_corrector=score_corrector,
159 | corrector_kwargs=corrector_kwargs,
160 | unconditional_guidance_scale=unconditional_guidance_scale,
161 | unconditional_conditioning=unconditional_conditioning,
162 | old_eps=old_eps, t_next=ts_next,
163 | dynamic_threshold=dynamic_threshold)
164 | img, pred_x0, e_t = outs
165 | old_eps.append(e_t)
166 | if len(old_eps) >= 4:
167 | old_eps.pop(0)
168 | if callback: callback(i)
169 | if img_callback: img_callback(pred_x0, i)
170 |
171 | if index % log_every_t == 0 or index == total_steps - 1:
172 | intermediates['x_inter'].append(img)
173 | intermediates['pred_x0'].append(pred_x0)
174 |
175 | return img, intermediates
176 |
177 | @torch.no_grad()
178 | def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
179 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
180 | unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
181 | dynamic_threshold=None):
182 | b, *_, device = *x.shape, x.device
183 |
184 | def get_model_output(x, t):
185 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
186 | e_t = self.model.apply_model(x, t, c)
187 | else:
188 | x_in = torch.cat([x] * 2)
189 | t_in = torch.cat([t] * 2)
190 | c_in = torch.cat([unconditional_conditioning, c])
191 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
192 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
193 |
194 | if score_corrector is not None:
195 | assert self.model.parameterization == "eps"
196 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
197 |
198 | return e_t
199 |
200 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
201 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
202 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
203 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
204 |
205 | def get_x_prev_and_pred_x0(e_t, index):
206 | # select parameters corresponding to the currently considered timestep
207 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
208 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
209 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
210 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
211 |
212 | # current prediction for x_0
213 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
214 | if quantize_denoised:
215 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
216 | if dynamic_threshold is not None:
217 | pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
218 | # direction pointing to x_t
219 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
220 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
221 | if noise_dropout > 0.:
222 | noise = torch.nn.functional.dropout(noise, p=noise_dropout)
223 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
224 | return x_prev, pred_x0
225 |
226 | e_t = get_model_output(x, t)
227 | if len(old_eps) == 0:
228 | # Pseudo Improved Euler (2nd order)
229 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
230 | e_t_next = get_model_output(x_prev, t_next)
231 | e_t_prime = (e_t + e_t_next) / 2
232 | elif len(old_eps) == 1:
233 | # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
234 | e_t_prime = (3 * e_t - old_eps[-1]) / 2
235 | elif len(old_eps) == 2:
236 | # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
237 | e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
238 | elif len(old_eps) >= 3:
239 | # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
240 | e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
241 |
242 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
243 |
244 | return x_prev, pred_x0, e_t
245 |
--------------------------------------------------------------------------------
/ldm/models/diffusion/sampling_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | def append_dims(x, target_dims):
6 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.
7 | From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
8 | dims_to_append = target_dims - x.ndim
9 | if dims_to_append < 0:
10 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
11 | return x[(...,) + (None,) * dims_to_append]
12 |
13 |
14 | def norm_thresholding(x0, value):
15 | s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
16 | return x0 * (value / s)
17 |
18 |
19 | def spatial_norm_thresholding(x0, value):
20 | # b c h w
21 | s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
22 | return x0 * (value / s)
--------------------------------------------------------------------------------
/ldm/modules/attention.py:
--------------------------------------------------------------------------------
1 | from inspect import isfunction
2 | import math
3 | import torch
4 | import torch.nn.functional as F
5 | from torch import nn, einsum
6 | from einops import rearrange, repeat
7 | from typing import Optional, Any
8 |
9 | from ldm.modules.diffusionmodules.util import checkpoint
10 |
11 |
12 | try:
13 | import xformers
14 | import xformers.ops
15 | XFORMERS_IS_AVAILBLE = True
16 | except:
17 | XFORMERS_IS_AVAILBLE = False
18 |
19 | # CrossAttn precision handling
20 | import os
21 | _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
22 |
23 | def exists(val):
24 | return val is not None
25 |
26 |
27 | def uniq(arr):
28 | return{el: True for el in arr}.keys()
29 |
30 |
31 | def default(val, d):
32 | if exists(val):
33 | return val
34 | return d() if isfunction(d) else d
35 |
36 |
37 | def max_neg_value(t):
38 | return -torch.finfo(t.dtype).max
39 |
40 |
41 | def init_(tensor):
42 | dim = tensor.shape[-1]
43 | std = 1 / math.sqrt(dim)
44 | tensor.uniform_(-std, std)
45 | return tensor
46 |
47 |
48 | # feedforward
49 | class GEGLU(nn.Module):
50 | def __init__(self, dim_in, dim_out):
51 | super().__init__()
52 | self.proj = nn.Linear(dim_in, dim_out * 2)
53 |
54 | def forward(self, x):
55 | x, gate = self.proj(x).chunk(2, dim=-1)
56 | return x * F.gelu(gate)
57 |
58 |
59 | class FeedForward(nn.Module):
60 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
61 | super().__init__()
62 | inner_dim = int(dim * mult)
63 | dim_out = default(dim_out, dim)
64 | project_in = nn.Sequential(
65 | nn.Linear(dim, inner_dim),
66 | nn.GELU()
67 | ) if not glu else GEGLU(dim, inner_dim)
68 |
69 | self.net = nn.Sequential(
70 | project_in,
71 | nn.Dropout(dropout),
72 | nn.Linear(inner_dim, dim_out)
73 | )
74 |
75 | def forward(self, x):
76 | return self.net(x)
77 |
78 |
79 | def zero_module(module):
80 | """
81 | Zero out the parameters of a module and return it.
82 | """
83 | for p in module.parameters():
84 | p.detach().zero_()
85 | return module
86 |
87 |
88 | def Normalize(in_channels):
89 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
90 |
91 |
92 | class SpatialSelfAttention(nn.Module):
93 | def __init__(self, in_channels):
94 | super().__init__()
95 | self.in_channels = in_channels
96 |
97 | self.norm = Normalize(in_channels)
98 | self.q = torch.nn.Conv2d(in_channels,
99 | in_channels,
100 | kernel_size=1,
101 | stride=1,
102 | padding=0)
103 | self.k = torch.nn.Conv2d(in_channels,
104 | in_channels,
105 | kernel_size=1,
106 | stride=1,
107 | padding=0)
108 | self.v = torch.nn.Conv2d(in_channels,
109 | in_channels,
110 | kernel_size=1,
111 | stride=1,
112 | padding=0)
113 | self.proj_out = torch.nn.Conv2d(in_channels,
114 | in_channels,
115 | kernel_size=1,
116 | stride=1,
117 | padding=0)
118 |
119 | def forward(self, x):
120 | h_ = x
121 | h_ = self.norm(h_)
122 | q = self.q(h_)
123 | k = self.k(h_)
124 | v = self.v(h_)
125 |
126 | # compute attention
127 | b,c,h,w = q.shape
128 | q = rearrange(q, 'b c h w -> b (h w) c')
129 | k = rearrange(k, 'b c h w -> b c (h w)')
130 | w_ = torch.einsum('bij,bjk->bik', q, k)
131 |
132 | w_ = w_ * (int(c)**(-0.5))
133 | w_ = torch.nn.functional.softmax(w_, dim=2)
134 |
135 | # attend to values
136 | v = rearrange(v, 'b c h w -> b c (h w)')
137 | w_ = rearrange(w_, 'b i j -> b j i')
138 | h_ = torch.einsum('bij,bjk->bik', v, w_)
139 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
140 | h_ = self.proj_out(h_)
141 |
142 | return x+h_
143 |
144 |
145 | class CrossAttention(nn.Module):
146 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
147 | super().__init__()
148 | inner_dim = dim_head * heads
149 | context_dim = default(context_dim, query_dim)
150 |
151 | self.scale = dim_head ** -0.5
152 | self.heads = heads
153 |
154 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
155 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
156 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
157 |
158 | self.to_out = nn.Sequential(
159 | nn.Linear(inner_dim, query_dim),
160 | nn.Dropout(dropout)
161 | )
162 |
163 | def forward(self, x, context=None, mask=None):
164 | h = self.heads
165 |
166 | q = self.to_q(x)
167 | context = default(context, x)
168 | k = self.to_k(context)
169 | v = self.to_v(context)
170 |
171 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
172 |
173 | # force cast to fp32 to avoid overflowing
174 | if _ATTN_PRECISION =="fp32":
175 | with torch.autocast(enabled=False, device_type = 'cuda'):
176 | q, k = q.float(), k.float()
177 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
178 | else:
179 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
180 |
181 | del q, k
182 |
183 | if exists(mask):
184 | mask = rearrange(mask, 'b ... -> b (...)')
185 | max_neg_value = -torch.finfo(sim.dtype).max
186 | mask = repeat(mask, 'b j -> (b h) () j', h=h)
187 | sim.masked_fill_(~mask, max_neg_value)
188 |
189 | # attention, what we cannot get enough of
190 | sim = sim.softmax(dim=-1)
191 |
192 | out = einsum('b i j, b j d -> b i d', sim, v)
193 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
194 | return self.to_out(out)
195 |
196 |
197 | class MemoryEfficientCrossAttention(nn.Module):
198 | # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
199 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
200 | super().__init__()
201 | print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
202 | f"{heads} heads.")
203 | inner_dim = dim_head * heads
204 | context_dim = default(context_dim, query_dim)
205 |
206 | self.heads = heads
207 | self.dim_head = dim_head
208 |
209 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
210 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
211 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
212 |
213 | self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
214 | self.attention_op: Optional[Any] = None
215 |
216 | def forward(self, x, context=None, mask=None):
217 | q = self.to_q(x)
218 | context = default(context, x)
219 | k = self.to_k(context)
220 | v = self.to_v(context)
221 |
222 | b, _, _ = q.shape
223 | q, k, v = map(
224 | lambda t: t.unsqueeze(3)
225 | .reshape(b, t.shape[1], self.heads, self.dim_head)
226 | .permute(0, 2, 1, 3)
227 | .reshape(b * self.heads, t.shape[1], self.dim_head)
228 | .contiguous(),
229 | (q, k, v),
230 | )
231 |
232 | # actually compute the attention, what we cannot get enough of
233 | out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
234 |
235 | if exists(mask):
236 | raise NotImplementedError
237 | out = (
238 | out.unsqueeze(0)
239 | .reshape(b, self.heads, out.shape[1], self.dim_head)
240 | .permute(0, 2, 1, 3)
241 | .reshape(b, out.shape[1], self.heads * self.dim_head)
242 | )
243 | return self.to_out(out)
244 |
245 |
246 | class BasicTransformerBlock(nn.Module):
247 | ATTENTION_MODES = {
248 | "softmax": CrossAttention, # vanilla attention
249 | "softmax-xformers": MemoryEfficientCrossAttention
250 | }
251 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
252 | disable_self_attn=False):
253 | super().__init__()
254 | attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
255 | assert attn_mode in self.ATTENTION_MODES
256 | attn_cls = self.ATTENTION_MODES[attn_mode]
257 | self.disable_self_attn = disable_self_attn
258 | self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
259 | context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
260 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
261 | self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
262 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
263 | self.norm1 = nn.LayerNorm(dim)
264 | self.norm2 = nn.LayerNorm(dim)
265 | self.norm3 = nn.LayerNorm(dim)
266 | self.checkpoint = checkpoint
267 |
268 | def forward(self, x, context=None):
269 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
270 |
271 | def _forward(self, x, context=None):
272 | x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
273 | x = self.attn2(self.norm2(x), context=context) + x
274 | x = self.ff(self.norm3(x)) + x
275 | return x
276 |
277 |
278 | class SpatialTransformer(nn.Module):
279 | """
280 | Transformer block for image-like data.
281 | First, project the input (aka embedding)
282 | and reshape to b, t, d.
283 | Then apply standard transformer action.
284 | Finally, reshape to image
285 | NEW: use_linear for more efficiency instead of the 1x1 convs
286 | """
287 | def __init__(self, in_channels, n_heads, d_head,
288 | depth=1, dropout=0., context_dim=None,
289 | disable_self_attn=False, use_linear=False,
290 | use_checkpoint=True):
291 | super().__init__()
292 | if exists(context_dim) and not isinstance(context_dim, list):
293 | context_dim = [context_dim]
294 | self.in_channels = in_channels
295 | inner_dim = n_heads * d_head
296 | self.norm = Normalize(in_channels)
297 | if not use_linear:
298 | self.proj_in = nn.Conv2d(in_channels,
299 | inner_dim,
300 | kernel_size=1,
301 | stride=1,
302 | padding=0)
303 | else:
304 | self.proj_in = nn.Linear(in_channels, inner_dim)
305 |
306 | self.transformer_blocks = nn.ModuleList(
307 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
308 | disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
309 | for d in range(depth)]
310 | )
311 | if not use_linear:
312 | self.proj_out = zero_module(nn.Conv2d(inner_dim,
313 | in_channels,
314 | kernel_size=1,
315 | stride=1,
316 | padding=0))
317 | else:
318 | self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
319 | self.use_linear = use_linear
320 |
321 | def forward(self, x, context=None):
322 | # note: if no context is given, cross-attention defaults to self-attention
323 | if not isinstance(context, list):
324 | context = [context]
325 | b, c, h, w = x.shape
326 | x_in = x
327 | x = self.norm(x)
328 | if not self.use_linear:
329 | x = self.proj_in(x)
330 | x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
331 | if self.use_linear:
332 | x = self.proj_in(x)
333 | for i, block in enumerate(self.transformer_blocks):
334 | x = block(x, context=context[i])
335 | if self.use_linear:
336 | x = self.proj_out(x)
337 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
338 | if not self.use_linear:
339 | x = self.proj_out(x)
340 | return x + x_in
341 |
342 |
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wenquanlu/HandRefiner/f07e196298bafe871064b2952587548e28ccd467/ldm/modules/diffusionmodules/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/upscaling.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from functools import partial
5 |
6 | from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
7 | from ldm.util import default
8 |
9 |
10 | class AbstractLowScaleModel(nn.Module):
11 | # for concatenating a downsampled image to the latent representation
12 | def __init__(self, noise_schedule_config=None):
13 | super(AbstractLowScaleModel, self).__init__()
14 | if noise_schedule_config is not None:
15 | self.register_schedule(**noise_schedule_config)
16 |
17 | def register_schedule(self, beta_schedule="linear", timesteps=1000,
18 | linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
19 | betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
20 | cosine_s=cosine_s)
21 | alphas = 1. - betas
22 | alphas_cumprod = np.cumprod(alphas, axis=0)
23 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
24 |
25 | timesteps, = betas.shape
26 | self.num_timesteps = int(timesteps)
27 | self.linear_start = linear_start
28 | self.linear_end = linear_end
29 | assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
30 |
31 | to_torch = partial(torch.tensor, dtype=torch.float32)
32 |
33 | self.register_buffer('betas', to_torch(betas))
34 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
35 | self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
36 |
37 | # calculations for diffusion q(x_t | x_{t-1}) and others
38 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
39 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
40 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
41 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
42 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
43 |
44 | def q_sample(self, x_start, t, noise=None):
45 | noise = default(noise, lambda: torch.randn_like(x_start))
46 | return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
47 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
48 |
49 | def forward(self, x):
50 | return x, None
51 |
52 | def decode(self, x):
53 | return x
54 |
55 |
56 | class SimpleImageConcat(AbstractLowScaleModel):
57 | # no noise level conditioning
58 | def __init__(self):
59 | super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
60 | self.max_noise_level = 0
61 |
62 | def forward(self, x):
63 | # fix to constant noise level
64 | return x, torch.zeros(x.shape[0], device=x.device).long()
65 |
66 |
67 | class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
68 | def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
69 | super().__init__(noise_schedule_config=noise_schedule_config)
70 | self.max_noise_level = max_noise_level
71 |
72 | def forward(self, x, noise_level=None):
73 | if noise_level is None:
74 | noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
75 | else:
76 | assert isinstance(noise_level, torch.Tensor)
77 | z = self.q_sample(x, noise_level)
78 | return z, noise_level
79 |
80 |
81 |
82 |
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/util.py:
--------------------------------------------------------------------------------
1 | # adopted from
2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3 | # and
4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5 | # and
6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7 | #
8 | # thanks!
9 |
10 |
11 | import os
12 | import math
13 | import torch
14 | import torch.nn as nn
15 | import numpy as np
16 | from einops import repeat
17 |
18 | from ldm.util import instantiate_from_config
19 |
20 |
21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
22 | if schedule == "linear":
23 | betas = (
24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
25 | )
26 |
27 | elif schedule == "cosine":
28 | timesteps = (
29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
30 | )
31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2
32 | alphas = torch.cos(alphas).pow(2)
33 | alphas = alphas / alphas[0]
34 | betas = 1 - alphas[1:] / alphas[:-1]
35 | betas = np.clip(betas, a_min=0, a_max=0.999)
36 |
37 | elif schedule == "sqrt_linear":
38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
39 | elif schedule == "sqrt":
40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
41 | else:
42 | raise ValueError(f"schedule '{schedule}' unknown.")
43 | return betas.numpy()
44 |
45 |
46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
47 | if ddim_discr_method == 'uniform':
48 | c = num_ddpm_timesteps // num_ddim_timesteps
49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
50 | elif ddim_discr_method == 'quad':
51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
52 | else:
53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
54 |
55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps
56 | # add one to get the final alpha values right (the ones from first scale to data during sampling)
57 | steps_out = ddim_timesteps + 1
58 | if verbose:
59 | print(f'Selected timesteps for ddim sampler: {steps_out}')
60 | return steps_out
61 |
62 |
63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
64 | # select alphas for computing the variance schedule
65 | alphas = alphacums[ddim_timesteps]
66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
67 |
68 | # according the the formula provided in https://arxiv.org/abs/2010.02502
69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
70 | if verbose:
71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
72 | print(f'For the chosen value of eta, which is {eta}, '
73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
74 | return sigmas, alphas, alphas_prev
75 |
76 |
77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
78 | """
79 | Create a beta schedule that discretizes the given alpha_t_bar function,
80 | which defines the cumulative product of (1-beta) over time from t = [0,1].
81 | :param num_diffusion_timesteps: the number of betas to produce.
82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
83 | produces the cumulative product of (1-beta) up to that
84 | part of the diffusion process.
85 | :param max_beta: the maximum beta to use; use values lower than 1 to
86 | prevent singularities.
87 | """
88 | betas = []
89 | for i in range(num_diffusion_timesteps):
90 | t1 = i / num_diffusion_timesteps
91 | t2 = (i + 1) / num_diffusion_timesteps
92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
93 | return np.array(betas)
94 |
95 |
96 | def extract_into_tensor(a, t, x_shape):
97 | b, *_ = t.shape
98 | out = a.gather(-1, t)
99 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
100 |
101 |
102 | def checkpoint(func, inputs, params, flag):
103 | """
104 | Evaluate a function without caching intermediate activations, allowing for
105 | reduced memory at the expense of extra compute in the backward pass.
106 | :param func: the function to evaluate.
107 | :param inputs: the argument sequence to pass to `func`.
108 | :param params: a sequence of parameters `func` depends on but does not
109 | explicitly take as arguments.
110 | :param flag: if False, disable gradient checkpointing.
111 | """
112 | if flag:
113 | args = tuple(inputs) + tuple(params)
114 | return CheckpointFunction.apply(func, len(inputs), *args)
115 | else:
116 | return func(*inputs)
117 |
118 |
119 | class CheckpointFunction(torch.autograd.Function):
120 | @staticmethod
121 | def forward(ctx, run_function, length, *args):
122 | ctx.run_function = run_function
123 | ctx.input_tensors = list(args[:length])
124 | ctx.input_params = list(args[length:])
125 | ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
126 | "dtype": torch.get_autocast_gpu_dtype(),
127 | "cache_enabled": torch.is_autocast_cache_enabled()}
128 | with torch.no_grad():
129 | output_tensors = ctx.run_function(*ctx.input_tensors)
130 | return output_tensors
131 |
132 | @staticmethod
133 | def backward(ctx, *output_grads):
134 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
135 | with torch.enable_grad(), \
136 | torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
137 | # Fixes a bug where the first op in run_function modifies the
138 | # Tensor storage in place, which is not allowed for detach()'d
139 | # Tensors.
140 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
141 | output_tensors = ctx.run_function(*shallow_copies)
142 | input_grads = torch.autograd.grad(
143 | output_tensors,
144 | ctx.input_tensors + ctx.input_params,
145 | output_grads,
146 | allow_unused=True,
147 | )
148 | del ctx.input_tensors
149 | del ctx.input_params
150 | del output_tensors
151 | return (None, None) + input_grads
152 |
153 |
154 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
155 | """
156 | Create sinusoidal timestep embeddings.
157 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
158 | These may be fractional.
159 | :param dim: the dimension of the output.
160 | :param max_period: controls the minimum frequency of the embeddings.
161 | :return: an [N x dim] Tensor of positional embeddings.
162 | """
163 | if not repeat_only:
164 | half = dim // 2
165 | freqs = torch.exp(
166 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
167 | ).to(device=timesteps.device)
168 | args = timesteps[:, None].float() * freqs[None]
169 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
170 | if dim % 2:
171 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
172 | else:
173 | embedding = repeat(timesteps, 'b -> b d', d=dim)
174 | return embedding
175 |
176 |
177 | def zero_module(module):
178 | """
179 | Zero out the parameters of a module and return it.
180 | """
181 | for p in module.parameters():
182 | p.detach().zero_()
183 | return module
184 |
185 |
186 | def scale_module(module, scale):
187 | """
188 | Scale the parameters of a module and return it.
189 | """
190 | for p in module.parameters():
191 | p.detach().mul_(scale)
192 | return module
193 |
194 |
195 | def mean_flat(tensor):
196 | """
197 | Take the mean over all non-batch dimensions.
198 | """
199 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
200 |
201 |
202 | def normalization(channels):
203 | """
204 | Make a standard normalization layer.
205 | :param channels: number of input channels.
206 | :return: an nn.Module for normalization.
207 | """
208 | return GroupNorm32(32, channels)
209 |
210 |
211 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
212 | class SiLU(nn.Module):
213 | def forward(self, x):
214 | return x * torch.sigmoid(x)
215 |
216 |
217 | class GroupNorm32(nn.GroupNorm):
218 | def forward(self, x):
219 | return super().forward(x.float()).type(x.dtype)
220 |
221 | def conv_nd(dims, *args, **kwargs):
222 | """
223 | Create a 1D, 2D, or 3D convolution module.
224 | """
225 | if dims == 1:
226 | return nn.Conv1d(*args, **kwargs)
227 | elif dims == 2:
228 | return nn.Conv2d(*args, **kwargs)
229 | elif dims == 3:
230 | return nn.Conv3d(*args, **kwargs)
231 | raise ValueError(f"unsupported dimensions: {dims}")
232 |
233 |
234 | def linear(*args, **kwargs):
235 | """
236 | Create a linear module.
237 | """
238 | return nn.Linear(*args, **kwargs)
239 |
240 |
241 | def avg_pool_nd(dims, *args, **kwargs):
242 | """
243 | Create a 1D, 2D, or 3D average pooling module.
244 | """
245 | if dims == 1:
246 | return nn.AvgPool1d(*args, **kwargs)
247 | elif dims == 2:
248 | return nn.AvgPool2d(*args, **kwargs)
249 | elif dims == 3:
250 | return nn.AvgPool3d(*args, **kwargs)
251 | raise ValueError(f"unsupported dimensions: {dims}")
252 |
253 |
254 | class HybridConditioner(nn.Module):
255 |
256 | def __init__(self, c_concat_config, c_crossattn_config):
257 | super().__init__()
258 | self.concat_conditioner = instantiate_from_config(c_concat_config)
259 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
260 |
261 | def forward(self, c_concat, c_crossattn):
262 | c_concat = self.concat_conditioner(c_concat)
263 | c_crossattn = self.crossattn_conditioner(c_crossattn)
264 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
265 |
266 |
267 | def noise_like(shape, device, repeat=False):
268 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
269 | noise = lambda: torch.randn(shape, device=device)
270 | return repeat_noise() if repeat else noise()
--------------------------------------------------------------------------------
/ldm/modules/distributions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wenquanlu/HandRefiner/f07e196298bafe871064b2952587548e28ccd467/ldm/modules/distributions/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/distributions/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 |
35 | def sample(self):
36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37 | return x
38 |
39 | def kl(self, other=None):
40 | if self.deterministic:
41 | return torch.Tensor([0.])
42 | else:
43 | if other is None:
44 | return 0.5 * torch.sum(torch.pow(self.mean, 2)
45 | + self.var - 1.0 - self.logvar,
46 | dim=[1, 2, 3])
47 | else:
48 | return 0.5 * torch.sum(
49 | torch.pow(self.mean - other.mean, 2) / other.var
50 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
51 | dim=[1, 2, 3])
52 |
53 | def nll(self, sample, dims=[1,2,3]):
54 | if self.deterministic:
55 | return torch.Tensor([0.])
56 | logtwopi = np.log(2.0 * np.pi)
57 | return 0.5 * torch.sum(
58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59 | dim=dims)
60 |
61 | def mode(self):
62 | return self.mean
63 |
64 |
65 | def normal_kl(mean1, logvar1, mean2, logvar2):
66 | """
67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68 | Compute the KL divergence between two gaussians.
69 | Shapes are automatically broadcasted, so batches can be compared to
70 | scalars, among other use cases.
71 | """
72 | tensor = None
73 | for obj in (mean1, logvar1, mean2, logvar2):
74 | if isinstance(obj, torch.Tensor):
75 | tensor = obj
76 | break
77 | assert tensor is not None, "at least one argument must be a Tensor"
78 |
79 | # Force variances to be Tensors. Broadcasting helps convert scalars to
80 | # Tensors, but it does not work for torch.exp().
81 | logvar1, logvar2 = [
82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83 | for x in (logvar1, logvar2)
84 | ]
85 |
86 | return 0.5 * (
87 | -1.0
88 | + logvar2
89 | - logvar1
90 | + torch.exp(logvar1 - logvar2)
91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92 | )
93 |
--------------------------------------------------------------------------------
/ldm/modules/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError('Decay must be between 0 and 1')
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
14 | else torch.tensor(-1, dtype=torch.int))
15 |
16 | for name, p in model.named_parameters():
17 | if p.requires_grad:
18 | # remove as '.'-character is not allowed in buffers
19 | s_name = name.replace('.', '')
20 | self.m_name2s_name.update({name: s_name})
21 | self.register_buffer(s_name, p.clone().detach().data)
22 |
23 | self.collected_params = []
24 |
25 | def reset_num_updates(self):
26 | del self.num_updates
27 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
28 |
29 | def forward(self, model):
30 | decay = self.decay
31 |
32 | if self.num_updates >= 0:
33 | self.num_updates += 1
34 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
35 |
36 | one_minus_decay = 1.0 - decay
37 |
38 | with torch.no_grad():
39 | m_param = dict(model.named_parameters())
40 | shadow_params = dict(self.named_buffers())
41 |
42 | for key in m_param:
43 | if m_param[key].requires_grad:
44 | sname = self.m_name2s_name[key]
45 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
46 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
47 | else:
48 | assert not key in self.m_name2s_name
49 |
50 | def copy_to(self, model):
51 | m_param = dict(model.named_parameters())
52 | shadow_params = dict(self.named_buffers())
53 | for key in m_param:
54 | if m_param[key].requires_grad:
55 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
56 | else:
57 | assert not key in self.m_name2s_name
58 |
59 | def store(self, parameters):
60 | """
61 | Save the current parameters for restoring later.
62 | Args:
63 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
64 | temporarily stored.
65 | """
66 | self.collected_params = [param.clone() for param in parameters]
67 |
68 | def restore(self, parameters):
69 | """
70 | Restore the parameters stored with the `store` method.
71 | Useful to validate the model with EMA parameters without affecting the
72 | original optimization process. Store the parameters before the
73 | `copy_to` method. After validation (or model saving), use this to
74 | restore the former parameters.
75 | Args:
76 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
77 | updated with the stored parameters.
78 | """
79 | for c_param, param in zip(self.collected_params, parameters):
80 | param.data.copy_(c_param.data)
81 |
--------------------------------------------------------------------------------
/ldm/modules/encoders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wenquanlu/HandRefiner/f07e196298bafe871064b2952587548e28ccd467/ldm/modules/encoders/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/encoders/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.utils.checkpoint import checkpoint
4 |
5 | from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
6 |
7 |
8 | from ldm.util import default, count_params
9 |
10 |
11 | class AbstractEncoder(nn.Module):
12 | def __init__(self):
13 | super().__init__()
14 |
15 | def encode(self, *args, **kwargs):
16 | raise NotImplementedError
17 |
18 |
19 | class IdentityEncoder(AbstractEncoder):
20 |
21 | def encode(self, x):
22 | return x
23 |
24 |
25 | class ClassEmbedder(nn.Module):
26 | def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
27 | super().__init__()
28 | self.key = key
29 | self.embedding = nn.Embedding(n_classes, embed_dim)
30 | self.n_classes = n_classes
31 | self.ucg_rate = ucg_rate
32 |
33 | def forward(self, batch, key=None, disable_dropout=False):
34 | if key is None:
35 | key = self.key
36 | # this is for use in crossattn
37 | c = batch[key][:, None]
38 | if self.ucg_rate > 0. and not disable_dropout:
39 | mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
40 | c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1)
41 | c = c.long()
42 | c = self.embedding(c)
43 | return c
44 |
45 | def get_unconditional_conditioning(self, bs, device="cuda"):
46 | uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
47 | uc = torch.ones((bs,), device=device) * uc_class
48 | uc = {self.key: uc}
49 | return uc
50 |
51 |
52 | def disabled_train(self, mode=True):
53 | """Overwrite model.train with this function to make sure train/eval mode
54 | does not change anymore."""
55 | return self
56 |
57 |
58 | class FrozenT5Embedder(AbstractEncoder):
59 | """Uses the T5 transformer encoder for text"""
60 | def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
61 | super().__init__()
62 | self.tokenizer = T5Tokenizer.from_pretrained(version)
63 | self.transformer = T5EncoderModel.from_pretrained(version)
64 | self.device = device
65 | self.max_length = max_length # TODO: typical value?
66 | if freeze:
67 | self.freeze()
68 |
69 | def freeze(self):
70 | self.transformer = self.transformer.eval()
71 | #self.train = disabled_train
72 | for param in self.parameters():
73 | param.requires_grad = False
74 |
75 | def forward(self, text):
76 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
77 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
78 | tokens = batch_encoding["input_ids"].to(self.device)
79 | outputs = self.transformer(input_ids=tokens)
80 |
81 | z = outputs.last_hidden_state
82 | return z
83 |
84 | def encode(self, text):
85 | return self(text)
86 |
87 |
88 | class FrozenCLIPEmbedder(AbstractEncoder):
89 | """Uses the CLIP transformer encoder for text (from huggingface)"""
90 | LAYERS = [
91 | "last",
92 | "pooled",
93 | "hidden"
94 | ]
95 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
96 | freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
97 | super().__init__()
98 | assert layer in self.LAYERS
99 | self.tokenizer = CLIPTokenizer.from_pretrained(version)
100 | self.transformer = CLIPTextModel.from_pretrained(version)
101 | self.device = device
102 | self.max_length = max_length
103 | if freeze:
104 | self.freeze()
105 | self.layer = layer
106 | self.layer_idx = layer_idx
107 | if layer == "hidden":
108 | assert layer_idx is not None
109 | assert 0 <= abs(layer_idx) <= 12
110 |
111 | def freeze(self):
112 | self.transformer = self.transformer.eval()
113 | #self.train = disabled_train
114 | for param in self.parameters():
115 | param.requires_grad = False
116 |
117 | def forward(self, text):
118 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
119 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
120 | tokens = batch_encoding["input_ids"].to("cuda")
121 | outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
122 | if self.layer == "last":
123 | z = outputs.last_hidden_state
124 | elif self.layer == "pooled":
125 | z = outputs.pooler_output[:, None, :]
126 | else:
127 | z = outputs.hidden_states[self.layer_idx]
128 | return z
129 |
130 | def encode(self, text):
131 | return self(text)
132 |
133 |
134 |
135 |
136 | class FrozenCLIPT5Encoder(AbstractEncoder):
137 | def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
138 | clip_max_length=77, t5_max_length=77):
139 | super().__init__()
140 | self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
141 | self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
142 | print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
143 | f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.")
144 |
145 | def encode(self, text):
146 | return self(text)
147 |
148 | def forward(self, text):
149 | clip_z = self.clip_encoder.encode(text)
150 | t5_z = self.t5_encoder.encode(text)
151 | return [clip_z, t5_z]
152 |
153 |
154 |
--------------------------------------------------------------------------------
/ldm/modules/image_degradation/__init__.py:
--------------------------------------------------------------------------------
1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
3 |
--------------------------------------------------------------------------------
/ldm/modules/image_degradation/utils/test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wenquanlu/HandRefiner/f07e196298bafe871064b2952587548e28ccd467/ldm/modules/image_degradation/utils/test.png
--------------------------------------------------------------------------------
/ldm/modules/midas/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wenquanlu/HandRefiner/f07e196298bafe871064b2952587548e28ccd467/ldm/modules/midas/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/midas/api.py:
--------------------------------------------------------------------------------
1 | # based on https://github.com/isl-org/MiDaS
2 |
3 | import cv2
4 | import torch
5 | import torch.nn as nn
6 | from torchvision.transforms import Compose
7 |
8 | from ldm.modules.midas.midas.dpt_depth import DPTDepthModel
9 | from ldm.modules.midas.midas.midas_net import MidasNet
10 | from ldm.modules.midas.midas.midas_net_custom import MidasNet_small
11 | from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet
12 |
13 |
14 | ISL_PATHS = {
15 | "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt",
16 | "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt",
17 | "midas_v21": "",
18 | "midas_v21_small": "",
19 | }
20 |
21 |
22 | def disabled_train(self, mode=True):
23 | """Overwrite model.train with this function to make sure train/eval mode
24 | does not change anymore."""
25 | return self
26 |
27 |
28 | def load_midas_transform(model_type):
29 | # https://github.com/isl-org/MiDaS/blob/master/run.py
30 | # load transform only
31 | if model_type == "dpt_large": # DPT-Large
32 | net_w, net_h = 384, 384
33 | resize_mode = "minimal"
34 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
35 |
36 | elif model_type == "dpt_hybrid": # DPT-Hybrid
37 | net_w, net_h = 384, 384
38 | resize_mode = "minimal"
39 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
40 |
41 | elif model_type == "midas_v21":
42 | net_w, net_h = 384, 384
43 | resize_mode = "upper_bound"
44 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
45 |
46 | elif model_type == "midas_v21_small":
47 | net_w, net_h = 256, 256
48 | resize_mode = "upper_bound"
49 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
50 |
51 | else:
52 | assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
53 |
54 | transform = Compose(
55 | [
56 | Resize(
57 | net_w,
58 | net_h,
59 | resize_target=None,
60 | keep_aspect_ratio=True,
61 | ensure_multiple_of=32,
62 | resize_method=resize_mode,
63 | image_interpolation_method=cv2.INTER_CUBIC,
64 | ),
65 | normalization,
66 | PrepareForNet(),
67 | ]
68 | )
69 |
70 | return transform
71 |
72 |
73 | def load_model(model_type):
74 | # https://github.com/isl-org/MiDaS/blob/master/run.py
75 | # load network
76 | model_path = ISL_PATHS[model_type]
77 | if model_type == "dpt_large": # DPT-Large
78 | model = DPTDepthModel(
79 | path=model_path,
80 | backbone="vitl16_384",
81 | non_negative=True,
82 | )
83 | net_w, net_h = 384, 384
84 | resize_mode = "minimal"
85 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
86 |
87 | elif model_type == "dpt_hybrid": # DPT-Hybrid
88 | model = DPTDepthModel(
89 | path=model_path,
90 | backbone="vitb_rn50_384",
91 | non_negative=True,
92 | )
93 | net_w, net_h = 384, 384
94 | resize_mode = "minimal"
95 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
96 |
97 | elif model_type == "midas_v21":
98 | model = MidasNet(model_path, non_negative=True)
99 | net_w, net_h = 384, 384
100 | resize_mode = "upper_bound"
101 | normalization = NormalizeImage(
102 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
103 | )
104 |
105 | elif model_type == "midas_v21_small":
106 | model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
107 | non_negative=True, blocks={'expand': True})
108 | net_w, net_h = 256, 256
109 | resize_mode = "upper_bound"
110 | normalization = NormalizeImage(
111 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
112 | )
113 |
114 | else:
115 | print(f"model_type '{model_type}' not implemented, use: --model_type large")
116 | assert False
117 |
118 | transform = Compose(
119 | [
120 | Resize(
121 | net_w,
122 | net_h,
123 | resize_target=None,
124 | keep_aspect_ratio=True,
125 | ensure_multiple_of=32,
126 | resize_method=resize_mode,
127 | image_interpolation_method=cv2.INTER_CUBIC,
128 | ),
129 | normalization,
130 | PrepareForNet(),
131 | ]
132 | )
133 |
134 | return model.eval(), transform
135 |
136 |
137 | class MiDaSInference(nn.Module):
138 | MODEL_TYPES_TORCH_HUB = [
139 | "DPT_Large",
140 | "DPT_Hybrid",
141 | "MiDaS_small"
142 | ]
143 | MODEL_TYPES_ISL = [
144 | "dpt_large",
145 | "dpt_hybrid",
146 | "midas_v21",
147 | "midas_v21_small",
148 | ]
149 |
150 | def __init__(self, model_type):
151 | super().__init__()
152 | assert (model_type in self.MODEL_TYPES_ISL)
153 | model, _ = load_model(model_type)
154 | self.model = model
155 | self.model.train = disabled_train
156 |
157 | def forward(self, x):
158 | # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array
159 | # NOTE: we expect that the correct transform has been called during dataloading.
160 | with torch.no_grad():
161 | prediction = self.model(x)
162 | prediction = torch.nn.functional.interpolate(
163 | prediction.unsqueeze(1),
164 | size=x.shape[2:],
165 | mode="bicubic",
166 | align_corners=False,
167 | )
168 | assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3])
169 | return prediction
170 |
171 |
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wenquanlu/HandRefiner/f07e196298bafe871064b2952587548e28ccd467/ldm/modules/midas/midas/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/base_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class BaseModel(torch.nn.Module):
5 | def load(self, path):
6 | """Load model from file.
7 |
8 | Args:
9 | path (str): file path
10 | """
11 | parameters = torch.load(path, map_location=torch.device('cpu'))
12 |
13 | if "optimizer" in parameters:
14 | parameters = parameters["model"]
15 |
16 | self.load_state_dict(parameters)
17 |
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/blocks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .vit import (
5 | _make_pretrained_vitb_rn50_384,
6 | _make_pretrained_vitl16_384,
7 | _make_pretrained_vitb16_384,
8 | forward_vit,
9 | )
10 |
11 | def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
12 | if backbone == "vitl16_384":
13 | pretrained = _make_pretrained_vitl16_384(
14 | use_pretrained, hooks=hooks, use_readout=use_readout
15 | )
16 | scratch = _make_scratch(
17 | [256, 512, 1024, 1024], features, groups=groups, expand=expand
18 | ) # ViT-L/16 - 85.0% Top1 (backbone)
19 | elif backbone == "vitb_rn50_384":
20 | pretrained = _make_pretrained_vitb_rn50_384(
21 | use_pretrained,
22 | hooks=hooks,
23 | use_vit_only=use_vit_only,
24 | use_readout=use_readout,
25 | )
26 | scratch = _make_scratch(
27 | [256, 512, 768, 768], features, groups=groups, expand=expand
28 | ) # ViT-H/16 - 85.0% Top1 (backbone)
29 | elif backbone == "vitb16_384":
30 | pretrained = _make_pretrained_vitb16_384(
31 | use_pretrained, hooks=hooks, use_readout=use_readout
32 | )
33 | scratch = _make_scratch(
34 | [96, 192, 384, 768], features, groups=groups, expand=expand
35 | ) # ViT-B/16 - 84.6% Top1 (backbone)
36 | elif backbone == "resnext101_wsl":
37 | pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
38 | scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
39 | elif backbone == "efficientnet_lite3":
40 | pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
41 | scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
42 | else:
43 | print(f"Backbone '{backbone}' not implemented")
44 | assert False
45 |
46 | return pretrained, scratch
47 |
48 |
49 | def _make_scratch(in_shape, out_shape, groups=1, expand=False):
50 | scratch = nn.Module()
51 |
52 | out_shape1 = out_shape
53 | out_shape2 = out_shape
54 | out_shape3 = out_shape
55 | out_shape4 = out_shape
56 | if expand==True:
57 | out_shape1 = out_shape
58 | out_shape2 = out_shape*2
59 | out_shape3 = out_shape*4
60 | out_shape4 = out_shape*8
61 |
62 | scratch.layer1_rn = nn.Conv2d(
63 | in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
64 | )
65 | scratch.layer2_rn = nn.Conv2d(
66 | in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
67 | )
68 | scratch.layer3_rn = nn.Conv2d(
69 | in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
70 | )
71 | scratch.layer4_rn = nn.Conv2d(
72 | in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
73 | )
74 |
75 | return scratch
76 |
77 |
78 | def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
79 | efficientnet = torch.hub.load(
80 | "rwightman/gen-efficientnet-pytorch",
81 | "tf_efficientnet_lite3",
82 | pretrained=use_pretrained,
83 | exportable=exportable
84 | )
85 | return _make_efficientnet_backbone(efficientnet)
86 |
87 |
88 | def _make_efficientnet_backbone(effnet):
89 | pretrained = nn.Module()
90 |
91 | pretrained.layer1 = nn.Sequential(
92 | effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
93 | )
94 | pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
95 | pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
96 | pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
97 |
98 | return pretrained
99 |
100 |
101 | def _make_resnet_backbone(resnet):
102 | pretrained = nn.Module()
103 | pretrained.layer1 = nn.Sequential(
104 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
105 | )
106 |
107 | pretrained.layer2 = resnet.layer2
108 | pretrained.layer3 = resnet.layer3
109 | pretrained.layer4 = resnet.layer4
110 |
111 | return pretrained
112 |
113 |
114 | def _make_pretrained_resnext101_wsl(use_pretrained):
115 | resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
116 | return _make_resnet_backbone(resnet)
117 |
118 |
119 |
120 | class Interpolate(nn.Module):
121 | """Interpolation module.
122 | """
123 |
124 | def __init__(self, scale_factor, mode, align_corners=False):
125 | """Init.
126 |
127 | Args:
128 | scale_factor (float): scaling
129 | mode (str): interpolation mode
130 | """
131 | super(Interpolate, self).__init__()
132 |
133 | self.interp = nn.functional.interpolate
134 | self.scale_factor = scale_factor
135 | self.mode = mode
136 | self.align_corners = align_corners
137 |
138 | def forward(self, x):
139 | """Forward pass.
140 |
141 | Args:
142 | x (tensor): input
143 |
144 | Returns:
145 | tensor: interpolated data
146 | """
147 |
148 | x = self.interp(
149 | x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
150 | )
151 |
152 | return x
153 |
154 |
155 | class ResidualConvUnit(nn.Module):
156 | """Residual convolution module.
157 | """
158 |
159 | def __init__(self, features):
160 | """Init.
161 |
162 | Args:
163 | features (int): number of features
164 | """
165 | super().__init__()
166 |
167 | self.conv1 = nn.Conv2d(
168 | features, features, kernel_size=3, stride=1, padding=1, bias=True
169 | )
170 |
171 | self.conv2 = nn.Conv2d(
172 | features, features, kernel_size=3, stride=1, padding=1, bias=True
173 | )
174 |
175 | self.relu = nn.ReLU(inplace=True)
176 |
177 | def forward(self, x):
178 | """Forward pass.
179 |
180 | Args:
181 | x (tensor): input
182 |
183 | Returns:
184 | tensor: output
185 | """
186 | out = self.relu(x)
187 | out = self.conv1(out)
188 | out = self.relu(out)
189 | out = self.conv2(out)
190 |
191 | return out + x
192 |
193 |
194 | class FeatureFusionBlock(nn.Module):
195 | """Feature fusion block.
196 | """
197 |
198 | def __init__(self, features):
199 | """Init.
200 |
201 | Args:
202 | features (int): number of features
203 | """
204 | super(FeatureFusionBlock, self).__init__()
205 |
206 | self.resConfUnit1 = ResidualConvUnit(features)
207 | self.resConfUnit2 = ResidualConvUnit(features)
208 |
209 | def forward(self, *xs):
210 | """Forward pass.
211 |
212 | Returns:
213 | tensor: output
214 | """
215 | output = xs[0]
216 |
217 | if len(xs) == 2:
218 | output += self.resConfUnit1(xs[1])
219 |
220 | output = self.resConfUnit2(output)
221 |
222 | output = nn.functional.interpolate(
223 | output, scale_factor=2, mode="bilinear", align_corners=True
224 | )
225 |
226 | return output
227 |
228 |
229 |
230 |
231 | class ResidualConvUnit_custom(nn.Module):
232 | """Residual convolution module.
233 | """
234 |
235 | def __init__(self, features, activation, bn):
236 | """Init.
237 |
238 | Args:
239 | features (int): number of features
240 | """
241 | super().__init__()
242 |
243 | self.bn = bn
244 |
245 | self.groups=1
246 |
247 | self.conv1 = nn.Conv2d(
248 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
249 | )
250 |
251 | self.conv2 = nn.Conv2d(
252 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
253 | )
254 |
255 | if self.bn==True:
256 | self.bn1 = nn.BatchNorm2d(features)
257 | self.bn2 = nn.BatchNorm2d(features)
258 |
259 | self.activation = activation
260 |
261 | self.skip_add = nn.quantized.FloatFunctional()
262 |
263 | def forward(self, x):
264 | """Forward pass.
265 |
266 | Args:
267 | x (tensor): input
268 |
269 | Returns:
270 | tensor: output
271 | """
272 |
273 | out = self.activation(x)
274 | out = self.conv1(out)
275 | if self.bn==True:
276 | out = self.bn1(out)
277 |
278 | out = self.activation(out)
279 | out = self.conv2(out)
280 | if self.bn==True:
281 | out = self.bn2(out)
282 |
283 | if self.groups > 1:
284 | out = self.conv_merge(out)
285 |
286 | return self.skip_add.add(out, x)
287 |
288 | # return out + x
289 |
290 |
291 | class FeatureFusionBlock_custom(nn.Module):
292 | """Feature fusion block.
293 | """
294 |
295 | def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
296 | """Init.
297 |
298 | Args:
299 | features (int): number of features
300 | """
301 | super(FeatureFusionBlock_custom, self).__init__()
302 |
303 | self.deconv = deconv
304 | self.align_corners = align_corners
305 |
306 | self.groups=1
307 |
308 | self.expand = expand
309 | out_features = features
310 | if self.expand==True:
311 | out_features = features//2
312 |
313 | self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
314 |
315 | self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
316 | self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
317 |
318 | self.skip_add = nn.quantized.FloatFunctional()
319 |
320 | def forward(self, *xs):
321 | """Forward pass.
322 |
323 | Returns:
324 | tensor: output
325 | """
326 | output = xs[0]
327 |
328 | if len(xs) == 2:
329 | res = self.resConfUnit1(xs[1])
330 | output = self.skip_add.add(output, res)
331 | # output += res
332 |
333 | output = self.resConfUnit2(output)
334 |
335 | output = nn.functional.interpolate(
336 | output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
337 | )
338 |
339 | output = self.out_conv(output)
340 |
341 | return output
342 |
343 |
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/dpt_depth.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .base_model import BaseModel
6 | from .blocks import (
7 | FeatureFusionBlock,
8 | FeatureFusionBlock_custom,
9 | Interpolate,
10 | _make_encoder,
11 | forward_vit,
12 | )
13 |
14 |
15 | def _make_fusion_block(features, use_bn):
16 | return FeatureFusionBlock_custom(
17 | features,
18 | nn.ReLU(False),
19 | deconv=False,
20 | bn=use_bn,
21 | expand=False,
22 | align_corners=True,
23 | )
24 |
25 |
26 | class DPT(BaseModel):
27 | def __init__(
28 | self,
29 | head,
30 | features=256,
31 | backbone="vitb_rn50_384",
32 | readout="project",
33 | channels_last=False,
34 | use_bn=False,
35 | ):
36 |
37 | super(DPT, self).__init__()
38 |
39 | self.channels_last = channels_last
40 |
41 | hooks = {
42 | "vitb_rn50_384": [0, 1, 8, 11],
43 | "vitb16_384": [2, 5, 8, 11],
44 | "vitl16_384": [5, 11, 17, 23],
45 | }
46 |
47 | # Instantiate backbone and reassemble blocks
48 | self.pretrained, self.scratch = _make_encoder(
49 | backbone,
50 | features,
51 | False, # Set to true of you want to train from scratch, uses ImageNet weights
52 | groups=1,
53 | expand=False,
54 | exportable=False,
55 | hooks=hooks[backbone],
56 | use_readout=readout,
57 | )
58 |
59 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
60 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
61 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
62 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
63 |
64 | self.scratch.output_conv = head
65 |
66 |
67 | def forward(self, x):
68 | if self.channels_last == True:
69 | x.contiguous(memory_format=torch.channels_last)
70 |
71 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
72 |
73 | layer_1_rn = self.scratch.layer1_rn(layer_1)
74 | layer_2_rn = self.scratch.layer2_rn(layer_2)
75 | layer_3_rn = self.scratch.layer3_rn(layer_3)
76 | layer_4_rn = self.scratch.layer4_rn(layer_4)
77 |
78 | path_4 = self.scratch.refinenet4(layer_4_rn)
79 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
80 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
81 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
82 |
83 | out = self.scratch.output_conv(path_1)
84 |
85 | return out
86 |
87 |
88 | class DPTDepthModel(DPT):
89 | def __init__(self, path=None, non_negative=True, **kwargs):
90 | features = kwargs["features"] if "features" in kwargs else 256
91 |
92 | head = nn.Sequential(
93 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
94 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
95 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
96 | nn.ReLU(True),
97 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
98 | nn.ReLU(True) if non_negative else nn.Identity(),
99 | nn.Identity(),
100 | )
101 |
102 | super().__init__(head, **kwargs)
103 |
104 | if path is not None:
105 | self.load(path)
106 |
107 | def forward(self, x):
108 | return super().forward(x).squeeze(dim=1)
109 |
110 |
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/midas_net.py:
--------------------------------------------------------------------------------
1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2 | This file contains code that is adapted from
3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4 | """
5 | import torch
6 | import torch.nn as nn
7 |
8 | from .base_model import BaseModel
9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
10 |
11 |
12 | class MidasNet(BaseModel):
13 | """Network for monocular depth estimation.
14 | """
15 |
16 | def __init__(self, path=None, features=256, non_negative=True):
17 | """Init.
18 |
19 | Args:
20 | path (str, optional): Path to saved model. Defaults to None.
21 | features (int, optional): Number of features. Defaults to 256.
22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50
23 | """
24 | print("Loading weights: ", path)
25 |
26 | super(MidasNet, self).__init__()
27 |
28 | use_pretrained = False if path is None else True
29 |
30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
31 |
32 | self.scratch.refinenet4 = FeatureFusionBlock(features)
33 | self.scratch.refinenet3 = FeatureFusionBlock(features)
34 | self.scratch.refinenet2 = FeatureFusionBlock(features)
35 | self.scratch.refinenet1 = FeatureFusionBlock(features)
36 |
37 | self.scratch.output_conv = nn.Sequential(
38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
39 | Interpolate(scale_factor=2, mode="bilinear"),
40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
41 | nn.ReLU(True),
42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
43 | nn.ReLU(True) if non_negative else nn.Identity(),
44 | )
45 |
46 | if path:
47 | self.load(path)
48 |
49 | def forward(self, x):
50 | """Forward pass.
51 |
52 | Args:
53 | x (tensor): input data (image)
54 |
55 | Returns:
56 | tensor: depth
57 | """
58 |
59 | layer_1 = self.pretrained.layer1(x)
60 | layer_2 = self.pretrained.layer2(layer_1)
61 | layer_3 = self.pretrained.layer3(layer_2)
62 | layer_4 = self.pretrained.layer4(layer_3)
63 |
64 | layer_1_rn = self.scratch.layer1_rn(layer_1)
65 | layer_2_rn = self.scratch.layer2_rn(layer_2)
66 | layer_3_rn = self.scratch.layer3_rn(layer_3)
67 | layer_4_rn = self.scratch.layer4_rn(layer_4)
68 |
69 | path_4 = self.scratch.refinenet4(layer_4_rn)
70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
73 |
74 | out = self.scratch.output_conv(path_1)
75 |
76 | return torch.squeeze(out, dim=1)
77 |
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/midas_net_custom.py:
--------------------------------------------------------------------------------
1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2 | This file contains code that is adapted from
3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4 | """
5 | import torch
6 | import torch.nn as nn
7 |
8 | from .base_model import BaseModel
9 | from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
10 |
11 |
12 | class MidasNet_small(BaseModel):
13 | """Network for monocular depth estimation.
14 | """
15 |
16 | def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
17 | blocks={'expand': True}):
18 | """Init.
19 |
20 | Args:
21 | path (str, optional): Path to saved model. Defaults to None.
22 | features (int, optional): Number of features. Defaults to 256.
23 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50
24 | """
25 | print("Loading weights: ", path)
26 |
27 | super(MidasNet_small, self).__init__()
28 |
29 | use_pretrained = False if path else True
30 |
31 | self.channels_last = channels_last
32 | self.blocks = blocks
33 | self.backbone = backbone
34 |
35 | self.groups = 1
36 |
37 | features1=features
38 | features2=features
39 | features3=features
40 | features4=features
41 | self.expand = False
42 | if "expand" in self.blocks and self.blocks['expand'] == True:
43 | self.expand = True
44 | features1=features
45 | features2=features*2
46 | features3=features*4
47 | features4=features*8
48 |
49 | self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
50 |
51 | self.scratch.activation = nn.ReLU(False)
52 |
53 | self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
54 | self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
55 | self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
56 | self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
57 |
58 |
59 | self.scratch.output_conv = nn.Sequential(
60 | nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
61 | Interpolate(scale_factor=2, mode="bilinear"),
62 | nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
63 | self.scratch.activation,
64 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
65 | nn.ReLU(True) if non_negative else nn.Identity(),
66 | nn.Identity(),
67 | )
68 |
69 | if path:
70 | self.load(path)
71 |
72 |
73 | def forward(self, x):
74 | """Forward pass.
75 |
76 | Args:
77 | x (tensor): input data (image)
78 |
79 | Returns:
80 | tensor: depth
81 | """
82 | if self.channels_last==True:
83 | print("self.channels_last = ", self.channels_last)
84 | x.contiguous(memory_format=torch.channels_last)
85 |
86 |
87 | layer_1 = self.pretrained.layer1(x)
88 | layer_2 = self.pretrained.layer2(layer_1)
89 | layer_3 = self.pretrained.layer3(layer_2)
90 | layer_4 = self.pretrained.layer4(layer_3)
91 |
92 | layer_1_rn = self.scratch.layer1_rn(layer_1)
93 | layer_2_rn = self.scratch.layer2_rn(layer_2)
94 | layer_3_rn = self.scratch.layer3_rn(layer_3)
95 | layer_4_rn = self.scratch.layer4_rn(layer_4)
96 |
97 |
98 | path_4 = self.scratch.refinenet4(layer_4_rn)
99 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
100 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
101 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
102 |
103 | out = self.scratch.output_conv(path_1)
104 |
105 | return torch.squeeze(out, dim=1)
106 |
107 |
108 |
109 | def fuse_model(m):
110 | prev_previous_type = nn.Identity()
111 | prev_previous_name = ''
112 | previous_type = nn.Identity()
113 | previous_name = ''
114 | for name, module in m.named_modules():
115 | if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
116 | # print("FUSED ", prev_previous_name, previous_name, name)
117 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
118 | elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
119 | # print("FUSED ", prev_previous_name, previous_name)
120 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
121 | # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
122 | # print("FUSED ", previous_name, name)
123 | # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
124 |
125 | prev_previous_type = previous_type
126 | prev_previous_name = previous_name
127 | previous_type = type(module)
128 | previous_name = name
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/transforms.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import math
4 |
5 |
6 | def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
7 | """Rezise the sample to ensure the given size. Keeps aspect ratio.
8 |
9 | Args:
10 | sample (dict): sample
11 | size (tuple): image size
12 |
13 | Returns:
14 | tuple: new size
15 | """
16 | shape = list(sample["disparity"].shape)
17 |
18 | if shape[0] >= size[0] and shape[1] >= size[1]:
19 | return sample
20 |
21 | scale = [0, 0]
22 | scale[0] = size[0] / shape[0]
23 | scale[1] = size[1] / shape[1]
24 |
25 | scale = max(scale)
26 |
27 | shape[0] = math.ceil(scale * shape[0])
28 | shape[1] = math.ceil(scale * shape[1])
29 |
30 | # resize
31 | sample["image"] = cv2.resize(
32 | sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
33 | )
34 |
35 | sample["disparity"] = cv2.resize(
36 | sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
37 | )
38 | sample["mask"] = cv2.resize(
39 | sample["mask"].astype(np.float32),
40 | tuple(shape[::-1]),
41 | interpolation=cv2.INTER_NEAREST,
42 | )
43 | sample["mask"] = sample["mask"].astype(bool)
44 |
45 | return tuple(shape)
46 |
47 |
48 | class Resize(object):
49 | """Resize sample to given size (width, height).
50 | """
51 |
52 | def __init__(
53 | self,
54 | width,
55 | height,
56 | resize_target=True,
57 | keep_aspect_ratio=False,
58 | ensure_multiple_of=1,
59 | resize_method="lower_bound",
60 | image_interpolation_method=cv2.INTER_AREA,
61 | ):
62 | """Init.
63 |
64 | Args:
65 | width (int): desired output width
66 | height (int): desired output height
67 | resize_target (bool, optional):
68 | True: Resize the full sample (image, mask, target).
69 | False: Resize image only.
70 | Defaults to True.
71 | keep_aspect_ratio (bool, optional):
72 | True: Keep the aspect ratio of the input sample.
73 | Output sample might not have the given width and height, and
74 | resize behaviour depends on the parameter 'resize_method'.
75 | Defaults to False.
76 | ensure_multiple_of (int, optional):
77 | Output width and height is constrained to be multiple of this parameter.
78 | Defaults to 1.
79 | resize_method (str, optional):
80 | "lower_bound": Output will be at least as large as the given size.
81 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
82 | "minimal": Scale as least as possible. (Output size might be smaller than given size.)
83 | Defaults to "lower_bound".
84 | """
85 | self.__width = width
86 | self.__height = height
87 |
88 | self.__resize_target = resize_target
89 | self.__keep_aspect_ratio = keep_aspect_ratio
90 | self.__multiple_of = ensure_multiple_of
91 | self.__resize_method = resize_method
92 | self.__image_interpolation_method = image_interpolation_method
93 |
94 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
95 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
96 |
97 | if max_val is not None and y > max_val:
98 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
99 |
100 | if y < min_val:
101 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
102 |
103 | return y
104 |
105 | def get_size(self, width, height):
106 | # determine new height and width
107 | scale_height = self.__height / height
108 | scale_width = self.__width / width
109 |
110 | if self.__keep_aspect_ratio:
111 | if self.__resize_method == "lower_bound":
112 | # scale such that output size is lower bound
113 | if scale_width > scale_height:
114 | # fit width
115 | scale_height = scale_width
116 | else:
117 | # fit height
118 | scale_width = scale_height
119 | elif self.__resize_method == "upper_bound":
120 | # scale such that output size is upper bound
121 | if scale_width < scale_height:
122 | # fit width
123 | scale_height = scale_width
124 | else:
125 | # fit height
126 | scale_width = scale_height
127 | elif self.__resize_method == "minimal":
128 | # scale as least as possbile
129 | if abs(1 - scale_width) < abs(1 - scale_height):
130 | # fit width
131 | scale_height = scale_width
132 | else:
133 | # fit height
134 | scale_width = scale_height
135 | else:
136 | raise ValueError(
137 | f"resize_method {self.__resize_method} not implemented"
138 | )
139 |
140 | if self.__resize_method == "lower_bound":
141 | new_height = self.constrain_to_multiple_of(
142 | scale_height * height, min_val=self.__height
143 | )
144 | new_width = self.constrain_to_multiple_of(
145 | scale_width * width, min_val=self.__width
146 | )
147 | elif self.__resize_method == "upper_bound":
148 | new_height = self.constrain_to_multiple_of(
149 | scale_height * height, max_val=self.__height
150 | )
151 | new_width = self.constrain_to_multiple_of(
152 | scale_width * width, max_val=self.__width
153 | )
154 | elif self.__resize_method == "minimal":
155 | new_height = self.constrain_to_multiple_of(scale_height * height)
156 | new_width = self.constrain_to_multiple_of(scale_width * width)
157 | else:
158 | raise ValueError(f"resize_method {self.__resize_method} not implemented")
159 |
160 | return (new_width, new_height)
161 |
162 | def __call__(self, sample):
163 | width, height = self.get_size(
164 | sample["image"].shape[1], sample["image"].shape[0]
165 | )
166 |
167 | # resize sample
168 | sample["image"] = cv2.resize(
169 | sample["image"],
170 | (width, height),
171 | interpolation=self.__image_interpolation_method,
172 | )
173 |
174 | if self.__resize_target:
175 | if "disparity" in sample:
176 | sample["disparity"] = cv2.resize(
177 | sample["disparity"],
178 | (width, height),
179 | interpolation=cv2.INTER_NEAREST,
180 | )
181 |
182 | if "depth" in sample:
183 | sample["depth"] = cv2.resize(
184 | sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
185 | )
186 |
187 | sample["mask"] = cv2.resize(
188 | sample["mask"].astype(np.float32),
189 | (width, height),
190 | interpolation=cv2.INTER_NEAREST,
191 | )
192 | sample["mask"] = sample["mask"].astype(bool)
193 |
194 | return sample
195 |
196 |
197 | class NormalizeImage(object):
198 | """Normlize image by given mean and std.
199 | """
200 |
201 | def __init__(self, mean, std):
202 | self.__mean = mean
203 | self.__std = std
204 |
205 | def __call__(self, sample):
206 | sample["image"] = (sample["image"] - self.__mean) / self.__std
207 |
208 | return sample
209 |
210 |
211 | class PrepareForNet(object):
212 | """Prepare sample for usage as network input.
213 | """
214 |
215 | def __init__(self):
216 | pass
217 |
218 | def __call__(self, sample):
219 | image = np.transpose(sample["image"], (2, 0, 1))
220 | sample["image"] = np.ascontiguousarray(image).astype(np.float32)
221 |
222 | if "mask" in sample:
223 | sample["mask"] = sample["mask"].astype(np.float32)
224 | sample["mask"] = np.ascontiguousarray(sample["mask"])
225 |
226 | if "disparity" in sample:
227 | disparity = sample["disparity"].astype(np.float32)
228 | sample["disparity"] = np.ascontiguousarray(disparity)
229 |
230 | if "depth" in sample:
231 | depth = sample["depth"].astype(np.float32)
232 | sample["depth"] = np.ascontiguousarray(depth)
233 |
234 | return sample
235 |
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/vit.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import timm
4 | import types
5 | import math
6 | import torch.nn.functional as F
7 |
8 |
9 | class Slice(nn.Module):
10 | def __init__(self, start_index=1):
11 | super(Slice, self).__init__()
12 | self.start_index = start_index
13 |
14 | def forward(self, x):
15 | return x[:, self.start_index :]
16 |
17 |
18 | class AddReadout(nn.Module):
19 | def __init__(self, start_index=1):
20 | super(AddReadout, self).__init__()
21 | self.start_index = start_index
22 |
23 | def forward(self, x):
24 | if self.start_index == 2:
25 | readout = (x[:, 0] + x[:, 1]) / 2
26 | else:
27 | readout = x[:, 0]
28 | return x[:, self.start_index :] + readout.unsqueeze(1)
29 |
30 |
31 | class ProjectReadout(nn.Module):
32 | def __init__(self, in_features, start_index=1):
33 | super(ProjectReadout, self).__init__()
34 | self.start_index = start_index
35 |
36 | self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
37 |
38 | def forward(self, x):
39 | readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
40 | features = torch.cat((x[:, self.start_index :], readout), -1)
41 |
42 | return self.project(features)
43 |
44 |
45 | class Transpose(nn.Module):
46 | def __init__(self, dim0, dim1):
47 | super(Transpose, self).__init__()
48 | self.dim0 = dim0
49 | self.dim1 = dim1
50 |
51 | def forward(self, x):
52 | x = x.transpose(self.dim0, self.dim1)
53 | return x
54 |
55 |
56 | def forward_vit(pretrained, x):
57 | b, c, h, w = x.shape
58 |
59 | glob = pretrained.model.forward_flex(x)
60 |
61 | layer_1 = pretrained.activations["1"]
62 | layer_2 = pretrained.activations["2"]
63 | layer_3 = pretrained.activations["3"]
64 | layer_4 = pretrained.activations["4"]
65 |
66 | layer_1 = pretrained.act_postprocess1[0:2](layer_1)
67 | layer_2 = pretrained.act_postprocess2[0:2](layer_2)
68 | layer_3 = pretrained.act_postprocess3[0:2](layer_3)
69 | layer_4 = pretrained.act_postprocess4[0:2](layer_4)
70 |
71 | unflatten = nn.Sequential(
72 | nn.Unflatten(
73 | 2,
74 | torch.Size(
75 | [
76 | h // pretrained.model.patch_size[1],
77 | w // pretrained.model.patch_size[0],
78 | ]
79 | ),
80 | )
81 | )
82 |
83 | if layer_1.ndim == 3:
84 | layer_1 = unflatten(layer_1)
85 | if layer_2.ndim == 3:
86 | layer_2 = unflatten(layer_2)
87 | if layer_3.ndim == 3:
88 | layer_3 = unflatten(layer_3)
89 | if layer_4.ndim == 3:
90 | layer_4 = unflatten(layer_4)
91 |
92 | layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
93 | layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
94 | layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
95 | layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
96 |
97 | return layer_1, layer_2, layer_3, layer_4
98 |
99 |
100 | def _resize_pos_embed(self, posemb, gs_h, gs_w):
101 | posemb_tok, posemb_grid = (
102 | posemb[:, : self.start_index],
103 | posemb[0, self.start_index :],
104 | )
105 |
106 | gs_old = int(math.sqrt(len(posemb_grid)))
107 |
108 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
109 | posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
110 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
111 |
112 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
113 |
114 | return posemb
115 |
116 |
117 | def forward_flex(self, x):
118 | b, c, h, w = x.shape
119 |
120 | pos_embed = self._resize_pos_embed(
121 | self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
122 | )
123 |
124 | B = x.shape[0]
125 |
126 | if hasattr(self.patch_embed, "backbone"):
127 | x = self.patch_embed.backbone(x)
128 | if isinstance(x, (list, tuple)):
129 | x = x[-1] # last feature if backbone outputs list/tuple of features
130 |
131 | x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
132 |
133 | if getattr(self, "dist_token", None) is not None:
134 | cls_tokens = self.cls_token.expand(
135 | B, -1, -1
136 | ) # stole cls_tokens impl from Phil Wang, thanks
137 | dist_token = self.dist_token.expand(B, -1, -1)
138 | x = torch.cat((cls_tokens, dist_token, x), dim=1)
139 | else:
140 | cls_tokens = self.cls_token.expand(
141 | B, -1, -1
142 | ) # stole cls_tokens impl from Phil Wang, thanks
143 | x = torch.cat((cls_tokens, x), dim=1)
144 |
145 | x = x + pos_embed
146 | x = self.pos_drop(x)
147 |
148 | for blk in self.blocks:
149 | x = blk(x)
150 |
151 | x = self.norm(x)
152 |
153 | return x
154 |
155 |
156 | activations = {}
157 |
158 |
159 | def get_activation(name):
160 | def hook(model, input, output):
161 | activations[name] = output
162 |
163 | return hook
164 |
165 |
166 | def get_readout_oper(vit_features, features, use_readout, start_index=1):
167 | if use_readout == "ignore":
168 | readout_oper = [Slice(start_index)] * len(features)
169 | elif use_readout == "add":
170 | readout_oper = [AddReadout(start_index)] * len(features)
171 | elif use_readout == "project":
172 | readout_oper = [
173 | ProjectReadout(vit_features, start_index) for out_feat in features
174 | ]
175 | else:
176 | assert (
177 | False
178 | ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
179 |
180 | return readout_oper
181 |
182 |
183 | def _make_vit_b16_backbone(
184 | model,
185 | features=[96, 192, 384, 768],
186 | size=[384, 384],
187 | hooks=[2, 5, 8, 11],
188 | vit_features=768,
189 | use_readout="ignore",
190 | start_index=1,
191 | ):
192 | pretrained = nn.Module()
193 |
194 | pretrained.model = model
195 | pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
196 | pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
197 | pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
198 | pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
199 |
200 | pretrained.activations = activations
201 |
202 | readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
203 |
204 | # 32, 48, 136, 384
205 | pretrained.act_postprocess1 = nn.Sequential(
206 | readout_oper[0],
207 | Transpose(1, 2),
208 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
209 | nn.Conv2d(
210 | in_channels=vit_features,
211 | out_channels=features[0],
212 | kernel_size=1,
213 | stride=1,
214 | padding=0,
215 | ),
216 | nn.ConvTranspose2d(
217 | in_channels=features[0],
218 | out_channels=features[0],
219 | kernel_size=4,
220 | stride=4,
221 | padding=0,
222 | bias=True,
223 | dilation=1,
224 | groups=1,
225 | ),
226 | )
227 |
228 | pretrained.act_postprocess2 = nn.Sequential(
229 | readout_oper[1],
230 | Transpose(1, 2),
231 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
232 | nn.Conv2d(
233 | in_channels=vit_features,
234 | out_channels=features[1],
235 | kernel_size=1,
236 | stride=1,
237 | padding=0,
238 | ),
239 | nn.ConvTranspose2d(
240 | in_channels=features[1],
241 | out_channels=features[1],
242 | kernel_size=2,
243 | stride=2,
244 | padding=0,
245 | bias=True,
246 | dilation=1,
247 | groups=1,
248 | ),
249 | )
250 |
251 | pretrained.act_postprocess3 = nn.Sequential(
252 | readout_oper[2],
253 | Transpose(1, 2),
254 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
255 | nn.Conv2d(
256 | in_channels=vit_features,
257 | out_channels=features[2],
258 | kernel_size=1,
259 | stride=1,
260 | padding=0,
261 | ),
262 | )
263 |
264 | pretrained.act_postprocess4 = nn.Sequential(
265 | readout_oper[3],
266 | Transpose(1, 2),
267 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
268 | nn.Conv2d(
269 | in_channels=vit_features,
270 | out_channels=features[3],
271 | kernel_size=1,
272 | stride=1,
273 | padding=0,
274 | ),
275 | nn.Conv2d(
276 | in_channels=features[3],
277 | out_channels=features[3],
278 | kernel_size=3,
279 | stride=2,
280 | padding=1,
281 | ),
282 | )
283 |
284 | pretrained.model.start_index = start_index
285 | pretrained.model.patch_size = [16, 16]
286 |
287 | # We inject this function into the VisionTransformer instances so that
288 | # we can use it with interpolated position embeddings without modifying the library source.
289 | pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
290 | pretrained.model._resize_pos_embed = types.MethodType(
291 | _resize_pos_embed, pretrained.model
292 | )
293 |
294 | return pretrained
295 |
296 |
297 | def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
298 | model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
299 |
300 | hooks = [5, 11, 17, 23] if hooks == None else hooks
301 | return _make_vit_b16_backbone(
302 | model,
303 | features=[256, 512, 1024, 1024],
304 | hooks=hooks,
305 | vit_features=1024,
306 | use_readout=use_readout,
307 | )
308 |
309 |
310 | def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
311 | model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
312 |
313 | hooks = [2, 5, 8, 11] if hooks == None else hooks
314 | return _make_vit_b16_backbone(
315 | model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
316 | )
317 |
318 |
319 | def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
320 | model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
321 |
322 | hooks = [2, 5, 8, 11] if hooks == None else hooks
323 | return _make_vit_b16_backbone(
324 | model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
325 | )
326 |
327 |
328 | def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
329 | model = timm.create_model(
330 | "vit_deit_base_distilled_patch16_384", pretrained=pretrained
331 | )
332 |
333 | hooks = [2, 5, 8, 11] if hooks == None else hooks
334 | return _make_vit_b16_backbone(
335 | model,
336 | features=[96, 192, 384, 768],
337 | hooks=hooks,
338 | use_readout=use_readout,
339 | start_index=2,
340 | )
341 |
342 |
343 | def _make_vit_b_rn50_backbone(
344 | model,
345 | features=[256, 512, 768, 768],
346 | size=[384, 384],
347 | hooks=[0, 1, 8, 11],
348 | vit_features=768,
349 | use_vit_only=False,
350 | use_readout="ignore",
351 | start_index=1,
352 | ):
353 | pretrained = nn.Module()
354 |
355 | pretrained.model = model
356 |
357 | if use_vit_only == True:
358 | pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
359 | pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
360 | else:
361 | pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
362 | get_activation("1")
363 | )
364 | pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
365 | get_activation("2")
366 | )
367 |
368 | pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
369 | pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
370 |
371 | pretrained.activations = activations
372 |
373 | readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
374 |
375 | if use_vit_only == True:
376 | pretrained.act_postprocess1 = nn.Sequential(
377 | readout_oper[0],
378 | Transpose(1, 2),
379 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
380 | nn.Conv2d(
381 | in_channels=vit_features,
382 | out_channels=features[0],
383 | kernel_size=1,
384 | stride=1,
385 | padding=0,
386 | ),
387 | nn.ConvTranspose2d(
388 | in_channels=features[0],
389 | out_channels=features[0],
390 | kernel_size=4,
391 | stride=4,
392 | padding=0,
393 | bias=True,
394 | dilation=1,
395 | groups=1,
396 | ),
397 | )
398 |
399 | pretrained.act_postprocess2 = nn.Sequential(
400 | readout_oper[1],
401 | Transpose(1, 2),
402 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
403 | nn.Conv2d(
404 | in_channels=vit_features,
405 | out_channels=features[1],
406 | kernel_size=1,
407 | stride=1,
408 | padding=0,
409 | ),
410 | nn.ConvTranspose2d(
411 | in_channels=features[1],
412 | out_channels=features[1],
413 | kernel_size=2,
414 | stride=2,
415 | padding=0,
416 | bias=True,
417 | dilation=1,
418 | groups=1,
419 | ),
420 | )
421 | else:
422 | pretrained.act_postprocess1 = nn.Sequential(
423 | nn.Identity(), nn.Identity(), nn.Identity()
424 | )
425 | pretrained.act_postprocess2 = nn.Sequential(
426 | nn.Identity(), nn.Identity(), nn.Identity()
427 | )
428 |
429 | pretrained.act_postprocess3 = nn.Sequential(
430 | readout_oper[2],
431 | Transpose(1, 2),
432 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
433 | nn.Conv2d(
434 | in_channels=vit_features,
435 | out_channels=features[2],
436 | kernel_size=1,
437 | stride=1,
438 | padding=0,
439 | ),
440 | )
441 |
442 | pretrained.act_postprocess4 = nn.Sequential(
443 | readout_oper[3],
444 | Transpose(1, 2),
445 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
446 | nn.Conv2d(
447 | in_channels=vit_features,
448 | out_channels=features[3],
449 | kernel_size=1,
450 | stride=1,
451 | padding=0,
452 | ),
453 | nn.Conv2d(
454 | in_channels=features[3],
455 | out_channels=features[3],
456 | kernel_size=3,
457 | stride=2,
458 | padding=1,
459 | ),
460 | )
461 |
462 | pretrained.model.start_index = start_index
463 | pretrained.model.patch_size = [16, 16]
464 |
465 | # We inject this function into the VisionTransformer instances so that
466 | # we can use it with interpolated position embeddings without modifying the library source.
467 | pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
468 |
469 | # We inject this function into the VisionTransformer instances so that
470 | # we can use it with interpolated position embeddings without modifying the library source.
471 | pretrained.model._resize_pos_embed = types.MethodType(
472 | _resize_pos_embed, pretrained.model
473 | )
474 |
475 | return pretrained
476 |
477 |
478 | def _make_pretrained_vitb_rn50_384(
479 | pretrained, use_readout="ignore", hooks=None, use_vit_only=False
480 | ):
481 | model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
482 |
483 | hooks = [0, 1, 8, 11] if hooks == None else hooks
484 | return _make_vit_b_rn50_backbone(
485 | model,
486 | features=[256, 512, 768, 768],
487 | size=[384, 384],
488 | hooks=hooks,
489 | use_vit_only=use_vit_only,
490 | use_readout=use_readout,
491 | )
492 |
--------------------------------------------------------------------------------
/ldm/modules/midas/utils.py:
--------------------------------------------------------------------------------
1 | """Utils for monoDepth."""
2 | import sys
3 | import re
4 | import numpy as np
5 | import cv2
6 | import torch
7 |
8 |
9 | def read_pfm(path):
10 | """Read pfm file.
11 |
12 | Args:
13 | path (str): path to file
14 |
15 | Returns:
16 | tuple: (data, scale)
17 | """
18 | with open(path, "rb") as file:
19 |
20 | color = None
21 | width = None
22 | height = None
23 | scale = None
24 | endian = None
25 |
26 | header = file.readline().rstrip()
27 | if header.decode("ascii") == "PF":
28 | color = True
29 | elif header.decode("ascii") == "Pf":
30 | color = False
31 | else:
32 | raise Exception("Not a PFM file: " + path)
33 |
34 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
35 | if dim_match:
36 | width, height = list(map(int, dim_match.groups()))
37 | else:
38 | raise Exception("Malformed PFM header.")
39 |
40 | scale = float(file.readline().decode("ascii").rstrip())
41 | if scale < 0:
42 | # little-endian
43 | endian = "<"
44 | scale = -scale
45 | else:
46 | # big-endian
47 | endian = ">"
48 |
49 | data = np.fromfile(file, endian + "f")
50 | shape = (height, width, 3) if color else (height, width)
51 |
52 | data = np.reshape(data, shape)
53 | data = np.flipud(data)
54 |
55 | return data, scale
56 |
57 |
58 | def write_pfm(path, image, scale=1):
59 | """Write pfm file.
60 |
61 | Args:
62 | path (str): pathto file
63 | image (array): data
64 | scale (int, optional): Scale. Defaults to 1.
65 | """
66 |
67 | with open(path, "wb") as file:
68 | color = None
69 |
70 | if image.dtype.name != "float32":
71 | raise Exception("Image dtype must be float32.")
72 |
73 | image = np.flipud(image)
74 |
75 | if len(image.shape) == 3 and image.shape[2] == 3: # color image
76 | color = True
77 | elif (
78 | len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
79 | ): # greyscale
80 | color = False
81 | else:
82 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
83 |
84 | file.write("PF\n" if color else "Pf\n".encode())
85 | file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
86 |
87 | endian = image.dtype.byteorder
88 |
89 | if endian == "<" or endian == "=" and sys.byteorder == "little":
90 | scale = -scale
91 |
92 | file.write("%f\n".encode() % scale)
93 |
94 | image.tofile(file)
95 |
96 |
97 | def read_image(path):
98 | """Read image and output RGB image (0-1).
99 |
100 | Args:
101 | path (str): path to file
102 |
103 | Returns:
104 | array: RGB image (0-1)
105 | """
106 | img = cv2.imread(path)
107 |
108 | if img.ndim == 2:
109 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
110 |
111 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
112 |
113 | return img
114 |
115 |
116 | def resize_image(img):
117 | """Resize image and make it fit for network.
118 |
119 | Args:
120 | img (array): image
121 |
122 | Returns:
123 | tensor: data ready for network
124 | """
125 | height_orig = img.shape[0]
126 | width_orig = img.shape[1]
127 |
128 | if width_orig > height_orig:
129 | scale = width_orig / 384
130 | else:
131 | scale = height_orig / 384
132 |
133 | height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
134 | width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
135 |
136 | img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
137 |
138 | img_resized = (
139 | torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
140 | )
141 | img_resized = img_resized.unsqueeze(0)
142 |
143 | return img_resized
144 |
145 |
146 | def resize_depth(depth, width, height):
147 | """Resize depth map and bring to CPU (numpy).
148 |
149 | Args:
150 | depth (tensor): depth
151 | width (int): image width
152 | height (int): image height
153 |
154 | Returns:
155 | array: processed depth
156 | """
157 | depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
158 |
159 | depth_resized = cv2.resize(
160 | depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
161 | )
162 |
163 | return depth_resized
164 |
165 | def write_depth(path, depth, bits=1):
166 | """Write depth map to pfm and png file.
167 |
168 | Args:
169 | path (str): filepath without extension
170 | depth (array): depth
171 | """
172 | write_pfm(path + ".pfm", depth.astype(np.float32))
173 |
174 | depth_min = depth.min()
175 | depth_max = depth.max()
176 |
177 | max_val = (2**(8*bits))-1
178 |
179 | if depth_max - depth_min > np.finfo("float").eps:
180 | out = max_val * (depth - depth_min) / (depth_max - depth_min)
181 | else:
182 | out = np.zeros(depth.shape, dtype=depth.type)
183 |
184 | if bits == 1:
185 | cv2.imwrite(path + ".png", out.astype("uint8"))
186 | elif bits == 2:
187 | cv2.imwrite(path + ".png", out.astype("uint16"))
188 |
189 | return
190 |
--------------------------------------------------------------------------------
/ldm/util.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | import torch
4 | from torch import optim
5 | import numpy as np
6 |
7 | from inspect import isfunction
8 | from PIL import Image, ImageDraw, ImageFont
9 |
10 |
11 | def log_txt_as_img(wh, xc, size=10):
12 | # wh a tuple of (width, height)
13 | # xc a list of captions to plot
14 | b = len(xc)
15 | txts = list()
16 | for bi in range(b):
17 | txt = Image.new("RGB", wh, color="white")
18 | draw = ImageDraw.Draw(txt)
19 | font = ImageFont.truetype('font/DejaVuSans.ttf', size=size)
20 | nc = int(40 * (wh[0] / 256))
21 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
22 |
23 | try:
24 | draw.text((0, 0), lines, fill="black", font=font)
25 | except UnicodeEncodeError:
26 | print("Cant encode string for logging. Skipping.")
27 |
28 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
29 | txts.append(txt)
30 | txts = np.stack(txts)
31 | txts = torch.tensor(txts)
32 | return txts
33 |
34 |
35 | def ismap(x):
36 | if not isinstance(x, torch.Tensor):
37 | return False
38 | return (len(x.shape) == 4) and (x.shape[1] > 3)
39 |
40 |
41 | def isimage(x):
42 | if not isinstance(x,torch.Tensor):
43 | return False
44 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
45 |
46 |
47 | def exists(x):
48 | return x is not None
49 |
50 |
51 | def default(val, d):
52 | if exists(val):
53 | return val
54 | return d() if isfunction(d) else d
55 |
56 |
57 | def mean_flat(tensor):
58 | """
59 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
60 | Take the mean over all non-batch dimensions.
61 | """
62 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
63 |
64 |
65 | def count_params(model, verbose=False):
66 | total_params = sum(p.numel() for p in model.parameters())
67 | if verbose:
68 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
69 | return total_params
70 |
71 |
72 | def instantiate_from_config(config):
73 | if not "target" in config:
74 | if config == '__is_first_stage__':
75 | return None
76 | elif config == "__is_unconditional__":
77 | return None
78 | raise KeyError("Expected key `target` to instantiate.")
79 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
80 |
81 |
82 | def get_obj_from_str(string, reload=False):
83 | module, cls = string.rsplit(".", 1)
84 | if reload:
85 | module_imp = importlib.import_module(module)
86 | importlib.reload(module_imp)
87 | return getattr(importlib.import_module(module, package=None), cls)
88 |
89 |
90 | class AdamWwithEMAandWings(optim.Optimizer):
91 | # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
92 | def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
93 | weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
94 | ema_power=1., param_names=()):
95 | """AdamW that saves EMA versions of the parameters."""
96 | if not 0.0 <= lr:
97 | raise ValueError("Invalid learning rate: {}".format(lr))
98 | if not 0.0 <= eps:
99 | raise ValueError("Invalid epsilon value: {}".format(eps))
100 | if not 0.0 <= betas[0] < 1.0:
101 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
102 | if not 0.0 <= betas[1] < 1.0:
103 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
104 | if not 0.0 <= weight_decay:
105 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
106 | if not 0.0 <= ema_decay <= 1.0:
107 | raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
108 | defaults = dict(lr=lr, betas=betas, eps=eps,
109 | weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
110 | ema_power=ema_power, param_names=param_names)
111 | super().__init__(params, defaults)
112 |
113 | def __setstate__(self, state):
114 | super().__setstate__(state)
115 | for group in self.param_groups:
116 | group.setdefault('amsgrad', False)
117 |
118 | @torch.no_grad()
119 | def step(self, closure=None):
120 | """Performs a single optimization step.
121 | Args:
122 | closure (callable, optional): A closure that reevaluates the model
123 | and returns the loss.
124 | """
125 | loss = None
126 | if closure is not None:
127 | with torch.enable_grad():
128 | loss = closure()
129 |
130 | for group in self.param_groups:
131 | params_with_grad = []
132 | grads = []
133 | exp_avgs = []
134 | exp_avg_sqs = []
135 | ema_params_with_grad = []
136 | state_sums = []
137 | max_exp_avg_sqs = []
138 | state_steps = []
139 | amsgrad = group['amsgrad']
140 | beta1, beta2 = group['betas']
141 | ema_decay = group['ema_decay']
142 | ema_power = group['ema_power']
143 |
144 | for p in group['params']:
145 | if p.grad is None:
146 | continue
147 | params_with_grad.append(p)
148 | if p.grad.is_sparse:
149 | raise RuntimeError('AdamW does not support sparse gradients')
150 | grads.append(p.grad)
151 |
152 | state = self.state[p]
153 |
154 | # State initialization
155 | if len(state) == 0:
156 | state['step'] = 0
157 | # Exponential moving average of gradient values
158 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
159 | # Exponential moving average of squared gradient values
160 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
161 | if amsgrad:
162 | # Maintains max of all exp. moving avg. of sq. grad. values
163 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
164 | # Exponential moving average of parameter values
165 | state['param_exp_avg'] = p.detach().float().clone()
166 |
167 | exp_avgs.append(state['exp_avg'])
168 | exp_avg_sqs.append(state['exp_avg_sq'])
169 | ema_params_with_grad.append(state['param_exp_avg'])
170 |
171 | if amsgrad:
172 | max_exp_avg_sqs.append(state['max_exp_avg_sq'])
173 |
174 | # update the steps for each param group update
175 | state['step'] += 1
176 | # record the step after step update
177 | state_steps.append(state['step'])
178 |
179 | optim._functional.adamw(params_with_grad,
180 | grads,
181 | exp_avgs,
182 | exp_avg_sqs,
183 | max_exp_avg_sqs,
184 | state_steps,
185 | amsgrad=amsgrad,
186 | beta1=beta1,
187 | beta2=beta2,
188 | lr=group['lr'],
189 | weight_decay=group['weight_decay'],
190 | eps=group['eps'],
191 | maximize=False)
192 |
193 | cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
194 | for param, ema_param in zip(params_with_grad, ema_params_with_grad):
195 | ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
196 |
197 | return loss
--------------------------------------------------------------------------------
/preprocessor/depth_preprocessor.py:
--------------------------------------------------------------------------------
1 | class Preprocessor:
2 | def __init__(self) -> None:
3 | pass
4 |
5 | def get_depth(self, input_dir, file_name):
6 | return
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | albumentations==1.3.1
2 | azureml==0.2.7
3 | chumpy==0.70
4 | einops==0.7.0
5 | matplotlib==3.7.1
6 | mediapipe
7 | numpy==1.23.5
8 | omegaconf==2.1.1
9 | opencv_contrib_python==4.7.0.72
10 | opencv_python==4.7.0.72
11 | opencv_python_headless==4.7.0.72
12 | Pillow==9.4.0
13 | pytorch_lightning==1.4.2
14 | pytorch_pretrained_bert==0.6.2
15 | safetensors==0.3.3
16 | scipy==1.9.0
17 | timm==0.6.13
18 | torch==2.0.0
19 | torchvision==0.15.1
20 | tqdm==4.65.0
21 | transformers==4.27.4
22 | trimesh[easy]==3.23.5
23 | yacs==0.1.8
24 |
25 | # If encounter any error, please see ControlNet and MeshGraphormer for more complete package requirements.
26 |
--------------------------------------------------------------------------------
/scripts/_gcnn.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import torch
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import scipy.sparse
6 | import math
7 |
8 | class SparseMM(torch.autograd.Function):
9 | """Redefine sparse @ dense matrix multiplication to enable backpropagation.
10 | The builtin matrix multiplication operation does not support backpropagation in some cases.
11 | """
12 | @staticmethod
13 | def forward(ctx, sparse, dense):
14 | ctx.req_grad = dense.requires_grad
15 | ctx.save_for_backward(sparse)
16 | return torch.matmul(sparse, dense)
17 |
18 | @staticmethod
19 | def backward(ctx, grad_output):
20 | grad_input = None
21 | sparse, = ctx.saved_tensors
22 | if ctx.req_grad:
23 | grad_input = torch.matmul(sparse.t(), grad_output)
24 | return None, grad_input
25 |
26 | def spmm(sparse, dense):
27 | return SparseMM.apply(sparse, dense)
28 |
29 |
30 | def gelu(x):
31 | """Implementation of the gelu activation function.
32 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
33 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
34 | Also see https://arxiv.org/abs/1606.08415
35 | """
36 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
37 |
38 | class BertLayerNorm(torch.nn.Module):
39 | def __init__(self, hidden_size, eps=1e-12):
40 | """Construct a layernorm module in the TF style (epsilon inside the square root).
41 | """
42 | super(BertLayerNorm, self).__init__()
43 | self.weight = torch.nn.Parameter(torch.ones(hidden_size))
44 | self.bias = torch.nn.Parameter(torch.zeros(hidden_size))
45 | self.variance_epsilon = eps
46 |
47 | def forward(self, x):
48 | u = x.mean(-1, keepdim=True)
49 | s = (x - u).pow(2).mean(-1, keepdim=True)
50 | x = (x - u) / torch.sqrt(s + self.variance_epsilon)
51 | return self.weight * x + self.bias
52 |
53 |
54 | class GraphResBlock(torch.nn.Module):
55 | """
56 | Graph Residual Block similar to the Bottleneck Residual Block in ResNet
57 | """
58 | def __init__(self, in_channels, out_channels, mesh_type='body'):
59 | super(GraphResBlock, self).__init__()
60 | self.in_channels = in_channels
61 | self.out_channels = out_channels
62 | self.lin1 = GraphLinear(in_channels, out_channels // 2)
63 | self.conv = GraphConvolution(out_channels // 2, out_channels // 2, mesh_type)
64 | self.lin2 = GraphLinear(out_channels // 2, out_channels)
65 | self.skip_conv = GraphLinear(in_channels, out_channels)
66 | # print('Use BertLayerNorm in GraphResBlock')
67 | self.pre_norm = BertLayerNorm(in_channels)
68 | self.norm1 = BertLayerNorm(out_channels // 2)
69 | self.norm2 = BertLayerNorm(out_channels // 2)
70 |
71 | def forward(self, x):
72 | trans_y = F.relu(self.pre_norm(x)).transpose(1,2)
73 | y = self.lin1(trans_y).transpose(1,2)
74 |
75 | y = F.relu(self.norm1(y))
76 | y = self.conv(y)
77 |
78 | trans_y = F.relu(self.norm2(y)).transpose(1,2)
79 | y = self.lin2(trans_y).transpose(1,2)
80 |
81 | z = x+y
82 |
83 | return z
84 |
85 | # class GraphResBlock(torch.nn.Module):
86 | # """
87 | # Graph Residual Block similar to the Bottleneck Residual Block in ResNet
88 | # """
89 | # def __init__(self, in_channels, out_channels, mesh_type='body'):
90 | # super(GraphResBlock, self).__init__()
91 | # self.in_channels = in_channels
92 | # self.out_channels = out_channels
93 | # self.conv = GraphConvolution(self.in_channels, self.out_channels, mesh_type)
94 | # print('Use BertLayerNorm and GeLU in GraphResBlock')
95 | # self.norm = BertLayerNorm(self.out_channels)
96 | # def forward(self, x):
97 | # y = self.conv(x)
98 | # y = self.norm(y)
99 | # y = gelu(y)
100 | # z = x+y
101 | # return z
102 |
103 | class GraphLinear(torch.nn.Module):
104 | """
105 | Generalization of 1x1 convolutions on Graphs
106 | """
107 | def __init__(self, in_channels, out_channels):
108 | super(GraphLinear, self).__init__()
109 | self.in_channels = in_channels
110 | self.out_channels = out_channels
111 | self.W = torch.nn.Parameter(torch.FloatTensor(out_channels, in_channels))
112 | self.b = torch.nn.Parameter(torch.FloatTensor(out_channels))
113 | self.reset_parameters()
114 |
115 | def reset_parameters(self):
116 | w_stdv = 1 / (self.in_channels * self.out_channels)
117 | self.W.data.uniform_(-w_stdv, w_stdv)
118 | self.b.data.uniform_(-w_stdv, w_stdv)
119 |
120 | def forward(self, x):
121 | return torch.matmul(self.W[None, :], x) + self.b[None, :, None]
122 |
123 | class GraphConvolution(torch.nn.Module):
124 | """Simple GCN layer, similar to https://arxiv.org/abs/1609.02907."""
125 | def __init__(self, in_features, out_features, mesh='body', bias=True):
126 | super(GraphConvolution, self).__init__()
127 | device=torch.device('cuda')
128 | self.in_features = in_features
129 | self.out_features = out_features
130 |
131 | if mesh=='body':
132 | adj_indices = torch.load('MeshGraphormer/src/modeling/data/smpl_431_adjmat_indices.pt')
133 | adj_mat_value = torch.load('MeshGraphormer/src/modeling/data/smpl_431_adjmat_values.pt')
134 | adj_mat_size = torch.load('MeshGraphormer/src/modeling/data/smpl_431_adjmat_size.pt')
135 | elif mesh=='hand':
136 | adj_indices = torch.load('MeshGraphormer/src/modeling/data/mano_195_adjmat_indices.pt')
137 | adj_mat_value = torch.load('MeshGraphormer/src/modeling/data/mano_195_adjmat_values.pt')
138 | adj_mat_size = torch.load('MeshGraphormer/src/modeling/data/mano_195_adjmat_size.pt')
139 |
140 | self.adjmat = torch.sparse_coo_tensor(adj_indices, adj_mat_value, size=adj_mat_size).to(device)
141 |
142 | self.weight = torch.nn.Parameter(torch.FloatTensor(in_features, out_features))
143 | if bias:
144 | self.bias = torch.nn.Parameter(torch.FloatTensor(out_features))
145 | else:
146 | self.register_parameter('bias', None)
147 | self.reset_parameters()
148 |
149 | def reset_parameters(self):
150 | # stdv = 1. / math.sqrt(self.weight.size(1))
151 | stdv = 6. / math.sqrt(self.weight.size(0) + self.weight.size(1))
152 | self.weight.data.uniform_(-stdv, stdv)
153 | if self.bias is not None:
154 | self.bias.data.uniform_(-stdv, stdv)
155 |
156 | def forward(self, x):
157 | if x.ndimension() == 2:
158 | support = torch.matmul(x, self.weight)
159 | output = torch.matmul(self.adjmat, support)
160 | if self.bias is not None:
161 | output = output + self.bias
162 | return output
163 | else:
164 | output = []
165 | for i in range(x.shape[0]):
166 | support = torch.matmul(x[i], self.weight)
167 | # output.append(torch.matmul(self.adjmat, support))
168 | output.append(spmm(self.adjmat, support))
169 | output = torch.stack(output, dim=0)
170 | if self.bias is not None:
171 | output = output + self.bias
172 | return output
173 |
174 | def __repr__(self):
175 | return self.__class__.__name__ + ' (' \
176 | + str(self.in_features) + ' -> ' \
177 | + str(self.out_features) + ')'
--------------------------------------------------------------------------------
/scripts/_mano.py:
--------------------------------------------------------------------------------
1 | """
2 | This file contains the MANO defination and mesh sampling operations for MANO mesh
3 |
4 | Adapted from opensource projects
5 | MANOPTH (https://github.com/hassony2/manopth)
6 | Pose2Mesh (https://github.com/hongsukchoi/Pose2Mesh_RELEASE)
7 | GraphCMR (https://github.com/nkolot/GraphCMR/)
8 | """
9 |
10 | from __future__ import division
11 | import numpy as np
12 | import torch
13 | import torch.nn as nn
14 | import os.path as osp
15 | import json
16 | import code
17 | from manopth.manolayer import ManoLayer
18 | import scipy.sparse
19 | import src.modeling.data.config as cfg
20 |
21 | class MANO(nn.Module):
22 | def __init__(self):
23 | super(MANO, self).__init__()
24 |
25 | self.mano_dir = 'MeshGraphormer/src/modeling/data'
26 | self.layer = self.get_layer()
27 | self.vertex_num = 778
28 | self.face = self.layer.th_faces.numpy()
29 | self.joint_regressor = self.layer.th_J_regressor.numpy()
30 |
31 | self.joint_num = 21
32 | self.joints_name = ('Wrist', 'Thumb_1', 'Thumb_2', 'Thumb_3', 'Thumb_4', 'Index_1', 'Index_2', 'Index_3', 'Index_4', 'Middle_1', 'Middle_2', 'Middle_3', 'Middle_4', 'Ring_1', 'Ring_2', 'Ring_3', 'Ring_4', 'Pinky_1', 'Pinky_2', 'Pinky_3', 'Pinky_4')
33 | self.skeleton = ( (0,1), (0,5), (0,9), (0,13), (0,17), (1,2), (2,3), (3,4), (5,6), (6,7), (7,8), (9,10), (10,11), (11,12), (13,14), (14,15), (15,16), (17,18), (18,19), (19,20) )
34 | self.root_joint_idx = self.joints_name.index('Wrist')
35 |
36 | # add fingertips to joint_regressor
37 | self.fingertip_vertex_idx = [745, 317, 444, 556, 673] # mesh vertex idx (right hand)
38 | thumbtip_onehot = np.array([1 if i == 745 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1)
39 | indextip_onehot = np.array([1 if i == 317 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1)
40 | middletip_onehot = np.array([1 if i == 445 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1)
41 | ringtip_onehot = np.array([1 if i == 556 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1)
42 | pinkytip_onehot = np.array([1 if i == 673 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1)
43 | self.joint_regressor = np.concatenate((self.joint_regressor, thumbtip_onehot, indextip_onehot, middletip_onehot, ringtip_onehot, pinkytip_onehot))
44 | self.joint_regressor = self.joint_regressor[[0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20],:]
45 | joint_regressor_torch = torch.from_numpy(self.joint_regressor).float()
46 | self.register_buffer('joint_regressor_torch', joint_regressor_torch)
47 |
48 | def get_layer(self):
49 | return ManoLayer(mano_root=osp.join(self.mano_dir), flat_hand_mean=False, use_pca=False) # load right hand MANO model
50 |
51 | def get_3d_joints(self, vertices):
52 | """
53 | This method is used to get the joint locations from the SMPL mesh
54 | Input:
55 | vertices: size = (B, 778, 3)
56 | Output:
57 | 3D joints: size = (B, 21, 3)
58 | """
59 | joints = torch.einsum('bik,ji->bjk', [vertices, self.joint_regressor_torch])
60 | return joints
61 |
62 |
63 | class SparseMM(torch.autograd.Function):
64 | """Redefine sparse @ dense matrix multiplication to enable backpropagation.
65 | The builtin matrix multiplication operation does not support backpropagation in some cases.
66 | """
67 | @staticmethod
68 | def forward(ctx, sparse, dense):
69 | ctx.req_grad = dense.requires_grad
70 | ctx.save_for_backward(sparse)
71 | return torch.matmul(sparse, dense)
72 |
73 | @staticmethod
74 | def backward(ctx, grad_output):
75 | grad_input = None
76 | sparse, = ctx.saved_tensors
77 | if ctx.req_grad:
78 | grad_input = torch.matmul(sparse.t(), grad_output)
79 | return None, grad_input
80 |
81 | def spmm(sparse, dense):
82 | return SparseMM.apply(sparse, dense)
83 |
84 |
85 | def scipy_to_pytorch(A, U, D):
86 | """Convert scipy sparse matrices to pytorch sparse matrix."""
87 | ptU = []
88 | ptD = []
89 |
90 | for i in range(len(U)):
91 | u = scipy.sparse.coo_matrix(U[i])
92 | i = torch.LongTensor(np.array([u.row, u.col]))
93 | v = torch.FloatTensor(u.data)
94 | ptU.append(torch.sparse.FloatTensor(i, v, u.shape))
95 |
96 | for i in range(len(D)):
97 | d = scipy.sparse.coo_matrix(D[i])
98 | i = torch.LongTensor(np.array([d.row, d.col]))
99 | v = torch.FloatTensor(d.data)
100 | ptD.append(torch.sparse.FloatTensor(i, v, d.shape))
101 |
102 | return ptU, ptD
103 |
104 |
105 | def adjmat_sparse(adjmat, nsize=1):
106 | """Create row-normalized sparse graph adjacency matrix."""
107 | adjmat = scipy.sparse.csr_matrix(adjmat)
108 | if nsize > 1:
109 | orig_adjmat = adjmat.copy()
110 | for _ in range(1, nsize):
111 | adjmat = adjmat * orig_adjmat
112 | adjmat.data = np.ones_like(adjmat.data)
113 | for i in range(adjmat.shape[0]):
114 | adjmat[i,i] = 1
115 | num_neighbors = np.array(1 / adjmat.sum(axis=-1))
116 | adjmat = adjmat.multiply(num_neighbors)
117 | adjmat = scipy.sparse.coo_matrix(adjmat)
118 | row = adjmat.row
119 | col = adjmat.col
120 | data = adjmat.data
121 | i = torch.LongTensor(np.array([row, col]))
122 | v = torch.from_numpy(data).float()
123 | adjmat = torch.sparse.FloatTensor(i, v, adjmat.shape)
124 | return adjmat
125 |
126 | def get_graph_params(filename, nsize=1):
127 | """Load and process graph adjacency matrix and upsampling/downsampling matrices."""
128 | data = np.load(filename, encoding='latin1', allow_pickle=True)
129 | A = data['A']
130 | U = data['U']
131 | D = data['D']
132 | U, D = scipy_to_pytorch(A, U, D)
133 | A = [adjmat_sparse(a, nsize=nsize) for a in A]
134 | return A, U, D
135 |
136 |
137 | class Mesh(object):
138 | """Mesh object that is used for handling certain graph operations."""
139 | def __init__(self, filename=cfg.MANO_sampling_matrix,
140 | num_downsampling=1, nsize=1, device=torch.device('cuda')):
141 | self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize)
142 | # self._A = [a.to(device) for a in self._A]
143 | self._U = [u.to(device) for u in self._U]
144 | self._D = [d.to(device) for d in self._D]
145 | self.num_downsampling = num_downsampling
146 |
147 | def downsample(self, x, n1=0, n2=None):
148 | """Downsample mesh."""
149 | if n2 is None:
150 | n2 = self.num_downsampling
151 | if x.ndimension() < 3:
152 | for i in range(n1, n2):
153 | x = spmm(self._D[i], x)
154 | elif x.ndimension() == 3:
155 | out = []
156 | for i in range(x.shape[0]):
157 | y = x[i]
158 | for j in range(n1, n2):
159 | y = spmm(self._D[j], y)
160 | out.append(y)
161 | x = torch.stack(out, dim=0)
162 | return x
163 |
164 | def upsample(self, x, n1=1, n2=0):
165 | """Upsample mesh."""
166 | if x.ndimension() < 3:
167 | for i in reversed(range(n2, n1)):
168 | x = spmm(self._U[i], x)
169 | elif x.ndimension() == 3:
170 | out = []
171 | for i in range(x.shape[0]):
172 | y = x[i]
173 | for j in reversed(range(n2, n1)):
174 | y = spmm(self._U[j], y)
175 | out.append(y)
176 | x = torch.stack(out, dim=0)
177 | return x
178 |
--------------------------------------------------------------------------------
/scripts/config.py:
--------------------------------------------------------------------------------
1 | """
2 | This file contains definitions of useful data stuctures and the paths
3 | for the datasets and data files necessary to run the code.
4 |
5 | Adapted from opensource project GraphCMR (https://github.com/nkolot/GraphCMR/) and Pose2Mesh (https://github.com/hongsukchoi/Pose2Mesh_RELEASE)
6 |
7 | """
8 |
9 | from os.path import join
10 | folder_path = 'MeshGraphormer/src/modeling/'
11 | JOINT_REGRESSOR_TRAIN_EXTRA = folder_path + 'data/J_regressor_extra.npy'
12 | JOINT_REGRESSOR_H36M_correct = folder_path + 'data/J_regressor_h36m_correct.npy'
13 | SMPL_FILE = folder_path + 'data/basicModel_neutral_lbs_10_207_0_v1.0.0.pkl'
14 | SMPL_Male = folder_path + 'data/basicModel_m_lbs_10_207_0_v1.0.0.pkl'
15 | SMPL_Female = folder_path + 'data/basicModel_f_lbs_10_207_0_v1.0.0.pkl'
16 | SMPL_sampling_matrix = folder_path + 'data/mesh_downsampling.npz'
17 | MANO_FILE = folder_path + 'data/MANO_RIGHT.pkl'
18 | MANO_sampling_matrix = folder_path + 'data/mano_downsampling.npz'
19 |
20 | JOINTS_IDX = [8, 5, 29, 30, 4, 7, 21, 19, 17, 16, 18, 20, 31, 32, 33, 34, 35, 36, 37, 24, 26, 25, 28, 27]
21 |
22 |
23 | """
24 | We follow the body joint definition, loss functions, and evaluation metrics from
25 | open source project GraphCMR (https://github.com/nkolot/GraphCMR/)
26 |
27 | Each dataset uses different sets of joints.
28 | We use a superset of 24 joints such that we include all joints from every dataset.
29 | If a dataset doesn't provide annotations for a specific joint, we simply ignore it.
30 | The joints used here are:
31 | """
32 | J24_NAME = ('R_Ankle', 'R_Knee', 'R_Hip', 'L_Hip', 'L_Knee', 'L_Ankle', 'R_Wrist', 'R_Elbow', 'R_Shoulder', 'L_Shoulder',
33 | 'L_Elbow','L_Wrist','Neck','Top_of_Head','Pelvis','Thorax','Spine','Jaw','Head','Nose','L_Eye','R_Eye','L_Ear','R_Ear')
34 | H36M_J17_NAME = ( 'Pelvis', 'R_Hip', 'R_Knee', 'R_Ankle', 'L_Hip', 'L_Knee', 'L_Ankle', 'Torso', 'Neck', 'Nose', 'Head',
35 | 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'R_Shoulder', 'R_Elbow', 'R_Wrist')
36 | J24_TO_J14 = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18]
37 | H36M_J17_TO_J14 = [3, 2, 1, 4, 5, 6, 16, 15, 14, 11, 12, 13, 8, 10]
38 |
39 | """
40 | We follow the hand joint definition and mesh topology from
41 | open source project Manopth (https://github.com/hassony2/manopth)
42 |
43 | The hand joints used here are:
44 | """
45 | J_NAME = ('Wrist', 'Thumb_1', 'Thumb_2', 'Thumb_3', 'Thumb_4', 'Index_1', 'Index_2', 'Index_3', 'Index_4', 'Middle_1',
46 | 'Middle_2', 'Middle_3', 'Middle_4', 'Ring_1', 'Ring_2', 'Ring_3', 'Ring_4', 'Pinky_1', 'Pinky_2', 'Pinky_3', 'Pinky_4')
47 | ROOT_INDEX = 0
--------------------------------------------------------------------------------
/scripts/download_models.sh:
--------------------------------------------------------------------------------
1 | # --------------------------------
2 | # Setup
3 | # --------------------------------
4 | export REPO_DIR=$PWD
5 | if [ ! -d $REPO_DIR/models ] ; then
6 | mkdir -p $REPO_DIR/models
7 | fi
8 | BLOB='https://datarelease.blob.core.windows.net/metro'
9 |
10 |
11 | # --------------------------------
12 | # Download our pre-trained models
13 | # --------------------------------
14 | if [ ! -d $REPO_DIR/models/graphormer_release ] ; then
15 | mkdir -p $REPO_DIR/models/graphormer_release
16 | fi
17 |
18 | # (3) Mesh Graphormer for hand mesh reconstruction (trained on FreiHAND)
19 | wget -nc $BLOB/models/graphormer_hand_state_dict.bin -O $REPO_DIR/models/graphormer_release/graphormer_hand_state_dict.bin
20 |
21 |
22 | # --------------------------------
23 | # Download the ImageNet pre-trained HRNet models
24 | # The weights are provided by https://github.com/HRNet/HRNet-Image-Classification
25 | # --------------------------------
26 | if [ ! -d $REPO_DIR/models/hrnet ] ; then
27 | mkdir -p $REPO_DIR/models/hrnet
28 | fi
29 | wget -nc $BLOB/models/hrnetv2_w64_imagenet_pretrained.pth -O $REPO_DIR/models/hrnet/hrnetv2_w64_imagenet_pretrained.pth
30 | wget -nc $BLOB/models/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml -O $REPO_DIR/models/hrnet/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml
31 |
32 |
33 |
34 |
--------------------------------------------------------------------------------
/test/1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wenquanlu/HandRefiner/f07e196298bafe871064b2952587548e28ccd467/test/1.jpg
--------------------------------------------------------------------------------
/test/2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wenquanlu/HandRefiner/f07e196298bafe871064b2952587548e28ccd467/test/2.jpg
--------------------------------------------------------------------------------
/test/3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wenquanlu/HandRefiner/f07e196298bafe871064b2952587548e28ccd467/test/3.jpg
--------------------------------------------------------------------------------
/test/4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wenquanlu/HandRefiner/f07e196298bafe871064b2952587548e28ccd467/test/4.jpg
--------------------------------------------------------------------------------
/test/5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wenquanlu/HandRefiner/f07e196298bafe871064b2952587548e28ccd467/test/5.jpg
--------------------------------------------------------------------------------
/test/test.json:
--------------------------------------------------------------------------------
1 | {"img": "1.jpg", "txt": "a man facing the camera, making a hand gesture, indoor"}
2 | {"img": "2.jpg", "txt": "a woman facing the camera, making a hand gesture, indoor"}
3 | {"img": "3.jpg", "txt": "a man facing the camera, making a hand gesture, indoor"}
4 | {"img": "4.jpg", "txt": "a man facing the camera, making a hand gesture, indoor"}
5 | {"img": "5.jpg", "txt": "a woman facing the camera, making a hand gesture, indoor"}
--------------------------------------------------------------------------------
/training/README.md:
--------------------------------------------------------------------------------
1 | ## Training Script - train.py
2 | The training script should be placed at the same level of the cldm folder.
3 | Some paths needed to be manually set:
4 |
5 | L40: path to SD1.5
6 | L43: path to depth controlnet weight
7 |
8 | ## Data Loader - control_synthcompositedata.py
9 | The loader should be placed in ldm/data/
10 |
11 | Some paths needed to be mannally set:
12 |
13 | dataset needs to be structured as:
14 | ```bash
15 | |- dataset1
16 | | |- image
17 | | |- mask
18 | | |- pose
19 | | |- prompt.json
20 | ```
21 | Some paths needed to be manually set:
22 | L9: path to dataset 1
23 | L10: path to dataset 2
24 | L18: path to dataset 1 prompt json file
25 | L23: path to dataset 2 prompt json file
26 |
27 | Each prompt json file are structured as:
28 | ```json
29 | {"jpg": "image name", "txt": "text prompt", "dataset": "dataset identifier (RHD|synthesisai)"}
30 | ```
31 |
--------------------------------------------------------------------------------
/training/control_synthcompositedata.py:
--------------------------------------------------------------------------------
1 | import json
2 | import cv2
3 | import numpy as np
4 | from PIL import Image
5 | import random
6 |
7 | from torch.utils.data import Dataset
8 |
9 | DATA_PATH_1 = "../RHD/RHD_published_v2/"
10 | DATA_PATH_2 = "../synthesisai/"
11 |
12 | abbrev_dict = {"RHD": DATA_PATH_1,
13 | "synthesisai": DATA_PATH_2}
14 |
15 | class Control_composite_Hand_synth_data(Dataset):
16 | def __init__(self):
17 | self.data = []
18 | with open('../RHD/RHD_published_v2/embedded_rgb_caption.json', 'rt') as f1:
19 | for line in f1:
20 | item = json.loads(line)
21 | item['dataset'] = 'RHD'
22 | self.data.append(item)
23 | with open('../synthesisai/embedded_rgb_caption.json', 'rt') as f2:
24 | for line in f2:
25 | item = json.loads(line)
26 | item['dataset'] = 'synthesisai'
27 | self.data.append(item)
28 | def __len__(self):
29 | return len(self.data)
30 |
31 | def __getitem__(self, idx):
32 | item = self.data[idx]
33 | source_filename = item['jpg']
34 | prompt = item['txt']
35 | dataset = item['dataset']
36 | datapath = abbrev_dict[dataset]
37 | if random.random() < 0.5:
38 | prompt = ""
39 | source = cv2.imread(datapath + "image/" + source_filename)
40 | source = (source.astype(np.float32) / 127.5) - 1.0
41 | source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB)
42 |
43 | mask = np.array(Image.open(datapath + "mask/" + source_filename).convert("L"))
44 | mask = mask.astype(np.float32)/255.0
45 | mask = mask[None]
46 | mask[mask < 0.5] = 0
47 | mask[mask >= 0.5] = 1
48 | mask = np.transpose(mask, [1, 2, 0])
49 |
50 | hint = cv2.imread(datapath + "pose/" + source_filename)
51 | hint = cv2.cvtColor(hint, cv2.COLOR_BGR2RGB)
52 |
53 | hint = hint.astype(np.float32) / 255.0
54 |
55 | masked_image = source * (mask < 0.5)
56 | return dict(jpg=source, txt=prompt, hint=hint, mask=mask, masked_image=masked_image)
--------------------------------------------------------------------------------
/training/train.py:
--------------------------------------------------------------------------------
1 | from ldm.data.control_synthcompositedata import Control_composite_Hand_synth_data
2 | import torch
3 | import pytorch_lightning as pl
4 | from torch.utils.data import DataLoader
5 | from cldm.model import create_model, load_state_dict
6 | from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
7 | from einops import rearrange
8 | from PIL import Image
9 | import numpy as np
10 | import os
11 | from cldm.logger import ImageLogger
12 | import argparse
13 |
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument("--devices", default="0", type=str, help="comma delimited list of devices")
16 | parser.add_argument("--batchsize", default=4, type=int)
17 | parser.add_argument("--acc_grad", default=4, type=int)
18 | parser.add_argument("--max_epochs", default=3, type=int)
19 | args = parser.parse_args()
20 | args.devices = [int(n) for n in args.devices.split(",")]
21 |
22 | def get_state_dict(d):
23 | return d.get('state_dict', d)
24 | def load_state_dict(ckpt_path, location='cpu'):
25 | _, extension = os.path.splitext(ckpt_path)
26 | if extension.lower() == ".safetensors":
27 | import safetensors.torch
28 | state_dict = safetensors.torch.load_file(ckpt_path, device=location)
29 | else:
30 | state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
31 | state_dict = get_state_dict(state_dict)
32 | print(f'Loaded state_dict from [{ckpt_path}]')
33 | return state_dict
34 |
35 | learning_rate = 1e-5
36 |
37 | model = create_model("control_depth_inpaint.yaml")
38 |
39 | #### load the SD inpainting weights
40 | states = load_state_dict("./sd-v1-5-inpainting.ckpt", location='cpu')
41 | model.load_state_dict(states, strict=False)
42 |
43 | control_states = load_state_dict("./models/control_v11f1p_sd15_depth.pth")
44 | model.load_state_dict(control_states, strict=False)
45 |
46 |
47 | model.learning_rate = learning_rate
48 | model.sd_locked = True
49 | model.only_mid_control = False
50 |
51 | dataset = Control_composite_Hand_synth_data()
52 |
53 | checkpoint_callback = ModelCheckpoint(save_top_k=-1, monitor="epoch")
54 |
55 | #### start of the training expectation: the model should behave the same to standalone depth controlnet + inpainting SD
56 | dataloader = DataLoader(dataset, num_workers=8, batch_size=args.batchsize, shuffle=True)
57 | trainer = pl.Trainer(precision=32, max_epochs=args.max_epochs, accelerator="gpu", devices=args.devices, accumulate_grad_batches=args.acc_grad, callbacks=[ImageLogger(), checkpoint_callback], strategy='ddp')
58 | trainer.fit(model, dataloader)
--------------------------------------------------------------------------------