├── __init__.py ├── README.md ├── LICENSE └── nodes.py /__init__.py: -------------------------------------------------------------------------------- 1 | from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 2 | 3 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Gridswapper 2 | Gridswapper takes a batch of latents and spreads them over the necessary amount of grids. It then automatically shuffles the images in the grids for each step. 3 | So, a batch of 12 latents for a 2x2 grid will generate 3 grid images in each step. It will then shuffle around the images for the next step. This makes it possible for all images to influence the others during the denoising process. 4 | This approach works well for generating 2-4 grids. 5 | 6 | 7 | To improve convergence, especially when generating many grids, consider: 8 | 9 | * Increasing the number of steps, (for example 4*batch size) - this makes it more likely that each image can influence all others during the shuffling process 10 | * Using ancestral samplers (like euler_a) - the added noise at each step seems to help with consistency 11 | * Train a lora on 2x2 grids of the type of pictures you want to generate. 12 | 13 | --- 14 | 15 | Screenshot 2024-10-26 at 20 57 21 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 kinfolk0117 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /nodes.py: -------------------------------------------------------------------------------- 1 | import comfy.samplers 2 | from typing import List 3 | import random 4 | 5 | import comfy.sample 6 | import latent_preview 7 | import torch 8 | import comfy.utils 9 | 10 | 11 | class Noise_EmptyNoise: 12 | def __init__(self): 13 | self.seed = 0 14 | 15 | def generate_noise(self, input_latent): 16 | latent_image = input_latent["samples"] 17 | return torch.zeros( 18 | latent_image.shape, 19 | dtype=latent_image.dtype, 20 | layout=latent_image.layout, 21 | device="cpu", 22 | ) 23 | 24 | 25 | class Noise_RandomNoise: 26 | def __init__(self, seed): 27 | self.seed = seed 28 | 29 | def generate_noise(self, input_latent): 30 | latent_image = input_latent["samples"] 31 | batch_inds = ( 32 | input_latent["batch_index"] if "batch_index" in input_latent else None 33 | ) 34 | return comfy.sample.prepare_noise(latent_image, self.seed, batch_inds) 35 | 36 | 37 | class GridSwapper: 38 | @classmethod 39 | def INPUT_TYPES(s): 40 | return { 41 | "required": { 42 | "model": ( 43 | "MODEL", 44 | {"tooltip": "The model used for denoising the input latent."}, 45 | ), 46 | "seed": ( 47 | "INT", 48 | { 49 | "default": 0, 50 | "min": 0, 51 | "max": 0xFFFFFFFFFFFFFFFF, 52 | "tooltip": "The random seed used for creating the noise.", 53 | }, 54 | ), 55 | "steps": ( 56 | "INT", 57 | { 58 | "default": 20, 59 | "min": 1, 60 | "max": 10000, 61 | "tooltip": "The number of steps used in the denoising process.", 62 | }, 63 | ), 64 | "cfg": ( 65 | "FLOAT", 66 | { 67 | "default": 8.0, 68 | "min": 0.0, 69 | "max": 100.0, 70 | "step": 0.1, 71 | "round": 0.01, 72 | "tooltip": "The Classifier-Free Guidance scale balances creativity and adherence to the prompt.", 73 | }, 74 | ), 75 | "sampler_name": ( 76 | comfy.samplers.KSampler.SAMPLERS, 77 | { 78 | "tooltip": "The algorithm used when sampling, this can affect the quality, speed, and style of the generated output." 79 | }, 80 | ), 81 | "scheduler": ( 82 | comfy.samplers.KSampler.SCHEDULERS, 83 | { 84 | "tooltip": "The scheduler controls how noise is gradually removed to form the image." 85 | }, 86 | ), 87 | "positive": ( 88 | "CONDITIONING", 89 | { 90 | "tooltip": "The conditioning describing the attributes you want to include in the image." 91 | }, 92 | ), 93 | "negative": ( 94 | "CONDITIONING", 95 | { 96 | "tooltip": "The conditioning describing the attributes you want to exclude from the image." 97 | }, 98 | ), 99 | "latent_image": ("LATENT", {"tooltip": "The latent image to denoise."}), 100 | "denoise": ( 101 | "FLOAT", 102 | { 103 | "default": 1.0, 104 | "min": 0.0, 105 | "max": 1.0, 106 | "step": 0.01, 107 | "tooltip": "The amount of denoising applied, lower values will maintain the structure of the initial image allowing for image to image sampling.", 108 | }, 109 | ), 110 | "rows": ( 111 | "INT", 112 | { 113 | "default": 2, 114 | "min": 1, 115 | "max": 64, 116 | "tooltip": "Number of grid rows.", 117 | }, 118 | ), 119 | "cols": ( 120 | "INT", 121 | { 122 | "default": 2, 123 | "min": 1, 124 | "max": 64, 125 | "tooltip": "Number of grid columns.", 126 | }, 127 | ) 128 | } 129 | } 130 | 131 | RETURN_TYPES = ("LATENT",) 132 | OUTPUT_TOOLTIPS = ("The denoised latent.",) 133 | FUNCTION = "sample" 134 | 135 | CATEGORY = "sampling/custom_sampling" # TODO other category 136 | 137 | # not sure if this is really kendall tau distance 138 | def kendall_tau_distance(self, perm1: List[int], perm2: List[int]) -> int: 139 | """ 140 | Calculate Kendall tau distance between permutations. 141 | Implementation optimized for speed using position lookup. 142 | """ 143 | n = len(perm1) 144 | # Create position lookup for second permutation 145 | pos2 = {val: idx for idx, val in enumerate(perm2)} 146 | 147 | # Convert perm1 to relative positions in perm2 148 | relative_pos = [pos2[val] for val in perm1] 149 | 150 | # Count inversions using merge sort approach 151 | inversions = 0 152 | for i in range(n): 153 | for j in range(i + 1, n): 154 | if relative_pos[i] > relative_pos[j]: 155 | inversions += 1 156 | return inversions 157 | 158 | def get_diverse_permutations_fast( 159 | self, n: int, k: int, num_candidates: int = 100 160 | ) -> List[List[int]]: 161 | """ 162 | Get k diverse permutations using fast approximate method. 163 | 164 | Args: 165 | n: Length of array 166 | k: Number of permutations to return 167 | num_candidates: Number of random candidates to consider each iteration 168 | 169 | Returns: 170 | List of k diverse permutations 171 | """ 172 | if k <= 0: 173 | return [] 174 | 175 | # Start with first permutation 176 | result = [list(range(n))] 177 | 178 | # Helper function to generate random permutation 179 | def random_perm(): 180 | perm = list(range(n)) 181 | random.shuffle(perm) 182 | return perm 183 | 184 | # For each additional permutation needed 185 | for _ in range(k - 1): 186 | candidates = [random_perm() for _ in range(num_candidates)] 187 | 188 | # Find candidate with maximum minimum distance to existing permutations 189 | max_min_distance = -1 190 | best_candidate = None 191 | 192 | for candidate in candidates: 193 | # Calculate minimum distance to any existing permutation 194 | min_distance = min( 195 | self.kendall_tau_distance(candidate, existing) 196 | for existing in result 197 | ) 198 | 199 | if min_distance > max_min_distance: 200 | max_min_distance = min_distance 201 | best_candidate = candidate 202 | 203 | result.append(best_candidate) 204 | 205 | return result 206 | 207 | def combine_latents(self, samples, rows, cols): 208 | x = samples 209 | cell_count = rows * cols 210 | 211 | if x.shape[0] != cell_count: 212 | raise ValueError(f"Expected {cell_count} latent images, got {x.shape[0]}") 213 | 214 | dim = x.shape[1] 215 | h = x.shape[2] 216 | w = x.shape[3] 217 | combined_h = h * rows 218 | combined_w = w * cols 219 | combined = torch.zeros( 220 | (1, dim, combined_h, combined_w), device=x.device, dtype=x.dtype 221 | ) 222 | for i in range(rows): 223 | for j in range(cols): 224 | row_start = i * h 225 | row_end = row_start + h 226 | col_start = j * w 227 | col_end = col_start + w 228 | index = i * cols + j 229 | combined[0, :, row_start:row_end, col_start:col_end] = x[index] 230 | return combined 231 | 232 | def split_latents(self, combined, rows, cols): 233 | x = combined 234 | cell_count = rows * cols 235 | dim = x.shape[1] 236 | 237 | if x.shape[0] != 1: 238 | raise ValueError(f"Expected 1 latent image, got {x.shape[0]}") 239 | 240 | combined_h = x.shape[2] 241 | combined_w = x.shape[3] 242 | h = combined_h // rows 243 | w = combined_w // cols 244 | 245 | split = torch.zeros((cell_count, dim, h, w), device=x.device, dtype=x.dtype) 246 | for i in range(rows): 247 | for j in range(cols): 248 | index = i * cols + j 249 | row_start = i * h 250 | row_end = row_start + h 251 | col_start = j * w 252 | col_end = col_start + w 253 | split[index] = x[0, :, row_start:row_end, col_start:col_end] 254 | return split 255 | 256 | 257 | def sample( 258 | self, 259 | model, 260 | seed, 261 | steps, 262 | cfg, 263 | sampler_name, 264 | scheduler, 265 | positive, 266 | negative, 267 | latent_image, 268 | denoise=1.0, 269 | rows=2, 270 | cols=2, 271 | ): 272 | cells = rows * cols 273 | 274 | latent = latent_image 275 | latent_image = latent["samples"] 276 | 277 | latent = latent.copy() 278 | latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image) 279 | latent["samples"] = latent_image 280 | 281 | latent_image = latent_image.clone() 282 | 283 | no_latents = latent_image.shape[0] 284 | 285 | if no_latents % cells != 0: 286 | raise ValueError(f"Number of latents ({no_latents}) is not a multiple of cells ({cells}), latents need to be divisible by cells. With {rows} rows x {cols} cols = {cells} cells, this means for example {1*cells}, {2*cells}, {3*cells}, ... latents are supported.") 287 | 288 | 289 | perms = self.get_diverse_permutations_fast(no_latents, steps) 290 | no_combined = latent_image.shape[0] // cells 291 | 292 | selected_latents = latent_image[range(cells)] 293 | samples_a = self.combine_latents(selected_latents, rows, cols) 294 | clatent = latent.copy() 295 | clatent["samples"] = samples_a 296 | 297 | noise = [] 298 | for i in range(no_combined): 299 | noise.append(Noise_RandomNoise(seed + i).generate_noise(clatent)) 300 | 301 | empty_noise = Noise_EmptyNoise().generate_noise(clatent) 302 | noise_mask = latent.get("noise_mask", None) 303 | disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED 304 | callback = latent_preview.prepare_callback(model, steps) 305 | 306 | for i in range(steps): 307 | print(f"Step {i+1}/{steps}") 308 | p = perms[i] 309 | for j in range(0, len(p), cells): 310 | s_noise = noise[j // cells] if i == 0 else empty_noise 311 | selected_latents = latent_image[p[j : j + cells]] 312 | samples_a = self.combine_latents(selected_latents, rows, cols) 313 | 314 | start_step = i+1 315 | last_step = start_step + 1 316 | force_full_denoise = last_step == steps 317 | 318 | samples_a = comfy.sample.sample( 319 | model, 320 | s_noise, 321 | steps, 322 | cfg, 323 | sampler_name, 324 | scheduler, 325 | positive, 326 | negative, 327 | samples_a, 328 | denoise=denoise, 329 | disable_noise=(i > 0), 330 | start_step=start_step, 331 | last_step=last_step, 332 | force_full_denoise=force_full_denoise, 333 | noise_mask=noise_mask, 334 | callback=callback, 335 | disable_pbar=disable_pbar, 336 | seed=seed, 337 | ) 338 | 339 | split_a = self.split_latents(samples_a, rows, cols) 340 | latent_image[p[j : j + cells]] = split_a 341 | 342 | out = latent.copy() 343 | out["samples"] = latent_image 344 | return (out,) 345 | 346 | 347 | NODE_CLASS_MAPPINGS = { 348 | "GridSwapper": GridSwapper, 349 | } 350 | 351 | NODE_DISPLAY_NAME_MAPPINGS = {} 352 | --------------------------------------------------------------------------------