├── .dockerignore ├── .gitignore ├── README.md ├── cog.yaml ├── custom_workflows └── sdxl_txt2img.json ├── output.png ├── predict.py ├── script └── download-weights └── test.py /.dockerignore: -------------------------------------------------------------------------------- 1 | # The .dockerignore file excludes files from the container build process. 2 | # 3 | # https://docs.docker.com/engine/reference/builder/#dockerignore-file 4 | 5 | # Exclude Git files 6 | .git 7 | .github 8 | .gitignore 9 | 10 | # Exclude Python cache files 11 | __pycache__ 12 | .mypy_cache 13 | .pytest_cache 14 | .ruff_cache 15 | 16 | # Exclude Python virtual environment 17 | /venv 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .cog 3 | 4 | ComfyUI 5 | models 6 | 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI SDXL txt2img Cog model 2 | 3 | This is an implementation of the ComfyUI text2img workflow as a Cog model. [Cog packages machine learning models as standard containers.](https://github.com/replicate/cog) 4 | 5 | First, download the pre-trained weights: 6 | 7 | cog run script/download-weights 8 | 9 | Then, you can run predictions: 10 | 11 | cog predict -i input_prompt="beautiful scenery nature glass bottle landscape, pink galaxy bottle" 12 | 13 | The workflow used for this repo is found under: 14 | 15 | custom_workflows/sdxl_txt2img.json 16 | 17 | ## Example: 18 | 19 | "beautiful scenery nature glass bottle landscape, pink galaxy bottle" 20 | 21 | ![alt text](output.png) 22 | -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | build: 2 | # set to true if your model requires a GPU 3 | gpu: true 4 | 5 | system_packages: 6 | - ffmpeg 7 | 8 | python_version: "3.10.6" 9 | 10 | python_packages: 11 | - torch 12 | - torchvision 13 | - torchaudio 14 | - torchsde 15 | - einops 16 | - transformers>=4.25.1 17 | - safetensors>=0.3.0 18 | - aiohttp 19 | - accelerate 20 | - pyyaml 21 | - Pillow 22 | - scipy 23 | - tqdm 24 | - psutil 25 | - websocket-client==1.6.3 26 | 27 | 28 | # predict.py defines how predictions are run on your model 29 | predict: "predict.py:Predictor" -------------------------------------------------------------------------------- /custom_workflows/sdxl_txt2img.json: -------------------------------------------------------------------------------- 1 | { 2 | "3": { 3 | "inputs": { 4 | "seed": 196429611935343, 5 | "steps": 20, 6 | "cfg": 8, 7 | "sampler_name": "euler", 8 | "scheduler": "normal", 9 | "denoise": 1, 10 | "model": [ 11 | "4", 12 | 0 13 | ], 14 | "positive": [ 15 | "6", 16 | 0 17 | ], 18 | "negative": [ 19 | "7", 20 | 0 21 | ], 22 | "latent_image": [ 23 | "5", 24 | 0 25 | ] 26 | }, 27 | "class_type": "KSampler" 28 | }, 29 | "4": { 30 | "inputs": { 31 | "ckpt_name": "sd_xl_base_1.0.safetensors" 32 | }, 33 | "class_type": "CheckpointLoaderSimple" 34 | }, 35 | "5": { 36 | "inputs": { 37 | "width": 512, 38 | "height": 512, 39 | "batch_size": 1 40 | }, 41 | "class_type": "EmptyLatentImage" 42 | }, 43 | "6": { 44 | "inputs": { 45 | "text": "beautiful scenery nature glass bottle landscape, purple galaxy bottle", 46 | "clip": [ 47 | "4", 48 | 1 49 | ] 50 | }, 51 | "class_type": "CLIPTextEncode" 52 | }, 53 | "7": { 54 | "inputs": { 55 | "text": "text, watermark", 56 | "clip": [ 57 | "4", 58 | 1 59 | ] 60 | }, 61 | "class_type": "CLIPTextEncode" 62 | }, 63 | "8": { 64 | "inputs": { 65 | "samples": [ 66 | "3", 67 | 0 68 | ], 69 | "vae": [ 70 | "4", 71 | 2 72 | ] 73 | }, 74 | "class_type": "VAEDecode" 75 | }, 76 | "9": { 77 | "inputs": { 78 | "filename_prefix": "ComfyUI", 79 | "images": [ 80 | "8", 81 | 0 82 | ] 83 | }, 84 | "class_type": "SaveImage" 85 | } 86 | } -------------------------------------------------------------------------------- /output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucataco/cog-comfyui-sdxl-txt2img/2043e4facca154b261243de60212180135726ca4/output.png -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import threading 3 | import time 4 | from cog import BasePredictor, Input, Path 5 | # from typing import List 6 | import os 7 | import torch 8 | import shutil 9 | import uuid 10 | import json 11 | import urllib 12 | import websocket 13 | from PIL import Image 14 | from urllib.error import URLError 15 | import random 16 | 17 | 18 | class Predictor(BasePredictor): 19 | def setup(self): 20 | # start server 21 | self.server_address = "127.0.0.1:8188" 22 | self.start_server() 23 | 24 | def start_server(self): 25 | server_thread = threading.Thread(target=self.run_server) 26 | server_thread.start() 27 | 28 | while not self.is_server_running(): 29 | time.sleep(1) # Wait for 1 second before checking again 30 | 31 | print("Server is up and running!") 32 | 33 | def run_server(self): 34 | command = "python ./ComfyUI/main.py" 35 | server_process = subprocess.Popen(command, shell=True) 36 | server_process.wait() 37 | 38 | # hacky solution, will fix later 39 | def is_server_running(self): 40 | try: 41 | with urllib.request.urlopen("http://{}/history/{}".format(self.server_address, "123")) as response: 42 | return response.status == 200 43 | except URLError: 44 | return False 45 | 46 | def queue_prompt(self, prompt, client_id): 47 | p = {"prompt": prompt, "client_id": client_id} 48 | data = json.dumps(p).encode('utf-8') 49 | req = urllib.request.Request("http://{}/prompt".format(self.server_address), data=data) 50 | return json.loads(urllib.request.urlopen(req).read()) 51 | 52 | def get_image(self, filename, subfolder, folder_type): 53 | data = {"filename": filename, "subfolder": subfolder, "type": folder_type} 54 | print(folder_type) 55 | url_values = urllib.parse.urlencode(data) 56 | with urllib.request.urlopen("http://{}/view?{}".format(self.server_address, url_values)) as response: 57 | return response.read() 58 | 59 | def get_images(self, ws, prompt, client_id): 60 | prompt_id = self.queue_prompt(prompt, client_id)['prompt_id'] 61 | output_images = {} 62 | while True: 63 | out = ws.recv() 64 | if isinstance(out, str): 65 | message = json.loads(out) 66 | if message['type'] == 'executing': 67 | data = message['data'] 68 | if data['node'] is None and data['prompt_id'] == prompt_id: 69 | break #Execution is done 70 | else: 71 | continue #previews are binary data 72 | 73 | history = self.get_history(prompt_id)[prompt_id] 74 | for o in history['outputs']: 75 | for node_id in history['outputs']: 76 | node_output = history['outputs'][node_id] 77 | print("node output: ", node_output) 78 | 79 | if 'images' in node_output: 80 | images_output = [] 81 | for image in node_output['images']: 82 | image_data = self.get_image(image['filename'], image['subfolder'], image['type']) 83 | images_output.append(image_data) 84 | output_images[node_id] = images_output 85 | 86 | return output_images 87 | 88 | def get_history(self, prompt_id): 89 | with urllib.request.urlopen("http://{}/history/{}".format(self.server_address, prompt_id)) as response: 90 | return json.loads(response.read()) 91 | 92 | # TODO: add dynamic fields based on the workflow selected 93 | def predict( 94 | self, 95 | input_prompt: str = Input(description="Prompt", default="beautiful scenery nature glass bottle landscape, purple galaxy bottle"), 96 | negative_prompt: str = Input(description="Negative Prompt", default="text, watermark, ugly, blurry"), 97 | steps: int = Input( 98 | description="Steps", 99 | default=30 100 | ), 101 | seed: int = Input(description="Sampling seed, leave Empty for Random", default=None), 102 | ) -> Path: 103 | """Run a single prediction on the model""" 104 | if seed is None: 105 | seed = int.from_bytes(os.urandom(3), "big") 106 | print(f"Using seed: {seed}") 107 | generator = torch.Generator("cuda").manual_seed(seed) 108 | 109 | # queue prompt 110 | img_output_path = self.get_workflow_output( 111 | input_prompt = input_prompt, 112 | negative_prompt = negative_prompt, 113 | steps = steps, 114 | seed = seed 115 | ) 116 | return Path(img_output_path) 117 | 118 | 119 | def get_workflow_output(self, input_prompt, negative_prompt, steps, seed): 120 | # load config 121 | prompt = None 122 | workflow_config = "./custom_workflows/sdxl_txt2img.json" 123 | with open(workflow_config, 'r') as file: 124 | prompt = json.load(file) 125 | 126 | if not prompt: 127 | raise Exception('no workflow config found') 128 | 129 | # set input variables 130 | prompt["6"]["inputs"]["text"] = input_prompt 131 | prompt["7"]["inputs"]["text"] = negative_prompt 132 | 133 | prompt["3"]["inputs"]["seed"] = seed 134 | prompt["3"]["inputs"]["steps"] = steps 135 | 136 | # start the process 137 | client_id = str(uuid.uuid4()) 138 | ws = websocket.WebSocket() 139 | ws.connect("ws://{}/ws?clientId={}".format(self.server_address, client_id)) 140 | images = self.get_images(ws, prompt, client_id) 141 | 142 | for node_id in images: 143 | for image_data in images[node_id]: 144 | from PIL import Image 145 | import io 146 | image = Image.open(io.BytesIO(image_data)) 147 | image.save("out-"+node_id+".png") 148 | return Path("out-"+node_id+".png") 149 | -------------------------------------------------------------------------------- /script/download-weights: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import sys 5 | import shutil 6 | 7 | # append project directory to path so predict.py can be imported 8 | # sys.path.append('.') 9 | # from predict import MODEL_NAME, MODEL_CACHE, CONTROL_CACHE 10 | 11 | # Clone ComfyUI repo 12 | os.system("git clone https://github.com/comfyanonymous/ComfyUI.git") 13 | os.system("mkdir -p models/checkpoints") 14 | 15 | # TODO: See if linking from ComfyUI/models/checkpoints to models/checkpoints works 16 | os.system("rm -rf ComfyUI/models/checkpoints") 17 | os.system("ln -s /src/models/checkpoints ComfyUI/models/checkpoints") 18 | 19 | # Download model weights 20 | os.system("wget https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors -P models/checkpoints/") 21 | os.system("wget https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/resolve/main/sd_xl_refiner_1.0.safetensors -P models/checkpoints/") 22 | 23 | # Move all models to ComfyUI/models/checkpoints 24 | # os.system("mv ./models/checkpoints/* ComfyUI/models/checkpoints/") 25 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client) 2 | import uuid 3 | import json 4 | import urllib.request 5 | import urllib.parse 6 | 7 | server_address = "192.168.1.241:8188" 8 | 9 | def queue_prompt(prompt, client_id): 10 | p = {"prompt": prompt, "client_id": client_id} 11 | data = json.dumps(p).encode('utf-8') 12 | req = urllib.request.Request("http://{}/prompt".format(server_address), data=data) 13 | return json.loads(urllib.request.urlopen(req).read()) 14 | 15 | def get_image(filename, subfolder, folder_type): 16 | data = {"filename": filename, "subfolder": subfolder, "type": folder_type} 17 | url_values = urllib.parse.urlencode(data) 18 | with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response: 19 | return response.read() 20 | 21 | def get_history(prompt_id): 22 | with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response: 23 | return json.loads(response.read()) 24 | 25 | def get_images(ws, prompt, client_id): 26 | prompt_id = queue_prompt(prompt, client_id)['prompt_id'] 27 | output_images = {} 28 | while True: 29 | out = ws.recv() 30 | if isinstance(out, str): 31 | message = json.loads(out) 32 | if message['type'] == 'executing': 33 | data = message['data'] 34 | if data['node'] is None and data['prompt_id'] == prompt_id: 35 | break #Execution is done 36 | else: 37 | continue #previews are binary data 38 | 39 | history = get_history(prompt_id)[prompt_id] 40 | for o in history['outputs']: 41 | for node_id in history['outputs']: 42 | node_output = history['outputs'][node_id] 43 | if 'images' in node_output: 44 | images_output = [] 45 | for image in node_output['images']: 46 | image_data = get_image(image['filename'], image['subfolder'], image['type']) 47 | images_output.append(image_data) 48 | output_images[node_id] = images_output 49 | 50 | return output_images 51 | 52 | 53 | workflow_config = "./custom_workflows/sdxl_txt2img.json" 54 | with open(workflow_config, 'r') as file: 55 | prompt = json.load(file) 56 | 57 | 58 | #set the text prompt for our positive CLIPTextEncode 59 | prompt["6"]["inputs"]["text"] = "beautiful scenery nature glass bottle landscape, orange galaxy bottle" 60 | 61 | #set the seed for our KSampler node 62 | prompt["3"]["inputs"]["seed"] = 196429611935343 63 | 64 | client_id = str(uuid.uuid4()) 65 | ws = websocket.WebSocket() 66 | ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id)) 67 | images = get_images(ws, prompt, client_id) 68 | 69 | #Commented out code to display the output images: 70 | for node_id in images: 71 | for image_data in images[node_id]: 72 | from PIL import Image 73 | import io 74 | image = Image.open(io.BytesIO(image_data)) 75 | image.save("out-"+node_id+".png") --------------------------------------------------------------------------------