├── Comfyui_CustomNet_node.py ├── LICENSE ├── README.md ├── __init__.py ├── configs ├── config_customnet.yaml └── config_customnet_inpaint.yaml ├── custom_net ├── __init__.py ├── attention.py ├── attentioni.py ├── autoencoder.py ├── classifier.py ├── customnet.py ├── customnet_inpaint.py ├── customnet_util.py ├── dddim.py ├── ddim.py ├── ddpm.py ├── distributions.py ├── ema.py ├── model.py ├── modules.py ├── openaimodel.py ├── openaimodeli.py ├── plms.py ├── sampling_util.py ├── util.py └── x_transformer.py ├── example ├── inpainting.png ├── normal.png ├── polar.png ├── position.png ├── workflow.json └── zaimuth.png ├── gradio_utils.py ├── lr_scheduler.py ├── pyproject.toml └── requirements.txt /Comfyui_CustomNet_node.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | import sys 4 | import os 5 | import numpy as np 6 | import torch 7 | from omegaconf import OmegaConf 8 | from pytorch_lightning import seed_everything 9 | import copy 10 | from .custom_net.ddim import DDIMSampler 11 | from einops import rearrange 12 | import math 13 | from PIL import Image, ImageDraw 14 | from .custom_net.customnet_util import instantiate_from_config, img2tensor 15 | from .gradio_utils import load_preprocess_model, preprocess_image 16 | from comfy.utils import common_upscale 17 | import folder_paths 18 | from folder_paths import base_path 19 | sys.path.append(os.path.join(base_path,"custom_nodes","ComfyUI_CustomNet","custom_net")) 20 | cur_path = os.path.dirname(os.path.abspath(__file__)) 21 | 22 | def get_instance_path(path): 23 | os_path = os.path.normpath(path) 24 | if sys.platform.startswith('win32'): 25 | os_path = os_path.replace('\\', "/") 26 | return os_path 27 | 28 | def tensor_to_image( tensor): 29 | image_np = tensor.squeeze().mul(255).clamp(0, 255).byte().numpy() 30 | image = Image.fromarray(image_np, mode='RGB') 31 | return image 32 | 33 | def upscale_to_pil(img_tensor, width, height): 34 | samples = img_tensor.movedim(-1, 1) 35 | img = common_upscale(samples, width, height, "nearest-exact", "center") 36 | samples = img.movedim(1, -1) 37 | img_pil = tensor_to_image(samples) 38 | return img_pil 39 | 40 | def preprocess_input(preprocess_model, input_image): # reg background 41 | processed_image = preprocess_image(preprocess_model, input_image) 42 | # input_img = (processed_image / 255.0).astype(np.float32) 43 | return processed_image 44 | # return processed_image, processed_image 45 | 46 | 47 | def prepare_data(device, input_image, x0, y0, x1, y1, polar, azimuth, prompt,bg_image,use_inpaint): 48 | 49 | # if input_image.size[0] != 256 or input_image.size[1] != 256: 50 | # input_image = input_image.resize((256, 256)) 51 | # input_image = np.array(input_image) 52 | # img_cond = img2tensor(input_image, bgr2rgb=False, float32=True).unsqueeze(0) / 255. 53 | input_image = np.array(input_image) 54 | img_cond = img2tensor(input_image, bgr2rgb=False, float32=True).unsqueeze(0) / 255. 55 | img_cond = img_cond * 2 - 1 56 | 57 | img_location = copy.deepcopy(img_cond) 58 | input_im_padding = torch.ones_like(img_location) 59 | 60 | x_0 = min(x0, x1) 61 | x_1 = max(x0, x1) 62 | y_0 = min(y0, y1) 63 | y_1 = max(y0, y1) 64 | 65 | # print(x0, y0, x1, y1) 66 | # print(x_0, y_0, x_1, y_1) 67 | img_location = torch.nn.functional.interpolate(img_location, (y_1 - y_0, x_1 - x_0), mode="bilinear") 68 | input_im_padding[:, :, y_0:y_1, x_0:x_1] = img_location 69 | img_location = input_im_padding 70 | 71 | if use_inpaint: 72 | bg_image = np.array(bg_image) 73 | bg_cond=img2tensor( bg_image, bgr2rgb=False, float32=True) / 255. 74 | bg_cond =bg_cond* 2 - 1 75 | bg_cond =bg_cond.unsqueeze(0) 76 | bg_cond[:, :, y_0:y_1, x_0:x_1] = 1 77 | else: 78 | bg_cond=img_cond #no use 79 | 80 | T = torch.tensor( 81 | [[math.radians(polar), math.sin(math.radians(azimuth)), math.cos(math.radians(azimuth)), 0.0]]).unsqueeze(1) 82 | 83 | 84 | batch = { 85 | "image_cond": img_cond.to(device), 86 | "image_location": img_location.to(device), 87 | "bg_cond":bg_cond.to(device), 88 | 'T': T.to(device), 89 | 'text': [prompt], 90 | } 91 | return batch 92 | 93 | 94 | device = torch.device("cuda") 95 | class CustomNet_LoadModel: 96 | def __init__(self): 97 | pass 98 | 99 | @classmethod 100 | def INPUT_TYPES(cls): 101 | return { 102 | "required": { 103 | "ckpt_name": (["none"] + folder_paths.get_filename_list("checkpoints"),), 104 | } 105 | } 106 | 107 | RETURN_TYPES = ("MODEL","DICT",) 108 | RETURN_NAMES = ("model","info") 109 | FUNCTION = "main_loader" 110 | CATEGORY = "CustomNet_Plus" 111 | 112 | def main_loader(self,ckpt_name,): 113 | # load model 114 | ckpt = folder_paths.get_full_path("checkpoints", ckpt_name) 115 | if "inpaint"in ckpt_name: 116 | path_yaml = os.path.join(cur_path, "configs", "config_customnet_inpaint.yaml") 117 | use_inpaint=True 118 | else: 119 | path_yaml = os.path.join(cur_path, "configs", "config_customnet.yaml") 120 | use_inpaint = False 121 | path_yaml = get_instance_path(path_yaml) 122 | config = OmegaConf.load(path_yaml) 123 | model = instantiate_from_config(config.model) 124 | ckpt_load = torch.load(ckpt, map_location="cpu") 125 | model.load_state_dict(ckpt_load, strict=False) 126 | del ckpt_load 127 | model = model.to(device) 128 | info={"inpaint":use_inpaint,} 129 | return (model,info,) 130 | 131 | class CustomNet_Sampler: 132 | 133 | def __init__(self): 134 | pass 135 | 136 | @classmethod 137 | def INPUT_TYPES(cls): 138 | return { 139 | "required": { 140 | "model":("MODEL",), 141 | "info":("DICT",), 142 | "image": ("IMAGE",), 143 | "prompt": ("STRING", {"multiline": True, "default": "on the seaside"}), 144 | "neg_prompt": ("STRING", {"multiline": True,"default": ""}), 145 | "steps": ("INT", {"default": 50, "min": 1, "max": 1024,"step": 1,"display": "number"}), 146 | "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), 147 | "width": ("INT", {"default": 256, "min": 128, "max": 512, "step": 64, "display": "number"}), 148 | "height": ("INT", {"default": 256, "min": 128, "max": 512, "step": 64, "display": "number"}), 149 | "obj_x": ("INT", {"default": 50, "min": 0, "max": 256, "step": 1, "display": "number"}), 150 | "obj_y": ("INT", {"default": 50, "min": 0, "max": 256, "step": 1, "display": "number"}), 151 | "bg_x": ("INT", {"default": 200, "min": 0, "max": 256, "step": 1, "display": "number"}), 152 | "bg_y": ("INT", {"default": 200, "min": 0, "max": 256, "step": 1, "display": "number"}), 153 | "polar": ("FLOAT", {"default": 0, "min": -30, "max": 30, "step": -0.5, "display": "number"}), 154 | "azimuth": ("FLOAT", {"default": 0, "min": -60, "max": 30, "step": -0.5, "display": "number"}), 155 | "batch_size": ("INT", {"default": 1, "min": 1, "max": 256, "step": 1, "display": "number"}), }, 156 | "optional": { 157 | "bg_image": ("IMAGE",), } 158 | 159 | } 160 | 161 | RETURN_TYPES = ("IMAGE",) 162 | RETURN_NAMES = ("output_image", ) 163 | FUNCTION = "customnet_main" 164 | CATEGORY = "CustomNet_Plus" 165 | 166 | 167 | @torch.no_grad() 168 | def customnet_main(self,model,info,image, prompt,neg_prompt,steps,seed,width,height,obj_x, obj_y, bg_x, bg_y, polar, azimuth,batch_size,**kwargs): 169 | preprocess_model = load_preprocess_model() 170 | sampler = DDIMSampler(model, device=device) 171 | use_inpaint=info["inpaint"] 172 | print(f"inpaint is {use_inpaint}") 173 | input_image = upscale_to_pil(image, width, height) # comfy upscale tensor2pil 174 | if use_inpaint: 175 | bg_image=kwargs.get("bg_image") 176 | bg_image=upscale_to_pil(bg_image, width, height) 177 | else: 178 | bg_image=input_image 179 | 180 | input_image = preprocess_input(preprocess_model, input_image) # using interface reg img 181 | seed_everything(seed) 182 | 183 | batch = prepare_data(device, input_image, obj_x, obj_y, bg_x, bg_y, polar, azimuth, prompt,bg_image,use_inpaint) 184 | 185 | c = model.get_learned_conditioning(batch["image_cond"]) 186 | c = torch.cat([c, batch["T"]], dim=-1) 187 | c = model.cc_projection(c) 188 | if use_inpaint: 189 | bg_concat = model.encode_first_stage(batch["bg_cond"]).mode().detach() 190 | 191 | ## condition 192 | cond = {} 193 | cond['c_concat'] = [model.encode_first_stage((batch["image_location"])).mode().detach()] 194 | cond['c_crossattn'] = [c] 195 | text_embedding = model.text_encoder(batch["text"]) 196 | cond["c_crossattn"].append(text_embedding) 197 | if use_inpaint: 198 | cond['c_concat'].append(bg_concat) 199 | 200 | ## null-condition 201 | uc = {} 202 | uc['c_concat'] = [torch.zeros(1, 4, 32, 32).to(c.device)] 203 | uc['c_crossattn'] = [torch.zeros_like(c).to(c.device)] 204 | uc_text_embedding = model.text_encoder([neg_prompt]) 205 | uc['c_crossattn'].append(uc_text_embedding) 206 | if use_inpaint: 207 | uc["c_concat"].append(bg_concat) 208 | 209 | 210 | ## sample 211 | shape = [4, 32, 32] 212 | samples_latents, _ = sampler.sample( 213 | S=steps, 214 | batch_size=batch_size, 215 | shape=shape, 216 | verbose=False, 217 | unconditional_guidance_scale=999, # useless 218 | conditioning=cond, 219 | unconditional_conditioning=uc, 220 | cfg_type=0, 221 | cfg_scale_dict={"img": 0., "text": 0., "all": 3.0} 222 | ) 223 | 224 | x_samples = model.decode_first_stage(samples_latents) 225 | x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0).cpu().numpy() 226 | x_samples = rearrange(255.0 * x_samples[0], 'c h w -> h w c').astype(np.uint8) 227 | image = Image.fromarray(x_samples) 228 | image = torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0) 229 | return (image,) 230 | 231 | 232 | NODE_CLASS_MAPPINGS = { 233 | "CustomNet_LoadModel":CustomNet_LoadModel, 234 | "CustomNet_Sampler": CustomNet_Sampler 235 | } 236 | 237 | NODE_DISPLAY_NAME_MAPPINGS = { 238 | "CustomNet_LoadModel":"CustomNet_LoadModel", 239 | "CustomNet_Sampler": "CustomNet_Sampler" 240 | } 241 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A CustomNet node for ComfyUI 2 | A CustomNet node for ComfyUI 3 | 4 | CustomNet: Object Customization with Variable-Viewpoints in Text-to-Image Diffusion Models. 5 | CustomNet From: [CustomNet](https://github.com/TencentARC/CustomNet) 6 | 7 | Update 8 | ---- 9 | 2024/08/11 10 | --同步官方的内绘模型及代码,优化模型加载方式,现在模型跟常规的SD模型在一个地方,优化模型加载方式, 11 | 12 | 1.Installation 13 | ----- 14 | In the .\ComfyUI \ custom_node directory, run the following: 15 | 16 | ``` python 17 | git clone https://github.com/smthemex/ComfyUI_CustomNet.git 18 | ``` 19 | 2.requirements 20 | ---- 21 | 每个人的环境不同,但是carvekit-colab是必须装的,是内置的脱底工具包,懒得去掉了,你可以先用其他sam节点处理物体图。首次运行,会安装carvekit-colab的模型文件,无梯子的注意。 22 | need carvekit-colab==4.1.0 23 | 24 | 3 Download the model 25 | ---- 26 | 3.1 normal: 27 | 下载customnet_v1.pth模型,并放在ComfyUI/models/checkpoints/目录下: 28 | Download the weights of Customnet “customnet_v1.pth” and put it to “ComfyUI/models/checkpoints/” [link](https://huggingface.co/TencentARC/CustomNet/tree/main) 29 | ``` 30 | └── ComfyUI/models/checkpoints/ 31 | ├── customnet_v1.pth 32 | ``` 33 | 3.2 inpainting: 34 | 下载customnet_inpaint_v1.pt模型,并放在ComfyUI/models/checkpoints/目录下: 35 | Download the weights of Customnet “customnet_inpaint_v1.pt” and put it to “ComfyUI/models/checkpoints/” [link](https://huggingface.co/TencentARC/CustomNet/tree/main) 36 | ``` 37 | └── ComfyUI/models/checkpoints/ 38 | ├── customnet_inpaint_v1.pt 39 | ``` 40 | 3.3 clip and carvekit: 41 | 首次使用会下载3个的模型文件,须连外网:,分别是 42 | clip:文件目录一般在C:/User/你的用户名/.cache/clip/ViT-L-14.pt 43 | carvekit的2个脱底模型: 44 | 目录C:/User/你的用户名/.cache/carvekit/checkpoints/fba/fba_matting.pth 45 | 目录C:/User/你的用户名/.cache/carvekit/checkpoints/tracer_b7/tracer_b7.pth 46 | 47 | 6 Tips 48 | ---- 49 | ---白底的物体图得到最好的效果; 50 | ---底模训练就是256的,所以没法做大图,除非腾讯把大图的模型放出来。 51 | ---The object image with a white background achieves the best effect; 52 | 53 | 5 Example 54 | ----- 55 | normal 常规脱底置于提示测的背景前面,最新的演示; Latest Presentation 56 | ![](https://github.com/smthemex/ComfyUI_CustomNet/blob/main/example/normal.png) 57 | 58 | inpainting 内绘模型,最新的演示; Latest Presentation 59 | ![](https://github.com/smthemex/ComfyUI_CustomNet/blob/main/example/inpainting.png) 60 | 61 | polar 主体上下视角 既往的演示, Previous demonstrations 62 | ![](https://github.com/smthemex/ComfyUI_CustomNet/blob/main/example/polar.png) 63 | 64 | zaimuth 主体左右视角 既往的演示, Previous demonstrations 65 | ![](https://github.com/smthemex/ComfyUI_CustomNet/blob/main/example/zaimuth.png) 66 | 67 | position X0 Y0 主体在背景中的位置 既往的演示, Previous demonstrations 68 | ![](https://github.com/smthemex/ComfyUI_CustomNet/blob/main/example/position.png) 69 | 70 | 71 | 6 Citation 72 | ------ 73 | 74 | ``` python 75 | @misc{yuan2023customnet, 76 | title={CustomNet: Zero-shot Object Customization with Variable-Viewpoints in Text-to-Image Diffusion Models}, 77 | author={Ziyang Yuan and Mingdeng Cao and Xintao Wang and Zhongang Qi and Chun Yuan and Ying Shan}, 78 | year={2023}, 79 | eprint={2310.19784}, 80 | archivePrefix={arXiv}, 81 | primaryClass={cs.CV} 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from .custom_net.customnet import CustomNet 3 | from .Comfyui_CustomNet_node import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 4 | 5 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] 6 | -------------------------------------------------------------------------------- /configs/config_customnet.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-05 3 | target: custom_net.customnet.CustomNet 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "image_target" 11 | cond_stage_key: "image_cond" 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: hybrid 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: false 19 | use_cond_concat: true 20 | use_bbox_mask: false 21 | use_bg_inpainting: false 22 | learning_rate_scale: 10 23 | ucg_training: 24 | txt: 0.15 25 | 26 | sd_15_ckpt: #"v1-5-pruned-emaonly.ckpt" 27 | 28 | unet_config: 29 | target: custom_net.openaimodeli.UNetModel 30 | params: 31 | image_size: 32 # unused 32 | in_channels: 8 33 | out_channels: 4 34 | model_channels: 320 35 | attention_resolutions: [ 4, 2, 1 ] 36 | num_res_blocks: 2 37 | channel_mult: [ 1, 2, 4, 4 ] 38 | num_heads: 8 39 | use_spatial_transformer: True 40 | transformer_depth: 1 41 | context_dim: 768 42 | use_checkpoint: True 43 | legacy: False 44 | 45 | 46 | first_stage_config: 47 | target: custom_net.autoencoder.AutoencoderKL 48 | params: 49 | embed_dim: 4 50 | monitor: val/rec_loss 51 | ddconfig: 52 | double_z: true 53 | z_channels: 4 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | - 4 63 | num_res_blocks: 2 64 | attn_resolutions: [] 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | 69 | cond_stage_config: 70 | target: custom_net.modules.FrozenCLIPImageEmbedder 71 | 72 | 73 | text_encoder_config: 74 | target: custom_net.modules.FrozenCLIPEmbedder 75 | params: 76 | version: openai/clip-vit-large-patch14 77 | 78 | 79 | ## this is a template dataset 80 | train_data: 81 | target: data.dataset.Dataset 82 | params: 83 | image_size: 256 84 | root: examples/dataset/ 85 | 86 | 87 | train_dataloader: 88 | batch_size: 12 89 | num_workers: 8 90 | 91 | 92 | 93 | 94 | lightning: 95 | find_unused_parameters: false 96 | metrics_over_trainsteps_checkpoint: True 97 | modelcheckpoint: 98 | params: 99 | every_n_train_steps: 10000 100 | save_top_k: -1 101 | monitor: null 102 | callbacks: 103 | image_logger: 104 | target: main.ImageLogger 105 | params: 106 | batch_frequency: 2500 107 | max_images: 32 108 | increase_log_steps: False 109 | log_first_step: True 110 | log_images_kwargs: 111 | use_ema_scope: False 112 | inpaint: False 113 | plot_progressive_rows: False 114 | plot_diffusion_rows: False 115 | N: 32 116 | unconditional_guidance_scale: 3.0 117 | unconditional_guidance_label: [""] 118 | 119 | trainer: 120 | benchmark: True 121 | limit_val_batches: 0 122 | num_sanity_val_steps: 0 123 | accumulate_grad_batches: 1 124 | -------------------------------------------------------------------------------- /configs/config_customnet_inpaint.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-05 3 | target: custom_net.customnet_inpaint.CustomNet 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "image_target" 11 | cond_stage_key: "image_cond" 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: hybrid 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: false 19 | use_cond_concat: true 20 | use_bbox_mask: false 21 | use_bg_inpainting: true 22 | learning_rate_scale: 10 23 | 24 | ucg_training: 25 | txt: 0.5 26 | 27 | sd_15_ckpt: #"v1-5-pruned-emaonly.ckpt" 28 | 29 | unet_config: 30 | target: custom_net.openaimodeli.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 12 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_heads: 8 40 | use_spatial_transformer: True 41 | transformer_depth: 1 42 | context_dim: 768 43 | use_checkpoint: True 44 | legacy: False 45 | 46 | 47 | 48 | first_stage_config: 49 | target: custom_net.autoencoder.AutoencoderKL 50 | params: 51 | embed_dim: 4 52 | monitor: val/rec_loss 53 | ddconfig: 54 | double_z: true 55 | z_channels: 4 56 | resolution: 256 57 | in_channels: 3 58 | out_ch: 3 59 | ch: 128 60 | ch_mult: 61 | - 1 62 | - 2 63 | - 4 64 | - 4 65 | num_res_blocks: 2 66 | attn_resolutions: [] 67 | dropout: 0.0 68 | lossconfig: 69 | target: torch.nn.Identity 70 | 71 | 72 | cond_stage_config: 73 | target: custom_net.modules.FrozenCLIPImageEmbedder 74 | 75 | 76 | text_encoder_config: 77 | target: custom_net.modules.FrozenCLIPEmbedder 78 | params: 79 | version: openai/clip-vit-large-patch14 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | ## this is a template dataset, NOTE: Need to add a background condition 90 | train_data: 91 | target: data.dataset_inpaint.Dataset 92 | params: 93 | image_size: 256 94 | root: examples/dataset/ 95 | 96 | 97 | 98 | 99 | train_dataloader: 100 | batch_size: 12 101 | num_workers: 8 102 | 103 | 104 | 105 | 106 | 107 | 108 | lightning: 109 | find_unused_parameters: false 110 | metrics_over_trainsteps_checkpoint: True 111 | modelcheckpoint: 112 | params: 113 | every_n_train_steps: 10000 114 | save_top_k: -1 115 | monitor: null 116 | callbacks: 117 | image_logger: 118 | target: main.ImageLogger 119 | params: 120 | batch_frequency: 2500 121 | max_images: 32 122 | increase_log_steps: False 123 | log_first_step: True 124 | log_images_kwargs: 125 | use_ema_scope: False 126 | inpaint: False 127 | plot_progressive_rows: False 128 | plot_diffusion_rows: False 129 | N: 32 130 | unconditional_guidance_scale: 3.0 131 | unconditional_guidance_label: [""] 132 | 133 | trainer: 134 | benchmark: True 135 | limit_val_batches: 0 136 | num_sanity_val_steps: 0 137 | accumulate_grad_batches: 1 138 | -------------------------------------------------------------------------------- /custom_net/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /custom_net/attention.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, einsum 6 | from einops import rearrange, repeat 7 | 8 | from .util import checkpoint 9 | 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def uniq(arr): 16 | return{el: True for el in arr}.keys() 17 | 18 | 19 | def default(val, d): 20 | if exists(val): 21 | return val 22 | return d() if isfunction(d) else d 23 | 24 | 25 | def max_neg_value(t): 26 | return -torch.finfo(t.dtype).max 27 | 28 | 29 | def init_(tensor): 30 | dim = tensor.shape[-1] 31 | std = 1 / math.sqrt(dim) 32 | tensor.uniform_(-std, std) 33 | return tensor 34 | 35 | 36 | # feedforward 37 | class GEGLU(nn.Module): 38 | def __init__(self, dim_in, dim_out): 39 | super().__init__() 40 | self.proj = nn.Linear(dim_in, dim_out * 2) 41 | 42 | def forward(self, x): 43 | x, gate = self.proj(x).chunk(2, dim=-1) 44 | return x * F.gelu(gate) 45 | 46 | 47 | class FeedForward(nn.Module): 48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 49 | super().__init__() 50 | inner_dim = int(dim * mult) 51 | dim_out = default(dim_out, dim) 52 | project_in = nn.Sequential( 53 | nn.Linear(dim, inner_dim), 54 | nn.GELU() 55 | ) if not glu else GEGLU(dim, inner_dim) 56 | 57 | self.net = nn.Sequential( 58 | project_in, 59 | nn.Dropout(dropout), 60 | nn.Linear(inner_dim, dim_out) 61 | ) 62 | 63 | def forward(self, x): 64 | return self.net(x) 65 | 66 | 67 | def zero_module(module): 68 | """ 69 | Zero out the parameters of a module and return it. 70 | """ 71 | for p in module.parameters(): 72 | p.detach().zero_() 73 | return module 74 | 75 | 76 | def Normalize(in_channels): 77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 78 | 79 | 80 | class LinearAttention(nn.Module): 81 | def __init__(self, dim, heads=4, dim_head=32): 82 | super().__init__() 83 | self.heads = heads 84 | hidden_dim = dim_head * heads 85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 87 | 88 | def forward(self, x): 89 | b, c, h, w = x.shape 90 | qkv = self.to_qkv(x) 91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) 92 | k = k.softmax(dim=-1) 93 | context = torch.einsum('bhdn,bhen->bhde', k, v) 94 | out = torch.einsum('bhde,bhdn->bhen', context, q) 95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) 96 | return self.to_out(out) 97 | 98 | 99 | class SpatialSelfAttention(nn.Module): 100 | def __init__(self, in_channels): 101 | super().__init__() 102 | self.in_channels = in_channels 103 | 104 | self.norm = Normalize(in_channels) 105 | self.q = torch.nn.Conv2d(in_channels, 106 | in_channels, 107 | kernel_size=1, 108 | stride=1, 109 | padding=0) 110 | self.k = torch.nn.Conv2d(in_channels, 111 | in_channels, 112 | kernel_size=1, 113 | stride=1, 114 | padding=0) 115 | self.v = torch.nn.Conv2d(in_channels, 116 | in_channels, 117 | kernel_size=1, 118 | stride=1, 119 | padding=0) 120 | self.proj_out = torch.nn.Conv2d(in_channels, 121 | in_channels, 122 | kernel_size=1, 123 | stride=1, 124 | padding=0) 125 | 126 | def forward(self, x): 127 | h_ = x 128 | h_ = self.norm(h_) 129 | q = self.q(h_) 130 | k = self.k(h_) 131 | v = self.v(h_) 132 | 133 | # compute attention 134 | b,c,h,w = q.shape 135 | q = rearrange(q, 'b c h w -> b (h w) c') 136 | k = rearrange(k, 'b c h w -> b c (h w)') 137 | w_ = torch.einsum('bij,bjk->bik', q, k) 138 | 139 | w_ = w_ * (int(c)**(-0.5)) 140 | w_ = torch.nn.functional.softmax(w_, dim=2) 141 | 142 | # attend to values 143 | v = rearrange(v, 'b c h w -> b c (h w)') 144 | w_ = rearrange(w_, 'b i j -> b j i') 145 | h_ = torch.einsum('bij,bjk->bik', v, w_) 146 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) 147 | h_ = self.proj_out(h_) 148 | 149 | return x+h_ 150 | 151 | 152 | class CrossAttention(nn.Module): 153 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): 154 | super().__init__() 155 | inner_dim = dim_head * heads 156 | context_dim = default(context_dim, query_dim) 157 | 158 | self.scale = dim_head ** -0.5 159 | self.heads = heads 160 | 161 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 162 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 163 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 164 | 165 | self.to_out = nn.Sequential( 166 | nn.Linear(inner_dim, query_dim), 167 | nn.Dropout(dropout) 168 | ) 169 | 170 | def forward(self, x, context=None, mask=None): 171 | h = self.heads 172 | 173 | q = self.to_q(x) 174 | context = default(context, x) 175 | k = self.to_k(context) 176 | v = self.to_v(context) 177 | 178 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 179 | 180 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 181 | 182 | if exists(mask): 183 | mask = rearrange(mask, 'b ... -> b (...)') 184 | max_neg_value = -torch.finfo(sim.dtype).max 185 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 186 | sim.masked_fill_(~mask, max_neg_value) 187 | 188 | # attention, what we cannot get enough of 189 | attn = sim.softmax(dim=-1) 190 | 191 | out = einsum('b i j, b j d -> b i d', attn, v) 192 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 193 | return self.to_out(out) 194 | 195 | 196 | class BasicTransformerBlock(nn.Module): 197 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, 198 | disable_self_attn=False): 199 | super().__init__() 200 | self.disable_self_attn = disable_self_attn 201 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, 202 | context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn 203 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 204 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, 205 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none 206 | self.norm1 = nn.LayerNorm(dim) 207 | self.norm2 = nn.LayerNorm(dim) 208 | self.norm3 = nn.LayerNorm(dim) 209 | self.checkpoint = checkpoint 210 | 211 | def forward(self, x, context=None): 212 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) 213 | 214 | def _forward(self, x, context=None): 215 | x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x 216 | x = self.attn2(self.norm2(x), context=context) + x 217 | x = self.ff(self.norm3(x)) + x 218 | return x 219 | 220 | 221 | class SpatialTransformer(nn.Module): 222 | """ 223 | Transformer block for image-like data. 224 | First, project the input (aka embedding) 225 | and reshape to b, t, d. 226 | Then apply standard transformer action. 227 | Finally, reshape to image 228 | """ 229 | def __init__(self, in_channels, n_heads, d_head, 230 | depth=1, dropout=0., context_dim=None, 231 | disable_self_attn=False): 232 | super().__init__() 233 | self.in_channels = in_channels 234 | inner_dim = n_heads * d_head 235 | self.norm = Normalize(in_channels) 236 | 237 | self.proj_in = nn.Conv2d(in_channels, 238 | inner_dim, 239 | kernel_size=1, 240 | stride=1, 241 | padding=0) 242 | 243 | self.transformer_blocks = nn.ModuleList( 244 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim, 245 | disable_self_attn=disable_self_attn) 246 | for d in range(depth)] 247 | ) 248 | 249 | self.proj_out = zero_module(nn.Conv2d(inner_dim, 250 | in_channels, 251 | kernel_size=1, 252 | stride=1, 253 | padding=0)) 254 | 255 | def forward(self, x, context=None): 256 | # note: if no context is given, cross-attention defaults to self-attention 257 | b, c, h, w = x.shape 258 | x_in = x 259 | x = self.norm(x) 260 | x = self.proj_in(x) 261 | x = rearrange(x, 'b c h w -> b (h w) c').contiguous() 262 | for block in self.transformer_blocks: 263 | x = block(x, context=context) 264 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() 265 | x = self.proj_out(x) 266 | return x + x_in 267 | -------------------------------------------------------------------------------- /custom_net/attentioni.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, einsum 6 | from einops import rearrange, repeat 7 | 8 | from .util import checkpoint 9 | 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def uniq(arr): 16 | return{el: True for el in arr}.keys() 17 | 18 | 19 | def default(val, d): 20 | if exists(val): 21 | return val 22 | return d() if isfunction(d) else d 23 | 24 | 25 | def max_neg_value(t): 26 | return -torch.finfo(t.dtype).max 27 | 28 | 29 | def init_(tensor): 30 | dim = tensor.shape[-1] 31 | std = 1 / math.sqrt(dim) 32 | tensor.uniform_(-std, std) 33 | return tensor 34 | 35 | 36 | # feedforward 37 | class GEGLU(nn.Module): 38 | def __init__(self, dim_in, dim_out): 39 | super().__init__() 40 | self.proj = nn.Linear(dim_in, dim_out * 2) 41 | 42 | def forward(self, x): 43 | x, gate = self.proj(x).chunk(2, dim=-1) 44 | return x * F.gelu(gate) 45 | 46 | 47 | class FeedForward(nn.Module): 48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 49 | super().__init__() 50 | inner_dim = int(dim * mult) 51 | dim_out = default(dim_out, dim) 52 | project_in = nn.Sequential( 53 | nn.Linear(dim, inner_dim), 54 | nn.GELU() 55 | ) if not glu else GEGLU(dim, inner_dim) 56 | 57 | self.net = nn.Sequential( 58 | project_in, 59 | nn.Dropout(dropout), 60 | nn.Linear(inner_dim, dim_out) 61 | ) 62 | 63 | def forward(self, x): 64 | return self.net(x) 65 | 66 | 67 | def zero_module(module): 68 | """ 69 | Zero out the parameters of a module and return it. 70 | """ 71 | for p in module.parameters(): 72 | p.detach().zero_() 73 | return module 74 | 75 | 76 | def Normalize(in_channels): 77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 78 | 79 | 80 | class LinearAttention(nn.Module): 81 | def __init__(self, dim, heads=4, dim_head=32): 82 | super().__init__() 83 | self.heads = heads 84 | hidden_dim = dim_head * heads 85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 87 | 88 | def forward(self, x): 89 | b, c, h, w = x.shape 90 | qkv = self.to_qkv(x) 91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) 92 | k = k.softmax(dim=-1) 93 | context = torch.einsum('bhdn,bhen->bhde', k, v) 94 | out = torch.einsum('bhde,bhdn->bhen', context, q) 95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) 96 | return self.to_out(out) 97 | 98 | 99 | class SpatialSelfAttention(nn.Module): 100 | def __init__(self, in_channels): 101 | super().__init__() 102 | self.in_channels = in_channels 103 | 104 | self.norm = Normalize(in_channels) 105 | self.q = torch.nn.Conv2d(in_channels, 106 | in_channels, 107 | kernel_size=1, 108 | stride=1, 109 | padding=0) 110 | self.k = torch.nn.Conv2d(in_channels, 111 | in_channels, 112 | kernel_size=1, 113 | stride=1, 114 | padding=0) 115 | self.v = torch.nn.Conv2d(in_channels, 116 | in_channels, 117 | kernel_size=1, 118 | stride=1, 119 | padding=0) 120 | self.proj_out = torch.nn.Conv2d(in_channels, 121 | in_channels, 122 | kernel_size=1, 123 | stride=1, 124 | padding=0) 125 | 126 | def forward(self, x): 127 | h_ = x 128 | h_ = self.norm(h_) 129 | q = self.q(h_) 130 | k = self.k(h_) 131 | v = self.v(h_) 132 | 133 | # compute attention 134 | b,c,h,w = q.shape 135 | q = rearrange(q, 'b c h w -> b (h w) c') 136 | k = rearrange(k, 'b c h w -> b c (h w)') 137 | w_ = torch.einsum('bij,bjk->bik', q, k) 138 | 139 | w_ = w_ * (int(c)**(-0.5)) 140 | w_ = torch.nn.functional.softmax(w_, dim=2) 141 | 142 | # attend to values 143 | v = rearrange(v, 'b c h w -> b c (h w)') 144 | w_ = rearrange(w_, 'b i j -> b j i') 145 | h_ = torch.einsum('bij,bjk->bik', v, w_) 146 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) 147 | h_ = self.proj_out(h_) 148 | 149 | return x+h_ 150 | 151 | ## Dual Cross-Attntion 152 | class CrossAttention(nn.Module): 153 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., new_text_att=False,): 154 | super().__init__() 155 | inner_dim = dim_head * heads 156 | context_dim = default(context_dim, query_dim) 157 | 158 | self.scale = dim_head ** -0.5 159 | self.heads = heads 160 | 161 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 162 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 163 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 164 | 165 | self.to_out = nn.Sequential( 166 | nn.Linear(inner_dim, query_dim), 167 | nn.Dropout(dropout) 168 | ) 169 | 170 | self.new_text_att = new_text_att 171 | if new_text_att: 172 | self.to_k_text = nn.Linear(context_dim, inner_dim, bias=False) 173 | self.to_v_text = nn.Linear(context_dim, inner_dim, bias=False) 174 | self.text_scale = 1.0 175 | 176 | 177 | 178 | def forward(self, x, context=None, mask=None): 179 | if self.new_text_att: 180 | context_text = context[:,1:,:] ## text embedding 181 | context = context[:,:1,:] ## image embedding 182 | 183 | h = self.heads 184 | 185 | q = self.to_q(x) 186 | context = default(context, x) 187 | k = self.to_k(context) 188 | v = self.to_v(context) 189 | 190 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 191 | 192 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 193 | 194 | if exists(mask): 195 | mask = rearrange(mask, 'b ... -> b (...)') 196 | max_neg_value = -torch.finfo(sim.dtype).max 197 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 198 | sim.masked_fill_(~mask, max_neg_value) 199 | 200 | # attention, what we cannot get enough of 201 | attn = sim.softmax(dim=-1) 202 | 203 | out = einsum('b i j, b j d -> b i d', attn, v) 204 | 205 | ## text att 206 | if self.new_text_att: 207 | k_text = self.to_k_text(context_text) 208 | v_text = self.to_v_text(context_text) 209 | k_text, v_text = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (k_text, v_text)) 210 | 211 | sim_text = einsum('b i d, b j d -> b i j', q, k_text) * self.scale 212 | 213 | attn_text = sim_text.softmax(dim=-1) 214 | out_text = einsum('b i j, b j d -> b i d', attn_text, v_text) 215 | 216 | out = out + out_text * self.text_scale 217 | 218 | # print("no text attn") 219 | 220 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 221 | return self.to_out(out) 222 | 223 | 224 | class BasicTransformerBlock(nn.Module): 225 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, 226 | disable_self_attn=False): 227 | super().__init__() 228 | self.disable_self_attn = disable_self_attn 229 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, 230 | context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn 231 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 232 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, 233 | heads=n_heads, dim_head=d_head, dropout=dropout, 234 | new_text_att=True, 235 | ) # is self-attn if context is none 236 | self.norm1 = nn.LayerNorm(dim) 237 | self.norm2 = nn.LayerNorm(dim) 238 | self.norm3 = nn.LayerNorm(dim) 239 | self.checkpoint = checkpoint 240 | 241 | def forward(self, x, context=None): 242 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) 243 | 244 | def _forward(self, x, context=None): 245 | x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x 246 | x = self.attn2(self.norm2(x), context=context) + x 247 | x = self.ff(self.norm3(x)) + x 248 | return x 249 | 250 | 251 | class SpatialTransformer(nn.Module): 252 | """ 253 | Transformer block for image-like data. 254 | First, project the input (aka embedding) 255 | and reshape to b, t, d. 256 | Then apply standard transformer action. 257 | Finally, reshape to image 258 | """ 259 | def __init__(self, in_channels, n_heads, d_head, 260 | depth=1, dropout=0., context_dim=None, 261 | disable_self_attn=False): 262 | super().__init__() 263 | self.in_channels = in_channels 264 | inner_dim = n_heads * d_head 265 | self.norm = Normalize(in_channels) 266 | 267 | self.proj_in = nn.Conv2d(in_channels, 268 | inner_dim, 269 | kernel_size=1, 270 | stride=1, 271 | padding=0) 272 | 273 | self.transformer_blocks = nn.ModuleList( 274 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim, 275 | disable_self_attn=disable_self_attn) 276 | for d in range(depth)] 277 | ) 278 | 279 | self.proj_out = zero_module(nn.Conv2d(inner_dim, 280 | in_channels, 281 | kernel_size=1, 282 | stride=1, 283 | padding=0)) 284 | 285 | def forward(self, x, context=None): 286 | # note: if no context is given, cross-attention defaults to self-attention 287 | b, c, h, w = x.shape 288 | x_in = x 289 | x = self.norm(x) 290 | x = self.proj_in(x) 291 | x = rearrange(x, 'b c h w -> b (h w) c').contiguous() 292 | for block in self.transformer_blocks: 293 | x = block(x, context=context) 294 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() 295 | x = self.proj_out(x) 296 | return x + x_in 297 | -------------------------------------------------------------------------------- /custom_net/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | import torch.nn.functional as F 4 | from contextlib import contextmanager 5 | 6 | from taming.modules.vqvae.quantize import VectorQuantizer 7 | 8 | from .model import Encoder, Decoder 9 | from .distributions import DiagonalGaussianDistribution 10 | 11 | from .customnet_util import instantiate_from_config 12 | 13 | 14 | class VQModel(pl.LightningModule): 15 | def __init__(self, 16 | ddconfig, 17 | lossconfig, 18 | n_embed, 19 | embed_dim, 20 | ckpt_path=None, 21 | ignore_keys=[], 22 | image_key="image", 23 | colorize_nlabels=None, 24 | monitor=None, 25 | batch_resize_range=None, 26 | scheduler_config=None, 27 | lr_g_factor=1.0, 28 | remap=None, 29 | sane_index_shape=False, # tell vector quantizer to return indices as bhw 30 | use_ema=False 31 | ): 32 | super().__init__() 33 | self.embed_dim = embed_dim 34 | self.n_embed = n_embed 35 | self.image_key = image_key 36 | self.encoder = Encoder(**ddconfig) 37 | self.decoder = Decoder(**ddconfig) 38 | self.loss = instantiate_from_config(lossconfig) 39 | self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, 40 | remap=remap, 41 | sane_index_shape=sane_index_shape) 42 | self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) 43 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 44 | if colorize_nlabels is not None: 45 | assert type(colorize_nlabels)==int 46 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 47 | if monitor is not None: 48 | self.monitor = monitor 49 | self.batch_resize_range = batch_resize_range 50 | if self.batch_resize_range is not None: 51 | print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.") 52 | 53 | self.use_ema = use_ema 54 | if self.use_ema: 55 | self.model_ema = LitEma(self) 56 | print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") 57 | 58 | if ckpt_path is not None: 59 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 60 | self.scheduler_config = scheduler_config 61 | self.lr_g_factor = lr_g_factor 62 | 63 | @contextmanager 64 | def ema_scope(self, context=None): 65 | if self.use_ema: 66 | self.model_ema.store(self.parameters()) 67 | self.model_ema.copy_to(self) 68 | if context is not None: 69 | print(f"{context}: Switched to EMA weights") 70 | try: 71 | yield None 72 | finally: 73 | if self.use_ema: 74 | self.model_ema.restore(self.parameters()) 75 | if context is not None: 76 | print(f"{context}: Restored training weights") 77 | 78 | def init_from_ckpt(self, path, ignore_keys=list()): 79 | sd = torch.load(path, map_location="cpu")["state_dict"] 80 | keys = list(sd.keys()) 81 | for k in keys: 82 | for ik in ignore_keys: 83 | if k.startswith(ik): 84 | print("Deleting key {} from state_dict.".format(k)) 85 | del sd[k] 86 | missing, unexpected = self.load_state_dict(sd, strict=False) 87 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") 88 | if len(missing) > 0: 89 | print(f"Missing Keys: {missing}") 90 | print(f"Unexpected Keys: {unexpected}") 91 | 92 | def on_train_batch_end(self, *args, **kwargs): 93 | if self.use_ema: 94 | self.model_ema(self) 95 | 96 | def encode(self, x): 97 | h = self.encoder(x) 98 | h = self.quant_conv(h) 99 | quant, emb_loss, info = self.quantize(h) 100 | return quant, emb_loss, info 101 | 102 | def encode_to_prequant(self, x): 103 | h = self.encoder(x) 104 | h = self.quant_conv(h) 105 | return h 106 | 107 | def decode(self, quant): 108 | quant = self.post_quant_conv(quant) 109 | dec = self.decoder(quant) 110 | return dec 111 | 112 | def decode_code(self, code_b): 113 | quant_b = self.quantize.embed_code(code_b) 114 | dec = self.decode(quant_b) 115 | return dec 116 | 117 | def forward(self, input, return_pred_indices=False): 118 | quant, diff, (_,_,ind) = self.encode(input) 119 | dec = self.decode(quant) 120 | if return_pred_indices: 121 | return dec, diff, ind 122 | return dec, diff 123 | 124 | def get_input(self, batch, k): 125 | x = batch[k] 126 | if len(x.shape) == 3: 127 | x = x[..., None] 128 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() 129 | if self.batch_resize_range is not None: 130 | lower_size = self.batch_resize_range[0] 131 | upper_size = self.batch_resize_range[1] 132 | if self.global_step <= 4: 133 | # do the first few batches with max size to avoid later oom 134 | new_resize = upper_size 135 | else: 136 | new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16)) 137 | if new_resize != x.shape[2]: 138 | x = F.interpolate(x, size=new_resize, mode="bicubic") 139 | x = x.detach() 140 | return x 141 | 142 | def training_step(self, batch, batch_idx, optimizer_idx): 143 | # https://github.com/pytorch/pytorch/issues/37142 144 | # try not to fool the heuristics 145 | x = self.get_input(batch, self.image_key) 146 | xrec, qloss, ind = self(x, return_pred_indices=True) 147 | 148 | if optimizer_idx == 0: 149 | # autoencode 150 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, 151 | last_layer=self.get_last_layer(), split="train", 152 | predicted_indices=ind) 153 | 154 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) 155 | return aeloss 156 | 157 | if optimizer_idx == 1: 158 | # discriminator 159 | discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, 160 | last_layer=self.get_last_layer(), split="train") 161 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) 162 | return discloss 163 | 164 | def validation_step(self, batch, batch_idx): 165 | log_dict = self._validation_step(batch, batch_idx) 166 | with self.ema_scope(): 167 | log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") 168 | return log_dict 169 | 170 | def _validation_step(self, batch, batch_idx, suffix=""): 171 | x = self.get_input(batch, self.image_key) 172 | xrec, qloss, ind = self(x, return_pred_indices=True) 173 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, 174 | self.global_step, 175 | last_layer=self.get_last_layer(), 176 | split="val"+suffix, 177 | predicted_indices=ind 178 | ) 179 | 180 | discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, 181 | self.global_step, 182 | last_layer=self.get_last_layer(), 183 | split="val"+suffix, 184 | predicted_indices=ind 185 | ) 186 | rec_loss = log_dict_ae[f"val{suffix}/rec_loss"] 187 | self.log(f"val{suffix}/rec_loss", rec_loss, 188 | prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) 189 | self.log(f"val{suffix}/aeloss", aeloss, 190 | prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) 191 | if version.parse(pl.__version__) >= version.parse('1.4.0'): 192 | del log_dict_ae[f"val{suffix}/rec_loss"] 193 | self.log_dict(log_dict_ae) 194 | self.log_dict(log_dict_disc) 195 | return self.log_dict 196 | 197 | def configure_optimizers(self): 198 | lr_d = self.learning_rate 199 | lr_g = self.lr_g_factor*self.learning_rate 200 | print("lr_d", lr_d) 201 | print("lr_g", lr_g) 202 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ 203 | list(self.decoder.parameters())+ 204 | list(self.quantize.parameters())+ 205 | list(self.quant_conv.parameters())+ 206 | list(self.post_quant_conv.parameters()), 207 | lr=lr_g, betas=(0.5, 0.9)) 208 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), 209 | lr=lr_d, betas=(0.5, 0.9)) 210 | 211 | if self.scheduler_config is not None: 212 | scheduler = instantiate_from_config(self.scheduler_config) 213 | 214 | print("Setting up LambdaLR scheduler...") 215 | scheduler = [ 216 | { 217 | 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule), 218 | 'interval': 'step', 219 | 'frequency': 1 220 | }, 221 | { 222 | 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule), 223 | 'interval': 'step', 224 | 'frequency': 1 225 | }, 226 | ] 227 | return [opt_ae, opt_disc], scheduler 228 | return [opt_ae, opt_disc], [] 229 | 230 | def get_last_layer(self): 231 | return self.decoder.conv_out.weight 232 | 233 | def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): 234 | log = dict() 235 | x = self.get_input(batch, self.image_key) 236 | x = x.to(self.device) 237 | if only_inputs: 238 | log["inputs"] = x 239 | return log 240 | xrec, _ = self(x) 241 | if x.shape[1] > 3: 242 | # colorize with random projection 243 | assert xrec.shape[1] > 3 244 | x = self.to_rgb(x) 245 | xrec = self.to_rgb(xrec) 246 | log["inputs"] = x 247 | log["reconstructions"] = xrec 248 | if plot_ema: 249 | with self.ema_scope(): 250 | xrec_ema, _ = self(x) 251 | if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) 252 | log["reconstructions_ema"] = xrec_ema 253 | return log 254 | 255 | def to_rgb(self, x): 256 | assert self.image_key == "segmentation" 257 | if not hasattr(self, "colorize"): 258 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 259 | x = F.conv2d(x, weight=self.colorize) 260 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1. 261 | return x 262 | 263 | 264 | class VQModelInterface(VQModel): 265 | def __init__(self, embed_dim, *args, **kwargs): 266 | super().__init__(embed_dim=embed_dim, *args, **kwargs) 267 | self.embed_dim = embed_dim 268 | 269 | def encode(self, x): 270 | h = self.encoder(x) 271 | h = self.quant_conv(h) 272 | return h 273 | 274 | def decode(self, h, force_not_quantize=False): 275 | # also go through quantization layer 276 | if not force_not_quantize: 277 | quant, emb_loss, info = self.quantize(h) 278 | else: 279 | quant = h 280 | quant = self.post_quant_conv(quant) 281 | dec = self.decoder(quant) 282 | return dec 283 | 284 | 285 | class AutoencoderKL(pl.LightningModule): 286 | def __init__(self, 287 | ddconfig, 288 | lossconfig, 289 | embed_dim, 290 | ckpt_path=None, 291 | ignore_keys=[], 292 | image_key="image", 293 | colorize_nlabels=None, 294 | monitor=None, 295 | ): 296 | super().__init__() 297 | self.image_key = image_key 298 | self.encoder = Encoder(**ddconfig) 299 | self.decoder = Decoder(**ddconfig) 300 | self.loss = instantiate_from_config(lossconfig) 301 | assert ddconfig["double_z"] 302 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) 303 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 304 | self.embed_dim = embed_dim 305 | if colorize_nlabels is not None: 306 | assert type(colorize_nlabels)==int 307 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 308 | if monitor is not None: 309 | self.monitor = monitor 310 | if ckpt_path is not None: 311 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 312 | 313 | def init_from_ckpt(self, path, ignore_keys=list()): 314 | sd = torch.load(path, map_location="cpu")["state_dict"] 315 | keys = list(sd.keys()) 316 | for k in keys: 317 | for ik in ignore_keys: 318 | if k.startswith(ik): 319 | print("Deleting key {} from state_dict.".format(k)) 320 | del sd[k] 321 | self.load_state_dict(sd, strict=False) 322 | print(f"Restored from {path}") 323 | 324 | def encode(self, x): 325 | h = self.encoder(x) 326 | moments = self.quant_conv(h) 327 | posterior = DiagonalGaussianDistribution(moments) 328 | return posterior 329 | 330 | def decode(self, z): 331 | z = self.post_quant_conv(z) 332 | dec = self.decoder(z) 333 | return dec 334 | 335 | def forward(self, input, sample_posterior=True): 336 | posterior = self.encode(input) 337 | if sample_posterior: 338 | z = posterior.sample() 339 | else: 340 | z = posterior.mode() 341 | dec = self.decode(z) 342 | return dec, posterior 343 | 344 | def get_input(self, batch, k): 345 | x = batch[k] 346 | if len(x.shape) == 3: 347 | x = x[..., None] 348 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() 349 | return x 350 | 351 | def training_step(self, batch, batch_idx, optimizer_idx): 352 | inputs = self.get_input(batch, self.image_key) 353 | reconstructions, posterior = self(inputs) 354 | 355 | if optimizer_idx == 0: 356 | # train encoder+decoder+logvar 357 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 358 | last_layer=self.get_last_layer(), split="train") 359 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 360 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) 361 | return aeloss 362 | 363 | if optimizer_idx == 1: 364 | # train the discriminator 365 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 366 | last_layer=self.get_last_layer(), split="train") 367 | 368 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 369 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) 370 | return discloss 371 | 372 | def validation_step(self, batch, batch_idx): 373 | inputs = self.get_input(batch, self.image_key) 374 | reconstructions, posterior = self(inputs) 375 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, 376 | last_layer=self.get_last_layer(), split="val") 377 | 378 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, 379 | last_layer=self.get_last_layer(), split="val") 380 | 381 | self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) 382 | self.log_dict(log_dict_ae) 383 | self.log_dict(log_dict_disc) 384 | return self.log_dict 385 | 386 | def configure_optimizers(self): 387 | lr = self.learning_rate 388 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ 389 | list(self.decoder.parameters())+ 390 | list(self.quant_conv.parameters())+ 391 | list(self.post_quant_conv.parameters()), 392 | lr=lr, betas=(0.5, 0.9)) 393 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), 394 | lr=lr, betas=(0.5, 0.9)) 395 | return [opt_ae, opt_disc], [] 396 | 397 | def get_last_layer(self): 398 | return self.decoder.conv_out.weight 399 | 400 | @torch.no_grad() 401 | def log_images(self, batch, only_inputs=False, **kwargs): 402 | log = dict() 403 | x = self.get_input(batch, self.image_key) 404 | x = x.to(self.device) 405 | if not only_inputs: 406 | xrec, posterior = self(x) 407 | if x.shape[1] > 3: 408 | # colorize with random projection 409 | assert xrec.shape[1] > 3 410 | x = self.to_rgb(x) 411 | xrec = self.to_rgb(xrec) 412 | log["samples"] = self.decode(torch.randn_like(posterior.sample())) 413 | log["reconstructions"] = xrec 414 | log["inputs"] = x 415 | return log 416 | 417 | def to_rgb(self, x): 418 | assert self.image_key == "segmentation" 419 | if not hasattr(self, "colorize"): 420 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 421 | x = F.conv2d(x, weight=self.colorize) 422 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1. 423 | return x 424 | 425 | 426 | class IdentityFirstStage(torch.nn.Module): 427 | def __init__(self, *args, vq_interface=False, **kwargs): 428 | self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff 429 | super().__init__() 430 | 431 | def encode(self, x, *args, **kwargs): 432 | return x 433 | 434 | def decode(self, x, *args, **kwargs): 435 | return x 436 | 437 | def quantize(self, x, *args, **kwargs): 438 | if self.vq_interface: 439 | return x, None, [None, None, None] 440 | return x 441 | 442 | def forward(self, x, *args, **kwargs): 443 | return x 444 | -------------------------------------------------------------------------------- /custom_net/classifier.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pytorch_lightning as pl 4 | from omegaconf import OmegaConf 5 | from torch.nn import functional as F 6 | from torch.optim import AdamW 7 | from torch.optim.lr_scheduler import LambdaLR 8 | from copy import deepcopy 9 | from einops import rearrange 10 | from glob import glob 11 | from natsort import natsorted 12 | 13 | from .openaimodel import EncoderUNetModel, UNetModel 14 | from .customnet_util import log_txt_as_img, default, ismap, instantiate_from_config 15 | 16 | __models__ = { 17 | 'class_label': EncoderUNetModel, 18 | 'segmentation': UNetModel 19 | } 20 | 21 | 22 | def disabled_train(self, mode=True): 23 | """Overwrite model.train with this function to make sure train/eval mode 24 | does not change anymore.""" 25 | return self 26 | 27 | 28 | class NoisyLatentImageClassifier(pl.LightningModule): 29 | 30 | def __init__(self, 31 | diffusion_path, 32 | num_classes, 33 | ckpt_path=None, 34 | pool='attention', 35 | label_key=None, 36 | diffusion_ckpt_path=None, 37 | scheduler_config=None, 38 | weight_decay=1.e-2, 39 | log_steps=10, 40 | monitor='val/loss', 41 | *args, 42 | **kwargs): 43 | super().__init__(*args, **kwargs) 44 | self.num_classes = num_classes 45 | # get latest config of diffusion model 46 | diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] 47 | self.diffusion_config = OmegaConf.load(diffusion_config).model 48 | self.diffusion_config.params.ckpt_path = diffusion_ckpt_path 49 | self.load_diffusion() 50 | 51 | self.monitor = monitor 52 | self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 53 | self.log_time_interval = self.diffusion_model.num_timesteps // log_steps 54 | self.log_steps = log_steps 55 | 56 | self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ 57 | else self.diffusion_model.cond_stage_key 58 | 59 | assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' 60 | 61 | if self.label_key not in __models__: 62 | raise NotImplementedError() 63 | 64 | self.load_classifier(ckpt_path, pool) 65 | 66 | self.scheduler_config = scheduler_config 67 | self.use_scheduler = self.scheduler_config is not None 68 | self.weight_decay = weight_decay 69 | 70 | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): 71 | sd = torch.load(path, map_location="cpu") 72 | if "state_dict" in list(sd.keys()): 73 | sd = sd["state_dict"] 74 | keys = list(sd.keys()) 75 | for k in keys: 76 | for ik in ignore_keys: 77 | if k.startswith(ik): 78 | print("Deleting key {} from state_dict.".format(k)) 79 | del sd[k] 80 | missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( 81 | sd, strict=False) 82 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") 83 | if len(missing) > 0: 84 | print(f"Missing Keys: {missing}") 85 | if len(unexpected) > 0: 86 | print(f"Unexpected Keys: {unexpected}") 87 | 88 | def load_diffusion(self): 89 | model = instantiate_from_config(self.diffusion_config) 90 | self.diffusion_model = model.eval() 91 | self.diffusion_model.train = disabled_train 92 | for param in self.diffusion_model.parameters(): 93 | param.requires_grad = False 94 | 95 | def load_classifier(self, ckpt_path, pool): 96 | model_config = deepcopy(self.diffusion_config.params.unet_config.params) 97 | model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels 98 | model_config.out_channels = self.num_classes 99 | if self.label_key == 'class_label': 100 | model_config.pool = pool 101 | 102 | self.model = __models__[self.label_key](**model_config) 103 | if ckpt_path is not None: 104 | print('#####################################################################') 105 | print(f'load from ckpt "{ckpt_path}"') 106 | print('#####################################################################') 107 | self.init_from_ckpt(ckpt_path) 108 | 109 | @torch.no_grad() 110 | def get_x_noisy(self, x, t, noise=None): 111 | noise = default(noise, lambda: torch.randn_like(x)) 112 | continuous_sqrt_alpha_cumprod = None 113 | if self.diffusion_model.use_continuous_noise: 114 | continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) 115 | # todo: make sure t+1 is correct here 116 | 117 | return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, 118 | continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) 119 | 120 | def forward(self, x_noisy, t, *args, **kwargs): 121 | return self.model(x_noisy, t) 122 | 123 | @torch.no_grad() 124 | def get_input(self, batch, k): 125 | x = batch[k] 126 | if len(x.shape) == 3: 127 | x = x[..., None] 128 | x = rearrange(x, 'b h w c -> b c h w') 129 | x = x.to(memory_format=torch.contiguous_format).float() 130 | return x 131 | 132 | @torch.no_grad() 133 | def get_conditioning(self, batch, k=None): 134 | if k is None: 135 | k = self.label_key 136 | assert k is not None, 'Needs to provide label key' 137 | 138 | targets = batch[k].to(self.device) 139 | 140 | if self.label_key == 'segmentation': 141 | targets = rearrange(targets, 'b h w c -> b c h w') 142 | for down in range(self.numd): 143 | h, w = targets.shape[-2:] 144 | targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') 145 | 146 | # targets = rearrange(targets,'b c h w -> b h w c') 147 | 148 | return targets 149 | 150 | def compute_top_k(self, logits, labels, k, reduction="mean"): 151 | _, top_ks = torch.topk(logits, k, dim=1) 152 | if reduction == "mean": 153 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() 154 | elif reduction == "none": 155 | return (top_ks == labels[:, None]).float().sum(dim=-1) 156 | 157 | def on_train_epoch_start(self): 158 | # save some memory 159 | self.diffusion_model.model.to('cpu') 160 | 161 | @torch.no_grad() 162 | def write_logs(self, loss, logits, targets): 163 | log_prefix = 'train' if self.training else 'val' 164 | log = {} 165 | log[f"{log_prefix}/loss"] = loss.mean() 166 | log[f"{log_prefix}/acc@1"] = self.compute_top_k( 167 | logits, targets, k=1, reduction="mean" 168 | ) 169 | log[f"{log_prefix}/acc@5"] = self.compute_top_k( 170 | logits, targets, k=5, reduction="mean" 171 | ) 172 | 173 | self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) 174 | self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) 175 | self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) 176 | lr = self.optimizers().param_groups[0]['lr'] 177 | self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) 178 | 179 | def shared_step(self, batch, t=None): 180 | x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) 181 | targets = self.get_conditioning(batch) 182 | if targets.dim() == 4: 183 | targets = targets.argmax(dim=1) 184 | if t is None: 185 | t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() 186 | else: 187 | t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() 188 | x_noisy = self.get_x_noisy(x, t) 189 | logits = self(x_noisy, t) 190 | 191 | loss = F.cross_entropy(logits, targets, reduction='none') 192 | 193 | self.write_logs(loss.detach(), logits.detach(), targets.detach()) 194 | 195 | loss = loss.mean() 196 | return loss, logits, x_noisy, targets 197 | 198 | def training_step(self, batch, batch_idx): 199 | loss, *_ = self.shared_step(batch) 200 | return loss 201 | 202 | def reset_noise_accs(self): 203 | self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in 204 | range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} 205 | 206 | def on_validation_start(self): 207 | self.reset_noise_accs() 208 | 209 | @torch.no_grad() 210 | def validation_step(self, batch, batch_idx): 211 | loss, *_ = self.shared_step(batch) 212 | 213 | for t in self.noisy_acc: 214 | _, logits, _, targets = self.shared_step(batch, t) 215 | self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) 216 | self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) 217 | 218 | return loss 219 | 220 | def configure_optimizers(self): 221 | optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) 222 | 223 | if self.use_scheduler: 224 | scheduler = instantiate_from_config(self.scheduler_config) 225 | 226 | print("Setting up LambdaLR scheduler...") 227 | scheduler = [ 228 | { 229 | 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), 230 | 'interval': 'step', 231 | 'frequency': 1 232 | }] 233 | return [optimizer], scheduler 234 | 235 | return optimizer 236 | 237 | @torch.no_grad() 238 | def log_images(self, batch, N=8, *args, **kwargs): 239 | log = dict() 240 | x = self.get_input(batch, self.diffusion_model.first_stage_key) 241 | log['inputs'] = x 242 | 243 | y = self.get_conditioning(batch) 244 | 245 | if self.label_key == 'class_label': 246 | y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) 247 | log['labels'] = y 248 | 249 | if ismap(y): 250 | log['labels'] = self.diffusion_model.to_rgb(y) 251 | 252 | for step in range(self.log_steps): 253 | current_time = step * self.log_time_interval 254 | 255 | _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) 256 | 257 | log[f'inputs@t{current_time}'] = x_noisy 258 | 259 | pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) 260 | pred = rearrange(pred, 'b h w c -> b c h w') 261 | 262 | log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) 263 | 264 | for key in log: 265 | log[key] = log[key][:N] 266 | 267 | return log 268 | -------------------------------------------------------------------------------- /custom_net/customnet.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from folder_paths import base_path 4 | sys.path.append(os.path.join(base_path,"custom_nodes","ComfyUI_CustomNet")) 5 | 6 | import einops 7 | import torch 8 | import torch as th 9 | import torch.nn as nn 10 | import cv2 11 | 12 | try: 13 | from pytorch_lightning.utilities.rank_zero import rank_zero_only 14 | except: 15 | try: 16 | from pytorch_lightning.utilities.distributed import rank_zero_only 17 | except: 18 | raise "import pytorch_lightning rank_zero_only error" 19 | import numpy as np 20 | from torch.optim.lr_scheduler import LambdaLR 21 | from einops import rearrange, repeat 22 | from torchvision.utils import make_grid 23 | 24 | from .util import ( 25 | conv_nd, 26 | linear, 27 | zero_module, 28 | timestep_embedding, 29 | ) 30 | from .attention import SpatialTransformer 31 | from .openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock 32 | 33 | from .ddpm import LatentDiffusion 34 | from .customnet_util import log_txt_as_img, exists, instantiate_from_config 35 | 36 | from .dddim import DDIMSampler 37 | from .customnet_util import load_state_dict 38 | 39 | 40 | 41 | 42 | 43 | 44 | class CustomNet(LatentDiffusion): 45 | def __init__(self, 46 | text_encoder_config, 47 | sd_15_ckpt=None, 48 | use_cond_concat=False, 49 | use_bbox_mask=False, 50 | use_bg_inpainting=False, 51 | learning_rate_scale=10, 52 | *args, **kwargs): 53 | super().__init__(*args, **kwargs) 54 | 55 | 56 | self.text_encoder = instantiate_from_config(text_encoder_config) 57 | 58 | if sd_15_ckpt is not None: 59 | self.load_model_from_ckpt(ckpt=sd_15_ckpt) 60 | 61 | self.use_cond_concat = use_cond_concat 62 | self.use_bbox_mask = use_bbox_mask 63 | self.use_bg_inpainting = use_bg_inpainting 64 | self.learning_rate_scale = learning_rate_scale 65 | 66 | 67 | def load_model_from_ckpt(self, ckpt, verbose=True): 68 | print(" =========================== init Stable Diffusion pretrained checkpoint =========================== ") 69 | print(f"Loading model from {ckpt}") 70 | pl_sd = torch.load(ckpt, map_location="cpu") 71 | if "global_step" in pl_sd: 72 | print(f"Global Step: {pl_sd['global_step']}") 73 | sd = pl_sd["state_dict"] 74 | sd_keys = sd.keys() 75 | 76 | 77 | missing = [] 78 | text_encoder_sd = self.text_encoder.state_dict() 79 | for k in text_encoder_sd.keys(): 80 | sd_k = "cond_stage_model."+ k 81 | if sd_k in sd_keys: 82 | text_encoder_sd[k] = sd[sd_k] 83 | else: 84 | missing.append(k) 85 | 86 | self.text_encoder.load_state_dict(text_encoder_sd) 87 | 88 | 89 | 90 | 91 | def configure_optimizers(self): 92 | lr = self.learning_rate 93 | params = [] 94 | params += list(self.cc_projection.parameters()) 95 | 96 | 97 | params_dualattn = [] 98 | for k, v in self.model.named_parameters(): 99 | if "to_k_text" in k or "to_v_text" in k: 100 | params_dualattn.append(v) 101 | print("training weight: ", k) 102 | else: 103 | params.append(v) 104 | 105 | 106 | opt = torch.optim.AdamW([ 107 | {'params':params_dualattn, 'lr': lr*self.learning_rate_scale}, 108 | {'params': params, 'lr': lr} 109 | ]) 110 | 111 | 112 | if self.use_scheduler: 113 | assert 'target' in self.scheduler_config 114 | scheduler = instantiate_from_config(self.scheduler_config) 115 | 116 | print("Setting up LambdaLR scheduler...") 117 | scheduler = [ 118 | { 119 | 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), 120 | 'interval': 'step', 121 | 'frequency': 1 122 | }] 123 | return [opt], scheduler 124 | return opt 125 | 126 | def training_step(self, batch, batch_idx): 127 | loss, loss_dict = self.shared_step(batch) 128 | 129 | self.log_dict(loss_dict, prog_bar=True, 130 | logger=True, on_step=True, on_epoch=True) 131 | 132 | self.log("global_step", self.global_step, 133 | prog_bar=True, logger=True, on_step=True, on_epoch=False) 134 | 135 | if self.use_scheduler: 136 | lr = self.optimizers().param_groups[0]['lr'] 137 | self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) 138 | 139 | return loss 140 | 141 | 142 | 143 | 144 | def shared_step(self, batch, **kwargs): 145 | 146 | if 'txt' in self.ucg_training: 147 | k = 'txt' 148 | p = self.ucg_training[k] 149 | for i in range(len(batch[k])): 150 | if self.ucg_prng.choice(2, p=[1 - p, p]): 151 | if isinstance(batch[k], list): 152 | batch[k][i] = "" 153 | 154 | with torch.no_grad(): 155 | text = batch['txt'] 156 | text_embedding = self.text_encoder(text) 157 | 158 | 159 | x, c = self.get_input(batch, self.first_stage_key) 160 | 161 | c["c_crossattn"].append(text_embedding) 162 | loss = self(x, c,) 163 | return loss 164 | 165 | 166 | def apply_model(self, x_noisy, t, cond, return_ids=False, **kwargs): 167 | 168 | if isinstance(cond, dict): 169 | # hybrid case, cond is exptected to be a dict 170 | pass 171 | else: 172 | if not isinstance(cond, list): 173 | cond = [cond] 174 | key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' 175 | cond = {key: cond} 176 | 177 | x_recon = self.model(x_noisy, t, **cond) 178 | 179 | if isinstance(x_recon, tuple) and not return_ids: 180 | return x_recon[0] 181 | else: 182 | return x_recon 183 | -------------------------------------------------------------------------------- /custom_net/customnet_inpaint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import einops 5 | import torch 6 | import torch as th 7 | import torch.nn as nn 8 | import cv2 9 | from folder_paths import base_path 10 | sys.path.append(os.path.join(base_path,"custom_nodes","ComfyUI_CustomNet")) 11 | try: 12 | from pytorch_lightning.utilities.rank_zero import rank_zero_only 13 | except: 14 | try: 15 | from pytorch_lightning.utilities.distributed import rank_zero_only 16 | except: 17 | raise "import pytorch_lightning rank_zero_only error" 18 | import numpy as np 19 | from torch.optim.lr_scheduler import LambdaLR 20 | 21 | from .util import ( 22 | conv_nd, 23 | linear, 24 | zero_module, 25 | timestep_embedding, 26 | ) 27 | 28 | from einops import rearrange, repeat 29 | from torchvision.utils import make_grid 30 | from .attention import SpatialTransformer 31 | from .openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock 32 | from .ddpm import LatentDiffusion 33 | from .customnet_util import log_txt_as_img, exists, instantiate_from_config 34 | 35 | from .dddim import DDIMSampler 36 | from .customnet_util import load_state_dict 37 | 38 | 39 | class CustomNet(LatentDiffusion): 40 | def __init__(self, 41 | text_encoder_config, 42 | sd_15_ckpt=None, 43 | use_cond_concat=False, 44 | use_bbox_mask=False, 45 | use_bg_inpainting=False, 46 | learning_rate_scale=10, 47 | *args, **kwargs): 48 | super().__init__(*args, **kwargs) 49 | 50 | 51 | self.text_encoder = instantiate_from_config(text_encoder_config) 52 | 53 | if sd_15_ckpt is not None: 54 | self.load_model_from_ckpt(ckpt=sd_15_ckpt) 55 | 56 | self.use_cond_concat = use_cond_concat 57 | self.use_bbox_mask = use_bbox_mask 58 | self.use_bg_inpainting = use_bg_inpainting 59 | self.learning_rate_scale = learning_rate_scale 60 | 61 | 62 | def load_model_from_ckpt(self, ckpt, verbose=True): 63 | print(" =========================== init Stable Diffusion pretrained checkpoint =========================== ") 64 | print(f"Loading model from {ckpt}") 65 | pl_sd = torch.load(ckpt, map_location="cpu") 66 | if "global_step" in pl_sd: 67 | print(f"Global Step: {pl_sd['global_step']}") 68 | sd = pl_sd["state_dict"] 69 | sd_keys = sd.keys() 70 | 71 | 72 | missing = [] 73 | text_encoder_sd = self.text_encoder.state_dict() 74 | for k in text_encoder_sd.keys(): 75 | sd_k = "cond_stage_model."+ k 76 | if sd_k in sd_keys: 77 | text_encoder_sd[k] = sd[sd_k] 78 | else: 79 | missing.append(k) 80 | 81 | self.text_encoder.load_state_dict(text_encoder_sd) 82 | 83 | 84 | 85 | 86 | def configure_optimizers(self): 87 | lr = self.learning_rate 88 | params = [] 89 | params += list(self.cc_projection.parameters()) 90 | 91 | 92 | params_dualattn = [] 93 | for k, v in self.model.named_parameters(): 94 | if "to_k_text" in k or "to_v_text" in k: 95 | params_dualattn.append(v) 96 | print("training weight: ", k) 97 | else: 98 | params.append(v) 99 | 100 | 101 | opt = torch.optim.AdamW([ 102 | {'params':params_dualattn, 'lr': lr*self.learning_rate_scale}, 103 | {'params': params, 'lr': lr} 104 | ]) 105 | 106 | 107 | if self.use_scheduler: 108 | assert 'target' in self.scheduler_config 109 | scheduler = instantiate_from_config(self.scheduler_config) 110 | 111 | print("Setting up LambdaLR scheduler...") 112 | scheduler = [ 113 | { 114 | 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), 115 | 'interval': 'step', 116 | 'frequency': 1 117 | }] 118 | return [opt], scheduler 119 | return opt 120 | 121 | def training_step(self, batch, batch_idx): 122 | loss, loss_dict = self.shared_step(batch) 123 | 124 | self.log_dict(loss_dict, prog_bar=True, 125 | logger=True, on_step=True, on_epoch=True) 126 | 127 | self.log("global_step", self.global_step, 128 | prog_bar=True, logger=True, on_step=True, on_epoch=False) 129 | 130 | if self.use_scheduler: 131 | lr = self.optimizers().param_groups[0]['lr'] 132 | self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) 133 | 134 | return loss 135 | 136 | 137 | 138 | 139 | def shared_step(self, batch, **kwargs): 140 | 141 | if 'txt' in self.ucg_training: 142 | k = 'txt' 143 | p = self.ucg_training[k] 144 | for i in range(len(batch[k])): 145 | if self.ucg_prng.choice(2, p=[1 - p, p]): 146 | if isinstance(batch[k], list): 147 | batch[k][i] = "" 148 | 149 | with torch.no_grad(): 150 | text = batch['txt'] 151 | text_embedding = self.text_encoder(text) 152 | 153 | 154 | x, c = self.get_input(batch, self.first_stage_key) 155 | 156 | c["c_crossattn"].append(text_embedding) 157 | loss = self(x, c,) 158 | return loss 159 | 160 | 161 | def apply_model(self, x_noisy, t, cond, return_ids=False, **kwargs): 162 | 163 | if isinstance(cond, dict): 164 | # hybrid case, cond is exptected to be a dict 165 | pass 166 | else: 167 | if not isinstance(cond, list): 168 | cond = [cond] 169 | key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' 170 | cond = {key: cond} 171 | 172 | x_recon = self.model(x_noisy, t, **cond) 173 | 174 | if isinstance(x_recon, tuple) and not return_ids: 175 | return x_recon[0] 176 | else: 177 | return x_recon -------------------------------------------------------------------------------- /custom_net/customnet_util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torchvision 4 | from torch import optim 5 | 6 | from inspect import isfunction 7 | from PIL import Image, ImageDraw, ImageFont 8 | import os 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | import torch 12 | import time 13 | import cv2 14 | from carvekit.api.high import HiInterface 15 | import PIL 16 | import sys 17 | import os 18 | 19 | path = os.path.dirname(os.path.abspath(__file__)) 20 | from folder_paths import base_path 21 | paths = os.path.join(base_path,"custom_nodes","ComfyUI_CustomNet") 22 | # print(paths) 23 | sys.path.append(paths) 24 | 25 | def pil_rectangle_crop(im): 26 | width, height = im.size # Get dimensions 27 | 28 | if width <= height: 29 | left = 0 30 | right = width 31 | top = (height - width)/2 32 | bottom = (height + width)/2 33 | else: 34 | 35 | top = 0 36 | bottom = height 37 | left = (width - height) / 2 38 | bottom = (width + height) / 2 39 | 40 | # Crop the center of the image 41 | im = im.crop((left, top, right, bottom)) 42 | return im 43 | 44 | def add_margin(pil_img, color, size=256): 45 | width, height = pil_img.size 46 | result = Image.new(pil_img.mode, (size, size), color) 47 | result.paste(pil_img, ((size - width) // 2, (size - height) // 2)) 48 | return result 49 | 50 | 51 | def create_carvekit_interface(): 52 | # Check doc strings for more information 53 | interface = HiInterface(object_type="object", # Can be "object" or "hairs-like". 54 | batch_size_seg=5, 55 | batch_size_matting=1, 56 | device='cuda' if torch.cuda.is_available() else 'cpu', 57 | seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net 58 | matting_mask_size=2048, 59 | trimap_prob_threshold=231, 60 | trimap_dilation=30, 61 | trimap_erosion_iters=5, 62 | fp16=False) 63 | 64 | return interface 65 | 66 | 67 | def load_and_preprocess(interface, input_im): 68 | ''' 69 | :param input_im (PIL Image). 70 | :return image (H, W, 3) array in [0, 1]. 71 | ''' 72 | # See https://github.com/Ir1d/image-background-remove-tool 73 | image = input_im.convert('RGB') 74 | 75 | image_without_background = interface([image])[0] 76 | image_without_background = np.array(image_without_background) 77 | est_seg = image_without_background > 127 78 | image = np.array(image) 79 | foreground = est_seg[:, : , -1].astype(np.bool_) 80 | image[~foreground] = [255., 255., 255.] 81 | x, y, w, h = cv2.boundingRect(foreground.astype(np.uint8)) 82 | image = image[y:y+h, x:x+w, :] 83 | image = PIL.Image.fromarray(np.array(image)) 84 | # resize image such that long edge is 512 85 | image.thumbnail([200, 200], Image.Resampling.LANCZOS) 86 | image = add_margin(image, (255, 255, 255), size=256) 87 | #image = np.array(image) 88 | return image 89 | 90 | 91 | def log_txt_as_img(wh, xc, size=10): 92 | # wh a tuple of (width, height) 93 | # xc a list of captions to plot 94 | b = len(xc) 95 | txts = list() 96 | for bi in range(b): 97 | txt = Image.new("RGB", wh, color="white") 98 | draw = ImageDraw.Draw(txt) 99 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 100 | nc = int(40 * (wh[0] / 256)) 101 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 102 | 103 | try: 104 | draw.text((0, 0), lines, fill="black", font=font) 105 | except UnicodeEncodeError: 106 | print("Cant encode string for logging. Skipping.") 107 | 108 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 109 | txts.append(txt) 110 | txts = np.stack(txts) 111 | txts = torch.tensor(txts) 112 | return txts 113 | 114 | 115 | def ismap(x): 116 | if not isinstance(x, torch.Tensor): 117 | return False 118 | return (len(x.shape) == 4) and (x.shape[1] > 3) 119 | 120 | 121 | def isimage(x): 122 | if not isinstance(x,torch.Tensor): 123 | return False 124 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 125 | 126 | 127 | def exists(x): 128 | return x is not None 129 | 130 | 131 | def default(val, d): 132 | if exists(val): 133 | return val 134 | return d() if isfunction(d) else d 135 | 136 | 137 | def mean_flat(tensor): 138 | """ 139 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 140 | Take the mean over all non-batch dimensions. 141 | """ 142 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 143 | 144 | 145 | def count_params(model, verbose=False): 146 | total_params = sum(p.numel() for p in model.parameters()) 147 | if verbose: 148 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") 149 | return total_params 150 | 151 | 152 | def instantiate_from_config(config): 153 | if not "target" in config: 154 | if config == '__is_first_stage__': 155 | return None 156 | elif config == "__is_unconditional__": 157 | return None 158 | raise KeyError("Expected key `target` to instantiate.") 159 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 160 | 161 | 162 | def get_obj_from_str(string, reload=False): 163 | module, cls = string.rsplit(".", 1) 164 | print(module,cls) 165 | if reload: 166 | module_imp = importlib.import_module(module) 167 | importlib.reload(module_imp) 168 | return getattr(importlib.import_module(module, package=None), cls) 169 | 170 | 171 | class AdamWwithEMAandWings(optim.Optimizer): 172 | # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 173 | def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using 174 | weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code 175 | ema_power=1., param_names=()): 176 | """AdamW that saves EMA versions of the parameters.""" 177 | if not 0.0 <= lr: 178 | raise ValueError("Invalid learning rate: {}".format(lr)) 179 | if not 0.0 <= eps: 180 | raise ValueError("Invalid epsilon value: {}".format(eps)) 181 | if not 0.0 <= betas[0] < 1.0: 182 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 183 | if not 0.0 <= betas[1] < 1.0: 184 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 185 | if not 0.0 <= weight_decay: 186 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 187 | if not 0.0 <= ema_decay <= 1.0: 188 | raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) 189 | defaults = dict(lr=lr, betas=betas, eps=eps, 190 | weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, 191 | ema_power=ema_power, param_names=param_names) 192 | super().__init__(params, defaults) 193 | 194 | def __setstate__(self, state): 195 | super().__setstate__(state) 196 | for group in self.param_groups: 197 | group.setdefault('amsgrad', False) 198 | 199 | @torch.no_grad() 200 | def step(self, closure=None): 201 | """Performs a single optimization step. 202 | Args: 203 | closure (callable, optional): A closure that reevaluates the model 204 | and returns the loss. 205 | """ 206 | loss = None 207 | if closure is not None: 208 | with torch.enable_grad(): 209 | loss = closure() 210 | 211 | for group in self.param_groups: 212 | params_with_grad = [] 213 | grads = [] 214 | exp_avgs = [] 215 | exp_avg_sqs = [] 216 | ema_params_with_grad = [] 217 | state_sums = [] 218 | max_exp_avg_sqs = [] 219 | state_steps = [] 220 | amsgrad = group['amsgrad'] 221 | beta1, beta2 = group['betas'] 222 | ema_decay = group['ema_decay'] 223 | ema_power = group['ema_power'] 224 | 225 | for p in group['params']: 226 | if p.grad is None: 227 | continue 228 | params_with_grad.append(p) 229 | if p.grad.is_sparse: 230 | raise RuntimeError('AdamW does not support sparse gradients') 231 | grads.append(p.grad) 232 | 233 | state = self.state[p] 234 | 235 | # State initialization 236 | if len(state) == 0: 237 | state['step'] = 0 238 | # Exponential moving average of gradient values 239 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 240 | # Exponential moving average of squared gradient values 241 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 242 | if amsgrad: 243 | # Maintains max of all exp. moving avg. of sq. grad. values 244 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 245 | # Exponential moving average of parameter values 246 | state['param_exp_avg'] = p.detach().float().clone() 247 | 248 | exp_avgs.append(state['exp_avg']) 249 | exp_avg_sqs.append(state['exp_avg_sq']) 250 | ema_params_with_grad.append(state['param_exp_avg']) 251 | 252 | if amsgrad: 253 | max_exp_avg_sqs.append(state['max_exp_avg_sq']) 254 | 255 | # update the steps for each param group update 256 | state['step'] += 1 257 | # record the step after step update 258 | state_steps.append(state['step']) 259 | 260 | optim._functional.adamw(params_with_grad, 261 | grads, 262 | exp_avgs, 263 | exp_avg_sqs, 264 | max_exp_avg_sqs, 265 | state_steps, 266 | amsgrad=amsgrad, 267 | beta1=beta1, 268 | beta2=beta2, 269 | lr=group['lr'], 270 | weight_decay=group['weight_decay'], 271 | eps=group['eps'], 272 | maximize=False) 273 | 274 | cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) 275 | for param, ema_param in zip(params_with_grad, ema_params_with_grad): 276 | ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) 277 | 278 | return loss 279 | 280 | def get_state_dict(d): 281 | return d.get('state_dict', d) 282 | 283 | def load_state_dict(ckpt_path, location='cpu'): 284 | _, extension = os.path.splitext(ckpt_path) 285 | if extension.lower() == ".safetensors": 286 | import safetensors.torch 287 | state_dict = safetensors.torch.load_file(ckpt_path, device=location) 288 | else: 289 | state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location))) 290 | state_dict = get_state_dict(state_dict) 291 | print(f'Loaded state_dict from [{ckpt_path}]') 292 | return state_dict 293 | 294 | 295 | def img2tensor(imgs, bgr2rgb=True, float32=True): 296 | """Numpy array to tensor. 297 | 298 | Args: 299 | imgs (list[ndarray] | ndarray): Input images. 300 | bgr2rgb (bool): Whether to change bgr to rgb. 301 | float32 (bool): Whether to change to float32. 302 | 303 | Returns: 304 | list[tensor] | tensor: Tensor images. If returned results only have 305 | one element, just return tensor. 306 | """ 307 | 308 | def _totensor(img, bgr2rgb, float32): 309 | if img.shape[2] == 3 and bgr2rgb: 310 | if img.dtype == 'float64': 311 | img = img.astype('float32') 312 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 313 | img = torch.from_numpy(img.transpose(2, 0, 1)) 314 | if float32: 315 | img = img.float() 316 | return img 317 | 318 | if isinstance(imgs, list): 319 | return [_totensor(img, bgr2rgb, float32) for img in imgs] 320 | else: 321 | return _totensor(imgs, bgr2rgb, float32) 322 | -------------------------------------------------------------------------------- /custom_net/dddim.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | from functools import partial 7 | from einops import rearrange 8 | 9 | from .util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor 10 | from .sampling_util import renorm_thresholding, norm_thresholding, spatial_norm_thresholding 11 | 12 | 13 | class DDIMSampler(object): 14 | def __init__(self, model, schedule="linear", **kwargs): 15 | super().__init__() 16 | self.model = model 17 | self.ddpm_num_timesteps = model.num_timesteps 18 | self.schedule = schedule 19 | 20 | def to(self, device): 21 | """Same as to in torch module 22 | Don't really underestand why this isn't a module in the first place""" 23 | for k, v in self.__dict__.items(): 24 | if isinstance(v, torch.Tensor): 25 | new_v = getattr(self, k).to(device) 26 | setattr(self, k, new_v) 27 | 28 | 29 | def register_buffer(self, name, attr): 30 | if type(attr) == torch.Tensor: 31 | if attr.device != torch.device("cuda"): 32 | attr = attr.to(torch.device("cuda")) 33 | setattr(self, name, attr) 34 | 35 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 36 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 37 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 38 | alphas_cumprod = self.model.alphas_cumprod 39 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 40 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 41 | 42 | self.register_buffer('betas', to_torch(self.model.betas)) 43 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 44 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 45 | 46 | # calculations for diffusion q(x_t | x_{t-1}) and others 47 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 48 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 49 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 50 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 51 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 52 | 53 | # ddim sampling parameters 54 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 55 | ddim_timesteps=self.ddim_timesteps, 56 | eta=ddim_eta,verbose=verbose) 57 | self.register_buffer('ddim_sigmas', ddim_sigmas) 58 | self.register_buffer('ddim_alphas', ddim_alphas) 59 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 60 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 61 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 62 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 63 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 64 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 65 | 66 | @torch.no_grad() 67 | def sample(self, 68 | S, 69 | batch_size, 70 | shape, 71 | conditioning=None, 72 | callback=None, 73 | normals_sequence=None, 74 | img_callback=None, 75 | quantize_x0=False, 76 | eta=0., 77 | mask=None, 78 | x0=None, 79 | temperature=1., 80 | noise_dropout=0., 81 | score_corrector=None, 82 | corrector_kwargs=None, 83 | verbose=True, 84 | x_T=None, 85 | log_every_t=100, 86 | unconditional_guidance_scale=1., 87 | unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 88 | dynamic_threshold=None, 89 | **kwargs 90 | ): 91 | if conditioning is not None: 92 | if isinstance(conditioning, dict): 93 | ctmp = conditioning[list(conditioning.keys())[0]] 94 | while isinstance(ctmp, list): ctmp = ctmp[0] 95 | cbs = ctmp.shape[0] 96 | if cbs != batch_size: 97 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 98 | 99 | else: 100 | if conditioning.shape[0] != batch_size: 101 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 102 | 103 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) 104 | # sampling 105 | C, H, W = shape 106 | size = (batch_size, C, H, W) 107 | print(f'Data shape for DDIM sampling is {size}, eta {eta}') 108 | 109 | samples, intermediates = self.ddim_sampling(conditioning, size, 110 | callback=callback, 111 | img_callback=img_callback, 112 | quantize_denoised=quantize_x0, 113 | mask=mask, x0=x0, 114 | ddim_use_original_steps=False, 115 | noise_dropout=noise_dropout, 116 | temperature=temperature, 117 | score_corrector=score_corrector, 118 | corrector_kwargs=corrector_kwargs, 119 | x_T=x_T, 120 | log_every_t=log_every_t, 121 | unconditional_guidance_scale=unconditional_guidance_scale, 122 | unconditional_conditioning=unconditional_conditioning, 123 | dynamic_threshold=dynamic_threshold, 124 | ) 125 | return samples, intermediates 126 | 127 | @torch.no_grad() 128 | def ddim_sampling(self, cond, shape, 129 | x_T=None, ddim_use_original_steps=False, 130 | callback=None, timesteps=None, quantize_denoised=False, 131 | mask=None, x0=None, img_callback=None, log_every_t=100, 132 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 133 | unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None, 134 | t_start=-1): 135 | device = self.model.betas.device 136 | b = shape[0] 137 | if x_T is None: 138 | img = torch.randn(shape, device=device) 139 | else: 140 | img = x_T 141 | 142 | if timesteps is None: 143 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 144 | elif timesteps is not None and not ddim_use_original_steps: 145 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 146 | timesteps = self.ddim_timesteps[:subset_end] 147 | 148 | timesteps = timesteps[:t_start] 149 | 150 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 151 | time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) 152 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 153 | print(f"Running DDIM Sampling with {total_steps} timesteps") 154 | 155 | iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) 156 | 157 | for i, step in enumerate(iterator): 158 | index = total_steps - i - 1 159 | ts = torch.full((b,), step, device=device, dtype=torch.long) 160 | 161 | if mask is not None: 162 | assert x0 is not None 163 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 164 | img = img_orig * mask + (1. - mask) * img 165 | 166 | outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 167 | quantize_denoised=quantize_denoised, temperature=temperature, 168 | noise_dropout=noise_dropout, score_corrector=score_corrector, 169 | corrector_kwargs=corrector_kwargs, 170 | unconditional_guidance_scale=unconditional_guidance_scale, 171 | unconditional_conditioning=unconditional_conditioning, 172 | dynamic_threshold=dynamic_threshold) 173 | img, pred_x0 = outs 174 | if callback: 175 | img = callback(i, img, pred_x0) 176 | if img_callback: img_callback(pred_x0, i) 177 | 178 | if index % log_every_t == 0 or index == total_steps - 1: 179 | intermediates['x_inter'].append(img) 180 | intermediates['pred_x0'].append(pred_x0) 181 | 182 | return img, intermediates 183 | 184 | @torch.no_grad() 185 | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 186 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 187 | unconditional_guidance_scale=1., unconditional_conditioning=None, 188 | dynamic_threshold=None, **kwargs): 189 | b, *_, device = *x.shape, x.device 190 | 191 | 192 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 193 | e_t = self.model.apply_model(x, t, c) 194 | else: 195 | x_in = torch.cat([x] * 2) 196 | t_in = torch.cat([t] * 2) 197 | if isinstance(c, dict): 198 | assert isinstance(unconditional_conditioning, dict) 199 | c_in = dict() 200 | for k in c: 201 | if isinstance(c[k], list): 202 | c_in[k] = [torch.cat([ 203 | unconditional_conditioning[k][i], 204 | c[k][i]]) for i in range(len(c[k]))] 205 | else: 206 | c_in[k] = torch.cat([ 207 | unconditional_conditioning[k], 208 | c[k]]) 209 | else: 210 | c_in = torch.cat([unconditional_conditioning, c]) 211 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) 212 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) 213 | 214 | if score_corrector is not None: 215 | assert self.model.parameterization == "eps" 216 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 217 | 218 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 219 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 220 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 221 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 222 | # select parameters corresponding to the currently considered timestep 223 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) 224 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) 225 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) 226 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) 227 | 228 | # current prediction for x_0 229 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 230 | if quantize_denoised: 231 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 232 | 233 | if dynamic_threshold is not None: 234 | pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) 235 | 236 | # direction pointing to x_t 237 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 238 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 239 | if noise_dropout > 0.: 240 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 241 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 242 | return x_prev, pred_x0 243 | 244 | @torch.no_grad() 245 | def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None, 246 | unconditional_guidance_scale=1.0, unconditional_conditioning=None): 247 | num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0] 248 | 249 | assert t_enc <= num_reference_steps 250 | num_steps = t_enc 251 | 252 | if use_original_steps: 253 | alphas_next = self.alphas_cumprod[:num_steps] 254 | alphas = self.alphas_cumprod_prev[:num_steps] 255 | else: 256 | alphas_next = self.ddim_alphas[:num_steps] 257 | alphas = torch.tensor(self.ddim_alphas_prev[:num_steps]) 258 | 259 | x_next = x0 260 | intermediates = [] 261 | inter_steps = [] 262 | for i in tqdm(range(num_steps), desc='Encoding Image'): 263 | t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long) 264 | if unconditional_guidance_scale == 1.: 265 | noise_pred = self.model.apply_model(x_next, t, c) 266 | else: 267 | assert unconditional_conditioning is not None 268 | e_t_uncond, noise_pred = torch.chunk( 269 | self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)), 270 | torch.cat((unconditional_conditioning, c))), 2) 271 | noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond) 272 | 273 | xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next 274 | weighted_noise_pred = alphas_next[i].sqrt() * ( 275 | (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred 276 | x_next = xt_weighted + weighted_noise_pred 277 | if return_intermediates and i % ( 278 | num_steps // return_intermediates) == 0 and i < num_steps - 1: 279 | intermediates.append(x_next) 280 | inter_steps.append(i) 281 | elif return_intermediates and i >= num_steps - 2: 282 | intermediates.append(x_next) 283 | inter_steps.append(i) 284 | 285 | out = {'x_encoded': x_next, 'intermediate_steps': inter_steps} 286 | if return_intermediates: 287 | out.update({'intermediates': intermediates}) 288 | return x_next, out 289 | 290 | @torch.no_grad() 291 | def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): 292 | # fast, but does not allow for exact reconstruction 293 | # t serves as an index to gather the correct alphas 294 | if use_original_steps: 295 | sqrt_alphas_cumprod = self.sqrt_alphas_cumprod 296 | sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod 297 | else: 298 | sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) 299 | sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas 300 | 301 | if noise is None: 302 | noise = torch.randn_like(x0) 303 | return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + 304 | extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) 305 | 306 | @torch.no_grad() 307 | def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, 308 | use_original_steps=False): 309 | 310 | timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps 311 | timesteps = timesteps[:t_start] 312 | 313 | time_range = np.flip(timesteps) 314 | total_steps = timesteps.shape[0] 315 | print(f"Running DDIM Sampling with {total_steps} timesteps") 316 | 317 | iterator = tqdm(time_range, desc='Decoding image', total=total_steps) 318 | x_dec = x_latent 319 | for i, step in enumerate(iterator): 320 | index = total_steps - i - 1 321 | ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) 322 | x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, 323 | unconditional_guidance_scale=unconditional_guidance_scale, 324 | unconditional_conditioning=unconditional_conditioning) 325 | return x_dec -------------------------------------------------------------------------------- /custom_net/ddim.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | from functools import partial 7 | from einops import rearrange 8 | 9 | from .util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor 10 | from .sampling_util import renorm_thresholding, norm_thresholding, spatial_norm_thresholding 11 | 12 | 13 | class DDIMSampler(object): 14 | def __init__(self, model, schedule="linear", **kwargs): 15 | super().__init__() 16 | self.model = model 17 | self.ddpm_num_timesteps = model.num_timesteps 18 | self.schedule = schedule 19 | 20 | def to(self, device): 21 | """Same as to in torch module 22 | Don't really underestand why this isn't a module in the first place""" 23 | for k, v in self.__dict__.items(): 24 | if isinstance(v, torch.Tensor): 25 | new_v = getattr(self, k).to(device) 26 | setattr(self, k, new_v) 27 | 28 | 29 | def register_buffer(self, name, attr): 30 | if type(attr) == torch.Tensor: 31 | if attr.device != torch.device("cuda"): 32 | attr = attr.to(torch.device("cuda")) 33 | setattr(self, name, attr) 34 | 35 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 36 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 37 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 38 | alphas_cumprod = self.model.alphas_cumprod 39 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 40 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 41 | 42 | self.register_buffer('betas', to_torch(self.model.betas)) 43 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 44 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 45 | 46 | # calculations for diffusion q(x_t | x_{t-1}) and others 47 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 48 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 49 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 50 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 51 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 52 | 53 | # ddim sampling parameters 54 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 55 | ddim_timesteps=self.ddim_timesteps, 56 | eta=ddim_eta,verbose=verbose) 57 | self.register_buffer('ddim_sigmas', ddim_sigmas) 58 | self.register_buffer('ddim_alphas', ddim_alphas) 59 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 60 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 61 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 62 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 63 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 64 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 65 | 66 | @torch.no_grad() 67 | def sample(self, 68 | S, 69 | batch_size, 70 | shape, 71 | conditioning=None, 72 | callback=None, 73 | normals_sequence=None, 74 | img_callback=None, 75 | quantize_x0=False, 76 | eta=0., 77 | mask=None, 78 | x0=None, 79 | temperature=1., 80 | noise_dropout=0., 81 | score_corrector=None, 82 | corrector_kwargs=None, 83 | verbose=True, 84 | x_T=None, 85 | log_every_t=100, 86 | unconditional_guidance_scale=1., 87 | unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 88 | dynamic_threshold=None, 89 | **kwargs 90 | ): 91 | if conditioning is not None: 92 | if isinstance(conditioning, dict): 93 | ctmp = conditioning[list(conditioning.keys())[0]] 94 | while isinstance(ctmp, list): ctmp = ctmp[0] 95 | cbs = ctmp.shape[0] 96 | if cbs != batch_size: 97 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 98 | 99 | else: 100 | if conditioning.shape[0] != batch_size: 101 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 102 | 103 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) 104 | # sampling 105 | C, H, W = shape 106 | size = (batch_size, C, H, W) 107 | print(f'Data shape for DDIM sampling is {size}, eta {eta}') 108 | 109 | samples, intermediates = self.ddim_sampling(conditioning, size, 110 | callback=callback, 111 | img_callback=img_callback, 112 | quantize_denoised=quantize_x0, 113 | mask=mask, x0=x0, 114 | ddim_use_original_steps=False, 115 | noise_dropout=noise_dropout, 116 | temperature=temperature, 117 | score_corrector=score_corrector, 118 | corrector_kwargs=corrector_kwargs, 119 | x_T=x_T, 120 | log_every_t=log_every_t, 121 | unconditional_guidance_scale=unconditional_guidance_scale, 122 | unconditional_conditioning=unconditional_conditioning, 123 | dynamic_threshold=dynamic_threshold, 124 | **kwargs) 125 | return samples, intermediates 126 | 127 | @torch.no_grad() 128 | def ddim_sampling(self, cond, shape, 129 | x_T=None, ddim_use_original_steps=False, 130 | callback=None, timesteps=None, quantize_denoised=False, 131 | mask=None, x0=None, img_callback=None, log_every_t=100, 132 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 133 | unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None, 134 | t_start=-1, 135 | **kwargs): 136 | device = self.model.betas.device 137 | b = shape[0] 138 | if x_T is None: 139 | img = torch.randn(shape, device=device) 140 | else: 141 | img = x_T 142 | 143 | if timesteps is None: 144 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 145 | elif timesteps is not None and not ddim_use_original_steps: 146 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 147 | timesteps = self.ddim_timesteps[:subset_end] 148 | 149 | timesteps = timesteps[:t_start] 150 | 151 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 152 | time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) 153 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 154 | print(f"Running DDIM Sampling with {total_steps} timesteps") 155 | 156 | iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) 157 | 158 | for i, step in enumerate(iterator): 159 | index = total_steps - i - 1 160 | ts = torch.full((b,), step, device=device, dtype=torch.long) 161 | 162 | if mask is not None: 163 | assert x0 is not None 164 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 165 | img = img_orig * mask + (1. - mask) * img 166 | 167 | outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 168 | quantize_denoised=quantize_denoised, temperature=temperature, 169 | noise_dropout=noise_dropout, score_corrector=score_corrector, 170 | corrector_kwargs=corrector_kwargs, 171 | unconditional_guidance_scale=unconditional_guidance_scale, 172 | unconditional_conditioning=unconditional_conditioning, 173 | dynamic_threshold=dynamic_threshold, 174 | **kwargs) 175 | img, pred_x0 = outs 176 | if callback: 177 | img = callback(i, img, pred_x0) 178 | if img_callback: img_callback(pred_x0, i) 179 | 180 | if index % log_every_t == 0 or index == total_steps - 1: 181 | intermediates['x_inter'].append(img) 182 | intermediates['pred_x0'].append(pred_x0) 183 | 184 | return img, intermediates 185 | 186 | @torch.no_grad() 187 | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 188 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 189 | unconditional_guidance_scale=1., unconditional_conditioning=None, 190 | dynamic_threshold=None, **kwargs): 191 | b, *_, device = *x.shape, x.device 192 | 193 | 194 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 195 | e_t = self.model.apply_model(x, t, c) 196 | else: 197 | cfg_type = kwargs.get("cfg_type", 0) 198 | cfg = kwargs.get("cfg_scale_dict", {"img": 0., "text":0., "all": 3.0 }) 199 | truncate_step = kwargs.get("cfg_type", 250) ## 250 -1 200 | 201 | if cfg_type == 0 : 202 | model_t = self.model.apply_model(x, t, c) 203 | 204 | model_uc = self.model.apply_model(x, t, unconditional_conditioning) 205 | 206 | e_t = model_uc + cfg["all"]*(model_t - model_uc) 207 | 208 | elif cfg_type==1: 209 | 210 | model_t = self.model.apply_model(x, t, c) 211 | 212 | c_text = {"c_concat": unconditional_conditioning["c_concat"], 213 | "c_crossattn": [unconditional_conditioning["c_crossattn"][0], 214 | c["c_crossattn"][1],]} 215 | 216 | c_img = {"c_concat": c["c_concat"], 217 | "c_crossattn": [c["c_crossattn"][0], 218 | unconditional_conditioning["c_crossattn"][1],]} 219 | 220 | model_t_text = self.model.apply_model(x, t, c_text) 221 | model_t_img = self.model.apply_model(x, t, c_img) 222 | model_uc = self.model.apply_model(x, t, unconditional_conditioning) 223 | 224 | e_t = model_uc + cfg["all"]*(model_t - model_uc) 225 | if t[0] > truncate_step: 226 | print(1, t[0]) 227 | e_t += cfg["img"] * (model_t_img - model_uc) 228 | e_t += cfg["text"] * (model_t_text - model_uc) 229 | 230 | 231 | if score_corrector is not None: 232 | assert self.model.parameterization == "eps" 233 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 234 | 235 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 236 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 237 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 238 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 239 | # select parameters corresponding to the currently considered timestep 240 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) 241 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) 242 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) 243 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) 244 | 245 | # current prediction for x_0 246 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 247 | if quantize_denoised: 248 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 249 | 250 | if dynamic_threshold is not None: 251 | pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) 252 | 253 | # direction pointing to x_t 254 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 255 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 256 | if noise_dropout > 0.: 257 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 258 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 259 | return x_prev, pred_x0 260 | 261 | @torch.no_grad() 262 | def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None, 263 | unconditional_guidance_scale=1.0, unconditional_conditioning=None): 264 | num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0] 265 | 266 | assert t_enc <= num_reference_steps 267 | num_steps = t_enc 268 | 269 | if use_original_steps: 270 | alphas_next = self.alphas_cumprod[:num_steps] 271 | alphas = self.alphas_cumprod_prev[:num_steps] 272 | else: 273 | alphas_next = self.ddim_alphas[:num_steps] 274 | alphas = torch.tensor(self.ddim_alphas_prev[:num_steps]) 275 | 276 | x_next = x0 277 | intermediates = [] 278 | inter_steps = [] 279 | for i in tqdm(range(num_steps), desc='Encoding Image'): 280 | t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long) 281 | if unconditional_guidance_scale == 1.: 282 | noise_pred = self.model.apply_model(x_next, t, c) 283 | else: 284 | assert unconditional_conditioning is not None 285 | e_t_uncond, noise_pred = torch.chunk( 286 | self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)), 287 | torch.cat((unconditional_conditioning, c))), 2) 288 | noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond) 289 | 290 | xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next 291 | weighted_noise_pred = alphas_next[i].sqrt() * ( 292 | (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred 293 | x_next = xt_weighted + weighted_noise_pred 294 | if return_intermediates and i % ( 295 | num_steps // return_intermediates) == 0 and i < num_steps - 1: 296 | intermediates.append(x_next) 297 | inter_steps.append(i) 298 | elif return_intermediates and i >= num_steps - 2: 299 | intermediates.append(x_next) 300 | inter_steps.append(i) 301 | 302 | out = {'x_encoded': x_next, 'intermediate_steps': inter_steps} 303 | if return_intermediates: 304 | out.update({'intermediates': intermediates}) 305 | return x_next, out 306 | 307 | @torch.no_grad() 308 | def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): 309 | # fast, but does not allow for exact reconstruction 310 | # t serves as an index to gather the correct alphas 311 | if use_original_steps: 312 | sqrt_alphas_cumprod = self.sqrt_alphas_cumprod 313 | sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod 314 | else: 315 | sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) 316 | sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas 317 | 318 | if noise is None: 319 | noise = torch.randn_like(x0) 320 | return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + 321 | extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) 322 | 323 | @torch.no_grad() 324 | def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, 325 | use_original_steps=False): 326 | 327 | timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps 328 | timesteps = timesteps[:t_start] 329 | 330 | time_range = np.flip(timesteps) 331 | total_steps = timesteps.shape[0] 332 | print(f"Running DDIM Sampling with {total_steps} timesteps") 333 | 334 | iterator = tqdm(time_range, desc='Decoding image', total=total_steps) 335 | x_dec = x_latent 336 | for i, step in enumerate(iterator): 337 | index = total_steps - i - 1 338 | ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) 339 | x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, 340 | unconditional_guidance_scale=unconditional_guidance_scale, 341 | unconditional_conditioning=unconditional_conditioning) 342 | return x_dec 343 | 344 | -------------------------------------------------------------------------------- /custom_net/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /custom_net/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /custom_net/modules.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | from functools import partial 6 | import kornia 7 | from .x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test 8 | from transformers import CLIPTokenizer, CLIPTextModel 9 | import torch.nn.functional as F 10 | from torchvision import transforms 11 | import random 12 | from .customnet_util import default, instantiate_from_config 13 | from .util import make_beta_schedule, extract_into_tensor, noise_like 14 | import clip 15 | 16 | class AbstractEncoder(nn.Module): 17 | def __init__(self): 18 | super().__init__() 19 | 20 | def encode(self, *args, **kwargs): 21 | raise NotImplementedError 22 | 23 | 24 | 25 | def disabled_train(self, mode=True): 26 | """Overwrite model.train with this function to make sure train/eval mode 27 | does not change anymore.""" 28 | return self 29 | 30 | 31 | class FrozenCLIPEmbedder(AbstractEncoder): 32 | """Uses the CLIP transformer encoder for text (from huggingface)""" 33 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): # clip-vit-base-patch32 34 | super().__init__() 35 | self.tokenizer = CLIPTokenizer.from_pretrained(version) 36 | self.transformer = CLIPTextModel.from_pretrained(version) 37 | self.device = device 38 | self.max_length = max_length # TODO: typical value? 39 | self.freeze() 40 | 41 | def freeze(self): 42 | self.transformer = self.transformer.eval() 43 | #self.train = disabled_train 44 | for param in self.parameters(): 45 | param.requires_grad = False 46 | 47 | def forward(self, text, return_pool=False): 48 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 49 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 50 | tokens = batch_encoding["input_ids"].to(self.device) 51 | outputs = self.transformer(input_ids=tokens) 52 | 53 | z = outputs.last_hidden_state 54 | if return_pool: 55 | return z, outputs.pooler_output 56 | else: 57 | return z 58 | 59 | def encode(self, text): 60 | return self(text) 61 | 62 | 63 | class FrozenCLIPImageEmbedder(AbstractEncoder): 64 | """ 65 | Uses the CLIP image encoder. 66 | Not actually frozen... If you want that set cond_stage_trainable=False in cfg 67 | """ 68 | def __init__( 69 | self, 70 | model='ViT-L/14', 71 | jit=False, 72 | device='cpu', 73 | antialias=False, 74 | ): 75 | super().__init__() 76 | self.model, _ = clip.load(name=model, device=device, jit=jit,) 77 | # We don't use the text part so delete it 78 | del self.model.transformer 79 | self.antialias = antialias 80 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) 81 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 82 | 83 | def preprocess(self, x): 84 | # Expects inputs in the range -1, 1 85 | # x = kornia.geometry.resize(x, (224, 224), 86 | # interpolation='bicubic',align_corners=True, 87 | # antialias=self.antialias) 88 | 89 | x = kornia.geometry.resize(x, (224, 224), 90 | interpolation='bicubic',align_corners=True) 91 | 92 | x = (x + 1.) / 2. 93 | # renormalize according to clip 94 | x = kornia.enhance.normalize(x, self.mean, self.std) 95 | return x 96 | 97 | def forward(self, x): 98 | # x is assumed to be in range [-1,1] 99 | if isinstance(x, list): 100 | # [""] denotes condition dropout for ucg 101 | device = self.model.visual.conv1.weight.device 102 | return torch.zeros(1, 768, device=device) 103 | return self.model.encode_image(self.preprocess(x)).float() 104 | 105 | def encode(self, im): 106 | return self(im).unsqueeze(1) 107 | 108 | -------------------------------------------------------------------------------- /custom_net/plms.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | from functools import partial 7 | 8 | from .util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like 9 | from .sampling_util import norm_thresholding 10 | 11 | 12 | class PLMSSampler(object): 13 | def __init__(self, model, schedule="linear", **kwargs): 14 | super().__init__() 15 | self.model = model 16 | self.ddpm_num_timesteps = model.num_timesteps 17 | self.schedule = schedule 18 | 19 | def register_buffer(self, name, attr): 20 | if type(attr) == torch.Tensor: 21 | if attr.device != torch.device("cuda"): 22 | attr = attr.to(torch.device("cuda")) 23 | setattr(self, name, attr) 24 | 25 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 26 | if ddim_eta != 0: 27 | raise ValueError('ddim_eta must be 0 for PLMS') 28 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 29 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 30 | alphas_cumprod = self.model.alphas_cumprod 31 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 32 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 33 | 34 | self.register_buffer('betas', to_torch(self.model.betas)) 35 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 36 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 37 | 38 | # calculations for diffusion q(x_t | x_{t-1}) and others 39 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 40 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 41 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 42 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 43 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 44 | 45 | # ddim sampling parameters 46 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 47 | ddim_timesteps=self.ddim_timesteps, 48 | eta=ddim_eta,verbose=verbose) 49 | self.register_buffer('ddim_sigmas', ddim_sigmas) 50 | self.register_buffer('ddim_alphas', ddim_alphas) 51 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 52 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 53 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 54 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 55 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 56 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 57 | 58 | @torch.no_grad() 59 | def sample(self, 60 | S, 61 | batch_size, 62 | shape, 63 | conditioning=None, 64 | callback=None, 65 | normals_sequence=None, 66 | img_callback=None, 67 | quantize_x0=False, 68 | eta=0., 69 | mask=None, 70 | x0=None, 71 | temperature=1., 72 | noise_dropout=0., 73 | score_corrector=None, 74 | corrector_kwargs=None, 75 | verbose=True, 76 | x_T=None, 77 | log_every_t=100, 78 | unconditional_guidance_scale=1., 79 | unconditional_conditioning=None, 80 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 81 | dynamic_threshold=None, 82 | **kwargs 83 | ): 84 | if conditioning is not None: 85 | if isinstance(conditioning, dict): 86 | ctmp = conditioning[list(conditioning.keys())[0]] 87 | while isinstance(ctmp, list): ctmp = ctmp[0] 88 | cbs = ctmp.shape[0] 89 | if cbs != batch_size: 90 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 91 | else: 92 | if conditioning.shape[0] != batch_size: 93 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 94 | 95 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) 96 | # sampling 97 | C, H, W = shape 98 | size = (batch_size, C, H, W) 99 | print(f'Data shape for PLMS sampling is {size}') 100 | 101 | samples, intermediates = self.plms_sampling(conditioning, size, 102 | callback=callback, 103 | img_callback=img_callback, 104 | quantize_denoised=quantize_x0, 105 | mask=mask, x0=x0, 106 | ddim_use_original_steps=False, 107 | noise_dropout=noise_dropout, 108 | temperature=temperature, 109 | score_corrector=score_corrector, 110 | corrector_kwargs=corrector_kwargs, 111 | x_T=x_T, 112 | log_every_t=log_every_t, 113 | unconditional_guidance_scale=unconditional_guidance_scale, 114 | unconditional_conditioning=unconditional_conditioning, 115 | dynamic_threshold=dynamic_threshold, 116 | ) 117 | return samples, intermediates 118 | 119 | @torch.no_grad() 120 | def plms_sampling(self, cond, shape, 121 | x_T=None, ddim_use_original_steps=False, 122 | callback=None, timesteps=None, quantize_denoised=False, 123 | mask=None, x0=None, img_callback=None, log_every_t=100, 124 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 125 | unconditional_guidance_scale=1., unconditional_conditioning=None, 126 | dynamic_threshold=None): 127 | device = self.model.betas.device 128 | b = shape[0] 129 | if x_T is None: 130 | img = torch.randn(shape, device=device) 131 | else: 132 | img = x_T 133 | 134 | if timesteps is None: 135 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 136 | elif timesteps is not None and not ddim_use_original_steps: 137 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 138 | timesteps = self.ddim_timesteps[:subset_end] 139 | 140 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 141 | time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) 142 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 143 | print(f"Running PLMS Sampling with {total_steps} timesteps") 144 | 145 | iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) 146 | old_eps = [] 147 | 148 | for i, step in enumerate(iterator): 149 | index = total_steps - i - 1 150 | ts = torch.full((b,), step, device=device, dtype=torch.long) 151 | ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) 152 | 153 | if mask is not None: 154 | assert x0 is not None 155 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 156 | img = img_orig * mask + (1. - mask) * img 157 | 158 | outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 159 | quantize_denoised=quantize_denoised, temperature=temperature, 160 | noise_dropout=noise_dropout, score_corrector=score_corrector, 161 | corrector_kwargs=corrector_kwargs, 162 | unconditional_guidance_scale=unconditional_guidance_scale, 163 | unconditional_conditioning=unconditional_conditioning, 164 | old_eps=old_eps, t_next=ts_next, 165 | dynamic_threshold=dynamic_threshold) 166 | img, pred_x0, e_t = outs 167 | old_eps.append(e_t) 168 | if len(old_eps) >= 4: 169 | old_eps.pop(0) 170 | if callback: callback(i) 171 | if img_callback: img_callback(pred_x0, i) 172 | 173 | if index % log_every_t == 0 or index == total_steps - 1: 174 | intermediates['x_inter'].append(img) 175 | intermediates['pred_x0'].append(pred_x0) 176 | 177 | return img, intermediates 178 | 179 | @torch.no_grad() 180 | def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 181 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 182 | unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, 183 | dynamic_threshold=None): 184 | b, *_, device = *x.shape, x.device 185 | 186 | def get_model_output(x, t): 187 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 188 | e_t = self.model.apply_model(x, t, c) 189 | else: 190 | x_in = torch.cat([x] * 2) 191 | t_in = torch.cat([t] * 2) 192 | if isinstance(c, dict): 193 | assert isinstance(unconditional_conditioning, dict) 194 | c_in = dict() 195 | for k in c: 196 | if isinstance(c[k], list): 197 | c_in[k] = [torch.cat([ 198 | unconditional_conditioning[k][i], 199 | c[k][i]]) for i in range(len(c[k]))] 200 | else: 201 | c_in[k] = torch.cat([ 202 | unconditional_conditioning[k], 203 | c[k]]) 204 | else: 205 | c_in = torch.cat([unconditional_conditioning, c]) 206 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) 207 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) 208 | 209 | if score_corrector is not None: 210 | assert self.model.parameterization == "eps" 211 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 212 | 213 | return e_t 214 | 215 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 216 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 217 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 218 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 219 | 220 | def get_x_prev_and_pred_x0(e_t, index): 221 | # select parameters corresponding to the currently considered timestep 222 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) 223 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) 224 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) 225 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) 226 | 227 | # current prediction for x_0 228 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 229 | if quantize_denoised: 230 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 231 | if dynamic_threshold is not None: 232 | pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) 233 | # direction pointing to x_t 234 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 235 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 236 | if noise_dropout > 0.: 237 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 238 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 239 | return x_prev, pred_x0 240 | 241 | e_t = get_model_output(x, t) 242 | if len(old_eps) == 0: 243 | # Pseudo Improved Euler (2nd order) 244 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) 245 | e_t_next = get_model_output(x_prev, t_next) 246 | e_t_prime = (e_t + e_t_next) / 2 247 | elif len(old_eps) == 1: 248 | # 2nd order Pseudo Linear Multistep (Adams-Bashforth) 249 | e_t_prime = (3 * e_t - old_eps[-1]) / 2 250 | elif len(old_eps) == 2: 251 | # 3nd order Pseudo Linear Multistep (Adams-Bashforth) 252 | e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 253 | elif len(old_eps) >= 3: 254 | # 4nd order Pseudo Linear Multistep (Adams-Bashforth) 255 | e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 256 | 257 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) 258 | 259 | return x_prev, pred_x0, e_t 260 | -------------------------------------------------------------------------------- /custom_net/sampling_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def append_dims(x, target_dims): 6 | """Appends dimensions to the end of a tensor until it has target_dims dimensions. 7 | From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" 8 | dims_to_append = target_dims - x.ndim 9 | if dims_to_append < 0: 10 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') 11 | return x[(...,) + (None,) * dims_to_append] 12 | 13 | 14 | def renorm_thresholding(x0, value): 15 | # renorm 16 | pred_max = x0.max() 17 | pred_min = x0.min() 18 | pred_x0 = (x0 - pred_min) / (pred_max - pred_min) # 0 ... 1 19 | pred_x0 = 2 * pred_x0 - 1. # -1 ... 1 20 | 21 | s = torch.quantile( 22 | rearrange(pred_x0, 'b ... -> b (...)').abs(), 23 | value, 24 | dim=-1 25 | ) 26 | s.clamp_(min=1.0) 27 | s = s.view(-1, *((1,) * (pred_x0.ndim - 1))) 28 | 29 | # clip by threshold 30 | # pred_x0 = pred_x0.clamp(-s, s) / s # needs newer pytorch # TODO bring back to pure-gpu with min/max 31 | 32 | # temporary hack: numpy on cpu 33 | pred_x0 = np.clip(pred_x0.cpu().numpy(), -s.cpu().numpy(), s.cpu().numpy()) / s.cpu().numpy() 34 | pred_x0 = torch.tensor(pred_x0).to(self.model.device) 35 | 36 | # re.renorm 37 | pred_x0 = (pred_x0 + 1.) / 2. # 0 ... 1 38 | pred_x0 = (pred_max - pred_min) * pred_x0 + pred_min # orig range 39 | return pred_x0 40 | 41 | 42 | def norm_thresholding(x0, value): 43 | s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) 44 | return x0 * (value / s) 45 | 46 | 47 | def spatial_norm_thresholding(x0, value): 48 | # b c h w 49 | s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) 50 | return x0 * (value / s) -------------------------------------------------------------------------------- /custom_net/util.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import os 12 | import math 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | from einops import repeat 17 | import importlib 18 | 19 | def get_obj_from_str(string, reload=False): 20 | module, cls = string.rsplit(".", 1) 21 | if reload: 22 | module_imp = importlib.import_module(module) 23 | importlib.reload(module_imp) 24 | return getattr(importlib.import_module(module, package=None), cls) 25 | 26 | def instantiate_from_config(config): 27 | if not "target" in config: 28 | if config == '__is_first_stage__': 29 | return None 30 | elif config == "__is_unconditional__": 31 | return None 32 | raise KeyError("Expected key `target` to instantiate.") 33 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 34 | 35 | 36 | 37 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 38 | if schedule == "linear": 39 | betas = ( 40 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 41 | ) 42 | 43 | elif schedule == "cosine": 44 | timesteps = ( 45 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 46 | ) 47 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 48 | alphas = torch.cos(alphas).pow(2) 49 | alphas = alphas / alphas[0] 50 | betas = 1 - alphas[1:] / alphas[:-1] 51 | betas = np.clip(betas, a_min=0, a_max=0.999) 52 | 53 | elif schedule == "sqrt_linear": 54 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 55 | elif schedule == "sqrt": 56 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 57 | else: 58 | raise ValueError(f"schedule '{schedule}' unknown.") 59 | return betas.numpy() 60 | 61 | 62 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 63 | if ddim_discr_method == 'uniform': 64 | c = num_ddpm_timesteps // num_ddim_timesteps 65 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 66 | elif ddim_discr_method == 'quad': 67 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 68 | else: 69 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 70 | 71 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 72 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 73 | steps_out = ddim_timesteps + 1 74 | if verbose: 75 | print(f'Selected timesteps for ddim sampler: {steps_out}') 76 | return steps_out 77 | 78 | 79 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 80 | # select alphas for computing the variance schedule 81 | alphas = alphacums[ddim_timesteps] 82 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 83 | 84 | # according the the formula provided in https://arxiv.org/abs/2010.02502 85 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 86 | if verbose: 87 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 88 | print(f'For the chosen value of eta, which is {eta}, ' 89 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 90 | return sigmas, alphas, alphas_prev 91 | 92 | 93 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 94 | """ 95 | Create a beta schedule that discretizes the given alpha_t_bar function, 96 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 97 | :param num_diffusion_timesteps: the number of betas to produce. 98 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 99 | produces the cumulative product of (1-beta) up to that 100 | part of the diffusion process. 101 | :param max_beta: the maximum beta to use; use values lower than 1 to 102 | prevent singularities. 103 | """ 104 | betas = [] 105 | for i in range(num_diffusion_timesteps): 106 | t1 = i / num_diffusion_timesteps 107 | t2 = (i + 1) / num_diffusion_timesteps 108 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 109 | return np.array(betas) 110 | 111 | 112 | def extract_into_tensor(a, t, x_shape): 113 | b, *_ = t.shape 114 | out = a.gather(-1, t) 115 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 116 | 117 | 118 | def checkpoint(func, inputs, params, flag): 119 | """ 120 | Evaluate a function without caching intermediate activations, allowing for 121 | reduced memory at the expense of extra compute in the backward pass. 122 | :param func: the function to evaluate. 123 | :param inputs: the argument sequence to pass to `func`. 124 | :param params: a sequence of parameters `func` depends on but does not 125 | explicitly take as arguments. 126 | :param flag: if False, disable gradient checkpointing. 127 | """ 128 | if flag: 129 | args = tuple(inputs) + tuple(params) 130 | return CheckpointFunction.apply(func, len(inputs), *args) 131 | else: 132 | return func(*inputs) 133 | 134 | 135 | class CheckpointFunction(torch.autograd.Function): 136 | @staticmethod 137 | def forward(ctx, run_function, length, *args): 138 | ctx.run_function = run_function 139 | ctx.input_tensors = list(args[:length]) 140 | ctx.input_params = list(args[length:]) 141 | 142 | with torch.no_grad(): 143 | output_tensors = ctx.run_function(*ctx.input_tensors) 144 | return output_tensors 145 | 146 | @staticmethod 147 | def backward(ctx, *output_grads): 148 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 149 | with torch.enable_grad(): 150 | # Fixes a bug where the first op in run_function modifies the 151 | # Tensor storage in place, which is not allowed for detach()'d 152 | # Tensors. 153 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 154 | output_tensors = ctx.run_function(*shallow_copies) 155 | input_grads = torch.autograd.grad( 156 | output_tensors, 157 | ctx.input_tensors + ctx.input_params, 158 | output_grads, 159 | allow_unused=True, 160 | ) 161 | del ctx.input_tensors 162 | del ctx.input_params 163 | del output_tensors 164 | return (None, None) + input_grads 165 | 166 | 167 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 168 | """ 169 | Create sinusoidal timestep embeddings. 170 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 171 | These may be fractional. 172 | :param dim: the dimension of the output. 173 | :param max_period: controls the minimum frequency of the embeddings. 174 | :return: an [N x dim] Tensor of positional embeddings. 175 | """ 176 | if not repeat_only: 177 | half = dim // 2 178 | freqs = torch.exp( 179 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 180 | ).to(device=timesteps.device) 181 | args = timesteps[:, None].float() * freqs[None] 182 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 183 | if dim % 2: 184 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 185 | else: 186 | embedding = repeat(timesteps, 'b -> b d', d=dim) 187 | return embedding 188 | 189 | 190 | def zero_module(module): 191 | """ 192 | Zero out the parameters of a module and return it. 193 | """ 194 | for p in module.parameters(): 195 | p.detach().zero_() 196 | return module 197 | 198 | 199 | def scale_module(module, scale): 200 | """ 201 | Scale the parameters of a module and return it. 202 | """ 203 | for p in module.parameters(): 204 | p.detach().mul_(scale) 205 | return module 206 | 207 | 208 | def mean_flat(tensor): 209 | """ 210 | Take the mean over all non-batch dimensions. 211 | """ 212 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 213 | 214 | 215 | def normalization(channels): 216 | """ 217 | Make a standard normalization layer. 218 | :param channels: number of input channels. 219 | :return: an nn.Module for normalization. 220 | """ 221 | return GroupNorm32(32, channels) 222 | 223 | 224 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 225 | class SiLU(nn.Module): 226 | def forward(self, x): 227 | return x * torch.sigmoid(x) 228 | 229 | 230 | class GroupNorm32(nn.GroupNorm): 231 | def forward(self, x): 232 | return super().forward(x.float()).type(x.dtype) 233 | 234 | def conv_nd(dims, *args, **kwargs): 235 | """ 236 | Create a 1D, 2D, or 3D convolution module. 237 | """ 238 | if dims == 1: 239 | return nn.Conv1d(*args, **kwargs) 240 | elif dims == 2: 241 | return nn.Conv2d(*args, **kwargs) 242 | elif dims == 3: 243 | return nn.Conv3d(*args, **kwargs) 244 | raise ValueError(f"unsupported dimensions: {dims}") 245 | 246 | 247 | def linear(*args, **kwargs): 248 | """ 249 | Create a linear module. 250 | """ 251 | return nn.Linear(*args, **kwargs) 252 | 253 | 254 | def avg_pool_nd(dims, *args, **kwargs): 255 | """ 256 | Create a 1D, 2D, or 3D average pooling module. 257 | """ 258 | if dims == 1: 259 | return nn.AvgPool1d(*args, **kwargs) 260 | elif dims == 2: 261 | return nn.AvgPool2d(*args, **kwargs) 262 | elif dims == 3: 263 | return nn.AvgPool3d(*args, **kwargs) 264 | raise ValueError(f"unsupported dimensions: {dims}") 265 | 266 | 267 | class HybridConditioner(nn.Module): 268 | 269 | def __init__(self, c_concat_config, c_crossattn_config): 270 | super().__init__() 271 | self.concat_conditioner = instantiate_from_config(c_concat_config) 272 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 273 | 274 | def forward(self, c_concat, c_crossattn): 275 | c_concat = self.concat_conditioner(c_concat) 276 | c_crossattn = self.crossattn_conditioner(c_crossattn) 277 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} 278 | 279 | 280 | def noise_like(shape, device, repeat=False): 281 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 282 | noise = lambda: torch.randn(shape, device=device) 283 | return repeat_noise() if repeat else noise() -------------------------------------------------------------------------------- /custom_net/x_transformer.py: -------------------------------------------------------------------------------- 1 | """shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers""" 2 | import torch 3 | from torch import nn, einsum 4 | import torch.nn.functional as F 5 | from functools import partial 6 | from inspect import isfunction 7 | from collections import namedtuple 8 | from einops import rearrange, repeat, reduce 9 | 10 | # constants 11 | 12 | DEFAULT_DIM_HEAD = 64 13 | 14 | Intermediates = namedtuple('Intermediates', [ 15 | 'pre_softmax_attn', 16 | 'post_softmax_attn' 17 | ]) 18 | 19 | LayerIntermediates = namedtuple('Intermediates', [ 20 | 'hiddens', 21 | 'attn_intermediates' 22 | ]) 23 | 24 | 25 | class AbsolutePositionalEmbedding(nn.Module): 26 | def __init__(self, dim, max_seq_len): 27 | super().__init__() 28 | self.emb = nn.Embedding(max_seq_len, dim) 29 | self.init_() 30 | 31 | def init_(self): 32 | nn.init.normal_(self.emb.weight, std=0.02) 33 | 34 | def forward(self, x): 35 | n = torch.arange(x.shape[1], device=x.device) 36 | return self.emb(n)[None, :, :] 37 | 38 | 39 | class FixedPositionalEmbedding(nn.Module): 40 | def __init__(self, dim): 41 | super().__init__() 42 | inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) 43 | self.register_buffer('inv_freq', inv_freq) 44 | 45 | def forward(self, x, seq_dim=1, offset=0): 46 | t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset 47 | sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) 48 | emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) 49 | return emb[None, :, :] 50 | 51 | 52 | # helpers 53 | 54 | def exists(val): 55 | return val is not None 56 | 57 | 58 | def default(val, d): 59 | if exists(val): 60 | return val 61 | return d() if isfunction(d) else d 62 | 63 | 64 | def always(val): 65 | def inner(*args, **kwargs): 66 | return val 67 | return inner 68 | 69 | 70 | def not_equals(val): 71 | def inner(x): 72 | return x != val 73 | return inner 74 | 75 | 76 | def equals(val): 77 | def inner(x): 78 | return x == val 79 | return inner 80 | 81 | 82 | def max_neg_value(tensor): 83 | return -torch.finfo(tensor.dtype).max 84 | 85 | 86 | # keyword argument helpers 87 | 88 | def pick_and_pop(keys, d): 89 | values = list(map(lambda key: d.pop(key), keys)) 90 | return dict(zip(keys, values)) 91 | 92 | 93 | def group_dict_by_key(cond, d): 94 | return_val = [dict(), dict()] 95 | for key in d.keys(): 96 | match = bool(cond(key)) 97 | ind = int(not match) 98 | return_val[ind][key] = d[key] 99 | return (*return_val,) 100 | 101 | 102 | def string_begins_with(prefix, str): 103 | return str.startswith(prefix) 104 | 105 | 106 | def group_by_key_prefix(prefix, d): 107 | return group_dict_by_key(partial(string_begins_with, prefix), d) 108 | 109 | 110 | def groupby_prefix_and_trim(prefix, d): 111 | kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) 112 | kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) 113 | return kwargs_without_prefix, kwargs 114 | 115 | 116 | # classes 117 | class Scale(nn.Module): 118 | def __init__(self, value, fn): 119 | super().__init__() 120 | self.value = value 121 | self.fn = fn 122 | 123 | def forward(self, x, **kwargs): 124 | x, *rest = self.fn(x, **kwargs) 125 | return (x * self.value, *rest) 126 | 127 | 128 | class Rezero(nn.Module): 129 | def __init__(self, fn): 130 | super().__init__() 131 | self.fn = fn 132 | self.g = nn.Parameter(torch.zeros(1)) 133 | 134 | def forward(self, x, **kwargs): 135 | x, *rest = self.fn(x, **kwargs) 136 | return (x * self.g, *rest) 137 | 138 | 139 | class ScaleNorm(nn.Module): 140 | def __init__(self, dim, eps=1e-5): 141 | super().__init__() 142 | self.scale = dim ** -0.5 143 | self.eps = eps 144 | self.g = nn.Parameter(torch.ones(1)) 145 | 146 | def forward(self, x): 147 | norm = torch.norm(x, dim=-1, keepdim=True) * self.scale 148 | return x / norm.clamp(min=self.eps) * self.g 149 | 150 | 151 | class RMSNorm(nn.Module): 152 | def __init__(self, dim, eps=1e-8): 153 | super().__init__() 154 | self.scale = dim ** -0.5 155 | self.eps = eps 156 | self.g = nn.Parameter(torch.ones(dim)) 157 | 158 | def forward(self, x): 159 | norm = torch.norm(x, dim=-1, keepdim=True) * self.scale 160 | return x / norm.clamp(min=self.eps) * self.g 161 | 162 | 163 | class Residual(nn.Module): 164 | def forward(self, x, residual): 165 | return x + residual 166 | 167 | 168 | class GRUGating(nn.Module): 169 | def __init__(self, dim): 170 | super().__init__() 171 | self.gru = nn.GRUCell(dim, dim) 172 | 173 | def forward(self, x, residual): 174 | gated_output = self.gru( 175 | rearrange(x, 'b n d -> (b n) d'), 176 | rearrange(residual, 'b n d -> (b n) d') 177 | ) 178 | 179 | return gated_output.reshape_as(x) 180 | 181 | 182 | # feedforward 183 | 184 | class GEGLU(nn.Module): 185 | def __init__(self, dim_in, dim_out): 186 | super().__init__() 187 | self.proj = nn.Linear(dim_in, dim_out * 2) 188 | 189 | def forward(self, x): 190 | x, gate = self.proj(x).chunk(2, dim=-1) 191 | return x * F.gelu(gate) 192 | 193 | 194 | class FeedForward(nn.Module): 195 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 196 | super().__init__() 197 | inner_dim = int(dim * mult) 198 | dim_out = default(dim_out, dim) 199 | project_in = nn.Sequential( 200 | nn.Linear(dim, inner_dim), 201 | nn.GELU() 202 | ) if not glu else GEGLU(dim, inner_dim) 203 | 204 | self.net = nn.Sequential( 205 | project_in, 206 | nn.Dropout(dropout), 207 | nn.Linear(inner_dim, dim_out) 208 | ) 209 | 210 | def forward(self, x): 211 | return self.net(x) 212 | 213 | 214 | # attention. 215 | class Attention(nn.Module): 216 | def __init__( 217 | self, 218 | dim, 219 | dim_head=DEFAULT_DIM_HEAD, 220 | heads=8, 221 | causal=False, 222 | mask=None, 223 | talking_heads=False, 224 | sparse_topk=None, 225 | use_entmax15=False, 226 | num_mem_kv=0, 227 | dropout=0., 228 | on_attn=False 229 | ): 230 | super().__init__() 231 | if use_entmax15: 232 | raise NotImplementedError("Check out entmax activation instead of softmax activation!") 233 | self.scale = dim_head ** -0.5 234 | self.heads = heads 235 | self.causal = causal 236 | self.mask = mask 237 | 238 | inner_dim = dim_head * heads 239 | 240 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 241 | self.to_k = nn.Linear(dim, inner_dim, bias=False) 242 | self.to_v = nn.Linear(dim, inner_dim, bias=False) 243 | self.dropout = nn.Dropout(dropout) 244 | 245 | # talking heads 246 | self.talking_heads = talking_heads 247 | if talking_heads: 248 | self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads)) 249 | self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads)) 250 | 251 | # explicit topk sparse attention 252 | self.sparse_topk = sparse_topk 253 | 254 | # entmax 255 | #self.attn_fn = entmax15 if use_entmax15 else F.softmax 256 | self.attn_fn = F.softmax 257 | 258 | # add memory key / values 259 | self.num_mem_kv = num_mem_kv 260 | if num_mem_kv > 0: 261 | self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) 262 | self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) 263 | 264 | # attention on attention 265 | self.attn_on_attn = on_attn 266 | self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim) 267 | 268 | def forward( 269 | self, 270 | x, 271 | context=None, 272 | mask=None, 273 | context_mask=None, 274 | rel_pos=None, 275 | sinusoidal_emb=None, 276 | prev_attn=None, 277 | mem=None 278 | ): 279 | b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device 280 | kv_input = default(context, x) 281 | 282 | q_input = x 283 | k_input = kv_input 284 | v_input = kv_input 285 | 286 | if exists(mem): 287 | k_input = torch.cat((mem, k_input), dim=-2) 288 | v_input = torch.cat((mem, v_input), dim=-2) 289 | 290 | if exists(sinusoidal_emb): 291 | # in shortformer, the query would start at a position offset depending on the past cached memory 292 | offset = k_input.shape[-2] - q_input.shape[-2] 293 | q_input = q_input + sinusoidal_emb(q_input, offset=offset) 294 | k_input = k_input + sinusoidal_emb(k_input) 295 | 296 | q = self.to_q(q_input) 297 | k = self.to_k(k_input) 298 | v = self.to_v(v_input) 299 | 300 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) 301 | 302 | input_mask = None 303 | if any(map(exists, (mask, context_mask))): 304 | q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) 305 | k_mask = q_mask if not exists(context) else context_mask 306 | k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) 307 | q_mask = rearrange(q_mask, 'b i -> b () i ()') 308 | k_mask = rearrange(k_mask, 'b j -> b () () j') 309 | input_mask = q_mask * k_mask 310 | 311 | if self.num_mem_kv > 0: 312 | mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) 313 | k = torch.cat((mem_k, k), dim=-2) 314 | v = torch.cat((mem_v, v), dim=-2) 315 | if exists(input_mask): 316 | input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) 317 | 318 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 319 | mask_value = max_neg_value(dots) 320 | 321 | if exists(prev_attn): 322 | dots = dots + prev_attn 323 | 324 | pre_softmax_attn = dots 325 | 326 | if talking_heads: 327 | dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous() 328 | 329 | if exists(rel_pos): 330 | dots = rel_pos(dots) 331 | 332 | if exists(input_mask): 333 | dots.masked_fill_(~input_mask, mask_value) 334 | del input_mask 335 | 336 | if self.causal: 337 | i, j = dots.shape[-2:] 338 | r = torch.arange(i, device=device) 339 | mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j') 340 | mask = F.pad(mask, (j - i, 0), value=False) 341 | dots.masked_fill_(mask, mask_value) 342 | del mask 343 | 344 | if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: 345 | top, _ = dots.topk(self.sparse_topk, dim=-1) 346 | vk = top[..., -1].unsqueeze(-1).expand_as(dots) 347 | mask = dots < vk 348 | dots.masked_fill_(mask, mask_value) 349 | del mask 350 | 351 | attn = self.attn_fn(dots, dim=-1) 352 | post_softmax_attn = attn 353 | 354 | attn = self.dropout(attn) 355 | 356 | if talking_heads: 357 | attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous() 358 | 359 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 360 | out = rearrange(out, 'b h n d -> b n (h d)') 361 | 362 | intermediates = Intermediates( 363 | pre_softmax_attn=pre_softmax_attn, 364 | post_softmax_attn=post_softmax_attn 365 | ) 366 | 367 | return self.to_out(out), intermediates 368 | 369 | 370 | class AttentionLayers(nn.Module): 371 | def __init__( 372 | self, 373 | dim, 374 | depth, 375 | heads=8, 376 | causal=False, 377 | cross_attend=False, 378 | only_cross=False, 379 | use_scalenorm=False, 380 | use_rmsnorm=False, 381 | use_rezero=False, 382 | rel_pos_num_buckets=32, 383 | rel_pos_max_distance=128, 384 | position_infused_attn=False, 385 | custom_layers=None, 386 | sandwich_coef=None, 387 | par_ratio=None, 388 | residual_attn=False, 389 | cross_residual_attn=False, 390 | macaron=False, 391 | pre_norm=True, 392 | gate_residual=False, 393 | **kwargs 394 | ): 395 | super().__init__() 396 | ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) 397 | attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs) 398 | 399 | dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) 400 | 401 | self.dim = dim 402 | self.depth = depth 403 | self.layers = nn.ModuleList([]) 404 | 405 | self.has_pos_emb = position_infused_attn 406 | self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None 407 | self.rotary_pos_emb = always(None) 408 | 409 | assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' 410 | self.rel_pos = None 411 | 412 | self.pre_norm = pre_norm 413 | 414 | self.residual_attn = residual_attn 415 | self.cross_residual_attn = cross_residual_attn 416 | 417 | norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm 418 | norm_class = RMSNorm if use_rmsnorm else norm_class 419 | norm_fn = partial(norm_class, dim) 420 | 421 | norm_fn = nn.Identity if use_rezero else norm_fn 422 | branch_fn = Rezero if use_rezero else None 423 | 424 | if cross_attend and not only_cross: 425 | default_block = ('a', 'c', 'f') 426 | elif cross_attend and only_cross: 427 | default_block = ('c', 'f') 428 | else: 429 | default_block = ('a', 'f') 430 | 431 | if macaron: 432 | default_block = ('f',) + default_block 433 | 434 | if exists(custom_layers): 435 | layer_types = custom_layers 436 | elif exists(par_ratio): 437 | par_depth = depth * len(default_block) 438 | assert 1 < par_ratio <= par_depth, 'par ratio out of range' 439 | default_block = tuple(filter(not_equals('f'), default_block)) 440 | par_attn = par_depth // par_ratio 441 | depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper 442 | par_width = (depth_cut + depth_cut // par_attn) // par_attn 443 | assert len(default_block) <= par_width, 'default block is too large for par_ratio' 444 | par_block = default_block + ('f',) * (par_width - len(default_block)) 445 | par_head = par_block * par_attn 446 | layer_types = par_head + ('f',) * (par_depth - len(par_head)) 447 | elif exists(sandwich_coef): 448 | assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' 449 | layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef 450 | else: 451 | layer_types = default_block * depth 452 | 453 | self.layer_types = layer_types 454 | self.num_attn_layers = len(list(filter(equals('a'), layer_types))) 455 | 456 | for layer_type in self.layer_types: 457 | if layer_type == 'a': 458 | layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) 459 | elif layer_type == 'c': 460 | layer = Attention(dim, heads=heads, **attn_kwargs) 461 | elif layer_type == 'f': 462 | layer = FeedForward(dim, **ff_kwargs) 463 | layer = layer if not macaron else Scale(0.5, layer) 464 | else: 465 | raise Exception(f'invalid layer type {layer_type}') 466 | 467 | if isinstance(layer, Attention) and exists(branch_fn): 468 | layer = branch_fn(layer) 469 | 470 | if gate_residual: 471 | residual_fn = GRUGating(dim) 472 | else: 473 | residual_fn = Residual() 474 | 475 | self.layers.append(nn.ModuleList([ 476 | norm_fn(), 477 | layer, 478 | residual_fn 479 | ])) 480 | 481 | def forward( 482 | self, 483 | x, 484 | context=None, 485 | mask=None, 486 | context_mask=None, 487 | mems=None, 488 | return_hiddens=False 489 | ): 490 | hiddens = [] 491 | intermediates = [] 492 | prev_attn = None 493 | prev_cross_attn = None 494 | 495 | mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers 496 | 497 | for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): 498 | is_last = ind == (len(self.layers) - 1) 499 | 500 | if layer_type == 'a': 501 | hiddens.append(x) 502 | layer_mem = mems.pop(0) 503 | 504 | residual = x 505 | 506 | if self.pre_norm: 507 | x = norm(x) 508 | 509 | if layer_type == 'a': 510 | out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos, 511 | prev_attn=prev_attn, mem=layer_mem) 512 | elif layer_type == 'c': 513 | out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn) 514 | elif layer_type == 'f': 515 | out = block(x) 516 | 517 | x = residual_fn(out, residual) 518 | 519 | if layer_type in ('a', 'c'): 520 | intermediates.append(inter) 521 | 522 | if layer_type == 'a' and self.residual_attn: 523 | prev_attn = inter.pre_softmax_attn 524 | elif layer_type == 'c' and self.cross_residual_attn: 525 | prev_cross_attn = inter.pre_softmax_attn 526 | 527 | if not self.pre_norm and not is_last: 528 | x = norm(x) 529 | 530 | if return_hiddens: 531 | intermediates = LayerIntermediates( 532 | hiddens=hiddens, 533 | attn_intermediates=intermediates 534 | ) 535 | 536 | return x, intermediates 537 | 538 | return x 539 | 540 | 541 | class Encoder(AttentionLayers): 542 | def __init__(self, **kwargs): 543 | assert 'causal' not in kwargs, 'cannot set causality on encoder' 544 | super().__init__(causal=False, **kwargs) 545 | 546 | 547 | 548 | class TransformerWrapper(nn.Module): 549 | def __init__( 550 | self, 551 | *, 552 | num_tokens, 553 | max_seq_len, 554 | attn_layers, 555 | emb_dim=None, 556 | max_mem_len=0., 557 | emb_dropout=0., 558 | num_memory_tokens=None, 559 | tie_embedding=False, 560 | use_pos_emb=True 561 | ): 562 | super().__init__() 563 | assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' 564 | 565 | dim = attn_layers.dim 566 | emb_dim = default(emb_dim, dim) 567 | 568 | self.max_seq_len = max_seq_len 569 | self.max_mem_len = max_mem_len 570 | self.num_tokens = num_tokens 571 | 572 | self.token_emb = nn.Embedding(num_tokens, emb_dim) 573 | self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( 574 | use_pos_emb and not attn_layers.has_pos_emb) else always(0) 575 | self.emb_dropout = nn.Dropout(emb_dropout) 576 | 577 | self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() 578 | self.attn_layers = attn_layers 579 | self.norm = nn.LayerNorm(dim) 580 | 581 | self.init_() 582 | 583 | self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() 584 | 585 | # memory tokens (like [cls]) from Memory Transformers paper 586 | num_memory_tokens = default(num_memory_tokens, 0) 587 | self.num_memory_tokens = num_memory_tokens 588 | if num_memory_tokens > 0: 589 | self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) 590 | 591 | # let funnel encoder know number of memory tokens, if specified 592 | if hasattr(attn_layers, 'num_memory_tokens'): 593 | attn_layers.num_memory_tokens = num_memory_tokens 594 | 595 | def init_(self): 596 | nn.init.normal_(self.token_emb.weight, std=0.02) 597 | 598 | def forward( 599 | self, 600 | x, 601 | return_embeddings=False, 602 | mask=None, 603 | return_mems=False, 604 | return_attn=False, 605 | mems=None, 606 | **kwargs 607 | ): 608 | b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens 609 | x = self.token_emb(x) 610 | x += self.pos_emb(x) 611 | x = self.emb_dropout(x) 612 | 613 | x = self.project_emb(x) 614 | 615 | if num_mem > 0: 616 | mem = repeat(self.memory_tokens, 'n d -> b n d', b=b) 617 | x = torch.cat((mem, x), dim=1) 618 | 619 | # auto-handle masking after appending memory tokens 620 | if exists(mask): 621 | mask = F.pad(mask, (num_mem, 0), value=True) 622 | 623 | x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) 624 | x = self.norm(x) 625 | 626 | mem, x = x[:, :num_mem], x[:, num_mem:] 627 | 628 | out = self.to_logits(x) if not return_embeddings else x 629 | 630 | if return_mems: 631 | hiddens = intermediates.hiddens 632 | new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens 633 | new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems)) 634 | return out, new_mems 635 | 636 | if return_attn: 637 | attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) 638 | return out, attn_maps 639 | 640 | return out 641 | 642 | -------------------------------------------------------------------------------- /example/inpainting.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_CustomNet/49eb0d40ff3b46d0f04c781f61a8b6678729a417/example/inpainting.png -------------------------------------------------------------------------------- /example/normal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_CustomNet/49eb0d40ff3b46d0f04c781f61a8b6678729a417/example/normal.png -------------------------------------------------------------------------------- /example/polar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_CustomNet/49eb0d40ff3b46d0f04c781f61a8b6678729a417/example/polar.png -------------------------------------------------------------------------------- /example/position.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_CustomNet/49eb0d40ff3b46d0f04c781f61a8b6678729a417/example/position.png -------------------------------------------------------------------------------- /example/workflow.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 27, 3 | "last_link_id": 29, 4 | "nodes": [ 5 | { 6 | "id": 18, 7 | "type": "SaveImage", 8 | "pos": [ 9 | 1833, 10 | -122 11 | ], 12 | "size": { 13 | "0": 315, 14 | "1": 270 15 | }, 16 | "flags": {}, 17 | "order": 7, 18 | "mode": 4, 19 | "inputs": [ 20 | { 21 | "name": "images", 22 | "type": "IMAGE", 23 | "link": 18 24 | } 25 | ], 26 | "properties": {}, 27 | "widgets_values": [ 28 | "ComfyUI" 29 | ] 30 | }, 31 | { 32 | "id": 17, 33 | "type": "CustomNet_Sampler", 34 | "pos": [ 35 | 1393, 36 | -127 37 | ], 38 | "size": { 39 | "0": 412.7192687988281, 40 | "1": 474 41 | }, 42 | "flags": {}, 43 | "order": 5, 44 | "mode": 4, 45 | "inputs": [ 46 | { 47 | "name": "model", 48 | "type": "MODEL", 49 | "link": 16 50 | }, 51 | { 52 | "name": "info", 53 | "type": "DICT", 54 | "link": 17 55 | }, 56 | { 57 | "name": "image", 58 | "type": "IMAGE", 59 | "link": 19, 60 | "slot_index": 2 61 | }, 62 | { 63 | "name": "bg_image", 64 | "type": "IMAGE", 65 | "link": 24, 66 | "slot_index": 3 67 | } 68 | ], 69 | "outputs": [ 70 | { 71 | "name": "output_image", 72 | "type": "IMAGE", 73 | "links": [ 74 | 18 75 | ], 76 | "shape": 3, 77 | "slot_index": 0 78 | } 79 | ], 80 | "properties": { 81 | "Node name for S&R": "CustomNet_Sampler" 82 | }, 83 | "widgets_values": [ 84 | "a pig at forest", 85 | "", 86 | 50, 87 | 860008325477834, 88 | "randomize", 89 | 256, 90 | 256, 91 | 125, 92 | 125, 93 | 256, 94 | 256, 95 | 0, 96 | 0, 97 | 1 98 | ] 99 | }, 100 | { 101 | "id": 22, 102 | "type": "LoadImage", 103 | "pos": [ 104 | 1042, 105 | 43 106 | ], 107 | "size": { 108 | "0": 315, 109 | "1": 314 110 | }, 111 | "flags": {}, 112 | "order": 0, 113 | "mode": 4, 114 | "outputs": [ 115 | { 116 | "name": "IMAGE", 117 | "type": "IMAGE", 118 | "links": [ 119 | 24 120 | ], 121 | "shape": 3 122 | }, 123 | { 124 | "name": "MASK", 125 | "type": "MASK", 126 | "links": null, 127 | "shape": 3 128 | } 129 | ], 130 | "properties": { 131 | "Node name for S&R": "LoadImage" 132 | }, 133 | "widgets_values": [ 134 | "Collections.jpg", 135 | "image" 136 | ] 137 | }, 138 | { 139 | "id": 16, 140 | "type": "CustomNet_LoadModel", 141 | "pos": [ 142 | 1023, 143 | -206 144 | ], 145 | "size": { 146 | "0": 315, 147 | "1": 78 148 | }, 149 | "flags": {}, 150 | "order": 1, 151 | "mode": 4, 152 | "outputs": [ 153 | { 154 | "name": "model", 155 | "type": "MODEL", 156 | "links": [ 157 | 16 158 | ], 159 | "shape": 3, 160 | "slot_index": 0 161 | }, 162 | { 163 | "name": "info", 164 | "type": "DICT", 165 | "links": [ 166 | 17 167 | ], 168 | "shape": 3, 169 | "slot_index": 1 170 | } 171 | ], 172 | "properties": { 173 | "Node name for S&R": "CustomNet_LoadModel" 174 | }, 175 | "widgets_values": [ 176 | "1SD1.5\\customnet_inpaint_v1.pt" 177 | ] 178 | }, 179 | { 180 | "id": 19, 181 | "type": "LoadImage", 182 | "pos": [ 183 | 695, 184 | -85 185 | ], 186 | "size": { 187 | "0": 315, 188 | "1": 314 189 | }, 190 | "flags": {}, 191 | "order": 2, 192 | "mode": 4, 193 | "outputs": [ 194 | { 195 | "name": "IMAGE", 196 | "type": "IMAGE", 197 | "links": [ 198 | 19 199 | ], 200 | "shape": 3, 201 | "slot_index": 0 202 | }, 203 | { 204 | "name": "MASK", 205 | "type": "MASK", 206 | "links": null, 207 | "shape": 3 208 | } 209 | ], 210 | "properties": { 211 | "Node name for S&R": "LoadImage" 212 | }, 213 | "widgets_values": [ 214 | "123.png", 215 | "image" 216 | ] 217 | }, 218 | { 219 | "id": 25, 220 | "type": "SaveImage", 221 | "pos": [ 222 | 3134.298971203125, 223 | -83.55231904687496 224 | ], 225 | "size": { 226 | "0": 315, 227 | "1": 270 228 | }, 229 | "flags": {}, 230 | "order": 8, 231 | "mode": 0, 232 | "inputs": [ 233 | { 234 | "name": "images", 235 | "type": "IMAGE", 236 | "link": 29 237 | } 238 | ], 239 | "properties": {}, 240 | "widgets_values": [ 241 | "ComfyUI" 242 | ] 243 | }, 244 | { 245 | "id": 24, 246 | "type": "CustomNet_Sampler", 247 | "pos": [ 248 | 2694.298971203125, 249 | -93.55231904687494 250 | ], 251 | "size": { 252 | "0": 412.7192687988281, 253 | "1": 474 254 | }, 255 | "flags": {}, 256 | "order": 6, 257 | "mode": 0, 258 | "inputs": [ 259 | { 260 | "name": "model", 261 | "type": "MODEL", 262 | "link": 25 263 | }, 264 | { 265 | "name": "info", 266 | "type": "DICT", 267 | "link": 26 268 | }, 269 | { 270 | "name": "image", 271 | "type": "IMAGE", 272 | "link": 27, 273 | "slot_index": 2 274 | }, 275 | { 276 | "name": "bg_image", 277 | "type": "IMAGE", 278 | "link": null, 279 | "slot_index": 3 280 | } 281 | ], 282 | "outputs": [ 283 | { 284 | "name": "output_image", 285 | "type": "IMAGE", 286 | "links": [ 287 | 29 288 | ], 289 | "shape": 3, 290 | "slot_index": 0 291 | } 292 | ], 293 | "properties": { 294 | "Node name for S&R": "CustomNet_Sampler" 295 | }, 296 | "widgets_values": [ 297 | "a pig at forest", 298 | "", 299 | 50, 300 | 608725700076689, 301 | "randomize", 302 | 256, 303 | 256, 304 | 125, 305 | 125, 306 | 256, 307 | 256, 308 | 0, 309 | 0, 310 | 1 311 | ] 312 | }, 313 | { 314 | "id": 23, 315 | "type": "CustomNet_LoadModel", 316 | "pos": [ 317 | 2324.298971203125, 318 | -173.55231904687506 319 | ], 320 | "size": { 321 | "0": 315, 322 | "1": 78 323 | }, 324 | "flags": {}, 325 | "order": 3, 326 | "mode": 0, 327 | "outputs": [ 328 | { 329 | "name": "model", 330 | "type": "MODEL", 331 | "links": [ 332 | 25 333 | ], 334 | "shape": 3, 335 | "slot_index": 0 336 | }, 337 | { 338 | "name": "info", 339 | "type": "DICT", 340 | "links": [ 341 | 26 342 | ], 343 | "shape": 3, 344 | "slot_index": 1 345 | } 346 | ], 347 | "properties": { 348 | "Node name for S&R": "CustomNet_LoadModel" 349 | }, 350 | "widgets_values": [ 351 | "1SD1.5\\customnet_v1.pt" 352 | ] 353 | }, 354 | { 355 | "id": 26, 356 | "type": "LoadImage", 357 | "pos": [ 358 | 2286.298971203125, 359 | 3.4476809531250003 360 | ], 361 | "size": [ 362 | 315, 363 | 314.00000381469727 364 | ], 365 | "flags": {}, 366 | "order": 4, 367 | "mode": 0, 368 | "outputs": [ 369 | { 370 | "name": "IMAGE", 371 | "type": "IMAGE", 372 | "links": [ 373 | 27 374 | ], 375 | "shape": 3, 376 | "slot_index": 0 377 | }, 378 | { 379 | "name": "MASK", 380 | "type": "MASK", 381 | "links": null, 382 | "shape": 3 383 | } 384 | ], 385 | "properties": { 386 | "Node name for S&R": "LoadImage" 387 | }, 388 | "widgets_values": [ 389 | "123.png", 390 | "image" 391 | ] 392 | } 393 | ], 394 | "links": [ 395 | [ 396 | 16, 397 | 16, 398 | 0, 399 | 17, 400 | 0, 401 | "MODEL" 402 | ], 403 | [ 404 | 17, 405 | 16, 406 | 1, 407 | 17, 408 | 1, 409 | "DICT" 410 | ], 411 | [ 412 | 18, 413 | 17, 414 | 0, 415 | 18, 416 | 0, 417 | "IMAGE" 418 | ], 419 | [ 420 | 19, 421 | 19, 422 | 0, 423 | 17, 424 | 2, 425 | "IMAGE" 426 | ], 427 | [ 428 | 24, 429 | 22, 430 | 0, 431 | 17, 432 | 3, 433 | "IMAGE" 434 | ], 435 | [ 436 | 25, 437 | 23, 438 | 0, 439 | 24, 440 | 0, 441 | "MODEL" 442 | ], 443 | [ 444 | 26, 445 | 23, 446 | 1, 447 | 24, 448 | 1, 449 | "DICT" 450 | ], 451 | [ 452 | 27, 453 | 26, 454 | 0, 455 | 24, 456 | 2, 457 | "IMAGE" 458 | ], 459 | [ 460 | 29, 461 | 24, 462 | 0, 463 | 25, 464 | 0, 465 | "IMAGE" 466 | ] 467 | ], 468 | "groups": [ 469 | { 470 | "title": "Group", 471 | "bounding": [ 472 | 692, 473 | -395, 474 | 1523, 475 | 892 476 | ], 477 | "color": "#3f789e", 478 | "font_size": 24, 479 | "locked": false 480 | }, 481 | { 482 | "title": "Group", 483 | "bounding": [ 484 | 2228, 485 | -391, 486 | 1279, 487 | 870 488 | ], 489 | "color": "#3f789e", 490 | "font_size": 24, 491 | "locked": false 492 | } 493 | ], 494 | "config": {}, 495 | "extra": { 496 | "ds": { 497 | "scale": 0.6209213230591553, 498 | "offset": [ 499 | -790.6783701887839, 500 | 587.7944958574166 501 | ] 502 | } 503 | }, 504 | "version": 0.4 505 | } -------------------------------------------------------------------------------- /example/zaimuth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_CustomNet/49eb0d40ff3b46d0f04c781f61a8b6678729a417/example/zaimuth.png -------------------------------------------------------------------------------- /gradio_utils.py: -------------------------------------------------------------------------------- 1 | from .custom_net.customnet_util import create_carvekit_interface, load_and_preprocess 2 | 3 | def load_preprocess_model(): 4 | carvekit = create_carvekit_interface() 5 | return carvekit 6 | 7 | def preprocess_image(models, input_im): 8 | ''' 9 | :param input_im (PIL Image). 10 | :return input_im (H, W, 3) array. 11 | ''' 12 | input_im = load_and_preprocess(models, input_im) 13 | return input_im 14 | -------------------------------------------------------------------------------- /lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui_customnet" 3 | description = "you can using customnet in comfyUI " 4 | version = "1.1.0" 5 | license = { file = "LICENSE" } 6 | dependencies = ["carvekit-colab==4.1.0", "torch-fidelity==0.3.0", "lovely-numpy>=0.2.8", "lovely-tensors>=0.1.14", "-e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers"] 7 | 8 | [project.urls] 9 | Repository = "https://github.com/smthemex/ComfyUI_CustomNet" 10 | # Used by Comfy Registry https://comfyregistry.org 11 | 12 | [tool.comfy] 13 | PublisherId = "smthemex" 14 | DisplayName = "ComfyUI_CustomNet" 15 | Icon = "" 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | carvekit-colab==4.1.0 2 | torch-fidelity==0.3.0 3 | lovely-numpy>=0.2.8 4 | lovely-tensors>=0.1.14 5 | -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers 6 | --------------------------------------------------------------------------------