├── LICENSE ├── README.md ├── annotator ├── canny │ └── __init__.py ├── ckpts │ └── ckpts.txt ├── hed │ └── __init__.py ├── midas │ ├── __init__.py │ ├── api.py │ ├── midas │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── blocks.py │ │ ├── dpt_depth.py │ │ ├── midas_net.py │ │ ├── midas_net_custom.py │ │ ├── transforms.py │ │ └── vit.py │ └── utils.py ├── mlsd │ ├── __init__.py │ ├── models │ │ ├── mbv2_mlsd_large.py │ │ └── mbv2_mlsd_tiny.py │ └── utils.py ├── openpose │ ├── __init__.py │ ├── body.py │ ├── hand.py │ ├── model.py │ └── util.py ├── uniformer │ ├── __init__.py │ └── configs │ │ └── _base_ │ │ ├── default_runtime.py │ │ └── schedules │ │ ├── schedule_160k.py │ │ ├── schedule_20k.py │ │ ├── schedule_40k.py │ │ └── schedule_80k.py └── util.py ├── cldm ├── cldm.py ├── hack.py ├── logger.py └── model.py ├── config.py ├── cycleNet ├── cycleNet.py ├── cycleNet_fast.py ├── ddim_hacked.py ├── logger.py └── model.py ├── docs └── train.md ├── environment.yaml ├── gradio_cycleNet.py ├── ldm ├── data │ ├── __init__.py │ └── util.py ├── models │ ├── autoencoder.py │ └── diffusion │ │ ├── __init__.py │ │ ├── ddim.py │ │ ├── ddpm.py │ │ ├── dpm_solver │ │ ├── __init__.py │ │ ├── dpm_solver.py │ │ └── sampler.py │ │ ├── plms.py │ │ └── sampling_util.py ├── modules │ ├── attention.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── model.py │ │ ├── openaimodel.py │ │ ├── upscaling.py │ │ └── util.py │ ├── distributions │ │ ├── __init__.py │ │ └── distributions.py │ ├── ema.py │ ├── encoders │ │ ├── __init__.py │ │ └── modules.py │ ├── image_degradation │ │ ├── __init__.py │ │ ├── bsrgan.py │ │ ├── bsrgan_light.py │ │ ├── utils │ │ │ └── test.png │ │ └── utils_image.py │ └── midas │ │ ├── __init__.py │ │ ├── api.py │ │ ├── midas │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── blocks.py │ │ ├── dpt_depth.py │ │ ├── midas_net.py │ │ ├── midas_net_custom.py │ │ ├── transforms.py │ │ └── vit.py │ │ └── utils.py └── util.py ├── models ├── cldm_v15.yaml ├── cldm_v21.yaml ├── cycle_v21.yaml └── fastcycle_v21.yaml ├── share.py └── tool_add_cycle_sd21.py /README.md: -------------------------------------------------------------------------------- 1 | # CycleNet 2 | **This repository is still actively under construction!** 3 | 4 | This is the official implementation of: 5 | 6 | CycleNet: Rethinking Cycle Consistency in Text-Guided Diffusion for Image Manipulation 7 | Sihan Xu*, Ziqiao Ma*, Yidong Huang, Honglak Lee, Joyce Chai 8 | University of Michigan, LG AI Research 9 | [NeurIPS 2023](https://neurips.cc/virtual/2023/poster/69913) 10 | 11 | ### [Project Page](https://cyclenetweb.github.io) | [Paper](http://arxiv.org/abs/2310.13165) | [ManiCups Dataset](https://huggingface.co/datasets/sled-umich/ManiCups) | [Models](https://huggingface.co/sled-umich/CycleNet) 12 | 13 | ## Conda Environment 14 | 15 | ``` 16 | conda env create -f environment.yaml 17 | conda activate cycle 18 | ``` 19 | 20 | ## Training 21 | 22 | This implementation builds upon [ControlNet](https://github.com/lllyasviel/ControlNet). 23 | Please redirect to this document on how to [train the model](./docs/train.md). 24 | 25 | ## Citation 26 | 27 | ``` 28 | @inproceedings{xu2023cyclenet, 29 | title = "CycleNet: Rethinking Cycle Consistent in Text‑Guided Diffusion for Image Manipulation", 30 | author = "Xu, Sihan and Ma, Ziqiao and Huang, Yidong and Lee, Honglak and Chai, Joyce", 31 | booktitle = "Advances in Neural Information Processing Systems (NeurIPS)", 32 | year = "2023", 33 | } 34 | ``` 35 | -------------------------------------------------------------------------------- /annotator/canny/__init__.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | 4 | class CannyDetector: 5 | def __call__(self, img, low_threshold, high_threshold): 6 | return cv2.Canny(img, low_threshold, high_threshold) 7 | -------------------------------------------------------------------------------- /annotator/ckpts/ckpts.txt: -------------------------------------------------------------------------------- 1 | Weights here. -------------------------------------------------------------------------------- /annotator/hed/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | import torch 5 | from einops import rearrange 6 | from annotator.util import annotator_ckpts_path 7 | 8 | 9 | class Network(torch.nn.Module): 10 | def __init__(self, model_path): 11 | super().__init__() 12 | 13 | self.netVggOne = torch.nn.Sequential( 14 | torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1), 15 | torch.nn.ReLU(inplace=False), 16 | torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), 17 | torch.nn.ReLU(inplace=False) 18 | ) 19 | 20 | self.netVggTwo = torch.nn.Sequential( 21 | torch.nn.MaxPool2d(kernel_size=2, stride=2), 22 | torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1), 23 | torch.nn.ReLU(inplace=False), 24 | torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), 25 | torch.nn.ReLU(inplace=False) 26 | ) 27 | 28 | self.netVggThr = torch.nn.Sequential( 29 | torch.nn.MaxPool2d(kernel_size=2, stride=2), 30 | torch.nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1), 31 | torch.nn.ReLU(inplace=False), 32 | torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), 33 | torch.nn.ReLU(inplace=False), 34 | torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), 35 | torch.nn.ReLU(inplace=False) 36 | ) 37 | 38 | self.netVggFou = torch.nn.Sequential( 39 | torch.nn.MaxPool2d(kernel_size=2, stride=2), 40 | torch.nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1), 41 | torch.nn.ReLU(inplace=False), 42 | torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), 43 | torch.nn.ReLU(inplace=False), 44 | torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), 45 | torch.nn.ReLU(inplace=False) 46 | ) 47 | 48 | self.netVggFiv = torch.nn.Sequential( 49 | torch.nn.MaxPool2d(kernel_size=2, stride=2), 50 | torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), 51 | torch.nn.ReLU(inplace=False), 52 | torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), 53 | torch.nn.ReLU(inplace=False), 54 | torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), 55 | torch.nn.ReLU(inplace=False) 56 | ) 57 | 58 | self.netScoreOne = torch.nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0) 59 | self.netScoreTwo = torch.nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1, stride=1, padding=0) 60 | self.netScoreThr = torch.nn.Conv2d(in_channels=256, out_channels=1, kernel_size=1, stride=1, padding=0) 61 | self.netScoreFou = torch.nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0) 62 | self.netScoreFiv = torch.nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0) 63 | 64 | self.netCombine = torch.nn.Sequential( 65 | torch.nn.Conv2d(in_channels=5, out_channels=1, kernel_size=1, stride=1, padding=0), 66 | torch.nn.Sigmoid() 67 | ) 68 | 69 | self.load_state_dict({strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.load(model_path).items()}) 70 | 71 | def forward(self, tenInput): 72 | tenInput = tenInput * 255.0 73 | tenInput = tenInput - torch.tensor(data=[104.00698793, 116.66876762, 122.67891434], dtype=tenInput.dtype, device=tenInput.device).view(1, 3, 1, 1) 74 | 75 | tenVggOne = self.netVggOne(tenInput) 76 | tenVggTwo = self.netVggTwo(tenVggOne) 77 | tenVggThr = self.netVggThr(tenVggTwo) 78 | tenVggFou = self.netVggFou(tenVggThr) 79 | tenVggFiv = self.netVggFiv(tenVggFou) 80 | 81 | tenScoreOne = self.netScoreOne(tenVggOne) 82 | tenScoreTwo = self.netScoreTwo(tenVggTwo) 83 | tenScoreThr = self.netScoreThr(tenVggThr) 84 | tenScoreFou = self.netScoreFou(tenVggFou) 85 | tenScoreFiv = self.netScoreFiv(tenVggFiv) 86 | 87 | tenScoreOne = torch.nn.functional.interpolate(input=tenScoreOne, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False) 88 | tenScoreTwo = torch.nn.functional.interpolate(input=tenScoreTwo, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False) 89 | tenScoreThr = torch.nn.functional.interpolate(input=tenScoreThr, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False) 90 | tenScoreFou = torch.nn.functional.interpolate(input=tenScoreFou, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False) 91 | tenScoreFiv = torch.nn.functional.interpolate(input=tenScoreFiv, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False) 92 | 93 | return self.netCombine(torch.cat([ tenScoreOne, tenScoreTwo, tenScoreThr, tenScoreFou, tenScoreFiv ], 1)) 94 | 95 | 96 | class HEDdetector: 97 | def __init__(self): 98 | remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/network-bsds500.pth" 99 | modelpath = os.path.join(annotator_ckpts_path, "network-bsds500.pth") 100 | if not os.path.exists(modelpath): 101 | from basicsr.utils.download_util import load_file_from_url 102 | load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path) 103 | self.netNetwork = Network(modelpath).cuda().eval() 104 | 105 | def __call__(self, input_image): 106 | assert input_image.ndim == 3 107 | input_image = input_image[:, :, ::-1].copy() 108 | with torch.no_grad(): 109 | image_hed = torch.from_numpy(input_image).float().cuda() 110 | image_hed = image_hed / 255.0 111 | image_hed = rearrange(image_hed, 'h w c -> 1 c h w') 112 | edge = self.netNetwork(image_hed)[0] 113 | edge = (edge.cpu().numpy() * 255.0).clip(0, 255).astype(np.uint8) 114 | return edge[0] 115 | 116 | 117 | def nms(x, t, s): 118 | x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s) 119 | 120 | f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8) 121 | f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8) 122 | f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8) 123 | f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8) 124 | 125 | y = np.zeros_like(x) 126 | 127 | for f in [f1, f2, f3, f4]: 128 | np.putmask(y, cv2.dilate(x, kernel=f) == x, x) 129 | 130 | z = np.zeros_like(y, dtype=np.uint8) 131 | z[y > t] = 255 132 | return z 133 | -------------------------------------------------------------------------------- /annotator/midas/__init__.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | 5 | from einops import rearrange 6 | from .api import MiDaSInference 7 | 8 | 9 | class MidasDetector: 10 | def __init__(self): 11 | self.model = MiDaSInference(model_type="dpt_hybrid").cuda() 12 | 13 | def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1): 14 | assert input_image.ndim == 3 15 | image_depth = input_image 16 | with torch.no_grad(): 17 | image_depth = torch.from_numpy(image_depth).float().cuda() 18 | image_depth = image_depth / 127.5 - 1.0 19 | image_depth = rearrange(image_depth, 'h w c -> 1 c h w') 20 | depth = self.model(image_depth)[0] 21 | 22 | depth_pt = depth.clone() 23 | depth_pt -= torch.min(depth_pt) 24 | depth_pt /= torch.max(depth_pt) 25 | depth_pt = depth_pt.cpu().numpy() 26 | depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8) 27 | 28 | depth_np = depth.cpu().numpy() 29 | x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3) 30 | y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3) 31 | z = np.ones_like(x) * a 32 | x[depth_pt < bg_th] = 0 33 | y[depth_pt < bg_th] = 0 34 | normal = np.stack([x, y, z], axis=2) 35 | normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5 36 | normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8) 37 | 38 | return depth_image, normal_image 39 | -------------------------------------------------------------------------------- /annotator/midas/api.py: -------------------------------------------------------------------------------- 1 | # based on https://github.com/isl-org/MiDaS 2 | 3 | import cv2 4 | import os 5 | import torch 6 | import torch.nn as nn 7 | from torchvision.transforms import Compose 8 | 9 | from .midas.dpt_depth import DPTDepthModel 10 | from .midas.midas_net import MidasNet 11 | from .midas.midas_net_custom import MidasNet_small 12 | from .midas.transforms import Resize, NormalizeImage, PrepareForNet 13 | from annotator.util import annotator_ckpts_path 14 | 15 | 16 | ISL_PATHS = { 17 | "dpt_large": os.path.join(annotator_ckpts_path, "dpt_large-midas-2f21e586.pt"), 18 | "dpt_hybrid": os.path.join(annotator_ckpts_path, "dpt_hybrid-midas-501f0c75.pt"), 19 | "midas_v21": "", 20 | "midas_v21_small": "", 21 | } 22 | 23 | remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt" 24 | 25 | 26 | def disabled_train(self, mode=True): 27 | """Overwrite model.train with this function to make sure train/eval mode 28 | does not change anymore.""" 29 | return self 30 | 31 | 32 | def load_midas_transform(model_type): 33 | # https://github.com/isl-org/MiDaS/blob/master/run.py 34 | # load transform only 35 | if model_type == "dpt_large": # DPT-Large 36 | net_w, net_h = 384, 384 37 | resize_mode = "minimal" 38 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 39 | 40 | elif model_type == "dpt_hybrid": # DPT-Hybrid 41 | net_w, net_h = 384, 384 42 | resize_mode = "minimal" 43 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 44 | 45 | elif model_type == "midas_v21": 46 | net_w, net_h = 384, 384 47 | resize_mode = "upper_bound" 48 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 49 | 50 | elif model_type == "midas_v21_small": 51 | net_w, net_h = 256, 256 52 | resize_mode = "upper_bound" 53 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 54 | 55 | else: 56 | assert False, f"model_type '{model_type}' not implemented, use: --model_type large" 57 | 58 | transform = Compose( 59 | [ 60 | Resize( 61 | net_w, 62 | net_h, 63 | resize_target=None, 64 | keep_aspect_ratio=True, 65 | ensure_multiple_of=32, 66 | resize_method=resize_mode, 67 | image_interpolation_method=cv2.INTER_CUBIC, 68 | ), 69 | normalization, 70 | PrepareForNet(), 71 | ] 72 | ) 73 | 74 | return transform 75 | 76 | 77 | def load_model(model_type): 78 | # https://github.com/isl-org/MiDaS/blob/master/run.py 79 | # load network 80 | model_path = ISL_PATHS[model_type] 81 | if model_type == "dpt_large": # DPT-Large 82 | model = DPTDepthModel( 83 | path=model_path, 84 | backbone="vitl16_384", 85 | non_negative=True, 86 | ) 87 | net_w, net_h = 384, 384 88 | resize_mode = "minimal" 89 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 90 | 91 | elif model_type == "dpt_hybrid": # DPT-Hybrid 92 | if not os.path.exists(model_path): 93 | from basicsr.utils.download_util import load_file_from_url 94 | load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path) 95 | 96 | model = DPTDepthModel( 97 | path=model_path, 98 | backbone="vitb_rn50_384", 99 | non_negative=True, 100 | ) 101 | net_w, net_h = 384, 384 102 | resize_mode = "minimal" 103 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 104 | 105 | elif model_type == "midas_v21": 106 | model = MidasNet(model_path, non_negative=True) 107 | net_w, net_h = 384, 384 108 | resize_mode = "upper_bound" 109 | normalization = NormalizeImage( 110 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 111 | ) 112 | 113 | elif model_type == "midas_v21_small": 114 | model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, 115 | non_negative=True, blocks={'expand': True}) 116 | net_w, net_h = 256, 256 117 | resize_mode = "upper_bound" 118 | normalization = NormalizeImage( 119 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 120 | ) 121 | 122 | else: 123 | print(f"model_type '{model_type}' not implemented, use: --model_type large") 124 | assert False 125 | 126 | transform = Compose( 127 | [ 128 | Resize( 129 | net_w, 130 | net_h, 131 | resize_target=None, 132 | keep_aspect_ratio=True, 133 | ensure_multiple_of=32, 134 | resize_method=resize_mode, 135 | image_interpolation_method=cv2.INTER_CUBIC, 136 | ), 137 | normalization, 138 | PrepareForNet(), 139 | ] 140 | ) 141 | 142 | return model.eval(), transform 143 | 144 | 145 | class MiDaSInference(nn.Module): 146 | MODEL_TYPES_TORCH_HUB = [ 147 | "DPT_Large", 148 | "DPT_Hybrid", 149 | "MiDaS_small" 150 | ] 151 | MODEL_TYPES_ISL = [ 152 | "dpt_large", 153 | "dpt_hybrid", 154 | "midas_v21", 155 | "midas_v21_small", 156 | ] 157 | 158 | def __init__(self, model_type): 159 | super().__init__() 160 | assert (model_type in self.MODEL_TYPES_ISL) 161 | model, _ = load_model(model_type) 162 | self.model = model 163 | self.model.train = disabled_train 164 | 165 | def forward(self, x): 166 | with torch.no_grad(): 167 | prediction = self.model(x) 168 | return prediction 169 | 170 | -------------------------------------------------------------------------------- /annotator/midas/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sled-group/CycleNet/a1c370aa9a0146c376c1f44d303d5534c1ba4743/annotator/midas/midas/__init__.py -------------------------------------------------------------------------------- /annotator/midas/midas/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseModel(torch.nn.Module): 5 | def load(self, path): 6 | """Load model from file. 7 | 8 | Args: 9 | path (str): file path 10 | """ 11 | parameters = torch.load(path, map_location=torch.device('cpu')) 12 | 13 | if "optimizer" in parameters: 14 | parameters = parameters["model"] 15 | 16 | self.load_state_dict(parameters) 17 | -------------------------------------------------------------------------------- /annotator/midas/midas/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .vit import ( 5 | _make_pretrained_vitb_rn50_384, 6 | _make_pretrained_vitl16_384, 7 | _make_pretrained_vitb16_384, 8 | forward_vit, 9 | ) 10 | 11 | def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): 12 | if backbone == "vitl16_384": 13 | pretrained = _make_pretrained_vitl16_384( 14 | use_pretrained, hooks=hooks, use_readout=use_readout 15 | ) 16 | scratch = _make_scratch( 17 | [256, 512, 1024, 1024], features, groups=groups, expand=expand 18 | ) # ViT-L/16 - 85.0% Top1 (backbone) 19 | elif backbone == "vitb_rn50_384": 20 | pretrained = _make_pretrained_vitb_rn50_384( 21 | use_pretrained, 22 | hooks=hooks, 23 | use_vit_only=use_vit_only, 24 | use_readout=use_readout, 25 | ) 26 | scratch = _make_scratch( 27 | [256, 512, 768, 768], features, groups=groups, expand=expand 28 | ) # ViT-H/16 - 85.0% Top1 (backbone) 29 | elif backbone == "vitb16_384": 30 | pretrained = _make_pretrained_vitb16_384( 31 | use_pretrained, hooks=hooks, use_readout=use_readout 32 | ) 33 | scratch = _make_scratch( 34 | [96, 192, 384, 768], features, groups=groups, expand=expand 35 | ) # ViT-B/16 - 84.6% Top1 (backbone) 36 | elif backbone == "resnext101_wsl": 37 | pretrained = _make_pretrained_resnext101_wsl(use_pretrained) 38 | scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 39 | elif backbone == "efficientnet_lite3": 40 | pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) 41 | scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 42 | else: 43 | print(f"Backbone '{backbone}' not implemented") 44 | assert False 45 | 46 | return pretrained, scratch 47 | 48 | 49 | def _make_scratch(in_shape, out_shape, groups=1, expand=False): 50 | scratch = nn.Module() 51 | 52 | out_shape1 = out_shape 53 | out_shape2 = out_shape 54 | out_shape3 = out_shape 55 | out_shape4 = out_shape 56 | if expand==True: 57 | out_shape1 = out_shape 58 | out_shape2 = out_shape*2 59 | out_shape3 = out_shape*4 60 | out_shape4 = out_shape*8 61 | 62 | scratch.layer1_rn = nn.Conv2d( 63 | in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 64 | ) 65 | scratch.layer2_rn = nn.Conv2d( 66 | in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 67 | ) 68 | scratch.layer3_rn = nn.Conv2d( 69 | in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 70 | ) 71 | scratch.layer4_rn = nn.Conv2d( 72 | in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 73 | ) 74 | 75 | return scratch 76 | 77 | 78 | def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): 79 | efficientnet = torch.hub.load( 80 | "rwightman/gen-efficientnet-pytorch", 81 | "tf_efficientnet_lite3", 82 | pretrained=use_pretrained, 83 | exportable=exportable 84 | ) 85 | return _make_efficientnet_backbone(efficientnet) 86 | 87 | 88 | def _make_efficientnet_backbone(effnet): 89 | pretrained = nn.Module() 90 | 91 | pretrained.layer1 = nn.Sequential( 92 | effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] 93 | ) 94 | pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) 95 | pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) 96 | pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) 97 | 98 | return pretrained 99 | 100 | 101 | def _make_resnet_backbone(resnet): 102 | pretrained = nn.Module() 103 | pretrained.layer1 = nn.Sequential( 104 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 105 | ) 106 | 107 | pretrained.layer2 = resnet.layer2 108 | pretrained.layer3 = resnet.layer3 109 | pretrained.layer4 = resnet.layer4 110 | 111 | return pretrained 112 | 113 | 114 | def _make_pretrained_resnext101_wsl(use_pretrained): 115 | resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") 116 | return _make_resnet_backbone(resnet) 117 | 118 | 119 | 120 | class Interpolate(nn.Module): 121 | """Interpolation module. 122 | """ 123 | 124 | def __init__(self, scale_factor, mode, align_corners=False): 125 | """Init. 126 | 127 | Args: 128 | scale_factor (float): scaling 129 | mode (str): interpolation mode 130 | """ 131 | super(Interpolate, self).__init__() 132 | 133 | self.interp = nn.functional.interpolate 134 | self.scale_factor = scale_factor 135 | self.mode = mode 136 | self.align_corners = align_corners 137 | 138 | def forward(self, x): 139 | """Forward pass. 140 | 141 | Args: 142 | x (tensor): input 143 | 144 | Returns: 145 | tensor: interpolated data 146 | """ 147 | 148 | x = self.interp( 149 | x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners 150 | ) 151 | 152 | return x 153 | 154 | 155 | class ResidualConvUnit(nn.Module): 156 | """Residual convolution module. 157 | """ 158 | 159 | def __init__(self, features): 160 | """Init. 161 | 162 | Args: 163 | features (int): number of features 164 | """ 165 | super().__init__() 166 | 167 | self.conv1 = nn.Conv2d( 168 | features, features, kernel_size=3, stride=1, padding=1, bias=True 169 | ) 170 | 171 | self.conv2 = nn.Conv2d( 172 | features, features, kernel_size=3, stride=1, padding=1, bias=True 173 | ) 174 | 175 | self.relu = nn.ReLU(inplace=True) 176 | 177 | def forward(self, x): 178 | """Forward pass. 179 | 180 | Args: 181 | x (tensor): input 182 | 183 | Returns: 184 | tensor: output 185 | """ 186 | out = self.relu(x) 187 | out = self.conv1(out) 188 | out = self.relu(out) 189 | out = self.conv2(out) 190 | 191 | return out + x 192 | 193 | 194 | class FeatureFusionBlock(nn.Module): 195 | """Feature fusion block. 196 | """ 197 | 198 | def __init__(self, features): 199 | """Init. 200 | 201 | Args: 202 | features (int): number of features 203 | """ 204 | super(FeatureFusionBlock, self).__init__() 205 | 206 | self.resConfUnit1 = ResidualConvUnit(features) 207 | self.resConfUnit2 = ResidualConvUnit(features) 208 | 209 | def forward(self, *xs): 210 | """Forward pass. 211 | 212 | Returns: 213 | tensor: output 214 | """ 215 | output = xs[0] 216 | 217 | if len(xs) == 2: 218 | output += self.resConfUnit1(xs[1]) 219 | 220 | output = self.resConfUnit2(output) 221 | 222 | output = nn.functional.interpolate( 223 | output, scale_factor=2, mode="bilinear", align_corners=True 224 | ) 225 | 226 | return output 227 | 228 | 229 | 230 | 231 | class ResidualConvUnit_custom(nn.Module): 232 | """Residual convolution module. 233 | """ 234 | 235 | def __init__(self, features, activation, bn): 236 | """Init. 237 | 238 | Args: 239 | features (int): number of features 240 | """ 241 | super().__init__() 242 | 243 | self.bn = bn 244 | 245 | self.groups=1 246 | 247 | self.conv1 = nn.Conv2d( 248 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups 249 | ) 250 | 251 | self.conv2 = nn.Conv2d( 252 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups 253 | ) 254 | 255 | if self.bn==True: 256 | self.bn1 = nn.BatchNorm2d(features) 257 | self.bn2 = nn.BatchNorm2d(features) 258 | 259 | self.activation = activation 260 | 261 | self.skip_add = nn.quantized.FloatFunctional() 262 | 263 | def forward(self, x): 264 | """Forward pass. 265 | 266 | Args: 267 | x (tensor): input 268 | 269 | Returns: 270 | tensor: output 271 | """ 272 | 273 | out = self.activation(x) 274 | out = self.conv1(out) 275 | if self.bn==True: 276 | out = self.bn1(out) 277 | 278 | out = self.activation(out) 279 | out = self.conv2(out) 280 | if self.bn==True: 281 | out = self.bn2(out) 282 | 283 | if self.groups > 1: 284 | out = self.conv_merge(out) 285 | 286 | return self.skip_add.add(out, x) 287 | 288 | # return out + x 289 | 290 | 291 | class FeatureFusionBlock_custom(nn.Module): 292 | """Feature fusion block. 293 | """ 294 | 295 | def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): 296 | """Init. 297 | 298 | Args: 299 | features (int): number of features 300 | """ 301 | super(FeatureFusionBlock_custom, self).__init__() 302 | 303 | self.deconv = deconv 304 | self.align_corners = align_corners 305 | 306 | self.groups=1 307 | 308 | self.expand = expand 309 | out_features = features 310 | if self.expand==True: 311 | out_features = features//2 312 | 313 | self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) 314 | 315 | self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) 316 | self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) 317 | 318 | self.skip_add = nn.quantized.FloatFunctional() 319 | 320 | def forward(self, *xs): 321 | """Forward pass. 322 | 323 | Returns: 324 | tensor: output 325 | """ 326 | output = xs[0] 327 | 328 | if len(xs) == 2: 329 | res = self.resConfUnit1(xs[1]) 330 | output = self.skip_add.add(output, res) 331 | # output += res 332 | 333 | output = self.resConfUnit2(output) 334 | 335 | output = nn.functional.interpolate( 336 | output, scale_factor=2, mode="bilinear", align_corners=self.align_corners 337 | ) 338 | 339 | output = self.out_conv(output) 340 | 341 | return output 342 | 343 | -------------------------------------------------------------------------------- /annotator/midas/midas/dpt_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base_model import BaseModel 6 | from .blocks import ( 7 | FeatureFusionBlock, 8 | FeatureFusionBlock_custom, 9 | Interpolate, 10 | _make_encoder, 11 | forward_vit, 12 | ) 13 | 14 | 15 | def _make_fusion_block(features, use_bn): 16 | return FeatureFusionBlock_custom( 17 | features, 18 | nn.ReLU(False), 19 | deconv=False, 20 | bn=use_bn, 21 | expand=False, 22 | align_corners=True, 23 | ) 24 | 25 | 26 | class DPT(BaseModel): 27 | def __init__( 28 | self, 29 | head, 30 | features=256, 31 | backbone="vitb_rn50_384", 32 | readout="project", 33 | channels_last=False, 34 | use_bn=False, 35 | ): 36 | 37 | super(DPT, self).__init__() 38 | 39 | self.channels_last = channels_last 40 | 41 | hooks = { 42 | "vitb_rn50_384": [0, 1, 8, 11], 43 | "vitb16_384": [2, 5, 8, 11], 44 | "vitl16_384": [5, 11, 17, 23], 45 | } 46 | 47 | # Instantiate backbone and reassemble blocks 48 | self.pretrained, self.scratch = _make_encoder( 49 | backbone, 50 | features, 51 | False, # Set to true of you want to train from scratch, uses ImageNet weights 52 | groups=1, 53 | expand=False, 54 | exportable=False, 55 | hooks=hooks[backbone], 56 | use_readout=readout, 57 | ) 58 | 59 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 60 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 61 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn) 62 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 63 | 64 | self.scratch.output_conv = head 65 | 66 | 67 | def forward(self, x): 68 | if self.channels_last == True: 69 | x.contiguous(memory_format=torch.channels_last) 70 | 71 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) 72 | 73 | layer_1_rn = self.scratch.layer1_rn(layer_1) 74 | layer_2_rn = self.scratch.layer2_rn(layer_2) 75 | layer_3_rn = self.scratch.layer3_rn(layer_3) 76 | layer_4_rn = self.scratch.layer4_rn(layer_4) 77 | 78 | path_4 = self.scratch.refinenet4(layer_4_rn) 79 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 80 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 81 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 82 | 83 | out = self.scratch.output_conv(path_1) 84 | 85 | return out 86 | 87 | 88 | class DPTDepthModel(DPT): 89 | def __init__(self, path=None, non_negative=True, **kwargs): 90 | features = kwargs["features"] if "features" in kwargs else 256 91 | 92 | head = nn.Sequential( 93 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), 94 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 95 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 96 | nn.ReLU(True), 97 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 98 | nn.ReLU(True) if non_negative else nn.Identity(), 99 | nn.Identity(), 100 | ) 101 | 102 | super().__init__(head, **kwargs) 103 | 104 | if path is not None: 105 | self.load(path) 106 | 107 | def forward(self, x): 108 | return super().forward(x).squeeze(dim=1) 109 | 110 | -------------------------------------------------------------------------------- /annotator/midas/midas/midas_net.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=256, non_negative=True): 17 | """Init. 18 | 19 | Args: 20 | path (str, optional): Path to saved model. Defaults to None. 21 | features (int, optional): Number of features. Defaults to 256. 22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 23 | """ 24 | print("Loading weights: ", path) 25 | 26 | super(MidasNet, self).__init__() 27 | 28 | use_pretrained = False if path is None else True 29 | 30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) 31 | 32 | self.scratch.refinenet4 = FeatureFusionBlock(features) 33 | self.scratch.refinenet3 = FeatureFusionBlock(features) 34 | self.scratch.refinenet2 = FeatureFusionBlock(features) 35 | self.scratch.refinenet1 = FeatureFusionBlock(features) 36 | 37 | self.scratch.output_conv = nn.Sequential( 38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 39 | Interpolate(scale_factor=2, mode="bilinear"), 40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 41 | nn.ReLU(True), 42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 43 | nn.ReLU(True) if non_negative else nn.Identity(), 44 | ) 45 | 46 | if path: 47 | self.load(path) 48 | 49 | def forward(self, x): 50 | """Forward pass. 51 | 52 | Args: 53 | x (tensor): input data (image) 54 | 55 | Returns: 56 | tensor: depth 57 | """ 58 | 59 | layer_1 = self.pretrained.layer1(x) 60 | layer_2 = self.pretrained.layer2(layer_1) 61 | layer_3 = self.pretrained.layer3(layer_2) 62 | layer_4 = self.pretrained.layer4(layer_3) 63 | 64 | layer_1_rn = self.scratch.layer1_rn(layer_1) 65 | layer_2_rn = self.scratch.layer2_rn(layer_2) 66 | layer_3_rn = self.scratch.layer3_rn(layer_3) 67 | layer_4_rn = self.scratch.layer4_rn(layer_4) 68 | 69 | path_4 = self.scratch.refinenet4(layer_4_rn) 70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 73 | 74 | out = self.scratch.output_conv(path_1) 75 | 76 | return torch.squeeze(out, dim=1) 77 | -------------------------------------------------------------------------------- /annotator/midas/midas/midas_net_custom.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet_small(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, 17 | blocks={'expand': True}): 18 | """Init. 19 | 20 | Args: 21 | path (str, optional): Path to saved model. Defaults to None. 22 | features (int, optional): Number of features. Defaults to 256. 23 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 24 | """ 25 | print("Loading weights: ", path) 26 | 27 | super(MidasNet_small, self).__init__() 28 | 29 | use_pretrained = False if path else True 30 | 31 | self.channels_last = channels_last 32 | self.blocks = blocks 33 | self.backbone = backbone 34 | 35 | self.groups = 1 36 | 37 | features1=features 38 | features2=features 39 | features3=features 40 | features4=features 41 | self.expand = False 42 | if "expand" in self.blocks and self.blocks['expand'] == True: 43 | self.expand = True 44 | features1=features 45 | features2=features*2 46 | features3=features*4 47 | features4=features*8 48 | 49 | self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) 50 | 51 | self.scratch.activation = nn.ReLU(False) 52 | 53 | self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 54 | self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 55 | self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 56 | self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) 57 | 58 | 59 | self.scratch.output_conv = nn.Sequential( 60 | nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), 61 | Interpolate(scale_factor=2, mode="bilinear"), 62 | nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), 63 | self.scratch.activation, 64 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 65 | nn.ReLU(True) if non_negative else nn.Identity(), 66 | nn.Identity(), 67 | ) 68 | 69 | if path: 70 | self.load(path) 71 | 72 | 73 | def forward(self, x): 74 | """Forward pass. 75 | 76 | Args: 77 | x (tensor): input data (image) 78 | 79 | Returns: 80 | tensor: depth 81 | """ 82 | if self.channels_last==True: 83 | print("self.channels_last = ", self.channels_last) 84 | x.contiguous(memory_format=torch.channels_last) 85 | 86 | 87 | layer_1 = self.pretrained.layer1(x) 88 | layer_2 = self.pretrained.layer2(layer_1) 89 | layer_3 = self.pretrained.layer3(layer_2) 90 | layer_4 = self.pretrained.layer4(layer_3) 91 | 92 | layer_1_rn = self.scratch.layer1_rn(layer_1) 93 | layer_2_rn = self.scratch.layer2_rn(layer_2) 94 | layer_3_rn = self.scratch.layer3_rn(layer_3) 95 | layer_4_rn = self.scratch.layer4_rn(layer_4) 96 | 97 | 98 | path_4 = self.scratch.refinenet4(layer_4_rn) 99 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 100 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 101 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 102 | 103 | out = self.scratch.output_conv(path_1) 104 | 105 | return torch.squeeze(out, dim=1) 106 | 107 | 108 | 109 | def fuse_model(m): 110 | prev_previous_type = nn.Identity() 111 | prev_previous_name = '' 112 | previous_type = nn.Identity() 113 | previous_name = '' 114 | for name, module in m.named_modules(): 115 | if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: 116 | # print("FUSED ", prev_previous_name, previous_name, name) 117 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) 118 | elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: 119 | # print("FUSED ", prev_previous_name, previous_name) 120 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) 121 | # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: 122 | # print("FUSED ", previous_name, name) 123 | # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) 124 | 125 | prev_previous_type = previous_type 126 | prev_previous_name = previous_name 127 | previous_type = type(module) 128 | previous_name = name -------------------------------------------------------------------------------- /annotator/midas/midas/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import math 4 | 5 | 6 | def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): 7 | """Rezise the sample to ensure the given size. Keeps aspect ratio. 8 | 9 | Args: 10 | sample (dict): sample 11 | size (tuple): image size 12 | 13 | Returns: 14 | tuple: new size 15 | """ 16 | shape = list(sample["disparity"].shape) 17 | 18 | if shape[0] >= size[0] and shape[1] >= size[1]: 19 | return sample 20 | 21 | scale = [0, 0] 22 | scale[0] = size[0] / shape[0] 23 | scale[1] = size[1] / shape[1] 24 | 25 | scale = max(scale) 26 | 27 | shape[0] = math.ceil(scale * shape[0]) 28 | shape[1] = math.ceil(scale * shape[1]) 29 | 30 | # resize 31 | sample["image"] = cv2.resize( 32 | sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method 33 | ) 34 | 35 | sample["disparity"] = cv2.resize( 36 | sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST 37 | ) 38 | sample["mask"] = cv2.resize( 39 | sample["mask"].astype(np.float32), 40 | tuple(shape[::-1]), 41 | interpolation=cv2.INTER_NEAREST, 42 | ) 43 | sample["mask"] = sample["mask"].astype(bool) 44 | 45 | return tuple(shape) 46 | 47 | 48 | class Resize(object): 49 | """Resize sample to given size (width, height). 50 | """ 51 | 52 | def __init__( 53 | self, 54 | width, 55 | height, 56 | resize_target=True, 57 | keep_aspect_ratio=False, 58 | ensure_multiple_of=1, 59 | resize_method="lower_bound", 60 | image_interpolation_method=cv2.INTER_AREA, 61 | ): 62 | """Init. 63 | 64 | Args: 65 | width (int): desired output width 66 | height (int): desired output height 67 | resize_target (bool, optional): 68 | True: Resize the full sample (image, mask, target). 69 | False: Resize image only. 70 | Defaults to True. 71 | keep_aspect_ratio (bool, optional): 72 | True: Keep the aspect ratio of the input sample. 73 | Output sample might not have the given width and height, and 74 | resize behaviour depends on the parameter 'resize_method'. 75 | Defaults to False. 76 | ensure_multiple_of (int, optional): 77 | Output width and height is constrained to be multiple of this parameter. 78 | Defaults to 1. 79 | resize_method (str, optional): 80 | "lower_bound": Output will be at least as large as the given size. 81 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) 82 | "minimal": Scale as least as possible. (Output size might be smaller than given size.) 83 | Defaults to "lower_bound". 84 | """ 85 | self.__width = width 86 | self.__height = height 87 | 88 | self.__resize_target = resize_target 89 | self.__keep_aspect_ratio = keep_aspect_ratio 90 | self.__multiple_of = ensure_multiple_of 91 | self.__resize_method = resize_method 92 | self.__image_interpolation_method = image_interpolation_method 93 | 94 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None): 95 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) 96 | 97 | if max_val is not None and y > max_val: 98 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) 99 | 100 | if y < min_val: 101 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) 102 | 103 | return y 104 | 105 | def get_size(self, width, height): 106 | # determine new height and width 107 | scale_height = self.__height / height 108 | scale_width = self.__width / width 109 | 110 | if self.__keep_aspect_ratio: 111 | if self.__resize_method == "lower_bound": 112 | # scale such that output size is lower bound 113 | if scale_width > scale_height: 114 | # fit width 115 | scale_height = scale_width 116 | else: 117 | # fit height 118 | scale_width = scale_height 119 | elif self.__resize_method == "upper_bound": 120 | # scale such that output size is upper bound 121 | if scale_width < scale_height: 122 | # fit width 123 | scale_height = scale_width 124 | else: 125 | # fit height 126 | scale_width = scale_height 127 | elif self.__resize_method == "minimal": 128 | # scale as least as possbile 129 | if abs(1 - scale_width) < abs(1 - scale_height): 130 | # fit width 131 | scale_height = scale_width 132 | else: 133 | # fit height 134 | scale_width = scale_height 135 | else: 136 | raise ValueError( 137 | f"resize_method {self.__resize_method} not implemented" 138 | ) 139 | 140 | if self.__resize_method == "lower_bound": 141 | new_height = self.constrain_to_multiple_of( 142 | scale_height * height, min_val=self.__height 143 | ) 144 | new_width = self.constrain_to_multiple_of( 145 | scale_width * width, min_val=self.__width 146 | ) 147 | elif self.__resize_method == "upper_bound": 148 | new_height = self.constrain_to_multiple_of( 149 | scale_height * height, max_val=self.__height 150 | ) 151 | new_width = self.constrain_to_multiple_of( 152 | scale_width * width, max_val=self.__width 153 | ) 154 | elif self.__resize_method == "minimal": 155 | new_height = self.constrain_to_multiple_of(scale_height * height) 156 | new_width = self.constrain_to_multiple_of(scale_width * width) 157 | else: 158 | raise ValueError(f"resize_method {self.__resize_method} not implemented") 159 | 160 | return (new_width, new_height) 161 | 162 | def __call__(self, sample): 163 | width, height = self.get_size( 164 | sample["image"].shape[1], sample["image"].shape[0] 165 | ) 166 | 167 | # resize sample 168 | sample["image"] = cv2.resize( 169 | sample["image"], 170 | (width, height), 171 | interpolation=self.__image_interpolation_method, 172 | ) 173 | 174 | if self.__resize_target: 175 | if "disparity" in sample: 176 | sample["disparity"] = cv2.resize( 177 | sample["disparity"], 178 | (width, height), 179 | interpolation=cv2.INTER_NEAREST, 180 | ) 181 | 182 | if "depth" in sample: 183 | sample["depth"] = cv2.resize( 184 | sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST 185 | ) 186 | 187 | sample["mask"] = cv2.resize( 188 | sample["mask"].astype(np.float32), 189 | (width, height), 190 | interpolation=cv2.INTER_NEAREST, 191 | ) 192 | sample["mask"] = sample["mask"].astype(bool) 193 | 194 | return sample 195 | 196 | 197 | class NormalizeImage(object): 198 | """Normlize image by given mean and std. 199 | """ 200 | 201 | def __init__(self, mean, std): 202 | self.__mean = mean 203 | self.__std = std 204 | 205 | def __call__(self, sample): 206 | sample["image"] = (sample["image"] - self.__mean) / self.__std 207 | 208 | return sample 209 | 210 | 211 | class PrepareForNet(object): 212 | """Prepare sample for usage as network input. 213 | """ 214 | 215 | def __init__(self): 216 | pass 217 | 218 | def __call__(self, sample): 219 | image = np.transpose(sample["image"], (2, 0, 1)) 220 | sample["image"] = np.ascontiguousarray(image).astype(np.float32) 221 | 222 | if "mask" in sample: 223 | sample["mask"] = sample["mask"].astype(np.float32) 224 | sample["mask"] = np.ascontiguousarray(sample["mask"]) 225 | 226 | if "disparity" in sample: 227 | disparity = sample["disparity"].astype(np.float32) 228 | sample["disparity"] = np.ascontiguousarray(disparity) 229 | 230 | if "depth" in sample: 231 | depth = sample["depth"].astype(np.float32) 232 | sample["depth"] = np.ascontiguousarray(depth) 233 | 234 | return sample 235 | -------------------------------------------------------------------------------- /annotator/midas/utils.py: -------------------------------------------------------------------------------- 1 | """Utils for monoDepth.""" 2 | import sys 3 | import re 4 | import numpy as np 5 | import cv2 6 | import torch 7 | 8 | 9 | def read_pfm(path): 10 | """Read pfm file. 11 | 12 | Args: 13 | path (str): path to file 14 | 15 | Returns: 16 | tuple: (data, scale) 17 | """ 18 | with open(path, "rb") as file: 19 | 20 | color = None 21 | width = None 22 | height = None 23 | scale = None 24 | endian = None 25 | 26 | header = file.readline().rstrip() 27 | if header.decode("ascii") == "PF": 28 | color = True 29 | elif header.decode("ascii") == "Pf": 30 | color = False 31 | else: 32 | raise Exception("Not a PFM file: " + path) 33 | 34 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) 35 | if dim_match: 36 | width, height = list(map(int, dim_match.groups())) 37 | else: 38 | raise Exception("Malformed PFM header.") 39 | 40 | scale = float(file.readline().decode("ascii").rstrip()) 41 | if scale < 0: 42 | # little-endian 43 | endian = "<" 44 | scale = -scale 45 | else: 46 | # big-endian 47 | endian = ">" 48 | 49 | data = np.fromfile(file, endian + "f") 50 | shape = (height, width, 3) if color else (height, width) 51 | 52 | data = np.reshape(data, shape) 53 | data = np.flipud(data) 54 | 55 | return data, scale 56 | 57 | 58 | def write_pfm(path, image, scale=1): 59 | """Write pfm file. 60 | 61 | Args: 62 | path (str): pathto file 63 | image (array): data 64 | scale (int, optional): Scale. Defaults to 1. 65 | """ 66 | 67 | with open(path, "wb") as file: 68 | color = None 69 | 70 | if image.dtype.name != "float32": 71 | raise Exception("Image dtype must be float32.") 72 | 73 | image = np.flipud(image) 74 | 75 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 76 | color = True 77 | elif ( 78 | len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 79 | ): # greyscale 80 | color = False 81 | else: 82 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") 83 | 84 | file.write("PF\n" if color else "Pf\n".encode()) 85 | file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) 86 | 87 | endian = image.dtype.byteorder 88 | 89 | if endian == "<" or endian == "=" and sys.byteorder == "little": 90 | scale = -scale 91 | 92 | file.write("%f\n".encode() % scale) 93 | 94 | image.tofile(file) 95 | 96 | 97 | def read_image(path): 98 | """Read image and output RGB image (0-1). 99 | 100 | Args: 101 | path (str): path to file 102 | 103 | Returns: 104 | array: RGB image (0-1) 105 | """ 106 | img = cv2.imread(path) 107 | 108 | if img.ndim == 2: 109 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 110 | 111 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 112 | 113 | return img 114 | 115 | 116 | def resize_image(img): 117 | """Resize image and make it fit for network. 118 | 119 | Args: 120 | img (array): image 121 | 122 | Returns: 123 | tensor: data ready for network 124 | """ 125 | height_orig = img.shape[0] 126 | width_orig = img.shape[1] 127 | 128 | if width_orig > height_orig: 129 | scale = width_orig / 384 130 | else: 131 | scale = height_orig / 384 132 | 133 | height = (np.ceil(height_orig / scale / 32) * 32).astype(int) 134 | width = (np.ceil(width_orig / scale / 32) * 32).astype(int) 135 | 136 | img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) 137 | 138 | img_resized = ( 139 | torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() 140 | ) 141 | img_resized = img_resized.unsqueeze(0) 142 | 143 | return img_resized 144 | 145 | 146 | def resize_depth(depth, width, height): 147 | """Resize depth map and bring to CPU (numpy). 148 | 149 | Args: 150 | depth (tensor): depth 151 | width (int): image width 152 | height (int): image height 153 | 154 | Returns: 155 | array: processed depth 156 | """ 157 | depth = torch.squeeze(depth[0, :, :, :]).to("cpu") 158 | 159 | depth_resized = cv2.resize( 160 | depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC 161 | ) 162 | 163 | return depth_resized 164 | 165 | def write_depth(path, depth, bits=1): 166 | """Write depth map to pfm and png file. 167 | 168 | Args: 169 | path (str): filepath without extension 170 | depth (array): depth 171 | """ 172 | write_pfm(path + ".pfm", depth.astype(np.float32)) 173 | 174 | depth_min = depth.min() 175 | depth_max = depth.max() 176 | 177 | max_val = (2**(8*bits))-1 178 | 179 | if depth_max - depth_min > np.finfo("float").eps: 180 | out = max_val * (depth - depth_min) / (depth_max - depth_min) 181 | else: 182 | out = np.zeros(depth.shape, dtype=depth.type) 183 | 184 | if bits == 1: 185 | cv2.imwrite(path + ".png", out.astype("uint8")) 186 | elif bits == 2: 187 | cv2.imwrite(path + ".png", out.astype("uint16")) 188 | 189 | return 190 | -------------------------------------------------------------------------------- /annotator/mlsd/__init__.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | import os 5 | 6 | from einops import rearrange 7 | from .models.mbv2_mlsd_tiny import MobileV2_MLSD_Tiny 8 | from .models.mbv2_mlsd_large import MobileV2_MLSD_Large 9 | from .utils import pred_lines 10 | 11 | from annotator.util import annotator_ckpts_path 12 | 13 | 14 | remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/mlsd_large_512_fp32.pth" 15 | 16 | 17 | class MLSDdetector: 18 | def __init__(self): 19 | model_path = os.path.join(annotator_ckpts_path, "mlsd_large_512_fp32.pth") 20 | if not os.path.exists(model_path): 21 | from basicsr.utils.download_util import load_file_from_url 22 | load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path) 23 | model = MobileV2_MLSD_Large() 24 | model.load_state_dict(torch.load(model_path), strict=True) 25 | self.model = model.cuda().eval() 26 | 27 | def __call__(self, input_image, thr_v, thr_d): 28 | assert input_image.ndim == 3 29 | img = input_image 30 | img_output = np.zeros_like(img) 31 | try: 32 | with torch.no_grad(): 33 | lines = pred_lines(img, self.model, [img.shape[0], img.shape[1]], thr_v, thr_d) 34 | for line in lines: 35 | x_start, y_start, x_end, y_end = [int(val) for val in line] 36 | cv2.line(img_output, (x_start, y_start), (x_end, y_end), [255, 255, 255], 1) 37 | except Exception as e: 38 | pass 39 | return img_output[:, :, 0] 40 | -------------------------------------------------------------------------------- /annotator/mlsd/models/mbv2_mlsd_tiny.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torch.nn as nn 5 | import torch.utils.model_zoo as model_zoo 6 | from torch.nn import functional as F 7 | 8 | 9 | class BlockTypeA(nn.Module): 10 | def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True): 11 | super(BlockTypeA, self).__init__() 12 | self.conv1 = nn.Sequential( 13 | nn.Conv2d(in_c2, out_c2, kernel_size=1), 14 | nn.BatchNorm2d(out_c2), 15 | nn.ReLU(inplace=True) 16 | ) 17 | self.conv2 = nn.Sequential( 18 | nn.Conv2d(in_c1, out_c1, kernel_size=1), 19 | nn.BatchNorm2d(out_c1), 20 | nn.ReLU(inplace=True) 21 | ) 22 | self.upscale = upscale 23 | 24 | def forward(self, a, b): 25 | b = self.conv1(b) 26 | a = self.conv2(a) 27 | b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True) 28 | return torch.cat((a, b), dim=1) 29 | 30 | 31 | class BlockTypeB(nn.Module): 32 | def __init__(self, in_c, out_c): 33 | super(BlockTypeB, self).__init__() 34 | self.conv1 = nn.Sequential( 35 | nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), 36 | nn.BatchNorm2d(in_c), 37 | nn.ReLU() 38 | ) 39 | self.conv2 = nn.Sequential( 40 | nn.Conv2d(in_c, out_c, kernel_size=3, padding=1), 41 | nn.BatchNorm2d(out_c), 42 | nn.ReLU() 43 | ) 44 | 45 | def forward(self, x): 46 | x = self.conv1(x) + x 47 | x = self.conv2(x) 48 | return x 49 | 50 | class BlockTypeC(nn.Module): 51 | def __init__(self, in_c, out_c): 52 | super(BlockTypeC, self).__init__() 53 | self.conv1 = nn.Sequential( 54 | nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5), 55 | nn.BatchNorm2d(in_c), 56 | nn.ReLU() 57 | ) 58 | self.conv2 = nn.Sequential( 59 | nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), 60 | nn.BatchNorm2d(in_c), 61 | nn.ReLU() 62 | ) 63 | self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1) 64 | 65 | def forward(self, x): 66 | x = self.conv1(x) 67 | x = self.conv2(x) 68 | x = self.conv3(x) 69 | return x 70 | 71 | def _make_divisible(v, divisor, min_value=None): 72 | """ 73 | This function is taken from the original tf repo. 74 | It ensures that all layers have a channel number that is divisible by 8 75 | It can be seen here: 76 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 77 | :param v: 78 | :param divisor: 79 | :param min_value: 80 | :return: 81 | """ 82 | if min_value is None: 83 | min_value = divisor 84 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 85 | # Make sure that round down does not go down by more than 10%. 86 | if new_v < 0.9 * v: 87 | new_v += divisor 88 | return new_v 89 | 90 | 91 | class ConvBNReLU(nn.Sequential): 92 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): 93 | self.channel_pad = out_planes - in_planes 94 | self.stride = stride 95 | #padding = (kernel_size - 1) // 2 96 | 97 | # TFLite uses slightly different padding than PyTorch 98 | if stride == 2: 99 | padding = 0 100 | else: 101 | padding = (kernel_size - 1) // 2 102 | 103 | super(ConvBNReLU, self).__init__( 104 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), 105 | nn.BatchNorm2d(out_planes), 106 | nn.ReLU6(inplace=True) 107 | ) 108 | self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride) 109 | 110 | 111 | def forward(self, x): 112 | # TFLite uses different padding 113 | if self.stride == 2: 114 | x = F.pad(x, (0, 1, 0, 1), "constant", 0) 115 | #print(x.shape) 116 | 117 | for module in self: 118 | if not isinstance(module, nn.MaxPool2d): 119 | x = module(x) 120 | return x 121 | 122 | 123 | class InvertedResidual(nn.Module): 124 | def __init__(self, inp, oup, stride, expand_ratio): 125 | super(InvertedResidual, self).__init__() 126 | self.stride = stride 127 | assert stride in [1, 2] 128 | 129 | hidden_dim = int(round(inp * expand_ratio)) 130 | self.use_res_connect = self.stride == 1 and inp == oup 131 | 132 | layers = [] 133 | if expand_ratio != 1: 134 | # pw 135 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) 136 | layers.extend([ 137 | # dw 138 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), 139 | # pw-linear 140 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 141 | nn.BatchNorm2d(oup), 142 | ]) 143 | self.conv = nn.Sequential(*layers) 144 | 145 | def forward(self, x): 146 | if self.use_res_connect: 147 | return x + self.conv(x) 148 | else: 149 | return self.conv(x) 150 | 151 | 152 | class MobileNetV2(nn.Module): 153 | def __init__(self, pretrained=True): 154 | """ 155 | MobileNet V2 main class 156 | Args: 157 | num_classes (int): Number of classes 158 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 159 | inverted_residual_setting: Network structure 160 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 161 | Set to 1 to turn off rounding 162 | block: Module specifying inverted residual building block for mobilenet 163 | """ 164 | super(MobileNetV2, self).__init__() 165 | 166 | block = InvertedResidual 167 | input_channel = 32 168 | last_channel = 1280 169 | width_mult = 1.0 170 | round_nearest = 8 171 | 172 | inverted_residual_setting = [ 173 | # t, c, n, s 174 | [1, 16, 1, 1], 175 | [6, 24, 2, 2], 176 | [6, 32, 3, 2], 177 | [6, 64, 4, 2], 178 | #[6, 96, 3, 1], 179 | #[6, 160, 3, 2], 180 | #[6, 320, 1, 1], 181 | ] 182 | 183 | # only check the first element, assuming user knows t,c,n,s are required 184 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 185 | raise ValueError("inverted_residual_setting should be non-empty " 186 | "or a 4-element list, got {}".format(inverted_residual_setting)) 187 | 188 | # building first layer 189 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 190 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 191 | features = [ConvBNReLU(4, input_channel, stride=2)] 192 | # building inverted residual blocks 193 | for t, c, n, s in inverted_residual_setting: 194 | output_channel = _make_divisible(c * width_mult, round_nearest) 195 | for i in range(n): 196 | stride = s if i == 0 else 1 197 | features.append(block(input_channel, output_channel, stride, expand_ratio=t)) 198 | input_channel = output_channel 199 | self.features = nn.Sequential(*features) 200 | 201 | self.fpn_selected = [3, 6, 10] 202 | # weight initialization 203 | for m in self.modules(): 204 | if isinstance(m, nn.Conv2d): 205 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 206 | if m.bias is not None: 207 | nn.init.zeros_(m.bias) 208 | elif isinstance(m, nn.BatchNorm2d): 209 | nn.init.ones_(m.weight) 210 | nn.init.zeros_(m.bias) 211 | elif isinstance(m, nn.Linear): 212 | nn.init.normal_(m.weight, 0, 0.01) 213 | nn.init.zeros_(m.bias) 214 | 215 | #if pretrained: 216 | # self._load_pretrained_model() 217 | 218 | def _forward_impl(self, x): 219 | # This exists since TorchScript doesn't support inheritance, so the superclass method 220 | # (this one) needs to have a name other than `forward` that can be accessed in a subclass 221 | fpn_features = [] 222 | for i, f in enumerate(self.features): 223 | if i > self.fpn_selected[-1]: 224 | break 225 | x = f(x) 226 | if i in self.fpn_selected: 227 | fpn_features.append(x) 228 | 229 | c2, c3, c4 = fpn_features 230 | return c2, c3, c4 231 | 232 | 233 | def forward(self, x): 234 | return self._forward_impl(x) 235 | 236 | def _load_pretrained_model(self): 237 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth') 238 | model_dict = {} 239 | state_dict = self.state_dict() 240 | for k, v in pretrain_dict.items(): 241 | if k in state_dict: 242 | model_dict[k] = v 243 | state_dict.update(model_dict) 244 | self.load_state_dict(state_dict) 245 | 246 | 247 | class MobileV2_MLSD_Tiny(nn.Module): 248 | def __init__(self): 249 | super(MobileV2_MLSD_Tiny, self).__init__() 250 | 251 | self.backbone = MobileNetV2(pretrained=True) 252 | 253 | self.block12 = BlockTypeA(in_c1= 32, in_c2= 64, 254 | out_c1= 64, out_c2=64) 255 | self.block13 = BlockTypeB(128, 64) 256 | 257 | self.block14 = BlockTypeA(in_c1 = 24, in_c2 = 64, 258 | out_c1= 32, out_c2= 32) 259 | self.block15 = BlockTypeB(64, 64) 260 | 261 | self.block16 = BlockTypeC(64, 16) 262 | 263 | def forward(self, x): 264 | c2, c3, c4 = self.backbone(x) 265 | 266 | x = self.block12(c3, c4) 267 | x = self.block13(x) 268 | x = self.block14(c2, x) 269 | x = self.block15(x) 270 | x = self.block16(x) 271 | x = x[:, 7:, :, :] 272 | #print(x.shape) 273 | x = F.interpolate(x, scale_factor=2.0, mode='bilinear', align_corners=True) 274 | 275 | return x -------------------------------------------------------------------------------- /annotator/openpose/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" 3 | 4 | import torch 5 | import numpy as np 6 | from . import util 7 | from .body import Body 8 | from .hand import Hand 9 | from annotator.util import annotator_ckpts_path 10 | 11 | 12 | body_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/body_pose_model.pth" 13 | hand_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/hand_pose_model.pth" 14 | 15 | 16 | class OpenposeDetector: 17 | def __init__(self): 18 | body_modelpath = os.path.join(annotator_ckpts_path, "body_pose_model.pth") 19 | hand_modelpath = os.path.join(annotator_ckpts_path, "hand_pose_model.pth") 20 | 21 | if not os.path.exists(hand_modelpath): 22 | from basicsr.utils.download_util import load_file_from_url 23 | load_file_from_url(body_model_path, model_dir=annotator_ckpts_path) 24 | load_file_from_url(hand_model_path, model_dir=annotator_ckpts_path) 25 | 26 | self.body_estimation = Body(body_modelpath) 27 | self.hand_estimation = Hand(hand_modelpath) 28 | 29 | def __call__(self, oriImg, hand=False): 30 | oriImg = oriImg[:, :, ::-1].copy() 31 | with torch.no_grad(): 32 | candidate, subset = self.body_estimation(oriImg) 33 | canvas = np.zeros_like(oriImg) 34 | canvas = util.draw_bodypose(canvas, candidate, subset) 35 | if hand: 36 | hands_list = util.handDetect(candidate, subset, oriImg) 37 | all_hand_peaks = [] 38 | for x, y, w, is_left in hands_list: 39 | peaks = self.hand_estimation(oriImg[y:y+w, x:x+w, :]) 40 | peaks[:, 0] = np.where(peaks[:, 0] == 0, peaks[:, 0], peaks[:, 0] + x) 41 | peaks[:, 1] = np.where(peaks[:, 1] == 0, peaks[:, 1], peaks[:, 1] + y) 42 | all_hand_peaks.append(peaks) 43 | canvas = util.draw_handpose(canvas, all_hand_peaks) 44 | return canvas, dict(candidate=candidate.tolist(), subset=subset.tolist()) 45 | -------------------------------------------------------------------------------- /annotator/openpose/hand.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import json 3 | import numpy as np 4 | import math 5 | import time 6 | from scipy.ndimage.filters import gaussian_filter 7 | import matplotlib.pyplot as plt 8 | import matplotlib 9 | import torch 10 | from skimage.measure import label 11 | 12 | from .model import handpose_model 13 | from . import util 14 | 15 | class Hand(object): 16 | def __init__(self, model_path): 17 | self.model = handpose_model() 18 | if torch.cuda.is_available(): 19 | self.model = self.model.cuda() 20 | print('cuda') 21 | model_dict = util.transfer(self.model, torch.load(model_path)) 22 | self.model.load_state_dict(model_dict) 23 | self.model.eval() 24 | 25 | def __call__(self, oriImg): 26 | scale_search = [0.5, 1.0, 1.5, 2.0] 27 | # scale_search = [0.5] 28 | boxsize = 368 29 | stride = 8 30 | padValue = 128 31 | thre = 0.05 32 | multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search] 33 | heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 22)) 34 | # paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38)) 35 | 36 | for m in range(len(multiplier)): 37 | scale = multiplier[m] 38 | imageToTest = cv2.resize(oriImg, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC) 39 | imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue) 40 | im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5 41 | im = np.ascontiguousarray(im) 42 | 43 | data = torch.from_numpy(im).float() 44 | if torch.cuda.is_available(): 45 | data = data.cuda() 46 | # data = data.permute([2, 0, 1]).unsqueeze(0).float() 47 | with torch.no_grad(): 48 | output = self.model(data).cpu().numpy() 49 | # output = self.model(data).numpy()q 50 | 51 | # extract outputs, resize, and remove padding 52 | heatmap = np.transpose(np.squeeze(output), (1, 2, 0)) # output 1 is heatmaps 53 | heatmap = cv2.resize(heatmap, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC) 54 | heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :] 55 | heatmap = cv2.resize(heatmap, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC) 56 | 57 | heatmap_avg += heatmap / len(multiplier) 58 | 59 | all_peaks = [] 60 | for part in range(21): 61 | map_ori = heatmap_avg[:, :, part] 62 | one_heatmap = gaussian_filter(map_ori, sigma=3) 63 | binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8) 64 | # 全部小于阈值 65 | if np.sum(binary) == 0: 66 | all_peaks.append([0, 0]) 67 | continue 68 | label_img, label_numbers = label(binary, return_num=True, connectivity=binary.ndim) 69 | max_index = np.argmax([np.sum(map_ori[label_img == i]) for i in range(1, label_numbers + 1)]) + 1 70 | label_img[label_img != max_index] = 0 71 | map_ori[label_img == 0] = 0 72 | 73 | y, x = util.npmax(map_ori) 74 | all_peaks.append([x, y]) 75 | return np.array(all_peaks) 76 | 77 | if __name__ == "__main__": 78 | hand_estimation = Hand('../model/hand_pose_model.pth') 79 | 80 | # test_image = '../images/hand.jpg' 81 | test_image = '../images/hand.jpg' 82 | oriImg = cv2.imread(test_image) # B,G,R order 83 | peaks = hand_estimation(oriImg) 84 | canvas = util.draw_handpose(oriImg, peaks, True) 85 | cv2.imshow('', canvas) 86 | cv2.waitKey(0) -------------------------------------------------------------------------------- /annotator/openpose/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import OrderedDict 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | def make_layers(block, no_relu_layers): 8 | layers = [] 9 | for layer_name, v in block.items(): 10 | if 'pool' in layer_name: 11 | layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1], 12 | padding=v[2]) 13 | layers.append((layer_name, layer)) 14 | else: 15 | conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1], 16 | kernel_size=v[2], stride=v[3], 17 | padding=v[4]) 18 | layers.append((layer_name, conv2d)) 19 | if layer_name not in no_relu_layers: 20 | layers.append(('relu_'+layer_name, nn.ReLU(inplace=True))) 21 | 22 | return nn.Sequential(OrderedDict(layers)) 23 | 24 | class bodypose_model(nn.Module): 25 | def __init__(self): 26 | super(bodypose_model, self).__init__() 27 | 28 | # these layers have no relu layer 29 | no_relu_layers = ['conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1',\ 30 | 'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2',\ 31 | 'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1',\ 32 | 'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1'] 33 | blocks = {} 34 | block0 = OrderedDict([ 35 | ('conv1_1', [3, 64, 3, 1, 1]), 36 | ('conv1_2', [64, 64, 3, 1, 1]), 37 | ('pool1_stage1', [2, 2, 0]), 38 | ('conv2_1', [64, 128, 3, 1, 1]), 39 | ('conv2_2', [128, 128, 3, 1, 1]), 40 | ('pool2_stage1', [2, 2, 0]), 41 | ('conv3_1', [128, 256, 3, 1, 1]), 42 | ('conv3_2', [256, 256, 3, 1, 1]), 43 | ('conv3_3', [256, 256, 3, 1, 1]), 44 | ('conv3_4', [256, 256, 3, 1, 1]), 45 | ('pool3_stage1', [2, 2, 0]), 46 | ('conv4_1', [256, 512, 3, 1, 1]), 47 | ('conv4_2', [512, 512, 3, 1, 1]), 48 | ('conv4_3_CPM', [512, 256, 3, 1, 1]), 49 | ('conv4_4_CPM', [256, 128, 3, 1, 1]) 50 | ]) 51 | 52 | 53 | # Stage 1 54 | block1_1 = OrderedDict([ 55 | ('conv5_1_CPM_L1', [128, 128, 3, 1, 1]), 56 | ('conv5_2_CPM_L1', [128, 128, 3, 1, 1]), 57 | ('conv5_3_CPM_L1', [128, 128, 3, 1, 1]), 58 | ('conv5_4_CPM_L1', [128, 512, 1, 1, 0]), 59 | ('conv5_5_CPM_L1', [512, 38, 1, 1, 0]) 60 | ]) 61 | 62 | block1_2 = OrderedDict([ 63 | ('conv5_1_CPM_L2', [128, 128, 3, 1, 1]), 64 | ('conv5_2_CPM_L2', [128, 128, 3, 1, 1]), 65 | ('conv5_3_CPM_L2', [128, 128, 3, 1, 1]), 66 | ('conv5_4_CPM_L2', [128, 512, 1, 1, 0]), 67 | ('conv5_5_CPM_L2', [512, 19, 1, 1, 0]) 68 | ]) 69 | blocks['block1_1'] = block1_1 70 | blocks['block1_2'] = block1_2 71 | 72 | self.model0 = make_layers(block0, no_relu_layers) 73 | 74 | # Stages 2 - 6 75 | for i in range(2, 7): 76 | blocks['block%d_1' % i] = OrderedDict([ 77 | ('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]), 78 | ('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]), 79 | ('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]), 80 | ('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]), 81 | ('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]), 82 | ('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]), 83 | ('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0]) 84 | ]) 85 | 86 | blocks['block%d_2' % i] = OrderedDict([ 87 | ('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]), 88 | ('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]), 89 | ('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]), 90 | ('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]), 91 | ('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]), 92 | ('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]), 93 | ('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0]) 94 | ]) 95 | 96 | for k in blocks.keys(): 97 | blocks[k] = make_layers(blocks[k], no_relu_layers) 98 | 99 | self.model1_1 = blocks['block1_1'] 100 | self.model2_1 = blocks['block2_1'] 101 | self.model3_1 = blocks['block3_1'] 102 | self.model4_1 = blocks['block4_1'] 103 | self.model5_1 = blocks['block5_1'] 104 | self.model6_1 = blocks['block6_1'] 105 | 106 | self.model1_2 = blocks['block1_2'] 107 | self.model2_2 = blocks['block2_2'] 108 | self.model3_2 = blocks['block3_2'] 109 | self.model4_2 = blocks['block4_2'] 110 | self.model5_2 = blocks['block5_2'] 111 | self.model6_2 = blocks['block6_2'] 112 | 113 | 114 | def forward(self, x): 115 | 116 | out1 = self.model0(x) 117 | 118 | out1_1 = self.model1_1(out1) 119 | out1_2 = self.model1_2(out1) 120 | out2 = torch.cat([out1_1, out1_2, out1], 1) 121 | 122 | out2_1 = self.model2_1(out2) 123 | out2_2 = self.model2_2(out2) 124 | out3 = torch.cat([out2_1, out2_2, out1], 1) 125 | 126 | out3_1 = self.model3_1(out3) 127 | out3_2 = self.model3_2(out3) 128 | out4 = torch.cat([out3_1, out3_2, out1], 1) 129 | 130 | out4_1 = self.model4_1(out4) 131 | out4_2 = self.model4_2(out4) 132 | out5 = torch.cat([out4_1, out4_2, out1], 1) 133 | 134 | out5_1 = self.model5_1(out5) 135 | out5_2 = self.model5_2(out5) 136 | out6 = torch.cat([out5_1, out5_2, out1], 1) 137 | 138 | out6_1 = self.model6_1(out6) 139 | out6_2 = self.model6_2(out6) 140 | 141 | return out6_1, out6_2 142 | 143 | class handpose_model(nn.Module): 144 | def __init__(self): 145 | super(handpose_model, self).__init__() 146 | 147 | # these layers have no relu layer 148 | no_relu_layers = ['conv6_2_CPM', 'Mconv7_stage2', 'Mconv7_stage3',\ 149 | 'Mconv7_stage4', 'Mconv7_stage5', 'Mconv7_stage6'] 150 | # stage 1 151 | block1_0 = OrderedDict([ 152 | ('conv1_1', [3, 64, 3, 1, 1]), 153 | ('conv1_2', [64, 64, 3, 1, 1]), 154 | ('pool1_stage1', [2, 2, 0]), 155 | ('conv2_1', [64, 128, 3, 1, 1]), 156 | ('conv2_2', [128, 128, 3, 1, 1]), 157 | ('pool2_stage1', [2, 2, 0]), 158 | ('conv3_1', [128, 256, 3, 1, 1]), 159 | ('conv3_2', [256, 256, 3, 1, 1]), 160 | ('conv3_3', [256, 256, 3, 1, 1]), 161 | ('conv3_4', [256, 256, 3, 1, 1]), 162 | ('pool3_stage1', [2, 2, 0]), 163 | ('conv4_1', [256, 512, 3, 1, 1]), 164 | ('conv4_2', [512, 512, 3, 1, 1]), 165 | ('conv4_3', [512, 512, 3, 1, 1]), 166 | ('conv4_4', [512, 512, 3, 1, 1]), 167 | ('conv5_1', [512, 512, 3, 1, 1]), 168 | ('conv5_2', [512, 512, 3, 1, 1]), 169 | ('conv5_3_CPM', [512, 128, 3, 1, 1]) 170 | ]) 171 | 172 | block1_1 = OrderedDict([ 173 | ('conv6_1_CPM', [128, 512, 1, 1, 0]), 174 | ('conv6_2_CPM', [512, 22, 1, 1, 0]) 175 | ]) 176 | 177 | blocks = {} 178 | blocks['block1_0'] = block1_0 179 | blocks['block1_1'] = block1_1 180 | 181 | # stage 2-6 182 | for i in range(2, 7): 183 | blocks['block%d' % i] = OrderedDict([ 184 | ('Mconv1_stage%d' % i, [150, 128, 7, 1, 3]), 185 | ('Mconv2_stage%d' % i, [128, 128, 7, 1, 3]), 186 | ('Mconv3_stage%d' % i, [128, 128, 7, 1, 3]), 187 | ('Mconv4_stage%d' % i, [128, 128, 7, 1, 3]), 188 | ('Mconv5_stage%d' % i, [128, 128, 7, 1, 3]), 189 | ('Mconv6_stage%d' % i, [128, 128, 1, 1, 0]), 190 | ('Mconv7_stage%d' % i, [128, 22, 1, 1, 0]) 191 | ]) 192 | 193 | for k in blocks.keys(): 194 | blocks[k] = make_layers(blocks[k], no_relu_layers) 195 | 196 | self.model1_0 = blocks['block1_0'] 197 | self.model1_1 = blocks['block1_1'] 198 | self.model2 = blocks['block2'] 199 | self.model3 = blocks['block3'] 200 | self.model4 = blocks['block4'] 201 | self.model5 = blocks['block5'] 202 | self.model6 = blocks['block6'] 203 | 204 | def forward(self, x): 205 | out1_0 = self.model1_0(x) 206 | out1_1 = self.model1_1(out1_0) 207 | concat_stage2 = torch.cat([out1_1, out1_0], 1) 208 | out_stage2 = self.model2(concat_stage2) 209 | concat_stage3 = torch.cat([out_stage2, out1_0], 1) 210 | out_stage3 = self.model3(concat_stage3) 211 | concat_stage4 = torch.cat([out_stage3, out1_0], 1) 212 | out_stage4 = self.model4(concat_stage4) 213 | concat_stage5 = torch.cat([out_stage4, out1_0], 1) 214 | out_stage5 = self.model5(concat_stage5) 215 | concat_stage6 = torch.cat([out_stage5, out1_0], 1) 216 | out_stage6 = self.model6(concat_stage6) 217 | return out_stage6 218 | 219 | 220 | -------------------------------------------------------------------------------- /annotator/openpose/util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import matplotlib 4 | import cv2 5 | 6 | 7 | def padRightDownCorner(img, stride, padValue): 8 | h = img.shape[0] 9 | w = img.shape[1] 10 | 11 | pad = 4 * [None] 12 | pad[0] = 0 # up 13 | pad[1] = 0 # left 14 | pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down 15 | pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right 16 | 17 | img_padded = img 18 | pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1)) 19 | img_padded = np.concatenate((pad_up, img_padded), axis=0) 20 | pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1)) 21 | img_padded = np.concatenate((pad_left, img_padded), axis=1) 22 | pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1)) 23 | img_padded = np.concatenate((img_padded, pad_down), axis=0) 24 | pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1)) 25 | img_padded = np.concatenate((img_padded, pad_right), axis=1) 26 | 27 | return img_padded, pad 28 | 29 | # transfer caffe model to pytorch which will match the layer name 30 | def transfer(model, model_weights): 31 | transfered_model_weights = {} 32 | for weights_name in model.state_dict().keys(): 33 | transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])] 34 | return transfered_model_weights 35 | 36 | # draw the body keypoint and lims 37 | def draw_bodypose(canvas, candidate, subset): 38 | stickwidth = 4 39 | limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ 40 | [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ 41 | [1, 16], [16, 18], [3, 17], [6, 18]] 42 | 43 | colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ 44 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ 45 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] 46 | for i in range(18): 47 | for n in range(len(subset)): 48 | index = int(subset[n][i]) 49 | if index == -1: 50 | continue 51 | x, y = candidate[index][0:2] 52 | cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1) 53 | for i in range(17): 54 | for n in range(len(subset)): 55 | index = subset[n][np.array(limbSeq[i]) - 1] 56 | if -1 in index: 57 | continue 58 | cur_canvas = canvas.copy() 59 | Y = candidate[index.astype(int), 0] 60 | X = candidate[index.astype(int), 1] 61 | mX = np.mean(X) 62 | mY = np.mean(Y) 63 | length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 64 | angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) 65 | polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) 66 | cv2.fillConvexPoly(cur_canvas, polygon, colors[i]) 67 | canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0) 68 | # plt.imsave("preview.jpg", canvas[:, :, [2, 1, 0]]) 69 | # plt.imshow(canvas[:, :, [2, 1, 0]]) 70 | return canvas 71 | 72 | 73 | # image drawed by opencv is not good. 74 | def draw_handpose(canvas, all_hand_peaks, show_number=False): 75 | edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \ 76 | [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]] 77 | 78 | for peaks in all_hand_peaks: 79 | for ie, e in enumerate(edges): 80 | if np.sum(np.all(peaks[e], axis=1)==0)==0: 81 | x1, y1 = peaks[e[0]] 82 | x2, y2 = peaks[e[1]] 83 | cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie/float(len(edges)), 1.0, 1.0])*255, thickness=2) 84 | 85 | for i, keyponit in enumerate(peaks): 86 | x, y = keyponit 87 | cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1) 88 | if show_number: 89 | cv2.putText(canvas, str(i), (x, y), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (0, 0, 0), lineType=cv2.LINE_AA) 90 | return canvas 91 | 92 | # detect hand according to body pose keypoints 93 | # please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp 94 | def handDetect(candidate, subset, oriImg): 95 | # right hand: wrist 4, elbow 3, shoulder 2 96 | # left hand: wrist 7, elbow 6, shoulder 5 97 | ratioWristElbow = 0.33 98 | detect_result = [] 99 | image_height, image_width = oriImg.shape[0:2] 100 | for person in subset.astype(int): 101 | # if any of three not detected 102 | has_left = np.sum(person[[5, 6, 7]] == -1) == 0 103 | has_right = np.sum(person[[2, 3, 4]] == -1) == 0 104 | if not (has_left or has_right): 105 | continue 106 | hands = [] 107 | #left hand 108 | if has_left: 109 | left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]] 110 | x1, y1 = candidate[left_shoulder_index][:2] 111 | x2, y2 = candidate[left_elbow_index][:2] 112 | x3, y3 = candidate[left_wrist_index][:2] 113 | hands.append([x1, y1, x2, y2, x3, y3, True]) 114 | # right hand 115 | if has_right: 116 | right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]] 117 | x1, y1 = candidate[right_shoulder_index][:2] 118 | x2, y2 = candidate[right_elbow_index][:2] 119 | x3, y3 = candidate[right_wrist_index][:2] 120 | hands.append([x1, y1, x2, y2, x3, y3, False]) 121 | 122 | for x1, y1, x2, y2, x3, y3, is_left in hands: 123 | # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox 124 | # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]); 125 | # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]); 126 | # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow); 127 | # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder); 128 | # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder); 129 | x = x3 + ratioWristElbow * (x3 - x2) 130 | y = y3 + ratioWristElbow * (y3 - y2) 131 | distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2) 132 | distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2) 133 | width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder) 134 | # x-y refers to the center --> offset to topLeft point 135 | # handRectangle.x -= handRectangle.width / 2.f; 136 | # handRectangle.y -= handRectangle.height / 2.f; 137 | x -= width / 2 138 | y -= width / 2 # width = height 139 | # overflow the image 140 | if x < 0: x = 0 141 | if y < 0: y = 0 142 | width1 = width 143 | width2 = width 144 | if x + width > image_width: width1 = image_width - x 145 | if y + width > image_height: width2 = image_height - y 146 | width = min(width1, width2) 147 | # the max hand box value is 20 pixels 148 | if width >= 20: 149 | detect_result.append([int(x), int(y), int(width), is_left]) 150 | 151 | ''' 152 | return value: [[x, y, w, True if left hand else False]]. 153 | width=height since the network require squared input. 154 | x, y is the coordinate of top left 155 | ''' 156 | return detect_result 157 | 158 | # get max index of 2d array 159 | def npmax(array): 160 | arrayindex = array.argmax(1) 161 | arrayvalue = array.max(1) 162 | i = arrayvalue.argmax() 163 | j = arrayindex[i] 164 | return i, j 165 | -------------------------------------------------------------------------------- /annotator/uniformer/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from annotator.uniformer.mmseg.apis import init_segmentor, inference_segmentor, show_result_pyplot 4 | from annotator.uniformer.mmseg.core.evaluation import get_palette 5 | from annotator.util import annotator_ckpts_path 6 | 7 | 8 | checkpoint_file = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/upernet_global_small.pth" 9 | 10 | 11 | class UniformerDetector: 12 | def __init__(self): 13 | modelpath = os.path.join(annotator_ckpts_path, "upernet_global_small.pth") 14 | if not os.path.exists(modelpath): 15 | from basicsr.utils.download_util import load_file_from_url 16 | load_file_from_url(checkpoint_file, model_dir=annotator_ckpts_path) 17 | config_file = os.path.join(os.path.dirname(annotator_ckpts_path), "uniformer", "exp", "upernet_global_small", "config.py") 18 | self.model = init_segmentor(config_file, modelpath).cuda() 19 | 20 | def __call__(self, img): 21 | result = inference_segmentor(self.model, img) 22 | res_img = show_result_pyplot(self.model, img, result, get_palette('ade'), opacity=1) 23 | return res_img 24 | -------------------------------------------------------------------------------- /annotator/uniformer/configs/_base_/default_runtime.py: -------------------------------------------------------------------------------- 1 | # yapf:disable 2 | log_config = dict( 3 | interval=50, 4 | hooks=[ 5 | dict(type='TextLoggerHook', by_epoch=False), 6 | # dict(type='TensorboardLoggerHook') 7 | ]) 8 | # yapf:enable 9 | dist_params = dict(backend='nccl') 10 | log_level = 'INFO' 11 | load_from = None 12 | resume_from = None 13 | workflow = [('train', 1)] 14 | cudnn_benchmark = True 15 | -------------------------------------------------------------------------------- /annotator/uniformer/configs/_base_/schedules/schedule_160k.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) 3 | optimizer_config = dict() 4 | # learning policy 5 | lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) 6 | # runtime settings 7 | runner = dict(type='IterBasedRunner', max_iters=160000) 8 | checkpoint_config = dict(by_epoch=False, interval=16000) 9 | evaluation = dict(interval=16000, metric='mIoU') 10 | -------------------------------------------------------------------------------- /annotator/uniformer/configs/_base_/schedules/schedule_20k.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) 3 | optimizer_config = dict() 4 | # learning policy 5 | lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) 6 | # runtime settings 7 | runner = dict(type='IterBasedRunner', max_iters=20000) 8 | checkpoint_config = dict(by_epoch=False, interval=2000) 9 | evaluation = dict(interval=2000, metric='mIoU') 10 | -------------------------------------------------------------------------------- /annotator/uniformer/configs/_base_/schedules/schedule_40k.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) 3 | optimizer_config = dict() 4 | # learning policy 5 | lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) 6 | # runtime settings 7 | runner = dict(type='IterBasedRunner', max_iters=40000) 8 | checkpoint_config = dict(by_epoch=False, interval=4000) 9 | evaluation = dict(interval=4000, metric='mIoU') 10 | -------------------------------------------------------------------------------- /annotator/uniformer/configs/_base_/schedules/schedule_80k.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) 3 | optimizer_config = dict() 4 | # learning policy 5 | lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) 6 | # runtime settings 7 | runner = dict(type='IterBasedRunner', max_iters=80000) 8 | checkpoint_config = dict(by_epoch=False, interval=8000) 9 | evaluation = dict(interval=8000, metric='mIoU') 10 | -------------------------------------------------------------------------------- /annotator/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | 5 | 6 | annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts') 7 | 8 | 9 | def HWC3(x): 10 | assert x.dtype == np.uint8 11 | if x.ndim == 2: 12 | x = x[:, :, None] 13 | assert x.ndim == 3 14 | H, W, C = x.shape 15 | assert C == 1 or C == 3 or C == 4 16 | if C == 3: 17 | return x 18 | if C == 1: 19 | return np.concatenate([x, x, x], axis=2) 20 | if C == 4: 21 | color = x[:, :, 0:3].astype(np.float32) 22 | alpha = x[:, :, 3:4].astype(np.float32) / 255.0 23 | y = color * alpha + 255.0 * (1.0 - alpha) 24 | y = y.clip(0, 255).astype(np.uint8) 25 | return y 26 | 27 | 28 | def resize_image(input_image, resolution): 29 | H, W, C = input_image.shape 30 | H = float(H) 31 | W = float(W) 32 | k = float(resolution) / min(H, W) 33 | H *= k 34 | W *= k 35 | H = int(np.round(H / 64.0)) * 64 36 | W = int(np.round(W / 64.0)) * 64 37 | img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) 38 | return img 39 | -------------------------------------------------------------------------------- /cldm/hack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import einops 3 | 4 | import ldm.modules.encoders.modules 5 | import ldm.modules.attention 6 | 7 | from transformers import logging 8 | from ldm.modules.attention import default 9 | 10 | 11 | def disable_verbosity(): 12 | logging.set_verbosity_error() 13 | print('logging improved.') 14 | return 15 | 16 | 17 | def enable_sliced_attention(): 18 | ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward 19 | print('Enabled sliced_attention.') 20 | return 21 | 22 | 23 | def hack_everything(clip_skip=0): 24 | disable_verbosity() 25 | ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward 26 | ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip 27 | print('Enabled clip hacks.') 28 | return 29 | 30 | 31 | # Written by Lvmin 32 | def _hacked_clip_forward(self, text): 33 | PAD = self.tokenizer.pad_token_id 34 | EOS = self.tokenizer.eos_token_id 35 | BOS = self.tokenizer.bos_token_id 36 | 37 | def tokenize(t): 38 | return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"] 39 | 40 | def transformer_encode(t): 41 | if self.clip_skip > 1: 42 | rt = self.transformer(input_ids=t, output_hidden_states=True) 43 | return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip]) 44 | else: 45 | return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state 46 | 47 | def split(x): 48 | return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3] 49 | 50 | def pad(x, p, i): 51 | return x[:i] if len(x) >= i else x + [p] * (i - len(x)) 52 | 53 | raw_tokens_list = tokenize(text) 54 | tokens_list = [] 55 | 56 | for raw_tokens in raw_tokens_list: 57 | raw_tokens_123 = split(raw_tokens) 58 | raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123] 59 | raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123] 60 | tokens_list.append(raw_tokens_123) 61 | 62 | tokens_list = torch.IntTensor(tokens_list).to(self.device) 63 | 64 | feed = einops.rearrange(tokens_list, 'b f i -> (b f) i') 65 | y = transformer_encode(feed) 66 | z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3) 67 | 68 | return z 69 | 70 | 71 | # Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py 72 | def _hacked_sliced_attentin_forward(self, x, context=None, mask=None): 73 | h = self.heads 74 | 75 | q = self.to_q(x) 76 | context = default(context, x) 77 | k = self.to_k(context) 78 | v = self.to_v(context) 79 | del context, x 80 | 81 | q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 82 | 83 | limit = k.shape[0] 84 | att_step = 1 85 | q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0)) 86 | k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0)) 87 | v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0)) 88 | 89 | q_chunks.reverse() 90 | k_chunks.reverse() 91 | v_chunks.reverse() 92 | sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) 93 | del k, q, v 94 | for i in range(0, limit, att_step): 95 | q_buffer = q_chunks.pop() 96 | k_buffer = k_chunks.pop() 97 | v_buffer = v_chunks.pop() 98 | sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale 99 | 100 | del k_buffer, q_buffer 101 | # attention, what we cannot get enough of, by chunks 102 | 103 | sim_buffer = sim_buffer.softmax(dim=-1) 104 | 105 | sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer) 106 | del v_buffer 107 | sim[i:i + att_step, :, :] = sim_buffer 108 | 109 | del sim_buffer 110 | sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h) 111 | return self.to_out(sim) 112 | -------------------------------------------------------------------------------- /cldm/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import torchvision 6 | from PIL import Image 7 | from pytorch_lightning.callbacks import Callback 8 | from pytorch_lightning.utilities.distributed import rank_zero_only 9 | 10 | 11 | class ImageLogger(Callback): 12 | def __init__(self, batch_frequency=2000, max_images=4, clamp=True, increase_log_steps=True, 13 | rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False, 14 | log_images_kwargs=None): 15 | super().__init__() 16 | self.rescale = rescale 17 | self.batch_freq = batch_frequency 18 | self.max_images = max_images 19 | if not increase_log_steps: 20 | self.log_steps = [self.batch_freq] 21 | self.clamp = clamp 22 | self.disabled = disabled 23 | self.log_on_batch_idx = log_on_batch_idx 24 | self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} 25 | self.log_first_step = log_first_step 26 | 27 | @rank_zero_only 28 | def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx): 29 | root = os.path.join(save_dir, "image_log", split) 30 | for k in images: 31 | grid = torchvision.utils.make_grid(images[k], nrow=4) 32 | if self.rescale: 33 | grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w 34 | grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) 35 | grid = grid.numpy() 36 | grid = (grid * 255).astype(np.uint8) 37 | filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx) 38 | path = os.path.join(root, filename) 39 | os.makedirs(os.path.split(path)[0], exist_ok=True) 40 | Image.fromarray(grid).save(path) 41 | 42 | def log_img(self, pl_module, batch, batch_idx, split="train"): 43 | check_idx = batch_idx # if self.log_on_batch_idx else pl_module.global_step 44 | if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0 45 | hasattr(pl_module, "log_images") and 46 | callable(pl_module.log_images) and 47 | self.max_images > 0): 48 | logger = type(pl_module.logger) 49 | 50 | is_train = pl_module.training 51 | if is_train: 52 | pl_module.eval() 53 | 54 | with torch.no_grad(): 55 | images = pl_module.log_images(batch, split=split, **self.log_images_kwargs) 56 | 57 | for k in images: 58 | N = min(images[k].shape[0], self.max_images) 59 | images[k] = images[k][:N] 60 | if isinstance(images[k], torch.Tensor): 61 | images[k] = images[k].detach().cpu() 62 | if self.clamp: 63 | images[k] = torch.clamp(images[k], -1., 1.) 64 | 65 | self.log_local(pl_module.logger.save_dir, split, images, 66 | pl_module.global_step, pl_module.current_epoch, batch_idx) 67 | 68 | if is_train: 69 | pl_module.train() 70 | 71 | def check_frequency(self, check_idx): 72 | return check_idx % self.batch_freq == 0 73 | 74 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 75 | if not self.disabled: 76 | self.log_img(pl_module, batch, batch_idx, split="train") 77 | -------------------------------------------------------------------------------- /cldm/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from omegaconf import OmegaConf 5 | from ldm.util import instantiate_from_config 6 | 7 | 8 | def get_state_dict(d): 9 | return d.get('state_dict', d) 10 | 11 | 12 | def load_state_dict(ckpt_path, location='cpu'): 13 | _, extension = os.path.splitext(ckpt_path) 14 | if extension.lower() == ".safetensors": 15 | import safetensors.torch 16 | state_dict = safetensors.torch.load_file(ckpt_path, device=location) 17 | else: 18 | state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location))) 19 | state_dict = get_state_dict(state_dict) 20 | print(f'Loaded state_dict from [{ckpt_path}]') 21 | return state_dict 22 | 23 | 24 | def create_model(config_path): 25 | config = OmegaConf.load(config_path) 26 | model = instantiate_from_config(config.model).cpu() 27 | print(f'Loaded model config from [{config_path}]') 28 | return model 29 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | save_memory = False 2 | -------------------------------------------------------------------------------- /cycleNet/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import torchvision 6 | from PIL import Image 7 | from pytorch_lightning.callbacks import Callback 8 | from pytorch_lightning.utilities.distributed import rank_zero_only 9 | 10 | 11 | class ImageLogger(Callback): 12 | def __init__(self, batch_frequency=2000, max_images=4, every_n_train_steps=1000, clamp=True, increase_log_steps=True, 13 | rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False, 14 | log_images_kwargs=None): 15 | super().__init__() 16 | self.rescale = rescale 17 | self.batch_freq = batch_frequency 18 | self.max_images = max_images 19 | if not increase_log_steps: 20 | self.log_steps = [self.batch_freq] 21 | self.clamp = clamp 22 | self.disabled = disabled 23 | self.log_on_batch_idx = log_on_batch_idx 24 | self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} 25 | self.log_first_step = log_first_step 26 | 27 | @rank_zero_only 28 | def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx): 29 | root = os.path.join(save_dir, "image_log", split) 30 | for k in images: 31 | grid = torchvision.utils.make_grid(images[k], nrow=4) 32 | if self.rescale: 33 | grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w 34 | grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) 35 | grid = grid.numpy() 36 | grid = (grid * 255).astype(np.uint8) 37 | filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx) 38 | path = os.path.join(root, filename) 39 | os.makedirs(os.path.split(path)[0], exist_ok=True) 40 | Image.fromarray(grid).save(path) 41 | 42 | def log_img(self, pl_module, batch, batch_idx, split="train"): 43 | check_idx = batch_idx # if self.log_on_batch_idx else pl_module.global_step 44 | if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0 45 | hasattr(pl_module, "log_images") and 46 | callable(pl_module.log_images) and 47 | self.max_images > 0): 48 | logger = type(pl_module.logger) 49 | 50 | is_train = pl_module.training 51 | if is_train: 52 | pl_module.eval() 53 | 54 | with torch.no_grad(): 55 | images = pl_module.log_images(batch, split=split, **self.log_images_kwargs) 56 | 57 | for k in images: 58 | N = min(images[k].shape[0], self.max_images) 59 | images[k] = images[k][:N] 60 | if isinstance(images[k], torch.Tensor): 61 | images[k] = images[k].detach().cpu() 62 | if self.clamp: 63 | images[k] = torch.clamp(images[k], -1., 1.) 64 | 65 | self.log_local(pl_module.logger.save_dir, split, images, 66 | pl_module.global_step, pl_module.current_epoch, batch_idx) 67 | 68 | if is_train: 69 | pl_module.train() 70 | 71 | def check_frequency(self, check_idx): 72 | return check_idx % self.batch_freq == 0 73 | 74 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 75 | if not self.disabled: 76 | self.log_img(pl_module, batch, batch_idx, split="train") 77 | -------------------------------------------------------------------------------- /cycleNet/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from omegaconf import OmegaConf 5 | from ldm.util import instantiate_from_config 6 | 7 | 8 | def get_state_dict(d): 9 | return d.get('state_dict', d) 10 | 11 | 12 | def load_state_dict(ckpt_path, location): 13 | _, extension = os.path.splitext(ckpt_path) 14 | if extension.lower() == ".safetensors": 15 | import safetensors.torch 16 | state_dict = safetensors.torch.load_file(ckpt_path, device=location) 17 | else: 18 | state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location))) 19 | state_dict = get_state_dict(state_dict) 20 | print(f'Loaded state_dict from [{ckpt_path}]') 21 | return state_dict 22 | 23 | 24 | def create_model(config_path): 25 | config = OmegaConf.load(config_path) 26 | model = instantiate_from_config(config.model).cpu() 27 | print(f'Loaded model config from [{config_path}]') 28 | return model 29 | -------------------------------------------------------------------------------- /docs/train.md: -------------------------------------------------------------------------------- 1 | # Train a CycleNet 2 | 3 | ## Step 1 - Download the dataset 4 | 5 | You can download the [CycleFill50k dataset](https://huggingface.co/datasets/sihanxu/fill50k/tree/main), and put it into the following dir: 6 | 7 | ``` 8 | CycleNet/training/cfill50k/prompt.json 9 | CycleNet/training/cfill50k/target/X.png 10 | ``` 11 | 12 | In the folder "fill50k/target", you will have 50k images of filled circles. 13 | 14 | ![image](https://user-images.githubusercontent.com/103425287/221340033-6efdb02e-712f-495c-a88c-f0046432a0bb.png) 15 | 16 | In the "cfill50k/prompt.json", you will have their filenames with their condition prompts and uncondition prompts. 17 | 18 | ![image](https://user-images.githubusercontent.com/103425287/221340135-92d88e10-465a-4273-8e0d-7cb856c717db.png) 19 | 20 | ## Step 2 - Load the dataset 21 | 22 | Then you can write a script to load the dataset as following(named "tutorial_dataset.py"): 23 | 24 | ```python 25 | import json 26 | import cv2 27 | import numpy as np 28 | 29 | from torch.utils.data import Dataset 30 | 31 | 32 | class MyDataset(Dataset): 33 | def __init__(self): 34 | self.data = [] 35 | with open('./training/cfill50k/prompt.json', 'rt') as f: 36 | for line in f: 37 | self.data.append(json.loads(line)) 38 | 39 | def __len__(self): 40 | return len(self.data) 41 | 42 | def __getitem__(self, idx): 43 | item = self.data[idx] 44 | 45 | image_filename = item['image'] 46 | source = item['source'] 47 | target = item['target'] 48 | 49 | image = cv2.imread('./training/cfill50k/' + image_filename) 50 | 51 | # Do not forget that OpenCV read images in BGR order. 52 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 53 | 54 | image = (image.astype(np.float32) / 127.5) - 1.0 55 | 56 | return dict(jpg=image, source=source, txt=target) 57 | ``` 58 | 59 | And you can use the following script to test: 60 | 61 | ```python 62 | from tutorial_dataset import MyDataset 63 | 64 | dataset = MyDataset() 65 | print(len(dataset)) 66 | 67 | item = dataset[16] 68 | image = item['jpg'] 69 | source = item['source'] 70 | target = item['txt'] 71 | print(image.shape) 72 | print(source) 73 | print(target) 74 | ``` 75 | 76 | The outputs of this simple test on my machine are 77 | 78 | ``` 79 | 80 | ``` 81 | 82 | Do not ask us why we use these three names as mentioned in ControlNet - this is related to the dark history of a library called LDM. 83 | 84 | ## Step 3 - Download the pretrained SD model 85 | 86 | Then you can go to the [offical page of Stable Diffusion](https://huggingface.co/stabilityai/stable-diffusion-2-1-base/tree/main), and download ["v2-1_512-ema-pruned.ckpt"](https://huggingface.co/stabilityai/stable-diffusion-2-1-base/tree/main). 87 | 88 | And you need to use ControlNet to control the net, which can be realized by the script provided by ControlNet like (if your SD filename is "./models/v2-1_512-ema-pruned.ckpt" and you want the script to save the processed model (SD+ControlNet) at location "./models/control_sd21_ini.ckpt": 89 | 90 | ``` 91 | python tool_add_cycle_sd21.py ./models/v2-1_512-ema-pruned.ckpt ./models/cycle_sd21_ini.ckpt 92 | ``` 93 | 94 | You may also use other filenames as long as the command is "python tool_add_control.py input_path output_path". 95 | 96 | The output should be like: 97 | 98 | ![t5](https://user-images.githubusercontent.com/103425287/221340617-dbbf606d-5c79-4934-a168-4c7aca743fa1.png) 99 | 100 | ## Step 4 - Train the CycleNet 101 | 102 | By using the pytorch lighting, the training is very simple. 103 | 104 | You can use the follow code to train the data we built before: 105 | 106 | ```python 107 | from share import * 108 | 109 | import pytorch_lightning as pl 110 | from torch.utils.data import DataLoader 111 | from tutorial_dataset import MyDataset 112 | from cycleNet.logger import ImageLogger 113 | from cycleNet.model import create_model, load_state_dict 114 | 115 | 116 | # Configs 117 | resume_path = './models/cycle_sd21_ini.ckpt' 118 | log_path = './logs' 119 | batch_size_per_gpu = 4 120 | gpus = 1 121 | logger_freq = 300 122 | learning_rate = 1e-5 123 | sd_locked = False 124 | only_mid_control = False 125 | 126 | if __name__ == "__main__": 127 | 128 | # First use cpu to load models. Pytorch Lightning will automatically move it to GPUs. 129 | model = create_model('./models/cycle_v21.yaml').cpu() 130 | model.load_state_dict(load_state_dict(resume_path, location='cpu')) 131 | model.learning_rate = learning_rate 132 | model.sd_locked = sd_locked 133 | model.only_mid_control = only_mid_control 134 | 135 | 136 | # Misc 137 | dataset = MyDataset() 138 | dataloader = DataLoader(dataset, num_workers=0, batch_size=batch_size_per_gpu, shuffle=True) 139 | 140 | logger = ImageLogger(batch_frequency=logger_freq, every_n_train_steps=logger_freq) 141 | trainer = pl.Trainer(accelerator="gpu", devices=gpus, precision=32, callbacks=[logger], default_root_dir=log_path) 142 | trainer.fit(model, dataloader) 143 | ``` 144 | 145 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: cycle 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8.5 7 | - pip=20.3 8 | - cudatoolkit=11.3 9 | - pytorch=1.12.1 10 | - torchvision=0.13.1 11 | - numpy=1.23.1 12 | - pip: 13 | - gradio==3.16.2 14 | - albumentations==1.3.0 15 | - opencv-contrib-python==4.3.0.36 16 | - imageio==2.9.0 17 | - imageio-ffmpeg==0.4.2 18 | - pytorch-lightning==1.5.0 19 | - omegaconf==2.1.1 20 | - test-tube>=0.7.5 21 | - streamlit==1.12.1 22 | - einops==0.3.0 23 | - transformers==4.19.2 24 | - webdataset==0.2.5 25 | - kornia==0.6 26 | - open_clip_torch==2.0.2 27 | - invisible-watermark>=0.1.5 28 | - streamlit-drawable-canvas==0.8.0 29 | - torchmetrics==0.6.0 30 | - timm==0.6.12 31 | - addict==2.4.0 32 | - yapf==0.32.0 33 | - prettytable==3.6.0 34 | - safetensors==0.2.7 35 | - basicsr==1.4.2 -------------------------------------------------------------------------------- /gradio_cycleNet.py: -------------------------------------------------------------------------------- 1 | from share import * 2 | import config 3 | 4 | import cv2 5 | import einops 6 | import gradio as gr 7 | import numpy as np 8 | import torch 9 | import random 10 | 11 | from pytorch_lightning import seed_everything 12 | from annotator.util import resize_image, HWC3 13 | from cycleNet.model import create_model, load_state_dict 14 | from cycleNet.ddim_hacked import DDIMSampler 15 | 16 | 17 | model_name = 'CycleNet' 18 | model = create_model(f'./models/cycle_v21.yaml').cpu() 19 | model.load_state_dict(load_state_dict('./models/cycle_sd21_ini.ckpt', location='cuda'), strict=False) 20 | model = model.cuda() 21 | ddim_sampler = DDIMSampler(model) 22 | 23 | 24 | def process(input_image, target_prompt, source_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, denoise_strength): 25 | global preprocessor 26 | 27 | with torch.no_grad(): 28 | input_image = HWC3(input_image) 29 | detected_map = input_image.copy() 30 | 31 | img = resize_image(input_image, image_resolution) 32 | H, W, C = img.shape 33 | 34 | detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) 35 | 36 | control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 37 | control = torch.stack([control for _ in range(num_samples)], dim=0) 38 | control = einops.rearrange(control, 'b h w c -> b c h w').clone() 39 | 40 | img = torch.from_numpy(img.copy()).float().cuda() / 127.0 - 1.0 41 | img = torch.stack([img for _ in range(num_samples)], dim=0) 42 | img = einops.rearrange(img, 'b h w c -> b c h w').clone() 43 | 44 | if seed == -1: 45 | seed = random.randint(0, 65535) 46 | seed_everything(seed) 47 | 48 | if config.save_memory: 49 | model.low_vram_shift(is_diffusing=False) 50 | 51 | cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([target_prompt] * num_samples)], "uc_crossattn": [model.get_learned_conditioning([source_prompt] * num_samples)]} 52 | un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([source_prompt] * num_samples)], "uc_crossattn": [model.get_learned_conditioning([source_prompt] * num_samples)]} 53 | 54 | if config.save_memory: 55 | model.low_vram_shift(is_diffusing=False) 56 | 57 | ddim_sampler.make_schedule(ddim_steps, ddim_eta=eta, verbose=True) 58 | t_enc = min(int(denoise_strength * ddim_steps), ddim_steps - 1) 59 | z = model.get_first_stage_encoding(model.encode_first_stage(img)) 60 | z_enc = ddim_sampler.stochastic_encode(z, torch.tensor([t_enc] * num_samples).to(model.device)) 61 | 62 | if config.save_memory: 63 | model.low_vram_shift(is_diffusing=True) 64 | 65 | model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) 66 | # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01 67 | 68 | samples = ddim_sampler.decode(z_enc, cond, t_enc, unconditional_guidance_scale=scale, unconditional_conditioning=un_cond) 69 | 70 | if config.save_memory: 71 | model.low_vram_shift(is_diffusing=False) 72 | 73 | x_samples = model.decode_first_stage(samples) 74 | x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) 75 | 76 | results = [x_samples[i] for i in range(num_samples)] 77 | return [input_image] + results 78 | 79 | 80 | block = gr.Blocks().queue() 81 | with block: 82 | with gr.Row(): 83 | gr.Markdown("## CycleNet") 84 | with gr.Row(): 85 | with gr.Column(): 86 | input_image = gr.Image(source='upload', type="numpy") 87 | target_prompt = gr.Textbox(label="Target") 88 | source_prompt = gr.Textbox(label="Source") 89 | run_button = gr.Button(label="Run") 90 | num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1) 91 | seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=12345) 92 | det = gr.Radio(choices=["None"], type="value", value="None", label="Preprocessor") 93 | denoise_strength = gr.Slider(label="Denoising Strength", minimum=0.1, maximum=1.0, value=0.5, step=0.01) 94 | with gr.Accordion("Advanced options", open=False): 95 | image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=2048, value=512, step=64) 96 | strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01) 97 | guess_mode = gr.Checkbox(label='Guess Mode', value=False) 98 | ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=100, step=1) 99 | scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1) 100 | eta = gr.Slider(label="DDIM ETA", minimum=0.0, maximum=1.0, value=1.0, step=0.01) 101 | with gr.Column(): 102 | result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto') 103 | ips = [input_image, target_prompt, source_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, denoise_strength] 104 | run_button.click(fn=process, inputs=ips, outputs=[result_gallery]) 105 | 106 | 107 | block.launch(server_name='127.0.0.1') 108 | -------------------------------------------------------------------------------- /ldm/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sled-group/CycleNet/a1c370aa9a0146c376c1f44d303d5534c1ba4743/ldm/data/__init__.py -------------------------------------------------------------------------------- /ldm/data/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ldm.modules.midas.api import load_midas_transform 4 | 5 | 6 | class AddMiDaS(object): 7 | def __init__(self, model_type): 8 | super().__init__() 9 | self.transform = load_midas_transform(model_type) 10 | 11 | def pt2np(self, x): 12 | x = ((x + 1.0) * .5).detach().cpu().numpy() 13 | return x 14 | 15 | def np2pt(self, x): 16 | x = torch.from_numpy(x) * 2 - 1. 17 | return x 18 | 19 | def __call__(self, sample): 20 | # sample['jpg'] is tensor hwc in [-1, 1] at this point 21 | x = self.pt2np(sample['jpg']) 22 | x = self.transform({"image": x})["image"] 23 | sample['midas_in'] = x 24 | return sample -------------------------------------------------------------------------------- /ldm/models/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | import torch.nn.functional as F 4 | from contextlib import contextmanager 5 | 6 | from ldm.modules.diffusionmodules.model import Encoder, Decoder 7 | from ldm.modules.distributions.distributions import DiagonalGaussianDistribution 8 | 9 | from ldm.util import instantiate_from_config 10 | from ldm.modules.ema import LitEma 11 | 12 | 13 | class AutoencoderKL(pl.LightningModule): 14 | def __init__(self, 15 | ddconfig, 16 | lossconfig, 17 | embed_dim, 18 | ckpt_path=None, 19 | ignore_keys=[], 20 | image_key="image", 21 | colorize_nlabels=None, 22 | monitor=None, 23 | ema_decay=None, 24 | learn_logvar=False 25 | ): 26 | super().__init__() 27 | self.learn_logvar = learn_logvar 28 | self.image_key = image_key 29 | self.encoder = Encoder(**ddconfig) 30 | self.decoder = Decoder(**ddconfig) 31 | self.loss = instantiate_from_config(lossconfig) 32 | assert ddconfig["double_z"] 33 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) 34 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 35 | self.embed_dim = embed_dim 36 | if colorize_nlabels is not None: 37 | assert type(colorize_nlabels)==int 38 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 39 | if monitor is not None: 40 | self.monitor = monitor 41 | 42 | self.use_ema = ema_decay is not None 43 | if self.use_ema: 44 | self.ema_decay = ema_decay 45 | assert 0. < ema_decay < 1. 46 | self.model_ema = LitEma(self, decay=ema_decay) 47 | print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") 48 | 49 | if ckpt_path is not None: 50 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 51 | 52 | def init_from_ckpt(self, path, ignore_keys=list()): 53 | sd = torch.load(path, map_location="cpu")["state_dict"] 54 | keys = list(sd.keys()) 55 | for k in keys: 56 | for ik in ignore_keys: 57 | if k.startswith(ik): 58 | print("Deleting key {} from state_dict.".format(k)) 59 | del sd[k] 60 | self.load_state_dict(sd, strict=False) 61 | print(f"Restored from {path}") 62 | 63 | @contextmanager 64 | def ema_scope(self, context=None): 65 | if self.use_ema: 66 | self.model_ema.store(self.parameters()) 67 | self.model_ema.copy_to(self) 68 | if context is not None: 69 | print(f"{context}: Switched to EMA weights") 70 | try: 71 | yield None 72 | finally: 73 | if self.use_ema: 74 | self.model_ema.restore(self.parameters()) 75 | if context is not None: 76 | print(f"{context}: Restored training weights") 77 | 78 | def on_train_batch_end(self, *args, **kwargs): 79 | if self.use_ema: 80 | self.model_ema(self) 81 | 82 | def encode(self, x): 83 | h = self.encoder(x) 84 | moments = self.quant_conv(h) 85 | posterior = DiagonalGaussianDistribution(moments) 86 | return posterior 87 | 88 | def decode(self, z): 89 | z = self.post_quant_conv(z) 90 | dec = self.decoder(z) 91 | return dec 92 | 93 | def forward(self, input, sample_posterior=True): 94 | posterior = self.encode(input) 95 | if sample_posterior: 96 | z = posterior.sample() 97 | else: 98 | z = posterior.mode() 99 | dec = self.decode(z) 100 | return dec, posterior 101 | 102 | def get_input(self, batch, k): 103 | x = batch[k] 104 | if len(x.shape) == 3: 105 | x = x[..., None] 106 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() 107 | return x 108 | 109 | def training_step(self, batch, batch_idx, optimizer_idx): 110 | inputs = self.get_input(batch, self.image_key) 111 | reconstructions, posterior = self(inputs) 112 | 113 | if optimizer_idx == 0: 114 | # train encoder+decoder+logvar 115 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 116 | last_layer=self.get_last_layer(), split="train") 117 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 118 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) 119 | return aeloss 120 | 121 | if optimizer_idx == 1: 122 | # train the discriminator 123 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 124 | last_layer=self.get_last_layer(), split="train") 125 | 126 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 127 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) 128 | return discloss 129 | 130 | def validation_step(self, batch, batch_idx): 131 | log_dict = self._validation_step(batch, batch_idx) 132 | with self.ema_scope(): 133 | log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") 134 | return log_dict 135 | 136 | def _validation_step(self, batch, batch_idx, postfix=""): 137 | inputs = self.get_input(batch, self.image_key) 138 | reconstructions, posterior = self(inputs) 139 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, 140 | last_layer=self.get_last_layer(), split="val"+postfix) 141 | 142 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, 143 | last_layer=self.get_last_layer(), split="val"+postfix) 144 | 145 | self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) 146 | self.log_dict(log_dict_ae) 147 | self.log_dict(log_dict_disc) 148 | return self.log_dict 149 | 150 | def configure_optimizers(self): 151 | lr = self.learning_rate 152 | ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list( 153 | self.quant_conv.parameters()) + list(self.post_quant_conv.parameters()) 154 | if self.learn_logvar: 155 | print(f"{self.__class__.__name__}: Learning logvar") 156 | ae_params_list.append(self.loss.logvar) 157 | opt_ae = torch.optim.Adam(ae_params_list, 158 | lr=lr, betas=(0.5, 0.9)) 159 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), 160 | lr=lr, betas=(0.5, 0.9)) 161 | return [opt_ae, opt_disc], [] 162 | 163 | def get_last_layer(self): 164 | return self.decoder.conv_out.weight 165 | 166 | @torch.no_grad() 167 | def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): 168 | log = dict() 169 | x = self.get_input(batch, self.image_key) 170 | x = x.to(self.device) 171 | if not only_inputs: 172 | xrec, posterior = self(x) 173 | if x.shape[1] > 3: 174 | # colorize with random projection 175 | assert xrec.shape[1] > 3 176 | x = self.to_rgb(x) 177 | xrec = self.to_rgb(xrec) 178 | log["samples"] = self.decode(torch.randn_like(posterior.sample())) 179 | log["reconstructions"] = xrec 180 | if log_ema or self.use_ema: 181 | with self.ema_scope(): 182 | xrec_ema, posterior_ema = self(x) 183 | if x.shape[1] > 3: 184 | # colorize with random projection 185 | assert xrec_ema.shape[1] > 3 186 | xrec_ema = self.to_rgb(xrec_ema) 187 | log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample())) 188 | log["reconstructions_ema"] = xrec_ema 189 | log["inputs"] = x 190 | return log 191 | 192 | def to_rgb(self, x): 193 | assert self.image_key == "segmentation" 194 | if not hasattr(self, "colorize"): 195 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 196 | x = F.conv2d(x, weight=self.colorize) 197 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1. 198 | return x 199 | 200 | 201 | class IdentityFirstStage(torch.nn.Module): 202 | def __init__(self, *args, vq_interface=False, **kwargs): 203 | self.vq_interface = vq_interface 204 | super().__init__() 205 | 206 | def encode(self, x, *args, **kwargs): 207 | return x 208 | 209 | def decode(self, x, *args, **kwargs): 210 | return x 211 | 212 | def quantize(self, x, *args, **kwargs): 213 | if self.vq_interface: 214 | return x, None, [None, None, None] 215 | return x 216 | 217 | def forward(self, x, *args, **kwargs): 218 | return x 219 | 220 | -------------------------------------------------------------------------------- /ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sled-group/CycleNet/a1c370aa9a0146c376c1f44d303d5534c1ba4743/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler import DPMSolverSampler -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/sampler.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | import torch 3 | 4 | from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver 5 | 6 | 7 | MODEL_TYPES = { 8 | "eps": "noise", 9 | "v": "v" 10 | } 11 | 12 | 13 | class DPMSolverSampler(object): 14 | def __init__(self, model, **kwargs): 15 | super().__init__() 16 | self.model = model 17 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) 18 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) 19 | 20 | def register_buffer(self, name, attr): 21 | if type(attr) == torch.Tensor: 22 | if attr.device != torch.device("cuda"): 23 | attr = attr.to(torch.device("cuda")) 24 | setattr(self, name, attr) 25 | 26 | @torch.no_grad() 27 | def sample(self, 28 | S, 29 | batch_size, 30 | shape, 31 | conditioning=None, 32 | callback=None, 33 | normals_sequence=None, 34 | img_callback=None, 35 | quantize_x0=False, 36 | eta=0., 37 | mask=None, 38 | x0=None, 39 | temperature=1., 40 | noise_dropout=0., 41 | score_corrector=None, 42 | corrector_kwargs=None, 43 | verbose=True, 44 | x_T=None, 45 | log_every_t=100, 46 | unconditional_guidance_scale=1., 47 | unconditional_conditioning=None, 48 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 49 | **kwargs 50 | ): 51 | if conditioning is not None: 52 | if isinstance(conditioning, dict): 53 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 54 | if cbs != batch_size: 55 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 56 | else: 57 | if conditioning.shape[0] != batch_size: 58 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 59 | 60 | # sampling 61 | C, H, W = shape 62 | size = (batch_size, C, H, W) 63 | 64 | print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') 65 | 66 | device = self.model.betas.device 67 | if x_T is None: 68 | img = torch.randn(size, device=device) 69 | else: 70 | img = x_T 71 | 72 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) 73 | 74 | model_fn = model_wrapper( 75 | lambda x, t, c: self.model.apply_model(x, t, c), 76 | ns, 77 | model_type=MODEL_TYPES[self.model.parameterization], 78 | guidance_type="classifier-free", 79 | condition=conditioning, 80 | unconditional_condition=unconditional_conditioning, 81 | guidance_scale=unconditional_guidance_scale, 82 | ) 83 | 84 | dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) 85 | x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) 86 | 87 | return x.to(device), None -------------------------------------------------------------------------------- /ldm/models/diffusion/sampling_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def append_dims(x, target_dims): 6 | """Appends dimensions to the end of a tensor until it has target_dims dimensions. 7 | From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" 8 | dims_to_append = target_dims - x.ndim 9 | if dims_to_append < 0: 10 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') 11 | return x[(...,) + (None,) * dims_to_append] 12 | 13 | 14 | def norm_thresholding(x0, value): 15 | s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) 16 | return x0 * (value / s) 17 | 18 | 19 | def spatial_norm_thresholding(x0, value): 20 | # b c h w 21 | s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) 22 | return x0 * (value / s) -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sled-group/CycleNet/a1c370aa9a0146c376c1f44d303d5534c1ba4743/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/upscaling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from functools import partial 5 | 6 | from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule 7 | from ldm.util import default 8 | 9 | 10 | class AbstractLowScaleModel(nn.Module): 11 | # for concatenating a downsampled image to the latent representation 12 | def __init__(self, noise_schedule_config=None): 13 | super(AbstractLowScaleModel, self).__init__() 14 | if noise_schedule_config is not None: 15 | self.register_schedule(**noise_schedule_config) 16 | 17 | def register_schedule(self, beta_schedule="linear", timesteps=1000, 18 | linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 19 | betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, 20 | cosine_s=cosine_s) 21 | alphas = 1. - betas 22 | alphas_cumprod = np.cumprod(alphas, axis=0) 23 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) 24 | 25 | timesteps, = betas.shape 26 | self.num_timesteps = int(timesteps) 27 | self.linear_start = linear_start 28 | self.linear_end = linear_end 29 | assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' 30 | 31 | to_torch = partial(torch.tensor, dtype=torch.float32) 32 | 33 | self.register_buffer('betas', to_torch(betas)) 34 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 35 | self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) 36 | 37 | # calculations for diffusion q(x_t | x_{t-1}) and others 38 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) 39 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) 40 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) 41 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) 42 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) 43 | 44 | def q_sample(self, x_start, t, noise=None): 45 | noise = default(noise, lambda: torch.randn_like(x_start)) 46 | return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 47 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) 48 | 49 | def forward(self, x): 50 | return x, None 51 | 52 | def decode(self, x): 53 | return x 54 | 55 | 56 | class SimpleImageConcat(AbstractLowScaleModel): 57 | # no noise level conditioning 58 | def __init__(self): 59 | super(SimpleImageConcat, self).__init__(noise_schedule_config=None) 60 | self.max_noise_level = 0 61 | 62 | def forward(self, x): 63 | # fix to constant noise level 64 | return x, torch.zeros(x.shape[0], device=x.device).long() 65 | 66 | 67 | class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel): 68 | def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False): 69 | super().__init__(noise_schedule_config=noise_schedule_config) 70 | self.max_noise_level = max_noise_level 71 | 72 | def forward(self, x, noise_level=None): 73 | if noise_level is None: 74 | noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() 75 | else: 76 | assert isinstance(noise_level, torch.Tensor) 77 | z = self.q_sample(x, noise_level) 78 | return z, noise_level 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sled-group/CycleNet/a1c370aa9a0146c376c1f44d303d5534c1ba4743/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1, dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | # remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.', '') 20 | self.m_name2s_name.update({name: s_name}) 21 | self.register_buffer(s_name, p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def reset_num_updates(self): 26 | del self.num_updates 27 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int)) 28 | 29 | def forward(self, model): 30 | decay = self.decay 31 | 32 | if self.num_updates >= 0: 33 | self.num_updates += 1 34 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 35 | 36 | one_minus_decay = 1.0 - decay 37 | 38 | with torch.no_grad(): 39 | m_param = dict(model.named_parameters()) 40 | shadow_params = dict(self.named_buffers()) 41 | 42 | for key in m_param: 43 | if m_param[key].requires_grad: 44 | sname = self.m_name2s_name[key] 45 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 46 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 47 | else: 48 | assert not key in self.m_name2s_name 49 | 50 | def copy_to(self, model): 51 | m_param = dict(model.named_parameters()) 52 | shadow_params = dict(self.named_buffers()) 53 | for key in m_param: 54 | if m_param[key].requires_grad: 55 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 56 | else: 57 | assert not key in self.m_name2s_name 58 | 59 | def store(self, parameters): 60 | """ 61 | Save the current parameters for restoring later. 62 | Args: 63 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 64 | temporarily stored. 65 | """ 66 | self.collected_params = [param.clone() for param in parameters] 67 | 68 | def restore(self, parameters): 69 | """ 70 | Restore the parameters stored with the `store` method. 71 | Useful to validate the model with EMA parameters without affecting the 72 | original optimization process. Store the parameters before the 73 | `copy_to` method. After validation (or model saving), use this to 74 | restore the former parameters. 75 | Args: 76 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 77 | updated with the stored parameters. 78 | """ 79 | for c_param, param in zip(self.collected_params, parameters): 80 | param.data.copy_(c_param.data) 81 | -------------------------------------------------------------------------------- /ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sled-group/CycleNet/a1c370aa9a0146c376c1f44d303d5534c1ba4743/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /ldm/modules/encoders/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.checkpoint import checkpoint 4 | 5 | from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel 6 | 7 | import open_clip 8 | from ldm.util import default, count_params 9 | 10 | 11 | class AbstractEncoder(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | def encode(self, *args, **kwargs): 16 | raise NotImplementedError 17 | 18 | 19 | class IdentityEncoder(AbstractEncoder): 20 | 21 | def encode(self, x): 22 | return x 23 | 24 | 25 | class ClassEmbedder(nn.Module): 26 | def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1): 27 | super().__init__() 28 | self.key = key 29 | self.embedding = nn.Embedding(n_classes, embed_dim) 30 | self.n_classes = n_classes 31 | self.ucg_rate = ucg_rate 32 | 33 | def forward(self, batch, key=None, disable_dropout=False): 34 | if key is None: 35 | key = self.key 36 | # this is for use in crossattn 37 | c = batch[key][:, None] 38 | if self.ucg_rate > 0. and not disable_dropout: 39 | mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) 40 | c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1) 41 | c = c.long() 42 | c = self.embedding(c) 43 | return c 44 | 45 | def get_unconditional_conditioning(self, bs, device="cuda"): 46 | uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) 47 | uc = torch.ones((bs,), device=device) * uc_class 48 | uc = {self.key: uc} 49 | return uc 50 | 51 | 52 | def disabled_train(self, mode=True): 53 | """Overwrite model.train with this function to make sure train/eval mode 54 | does not change anymore.""" 55 | return self 56 | 57 | 58 | class FrozenT5Embedder(AbstractEncoder): 59 | """Uses the T5 transformer encoder for text""" 60 | def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl 61 | super().__init__() 62 | self.tokenizer = T5Tokenizer.from_pretrained(version) 63 | self.transformer = T5EncoderModel.from_pretrained(version) 64 | self.device = device 65 | self.max_length = max_length # TODO: typical value? 66 | if freeze: 67 | self.freeze() 68 | 69 | def freeze(self): 70 | self.transformer = self.transformer.eval() 71 | #self.train = disabled_train 72 | for param in self.parameters(): 73 | param.requires_grad = False 74 | 75 | def forward(self, text): 76 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 77 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 78 | tokens = batch_encoding["input_ids"].to(self.device) 79 | outputs = self.transformer(input_ids=tokens) 80 | 81 | z = outputs.last_hidden_state 82 | return z 83 | 84 | def encode(self, text): 85 | return self(text) 86 | 87 | 88 | class FrozenCLIPEmbedder(AbstractEncoder): 89 | """Uses the CLIP transformer encoder for text (from huggingface)""" 90 | LAYERS = [ 91 | "last", 92 | "pooled", 93 | "hidden" 94 | ] 95 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, 96 | freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 97 | super().__init__() 98 | assert layer in self.LAYERS 99 | self.tokenizer = CLIPTokenizer.from_pretrained(version) 100 | self.transformer = CLIPTextModel.from_pretrained(version) 101 | self.device = device 102 | self.max_length = max_length 103 | if freeze: 104 | self.freeze() 105 | self.layer = layer 106 | self.layer_idx = layer_idx 107 | if layer == "hidden": 108 | assert layer_idx is not None 109 | assert 0 <= abs(layer_idx) <= 12 110 | 111 | def freeze(self): 112 | self.transformer = self.transformer.eval() 113 | #self.train = disabled_train 114 | for param in self.parameters(): 115 | param.requires_grad = False 116 | 117 | def forward(self, text): 118 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 119 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 120 | tokens = batch_encoding["input_ids"].to(self.device) 121 | outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") 122 | if self.layer == "last": 123 | z = outputs.last_hidden_state 124 | elif self.layer == "pooled": 125 | z = outputs.pooler_output[:, None, :] 126 | else: 127 | z = outputs.hidden_states[self.layer_idx] 128 | return z 129 | 130 | def encode(self, text): 131 | return self(text) 132 | 133 | 134 | class FrozenOpenCLIPEmbedder(AbstractEncoder): 135 | """ 136 | Uses the OpenCLIP transformer encoder for text 137 | """ 138 | LAYERS = [ 139 | #"pooled", 140 | "last", 141 | "penultimate" 142 | ] 143 | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, 144 | freeze=True, layer="last"): 145 | super().__init__() 146 | assert layer in self.LAYERS 147 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) 148 | del model.visual 149 | self.model = model 150 | 151 | self.device = device 152 | self.max_length = max_length 153 | if freeze: 154 | self.freeze() 155 | self.layer = layer 156 | if self.layer == "last": 157 | self.layer_idx = 0 158 | elif self.layer == "penultimate": 159 | self.layer_idx = 1 160 | else: 161 | raise NotImplementedError() 162 | 163 | def freeze(self): 164 | self.model = self.model.eval() 165 | for param in self.parameters(): 166 | param.requires_grad = False 167 | 168 | def forward(self, text): 169 | tokens = open_clip.tokenize(text) 170 | z = self.encode_with_transformer(tokens.to(self.device)) 171 | return z 172 | 173 | def encode_with_transformer(self, text): 174 | x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] 175 | x = x + self.model.positional_embedding 176 | x = x.permute(1, 0, 2) # NLD -> LND 177 | x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) 178 | x = x.permute(1, 0, 2) # LND -> NLD 179 | x = self.model.ln_final(x) 180 | return x 181 | 182 | def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): 183 | for i, r in enumerate(self.model.transformer.resblocks): 184 | if i == len(self.model.transformer.resblocks) - self.layer_idx: 185 | break 186 | if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): 187 | x = checkpoint(r, x, attn_mask) 188 | else: 189 | x = r(x, attn_mask=attn_mask) 190 | return x 191 | 192 | def encode(self, text): 193 | return self(text) 194 | 195 | 196 | class FrozenCLIPT5Encoder(AbstractEncoder): 197 | def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", 198 | clip_max_length=77, t5_max_length=77): 199 | super().__init__() 200 | self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) 201 | self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) 202 | print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, " 203 | f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.") 204 | 205 | def encode(self, text): 206 | return self(text) 207 | 208 | def forward(self, text): 209 | clip_z = self.clip_encoder.encode(text) 210 | t5_z = self.t5_encoder.encode(text) 211 | return [clip_z, t5_z] 212 | 213 | 214 | -------------------------------------------------------------------------------- /ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sled-group/CycleNet/a1c370aa9a0146c376c1f44d303d5534c1ba4743/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /ldm/modules/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sled-group/CycleNet/a1c370aa9a0146c376c1f44d303d5534c1ba4743/ldm/modules/midas/__init__.py -------------------------------------------------------------------------------- /ldm/modules/midas/api.py: -------------------------------------------------------------------------------- 1 | # based on https://github.com/isl-org/MiDaS 2 | 3 | import cv2 4 | import torch 5 | import torch.nn as nn 6 | from torchvision.transforms import Compose 7 | 8 | from ldm.modules.midas.midas.dpt_depth import DPTDepthModel 9 | from ldm.modules.midas.midas.midas_net import MidasNet 10 | from ldm.modules.midas.midas.midas_net_custom import MidasNet_small 11 | from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet 12 | 13 | 14 | ISL_PATHS = { 15 | "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt", 16 | "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt", 17 | "midas_v21": "", 18 | "midas_v21_small": "", 19 | } 20 | 21 | 22 | def disabled_train(self, mode=True): 23 | """Overwrite model.train with this function to make sure train/eval mode 24 | does not change anymore.""" 25 | return self 26 | 27 | 28 | def load_midas_transform(model_type): 29 | # https://github.com/isl-org/MiDaS/blob/master/run.py 30 | # load transform only 31 | if model_type == "dpt_large": # DPT-Large 32 | net_w, net_h = 384, 384 33 | resize_mode = "minimal" 34 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 35 | 36 | elif model_type == "dpt_hybrid": # DPT-Hybrid 37 | net_w, net_h = 384, 384 38 | resize_mode = "minimal" 39 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 40 | 41 | elif model_type == "midas_v21": 42 | net_w, net_h = 384, 384 43 | resize_mode = "upper_bound" 44 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 45 | 46 | elif model_type == "midas_v21_small": 47 | net_w, net_h = 256, 256 48 | resize_mode = "upper_bound" 49 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 50 | 51 | else: 52 | assert False, f"model_type '{model_type}' not implemented, use: --model_type large" 53 | 54 | transform = Compose( 55 | [ 56 | Resize( 57 | net_w, 58 | net_h, 59 | resize_target=None, 60 | keep_aspect_ratio=True, 61 | ensure_multiple_of=32, 62 | resize_method=resize_mode, 63 | image_interpolation_method=cv2.INTER_CUBIC, 64 | ), 65 | normalization, 66 | PrepareForNet(), 67 | ] 68 | ) 69 | 70 | return transform 71 | 72 | 73 | def load_model(model_type): 74 | # https://github.com/isl-org/MiDaS/blob/master/run.py 75 | # load network 76 | model_path = ISL_PATHS[model_type] 77 | if model_type == "dpt_large": # DPT-Large 78 | model = DPTDepthModel( 79 | path=model_path, 80 | backbone="vitl16_384", 81 | non_negative=True, 82 | ) 83 | net_w, net_h = 384, 384 84 | resize_mode = "minimal" 85 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 86 | 87 | elif model_type == "dpt_hybrid": # DPT-Hybrid 88 | model = DPTDepthModel( 89 | path=model_path, 90 | backbone="vitb_rn50_384", 91 | non_negative=True, 92 | ) 93 | net_w, net_h = 384, 384 94 | resize_mode = "minimal" 95 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 96 | 97 | elif model_type == "midas_v21": 98 | model = MidasNet(model_path, non_negative=True) 99 | net_w, net_h = 384, 384 100 | resize_mode = "upper_bound" 101 | normalization = NormalizeImage( 102 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 103 | ) 104 | 105 | elif model_type == "midas_v21_small": 106 | model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, 107 | non_negative=True, blocks={'expand': True}) 108 | net_w, net_h = 256, 256 109 | resize_mode = "upper_bound" 110 | normalization = NormalizeImage( 111 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 112 | ) 113 | 114 | else: 115 | print(f"model_type '{model_type}' not implemented, use: --model_type large") 116 | assert False 117 | 118 | transform = Compose( 119 | [ 120 | Resize( 121 | net_w, 122 | net_h, 123 | resize_target=None, 124 | keep_aspect_ratio=True, 125 | ensure_multiple_of=32, 126 | resize_method=resize_mode, 127 | image_interpolation_method=cv2.INTER_CUBIC, 128 | ), 129 | normalization, 130 | PrepareForNet(), 131 | ] 132 | ) 133 | 134 | return model.eval(), transform 135 | 136 | 137 | class MiDaSInference(nn.Module): 138 | MODEL_TYPES_TORCH_HUB = [ 139 | "DPT_Large", 140 | "DPT_Hybrid", 141 | "MiDaS_small" 142 | ] 143 | MODEL_TYPES_ISL = [ 144 | "dpt_large", 145 | "dpt_hybrid", 146 | "midas_v21", 147 | "midas_v21_small", 148 | ] 149 | 150 | def __init__(self, model_type): 151 | super().__init__() 152 | assert (model_type in self.MODEL_TYPES_ISL) 153 | model, _ = load_model(model_type) 154 | self.model = model 155 | self.model.train = disabled_train 156 | 157 | def forward(self, x): 158 | # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array 159 | # NOTE: we expect that the correct transform has been called during dataloading. 160 | with torch.no_grad(): 161 | prediction = self.model(x) 162 | prediction = torch.nn.functional.interpolate( 163 | prediction.unsqueeze(1), 164 | size=x.shape[2:], 165 | mode="bicubic", 166 | align_corners=False, 167 | ) 168 | assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3]) 169 | return prediction 170 | 171 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sled-group/CycleNet/a1c370aa9a0146c376c1f44d303d5534c1ba4743/ldm/modules/midas/midas/__init__.py -------------------------------------------------------------------------------- /ldm/modules/midas/midas/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseModel(torch.nn.Module): 5 | def load(self, path): 6 | """Load model from file. 7 | 8 | Args: 9 | path (str): file path 10 | """ 11 | parameters = torch.load(path, map_location=torch.device('cpu')) 12 | 13 | if "optimizer" in parameters: 14 | parameters = parameters["model"] 15 | 16 | self.load_state_dict(parameters) 17 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .vit import ( 5 | _make_pretrained_vitb_rn50_384, 6 | _make_pretrained_vitl16_384, 7 | _make_pretrained_vitb16_384, 8 | forward_vit, 9 | ) 10 | 11 | def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): 12 | if backbone == "vitl16_384": 13 | pretrained = _make_pretrained_vitl16_384( 14 | use_pretrained, hooks=hooks, use_readout=use_readout 15 | ) 16 | scratch = _make_scratch( 17 | [256, 512, 1024, 1024], features, groups=groups, expand=expand 18 | ) # ViT-L/16 - 85.0% Top1 (backbone) 19 | elif backbone == "vitb_rn50_384": 20 | pretrained = _make_pretrained_vitb_rn50_384( 21 | use_pretrained, 22 | hooks=hooks, 23 | use_vit_only=use_vit_only, 24 | use_readout=use_readout, 25 | ) 26 | scratch = _make_scratch( 27 | [256, 512, 768, 768], features, groups=groups, expand=expand 28 | ) # ViT-H/16 - 85.0% Top1 (backbone) 29 | elif backbone == "vitb16_384": 30 | pretrained = _make_pretrained_vitb16_384( 31 | use_pretrained, hooks=hooks, use_readout=use_readout 32 | ) 33 | scratch = _make_scratch( 34 | [96, 192, 384, 768], features, groups=groups, expand=expand 35 | ) # ViT-B/16 - 84.6% Top1 (backbone) 36 | elif backbone == "resnext101_wsl": 37 | pretrained = _make_pretrained_resnext101_wsl(use_pretrained) 38 | scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 39 | elif backbone == "efficientnet_lite3": 40 | pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) 41 | scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 42 | else: 43 | print(f"Backbone '{backbone}' not implemented") 44 | assert False 45 | 46 | return pretrained, scratch 47 | 48 | 49 | def _make_scratch(in_shape, out_shape, groups=1, expand=False): 50 | scratch = nn.Module() 51 | 52 | out_shape1 = out_shape 53 | out_shape2 = out_shape 54 | out_shape3 = out_shape 55 | out_shape4 = out_shape 56 | if expand==True: 57 | out_shape1 = out_shape 58 | out_shape2 = out_shape*2 59 | out_shape3 = out_shape*4 60 | out_shape4 = out_shape*8 61 | 62 | scratch.layer1_rn = nn.Conv2d( 63 | in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 64 | ) 65 | scratch.layer2_rn = nn.Conv2d( 66 | in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 67 | ) 68 | scratch.layer3_rn = nn.Conv2d( 69 | in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 70 | ) 71 | scratch.layer4_rn = nn.Conv2d( 72 | in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 73 | ) 74 | 75 | return scratch 76 | 77 | 78 | def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): 79 | efficientnet = torch.hub.load( 80 | "rwightman/gen-efficientnet-pytorch", 81 | "tf_efficientnet_lite3", 82 | pretrained=use_pretrained, 83 | exportable=exportable 84 | ) 85 | return _make_efficientnet_backbone(efficientnet) 86 | 87 | 88 | def _make_efficientnet_backbone(effnet): 89 | pretrained = nn.Module() 90 | 91 | pretrained.layer1 = nn.Sequential( 92 | effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] 93 | ) 94 | pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) 95 | pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) 96 | pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) 97 | 98 | return pretrained 99 | 100 | 101 | def _make_resnet_backbone(resnet): 102 | pretrained = nn.Module() 103 | pretrained.layer1 = nn.Sequential( 104 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 105 | ) 106 | 107 | pretrained.layer2 = resnet.layer2 108 | pretrained.layer3 = resnet.layer3 109 | pretrained.layer4 = resnet.layer4 110 | 111 | return pretrained 112 | 113 | 114 | def _make_pretrained_resnext101_wsl(use_pretrained): 115 | resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") 116 | return _make_resnet_backbone(resnet) 117 | 118 | 119 | 120 | class Interpolate(nn.Module): 121 | """Interpolation module. 122 | """ 123 | 124 | def __init__(self, scale_factor, mode, align_corners=False): 125 | """Init. 126 | 127 | Args: 128 | scale_factor (float): scaling 129 | mode (str): interpolation mode 130 | """ 131 | super(Interpolate, self).__init__() 132 | 133 | self.interp = nn.functional.interpolate 134 | self.scale_factor = scale_factor 135 | self.mode = mode 136 | self.align_corners = align_corners 137 | 138 | def forward(self, x): 139 | """Forward pass. 140 | 141 | Args: 142 | x (tensor): input 143 | 144 | Returns: 145 | tensor: interpolated data 146 | """ 147 | 148 | x = self.interp( 149 | x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners 150 | ) 151 | 152 | return x 153 | 154 | 155 | class ResidualConvUnit(nn.Module): 156 | """Residual convolution module. 157 | """ 158 | 159 | def __init__(self, features): 160 | """Init. 161 | 162 | Args: 163 | features (int): number of features 164 | """ 165 | super().__init__() 166 | 167 | self.conv1 = nn.Conv2d( 168 | features, features, kernel_size=3, stride=1, padding=1, bias=True 169 | ) 170 | 171 | self.conv2 = nn.Conv2d( 172 | features, features, kernel_size=3, stride=1, padding=1, bias=True 173 | ) 174 | 175 | self.relu = nn.ReLU(inplace=True) 176 | 177 | def forward(self, x): 178 | """Forward pass. 179 | 180 | Args: 181 | x (tensor): input 182 | 183 | Returns: 184 | tensor: output 185 | """ 186 | out = self.relu(x) 187 | out = self.conv1(out) 188 | out = self.relu(out) 189 | out = self.conv2(out) 190 | 191 | return out + x 192 | 193 | 194 | class FeatureFusionBlock(nn.Module): 195 | """Feature fusion block. 196 | """ 197 | 198 | def __init__(self, features): 199 | """Init. 200 | 201 | Args: 202 | features (int): number of features 203 | """ 204 | super(FeatureFusionBlock, self).__init__() 205 | 206 | self.resConfUnit1 = ResidualConvUnit(features) 207 | self.resConfUnit2 = ResidualConvUnit(features) 208 | 209 | def forward(self, *xs): 210 | """Forward pass. 211 | 212 | Returns: 213 | tensor: output 214 | """ 215 | output = xs[0] 216 | 217 | if len(xs) == 2: 218 | output += self.resConfUnit1(xs[1]) 219 | 220 | output = self.resConfUnit2(output) 221 | 222 | output = nn.functional.interpolate( 223 | output, scale_factor=2, mode="bilinear", align_corners=True 224 | ) 225 | 226 | return output 227 | 228 | 229 | 230 | 231 | class ResidualConvUnit_custom(nn.Module): 232 | """Residual convolution module. 233 | """ 234 | 235 | def __init__(self, features, activation, bn): 236 | """Init. 237 | 238 | Args: 239 | features (int): number of features 240 | """ 241 | super().__init__() 242 | 243 | self.bn = bn 244 | 245 | self.groups=1 246 | 247 | self.conv1 = nn.Conv2d( 248 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups 249 | ) 250 | 251 | self.conv2 = nn.Conv2d( 252 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups 253 | ) 254 | 255 | if self.bn==True: 256 | self.bn1 = nn.BatchNorm2d(features) 257 | self.bn2 = nn.BatchNorm2d(features) 258 | 259 | self.activation = activation 260 | 261 | self.skip_add = nn.quantized.FloatFunctional() 262 | 263 | def forward(self, x): 264 | """Forward pass. 265 | 266 | Args: 267 | x (tensor): input 268 | 269 | Returns: 270 | tensor: output 271 | """ 272 | 273 | out = self.activation(x) 274 | out = self.conv1(out) 275 | if self.bn==True: 276 | out = self.bn1(out) 277 | 278 | out = self.activation(out) 279 | out = self.conv2(out) 280 | if self.bn==True: 281 | out = self.bn2(out) 282 | 283 | if self.groups > 1: 284 | out = self.conv_merge(out) 285 | 286 | return self.skip_add.add(out, x) 287 | 288 | # return out + x 289 | 290 | 291 | class FeatureFusionBlock_custom(nn.Module): 292 | """Feature fusion block. 293 | """ 294 | 295 | def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): 296 | """Init. 297 | 298 | Args: 299 | features (int): number of features 300 | """ 301 | super(FeatureFusionBlock_custom, self).__init__() 302 | 303 | self.deconv = deconv 304 | self.align_corners = align_corners 305 | 306 | self.groups=1 307 | 308 | self.expand = expand 309 | out_features = features 310 | if self.expand==True: 311 | out_features = features//2 312 | 313 | self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) 314 | 315 | self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) 316 | self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) 317 | 318 | self.skip_add = nn.quantized.FloatFunctional() 319 | 320 | def forward(self, *xs): 321 | """Forward pass. 322 | 323 | Returns: 324 | tensor: output 325 | """ 326 | output = xs[0] 327 | 328 | if len(xs) == 2: 329 | res = self.resConfUnit1(xs[1]) 330 | output = self.skip_add.add(output, res) 331 | # output += res 332 | 333 | output = self.resConfUnit2(output) 334 | 335 | output = nn.functional.interpolate( 336 | output, scale_factor=2, mode="bilinear", align_corners=self.align_corners 337 | ) 338 | 339 | output = self.out_conv(output) 340 | 341 | return output 342 | 343 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/dpt_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base_model import BaseModel 6 | from .blocks import ( 7 | FeatureFusionBlock, 8 | FeatureFusionBlock_custom, 9 | Interpolate, 10 | _make_encoder, 11 | forward_vit, 12 | ) 13 | 14 | 15 | def _make_fusion_block(features, use_bn): 16 | return FeatureFusionBlock_custom( 17 | features, 18 | nn.ReLU(False), 19 | deconv=False, 20 | bn=use_bn, 21 | expand=False, 22 | align_corners=True, 23 | ) 24 | 25 | 26 | class DPT(BaseModel): 27 | def __init__( 28 | self, 29 | head, 30 | features=256, 31 | backbone="vitb_rn50_384", 32 | readout="project", 33 | channels_last=False, 34 | use_bn=False, 35 | ): 36 | 37 | super(DPT, self).__init__() 38 | 39 | self.channels_last = channels_last 40 | 41 | hooks = { 42 | "vitb_rn50_384": [0, 1, 8, 11], 43 | "vitb16_384": [2, 5, 8, 11], 44 | "vitl16_384": [5, 11, 17, 23], 45 | } 46 | 47 | # Instantiate backbone and reassemble blocks 48 | self.pretrained, self.scratch = _make_encoder( 49 | backbone, 50 | features, 51 | False, # Set to true of you want to train from scratch, uses ImageNet weights 52 | groups=1, 53 | expand=False, 54 | exportable=False, 55 | hooks=hooks[backbone], 56 | use_readout=readout, 57 | ) 58 | 59 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 60 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 61 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn) 62 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 63 | 64 | self.scratch.output_conv = head 65 | 66 | 67 | def forward(self, x): 68 | if self.channels_last == True: 69 | x.contiguous(memory_format=torch.channels_last) 70 | 71 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) 72 | 73 | layer_1_rn = self.scratch.layer1_rn(layer_1) 74 | layer_2_rn = self.scratch.layer2_rn(layer_2) 75 | layer_3_rn = self.scratch.layer3_rn(layer_3) 76 | layer_4_rn = self.scratch.layer4_rn(layer_4) 77 | 78 | path_4 = self.scratch.refinenet4(layer_4_rn) 79 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 80 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 81 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 82 | 83 | out = self.scratch.output_conv(path_1) 84 | 85 | return out 86 | 87 | 88 | class DPTDepthModel(DPT): 89 | def __init__(self, path=None, non_negative=True, **kwargs): 90 | features = kwargs["features"] if "features" in kwargs else 256 91 | 92 | head = nn.Sequential( 93 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), 94 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 95 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 96 | nn.ReLU(True), 97 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 98 | nn.ReLU(True) if non_negative else nn.Identity(), 99 | nn.Identity(), 100 | ) 101 | 102 | super().__init__(head, **kwargs) 103 | 104 | if path is not None: 105 | self.load(path) 106 | 107 | def forward(self, x): 108 | return super().forward(x).squeeze(dim=1) 109 | 110 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/midas_net.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=256, non_negative=True): 17 | """Init. 18 | 19 | Args: 20 | path (str, optional): Path to saved model. Defaults to None. 21 | features (int, optional): Number of features. Defaults to 256. 22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 23 | """ 24 | print("Loading weights: ", path) 25 | 26 | super(MidasNet, self).__init__() 27 | 28 | use_pretrained = False if path is None else True 29 | 30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) 31 | 32 | self.scratch.refinenet4 = FeatureFusionBlock(features) 33 | self.scratch.refinenet3 = FeatureFusionBlock(features) 34 | self.scratch.refinenet2 = FeatureFusionBlock(features) 35 | self.scratch.refinenet1 = FeatureFusionBlock(features) 36 | 37 | self.scratch.output_conv = nn.Sequential( 38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 39 | Interpolate(scale_factor=2, mode="bilinear"), 40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 41 | nn.ReLU(True), 42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 43 | nn.ReLU(True) if non_negative else nn.Identity(), 44 | ) 45 | 46 | if path: 47 | self.load(path) 48 | 49 | def forward(self, x): 50 | """Forward pass. 51 | 52 | Args: 53 | x (tensor): input data (image) 54 | 55 | Returns: 56 | tensor: depth 57 | """ 58 | 59 | layer_1 = self.pretrained.layer1(x) 60 | layer_2 = self.pretrained.layer2(layer_1) 61 | layer_3 = self.pretrained.layer3(layer_2) 62 | layer_4 = self.pretrained.layer4(layer_3) 63 | 64 | layer_1_rn = self.scratch.layer1_rn(layer_1) 65 | layer_2_rn = self.scratch.layer2_rn(layer_2) 66 | layer_3_rn = self.scratch.layer3_rn(layer_3) 67 | layer_4_rn = self.scratch.layer4_rn(layer_4) 68 | 69 | path_4 = self.scratch.refinenet4(layer_4_rn) 70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 73 | 74 | out = self.scratch.output_conv(path_1) 75 | 76 | return torch.squeeze(out, dim=1) 77 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/midas_net_custom.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet_small(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, 17 | blocks={'expand': True}): 18 | """Init. 19 | 20 | Args: 21 | path (str, optional): Path to saved model. Defaults to None. 22 | features (int, optional): Number of features. Defaults to 256. 23 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 24 | """ 25 | print("Loading weights: ", path) 26 | 27 | super(MidasNet_small, self).__init__() 28 | 29 | use_pretrained = False if path else True 30 | 31 | self.channels_last = channels_last 32 | self.blocks = blocks 33 | self.backbone = backbone 34 | 35 | self.groups = 1 36 | 37 | features1=features 38 | features2=features 39 | features3=features 40 | features4=features 41 | self.expand = False 42 | if "expand" in self.blocks and self.blocks['expand'] == True: 43 | self.expand = True 44 | features1=features 45 | features2=features*2 46 | features3=features*4 47 | features4=features*8 48 | 49 | self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) 50 | 51 | self.scratch.activation = nn.ReLU(False) 52 | 53 | self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 54 | self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 55 | self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 56 | self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) 57 | 58 | 59 | self.scratch.output_conv = nn.Sequential( 60 | nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), 61 | Interpolate(scale_factor=2, mode="bilinear"), 62 | nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), 63 | self.scratch.activation, 64 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 65 | nn.ReLU(True) if non_negative else nn.Identity(), 66 | nn.Identity(), 67 | ) 68 | 69 | if path: 70 | self.load(path) 71 | 72 | 73 | def forward(self, x): 74 | """Forward pass. 75 | 76 | Args: 77 | x (tensor): input data (image) 78 | 79 | Returns: 80 | tensor: depth 81 | """ 82 | if self.channels_last==True: 83 | print("self.channels_last = ", self.channels_last) 84 | x.contiguous(memory_format=torch.channels_last) 85 | 86 | 87 | layer_1 = self.pretrained.layer1(x) 88 | layer_2 = self.pretrained.layer2(layer_1) 89 | layer_3 = self.pretrained.layer3(layer_2) 90 | layer_4 = self.pretrained.layer4(layer_3) 91 | 92 | layer_1_rn = self.scratch.layer1_rn(layer_1) 93 | layer_2_rn = self.scratch.layer2_rn(layer_2) 94 | layer_3_rn = self.scratch.layer3_rn(layer_3) 95 | layer_4_rn = self.scratch.layer4_rn(layer_4) 96 | 97 | 98 | path_4 = self.scratch.refinenet4(layer_4_rn) 99 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 100 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 101 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 102 | 103 | out = self.scratch.output_conv(path_1) 104 | 105 | return torch.squeeze(out, dim=1) 106 | 107 | 108 | 109 | def fuse_model(m): 110 | prev_previous_type = nn.Identity() 111 | prev_previous_name = '' 112 | previous_type = nn.Identity() 113 | previous_name = '' 114 | for name, module in m.named_modules(): 115 | if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: 116 | # print("FUSED ", prev_previous_name, previous_name, name) 117 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) 118 | elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: 119 | # print("FUSED ", prev_previous_name, previous_name) 120 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) 121 | # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: 122 | # print("FUSED ", previous_name, name) 123 | # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) 124 | 125 | prev_previous_type = previous_type 126 | prev_previous_name = previous_name 127 | previous_type = type(module) 128 | previous_name = name -------------------------------------------------------------------------------- /ldm/modules/midas/midas/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import math 4 | 5 | 6 | def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): 7 | """Rezise the sample to ensure the given size. Keeps aspect ratio. 8 | 9 | Args: 10 | sample (dict): sample 11 | size (tuple): image size 12 | 13 | Returns: 14 | tuple: new size 15 | """ 16 | shape = list(sample["disparity"].shape) 17 | 18 | if shape[0] >= size[0] and shape[1] >= size[1]: 19 | return sample 20 | 21 | scale = [0, 0] 22 | scale[0] = size[0] / shape[0] 23 | scale[1] = size[1] / shape[1] 24 | 25 | scale = max(scale) 26 | 27 | shape[0] = math.ceil(scale * shape[0]) 28 | shape[1] = math.ceil(scale * shape[1]) 29 | 30 | # resize 31 | sample["image"] = cv2.resize( 32 | sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method 33 | ) 34 | 35 | sample["disparity"] = cv2.resize( 36 | sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST 37 | ) 38 | sample["mask"] = cv2.resize( 39 | sample["mask"].astype(np.float32), 40 | tuple(shape[::-1]), 41 | interpolation=cv2.INTER_NEAREST, 42 | ) 43 | sample["mask"] = sample["mask"].astype(bool) 44 | 45 | return tuple(shape) 46 | 47 | 48 | class Resize(object): 49 | """Resize sample to given size (width, height). 50 | """ 51 | 52 | def __init__( 53 | self, 54 | width, 55 | height, 56 | resize_target=True, 57 | keep_aspect_ratio=False, 58 | ensure_multiple_of=1, 59 | resize_method="lower_bound", 60 | image_interpolation_method=cv2.INTER_AREA, 61 | ): 62 | """Init. 63 | 64 | Args: 65 | width (int): desired output width 66 | height (int): desired output height 67 | resize_target (bool, optional): 68 | True: Resize the full sample (image, mask, target). 69 | False: Resize image only. 70 | Defaults to True. 71 | keep_aspect_ratio (bool, optional): 72 | True: Keep the aspect ratio of the input sample. 73 | Output sample might not have the given width and height, and 74 | resize behaviour depends on the parameter 'resize_method'. 75 | Defaults to False. 76 | ensure_multiple_of (int, optional): 77 | Output width and height is constrained to be multiple of this parameter. 78 | Defaults to 1. 79 | resize_method (str, optional): 80 | "lower_bound": Output will be at least as large as the given size. 81 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) 82 | "minimal": Scale as least as possible. (Output size might be smaller than given size.) 83 | Defaults to "lower_bound". 84 | """ 85 | self.__width = width 86 | self.__height = height 87 | 88 | self.__resize_target = resize_target 89 | self.__keep_aspect_ratio = keep_aspect_ratio 90 | self.__multiple_of = ensure_multiple_of 91 | self.__resize_method = resize_method 92 | self.__image_interpolation_method = image_interpolation_method 93 | 94 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None): 95 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) 96 | 97 | if max_val is not None and y > max_val: 98 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) 99 | 100 | if y < min_val: 101 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) 102 | 103 | return y 104 | 105 | def get_size(self, width, height): 106 | # determine new height and width 107 | scale_height = self.__height / height 108 | scale_width = self.__width / width 109 | 110 | if self.__keep_aspect_ratio: 111 | if self.__resize_method == "lower_bound": 112 | # scale such that output size is lower bound 113 | if scale_width > scale_height: 114 | # fit width 115 | scale_height = scale_width 116 | else: 117 | # fit height 118 | scale_width = scale_height 119 | elif self.__resize_method == "upper_bound": 120 | # scale such that output size is upper bound 121 | if scale_width < scale_height: 122 | # fit width 123 | scale_height = scale_width 124 | else: 125 | # fit height 126 | scale_width = scale_height 127 | elif self.__resize_method == "minimal": 128 | # scale as least as possbile 129 | if abs(1 - scale_width) < abs(1 - scale_height): 130 | # fit width 131 | scale_height = scale_width 132 | else: 133 | # fit height 134 | scale_width = scale_height 135 | else: 136 | raise ValueError( 137 | f"resize_method {self.__resize_method} not implemented" 138 | ) 139 | 140 | if self.__resize_method == "lower_bound": 141 | new_height = self.constrain_to_multiple_of( 142 | scale_height * height, min_val=self.__height 143 | ) 144 | new_width = self.constrain_to_multiple_of( 145 | scale_width * width, min_val=self.__width 146 | ) 147 | elif self.__resize_method == "upper_bound": 148 | new_height = self.constrain_to_multiple_of( 149 | scale_height * height, max_val=self.__height 150 | ) 151 | new_width = self.constrain_to_multiple_of( 152 | scale_width * width, max_val=self.__width 153 | ) 154 | elif self.__resize_method == "minimal": 155 | new_height = self.constrain_to_multiple_of(scale_height * height) 156 | new_width = self.constrain_to_multiple_of(scale_width * width) 157 | else: 158 | raise ValueError(f"resize_method {self.__resize_method} not implemented") 159 | 160 | return (new_width, new_height) 161 | 162 | def __call__(self, sample): 163 | width, height = self.get_size( 164 | sample["image"].shape[1], sample["image"].shape[0] 165 | ) 166 | 167 | # resize sample 168 | sample["image"] = cv2.resize( 169 | sample["image"], 170 | (width, height), 171 | interpolation=self.__image_interpolation_method, 172 | ) 173 | 174 | if self.__resize_target: 175 | if "disparity" in sample: 176 | sample["disparity"] = cv2.resize( 177 | sample["disparity"], 178 | (width, height), 179 | interpolation=cv2.INTER_NEAREST, 180 | ) 181 | 182 | if "depth" in sample: 183 | sample["depth"] = cv2.resize( 184 | sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST 185 | ) 186 | 187 | sample["mask"] = cv2.resize( 188 | sample["mask"].astype(np.float32), 189 | (width, height), 190 | interpolation=cv2.INTER_NEAREST, 191 | ) 192 | sample["mask"] = sample["mask"].astype(bool) 193 | 194 | return sample 195 | 196 | 197 | class NormalizeImage(object): 198 | """Normlize image by given mean and std. 199 | """ 200 | 201 | def __init__(self, mean, std): 202 | self.__mean = mean 203 | self.__std = std 204 | 205 | def __call__(self, sample): 206 | sample["image"] = (sample["image"] - self.__mean) / self.__std 207 | 208 | return sample 209 | 210 | 211 | class PrepareForNet(object): 212 | """Prepare sample for usage as network input. 213 | """ 214 | 215 | def __init__(self): 216 | pass 217 | 218 | def __call__(self, sample): 219 | image = np.transpose(sample["image"], (2, 0, 1)) 220 | sample["image"] = np.ascontiguousarray(image).astype(np.float32) 221 | 222 | if "mask" in sample: 223 | sample["mask"] = sample["mask"].astype(np.float32) 224 | sample["mask"] = np.ascontiguousarray(sample["mask"]) 225 | 226 | if "disparity" in sample: 227 | disparity = sample["disparity"].astype(np.float32) 228 | sample["disparity"] = np.ascontiguousarray(disparity) 229 | 230 | if "depth" in sample: 231 | depth = sample["depth"].astype(np.float32) 232 | sample["depth"] = np.ascontiguousarray(depth) 233 | 234 | return sample 235 | -------------------------------------------------------------------------------- /ldm/modules/midas/utils.py: -------------------------------------------------------------------------------- 1 | """Utils for monoDepth.""" 2 | import sys 3 | import re 4 | import numpy as np 5 | import cv2 6 | import torch 7 | 8 | 9 | def read_pfm(path): 10 | """Read pfm file. 11 | 12 | Args: 13 | path (str): path to file 14 | 15 | Returns: 16 | tuple: (data, scale) 17 | """ 18 | with open(path, "rb") as file: 19 | 20 | color = None 21 | width = None 22 | height = None 23 | scale = None 24 | endian = None 25 | 26 | header = file.readline().rstrip() 27 | if header.decode("ascii") == "PF": 28 | color = True 29 | elif header.decode("ascii") == "Pf": 30 | color = False 31 | else: 32 | raise Exception("Not a PFM file: " + path) 33 | 34 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) 35 | if dim_match: 36 | width, height = list(map(int, dim_match.groups())) 37 | else: 38 | raise Exception("Malformed PFM header.") 39 | 40 | scale = float(file.readline().decode("ascii").rstrip()) 41 | if scale < 0: 42 | # little-endian 43 | endian = "<" 44 | scale = -scale 45 | else: 46 | # big-endian 47 | endian = ">" 48 | 49 | data = np.fromfile(file, endian + "f") 50 | shape = (height, width, 3) if color else (height, width) 51 | 52 | data = np.reshape(data, shape) 53 | data = np.flipud(data) 54 | 55 | return data, scale 56 | 57 | 58 | def write_pfm(path, image, scale=1): 59 | """Write pfm file. 60 | 61 | Args: 62 | path (str): pathto file 63 | image (array): data 64 | scale (int, optional): Scale. Defaults to 1. 65 | """ 66 | 67 | with open(path, "wb") as file: 68 | color = None 69 | 70 | if image.dtype.name != "float32": 71 | raise Exception("Image dtype must be float32.") 72 | 73 | image = np.flipud(image) 74 | 75 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 76 | color = True 77 | elif ( 78 | len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 79 | ): # greyscale 80 | color = False 81 | else: 82 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") 83 | 84 | file.write("PF\n" if color else "Pf\n".encode()) 85 | file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) 86 | 87 | endian = image.dtype.byteorder 88 | 89 | if endian == "<" or endian == "=" and sys.byteorder == "little": 90 | scale = -scale 91 | 92 | file.write("%f\n".encode() % scale) 93 | 94 | image.tofile(file) 95 | 96 | 97 | def read_image(path): 98 | """Read image and output RGB image (0-1). 99 | 100 | Args: 101 | path (str): path to file 102 | 103 | Returns: 104 | array: RGB image (0-1) 105 | """ 106 | img = cv2.imread(path) 107 | 108 | if img.ndim == 2: 109 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 110 | 111 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 112 | 113 | return img 114 | 115 | 116 | def resize_image(img): 117 | """Resize image and make it fit for network. 118 | 119 | Args: 120 | img (array): image 121 | 122 | Returns: 123 | tensor: data ready for network 124 | """ 125 | height_orig = img.shape[0] 126 | width_orig = img.shape[1] 127 | 128 | if width_orig > height_orig: 129 | scale = width_orig / 384 130 | else: 131 | scale = height_orig / 384 132 | 133 | height = (np.ceil(height_orig / scale / 32) * 32).astype(int) 134 | width = (np.ceil(width_orig / scale / 32) * 32).astype(int) 135 | 136 | img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) 137 | 138 | img_resized = ( 139 | torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() 140 | ) 141 | img_resized = img_resized.unsqueeze(0) 142 | 143 | return img_resized 144 | 145 | 146 | def resize_depth(depth, width, height): 147 | """Resize depth map and bring to CPU (numpy). 148 | 149 | Args: 150 | depth (tensor): depth 151 | width (int): image width 152 | height (int): image height 153 | 154 | Returns: 155 | array: processed depth 156 | """ 157 | depth = torch.squeeze(depth[0, :, :, :]).to("cpu") 158 | 159 | depth_resized = cv2.resize( 160 | depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC 161 | ) 162 | 163 | return depth_resized 164 | 165 | def write_depth(path, depth, bits=1): 166 | """Write depth map to pfm and png file. 167 | 168 | Args: 169 | path (str): filepath without extension 170 | depth (array): depth 171 | """ 172 | write_pfm(path + ".pfm", depth.astype(np.float32)) 173 | 174 | depth_min = depth.min() 175 | depth_max = depth.max() 176 | 177 | max_val = (2**(8*bits))-1 178 | 179 | if depth_max - depth_min > np.finfo("float").eps: 180 | out = max_val * (depth - depth_min) / (depth_max - depth_min) 181 | else: 182 | out = np.zeros(depth.shape, dtype=depth.type) 183 | 184 | if bits == 1: 185 | cv2.imwrite(path + ".png", out.astype("uint8")) 186 | elif bits == 2: 187 | cv2.imwrite(path + ".png", out.astype("uint16")) 188 | 189 | return 190 | -------------------------------------------------------------------------------- /ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | from torch import optim 5 | import numpy as np 6 | 7 | from inspect import isfunction 8 | from PIL import Image, ImageDraw, ImageFont 9 | 10 | 11 | def log_txt_as_img(wh, xc, size=10): 12 | # wh a tuple of (width, height) 13 | # xc a list of captions to plot 14 | b = len(xc) 15 | txts = list() 16 | for bi in range(b): 17 | txt = Image.new("RGB", wh, color="white") 18 | draw = ImageDraw.Draw(txt) 19 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 20 | nc = int(40 * (wh[0] / 256)) 21 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 22 | 23 | try: 24 | draw.text((0, 0), lines, fill="black", font=font) 25 | except UnicodeEncodeError: 26 | print("Cant encode string for logging. Skipping.") 27 | 28 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 29 | txts.append(txt) 30 | txts = np.stack(txts) 31 | txts = torch.tensor(txts) 32 | return txts 33 | 34 | 35 | def ismap(x): 36 | if not isinstance(x, torch.Tensor): 37 | return False 38 | return (len(x.shape) == 4) and (x.shape[1] > 3) 39 | 40 | 41 | def isimage(x): 42 | if not isinstance(x,torch.Tensor): 43 | return False 44 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 45 | 46 | 47 | def exists(x): 48 | return x is not None 49 | 50 | 51 | def default(val, d): 52 | if exists(val): 53 | return val 54 | return d() if isfunction(d) else d 55 | 56 | 57 | def mean_flat(tensor): 58 | """ 59 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 60 | Take the mean over all non-batch dimensions. 61 | """ 62 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 63 | 64 | 65 | def count_params(model, verbose=False): 66 | total_params = sum(p.numel() for p in model.parameters()) 67 | if verbose: 68 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") 69 | return total_params 70 | 71 | 72 | def instantiate_from_config(config): 73 | if not "target" in config: 74 | if config == '__is_first_stage__': 75 | return None 76 | elif config == "__is_unconditional__": 77 | return None 78 | raise KeyError("Expected key `target` to instantiate.") 79 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 80 | 81 | 82 | def get_obj_from_str(string, reload=False): 83 | module, cls = string.rsplit(".", 1) 84 | if reload: 85 | module_imp = importlib.import_module(module) 86 | importlib.reload(module_imp) 87 | return getattr(importlib.import_module(module, package=None), cls) 88 | 89 | 90 | class AdamWwithEMAandWings(optim.Optimizer): 91 | # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 92 | def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using 93 | weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code 94 | ema_power=1., param_names=()): 95 | """AdamW that saves EMA versions of the parameters.""" 96 | if not 0.0 <= lr: 97 | raise ValueError("Invalid learning rate: {}".format(lr)) 98 | if not 0.0 <= eps: 99 | raise ValueError("Invalid epsilon value: {}".format(eps)) 100 | if not 0.0 <= betas[0] < 1.0: 101 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 102 | if not 0.0 <= betas[1] < 1.0: 103 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 104 | if not 0.0 <= weight_decay: 105 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 106 | if not 0.0 <= ema_decay <= 1.0: 107 | raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) 108 | defaults = dict(lr=lr, betas=betas, eps=eps, 109 | weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, 110 | ema_power=ema_power, param_names=param_names) 111 | super().__init__(params, defaults) 112 | 113 | def __setstate__(self, state): 114 | super().__setstate__(state) 115 | for group in self.param_groups: 116 | group.setdefault('amsgrad', False) 117 | 118 | @torch.no_grad() 119 | def step(self, closure=None): 120 | """Performs a single optimization step. 121 | Args: 122 | closure (callable, optional): A closure that reevaluates the model 123 | and returns the loss. 124 | """ 125 | loss = None 126 | if closure is not None: 127 | with torch.enable_grad(): 128 | loss = closure() 129 | 130 | for group in self.param_groups: 131 | params_with_grad = [] 132 | grads = [] 133 | exp_avgs = [] 134 | exp_avg_sqs = [] 135 | ema_params_with_grad = [] 136 | state_sums = [] 137 | max_exp_avg_sqs = [] 138 | state_steps = [] 139 | amsgrad = group['amsgrad'] 140 | beta1, beta2 = group['betas'] 141 | ema_decay = group['ema_decay'] 142 | ema_power = group['ema_power'] 143 | 144 | for p in group['params']: 145 | if p.grad is None: 146 | continue 147 | params_with_grad.append(p) 148 | if p.grad.is_sparse: 149 | raise RuntimeError('AdamW does not support sparse gradients') 150 | grads.append(p.grad) 151 | 152 | state = self.state[p] 153 | 154 | # State initialization 155 | if len(state) == 0: 156 | state['step'] = 0 157 | # Exponential moving average of gradient values 158 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 159 | # Exponential moving average of squared gradient values 160 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 161 | if amsgrad: 162 | # Maintains max of all exp. moving avg. of sq. grad. values 163 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 164 | # Exponential moving average of parameter values 165 | state['param_exp_avg'] = p.detach().float().clone() 166 | 167 | exp_avgs.append(state['exp_avg']) 168 | exp_avg_sqs.append(state['exp_avg_sq']) 169 | ema_params_with_grad.append(state['param_exp_avg']) 170 | 171 | if amsgrad: 172 | max_exp_avg_sqs.append(state['max_exp_avg_sq']) 173 | 174 | # update the steps for each param group update 175 | state['step'] += 1 176 | # record the step after step update 177 | state_steps.append(state['step']) 178 | 179 | optim._functional.adamw(params_with_grad, 180 | grads, 181 | exp_avgs, 182 | exp_avg_sqs, 183 | max_exp_avg_sqs, 184 | state_steps, 185 | amsgrad=amsgrad, 186 | beta1=beta1, 187 | beta2=beta2, 188 | lr=group['lr'], 189 | weight_decay=group['weight_decay'], 190 | eps=group['eps'], 191 | maximize=False) 192 | 193 | cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) 194 | for param, ema_param in zip(params_with_grad, ema_params_with_grad): 195 | ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) 196 | 197 | return loss -------------------------------------------------------------------------------- /models/cldm_v15.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: cldm.cldm.ControlLDM 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.0120 6 | num_timesteps_cond: 1 7 | log_every_t: 200 8 | timesteps: 1000 9 | first_stage_key: "jpg" 10 | cond_stage_key: "txt" 11 | control_key: "hint" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | only_mid_control: False 20 | 21 | control_stage_config: 22 | target: cldm.cldm.ControlNet 23 | params: 24 | image_size: 32 # unused 25 | in_channels: 4 26 | hint_channels: 3 27 | model_channels: 320 28 | attention_resolutions: [ 4, 2, 1 ] 29 | num_res_blocks: 2 30 | channel_mult: [ 1, 2, 4, 4 ] 31 | num_heads: 8 32 | use_spatial_transformer: True 33 | transformer_depth: 1 34 | context_dim: 768 35 | use_checkpoint: True 36 | legacy: False 37 | 38 | unet_config: 39 | target: cldm.cldm.ControlledUnetModel 40 | params: 41 | image_size: 32 # unused 42 | in_channels: 4 43 | out_channels: 4 44 | model_channels: 320 45 | attention_resolutions: [ 4, 2, 1 ] 46 | num_res_blocks: 2 47 | channel_mult: [ 1, 2, 4, 4 ] 48 | num_heads: 8 49 | use_spatial_transformer: True 50 | transformer_depth: 1 51 | context_dim: 768 52 | use_checkpoint: True 53 | legacy: False 54 | 55 | first_stage_config: 56 | target: ldm.models.autoencoder.AutoencoderKL 57 | params: 58 | embed_dim: 4 59 | monitor: val/rec_loss 60 | ddconfig: 61 | double_z: true 62 | z_channels: 4 63 | resolution: 256 64 | in_channels: 3 65 | out_ch: 3 66 | ch: 128 67 | ch_mult: 68 | - 1 69 | - 2 70 | - 4 71 | - 4 72 | num_res_blocks: 2 73 | attn_resolutions: [] 74 | dropout: 0.0 75 | lossconfig: 76 | target: torch.nn.Identity 77 | 78 | cond_stage_config: 79 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 80 | -------------------------------------------------------------------------------- /models/cldm_v21.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: cldm.cldm.ControlLDM 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.0120 6 | num_timesteps_cond: 1 7 | log_every_t: 200 8 | timesteps: 1000 9 | first_stage_key: "jpg" 10 | cond_stage_key: "txt" 11 | control_key: "hint" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | only_mid_control: False 20 | 21 | control_stage_config: 22 | target: cldm.cldm.ControlNet 23 | params: 24 | use_checkpoint: True 25 | image_size: 32 # unused 26 | in_channels: 4 27 | hint_channels: 3 28 | model_channels: 320 29 | attention_resolutions: [ 4, 2, 1 ] 30 | num_res_blocks: 2 31 | channel_mult: [ 1, 2, 4, 4 ] 32 | num_head_channels: 64 # need to fix for flash-attn 33 | use_spatial_transformer: True 34 | use_linear_in_transformer: True 35 | transformer_depth: 1 36 | context_dim: 1024 37 | legacy: False 38 | 39 | unet_config: 40 | target: cldm.cldm.ControlledUnetModel 41 | params: 42 | use_checkpoint: True 43 | image_size: 32 # unused 44 | in_channels: 4 45 | out_channels: 4 46 | model_channels: 320 47 | attention_resolutions: [ 4, 2, 1 ] 48 | num_res_blocks: 2 49 | channel_mult: [ 1, 2, 4, 4 ] 50 | num_head_channels: 64 # need to fix for flash-attn 51 | use_spatial_transformer: True 52 | use_linear_in_transformer: True 53 | transformer_depth: 1 54 | context_dim: 1024 55 | legacy: False 56 | 57 | first_stage_config: 58 | target: ldm.models.autoencoder.AutoencoderKL 59 | params: 60 | embed_dim: 4 61 | monitor: val/rec_loss 62 | ddconfig: 63 | #attn_type: "vanilla-xformers" 64 | double_z: true 65 | z_channels: 4 66 | resolution: 256 67 | in_channels: 3 68 | out_ch: 3 69 | ch: 128 70 | ch_mult: 71 | - 1 72 | - 2 73 | - 4 74 | - 4 75 | num_res_blocks: 2 76 | attn_resolutions: [] 77 | dropout: 0.0 78 | lossconfig: 79 | target: torch.nn.Identity 80 | 81 | cond_stage_config: 82 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 83 | params: 84 | freeze: True 85 | layer: "penultimate" 86 | -------------------------------------------------------------------------------- /models/cycle_v21.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: cycleNet.cycleNet.CycleLDM 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.0120 6 | num_timesteps_cond: 1 7 | log_every_t: 200 8 | timesteps: 1000 9 | first_stage_key: "jpg" 10 | cond_stage_key: "txt" 11 | uncond_stage_key: "source" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | only_mid_control: False 20 | recon_weight: 1 21 | cycle_weight: 0.01 22 | disc_weight: 0.1 23 | disc_mode: eps 24 | consis_weight: 0.1 25 | 26 | control_stage_config: 27 | target: cldm.cldm.ControlNet 28 | params: 29 | use_checkpoint: True 30 | image_size: 32 # unused 31 | in_channels: 4 32 | hint_channels: 3 33 | model_channels: 320 34 | attention_resolutions: [ 4, 2, 1 ] 35 | num_res_blocks: 2 36 | channel_mult: [ 1, 2, 4, 4 ] 37 | num_head_channels: 64 # need to fix for flash-attn 38 | use_spatial_transformer: True 39 | use_linear_in_transformer: True 40 | transformer_depth: 1 41 | context_dim: 1024 42 | legacy: False 43 | 44 | unet_config: 45 | target: cycleNet.cycleNet.ControlledUnetModel 46 | params: 47 | use_checkpoint: True 48 | image_size: 32 # unused 49 | in_channels: 4 50 | out_channels: 4 51 | model_channels: 320 52 | attention_resolutions: [ 4, 2, 1 ] 53 | num_res_blocks: 2 54 | channel_mult: [ 1, 2, 4, 4 ] 55 | num_head_channels: 64 # need to fix for flash-attn 56 | use_spatial_transformer: True 57 | use_linear_in_transformer: True 58 | transformer_depth: 1 59 | context_dim: 1024 60 | legacy: False 61 | 62 | first_stage_config: 63 | target: ldm.models.autoencoder.AutoencoderKL 64 | params: 65 | embed_dim: 4 66 | monitor: val/rec_loss 67 | ddconfig: 68 | #attn_type: "vanilla-xformers" 69 | double_z: true 70 | z_channels: 4 71 | resolution: 256 72 | in_channels: 3 73 | out_ch: 3 74 | ch: 128 75 | ch_mult: 76 | - 1 77 | - 2 78 | - 4 79 | - 4 80 | num_res_blocks: 2 81 | attn_resolutions: [] 82 | dropout: 0.0 83 | lossconfig: 84 | target: torch.nn.Identity 85 | 86 | cond_stage_config: 87 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 88 | params: 89 | freeze: True 90 | layer: "penultimate" 91 | -------------------------------------------------------------------------------- /models/fastcycle_v21.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: cycleNet.cycleNet_fast.CycleLDM 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.0120 6 | num_timesteps_cond: 1 7 | log_every_t: 200 8 | timesteps: 1000 9 | first_stage_key: "jpg" 10 | cond_stage_key: "txt" 11 | uncond_stage_key: "source" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | only_mid_control: False 20 | recon_weight: 1 21 | disc_weight: 0.1 22 | disc_mode: eps 23 | 24 | control_stage_config: 25 | target: cldm.cldm.ControlNet 26 | params: 27 | use_checkpoint: True 28 | image_size: 32 # unused 29 | in_channels: 4 30 | hint_channels: 3 31 | model_channels: 320 32 | attention_resolutions: [ 4, 2, 1 ] 33 | num_res_blocks: 2 34 | channel_mult: [ 1, 2, 4, 4 ] 35 | num_head_channels: 64 # need to fix for flash-attn 36 | use_spatial_transformer: True 37 | use_linear_in_transformer: True 38 | transformer_depth: 1 39 | context_dim: 1024 40 | legacy: False 41 | 42 | unet_config: 43 | target: cycleNet.cycleNet.ControlledUnetModel 44 | params: 45 | use_checkpoint: True 46 | image_size: 32 # unused 47 | in_channels: 4 48 | out_channels: 4 49 | model_channels: 320 50 | attention_resolutions: [ 4, 2, 1 ] 51 | num_res_blocks: 2 52 | channel_mult: [ 1, 2, 4, 4 ] 53 | num_head_channels: 64 # need to fix for flash-attn 54 | use_spatial_transformer: True 55 | use_linear_in_transformer: True 56 | transformer_depth: 1 57 | context_dim: 1024 58 | legacy: False 59 | 60 | first_stage_config: 61 | target: ldm.models.autoencoder.AutoencoderKL 62 | params: 63 | embed_dim: 4 64 | monitor: val/rec_loss 65 | ddconfig: 66 | #attn_type: "vanilla-xformers" 67 | double_z: true 68 | z_channels: 4 69 | resolution: 256 70 | in_channels: 3 71 | out_ch: 3 72 | ch: 128 73 | ch_mult: 74 | - 1 75 | - 2 76 | - 4 77 | - 4 78 | num_res_blocks: 2 79 | attn_resolutions: [] 80 | dropout: 0.0 81 | lossconfig: 82 | target: torch.nn.Identity 83 | 84 | cond_stage_config: 85 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 86 | params: 87 | freeze: True 88 | layer: "penultimate" 89 | -------------------------------------------------------------------------------- /share.py: -------------------------------------------------------------------------------- 1 | import config 2 | from cldm.hack import disable_verbosity, enable_sliced_attention 3 | 4 | 5 | disable_verbosity() 6 | 7 | if config.save_memory: 8 | enable_sliced_attention() 9 | -------------------------------------------------------------------------------- /tool_add_cycle_sd21.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | assert len(sys.argv) == 3, 'Args are wrong.' 5 | 6 | input_path = sys.argv[1] 7 | output_path = sys.argv[2] 8 | 9 | assert os.path.exists(input_path), 'Input model does not exist.' 10 | assert not os.path.exists(output_path), 'Output filename already exists.' 11 | assert os.path.exists(os.path.dirname(output_path)), 'Output path is not valid.' 12 | 13 | import torch 14 | from share import * 15 | from cycleNet.model import create_model 16 | 17 | 18 | def get_node_name(name, parent_name): 19 | if len(name) <= len(parent_name): 20 | return False, '' 21 | p = name[:len(parent_name)] 22 | if p != parent_name: 23 | return False, '' 24 | return True, name[len(parent_name):] 25 | 26 | 27 | model = create_model(config_path='./models/cycle_v21.yaml') 28 | 29 | pretrained_weights = torch.load(input_path) 30 | if 'state_dict' in pretrained_weights: 31 | pretrained_weights = pretrained_weights['state_dict'] 32 | 33 | scratch_dict = model.state_dict() 34 | 35 | target_dict = {} 36 | for k in scratch_dict.keys(): 37 | is_control, name = get_node_name(k, 'control_') 38 | if is_control: 39 | copy_k = 'model.diffusion_' + name 40 | else: 41 | copy_k = k 42 | if copy_k in pretrained_weights: 43 | target_dict[k] = pretrained_weights[copy_k].clone() 44 | else: 45 | target_dict[k] = scratch_dict[k].clone() 46 | print(f'These weights are newly added: {k}') 47 | 48 | model.load_state_dict(target_dict, strict=True) 49 | torch.save(model.state_dict(), output_path) 50 | print('Done.') 51 | --------------------------------------------------------------------------------