├── LoRAcaption.py ├── README.md └── __init__.py /LoRAcaption.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from PIL import Image 4 | from PIL import ImageOps 5 | import numpy as np 6 | import torch 7 | import comfy 8 | 9 | class LoRACaptionSave: 10 | def __init__(self): 11 | pass 12 | 13 | @classmethod 14 | def INPUT_TYPES(cls): 15 | return { 16 | "required": { 17 | "namelist": ("STRING", {"forceInput": True}), 18 | "path": ("STRING", {"forceInput": True}), 19 | "text": ("STRING", {"forceInput": True}), 20 | }, 21 | "optional": { 22 | "prefix": ("STRING", {"default": " "}), 23 | } 24 | } 25 | 26 | OUTPUT_NODE = True 27 | RETURN_TYPES = () 28 | FUNCTION = "save_text_file" 29 | CATEGORY = "LJRE/LORA" 30 | 31 | def save_text_file(self, text, path, namelist, prefix): 32 | 33 | if not os.path.exists(path): 34 | cstr(f"The path `{path}` doesn't exist! Creating it...").warning.print() 35 | try: 36 | os.makedirs(path, exist_ok=True) 37 | except OSError as e: 38 | cstr(f"The path `{path}` could not be created! Is there write access?\n{e}").error.print() 39 | 40 | if text.strip() == '': 41 | cstr(f"There is no text specified to save! Text is empty.").error.print() 42 | 43 | namelistsplit = namelist.splitlines() 44 | namelistsplit = [i[:-4] for i in namelistsplit] 45 | 46 | 47 | if prefix.endswith(","): 48 | prefix += " " 49 | elif not prefix.endswith(", "): 50 | prefix+= ", " 51 | 52 | file_extension = '.txt' 53 | filename = self.generate_filename(path, namelistsplit, file_extension) 54 | 55 | file_path = os.path.join(path, filename) 56 | self.writeTextFile(file_path, text, prefix) 57 | 58 | return (text, { "ui": { "string": text } } ) 59 | 60 | def generate_filename(self, path, namelistsplit, extension): 61 | counter = 1 62 | filename = f"{namelistsplit[counter-1]}{extension}" 63 | while os.path.exists(os.path.join(path, filename)): 64 | counter += 1 65 | filename = f"{namelistsplit[counter-1]}{extension}" 66 | 67 | return filename 68 | 69 | def writeTextFile(self, file, content, prefix): 70 | try: 71 | with open(file, 'w', encoding='utf-8', newline='\n') as f: 72 | content= prefix + content 73 | f.write(content) 74 | except OSError: 75 | cstr(f"Unable to save file `{file}`").error.print() 76 | 77 | def io_file_list(dir='',pattern='*.txt'): 78 | res=[] 79 | for filename in glob.glob(os.path.join(dir,pattern)): 80 | res.append(filename) 81 | return res 82 | 83 | 84 | class LoRACaptionLoad: 85 | def __init__(self): 86 | pass 87 | 88 | @classmethod 89 | def INPUT_TYPES(s): 90 | return { 91 | "required": { 92 | "path": ("STRING", {"default":""}), 93 | }, 94 | } 95 | 96 | RETURN_TYPES = ("STRING", "STRING", "IMAGE",) 97 | RETURN_NAMES = ("Name list", "path", "Image list",) 98 | 99 | FUNCTION = "captionload" 100 | 101 | #OUTPUT_NODE = False 102 | 103 | CATEGORY = "LJRE/LORA" 104 | 105 | def captionload(self, path, pattern='*.png'): 106 | text=io_file_list(path,pattern) 107 | text=list(map(os.path.basename,text)) 108 | text='\n'.join(text) 109 | 110 | #image loader 111 | if not os.path.isdir(path): 112 | raise FileNotFoundError(f"path '{path} cannot be found.'") 113 | dir_files = os.listdir(path) 114 | if len(dir_files) == 0: 115 | raise FileNotFoundError(f"No files in path '{path}'.") 116 | 117 | # Filter files by extension 118 | valid_extensions = ['.png'] 119 | dir_files = [f for f in dir_files if any(f.lower().endswith(ext) for ext in valid_extensions)] 120 | 121 | dir_files = [os.path.join(path, x) for x in dir_files] 122 | 123 | images = [] 124 | image_count = 0 125 | 126 | for image_path in dir_files: 127 | if os.path.isdir(image_path) and os.path.ex: 128 | continue 129 | i = Image.open(image_path) 130 | i = ImageOps.exif_transpose(i) 131 | image = i.convert("RGB") 132 | image = np.array(image).astype(np.float32) / 255.0 133 | image = torch.from_numpy(image)[None,] 134 | images.append(image) 135 | image_count += 1 136 | 137 | if len(images) == 1: 138 | return (images[0], 1) 139 | elif len(images) > 1: 140 | image1 = images[0] 141 | for image2 in images[1:]: 142 | if image1.shape[1:] != image2.shape[1:]: 143 | image2 = comfy.utils.common_upscale(image2.movedim(-1, 1), image1.shape[2], image1.shape[1], "bilinear", "center").movedim(1, -1) 144 | image1 = torch.cat((image1, image2), dim=0) 145 | 146 | 147 | 148 | return text, path, image1, len(images) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This is a custom node pack for ComfyUI. 2 | The LoRA Caption custom nodes, just like their name suggests, allow you to caption images so they are ready for LoRA training. You can find them by right-clicking and looking for the LJRE category, or you can double-click on an empty space and search for "caption". 3 | They were made to work with WD14 Tagger. 4 | 5 | 6 | 7 | Here is the workflow: 8 | ![Capture](https://github.com/LarryJane491/Image-Captioning-in-ComfyUI/assets/156431112/89a33c73-32c2-470d-9494-89349ccd76ee) 9 | 10 | 11 | 12 | Simple but elegant x) 13 | This workflow shows both nodes in this pack: LoRA Caption Load and LoRA Caption Save. 14 | 15 | The other custom nodes used here are: 16 | 17 | WD 1.4 Tagger (mandatory) 18 | 19 | Jjk custom nodes (optional) 20 | 21 | 22 | The Tagger is mandatory as this is the one that actually does the captioning. You also have to download a model, check out the github of that node for more information. My custom nodes are built as a complement for this one. 23 | 24 | Jjk is optional, it just lets you see that the software does extract the names of the files. 25 | 26 | 27 | 28 | Here is how it works: 29 | 30 | Gather the images for your LoRA database, in a single folder. Make sure the images are all in png (this requirement will be changed in a new version). 31 | 32 | Copy that folder’s path and write it down in the widget of the Load node. 33 | 34 | Plug the image output of the Load node into the Tagger, and the other two outputs in the inputs of the Save node. Plug the Tagger output into the Save node too. 35 | 36 | And that’s it! Just launch the workflow now. 37 | 38 | 39 | 40 | The Load node has two jobs: feed the images to the tagger and get the names of every image file in that folder. The name list and the captions are then fed to the Save node, which creates text files with the image name as its own name and the description of the image as its content (in other words: it creates the caption files). 41 | 42 | Once the files are done, your database is ready for LoRA training! The next big step is LoRa Training, which is possible from withing ComfyUI with another custom node of my own creation. 43 | 44 | 45 | 46 | 47 | 48 | Notes: 49 | 50 | The WD 1.4 Tagger is for anime images, so I don’t know how good it is for realistic images. I don’t see why it wouldn’t work though! At least for anime it is extremely impressive imo. 51 | 52 | If the text files already exist, Comfy will throw the Out of Range error. I could easily fix that, but I don’t see the point: just make sure the text files don’t exist already. If you want to change them, just delete them and relaunch the workflow. 53 | 54 | The widget lets you write a common prefix. It’s useful for creating trigger words for your LoRA. If you use the widget, make sure it ends with a comma. Again, it’s something I could easily fix, but I'm a little lazy x). 55 | 56 | 57 | 58 | I would like to thank the creators of Inspire Pack and YMC Suite Node, as my functions are heavily inspired by theirs. In fact, I had a workflow working with them, without my custom nodes at all. My project is just a rewrite of some of their functions, as a way to train myself for making my own nodes. 59 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .LoRAcaption import LoRACaptionSave, LoRACaptionLoad 2 | 3 | NODE_CLASS_MAPPINGS = { 4 | "LoRA Caption Save": LoRACaptionSave, 5 | "LoRA Caption Load": LoRACaptionLoad, 6 | } 7 | 8 | NODE_DISPLAY_NAME_MAPPINGS = { 9 | 10 | } 11 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] --------------------------------------------------------------------------------