├── .gitignore ├── classifiers ├── __init__.py ├── aes-B32-v0.safetensors ├── laion-sac-logos-ava-v2.safetensors ├── cafe_waifu.py ├── cafe_aesthetic.py ├── aesthetic.py └── laion.py ├── README.md ├── __init__.py └── nodes_model_merging.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | -------------------------------------------------------------------------------- /classifiers/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["laion", "aesthetic", "cafe_waifu", "cafe_aesthetic"] 2 | -------------------------------------------------------------------------------- /classifiers/aes-B32-v0.safetensors: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/szhublox/ambw_comfyui/HEAD/classifiers/aes-B32-v0.safetensors -------------------------------------------------------------------------------- /classifiers/laion-sac-logos-ava-v2.safetensors: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/szhublox/ambw_comfyui/HEAD/classifiers/laion-sac-logos-ava-v2.safetensors -------------------------------------------------------------------------------- /classifiers/cafe_waifu.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | 3 | def score(image): 4 | pipe = transformers.pipeline("image-classification", 5 | model="cafeai/cafe_waifu") 6 | result = pipe(image, top_k=5) 7 | for data in result: 8 | if data['label'] == "waifu": 9 | return data['score'] 10 | -------------------------------------------------------------------------------- /classifiers/cafe_aesthetic.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | 3 | def score(image): 4 | pipe = transformers.pipeline("image-classification", 5 | model="cafeai/cafe_aesthetic") 6 | result = pipe(image, top_k=2) 7 | for data in result: 8 | if data['label'] == "aesthetic": 9 | return data['score'] 10 | -------------------------------------------------------------------------------- /classifiers/aesthetic.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | import numpy as np 4 | import torch 5 | import safetensors 6 | from transformers import CLIPModel, CLIPProcessor 7 | 8 | use_cuda = torch.cuda.is_available() 9 | 10 | def image_embeddings_direct(image, model, processor): 11 | inputs = processor(images=image, return_tensors='pt')['pixel_values'] 12 | if use_cuda: 13 | inputs = inputs.to('cuda') 14 | result = model.get_image_features(pixel_values=inputs).cpu().detach().numpy() 15 | return (result / np.linalg.norm(result)).squeeze(axis=0) 16 | 17 | # binary classifier that consumes CLIP embeddings 18 | class Classifier(torch.nn.Module): 19 | def __init__(self, input_size, hidden_size, output_size): 20 | super().__init__() 21 | self.fc1 = torch.nn.Linear(input_size, hidden_size) 22 | self.fc2 = torch.nn.Linear(hidden_size, hidden_size//2) 23 | self.fc3 = torch.nn.Linear(hidden_size//2, output_size) 24 | self.relu = torch.nn.ReLU() 25 | self.sigmoid = torch.nn.Sigmoid() 26 | 27 | def forward(self, x): 28 | x = self.fc1(x) 29 | x = self.relu(x) 30 | x = self.fc2(x) 31 | x = self.relu(x) 32 | x = self.fc3(x) 33 | x = self.sigmoid(x) 34 | return x 35 | 36 | dirname = pathlib.Path(__file__).parent 37 | aesthetic_path = dirname.joinpath("aes-B32-v0.safetensors") 38 | clip_name = 'openai/clip-vit-base-patch32' 39 | clipprocessor = CLIPProcessor.from_pretrained(clip_name) 40 | clipmodel = CLIPModel.from_pretrained(clip_name).to('cuda').eval() 41 | aes_model = Classifier(512, 256, 1).to('cuda') 42 | aes_model.load_state_dict(safetensors.torch.load_file(aesthetic_path)) 43 | 44 | def score(image): 45 | image_embeds = image_embeddings_direct(image, clipmodel, clipprocessor) 46 | prediction = aes_model(torch.from_numpy(image_embeds).float().to('cuda')) 47 | return prediction.item() 48 | -------------------------------------------------------------------------------- /classifiers/laion.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | import numpy as np 4 | import clip 5 | import torch 6 | import safetensors.torch 7 | 8 | use_cuda = torch.cuda.is_available() 9 | 10 | def image_embeddings_direct(image, model, processor): 11 | inputs = processor(images=image, return_tensors='pt')['pixel_values'] 12 | if use_cuda: 13 | inputs = inputs.to('cuda') 14 | result = model.get_image_features(pixel_values=inputs).cpu().detach().numpy() 15 | return (result / np.linalg.norm(result)).squeeze(axis=0) 16 | 17 | def normalized(a, axis=-1, order=2): 18 | l2 = np.atleast_1d(np.linalg.norm(a, order, axis)) 19 | l2[l2 == 0] = 1 20 | return a / np.expand_dims(l2, axis) 21 | 22 | device = "cuda" if torch.cuda.is_available() else "cpu" 23 | model, preprocess = clip.load("ViT-L/14", device=device) 24 | 25 | def image_embeddings_direct_laion(pil_image): 26 | image = preprocess(pil_image).unsqueeze(0).to(device) 27 | with torch.no_grad(): 28 | image_features = model.encode_image(image) 29 | im_emb_arr = normalized(image_features.cpu().detach().numpy()) 30 | return im_emb_arr 31 | 32 | class MLP(torch.nn.Module): 33 | def __init__(self, input_size, xcol='emb', ycol='avg_rating'): 34 | super().__init__() 35 | self.input_size = input_size 36 | self.xcol = xcol 37 | self.ycol = ycol 38 | self.layers = torch.nn.Sequential( 39 | torch.nn.Linear(self.input_size, 1024), 40 | torch.nn.Dropout(0.2), 41 | torch.nn.Linear(1024, 128), 42 | torch.nn.Dropout(0.2), 43 | torch.nn.Linear(128, 64), 44 | torch.nn.Dropout(0.1), 45 | torch.nn.Linear(64, 16), 46 | torch.nn.Linear(16, 1) 47 | ) 48 | 49 | def forward(self, x): 50 | return self.layers(x) 51 | 52 | dirname = pathlib.Path(__file__).parent 53 | aesthetic_path = dirname.joinpath("laion-sac-logos-ava-v2.safetensors") 54 | aes_model = MLP(768).to('cuda').eval() 55 | aes_model.load_state_dict(safetensors.torch.load_file(aesthetic_path)) 56 | 57 | def score(image): 58 | image_embeds = image_embeddings_direct_laion(image) 59 | prediction = aes_model(torch.from_numpy(image_embeds).float().to('cuda')) 60 | return prediction.item() 61 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Auto-MBW for [ComfyUI](https://github.com/comfyanonymous/ComfyUI) loosely based on [sdweb-auto-MBW](https://github.com/Xerxemi/sdweb-auto-MBW) 2 | 3 | ### Purpose 4 | This node "advanced > auto merge block weighted" takes two models, merges individual blocks together at various ratios, and automatically rates each merge, keeping the ratio with the highest score. Whether this is a good idea or not is anyone's guess. In practice this makes models that make images the classifier says are good. You would probably disagree with the classifiers' decisions often. 5 | 6 | ### Settings 7 | - Prompt: to generate sample images to be rated 8 | - Sample Count: number of samples per ratio per block to generate 9 | - Search Depth: number of branches to take while choosing ratios to test 10 | - Classifier: model used to rate images 11 | 12 | ### Search Depth 13 | To calculate ratios to test, the node branches out from powers of 0.5 14 | 15 | - A depth of 2 will examine 0.0, 0.5, 1.0 16 | - A depth of 4 will examine 0.0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0 17 | - A depth of 6 will examine 33 different ratios 18 | 19 | There are 25 blocks to examine. If you use a depth of 4 and create 2 samples each, `25 * 9 * 2 = 450` images will be generated. 20 | 21 | ### Classifier 22 | The classifier models have been taken from the sdweb-auto-MBW repo. 23 | 24 | - [Laion Aesthetic Predictor](https://huggingface.co/spaces/Geonmo/laion-aesthetic-predictor) 25 | - [Waifu Diffusion 1.4 aesthetic model](https://huggingface.co/hakurei/waifu-diffusion-v1-4) 26 | - [Cafe Waifu](https://huggingface.co/cafeai/cafe_waifu) and [Cafe Aesthetic](https://huggingface.co/cafeai/cafe_aesthetic) 27 | 28 | ### Notes 29 | - many hardcoded settings are arbitrary such as the seed, sampler and block processing order 30 | - generated images are not saved 31 | - the resulting model will contain the text encoder and VAE sent to the node 32 | 33 | ### Bugs 34 | - merging process doesn't use the comfy ModelPatcher method and takes hundreds of milliseconds 35 | - - as a result, --highvram flag recommended. both models will be kept in VRAM and the process is much faster 36 | - the unet will (probably) be fp16 and the rest fp32. that's how they're sent to the node 37 | - - see: `model_management.should_use_fp16()` 38 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import math 3 | import pathlib 4 | import sys 5 | import warnings 6 | 7 | import numpy as np 8 | from PIL import Image 9 | import torch 10 | import tqdm 11 | 12 | import folder_paths 13 | import model_management 14 | import nodes 15 | 16 | sys.path.append(str(pathlib.Path(__file__).parent)) 17 | import classifiers 18 | from nodes_model_merging import CheckpointSave 19 | 20 | BLOCK_ORDER = [12, 11, 13, 10, 14, 9, 15, 8, 16, 7, 17, 6, 18, 21 | 5, 19, 4, 20, 3, 21, 2, 22, 1, 23, 0, 24] 22 | 23 | class AutoMBW(CheckpointSave): 24 | def __init__(self): 25 | self.type = "output" 26 | 27 | @classmethod 28 | def INPUT_TYPES(s): 29 | return { 30 | "required": { 31 | "model1": ("MODEL",), 32 | "model2": ("MODEL",), 33 | "clip": ("CLIP",), 34 | "vae": ("VAE",), 35 | "prompt": ("STRING", { 36 | "multiline": True, 37 | "default": "masterpiece girl" 38 | }), 39 | "negative": ("STRING", { 40 | "multiline": True, 41 | "default": "worst quality" 42 | }), 43 | "search_depth": ("INT", {"default": 4, "min": 2}), 44 | "sample_count": ("INT", {"default": 1, "min": 1}), 45 | "classifier": (classifiers.__all__,), 46 | "filename_prefix": ("STRING", { "multiline": False, "default": "ambw" }), 47 | }} 48 | 49 | RETURN_TYPES = () 50 | OUTPUT_NODE = True 51 | FUNCTION = "ambw" 52 | CATEGORY = "advanced" 53 | 54 | @torch.no_grad() 55 | def merge(self, block, ratio): 56 | sd1 = self.model1.model.state_dict() 57 | sd2 = self.model2.model.state_dict() 58 | 59 | self.blocks_backup = {} 60 | for key in self.blocks[block]: 61 | self.blocks_backup[key] = sd1[key].clone() 62 | sd1[key].copy_(sd1[key] * (1 - ratio) + sd2[key] * ratio) 63 | 64 | @torch.no_grad() 65 | def unmerge(self): 66 | sd1 = self.model1.model.state_dict() 67 | 68 | for key in self.blocks_backup: 69 | sd1[key].copy_(self.blocks_backup[key]) 70 | 71 | def rate_model(self): 72 | rating = 0 73 | for i in range(self.sample_count): 74 | latent = nodes.common_ksampler( 75 | self.model1, i, 20, 7.0, "ddim", "normal", self.prompt, 76 | self.negative, {"samples": torch.zeros([1, 4, 64, 64])}, 77 | denoise=1.0) 78 | decoded = self.vae.decode(latent[0]["samples"]) 79 | image = Image.fromarray( 80 | np.clip(255. * decoded.cpu().numpy().squeeze(), 81 | 0, 255).astype(np.uint8)) 82 | with warnings.catch_warnings(): 83 | # several possible transformers nags 84 | warnings.filterwarnings('ignore') 85 | rating += self.classifier(image) 86 | return rating 87 | 88 | def search(self, block, current, start, depth, maximum): 89 | if depth > self.search_depth or current > 1 or current < 0: 90 | return maximum 91 | 92 | self.merge(block, current) 93 | score = self.rate_model() 94 | self.unmerge() 95 | if score > maximum[1]: 96 | maximum = (current, score) 97 | 98 | step = math.pow(start, depth) 99 | for test_step in (-step, step): 100 | score = self.search(block, current + test_step, start, depth + 1, 101 | maximum) 102 | if score[1] > maximum[1]: 103 | maximum = score 104 | return maximum 105 | 106 | def ambw(self, model1, model2, clip, vae, prompt, negative, search_depth, 107 | sample_count, classifier, filename_prefix): 108 | # python setup 109 | self.output_dir = folder_paths.get_output_directory() 110 | self.model1 = model1 111 | self.model2 = model2 112 | self.vae = vae 113 | self.prompt = [[clip.encode(prompt), {}]] 114 | self.negative = [[clip.encode(negative), {}]] 115 | self.search_depth = search_depth 116 | self.sample_count = sample_count 117 | self.classifier = importlib.import_module( 118 | "." + classifier, "classifiers").score 119 | 120 | # model setup 121 | if model_management.vram_state == model_management.VRAMState.HIGH_VRAM: 122 | model1.model.to(model_management.get_torch_device()) 123 | model2.model.to(model_management.get_torch_device()) 124 | 125 | self.ratios = [None] * 25 126 | self.blocks = [None] * 25 127 | sd1 = model1.model.state_dict() 128 | self.blocks[12] = [key for key in sd1 if "middle_block" in key] 129 | for index in range(12): 130 | self.blocks[index] = \ 131 | [key for key in sd1 if f"input_blocks.{index}." in key] 132 | self.blocks[index + 13] = \ 133 | [key for key in sd1 if f"output_blocks.{index}." in key] 134 | 135 | def tqdm_steps(depth): 136 | if depth < 3: 137 | return 3 138 | return math.pow(2, depth - 2) + tqdm_steps(depth - 1) 139 | 140 | for block in tqdm.tqdm(BLOCK_ORDER, desc='automerge', unit='block', 141 | position=int(tqdm_steps( 142 | search_depth) * sample_count)): 143 | self.ratios[block] = self.search(block, 0.5, 0.5, 1, (0.5, 0))[0] 144 | self.merge(block, self.ratios[block]) 145 | print(self.ratios) 146 | 147 | self.save(self.model1, clip, vae, filename_prefix) 148 | 149 | return () 150 | 151 | 152 | NODE_CLASS_MAPPINGS = { 153 | "Auto Merge Block Weighted": AutoMBW 154 | } 155 | -------------------------------------------------------------------------------- /nodes_model_merging.py: -------------------------------------------------------------------------------- 1 | import comfy.sd 2 | import comfy.utils 3 | import comfy.model_base 4 | 5 | import folder_paths 6 | import json 7 | import os 8 | 9 | from comfy.cli_args import args 10 | 11 | class ModelMergeSimple: 12 | @classmethod 13 | def INPUT_TYPES(s): 14 | return {"required": { "model1": ("MODEL",), 15 | "model2": ("MODEL",), 16 | "ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), 17 | }} 18 | RETURN_TYPES = ("MODEL",) 19 | FUNCTION = "merge" 20 | 21 | CATEGORY = "advanced/model_merging" 22 | 23 | def merge(self, model1, model2, ratio): 24 | m = model1.clone() 25 | kp = model2.get_key_patches("diffusion_model.") 26 | for k in kp: 27 | m.add_patches({k: kp[k]}, 1.0 - ratio, ratio) 28 | return (m, ) 29 | 30 | class CLIPMergeSimple: 31 | @classmethod 32 | def INPUT_TYPES(s): 33 | return {"required": { "clip1": ("CLIP",), 34 | "clip2": ("CLIP",), 35 | "ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), 36 | }} 37 | RETURN_TYPES = ("CLIP",) 38 | FUNCTION = "merge" 39 | 40 | CATEGORY = "advanced/model_merging" 41 | 42 | def merge(self, clip1, clip2, ratio): 43 | m = clip1.clone() 44 | kp = clip2.get_key_patches() 45 | for k in kp: 46 | if k.endswith(".position_ids") or k.endswith(".logit_scale"): 47 | continue 48 | m.add_patches({k: kp[k]}, 1.0 - ratio, ratio) 49 | return (m, ) 50 | 51 | class ModelMergeBlocks: 52 | @classmethod 53 | def INPUT_TYPES(s): 54 | return {"required": { "model1": ("MODEL",), 55 | "model2": ("MODEL",), 56 | "input": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), 57 | "middle": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), 58 | "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) 59 | }} 60 | RETURN_TYPES = ("MODEL",) 61 | FUNCTION = "merge" 62 | 63 | CATEGORY = "advanced/model_merging" 64 | 65 | def merge(self, model1, model2, **kwargs): 66 | m = model1.clone() 67 | kp = model2.get_key_patches("diffusion_model.") 68 | default_ratio = next(iter(kwargs.values())) 69 | 70 | for k in kp: 71 | ratio = default_ratio 72 | k_unet = k[len("diffusion_model."):] 73 | 74 | last_arg_size = 0 75 | for arg in kwargs: 76 | if k_unet.startswith(arg) and last_arg_size < len(arg): 77 | ratio = kwargs[arg] 78 | last_arg_size = len(arg) 79 | 80 | m.add_patches({k: kp[k]}, 1.0 - ratio, ratio) 81 | return (m, ) 82 | 83 | class CheckpointSave: 84 | def __init__(self): 85 | self.output_dir = folder_paths.get_output_directory() 86 | 87 | @classmethod 88 | def INPUT_TYPES(s): 89 | return {"required": { "model": ("MODEL",), 90 | "clip": ("CLIP",), 91 | "vae": ("VAE",), 92 | "filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),}, 93 | "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},} 94 | RETURN_TYPES = () 95 | FUNCTION = "save" 96 | OUTPUT_NODE = True 97 | 98 | CATEGORY = "advanced/model_merging" 99 | 100 | def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None): 101 | full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) 102 | prompt_info = "" 103 | if prompt is not None: 104 | prompt_info = json.dumps(prompt) 105 | 106 | metadata = {} 107 | 108 | enable_modelspec = True 109 | if isinstance(model.model, comfy.model_base.SDXL): 110 | metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base" 111 | elif isinstance(model.model, comfy.model_base.SDXLRefiner): 112 | metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner" 113 | else: 114 | enable_modelspec = False 115 | 116 | if enable_modelspec: 117 | metadata["modelspec.sai_model_spec"] = "1.0.0" 118 | metadata["modelspec.implementation"] = "sgm" 119 | metadata["modelspec.title"] = "{} {}".format(filename, counter) 120 | 121 | #TODO: 122 | # "stable-diffusion-v1", "stable-diffusion-v1-inpainting", "stable-diffusion-v2-512", 123 | # "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h", 124 | # "v2-inpainting" 125 | 126 | if model.model.model_type == comfy.model_base.ModelType.EPS: 127 | metadata["modelspec.predict_key"] = "epsilon" 128 | elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION: 129 | metadata["modelspec.predict_key"] = "v" 130 | 131 | if not args.disable_metadata: 132 | metadata["prompt"] = prompt_info 133 | if extra_pnginfo is not None: 134 | for x in extra_pnginfo: 135 | metadata[x] = json.dumps(extra_pnginfo[x]) 136 | 137 | output_checkpoint = f"{filename}_{counter:05}_.safetensors" 138 | output_checkpoint = os.path.join(full_output_folder, output_checkpoint) 139 | 140 | comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata) 141 | return {} 142 | 143 | 144 | NODE_CLASS_MAPPINGS = { 145 | "ModelMergeSimple": ModelMergeSimple, 146 | "ModelMergeBlocks": ModelMergeBlocks, 147 | "CheckpointSave": CheckpointSave, 148 | "CLIPMergeSimple": CLIPMergeSimple, 149 | } 150 | --------------------------------------------------------------------------------