├── README.md
├── evolve.py
├── images
├── after_evolution_sample1.png
├── after_evolution_sample2.png
├── after_evolution_sample3.png
├── before_evolution_sample1.png
├── before_evolution_sample2.png
└── before_evolution_sample3.png
├── llava_util.py
├── merge.py
├── requirements.txt
├── scripts
└── eval-model.py
├── sd3-evolve.py
└── sdxl-evolve.py
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Diffusion Evolver
3 |
4 | This project is an evolutionary framework for optimizing Stable Diffusion XL models. It allows you to evolve a population of models through crossover, mutation, and selection based on their performance judged by a VLM.
5 |
6 | ## Examples
7 |
8 | | After Evolution | Before Evolution | Model |
9 | |------------------|-----------------| ----- |
10 | |
|
| Model A * |
11 | |
|
| Model B * |
12 | |
|
| Model C * |
13 |
14 | After 10 cycles with prompt "T-Rex wearing aviator sunglasses, posing in front of a diffusion-generated Jurassic landscape, 80s vaporwave style, 4K"
15 |
16 | \* Images are paired with the closest model output from the initial population. All images share the same seed and diffusion settings. The VLM was not shown the prompt and default settings were used.
17 |
18 | \* Demo models are available here [https://huggingface.co/martyn/sdxl-evolved-demo-models](https://huggingface.co/martyn/sdxl-evolved-demo-models)
19 |
20 | ## Models
21 |
22 | More models are available here [https://huggingface.co/collections/martyn/evolved-sdxl-models-660b9185df88d3dbac68c052](https://huggingface.co/collections/martyn/evolved-sdxl-models-660b9185df88d3dbac68c052)
23 |
24 | and on Civit.ai [https://civitai.com/user/chandro/models](https://civitai.com/user/chandro/models)
25 |
26 | ## Installation
27 |
28 | 1. Clone the repository:
29 | ```
30 | git clone
31 | ```
32 |
33 | 2. Install the required dependencies:
34 | ```
35 | pip install -r requirements.txt
36 | ```
37 |
38 | 3. Set up the necessary API credentials:
39 | - Obtain an API key from Anthropic.
40 | - Set the `ANTHROPIC_API_KEY` environment variable with your API key.
41 |
42 | ## System requirements
43 |
44 | ### Minimum
45 |
46 | * A system capable of running inference on Stable Diffusion XL
47 | * Hard disk space(7GB\*(population_size+initial_population))
48 |
49 | ### Recommended
50 |
51 | * A modern GPU
52 | * A lot of ram(> 128 GB). More allows for larger population size
53 | * A ramdisk to store the evolved population
54 |
55 | ## Usage
56 |
57 | To run the evolutionary framework, use the following command:
58 |
59 | ```
60 | python sdxl-evolve.py model_list.yml [options]
61 | ```
62 |
63 | The `model_list.yml` file should contain a list of initial model candidates in YAML format.
64 |
65 | Available options:
66 | - `-seed`: Random seed for reproducibility.
67 | - `-cycles`: Number of evolutionary cycles to run (default: 10).
68 | - `-elite`: Number of elite candidates to keep in each iteration (default: 10).
69 | - `-parents`: Number of parents for each child (default: 2).
70 | - `-population`: Size of the population (default: 50).
71 | - `-mutation`: Chance of mutation (default: 0.05).
72 | - `-output_path`: Directory to save the results (default: "evolve_output").
73 | - `-eval_file`: Text file containing prompts for evaluation (default: "evals.txt").
74 | - `-eval_samples`: Number of samples to evaluate between candidates (default: 3).
75 | - `-vlm`: The VLM to use, claude(default) or llava
76 | - `-append_prompt`: Adds to the end of eval prompts
77 | - `-negative_prompt`: Negative prompt in diffusion sampling
78 | - `-guidance_scale`: Guidance scale for diffusion sampling
79 | - `-diffusion_steps`: Number of iterations to diffuse with the candidate during eval
80 | - `-width`: Generation width
81 | - `-height`: Generation height
82 | - `-resize_width`: Width to resize images after generation
83 | - `-resize_height`: Height to resize images after generation
84 |
85 | ## Documentation
86 |
87 | The framework consists of the following main components:
88 |
89 | - `evolve.py`: Defines the core evolutionary algorithm, including candidate representation, selection, crossover, mutation, and population management.
90 | - `merge.py`: Provides functions for merging SafeTensor files to create new model candidates using DARE.
91 | - `sdxl-evolve.py`: The main script that orchestrates the evolutionary process, including image generation, evaluation, and comparison using the VLM.
92 |
93 | ### Details
94 |
95 | - `Candidate`: Represents a model candidate with its file path, initial population flag, p-value, and lambda value.
96 | - `selection`: Selects a subset of candidates as parents for breeding.
97 | - `mutation`: Applies random mutations to an offspring candidate.
98 | - `breed`: Performs crossover and mutation to create a new offspring candidate.
99 | - `evolve`: Evolves the population by selecting parents, breeding offspring, and updating the population.
100 | - `run_evolution`: Runs the evolutionary process for a specified number of cycles.
101 | - `load_candidates`: Loads initial model candidates from a YAML file.
102 | - `write_yaml`: Writes the population to a YAML file.
103 | - `generate_images`: Generates images using a Stable Diffusion XL pipeline for evaluation.
104 | - `vlm_judge`: Uses the VLM to compare and judge the quality of generated images.
105 |
106 | ## References
107 |
108 | - sakana.ai, Evolving New Foundation Models: Unleashing the Power of Automating Model Development: [https://sakana.ai/evolutionary-model-merge/](https://sakana.ai/evolutionary-model-merge/)
109 | - DARE, Language Models are Super Mario: Absorbing Abilities from Homologous Models as a Free Lunch: [https://arxiv.org/pdf/2311.03099.pdf](https://arxiv.org/pdf/2311.03099.pdf)
110 | - Stable Diffusion XL: [https://github.com/CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion)
111 | - Anthropic API: [https://www.anthropic.com](https://www.anthropic.com)
112 | - SafeTensors: [https://github.com/huggingface/safetensors](https://github.com/huggingface/safetensors)
113 |
114 | Feel free to contribute, report issues, or suggest improvements!
115 |
--------------------------------------------------------------------------------
/evolve.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import merge
3 | import numpy as np
4 | import os
5 | import random
6 | import torch
7 | import uuid
8 | import yaml
9 | import sys
10 | from safetensors.torch import save_file, safe_open
11 |
12 | from pathlib import Path
13 |
14 | class Candidate:
15 | def __init__(self, file_path, p, lambda_val, initial_population=False, generation=0, seed=None):
16 | self.file_path = file_path
17 | self.initial_population = initial_population
18 | self.p = p
19 | self.lambda_val = lambda_val
20 | self.generation = generation
21 | if seed is None:
22 | self.seed = random.randint(-sys.maxsize-1, sys.maxsize)
23 | if initial_population:
24 | rand_point = torch.randn(4)
25 | self.location = rand_point / torch.norm(rand_point)
26 |
27 | def to_dict(self):
28 | return {
29 | "model": self.file_path,
30 | "p": self.p,
31 | "lambda": self.lambda_val,
32 | "generation": self.generation,
33 | "seed": self.seed
34 | }
35 |
36 | def random_p():
37 | return (random.random() / 2.0)+0.1
38 |
39 | def random_lambda():
40 | return random.random()*2.5+0.5
41 |
42 | def calculate_diversity_scores(candidates):
43 | locations = torch.stack([candidate.location for candidate in candidates])
44 | centroid = torch.mean(locations, dim=0)
45 | distances = torch.norm(locations - centroid, dim=1)
46 | return distances
47 |
48 | def adjust_selection_probabilities(distances):
49 | # Normalize distances to get a base probability, ensuring not overwhelming influence
50 | min_dist, max_dist = distances.min(), distances.max()
51 | if min_dist - max_dist < 0.001:
52 | return [1.0/len(distances) for d in distances]
53 | # Simple linear transformation to ensure max 30% selection chance increase
54 | adjusted_probs = (distances - min_dist) / (max_dist - min_dist) * 0.3 + 0.7
55 | adjusted_probs /= adjusted_probs.sum() # Normalize to ensure it sums to 1
56 | return adjusted_probs
57 |
58 | def selection(population, num_parents):
59 | logging.info("Selecting candidates.")
60 | distances = calculate_diversity_scores(population)
61 | adjusted_probs = adjust_selection_probabilities(distances)
62 | selected_indices = np.random.choice(range(len(population)), size=num_parents, replace=False, p=adjusted_probs)
63 |
64 | return [population[i] for i in selected_indices]
65 |
66 | def perturb_tensor_map(tensor_map):
67 | for key, value in tensor_map.items():
68 | if 'diffusion_model' in key:
69 | tensor_map[key] = value + torch.normal(torch.zeros_like(value), value.std() * 0.01)
70 | return tensor_map
71 |
72 | def perturb(candidate):
73 | tensor_map = {}
74 | with safe_open(candidate.file_path, framework="pt", device="cpu") as f1:
75 | for key in f1.keys():
76 | if 'diffusion_model' in key:
77 | v = f1.get_tensor(key)
78 | tensor_map[key] = v + torch.normal(torch.zeros_like(v), v.std() * 0.01)
79 | else:
80 | tensor_map[key] = f1.get_tensor(key)
81 | return tensor_map
82 |
83 | def mutation(offspring):
84 | offspring.p = random_p()
85 | offspring.lambda_val = random_lambda()
86 |
87 | def breed(parents, mutation_rate, output_path):
88 | logging.info("Crossover and mutation...")
89 | file_path = str(Path(output_path) / (str(uuid.uuid4())+".safetensors"))
90 | offspring = Candidate(file_path, parents[0].p, parents[0].lambda_val)
91 | mutation_event = random.random() <= mutation_rate
92 | if mutation_event:
93 | mutation(offspring)
94 | tensor_map = merge.merge_safetensors(parents[0].file_path, parents[1].file_path, offspring.p, offspring.lambda_val)
95 | mutation_event = random.random() <= mutation_rate
96 | if mutation_event:
97 | tensor_map = perturb_tensor_map(tensor_map)
98 |
99 |
100 | for parent in parents[2:]:
101 | tensor_map = merge.merge_safetensors(offspring.file_path, parent.file_path, offspring.p, offspring.lambda_val)
102 |
103 | offspring.generation = max([parent.generation for parent in parents]) + 1
104 | offspring.location = torch.mean(torch.stack([parent.location for parent in parents]), dim=0)
105 |
106 | logging.info(f"Saving to {offspring.file_path}, from {','.join([p.file_path for p in parents])} p={offspring.p} λ={offspring.lambda_val} gen={offspring.generation}")
107 | save_file(tensor_map, offspring.file_path)
108 | del tensor_map
109 | return offspring
110 |
111 | def evolve(population, population_size, num_parents, mutation_rate, output_path, children_count=1):
112 | seed_population = list(population)
113 | while len(population) < population_size:
114 | parents = selection(seed_population, num_parents)
115 | for i in range(min(children_count, population_size - len(population))):
116 | offspring = breed(parents, mutation_rate, output_path)
117 | population.append(offspring)
118 |
119 | return population
120 |
121 | async def correct_insert_element(item, sorted_list, compare, top_k):
122 | if not sorted_list:
123 | return [item]
124 | # find a place for insertion
125 | insert_pos = await find_insertion_point(item, sorted_list, compare, top_k)
126 | # insert item tentatively
127 | sorted_list.insert(insert_pos, item)
128 | return sorted_list
129 |
130 | async def find_insertion_point(item, sorted_list, compare, top_k):
131 | # binary search variant that accounts for potential comparison errors
132 | low, high = 0, len(sorted_list) - 1
133 | while low <= high:
134 | if low > top_k and top_k > 0:
135 | return low
136 | mid = (low + high) // 2
137 | result = await compare(item, sorted_list[mid])
138 | # adjust binary search based on comparison, considering potential inaccuracies
139 | if result == 1:
140 | high = mid - 1
141 | else:
142 | low = mid + 1
143 | return low
144 |
145 | async def sort_with_correction(buffer, compare, top_k=-1):
146 | sorted_list = []
147 | for item in buffer:
148 | sorted_list = await correct_insert_element(item, sorted_list, compare, top_k)
149 | # correction mechanism here
150 | sorted_list = await correction_pass(sorted_list)
151 | return sorted_list
152 |
153 | async def correction_pass(sorted_list):
154 | # implement a correction pass, potentially re-comparing elements
155 | # this could involve heuristic-based swaps or reinsertions
156 | return sorted_list
157 |
158 | def choose_first_occurrence(s, opta, optb):
159 | # find the index of a and b
160 | index_a = s.find(opta)
161 | index_b = s.find(optb)
162 | # check if both a and b are found
163 | if index_a != -1 and index_b != -1:
164 | # return the one that occurs first
165 | if index_a < index_b:
166 | return opta
167 | else:
168 | return optb
169 | elif index_a != -1:
170 | # only a is found
171 | return opta
172 | elif index_b != -1:
173 | # only b is found
174 | return optb
175 | else:
176 | # neither a nor b is found
177 | return none
178 |
179 | async def run_evolution(population, elite_size, num_parents, population_size, mutation_rate, output_path, evaluation_criteria):
180 | logging.info("Before evolve")
181 | log_candidates(population)
182 | population = evolve(population, population_size, num_parents, mutation_rate, output_path)
183 |
184 | logging.info("Before sorting")
185 | log_candidates(population)
186 | population = await sort_with_correction(population, evaluation_criteria)
187 | logging.info("After sorting")
188 | log_candidates(population)
189 | for tokill in population[elite_size:]:
190 | if not tokill.initial_population:
191 | os.remove(tokill.file_path)
192 | return population[:elite_size]
193 |
194 | def log_candidates(population):
195 | format_str = "{{0}}. {{1:<24}}".format()
196 | for index, candidate in enumerate(population, start=1):
197 | logging.info(format_str.format(index, candidate.file_path))
198 |
199 | def load_candidates(file_path):
200 | candidates = []
201 | with open(file_path, 'r') as file:
202 | data = yaml.safe_load(file)
203 | for candidate_data in data["models"]:
204 | p = candidate_data.get('p', random_p())
205 | lambda_val = candidate_data.get("lambda", random_lambda())
206 | generation = candidate_data.get("generation", 0)
207 | seed = candidate_data.get("seed", None)
208 | candidate = Candidate(candidate_data['model'], p=p, lambda_val=lambda_val, initial_population=True, generation=generation, seed=seed)
209 | candidates.append(candidate)
210 | return candidates
211 |
212 | def write_yaml(population, path):
213 | yaml_str = yaml.dump({"models": [c.to_dict() for c in population]}, sort_keys=False)
214 |
215 | # Write the YAML string to a file
216 | with open(path, "w") as file:
217 | file.write(yaml_str)
218 |
--------------------------------------------------------------------------------
/images/after_evolution_sample1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/255BITS/diffusion-evolver/10e19e4a2db73099f3c3aa591b724cc6a8616dec/images/after_evolution_sample1.png
--------------------------------------------------------------------------------
/images/after_evolution_sample2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/255BITS/diffusion-evolver/10e19e4a2db73099f3c3aa591b724cc6a8616dec/images/after_evolution_sample2.png
--------------------------------------------------------------------------------
/images/after_evolution_sample3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/255BITS/diffusion-evolver/10e19e4a2db73099f3c3aa591b724cc6a8616dec/images/after_evolution_sample3.png
--------------------------------------------------------------------------------
/images/before_evolution_sample1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/255BITS/diffusion-evolver/10e19e4a2db73099f3c3aa591b724cc6a8616dec/images/before_evolution_sample1.png
--------------------------------------------------------------------------------
/images/before_evolution_sample2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/255BITS/diffusion-evolver/10e19e4a2db73099f3c3aa591b724cc6a8616dec/images/before_evolution_sample2.png
--------------------------------------------------------------------------------
/images/before_evolution_sample3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/255BITS/diffusion-evolver/10e19e4a2db73099f3c3aa591b724cc6a8616dec/images/before_evolution_sample3.png
--------------------------------------------------------------------------------
/llava_util.py:
--------------------------------------------------------------------------------
1 |
2 | # This file is modified from https://github.com/haotian-liu/LLaVA/
3 |
4 | import argparse
5 | import torch
6 |
7 | from transformers import TextStreamer
8 |
9 | from llava.constants import (
10 | IMAGE_TOKEN_INDEX,
11 | DEFAULT_IMAGE_TOKEN,
12 | DEFAULT_IM_START_TOKEN,
13 | DEFAULT_IM_END_TOKEN,
14 | IMAGE_PLACEHOLDER,
15 | )
16 | from llava.conversation import conv_templates, SeparatorStyle
17 | from llava.model.builder import load_pretrained_model
18 | from llava.utils import disable_torch_init
19 | from llava.mm_utils import (
20 | process_images,
21 | tokenizer_image_token,
22 | get_model_name_from_path,
23 | KeywordsStoppingCriteria,
24 | )
25 |
26 | import requests
27 | from PIL import Image
28 | from io import BytesIO
29 | import re
30 |
31 |
32 | def run_llava(model_path, conv_mode, query, images, sep=",", temperature=0.2, top_p=None, num_beams=1, max_new_tokens=512, model_base=None, device="cuda:0"):
33 | args = argparse.Namespace(
34 | model_path=model_path,
35 | model_base=model_base,
36 | temperature=temperature,
37 | top_p=top_p,
38 | num_beams=num_beams,
39 | max_new_tokens=max_new_tokens,
40 | conv_mode=conv_mode,
41 | query=query,
42 | images=images,
43 | sep=sep,
44 | device=device
45 | )
46 |
47 | return eval_model(args)
48 |
49 | def load_image(image_file):
50 | if image_file.startswith("http") or image_file.startswith("https"):
51 | response = requests.get(image_file)
52 | image = Image.open(BytesIO(response.content)).convert("RGB")
53 | else:
54 | image = Image.open(image_file).convert("RGB")
55 | return image
56 |
57 | def load_images(image_files):
58 | out = []
59 | for image_file in image_files:
60 | image = load_image(image_file)
61 | out.append(image)
62 | return out
63 |
64 | llavamodel = None
65 |
66 | def eval_model(args):
67 | # Model
68 | disable_torch_init()
69 | global llavamodel
70 |
71 | model_name = get_model_name_from_path(args.model_path)
72 | if llavamodel is None:
73 | llavamodel = load_pretrained_model(
74 | args.model_path, args.model_base, model_name, device=args.device
75 | )
76 | tokenizer, model, image_processor, context_len = llavamodel
77 |
78 | qs = args.query
79 | image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
80 | if IMAGE_PLACEHOLDER in qs:
81 | if model.config.mm_use_im_start_end:
82 | qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
83 | else:
84 | qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
85 | else:
86 | if DEFAULT_IMAGE_TOKEN not in qs:
87 | if model.config.mm_use_im_start_end:
88 | qs = image_token_se + "\n" + qs
89 | else:
90 | qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
91 |
92 | if "llama-2" in model_name.lower():
93 | conv_mode = "llava_llama_2"
94 | elif "v1" in model_name.lower():
95 | conv_mode = "llava_v1"
96 | elif "mpt" in model_name.lower():
97 | conv_mode = "mpt"
98 | else:
99 | conv_mode = "llava_v0"
100 |
101 | if args.conv_mode is not None and conv_mode != args.conv_mode:
102 | pass
103 | #print(
104 | # "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
105 | # conv_mode, args.conv_mode, args.conv_mode
106 | # )
107 | #)
108 | else:
109 | args.conv_mode = conv_mode
110 |
111 | conv = conv_templates[args.conv_mode].copy()
112 | conv.append_message(conv.roles[0], qs)
113 | conv.append_message(conv.roles[1], None)
114 | prompt = conv.get_prompt()
115 |
116 | images_tensor = process_images(
117 | args.images,
118 | image_processor,
119 | model.config
120 | ).to(args.device, dtype=torch.float16)
121 |
122 | input_ids = (
123 | tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
124 | .unsqueeze(0)
125 | .to(args.device)
126 | )
127 |
128 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
129 | keywords = [stop_str]
130 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
131 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
132 |
133 | with torch.inference_mode():
134 | output_ids = model.generate(
135 | input_ids,
136 | images=images_tensor,
137 | do_sample=True if args.temperature > 0 else False,
138 | temperature=args.temperature,
139 | #top_p=args.top_p,
140 | #num_beams=args.num_beams,
141 | max_new_tokens=args.max_new_tokens,
142 | use_cache=True,
143 | #stopping_criteria=[stopping_criteria],
144 | streamer=streamer,
145 | image_sizes=[image.size for image in args.images]
146 | )
147 |
148 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
149 |
150 | return outputs
151 |
152 |
153 | if __name__ == "__main__":
154 | parser = argparse.ArgumentParser()
155 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
156 | parser.add_argument("--model-base", type=str, default=None)
157 | parser.add_argument("--image-file", type=str, required=True)
158 | parser.add_argument("--query", type=str, required=True)
159 | parser.add_argument("--conv-mode", type=str, default=None)
160 | parser.add_argument("--sep", type=str, default=",")
161 | parser.add_argument("--temperature", type=float, default=0.2)
162 | parser.add_argument("--top_p", type=float, default=None)
163 | parser.add_argument("--num_beams", type=int, default=1)
164 | parser.add_argument("--max_new_tokens", type=int, default=512)
165 | args = parser.parse_args()
166 |
167 | eval_model(args)
168 |
--------------------------------------------------------------------------------
/merge.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import shutil
4 | import torch
5 | import torch.nn.functional as F
6 | from safetensors.torch import safe_open, save_file
7 |
8 | def merge_tensors(tensor1, tensor2, p):
9 | # Calculate the delta of the weights
10 | delta = tensor2 - tensor1
11 | # Generate the mask m^t from Bernoulli distribution
12 | m = torch.bernoulli(torch.full(delta.shape, p)).to(delta.dtype)
13 | # Apply the mask to the delta to get δ̃^t
14 | delta_tilde = m * delta
15 | # Scale the masked delta by the dropout rate to get δ̂^t
16 | delta_hat = delta_tilde / (1 - p)
17 | return delta_hat
18 |
19 | def merge_safetensors(file_path1, file_path2, p, lambda_val):
20 | merged_tensors = {}
21 |
22 | with safe_open(file_path1, framework="pt", device="cpu") as f1, safe_open(file_path2, framework="pt", device="cpu") as f2:
23 | keys1 = set(f1.keys())
24 | keys2 = set(f2.keys())
25 | common_keys = keys1.intersection(keys2)
26 |
27 | for key in common_keys:
28 | tensor1 = f1.get_tensor(key)
29 | tensor2 = f2.get_tensor(key)
30 | merged_tensors[key] = tensor1 + lambda_val * merge_tensors(tensor1, tensor2, p)
31 |
32 | return merged_tensors
33 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.28.0
2 | annotated-types==0.6.0
3 | anthropic==0.21.3
4 | anyio==4.3.0
5 | certifi==2024.2.2
6 | charset-normalizer==3.3.2
7 | diffusers==0.27.2
8 | distro==1.9.0
9 | filelock==3.13.3
10 | fsspec==2024.3.1
11 | h11==0.14.0
12 | httpcore==1.0.5
13 | httpx==0.27.0
14 | huggingface-hub==0.22.1
15 | idna==3.6
16 | importlib_metadata==7.1.0
17 | Jinja2==3.1.3
18 | MarkupSafe==2.1.5
19 | mpmath==1.3.0
20 | networkx==3.2.1
21 | numpy==1.26.4
22 | nvidia-cublas-cu12==12.1.3.1
23 | nvidia-cuda-cupti-cu12==12.1.105
24 | nvidia-cuda-nvrtc-cu12==12.1.105
25 | nvidia-cuda-runtime-cu12==12.1.105
26 | nvidia-cudnn-cu12==8.9.2.26
27 | nvidia-cufft-cu12==11.0.2.54
28 | nvidia-curand-cu12==10.3.2.106
29 | nvidia-cusolver-cu12==11.4.5.107
30 | nvidia-cusparse-cu12==12.1.0.106
31 | nvidia-nccl-cu12==2.19.3
32 | nvidia-nvjitlink-cu12==12.4.99
33 | nvidia-nvtx-cu12==12.1.105
34 | packaging==24.0
35 | peft==0.10.0
36 | pillow==10.2.0
37 | psutil==5.9.8
38 | pydantic==2.6.4
39 | pydantic_core==2.16.3
40 | PyYAML==6.0.1
41 | regex==2023.12.25
42 | requests==2.31.0
43 | safetensors==0.4.2
44 | setuptools==69.2.0
45 | sniffio==1.3.1
46 | sympy==1.12
47 | tokenizers==0.15.2
48 | torch==2.2.1
49 | tqdm==4.66.2
50 | transformers==4.39.1
51 | typing_extensions==4.10.0
52 | urllib3==2.2.1
53 | wheel==0.43.0
54 | zipp==3.18.1
55 |
--------------------------------------------------------------------------------
/scripts/eval-model.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import base64
3 | import io
4 | import json
5 | import random
6 | import re
7 | import requests
8 | import sys
9 | import time
10 | import torch
11 | import os
12 |
13 | from PIL import Image, PngImagePlugin
14 |
15 | NUM_IMAGES=128
16 | random.seed(238)
17 |
18 | def load_random_evals(file_path, count):
19 | evals = []
20 |
21 | with open(file_path, 'r') as file:
22 | lines = file.readlines()
23 |
24 | count = min(count, len(lines))
25 | selected_lines = random.sample(lines, count)
26 |
27 | for line in selected_lines:
28 | evals.append({"prompt": line.strip(), "seed": random.randint(0, 2**32)})
29 |
30 | return evals
31 |
32 | def load_file_and_return_random_line(file_path):
33 | with open(file_path, 'r') as file:
34 | file_map = file.read().split('\n')
35 |
36 | line = random.choice(file_map)
37 | return line
38 |
39 | def wildcard_replace(s, directory):
40 | if directory is None:
41 | return s
42 | wildcards = re.findall(r'__(.*?)__', s)
43 | replaced = [load_file_and_return_random_line(directory+"/"+w+".txt") for w in wildcards]
44 | replacements = dict(zip(wildcards, replaced))
45 | for wildcard, replacement in replacements.items():
46 | s = s.replace('__{}__'.format(wildcard), replacement)
47 |
48 | return s
49 |
50 | def generate_image(prompt, negative_prompt, config_file=None, fname=None):
51 | seed = random.randint(0, 2**32-1)
52 | if config_file is None:
53 | with open("txt2img/sdwebui_config.json", 'r') as file:
54 | config_file = json.load(file)
55 | headers = {"Content-Type": "application/json"}
56 | data = dict(config_file)
57 |
58 | prompt = wildcard_replace(prompt, "wildcards")
59 | data["prompt"]=prompt
60 | data["negative_prompt"]=negative_prompt
61 |
62 | data["seed"]=seed
63 | url = data["sd_webui_url"]
64 | del data["sd_webui_url"]
65 |
66 | response = requests.post(url, headers=headers, data=json.dumps(data))
67 |
68 | if response.status_code == 200:
69 | r = response.json()
70 | image = Image.open(io.BytesIO(base64.b64decode(r['images'][0].split(",",1)[0])))
71 | jsoninfo = json.loads(r['info'])
72 | #print(jsoninfo["infotexts"][0])
73 | png_payload = {
74 | "image": "data:image/png;base64," + r['images'][0]
75 | }
76 | response2 = requests.post(url=url.replace("txt2img", "png-info"), json=png_payload)
77 |
78 | pnginfo = PngImagePlugin.PngInfo()
79 | pnginfo.add_text("parameters", response2.json().get("info"))
80 |
81 | if fname is None:
82 | fname= random_fname()
83 | image.save(fname, pnginfo=pnginfo)
84 | return fname
85 |
86 | else:
87 | print(f"Request failed with status code {response.status_code}")
88 | return generate_image(prompt, negative_prompt, config_file, fname)
89 |
90 | def run_eval(working_dir):
91 | for i in range(NUM_IMAGES):
92 | config = {
93 | "sd_webui_url": "http://localhost:3000/sdapi/v1/txt2img",
94 | "height": 1152,
95 | "width": 896,
96 | "sampler_name": "Euler",
97 | "scheduler": "SGM Uniform",
98 | "cfg_scale": 1,
99 | "steps": 8
100 | }
101 | prompt = f"__person__, __sdprompt__, __bg__"
102 | img = generate_image(prompt, "nsfw", config, working_dir+"/"+str(i)+".png")
103 |
104 | async def main(model, working_dir="."):
105 | subdir = os.path.join(working_dir, "evals", model.split("/")[-1].split(".")[0], "imgs")
106 | os.makedirs(subdir, exist_ok=True)
107 | run_eval(subdir)
108 |
109 |
110 | if __name__ == "__main__":
111 | if len(sys.argv) > 1:
112 | models = sys.argv[1:]
113 | else:
114 | print("Usage: script_name.py +")
115 | sys.exit(1)
116 | for model in models:
117 | asyncio.run(main(model))
118 |
119 |
--------------------------------------------------------------------------------
/sd3-evolve.py:
--------------------------------------------------------------------------------
1 | import anthropic
2 | import argparse
3 | import asyncio
4 | import base64
5 | import evolve
6 | import logging
7 | import time
8 | import os
9 | import random
10 | import torch
11 | import uuid
12 |
13 | from transformers import CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
14 |
15 | from PIL import Image
16 | from dataclasses import dataclass
17 | from diffusers import StableDiffusion3Pipeline
18 | from huggingface_hub import hf_hub_download
19 | from io import BytesIO
20 | from pathlib import Path
21 | from tqdm.asyncio import tqdm
22 |
23 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
24 |
25 | def parse_arguments():
26 | parser = argparse.ArgumentParser(description="Evolutionary merge")
27 | parser.add_argument('model_list', type=str, help='yml file containing list of models in the initial population')
28 | parser.add_argument('-seed', dest='seed', type=int, default=None, help='Random seed')
29 | parser.add_argument('-cycles', dest='cycles', type=int, default=10, help='Number of evolutionary cycles')
30 | parser.add_argument('-elite', dest='elite', type=int, default=5, help='Number of elite candidates to keep every iteration')
31 | parser.add_argument('-parents', dest='parents', type=int, default=2, help='Number of parents for each child')
32 | parser.add_argument('-population', dest='population', type=int, default=20, help='Size of population')
33 | parser.add_argument('-mutation', dest='mutation', type=float, default=0.05, help='Chance of mutation')
34 | parser.add_argument('-output_path', dest='output_path', type=str, default="evolve_output", help='Directory to save results')
35 | parser.add_argument('-criteria', dest='criteria', type=str, default='Which candidate generated more colorful images?', help='Criteria for decision making in the VLM.')
36 | parser.add_argument('-eval_file', dest='eval_file', type=str, default='evals.txt', help='A txt file containing a newline delimited list of prompts to evaluation against')
37 | parser.add_argument('-eval_samples', dest='eval_samples', type=int, default=2, help='The number of samples to evaluate between candidates')
38 | parser.add_argument('-device', dest='device', type=str, default="cuda:0", help='The device to run on')
39 | parser.add_argument('-reintroduction_threshold', dest='reintroduction_threshold', type=float, default=0, help='The chance to reintroduce an initial model back into the elite population. Can help with solution diversity.')
40 | parser.add_argument('-vlm', dest='vlm', type=str, default="claude", help='The vlm to use. claude or llava')
41 | parser.add_argument('-append_prompt', dest='append_prompt', type=str, default="", help='Appends to the prompt')
42 | parser.add_argument('-negative_prompt', dest='negative_prompt', type=str, default="", help='Set the negative prompt')
43 | parser.add_argument('-guidance_scale', dest='guidance_scale', type=float, default=1, help='The guidance scale to use')
44 | parser.add_argument('-scheduler', dest='scheduler', type=str, default="sgm_uniform", help='The diffusion scheduler to use')
45 | parser.add_argument('-diffusion_steps', dest='diffusion_steps', type=int, default=8, help='The number of diffusion steps to run')
46 | parser.add_argument('-diffusion_prompt_change', dest='diffusion_prompt_change', type=str, choices=["every_cycle", "never"], default="cycle", help='The type of generation cache to use. Controls when vlm image prompts are changed. Choices: never, every_cycle')
47 | parser.add_argument("-width", dest='width', type=int, default=1024, help='Width of diffusion samples to generate')
48 | parser.add_argument("-height", dest='height', type=int, default=1024, help='Height of diffusion samples to generate')
49 | parser.add_argument("-resize_width", dest='resize_width', type=int, default=512, help='Width to resized diffusion samples before sending to the VLM')
50 | parser.add_argument("-resize_height", dest='resize_height', type=int, default=512, help='Height to resize diffusion samples before sending to the VLM')
51 | parser.add_argument("-vae", dest='vae', type=str, default=None, help='Custom VAE to use during sampling')
52 | parser.add_argument("-perturb_seed_population", dest='perturb_seed_population', type=int, default=0, help='Build this many children of the seed population')
53 | return parser.parse_args()
54 |
55 | def generate_images(file_path, evals, device, cache, settings):
56 | if file_path in cache:
57 | return cache[file_path]
58 | images = []
59 | logging.info(f"Loading {file_path}")
60 |
61 | dtype = torch.bfloat16
62 | model_id = "stabilityai/stable-diffusion-3-medium-diffusers"
63 |
64 | pipe = StableDiffusion3Pipeline.from_single_file(file_path, text_encoder = None, text_encoder_2 = None, text_encoder_3 = None, device=device)
65 | pipe.text_encoder = CLIPTextModelWithProjection.from_pretrained(model_id, subfolder="text_encoder")
66 | pipe.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(model_id, subfolder="text_encoder_2")
67 | #pipe.text_encoder_3 = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_3")
68 | pipe.tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
69 | pipe.tokenizer_2 = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer_2")
70 | #pipe.tokenizer_3 = T5TokenizerFast.from_pretrained(model_id, subfolder="tokenizer_3")
71 | pipe = pipe.to(dtype).to(device)
72 |
73 | for i, evl in enumerate(evals):
74 | image = pipe(evl['prompt']+settings.append_prompt, width=settings.width, height=settings.height, negative_prompt=settings.negative_prompt, num_inference_steps=settings.diffusion_steps, guidance_scale=settings.guidance_scale, generator=torch.manual_seed(evl['seed'])).images[0]
75 | if settings.resize_width != settings.width:
76 | image = image.resize((settings.resize_width, settings.resize_height))
77 | image.save(f"output-evolve-{file_path.split('/')[-1]}-{i}.png")
78 | images.append(image)
79 |
80 | del pipe
81 | cache[file_path] = images
82 | return images
83 |
84 | def generate_b64_images(*args):
85 | # Assuming generate_images() returns a list of PIL Image objects
86 | images = generate_images(*args)
87 | b64_images = []
88 |
89 | for img in images:
90 | buffered = BytesIO()
91 | img.save(buffered, format="PNG")
92 | img_str = base64.b64encode(buffered.getvalue()).decode()
93 | b64_images.append(img_str)
94 |
95 | return b64_images
96 |
97 | def load_random_evals(file_path, count):
98 | evals = []
99 |
100 | with open(file_path, 'r') as file:
101 | lines = file.readlines()
102 |
103 | count = min(count, len(lines))
104 | selected_lines = random.sample(lines, count)
105 |
106 | for line in selected_lines:
107 | evals.append({"prompt": line.strip(), "seed": random.randint(0, 2**32)})
108 |
109 | return evals
110 |
111 | def claude_vlm_judge(criteria, prompts, b64_images_a, b64_images_b):
112 | client = anthropic.Anthropic()
113 | media_type = "image/png"
114 | prompts_list = "\n".join(["{0}. {1}".format(i, prompt) for i, prompt in enumerate(prompts)])
115 | # begin_text = f"""
116 | #Here are the prompts for the generations:
117 | #```
118 | #{prompts_list}
119 | #```
120 | #
121 | #Each candidate will be given these prompts to generate images. First you will receive candidate 1 generations based on these prompts, then candidate 2.
122 | #
123 | begin_text = f"""You will first see both candidates images then judge which did better based on the following criteria:
124 | ```
125 | Criteria: {criteria}
126 | ```
127 |
128 | Candidate 1 generations:
129 | """.strip()
130 | end_text = """
131 | Which candidate won based on the criteria? If candidate 1, output '1'. If candidate 2, output '2'. This is automated, the first 1 or 2 you output will be the winner.
132 | """.strip()
133 | messages = [
134 | {
135 | "role": "user",
136 | "content": [
137 | {
138 | "type": "text",
139 | "text": begin_text
140 | },
141 | *[
142 | {
143 | "type": "image",
144 | "source": {
145 | "type": "base64",
146 | "media_type": media_type,
147 | "data": b64_image,
148 | },
149 | } for b64_image in b64_images_a],
150 | {
151 | "type": "text",
152 | "text": "Candidate 2 generations:"
153 | },
154 | *[
155 | {
156 | "type": "image",
157 | "source": {
158 | "type": "base64",
159 | "media_type": media_type,
160 | "data": b64_image,
161 | },
162 | } for b64_image in b64_images_b],
163 | {
164 | "type": "text",
165 | "text": end_text
166 | }
167 | ],
168 | }
169 | ]
170 |
171 | model = "claude-3-haiku-20240307"
172 | message = client.messages.create(
173 | model=model,
174 | max_tokens=128,
175 | system="You are diffusion evolver AI, a judge for an image generation contest. You will be presented images from two models with the same prompt and seed. At the end you will give your judgement based on a specified criteria.",
176 | messages=messages,
177 | )
178 | text = message.content[0].text
179 | for i, ch in enumerate(text):
180 | if ch == "1" or ch == "2":
181 | return int(ch)
182 | logging.info("wtf bad output", text)
183 | raise "error"
184 |
185 | def claude_vlm_judge_with_retry(*args, max_retries=3, initial_wait=1, max_wait=10):
186 | for attempt in range(max_retries):
187 | try:
188 | return claude_vlm_judge(*args)
189 | except Exception as e:
190 | wait_time = min(max_wait, initial_wait * 2 ** attempt)
191 | wait_time += random.uniform(0, wait_time * 0.2) # Adding random jitter
192 | if attempt < max_retries - 1:
193 | # Log the full stack trace before retrying
194 | logging.exception(f"Attempt {attempt + 1} failed. Retrying in {wait_time:.2f} seconds...")
195 | time.sleep(wait_time)
196 | else:
197 | # Log the full stack trace before raising the exception after all retries have failed
198 | logging.exception("All attempts failed. Raising exception.")
199 | raise
200 |
201 | def combine_pil(a, b):
202 | total_width = a.width + b.width
203 | max_height = max(a.height, b.height)
204 |
205 | combined_image = Image.new('RGB', (total_width, max_height))
206 |
207 | combined_image.paste(a, (0, 0))
208 | combined_image.paste(b, (a.width, 0))
209 |
210 | return combined_image
211 |
212 | def llava_vlm_decide(prompt, img, device):
213 | import llava_util
214 | model = "liuhaotian/llava-v1.6-mistral-7b"
215 | text = llava_util.run_llava(model, None, prompt, [img], device=device, max_new_tokens=128)
216 | for i, ch in enumerate(text):
217 | if ch == "1" or ch == "2":
218 | return int(ch)
219 | logging.info("wtf bad output", text)
220 | raise "error"
221 |
222 | def llava_vlm_judge(criteria, prompts, b64_images_a, b64_images_b, device):
223 | for (prompt, img_a, img_b) in zip(prompts, b64_images_a, b64_images_b):
224 | img_combine = combine_pil(img_a, img_b)
225 | prompt = f"You are a judge in an image generation contest. {criteria} '1' for the image on the left, '2' for the image on the right. Answer only '1'(left) or '2'(right). This is automated and the first number in your answer will be chosen."
226 | return llava_vlm_decide(prompt, img_combine, device)
227 |
228 | def llava_vlm_judge_with_retry(*args, max_retries=3):
229 | for i in range(max_retries):
230 | try:
231 | return llava_vlm_judge(*args)
232 | except Exception as e:
233 | if i < max_retries:
234 | logging.exception("Llava did not give output. Retrying...")
235 | else:
236 | logging.exception("Llava failed!")
237 | raise
238 |
239 | def compare(cache, criteria, device, evals, metrics, vlm, settings):
240 | async def vlm_compare(a: evolve.Candidate, b:evolve.Candidate):
241 | cache_key = 'compare:'+a.file_path+'.'+b.file_path
242 | if cache_key in cache:
243 | return cache[cache_key]
244 | reverse = random.random() > 0.5
245 | prompts = [evl["prompt"] for evl in evals]
246 | if reverse:
247 | a, b = b, a
248 |
249 | if vlm == 'claude':
250 | b64_images_a = generate_b64_images(a.file_path, evals, device, cache, settings)
251 | b64_images_b = generate_b64_images(b.file_path, evals, device, cache, settings)
252 | judgement = claude_vlm_judge_with_retry(criteria, prompts, b64_images_a, b64_images_b)
253 | elif vlm == 'llava':
254 | images_a = generate_images(a.file_path, evals, device, cache, settings)
255 | images_b = generate_images(b.file_path, evals, device, cache, settings)
256 | judgement = llava_vlm_judge_with_retry(criteria, prompts, images_a, images_b, device)
257 | else:
258 | raise "vlm not supported:" + vlm
259 |
260 | if reverse:
261 | judgement = (1 if judgement == 2 else 2)
262 | metrics.total += 1
263 |
264 | if judgement == 1:
265 | metrics.yays += 1
266 | else:
267 | metrics.nays += 1
268 | logging.info(f"Number of comparisons Total: {metrics.total} Yay: {metrics.yays} Nay: {metrics.nays}")
269 |
270 |
271 | if judgement == 1:
272 | cache[cache_key] = 1
273 | return 1
274 | cache[cache_key] = -1
275 | return -1
276 | return vlm_compare
277 |
278 | @dataclass
279 | class Metrics:
280 | total: int = 0
281 | yays: int = 0
282 | nays: int = 0
283 |
284 | @dataclass
285 | class DiffusionSettings:
286 | guidance_scale: int
287 | negative_prompt: str
288 | append_prompt: str
289 | diffusion_steps: int
290 | width: int
291 | height: int
292 | resize_width: int
293 | resize_height: int
294 | scheduler: str
295 | vae: str
296 |
297 | async def main():
298 | # Parse command-line arguments
299 | args = parse_arguments()
300 | if args.seed is not None:
301 | torch.manual_seed(args.seed)
302 | os.makedirs(args.output_path, exist_ok=True)
303 | metrics = Metrics()
304 | cache = {}
305 | evals = load_random_evals(args.eval_file, args.eval_samples)
306 | settings = DiffusionSettings(
307 | append_prompt = args.append_prompt,
308 | diffusion_steps = args.diffusion_steps,
309 | guidance_scale = args.guidance_scale,
310 | height = args.height,
311 | negative_prompt = args.negative_prompt,
312 | resize_height = args.resize_height,
313 | resize_width = args.resize_width,
314 | width = args.width,
315 | vae = args.vae,
316 | scheduler = args.scheduler
317 | )
318 | initial_population = evolve.load_candidates(args.model_list)
319 | initial_populiation_count = len(initial_population)
320 | while(len(initial_population) < args.perturb_seed_population):
321 | parent = random.choice(initial_population[:initial_populiation_count])
322 | file_path = str(Path(args.output_path) / (str(uuid.uuid4())+".safetensors"))
323 | offspring = evolve.Candidate(file_path, parent.p, parent.lambda_val, initial_population=True)
324 | offspring.generation = parent.generation + 1
325 | print("perturbing from(clone)", parent.file_path)
326 | tensor_map = evolve.perturb(parent)
327 | print("saving", offspring.file_path)
328 | evolve.save_file(tensor_map, offspring.file_path)
329 | del tensor_map
330 | initial_population.append(offspring)
331 | print("--", initial_population)
332 |
333 | population = list(initial_population)
334 | evolve.write_yaml(population, Path(args.output_path) / "initial.yaml")
335 | logging.info("Beginning evolution")
336 |
337 | async for i in tqdm(range(args.cycles), desc='Evolving'):
338 | if args.diffusion_prompt_change == "every_cycle":
339 | evals = load_random_evals(args.eval_file, args.eval_samples)
340 | cache = {}
341 | comparator = compare(cache, args.criteria, args.device, evals, metrics, args.vlm, settings)
342 | population = await evolve.run_evolution(population, args.elite, args.parents, args.population, args.mutation, args.output_path, comparator)
343 | evolve.write_yaml(population, Path(args.output_path) / f"step-{i}.yaml")
344 | if random.random() < args.reintroduction_threshold:
345 | population.append(random.choice(initial_population))
346 |
347 | logging.info("Resulting population:")
348 | evolve.log_candidates(population)
349 | if __name__ == "__main__":
350 | asyncio.run(main())
351 |
--------------------------------------------------------------------------------
/sdxl-evolve.py:
--------------------------------------------------------------------------------
1 | import anthropic
2 | import argparse
3 | import asyncio
4 | import base64
5 | import evolve
6 | import logging
7 | import time
8 | import os
9 | import random
10 | import torch
11 |
12 | from PIL import Image
13 | from dataclasses import dataclass
14 | from diffusers import AutoencoderKL
15 | from diffusers import EulerDiscreteScheduler
16 | from diffusers import StableDiffusionXLPipeline
17 | from huggingface_hub import hf_hub_download
18 | from io import BytesIO
19 | from pathlib import Path
20 | from tqdm.asyncio import tqdm
21 |
22 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
23 |
24 | def parse_arguments():
25 | parser = argparse.ArgumentParser(description="Evolutionary merge")
26 | parser.add_argument('model_list', type=str, help='yml file containing list of models in the initial population')
27 | parser.add_argument('-seed', dest='seed', type=int, default=None, help='Random seed')
28 | parser.add_argument('-cycles', dest='cycles', type=int, default=10, help='Number of evolutionary cycles')
29 | parser.add_argument('-elite', dest='elite', type=int, default=5, help='Number of elite candidates to keep every iteration')
30 | parser.add_argument('-parents', dest='parents', type=int, default=2, help='Number of parents for each child')
31 | parser.add_argument('-population', dest='population', type=int, default=20, help='Size of population')
32 | parser.add_argument('-mutation', dest='mutation', type=float, default=0.05, help='Chance of mutation')
33 | parser.add_argument('-output_path', dest='output_path', type=str, default="evolve_output", help='Directory to save results')
34 | parser.add_argument('-criteria', dest='criteria', type=str, default='Which candidate generated more colorful images?', help='Criteria for decision making in the VLM.')
35 | parser.add_argument('-eval_file', dest='eval_file', type=str, default='evals.txt', help='A txt file containing a newline delimited list of prompts to evaluation against')
36 | parser.add_argument('-eval_samples', dest='eval_samples', type=int, default=2, help='The number of samples to evaluate between candidates')
37 | parser.add_argument('-device', dest='device', type=str, default="cuda:0", help='The device to run on')
38 | parser.add_argument('-reintroduction_threshold', dest='reintroduction_threshold', type=float, default=0, help='The chance to reintroduce an initial model back into the elite population. Can help with solution diversity.')
39 | parser.add_argument('-vlm', dest='vlm', type=str, default="claude", help='The vlm to use. claude or llava')
40 | parser.add_argument('-append_prompt', dest='append_prompt', type=str, default="", help='Appends to the prompt')
41 | parser.add_argument('-negative_prompt', dest='negative_prompt', type=str, default="", help='Set the negative prompt')
42 | parser.add_argument('-guidance_scale', dest='guidance_scale', type=float, default=1, help='The guidance scale to use')
43 | parser.add_argument('-scheduler', dest='scheduler', type=str, default="sgm_uniform", help='The diffusion scheduler to use')
44 | parser.add_argument('-diffusion_steps', dest='diffusion_steps', type=int, default=8, help='The number of diffusion steps to run')
45 | parser.add_argument('-diffusion_prompt_change', dest='diffusion_prompt_change', type=str, choices=["every_cycle", "never"], default="cycle", help='The type of generation cache to use. Controls when vlm image prompts are changed. Choices: never, every_cycle')
46 | parser.add_argument("-width", dest='width', type=int, default=1024, help='Width of diffusion samples to generate')
47 | parser.add_argument("-height", dest='height', type=int, default=1024, help='Height of diffusion samples to generate')
48 | parser.add_argument("-resize_width", dest='resize_width', type=int, default=512, help='Width to resized diffusion samples before sending to the VLM')
49 | parser.add_argument("-resize_height", dest='resize_height', type=int, default=512, help='Height to resize diffusion samples before sending to the VLM')
50 | parser.add_argument("-vae", dest='vae', type=str, default=None, help='Custom VAE to use during sampling')
51 | return parser.parse_args()
52 |
53 | def generate_images(file_path, evals, device, cache, settings):
54 | if file_path in cache:
55 | return cache[file_path]
56 | images = []
57 | logging.info(f"Loading {file_path}")
58 |
59 | pipe = StableDiffusionXLPipeline.from_single_file(file_path, torch_dtype=torch.float16, variant="fp16", use_safetensors=True).to(device)
60 | if settings.vae:
61 | pipe.vae = AutoencoderKL.from_single_file(settings.vae, torch_dtype=torch.float16).to(pipe.device)
62 | if settings.scheduler == "sgm_uniform":
63 | pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
64 | else:
65 | pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
66 |
67 | for i, evl in enumerate(evals):
68 | image = pipe(evl['prompt']+settings.append_prompt, width=settings.width, height=settings.height, negative_prompt=settings.negative_prompt, num_inference_steps=settings.diffusion_steps, guidance_scale=settings.guidance_scale, generator=torch.manual_seed(evl['seed'])).images[0]
69 | if settings.resize_width != settings.width:
70 | image = image.resize((settings.resize_width, settings.resize_height))
71 | image.save(f"output-evolve-{file_path.split('/')[-1]}-{i}.png")
72 | images.append(image)
73 |
74 | del pipe
75 | cache[file_path] = images
76 | return images
77 |
78 | def generate_b64_images(*args):
79 | # Assuming generate_images() returns a list of PIL Image objects
80 | images = generate_images(*args)
81 | b64_images = []
82 |
83 | for img in images:
84 | buffered = BytesIO()
85 | img.save(buffered, format="PNG")
86 | img_str = base64.b64encode(buffered.getvalue()).decode()
87 | b64_images.append(img_str)
88 |
89 | return b64_images
90 |
91 | def load_random_evals(file_path, count):
92 | evals = []
93 |
94 | with open(file_path, 'r') as file:
95 | lines = file.readlines()
96 |
97 | count = min(count, len(lines))
98 | selected_lines = random.sample(lines, count)
99 |
100 | for line in selected_lines:
101 | evals.append({"prompt": line.strip(), "seed": random.randint(0, 2**32)})
102 |
103 | return evals
104 |
105 | def claude_vlm_judge(criteria, prompts, b64_images_a, b64_images_b):
106 | client = anthropic.Anthropic()
107 | media_type = "image/png"
108 | prompts_list = "\n".join(["{0}. {1}".format(i, prompt) for i, prompt in enumerate(prompts)])
109 | # begin_text = f"""
110 | #Here are the prompts for the generations:
111 | #```
112 | #{prompts_list}
113 | #```
114 | #
115 | #Each candidate will be given these prompts to generate images. First you will receive candidate 1 generations based on these prompts, then candidate 2.
116 | #
117 | begin_text = f"""You will first see both candidates images then judge which did better based on the following criteria:
118 | ```
119 | Criteria: {criteria}
120 | ```
121 |
122 | Candidate 1 generations:
123 | """.strip()
124 | end_text = """
125 | Which candidate won based on the criteria? If candidate 1, output '1'. If candidate 2, output '2'. This is automated, the first 1 or 2 you output will be the winner.
126 | """.strip()
127 | messages = [
128 | {
129 | "role": "user",
130 | "content": [
131 | {
132 | "type": "text",
133 | "text": begin_text
134 | },
135 | *[
136 | {
137 | "type": "image",
138 | "source": {
139 | "type": "base64",
140 | "media_type": media_type,
141 | "data": b64_image,
142 | },
143 | } for b64_image in b64_images_a],
144 | {
145 | "type": "text",
146 | "text": "Candidate 2 generations:"
147 | },
148 | *[
149 | {
150 | "type": "image",
151 | "source": {
152 | "type": "base64",
153 | "media_type": media_type,
154 | "data": b64_image,
155 | },
156 | } for b64_image in b64_images_b],
157 | {
158 | "type": "text",
159 | "text": end_text
160 | }
161 | ],
162 | }
163 | ]
164 |
165 | model = "claude-3-haiku-20240307"
166 | message = client.messages.create(
167 | model=model,
168 | max_tokens=128,
169 | system="You are diffusion evolver AI, a judge for an image generation contest. You will be presented images from two models with the same prompt and seed. At the end you will give your judgement based on a specified criteria.",
170 | messages=messages,
171 | )
172 | text = message.content[0].text
173 | for i, ch in enumerate(text):
174 | if ch == "1" or ch == "2":
175 | return int(ch)
176 | logging.info("wtf bad output", text)
177 | raise "error"
178 |
179 | def claude_vlm_judge_with_retry(*args, max_retries=3, initial_wait=1, max_wait=10):
180 | for attempt in range(max_retries):
181 | try:
182 | return claude_vlm_judge(*args)
183 | except Exception as e:
184 | wait_time = min(max_wait, initial_wait * 2 ** attempt)
185 | wait_time += random.uniform(0, wait_time * 0.2) # Adding random jitter
186 | if attempt < max_retries - 1:
187 | # Log the full stack trace before retrying
188 | logging.exception(f"Attempt {attempt + 1} failed. Retrying in {wait_time:.2f} seconds...")
189 | time.sleep(wait_time)
190 | else:
191 | # Log the full stack trace before raising the exception after all retries have failed
192 | logging.exception("All attempts failed. Raising exception.")
193 | raise
194 |
195 | def combine_pil(a, b):
196 | total_width = a.width + b.width
197 | max_height = max(a.height, b.height)
198 |
199 | combined_image = Image.new('RGB', (total_width, max_height))
200 |
201 | combined_image.paste(a, (0, 0))
202 | combined_image.paste(b, (a.width, 0))
203 |
204 | return combined_image
205 |
206 | def llava_vlm_decide(prompt, img, device):
207 | import llava_util
208 | model = "liuhaotian/llava-v1.6-mistral-7b"
209 | text = llava_util.run_llava(model, None, prompt, [img], device=device, max_new_tokens=128)
210 | for i, ch in enumerate(text):
211 | if ch == "1" or ch == "2":
212 | return int(ch)
213 | logging.info("wtf bad output", text)
214 | raise "error"
215 |
216 | def llava_vlm_judge(criteria, prompts, b64_images_a, b64_images_b, device):
217 | for (prompt, img_a, img_b) in zip(prompts, b64_images_a, b64_images_b):
218 | img_combine = combine_pil(img_a, img_b)
219 | prompt = f"You are a judge in an image generation contest. {criteria} '1' for the image on the left, '2' for the image on the right. Answer only '1'(left) or '2'(right). This is automated and the first number in your answer will be chosen."
220 | return llava_vlm_decide(prompt, img_combine, device)
221 |
222 | def llava_vlm_judge_with_retry(*args, max_retries=3):
223 | for i in range(max_retries):
224 | try:
225 | return llava_vlm_judge(*args)
226 | except Exception as e:
227 | if i < max_retries:
228 | logging.exception("Llava did not give output. Retrying...")
229 | else:
230 | logging.exception("Llava failed!")
231 | raise
232 |
233 | def compare(cache, criteria, device, evals, metrics, vlm, settings):
234 | async def vlm_compare(a: evolve.Candidate, b:evolve.Candidate):
235 | cache_key = 'compare:'+a.file_path+'.'+b.file_path
236 | if cache_key in cache:
237 | return cache[cache_key]
238 | reverse = random.random() > 0.5
239 | prompts = [evl["prompt"] for evl in evals]
240 | if reverse:
241 | a, b = b, a
242 |
243 | if vlm == 'claude':
244 | b64_images_a = generate_b64_images(a.file_path, evals, device, cache, settings)
245 | b64_images_b = generate_b64_images(b.file_path, evals, device, cache, settings)
246 | judgement = claude_vlm_judge_with_retry(criteria, prompts, b64_images_a, b64_images_b)
247 | elif vlm == 'llava':
248 | images_a = generate_images(a.file_path, evals, device, cache, settings)
249 | images_b = generate_images(b.file_path, evals, device, cache, settings)
250 | judgement = llava_vlm_judge_with_retry(criteria, prompts, images_a, images_b, device)
251 | else:
252 | raise "vlm not supported:" + vlm
253 |
254 | if reverse:
255 | judgement = (1 if judgement == 2 else 2)
256 | metrics.total += 1
257 |
258 | if judgement == 1:
259 | metrics.yays += 1
260 | else:
261 | metrics.nays += 1
262 | logging.info(f"Number of comparisons Total: {metrics.total} Yay: {metrics.yays} Nay: {metrics.nays}")
263 |
264 |
265 | if judgement == 1:
266 | cache[cache_key] = 1
267 | return 1
268 | cache[cache_key] = -1
269 | return -1
270 | return vlm_compare
271 |
272 | @dataclass
273 | class Metrics:
274 | total: int = 0
275 | yays: int = 0
276 | nays: int = 0
277 |
278 | @dataclass
279 | class DiffusionSettings:
280 | guidance_scale: int
281 | negative_prompt: str
282 | append_prompt: str
283 | diffusion_steps: int
284 | width: int
285 | height: int
286 | resize_width: int
287 | resize_height: int
288 | scheduler: str
289 | vae: str
290 |
291 | async def main():
292 | # Parse command-line arguments
293 | args = parse_arguments()
294 | if args.seed is not None:
295 | torch.manual_seed(args.seed)
296 | os.makedirs(args.output_path, exist_ok=True)
297 | metrics = Metrics()
298 | cache = {}
299 | evals = load_random_evals(args.eval_file, args.eval_samples)
300 | settings = DiffusionSettings(
301 | append_prompt = args.append_prompt,
302 | diffusion_steps = args.diffusion_steps,
303 | guidance_scale = args.guidance_scale,
304 | height = args.height,
305 | negative_prompt = args.negative_prompt,
306 | resize_height = args.resize_height,
307 | resize_width = args.resize_width,
308 | width = args.width,
309 | vae = args.vae,
310 | scheduler = args.scheduler
311 | )
312 | initial_population = evolve.load_candidates(args.model_list)
313 | population = list(initial_population)
314 | evolve.write_yaml(population, Path(args.output_path) / "initial.yaml")
315 | logging.info("Beginning evolution")
316 |
317 | async for i in tqdm(range(args.cycles), desc='Evolving'):
318 | if args.diffusion_prompt_change == "every_cycle":
319 | evals = load_random_evals(args.eval_file, args.eval_samples)
320 | cache = {}
321 | comparator = compare(cache, args.criteria, args.device, evals, metrics, args.vlm, settings)
322 | population = await evolve.run_evolution(population, args.elite, args.parents, args.population, args.mutation, args.output_path, comparator)
323 | evolve.write_yaml(population, Path(args.output_path) / f"step-{i}.yaml")
324 | if random.random() < args.reintroduction_threshold:
325 | population.append(random.choice(initial_population))
326 |
327 | logging.info("Resulting population:")
328 | evolve.log_candidates(population)
329 | if __name__ == "__main__":
330 | asyncio.run(main())
331 |
--------------------------------------------------------------------------------