├── README.md ├── evolve.py ├── images ├── after_evolution_sample1.png ├── after_evolution_sample2.png ├── after_evolution_sample3.png ├── before_evolution_sample1.png ├── before_evolution_sample2.png └── before_evolution_sample3.png ├── llava_util.py ├── merge.py ├── requirements.txt ├── scripts └── eval-model.py ├── sd3-evolve.py └── sdxl-evolve.py /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Diffusion Evolver 3 | 4 | This project is an evolutionary framework for optimizing Stable Diffusion XL models. It allows you to evolve a population of models through crossover, mutation, and selection based on their performance judged by a VLM. 5 | 6 | ## Examples 7 | 8 | | After Evolution | Before Evolution | Model | 9 | |------------------|-----------------| ----- | 10 | | | | Model A * | 11 | | | | Model B * | 12 | | | | Model C * | 13 | 14 | After 10 cycles with prompt "T-Rex wearing aviator sunglasses, posing in front of a diffusion-generated Jurassic landscape, 80s vaporwave style, 4K" 15 | 16 | \* Images are paired with the closest model output from the initial population. All images share the same seed and diffusion settings. The VLM was not shown the prompt and default settings were used. 17 | 18 | \* Demo models are available here [https://huggingface.co/martyn/sdxl-evolved-demo-models](https://huggingface.co/martyn/sdxl-evolved-demo-models) 19 | 20 | ## Models 21 | 22 | More models are available here [https://huggingface.co/collections/martyn/evolved-sdxl-models-660b9185df88d3dbac68c052](https://huggingface.co/collections/martyn/evolved-sdxl-models-660b9185df88d3dbac68c052) 23 | 24 | and on Civit.ai [https://civitai.com/user/chandro/models](https://civitai.com/user/chandro/models) 25 | 26 | ## Installation 27 | 28 | 1. Clone the repository: 29 | ``` 30 | git clone 31 | ``` 32 | 33 | 2. Install the required dependencies: 34 | ``` 35 | pip install -r requirements.txt 36 | ``` 37 | 38 | 3. Set up the necessary API credentials: 39 | - Obtain an API key from Anthropic. 40 | - Set the `ANTHROPIC_API_KEY` environment variable with your API key. 41 | 42 | ## System requirements 43 | 44 | ### Minimum 45 | 46 | * A system capable of running inference on Stable Diffusion XL 47 | * Hard disk space(7GB\*(population_size+initial_population)) 48 | 49 | ### Recommended 50 | 51 | * A modern GPU 52 | * A lot of ram(> 128 GB). More allows for larger population size 53 | * A ramdisk to store the evolved population 54 | 55 | ## Usage 56 | 57 | To run the evolutionary framework, use the following command: 58 | 59 | ``` 60 | python sdxl-evolve.py model_list.yml [options] 61 | ``` 62 | 63 | The `model_list.yml` file should contain a list of initial model candidates in YAML format. 64 | 65 | Available options: 66 | - `-seed`: Random seed for reproducibility. 67 | - `-cycles`: Number of evolutionary cycles to run (default: 10). 68 | - `-elite`: Number of elite candidates to keep in each iteration (default: 10). 69 | - `-parents`: Number of parents for each child (default: 2). 70 | - `-population`: Size of the population (default: 50). 71 | - `-mutation`: Chance of mutation (default: 0.05). 72 | - `-output_path`: Directory to save the results (default: "evolve_output"). 73 | - `-eval_file`: Text file containing prompts for evaluation (default: "evals.txt"). 74 | - `-eval_samples`: Number of samples to evaluate between candidates (default: 3). 75 | - `-vlm`: The VLM to use, claude(default) or llava 76 | - `-append_prompt`: Adds to the end of eval prompts 77 | - `-negative_prompt`: Negative prompt in diffusion sampling 78 | - `-guidance_scale`: Guidance scale for diffusion sampling 79 | - `-diffusion_steps`: Number of iterations to diffuse with the candidate during eval 80 | - `-width`: Generation width 81 | - `-height`: Generation height 82 | - `-resize_width`: Width to resize images after generation 83 | - `-resize_height`: Height to resize images after generation 84 | 85 | ## Documentation 86 | 87 | The framework consists of the following main components: 88 | 89 | - `evolve.py`: Defines the core evolutionary algorithm, including candidate representation, selection, crossover, mutation, and population management. 90 | - `merge.py`: Provides functions for merging SafeTensor files to create new model candidates using DARE. 91 | - `sdxl-evolve.py`: The main script that orchestrates the evolutionary process, including image generation, evaluation, and comparison using the VLM. 92 | 93 | ### Details 94 | 95 | - `Candidate`: Represents a model candidate with its file path, initial population flag, p-value, and lambda value. 96 | - `selection`: Selects a subset of candidates as parents for breeding. 97 | - `mutation`: Applies random mutations to an offspring candidate. 98 | - `breed`: Performs crossover and mutation to create a new offspring candidate. 99 | - `evolve`: Evolves the population by selecting parents, breeding offspring, and updating the population. 100 | - `run_evolution`: Runs the evolutionary process for a specified number of cycles. 101 | - `load_candidates`: Loads initial model candidates from a YAML file. 102 | - `write_yaml`: Writes the population to a YAML file. 103 | - `generate_images`: Generates images using a Stable Diffusion XL pipeline for evaluation. 104 | - `vlm_judge`: Uses the VLM to compare and judge the quality of generated images. 105 | 106 | ## References 107 | 108 | - sakana.ai, Evolving New Foundation Models: Unleashing the Power of Automating Model Development: [https://sakana.ai/evolutionary-model-merge/](https://sakana.ai/evolutionary-model-merge/) 109 | - DARE, Language Models are Super Mario: Absorbing Abilities from Homologous Models as a Free Lunch: [https://arxiv.org/pdf/2311.03099.pdf](https://arxiv.org/pdf/2311.03099.pdf) 110 | - Stable Diffusion XL: [https://github.com/CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion) 111 | - Anthropic API: [https://www.anthropic.com](https://www.anthropic.com) 112 | - SafeTensors: [https://github.com/huggingface/safetensors](https://github.com/huggingface/safetensors) 113 | 114 | Feel free to contribute, report issues, or suggest improvements! 115 | -------------------------------------------------------------------------------- /evolve.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import merge 3 | import numpy as np 4 | import os 5 | import random 6 | import torch 7 | import uuid 8 | import yaml 9 | import sys 10 | from safetensors.torch import save_file, safe_open 11 | 12 | from pathlib import Path 13 | 14 | class Candidate: 15 | def __init__(self, file_path, p, lambda_val, initial_population=False, generation=0, seed=None): 16 | self.file_path = file_path 17 | self.initial_population = initial_population 18 | self.p = p 19 | self.lambda_val = lambda_val 20 | self.generation = generation 21 | if seed is None: 22 | self.seed = random.randint(-sys.maxsize-1, sys.maxsize) 23 | if initial_population: 24 | rand_point = torch.randn(4) 25 | self.location = rand_point / torch.norm(rand_point) 26 | 27 | def to_dict(self): 28 | return { 29 | "model": self.file_path, 30 | "p": self.p, 31 | "lambda": self.lambda_val, 32 | "generation": self.generation, 33 | "seed": self.seed 34 | } 35 | 36 | def random_p(): 37 | return (random.random() / 2.0)+0.1 38 | 39 | def random_lambda(): 40 | return random.random()*2.5+0.5 41 | 42 | def calculate_diversity_scores(candidates): 43 | locations = torch.stack([candidate.location for candidate in candidates]) 44 | centroid = torch.mean(locations, dim=0) 45 | distances = torch.norm(locations - centroid, dim=1) 46 | return distances 47 | 48 | def adjust_selection_probabilities(distances): 49 | # Normalize distances to get a base probability, ensuring not overwhelming influence 50 | min_dist, max_dist = distances.min(), distances.max() 51 | if min_dist - max_dist < 0.001: 52 | return [1.0/len(distances) for d in distances] 53 | # Simple linear transformation to ensure max 30% selection chance increase 54 | adjusted_probs = (distances - min_dist) / (max_dist - min_dist) * 0.3 + 0.7 55 | adjusted_probs /= adjusted_probs.sum() # Normalize to ensure it sums to 1 56 | return adjusted_probs 57 | 58 | def selection(population, num_parents): 59 | logging.info("Selecting candidates.") 60 | distances = calculate_diversity_scores(population) 61 | adjusted_probs = adjust_selection_probabilities(distances) 62 | selected_indices = np.random.choice(range(len(population)), size=num_parents, replace=False, p=adjusted_probs) 63 | 64 | return [population[i] for i in selected_indices] 65 | 66 | def perturb_tensor_map(tensor_map): 67 | for key, value in tensor_map.items(): 68 | if 'diffusion_model' in key: 69 | tensor_map[key] = value + torch.normal(torch.zeros_like(value), value.std() * 0.01) 70 | return tensor_map 71 | 72 | def perturb(candidate): 73 | tensor_map = {} 74 | with safe_open(candidate.file_path, framework="pt", device="cpu") as f1: 75 | for key in f1.keys(): 76 | if 'diffusion_model' in key: 77 | v = f1.get_tensor(key) 78 | tensor_map[key] = v + torch.normal(torch.zeros_like(v), v.std() * 0.01) 79 | else: 80 | tensor_map[key] = f1.get_tensor(key) 81 | return tensor_map 82 | 83 | def mutation(offspring): 84 | offspring.p = random_p() 85 | offspring.lambda_val = random_lambda() 86 | 87 | def breed(parents, mutation_rate, output_path): 88 | logging.info("Crossover and mutation...") 89 | file_path = str(Path(output_path) / (str(uuid.uuid4())+".safetensors")) 90 | offspring = Candidate(file_path, parents[0].p, parents[0].lambda_val) 91 | mutation_event = random.random() <= mutation_rate 92 | if mutation_event: 93 | mutation(offspring) 94 | tensor_map = merge.merge_safetensors(parents[0].file_path, parents[1].file_path, offspring.p, offspring.lambda_val) 95 | mutation_event = random.random() <= mutation_rate 96 | if mutation_event: 97 | tensor_map = perturb_tensor_map(tensor_map) 98 | 99 | 100 | for parent in parents[2:]: 101 | tensor_map = merge.merge_safetensors(offspring.file_path, parent.file_path, offspring.p, offspring.lambda_val) 102 | 103 | offspring.generation = max([parent.generation for parent in parents]) + 1 104 | offspring.location = torch.mean(torch.stack([parent.location for parent in parents]), dim=0) 105 | 106 | logging.info(f"Saving to {offspring.file_path}, from {','.join([p.file_path for p in parents])} p={offspring.p} λ={offspring.lambda_val} gen={offspring.generation}") 107 | save_file(tensor_map, offspring.file_path) 108 | del tensor_map 109 | return offspring 110 | 111 | def evolve(population, population_size, num_parents, mutation_rate, output_path, children_count=1): 112 | seed_population = list(population) 113 | while len(population) < population_size: 114 | parents = selection(seed_population, num_parents) 115 | for i in range(min(children_count, population_size - len(population))): 116 | offspring = breed(parents, mutation_rate, output_path) 117 | population.append(offspring) 118 | 119 | return population 120 | 121 | async def correct_insert_element(item, sorted_list, compare, top_k): 122 | if not sorted_list: 123 | return [item] 124 | # find a place for insertion 125 | insert_pos = await find_insertion_point(item, sorted_list, compare, top_k) 126 | # insert item tentatively 127 | sorted_list.insert(insert_pos, item) 128 | return sorted_list 129 | 130 | async def find_insertion_point(item, sorted_list, compare, top_k): 131 | # binary search variant that accounts for potential comparison errors 132 | low, high = 0, len(sorted_list) - 1 133 | while low <= high: 134 | if low > top_k and top_k > 0: 135 | return low 136 | mid = (low + high) // 2 137 | result = await compare(item, sorted_list[mid]) 138 | # adjust binary search based on comparison, considering potential inaccuracies 139 | if result == 1: 140 | high = mid - 1 141 | else: 142 | low = mid + 1 143 | return low 144 | 145 | async def sort_with_correction(buffer, compare, top_k=-1): 146 | sorted_list = [] 147 | for item in buffer: 148 | sorted_list = await correct_insert_element(item, sorted_list, compare, top_k) 149 | # correction mechanism here 150 | sorted_list = await correction_pass(sorted_list) 151 | return sorted_list 152 | 153 | async def correction_pass(sorted_list): 154 | # implement a correction pass, potentially re-comparing elements 155 | # this could involve heuristic-based swaps or reinsertions 156 | return sorted_list 157 | 158 | def choose_first_occurrence(s, opta, optb): 159 | # find the index of a and b 160 | index_a = s.find(opta) 161 | index_b = s.find(optb) 162 | # check if both a and b are found 163 | if index_a != -1 and index_b != -1: 164 | # return the one that occurs first 165 | if index_a < index_b: 166 | return opta 167 | else: 168 | return optb 169 | elif index_a != -1: 170 | # only a is found 171 | return opta 172 | elif index_b != -1: 173 | # only b is found 174 | return optb 175 | else: 176 | # neither a nor b is found 177 | return none 178 | 179 | async def run_evolution(population, elite_size, num_parents, population_size, mutation_rate, output_path, evaluation_criteria): 180 | logging.info("Before evolve") 181 | log_candidates(population) 182 | population = evolve(population, population_size, num_parents, mutation_rate, output_path) 183 | 184 | logging.info("Before sorting") 185 | log_candidates(population) 186 | population = await sort_with_correction(population, evaluation_criteria) 187 | logging.info("After sorting") 188 | log_candidates(population) 189 | for tokill in population[elite_size:]: 190 | if not tokill.initial_population: 191 | os.remove(tokill.file_path) 192 | return population[:elite_size] 193 | 194 | def log_candidates(population): 195 | format_str = "{{0}}. {{1:<24}}".format() 196 | for index, candidate in enumerate(population, start=1): 197 | logging.info(format_str.format(index, candidate.file_path)) 198 | 199 | def load_candidates(file_path): 200 | candidates = [] 201 | with open(file_path, 'r') as file: 202 | data = yaml.safe_load(file) 203 | for candidate_data in data["models"]: 204 | p = candidate_data.get('p', random_p()) 205 | lambda_val = candidate_data.get("lambda", random_lambda()) 206 | generation = candidate_data.get("generation", 0) 207 | seed = candidate_data.get("seed", None) 208 | candidate = Candidate(candidate_data['model'], p=p, lambda_val=lambda_val, initial_population=True, generation=generation, seed=seed) 209 | candidates.append(candidate) 210 | return candidates 211 | 212 | def write_yaml(population, path): 213 | yaml_str = yaml.dump({"models": [c.to_dict() for c in population]}, sort_keys=False) 214 | 215 | # Write the YAML string to a file 216 | with open(path, "w") as file: 217 | file.write(yaml_str) 218 | -------------------------------------------------------------------------------- /images/after_evolution_sample1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/255BITS/diffusion-evolver/10e19e4a2db73099f3c3aa591b724cc6a8616dec/images/after_evolution_sample1.png -------------------------------------------------------------------------------- /images/after_evolution_sample2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/255BITS/diffusion-evolver/10e19e4a2db73099f3c3aa591b724cc6a8616dec/images/after_evolution_sample2.png -------------------------------------------------------------------------------- /images/after_evolution_sample3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/255BITS/diffusion-evolver/10e19e4a2db73099f3c3aa591b724cc6a8616dec/images/after_evolution_sample3.png -------------------------------------------------------------------------------- /images/before_evolution_sample1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/255BITS/diffusion-evolver/10e19e4a2db73099f3c3aa591b724cc6a8616dec/images/before_evolution_sample1.png -------------------------------------------------------------------------------- /images/before_evolution_sample2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/255BITS/diffusion-evolver/10e19e4a2db73099f3c3aa591b724cc6a8616dec/images/before_evolution_sample2.png -------------------------------------------------------------------------------- /images/before_evolution_sample3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/255BITS/diffusion-evolver/10e19e4a2db73099f3c3aa591b724cc6a8616dec/images/before_evolution_sample3.png -------------------------------------------------------------------------------- /llava_util.py: -------------------------------------------------------------------------------- 1 | 2 | # This file is modified from https://github.com/haotian-liu/LLaVA/ 3 | 4 | import argparse 5 | import torch 6 | 7 | from transformers import TextStreamer 8 | 9 | from llava.constants import ( 10 | IMAGE_TOKEN_INDEX, 11 | DEFAULT_IMAGE_TOKEN, 12 | DEFAULT_IM_START_TOKEN, 13 | DEFAULT_IM_END_TOKEN, 14 | IMAGE_PLACEHOLDER, 15 | ) 16 | from llava.conversation import conv_templates, SeparatorStyle 17 | from llava.model.builder import load_pretrained_model 18 | from llava.utils import disable_torch_init 19 | from llava.mm_utils import ( 20 | process_images, 21 | tokenizer_image_token, 22 | get_model_name_from_path, 23 | KeywordsStoppingCriteria, 24 | ) 25 | 26 | import requests 27 | from PIL import Image 28 | from io import BytesIO 29 | import re 30 | 31 | 32 | def run_llava(model_path, conv_mode, query, images, sep=",", temperature=0.2, top_p=None, num_beams=1, max_new_tokens=512, model_base=None, device="cuda:0"): 33 | args = argparse.Namespace( 34 | model_path=model_path, 35 | model_base=model_base, 36 | temperature=temperature, 37 | top_p=top_p, 38 | num_beams=num_beams, 39 | max_new_tokens=max_new_tokens, 40 | conv_mode=conv_mode, 41 | query=query, 42 | images=images, 43 | sep=sep, 44 | device=device 45 | ) 46 | 47 | return eval_model(args) 48 | 49 | def load_image(image_file): 50 | if image_file.startswith("http") or image_file.startswith("https"): 51 | response = requests.get(image_file) 52 | image = Image.open(BytesIO(response.content)).convert("RGB") 53 | else: 54 | image = Image.open(image_file).convert("RGB") 55 | return image 56 | 57 | def load_images(image_files): 58 | out = [] 59 | for image_file in image_files: 60 | image = load_image(image_file) 61 | out.append(image) 62 | return out 63 | 64 | llavamodel = None 65 | 66 | def eval_model(args): 67 | # Model 68 | disable_torch_init() 69 | global llavamodel 70 | 71 | model_name = get_model_name_from_path(args.model_path) 72 | if llavamodel is None: 73 | llavamodel = load_pretrained_model( 74 | args.model_path, args.model_base, model_name, device=args.device 75 | ) 76 | tokenizer, model, image_processor, context_len = llavamodel 77 | 78 | qs = args.query 79 | image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN 80 | if IMAGE_PLACEHOLDER in qs: 81 | if model.config.mm_use_im_start_end: 82 | qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs) 83 | else: 84 | qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs) 85 | else: 86 | if DEFAULT_IMAGE_TOKEN not in qs: 87 | if model.config.mm_use_im_start_end: 88 | qs = image_token_se + "\n" + qs 89 | else: 90 | qs = DEFAULT_IMAGE_TOKEN + "\n" + qs 91 | 92 | if "llama-2" in model_name.lower(): 93 | conv_mode = "llava_llama_2" 94 | elif "v1" in model_name.lower(): 95 | conv_mode = "llava_v1" 96 | elif "mpt" in model_name.lower(): 97 | conv_mode = "mpt" 98 | else: 99 | conv_mode = "llava_v0" 100 | 101 | if args.conv_mode is not None and conv_mode != args.conv_mode: 102 | pass 103 | #print( 104 | # "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format( 105 | # conv_mode, args.conv_mode, args.conv_mode 106 | # ) 107 | #) 108 | else: 109 | args.conv_mode = conv_mode 110 | 111 | conv = conv_templates[args.conv_mode].copy() 112 | conv.append_message(conv.roles[0], qs) 113 | conv.append_message(conv.roles[1], None) 114 | prompt = conv.get_prompt() 115 | 116 | images_tensor = process_images( 117 | args.images, 118 | image_processor, 119 | model.config 120 | ).to(args.device, dtype=torch.float16) 121 | 122 | input_ids = ( 123 | tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") 124 | .unsqueeze(0) 125 | .to(args.device) 126 | ) 127 | 128 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 129 | keywords = [stop_str] 130 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 131 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) 132 | 133 | with torch.inference_mode(): 134 | output_ids = model.generate( 135 | input_ids, 136 | images=images_tensor, 137 | do_sample=True if args.temperature > 0 else False, 138 | temperature=args.temperature, 139 | #top_p=args.top_p, 140 | #num_beams=args.num_beams, 141 | max_new_tokens=args.max_new_tokens, 142 | use_cache=True, 143 | #stopping_criteria=[stopping_criteria], 144 | streamer=streamer, 145 | image_sizes=[image.size for image in args.images] 146 | ) 147 | 148 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() 149 | 150 | return outputs 151 | 152 | 153 | if __name__ == "__main__": 154 | parser = argparse.ArgumentParser() 155 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 156 | parser.add_argument("--model-base", type=str, default=None) 157 | parser.add_argument("--image-file", type=str, required=True) 158 | parser.add_argument("--query", type=str, required=True) 159 | parser.add_argument("--conv-mode", type=str, default=None) 160 | parser.add_argument("--sep", type=str, default=",") 161 | parser.add_argument("--temperature", type=float, default=0.2) 162 | parser.add_argument("--top_p", type=float, default=None) 163 | parser.add_argument("--num_beams", type=int, default=1) 164 | parser.add_argument("--max_new_tokens", type=int, default=512) 165 | args = parser.parse_args() 166 | 167 | eval_model(args) 168 | -------------------------------------------------------------------------------- /merge.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import torch 5 | import torch.nn.functional as F 6 | from safetensors.torch import safe_open, save_file 7 | 8 | def merge_tensors(tensor1, tensor2, p): 9 | # Calculate the delta of the weights 10 | delta = tensor2 - tensor1 11 | # Generate the mask m^t from Bernoulli distribution 12 | m = torch.bernoulli(torch.full(delta.shape, p)).to(delta.dtype) 13 | # Apply the mask to the delta to get δ̃^t 14 | delta_tilde = m * delta 15 | # Scale the masked delta by the dropout rate to get δ̂^t 16 | delta_hat = delta_tilde / (1 - p) 17 | return delta_hat 18 | 19 | def merge_safetensors(file_path1, file_path2, p, lambda_val): 20 | merged_tensors = {} 21 | 22 | with safe_open(file_path1, framework="pt", device="cpu") as f1, safe_open(file_path2, framework="pt", device="cpu") as f2: 23 | keys1 = set(f1.keys()) 24 | keys2 = set(f2.keys()) 25 | common_keys = keys1.intersection(keys2) 26 | 27 | for key in common_keys: 28 | tensor1 = f1.get_tensor(key) 29 | tensor2 = f2.get_tensor(key) 30 | merged_tensors[key] = tensor1 + lambda_val * merge_tensors(tensor1, tensor2, p) 31 | 32 | return merged_tensors 33 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.28.0 2 | annotated-types==0.6.0 3 | anthropic==0.21.3 4 | anyio==4.3.0 5 | certifi==2024.2.2 6 | charset-normalizer==3.3.2 7 | diffusers==0.27.2 8 | distro==1.9.0 9 | filelock==3.13.3 10 | fsspec==2024.3.1 11 | h11==0.14.0 12 | httpcore==1.0.5 13 | httpx==0.27.0 14 | huggingface-hub==0.22.1 15 | idna==3.6 16 | importlib_metadata==7.1.0 17 | Jinja2==3.1.3 18 | MarkupSafe==2.1.5 19 | mpmath==1.3.0 20 | networkx==3.2.1 21 | numpy==1.26.4 22 | nvidia-cublas-cu12==12.1.3.1 23 | nvidia-cuda-cupti-cu12==12.1.105 24 | nvidia-cuda-nvrtc-cu12==12.1.105 25 | nvidia-cuda-runtime-cu12==12.1.105 26 | nvidia-cudnn-cu12==8.9.2.26 27 | nvidia-cufft-cu12==11.0.2.54 28 | nvidia-curand-cu12==10.3.2.106 29 | nvidia-cusolver-cu12==11.4.5.107 30 | nvidia-cusparse-cu12==12.1.0.106 31 | nvidia-nccl-cu12==2.19.3 32 | nvidia-nvjitlink-cu12==12.4.99 33 | nvidia-nvtx-cu12==12.1.105 34 | packaging==24.0 35 | peft==0.10.0 36 | pillow==10.2.0 37 | psutil==5.9.8 38 | pydantic==2.6.4 39 | pydantic_core==2.16.3 40 | PyYAML==6.0.1 41 | regex==2023.12.25 42 | requests==2.31.0 43 | safetensors==0.4.2 44 | setuptools==69.2.0 45 | sniffio==1.3.1 46 | sympy==1.12 47 | tokenizers==0.15.2 48 | torch==2.2.1 49 | tqdm==4.66.2 50 | transformers==4.39.1 51 | typing_extensions==4.10.0 52 | urllib3==2.2.1 53 | wheel==0.43.0 54 | zipp==3.18.1 55 | -------------------------------------------------------------------------------- /scripts/eval-model.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import base64 3 | import io 4 | import json 5 | import random 6 | import re 7 | import requests 8 | import sys 9 | import time 10 | import torch 11 | import os 12 | 13 | from PIL import Image, PngImagePlugin 14 | 15 | NUM_IMAGES=128 16 | random.seed(238) 17 | 18 | def load_random_evals(file_path, count): 19 | evals = [] 20 | 21 | with open(file_path, 'r') as file: 22 | lines = file.readlines() 23 | 24 | count = min(count, len(lines)) 25 | selected_lines = random.sample(lines, count) 26 | 27 | for line in selected_lines: 28 | evals.append({"prompt": line.strip(), "seed": random.randint(0, 2**32)}) 29 | 30 | return evals 31 | 32 | def load_file_and_return_random_line(file_path): 33 | with open(file_path, 'r') as file: 34 | file_map = file.read().split('\n') 35 | 36 | line = random.choice(file_map) 37 | return line 38 | 39 | def wildcard_replace(s, directory): 40 | if directory is None: 41 | return s 42 | wildcards = re.findall(r'__(.*?)__', s) 43 | replaced = [load_file_and_return_random_line(directory+"/"+w+".txt") for w in wildcards] 44 | replacements = dict(zip(wildcards, replaced)) 45 | for wildcard, replacement in replacements.items(): 46 | s = s.replace('__{}__'.format(wildcard), replacement) 47 | 48 | return s 49 | 50 | def generate_image(prompt, negative_prompt, config_file=None, fname=None): 51 | seed = random.randint(0, 2**32-1) 52 | if config_file is None: 53 | with open("txt2img/sdwebui_config.json", 'r') as file: 54 | config_file = json.load(file) 55 | headers = {"Content-Type": "application/json"} 56 | data = dict(config_file) 57 | 58 | prompt = wildcard_replace(prompt, "wildcards") 59 | data["prompt"]=prompt 60 | data["negative_prompt"]=negative_prompt 61 | 62 | data["seed"]=seed 63 | url = data["sd_webui_url"] 64 | del data["sd_webui_url"] 65 | 66 | response = requests.post(url, headers=headers, data=json.dumps(data)) 67 | 68 | if response.status_code == 200: 69 | r = response.json() 70 | image = Image.open(io.BytesIO(base64.b64decode(r['images'][0].split(",",1)[0]))) 71 | jsoninfo = json.loads(r['info']) 72 | #print(jsoninfo["infotexts"][0]) 73 | png_payload = { 74 | "image": "data:image/png;base64," + r['images'][0] 75 | } 76 | response2 = requests.post(url=url.replace("txt2img", "png-info"), json=png_payload) 77 | 78 | pnginfo = PngImagePlugin.PngInfo() 79 | pnginfo.add_text("parameters", response2.json().get("info")) 80 | 81 | if fname is None: 82 | fname= random_fname() 83 | image.save(fname, pnginfo=pnginfo) 84 | return fname 85 | 86 | else: 87 | print(f"Request failed with status code {response.status_code}") 88 | return generate_image(prompt, negative_prompt, config_file, fname) 89 | 90 | def run_eval(working_dir): 91 | for i in range(NUM_IMAGES): 92 | config = { 93 | "sd_webui_url": "http://localhost:3000/sdapi/v1/txt2img", 94 | "height": 1152, 95 | "width": 896, 96 | "sampler_name": "Euler", 97 | "scheduler": "SGM Uniform", 98 | "cfg_scale": 1, 99 | "steps": 8 100 | } 101 | prompt = f"__person__, __sdprompt__, __bg__" 102 | img = generate_image(prompt, "nsfw", config, working_dir+"/"+str(i)+".png") 103 | 104 | async def main(model, working_dir="."): 105 | subdir = os.path.join(working_dir, "evals", model.split("/")[-1].split(".")[0], "imgs") 106 | os.makedirs(subdir, exist_ok=True) 107 | run_eval(subdir) 108 | 109 | 110 | if __name__ == "__main__": 111 | if len(sys.argv) > 1: 112 | models = sys.argv[1:] 113 | else: 114 | print("Usage: script_name.py +") 115 | sys.exit(1) 116 | for model in models: 117 | asyncio.run(main(model)) 118 | 119 | -------------------------------------------------------------------------------- /sd3-evolve.py: -------------------------------------------------------------------------------- 1 | import anthropic 2 | import argparse 3 | import asyncio 4 | import base64 5 | import evolve 6 | import logging 7 | import time 8 | import os 9 | import random 10 | import torch 11 | import uuid 12 | 13 | from transformers import CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, T5TokenizerFast 14 | 15 | from PIL import Image 16 | from dataclasses import dataclass 17 | from diffusers import StableDiffusion3Pipeline 18 | from huggingface_hub import hf_hub_download 19 | from io import BytesIO 20 | from pathlib import Path 21 | from tqdm.asyncio import tqdm 22 | 23 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 24 | 25 | def parse_arguments(): 26 | parser = argparse.ArgumentParser(description="Evolutionary merge") 27 | parser.add_argument('model_list', type=str, help='yml file containing list of models in the initial population') 28 | parser.add_argument('-seed', dest='seed', type=int, default=None, help='Random seed') 29 | parser.add_argument('-cycles', dest='cycles', type=int, default=10, help='Number of evolutionary cycles') 30 | parser.add_argument('-elite', dest='elite', type=int, default=5, help='Number of elite candidates to keep every iteration') 31 | parser.add_argument('-parents', dest='parents', type=int, default=2, help='Number of parents for each child') 32 | parser.add_argument('-population', dest='population', type=int, default=20, help='Size of population') 33 | parser.add_argument('-mutation', dest='mutation', type=float, default=0.05, help='Chance of mutation') 34 | parser.add_argument('-output_path', dest='output_path', type=str, default="evolve_output", help='Directory to save results') 35 | parser.add_argument('-criteria', dest='criteria', type=str, default='Which candidate generated more colorful images?', help='Criteria for decision making in the VLM.') 36 | parser.add_argument('-eval_file', dest='eval_file', type=str, default='evals.txt', help='A txt file containing a newline delimited list of prompts to evaluation against') 37 | parser.add_argument('-eval_samples', dest='eval_samples', type=int, default=2, help='The number of samples to evaluate between candidates') 38 | parser.add_argument('-device', dest='device', type=str, default="cuda:0", help='The device to run on') 39 | parser.add_argument('-reintroduction_threshold', dest='reintroduction_threshold', type=float, default=0, help='The chance to reintroduce an initial model back into the elite population. Can help with solution diversity.') 40 | parser.add_argument('-vlm', dest='vlm', type=str, default="claude", help='The vlm to use. claude or llava') 41 | parser.add_argument('-append_prompt', dest='append_prompt', type=str, default="", help='Appends to the prompt') 42 | parser.add_argument('-negative_prompt', dest='negative_prompt', type=str, default="", help='Set the negative prompt') 43 | parser.add_argument('-guidance_scale', dest='guidance_scale', type=float, default=1, help='The guidance scale to use') 44 | parser.add_argument('-scheduler', dest='scheduler', type=str, default="sgm_uniform", help='The diffusion scheduler to use') 45 | parser.add_argument('-diffusion_steps', dest='diffusion_steps', type=int, default=8, help='The number of diffusion steps to run') 46 | parser.add_argument('-diffusion_prompt_change', dest='diffusion_prompt_change', type=str, choices=["every_cycle", "never"], default="cycle", help='The type of generation cache to use. Controls when vlm image prompts are changed. Choices: never, every_cycle') 47 | parser.add_argument("-width", dest='width', type=int, default=1024, help='Width of diffusion samples to generate') 48 | parser.add_argument("-height", dest='height', type=int, default=1024, help='Height of diffusion samples to generate') 49 | parser.add_argument("-resize_width", dest='resize_width', type=int, default=512, help='Width to resized diffusion samples before sending to the VLM') 50 | parser.add_argument("-resize_height", dest='resize_height', type=int, default=512, help='Height to resize diffusion samples before sending to the VLM') 51 | parser.add_argument("-vae", dest='vae', type=str, default=None, help='Custom VAE to use during sampling') 52 | parser.add_argument("-perturb_seed_population", dest='perturb_seed_population', type=int, default=0, help='Build this many children of the seed population') 53 | return parser.parse_args() 54 | 55 | def generate_images(file_path, evals, device, cache, settings): 56 | if file_path in cache: 57 | return cache[file_path] 58 | images = [] 59 | logging.info(f"Loading {file_path}") 60 | 61 | dtype = torch.bfloat16 62 | model_id = "stabilityai/stable-diffusion-3-medium-diffusers" 63 | 64 | pipe = StableDiffusion3Pipeline.from_single_file(file_path, text_encoder = None, text_encoder_2 = None, text_encoder_3 = None, device=device) 65 | pipe.text_encoder = CLIPTextModelWithProjection.from_pretrained(model_id, subfolder="text_encoder") 66 | pipe.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(model_id, subfolder="text_encoder_2") 67 | #pipe.text_encoder_3 = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_3") 68 | pipe.tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer") 69 | pipe.tokenizer_2 = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer_2") 70 | #pipe.tokenizer_3 = T5TokenizerFast.from_pretrained(model_id, subfolder="tokenizer_3") 71 | pipe = pipe.to(dtype).to(device) 72 | 73 | for i, evl in enumerate(evals): 74 | image = pipe(evl['prompt']+settings.append_prompt, width=settings.width, height=settings.height, negative_prompt=settings.negative_prompt, num_inference_steps=settings.diffusion_steps, guidance_scale=settings.guidance_scale, generator=torch.manual_seed(evl['seed'])).images[0] 75 | if settings.resize_width != settings.width: 76 | image = image.resize((settings.resize_width, settings.resize_height)) 77 | image.save(f"output-evolve-{file_path.split('/')[-1]}-{i}.png") 78 | images.append(image) 79 | 80 | del pipe 81 | cache[file_path] = images 82 | return images 83 | 84 | def generate_b64_images(*args): 85 | # Assuming generate_images() returns a list of PIL Image objects 86 | images = generate_images(*args) 87 | b64_images = [] 88 | 89 | for img in images: 90 | buffered = BytesIO() 91 | img.save(buffered, format="PNG") 92 | img_str = base64.b64encode(buffered.getvalue()).decode() 93 | b64_images.append(img_str) 94 | 95 | return b64_images 96 | 97 | def load_random_evals(file_path, count): 98 | evals = [] 99 | 100 | with open(file_path, 'r') as file: 101 | lines = file.readlines() 102 | 103 | count = min(count, len(lines)) 104 | selected_lines = random.sample(lines, count) 105 | 106 | for line in selected_lines: 107 | evals.append({"prompt": line.strip(), "seed": random.randint(0, 2**32)}) 108 | 109 | return evals 110 | 111 | def claude_vlm_judge(criteria, prompts, b64_images_a, b64_images_b): 112 | client = anthropic.Anthropic() 113 | media_type = "image/png" 114 | prompts_list = "\n".join(["{0}. {1}".format(i, prompt) for i, prompt in enumerate(prompts)]) 115 | # begin_text = f""" 116 | #Here are the prompts for the generations: 117 | #``` 118 | #{prompts_list} 119 | #``` 120 | # 121 | #Each candidate will be given these prompts to generate images. First you will receive candidate 1 generations based on these prompts, then candidate 2. 122 | # 123 | begin_text = f"""You will first see both candidates images then judge which did better based on the following criteria: 124 | ``` 125 | Criteria: {criteria} 126 | ``` 127 | 128 | Candidate 1 generations: 129 | """.strip() 130 | end_text = """ 131 | Which candidate won based on the criteria? If candidate 1, output '1'. If candidate 2, output '2'. This is automated, the first 1 or 2 you output will be the winner. 132 | """.strip() 133 | messages = [ 134 | { 135 | "role": "user", 136 | "content": [ 137 | { 138 | "type": "text", 139 | "text": begin_text 140 | }, 141 | *[ 142 | { 143 | "type": "image", 144 | "source": { 145 | "type": "base64", 146 | "media_type": media_type, 147 | "data": b64_image, 148 | }, 149 | } for b64_image in b64_images_a], 150 | { 151 | "type": "text", 152 | "text": "Candidate 2 generations:" 153 | }, 154 | *[ 155 | { 156 | "type": "image", 157 | "source": { 158 | "type": "base64", 159 | "media_type": media_type, 160 | "data": b64_image, 161 | }, 162 | } for b64_image in b64_images_b], 163 | { 164 | "type": "text", 165 | "text": end_text 166 | } 167 | ], 168 | } 169 | ] 170 | 171 | model = "claude-3-haiku-20240307" 172 | message = client.messages.create( 173 | model=model, 174 | max_tokens=128, 175 | system="You are diffusion evolver AI, a judge for an image generation contest. You will be presented images from two models with the same prompt and seed. At the end you will give your judgement based on a specified criteria.", 176 | messages=messages, 177 | ) 178 | text = message.content[0].text 179 | for i, ch in enumerate(text): 180 | if ch == "1" or ch == "2": 181 | return int(ch) 182 | logging.info("wtf bad output", text) 183 | raise "error" 184 | 185 | def claude_vlm_judge_with_retry(*args, max_retries=3, initial_wait=1, max_wait=10): 186 | for attempt in range(max_retries): 187 | try: 188 | return claude_vlm_judge(*args) 189 | except Exception as e: 190 | wait_time = min(max_wait, initial_wait * 2 ** attempt) 191 | wait_time += random.uniform(0, wait_time * 0.2) # Adding random jitter 192 | if attempt < max_retries - 1: 193 | # Log the full stack trace before retrying 194 | logging.exception(f"Attempt {attempt + 1} failed. Retrying in {wait_time:.2f} seconds...") 195 | time.sleep(wait_time) 196 | else: 197 | # Log the full stack trace before raising the exception after all retries have failed 198 | logging.exception("All attempts failed. Raising exception.") 199 | raise 200 | 201 | def combine_pil(a, b): 202 | total_width = a.width + b.width 203 | max_height = max(a.height, b.height) 204 | 205 | combined_image = Image.new('RGB', (total_width, max_height)) 206 | 207 | combined_image.paste(a, (0, 0)) 208 | combined_image.paste(b, (a.width, 0)) 209 | 210 | return combined_image 211 | 212 | def llava_vlm_decide(prompt, img, device): 213 | import llava_util 214 | model = "liuhaotian/llava-v1.6-mistral-7b" 215 | text = llava_util.run_llava(model, None, prompt, [img], device=device, max_new_tokens=128) 216 | for i, ch in enumerate(text): 217 | if ch == "1" or ch == "2": 218 | return int(ch) 219 | logging.info("wtf bad output", text) 220 | raise "error" 221 | 222 | def llava_vlm_judge(criteria, prompts, b64_images_a, b64_images_b, device): 223 | for (prompt, img_a, img_b) in zip(prompts, b64_images_a, b64_images_b): 224 | img_combine = combine_pil(img_a, img_b) 225 | prompt = f"You are a judge in an image generation contest. {criteria} '1' for the image on the left, '2' for the image on the right. Answer only '1'(left) or '2'(right). This is automated and the first number in your answer will be chosen." 226 | return llava_vlm_decide(prompt, img_combine, device) 227 | 228 | def llava_vlm_judge_with_retry(*args, max_retries=3): 229 | for i in range(max_retries): 230 | try: 231 | return llava_vlm_judge(*args) 232 | except Exception as e: 233 | if i < max_retries: 234 | logging.exception("Llava did not give output. Retrying...") 235 | else: 236 | logging.exception("Llava failed!") 237 | raise 238 | 239 | def compare(cache, criteria, device, evals, metrics, vlm, settings): 240 | async def vlm_compare(a: evolve.Candidate, b:evolve.Candidate): 241 | cache_key = 'compare:'+a.file_path+'.'+b.file_path 242 | if cache_key in cache: 243 | return cache[cache_key] 244 | reverse = random.random() > 0.5 245 | prompts = [evl["prompt"] for evl in evals] 246 | if reverse: 247 | a, b = b, a 248 | 249 | if vlm == 'claude': 250 | b64_images_a = generate_b64_images(a.file_path, evals, device, cache, settings) 251 | b64_images_b = generate_b64_images(b.file_path, evals, device, cache, settings) 252 | judgement = claude_vlm_judge_with_retry(criteria, prompts, b64_images_a, b64_images_b) 253 | elif vlm == 'llava': 254 | images_a = generate_images(a.file_path, evals, device, cache, settings) 255 | images_b = generate_images(b.file_path, evals, device, cache, settings) 256 | judgement = llava_vlm_judge_with_retry(criteria, prompts, images_a, images_b, device) 257 | else: 258 | raise "vlm not supported:" + vlm 259 | 260 | if reverse: 261 | judgement = (1 if judgement == 2 else 2) 262 | metrics.total += 1 263 | 264 | if judgement == 1: 265 | metrics.yays += 1 266 | else: 267 | metrics.nays += 1 268 | logging.info(f"Number of comparisons Total: {metrics.total} Yay: {metrics.yays} Nay: {metrics.nays}") 269 | 270 | 271 | if judgement == 1: 272 | cache[cache_key] = 1 273 | return 1 274 | cache[cache_key] = -1 275 | return -1 276 | return vlm_compare 277 | 278 | @dataclass 279 | class Metrics: 280 | total: int = 0 281 | yays: int = 0 282 | nays: int = 0 283 | 284 | @dataclass 285 | class DiffusionSettings: 286 | guidance_scale: int 287 | negative_prompt: str 288 | append_prompt: str 289 | diffusion_steps: int 290 | width: int 291 | height: int 292 | resize_width: int 293 | resize_height: int 294 | scheduler: str 295 | vae: str 296 | 297 | async def main(): 298 | # Parse command-line arguments 299 | args = parse_arguments() 300 | if args.seed is not None: 301 | torch.manual_seed(args.seed) 302 | os.makedirs(args.output_path, exist_ok=True) 303 | metrics = Metrics() 304 | cache = {} 305 | evals = load_random_evals(args.eval_file, args.eval_samples) 306 | settings = DiffusionSettings( 307 | append_prompt = args.append_prompt, 308 | diffusion_steps = args.diffusion_steps, 309 | guidance_scale = args.guidance_scale, 310 | height = args.height, 311 | negative_prompt = args.negative_prompt, 312 | resize_height = args.resize_height, 313 | resize_width = args.resize_width, 314 | width = args.width, 315 | vae = args.vae, 316 | scheduler = args.scheduler 317 | ) 318 | initial_population = evolve.load_candidates(args.model_list) 319 | initial_populiation_count = len(initial_population) 320 | while(len(initial_population) < args.perturb_seed_population): 321 | parent = random.choice(initial_population[:initial_populiation_count]) 322 | file_path = str(Path(args.output_path) / (str(uuid.uuid4())+".safetensors")) 323 | offspring = evolve.Candidate(file_path, parent.p, parent.lambda_val, initial_population=True) 324 | offspring.generation = parent.generation + 1 325 | print("perturbing from(clone)", parent.file_path) 326 | tensor_map = evolve.perturb(parent) 327 | print("saving", offspring.file_path) 328 | evolve.save_file(tensor_map, offspring.file_path) 329 | del tensor_map 330 | initial_population.append(offspring) 331 | print("--", initial_population) 332 | 333 | population = list(initial_population) 334 | evolve.write_yaml(population, Path(args.output_path) / "initial.yaml") 335 | logging.info("Beginning evolution") 336 | 337 | async for i in tqdm(range(args.cycles), desc='Evolving'): 338 | if args.diffusion_prompt_change == "every_cycle": 339 | evals = load_random_evals(args.eval_file, args.eval_samples) 340 | cache = {} 341 | comparator = compare(cache, args.criteria, args.device, evals, metrics, args.vlm, settings) 342 | population = await evolve.run_evolution(population, args.elite, args.parents, args.population, args.mutation, args.output_path, comparator) 343 | evolve.write_yaml(population, Path(args.output_path) / f"step-{i}.yaml") 344 | if random.random() < args.reintroduction_threshold: 345 | population.append(random.choice(initial_population)) 346 | 347 | logging.info("Resulting population:") 348 | evolve.log_candidates(population) 349 | if __name__ == "__main__": 350 | asyncio.run(main()) 351 | -------------------------------------------------------------------------------- /sdxl-evolve.py: -------------------------------------------------------------------------------- 1 | import anthropic 2 | import argparse 3 | import asyncio 4 | import base64 5 | import evolve 6 | import logging 7 | import time 8 | import os 9 | import random 10 | import torch 11 | 12 | from PIL import Image 13 | from dataclasses import dataclass 14 | from diffusers import AutoencoderKL 15 | from diffusers import EulerDiscreteScheduler 16 | from diffusers import StableDiffusionXLPipeline 17 | from huggingface_hub import hf_hub_download 18 | from io import BytesIO 19 | from pathlib import Path 20 | from tqdm.asyncio import tqdm 21 | 22 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 23 | 24 | def parse_arguments(): 25 | parser = argparse.ArgumentParser(description="Evolutionary merge") 26 | parser.add_argument('model_list', type=str, help='yml file containing list of models in the initial population') 27 | parser.add_argument('-seed', dest='seed', type=int, default=None, help='Random seed') 28 | parser.add_argument('-cycles', dest='cycles', type=int, default=10, help='Number of evolutionary cycles') 29 | parser.add_argument('-elite', dest='elite', type=int, default=5, help='Number of elite candidates to keep every iteration') 30 | parser.add_argument('-parents', dest='parents', type=int, default=2, help='Number of parents for each child') 31 | parser.add_argument('-population', dest='population', type=int, default=20, help='Size of population') 32 | parser.add_argument('-mutation', dest='mutation', type=float, default=0.05, help='Chance of mutation') 33 | parser.add_argument('-output_path', dest='output_path', type=str, default="evolve_output", help='Directory to save results') 34 | parser.add_argument('-criteria', dest='criteria', type=str, default='Which candidate generated more colorful images?', help='Criteria for decision making in the VLM.') 35 | parser.add_argument('-eval_file', dest='eval_file', type=str, default='evals.txt', help='A txt file containing a newline delimited list of prompts to evaluation against') 36 | parser.add_argument('-eval_samples', dest='eval_samples', type=int, default=2, help='The number of samples to evaluate between candidates') 37 | parser.add_argument('-device', dest='device', type=str, default="cuda:0", help='The device to run on') 38 | parser.add_argument('-reintroduction_threshold', dest='reintroduction_threshold', type=float, default=0, help='The chance to reintroduce an initial model back into the elite population. Can help with solution diversity.') 39 | parser.add_argument('-vlm', dest='vlm', type=str, default="claude", help='The vlm to use. claude or llava') 40 | parser.add_argument('-append_prompt', dest='append_prompt', type=str, default="", help='Appends to the prompt') 41 | parser.add_argument('-negative_prompt', dest='negative_prompt', type=str, default="", help='Set the negative prompt') 42 | parser.add_argument('-guidance_scale', dest='guidance_scale', type=float, default=1, help='The guidance scale to use') 43 | parser.add_argument('-scheduler', dest='scheduler', type=str, default="sgm_uniform", help='The diffusion scheduler to use') 44 | parser.add_argument('-diffusion_steps', dest='diffusion_steps', type=int, default=8, help='The number of diffusion steps to run') 45 | parser.add_argument('-diffusion_prompt_change', dest='diffusion_prompt_change', type=str, choices=["every_cycle", "never"], default="cycle", help='The type of generation cache to use. Controls when vlm image prompts are changed. Choices: never, every_cycle') 46 | parser.add_argument("-width", dest='width', type=int, default=1024, help='Width of diffusion samples to generate') 47 | parser.add_argument("-height", dest='height', type=int, default=1024, help='Height of diffusion samples to generate') 48 | parser.add_argument("-resize_width", dest='resize_width', type=int, default=512, help='Width to resized diffusion samples before sending to the VLM') 49 | parser.add_argument("-resize_height", dest='resize_height', type=int, default=512, help='Height to resize diffusion samples before sending to the VLM') 50 | parser.add_argument("-vae", dest='vae', type=str, default=None, help='Custom VAE to use during sampling') 51 | return parser.parse_args() 52 | 53 | def generate_images(file_path, evals, device, cache, settings): 54 | if file_path in cache: 55 | return cache[file_path] 56 | images = [] 57 | logging.info(f"Loading {file_path}") 58 | 59 | pipe = StableDiffusionXLPipeline.from_single_file(file_path, torch_dtype=torch.float16, variant="fp16", use_safetensors=True).to(device) 60 | if settings.vae: 61 | pipe.vae = AutoencoderKL.from_single_file(settings.vae, torch_dtype=torch.float16).to(pipe.device) 62 | if settings.scheduler == "sgm_uniform": 63 | pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") 64 | else: 65 | pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) 66 | 67 | for i, evl in enumerate(evals): 68 | image = pipe(evl['prompt']+settings.append_prompt, width=settings.width, height=settings.height, negative_prompt=settings.negative_prompt, num_inference_steps=settings.diffusion_steps, guidance_scale=settings.guidance_scale, generator=torch.manual_seed(evl['seed'])).images[0] 69 | if settings.resize_width != settings.width: 70 | image = image.resize((settings.resize_width, settings.resize_height)) 71 | image.save(f"output-evolve-{file_path.split('/')[-1]}-{i}.png") 72 | images.append(image) 73 | 74 | del pipe 75 | cache[file_path] = images 76 | return images 77 | 78 | def generate_b64_images(*args): 79 | # Assuming generate_images() returns a list of PIL Image objects 80 | images = generate_images(*args) 81 | b64_images = [] 82 | 83 | for img in images: 84 | buffered = BytesIO() 85 | img.save(buffered, format="PNG") 86 | img_str = base64.b64encode(buffered.getvalue()).decode() 87 | b64_images.append(img_str) 88 | 89 | return b64_images 90 | 91 | def load_random_evals(file_path, count): 92 | evals = [] 93 | 94 | with open(file_path, 'r') as file: 95 | lines = file.readlines() 96 | 97 | count = min(count, len(lines)) 98 | selected_lines = random.sample(lines, count) 99 | 100 | for line in selected_lines: 101 | evals.append({"prompt": line.strip(), "seed": random.randint(0, 2**32)}) 102 | 103 | return evals 104 | 105 | def claude_vlm_judge(criteria, prompts, b64_images_a, b64_images_b): 106 | client = anthropic.Anthropic() 107 | media_type = "image/png" 108 | prompts_list = "\n".join(["{0}. {1}".format(i, prompt) for i, prompt in enumerate(prompts)]) 109 | # begin_text = f""" 110 | #Here are the prompts for the generations: 111 | #``` 112 | #{prompts_list} 113 | #``` 114 | # 115 | #Each candidate will be given these prompts to generate images. First you will receive candidate 1 generations based on these prompts, then candidate 2. 116 | # 117 | begin_text = f"""You will first see both candidates images then judge which did better based on the following criteria: 118 | ``` 119 | Criteria: {criteria} 120 | ``` 121 | 122 | Candidate 1 generations: 123 | """.strip() 124 | end_text = """ 125 | Which candidate won based on the criteria? If candidate 1, output '1'. If candidate 2, output '2'. This is automated, the first 1 or 2 you output will be the winner. 126 | """.strip() 127 | messages = [ 128 | { 129 | "role": "user", 130 | "content": [ 131 | { 132 | "type": "text", 133 | "text": begin_text 134 | }, 135 | *[ 136 | { 137 | "type": "image", 138 | "source": { 139 | "type": "base64", 140 | "media_type": media_type, 141 | "data": b64_image, 142 | }, 143 | } for b64_image in b64_images_a], 144 | { 145 | "type": "text", 146 | "text": "Candidate 2 generations:" 147 | }, 148 | *[ 149 | { 150 | "type": "image", 151 | "source": { 152 | "type": "base64", 153 | "media_type": media_type, 154 | "data": b64_image, 155 | }, 156 | } for b64_image in b64_images_b], 157 | { 158 | "type": "text", 159 | "text": end_text 160 | } 161 | ], 162 | } 163 | ] 164 | 165 | model = "claude-3-haiku-20240307" 166 | message = client.messages.create( 167 | model=model, 168 | max_tokens=128, 169 | system="You are diffusion evolver AI, a judge for an image generation contest. You will be presented images from two models with the same prompt and seed. At the end you will give your judgement based on a specified criteria.", 170 | messages=messages, 171 | ) 172 | text = message.content[0].text 173 | for i, ch in enumerate(text): 174 | if ch == "1" or ch == "2": 175 | return int(ch) 176 | logging.info("wtf bad output", text) 177 | raise "error" 178 | 179 | def claude_vlm_judge_with_retry(*args, max_retries=3, initial_wait=1, max_wait=10): 180 | for attempt in range(max_retries): 181 | try: 182 | return claude_vlm_judge(*args) 183 | except Exception as e: 184 | wait_time = min(max_wait, initial_wait * 2 ** attempt) 185 | wait_time += random.uniform(0, wait_time * 0.2) # Adding random jitter 186 | if attempt < max_retries - 1: 187 | # Log the full stack trace before retrying 188 | logging.exception(f"Attempt {attempt + 1} failed. Retrying in {wait_time:.2f} seconds...") 189 | time.sleep(wait_time) 190 | else: 191 | # Log the full stack trace before raising the exception after all retries have failed 192 | logging.exception("All attempts failed. Raising exception.") 193 | raise 194 | 195 | def combine_pil(a, b): 196 | total_width = a.width + b.width 197 | max_height = max(a.height, b.height) 198 | 199 | combined_image = Image.new('RGB', (total_width, max_height)) 200 | 201 | combined_image.paste(a, (0, 0)) 202 | combined_image.paste(b, (a.width, 0)) 203 | 204 | return combined_image 205 | 206 | def llava_vlm_decide(prompt, img, device): 207 | import llava_util 208 | model = "liuhaotian/llava-v1.6-mistral-7b" 209 | text = llava_util.run_llava(model, None, prompt, [img], device=device, max_new_tokens=128) 210 | for i, ch in enumerate(text): 211 | if ch == "1" or ch == "2": 212 | return int(ch) 213 | logging.info("wtf bad output", text) 214 | raise "error" 215 | 216 | def llava_vlm_judge(criteria, prompts, b64_images_a, b64_images_b, device): 217 | for (prompt, img_a, img_b) in zip(prompts, b64_images_a, b64_images_b): 218 | img_combine = combine_pil(img_a, img_b) 219 | prompt = f"You are a judge in an image generation contest. {criteria} '1' for the image on the left, '2' for the image on the right. Answer only '1'(left) or '2'(right). This is automated and the first number in your answer will be chosen." 220 | return llava_vlm_decide(prompt, img_combine, device) 221 | 222 | def llava_vlm_judge_with_retry(*args, max_retries=3): 223 | for i in range(max_retries): 224 | try: 225 | return llava_vlm_judge(*args) 226 | except Exception as e: 227 | if i < max_retries: 228 | logging.exception("Llava did not give output. Retrying...") 229 | else: 230 | logging.exception("Llava failed!") 231 | raise 232 | 233 | def compare(cache, criteria, device, evals, metrics, vlm, settings): 234 | async def vlm_compare(a: evolve.Candidate, b:evolve.Candidate): 235 | cache_key = 'compare:'+a.file_path+'.'+b.file_path 236 | if cache_key in cache: 237 | return cache[cache_key] 238 | reverse = random.random() > 0.5 239 | prompts = [evl["prompt"] for evl in evals] 240 | if reverse: 241 | a, b = b, a 242 | 243 | if vlm == 'claude': 244 | b64_images_a = generate_b64_images(a.file_path, evals, device, cache, settings) 245 | b64_images_b = generate_b64_images(b.file_path, evals, device, cache, settings) 246 | judgement = claude_vlm_judge_with_retry(criteria, prompts, b64_images_a, b64_images_b) 247 | elif vlm == 'llava': 248 | images_a = generate_images(a.file_path, evals, device, cache, settings) 249 | images_b = generate_images(b.file_path, evals, device, cache, settings) 250 | judgement = llava_vlm_judge_with_retry(criteria, prompts, images_a, images_b, device) 251 | else: 252 | raise "vlm not supported:" + vlm 253 | 254 | if reverse: 255 | judgement = (1 if judgement == 2 else 2) 256 | metrics.total += 1 257 | 258 | if judgement == 1: 259 | metrics.yays += 1 260 | else: 261 | metrics.nays += 1 262 | logging.info(f"Number of comparisons Total: {metrics.total} Yay: {metrics.yays} Nay: {metrics.nays}") 263 | 264 | 265 | if judgement == 1: 266 | cache[cache_key] = 1 267 | return 1 268 | cache[cache_key] = -1 269 | return -1 270 | return vlm_compare 271 | 272 | @dataclass 273 | class Metrics: 274 | total: int = 0 275 | yays: int = 0 276 | nays: int = 0 277 | 278 | @dataclass 279 | class DiffusionSettings: 280 | guidance_scale: int 281 | negative_prompt: str 282 | append_prompt: str 283 | diffusion_steps: int 284 | width: int 285 | height: int 286 | resize_width: int 287 | resize_height: int 288 | scheduler: str 289 | vae: str 290 | 291 | async def main(): 292 | # Parse command-line arguments 293 | args = parse_arguments() 294 | if args.seed is not None: 295 | torch.manual_seed(args.seed) 296 | os.makedirs(args.output_path, exist_ok=True) 297 | metrics = Metrics() 298 | cache = {} 299 | evals = load_random_evals(args.eval_file, args.eval_samples) 300 | settings = DiffusionSettings( 301 | append_prompt = args.append_prompt, 302 | diffusion_steps = args.diffusion_steps, 303 | guidance_scale = args.guidance_scale, 304 | height = args.height, 305 | negative_prompt = args.negative_prompt, 306 | resize_height = args.resize_height, 307 | resize_width = args.resize_width, 308 | width = args.width, 309 | vae = args.vae, 310 | scheduler = args.scheduler 311 | ) 312 | initial_population = evolve.load_candidates(args.model_list) 313 | population = list(initial_population) 314 | evolve.write_yaml(population, Path(args.output_path) / "initial.yaml") 315 | logging.info("Beginning evolution") 316 | 317 | async for i in tqdm(range(args.cycles), desc='Evolving'): 318 | if args.diffusion_prompt_change == "every_cycle": 319 | evals = load_random_evals(args.eval_file, args.eval_samples) 320 | cache = {} 321 | comparator = compare(cache, args.criteria, args.device, evals, metrics, args.vlm, settings) 322 | population = await evolve.run_evolution(population, args.elite, args.parents, args.population, args.mutation, args.output_path, comparator) 323 | evolve.write_yaml(population, Path(args.output_path) / f"step-{i}.yaml") 324 | if random.random() < args.reintroduction_threshold: 325 | population.append(random.choice(initial_population)) 326 | 327 | logging.info("Resulting population:") 328 | evolve.log_candidates(population) 329 | if __name__ == "__main__": 330 | asyncio.run(main()) 331 | --------------------------------------------------------------------------------