├── .gitignore ├── JanusPro.py ├── LICENSE ├── README.md ├── __init__.py ├── janus ├── __init__.py ├── models │ ├── __init__.py │ ├── clip_encoder.py │ ├── image_processing_vlm.py │ ├── modeling_vlm.py │ ├── processing_vlm.py │ ├── projector.py │ ├── siglip_vit.py │ └── vq_model.py └── utils │ ├── __init__.py │ ├── conversation.py │ └── io.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # PyPI configuration file 171 | .pypirc 172 | -------------------------------------------------------------------------------- /JanusPro.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import torch 4 | import numpy as np 5 | import folder_paths 6 | import time 7 | import re 8 | from PIL import Image 9 | from transformers import AutoConfig, AutoModelForCausalLM 10 | 11 | # 关键路径处理:将当前目录添加到系统路径 12 | current_dir = os.path.dirname(os.path.abspath(__file__)) 13 | sys.path.insert(0, current_dir) # 添加当前目录到Python路径 14 | 15 | try: 16 | from janus.models import MultiModalityCausalLM, VLChatProcessor 17 | from janus.utils.io import load_pil_images 18 | except ImportError as e: 19 | print(f"路径调试信息:") 20 | print(f"当前目录: {current_dir}") 21 | print(f"目录内容: {os.listdir(current_dir)}") 22 | print(f"sys.path: {sys.path}") 23 | raise 24 | 25 | # 添加模型路径配置 26 | current_directory = os.path.dirname(os.path.abspath(__file__)) 27 | folder_paths.folder_names_and_paths["Janus"] = ([os.path.join(folder_paths.models_dir, "Janus")], folder_paths.supported_pt_extensions) 28 | 29 | # 辅助函数 30 | def tensor2pil(image): 31 | return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) 32 | 33 | def pil2tensor(image): 34 | return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0) 35 | 36 | class Janus_ModelLoader: 37 | def __init__(self): 38 | pass 39 | 40 | @classmethod 41 | def INPUT_TYPES(cls): 42 | return { 43 | "required": { 44 | "model_path": ("STRING", {"default": "deepseek-ai/Janus-Pro-7B"}), 45 | } 46 | } 47 | 48 | RETURN_TYPES = ("JANUS_MODEL", "PROCESSOR", "TOKENIZER") 49 | RETURN_NAMES = ("model", "processor", "tokenizer") 50 | FUNCTION = "load_model" 51 | CATEGORY = "🧩Janus" 52 | 53 | def load_model(self, model_path): 54 | # 加载配置 55 | config = AutoConfig.from_pretrained(model_path) 56 | language_config = config.language_config 57 | language_config._attn_implementation = 'eager' 58 | 59 | # 加载模型 60 | vl_gpt = AutoModelForCausalLM.from_pretrained( 61 | model_path, 62 | language_config=language_config, 63 | trust_remote_code=True 64 | ).to(torch.bfloat16 if torch.cuda.is_available() else torch.float16) 65 | 66 | if torch.cuda.is_available(): 67 | vl_gpt = vl_gpt.cuda() 68 | 69 | # 加载处理器 70 | processor = VLChatProcessor.from_pretrained(model_path) 71 | tokenizer = processor.tokenizer 72 | 73 | return (vl_gpt, processor, tokenizer) 74 | 75 | class Janus_MultimodalUnderstanding: 76 | @classmethod 77 | def INPUT_TYPES(cls): 78 | return { 79 | "required": { 80 | "model": ("JANUS_MODEL",), 81 | "processor": ("PROCESSOR",), 82 | "tokenizer": ("TOKENIZER",), 83 | "image": ("IMAGE",), 84 | "question": ("STRING", {"default": "describe the image", "multiline": True}), 85 | "seed": ("INT", {"default": 42, "min": 0, "max": 0xffffffffffffffff}), 86 | "top_p": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 1.0, "step": 0.05}), 87 | "temperature": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.05}), 88 | }, 89 | "optional": { 90 | "max_new_tokens": ("INT", {"default": 512, "min": 16, "max": 2048}), 91 | } 92 | } 93 | 94 | RETURN_TYPES = ("STRING",) 95 | RETURN_NAMES = ("response",) 96 | FUNCTION = "understand" 97 | CATEGORY = "🧩Janus" 98 | 99 | def understand(self, model, processor, tokenizer, image, question, seed, top_p, temperature, max_new_tokens=512): 100 | # 修复种子范围问题 101 | seed = seed % (2**32) 102 | 103 | # 设置随机种子(添加CUDA同步) 104 | torch.manual_seed(seed) 105 | np.random.seed(seed % (2**32 - 1)) # 适配numpy种子范围 106 | if torch.cuda.is_available(): 107 | torch.cuda.manual_seed_all(seed) 108 | torch.cuda.synchronize() 109 | 110 | try: 111 | # 图像预处理(添加维度验证) 112 | if isinstance(image, list): 113 | image_tensor = image[0] 114 | else: 115 | image_tensor = image 116 | 117 | pil_image = tensor2pil(image_tensor) 118 | if pil_image.mode != "RGB": 119 | pil_image = pil_image.convert("RGB") 120 | 121 | # 构建对话(添加异常处理) 122 | try: 123 | conversation = [{ 124 | "role": "<|User|>", 125 | "content": f"\n{question}", 126 | "images": [pil_image], 127 | }, { 128 | "role": "<|Assistant|>", 129 | "content": "" 130 | }] 131 | except Exception as e: 132 | print(f"对话构建失败: {e}") 133 | return ("Error: Invalid conversation format",) 134 | 135 | # 处理输入(添加维度调试) 136 | try: 137 | prepare_inputs = processor( 138 | conversations=conversation, 139 | images=[pil_image], 140 | force_batchify=True 141 | ).to(model.device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16) 142 | 143 | print(f"输入张量形状 - input_ids: {prepare_inputs.input_ids.shape}") 144 | print(f"注意力掩码形状: {prepare_inputs.attention_mask.shape}") 145 | except Exception as e: 146 | print(f"输入处理失败: {e}") 147 | return ("Error: Input processing failed",) 148 | 149 | # 生成过程(添加参数验证) 150 | try: 151 | inputs_embeds = model.prepare_inputs_embeds(**prepare_inputs) 152 | print(f"输入嵌入形状: {inputs_embeds.shape}") 153 | 154 | generation_config = { 155 | "inputs_embeds": inputs_embeds, 156 | "attention_mask": prepare_inputs.attention_mask, 157 | "pad_token_id": tokenizer.eos_token_id, 158 | "bos_token_id": tokenizer.bos_token_id, 159 | "eos_token_id": tokenizer.eos_token_id, 160 | "max_new_tokens": max_new_tokens, 161 | "do_sample": temperature > 0, 162 | "temperature": temperature if temperature > 0 else 1.0, 163 | "top_p": top_p, 164 | } 165 | 166 | # 执行生成(添加时间监控) 167 | start_time = time.time() 168 | outputs = model.language_model.generate(**generation_config) 169 | print(f"生成耗时: {time.time() - start_time:.2f}秒") 170 | 171 | except Exception as e: 172 | print(f"生成失败: {e}") 173 | return ("Error: Generation failed",) 174 | 175 | # 解码输出(添加异常处理) 176 | try: 177 | full_output = outputs[0].cpu().tolist() 178 | answer = tokenizer.decode(full_output, skip_special_tokens=True) 179 | 180 | # 清理特殊标记 181 | clean_pattern = r'<\|.*?\|>' 182 | clean_answer = re.sub(clean_pattern, '', answer).strip() 183 | 184 | return (clean_answer,) 185 | 186 | except Exception as e: 187 | print(f"解码失败: {e}") 188 | return ("Error: Output decoding failed",) 189 | 190 | except Exception as e: 191 | print(f"处理过程中出现未捕获的异常: {e}") 192 | return ("Error: Unexpected processing error",) 193 | 194 | 195 | class Janus_ImageGeneration: 196 | @classmethod 197 | def INPUT_TYPES(cls): 198 | return { 199 | "required": { 200 | "model": ("JANUS_MODEL",), 201 | "processor": ("PROCESSOR",), 202 | "tokenizer": ("TOKENIZER",), 203 | "prompt": ("STRING", {"multiline": True, "default": "Master shifu racoon wearing drip attire"}), 204 | "seed": ("INT", {"default": 12345, "min": 0, "max": 0xffffffffffffffff}), 205 | "cfg_weight": ("FLOAT", {"default": 5.0, "min": 1.0, "max": 10.0, "step": 0.5}), 206 | "temperature": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.05}), 207 | } 208 | } 209 | 210 | RETURN_TYPES = ("IMAGE",) 211 | RETURN_NAMES = ("images",) 212 | FUNCTION = "generate" 213 | CATEGORY = "🧩Janus" 214 | 215 | def generate(self, model, processor, tokenizer, prompt, seed, cfg_weight, temperature): 216 | # 清理缓存并设置种子 217 | torch.cuda.empty_cache() 218 | seed = seed % (2**32) 219 | torch.manual_seed(seed) 220 | np.random.seed(seed) 221 | if torch.cuda.is_available(): 222 | torch.cuda.manual_seed_all(seed) 223 | 224 | # 固定参数(与原始代码一致) 225 | width = 384 226 | height = 384 227 | parallel_size = 5 228 | patch_size = 16 229 | image_token_num = 576 230 | 231 | # 构建输入文本 232 | messages = [{'role': '<|User|>', 'content': prompt}, 233 | {'role': '<|Assistant|>', 'content': ''}] 234 | text = processor.apply_sft_template_for_multi_turn_prompts( 235 | conversations=messages, 236 | sft_format=processor.sft_format, 237 | system_prompt='' 238 | ) + processor.image_start_tag 239 | 240 | # 生成输入ID 241 | input_ids = torch.LongTensor(tokenizer.encode(text)).to(model.device) 242 | 243 | # 初始化Tokens(严格保持原始结构) 244 | tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int, device=model.device) 245 | for i in range(parallel_size * 2): 246 | tokens[i, :] = input_ids 247 | if i % 2 != 0: 248 | tokens[i, 1:-1] = processor.pad_id 249 | 250 | # 生成过程(保持原始循环结构) 251 | inputs_embeds = model.language_model.get_input_embeddings()(tokens) 252 | generated_tokens = torch.zeros((parallel_size, image_token_num), dtype=torch.int, device=model.device) 253 | 254 | pkv = None 255 | for i in range(image_token_num): 256 | with torch.no_grad(): 257 | outputs = model.language_model.model( 258 | inputs_embeds=inputs_embeds, 259 | use_cache=True, 260 | past_key_values=pkv 261 | ) 262 | pkv = outputs.past_key_values 263 | 264 | # 原始分类器自由引导实现 265 | logits = model.gen_head(outputs.last_hidden_state[:, -1, :]) 266 | logit_cond = logits[0::2, :] 267 | logit_uncond = logits[1::2, :] 268 | logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond) 269 | 270 | # 采样逻辑 271 | probs = torch.softmax(logits / temperature, dim=-1) 272 | next_token = torch.multinomial(probs, num_samples=1) 273 | generated_tokens[:, i] = next_token.squeeze(dim=-1) 274 | 275 | # 准备下一轮输入(保持原始视图操作) 276 | next_token = torch.cat([next_token.unsqueeze(1), next_token.unsqueeze(1)], dim=1).view(-1) 277 | img_embeds = model.prepare_gen_img_embeds(next_token) 278 | inputs_embeds = img_embeds.unsqueeze(dim=1) 279 | 280 | # 图像解码(严格保持原始实现) 281 | patches = model.gen_vision_model.decode_code( 282 | generated_tokens.to(dtype=torch.int), 283 | shape=[parallel_size, 8, width//patch_size, height//patch_size] 284 | ) 285 | 286 | # 后处理(原始unpack逻辑) 287 | dec = patches.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1) 288 | dec = np.clip((dec + 1) / 2 * 255, 0, 255).astype(np.uint8) 289 | visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8) 290 | visual_img[:, :, :] = dec 291 | 292 | # 转换为ComfyUI图像格式 293 | output_images = [] 294 | for i in range(parallel_size): 295 | pil_img = Image.fromarray(visual_img[i]).resize((768, 768), Image.LANCZOS) 296 | output_images.append(pil2tensor(pil_img)) 297 | 298 | return (torch.cat(output_images, dim=0),) 299 | 300 | 301 | NODE_CLASS_MAPPINGS = { 302 | "Janus_ModelLoader": Janus_ModelLoader, 303 | "Janus_MultimodalUnderstanding": Janus_MultimodalUnderstanding, 304 | "Janus_ImageGeneration": Janus_ImageGeneration 305 | } 306 | 307 | NODE_DISPLAY_NAME_MAPPINGS = { 308 | "Janus_ModelLoader": "🧩Janus Model Loader", 309 | "Janus_MultimodalUnderstanding": "🧩Janus Multimodal Understanding", 310 | "Janus_ImageGeneration": "🧩Janus Image Generation" 311 | } 312 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | ![ComfyUI DeepSeek JanusPro 封面图](https://github.com/user-attachments/assets/9f8ef93e-067c-42af-9eaf-d0f0fbd427c4) 3 | 4 | 5 | # ComfyUI-DeepSeek-JanusPro(项目说明+细节还在完善中,代码已经可以使用) 6 | 7 | 截屏2025-01-30 02 06 33 8 | 9 | ## 由 DeepSeek R1 成功独立完成代码(指:我未写、我未了解原项目代码、我未检查代码) 10 | 11 | DeepSeek R1 自己给自己的 JanusPro 成功写好 ComfyUI 插件(我没写一行! 12 | 13 | 关键点:之前是 LLM 辅助我写插件,我还得了解代码本身,现在几乎无脑给 R1 就能直接交付了 14 | 15 | 无需微调直接就成,无需人看代码/写代码,细节准确度高,预计交互次数理想状态下可以控制在 3-5 次以内(标准是直接就能在 ComfyUI 成功运行),体感比 O1 的细节/准确度更好(还需进一步验证 16 | 17 | 18 | ## 具体过程如下 19 | 20 | 1)我的角色:信息传递员+判断者,我没看 JanusPro 代码,直接都丢给 R1 处理 21 | 22 | 2)给 R1 的样本学习:我自己写的 Emu3 插件的完整代码(两者架构不同 23 | 24 | 3)把 JanusPro 的官方 demo 代码丢给 R1 25 | 26 | 4)R1 先将其分为3个核心节点,然后写出了完整代码,并对其做了优化和兼容性考虑(增强,还给出了使用方式和建议参数范围 27 | 28 | 5)运行之后遇到第一次报错(1个,我提出要求之后 R1 完成修改 29 | 30 | 6)运行之后遇到第二次报错(2个,成功解决,但是由于报错之后未运行第二项功能的节点,所以我提出同样也需要修改,R1 完成修改,但是漏掉了部分关键格式 31 | 32 | 7)补充完整遗漏,第一部分功能已经实现可以正常运行 33 | 34 | 8)第二部分功能 R1 做了过度思考和复杂化,导致偏离原代码,我在发现此现象后,向其提出是否已经偏离原代码,请检查,R1 回顾之前报错并纠正偏离,第二部分也成功实现并运行,运行结果如下图 35 | 36 | 37 | ## 部分思考过程截图 38 | 39 | 截屏2025-01-30 02 50 57 40 | 41 | 截屏2025-01-30 02 51 47 42 | 43 | 44 | ## 使用示例: 45 | 46 | 截屏2025-01-30 02 13 08 47 | 48 | 49 | ## 更新日志 50 | 51 | - 20250221 新增封面图,并将会合并到新的大项目中:[DeepSeek|All-In-One|ComfyUI](https://github.com/ZHO-ZHO-ZHO/ComfyUI-DeepSeek-All-In-One) 52 | 53 | - 20250130(大年初二) 54 | 55 | V1.0 由 DeepSeek R1 成功独立完成代码(指:我未写、我未了解原项目代码、我未检查代码) 56 | 57 | 创建项目 58 | 59 | 60 | ## Stars 61 | 62 | [![Star History Chart](https://api.star-history.com/svg?repos=ZHO-ZHO-ZHO/ComfyUI-DeepSeek-JanusPro&type=Date)](https://star-history.com/#ZHO-ZHO-ZHO/ComfyUI-DeepSeek-JanusPro&Date) 63 | 64 | 65 | ## 关于我 | About me 66 | 67 | 📬 **联系我**: 68 | - 邮箱:zhozho3965@gmail.com 69 | - QQ 群:839821928 70 | 71 | 🔗 **社交媒体**: 72 | - 个人页:[-Zho-](https://jike.city/zho) 73 | - Bilibili:[我的B站主页](https://space.bilibili.com/484366804) 74 | - X(Twitter):[我的Twitter](https://twitter.com/ZHO_ZHO_ZHO) 75 | - 小红书:[我的小红书主页](https://www.xiaohongshu.com/user/profile/63f11530000000001001e0c8?xhsshare=CopyLink&appuid=63f11530000000001001e0c8&apptime=1690528872) 76 | 77 | 💡 **支持我**: 78 | - B站:[B站充电](https://space.bilibili.com/484366804) 79 | - 爱发电:[为我充电](https://afdian.com/a/ZHOZHO) 80 | 81 | 82 | ## Credits 83 | 84 | [Janus](https://github.com/deepseek-ai/Janus/tree/main) 85 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .JanusPro import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 2 | 3 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] 4 | -------------------------------------------------------------------------------- /janus/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024 DeepSeek. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | 21 | # check if python version is above 3.10 22 | import sys 23 | 24 | if sys.version_info >= (3, 10): 25 | print("Python version is above 3.10, patching the collections module.") 26 | # Monkey patch collections 27 | import collections 28 | import collections.abc 29 | 30 | for type_name in collections.abc.__all__: 31 | setattr(collections, type_name, getattr(collections.abc, type_name)) 32 | -------------------------------------------------------------------------------- /janus/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024 DeepSeek. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | from .image_processing_vlm import VLMImageProcessor 21 | from .modeling_vlm import MultiModalityCausalLM 22 | from .processing_vlm import VLChatProcessor 23 | 24 | __all__ = [ 25 | "VLMImageProcessor", 26 | "VLChatProcessor", 27 | "MultiModalityCausalLM", 28 | ] 29 | -------------------------------------------------------------------------------- /janus/models/clip_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024 DeepSeek. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | from typing import Dict, List, Literal, Optional, Tuple, Union 21 | 22 | import torch 23 | import torch.nn as nn 24 | import torchvision.transforms 25 | from einops import rearrange 26 | 27 | from janus.models.siglip_vit import create_siglip_vit 28 | 29 | 30 | class CLIPVisionTower(nn.Module): 31 | def __init__( 32 | self, 33 | model_name: str = "siglip_large_patch16_384", 34 | image_size: Union[Tuple[int, int], int] = 336, 35 | select_feature: str = "patch", 36 | select_layer: int = -2, 37 | select_layers: list = None, 38 | ckpt_path: str = "", 39 | pixel_mean: Optional[List[float]] = None, 40 | pixel_std: Optional[List[float]] = None, 41 | **kwargs, 42 | ): 43 | super().__init__() 44 | 45 | self.model_name = model_name 46 | self.select_feature = select_feature 47 | self.select_layer = select_layer 48 | self.select_layers = select_layers 49 | 50 | vision_tower_params = { 51 | "model_name": model_name, 52 | "image_size": image_size, 53 | "ckpt_path": ckpt_path, 54 | "select_layer": select_layer, 55 | } 56 | vision_tower_params.update(kwargs) 57 | self.vision_tower, self.forward_kwargs = self.build_vision_tower( 58 | vision_tower_params 59 | ) 60 | 61 | if pixel_mean is not None and pixel_std is not None: 62 | image_norm = torchvision.transforms.Normalize( 63 | mean=pixel_mean, std=pixel_std 64 | ) 65 | else: 66 | image_norm = None 67 | 68 | self.image_norm = image_norm 69 | 70 | def build_vision_tower(self, vision_tower_params): 71 | if self.model_name.startswith("siglip"): 72 | self.select_feature = "same" 73 | vision_tower = create_siglip_vit(**vision_tower_params) 74 | forward_kwargs = dict() 75 | 76 | elif self.model_name.startswith("sam"): 77 | vision_tower = create_sam_vit(**vision_tower_params) 78 | forward_kwargs = dict() 79 | 80 | else: # huggingface 81 | from transformers import CLIPVisionModel 82 | 83 | vision_tower = CLIPVisionModel.from_pretrained(**vision_tower_params) 84 | forward_kwargs = dict(output_hidden_states=True) 85 | 86 | return vision_tower, forward_kwargs 87 | 88 | def feature_select(self, image_forward_outs): 89 | if isinstance(image_forward_outs, torch.Tensor): 90 | # the output has been the self.select_layer"s features 91 | image_features = image_forward_outs 92 | else: 93 | image_features = image_forward_outs.hidden_states[self.select_layer] 94 | 95 | if self.select_feature == "patch": 96 | # if the output has cls_token 97 | image_features = image_features[:, 1:] 98 | elif self.select_feature == "cls_patch": 99 | image_features = image_features 100 | elif self.select_feature == "same": 101 | image_features = image_features 102 | 103 | else: 104 | raise ValueError(f"Unexpected select feature: {self.select_feature}") 105 | return image_features 106 | 107 | def forward(self, images): 108 | """ 109 | 110 | Args: 111 | images (torch.Tensor): [b, 3, H, W] 112 | 113 | Returns: 114 | image_features (torch.Tensor): [b, n_patch, d] 115 | """ 116 | 117 | if self.image_norm is not None: 118 | images = self.image_norm(images) 119 | 120 | image_forward_outs = self.vision_tower(images, **self.forward_kwargs) 121 | image_features = self.feature_select(image_forward_outs) 122 | return image_features 123 | -------------------------------------------------------------------------------- /janus/models/image_processing_vlm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024 DeepSeek. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | from typing import List, Tuple, Union 21 | 22 | import numpy as np 23 | import torch 24 | import torchvision 25 | import torchvision.transforms.functional 26 | from PIL import Image 27 | from transformers import AutoImageProcessor, PretrainedConfig 28 | from transformers.image_processing_utils import BaseImageProcessor, BatchFeature 29 | from transformers.image_utils import to_numpy_array 30 | from transformers.utils import logging 31 | 32 | logger = logging.get_logger(__name__) 33 | 34 | ImageType = Union[np.ndarray, torch.Tensor, Image.Image] 35 | IMAGENET_MEAN = (0.48145466, 0.4578275, 0.40821073) 36 | IMAGENET_STD = (0.26862954, 0.26130258, 0.27577711) 37 | IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) 38 | IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) 39 | 40 | 41 | def expand2square(pil_img, background_color): 42 | width, height = pil_img.size 43 | if width == height: 44 | return pil_img 45 | elif width > height: 46 | result = Image.new(pil_img.mode, (width, width), background_color) 47 | result.paste(pil_img, (0, (width - height) // 2)) 48 | return result 49 | else: 50 | result = Image.new(pil_img.mode, (height, height), background_color) 51 | result.paste(pil_img, ((height - width) // 2, 0)) 52 | return result 53 | 54 | 55 | class VLMImageProcessorConfig(PretrainedConfig): 56 | model_type = "deepseek_vlm" 57 | image_size: int 58 | min_size: int 59 | image_mean: Union[Tuple[float, float, float], List[float]] 60 | image_std: Union[Tuple[float, float, float], List[float]] 61 | rescale_factor: float 62 | do_normalize: bool 63 | 64 | def __init__( 65 | self, 66 | image_size: int, 67 | min_size: int = 14, 68 | image_mean: Union[Tuple[float, float, float], List[float]] = ( 69 | 0.48145466, 70 | 0.4578275, 71 | 0.40821073, 72 | ), 73 | image_std: Union[Tuple[float, float, float], List[float]] = ( 74 | 0.26862954, 75 | 0.26130258, 76 | 0.27577711, 77 | ), 78 | rescale_factor: float = 1.0 / 255.0, 79 | do_normalize: bool = True, 80 | **kwargs, 81 | ): 82 | self.image_size = image_size 83 | self.min_size = min_size 84 | self.image_mean = image_mean 85 | self.image_std = image_std 86 | self.rescale_factor = rescale_factor 87 | self.do_normalize = do_normalize 88 | 89 | super().__init__(**kwargs) 90 | 91 | 92 | class VLMImageProcessor(BaseImageProcessor): 93 | model_input_names = ["pixel_values"] 94 | 95 | def __init__( 96 | self, 97 | image_size: int, 98 | min_size: int = 14, 99 | image_mean: Union[Tuple[float, float, float], List[float]] = ( 100 | 0.48145466, 101 | 0.4578275, 102 | 0.40821073, 103 | ), 104 | image_std: Union[Tuple[float, float, float], List[float]] = ( 105 | 0.26862954, 106 | 0.26130258, 107 | 0.27577711, 108 | ), 109 | rescale_factor: float = 1.0 / 255.0, 110 | do_normalize: bool = True, 111 | **kwargs, 112 | ): 113 | super().__init__(**kwargs) 114 | 115 | self.image_size = image_size 116 | self.rescale_factor = rescale_factor 117 | self.image_mean = image_mean 118 | self.image_std = image_std 119 | self.min_size = min_size 120 | self.do_normalize = do_normalize 121 | 122 | if image_mean is None: 123 | self.background_color = (127, 127, 127) 124 | else: 125 | self.background_color = tuple([int(x * 255) for x in image_mean]) 126 | 127 | def resize(self, pil_img: Image) -> np.ndarray: 128 | """ 129 | 130 | Args: 131 | pil_img (PIL.Image): [H, W, 3] in PIL.Image in RGB 132 | 133 | Returns: 134 | x (np.ndarray): [3, self.image_size, self.image_size] 135 | """ 136 | 137 | width, height = pil_img.size 138 | max_size = max(width, height) 139 | 140 | size = [ 141 | max(int(height / max_size * self.image_size), self.min_size), 142 | max(int(width / max_size * self.image_size), self.min_size), 143 | ] 144 | 145 | if width <= 0 or height <= 0 or size[0] <= 0 or size[1] <= 0: 146 | print(f"orig size = {pil_img.size}, new size = {size}") 147 | raise ValueError("Invalid size!") 148 | 149 | pil_img = torchvision.transforms.functional.resize( 150 | pil_img, 151 | size, 152 | interpolation=torchvision.transforms.functional.InterpolationMode.BICUBIC, 153 | antialias=True, 154 | ) 155 | 156 | pil_img = expand2square(pil_img, self.background_color) 157 | x = to_numpy_array(pil_img) 158 | 159 | # [H, W, 3] -> [3, H, W] 160 | x = np.transpose(x, (2, 0, 1)) 161 | 162 | return x 163 | 164 | def preprocess(self, images, return_tensors: str = "pt", **kwargs) -> BatchFeature: 165 | # resize and pad to [self.image_size, self.image_size] 166 | # then convert from [H, W, 3] to [3, H, W] 167 | images: List[np.ndarray] = [self.resize(image) for image in images] 168 | 169 | # resacle from [0, 255] -> [0, 1] 170 | images = [ 171 | self.rescale( 172 | image=image, 173 | scale=self.rescale_factor, 174 | input_data_format="channels_first", 175 | ) 176 | for image in images 177 | ] 178 | 179 | # normalize 180 | if self.do_normalize: 181 | images = [ 182 | self.normalize( 183 | image=image, 184 | mean=self.image_mean, 185 | std=self.image_std, 186 | input_data_format="channels_first", 187 | ) 188 | for image in images 189 | ] 190 | 191 | data = {"pixel_values": images} 192 | return BatchFeature(data=data, tensor_type=return_tensors) 193 | 194 | @property 195 | def default_shape(self): 196 | return [3, self.image_size, self.image_size] 197 | 198 | 199 | AutoImageProcessor.register(VLMImageProcessorConfig, VLMImageProcessor) 200 | 201 | 202 | if __name__ == "__main__": 203 | image_processor = VLMImageProcessor( 204 | image_size=1024, 205 | image_mean=IMAGENET_INCEPTION_MEAN, 206 | image_std=IMAGENET_INCEPTION_STD, 207 | do_normalize=True, 208 | ) 209 | -------------------------------------------------------------------------------- /janus/models/modeling_vlm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024 DeepSeek. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | import torch 21 | from attrdict import AttrDict 22 | from einops import rearrange 23 | from transformers import ( 24 | AutoConfig, 25 | AutoModelForCausalLM, 26 | LlamaConfig, 27 | LlamaForCausalLM, 28 | PreTrainedModel, 29 | ) 30 | from transformers.configuration_utils import PretrainedConfig 31 | 32 | from janus.models.clip_encoder import CLIPVisionTower 33 | from janus.models.projector import MlpProjector 34 | 35 | 36 | class vision_head(torch.nn.Module): 37 | def __init__(self, params): 38 | super().__init__() 39 | self.output_mlp_projector = torch.nn.Linear( 40 | params.n_embed, params.image_token_embed 41 | ) 42 | self.vision_activation = torch.nn.GELU() 43 | self.vision_head = torch.nn.Linear( 44 | params.image_token_embed, params.image_token_size 45 | ) 46 | 47 | def forward(self, x): 48 | x = self.output_mlp_projector(x) 49 | x = self.vision_activation(x) 50 | x = self.vision_head(x) 51 | return x 52 | 53 | 54 | def model_name_to_cls(cls_name): 55 | if "MlpProjector" in cls_name: 56 | cls = MlpProjector 57 | 58 | elif "CLIPVisionTower" in cls_name: 59 | cls = CLIPVisionTower 60 | 61 | elif "VQ" in cls_name: 62 | from janus.models.vq_model import VQ_models 63 | 64 | cls = VQ_models[cls_name] 65 | elif "vision_head" in cls_name: 66 | cls = vision_head 67 | else: 68 | raise ValueError(f"class_name {cls_name} is invalid.") 69 | 70 | return cls 71 | 72 | 73 | class VisionConfig(PretrainedConfig): 74 | model_type = "vision" 75 | cls: str = "" 76 | params: AttrDict = {} 77 | 78 | def __init__(self, **kwargs): 79 | super().__init__(**kwargs) 80 | 81 | self.cls = kwargs.get("cls", "") 82 | if not isinstance(self.cls, str): 83 | self.cls = self.cls.__name__ 84 | 85 | self.params = AttrDict(kwargs.get("params", {})) 86 | 87 | 88 | class AlignerConfig(PretrainedConfig): 89 | model_type = "aligner" 90 | cls: str = "" 91 | params: AttrDict = {} 92 | 93 | def __init__(self, **kwargs): 94 | super().__init__(**kwargs) 95 | 96 | self.cls = kwargs.get("cls", "") 97 | if not isinstance(self.cls, str): 98 | self.cls = self.cls.__name__ 99 | 100 | self.params = AttrDict(kwargs.get("params", {})) 101 | 102 | 103 | class GenVisionConfig(PretrainedConfig): 104 | model_type = "gen_vision" 105 | cls: str = "" 106 | params: AttrDict = {} 107 | 108 | def __init__(self, **kwargs): 109 | super().__init__(**kwargs) 110 | 111 | self.cls = kwargs.get("cls", "") 112 | if not isinstance(self.cls, str): 113 | self.cls = self.cls.__name__ 114 | 115 | self.params = AttrDict(kwargs.get("params", {})) 116 | 117 | 118 | class GenAlignerConfig(PretrainedConfig): 119 | model_type = "gen_aligner" 120 | cls: str = "" 121 | params: AttrDict = {} 122 | 123 | def __init__(self, **kwargs): 124 | super().__init__(**kwargs) 125 | 126 | self.cls = kwargs.get("cls", "") 127 | if not isinstance(self.cls, str): 128 | self.cls = self.cls.__name__ 129 | 130 | self.params = AttrDict(kwargs.get("params", {})) 131 | 132 | 133 | class GenHeadConfig(PretrainedConfig): 134 | model_type = "gen_head" 135 | cls: str = "" 136 | params: AttrDict = {} 137 | 138 | def __init__(self, **kwargs): 139 | super().__init__(**kwargs) 140 | 141 | self.cls = kwargs.get("cls", "") 142 | if not isinstance(self.cls, str): 143 | self.cls = self.cls.__name__ 144 | 145 | self.params = AttrDict(kwargs.get("params", {})) 146 | 147 | 148 | class MultiModalityConfig(PretrainedConfig): 149 | model_type = "multi_modality" 150 | vision_config: VisionConfig 151 | aligner_config: AlignerConfig 152 | 153 | gen_vision_config: GenVisionConfig 154 | gen_aligner_config: GenAlignerConfig 155 | gen_head_config: GenHeadConfig 156 | 157 | language_config: LlamaConfig 158 | 159 | def __init__(self, **kwargs): 160 | super().__init__(**kwargs) 161 | vision_config = kwargs.get("vision_config", {}) 162 | self.vision_config = VisionConfig(**vision_config) 163 | 164 | aligner_config = kwargs.get("aligner_config", {}) 165 | self.aligner_config = AlignerConfig(**aligner_config) 166 | 167 | gen_vision_config = kwargs.get("gen_vision_config", {}) 168 | self.gen_vision_config = GenVisionConfig(**gen_vision_config) 169 | 170 | gen_aligner_config = kwargs.get("gen_aligner_config", {}) 171 | self.gen_aligner_config = GenAlignerConfig(**gen_aligner_config) 172 | 173 | gen_head_config = kwargs.get("gen_head_config", {}) 174 | self.gen_head_config = GenHeadConfig(**gen_head_config) 175 | 176 | language_config = kwargs.get("language_config", {}) 177 | if isinstance(language_config, LlamaConfig): 178 | self.language_config = language_config 179 | else: 180 | self.language_config = LlamaConfig(**language_config) 181 | 182 | 183 | class MultiModalityPreTrainedModel(PreTrainedModel): 184 | config_class = MultiModalityConfig 185 | base_model_prefix = "multi_modality" 186 | _no_split_modules = [] 187 | _skip_keys_device_placement = "past_key_values" 188 | 189 | 190 | class MultiModalityCausalLM(MultiModalityPreTrainedModel): 191 | def __init__(self, config: MultiModalityConfig): 192 | super().__init__(config) 193 | 194 | vision_config = config.vision_config 195 | vision_cls = model_name_to_cls(vision_config.cls) 196 | self.vision_model = vision_cls(**vision_config.params) 197 | 198 | aligner_config = config.aligner_config 199 | aligner_cls = model_name_to_cls(aligner_config.cls) 200 | self.aligner = aligner_cls(aligner_config.params) 201 | 202 | gen_vision_config = config.gen_vision_config 203 | gen_vision_cls = model_name_to_cls(gen_vision_config.cls) 204 | self.gen_vision_model = gen_vision_cls() 205 | 206 | gen_aligner_config = config.gen_aligner_config 207 | gen_aligner_cls = model_name_to_cls(gen_aligner_config.cls) 208 | self.gen_aligner = gen_aligner_cls(gen_aligner_config.params) 209 | 210 | gen_head_config = config.gen_head_config 211 | gen_head_cls = model_name_to_cls(gen_head_config.cls) 212 | self.gen_head = gen_head_cls(gen_head_config.params) 213 | 214 | self.gen_embed = torch.nn.Embedding( 215 | gen_vision_config.params.image_token_size, gen_vision_config.params.n_embed 216 | ) 217 | 218 | language_config = config.language_config 219 | self.language_model = LlamaForCausalLM(language_config) 220 | 221 | def prepare_inputs_embeds( 222 | self, 223 | input_ids: torch.LongTensor, 224 | pixel_values: torch.FloatTensor, 225 | images_seq_mask: torch.LongTensor, 226 | images_emb_mask: torch.LongTensor, 227 | **kwargs, 228 | ): 229 | """ 230 | 231 | Args: 232 | input_ids (torch.LongTensor): [b, T] 233 | pixel_values (torch.FloatTensor): [b, n_images, 3, h, w] 234 | images_seq_mask (torch.BoolTensor): [b, T] 235 | images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens] 236 | 237 | assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask) 238 | 239 | Returns: 240 | input_embeds (torch.Tensor): [b, T, D] 241 | """ 242 | 243 | bs, n = pixel_values.shape[0:2] 244 | images = rearrange(pixel_values, "b n c h w -> (b n) c h w") 245 | # [b x n, T2, D] 246 | images_embeds = self.aligner(self.vision_model(images)) 247 | 248 | # [b x n, T2, D] -> [b, n x T2, D] 249 | images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n) 250 | # [b, n, T2] -> [b, n x T2] 251 | images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)") 252 | 253 | # [b, T, D] 254 | input_ids[input_ids < 0] = 0 # ignore the image embeddings 255 | inputs_embeds = self.language_model.get_input_embeddings()(input_ids) 256 | 257 | # replace with the image embeddings 258 | inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask] 259 | 260 | return inputs_embeds 261 | 262 | def prepare_gen_img_embeds(self, image_ids: torch.LongTensor): 263 | return self.gen_aligner(self.gen_embed(image_ids)) 264 | 265 | 266 | AutoConfig.register("vision", VisionConfig) 267 | AutoConfig.register("aligner", AlignerConfig) 268 | AutoConfig.register("gen_vision", GenVisionConfig) 269 | AutoConfig.register("gen_aligner", GenAlignerConfig) 270 | AutoConfig.register("gen_head", GenHeadConfig) 271 | AutoConfig.register("multi_modality", MultiModalityConfig) 272 | AutoModelForCausalLM.register(MultiModalityConfig, MultiModalityCausalLM) 273 | -------------------------------------------------------------------------------- /janus/models/processing_vlm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024 DeepSeek. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | from dataclasses import dataclass 21 | from typing import Dict, List 22 | 23 | import torch 24 | from PIL.Image import Image 25 | from transformers import LlamaTokenizerFast 26 | from transformers.processing_utils import ProcessorMixin 27 | 28 | from janus.models.image_processing_vlm import VLMImageProcessor 29 | from janus.utils.conversation import get_conv_template 30 | 31 | 32 | class DictOutput(object): 33 | def keys(self): 34 | return self.__dict__.keys() 35 | 36 | def __getitem__(self, item): 37 | return self.__dict__[item] 38 | 39 | def __setitem__(self, key, value): 40 | self.__dict__[key] = value 41 | 42 | 43 | @dataclass 44 | class VLChatProcessorOutput(DictOutput): 45 | sft_format: str 46 | input_ids: torch.Tensor 47 | pixel_values: torch.Tensor 48 | num_image_tokens: torch.IntTensor 49 | 50 | def __len__(self): 51 | return len(self.input_ids) 52 | 53 | 54 | @dataclass 55 | class BatchedVLChatProcessorOutput(DictOutput): 56 | sft_format: List[str] 57 | input_ids: torch.Tensor 58 | pixel_values: torch.Tensor 59 | attention_mask: torch.Tensor 60 | images_seq_mask: torch.BoolTensor 61 | images_emb_mask: torch.BoolTensor 62 | 63 | def to(self, device, dtype=torch.bfloat16): 64 | self.input_ids = self.input_ids.to(device) 65 | self.attention_mask = self.attention_mask.to(device) 66 | self.images_seq_mask = self.images_seq_mask.to(device) 67 | self.images_emb_mask = self.images_emb_mask.to(device) 68 | self.pixel_values = self.pixel_values.to(device=device, dtype=dtype) 69 | return self 70 | 71 | 72 | class VLChatProcessor(ProcessorMixin): 73 | image_processor_class = "AutoImageProcessor" 74 | tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") 75 | 76 | attributes = ["image_processor", "tokenizer"] 77 | 78 | system_prompt = ( 79 | "You are a helpful language and vision assistant. " 80 | "You are able to understand the visual content that the user provides, " 81 | "and assist the user with a variety of tasks using natural language." 82 | ) 83 | 84 | def __init__( 85 | self, 86 | image_processor: VLMImageProcessor, 87 | tokenizer: LlamaTokenizerFast, 88 | image_tag: str = "", 89 | image_start_tag: str = "", 90 | image_end_tag: str = "", 91 | pad_tag: str = "<|▁pad▁|>", 92 | num_image_tokens: int = 576, 93 | add_special_token: bool = False, 94 | sft_format: str = "deepseek", 95 | mask_prompt: bool = True, 96 | ignore_id: int = -100, 97 | **kwargs, 98 | ): 99 | self.image_processor = image_processor 100 | self.tokenizer = tokenizer 101 | 102 | image_id = self.tokenizer.vocab.get(image_tag) 103 | if image_id is None: 104 | special_tokens = [image_tag] 105 | special_tokens_dict = {"additional_special_tokens": special_tokens} 106 | self.tokenizer.add_special_tokens(special_tokens_dict) 107 | print(f"Add image tag = {image_tag} to the tokenizer") 108 | 109 | self.image_tag = image_tag 110 | self.image_start_tag = image_start_tag 111 | self.image_end_tag = image_end_tag 112 | self.pad_tag = pad_tag 113 | 114 | self.num_image_tokens = num_image_tokens 115 | self.add_special_token = add_special_token 116 | self.sft_format = sft_format 117 | self.mask_prompt = mask_prompt 118 | self.ignore_id = ignore_id 119 | 120 | super().__init__( 121 | image_processor, 122 | tokenizer, 123 | image_tag, 124 | num_image_tokens, 125 | add_special_token, 126 | sft_format, 127 | mask_prompt, 128 | ignore_id, 129 | **kwargs, 130 | ) 131 | 132 | def new_chat_template(self): 133 | conv = get_conv_template(self.sft_format) 134 | conv.set_system_message(self.system_prompt) 135 | return conv 136 | 137 | def apply_sft_template_for_multi_turn_prompts( 138 | self, 139 | conversations: List[Dict[str, str]], 140 | sft_format: str = "deepseek", 141 | system_prompt: str = "", 142 | ): 143 | """ 144 | Applies the SFT template to conversation. 145 | 146 | An example of conversation: 147 | conversation = [ 148 | { 149 | "role": "User", 150 | "content": " is Figure 1.\n is Figure 2.\nWhich image is brighter?", 151 | "images": [ 152 | "./multi-images/attribute_comparison_1.png", 153 | "./multi-images/attribute_comparison_2.png" 154 | ] 155 | }, 156 | { 157 | "role": "Assistant", 158 | "content": "" 159 | } 160 | ] 161 | 162 | Args: 163 | conversations (List[Dict]): A conversation with a List of Dict[str, str] text. 164 | sft_format (str, optional): The format of the SFT template to use. Defaults to "deepseek". 165 | system_prompt (str, optional): The system prompt to use in the SFT template. Defaults to "". 166 | 167 | Returns: 168 | sft_prompt (str): The formatted text. 169 | """ 170 | 171 | conv = get_conv_template(sft_format) 172 | conv.set_system_message(system_prompt) 173 | for message in conversations: 174 | conv.append_message(message["role"], message["content"].strip()) 175 | sft_prompt = conv.get_prompt().strip() 176 | 177 | return sft_prompt 178 | 179 | @property 180 | def image_token(self): 181 | return self.image_tag 182 | 183 | @property 184 | def image_id(self): 185 | image_id = self.tokenizer.vocab.get(self.image_tag) 186 | return image_id 187 | 188 | @property 189 | def image_start_id(self): 190 | image_start_id = self.tokenizer.vocab.get(self.image_start_tag) 191 | return image_start_id 192 | 193 | @property 194 | def image_end_id(self): 195 | image_end_id = self.tokenizer.vocab.get(self.image_end_tag) 196 | return image_end_id 197 | 198 | @property 199 | def image_start_token(self): 200 | return self.image_start_tag 201 | 202 | @property 203 | def image_end_token(self): 204 | return self.image_end_tag 205 | 206 | @property 207 | def pad_id(self): 208 | pad_id = self.tokenizer.vocab.get(self.pad_tag) 209 | # pad_id = self.tokenizer.pad_token_id 210 | # if pad_id is None: 211 | # pad_id = self.tokenizer.eos_token_id 212 | 213 | return pad_id 214 | 215 | def add_image_token( 216 | self, 217 | image_indices: List[int], 218 | input_ids: torch.LongTensor, 219 | ): 220 | """ 221 | 222 | Args: 223 | image_indices (List[int]): [index_0, index_1, ..., index_j] 224 | input_ids (torch.LongTensor): [N] 225 | 226 | Returns: 227 | input_ids (torch.LongTensor): [N + image tokens] 228 | num_image_tokens (torch.IntTensor): [n_images] 229 | """ 230 | 231 | input_slices = [] 232 | 233 | start = 0 234 | for index in image_indices: 235 | if self.add_special_token: 236 | end = index + 1 237 | else: 238 | end = index 239 | 240 | # original text tokens 241 | input_slices.append(input_ids[start:end]) 242 | 243 | # add boi, image tokens, eoi and set the mask as False 244 | input_slices.append(self.image_start_id * torch.ones((1), dtype=torch.long)) 245 | input_slices.append( 246 | self.image_id * torch.ones((self.num_image_tokens,), dtype=torch.long) 247 | ) 248 | input_slices.append(self.image_end_id * torch.ones((1), dtype=torch.long)) 249 | start = index + 1 250 | 251 | # the left part 252 | input_slices.append(input_ids[start:]) 253 | 254 | # concat all slices 255 | input_ids = torch.cat(input_slices, dim=0) 256 | num_image_tokens = torch.IntTensor([self.num_image_tokens] * len(image_indices)) 257 | 258 | return input_ids, num_image_tokens 259 | 260 | def process_one( 261 | self, 262 | prompt: str = None, 263 | conversations: List[Dict[str, str]] = None, 264 | images: List[Image] = None, 265 | **kwargs, 266 | ): 267 | """ 268 | 269 | Args: 270 | prompt (str): the formatted prompt; 271 | conversations (List[Dict]): conversations with a list of messages; 272 | images (List[ImageType]): the list of images; 273 | **kwargs: 274 | 275 | Returns: 276 | outputs (BaseProcessorOutput): the output of the processor, 277 | - input_ids (torch.LongTensor): [N + image tokens] 278 | - target_ids (torch.LongTensor): [N + image tokens] 279 | - images (torch.FloatTensor): [n_images, 3, H, W] 280 | - image_id (int): the id of the image token 281 | - num_image_tokens (List[int]): the number of image tokens 282 | """ 283 | 284 | assert ( 285 | prompt is None or conversations is None 286 | ), "prompt and conversations cannot be used at the same time." 287 | 288 | if prompt is None: 289 | # apply sft format 290 | sft_format = self.apply_sft_template_for_multi_turn_prompts( 291 | conversations=conversations, 292 | sft_format=self.sft_format, 293 | system_prompt=self.system_prompt, 294 | ) 295 | else: 296 | sft_format = prompt 297 | 298 | # tokenize 299 | input_ids = self.tokenizer.encode(sft_format) 300 | input_ids = torch.LongTensor(input_ids) 301 | 302 | # add image tokens to the input_ids 303 | image_token_mask: torch.BoolTensor = input_ids == self.image_id 304 | image_indices = image_token_mask.nonzero() 305 | input_ids, num_image_tokens = self.add_image_token( 306 | image_indices=image_indices, 307 | input_ids=input_ids, 308 | ) 309 | 310 | # load images 311 | images_outputs = self.image_processor(images, return_tensors="pt") 312 | 313 | prepare = VLChatProcessorOutput( 314 | sft_format=sft_format, 315 | input_ids=input_ids, 316 | pixel_values=images_outputs.pixel_values, 317 | num_image_tokens=num_image_tokens, 318 | ) 319 | 320 | return prepare 321 | 322 | def __call__( 323 | self, 324 | *, 325 | prompt: str = None, 326 | conversations: List[Dict[str, str]] = None, 327 | images: List[Image] = None, 328 | force_batchify: bool = True, 329 | **kwargs, 330 | ): 331 | """ 332 | 333 | Args: 334 | prompt (str): the formatted prompt; 335 | conversations (List[Dict]): conversations with a list of messages; 336 | images (List[ImageType]): the list of images; 337 | force_batchify (bool): force batchify the inputs; 338 | **kwargs: 339 | 340 | Returns: 341 | outputs (BaseProcessorOutput): the output of the processor, 342 | - input_ids (torch.LongTensor): [N + image tokens] 343 | - images (torch.FloatTensor): [n_images, 3, H, W] 344 | - image_id (int): the id of the image token 345 | - num_image_tokens (List[int]): the number of image tokens 346 | """ 347 | 348 | prepare = self.process_one( 349 | prompt=prompt, conversations=conversations, images=images 350 | ) 351 | 352 | if force_batchify: 353 | prepare = self.batchify([prepare]) 354 | 355 | return prepare 356 | 357 | def batchify( 358 | self, prepare_list: List[VLChatProcessorOutput] 359 | ) -> BatchedVLChatProcessorOutput: 360 | """ 361 | Preprocesses the inputs for multimodal inference. 362 | 363 | Args: 364 | prepare_list (List[VLChatProcessorOutput]): A list of VLChatProcessorOutput. 365 | 366 | Returns: 367 | BatchedVLChatProcessorOutput: A dictionary of the inputs to use for multimodal inference. 368 | """ 369 | 370 | batch_size = len(prepare_list) 371 | sft_format = [] 372 | n_images = [] 373 | seq_lens = [] 374 | for prepare in prepare_list: 375 | n_images.append(len(prepare.num_image_tokens)) 376 | seq_lens.append(len(prepare)) 377 | 378 | input_token_max_len = max(seq_lens) 379 | max_n_images = max(1, max(n_images)) 380 | 381 | batched_input_ids = torch.full( 382 | (batch_size, input_token_max_len), self.pad_id 383 | ).long() # FIXME 384 | batched_attention_mask = torch.zeros((batch_size, input_token_max_len)).long() 385 | batched_pixel_values = torch.zeros( 386 | (batch_size, max_n_images, *self.image_processor.default_shape) 387 | ).float() 388 | batched_images_seq_mask = torch.zeros((batch_size, input_token_max_len)).bool() 389 | batched_images_emb_mask = torch.zeros( 390 | (batch_size, max_n_images, self.num_image_tokens) 391 | ).bool() 392 | 393 | for i, prepare in enumerate(prepare_list): 394 | input_ids = prepare.input_ids 395 | seq_len = len(prepare) 396 | n_image = len(prepare.num_image_tokens) 397 | # left-padding 398 | batched_attention_mask[i, -seq_len:] = 1 399 | batched_input_ids[i, -seq_len:] = torch.LongTensor(input_ids) 400 | batched_images_seq_mask[i, -seq_len:] = input_ids == self.image_id 401 | 402 | if n_image > 0: 403 | batched_pixel_values[i, :n_image] = prepare.pixel_values 404 | for j, n_image_tokens in enumerate(prepare.num_image_tokens): 405 | batched_images_emb_mask[i, j, :n_image_tokens] = True 406 | 407 | sft_format.append(prepare.sft_format) 408 | 409 | batched_prepares = BatchedVLChatProcessorOutput( 410 | input_ids=batched_input_ids, 411 | attention_mask=batched_attention_mask, 412 | pixel_values=batched_pixel_values, 413 | images_seq_mask=batched_images_seq_mask, 414 | images_emb_mask=batched_images_emb_mask, 415 | sft_format=sft_format, 416 | ) 417 | 418 | return batched_prepares 419 | -------------------------------------------------------------------------------- /janus/models/projector.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024 DeepSeek. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | from typing import Tuple, Union 21 | 22 | import torch 23 | import torch.nn as nn 24 | from attrdict import AttrDict 25 | 26 | 27 | class MlpProjector(nn.Module): 28 | def __init__(self, cfg): 29 | super().__init__() 30 | 31 | self.cfg = cfg 32 | 33 | if cfg.projector_type == "identity": 34 | modules = nn.Identity() 35 | 36 | elif cfg.projector_type == "linear": 37 | modules = nn.Linear(cfg.input_dim, cfg.n_embed) 38 | 39 | elif cfg.projector_type == "mlp_gelu": 40 | mlp_depth = cfg.get("depth", 1) 41 | modules = [nn.Linear(cfg.input_dim, cfg.n_embed)] 42 | for _ in range(1, mlp_depth): 43 | modules.append(nn.GELU()) 44 | modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) 45 | modules = nn.Sequential(*modules) 46 | 47 | elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu": 48 | mlp_depth = cfg.get("depth", 1) 49 | self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2) 50 | self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2) 51 | 52 | modules = [] 53 | for _ in range(1, mlp_depth): 54 | modules.append(nn.GELU()) 55 | modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) 56 | modules = nn.Sequential(*modules) 57 | 58 | else: 59 | raise ValueError(f"Unknown projector type: {cfg.projector_type}") 60 | 61 | self.layers = modules 62 | 63 | def forward( 64 | self, x_or_tuple: Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor] 65 | ): 66 | """ 67 | 68 | Args: 69 | x_or_tuple (Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: if it is a tuple of torch.Tensor, 70 | then it comes from the hybrid vision encoder, and x = high_res_x, low_res_x); 71 | otherwise it is the feature from the single vision encoder. 72 | 73 | Returns: 74 | x (torch.Tensor): [b, s, c] 75 | """ 76 | 77 | if isinstance(x_or_tuple, tuple): 78 | # self.cfg.projector_type == "low_high_hybrid_split_mlp_gelu": 79 | high_x, low_x = x_or_tuple 80 | high_x = self.high_up_proj(high_x) 81 | low_x = self.low_up_proj(low_x) 82 | x = torch.concat([high_x, low_x], dim=-1) 83 | else: 84 | x = x_or_tuple 85 | 86 | return self.layers(x) 87 | 88 | 89 | if __name__ == "__main__": 90 | cfg = AttrDict( 91 | input_dim=1024, 92 | n_embed=2048, 93 | depth=2, 94 | projector_type="low_high_hybrid_split_mlp_gelu", 95 | ) 96 | inputs = (torch.rand(4, 576, 1024), torch.rand(4, 576, 1024)) 97 | 98 | m = MlpProjector(cfg) 99 | out = m(inputs) 100 | print(out.shape) 101 | -------------------------------------------------------------------------------- /janus/models/siglip_vit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024 DeepSeek. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | # https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py 21 | import math 22 | import warnings 23 | from dataclasses import dataclass 24 | from functools import partial 25 | from typing import ( 26 | Callable, 27 | Dict, 28 | Final, 29 | List, 30 | Literal, 31 | Optional, 32 | Sequence, 33 | Set, 34 | Tuple, 35 | Type, 36 | Union, 37 | ) 38 | 39 | import torch 40 | import torch.nn as nn 41 | import torch.nn.functional as F 42 | from timm.layers import ( 43 | AttentionPoolLatent, 44 | DropPath, 45 | LayerType, 46 | Mlp, 47 | PatchDropout, 48 | PatchEmbed, 49 | resample_abs_pos_embed, 50 | ) 51 | from timm.models._manipulate import checkpoint_seq, named_apply 52 | 53 | 54 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 55 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 56 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 57 | def norm_cdf(x): 58 | # Computes standard normal cumulative distribution function 59 | return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 60 | 61 | if (mean < a - 2 * std) or (mean > b + 2 * std): 62 | warnings.warn( 63 | "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 64 | "The distribution of values may be incorrect.", 65 | stacklevel=2, 66 | ) 67 | 68 | with torch.no_grad(): 69 | # Values are generated by using a truncated uniform distribution and 70 | # then using the inverse CDF for the normal distribution. 71 | # Get upper and lower cdf values 72 | l = norm_cdf((a - mean) / std) # noqa: E741 73 | u = norm_cdf((b - mean) / std) 74 | 75 | # Uniformly fill tensor with values from [l, u], then translate to 76 | # [2l-1, 2u-1]. 77 | tensor.uniform_(2 * l - 1, 2 * u - 1) 78 | 79 | # Use inverse cdf transform for normal distribution to get truncated 80 | # standard normal 81 | tensor.erfinv_() 82 | 83 | # Transform to proper mean, std 84 | tensor.mul_(std * math.sqrt(2.0)) 85 | tensor.add_(mean) 86 | 87 | # Clamp to ensure it's in the proper range 88 | tensor.clamp_(min=a, max=b) 89 | return tensor 90 | 91 | 92 | def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): 93 | # type: (torch.Tensor, float, float, float, float) -> torch.Tensor 94 | r"""The original timm.models.layers.weight_init.trunc_normal_ can not handle bfloat16 yet, here we first 95 | convert the tensor to float32, apply the trunc_normal_() in float32, and then convert it back to its original dtype. 96 | Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn 97 | from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 98 | with values outside :math:`[a, b]` redrawn until they are within 99 | the bounds. The method used for generating the random values works 100 | best when :math:`a \leq \text{mean} \leq b`. 101 | Args: 102 | tensor: an n-dimensional `torch.Tensor` 103 | mean: the mean of the normal distribution 104 | std: the standard deviation of the normal distribution 105 | a: the minimum cutoff value 106 | b: the maximum cutoff value 107 | Examples: 108 | >>> w = torch.empty(3, 5) 109 | >>> nn.init.trunc_normal_(w) 110 | """ 111 | 112 | with torch.no_grad(): 113 | dtype = tensor.dtype 114 | tensor_fp32 = tensor.float() 115 | tensor_fp32 = _no_grad_trunc_normal_(tensor_fp32, mean, std, a, b) 116 | tensor_dtype = tensor_fp32.to(dtype=dtype) 117 | tensor.copy_(tensor_dtype) 118 | 119 | 120 | def init_weights(self): 121 | if self.pos_embed is not None: 122 | trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5) 123 | trunc_normal_(self.latent, std=self.latent_dim**-0.5) 124 | 125 | 126 | def init_weights_vit_timm(module: nn.Module, name: str = "") -> None: 127 | """ViT weight initialization, original timm impl (for reproducibility)""" 128 | if isinstance(module, nn.Linear): 129 | trunc_normal_(module.weight, std=0.02) 130 | if module.bias is not None: 131 | nn.init.zeros_(module.bias) 132 | elif hasattr(module, "init_weights"): 133 | module.init_weights() 134 | 135 | 136 | class Attention(nn.Module): 137 | fused_attn: Final[bool] 138 | 139 | def __init__( 140 | self, 141 | dim: int, 142 | num_heads: int = 8, 143 | qkv_bias: bool = False, 144 | qk_norm: bool = False, 145 | attn_drop: float = 0.0, 146 | proj_drop: float = 0.0, 147 | norm_layer: nn.Module = nn.LayerNorm, 148 | ) -> None: 149 | super().__init__() 150 | assert dim % num_heads == 0, "dim should be divisible by num_heads" 151 | self.num_heads = num_heads 152 | self.head_dim = dim // num_heads 153 | self.scale = self.head_dim**-0.5 154 | # self.fused_attn = use_fused_attn() 155 | self.fused_attn = True 156 | 157 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 158 | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 159 | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 160 | self.attn_drop = nn.Dropout(attn_drop) 161 | self.proj = nn.Linear(dim, dim) 162 | self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0.0 else nn.Identity() 163 | 164 | def forward(self, x: torch.Tensor) -> torch.Tensor: 165 | B, N, C = x.shape 166 | qkv = ( 167 | self.qkv(x) 168 | .reshape(B, N, 3, self.num_heads, self.head_dim) 169 | .permute(2, 0, 3, 1, 4) 170 | ) 171 | q, k, v = qkv.unbind(0) 172 | q, k = self.q_norm(q), self.k_norm(k) 173 | 174 | if self.fused_attn: 175 | x = F.scaled_dot_product_attention( 176 | q, 177 | k, 178 | v, 179 | dropout_p=self.attn_drop.p if self.training else 0.0, 180 | ) 181 | else: 182 | q = q * self.scale 183 | attn = q @ k.transpose(-2, -1) 184 | attn = attn.softmax(dim=-1) 185 | attn = self.attn_drop(attn) 186 | x = attn @ v 187 | 188 | x = x.transpose(1, 2).reshape(B, N, C) 189 | x = self.proj(x) 190 | x = self.proj_drop(x) 191 | return x 192 | 193 | 194 | class LayerScale(nn.Module): 195 | def __init__( 196 | self, 197 | dim: int, 198 | init_values: float = 1e-5, 199 | inplace: bool = False, 200 | ) -> None: 201 | super().__init__() 202 | self.inplace = inplace 203 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 204 | 205 | def forward(self, x: torch.Tensor) -> torch.Tensor: 206 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 207 | 208 | 209 | class Block(nn.Module): 210 | def __init__( 211 | self, 212 | dim: int, 213 | num_heads: int, 214 | mlp_ratio: float = 4.0, 215 | qkv_bias: bool = False, 216 | qk_norm: bool = False, 217 | proj_drop: float = 0.0, 218 | attn_drop: float = 0.0, 219 | init_values: Optional[float] = None, 220 | drop_path: float = 0.0, 221 | act_layer: nn.Module = nn.GELU, 222 | norm_layer: nn.Module = nn.LayerNorm, 223 | mlp_layer: nn.Module = Mlp, 224 | ) -> None: 225 | super().__init__() 226 | self.norm1 = norm_layer(dim) 227 | self.attn = Attention( 228 | dim, 229 | num_heads=num_heads, 230 | qkv_bias=qkv_bias, 231 | qk_norm=qk_norm, 232 | attn_drop=attn_drop, 233 | proj_drop=proj_drop, 234 | norm_layer=norm_layer, 235 | ) 236 | self.ls1 = ( 237 | LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 238 | ) 239 | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 240 | 241 | self.norm2 = norm_layer(dim) 242 | self.mlp = mlp_layer( 243 | in_features=dim, 244 | hidden_features=int(dim * mlp_ratio), 245 | act_layer=act_layer, 246 | drop=proj_drop, 247 | ) 248 | self.ls2 = ( 249 | LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 250 | ) 251 | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 252 | 253 | def forward(self, x: torch.Tensor) -> torch.Tensor: 254 | x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) 255 | x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) 256 | return x 257 | 258 | 259 | class VisionTransformer(nn.Module): 260 | """Vision Transformer 261 | 262 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` 263 | - https://arxiv.org/abs/2010.11929 264 | """ 265 | 266 | dynamic_img_size: Final[bool] 267 | 268 | def __init__( 269 | self, 270 | img_size: Union[int, Tuple[int, int]] = 224, 271 | patch_size: Union[int, Tuple[int, int]] = 16, 272 | in_chans: int = 3, 273 | num_classes: int = 1000, 274 | global_pool: Literal["", "avg", "token", "map"] = "token", 275 | embed_dim: int = 768, 276 | depth: int = 12, 277 | num_heads: int = 12, 278 | mlp_ratio: float = 4.0, 279 | qkv_bias: bool = True, 280 | qk_norm: bool = False, 281 | init_values: Optional[float] = None, 282 | class_token: bool = True, 283 | no_embed_class: bool = False, 284 | reg_tokens: int = 0, 285 | pre_norm: bool = False, 286 | fc_norm: Optional[bool] = None, 287 | dynamic_img_size: bool = False, 288 | dynamic_img_pad: bool = False, 289 | drop_rate: float = 0.0, 290 | pos_drop_rate: float = 0.0, 291 | patch_drop_rate: float = 0.0, 292 | proj_drop_rate: float = 0.0, 293 | attn_drop_rate: float = 0.0, 294 | drop_path_rate: float = 0.0, 295 | weight_init: Literal["skip", "jax", "jax_nlhb", "moco", ""] = "", 296 | embed_layer: Callable = PatchEmbed, 297 | norm_layer: Optional[LayerType] = None, 298 | act_layer: Optional[LayerType] = None, 299 | block_fn: Type[nn.Module] = Block, 300 | mlp_layer: Type[nn.Module] = Mlp, 301 | ignore_head: bool = False, 302 | ) -> None: 303 | """ 304 | Args: 305 | img_size: Input image size. 306 | patch_size: Patch size. 307 | in_chans: Number of image input channels. 308 | num_classes: Mumber of classes for classification head. 309 | global_pool: Type of global pooling for final sequence (default: 'token'). 310 | embed_dim: Transformer embedding dimension. 311 | depth: Depth of transformer. 312 | num_heads: Number of attention heads. 313 | mlp_ratio: Ratio of mlp hidden dim to embedding dim. 314 | qkv_bias: Enable bias for qkv projections if True. 315 | init_values: Layer-scale init values (layer-scale enabled if not None). 316 | class_token: Use class token. 317 | no_embed_class: Don't include position embeddings for class (or reg) tokens. 318 | reg_tokens: Number of register tokens. 319 | fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'. 320 | drop_rate: Head dropout rate. 321 | pos_drop_rate: Position embedding dropout rate. 322 | attn_drop_rate: Attention dropout rate. 323 | drop_path_rate: Stochastic depth rate. 324 | weight_init: Weight initialization scheme. 325 | embed_layer: Patch embedding layer. 326 | norm_layer: Normalization layer. 327 | act_layer: MLP activation layer. 328 | block_fn: Transformer block layer. 329 | """ 330 | super().__init__() 331 | assert global_pool in ("", "avg", "token", "map") 332 | assert class_token or global_pool != "token" 333 | use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm 334 | # norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6) 335 | # act_layer = get_act_layer(act_layer) or nn.GELU 336 | norm_layer = partial(nn.LayerNorm, eps=1e-6) 337 | act_layer = nn.GELU 338 | 339 | self.num_classes = num_classes 340 | self.global_pool = global_pool 341 | self.num_features = self.embed_dim = ( 342 | embed_dim # num_features for consistency with other models 343 | ) 344 | self.num_prefix_tokens = 1 if class_token else 0 345 | self.num_prefix_tokens += reg_tokens 346 | self.num_reg_tokens = reg_tokens 347 | self.has_class_token = class_token 348 | self.no_embed_class = ( 349 | no_embed_class # don't embed prefix positions (includes reg) 350 | ) 351 | self.dynamic_img_size = dynamic_img_size 352 | self.grad_checkpointing = False 353 | self.ignore_head = ignore_head 354 | 355 | embed_args = {} 356 | if dynamic_img_size: 357 | # flatten deferred until after pos embed 358 | embed_args.update(dict(strict_img_size=False, output_fmt="NHWC")) 359 | self.patch_embed = embed_layer( 360 | img_size=img_size, 361 | patch_size=patch_size, 362 | in_chans=in_chans, 363 | embed_dim=embed_dim, 364 | bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) 365 | dynamic_img_pad=dynamic_img_pad, 366 | **embed_args, 367 | ) 368 | num_patches = self.patch_embed.num_patches 369 | 370 | self.cls_token = ( 371 | nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None 372 | ) 373 | self.reg_token = ( 374 | nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None 375 | ) 376 | embed_len = ( 377 | num_patches if no_embed_class else num_patches + self.num_prefix_tokens 378 | ) 379 | self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02) 380 | self.pos_drop = nn.Dropout(p=pos_drop_rate) 381 | if patch_drop_rate > 0: 382 | self.patch_drop = PatchDropout( 383 | patch_drop_rate, 384 | num_prefix_tokens=self.num_prefix_tokens, 385 | ) 386 | else: 387 | self.patch_drop = nn.Identity() 388 | self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() 389 | 390 | dpr = [ 391 | x.item() for x in torch.linspace(0, drop_path_rate, depth) 392 | ] # stochastic depth decay rule 393 | self.blocks = nn.Sequential( 394 | *[ 395 | block_fn( 396 | dim=embed_dim, 397 | num_heads=num_heads, 398 | mlp_ratio=mlp_ratio, 399 | qkv_bias=qkv_bias, 400 | qk_norm=qk_norm, 401 | init_values=init_values, 402 | proj_drop=proj_drop_rate, 403 | attn_drop=attn_drop_rate, 404 | drop_path=dpr[i], 405 | norm_layer=norm_layer, 406 | act_layer=act_layer, 407 | mlp_layer=mlp_layer, 408 | ) 409 | for i in range(depth) 410 | ] 411 | ) 412 | self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() 413 | 414 | # Classifier Head 415 | if global_pool == "map": 416 | AttentionPoolLatent.init_weights = init_weights 417 | self.attn_pool = AttentionPoolLatent( 418 | self.embed_dim, 419 | num_heads=num_heads, 420 | mlp_ratio=mlp_ratio, 421 | norm_layer=norm_layer, 422 | ) 423 | else: 424 | self.attn_pool = None 425 | self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() 426 | self.head_drop = nn.Dropout(drop_rate) 427 | self.head = ( 428 | nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 429 | ) 430 | 431 | if weight_init != "skip": 432 | self.init_weights(weight_init) 433 | 434 | def init_weights(self, mode: Literal["jax", "jax_nlhb", "moco", ""] = "") -> None: 435 | assert mode in ("jax", "jax_nlhb", "moco", "") 436 | # head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0 437 | trunc_normal_(self.pos_embed, std=0.02) 438 | if self.cls_token is not None: 439 | nn.init.normal_(self.cls_token, std=1e-6) 440 | named_apply(init_weights_vit_timm, self) 441 | 442 | @torch.jit.ignore 443 | def no_weight_decay(self) -> Set: 444 | return {"pos_embed", "cls_token", "dist_token"} 445 | 446 | @torch.jit.ignore 447 | def group_matcher(self, coarse: bool = False) -> Dict: 448 | return dict( 449 | stem=r"^cls_token|pos_embed|patch_embed", # stem and embed 450 | blocks=[(r"^blocks\.(\d+)", None), (r"^norm", (99999,))], 451 | ) 452 | 453 | @torch.jit.ignore 454 | def set_grad_checkpointing(self, enable: bool = True) -> None: 455 | self.grad_checkpointing = enable 456 | 457 | @torch.jit.ignore 458 | def get_classifier(self) -> nn.Module: 459 | return self.head 460 | 461 | def reset_classifier(self, num_classes: int, global_pool=None) -> None: 462 | self.num_classes = num_classes 463 | if global_pool is not None: 464 | assert global_pool in ("", "avg", "token", "map") 465 | if global_pool == "map" and self.attn_pool is None: 466 | assert ( 467 | False 468 | ), "Cannot currently add attention pooling in reset_classifier()." 469 | elif global_pool != "map " and self.attn_pool is not None: 470 | self.attn_pool = None # remove attention pooling 471 | self.global_pool = global_pool 472 | self.head = ( 473 | nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 474 | ) 475 | 476 | def _pos_embed(self, x: torch.Tensor) -> torch.Tensor: 477 | if self.dynamic_img_size: 478 | B, H, W, C = x.shape 479 | pos_embed = resample_abs_pos_embed( 480 | self.pos_embed, 481 | (H, W), 482 | num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens, 483 | ) 484 | x = x.view(B, -1, C) 485 | else: 486 | pos_embed = self.pos_embed 487 | 488 | to_cat = [] 489 | if self.cls_token is not None: 490 | to_cat.append(self.cls_token.expand(x.shape[0], -1, -1)) 491 | if self.reg_token is not None: 492 | to_cat.append(self.reg_token.expand(x.shape[0], -1, -1)) 493 | 494 | if self.no_embed_class: 495 | # deit-3, updated JAX (big vision) 496 | # position embedding does not overlap with class token, add then concat 497 | x = x + pos_embed 498 | if to_cat: 499 | x = torch.cat(to_cat + [x], dim=1) 500 | else: 501 | # original timm, JAX, and deit vit impl 502 | # pos_embed has entry for class token, concat then add 503 | if to_cat: 504 | x = torch.cat(to_cat + [x], dim=1) 505 | x = x + pos_embed 506 | 507 | return self.pos_drop(x) 508 | 509 | def _intermediate_layers( 510 | self, 511 | x: torch.Tensor, 512 | n: Union[int, Sequence] = 1, 513 | ) -> List[torch.Tensor]: 514 | outputs, num_blocks = [], len(self.blocks) 515 | take_indices = set( 516 | range(num_blocks - n, num_blocks) if isinstance(n, int) else n 517 | ) 518 | 519 | # forward pass 520 | x = self.patch_embed(x) 521 | x = self._pos_embed(x) 522 | x = self.patch_drop(x) 523 | x = self.norm_pre(x) 524 | for i, blk in enumerate(self.blocks): 525 | x = blk(x) 526 | if i in take_indices: 527 | outputs.append(x) 528 | 529 | return outputs 530 | 531 | def get_intermediate_layers( 532 | self, 533 | x: torch.Tensor, 534 | n: Union[int, Sequence] = 1, 535 | reshape: bool = False, 536 | return_prefix_tokens: bool = False, 537 | norm: bool = False, 538 | ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: 539 | """Intermediate layer accessor (NOTE: This is a WIP experiment). 540 | Inspired by DINO / DINOv2 interface 541 | """ 542 | # take last n blocks if n is an int, if in is a sequence, select by matching indices 543 | outputs = self._intermediate_layers(x, n) 544 | if norm: 545 | outputs = [self.norm(out) for out in outputs] 546 | prefix_tokens = [out[:, 0 : self.num_prefix_tokens] for out in outputs] 547 | outputs = [out[:, self.num_prefix_tokens :] for out in outputs] 548 | 549 | if reshape: 550 | grid_size = self.patch_embed.grid_size 551 | outputs = [ 552 | out.reshape(x.shape[0], grid_size[0], grid_size[1], -1) 553 | .permute(0, 3, 1, 2) 554 | .contiguous() 555 | for out in outputs 556 | ] 557 | 558 | if return_prefix_tokens: 559 | return tuple(zip(outputs, prefix_tokens)) 560 | return tuple(outputs) 561 | 562 | def forward_features(self, x: torch.Tensor) -> torch.Tensor: 563 | x = self.patch_embed(x) 564 | x = self._pos_embed(x) 565 | x = self.patch_drop(x) 566 | x = self.norm_pre(x) 567 | if self.grad_checkpointing and not torch.jit.is_scripting(): 568 | x = checkpoint_seq(self.blocks, x) 569 | else: 570 | x = self.blocks(x) 571 | x = self.norm(x) 572 | return x 573 | 574 | def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: 575 | if self.attn_pool is not None: 576 | x = self.attn_pool(x) 577 | elif self.global_pool == "avg": 578 | x = x[:, self.num_prefix_tokens :].mean(dim=1) 579 | elif self.global_pool: 580 | x = x[:, 0] # class token 581 | x = self.fc_norm(x) 582 | x = self.head_drop(x) 583 | return x if pre_logits else self.head(x) 584 | 585 | def forward(self, x: torch.Tensor) -> torch.Tensor: 586 | x = self.forward_features(x) 587 | if not self.ignore_head: 588 | x = self.forward_head(x) 589 | return x 590 | 591 | 592 | @dataclass 593 | class SigLIPVisionCfg: 594 | width: int = 1152 595 | layers: Union[Tuple[int, int, int, int], int] = 27 596 | heads: int = 16 597 | patch_size: int = 14 598 | image_size: Union[Tuple[int, int], int] = 336 599 | global_pool: str = "map" 600 | mlp_ratio: float = 3.7362 601 | class_token: bool = False 602 | num_classes: int = 0 603 | use_checkpoint: bool = False 604 | 605 | 606 | SigLIP_MODEL_CONFIG = { 607 | "siglip_so400m_patch14_384": { 608 | "image_size": 336, 609 | "patch_size": 14, 610 | "width": 1152, 611 | "layers": 27, 612 | "heads": 16, 613 | "mlp_ratio": 3.7362, 614 | "global_pool": "map", 615 | "use_checkpoint": False, 616 | }, 617 | "siglip_so400m_patch14_224": { 618 | "image_size": 224, 619 | "patch_size": 14, 620 | "width": 1152, 621 | "layers": 27, 622 | "heads": 16, 623 | "mlp_ratio": 3.7362, 624 | "global_pool": "map", 625 | "use_checkpoint": False, 626 | }, 627 | "siglip_large_patch16_384": { 628 | "image_size": 384, 629 | "patch_size": 16, 630 | "width": 1024, 631 | "layers": 24, 632 | "heads": 16, 633 | "mlp_ratio": 4, 634 | "global_pool": "map", 635 | "use_checkpoint": False, 636 | }, 637 | } 638 | 639 | 640 | def create_siglip_vit( 641 | model_name: str = "siglip_so400m_patch14_384", 642 | image_size: int = 384, 643 | select_layer: int = -1, 644 | ckpt_path: str = "", 645 | **kwargs, 646 | ): 647 | assert ( 648 | model_name in SigLIP_MODEL_CONFIG.keys() 649 | ), f"model name should be in {SigLIP_MODEL_CONFIG.keys()}" 650 | 651 | vision_cfg = SigLIPVisionCfg(**SigLIP_MODEL_CONFIG[model_name]) 652 | 653 | if select_layer <= 0: 654 | layers = min(vision_cfg.layers, vision_cfg.layers + select_layer + 1) 655 | else: 656 | layers = min(vision_cfg.layers, select_layer) 657 | 658 | model = VisionTransformer( 659 | img_size=image_size, 660 | patch_size=vision_cfg.patch_size, 661 | embed_dim=vision_cfg.width, 662 | depth=layers, 663 | num_heads=vision_cfg.heads, 664 | mlp_ratio=vision_cfg.mlp_ratio, 665 | class_token=vision_cfg.class_token, 666 | global_pool=vision_cfg.global_pool, 667 | ignore_head=kwargs.get("ignore_head", True), 668 | weight_init=kwargs.get("weight_init", "skip"), 669 | num_classes=0, 670 | ) 671 | 672 | if ckpt_path: 673 | state_dict = torch.load(ckpt_path, map_location="cpu") 674 | 675 | incompatible_keys = model.load_state_dict(state_dict, strict=False) 676 | print( 677 | f"SigLIP-ViT restores from {ckpt_path},\n" 678 | f"\tincompatible_keys:', {incompatible_keys}." 679 | ) 680 | 681 | return model 682 | -------------------------------------------------------------------------------- /janus/models/vq_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024 DeepSeek. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | 21 | from dataclasses import dataclass, field 22 | from typing import List 23 | 24 | import torch 25 | import torch.nn as nn 26 | import torch.nn.functional as F 27 | 28 | from functools import partial 29 | 30 | 31 | @dataclass 32 | class ModelArgs: 33 | codebook_size: int = 16384 34 | codebook_embed_dim: int = 8 35 | codebook_l2_norm: bool = True 36 | codebook_show_usage: bool = True 37 | commit_loss_beta: float = 0.25 38 | entropy_loss_ratio: float = 0.0 39 | 40 | encoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4]) 41 | decoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4]) 42 | z_channels: int = 256 43 | dropout_p: float = 0.0 44 | 45 | 46 | class Encoder(nn.Module): 47 | def __init__( 48 | self, 49 | in_channels=3, 50 | ch=128, 51 | ch_mult=(1, 1, 2, 2, 4), 52 | num_res_blocks=2, 53 | norm_type="group", 54 | dropout=0.0, 55 | resamp_with_conv=True, 56 | z_channels=256, 57 | ): 58 | super().__init__() 59 | self.num_resolutions = len(ch_mult) 60 | self.num_res_blocks = num_res_blocks 61 | self.conv_in = nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1) 62 | 63 | # downsampling 64 | in_ch_mult = (1,) + tuple(ch_mult) 65 | self.conv_blocks = nn.ModuleList() 66 | for i_level in range(self.num_resolutions): 67 | conv_block = nn.Module() 68 | # res & attn 69 | res_block = nn.ModuleList() 70 | attn_block = nn.ModuleList() 71 | block_in = ch * in_ch_mult[i_level] 72 | block_out = ch * ch_mult[i_level] 73 | for _ in range(self.num_res_blocks): 74 | res_block.append( 75 | ResnetBlock( 76 | block_in, block_out, dropout=dropout, norm_type=norm_type 77 | ) 78 | ) 79 | block_in = block_out 80 | if i_level == self.num_resolutions - 1: 81 | attn_block.append(AttnBlock(block_in, norm_type)) 82 | conv_block.res = res_block 83 | conv_block.attn = attn_block 84 | # downsample 85 | if i_level != self.num_resolutions - 1: 86 | conv_block.downsample = Downsample(block_in, resamp_with_conv) 87 | self.conv_blocks.append(conv_block) 88 | 89 | # middle 90 | self.mid = nn.ModuleList() 91 | self.mid.append( 92 | ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type) 93 | ) 94 | self.mid.append(AttnBlock(block_in, norm_type=norm_type)) 95 | self.mid.append( 96 | ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type) 97 | ) 98 | 99 | # end 100 | self.norm_out = Normalize(block_in, norm_type) 101 | self.conv_out = nn.Conv2d( 102 | block_in, z_channels, kernel_size=3, stride=1, padding=1 103 | ) 104 | 105 | def forward(self, x): 106 | h = self.conv_in(x) 107 | # downsampling 108 | for i_level, block in enumerate(self.conv_blocks): 109 | for i_block in range(self.num_res_blocks): 110 | h = block.res[i_block](h) 111 | if len(block.attn) > 0: 112 | h = block.attn[i_block](h) 113 | if i_level != self.num_resolutions - 1: 114 | h = block.downsample(h) 115 | 116 | # middle 117 | for mid_block in self.mid: 118 | h = mid_block(h) 119 | 120 | # end 121 | h = self.norm_out(h) 122 | h = nonlinearity(h) 123 | h = self.conv_out(h) 124 | return h 125 | 126 | 127 | class Decoder(nn.Module): 128 | def __init__( 129 | self, 130 | z_channels=256, 131 | ch=128, 132 | ch_mult=(1, 1, 2, 2, 4), 133 | num_res_blocks=2, 134 | norm_type="group", 135 | dropout=0.0, 136 | resamp_with_conv=True, 137 | out_channels=3, 138 | ): 139 | super().__init__() 140 | self.num_resolutions = len(ch_mult) 141 | self.num_res_blocks = num_res_blocks 142 | 143 | block_in = ch * ch_mult[self.num_resolutions - 1] 144 | # z to block_in 145 | self.conv_in = nn.Conv2d( 146 | z_channels, block_in, kernel_size=3, stride=1, padding=1 147 | ) 148 | 149 | # middle 150 | self.mid = nn.ModuleList() 151 | self.mid.append( 152 | ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type) 153 | ) 154 | self.mid.append(AttnBlock(block_in, norm_type=norm_type)) 155 | self.mid.append( 156 | ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type) 157 | ) 158 | 159 | # upsampling 160 | self.conv_blocks = nn.ModuleList() 161 | for i_level in reversed(range(self.num_resolutions)): 162 | conv_block = nn.Module() 163 | # res & attn 164 | res_block = nn.ModuleList() 165 | attn_block = nn.ModuleList() 166 | block_out = ch * ch_mult[i_level] 167 | for _ in range(self.num_res_blocks + 1): 168 | res_block.append( 169 | ResnetBlock( 170 | block_in, block_out, dropout=dropout, norm_type=norm_type 171 | ) 172 | ) 173 | block_in = block_out 174 | if i_level == self.num_resolutions - 1: 175 | attn_block.append(AttnBlock(block_in, norm_type)) 176 | conv_block.res = res_block 177 | conv_block.attn = attn_block 178 | # downsample 179 | if i_level != 0: 180 | conv_block.upsample = Upsample(block_in, resamp_with_conv) 181 | self.conv_blocks.append(conv_block) 182 | 183 | # end 184 | self.norm_out = Normalize(block_in, norm_type) 185 | self.conv_out = nn.Conv2d( 186 | block_in, out_channels, kernel_size=3, stride=1, padding=1 187 | ) 188 | 189 | @property 190 | def last_layer(self): 191 | return self.conv_out.weight 192 | 193 | def forward(self, z): 194 | # z to block_in 195 | h = self.conv_in(z) 196 | 197 | # middle 198 | for mid_block in self.mid: 199 | h = mid_block(h) 200 | 201 | # upsampling 202 | for i_level, block in enumerate(self.conv_blocks): 203 | for i_block in range(self.num_res_blocks + 1): 204 | h = block.res[i_block](h) 205 | if len(block.attn) > 0: 206 | h = block.attn[i_block](h) 207 | if i_level != self.num_resolutions - 1: 208 | h = block.upsample(h) 209 | 210 | # end 211 | h = self.norm_out(h) 212 | h = nonlinearity(h) 213 | h = self.conv_out(h) 214 | return h 215 | 216 | 217 | class VectorQuantizer(nn.Module): 218 | def __init__(self, n_e, e_dim, beta, entropy_loss_ratio, l2_norm, show_usage): 219 | super().__init__() 220 | self.n_e = n_e 221 | self.e_dim = e_dim 222 | self.beta = beta 223 | self.entropy_loss_ratio = entropy_loss_ratio 224 | self.l2_norm = l2_norm 225 | self.show_usage = show_usage 226 | 227 | self.embedding = nn.Embedding(self.n_e, self.e_dim) 228 | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) 229 | if self.l2_norm: 230 | self.embedding.weight.data = F.normalize( 231 | self.embedding.weight.data, p=2, dim=-1 232 | ) 233 | if self.show_usage: 234 | self.register_buffer("codebook_used", nn.Parameter(torch.zeros(65536))) 235 | 236 | def forward(self, z): 237 | # reshape z -> (batch, height, width, channel) and flatten 238 | z = torch.einsum("b c h w -> b h w c", z).contiguous() 239 | z_flattened = z.view(-1, self.e_dim) 240 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z 241 | 242 | if self.l2_norm: 243 | z = F.normalize(z, p=2, dim=-1) 244 | z_flattened = F.normalize(z_flattened, p=2, dim=-1) 245 | embedding = F.normalize(self.embedding.weight, p=2, dim=-1) 246 | else: 247 | embedding = self.embedding.weight 248 | 249 | d = ( 250 | torch.sum(z_flattened**2, dim=1, keepdim=True) 251 | + torch.sum(embedding**2, dim=1) 252 | - 2 253 | * torch.einsum( 254 | "bd,dn->bn", z_flattened, torch.einsum("n d -> d n", embedding) 255 | ) 256 | ) 257 | 258 | min_encoding_indices = torch.argmin(d, dim=1) 259 | z_q = embedding[min_encoding_indices].view(z.shape) 260 | perplexity = None 261 | min_encodings = None 262 | vq_loss = None 263 | commit_loss = None 264 | entropy_loss = None 265 | 266 | # compute loss for embedding 267 | if self.training: 268 | vq_loss = torch.mean((z_q - z.detach()) ** 2) 269 | commit_loss = self.beta * torch.mean((z_q.detach() - z) ** 2) 270 | entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(-d) 271 | 272 | # preserve gradients 273 | z_q = z + (z_q - z).detach() 274 | 275 | # reshape back to match original input shape 276 | z_q = torch.einsum("b h w c -> b c h w", z_q) 277 | 278 | return ( 279 | z_q, 280 | (vq_loss, commit_loss, entropy_loss), 281 | (perplexity, min_encodings, min_encoding_indices), 282 | ) 283 | 284 | def get_codebook_entry(self, indices, shape=None, channel_first=True): 285 | # shape = (batch, channel, height, width) if channel_first else (batch, height, width, channel) 286 | if self.l2_norm: 287 | embedding = F.normalize(self.embedding.weight, p=2, dim=-1) 288 | else: 289 | embedding = self.embedding.weight 290 | z_q = embedding[indices] # (b*h*w, c) 291 | 292 | if shape is not None: 293 | if channel_first: 294 | z_q = z_q.reshape(shape[0], shape[2], shape[3], shape[1]) 295 | # reshape back to match original input shape 296 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 297 | else: 298 | z_q = z_q.view(shape) 299 | return z_q 300 | 301 | 302 | class ResnetBlock(nn.Module): 303 | def __init__( 304 | self, 305 | in_channels, 306 | out_channels=None, 307 | conv_shortcut=False, 308 | dropout=0.0, 309 | norm_type="group", 310 | ): 311 | super().__init__() 312 | self.in_channels = in_channels 313 | out_channels = in_channels if out_channels is None else out_channels 314 | self.out_channels = out_channels 315 | self.use_conv_shortcut = conv_shortcut 316 | 317 | self.norm1 = Normalize(in_channels, norm_type) 318 | self.conv1 = nn.Conv2d( 319 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 320 | ) 321 | self.norm2 = Normalize(out_channels, norm_type) 322 | self.dropout = nn.Dropout(dropout) 323 | self.conv2 = nn.Conv2d( 324 | out_channels, out_channels, kernel_size=3, stride=1, padding=1 325 | ) 326 | 327 | if self.in_channels != self.out_channels: 328 | if self.use_conv_shortcut: 329 | self.conv_shortcut = nn.Conv2d( 330 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 331 | ) 332 | else: 333 | self.nin_shortcut = nn.Conv2d( 334 | in_channels, out_channels, kernel_size=1, stride=1, padding=0 335 | ) 336 | 337 | def forward(self, x): 338 | h = x 339 | h = self.norm1(h) 340 | h = nonlinearity(h) 341 | h = self.conv1(h) 342 | h = self.norm2(h) 343 | h = nonlinearity(h) 344 | h = self.dropout(h) 345 | h = self.conv2(h) 346 | 347 | if self.in_channels != self.out_channels: 348 | if self.use_conv_shortcut: 349 | x = self.conv_shortcut(x) 350 | else: 351 | x = self.nin_shortcut(x) 352 | return x + h 353 | 354 | 355 | class AttnBlock(nn.Module): 356 | def __init__(self, in_channels, norm_type="group"): 357 | super().__init__() 358 | self.norm = Normalize(in_channels, norm_type) 359 | self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) 360 | self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) 361 | self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) 362 | self.proj_out = nn.Conv2d( 363 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 364 | ) 365 | 366 | def forward(self, x): 367 | h_ = x 368 | h_ = self.norm(h_) 369 | q = self.q(h_) 370 | k = self.k(h_) 371 | v = self.v(h_) 372 | 373 | # compute attention 374 | b, c, h, w = q.shape 375 | q = q.reshape(b, c, h * w) 376 | q = q.permute(0, 2, 1) # b,hw,c 377 | k = k.reshape(b, c, h * w) # b,c,hw 378 | w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] 379 | w_ = w_ * (int(c) ** (-0.5)) 380 | w_ = F.softmax(w_, dim=2) 381 | 382 | # attend to values 383 | v = v.reshape(b, c, h * w) 384 | w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) 385 | h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] 386 | h_ = h_.reshape(b, c, h, w) 387 | 388 | h_ = self.proj_out(h_) 389 | 390 | return x + h_ 391 | 392 | 393 | def nonlinearity(x): 394 | # swish 395 | return x * torch.sigmoid(x) 396 | 397 | 398 | def Normalize(in_channels, norm_type="group"): 399 | assert norm_type in ["group", "batch"] 400 | if norm_type == "group": 401 | return nn.GroupNorm( 402 | num_groups=32, num_channels=in_channels, eps=1e-6, affine=True 403 | ) 404 | elif norm_type == "batch": 405 | return nn.SyncBatchNorm(in_channels) 406 | 407 | 408 | class Upsample(nn.Module): 409 | def __init__(self, in_channels, with_conv): 410 | super().__init__() 411 | self.with_conv = with_conv 412 | if self.with_conv: 413 | self.conv = nn.Conv2d( 414 | in_channels, in_channels, kernel_size=3, stride=1, padding=1 415 | ) 416 | 417 | def forward(self, x): 418 | if x.dtype != torch.float32: 419 | x = F.interpolate(x.to(torch.float), scale_factor=2.0, mode="nearest").to( 420 | torch.bfloat16 421 | ) 422 | else: 423 | x = F.interpolate(x, scale_factor=2.0, mode="nearest") 424 | 425 | if self.with_conv: 426 | x = self.conv(x) 427 | return x 428 | 429 | 430 | class Downsample(nn.Module): 431 | def __init__(self, in_channels, with_conv): 432 | super().__init__() 433 | self.with_conv = with_conv 434 | if self.with_conv: 435 | # no asymmetric padding in torch conv, must do it ourselves 436 | self.conv = nn.Conv2d( 437 | in_channels, in_channels, kernel_size=3, stride=2, padding=0 438 | ) 439 | 440 | def forward(self, x): 441 | if self.with_conv: 442 | pad = (0, 1, 0, 1) 443 | x = F.pad(x, pad, mode="constant", value=0) 444 | x = self.conv(x) 445 | else: 446 | x = F.avg_pool2d(x, kernel_size=2, stride=2) 447 | return x 448 | 449 | 450 | def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01): 451 | flat_affinity = affinity.reshape(-1, affinity.shape[-1]) 452 | flat_affinity /= temperature 453 | probs = F.softmax(flat_affinity, dim=-1) 454 | log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1) 455 | if loss_type == "softmax": 456 | target_probs = probs 457 | else: 458 | raise ValueError("Entropy loss {} not supported".format(loss_type)) 459 | avg_probs = torch.mean(target_probs, dim=0) 460 | avg_entropy = -torch.sum(avg_probs * torch.log(avg_probs + 1e-5)) 461 | sample_entropy = -torch.mean(torch.sum(target_probs * log_probs, dim=-1)) 462 | loss = sample_entropy - avg_entropy 463 | return loss 464 | 465 | 466 | class VQModel(nn.Module): 467 | def __init__(self, config: ModelArgs): 468 | super().__init__() 469 | self.config = config 470 | self.encoder = Encoder( 471 | ch_mult=config.encoder_ch_mult, 472 | z_channels=config.z_channels, 473 | dropout=config.dropout_p, 474 | ) 475 | self.decoder = Decoder( 476 | ch_mult=config.decoder_ch_mult, 477 | z_channels=config.z_channels, 478 | dropout=config.dropout_p, 479 | ) 480 | 481 | self.quantize = VectorQuantizer( 482 | config.codebook_size, 483 | config.codebook_embed_dim, 484 | config.commit_loss_beta, 485 | config.entropy_loss_ratio, 486 | config.codebook_l2_norm, 487 | config.codebook_show_usage, 488 | ) 489 | self.quant_conv = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1) 490 | self.post_quant_conv = nn.Conv2d( 491 | config.codebook_embed_dim, config.z_channels, 1 492 | ) 493 | 494 | def encode(self, x): 495 | h = self.encoder(x) 496 | h = self.quant_conv(h) 497 | quant, emb_loss, info = self.quantize(h) 498 | return quant, emb_loss, info 499 | 500 | def decode(self, quant): 501 | quant = self.post_quant_conv(quant) 502 | dec = self.decoder(quant) 503 | return dec 504 | 505 | def decode_code(self, code_b, shape=None, channel_first=True): 506 | quant_b = self.quantize.get_codebook_entry(code_b, shape, channel_first) 507 | dec = self.decode(quant_b) 508 | return dec 509 | 510 | def forward(self, input): 511 | quant, diff, _ = self.encode(input) 512 | dec = self.decode(quant) 513 | return dec, diff 514 | 515 | 516 | ################################################################################# 517 | # VQ Model Configs # 518 | ################################################################################# 519 | def VQ_16(**kwargs): 520 | return VQModel( 521 | ModelArgs( 522 | encoder_ch_mult=[1, 1, 2, 2, 4], decoder_ch_mult=[1, 1, 2, 2, 4], **kwargs 523 | ) 524 | ) 525 | 526 | 527 | VQ_models = {"VQ-16": VQ_16} 528 | -------------------------------------------------------------------------------- /janus/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024 DeepSeek. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | -------------------------------------------------------------------------------- /janus/utils/conversation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024 DeepSeek. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | """ 21 | From https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py 22 | """ 23 | 24 | import dataclasses 25 | from enum import IntEnum, auto 26 | from typing import Dict, List 27 | 28 | 29 | class SeparatorStyle(IntEnum): 30 | """Separator styles.""" 31 | 32 | ADD_COLON_SINGLE = auto() 33 | ADD_COLON_TWO = auto() 34 | ADD_COLON_SPACE_SINGLE = auto() 35 | NO_COLON_SINGLE = auto() 36 | NO_COLON_TWO = auto() 37 | ADD_NEW_LINE_SINGLE = auto() 38 | LLAMA2 = auto() 39 | CHATGLM = auto() 40 | CHATML = auto() 41 | CHATINTERN = auto() 42 | DOLLY = auto() 43 | RWKV = auto() 44 | PHOENIX = auto() 45 | ROBIN = auto() 46 | DeepSeek = auto() 47 | PLAIN = auto() 48 | ALIGNMENT = auto() 49 | 50 | 51 | @dataclasses.dataclass 52 | class Conversation: 53 | """A class that manages prompt templates and keeps all conversation history.""" 54 | 55 | # The name of this template 56 | name: str 57 | # The template of the system prompt 58 | system_template: str = "{system_message}" 59 | # The system message 60 | system_message: str = "" 61 | # The names of two roles 62 | roles: List[str] = (("USER", "ASSISTANT"),) 63 | # All messages. Each item is (role, message). 64 | messages: List[List[str]] = () 65 | # The number of few shot examples 66 | offset: int = 0 67 | # The separator style and configurations 68 | sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE 69 | sep: str = "\n" 70 | sep2: str = None 71 | # Stop criteria (the default one is EOS token) 72 | stop_str: str = None 73 | # Stops generation if meeting any token in this list 74 | stop_token_ids: List[int] = None 75 | 76 | def get_prompt(self) -> str: 77 | """Get the prompt for generation.""" 78 | system_prompt = self.system_template.format(system_message=self.system_message) 79 | 80 | if self.sep_style == SeparatorStyle.DeepSeek: 81 | seps = [self.sep, self.sep2] 82 | if system_prompt == "" or system_prompt is None: 83 | ret = "" 84 | else: 85 | ret = system_prompt + seps[0] 86 | for i, (role, message) in enumerate(self.messages): 87 | if message: 88 | ret += role + ": " + message + seps[i % 2] 89 | else: 90 | ret += role + ":" 91 | return ret 92 | elif self.sep_style == SeparatorStyle.LLAMA2: 93 | seps = [self.sep, self.sep2] 94 | if self.system_message: 95 | ret = system_prompt 96 | else: 97 | ret = "[INST] " 98 | for i, (role, message) in enumerate(self.messages): 99 | tag = self.roles[i % 2] 100 | if message: 101 | if type(message) is tuple: # multimodal message 102 | message, _ = message 103 | if i == 0: 104 | ret += message + " " 105 | else: 106 | ret += tag + " " + message + seps[i % 2] 107 | else: 108 | ret += tag 109 | return ret 110 | elif self.sep_style == SeparatorStyle.PLAIN: 111 | seps = [self.sep, self.sep2] 112 | ret = "" 113 | for i, (role, message) in enumerate(self.messages): 114 | if message: 115 | if type(message) is tuple: 116 | message, _, _ = message 117 | if i % 2 == 0: 118 | ret += message + seps[i % 2] 119 | else: 120 | ret += message + seps[i % 2] 121 | else: 122 | ret += "" 123 | return ret 124 | elif self.sep_style == SeparatorStyle.ALIGNMENT: 125 | seps = [self.sep, self.sep2] 126 | ret = "" 127 | for i, (role, message) in enumerate(self.messages): 128 | if message: 129 | if type(message) is tuple: 130 | message, _, _ = message 131 | if i % 2 == 0: 132 | ret += "\n" + seps[i % 2] 133 | else: 134 | ret += message + seps[i % 2] 135 | else: 136 | ret += "" 137 | return ret 138 | else: 139 | raise ValueError(f"Invalid style: {self.sep_style}") 140 | 141 | def get_prompt_for_current_round(self, content=None): 142 | """Get current round formatted question prompt during sft training""" 143 | if self.sep_style == SeparatorStyle.PLAIN: 144 | formatted_question = "\n" 145 | elif self.sep_style == SeparatorStyle.DeepSeek: 146 | formatted_question = ( 147 | f"{self.roles[0]}: " + content.strip() + self.sep + f"{self.roles[1]}:" 148 | ) 149 | else: 150 | raise ValueError(f"Unsupported sep_style: {self.sep_style}") 151 | return formatted_question 152 | 153 | def set_system_message(self, system_message: str): 154 | """Set the system message.""" 155 | self.system_message = system_message 156 | 157 | def append_message(self, role: str, message: str): 158 | """Append a new message.""" 159 | self.messages.append([role, message]) 160 | 161 | def reset_message(self): 162 | """Reset a new message.""" 163 | self.messages = [] 164 | 165 | def update_last_message(self, message: str): 166 | """Update the last output. 167 | 168 | The last message is typically set to be None when constructing the prompt, 169 | so we need to update it in-place after getting the response from a model. 170 | """ 171 | self.messages[-1][1] = message 172 | 173 | def to_gradio_chatbot(self): 174 | """Convert the conversation to gradio chatbot format.""" 175 | ret = [] 176 | for i, (role, msg) in enumerate(self.messages[self.offset :]): 177 | if i % 2 == 0: 178 | ret.append([msg, None]) 179 | else: 180 | ret[-1][-1] = msg 181 | return ret 182 | 183 | def to_openai_api_messages(self): 184 | """Convert the conversation to OpenAI chat completion format.""" 185 | system_prompt = self.system_template.format(system_message=self.system_message) 186 | ret = [{"role": "system", "content": system_prompt}] 187 | 188 | for i, (_, msg) in enumerate(self.messages[self.offset :]): 189 | if i % 2 == 0: 190 | ret.append({"role": "user", "content": msg}) 191 | else: 192 | if msg is not None: 193 | ret.append({"role": "assistant", "content": msg}) 194 | return ret 195 | 196 | def copy(self): 197 | return Conversation( 198 | name=self.name, 199 | system_template=self.system_template, 200 | system_message=self.system_message, 201 | roles=self.roles, 202 | messages=[[x, y] for x, y in self.messages], 203 | offset=self.offset, 204 | sep_style=self.sep_style, 205 | sep=self.sep, 206 | sep2=self.sep2, 207 | stop_str=self.stop_str, 208 | stop_token_ids=self.stop_token_ids, 209 | ) 210 | 211 | def dict(self): 212 | return { 213 | "template_name": self.name, 214 | "system_message": self.system_message, 215 | "roles": self.roles, 216 | "messages": self.messages, 217 | "offset": self.offset, 218 | } 219 | 220 | 221 | # A global registry for all conversation templates 222 | conv_templates: Dict[str, Conversation] = {} 223 | 224 | 225 | def register_conv_template(template: Conversation, override: bool = False): 226 | """Register a new conversation template.""" 227 | if not override: 228 | assert ( 229 | template.name not in conv_templates 230 | ), f"{template.name} has been registered." 231 | 232 | conv_templates[template.name] = template 233 | 234 | 235 | def get_conv_template(name: str) -> Conversation: 236 | """Get a conversation template.""" 237 | return conv_templates[name].copy() 238 | 239 | 240 | # llava_llama2 template 241 | register_conv_template( 242 | Conversation( 243 | name="llava_llama2", 244 | system_message="You are a helpful language and vision assistant. " 245 | "You are able to understand the visual content that the user provides, " 246 | "and assist the user with a variety of tasks using natural language.", 247 | system_template="[INST] <>\n{system_message}\n<>\n\n", 248 | roles=("[INST]", "[/INST]"), 249 | messages=(), 250 | offset=0, 251 | sep_style=SeparatorStyle.LLAMA2, 252 | sep=" ", 253 | sep2=" ", 254 | stop_token_ids=[2], 255 | ) 256 | ) 257 | 258 | # llama2 template 259 | # reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212 260 | register_conv_template( 261 | Conversation( 262 | name="llama-2", 263 | system_template="[INST] <>\n{system_message}\n<>\n\n", 264 | roles=("[INST]", "[/INST]"), 265 | messages=(), 266 | offset=0, 267 | sep_style=SeparatorStyle.LLAMA2, 268 | sep=" ", 269 | sep2=" ", 270 | stop_token_ids=[2], 271 | ) 272 | ) 273 | 274 | 275 | # deepseek template 276 | register_conv_template( 277 | Conversation( 278 | name="deepseek_old", 279 | system_template="{system_message}", 280 | # system_message="You are a helpful assistant. Please answer truthfully and write out your " 281 | # "thinking step by step to be sure you get the right answer.", 282 | system_message="", 283 | roles=("User", "Assistant"), 284 | messages=(), 285 | offset=0, 286 | sep_style=SeparatorStyle.DeepSeek, 287 | sep="\n\n", 288 | sep2="<|end▁of▁sentence|>", 289 | stop_token_ids=[100001], 290 | stop_str=["User:", "<|end▁of▁sentence|>"], 291 | ) 292 | ) 293 | register_conv_template( 294 | Conversation( 295 | name="deepseek", 296 | system_template="{system_message}", 297 | # system_message="You are a helpful assistant. Please answer truthfully and write out your " 298 | # "thinking step by step to be sure you get the right answer.", 299 | system_message="", 300 | roles=("<|User|>", "<|Assistant|>"), 301 | messages=(), 302 | offset=0, 303 | sep_style=SeparatorStyle.DeepSeek, 304 | sep="\n\n", 305 | sep2="<|end▁of▁sentence|>", 306 | stop_token_ids=[100001], 307 | stop_str=["<|User|>", "<|end▁of▁sentence|>"] 308 | ) 309 | ) 310 | 311 | register_conv_template( 312 | Conversation( 313 | name="plain", 314 | system_template="", 315 | system_message="", 316 | roles=("", ""), 317 | messages=(), 318 | offset=0, 319 | sep_style=SeparatorStyle.PLAIN, 320 | sep="", 321 | sep2="", 322 | stop_token_ids=[2], 323 | stop_str=[""], 324 | ) 325 | ) 326 | 327 | 328 | register_conv_template( 329 | Conversation( 330 | name="alignment", 331 | system_template="", 332 | system_message="", 333 | roles=("", ""), 334 | messages=(), 335 | offset=0, 336 | sep_style=SeparatorStyle.ALIGNMENT, 337 | sep="", 338 | sep2="", 339 | stop_token_ids=[2], 340 | stop_str=[""], 341 | ) 342 | ) 343 | 344 | 345 | if __name__ == "__main__": 346 | # print("Llama-2 template:") 347 | # conv = get_conv_template("llama-2") 348 | # conv.set_system_message("You are a helpful, respectful and honest assistant.") 349 | # conv.append_message(conv.roles[0], "Hello!") 350 | # conv.append_message(conv.roles[1], "Hi!") 351 | # conv.append_message(conv.roles[0], "How are you?") 352 | # conv.append_message(conv.roles[1], None) 353 | # print(conv.get_prompt()) 354 | 355 | # print("\n") 356 | 357 | print("deepseek template:") 358 | conv = get_conv_template("deepseek") 359 | conv.append_message(conv.roles[0], "Hello!") 360 | conv.append_message(conv.roles[1], "Hi! This is Tony.") 361 | conv.append_message(conv.roles[0], "Who are you?") 362 | conv.append_message(conv.roles[1], "I am a helpful assistant.") 363 | conv.append_message(conv.roles[0], "How are you?") 364 | conv.append_message(conv.roles[1], None) 365 | print(conv.get_prompt()) 366 | -------------------------------------------------------------------------------- /janus/utils/io.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024 DeepSeek. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | import json 21 | from typing import Dict, List 22 | 23 | import PIL.Image 24 | import torch 25 | import base64 26 | import io 27 | from transformers import AutoModelForCausalLM 28 | 29 | from janus.models import MultiModalityCausalLM, VLChatProcessor 30 | 31 | 32 | def load_pretrained_model(model_path: str): 33 | vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path) 34 | tokenizer = vl_chat_processor.tokenizer 35 | 36 | vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained( 37 | model_path, trust_remote_code=True 38 | ) 39 | vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval() 40 | 41 | return tokenizer, vl_chat_processor, vl_gpt 42 | 43 | 44 | def load_pil_images(conversations: List[Dict[str, str]]) -> List[PIL.Image.Image]: 45 | """ 46 | 47 | Support file path or base64 images. 48 | 49 | Args: 50 | conversations (List[Dict[str, str]]): the conversations with a list of messages. An example is : 51 | [ 52 | { 53 | "role": "User", 54 | "content": "\nExtract all information from this image and convert them into markdown format.", 55 | "images": ["./examples/table_datasets.png"] 56 | }, 57 | {"role": "Assistant", "content": ""}, 58 | ] 59 | 60 | Returns: 61 | pil_images (List[PIL.Image.Image]): the list of PIL images. 62 | 63 | """ 64 | 65 | pil_images = [] 66 | 67 | for message in conversations: 68 | if "images" not in message: 69 | continue 70 | 71 | for image_data in message["images"]: 72 | if image_data.startswith("data:image"): 73 | # Image data is in base64 format 74 | _, image_data = image_data.split(",", 1) 75 | image_bytes = base64.b64decode(image_data) 76 | pil_img = PIL.Image.open(io.BytesIO(image_bytes)) 77 | else: 78 | # Image data is a file path 79 | pil_img = PIL.Image.open(image_data) 80 | pil_img = pil_img.convert("RGB") 81 | pil_images.append(pil_img) 82 | 83 | return pil_images 84 | 85 | 86 | def load_json(filepath): 87 | with open(filepath, "r") as f: 88 | data = json.load(f) 89 | return data 90 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | attrdict 2 | --------------------------------------------------------------------------------