├── README.md ├── __init__.py ├── assets └── node.png ├── models └── joycaption │ └── folder_for_joycaption_image_adapter ├── nodes.py ├── pyproject.toml └── requirements.txt /README.md: -------------------------------------------------------------------------------- 1 | # Joycaption for ComfyUI 2 | Integration of joycaption model for better tagging of images. Use in combination with models which support rich text encoders like Flux or SDXL. 3 | 4 | ![node.png](assets/node.png) 5 | 6 | Credits to the model all going to https://huggingface.co/spaces/fancyfeast/joy-caption-pre-alpha 7 | 8 | ## Nodes 9 | This documentation provides an overview of Joycaption node. 10 | 11 | ### Joycaption 12 | In order to get Joycaption to work, you need to handle the following models: 13 | - a LLM (in this case LLaMa 3.1 8B instruction model quantized to int4), approx. 6GB VRAM 14 | - a CLIP model (google/siglip-so400m-patch14-384), which is used to encode the images 15 | - the image adapter model, which can be found at the link above 16 | 17 | For low VRAM environments (< 12 GB), it is recommended to shift the clip model to the cpu. However, this will slow down the process significantly. 18 | **Parameters** 19 | - `image`: The input images which should be tagged. 20 | - `llm_device`: The device to use for the LLM model ('cuda' or 'cpu'). 21 | - `clip_device`: The device to use for the CLIP model ('cuda' or 'cpu'). 22 | - `instruction`: The instruction to use for the LLM model. -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .nodes import JoyCaptioning 2 | 3 | version_code = [0, 1] 4 | version_str = f"V{version_code[0]}.{version_code[1]}" + (f'.{version_code[2]}' if len(version_code) > 2 else '') 5 | print(f"### Loading: ComfyUI Joycaption({version_str})") 6 | 7 | NODE_CLASS_MAPPINGS = { 8 | "Joycaption": JoyCaptioning, 9 | } 10 | 11 | NODE_DISPLAY_NAME_MAPPINGS = { 12 | "Joycaption": "Joycaption", 13 | } 14 | 15 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] 16 | -------------------------------------------------------------------------------- /assets/node.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larsupb/comfyui-joycaption/dbda9cb28c935599d2cbadd3a4f6525e57f3fe48/assets/node.png -------------------------------------------------------------------------------- /models/joycaption/folder_for_joycaption_image_adapter: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larsupb/comfyui-joycaption/dbda9cb28c935599d2cbadd3a4f6525e57f3fe48/models/joycaption/folder_for_joycaption_image_adapter -------------------------------------------------------------------------------- /nodes.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.amp.autocast_mode 4 | import os 5 | import requests 6 | import gc 7 | from transformers import (AutoModel, AutoProcessor, AutoTokenizer, 8 | PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM) 9 | from pathlib import Path 10 | from PIL import Image 11 | 12 | class JoyCaptioning: 13 | def __init__(self): 14 | self.CLIP_PATH = "google/siglip-so400m-patch14-384" 15 | self.LLM_ID = "hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4" # " 16 | self.CHECKPOINT_PATH = os.path.join(Path(__file__).parent, Path("models/joycaption/wpkklhc6")) 17 | 18 | self.image_adapter = None 19 | self.clip_model = None 20 | self.clip_processor = None 21 | self.tokenizer = None 22 | self.text_model = None 23 | 24 | self.acquire_model() 25 | 26 | @classmethod 27 | def INPUT_TYPES(cls): 28 | return { 29 | "required": { 30 | "image": ("IMAGE",), 31 | "llm_device": (["cuda", "cpu"], {"default": "cuda"}), 32 | "clip_device": (["cuda", "cpu"], {"default": "cpu"}), 33 | "instruction": ("STRING", {"default": "A descriptive caption for this image", "multiline": True}), 34 | }, 35 | } 36 | 37 | RETURN_TYPES = ("String",) 38 | FUNCTION = "generate_joycaption" 39 | CATEGORY = "Tagging" 40 | 41 | def generate_joycaption(self, image: torch.tensor, llm_device='cuda', clip_device='cpu', 42 | instruction="A descriptive caption for this image"): 43 | torch.cuda.empty_cache() 44 | 45 | self.load_model(llm_device, clip_device) 46 | 47 | # Convert the Tensor to a PIL image 48 | image_np = image.numpy().squeeze() # Remove the first dimension (batch size of 1) 49 | # Convert the numpy array back to the original range (0-255) and data type (uint8) 50 | image_np = (image_np * 255).astype(np.uint8) 51 | # Create a PIL image from the numpy array 52 | image = Image.fromarray(image_np, mode="RGB") 53 | 54 | # resize image 55 | image = resize_image(image) 56 | 57 | # Tokenize the prompt 58 | prompt = self.tokenizer.encode(instruction + ":\n", return_tensors='pt', 59 | padding=False, truncation=False, add_special_tokens=False) 60 | 61 | # Preprocess image 62 | image = self.clip_processor(images=image, return_tensors='pt').pixel_values 63 | 64 | # Embed image 65 | vision_outputs = self.clip_model(pixel_values=image, output_hidden_states=True) 66 | image_features = vision_outputs.hidden_states[-2].to('cuda') 67 | embedded_images = self.image_adapter(image_features) 68 | embedded_images = embedded_images.to('cuda') 69 | 70 | # Embed prompt 71 | prompt_embeds = self.text_model.model.embed_tokens(prompt.to('cuda')) 72 | assert prompt_embeds.shape == (1, prompt.shape[1], 73 | self.text_model.config.hidden_size), \ 74 | f"Prompt shape is {prompt_embeds.shape}, expected {(1, prompt.shape[1], self.text_model.config.hidden_size)}" 75 | embedded_bos = self.text_model.model.embed_tokens( 76 | torch.tensor([[self.tokenizer.bos_token_id]], device=self.text_model.device, dtype=torch.int64)) 77 | 78 | # Construct prompts 79 | inputs_embeds = torch.cat([ 80 | embedded_bos.expand(embedded_images.shape[0], -1, -1), 81 | embedded_images.to(dtype=embedded_bos.dtype), 82 | prompt_embeds.expand(embedded_images.shape[0], -1, -1), 83 | ], dim=1) 84 | 85 | input_ids = torch.cat([ 86 | torch.tensor([[self.tokenizer.bos_token_id]], dtype=torch.long), 87 | torch.zeros((1, embedded_images.shape[1]), dtype=torch.long), 88 | prompt, 89 | ], dim=1).to('cuda') 90 | attention_mask = torch.ones_like(input_ids) 91 | 92 | generate_ids = self.text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, 93 | max_new_tokens=255, do_sample=True, top_k=10, temperature=0.5, 94 | suppress_tokens=None) 95 | 96 | # Trim off the prompt 97 | generate_ids = generate_ids[:, input_ids.shape[1]:] 98 | if generate_ids[0][-1] == self.tokenizer.eos_token_id: 99 | generate_ids = generate_ids[:, :-1] 100 | 101 | caption = self.tokenizer.batch_decode(generate_ids, skip_special_tokens=False, 102 | clean_up_tokenization_spaces=False)[0] 103 | 104 | self.cleanup() 105 | 106 | return (caption.strip(),) 107 | 108 | def load_model(self, llm_device, clip_device): 109 | # LLM 110 | print("Loading LLM") 111 | self.text_model = AutoModelForCausalLM.from_pretrained(self.LLM_ID, device_map=llm_device) 112 | self.text_model.eval() 113 | # LLM Tokenizer 114 | print("Loading tokenizer") 115 | self.tokenizer = AutoTokenizer.from_pretrained(self.LLM_ID, device_map=llm_device, use_fast=False) 116 | assert (isinstance(self.tokenizer, PreTrainedTokenizer) or 117 | isinstance(self.tokenizer, PreTrainedTokenizerFast)), f"Tokenizer is of type {type(self.tokenizer)}" 118 | 119 | # Load CLIP 120 | print("Loading CLIP") 121 | self.clip_processor = AutoProcessor.from_pretrained(self.CLIP_PATH, device_map=clip_device) 122 | self.clip_model = AutoModel.from_pretrained(self.CLIP_PATH, device_map=clip_device) 123 | self.clip_model = self.clip_model.vision_model 124 | self.clip_model.eval() 125 | self.clip_model.requires_grad_(False) 126 | 127 | # Image Adapter 128 | print("Loading image adapter") 129 | self.acquire_model() 130 | self.image_adapter = ImageAdapter(self.clip_model.config.hidden_size, self.text_model.config.hidden_size) 131 | self.image_adapter.load_state_dict(torch.load(os.path.join(self.CHECKPOINT_PATH, "image_adapter.pt"))) 132 | self.image_adapter.eval() 133 | self.image_adapter.to("cuda") 134 | 135 | def acquire_model(self): 136 | if (not os.path.exists(self.CHECKPOINT_PATH) or 137 | not os.path.exists(os.path.join(self.CHECKPOINT_PATH, "image_adapter.pt"))): 138 | os.makedirs(self.CHECKPOINT_PATH) 139 | # download the model and its config with requests 140 | url = "https://huggingface.co/spaces/fancyfeast/joy-caption-pre-alpha/resolve/main/wpkklhc6/image_adapter.pt" 141 | download(url, os.path.join(self.CHECKPOINT_PATH, "image_adapter.pt")) 142 | 143 | def cleanup(self): 144 | self.text_model.cpu() 145 | self.clip_model.cpu() 146 | self.image_adapter.cpu() 147 | del self.text_model 148 | del self.tokenizer 149 | del self.clip_model 150 | del self.clip_processor 151 | del self.image_adapter 152 | gc.collect() 153 | torch.cuda.empty_cache() 154 | 155 | 156 | class ImageAdapter(torch.nn.Module): 157 | def __init__(self, input_features: int, output_features: int): 158 | super().__init__() 159 | self.linear1 = torch.nn.Linear(input_features, output_features) 160 | self.activation = torch.nn.GELU() 161 | self.linear2 = torch.nn.Linear(output_features, output_features) 162 | 163 | def forward(self, vision_outputs: torch.Tensor): 164 | x = self.linear1(vision_outputs) 165 | x = self.activation(x) 166 | x = self.linear2(x) 167 | return x 168 | 169 | 170 | def resize_image(img, max=768): 171 | # Get the current dimensions of the image 172 | width, height = img.size 173 | 174 | # Determine the scaling factor to ensure the longest side is 1024 pixels 175 | if width > height: 176 | new_width = max 177 | new_height = int((max / width) * height) 178 | else: 179 | new_height = max 180 | new_width = int((max / height) * width) 181 | 182 | # Resize the image 183 | img = img.resize((new_width, new_height)) 184 | return img 185 | 186 | 187 | def download(url, file_path): 188 | r = requests.get(url) 189 | # write file to disk 190 | with open(file_path, "wb") as f: 191 | f.write(r.content) -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui-joycaption" 3 | description = "An extension for using joycaption (pre alpha) in comfyui" 4 | version = "1.0.0" 5 | license = "LICENSE" 6 | 7 | [project.urls] 8 | Repository = "https://github.com/larsupb/comfyui-joycaption.git" 9 | # Used by Comfy Registry https://comfyregistry.org 10 | 11 | [tool.comfy] 12 | PublisherId = "larsupb" 13 | DisplayName = "ComfyUI-Joycaption" 14 | Icon = "" 15 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | huggingface-hub 2 | tokenizers==0.19.1 3 | transformers==4.43.4 4 | optimum==1.21.4 5 | auto-gptq==0.7.1 --------------------------------------------------------------------------------