├── Adapter ├── Sampling.py ├── extra_condition │ ├── api.py │ ├── model_edge.py │ └── openpose │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── api.cpython-310.pyc │ │ ├── body.cpython-310.pyc │ │ ├── model.cpython-310.pyc │ │ └── util.cpython-310.pyc │ │ ├── api.py │ │ ├── body.py │ │ ├── hand.py │ │ ├── model.py │ │ └── util.py ├── inference_base.py ├── models │ └── adapters.py └── utils.py ├── README.md ├── app.py ├── assets └── logo3.png ├── configs ├── inference │ ├── Adapter-XL-canny.yaml │ ├── Adapter-XL-openpose.yaml │ └── Adapter-XL-sketch.yaml ├── train │ └── Adapter-XL-sketch.yaml └── utils.py ├── dataset ├── dataset_laion.py └── utils.py ├── demo.py ├── examples ├── dog.png ├── people.jpg └── room.jpg ├── models └── unet.py ├── requirements.txt ├── test.py └── train_sketch.py /Adapter/Sampling.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | from diffusers import DDPMScheduler, AutoencoderKL 3 | import torch 4 | from pytorch_lightning import seed_everything 5 | import tqdm 6 | import copy 7 | import random 8 | from basicsr.utils import tensor2img 9 | import numpy as np 10 | 11 | from Adapter.utils import import_model_class_from_model_name_or_path 12 | from models.unet import UNet 13 | 14 | class diffusion_inference: 15 | def __init__(self, model_id): 16 | self.device = 'cuda' 17 | self.model_id = model_id 18 | 19 | # load unet model 20 | self.scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler") 21 | self.model = UNet.from_pretrained(model_id, subfolder="unet").to(self.device) 22 | try: 23 | self.model.enable_xformers_memory_efficient_attention() 24 | except: 25 | print('The current xformers is not compatible, please reinstall xformers to speed up.') 26 | self.scheduler.set_timesteps(50) 27 | 28 | tokenizer_one = AutoTokenizer.from_pretrained( 29 | self.model_id, subfolder="tokenizer", revision=None, use_fast=False 30 | ) 31 | tokenizer_two = AutoTokenizer.from_pretrained( 32 | self.model_id, subfolder="tokenizer_2", revision=None, use_fast=False 33 | ) 34 | 35 | # import correct text encoder classes 36 | text_encoder_cls_one = import_model_class_from_model_name_or_path( 37 | self.model_id, None 38 | ) 39 | text_encoder_cls_two = import_model_class_from_model_name_or_path( 40 | self.model_id, None, subfolder="text_encoder_2" 41 | ) 42 | 43 | # Load scheduler and models 44 | text_encoder_one = text_encoder_cls_one.from_pretrained( 45 | self.model_id, subfolder="text_encoder", revision=None 46 | ) 47 | text_encoder_two = text_encoder_cls_two.from_pretrained( 48 | self.model_id, subfolder="text_encoder_2", revision=None 49 | ) 50 | # self.text_encoders = [text_encoder_one.to(self.device), text_encoder_two.to(self.device)] 51 | self.text_encoders = [text_encoder_one, text_encoder_two] 52 | self.tokenizers = [tokenizer_one, tokenizer_two] 53 | self.vae = AutoencoderKL.from_pretrained( 54 | self.model_id, 55 | subfolder="vae", 56 | revision=None, 57 | )#.to(self.device) 58 | 59 | def reset_schedule(self, timesteps): 60 | self.scheduler.set_timesteps(timesteps) 61 | 62 | def inference(self, prompt, size, prompt_n='', adapter_features=None, guidance_scale=7.5, seed=-1, steps=50): 63 | prompt_batch = [prompt_n, prompt] 64 | prompt_embeds, unet_added_cond_kwargs = self.compute_embeddings( 65 | prompt_batch=prompt_batch,proportion_empty_prompts=0,text_encoders=self.text_encoders,tokenizers=self.tokenizers,size=size 66 | ) 67 | self.reset_schedule(steps) 68 | if seed != -1: 69 | seed_everything(seed) 70 | noisy_latents = torch.randn((1, 4, size[0]//8, size[1]//8)).to("cuda") 71 | 72 | with torch.no_grad(): 73 | for t in tqdm.tqdm(self.scheduler.timesteps): 74 | with torch.no_grad(): 75 | input = torch.cat([noisy_latents]*2) 76 | noise_pred = self.model( 77 | input, 78 | t, 79 | encoder_hidden_states=prompt_embeds["prompt_embeds"], 80 | added_cond_kwargs=unet_added_cond_kwargs, 81 | down_block_additional_residuals=copy.deepcopy(adapter_features), 82 | )[0] 83 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 84 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 85 | noisy_latents = self.scheduler.step(noise_pred, t, noisy_latents)[0] 86 | 87 | image = self.vae.decode(noisy_latents.cpu() / self.vae.config.scaling_factor, return_dict=False)[0] 88 | image = (image / 2 + 0.5).clamp(0, 1) 89 | image = tensor2img(image) 90 | 91 | return image 92 | 93 | 94 | def encode_prompt(self, prompt_batch, proportion_empty_prompts, is_train=True): 95 | prompt_embeds_list = [] 96 | 97 | captions = [] 98 | for caption in prompt_batch: 99 | if random.random() < proportion_empty_prompts: 100 | captions.append("") 101 | elif isinstance(caption, str): 102 | captions.append(caption) 103 | elif isinstance(caption, (list, np.ndarray)): 104 | # take a random caption if there are multiple 105 | captions.append(random.choice(caption) if is_train else caption[0]) 106 | 107 | with torch.no_grad(): 108 | for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders): 109 | text_inputs = tokenizer( 110 | captions, 111 | padding="max_length", 112 | max_length=tokenizer.model_max_length, 113 | truncation=True, 114 | return_tensors="pt", 115 | ) 116 | text_input_ids = text_inputs.input_ids 117 | prompt_embeds = text_encoder( 118 | text_input_ids.to(text_encoder.device), 119 | output_hidden_states=True, 120 | ) 121 | 122 | # We are only ALWAYS interested in the pooled output of the final text encoder 123 | pooled_prompt_embeds = prompt_embeds[0] 124 | prompt_embeds = prompt_embeds.hidden_states[-2] 125 | bs_embed, seq_len, _ = prompt_embeds.shape 126 | prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) 127 | prompt_embeds_list.append(prompt_embeds) 128 | 129 | prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) 130 | pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) 131 | return prompt_embeds, pooled_prompt_embeds 132 | 133 | def compute_embeddings(self, prompt_batch, proportion_empty_prompts, text_encoders, tokenizers, size, is_train=True): 134 | original_size = size 135 | target_size = size 136 | crops_coords_top_left = (0, 0) 137 | 138 | prompt_embeds, pooled_prompt_embeds = self.encode_prompt( 139 | prompt_batch, proportion_empty_prompts, is_train 140 | ) 141 | add_text_embeds = pooled_prompt_embeds 142 | 143 | # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids 144 | add_time_ids = list(original_size + crops_coords_top_left + target_size) 145 | add_time_ids = torch.tensor([add_time_ids]) 146 | 147 | prompt_embeds = prompt_embeds.to(self.device) 148 | add_text_embeds = add_text_embeds.to(self.device) 149 | add_time_ids = add_time_ids.repeat(len(prompt_batch), 1) 150 | add_time_ids = add_time_ids.to(self.device, dtype=prompt_embeds.dtype) 151 | unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} 152 | 153 | return {"prompt_embeds": prompt_embeds}, unet_added_cond_kwargs 154 | 155 | 156 | -------------------------------------------------------------------------------- /Adapter/extra_condition/api.py: -------------------------------------------------------------------------------- 1 | from enum import Enum, unique 2 | import cv2 3 | import torch 4 | from basicsr.utils import img2tensor 5 | from PIL import Image 6 | from torch import autocast 7 | 8 | from Adapter.utils import resize_numpy_image 9 | 10 | @unique 11 | class ExtraCondition(Enum): 12 | sketch = 0 13 | keypose = 1 14 | seg = 2 15 | depth = 3 16 | canny = 4 17 | style = 5 18 | color = 6 19 | openpose = 7 20 | edge = 8 21 | zoedepth = 9 22 | 23 | 24 | def get_cond_model(opt, cond_type: ExtraCondition): 25 | if cond_type == ExtraCondition.sketch: 26 | from Adapter.extra_condition.model_edge import pidinet 27 | model = pidinet() 28 | ckp = torch.load('checkpoints/table5_pidinet.pth', map_location='cpu')['state_dict'] 29 | model.load_state_dict({k.replace('module.', ''): v for k, v in ckp.items()}, strict=True) 30 | model.to(opt.device) 31 | return model 32 | elif cond_type == ExtraCondition.seg: 33 | raise NotImplementedError 34 | elif cond_type == ExtraCondition.keypose: 35 | import mmcv 36 | from mmdet.apis import init_detector 37 | from mmpose.apis import init_pose_model 38 | det_config = 'configs/mm/faster_rcnn_r50_fpn_coco.py' 39 | det_checkpoint = 'checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth' 40 | pose_config = 'configs/mm/hrnet_w48_coco_256x192.py' 41 | pose_checkpoint = 'checkpoints/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth' 42 | det_config_mmcv = mmcv.Config.fromfile(det_config) 43 | det_model = init_detector(det_config_mmcv, det_checkpoint, device=opt.device) 44 | pose_config_mmcv = mmcv.Config.fromfile(pose_config) 45 | pose_model = init_pose_model(pose_config_mmcv, pose_checkpoint, device=opt.device) 46 | return {'pose_model': pose_model, 'det_model': det_model} 47 | elif cond_type == ExtraCondition.depth: 48 | from Adapter.extra_condition.midas.api import MiDaSInference 49 | model = MiDaSInference(model_type='dpt_hybrid').to(opt.device) 50 | return model 51 | elif cond_type == ExtraCondition.zoedepth: 52 | from handyinfer.depth_estimation import init_depth_estimation_model 53 | model = init_depth_estimation_model('ZoeD_N', device=opt.device) 54 | return model 55 | elif cond_type == ExtraCondition.canny: 56 | return None 57 | elif cond_type == ExtraCondition.style: 58 | from transformers import CLIPProcessor, CLIPVisionModel 59 | version = 'openai/clip-vit-large-patch14' 60 | processor = CLIPProcessor.from_pretrained(version) 61 | clip_vision_model = CLIPVisionModel.from_pretrained(version).to(opt.device) 62 | return {'processor': processor, 'clip_vision_model': clip_vision_model} 63 | elif cond_type == ExtraCondition.color: 64 | return None 65 | elif cond_type == ExtraCondition.openpose: 66 | from Adapter.extra_condition.openpose.api import OpenposeInference 67 | model = OpenposeInference().to(opt.device) 68 | return model 69 | elif cond_type == ExtraCondition.edge: 70 | return None 71 | else: 72 | raise NotImplementedError 73 | 74 | 75 | def get_cond_sketch(opt, cond_image, cond_inp_type, cond_model=None): 76 | if isinstance(cond_image, str): 77 | edge = cv2.imread(cond_image) 78 | else: 79 | # for gradio input, pay attention, it's rgb numpy 80 | edge = cv2.cvtColor(cond_image, cv2.COLOR_RGB2BGR) 81 | edge = resize_numpy_image(edge, max_resolution=opt.max_resolution, resize_short_edge=opt.resize_short_edge) 82 | opt.H, opt.W = edge.shape[:2] 83 | if cond_inp_type == 'sketch': 84 | edge = img2tensor(edge)[0].unsqueeze(0).unsqueeze(0) / 255. 85 | edge = edge.to(opt.device) 86 | elif cond_inp_type == 'image': 87 | edge = img2tensor(edge).unsqueeze(0) / 255. 88 | edge = cond_model(edge.to(opt.device))[-1] 89 | else: 90 | raise NotImplementedError 91 | 92 | # edge = 1-edge # for white background 93 | edge = edge > 0.5 94 | edge = edge.float() 95 | 96 | return edge 97 | 98 | 99 | def get_cond_seg(opt, cond_image, cond_inp_type='image', cond_model=None): 100 | if isinstance(cond_image, str): 101 | seg = cv2.imread(cond_image) 102 | else: 103 | seg = cv2.cvtColor(cond_image, cv2.COLOR_RGB2BGR) 104 | seg = resize_numpy_image(seg, max_resolution=opt.max_resolution, resize_short_edge=opt.resize_short_edge) 105 | opt.H, opt.W = seg.shape[:2] 106 | if cond_inp_type == 'seg': 107 | seg = img2tensor(seg).unsqueeze(0) / 255. 108 | seg = seg.to(opt.device) 109 | else: 110 | raise NotImplementedError 111 | 112 | return seg 113 | 114 | 115 | def get_cond_keypose(opt, cond_image, cond_inp_type='image', cond_model=None): 116 | if isinstance(cond_image, str): 117 | pose = cv2.imread(cond_image) 118 | else: 119 | pose = cv2.cvtColor(cond_image, cv2.COLOR_RGB2BGR) 120 | pose = resize_numpy_image(pose, max_resolution=opt.max_resolution, resize_short_edge=opt.resize_short_edge) 121 | opt.H, opt.W = pose.shape[:2] 122 | if cond_inp_type == 'keypose': 123 | pose = img2tensor(pose).unsqueeze(0) / 255. 124 | pose = pose.to(opt.device) 125 | elif cond_inp_type == 'image': 126 | from Adapter.extra_condition.utils import imshow_keypoints 127 | from mmdet.apis import inference_detector 128 | from mmpose.apis import (inference_top_down_pose_model, process_mmdet_results) 129 | 130 | # mmpose seems not compatible with autocast fp16 131 | with autocast("cuda", dtype=torch.float32): 132 | mmdet_results = inference_detector(cond_model['det_model'], pose) 133 | # keep the person class bounding boxes. 134 | person_results = process_mmdet_results(mmdet_results, 1) 135 | 136 | # optional 137 | return_heatmap = False 138 | dataset = cond_model['pose_model'].cfg.data['test']['type'] 139 | 140 | # e.g. use ('backbone', ) to return backbone feature 141 | output_layer_names = None 142 | pose_results, returned_outputs = inference_top_down_pose_model( 143 | cond_model['pose_model'], 144 | pose, 145 | person_results, 146 | bbox_thr=0.2, 147 | format='xyxy', 148 | dataset=dataset, 149 | dataset_info=None, 150 | return_heatmap=return_heatmap, 151 | outputs=output_layer_names) 152 | 153 | # show the results 154 | pose = imshow_keypoints(pose, pose_results, radius=2, thickness=2) 155 | pose = img2tensor(pose).unsqueeze(0) / 255. 156 | pose = pose.to(opt.device) 157 | else: 158 | raise NotImplementedError 159 | 160 | return pose 161 | 162 | 163 | def get_cond_depth(opt, cond_image, cond_inp_type='image', cond_model=None): 164 | if isinstance(cond_image, str): 165 | depth = cv2.imread(cond_image) 166 | else: 167 | depth = cv2.cvtColor(cond_image, cv2.COLOR_RGB2BGR) 168 | depth = resize_numpy_image(depth, max_resolution=opt.max_resolution, resize_short_edge=opt.resize_short_edge) 169 | opt.H, opt.W = depth.shape[:2] 170 | if cond_inp_type == 'depth': 171 | depth = img2tensor(depth).unsqueeze(0) / 255. 172 | depth = depth.to(opt.device) 173 | elif cond_inp_type == 'image': 174 | depth = img2tensor(depth).unsqueeze(0) / 127.5 - 1.0 175 | depth = cond_model(depth.to(opt.device)).repeat(1, 3, 1, 1) 176 | depth -= torch.min(depth) 177 | depth /= torch.max(depth) 178 | else: 179 | raise NotImplementedError 180 | 181 | return depth 182 | 183 | 184 | def get_cond_zoedepth(opt, cond_image, cond_inp_type='image', cond_model=None): 185 | if isinstance(cond_image, str): 186 | depth = cv2.imread(cond_image) 187 | else: 188 | depth = cv2.cvtColor(cond_image, cv2.COLOR_RGB2BGR) 189 | depth = resize_numpy_image(depth, max_resolution=opt.max_resolution, resize_short_edge=opt.resize_short_edge) 190 | opt.H, opt.W = depth.shape[:2] 191 | if cond_inp_type == 'zoedepth': 192 | depth = img2tensor(depth).unsqueeze(0) / 255. 193 | depth = depth.to(opt.device) 194 | elif cond_inp_type == 'image': 195 | depth = img2tensor(depth).unsqueeze(0) / 255. 196 | 197 | with autocast("cuda", dtype=torch.float32): 198 | depth = cond_model.infer(depth.to(opt.device)) 199 | depth = depth.repeat(1, 3, 1, 1) 200 | else: 201 | raise NotImplementedError 202 | 203 | return depth 204 | 205 | 206 | def get_cond_canny(opt, cond_image, cond_inp_type='image', cond_model=None): 207 | if isinstance(cond_image, str): 208 | canny = cv2.imread(cond_image) 209 | else: 210 | canny = cv2.cvtColor(cond_image, cv2.COLOR_RGB2BGR) 211 | canny = resize_numpy_image(canny, max_resolution=opt.max_resolution, resize_short_edge=opt.resize_short_edge) 212 | opt.H, opt.W = canny.shape[:2] 213 | if cond_inp_type == 'canny': 214 | canny = img2tensor(canny)[0:1].unsqueeze(0) / 255. 215 | canny = canny.to(opt.device) 216 | elif cond_inp_type == 'image': 217 | canny = cv2.Canny(canny, 100, 200)[..., None] 218 | canny = img2tensor(canny).unsqueeze(0) / 255. 219 | canny = canny.to(opt.device) 220 | else: 221 | raise NotImplementedError 222 | 223 | return canny 224 | 225 | 226 | def get_cond_style(opt, cond_image, cond_inp_type='image', cond_model=None): 227 | assert cond_inp_type == 'image' 228 | if isinstance(cond_image, str): 229 | style = Image.open(cond_image) 230 | else: 231 | # numpy image to PIL image 232 | style = Image.fromarray(cond_image) 233 | 234 | style_for_clip = cond_model['processor'](images=style, return_tensors="pt")['pixel_values'] 235 | style_feat = cond_model['clip_vision_model'](style_for_clip.to(opt.device))['last_hidden_state'] 236 | 237 | return style_feat 238 | 239 | 240 | def get_cond_color(opt, cond_image, cond_inp_type='image', cond_model=None): 241 | if isinstance(cond_image, str): 242 | color = cv2.imread(cond_image) 243 | else: 244 | color = cv2.cvtColor(cond_image, cv2.COLOR_RGB2BGR) 245 | color = resize_numpy_image(color, max_resolution=opt.max_resolution, resize_short_edge=opt.resize_short_edge) 246 | opt.H, opt.W = color.shape[:2] 247 | if cond_inp_type == 'image': 248 | color = cv2.resize(color, (opt.W // 64, opt.H // 64), interpolation=cv2.INTER_CUBIC) 249 | color = cv2.resize(color, (opt.W, opt.H), interpolation=cv2.INTER_NEAREST) 250 | color = img2tensor(color).unsqueeze(0) / 255. 251 | color = color.to(opt.device) 252 | return color 253 | 254 | 255 | def get_cond_openpose(opt, cond_image, cond_inp_type='image', cond_model=None): 256 | if isinstance(cond_image, str): 257 | openpose_keypose = cv2.imread(cond_image) 258 | else: 259 | openpose_keypose = cv2.cvtColor(cond_image, cv2.COLOR_RGB2BGR) 260 | openpose_keypose = resize_numpy_image( 261 | openpose_keypose, max_resolution=opt.max_resolution, resize_short_edge=opt.resize_short_edge) 262 | opt.H, opt.W = openpose_keypose.shape[:2] 263 | if cond_inp_type == 'openpose': 264 | openpose_keypose = img2tensor(openpose_keypose).unsqueeze(0) / 255. 265 | openpose_keypose = openpose_keypose.to(opt.device) 266 | elif cond_inp_type == 'image': 267 | with autocast('cuda', dtype=torch.float32): 268 | w, h = openpose_keypose.shape[:2] 269 | openpose_keypose = cv2.resize(openpose_keypose, (h//2, w//2)) 270 | openpose_keypose = cond_model(openpose_keypose) 271 | openpose_keypose = cv2.resize(openpose_keypose, (h, w))[:,:,::-1] 272 | openpose_keypose = img2tensor(openpose_keypose).unsqueeze(0) / 255. 273 | openpose_keypose = openpose_keypose.to(opt.device) 274 | 275 | else: 276 | raise NotImplementedError 277 | 278 | return openpose_keypose 279 | 280 | 281 | def get_cond_edge(opt, cond_image, cond_inp_type='image', cond_model=None): 282 | if isinstance(cond_image, str): 283 | edge = cv2.imread(cond_image) 284 | else: 285 | edge = cv2.cvtColor(cond_image, cv2.COLOR_RGB2BGR) 286 | edge = resize_numpy_image(edge, max_resolution=opt.max_resolution, resize_short_edge=opt.resize_short_edge) 287 | opt.H, opt.W = edge.shape[:2] 288 | if cond_inp_type == 'edge': 289 | edge = img2tensor(edge)[0:1].unsqueeze(0) / 255. 290 | edge = edge.to(opt.device) 291 | elif cond_inp_type == 'image': 292 | edge = cv2.cvtColor(edge, cv2.COLOR_RGB2GRAY) / 1.0 293 | blur = cv2.blur(edge, (4, 4)) 294 | edge = ((edge - blur) + 128)[..., None] 295 | edge = img2tensor(edge).unsqueeze(0) / 255. 296 | edge = edge.to(opt.device) 297 | else: 298 | raise NotImplementedError 299 | 300 | return edge 301 | 302 | 303 | def get_adapter_feature(inputs, adapters): 304 | ret_feat_map = None 305 | ret_feat_seq = None 306 | if not isinstance(inputs, list): 307 | inputs = [inputs] 308 | adapters = [adapters] 309 | 310 | for input, adapter in zip(inputs, adapters): 311 | cur_feature = adapter['model'](input) 312 | if isinstance(cur_feature, list): 313 | if ret_feat_map is None: 314 | ret_feat_map = list(map(lambda x: x * adapter['cond_weight'], cur_feature)) 315 | else: 316 | ret_feat_map = list(map(lambda x, y: x + y * adapter['cond_weight'], ret_feat_map, cur_feature)) 317 | else: 318 | if ret_feat_seq is None: 319 | ret_feat_seq = cur_feature * adapter['cond_weight'] 320 | else: 321 | ret_feat_seq = torch.cat([ret_feat_seq, cur_feature * adapter['cond_weight']], dim=1) 322 | 323 | return ret_feat_map, ret_feat_seq 324 | -------------------------------------------------------------------------------- /Adapter/extra_condition/model_edge.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Zhuo Su, Wenzhe Liu 3 | Date: Feb 18, 2021 4 | """ 5 | 6 | import math 7 | 8 | import cv2 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from basicsr.utils import img2tensor 14 | 15 | nets = { 16 | 'baseline': { 17 | 'layer0': 'cv', 18 | 'layer1': 'cv', 19 | 'layer2': 'cv', 20 | 'layer3': 'cv', 21 | 'layer4': 'cv', 22 | 'layer5': 'cv', 23 | 'layer6': 'cv', 24 | 'layer7': 'cv', 25 | 'layer8': 'cv', 26 | 'layer9': 'cv', 27 | 'layer10': 'cv', 28 | 'layer11': 'cv', 29 | 'layer12': 'cv', 30 | 'layer13': 'cv', 31 | 'layer14': 'cv', 32 | 'layer15': 'cv', 33 | }, 34 | 'c-v15': { 35 | 'layer0': 'cd', 36 | 'layer1': 'cv', 37 | 'layer2': 'cv', 38 | 'layer3': 'cv', 39 | 'layer4': 'cv', 40 | 'layer5': 'cv', 41 | 'layer6': 'cv', 42 | 'layer7': 'cv', 43 | 'layer8': 'cv', 44 | 'layer9': 'cv', 45 | 'layer10': 'cv', 46 | 'layer11': 'cv', 47 | 'layer12': 'cv', 48 | 'layer13': 'cv', 49 | 'layer14': 'cv', 50 | 'layer15': 'cv', 51 | }, 52 | 'a-v15': { 53 | 'layer0': 'ad', 54 | 'layer1': 'cv', 55 | 'layer2': 'cv', 56 | 'layer3': 'cv', 57 | 'layer4': 'cv', 58 | 'layer5': 'cv', 59 | 'layer6': 'cv', 60 | 'layer7': 'cv', 61 | 'layer8': 'cv', 62 | 'layer9': 'cv', 63 | 'layer10': 'cv', 64 | 'layer11': 'cv', 65 | 'layer12': 'cv', 66 | 'layer13': 'cv', 67 | 'layer14': 'cv', 68 | 'layer15': 'cv', 69 | }, 70 | 'r-v15': { 71 | 'layer0': 'rd', 72 | 'layer1': 'cv', 73 | 'layer2': 'cv', 74 | 'layer3': 'cv', 75 | 'layer4': 'cv', 76 | 'layer5': 'cv', 77 | 'layer6': 'cv', 78 | 'layer7': 'cv', 79 | 'layer8': 'cv', 80 | 'layer9': 'cv', 81 | 'layer10': 'cv', 82 | 'layer11': 'cv', 83 | 'layer12': 'cv', 84 | 'layer13': 'cv', 85 | 'layer14': 'cv', 86 | 'layer15': 'cv', 87 | }, 88 | 'cvvv4': { 89 | 'layer0': 'cd', 90 | 'layer1': 'cv', 91 | 'layer2': 'cv', 92 | 'layer3': 'cv', 93 | 'layer4': 'cd', 94 | 'layer5': 'cv', 95 | 'layer6': 'cv', 96 | 'layer7': 'cv', 97 | 'layer8': 'cd', 98 | 'layer9': 'cv', 99 | 'layer10': 'cv', 100 | 'layer11': 'cv', 101 | 'layer12': 'cd', 102 | 'layer13': 'cv', 103 | 'layer14': 'cv', 104 | 'layer15': 'cv', 105 | }, 106 | 'avvv4': { 107 | 'layer0': 'ad', 108 | 'layer1': 'cv', 109 | 'layer2': 'cv', 110 | 'layer3': 'cv', 111 | 'layer4': 'ad', 112 | 'layer5': 'cv', 113 | 'layer6': 'cv', 114 | 'layer7': 'cv', 115 | 'layer8': 'ad', 116 | 'layer9': 'cv', 117 | 'layer10': 'cv', 118 | 'layer11': 'cv', 119 | 'layer12': 'ad', 120 | 'layer13': 'cv', 121 | 'layer14': 'cv', 122 | 'layer15': 'cv', 123 | }, 124 | 'rvvv4': { 125 | 'layer0': 'rd', 126 | 'layer1': 'cv', 127 | 'layer2': 'cv', 128 | 'layer3': 'cv', 129 | 'layer4': 'rd', 130 | 'layer5': 'cv', 131 | 'layer6': 'cv', 132 | 'layer7': 'cv', 133 | 'layer8': 'rd', 134 | 'layer9': 'cv', 135 | 'layer10': 'cv', 136 | 'layer11': 'cv', 137 | 'layer12': 'rd', 138 | 'layer13': 'cv', 139 | 'layer14': 'cv', 140 | 'layer15': 'cv', 141 | }, 142 | 'cccv4': { 143 | 'layer0': 'cd', 144 | 'layer1': 'cd', 145 | 'layer2': 'cd', 146 | 'layer3': 'cv', 147 | 'layer4': 'cd', 148 | 'layer5': 'cd', 149 | 'layer6': 'cd', 150 | 'layer7': 'cv', 151 | 'layer8': 'cd', 152 | 'layer9': 'cd', 153 | 'layer10': 'cd', 154 | 'layer11': 'cv', 155 | 'layer12': 'cd', 156 | 'layer13': 'cd', 157 | 'layer14': 'cd', 158 | 'layer15': 'cv', 159 | }, 160 | 'aaav4': { 161 | 'layer0': 'ad', 162 | 'layer1': 'ad', 163 | 'layer2': 'ad', 164 | 'layer3': 'cv', 165 | 'layer4': 'ad', 166 | 'layer5': 'ad', 167 | 'layer6': 'ad', 168 | 'layer7': 'cv', 169 | 'layer8': 'ad', 170 | 'layer9': 'ad', 171 | 'layer10': 'ad', 172 | 'layer11': 'cv', 173 | 'layer12': 'ad', 174 | 'layer13': 'ad', 175 | 'layer14': 'ad', 176 | 'layer15': 'cv', 177 | }, 178 | 'rrrv4': { 179 | 'layer0': 'rd', 180 | 'layer1': 'rd', 181 | 'layer2': 'rd', 182 | 'layer3': 'cv', 183 | 'layer4': 'rd', 184 | 'layer5': 'rd', 185 | 'layer6': 'rd', 186 | 'layer7': 'cv', 187 | 'layer8': 'rd', 188 | 'layer9': 'rd', 189 | 'layer10': 'rd', 190 | 'layer11': 'cv', 191 | 'layer12': 'rd', 192 | 'layer13': 'rd', 193 | 'layer14': 'rd', 194 | 'layer15': 'cv', 195 | }, 196 | 'c16': { 197 | 'layer0': 'cd', 198 | 'layer1': 'cd', 199 | 'layer2': 'cd', 200 | 'layer3': 'cd', 201 | 'layer4': 'cd', 202 | 'layer5': 'cd', 203 | 'layer6': 'cd', 204 | 'layer7': 'cd', 205 | 'layer8': 'cd', 206 | 'layer9': 'cd', 207 | 'layer10': 'cd', 208 | 'layer11': 'cd', 209 | 'layer12': 'cd', 210 | 'layer13': 'cd', 211 | 'layer14': 'cd', 212 | 'layer15': 'cd', 213 | }, 214 | 'a16': { 215 | 'layer0': 'ad', 216 | 'layer1': 'ad', 217 | 'layer2': 'ad', 218 | 'layer3': 'ad', 219 | 'layer4': 'ad', 220 | 'layer5': 'ad', 221 | 'layer6': 'ad', 222 | 'layer7': 'ad', 223 | 'layer8': 'ad', 224 | 'layer9': 'ad', 225 | 'layer10': 'ad', 226 | 'layer11': 'ad', 227 | 'layer12': 'ad', 228 | 'layer13': 'ad', 229 | 'layer14': 'ad', 230 | 'layer15': 'ad', 231 | }, 232 | 'r16': { 233 | 'layer0': 'rd', 234 | 'layer1': 'rd', 235 | 'layer2': 'rd', 236 | 'layer3': 'rd', 237 | 'layer4': 'rd', 238 | 'layer5': 'rd', 239 | 'layer6': 'rd', 240 | 'layer7': 'rd', 241 | 'layer8': 'rd', 242 | 'layer9': 'rd', 243 | 'layer10': 'rd', 244 | 'layer11': 'rd', 245 | 'layer12': 'rd', 246 | 'layer13': 'rd', 247 | 'layer14': 'rd', 248 | 'layer15': 'rd', 249 | }, 250 | 'carv4': { 251 | 'layer0': 'cd', 252 | 'layer1': 'ad', 253 | 'layer2': 'rd', 254 | 'layer3': 'cv', 255 | 'layer4': 'cd', 256 | 'layer5': 'ad', 257 | 'layer6': 'rd', 258 | 'layer7': 'cv', 259 | 'layer8': 'cd', 260 | 'layer9': 'ad', 261 | 'layer10': 'rd', 262 | 'layer11': 'cv', 263 | 'layer12': 'cd', 264 | 'layer13': 'ad', 265 | 'layer14': 'rd', 266 | 'layer15': 'cv', 267 | }, 268 | } 269 | 270 | def createConvFunc(op_type): 271 | assert op_type in ['cv', 'cd', 'ad', 'rd'], 'unknown op type: %s' % str(op_type) 272 | if op_type == 'cv': 273 | return F.conv2d 274 | 275 | if op_type == 'cd': 276 | def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1): 277 | assert dilation in [1, 2], 'dilation for cd_conv should be in 1 or 2' 278 | assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for cd_conv should be 3x3' 279 | assert padding == dilation, 'padding for cd_conv set wrong' 280 | 281 | weights_c = weights.sum(dim=[2, 3], keepdim=True) 282 | # print(x.device, weights_c.device) 283 | yc = F.conv2d(x, weights_c, stride=stride, padding=0, groups=groups) 284 | y = F.conv2d(x, weights, bias, stride=stride, padding=padding, dilation=dilation, groups=groups) 285 | return y - yc 286 | return func 287 | elif op_type == 'ad': 288 | def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1): 289 | assert dilation in [1, 2], 'dilation for ad_conv should be in 1 or 2' 290 | assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for ad_conv should be 3x3' 291 | assert padding == dilation, 'padding for ad_conv set wrong' 292 | 293 | shape = weights.shape 294 | weights = weights.view(shape[0], shape[1], -1) 295 | weights_conv = (weights - weights[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]]).view(shape) # clock-wise 296 | y = F.conv2d(x, weights_conv, bias, stride=stride, padding=padding, dilation=dilation, groups=groups) 297 | return y 298 | return func 299 | elif op_type == 'rd': 300 | def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1): 301 | assert dilation in [1, 2], 'dilation for rd_conv should be in 1 or 2' 302 | assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for rd_conv should be 3x3' 303 | padding = 2 * dilation 304 | 305 | shape = weights.shape 306 | if weights.is_cuda: 307 | buffer = torch.cuda.FloatTensor(shape[0], shape[1], 5 * 5).fill_(0) 308 | else: 309 | buffer = torch.zeros(shape[0], shape[1], 5 * 5) 310 | buffer = buffer.to(dtype=weights.dtype) 311 | weights = weights.view(shape[0], shape[1], -1) 312 | buffer[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = weights[:, :, 1:] 313 | buffer[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -weights[:, :, 1:] 314 | buffer[:, :, 12] = 0 315 | buffer = buffer.view(shape[0], shape[1], 5, 5) 316 | y = F.conv2d(x, buffer, bias, stride=stride, padding=padding, dilation=dilation, groups=groups) 317 | return y 318 | return func 319 | else: 320 | print('impossible to be here unless you force that') 321 | return None 322 | 323 | class Conv2d(nn.Module): 324 | def __init__(self, pdc, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False): 325 | super(Conv2d, self).__init__() 326 | if in_channels % groups != 0: 327 | raise ValueError('in_channels must be divisible by groups') 328 | if out_channels % groups != 0: 329 | raise ValueError('out_channels must be divisible by groups') 330 | self.in_channels = in_channels 331 | self.out_channels = out_channels 332 | self.kernel_size = kernel_size 333 | self.stride = stride 334 | self.padding = padding 335 | self.dilation = dilation 336 | self.groups = groups 337 | self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, kernel_size, kernel_size)) 338 | if bias: 339 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 340 | else: 341 | self.register_parameter('bias', None) 342 | self.reset_parameters() 343 | self.pdc = pdc 344 | 345 | def reset_parameters(self): 346 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 347 | if self.bias is not None: 348 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) 349 | bound = 1 / math.sqrt(fan_in) 350 | nn.init.uniform_(self.bias, -bound, bound) 351 | 352 | def forward(self, input): 353 | 354 | return self.pdc(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 355 | 356 | class CSAM(nn.Module): 357 | """ 358 | Compact Spatial Attention Module 359 | """ 360 | def __init__(self, channels): 361 | super(CSAM, self).__init__() 362 | 363 | mid_channels = 4 364 | self.relu1 = nn.ReLU() 365 | self.conv1 = nn.Conv2d(channels, mid_channels, kernel_size=1, padding=0) 366 | self.conv2 = nn.Conv2d(mid_channels, 1, kernel_size=3, padding=1, bias=False) 367 | self.sigmoid = nn.Sigmoid() 368 | nn.init.constant_(self.conv1.bias, 0) 369 | 370 | def forward(self, x): 371 | y = self.relu1(x) 372 | y = self.conv1(y) 373 | y = self.conv2(y) 374 | y = self.sigmoid(y) 375 | 376 | return x * y 377 | 378 | class CDCM(nn.Module): 379 | """ 380 | Compact Dilation Convolution based Module 381 | """ 382 | def __init__(self, in_channels, out_channels): 383 | super(CDCM, self).__init__() 384 | 385 | self.relu1 = nn.ReLU() 386 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) 387 | self.conv2_1 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=5, padding=5, bias=False) 388 | self.conv2_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=7, padding=7, bias=False) 389 | self.conv2_3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=9, padding=9, bias=False) 390 | self.conv2_4 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=11, padding=11, bias=False) 391 | nn.init.constant_(self.conv1.bias, 0) 392 | 393 | def forward(self, x): 394 | x = self.relu1(x) 395 | x = self.conv1(x) 396 | x1 = self.conv2_1(x) 397 | x2 = self.conv2_2(x) 398 | x3 = self.conv2_3(x) 399 | x4 = self.conv2_4(x) 400 | return x1 + x2 + x3 + x4 401 | 402 | 403 | class MapReduce(nn.Module): 404 | """ 405 | Reduce feature maps into a single edge map 406 | """ 407 | def __init__(self, channels): 408 | super(MapReduce, self).__init__() 409 | self.conv = nn.Conv2d(channels, 1, kernel_size=1, padding=0) 410 | nn.init.constant_(self.conv.bias, 0) 411 | 412 | def forward(self, x): 413 | return self.conv(x) 414 | 415 | 416 | class PDCBlock(nn.Module): 417 | def __init__(self, pdc, inplane, ouplane, stride=1): 418 | super(PDCBlock, self).__init__() 419 | self.stride=stride 420 | 421 | self.stride=stride 422 | if self.stride > 1: 423 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 424 | self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0) 425 | self.conv1 = Conv2d(pdc, inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False) 426 | self.relu2 = nn.ReLU() 427 | self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False) 428 | 429 | def forward(self, x): 430 | if self.stride > 1: 431 | x = self.pool(x) 432 | y = self.conv1(x) 433 | y = self.relu2(y) 434 | y = self.conv2(y) 435 | if self.stride > 1: 436 | x = self.shortcut(x) 437 | y = y + x 438 | return y 439 | 440 | class PDCBlock_converted(nn.Module): 441 | """ 442 | CPDC, APDC can be converted to vanilla 3x3 convolution 443 | RPDC can be converted to vanilla 5x5 convolution 444 | """ 445 | def __init__(self, pdc, inplane, ouplane, stride=1): 446 | super(PDCBlock_converted, self).__init__() 447 | self.stride=stride 448 | 449 | if self.stride > 1: 450 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 451 | self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0) 452 | if pdc == 'rd': 453 | self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=5, padding=2, groups=inplane, bias=False) 454 | else: 455 | self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False) 456 | self.relu2 = nn.ReLU() 457 | self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False) 458 | 459 | def forward(self, x): 460 | if self.stride > 1: 461 | x = self.pool(x) 462 | y = self.conv1(x) 463 | y = self.relu2(y) 464 | y = self.conv2(y) 465 | if self.stride > 1: 466 | x = self.shortcut(x) 467 | y = y + x 468 | return y 469 | 470 | class PiDiNet(nn.Module): 471 | def __init__(self, inplane, pdcs, dil=None, sa=False, convert=False): 472 | super(PiDiNet, self).__init__() 473 | self.sa = sa 474 | if dil is not None: 475 | assert isinstance(dil, int), 'dil should be an int' 476 | self.dil = dil 477 | 478 | self.fuseplanes = [] 479 | 480 | self.inplane = inplane 481 | if convert: 482 | if pdcs[0] == 'rd': 483 | init_kernel_size = 5 484 | init_padding = 2 485 | else: 486 | init_kernel_size = 3 487 | init_padding = 1 488 | self.init_block = nn.Conv2d(3, self.inplane, 489 | kernel_size=init_kernel_size, padding=init_padding, bias=False) 490 | block_class = PDCBlock_converted 491 | else: 492 | self.init_block = Conv2d(pdcs[0], 3, self.inplane, kernel_size=3, padding=1) 493 | block_class = PDCBlock 494 | 495 | self.block1_1 = block_class(pdcs[1], self.inplane, self.inplane) 496 | self.block1_2 = block_class(pdcs[2], self.inplane, self.inplane) 497 | self.block1_3 = block_class(pdcs[3], self.inplane, self.inplane) 498 | self.fuseplanes.append(self.inplane) # C 499 | 500 | inplane = self.inplane 501 | self.inplane = self.inplane * 2 502 | self.block2_1 = block_class(pdcs[4], inplane, self.inplane, stride=2) 503 | self.block2_2 = block_class(pdcs[5], self.inplane, self.inplane) 504 | self.block2_3 = block_class(pdcs[6], self.inplane, self.inplane) 505 | self.block2_4 = block_class(pdcs[7], self.inplane, self.inplane) 506 | self.fuseplanes.append(self.inplane) # 2C 507 | 508 | inplane = self.inplane 509 | self.inplane = self.inplane * 2 510 | self.block3_1 = block_class(pdcs[8], inplane, self.inplane, stride=2) 511 | self.block3_2 = block_class(pdcs[9], self.inplane, self.inplane) 512 | self.block3_3 = block_class(pdcs[10], self.inplane, self.inplane) 513 | self.block3_4 = block_class(pdcs[11], self.inplane, self.inplane) 514 | self.fuseplanes.append(self.inplane) # 4C 515 | 516 | self.block4_1 = block_class(pdcs[12], self.inplane, self.inplane, stride=2) 517 | self.block4_2 = block_class(pdcs[13], self.inplane, self.inplane) 518 | self.block4_3 = block_class(pdcs[14], self.inplane, self.inplane) 519 | self.block4_4 = block_class(pdcs[15], self.inplane, self.inplane) 520 | self.fuseplanes.append(self.inplane) # 4C 521 | 522 | self.conv_reduces = nn.ModuleList() 523 | if self.sa and self.dil is not None: 524 | self.attentions = nn.ModuleList() 525 | self.dilations = nn.ModuleList() 526 | for i in range(4): 527 | self.dilations.append(CDCM(self.fuseplanes[i], self.dil)) 528 | self.attentions.append(CSAM(self.dil)) 529 | self.conv_reduces.append(MapReduce(self.dil)) 530 | elif self.sa: 531 | self.attentions = nn.ModuleList() 532 | for i in range(4): 533 | self.attentions.append(CSAM(self.fuseplanes[i])) 534 | self.conv_reduces.append(MapReduce(self.fuseplanes[i])) 535 | elif self.dil is not None: 536 | self.dilations = nn.ModuleList() 537 | for i in range(4): 538 | self.dilations.append(CDCM(self.fuseplanes[i], self.dil)) 539 | self.conv_reduces.append(MapReduce(self.dil)) 540 | else: 541 | for i in range(4): 542 | self.conv_reduces.append(MapReduce(self.fuseplanes[i])) 543 | 544 | self.classifier = nn.Conv2d(4, 1, kernel_size=1) # has bias 545 | nn.init.constant_(self.classifier.weight, 0.25) 546 | nn.init.constant_(self.classifier.bias, 0) 547 | 548 | # print('initialization done') 549 | 550 | def get_weights(self): 551 | conv_weights = [] 552 | bn_weights = [] 553 | relu_weights = [] 554 | for pname, p in self.named_parameters(): 555 | if 'bn' in pname: 556 | bn_weights.append(p) 557 | elif 'relu' in pname: 558 | relu_weights.append(p) 559 | else: 560 | conv_weights.append(p) 561 | 562 | return conv_weights, bn_weights, relu_weights 563 | 564 | def forward(self, x): 565 | H, W = x.size()[2:] 566 | 567 | x = self.init_block(x) 568 | 569 | x1 = self.block1_1(x) 570 | x1 = self.block1_2(x1) 571 | x1 = self.block1_3(x1) 572 | 573 | x2 = self.block2_1(x1) 574 | x2 = self.block2_2(x2) 575 | x2 = self.block2_3(x2) 576 | x2 = self.block2_4(x2) 577 | 578 | x3 = self.block3_1(x2) 579 | x3 = self.block3_2(x3) 580 | x3 = self.block3_3(x3) 581 | x3 = self.block3_4(x3) 582 | 583 | x4 = self.block4_1(x3) 584 | x4 = self.block4_2(x4) 585 | x4 = self.block4_3(x4) 586 | x4 = self.block4_4(x4) 587 | 588 | x_fuses = [] 589 | if self.sa and self.dil is not None: 590 | for i, xi in enumerate([x1, x2, x3, x4]): 591 | x_fuses.append(self.attentions[i](self.dilations[i](xi))) 592 | elif self.sa: 593 | for i, xi in enumerate([x1, x2, x3, x4]): 594 | x_fuses.append(self.attentions[i](xi)) 595 | elif self.dil is not None: 596 | for i, xi in enumerate([x1, x2, x3, x4]): 597 | x_fuses.append(self.dilations[i](xi)) 598 | else: 599 | x_fuses = [x1, x2, x3, x4] 600 | 601 | e1 = self.conv_reduces[0](x_fuses[0]) 602 | e1 = F.interpolate(e1, (H, W), mode="bilinear", align_corners=False) 603 | 604 | e2 = self.conv_reduces[1](x_fuses[1]) 605 | e2 = F.interpolate(e2, (H, W), mode="bilinear", align_corners=False) 606 | 607 | e3 = self.conv_reduces[2](x_fuses[2]) 608 | e3 = F.interpolate(e3, (H, W), mode="bilinear", align_corners=False) 609 | 610 | e4 = self.conv_reduces[3](x_fuses[3]) 611 | e4 = F.interpolate(e4, (H, W), mode="bilinear", align_corners=False) 612 | 613 | outputs = [e1, e2, e3, e4] 614 | 615 | output = self.classifier(torch.cat(outputs, dim=1)) 616 | #if not self.training: 617 | # return torch.sigmoid(output) 618 | 619 | outputs.append(output) 620 | outputs = [torch.sigmoid(r) for r in outputs] 621 | return outputs 622 | 623 | def config_model(model): 624 | model_options = list(nets.keys()) 625 | assert model in model_options, \ 626 | 'unrecognized model, please choose from %s' % str(model_options) 627 | 628 | # print(str(nets[model])) 629 | 630 | pdcs = [] 631 | for i in range(16): 632 | layer_name = 'layer%d' % i 633 | op = nets[model][layer_name] 634 | pdcs.append(createConvFunc(op)) 635 | 636 | return pdcs 637 | 638 | def pidinet(): 639 | pdcs = config_model('carv4') 640 | dil = 24 #if args.dil else None 641 | return PiDiNet(60, pdcs, dil=dil, sa=True) 642 | 643 | 644 | if __name__ == '__main__': 645 | model = pidinet() 646 | ckp = torch.load('table5_pidinet.pth')['state_dict'] 647 | model.load_state_dict({k.replace('module.',''):v for k, v in ckp.items()}) 648 | im = cv2.imread('examples/test_my/cat_v4.png') 649 | im = img2tensor(im).unsqueeze(0)/255. 650 | res = model(im)[-1] 651 | res = res>0.5 652 | res = res.float() 653 | res = (res[0,0].cpu().data.numpy()*255.).astype(np.uint8) 654 | print(res.shape) 655 | cv2.imwrite('edge.png', res) 656 | -------------------------------------------------------------------------------- /Adapter/extra_condition/openpose/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/T2I-Adapter/c408b059c36e3f9ce336b66746bd606edaa5483a/Adapter/extra_condition/openpose/__init__.py -------------------------------------------------------------------------------- /Adapter/extra_condition/openpose/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/T2I-Adapter/c408b059c36e3f9ce336b66746bd606edaa5483a/Adapter/extra_condition/openpose/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /Adapter/extra_condition/openpose/__pycache__/api.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/T2I-Adapter/c408b059c36e3f9ce336b66746bd606edaa5483a/Adapter/extra_condition/openpose/__pycache__/api.cpython-310.pyc -------------------------------------------------------------------------------- /Adapter/extra_condition/openpose/__pycache__/body.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/T2I-Adapter/c408b059c36e3f9ce336b66746bd606edaa5483a/Adapter/extra_condition/openpose/__pycache__/body.cpython-310.pyc -------------------------------------------------------------------------------- /Adapter/extra_condition/openpose/__pycache__/model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/T2I-Adapter/c408b059c36e3f9ce336b66746bd606edaa5483a/Adapter/extra_condition/openpose/__pycache__/model.cpython-310.pyc -------------------------------------------------------------------------------- /Adapter/extra_condition/openpose/__pycache__/util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/T2I-Adapter/c408b059c36e3f9ce336b66746bd606edaa5483a/Adapter/extra_condition/openpose/__pycache__/util.cpython-310.pyc -------------------------------------------------------------------------------- /Adapter/extra_condition/openpose/api.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch.nn as nn 4 | 5 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" 6 | 7 | import cv2 8 | import torch 9 | 10 | from . import util 11 | from .body import Body 12 | 13 | remote_model_path = "https://huggingface.co/TencentARC/T2I-Adapter/blob/main/third-party-models/body_pose_model.pth" 14 | 15 | 16 | class OpenposeInference(nn.Module): 17 | 18 | def __init__(self): 19 | super().__init__() 20 | body_modelpath = os.path.join('checkpoints', "body_pose_model.pth") 21 | 22 | if not os.path.exists(body_modelpath): 23 | from basicsr.utils.download_util import load_file_from_url 24 | load_file_from_url(remote_model_path, model_dir='checkpoints') 25 | 26 | self.body_estimation = Body(body_modelpath) 27 | 28 | def forward(self, x): 29 | x = x[:, :, ::-1].copy() 30 | with torch.no_grad(): 31 | candidate, subset = self.body_estimation(x) 32 | # print(candidate.shape) 33 | # print(subset.shape) 34 | canvas = np.zeros_like(x) 35 | canvas = util.draw_bodypose(canvas, candidate, subset) 36 | canvas = cv2.cvtColor(canvas, cv2.COLOR_RGB2BGR) 37 | return canvas 38 | 39 | 40 | class OpenposeInference_count(nn.Module): 41 | 42 | def __init__(self): 43 | super().__init__() 44 | body_modelpath = os.path.join('checkpoints', "body_pose_model.pth") 45 | 46 | if not os.path.exists(body_modelpath): 47 | from basicsr.utils.download_util import load_file_from_url 48 | load_file_from_url(remote_model_path, model_dir='checkpoints') 49 | 50 | self.body_estimation = Body(body_modelpath) 51 | 52 | def forward(self, x): 53 | x = x[:, :, ::-1].copy() 54 | with torch.no_grad(): 55 | candidate, subset = self.body_estimation(x) 56 | return subset.shape[0] -------------------------------------------------------------------------------- /Adapter/extra_condition/openpose/body.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import matplotlib 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import time 7 | import torch 8 | from scipy.ndimage.filters import gaussian_filter 9 | from torchvision import transforms 10 | 11 | from . import util 12 | from .model import bodypose_model 13 | 14 | 15 | class Body(object): 16 | 17 | def __init__(self, model_path): 18 | self.model = bodypose_model() 19 | if torch.cuda.is_available(): 20 | self.model = self.model.cuda() 21 | print('cuda') 22 | model_dict = util.transfer(self.model, torch.load(model_path)) 23 | self.model.load_state_dict(model_dict) 24 | self.model.eval() 25 | 26 | def __call__(self, oriImg): 27 | # scale_search = [0.5, 1.0, 1.5, 2.0] 28 | scale_search = [0.5] 29 | boxsize = 368 30 | stride = 8 31 | padValue = 128 32 | thre1 = 0.1 33 | thre2 = 0.05 34 | multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search] 35 | heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19)) 36 | paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38)) 37 | 38 | for m in range(len(multiplier)): 39 | scale = multiplier[m] 40 | imageToTest = cv2.resize(oriImg, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC) 41 | imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue) 42 | im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5 43 | im = np.ascontiguousarray(im) 44 | 45 | data = torch.from_numpy(im).float() 46 | if torch.cuda.is_available(): 47 | data = data.cuda() 48 | # data = data.permute([2, 0, 1]).unsqueeze(0).float() 49 | with torch.no_grad(): 50 | Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data) 51 | Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy() 52 | Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy() 53 | 54 | # extract outputs, resize, and remove padding 55 | # heatmap = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[1]].data), (1, 2, 0)) # output 1 is heatmaps 56 | heatmap = np.transpose(np.squeeze(Mconv7_stage6_L2), (1, 2, 0)) # output 1 is heatmaps 57 | heatmap = cv2.resize(heatmap, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC) 58 | heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :] 59 | heatmap = cv2.resize(heatmap, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC) 60 | 61 | # paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0)) # output 0 is PAFs 62 | paf = np.transpose(np.squeeze(Mconv7_stage6_L1), (1, 2, 0)) # output 0 is PAFs 63 | paf = cv2.resize(paf, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC) 64 | paf = paf[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :] 65 | paf = cv2.resize(paf, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC) 66 | 67 | heatmap_avg += heatmap_avg + heatmap / len(multiplier) 68 | paf_avg += +paf / len(multiplier) 69 | 70 | all_peaks = [] 71 | peak_counter = 0 72 | 73 | for part in range(18): 74 | map_ori = heatmap_avg[:, :, part] 75 | one_heatmap = gaussian_filter(map_ori, sigma=3) 76 | 77 | map_left = np.zeros(one_heatmap.shape) 78 | map_left[1:, :] = one_heatmap[:-1, :] 79 | map_right = np.zeros(one_heatmap.shape) 80 | map_right[:-1, :] = one_heatmap[1:, :] 81 | map_up = np.zeros(one_heatmap.shape) 82 | map_up[:, 1:] = one_heatmap[:, :-1] 83 | map_down = np.zeros(one_heatmap.shape) 84 | map_down[:, :-1] = one_heatmap[:, 1:] 85 | 86 | peaks_binary = np.logical_and.reduce((one_heatmap >= map_left, one_heatmap >= map_right, 87 | one_heatmap >= map_up, one_heatmap >= map_down, one_heatmap > thre1)) 88 | peaks = list(zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])) # note reverse 89 | peaks_with_score = [x + (map_ori[x[1], x[0]], ) for x in peaks] 90 | peak_id = range(peak_counter, peak_counter + len(peaks)) 91 | peaks_with_score_and_id = [peaks_with_score[i] + (peak_id[i], ) for i in range(len(peak_id))] 92 | 93 | all_peaks.append(peaks_with_score_and_id) 94 | peak_counter += len(peaks) 95 | 96 | # find connection in the specified sequence, center 29 is in the position 15 97 | limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ 98 | [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ 99 | [1, 16], [16, 18], [3, 17], [6, 18]] 100 | # the middle joints heatmap correpondence 101 | mapIdx = [[31, 32], [39, 40], [33, 34], [35, 36], [41, 42], [43, 44], [19, 20], [21, 22], \ 102 | [23, 24], [25, 26], [27, 28], [29, 30], [47, 48], [49, 50], [53, 54], [51, 52], \ 103 | [55, 56], [37, 38], [45, 46]] 104 | 105 | connection_all = [] 106 | special_k = [] 107 | mid_num = 10 108 | 109 | for k in range(len(mapIdx)): 110 | score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]] 111 | candA = all_peaks[limbSeq[k][0] - 1] 112 | candB = all_peaks[limbSeq[k][1] - 1] 113 | nA = len(candA) 114 | nB = len(candB) 115 | indexA, indexB = limbSeq[k] 116 | if (nA != 0 and nB != 0): 117 | connection_candidate = [] 118 | for i in range(nA): 119 | for j in range(nB): 120 | vec = np.subtract(candB[j][:2], candA[i][:2]) 121 | norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1]) 122 | norm = max(0.001, norm) 123 | vec = np.divide(vec, norm) 124 | 125 | startend = list(zip(np.linspace(candA[i][0], candB[j][0], num=mid_num), \ 126 | np.linspace(candA[i][1], candB[j][1], num=mid_num))) 127 | 128 | vec_x = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 0] \ 129 | for I in range(len(startend))]) 130 | vec_y = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 1] \ 131 | for I in range(len(startend))]) 132 | 133 | score_midpts = np.multiply(vec_x, vec[0]) + np.multiply(vec_y, vec[1]) 134 | score_with_dist_prior = sum(score_midpts) / len(score_midpts) + min( 135 | 0.5 * oriImg.shape[0] / norm - 1, 0) 136 | criterion1 = len(np.nonzero(score_midpts > thre2)[0]) > 0.8 * len(score_midpts) 137 | criterion2 = score_with_dist_prior > 0 138 | if criterion1 and criterion2: 139 | connection_candidate.append( 140 | [i, j, score_with_dist_prior, score_with_dist_prior + candA[i][2] + candB[j][2]]) 141 | 142 | connection_candidate = sorted(connection_candidate, key=lambda x: x[2], reverse=True) 143 | connection = np.zeros((0, 5)) 144 | for c in range(len(connection_candidate)): 145 | i, j, s = connection_candidate[c][0:3] 146 | if (i not in connection[:, 3] and j not in connection[:, 4]): 147 | connection = np.vstack([connection, [candA[i][3], candB[j][3], s, i, j]]) 148 | if (len(connection) >= min(nA, nB)): 149 | break 150 | 151 | connection_all.append(connection) 152 | else: 153 | special_k.append(k) 154 | connection_all.append([]) 155 | 156 | # last number in each row is the total parts number of that person 157 | # the second last number in each row is the score of the overall configuration 158 | subset = -1 * np.ones((0, 20)) 159 | candidate = np.array([item for sublist in all_peaks for item in sublist]) 160 | 161 | for k in range(len(mapIdx)): 162 | if k not in special_k: 163 | partAs = connection_all[k][:, 0] 164 | partBs = connection_all[k][:, 1] 165 | indexA, indexB = np.array(limbSeq[k]) - 1 166 | 167 | for i in range(len(connection_all[k])): # = 1:size(temp,1) 168 | found = 0 169 | subset_idx = [-1, -1] 170 | for j in range(len(subset)): # 1:size(subset,1): 171 | if subset[j][indexA] == partAs[i] or subset[j][indexB] == partBs[i]: 172 | subset_idx[found] = j 173 | found += 1 174 | 175 | if found == 1: 176 | j = subset_idx[0] 177 | if subset[j][indexB] != partBs[i]: 178 | subset[j][indexB] = partBs[i] 179 | subset[j][-1] += 1 180 | subset[j][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2] 181 | elif found == 2: # if found 2 and disjoint, merge them 182 | j1, j2 = subset_idx 183 | membership = ((subset[j1] >= 0).astype(int) + (subset[j2] >= 0).astype(int))[:-2] 184 | if len(np.nonzero(membership == 2)[0]) == 0: # merge 185 | subset[j1][:-2] += (subset[j2][:-2] + 1) 186 | subset[j1][-2:] += subset[j2][-2:] 187 | subset[j1][-2] += connection_all[k][i][2] 188 | subset = np.delete(subset, j2, 0) 189 | else: # as like found == 1 190 | subset[j1][indexB] = partBs[i] 191 | subset[j1][-1] += 1 192 | subset[j1][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2] 193 | 194 | # if find no partA in the subset, create a new subset 195 | elif not found and k < 17: 196 | row = -1 * np.ones(20) 197 | row[indexA] = partAs[i] 198 | row[indexB] = partBs[i] 199 | row[-1] = 2 200 | row[-2] = sum(candidate[connection_all[k][i, :2].astype(int), 2]) + connection_all[k][i][2] 201 | subset = np.vstack([subset, row]) 202 | # delete some rows of subset which has few parts occur 203 | deleteIdx = [] 204 | for i in range(len(subset)): 205 | if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4: 206 | deleteIdx.append(i) 207 | subset = np.delete(subset, deleteIdx, axis=0) 208 | 209 | # subset: n*20 array, 0-17 is the index in candidate, 18 is the total score, 19 is the total parts 210 | # candidate: x, y, score, id 211 | return candidate, subset 212 | -------------------------------------------------------------------------------- /Adapter/extra_condition/openpose/hand.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import json 3 | import math 4 | import matplotlib 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import time 8 | import torch 9 | from scipy.ndimage.filters import gaussian_filter 10 | from skimage.measure import label 11 | 12 | from . import util 13 | from .model import handpose_model 14 | 15 | 16 | class Hand(object): 17 | 18 | def __init__(self, model_path): 19 | self.model = handpose_model() 20 | if torch.cuda.is_available(): 21 | self.model = self.model.cuda() 22 | print('cuda') 23 | model_dict = util.transfer(self.model, torch.load(model_path)) 24 | self.model.load_state_dict(model_dict) 25 | self.model.eval() 26 | 27 | def __call__(self, oriImg): 28 | scale_search = [0.5, 1.0, 1.5, 2.0] 29 | # scale_search = [0.5] 30 | boxsize = 368 31 | stride = 8 32 | padValue = 128 33 | thre = 0.05 34 | multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search] 35 | heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 22)) 36 | # paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38)) 37 | 38 | for m in range(len(multiplier)): 39 | scale = multiplier[m] 40 | imageToTest = cv2.resize(oriImg, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC) 41 | imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue) 42 | im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5 43 | im = np.ascontiguousarray(im) 44 | 45 | data = torch.from_numpy(im).float() 46 | if torch.cuda.is_available(): 47 | data = data.cuda() 48 | # data = data.permute([2, 0, 1]).unsqueeze(0).float() 49 | with torch.no_grad(): 50 | output = self.model(data).cpu().numpy() 51 | # output = self.model(data).numpy()q 52 | 53 | # extract outputs, resize, and remove padding 54 | heatmap = np.transpose(np.squeeze(output), (1, 2, 0)) # output 1 is heatmaps 55 | heatmap = cv2.resize(heatmap, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC) 56 | heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :] 57 | heatmap = cv2.resize(heatmap, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC) 58 | 59 | heatmap_avg += heatmap / len(multiplier) 60 | 61 | all_peaks = [] 62 | for part in range(21): 63 | map_ori = heatmap_avg[:, :, part] 64 | one_heatmap = gaussian_filter(map_ori, sigma=3) 65 | binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8) 66 | # 全部小于阈值 67 | if np.sum(binary) == 0: 68 | all_peaks.append([0, 0]) 69 | continue 70 | label_img, label_numbers = label(binary, return_num=True, connectivity=binary.ndim) 71 | max_index = np.argmax([np.sum(map_ori[label_img == i]) for i in range(1, label_numbers + 1)]) + 1 72 | label_img[label_img != max_index] = 0 73 | map_ori[label_img == 0] = 0 74 | 75 | y, x = util.npmax(map_ori) 76 | all_peaks.append([x, y]) 77 | return np.array(all_peaks) 78 | -------------------------------------------------------------------------------- /Adapter/extra_condition/openpose/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from collections import OrderedDict 4 | 5 | 6 | def make_layers(block, no_relu_layers): 7 | layers = [] 8 | for layer_name, v in block.items(): 9 | if 'pool' in layer_name: 10 | layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1], padding=v[2]) 11 | layers.append((layer_name, layer)) 12 | else: 13 | conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1], kernel_size=v[2], stride=v[3], padding=v[4]) 14 | layers.append((layer_name, conv2d)) 15 | if layer_name not in no_relu_layers: 16 | layers.append(('relu_' + layer_name, nn.ReLU(inplace=True))) 17 | 18 | return nn.Sequential(OrderedDict(layers)) 19 | 20 | 21 | class bodypose_model(nn.Module): 22 | 23 | def __init__(self): 24 | super(bodypose_model, self).__init__() 25 | 26 | # these layers have no relu layer 27 | no_relu_layers = ['conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1',\ 28 | 'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2',\ 29 | 'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1',\ 30 | 'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1'] 31 | blocks = {} 32 | block0 = OrderedDict([('conv1_1', [3, 64, 3, 1, 1]), ('conv1_2', [64, 64, 3, 1, 1]), ('pool1_stage1', [2, 2, 33 | 0]), 34 | ('conv2_1', [64, 128, 3, 1, 1]), ('conv2_2', [128, 128, 3, 1, 1]), 35 | ('pool2_stage1', [2, 2, 0]), ('conv3_1', [128, 256, 3, 1, 1]), 36 | ('conv3_2', [256, 256, 3, 1, 1]), ('conv3_3', [256, 256, 3, 1, 1]), 37 | ('conv3_4', [256, 256, 3, 1, 1]), ('pool3_stage1', [2, 2, 0]), 38 | ('conv4_1', [256, 512, 3, 1, 1]), ('conv4_2', [512, 512, 3, 1, 1]), 39 | ('conv4_3_CPM', [512, 256, 3, 1, 1]), ('conv4_4_CPM', [256, 128, 3, 1, 1])]) 40 | 41 | # Stage 1 42 | block1_1 = OrderedDict([('conv5_1_CPM_L1', [128, 128, 3, 1, 1]), ('conv5_2_CPM_L1', [128, 128, 3, 1, 1]), 43 | ('conv5_3_CPM_L1', [128, 128, 3, 1, 1]), ('conv5_4_CPM_L1', [128, 512, 1, 1, 0]), 44 | ('conv5_5_CPM_L1', [512, 38, 1, 1, 0])]) 45 | 46 | block1_2 = OrderedDict([('conv5_1_CPM_L2', [128, 128, 3, 1, 1]), ('conv5_2_CPM_L2', [128, 128, 3, 1, 1]), 47 | ('conv5_3_CPM_L2', [128, 128, 3, 1, 1]), ('conv5_4_CPM_L2', [128, 512, 1, 1, 0]), 48 | ('conv5_5_CPM_L2', [512, 19, 1, 1, 0])]) 49 | blocks['block1_1'] = block1_1 50 | blocks['block1_2'] = block1_2 51 | 52 | self.model0 = make_layers(block0, no_relu_layers) 53 | 54 | # Stages 2 - 6 55 | for i in range(2, 7): 56 | blocks['block%d_1' % i] = OrderedDict([('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]), 57 | ('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]), 58 | ('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]), 59 | ('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]), 60 | ('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]), 61 | ('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]), 62 | ('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0])]) 63 | 64 | blocks['block%d_2' % i] = OrderedDict([('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]), 65 | ('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]), 66 | ('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]), 67 | ('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]), 68 | ('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]), 69 | ('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]), 70 | ('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0])]) 71 | 72 | for k in blocks.keys(): 73 | blocks[k] = make_layers(blocks[k], no_relu_layers) 74 | 75 | self.model1_1 = blocks['block1_1'] 76 | self.model2_1 = blocks['block2_1'] 77 | self.model3_1 = blocks['block3_1'] 78 | self.model4_1 = blocks['block4_1'] 79 | self.model5_1 = blocks['block5_1'] 80 | self.model6_1 = blocks['block6_1'] 81 | 82 | self.model1_2 = blocks['block1_2'] 83 | self.model2_2 = blocks['block2_2'] 84 | self.model3_2 = blocks['block3_2'] 85 | self.model4_2 = blocks['block4_2'] 86 | self.model5_2 = blocks['block5_2'] 87 | self.model6_2 = blocks['block6_2'] 88 | 89 | def forward(self, x): 90 | 91 | out1 = self.model0(x) 92 | 93 | out1_1 = self.model1_1(out1) 94 | out1_2 = self.model1_2(out1) 95 | out2 = torch.cat([out1_1, out1_2, out1], 1) 96 | 97 | out2_1 = self.model2_1(out2) 98 | out2_2 = self.model2_2(out2) 99 | out3 = torch.cat([out2_1, out2_2, out1], 1) 100 | 101 | out3_1 = self.model3_1(out3) 102 | out3_2 = self.model3_2(out3) 103 | out4 = torch.cat([out3_1, out3_2, out1], 1) 104 | 105 | out4_1 = self.model4_1(out4) 106 | out4_2 = self.model4_2(out4) 107 | out5 = torch.cat([out4_1, out4_2, out1], 1) 108 | 109 | out5_1 = self.model5_1(out5) 110 | out5_2 = self.model5_2(out5) 111 | out6 = torch.cat([out5_1, out5_2, out1], 1) 112 | 113 | out6_1 = self.model6_1(out6) 114 | out6_2 = self.model6_2(out6) 115 | 116 | return out6_1, out6_2 117 | 118 | 119 | class handpose_model(nn.Module): 120 | 121 | def __init__(self): 122 | super(handpose_model, self).__init__() 123 | 124 | # these layers have no relu layer 125 | no_relu_layers = ['conv6_2_CPM', 'Mconv7_stage2', 'Mconv7_stage3',\ 126 | 'Mconv7_stage4', 'Mconv7_stage5', 'Mconv7_stage6'] 127 | # stage 1 128 | block1_0 = OrderedDict([('conv1_1', [3, 64, 3, 1, 1]), ('conv1_2', [64, 64, 3, 1, 1]), 129 | ('pool1_stage1', [2, 2, 0]), ('conv2_1', [64, 128, 3, 1, 1]), 130 | ('conv2_2', [128, 128, 3, 1, 1]), ('pool2_stage1', [2, 2, 0]), 131 | ('conv3_1', [128, 256, 3, 1, 1]), ('conv3_2', [256, 256, 3, 1, 1]), 132 | ('conv3_3', [256, 256, 3, 1, 1]), ('conv3_4', [256, 256, 3, 1, 1]), 133 | ('pool3_stage1', [2, 2, 0]), ('conv4_1', [256, 512, 3, 1, 1]), 134 | ('conv4_2', [512, 512, 3, 1, 1]), ('conv4_3', [512, 512, 3, 1, 1]), 135 | ('conv4_4', [512, 512, 3, 1, 1]), ('conv5_1', [512, 512, 3, 1, 1]), 136 | ('conv5_2', [512, 512, 3, 1, 1]), ('conv5_3_CPM', [512, 128, 3, 1, 1])]) 137 | 138 | block1_1 = OrderedDict([('conv6_1_CPM', [128, 512, 1, 1, 0]), ('conv6_2_CPM', [512, 22, 1, 1, 0])]) 139 | 140 | blocks = {} 141 | blocks['block1_0'] = block1_0 142 | blocks['block1_1'] = block1_1 143 | 144 | # stage 2-6 145 | for i in range(2, 7): 146 | blocks['block%d' % i] = OrderedDict([('Mconv1_stage%d' % i, [150, 128, 7, 1, 3]), 147 | ('Mconv2_stage%d' % i, [128, 128, 7, 1, 3]), 148 | ('Mconv3_stage%d' % i, [128, 128, 7, 1, 3]), 149 | ('Mconv4_stage%d' % i, [128, 128, 7, 1, 3]), 150 | ('Mconv5_stage%d' % i, [128, 128, 7, 1, 3]), 151 | ('Mconv6_stage%d' % i, [128, 128, 1, 1, 0]), 152 | ('Mconv7_stage%d' % i, [128, 22, 1, 1, 0])]) 153 | 154 | for k in blocks.keys(): 155 | blocks[k] = make_layers(blocks[k], no_relu_layers) 156 | 157 | self.model1_0 = blocks['block1_0'] 158 | self.model1_1 = blocks['block1_1'] 159 | self.model2 = blocks['block2'] 160 | self.model3 = blocks['block3'] 161 | self.model4 = blocks['block4'] 162 | self.model5 = blocks['block5'] 163 | self.model6 = blocks['block6'] 164 | 165 | def forward(self, x): 166 | out1_0 = self.model1_0(x) 167 | out1_1 = self.model1_1(out1_0) 168 | concat_stage2 = torch.cat([out1_1, out1_0], 1) 169 | out_stage2 = self.model2(concat_stage2) 170 | concat_stage3 = torch.cat([out_stage2, out1_0], 1) 171 | out_stage3 = self.model3(concat_stage3) 172 | concat_stage4 = torch.cat([out_stage3, out1_0], 1) 173 | out_stage4 = self.model4(concat_stage4) 174 | concat_stage5 = torch.cat([out_stage4, out1_0], 1) 175 | out_stage5 = self.model5(concat_stage5) 176 | concat_stage6 = torch.cat([out_stage5, out1_0], 1) 177 | out_stage6 = self.model6(concat_stage6) 178 | return out_stage6 179 | -------------------------------------------------------------------------------- /Adapter/extra_condition/openpose/util.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import cv2 4 | import matplotlib 5 | import numpy as np 6 | 7 | 8 | def padRightDownCorner(img, stride, padValue): 9 | h = img.shape[0] 10 | w = img.shape[1] 11 | 12 | pad = 4 * [None] 13 | pad[0] = 0 # up 14 | pad[1] = 0 # left 15 | pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down 16 | pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right 17 | 18 | img_padded = img 19 | pad_up = np.tile(img_padded[0:1, :, :] * 0 + padValue, (pad[0], 1, 1)) 20 | img_padded = np.concatenate((pad_up, img_padded), axis=0) 21 | pad_left = np.tile(img_padded[:, 0:1, :] * 0 + padValue, (1, pad[1], 1)) 22 | img_padded = np.concatenate((pad_left, img_padded), axis=1) 23 | pad_down = np.tile(img_padded[-2:-1, :, :] * 0 + padValue, (pad[2], 1, 1)) 24 | img_padded = np.concatenate((img_padded, pad_down), axis=0) 25 | pad_right = np.tile(img_padded[:, -2:-1, :] * 0 + padValue, (1, pad[3], 1)) 26 | img_padded = np.concatenate((img_padded, pad_right), axis=1) 27 | 28 | return img_padded, pad 29 | 30 | 31 | # transfer caffe model to pytorch which will match the layer name 32 | def transfer(model, model_weights): 33 | transfered_model_weights = {} 34 | for weights_name in model.state_dict().keys(): 35 | transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])] 36 | return transfered_model_weights 37 | 38 | 39 | # draw the body keypoint and lims 40 | def draw_bodypose(canvas, candidate, subset): 41 | stickwidth = 4 42 | limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ 43 | [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ 44 | [1, 16], [16, 18], [3, 17], [6, 18]] 45 | 46 | colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ 47 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ 48 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] 49 | for i in range(18): 50 | for n in range(len(subset)): 51 | index = int(subset[n][i]) 52 | if index == -1: 53 | continue 54 | x, y = candidate[index][0:2] 55 | cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1) 56 | for i in range(17): 57 | for n in range(len(subset)): 58 | index = subset[n][np.array(limbSeq[i]) - 1] 59 | if -1 in index: 60 | continue 61 | cur_canvas = canvas.copy() 62 | Y = candidate[index.astype(int), 0] 63 | X = candidate[index.astype(int), 1] 64 | mX = np.mean(X) 65 | mY = np.mean(Y) 66 | length = ((X[0] - X[1])**2 + (Y[0] - Y[1])**2)**0.5 67 | angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) 68 | polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) 69 | cv2.fillConvexPoly(cur_canvas, polygon, colors[i]) 70 | canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0) 71 | # plt.imsave("preview.jpg", canvas[:, :, [2, 1, 0]]) 72 | # plt.imshow(canvas[:, :, [2, 1, 0]]) 73 | return canvas 74 | 75 | 76 | # image drawed by opencv is not good. 77 | def draw_handpose(canvas, all_hand_peaks, show_number=False): 78 | edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \ 79 | [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]] 80 | 81 | for peaks in all_hand_peaks: 82 | for ie, e in enumerate(edges): 83 | if np.sum(np.all(peaks[e], axis=1) == 0) == 0: 84 | x1, y1 = peaks[e[0]] 85 | x2, y2 = peaks[e[1]] 86 | cv2.line( 87 | canvas, (x1, y1), (x2, y2), 88 | matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, 89 | thickness=2) 90 | 91 | for i, keyponit in enumerate(peaks): 92 | x, y = keyponit 93 | cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1) 94 | if show_number: 95 | cv2.putText(canvas, str(i), (x, y), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (0, 0, 0), lineType=cv2.LINE_AA) 96 | return canvas 97 | 98 | 99 | # detect hand according to body pose keypoints 100 | # please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp 101 | def handDetect(candidate, subset, oriImg): 102 | # right hand: wrist 4, elbow 3, shoulder 2 103 | # left hand: wrist 7, elbow 6, shoulder 5 104 | ratioWristElbow = 0.33 105 | detect_result = [] 106 | image_height, image_width = oriImg.shape[0:2] 107 | for person in subset.astype(int): 108 | # if any of three not detected 109 | has_left = np.sum(person[[5, 6, 7]] == -1) == 0 110 | has_right = np.sum(person[[2, 3, 4]] == -1) == 0 111 | if not (has_left or has_right): 112 | continue 113 | hands = [] 114 | #left hand 115 | if has_left: 116 | left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]] 117 | x1, y1 = candidate[left_shoulder_index][:2] 118 | x2, y2 = candidate[left_elbow_index][:2] 119 | x3, y3 = candidate[left_wrist_index][:2] 120 | hands.append([x1, y1, x2, y2, x3, y3, True]) 121 | # right hand 122 | if has_right: 123 | right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]] 124 | x1, y1 = candidate[right_shoulder_index][:2] 125 | x2, y2 = candidate[right_elbow_index][:2] 126 | x3, y3 = candidate[right_wrist_index][:2] 127 | hands.append([x1, y1, x2, y2, x3, y3, False]) 128 | 129 | for x1, y1, x2, y2, x3, y3, is_left in hands: 130 | # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox 131 | # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]); 132 | # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]); 133 | # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow); 134 | # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder); 135 | # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder); 136 | x = x3 + ratioWristElbow * (x3 - x2) 137 | y = y3 + ratioWristElbow * (y3 - y2) 138 | distanceWristElbow = math.sqrt((x3 - x2)**2 + (y3 - y2)**2) 139 | distanceElbowShoulder = math.sqrt((x2 - x1)**2 + (y2 - y1)**2) 140 | width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder) 141 | # x-y refers to the center --> offset to topLeft point 142 | # handRectangle.x -= handRectangle.width / 2.f; 143 | # handRectangle.y -= handRectangle.height / 2.f; 144 | x -= width / 2 145 | y -= width / 2 # width = height 146 | # overflow the image 147 | if x < 0: x = 0 148 | if y < 0: y = 0 149 | width1 = width 150 | width2 = width 151 | if x + width > image_width: width1 = image_width - x 152 | if y + width > image_height: width2 = image_height - y 153 | width = min(width1, width2) 154 | # the max hand box value is 20 pixels 155 | if width >= 20: 156 | detect_result.append([int(x), int(y), int(width), is_left]) 157 | ''' 158 | return value: [[x, y, w, True if left hand else False]]. 159 | width=height since the network require squared input. 160 | x, y is the coordinate of top left 161 | ''' 162 | return detect_result 163 | 164 | 165 | # get max index of 2d array 166 | def npmax(array): 167 | arrayindex = array.argmax(1) 168 | arrayvalue = array.max(1) 169 | i = arrayvalue.argmax() 170 | j = arrayindex[i] 171 | return i, j 172 | 173 | 174 | def HWC3(x): 175 | assert x.dtype == np.uint8 176 | if x.ndim == 2: 177 | x = x[:, :, None] 178 | assert x.ndim == 3 179 | H, W, C = x.shape 180 | assert C == 1 or C == 3 or C == 4 181 | if C == 3: 182 | return x 183 | if C == 1: 184 | return np.concatenate([x, x, x], axis=2) 185 | if C == 4: 186 | color = x[:, :, 0:3].astype(np.float32) 187 | alpha = x[:, :, 3:4].astype(np.float32) / 255.0 188 | y = color * alpha + 255.0 * (1.0 - alpha) 189 | y = y.clip(0, 255).astype(np.uint8) 190 | return y 191 | 192 | 193 | def resize_image(input_image, resolution): 194 | H, W, C = input_image.shape 195 | H = float(H) 196 | W = float(W) 197 | k = float(resolution) / min(H, W) 198 | H *= k 199 | W *= k 200 | H = int(np.round(H / 64.0)) * 64 201 | W = int(np.round(W / 64.0)) * 64 202 | img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) 203 | return img 204 | -------------------------------------------------------------------------------- /Adapter/inference_base.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from omegaconf import OmegaConf 4 | 5 | DEFAULT_NEGATIVE_PROMPT = 'extra digit, fewer digits, cropped, worst quality, low quality' 6 | 7 | def get_base_argument_parser() -> argparse.ArgumentParser: 8 | """get the base argument parser for inference scripts""" 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument( 11 | '--outdir', 12 | type=str, 13 | help='dir to write results to', 14 | default=None, 15 | ) 16 | 17 | parser.add_argument( 18 | '--prompt', 19 | type=str, 20 | default='', 21 | help='positive prompt', 22 | ) 23 | 24 | parser.add_argument( 25 | '--neg_prompt', 26 | type=str, 27 | default=DEFAULT_NEGATIVE_PROMPT, 28 | help='negative prompt', 29 | ) 30 | 31 | parser.add_argument( 32 | '--cond_path', 33 | type=str, 34 | default=None, 35 | help='condition image path', 36 | ) 37 | 38 | parser.add_argument( 39 | '--cond_inp_type', 40 | type=str, 41 | default='image', 42 | help='the type of the input condition image, take depth T2I as example, the input can be raw image, ' 43 | 'which depth will be calculated, or the input can be a directly a depth map image', 44 | ) 45 | 46 | parser.add_argument( 47 | '--sampler', 48 | type=str, 49 | default='ddim', 50 | choices=['ddim', 'plms'], 51 | help='sampling algorithm, currently, only ddim and plms are supported, more are on the way', 52 | ) 53 | 54 | parser.add_argument( 55 | '--steps', 56 | type=int, 57 | default=50, 58 | help='number of sampling steps', 59 | ) 60 | 61 | parser.add_argument( 62 | '--max_resolution', 63 | type=float, 64 | default=1024 * 1024, 65 | help='max image height * width, only for computer with limited vram', 66 | ) 67 | 68 | parser.add_argument( 69 | '--resize_short_edge', 70 | type=int, 71 | default=None, 72 | help='resize short edge of the input image, if this arg is set, max_resolution will not be used', 73 | ) 74 | 75 | parser.add_argument( 76 | '--C', 77 | type=int, 78 | default=4, 79 | help='latent channels', 80 | ) 81 | 82 | parser.add_argument( 83 | '--f', 84 | type=int, 85 | default=8, 86 | help='downsampling factor', 87 | ) 88 | 89 | parser.add_argument( 90 | '--scale', 91 | type=float, 92 | default=7.5, 93 | help='unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))', 94 | ) 95 | 96 | parser.add_argument( 97 | '--cond_tau', 98 | type=float, 99 | default=1.0, 100 | help='timestamp parameter that determines until which step the adapter is applied, ' 101 | 'similar as Prompt-to-Prompt tau', 102 | ) 103 | 104 | parser.add_argument( 105 | '--cond_weight', 106 | type=float, 107 | default=1.0, 108 | help='the adapter features are multiplied by the cond_weight. The larger the cond_weight, the more aligned ' 109 | 'the generated image and condition will be, but the generated quality may be reduced', 110 | ) 111 | 112 | parser.add_argument( 113 | '--seed', 114 | type=int, 115 | default=42, 116 | ) 117 | 118 | parser.add_argument( 119 | '--n_samples', 120 | type=int, 121 | default=4, 122 | help='# of samples to generate', 123 | ) 124 | 125 | return parser 126 | 127 | -------------------------------------------------------------------------------- /Adapter/models/adapters.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from collections import OrderedDict 4 | 5 | 6 | def conv_nd(dims, *args, **kwargs): 7 | """ 8 | Create a 1D, 2D, or 3D convolution module. 9 | """ 10 | if dims == 1: 11 | return nn.Conv1d(*args, **kwargs) 12 | elif dims == 2: 13 | return nn.Conv2d(*args, **kwargs) 14 | elif dims == 3: 15 | return nn.Conv3d(*args, **kwargs) 16 | raise ValueError(f"unsupported dimensions: {dims}") 17 | 18 | 19 | def avg_pool_nd(dims, *args, **kwargs): 20 | """ 21 | Create a 1D, 2D, or 3D average pooling module. 22 | """ 23 | if dims == 1: 24 | return nn.AvgPool1d(*args, **kwargs) 25 | elif dims == 2: 26 | return nn.AvgPool2d(*args, **kwargs) 27 | elif dims == 3: 28 | return nn.AvgPool3d(*args, **kwargs) 29 | raise ValueError(f"unsupported dimensions: {dims}") 30 | 31 | def get_parameter_dtype(parameter: torch.nn.Module): 32 | try: 33 | params = tuple(parameter.parameters()) 34 | if len(params) > 0: 35 | return params[0].dtype 36 | 37 | buffers = tuple(parameter.buffers()) 38 | if len(buffers) > 0: 39 | return buffers[0].dtype 40 | 41 | except StopIteration: 42 | # For torch.nn.DataParallel compatibility in PyTorch 1.5 43 | 44 | def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: 45 | tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] 46 | return tuples 47 | 48 | gen = parameter._named_members(get_members_fn=find_tensor_attributes) 49 | first_tuple = next(gen) 50 | return first_tuple[1].dtype 51 | 52 | class Downsample(nn.Module): 53 | """ 54 | A downsampling layer with an optional convolution. 55 | :param channels: channels in the inputs and outputs. 56 | :param use_conv: a bool determining if a convolution is applied. 57 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 58 | downsampling occurs in the inner-two dimensions. 59 | """ 60 | 61 | def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): 62 | super().__init__() 63 | self.channels = channels 64 | self.out_channels = out_channels or channels 65 | self.use_conv = use_conv 66 | self.dims = dims 67 | stride = 2 if dims != 3 else (1, 2, 2) 68 | if use_conv: 69 | self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding) 70 | else: 71 | assert self.channels == self.out_channels 72 | self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) 73 | 74 | def forward(self, x): 75 | assert x.shape[1] == self.channels 76 | return self.op(x) 77 | 78 | 79 | class ResnetBlock(nn.Module): 80 | 81 | def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True): 82 | super().__init__() 83 | ps = ksize // 2 84 | if in_c != out_c or sk == False: 85 | self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps) 86 | else: 87 | self.in_conv = None 88 | self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1) 89 | self.act = nn.ReLU() 90 | self.block2 = nn.Conv2d(out_c, out_c, ksize, 1, ps) 91 | if sk == False: 92 | self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps) 93 | else: 94 | self.skep = None 95 | 96 | self.down = down 97 | if self.down == True: 98 | self.down_opt = Downsample(in_c, use_conv=use_conv) 99 | 100 | def forward(self, x): 101 | if self.down == True: 102 | x = self.down_opt(x) 103 | if self.in_conv is not None: # edit 104 | x = self.in_conv(x) 105 | 106 | h = self.block1(x) 107 | h = self.act(h) 108 | h = self.block2(h) 109 | if self.skep is not None: 110 | return h + self.skep(x) 111 | else: 112 | return h + x 113 | 114 | 115 | class Adapter_XL(nn.Module): 116 | 117 | def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64, ksize=3, sk=False, use_conv=True): 118 | super(Adapter_XL, self).__init__() 119 | self.unshuffle = nn.PixelUnshuffle(16) 120 | self.channels = channels 121 | self.nums_rb = nums_rb 122 | self.body = [] 123 | for i in range(len(channels)): 124 | for j in range(nums_rb): 125 | if (i == 2) and (j == 0): 126 | self.body.append( 127 | ResnetBlock(channels[i - 1], channels[i], down=True, ksize=ksize, sk=sk, use_conv=use_conv)) 128 | elif (i == 1) and (j == 0): 129 | self.body.append( 130 | ResnetBlock(channels[i - 1], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv)) 131 | else: 132 | self.body.append( 133 | ResnetBlock(channels[i], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv)) 134 | self.body = nn.ModuleList(self.body) 135 | self.conv_in = nn.Conv2d(cin, channels[0], 3, 1, 1) 136 | 137 | @property 138 | def dtype(self) -> torch.dtype: 139 | """ 140 | `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). 141 | """ 142 | return get_parameter_dtype(self) 143 | 144 | def forward(self, x): 145 | # unshuffle 146 | x = self.unshuffle(x) 147 | # extract features 148 | features = [] 149 | x = self.conv_in(x) 150 | for i in range(len(self.channels)): 151 | for j in range(self.nums_rb): 152 | idx = i * self.nums_rb + j 153 | x = self.body[idx](x) 154 | features.append(x) 155 | 156 | return features 157 | 158 | -------------------------------------------------------------------------------- /Adapter/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import PretrainedConfig 2 | import numpy as np 3 | import cv2 4 | 5 | def import_model_class_from_model_name_or_path( 6 | pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" 7 | ): 8 | text_encoder_config = PretrainedConfig.from_pretrained( 9 | pretrained_model_name_or_path, subfolder=subfolder, revision=revision 10 | ) 11 | model_class = text_encoder_config.architectures[0] 12 | 13 | if model_class == "CLIPTextModel": 14 | from transformers import CLIPTextModel 15 | 16 | return CLIPTextModel 17 | elif model_class == "CLIPTextModelWithProjection": 18 | from transformers import CLIPTextModelWithProjection 19 | 20 | return CLIPTextModelWithProjection 21 | else: 22 | raise ValueError(f"{model_class} is not supported.") 23 | 24 | def resize_numpy_image(image, max_resolution=1024 * 1024, resize_short_edge=None): 25 | h, w = image.shape[:2] 26 | if resize_short_edge is not None: 27 | k = resize_short_edge / min(h, w) 28 | else: 29 | k = max_resolution / (h * w) 30 | k = k**0.5 31 | h = int(np.round(h * k / 64)) * 64 32 | w = int(np.round(w * k / 64)) * 64 33 | image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4) 34 | return image -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 | 5 | 6 | ###
👉 T2I-Adapter for [SD-1.4/1.5], for [SDXL]
7 | 8 |
9 | 10 | [![Huggingface T2I-Adapter-SDXL](https://img.shields.io/static/v1?label=Demo&message=Huggingface%20Gradio&color=orange)](https://huggingface.co/spaces/TencentARC/T2I-Adapter-SDXL)   [![Blog T2I-Adapter-SDXL](https://img.shields.io/static/v1?label=Blog&message=HuggingFace&color=orange)](https://huggingface.co/blog/t2i-sdxl-adapters)   [![arXiv](https://img.shields.io/badge/arXiv-2302.08453-b31b1b.svg?style=flat-square)](https://arxiv.org/abs/2302.08453) 11 | 12 |
13 | 14 | --- 15 | 16 | Official implementation of **[T2I-Adapter: Learning Adapters to Dig out More Controllable Ability for Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.08453)** based on [Stable Diffusion-XL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0). 17 | 18 | The diffusers team and the T2I-Adapter authors have been collaborating to bring the support of T2I-Adapters for Stable Diffusion XL (SDXL) in diffusers! It achieves impressive results in both performance and efficiency. 19 | 20 | --- 21 | ![image](https://github.com/TencentARC/T2I-Adapter/assets/54032224/d249f699-b6d5-461d-9fdf-f0d009f14f4d) 22 | 23 | ## 🚩 **New Features/Updates** 24 | - ✅ Sep. 8, 2023. We collaborate with the diffusers team to bring the support of T2I-Adapters for Stable Diffusion XL (SDXL) in diffusers! It achieves impressive results in both performance and efficiency. We release T2I-Adapter-SDXL models for [sketch](https://huggingface.co/TencentARC/t2i-adapter-sketch-sdxl-1.0), [canny](https://huggingface.co/TencentARC/t2i-adapter-canny-sdxl-1.0), [lineart](https://huggingface.co/TencentARC/t2i-adapter-lineart-sdxl-1.0), [openpose](https://huggingface.co/TencentARC/t2i-adapter-openpose-sdxl-1.0), [depth-zoe](https://huggingface.co/TencentARC/t2i-adapter-depth-zoe-sdxl-1.0), and [depth-mid](https://huggingface.co/TencentARC/t2i-adapter-depth-midas-sdxl-1.0). We release two online demos: [![Huggingface T2I-Adapter-SDXL](https://img.shields.io/static/v1?label=Demo&message=Huggingface%20Gradio&color=orange)](https://huggingface.co/spaces/TencentARC/T2I-Adapter-SDXL) and [![Huggingface T2I-Adapter-SDXL Doodle](https://img.shields.io/static/v1?label=Demo&message=Huggingface%20Doodly%20Demo&color=orange)](https://huggingface.co/spaces/TencentARC/T2I-Adapter-SDXL-Sketch). 25 | - ✅ Aug. 21, 2023. We release [T2I-Adapter-SDXL](https://github.com/TencentARC/T2I-Adapter/), including sketch, canny, and keypoint. We still use the original recipe (77M parameters, a single inference) to drive [StableDiffusion-XL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0). Due to the limited computing resources, those adapters still need further improvement. We are collaborating with [HuggingFace](https://huggingface.co/), and a more powerful adapter is in the works. 26 | 27 | - ✅ Jul. 13, 2023. [Stability AI](https://stability.ai/) release [Stable Doodle](https://stability.ai/blog/clipdrop-launches-stable-doodle), a groundbreaking sketch-to-image tool based on T2I-Adapter and [SDXL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9). It makes drawing easier. 28 | 29 | - ✅ Mar. 16, 2023. We add **CoAdapter** (**Co**mposable **Adapter**). The online Huggingface Gadio has been updated [![Huggingface Gradio (CoAdapter)](https://img.shields.io/static/v1?label=Demo&message=Huggingface%20Gradio&color=orange)](https://huggingface.co/spaces/Adapter/CoAdapter). You can also try the [local gradio demo](app_coadapter.py). 30 | - ✅ Mar. 16, 2023. We have shrunk the git repo with [bfg](https://rtyley.github.io/bfg-repo-cleaner/). If you encounter any issues when pulling or pushing, you can try re-cloning the repository. Sorry for the inconvenience. 31 | - ✅ Mar. 3, 2023. Add a [*color adapter (spatial palette)*](https://huggingface.co/TencentARC/T2I-Adapter/tree/main/models), which has only **17M parameters**. 32 | - ✅ Mar. 3, 2023. Add four new adapters [*style, color, openpose and canny*](https://huggingface.co/TencentARC/T2I-Adapter/tree/main/models). See more info in the **[Adapter Zoo](https://github.com/TencentARC/T2I-Adapter/blob/SD/docs/AdapterZoo.md)**. 33 | - ✅ Feb. 23, 2023. Add the depth adapter [*t2iadapter_depth_sd14v1.pth*](https://huggingface.co/TencentARC/T2I-Adapter/tree/main/models). See more info in the **[Adapter Zoo](https://github.com/TencentARC/T2I-Adapter/blob/SD/docs/AdapterZoo.md)**. 34 | - ✅ Feb. 15, 2023. Release [T2I-Adapter](https://github.com/TencentARC/T2I-Adapter/tree/SD). 35 | 36 | --- 37 | 38 | # 🔥🔥🔥 Why T2I-Adapter-SDXL? 39 | ## The Original Recipe Drives Larger SD. 40 | 41 | | | SD-V1.4/1.5 | SD-XL | T2I-Adapter | T2I-Adapter-SDXL | 42 | | --- | --- |--- |--- |--- | 43 | | Parameters | 860M | 2.6B |77 M | 77/79 M | | 44 | 45 | ## Inherit High-quality Generation from SDXL. 46 | 47 | - Lineart-guided 48 | 49 | Model from [TencentARC/t2i-adapter-lineart-sdxl-1.0](https://huggingface.co/TencentARC/t2i-adapter-lineart-sdxl-1.0) 50 |

51 | 52 |

53 | 54 | - Keypoint-guided 55 | 56 | Model from [openpose_sdxl_1.0](https://huggingface.co/Adapter/t2iadapter/tree/main/openpose_sdxl_1.0) 57 |

58 | 59 |

60 | 61 | - Sketch-guided 62 | 63 | Model from [TencentARC/t2i-adapter-sketch-sdxl-1.0](https://huggingface.co/TencentARC/t2i-adapter-sketch-sdxl-1.0) 64 |

65 | 66 |

67 | 68 | - Canny-guided 69 | Model from [TencentARC/t2i-adapter-canny-sdxl-1.0](https://huggingface.co/TencentARC/t2i-adapter-canny-sdxl-1.0) 70 |

71 | 72 |

73 | 74 | - Depth-guided 75 | 76 | Depth guided models from [TencentARC/t2i-adapter-depth-midas-sdxl-1.0](https://huggingface.co/TencentARC/t2i-adapter-depth-midas-sdxl-1.0) and [TencentARC/t2i-adapter-depth-zoe-sdxl-1.0](https://huggingface.co/TencentARC/t2i-adapter-depth-zoe-sdxl-1.0) respectively 77 |

78 | 79 |

80 | 81 | # 🔧 Dependencies and Installation 82 | 83 | - Python >= 3.8 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html)) 84 | - [PyTorch >= 2.0.1](https://pytorch.org/) 85 | ```bash 86 | pip install -r requirements.txt 87 | ``` 88 | 89 | # ⏬ Download Models 90 | All models will be automatically downloaded. You can also choose to download manually from this [url](https://huggingface.co/TencentARC). 91 | 92 | # 🔥 How to Train 93 | Here we take sketch guidance as an example, but of course, you can also prepare your own dataset following this method. 94 | ```bash 95 | accelerate launch train_sketch.py --pretrained_model_name_or_path stabilityai/stable-diffusion-xl-base-1.0 --output_dir experiments/adapter_sketch_xl --config configs/train/Adapter-XL-sketch.yaml --mixed_precision="fp16" --resolution=1024 --learning_rate=1e-5 --max_train_steps=60000 --train_batch_size=1 --gradient_accumulation_steps=4 --report_to="wandb" --seed=42 --num_train_epochs 100 96 | ``` 97 | 98 | We train with `FP16` data precision on `4` NVIDIA `A100` GPUs. 99 | 100 | # 💻 How to Test 101 | Inference requires at least `15GB` of GPU memory. 102 | 103 | ## Quick start with [diffusers](https://github.com/huggingface/diffusers) 104 | 105 | To get started, first install the required dependencies: 106 | 107 | ```bash 108 | pip install git+https://github.com/huggingface/diffusers.git@t2iadapterxl # for now 109 | pip install -U controlnet_aux==0.0.7 # for conditioning models and detectors 110 | pip install transformers accelerate safetensors 111 | ``` 112 | 113 | 1. Images are first downloaded into the appropriate *control image* format. 114 | 2. The *control image* and *prompt* are passed to the [`StableDiffusionXLAdapterPipeline`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py#L125). 115 | 116 | Let's have a look at a simple example using the [LineArt Adapter](https://huggingface.co/TencentARC/t2i-adapter-lineart-sdxl-1.0). 117 | 118 | - Dependency 119 | ```py 120 | from diffusers import StableDiffusionXLAdapterPipeline, T2IAdapter, EulerAncestralDiscreteScheduler, AutoencoderKL 121 | from diffusers.utils import load_image, make_image_grid 122 | from controlnet_aux.lineart import LineartDetector 123 | import torch 124 | 125 | # load adapter 126 | adapter = T2IAdapter.from_pretrained( 127 | "TencentARC/t2i-adapter-lineart-sdxl-1.0", torch_dtype=torch.float16, varient="fp16" 128 | ).to("cuda") 129 | 130 | # load euler_a scheduler 131 | model_id = 'stabilityai/stable-diffusion-xl-base-1.0' 132 | euler_a = EulerAncestralDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler") 133 | vae=AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) 134 | pipe = StableDiffusionXLAdapterPipeline.from_pretrained( 135 | model_id, vae=vae, adapter=adapter, scheduler=euler_a, torch_dtype=torch.float16, variant="fp16", 136 | ).to("cuda") 137 | pipe.enable_xformers_memory_efficient_attention() 138 | 139 | line_detector = LineartDetector.from_pretrained("lllyasviel/Annotators").to("cuda") 140 | ``` 141 | 142 | - Condition Image 143 | ```py 144 | url = "https://huggingface.co/Adapter/t2iadapter/resolve/main/figs_SDXLV1.0/org_lin.jpg" 145 | image = load_image(url) 146 | image = line_detector( 147 | image, detect_resolution=384, image_resolution=1024 148 | ) 149 | ``` 150 | 151 | 152 | - Generation 153 | ```py 154 | prompt = "Ice dragon roar, 4k photo" 155 | negative_prompt = "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured" 156 | gen_images = pipe( 157 | prompt=prompt, 158 | negative_prompt=negative_prompt, 159 | image=image, 160 | num_inference_steps=30, 161 | adapter_conditioning_scale=0.8, 162 | guidance_scale=7.5, 163 | ).images[0] 164 | gen_images.save('out_lin.png') 165 | ``` 166 | 167 | 168 | ## Online Demo [![Huggingface T2I-Adapter-SDXL](https://img.shields.io/static/v1?label=Demo&message=Huggingface%20Gradio&color=orange)](https://huggingface.co/spaces/TencentARC/T2I-Adapter-SDXL) 169 | 170 | 171 | ## Online Doodly Demo [![Huggingface T2I-Adapter-SDXL](https://img.shields.io/static/v1?label=Demo&message=Huggingface%20Gradio&color=orange)](https://huggingface.co/spaces/TencentARC/T2I-Adapter-SDXL-Sketch) 172 | 173 | 174 | 175 | 176 | # Tutorials on HuggingFace: 177 | - Sketch: [https://huggingface.co/TencentARC/t2i-adapter-sketch-sdxl-1.0](https://huggingface.co/TencentARC/t2i-adapter-sketch-sdxl-1.0) 178 | - Canny: [https://huggingface.co/TencentARC/t2i-adapter-canny-sdxl-1.0](https://huggingface.co/TencentARC/t2i-adapter-canny-sdxl-1.0) 179 | - Lineart: [https://huggingface.co/TencentARC/t2i-adapter-lineart-sdxl-1.0](https://huggingface.co/TencentARC/t2i-adapter-lineart-sdxl-1.0) 180 | - Openpose: [https://huggingface.co/TencentARC/t2i-adapter-openpose-sdxl-1.0](https://huggingface.co/TencentARC/t2i-adapter-openpose-sdxl-1.0) 181 | - Depth-mid: [https://huggingface.co/TencentARC/t2i-adapter-depth-midas-sdxl-1.0](https://huggingface.co/TencentARC/t2i-adapter-depth-midas-sdxl-1.0) 182 | - Depth-zoe: [https://huggingface.co/TencentARC/t2i-adapter-depth-zoe-sdxl-1.0](https://huggingface.co/TencentARC/t2i-adapter-depth-zoe-sdxl-1.0) 183 | 184 | ... 185 | 186 | # Other Source 187 | Jul. 13, 2023. [Stability AI](https://stability.ai/) release [Stable Doodle](https://stability.ai/blog/clipdrop-launches-stable-doodle), a groundbreaking sketch-to-image tool based on T2I-Adapter and [SDXL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9). It makes drawing easier. 188 | 189 | https://user-images.githubusercontent.com/73707470/253800159-c7e12362-1ea1-4b20-a44e-bd6c8d546765.mp4 190 | 191 | # 🤗 Acknowledgements 192 | - Thanks to HuggingFace for their support of T2I-Adapter. 193 | - T2I-Adapter is co-hosted by Tencent ARC Lab and Peking University [VILLA](https://villa.jianzhang.tech/). 194 | 195 | # BibTeX 196 | 197 | @article{mou2023t2i, 198 | title={T2i-adapter: Learning adapters to dig out more controllable ability for text-to-image diffusion models}, 199 | author={Mou, Chong and Wang, Xintao and Xie, Liangbin and Wu, Yanze and Zhang, Jian and Qi, Zhongang and Shan, Ying and Qie, Xiaohu}, 200 | journal={arXiv preprint arXiv:2302.08453}, 201 | year={2023} 202 | } 203 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import gradio as gr 3 | import torch 4 | from basicsr.utils import tensor2img 5 | import os 6 | from huggingface_hub import hf_hub_url 7 | import subprocess 8 | import shlex 9 | import cv2 10 | from omegaconf import OmegaConf 11 | 12 | from demo import create_demo_sketch, create_demo_canny, create_demo_pose 13 | from Adapter.Sampling import diffusion_inference 14 | from configs.utils import instantiate_from_config 15 | from Adapter.extra_condition.api import get_cond_model, ExtraCondition 16 | from Adapter.extra_condition import api 17 | from Adapter.inference_base import get_base_argument_parser 18 | 19 | torch.set_grad_enabled(False) 20 | 21 | urls = { 22 | 'TencentARC/T2I-Adapter':[ 23 | 'models_XL/adapter-xl-canny.pth', 'models_XL/adapter-xl-sketch.pth', 24 | 'models_XL/adapter-xl-openpose.pth', 'third-party-models/body_pose_model.pth', 25 | 'third-party-models/table5_pidinet.pth' 26 | ] 27 | } 28 | 29 | if os.path.exists('checkpoints') == False: 30 | os.mkdir('checkpoints') 31 | for repo in urls: 32 | files = urls[repo] 33 | for file in files: 34 | url = hf_hub_url(repo, file) 35 | name_ckp = url.split('/')[-1] 36 | save_path = os.path.join('checkpoints',name_ckp) 37 | if os.path.exists(save_path) == False: 38 | subprocess.run(shlex.split(f'wget {url} -O {save_path}')) 39 | 40 | parser = get_base_argument_parser() 41 | global_opt = parser.parse_args() 42 | global_opt.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 43 | 44 | DESCRIPTION = '# [T2I-Adapter-XL](https://github.com/TencentARC/T2I-Adapter)' 45 | 46 | DESCRIPTION += f'

Gradio demo for **T2I-Adapter-XL**: [[GitHub]](https://github.com/TencentARC/T2I-Adapter). If T2I-Adapter-XL is helpful, please help to ⭐ the [Github Repo](https://github.com/TencentARC/T2I-Adapter) and recommend it to your friends 😊

' 47 | 48 | # DESCRIPTION += f'

For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. Duplicate Space

' 49 | 50 | # diffusion sampler creation 51 | sampler = diffusion_inference('stabilityai/stable-diffusion-xl-base-1.0') 52 | 53 | def run(input_image, in_type, prompt, a_prompt, n_prompt, ddim_steps, scale, seed, cond_name, con_strength): 54 | in_type = in_type.lower() 55 | prompt = prompt+', '+a_prompt 56 | config = OmegaConf.load(f'configs/inference/Adapter-XL-{cond_name}.yaml') 57 | # Adapter creation 58 | adapter_config = config.model.params.adapter_config 59 | adapter = instantiate_from_config(adapter_config).cuda() 60 | adapter.load_state_dict(torch.load(config.model.params.adapter_config.pretrained)) 61 | cond_model = get_cond_model(global_opt, getattr(ExtraCondition, cond_name)) 62 | process_cond_module = getattr(api, f'get_cond_{cond_name}') 63 | 64 | # diffusion generation 65 | cond = process_cond_module( 66 | global_opt, 67 | input_image, 68 | cond_inp_type = in_type, 69 | cond_model = cond_model 70 | ) 71 | with torch.no_grad(): 72 | adapter_features = adapter(cond) 73 | 74 | for i in range(len(adapter_features)): 75 | adapter_features[i] = adapter_features[i]*con_strength 76 | 77 | result = sampler.inference( 78 | prompt = prompt, 79 | prompt_n = n_prompt, 80 | steps = ddim_steps, 81 | adapter_features = copy.deepcopy(adapter_features), 82 | guidance_scale = scale, 83 | size = (cond.shape[-2], cond.shape[-1]), 84 | seed= seed, 85 | ) 86 | im_cond = tensor2img(cond) 87 | 88 | return result[:,:,::-1], im_cond 89 | 90 | with gr.Blocks(css='style.css') as demo: 91 | gr.Markdown(DESCRIPTION) 92 | with gr.Tabs(): 93 | with gr.TabItem('Sketch guided'): 94 | create_demo_sketch(run) 95 | with gr.TabItem('Canny guided'): 96 | create_demo_canny(run) 97 | with gr.TabItem('Keypoint guided'): 98 | create_demo_pose(run) 99 | 100 | demo.queue(concurrency_count=3, max_size=20) 101 | demo.launch(server_name="0.0.0.0") 102 | -------------------------------------------------------------------------------- /assets/logo3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/T2I-Adapter/c408b059c36e3f9ce336b66746bd606edaa5483a/assets/logo3.png -------------------------------------------------------------------------------- /configs/inference/Adapter-XL-canny.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | params: 3 | adapter_config: 4 | name: canny 5 | target: Adapter.models.adapters.Adapter_XL 6 | params: 7 | cin: 256 8 | channels: [320, 640, 1280, 1280] 9 | nums_rb: 2 10 | ksize: 1 11 | sk: true 12 | use_conv: false 13 | pretrained: checkpoints/adapter-xl-canny.pth -------------------------------------------------------------------------------- /configs/inference/Adapter-XL-openpose.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | params: 3 | adapter_config: 4 | name: openpose 5 | target: Adapter.models.adapters.Adapter_XL 6 | params: 7 | cin: 768 8 | channels: [320, 640, 1280, 1280] 9 | nums_rb: 2 10 | ksize: 1 11 | sk: true 12 | use_conv: false 13 | # pretrained: /group/30042/chongmou/ft_local/Diffusion_part2/T2I-Adapter-XL/experiments/adapter_encoder_mid_openpose_extream_ft/checkpoint-9000/model_00.pth 14 | pretrained: checkpoints/adapter-xl-openpose.pth -------------------------------------------------------------------------------- /configs/inference/Adapter-XL-sketch.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | params: 3 | adapter_config: 4 | name: sketch 5 | target: Adapter.models.adapters.Adapter_XL 6 | params: 7 | cin: 256 8 | channels: [320, 640, 1280, 1280] 9 | nums_rb: 2 10 | ksize: 1 11 | sk: true 12 | use_conv: false 13 | pretrained: checkpoints/adapter-xl-sketch.pth -------------------------------------------------------------------------------- /configs/train/Adapter-XL-sketch.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | params: 3 | adapter_config: 4 | name: sketch 5 | target: Adapter.models.adapters.Adapter_XL 6 | params: 7 | cin: 256 8 | channels: [320, 640, 1280, 1280] 9 | nums_rb: 2 10 | ksize: 1 11 | sk: true 12 | use_conv: false 13 | pretrained: checkpoints/adapter-xl-sketch.pth 14 | data: 15 | target: dataset.dataset_laion.WebDataModuleFromConfig_Laion_Lexica 16 | params: 17 | tar_base1: "data/LAION_6plus" 18 | tar_base2: "data/WebDataset" 19 | batch_size: 2 20 | num_workers: 8 21 | multinode: True 22 | train: 23 | shards1: 'train_{00000..00006}/{00000..00171}.tar' 24 | shards2: 'lexica-{000000..000099}.tar' 25 | shards1_prob: 0.7 26 | shards2_prob: 0.3 27 | shuffle: 10000 28 | image_key: jpg 29 | image_transforms: 30 | - target: torchvision.transforms.Resize 31 | params: 32 | size: 1024 33 | interpolation: 3 34 | - target: torchvision.transforms.RandomCrop 35 | params: 36 | size: 1024 37 | process: 38 | target: dataset.utils.AddEqual_fp16 -------------------------------------------------------------------------------- /configs/utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import math 3 | 4 | import cv2 5 | import torch 6 | import numpy as np 7 | 8 | import os 9 | from safetensors.torch import load_file 10 | 11 | from inspect import isfunction 12 | from PIL import Image, ImageDraw, ImageFont 13 | 14 | 15 | def log_txt_as_img(wh, xc, size=10): 16 | # wh a tuple of (width, height) 17 | # xc a list of captions to plot 18 | b = len(xc) 19 | txts = list() 20 | for bi in range(b): 21 | txt = Image.new("RGB", wh, color="white") 22 | draw = ImageDraw.Draw(txt) 23 | font = ImageFont.truetype('assets/DejaVuSans.ttf', size=size) 24 | nc = int(40 * (wh[0] / 256)) 25 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 26 | 27 | try: 28 | draw.text((0, 0), lines, fill="black", font=font) 29 | except UnicodeEncodeError: 30 | print("Cant encode string for logging. Skipping.") 31 | 32 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 33 | txts.append(txt) 34 | txts = np.stack(txts) 35 | txts = torch.tensor(txts) 36 | return txts 37 | 38 | 39 | def ismap(x): 40 | if not isinstance(x, torch.Tensor): 41 | return False 42 | return (len(x.shape) == 4) and (x.shape[1] > 3) 43 | 44 | 45 | def isimage(x): 46 | if not isinstance(x, torch.Tensor): 47 | return False 48 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 49 | 50 | 51 | def exists(x): 52 | return x is not None 53 | 54 | 55 | def default(val, d): 56 | if exists(val): 57 | return val 58 | return d() if isfunction(d) else d 59 | 60 | 61 | def mean_flat(tensor): 62 | """ 63 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 64 | Take the mean over all non-batch dimensions. 65 | """ 66 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 67 | 68 | 69 | def count_params(model, verbose=False): 70 | total_params = sum(p.numel() for p in model.parameters()) 71 | if verbose: 72 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 73 | return total_params 74 | 75 | 76 | def instantiate_from_config(config): 77 | if not "target" in config: 78 | if config == '__is_first_stage__': 79 | return None 80 | elif config == "__is_unconditional__": 81 | return None 82 | raise KeyError("Expected key `target` to instantiate.") 83 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 84 | 85 | 86 | def get_obj_from_str(string, reload=False): 87 | module, cls = string.rsplit(".", 1) 88 | if reload: 89 | module_imp = importlib.import_module(module) 90 | importlib.reload(module_imp) 91 | return getattr(importlib.import_module(module, package=None), cls) 92 | 93 | 94 | checkpoint_dict_replacements = { 95 | 'cond_stage_model.transformer.text_model.embeddings.': 'cond_stage_model.transformer.embeddings.', 96 | 'cond_stage_model.transformer.text_model.encoder.': 'cond_stage_model.transformer.encoder.', 97 | 'cond_stage_model.transformer.text_model.final_layer_norm.': 'cond_stage_model.transformer.final_layer_norm.', 98 | } 99 | 100 | 101 | def transform_checkpoint_dict_key(k): 102 | for text, replacement in checkpoint_dict_replacements.items(): 103 | if k.startswith(text): 104 | k = replacement + k[len(text):] 105 | 106 | return k 107 | 108 | 109 | def get_state_dict_from_checkpoint(pl_sd): 110 | pl_sd = pl_sd.pop("state_dict", pl_sd) 111 | pl_sd.pop("state_dict", None) 112 | 113 | sd = {} 114 | for k, v in pl_sd.items(): 115 | new_key = transform_checkpoint_dict_key(k) 116 | 117 | if new_key is not None: 118 | sd[new_key] = v 119 | 120 | pl_sd.clear() 121 | pl_sd.update(sd) 122 | 123 | return pl_sd 124 | 125 | 126 | def read_state_dict(checkpoint_file, print_global_state=False): 127 | _, extension = os.path.splitext(checkpoint_file) 128 | if extension.lower() == ".safetensors": 129 | pl_sd = load_file(checkpoint_file, device='cpu') 130 | else: 131 | pl_sd = torch.load(checkpoint_file, map_location='cpu') 132 | 133 | if print_global_state and "global_step" in pl_sd: 134 | print(f"Global Step: {pl_sd['global_step']}") 135 | 136 | sd = get_state_dict_from_checkpoint(pl_sd) 137 | return sd 138 | 139 | 140 | def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False): 141 | print(f"Loading model from {ckpt}") 142 | sd = read_state_dict(ckpt) 143 | model = instantiate_from_config(config.model) 144 | m, u = model.load_state_dict(sd, strict=False) 145 | if len(m) > 0 and verbose: 146 | print("missing keys:") 147 | print(m) 148 | if len(u) > 0 and verbose: 149 | print("unexpected keys:") 150 | print(u) 151 | 152 | if 'anything' in ckpt.lower() and vae_ckpt is None: 153 | vae_ckpt = 'models/anything-v4.0.vae.pt' 154 | 155 | if vae_ckpt is not None and vae_ckpt != 'None': 156 | print(f"Loading vae model from {vae_ckpt}") 157 | vae_sd = torch.load(vae_ckpt, map_location="cpu") 158 | if "global_step" in vae_sd: 159 | print(f"Global Step: {vae_sd['global_step']}") 160 | sd = vae_sd["state_dict"] 161 | m, u = model.first_stage_model.load_state_dict(sd, strict=False) 162 | if len(m) > 0 and verbose: 163 | print("missing keys:") 164 | print(m) 165 | if len(u) > 0 and verbose: 166 | print("unexpected keys:") 167 | print(u) 168 | 169 | model.cuda() 170 | model.eval() 171 | return model 172 | -------------------------------------------------------------------------------- /dataset/dataset_laion.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | import os 5 | import pytorch_lightning as pl 6 | import torch 7 | import webdataset as wds 8 | from torchvision.transforms import transforms 9 | 10 | from configs.utils import instantiate_from_config 11 | 12 | 13 | def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True): 14 | """Take a list of samples (as dictionary) and create a batch, preserving the keys. 15 | If `tensors` is True, `ndarray` objects are combined into 16 | tensor batches. 17 | :param dict samples: list of samples 18 | :param bool tensors: whether to turn lists of ndarrays into a single ndarray 19 | :returns: single sample consisting of a batch 20 | :rtype: dict 21 | """ 22 | keys = set.intersection(*[set(sample.keys()) for sample in samples]) 23 | batched = {key: [] for key in keys} 24 | 25 | for s in samples: 26 | [batched[key].append(s[key]) for key in batched] 27 | 28 | result = {} 29 | for key in batched: 30 | if isinstance(batched[key][0], (int, float)): 31 | if combine_scalars: 32 | result[key] = np.array(list(batched[key])) 33 | elif isinstance(batched[key][0], torch.Tensor): 34 | if combine_tensors: 35 | result[key] = torch.stack(list(batched[key])) 36 | elif isinstance(batched[key][0], np.ndarray): 37 | if combine_tensors: 38 | result[key] = np.array(list(batched[key])) 39 | else: 40 | result[key] = list(batched[key]) 41 | return result 42 | 43 | 44 | class WebDataModuleFromConfig_Laion_Lexica(pl.LightningDataModule): 45 | 46 | def __init__(self, 47 | tar_base1, 48 | tar_base2, 49 | batch_size, 50 | train=None, 51 | validation=None, 52 | test=None, 53 | num_workers=4, 54 | multinode=True, 55 | min_size=None, 56 | max_pwatermark=1.0, 57 | **kwargs): 58 | super().__init__() 59 | print(f'Setting tar base to {tar_base1} and {tar_base2}') 60 | self.tar_base1 = tar_base1 61 | self.tar_base2 = tar_base2 62 | self.batch_size = batch_size 63 | self.num_workers = num_workers 64 | self.train = train 65 | self.validation = validation 66 | self.test = test 67 | self.multinode = multinode 68 | self.min_size = min_size # filter out very small images 69 | self.max_pwatermark = max_pwatermark # filter out watermarked images 70 | 71 | def make_loader(self, dataset_config): 72 | image_transforms = [instantiate_from_config(tt) for tt in dataset_config.image_transforms] 73 | image_transforms = transforms.Compose(image_transforms) 74 | 75 | process = instantiate_from_config(dataset_config['process']) 76 | 77 | shuffle = dataset_config.get('shuffle', 0) 78 | shardshuffle = shuffle > 0 79 | 80 | nodesplitter = wds.shardlists.split_by_node if self.multinode else wds.shardlists.single_node_only 81 | 82 | # make dataset for laion 83 | tars_1 = os.path.join(self.tar_base1, dataset_config.shards1) 84 | 85 | dset1 = wds.WebDataset( 86 | tars_1, nodesplitter=nodesplitter, shardshuffle=shardshuffle, 87 | handler=wds.warn_and_continue).repeat().shuffle(shuffle) 88 | print(f'Loading webdataset with {len(dset1.pipeline[0].urls)} shards.') 89 | 90 | dset1 = ( 91 | dset1.select(self.filter_keys).decode('pil', 92 | handler=wds.warn_and_continue).select(self.filter_size).map_dict( 93 | jpg=image_transforms, handler=wds.warn_and_continue).map(process)) 94 | dset1 = (dset1.batched(self.batch_size, partial=False, collation_fn=dict_collation_fn)) 95 | 96 | # make dataset for lexica 97 | tars_2 = os.path.join(self.tar_base2, dataset_config.shards2) 98 | 99 | dset2 = wds.WebDataset( 100 | tars_2, nodesplitter=nodesplitter, shardshuffle=shardshuffle, 101 | handler=wds.warn_and_continue).repeat().shuffle(shuffle) 102 | 103 | dset2 = ( 104 | dset2.decode('pil', 105 | handler=wds.warn_and_continue).map_dict(jpg=image_transforms, 106 | handler=wds.warn_and_continue).map(process)) 107 | dset2 = (dset2.batched(self.batch_size, partial=False, collation_fn=dict_collation_fn)) 108 | 109 | # get the corresponding prob 110 | shards1_prob = dataset_config.get('shards1_prob', 0) 111 | shards2_prob = dataset_config.get('shards2_prob', 0) 112 | dataset = wds.RandomMix([dset1, dset2], [shards1_prob, shards2_prob]) 113 | 114 | loader = wds.WebLoader(dataset, batch_size=None, shuffle=False, num_workers=self.num_workers) 115 | 116 | return loader 117 | 118 | def filter_size(self, x): 119 | if self.min_size is None: 120 | return True 121 | try: 122 | return x['json']['original_width'] >= self.min_size and x['json']['original_height'] >= self.min_size and x[ 123 | 'json']['pwatermark'] <= self.max_pwatermark 124 | except Exception: 125 | return False 126 | 127 | def filter_keys(self, x): 128 | try: 129 | return ("jpg" in x) and ("txt" in x) 130 | except Exception: 131 | return False 132 | 133 | def train_dataloader(self): 134 | return self.make_loader(self.train) 135 | 136 | def val_dataloader(self): 137 | return None 138 | 139 | def test_dataloader(self): 140 | return None -------------------------------------------------------------------------------- /dataset/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import cv2 4 | import numpy as np 5 | import random 6 | from torchvision.transforms import transforms 7 | from torchvision.transforms.functional import to_tensor 8 | from transformers import CLIPProcessor 9 | from basicsr.utils import img2tensor 10 | import torch 11 | 12 | 13 | class AddEqual_fp16(object): 14 | 15 | def __init__(self): 16 | print('There is no specific process function') 17 | 18 | def __call__(self, sample): 19 | # sample['jpg'] is PIL image 20 | x = sample['jpg'] 21 | sample['jpg'] = to_tensor(x)#.to(torch.float16) 22 | return sample 23 | 24 | 25 | class AddCannyRandomThreshold(object): 26 | 27 | def __init__(self, low_threshold=100, high_threshold=200, shift_range=50): 28 | self.low_threshold = low_threshold 29 | self.high_threshold = high_threshold 30 | self.threshold_prng = np.random.RandomState() 31 | self.shift_range = shift_range 32 | 33 | def __call__(self, sample): 34 | # sample['jpg'] is PIL image 35 | x = sample['jpg'] 36 | img = cv2.cvtColor(np.array(x), cv2.COLOR_RGB2BGR) 37 | low_threshold = self.low_threshold + self.threshold_prng.randint(-self.shift_range, self.shift_range) 38 | high_threshold = self.high_threshold + self.threshold_prng.randint(-self.shift_range, self.shift_range) 39 | canny = cv2.Canny(img, low_threshold, high_threshold)[..., None] 40 | sample['canny'] = img2tensor(canny, bgr2rgb=True, float32=True) / 255. 41 | sample['jpg'] = to_tensor(x) 42 | return sample -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | 3 | def create_demo_sketch(run): 4 | cond_name = gr.State(value='sketch') 5 | with gr.Blocks() as demo: 6 | with gr.Row(): 7 | gr.Markdown('## Control Stable Diffusion-XL with Sketch Maps') 8 | with gr.Row(): 9 | with gr.Column(): 10 | input_image = gr.Image(source='upload', type='numpy') 11 | prompt = gr.Textbox(label='Prompt') 12 | run_button = gr.Button(label='Run') 13 | in_type = gr.Radio( 14 | choices=["Image", "Sketch"], 15 | label=f"Input type for Sketch", 16 | interactive=True, 17 | value="Image", 18 | ) 19 | with gr.Accordion('Advanced options', open=False): 20 | con_strength = gr.Slider(label='Control Strength', 21 | minimum=0.0, 22 | maximum=1.0, 23 | value=1.0, 24 | step=0.1) 25 | ddim_steps = gr.Slider(label='Steps', 26 | minimum=1, 27 | maximum=100, 28 | value=20, 29 | step=1) 30 | scale = gr.Slider(label='Guidance Scale', 31 | minimum=0.1, 32 | maximum=30.0, 33 | value=7.5, 34 | step=0.1) 35 | seed = gr.Slider(label='Seed', 36 | minimum=-1, 37 | maximum=2147483647, 38 | step=1, 39 | randomize=True) 40 | a_prompt = gr.Textbox( 41 | label='Added Prompt', 42 | value='in real world, high quality') 43 | n_prompt = gr.Textbox( 44 | label='Negative Prompt', 45 | value='extra digit, fewer digits, cropped, worst quality, low quality' 46 | ) 47 | with gr.Column(): 48 | result_gallery = gr.Gallery(label='Output', 49 | show_label=False, 50 | elem_id='gallery').style( 51 | grid=2, height='auto') 52 | ips = [ 53 | input_image, in_type, prompt, a_prompt, n_prompt, 54 | ddim_steps, scale, seed, cond_name, con_strength 55 | ] 56 | run_button.click(fn=run, 57 | inputs=ips, 58 | outputs=[result_gallery], 59 | api_name='sketch') 60 | return demo 61 | 62 | def create_demo_canny(run): 63 | cond_name = gr.State(value='canny') 64 | with gr.Blocks() as demo: 65 | with gr.Row(): 66 | gr.Markdown('## Control Stable Diffusion-XL with Canny Maps') 67 | with gr.Row(): 68 | with gr.Column(): 69 | input_image = gr.Image(source='upload', type='numpy') 70 | prompt = gr.Textbox(label='Prompt') 71 | run_button = gr.Button(label='Run') 72 | in_type = gr.Radio( 73 | choices=["Image", "Canny"], 74 | label=f"Input type for Canny", 75 | interactive=True, 76 | value="Image", 77 | ) 78 | with gr.Accordion('Advanced options', open=False): 79 | con_strength = gr.Slider(label='Control Strength', 80 | minimum=0.0, 81 | maximum=1.0, 82 | value=1.0, 83 | step=0.1) 84 | ddim_steps = gr.Slider(label='Steps', 85 | minimum=1, 86 | maximum=100, 87 | value=20, 88 | step=1) 89 | scale = gr.Slider(label='Guidance Scale', 90 | minimum=0.1, 91 | maximum=30.0, 92 | value=7.5, 93 | step=0.1) 94 | seed = gr.Slider(label='Seed', 95 | minimum=-1, 96 | maximum=2147483647, 97 | step=1, 98 | randomize=True) 99 | a_prompt = gr.Textbox( 100 | label='Added Prompt', 101 | value='in real world, high quality') 102 | n_prompt = gr.Textbox( 103 | label='Negative Prompt', 104 | value='extra digit, fewer digits, cropped, worst quality, low quality' 105 | ) 106 | with gr.Column(): 107 | result_gallery = gr.Gallery(label='Output', 108 | show_label=False, 109 | elem_id='gallery').style( 110 | grid=2, height='auto') 111 | ips = [ 112 | input_image, in_type, prompt, a_prompt, n_prompt, 113 | ddim_steps, scale, seed, cond_name, con_strength 114 | ] 115 | run_button.click(fn=run, 116 | inputs=ips, 117 | outputs=[result_gallery], 118 | api_name='canny') 119 | return demo 120 | 121 | def create_demo_pose(run): 122 | cond_name = gr.State(value='openpose') 123 | in_type = gr.State(value='Image') 124 | with gr.Blocks() as demo: 125 | with gr.Row(): 126 | gr.Markdown('## Control Stable Diffusion-XL with Keypoint Maps') 127 | with gr.Row(): 128 | with gr.Column(): 129 | input_image = gr.Image(source='upload', type='numpy') 130 | prompt = gr.Textbox(label='Prompt') 131 | run_button = gr.Button(label='Run') 132 | with gr.Accordion('Advanced options', open=False): 133 | con_strength = gr.Slider(label='Control Strength', 134 | minimum=0.0, 135 | maximum=1.0, 136 | value=1.0, 137 | step=0.1) 138 | ddim_steps = gr.Slider(label='Steps', 139 | minimum=1, 140 | maximum=100, 141 | value=20, 142 | step=1) 143 | scale = gr.Slider(label='Guidance Scale', 144 | minimum=0.1, 145 | maximum=30.0, 146 | value=7.5, 147 | step=0.1) 148 | seed = gr.Slider(label='Seed', 149 | minimum=-1, 150 | maximum=2147483647, 151 | step=1, 152 | randomize=True) 153 | a_prompt = gr.Textbox( 154 | label='Added Prompt', 155 | value='in real world, high quality') 156 | n_prompt = gr.Textbox( 157 | label='Negative Prompt', 158 | value='extra digit, fewer digits, cropped, worst quality, low quality' 159 | ) 160 | with gr.Column(): 161 | result_gallery = gr.Gallery(label='Output', 162 | show_label=False, 163 | elem_id='gallery').style( 164 | grid=2, height='auto') 165 | ips = [ 166 | input_image, in_type, prompt, a_prompt, n_prompt, 167 | ddim_steps, scale, seed, cond_name, con_strength 168 | ] 169 | run_button.click(fn=run, 170 | inputs=ips, 171 | outputs=[result_gallery], 172 | api_name='openpose') 173 | return demo -------------------------------------------------------------------------------- /examples/dog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/T2I-Adapter/c408b059c36e3f9ce336b66746bd606edaa5483a/examples/dog.png -------------------------------------------------------------------------------- /examples/people.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/T2I-Adapter/c408b059c36e3f9ce336b66746bd606edaa5483a/examples/people.jpg -------------------------------------------------------------------------------- /examples/room.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/T2I-Adapter/c408b059c36e3f9ce336b66746bd606edaa5483a/examples/room.jpg -------------------------------------------------------------------------------- /models/unet.py: -------------------------------------------------------------------------------- 1 | from diffusers.models.unet_2d_condition import UNet2DConditionModel, UNet2DConditionOutput, logger 2 | import torch 3 | from typing import Any, Dict, List, Optional, Tuple, Union 4 | 5 | class UNet(UNet2DConditionModel): 6 | def forward( 7 | self, 8 | sample: torch.FloatTensor, 9 | timestep: Union[torch.Tensor, float, int], 10 | encoder_hidden_states: torch.Tensor, 11 | class_labels: Optional[torch.Tensor] = None, 12 | timestep_cond: Optional[torch.Tensor] = None, 13 | attention_mask: Optional[torch.Tensor] = None, 14 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 15 | added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, 16 | down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, 17 | mid_block_additional_residual: Optional[torch.Tensor] = None, 18 | encoder_attention_mask: Optional[torch.Tensor] = None, 19 | return_dict: bool = True, 20 | ) -> Union[UNet2DConditionOutput, Tuple]: 21 | r""" 22 | The [`UNet2DConditionModel`] forward method. 23 | 24 | Args: 25 | sample (`torch.FloatTensor`): 26 | The noisy input tensor with the following shape `(batch, channel, height, width)`. 27 | timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. 28 | encoder_hidden_states (`torch.FloatTensor`): 29 | The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. 30 | encoder_attention_mask (`torch.Tensor`): 31 | A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If 32 | `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, 33 | which adds large negative values to the attention scores corresponding to "discard" tokens. 34 | return_dict (`bool`, *optional*, defaults to `True`): 35 | Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain 36 | tuple. 37 | cross_attention_kwargs (`dict`, *optional*): 38 | A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. 39 | added_cond_kwargs: (`dict`, *optional*): 40 | A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that 41 | are passed along to the UNet blocks. 42 | 43 | Returns: 44 | [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: 45 | If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise 46 | a `tuple` is returned where the first element is the sample tensor. 47 | """ 48 | # By default samples have to be AT least a multiple of the overall upsampling factor. 49 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). 50 | # However, the upsampling interpolation output size can be forced to fit any upsampling size 51 | # on the fly if necessary. 52 | default_overall_up_factor = 2**self.num_upsamplers 53 | 54 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 55 | forward_upsample_size = False 56 | upsample_size = None 57 | 58 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): 59 | logger.info("Forward upsample size to force interpolation output size.") 60 | forward_upsample_size = True 61 | 62 | # ensure attention_mask is a bias, and give it a singleton query_tokens dimension 63 | # expects mask of shape: 64 | # [batch, key_tokens] 65 | # adds singleton query_tokens dimension: 66 | # [batch, 1, key_tokens] 67 | # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: 68 | # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) 69 | # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) 70 | if attention_mask is not None: 71 | # assume that mask is expressed as: 72 | # (1 = keep, 0 = discard) 73 | # convert mask into a bias that can be added to attention scores: 74 | # (keep = +0, discard = -10000.0) 75 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 76 | attention_mask = attention_mask.unsqueeze(1) 77 | 78 | # convert encoder_attention_mask to a bias the same way we do for attention_mask 79 | if encoder_attention_mask is not None: 80 | encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 81 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1) 82 | 83 | # 0. center input if necessary 84 | if self.config.center_input_sample: 85 | sample = 2 * sample - 1.0 86 | 87 | # 1. time 88 | timesteps = timestep 89 | if not torch.is_tensor(timesteps): 90 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 91 | # This would be a good case for the `match` statement (Python 3.10+) 92 | is_mps = sample.device.type == "mps" 93 | if isinstance(timestep, float): 94 | dtype = torch.float32 if is_mps else torch.float64 95 | else: 96 | dtype = torch.int32 if is_mps else torch.int64 97 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 98 | elif len(timesteps.shape) == 0: 99 | timesteps = timesteps[None].to(sample.device) 100 | 101 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 102 | timesteps = timesteps.expand(sample.shape[0]) 103 | 104 | t_emb = self.time_proj(timesteps) 105 | 106 | # `Timesteps` does not contain any weights and will always return f32 tensors 107 | # but time_embedding might actually be running in fp16. so we need to cast here. 108 | # there might be better ways to encapsulate this. 109 | t_emb = t_emb.to(dtype=sample.dtype) 110 | 111 | emb = self.time_embedding(t_emb, timestep_cond) 112 | aug_emb = None 113 | 114 | if self.class_embedding is not None: 115 | if class_labels is None: 116 | raise ValueError("class_labels should be provided when num_class_embeds > 0") 117 | 118 | if self.config.class_embed_type == "timestep": 119 | class_labels = self.time_proj(class_labels) 120 | 121 | # `Timesteps` does not contain any weights and will always return f32 tensors 122 | # there might be better ways to encapsulate this. 123 | class_labels = class_labels.to(dtype=sample.dtype) 124 | 125 | class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) 126 | 127 | if self.config.class_embeddings_concat: 128 | emb = torch.cat([emb, class_emb], dim=-1) 129 | else: 130 | emb = emb + class_emb 131 | 132 | if self.config.addition_embed_type == "text": 133 | aug_emb = self.add_embedding(encoder_hidden_states) 134 | elif self.config.addition_embed_type == "text_image": 135 | # Kandinsky 2.1 - style 136 | if "image_embeds" not in added_cond_kwargs: 137 | raise ValueError( 138 | f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" 139 | ) 140 | 141 | image_embs = added_cond_kwargs.get("image_embeds") 142 | text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) 143 | aug_emb = self.add_embedding(text_embs, image_embs) 144 | elif self.config.addition_embed_type == "text_time": 145 | # SDXL - style 146 | if "text_embeds" not in added_cond_kwargs: 147 | raise ValueError( 148 | f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" 149 | ) 150 | text_embeds = added_cond_kwargs.get("text_embeds") 151 | if "time_ids" not in added_cond_kwargs: 152 | raise ValueError( 153 | f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" 154 | ) 155 | time_ids = added_cond_kwargs.get("time_ids") 156 | time_embeds = self.add_time_proj(time_ids.flatten()) 157 | time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) 158 | 159 | add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) 160 | add_embeds = add_embeds.to(emb.dtype) 161 | aug_emb = self.add_embedding(add_embeds) 162 | elif self.config.addition_embed_type == "image": 163 | # Kandinsky 2.2 - style 164 | if "image_embeds" not in added_cond_kwargs: 165 | raise ValueError( 166 | f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" 167 | ) 168 | image_embs = added_cond_kwargs.get("image_embeds") 169 | aug_emb = self.add_embedding(image_embs) 170 | elif self.config.addition_embed_type == "image_hint": 171 | # Kandinsky 2.2 - style 172 | if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: 173 | raise ValueError( 174 | f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" 175 | ) 176 | image_embs = added_cond_kwargs.get("image_embeds") 177 | hint = added_cond_kwargs.get("hint") 178 | aug_emb, hint = self.add_embedding(image_embs, hint) 179 | sample = torch.cat([sample, hint], dim=1) 180 | 181 | emb = emb + aug_emb if aug_emb is not None else emb 182 | 183 | if self.time_embed_act is not None: 184 | emb = self.time_embed_act(emb) 185 | 186 | if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": 187 | encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) 188 | elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": 189 | # Kadinsky 2.1 - style 190 | if "image_embeds" not in added_cond_kwargs: 191 | raise ValueError( 192 | f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" 193 | ) 194 | 195 | image_embeds = added_cond_kwargs.get("image_embeds") 196 | encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) 197 | elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": 198 | # Kandinsky 2.2 - style 199 | if "image_embeds" not in added_cond_kwargs: 200 | raise ValueError( 201 | f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" 202 | ) 203 | image_embeds = added_cond_kwargs.get("image_embeds") 204 | encoder_hidden_states = self.encoder_hid_proj(image_embeds) 205 | # 2. pre-process 206 | sample = self.conv_in(sample) 207 | 208 | # 3. down 209 | 210 | is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None 211 | is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None 212 | 213 | down_block_res_samples = (sample,) 214 | for downsample_block in self.down_blocks: 215 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 216 | # For t2i-adapter CrossAttnDownBlock2D 217 | additional_residuals = {} 218 | if is_adapter and len(down_block_additional_residuals) > 0: 219 | additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0) 220 | 221 | sample, res_samples = downsample_block( 222 | hidden_states=sample, 223 | temb=emb, 224 | encoder_hidden_states=encoder_hidden_states, 225 | attention_mask=attention_mask, 226 | cross_attention_kwargs=cross_attention_kwargs, 227 | encoder_attention_mask=encoder_attention_mask, 228 | **additional_residuals, 229 | ) 230 | else: 231 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 232 | 233 | if is_adapter and len(down_block_additional_residuals) > 0: 234 | sample += down_block_additional_residuals.pop(0) 235 | down_block_res_samples += res_samples 236 | 237 | if is_controlnet: 238 | new_down_block_res_samples = () 239 | 240 | for down_block_res_sample, down_block_additional_residual in zip( 241 | down_block_res_samples, down_block_additional_residuals 242 | ): 243 | down_block_res_sample = down_block_res_sample + down_block_additional_residual 244 | new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) 245 | 246 | down_block_res_samples = new_down_block_res_samples 247 | 248 | # 4. mid 249 | if self.mid_block is not None: 250 | sample = self.mid_block( 251 | sample, 252 | emb, 253 | encoder_hidden_states=encoder_hidden_states, 254 | attention_mask=attention_mask, 255 | cross_attention_kwargs=cross_attention_kwargs, 256 | encoder_attention_mask=encoder_attention_mask, 257 | ) 258 | # only add this two lines to support T2I-Adapter-XL 259 | if is_adapter and len(down_block_additional_residuals) > 0: 260 | sample += down_block_additional_residuals.pop(0) 261 | 262 | if is_controlnet: 263 | sample = sample + mid_block_additional_residual 264 | 265 | # 5. up 266 | for i, upsample_block in enumerate(self.up_blocks): 267 | is_final_block = i == len(self.up_blocks) - 1 268 | 269 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 270 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 271 | 272 | # if we have not reached the final block and need to forward the 273 | # upsample size, we do it here 274 | if not is_final_block and forward_upsample_size: 275 | upsample_size = down_block_res_samples[-1].shape[2:] 276 | 277 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: 278 | sample = upsample_block( 279 | hidden_states=sample, 280 | temb=emb, 281 | res_hidden_states_tuple=res_samples, 282 | encoder_hidden_states=encoder_hidden_states, 283 | cross_attention_kwargs=cross_attention_kwargs, 284 | upsample_size=upsample_size, 285 | attention_mask=attention_mask, 286 | encoder_attention_mask=encoder_attention_mask, 287 | ) 288 | else: 289 | sample = upsample_block( 290 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size 291 | ) 292 | 293 | # 6. post-process 294 | if self.conv_norm_out: 295 | sample = self.conv_norm_out(sample) 296 | sample = self.conv_act(sample) 297 | sample = self.conv_out(sample) 298 | 299 | if not return_dict: 300 | return (sample,) 301 | 302 | return UNet2DConditionOutput(sample=sample) 303 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.1 2 | diffusers==0.19.3 3 | omegaconf 4 | transformers 5 | datasets 6 | pytorch_lightning 7 | gradio 8 | accelerate -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from diffusers import StableDiffusionXLAdapterPipeline, T2IAdapter, EulerAncestralDiscreteScheduler, AutoencoderKL 2 | from diffusers.utils import load_image, make_image_grid 3 | from controlnet_aux.lineart import LineartDetector 4 | import torch 5 | 6 | if __name__ == '__main__': 7 | # load adapter 8 | adapter = T2IAdapter.from_pretrained( 9 | "TencentARC/t2i-adapter-lineart-sdxl-1.0", torch_dtype=torch.float16, varient="fp16" 10 | ).to("cuda") 11 | 12 | # load euler_a scheduler 13 | model_id = 'stabilityai/stable-diffusion-xl-base-1.0' 14 | euler_a = EulerAncestralDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler") 15 | vae=AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) 16 | pipe = StableDiffusionXLAdapterPipeline.from_pretrained( 17 | model_id, vae=vae, adapter=adapter, scheduler=euler_a, torch_dtype=torch.float16, variant="fp16", 18 | ).to("cuda") 19 | pipe.enable_xformers_memory_efficient_attention() 20 | 21 | line_detector = LineartDetector.from_pretrained("lllyasviel/Annotators").to("cuda") 22 | 23 | url = "https://huggingface.co/Adapter/t2iadapter/resolve/main/figs_SDXLV1.0/org_lin.jpg" 24 | image = load_image(url) 25 | image = line_detector( 26 | image, detect_resolution=384, image_resolution=1024 27 | ) 28 | 29 | prompt = "Ice dragon roar, 4k photo" 30 | negative_prompt = "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured" 31 | gen_images = pipe( 32 | prompt=prompt, 33 | negative_prompt=negative_prompt, 34 | image=image, 35 | num_inference_steps=30, 36 | adapter_conditioning_scale=0.8, 37 | guidance_scale=7.5, 38 | ).images[0] 39 | gen_images.save('out_lin.png') 40 | -------------------------------------------------------------------------------- /train_sketch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | 16 | import argparse 17 | import functools 18 | import gc 19 | import logging 20 | import math 21 | import os 22 | import random 23 | import shutil 24 | from pathlib import Path 25 | 26 | import accelerate 27 | import numpy as np 28 | import torch 29 | import torch.nn.functional as F 30 | import torch.utils.checkpoint 31 | import transformers 32 | from accelerate import Accelerator 33 | from accelerate.logging import get_logger 34 | from accelerate.utils import ProjectConfiguration, set_seed 35 | from datasets import load_dataset 36 | from huggingface_hub import create_repo, upload_folder 37 | from packaging import version 38 | from PIL import Image 39 | from torchvision import transforms 40 | from tqdm.auto import tqdm 41 | from transformers import AutoTokenizer, PretrainedConfig 42 | 43 | import diffusers 44 | from diffusers import ( 45 | AutoencoderKL, 46 | DDPMScheduler, 47 | UNet2DConditionModel, 48 | UniPCMultistepScheduler, 49 | ) 50 | from diffusers.optimization import get_scheduler 51 | from diffusers.utils import check_min_version, is_wandb_available 52 | from diffusers.utils.import_utils import is_xformers_available 53 | 54 | from configs.utils import instantiate_from_config 55 | from omegaconf import OmegaConf 56 | from Adapter.extra_condition.model_edge import pidinet 57 | from models.unet import UNet 58 | from basicsr.utils import tensor2img 59 | import cv2 60 | from huggingface_hub import hf_hub_url 61 | import subprocess 62 | import shlex 63 | 64 | urls = { 65 | 'TencentARC/T2I-Adapter':[ 66 | 'third-party-models/body_pose_model.pth', 'third-party-models/table5_pidinet.pth' 67 | ] 68 | } 69 | 70 | if os.path.exists('checkpoints') == False: 71 | os.mkdir('checkpoints') 72 | for repo in urls: 73 | files = urls[repo] 74 | for file in files: 75 | url = hf_hub_url(repo, file) 76 | name_ckp = url.split('/')[-1] 77 | save_path = os.path.join('checkpoints',name_ckp) 78 | if os.path.exists(save_path) == False: 79 | subprocess.run(shlex.split(f'wget {url} -O {save_path}')) 80 | 81 | 82 | if is_wandb_available(): 83 | import wandb 84 | 85 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 86 | # check_min_version("0.20.0.dev0") 87 | 88 | logger = get_logger(__name__) 89 | 90 | def import_model_class_from_model_name_or_path( 91 | pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" 92 | ): 93 | text_encoder_config = PretrainedConfig.from_pretrained( 94 | pretrained_model_name_or_path, subfolder=subfolder, revision=revision 95 | ) 96 | model_class = text_encoder_config.architectures[0] 97 | 98 | if model_class == "CLIPTextModel": 99 | from transformers import CLIPTextModel 100 | 101 | return CLIPTextModel 102 | elif model_class == "CLIPTextModelWithProjection": 103 | from transformers import CLIPTextModelWithProjection 104 | 105 | return CLIPTextModelWithProjection 106 | else: 107 | raise ValueError(f"{model_class} is not supported.") 108 | 109 | 110 | def parse_args(input_args=None): 111 | parser = argparse.ArgumentParser(description="Simple example of a T2I-Adapter training script.") 112 | parser.add_argument( 113 | "--pretrained_model_name_or_path", 114 | type=str, 115 | default=None, 116 | required=True, 117 | help="Path to pretrained model or model identifier from huggingface.co/models.", 118 | ) 119 | parser.add_argument( 120 | "--pretrained_vae_model_name_or_path", 121 | type=str, 122 | default=None, 123 | help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.", 124 | ) 125 | parser.add_argument( 126 | "--revision", 127 | type=str, 128 | default=None, 129 | required=False, 130 | help=( 131 | "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be" 132 | " float32 precision." 133 | ), 134 | ) 135 | parser.add_argument( 136 | "--tokenizer_name", 137 | type=str, 138 | default=None, 139 | help="Pretrained tokenizer name or path if not the same as model_name", 140 | ) 141 | parser.add_argument( 142 | "--output_dir", 143 | type=str, 144 | default="experiments/adapter_xl_sketch", 145 | help="The output directory where the model predictions and checkpoints will be written.", 146 | ) 147 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 148 | parser.add_argument( 149 | "--resolution", 150 | type=int, 151 | default=1024, 152 | help=( 153 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 154 | " resolution" 155 | ), 156 | ) 157 | parser.add_argument( 158 | "--crops_coords_top_left_h", 159 | type=int, 160 | default=0, 161 | help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."), 162 | ) 163 | parser.add_argument( 164 | "--crops_coords_top_left_w", 165 | type=int, 166 | default=0, 167 | help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."), 168 | ) 169 | parser.add_argument( 170 | "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." 171 | ) 172 | parser.add_argument("--num_train_epochs", type=int, default=1) 173 | parser.add_argument( 174 | "--max_train_steps", 175 | type=int, 176 | default=None, 177 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 178 | ) 179 | parser.add_argument( 180 | "--checkpointing_steps", 181 | type=int, 182 | default=1000, 183 | help=( 184 | "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. " 185 | "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference." 186 | "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components." 187 | "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step" 188 | "instructions." 189 | ), 190 | ) 191 | parser.add_argument( 192 | "--gradient_accumulation_steps", 193 | type=int, 194 | default=1, 195 | help="Number of updates steps to accumulate before performing a backward/update pass.", 196 | ) 197 | parser.add_argument( 198 | "--gradient_checkpointing", 199 | action="store_true", 200 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 201 | ) 202 | parser.add_argument( 203 | "--learning_rate", 204 | type=float, 205 | default=5e-6, 206 | help="Initial learning rate (after the potential warmup period) to use.", 207 | ) 208 | parser.add_argument( 209 | "--scale_lr", 210 | action="store_true", 211 | default=False, 212 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 213 | ) 214 | parser.add_argument( 215 | "--config", 216 | type=str, 217 | default="configs/train/Adapter-XL-sketch.yaml", 218 | help=('config to load the train model and dataset'), 219 | ) 220 | parser.add_argument( 221 | "--lr_scheduler", 222 | type=str, 223 | default="constant", 224 | help=( 225 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 226 | ' "constant", "constant_with_warmup"]' 227 | ), 228 | ) 229 | parser.add_argument( 230 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 231 | ) 232 | parser.add_argument( 233 | "--lr_num_cycles", 234 | type=int, 235 | default=1, 236 | help="Number of hard resets of the lr in cosine_with_restarts scheduler.", 237 | ) 238 | parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") 239 | parser.add_argument( 240 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." 241 | ) 242 | parser.add_argument( 243 | "--dataloader_num_workers", 244 | type=int, 245 | default=0, 246 | help=( 247 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." 248 | ), 249 | ) 250 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 251 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 252 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 253 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 254 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 255 | parser.add_argument( 256 | "--logging_dir", 257 | type=str, 258 | default="logs", 259 | help=( 260 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 261 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 262 | ), 263 | ) 264 | parser.add_argument( 265 | "--allow_tf32", 266 | action="store_true", 267 | help=( 268 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" 269 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 270 | ), 271 | ) 272 | parser.add_argument( 273 | "--report_to", 274 | type=str, 275 | default="tensorboard", 276 | help=( 277 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 278 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 279 | ), 280 | ) 281 | parser.add_argument( 282 | "--mixed_precision", 283 | type=str, 284 | default=None, 285 | choices=["no", "fp16", "bf16"], 286 | help=( 287 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 288 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 289 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 290 | ), 291 | ) 292 | parser.add_argument( 293 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." 294 | ) 295 | parser.add_argument( 296 | "--set_grads_to_none", 297 | action="store_true", 298 | help=( 299 | "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" 300 | " behaviors, so disable this argument if it causes any problems. More info:" 301 | " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" 302 | ), 303 | ) 304 | parser.add_argument( 305 | "--proportion_empty_prompts", 306 | type=float, 307 | default=0, 308 | help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", 309 | ) 310 | parser.add_argument( 311 | "--tracker_project_name", 312 | type=str, 313 | default="sd_xl_train_t2i_adapter ", 314 | help=( 315 | "The `project_name` argument passed to Accelerator.init_trackers for" 316 | " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" 317 | ), 318 | ) 319 | 320 | if input_args is not None: 321 | args = parser.parse_args(input_args) 322 | else: 323 | args = parser.parse_args() 324 | 325 | if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: 326 | raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") 327 | 328 | if args.resolution % 8 != 0: 329 | raise ValueError( 330 | "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and T2I-Adapter." 331 | ) 332 | 333 | return args 334 | 335 | 336 | # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt 337 | def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train=True): 338 | prompt_embeds_list = [] 339 | 340 | captions = [] 341 | for caption in prompt_batch: 342 | if random.random() < proportion_empty_prompts: 343 | captions.append("") 344 | elif isinstance(caption, str): 345 | captions.append(caption) 346 | elif isinstance(caption, (list, np.ndarray)): 347 | # take a random caption if there are multiple 348 | captions.append(random.choice(caption) if is_train else caption[0]) 349 | 350 | with torch.no_grad(): 351 | for tokenizer, text_encoder in zip(tokenizers, text_encoders): 352 | text_inputs = tokenizer( 353 | captions, 354 | padding="max_length", 355 | max_length=tokenizer.model_max_length, 356 | truncation=True, 357 | return_tensors="pt", 358 | ) 359 | text_input_ids = text_inputs.input_ids 360 | prompt_embeds = text_encoder( 361 | text_input_ids.to(text_encoder.device), 362 | output_hidden_states=True, 363 | ) 364 | 365 | # We are only ALWAYS interested in the pooled output of the final text encoder 366 | pooled_prompt_embeds = prompt_embeds[0] 367 | prompt_embeds = prompt_embeds.hidden_states[-2] 368 | bs_embed, seq_len, _ = prompt_embeds.shape 369 | prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) 370 | prompt_embeds_list.append(prompt_embeds) 371 | 372 | prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) 373 | pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) 374 | return prompt_embeds, pooled_prompt_embeds 375 | 376 | 377 | def random_threshold(edge, low_threshold=0.3, high_threshold=0.8): 378 | threshold = round(random.uniform(low_threshold, high_threshold), 1) 379 | edge = edge > threshold 380 | return edge 381 | 382 | 383 | def main(args): 384 | logging_dir = Path(args.output_dir, args.logging_dir) 385 | 386 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) 387 | 388 | accelerator = Accelerator( 389 | gradient_accumulation_steps=args.gradient_accumulation_steps, 390 | mixed_precision=args.mixed_precision, 391 | log_with=args.report_to, 392 | project_config=accelerator_project_config, 393 | ) 394 | 395 | # Make one log on every process with the configuration for debugging. 396 | logging.basicConfig( 397 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 398 | datefmt="%m/%d/%Y %H:%M:%S", 399 | level=logging.INFO, 400 | ) 401 | logger.info(accelerator.state, main_process_only=False) 402 | if accelerator.is_local_main_process: 403 | transformers.utils.logging.set_verbosity_warning() 404 | diffusers.utils.logging.set_verbosity_info() 405 | else: 406 | transformers.utils.logging.set_verbosity_error() 407 | diffusers.utils.logging.set_verbosity_error() 408 | 409 | # If passed along, set the training seed now. 410 | if args.seed is not None: 411 | set_seed(args.seed) 412 | 413 | # Handle the repository creation 414 | if accelerator.is_main_process: 415 | if args.output_dir is not None: 416 | os.makedirs(args.output_dir, exist_ok=True) 417 | 418 | # Load the tokenizers 419 | tokenizer_one = AutoTokenizer.from_pretrained( 420 | args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False 421 | ) 422 | tokenizer_two = AutoTokenizer.from_pretrained( 423 | args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False 424 | ) 425 | 426 | # import correct text encoder classes 427 | text_encoder_cls_one = import_model_class_from_model_name_or_path( 428 | args.pretrained_model_name_or_path, args.revision 429 | ) 430 | text_encoder_cls_two = import_model_class_from_model_name_or_path( 431 | args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" 432 | ) 433 | 434 | # Load scheduler and models 435 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") 436 | text_encoder_one = text_encoder_cls_one.from_pretrained( 437 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision 438 | ) 439 | text_encoder_two = text_encoder_cls_two.from_pretrained( 440 | args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision 441 | ) 442 | vae_path = ( 443 | args.pretrained_model_name_or_path 444 | if args.pretrained_vae_model_name_or_path is None 445 | else args.pretrained_vae_model_name_or_path 446 | ) 447 | vae = AutoencoderKL.from_pretrained( 448 | vae_path, 449 | subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, 450 | revision=args.revision, 451 | ) 452 | unet = UNet.from_pretrained( 453 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision 454 | ) 455 | 456 | # `accelerate` 0.16.0 will have better support for customized saving 457 | if version.parse(accelerate.__version__) >= version.parse("0.16.0"): 458 | # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format 459 | def save_model_hook(models, weights, output_dir): 460 | i = len(weights) - 1 461 | 462 | while len(weights) > 0: 463 | weights.pop() 464 | model = models[i] 465 | torch.save(model.state_dict(), os.path.join(output_dir, 'model_%02d.pth'%i)) 466 | i -= 1 467 | 468 | accelerator.register_save_state_pre_hook(save_model_hook) 469 | 470 | vae.requires_grad_(False) 471 | text_encoder_one.requires_grad_(False) 472 | text_encoder_two.requires_grad_(False) 473 | 474 | if args.enable_xformers_memory_efficient_attention: 475 | if is_xformers_available(): 476 | import xformers 477 | 478 | xformers_version = version.parse(xformers.__version__) 479 | if xformers_version == version.parse("0.0.16"): 480 | logger.warn( 481 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." 482 | ) 483 | unet.enable_xformers_memory_efficient_attention() 484 | else: 485 | raise ValueError("xformers is not available. Make sure it is installed correctly") 486 | 487 | if args.gradient_checkpointing: 488 | unet.enable_gradient_checkpointing() 489 | 490 | # Enable TF32 for faster training on Ampere GPUs, 491 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 492 | if args.allow_tf32: 493 | torch.backends.cuda.matmul.allow_tf32 = True 494 | 495 | if args.scale_lr: 496 | args.learning_rate = ( 497 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 498 | ) 499 | 500 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 501 | if args.use_8bit_adam: 502 | try: 503 | import bitsandbytes as bnb 504 | except ImportError: 505 | raise ImportError( 506 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." 507 | ) 508 | 509 | optimizer_class = bnb.optim.AdamW8bit 510 | else: 511 | optimizer_class = torch.optim.AdamW 512 | 513 | # configs 514 | config = OmegaConf.load(args.config) 515 | # Optimizer creation 516 | adapter_config = config.model.params.adapter_config 517 | adapter = instantiate_from_config(adapter_config).cuda() 518 | params_to_optimize = adapter.parameters() 519 | optimizer = optimizer_class( 520 | params_to_optimize, 521 | lr=args.learning_rate, 522 | betas=(args.adam_beta1, args.adam_beta2), 523 | weight_decay=args.adam_weight_decay, 524 | eps=args.adam_epsilon, 525 | ) 526 | # load sketch model 527 | sketch_model = pidinet() 528 | ckp = torch.load('checkpoints/table5_pidinet.pth', map_location='cpu')['state_dict'] 529 | sketch_model.load_state_dict({k.replace('module.', ''): v for k, v in ckp.items()}, strict=True) 530 | sketch_model = sketch_model.cuda() 531 | for param in sketch_model.parameters(): 532 | param.required_grad = False 533 | 534 | # For mixed precision training we cast the text_encoder and vae weights to half-precision 535 | # as these models are only used for inference, keeping weights in full precision is not required. 536 | weight_dtype = torch.float32 537 | if accelerator.mixed_precision == "fp16": 538 | weight_dtype = torch.float16 539 | elif accelerator.mixed_precision == "bf16": 540 | weight_dtype = torch.bfloat16 541 | 542 | # Move vae, unet and text_encoder to device and cast to weight_dtype 543 | # The VAE is in float32 to avoid NaN losses. 544 | if args.pretrained_vae_model_name_or_path is not None: 545 | vae.to(accelerator.device, dtype=weight_dtype) 546 | else: 547 | vae.to(accelerator.device, dtype=torch.float32) 548 | unet.to(accelerator.device, dtype=weight_dtype) 549 | text_encoder_one.to(accelerator.device, dtype=weight_dtype) 550 | text_encoder_two.to(accelerator.device, dtype=weight_dtype) 551 | 552 | # Here, we compute not just the text embeddings but also the additional embeddings 553 | # needed for the SD XL UNet to operate. 554 | def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizers, is_train=True): 555 | original_size = (args.resolution, args.resolution) 556 | target_size = (args.resolution, args.resolution) 557 | crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w) 558 | prompt_batch = batch['txt'] 559 | 560 | prompt_embeds, pooled_prompt_embeds = encode_prompt( 561 | prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train 562 | ) 563 | add_text_embeds = pooled_prompt_embeds 564 | 565 | # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids 566 | add_time_ids = list(original_size + crops_coords_top_left + target_size) 567 | add_time_ids = torch.tensor([add_time_ids]) 568 | 569 | prompt_embeds = prompt_embeds.to(accelerator.device) 570 | add_text_embeds = add_text_embeds.to(accelerator.device) 571 | add_time_ids = add_time_ids.repeat(len(prompt_batch), 1) 572 | add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype) 573 | unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} 574 | 575 | return {"prompt_embeds": prompt_embeds}, unet_added_cond_kwargs#, **unet_added_cond_kwargs} 576 | 577 | # Let's first compute all the embeddings so that we can free up the text encoders 578 | # from memory. 579 | text_encoders = [text_encoder_one, text_encoder_two] 580 | tokenizers = [tokenizer_one, tokenizer_two] 581 | gc.collect() 582 | torch.cuda.empty_cache() 583 | 584 | # data 585 | data = instantiate_from_config(config.data) 586 | train_dataloader = data.train_dataloader() 587 | 588 | # Scheduler and math around the number of training steps. 589 | overrode_max_train_steps = False 590 | # num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 591 | num_update_steps_per_epoch = math.ceil(1e7 / args.gradient_accumulation_steps) 592 | if args.max_train_steps is None: 593 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 594 | overrode_max_train_steps = True 595 | 596 | lr_scheduler = get_scheduler( 597 | args.lr_scheduler, 598 | optimizer=optimizer, 599 | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, 600 | num_training_steps=args.max_train_steps * accelerator.num_processes, 601 | num_cycles=args.lr_num_cycles, 602 | power=args.lr_power, 603 | ) 604 | 605 | # Prepare everything with our `accelerator`. 606 | adapter, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 607 | adapter, optimizer, train_dataloader, lr_scheduler 608 | ) 609 | 610 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 611 | # num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 612 | # if overrode_max_train_steps: 613 | # args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 614 | # # Afterwards we recalculate our number of training epochs 615 | # args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 616 | 617 | # We need to initialize the trackers we use, and also store our configuration. 618 | # The trackers initializes automatically on the main process. 619 | if accelerator.is_main_process: 620 | tracker_config = dict(vars(args)) 621 | 622 | accelerator.init_trackers(args.tracker_project_name, config=tracker_config) 623 | 624 | # Train! 625 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 626 | 627 | logger.info("***** Running training *****") 628 | logger.info(f" Num Epochs = {args.num_train_epochs}") 629 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 630 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 631 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 632 | logger.info(f" Total optimization steps = {args.max_train_steps}") 633 | global_step = 0 634 | first_epoch = 0 635 | 636 | initial_global_step = 0 637 | 638 | progress_bar = tqdm( 639 | range(0, args.max_train_steps), 640 | initial=initial_global_step, 641 | desc="Steps", 642 | # Only show the progress bar once on each machine. 643 | disable=not accelerator.is_local_main_process, 644 | ) 645 | 646 | image_logs = None 647 | for epoch in range(first_epoch, args.num_train_epochs): 648 | for step, batch in enumerate(train_dataloader): 649 | with accelerator.accumulate(adapter): 650 | # norm input 651 | batch["jpg"] = batch["jpg"].cuda() 652 | batch["jpg"] = batch["jpg"]*2.-1. 653 | # get sketch 654 | edge = 0.5 * batch['jpg'] + 0.5 655 | edge = sketch_model(edge)[-1] 656 | # add random threshold and random masking 657 | edge = random_threshold(edge).to(dtype=weight_dtype) 658 | 659 | # Convert images to latent space 660 | if args.pretrained_vae_model_name_or_path is not None: 661 | pixel_values = batch["jpg"].to(dtype=weight_dtype) 662 | else: 663 | pixel_values = batch["jpg"] 664 | latents = vae.encode(pixel_values).latent_dist.sample() 665 | latents = latents * vae.config.scaling_factor 666 | if args.pretrained_vae_model_name_or_path is None: 667 | latents = latents.to(weight_dtype) 668 | 669 | # Sample noise that we'll add to the latents 670 | noise = torch.randn_like(latents) 671 | bsz = latents.shape[0] 672 | 673 | # Cubic sampling to sample a random timestep for each image 674 | timesteps = torch.rand((bsz, ), device=latents.device) 675 | timesteps = (1 - timesteps**3) * noise_scheduler.config.num_train_timesteps 676 | timesteps = timesteps.long() 677 | 678 | # Add noise to the latents according to the noise magnitude at each timestep 679 | # (this is the forward diffusion process) 680 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 681 | 682 | # get text embedding 683 | prompt_embeds, unet_added_cond_kwargs = compute_embeddings( 684 | batch=batch,proportion_empty_prompts=0,text_encoders=text_encoders,tokenizers=tokenizers 685 | ) 686 | 687 | # Adapter conditioning. 688 | down_block_additional_residuals = adapter( 689 | edge 690 | ) 691 | 692 | # Predict the noise residual 693 | model_pred = unet( 694 | noisy_latents, 695 | timesteps, 696 | encoder_hidden_states=prompt_embeds["prompt_embeds"], 697 | added_cond_kwargs=unet_added_cond_kwargs, 698 | down_block_additional_residuals=[ 699 | sample.to(dtype=weight_dtype) for sample in down_block_additional_residuals 700 | ] 701 | # down_block_additional_residuals, 702 | ).sample 703 | 704 | # Get the target for loss depending on the prediction type 705 | if noise_scheduler.config.prediction_type == "epsilon": 706 | target = noise 707 | elif noise_scheduler.config.prediction_type == "v_prediction": 708 | target = noise_scheduler.get_velocity(latents, noise, timesteps) 709 | else: 710 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 711 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 712 | 713 | accelerator.backward(loss) 714 | if accelerator.sync_gradients: 715 | params_to_clip = adapter.parameters() 716 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) 717 | optimizer.step() 718 | lr_scheduler.step() 719 | optimizer.zero_grad(set_to_none=args.set_grads_to_none) 720 | 721 | 722 | # Checks if the accelerator has performed an optimization step behind the scenes 723 | if accelerator.sync_gradients: 724 | progress_bar.update(1) 725 | global_step += 1 726 | 727 | if accelerator.is_main_process: 728 | if global_step % args.checkpointing_steps == 0: 729 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 730 | accelerator.save_state(save_path) 731 | logger.info(f"Saved state to {save_path}") 732 | 733 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 734 | progress_bar.set_postfix(**logs) 735 | accelerator.log(logs, step=global_step) 736 | 737 | if global_step >= args.max_train_steps: 738 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 739 | accelerator.save_state(save_path) 740 | logger.info(f"Saved state to {save_path}") 741 | break 742 | 743 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 744 | accelerator.save_state(save_path) 745 | logger.info(f"Saved state to {save_path}") 746 | 747 | 748 | if __name__ == "__main__": 749 | args = parse_args() 750 | main(args) 751 | --------------------------------------------------------------------------------