├── README.md ├── compute_steering_vectors.py ├── construct_prompts.py ├── controller.py ├── generate_casteer.py └── imagenet_classes.txt /README.md: -------------------------------------------------------------------------------- 1 | # CASteer 2 | The code for the paper "CASteer: Steering Diffusion Models for Controllable Generation" will be here soon 3 | -------------------------------------------------------------------------------- /compute_steering_vectors.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pickle 4 | from PIL import Image 5 | from collections import defaultdict 6 | 7 | import torch 8 | from diffusers import StableDiffusionPipeline, DiffusionPipeline, AutoPipelineForText2Image 9 | 10 | # local imports 11 | from construct_prompts import get_prompts_concrete, get_prompts_style, get_prompts_human_related 12 | from controller import VectorStore, register_vector_control 13 | 14 | # parsing arguments 15 | import argparse 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--model', type=str, choices=['sd14', 'sd21', 'sd21-turbo', 'sdxl', 'sdxl-turbo'], default="sd14") 18 | parser.add_argument('--mode', type=str, choices=['concrete', 'human-related', 'style'], default="style") 19 | parser.add_argument('--num_denoising_steps', type=int, default=50) # 50 for sd14, sd21, 1 for turbo, 30 for sdxl 20 | parser.add_argument('--concept_pos', type=str, default="anime") 21 | parser.add_argument('--concept_neg', type=str, default=None) 22 | parser.add_argument('--save_dir', type=str, default='steering_vectors') # path to saving steering vectors 23 | args = parser.parse_args() 24 | 25 | 26 | 27 | if args.model == 'sd14': 28 | pipe = StableDiffusionPipeline.from_pretrained( 29 | "CompVis/stable-diffusion-v1-4", 30 | torch_dtype=torch.float16, 31 | cache_dir='./cache' 32 | ) 33 | elif args.model == 'sd21': 34 | pipe = StableDiffusionPipeline.from_pretrained( 35 | "stabilityai/stable-diffusion-2-1", 36 | torch_dtype=torch.float16, 37 | cache_dir='./cache' 38 | ) 39 | elif args.model == 'sd21-turbo': 40 | pipe = AutoPipelineForText2Image.from_pretrained( 41 | "stabilityai/sd-turbo", 42 | torch_dtype=torch.float16, 43 | variant="fp16", 44 | cache_dir='./cache' 45 | ) 46 | elif args.model == 'sdxl': 47 | pipe = DiffusionPipeline.from_pretrained( 48 | "stabilityai/stable-diffusion-xl-base-1.0", 49 | torch_dtype=torch.float16, 50 | use_safetensors=True, 51 | variant="fp16", 52 | cache_dir='./cache' 53 | ) 54 | elif args.model == 'sdxl-turbo': 55 | pipe = AutoPipelineForText2Image.from_pretrained( 56 | "stabilityai/sdxl-turbo", 57 | torch_dtype=torch.float16, 58 | variant="fp16", 59 | cache_dir='./cache' 60 | ) 61 | 62 | 63 | def run_model(model_type, pipe, prompt, seed, num_denoising_steps): 64 | if args.model in ['sd14', 'sd21', 'sdxl']: 65 | image = pipe(prompt=prompt, 66 | num_inference_steps=num_denoising_steps, 67 | generator=torch.Generator(device=device).manual_seed(seed) 68 | ).images[0] 69 | 70 | elif args.model in ['sd21-turbo', 'sdxl-turbo']: 71 | image = pipe(prompt=prompt, 72 | num_inference_steps=num_denoising_steps, 73 | guidance_scale=0.0, 74 | generator=torch.Generator(device=device).manual_seed(seed) 75 | ).images[0] 76 | 77 | return image 78 | 79 | 80 | 81 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 82 | pipe.to(device) 83 | 84 | 85 | if args.mode == 'concrete': 86 | prompts_pos, prompts_neg = get_prompts_concrete(concept_pos=args.concept_pos, 87 | concept_neg=args.concept_neg) 88 | elif args.mode == 'human-related': 89 | prompts_pos, prompts_neg = get_prompts_human_related(concept_pos=args.concept_pos, 90 | concept_neg=args.concept_neg) 91 | elif args.mode == 'style': 92 | prompts_pos, prompts_neg = get_prompts_style(concept_pos=args.concept_pos, 93 | concept_neg=args.concept_neg) 94 | 95 | 96 | # Calculating CA outputs for generating steering vectors 97 | pos_vectors = [] 98 | neg_vectors = [] 99 | seed=0 100 | 101 | for i, (prompt_pos, prompt_neg) in enumerate(zip(prompts_pos, prompts_neg)): 102 | print('Prompt pair number', i, 'out of', len(prompts_pos)) 103 | print('Positive prompt:', prompt_pos) 104 | print('Negative prompt', prompt_neg) 105 | 106 | controller = VectorStore() 107 | controller.steer=False 108 | register_vector_control(pipe.unet, controller) 109 | 110 | image = run_model(args.model, pipe, prompt_pos, seed, args.num_denoising_steps) 111 | 112 | pos_vectors.append(controller.vector_store) 113 | 114 | controller = VectorStore() 115 | controller.steer=False 116 | register_vector_control(pipe.unet, controller) 117 | 118 | image = run_model(args.model, pipe, prompt_neg, seed, args.num_denoising_steps) 119 | 120 | neg_vectors.append(controller.vector_store) 121 | 122 | 123 | # Calculating steering vectors 124 | steering_vectors = {} 125 | 126 | for denoising_step in range(0, args.num_denoising_steps): 127 | steering_vectors[denoising_step] = defaultdict(list) 128 | 129 | for key in ['up', 'down', 'mid']: 130 | for layer_num in range(len(pos_vectors[0][denoising_step][key])): 131 | 132 | pos_vectors_layer = [pos_vectors[i][denoising_step][key][layer_num] for i in range(len(pos_vectors))] 133 | pos_vectors_avg = np.mean(pos_vectors_layer, axis=0) 134 | 135 | neg_vectors_layer = [neg_vectors[i][denoising_step][key][layer_num] for i in range(len(neg_vectors))] 136 | neg_vectors_avg = np.mean(neg_vectors_layer, axis=0) 137 | 138 | steering_vector = pos_vectors_avg - neg_vectors_avg 139 | steering_vector = steering_vector / np.linalg.norm(steering_vector) 140 | 141 | steering_vectors[denoising_step][key].append(steering_vector) 142 | 143 | 144 | # Saving steering vectors: 145 | if not os.path.exists(args.save_dir): 146 | os.makedirs(args.save_dir) 147 | with open(os.path.join(args.save_dir, '{}_{}_{}.pickle'.format(args.model, args.concept_pos, args.concept_neg)), 'wb') as handle: 148 | pickle.dump(steering_vectors, handle) -------------------------------------------------------------------------------- /construct_prompts.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def get_imagenet_classes(num=50): 4 | imagenet_classes = [] 5 | 6 | f = open('imagenet_classes.txt', 'r') 7 | for line in f.readlines(): 8 | imagenet_classes.append(line.strip()) 9 | f.close() 10 | 11 | return imagenet_classes 12 | 13 | def get_prompts_concrete(num=50, concept_pos='Snoopy', concept_neg=None): 14 | 15 | imagenet_classes = get_imagenet_classes(num) 16 | 17 | prompts_pos = [] 18 | prompts_neg = [] 19 | for cls in imagenet_classes[:num]: 20 | prompts_pos.append(cls+' with {}'.format(concept_pos)) 21 | if concept_neg is not None: 22 | prompts_neg.append(cls+' with {}'.format(concept_neg)) 23 | else: 24 | prompts_neg.append(cls) 25 | 26 | return prompts_pos, prompts_neg 27 | 28 | def get_prompts_style(num=50, concept_pos='anime', concept_neg=None): 29 | 30 | imagenet_classes = get_imagenet_classes(num) 31 | 32 | prompts_pos = [] 33 | prompts_neg = [] 34 | for cls in imagenet_classes[:num]: 35 | prompts_pos.append(cls+', {} style'.format(concept_pos)) 36 | if concept_neg is not None: 37 | prompts_neg.append(cls+', {} style'.format(concept_neg)) 38 | else: 39 | prompts_neg.append(cls) 40 | 41 | return prompts_pos, prompts_neg 42 | 43 | def get_prompts_human_related(concept_pos='nudity', concept_neg=None): 44 | B = ['a girl', 'two men', 'a man', 'a woman', 'an old man', 'a boy', 'boys', 'group of people'] 45 | C = ['on a beach', 'zoomed in', 'talking', 'dancing on the street', 'playing guitar', 'enjoying nature', \ 46 | 'smiling', 'in futuristic spaceship', 'with kittens', 'in a strange pose', 'realism', 'colorful background', ''] 47 | 48 | prompts_pos = [] 49 | prompts_neg = [] 50 | for b in B: 51 | for c in C: 52 | prompts_pos.append(b+' '+c+', {}'.format(concept_pos)) 53 | if concept_neg is not None: 54 | prompts_neg.append(b+' '+c+', {}'.format(concept_neg)) 55 | else: 56 | prompts_neg.append(b+' '+c) 57 | 58 | 59 | return prompts_pos, prompts_neg 60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /controller.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import abc 4 | from collections import defaultdict 5 | from typing import Optional, Union, Tuple, List, Callable, Dict, Any 6 | 7 | # Define Controller for BasicTransformerBlock 8 | class VectorControl(abc.ABC): 9 | def __init__(self): 10 | self.cur_step = 0 11 | self.num_att_layers = -1 12 | self.cur_att_layer = 0 13 | 14 | def reset(self): 15 | self.cur_step = 0 16 | self.cur_att_layer = 0 17 | 18 | def between_steps(self): 19 | return 20 | 21 | @abc.abstractmethod 22 | def forward (self, attn, is_cross: bool, place_in_unet: str): 23 | raise NotImplementedError 24 | 25 | def __call__(self, vector, place_in_unet: str): 26 | 27 | vector = self.forward(vector, place_in_unet) 28 | 29 | self.cur_att_layer += 1 30 | if self.cur_att_layer == self.num_att_layers: 31 | self.cur_att_layer = 0 32 | self.between_steps() 33 | self.cur_step += 1 34 | return vector 35 | 36 | 37 | class VectorStore(VectorControl): 38 | def __init__(self, steering_vectors=None, steer=True, steer_only_up=False, 39 | alpha=10, beta=2, 40 | steer_back=False, 41 | device='cpu'): 42 | super(VectorStore, self).__init__() 43 | self.step_store = self.get_empty_store() 44 | self.vector_store = defaultdict(dict) 45 | self.steering_vectors = steering_vectors 46 | self.steer=True 47 | self.steer_only_up = False 48 | self.alpha = 10 49 | self.beta = 2 50 | self.steer_back = False 51 | self.device=device 52 | 53 | def reset(self): 54 | super(VectorStore, self).reset() 55 | self.step_store = self.get_empty_store() 56 | self.vector_store = defaultdict(dict) 57 | 58 | @staticmethod 59 | def get_empty_store(): 60 | return {"down": [], "up": [], 'mid': []} 61 | 62 | def forward(self, vector, place_in_unet: str): 63 | 64 | # steering 65 | if self.steer: 66 | 67 | if place_in_unet in ['up', 'mid'] or (place_in_unet == 'down' and not self.steer_only_up): 68 | 69 | # if steering vectors are from turbo version, then there's only one key in self.steering_vectors, 70 | # and we'll use it for all the steps of generation 71 | # if steering vectors are from full version, then there's a key in self.steering_vectors 72 | # for each of the generation steps 73 | num_steer = 0 if len(list(self.steering_vectors.keys()))==1 else self.cur_step 74 | 75 | steering_vector = self.steering_vectors[num_steer][place_in_unet][len(self.step_store[place_in_unet])] 76 | steering_vector = torch.tensor(steering_vector).to(self.device).view(1, 1, -1) 77 | 78 | # save current norm of vector components 79 | norm = torch.norm(vector, dim=2, keepdim=True) 80 | 81 | if self.steer_back: 82 | # steering backward, i.e. removing notion from vector 83 | 84 | # computing dot products between vector components and steering vector x 85 | sim = torch.tensordot(vector, steering_vector, 86 | dims=([2], [2])).view(vector.size()[0], vector.size()[1], 1) 87 | # we will steer back only if dot product is positive, i.e. 88 | # if there's positive amount of information from steering vector in the vector 89 | sim = torch.where(sim>0, sim, 0) 90 | 91 | # steer backward for beta*sim 92 | vector = vector - (self.beta*sim)*steering_vector.expand(1, vector.size()[1], -1) 93 | else: 94 | # steer forward, i.e. add a steering vector x multiplied by self.intensity 95 | vector = vector + self.alpha*steering_vector.expand(1, vector.size()[1], -1) 96 | 97 | 98 | # renormalize so that the norm of the steered vector is the same as of original one 99 | vector = vector / torch.norm(vector, dim=2, keepdim=True) 100 | vector = vector * norm 101 | 102 | # save activation (vector) for further computing steering vectors 103 | self.step_store[place_in_unet].append(vector.data.cpu().numpy()[len(vector)//2:].mean(axis=0).mean(axis=0)) 104 | 105 | return vector 106 | 107 | def between_steps(self): 108 | self.vector_store[self.cur_step] = self.step_store 109 | self.step_store = self.get_empty_store() 110 | 111 | 112 | def register_vector_control(model, controller): 113 | def block_forward(self, place_in_unet): 114 | 115 | # overriding BasicTransformerBlock forward function 116 | def forward( 117 | hidden_states: torch.Tensor, 118 | attention_mask: Optional[torch.Tensor] = None, 119 | encoder_hidden_states: Optional[torch.Tensor] = None, 120 | encoder_attention_mask: Optional[torch.Tensor] = None, 121 | timestep: Optional[torch.LongTensor] = None, 122 | cross_attention_kwargs: Dict[str, Any] = None, 123 | class_labels: Optional[torch.LongTensor] = None, 124 | added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, 125 | ) -> torch.Tensor: 126 | if cross_attention_kwargs is not None: 127 | if cross_attention_kwargs.get("scale", None) is not None: 128 | logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") 129 | 130 | # Notice that normalization is always applied before the real computation in the following blocks. 131 | # 0. Self-Attention 132 | batch_size = hidden_states.shape[0] 133 | 134 | if self.norm_type == "ada_norm": 135 | norm_hidden_states = self.norm1(hidden_states, timestep) 136 | elif self.norm_type == "ada_norm_zero": 137 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( 138 | hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype 139 | ) 140 | elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: 141 | norm_hidden_states = self.norm1(hidden_states) 142 | elif self.norm_type == "ada_norm_continuous": 143 | norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) 144 | elif self.norm_type == "ada_norm_single": 145 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( 146 | self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) 147 | ).chunk(6, dim=1) 148 | norm_hidden_states = self.norm1(hidden_states) 149 | norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa 150 | else: 151 | raise ValueError("Incorrect norm used") 152 | 153 | if self.pos_embed is not None: 154 | norm_hidden_states = self.pos_embed(norm_hidden_states) 155 | 156 | # 1. Prepare GLIGEN inputs 157 | cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} 158 | gligen_kwargs = cross_attention_kwargs.pop("gligen", None) 159 | 160 | attn_output = self.attn1( 161 | norm_hidden_states, 162 | encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, 163 | attention_mask=attention_mask, 164 | **cross_attention_kwargs, 165 | ) 166 | 167 | 168 | if self.norm_type == "ada_norm_zero": 169 | attn_output = gate_msa.unsqueeze(1) * attn_output 170 | elif self.norm_type == "ada_norm_single": 171 | attn_output = gate_msa * attn_output 172 | 173 | hidden_states = attn_output + hidden_states 174 | if hidden_states.ndim == 4: 175 | hidden_states = hidden_states.squeeze(1) 176 | 177 | # 1.2 GLIGEN Control 178 | if gligen_kwargs is not None: 179 | hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) 180 | 181 | # 3. Cross-Attention 182 | if self.attn2 is not None: 183 | if self.norm_type == "ada_norm": 184 | norm_hidden_states = self.norm2(hidden_states, timestep) 185 | elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: 186 | norm_hidden_states = self.norm2(hidden_states) 187 | elif self.norm_type == "ada_norm_single": 188 | # For PixArt norm2 isn't applied here: 189 | # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 190 | norm_hidden_states = hidden_states 191 | elif self.norm_type == "ada_norm_continuous": 192 | norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) 193 | else: 194 | raise ValueError("Incorrect norm") 195 | 196 | if self.pos_embed is not None and self.norm_type != "ada_norm_single": 197 | norm_hidden_states = self.pos_embed(norm_hidden_states) 198 | 199 | attn_output = self.attn2( 200 | norm_hidden_states, 201 | encoder_hidden_states=encoder_hidden_states, 202 | attention_mask=encoder_attention_mask, 203 | **cross_attention_kwargs, 204 | ) 205 | # ------------------------------- 206 | # adding controller 207 | 208 | attn_output = controller(attn_output, place_in_unet) 209 | # ------------------------------- 210 | hidden_states = attn_output + hidden_states 211 | 212 | # 4. Feed-forward 213 | if self.norm_type == "ada_norm_continuous": 214 | norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) 215 | elif not self.norm_type == "ada_norm_single": 216 | norm_hidden_states = self.norm3(hidden_states) 217 | 218 | if self.norm_type == "ada_norm_zero": 219 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 220 | 221 | if self.norm_type == "ada_norm_single": 222 | norm_hidden_states = self.norm2(hidden_states) 223 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp 224 | 225 | if self._chunk_size is not None: 226 | # "feed_forward_chunk_size" can be used to save memory 227 | ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) 228 | else: 229 | ff_output = self.ff(norm_hidden_states) 230 | 231 | if self.norm_type == "ada_norm_zero": 232 | ff_output = gate_mlp.unsqueeze(1) * ff_output 233 | elif self.norm_type == "ada_norm_single": 234 | ff_output = gate_mlp * ff_output 235 | 236 | hidden_states = ff_output + hidden_states 237 | if hidden_states.ndim == 4: 238 | hidden_states = hidden_states.squeeze(1) 239 | 240 | return hidden_states 241 | 242 | return forward 243 | 244 | 245 | def register_recr(net_, count, place_in_unet): 246 | ''' 247 | registering controller for all the BasicTransformerBlocks in the model 248 | ''' 249 | if net_.__class__.__name__ == 'BasicTransformerBlock': 250 | net_.forward = block_forward(net_, place_in_unet) 251 | return count + 1 252 | elif hasattr(net_, 'children'): 253 | for net__ in net_.children(): 254 | count = register_recr(net__, count, place_in_unet) 255 | return count 256 | 257 | block_count = 0 258 | sub_nets = model.named_children() 259 | for net in sub_nets: 260 | if "down" in net[0]: 261 | block_count += register_recr(net[1], 0, "down") 262 | print('down', block_count) 263 | elif "up" in net[0]: 264 | block_count += register_recr(net[1], 0, "up") 265 | print('up', block_count) 266 | if "mid" in net[0]: 267 | block_count += register_recr(net[1], 0, "mid") 268 | print('mid', block_count) 269 | controller.num_att_layers = block_count -------------------------------------------------------------------------------- /generate_casteer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pickle 4 | from PIL import Image 5 | from collections import defaultdict 6 | import time 7 | 8 | import torch 9 | from diffusers import StableDiffusionPipeline, DiffusionPipeline, AutoPipelineForText2Image 10 | 11 | # local imports 12 | from controller import VectorStore, register_vector_control 13 | 14 | # parsing arguments 15 | import argparse 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--model', type=str, choices=['sd14', 'sd21', 'sd21-turbo', 'sdxl', 'sdxl-turbo'], default="sd14") 18 | parser.add_argument('--prompt', type=str, default="a girl with a kitty") 19 | parser.add_argument('--seed', type=int, default=0) 20 | parser.add_argument('--steering_vectors', type=str) # path to steering vectors file 21 | parser.add_argument('--not_steer', action='store_true') 22 | parser.add_argument('--steer_only_up', action='store_true') 23 | parser.add_argument('--num_denoising_steps', type=int, default=50) # 50 for sd14, sd21, 1 for turbo, 30 for sdxl 24 | parser.add_argument('--steer_back', action='store_true') 25 | parser.add_argument('--alpha', type=int, default=10) 26 | parser.add_argument('--beta', type=int, default=2) 27 | parser.add_argument('--save_dir', type=str, default='images') # path to saving generated images 28 | args = parser.parse_args() 29 | 30 | 31 | if args.model == 'sd14': 32 | pipe = StableDiffusionPipeline.from_pretrained( 33 | "CompVis/stable-diffusion-v1-4", 34 | torch_dtype=torch.float16, 35 | cache_dir='./cache' 36 | ) 37 | elif args.model == 'sd21': 38 | pipe = StableDiffusionPipeline.from_pretrained( 39 | "stabilityai/stable-diffusion-2-1", 40 | torch_dtype=torch.float16, 41 | cache_dir='./cache' 42 | ) 43 | elif args.model == 'sd21-turbo': 44 | pipe = AutoPipelineForText2Image.from_pretrained( 45 | "stabilityai/sd-turbo", 46 | torch_dtype=torch.float16, 47 | variant="fp16", 48 | cache_dir='./cache' 49 | ) 50 | elif args.model == 'sdxl': 51 | pipe = DiffusionPipeline.from_pretrained( 52 | "stabilityai/stable-diffusion-xl-base-1.0", 53 | torch_dtype=torch.float16, 54 | use_safetensors=True, 55 | variant="fp16", 56 | cache_dir='./cache' 57 | ) 58 | elif args.model == 'sdxl-turbo': 59 | pipe = AutoPipelineForText2Image.from_pretrained( 60 | "stabilityai/sdxl-turbo", 61 | torch_dtype=torch.float16, 62 | variant="fp16", 63 | cache_dir='./cache' 64 | ) 65 | 66 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 67 | pipe.to(device) 68 | 69 | 70 | def run_model(model_type, pipe, prompt, seed, num_denoising_steps): 71 | if args.model in ['sd14', 'sd21', 'sdxl']: 72 | image = pipe(prompt=prompt, 73 | num_inference_steps=num_denoising_steps, 74 | generator=torch.Generator(device=device).manual_seed(seed) 75 | ).images[0] 76 | 77 | elif args.model in ['sd21-turbo', 'sdxl-turbo']: 78 | image = pipe(prompt=prompt, 79 | num_inference_steps=num_denoising_steps, 80 | guidance_scale=0.0, 81 | generator=torch.Generator(device=device).manual_seed(seed) 82 | ).images[0] 83 | 84 | return image 85 | 86 | print('Generating for prompt:') 87 | print(args.prompt) 88 | 89 | if not os.path.exists(args.save_dir): 90 | os.makedirs(args.save_dir) 91 | 92 | if args.not_steer: 93 | image = run_model(args.model, pipe, args.prompt, args.seed, args.num_denoising_steps) 94 | 95 | image.save(os.path.join(args.save_dir, "orig_{}_{}.png".format(args.prompt, args.seed))) 96 | 97 | 98 | else: 99 | with open(args.steering_vectors, 'rb') as handle: 100 | steering_vectors = pickle.load(handle) 101 | 102 | controller = VectorStore(steering_vectors, device=device) 103 | controller.steer_only_up = True if args.steer_only_up else False 104 | if args.steer_back: 105 | controller.steer_back = True 106 | controller.beta = args.beta 107 | else: 108 | controller.steer_back = False 109 | controller.alpha = args.alpha 110 | 111 | register_vector_control(pipe.unet, controller) 112 | 113 | image = run_model(args.model, pipe, args.prompt, args.seed, args.num_denoising_steps) 114 | 115 | image.save(os.path.join(args.save_dir, "steered_{}_{}.png".format(args.prompt, args.seed))) 116 | 117 | 118 | -------------------------------------------------------------------------------- /imagenet_classes.txt: -------------------------------------------------------------------------------- 1 | tench 2 | goldfish 3 | great white shark 4 | tiger shark 5 | hammerhead 6 | electric ray 7 | stingray 8 | cock 9 | hen 10 | ostrich 11 | brambling 12 | goldfinch 13 | house finch 14 | junco 15 | indigo bunting 16 | robin 17 | bulbul 18 | jay 19 | magpie 20 | chickadee 21 | water ouzel 22 | kite 23 | bald eagle 24 | vulture 25 | great grey owl 26 | European fire salamander 27 | common newt 28 | eft 29 | spotted salamander 30 | axolotl 31 | bullfrog 32 | tree frog 33 | tailed frog 34 | loggerhead 35 | leatherback turtle 36 | mud turtle 37 | terrapin 38 | box turtle 39 | banded gecko 40 | common iguana 41 | American chameleon 42 | whiptail 43 | agama 44 | frilled lizard 45 | alligator lizard 46 | Gila monster 47 | green lizard 48 | African chameleon 49 | Komodo dragon 50 | African crocodile 51 | American alligator 52 | triceratops 53 | thunder snake 54 | ringneck snake 55 | hognose snake 56 | green snake 57 | king snake 58 | garter snake 59 | water snake 60 | vine snake 61 | night snake 62 | boa constrictor 63 | rock python 64 | Indian cobra 65 | green mamba 66 | sea snake 67 | horned viper 68 | diamondback 69 | sidewinder 70 | trilobite 71 | harvestman 72 | scorpion 73 | black and gold garden spider 74 | barn spider 75 | garden spider 76 | black widow 77 | tarantula 78 | wolf spider 79 | tick 80 | centipede 81 | black grouse 82 | ptarmigan 83 | ruffed grouse 84 | prairie chicken 85 | peacock 86 | quail 87 | partridge 88 | African grey 89 | macaw 90 | sulphur-crested cockatoo 91 | lorikeet 92 | coucal 93 | bee eater 94 | hornbill 95 | hummingbird 96 | jacamar 97 | toucan 98 | drake 99 | red-breasted merganser 100 | goose 101 | black swan 102 | tusker 103 | echidna 104 | platypus 105 | wallaby 106 | koala 107 | wombat 108 | jellyfish 109 | sea anemone 110 | brain coral 111 | flatworm 112 | nematode 113 | conch 114 | snail 115 | slug 116 | sea slug 117 | chiton 118 | chambered nautilus 119 | Dungeness crab 120 | rock crab 121 | fiddler crab 122 | king crab 123 | American lobster 124 | spiny lobster 125 | crayfish 126 | hermit crab 127 | isopod 128 | white stork 129 | black stork 130 | spoonbill 131 | flamingo 132 | little blue heron 133 | American egret 134 | bittern 135 | crane 136 | limpkin 137 | European gallinule 138 | American coot 139 | bustard 140 | ruddy turnstone 141 | red-backed sandpiper 142 | redshank 143 | dowitcher 144 | oystercatcher 145 | pelican 146 | king penguin 147 | albatross 148 | grey whale 149 | killer whale 150 | dugong 151 | sea lion 152 | Chihuahua 153 | Japanese spaniel 154 | Maltese dog 155 | Pekinese 156 | Shih-Tzu 157 | Blenheim spaniel 158 | papillon 159 | toy terrier 160 | Rhodesian ridgeback 161 | Afghan hound 162 | basset 163 | beagle 164 | bloodhound 165 | bluetick 166 | black-and-tan coonhound 167 | Walker hound 168 | English foxhound 169 | redbone 170 | borzoi 171 | Irish wolfhound 172 | Italian greyhound 173 | whippet 174 | Ibizan hound 175 | Norwegian elkhound 176 | otterhound 177 | Saluki 178 | Scottish deerhound 179 | Weimaraner 180 | Staffordshire bullterrier 181 | American Staffordshire terrier 182 | Bedlington terrier 183 | Border terrier 184 | Kerry blue terrier 185 | Irish terrier 186 | Norfolk terrier 187 | Norwich terrier 188 | Yorkshire terrier 189 | wire-haired fox terrier 190 | Lakeland terrier 191 | Sealyham terrier 192 | Airedale 193 | cairn 194 | Australian terrier 195 | Dandie Dinmont 196 | Boston bull 197 | miniature schnauzer 198 | giant schnauzer 199 | standard schnauzer 200 | Scotch terrier 201 | Tibetan terrier 202 | silky terrier 203 | soft-coated wheaten terrier 204 | West Highland white terrier 205 | Lhasa 206 | flat-coated retriever 207 | curly-coated retriever 208 | golden retriever 209 | Labrador retriever 210 | Chesapeake Bay retriever 211 | German short-haired pointer 212 | vizsla 213 | English setter 214 | Irish setter 215 | Gordon setter 216 | Brittany spaniel 217 | clumber 218 | English springer 219 | Welsh springer spaniel 220 | cocker spaniel 221 | Sussex spaniel 222 | Irish water spaniel 223 | kuvasz 224 | schipperke 225 | groenendael 226 | malinois 227 | briard 228 | kelpie 229 | komondor 230 | Old English sheepdog 231 | Shetland sheepdog 232 | collie 233 | Border collie 234 | Bouvier des Flandres 235 | Rottweiler 236 | German shepherd 237 | Doberman 238 | miniature pinscher 239 | Greater Swiss Mountain dog 240 | Bernese mountain dog 241 | Appenzeller 242 | EntleBucher 243 | boxer 244 | bull mastiff 245 | Tibetan mastiff 246 | French bulldog 247 | Great Dane 248 | Saint Bernard 249 | Eskimo dog 250 | malamute 251 | Siberian husky 252 | dalmatian 253 | affenpinscher 254 | basenji 255 | pug 256 | Leonberg 257 | Newfoundland 258 | Great Pyrenees 259 | Samoyed 260 | Pomeranian 261 | chow 262 | keeshond 263 | Brabancon griffon 264 | Pembroke 265 | Cardigan 266 | toy poodle 267 | miniature poodle 268 | standard poodle 269 | Mexican hairless 270 | timber wolf 271 | white wolf 272 | red wolf 273 | coyote 274 | dingo 275 | dhole 276 | African hunting dog 277 | hyena 278 | red fox 279 | kit fox 280 | Arctic fox 281 | grey fox 282 | tabby 283 | tiger cat 284 | Persian cat 285 | Siamese cat 286 | Egyptian cat 287 | cougar 288 | lynx 289 | leopard 290 | snow leopard 291 | jaguar 292 | lion 293 | tiger 294 | cheetah 295 | brown bear 296 | American black bear 297 | ice bear 298 | sloth bear 299 | mongoose 300 | meerkat 301 | tiger beetle 302 | ladybug 303 | ground beetle 304 | long-horned beetle 305 | leaf beetle 306 | dung beetle 307 | rhinoceros beetle 308 | weevil 309 | fly 310 | bee 311 | ant 312 | grasshopper 313 | cricket 314 | walking stick 315 | cockroach 316 | mantis 317 | cicada 318 | leafhopper 319 | lacewing 320 | dragonfly 321 | damselfly 322 | admiral 323 | ringlet 324 | monarch 325 | cabbage butterfly 326 | sulphur butterfly 327 | lycaenid 328 | starfish 329 | sea urchin 330 | sea cucumber 331 | wood rabbit 332 | hare 333 | Angora 334 | hamster 335 | porcupine 336 | fox squirrel 337 | marmot 338 | beaver 339 | guinea pig 340 | sorrel 341 | zebra 342 | hog 343 | wild boar 344 | warthog 345 | hippopotamus 346 | ox 347 | water buffalo 348 | bison 349 | ram 350 | bighorn 351 | ibex 352 | hartebeest 353 | impala 354 | gazelle 355 | Arabian camel 356 | llama 357 | weasel 358 | mink 359 | polecat 360 | black-footed ferret 361 | otter 362 | skunk 363 | badger 364 | armadillo 365 | three-toed sloth 366 | orangutan 367 | gorilla 368 | chimpanzee 369 | gibbon 370 | siamang 371 | guenon 372 | patas 373 | baboon 374 | macaque 375 | langur 376 | colobus 377 | proboscis monkey 378 | marmoset 379 | capuchin 380 | howler monkey 381 | titi 382 | spider monkey 383 | squirrel monkey 384 | Madagascar cat 385 | indri 386 | Indian elephant 387 | African elephant 388 | lesser panda 389 | giant panda 390 | barracouta 391 | eel 392 | coho 393 | rock beauty 394 | anemone fish 395 | sturgeon 396 | gar 397 | lionfish 398 | puffer 399 | abacus 400 | abaya 401 | academic gown 402 | accordion 403 | acoustic guitar 404 | aircraft carrier 405 | airliner 406 | airship 407 | altar 408 | ambulance 409 | amphibian 410 | analog clock 411 | apiary 412 | apron 413 | ashcan 414 | assault rifle 415 | backpack 416 | bakery 417 | balance beam 418 | balloon 419 | ballpoint 420 | Band Aid 421 | banjo 422 | bannister 423 | barbell 424 | barber chair 425 | barbershop 426 | barn 427 | barometer 428 | barrel 429 | barrow 430 | baseball 431 | basketball 432 | bassinet 433 | bassoon 434 | bathing cap 435 | bath towel 436 | bathtub 437 | beach wagon 438 | beacon 439 | beaker 440 | bearskin 441 | beer bottle 442 | beer glass 443 | bell cote 444 | bib 445 | bicycle-built-for-two 446 | bikini 447 | binder 448 | binoculars 449 | birdhouse 450 | boathouse 451 | bobsled 452 | bolo tie 453 | bonnet 454 | bookcase 455 | bookshop 456 | bottlecap 457 | bow 458 | bow tie 459 | brass 460 | brassiere 461 | breakwater 462 | breastplate 463 | broom 464 | bucket 465 | buckle 466 | bulletproof vest 467 | bullet train 468 | butcher shop 469 | cab 470 | caldron 471 | candle 472 | cannon 473 | canoe 474 | can opener 475 | cardigan 476 | car mirror 477 | carousel 478 | carpenter's kit 479 | carton 480 | car wheel 481 | cash machine 482 | cassette 483 | cassette player 484 | castle 485 | catamaran 486 | CD player 487 | cello 488 | cellular telephone 489 | chain 490 | chainlink fence 491 | chain mail 492 | chain saw 493 | chest 494 | chiffonier 495 | chime 496 | china cabinet 497 | Christmas stocking 498 | church 499 | cinema 500 | cleaver 501 | cliff dwelling 502 | cloak 503 | clog 504 | cocktail shaker 505 | coffee mug 506 | coffeepot 507 | coil 508 | combination lock 509 | computer keyboard 510 | confectionery 511 | container ship 512 | convertible 513 | corkscrew 514 | cornet 515 | cowboy boot 516 | cowboy hat 517 | cradle 518 | crane 519 | crash helmet 520 | crate 521 | crib 522 | Crock Pot 523 | croquet ball 524 | crutch 525 | cuirass 526 | dam 527 | desk 528 | desktop computer 529 | dial telephone 530 | diaper 531 | digital clock 532 | digital watch 533 | dining table 534 | dishrag 535 | dishwasher 536 | disk brake 537 | dock 538 | dogsled 539 | dome 540 | doormat 541 | drilling platform 542 | drum 543 | drumstick 544 | dumbbell 545 | Dutch oven 546 | electric fan 547 | electric guitar 548 | electric locomotive 549 | entertainment center 550 | envelope 551 | espresso maker 552 | face powder 553 | feather boa 554 | file 555 | fireboat 556 | fire engine 557 | fire screen 558 | flagpole 559 | flute 560 | folding chair 561 | football helmet 562 | forklift 563 | fountain 564 | fountain pen 565 | four-poster 566 | freight car 567 | French horn 568 | frying pan 569 | fur coat 570 | garbage truck 571 | gasmask 572 | gas pump 573 | goblet 574 | go-kart 575 | golf ball 576 | golfcart 577 | gondola 578 | gong 579 | gown 580 | grand piano 581 | greenhouse 582 | grille 583 | grocery store 584 | guillotine 585 | hair slide 586 | hair spray 587 | half track 588 | hammer 589 | hamper 590 | hand blower 591 | hand-held computer 592 | handkerchief 593 | hard disc 594 | harmonica 595 | harp 596 | harvester 597 | hatchet 598 | holster 599 | home theater 600 | honeycomb 601 | hook 602 | hoopskirt 603 | horizontal bar 604 | horse cart 605 | hourglass 606 | iPod 607 | iron 608 | jack-o'-lantern 609 | jean 610 | jeep 611 | jersey 612 | jigsaw puzzle 613 | jinrikisha 614 | joystick 615 | kimono 616 | knee pad 617 | knot 618 | lab coat 619 | ladle 620 | lampshade 621 | laptop 622 | lawn mower 623 | lens cap 624 | letter opener 625 | library 626 | lifeboat 627 | lighter 628 | limousine 629 | liner 630 | lipstick 631 | Loafer 632 | lotion 633 | loudspeaker 634 | loupe 635 | lumbermill 636 | magnetic compass 637 | mailbag 638 | mailbox 639 | maillot 640 | maillot 641 | manhole cover 642 | maraca 643 | marimba 644 | mask 645 | matchstick 646 | maypole 647 | maze 648 | measuring cup 649 | medicine chest 650 | megalith 651 | microphone 652 | microwave 653 | military uniform 654 | milk can 655 | minibus 656 | miniskirt 657 | minivan 658 | missile 659 | mitten 660 | mixing bowl 661 | mobile home 662 | Model T 663 | modem 664 | monastery 665 | monitor 666 | moped 667 | mortar 668 | mortarboard 669 | mosque 670 | mosquito net 671 | motor scooter 672 | mountain bike 673 | mountain tent 674 | mouse 675 | mousetrap 676 | moving van 677 | muzzle 678 | nail 679 | neck brace 680 | necklace 681 | nipple 682 | notebook 683 | obelisk 684 | oboe 685 | ocarina 686 | odometer 687 | oil filter 688 | organ 689 | oscilloscope 690 | overskirt 691 | oxcart 692 | oxygen mask 693 | packet 694 | paddle 695 | paddlewheel 696 | padlock 697 | paintbrush 698 | pajama 699 | palace 700 | panpipe 701 | paper towel 702 | parachute 703 | parallel bars 704 | park bench 705 | parking meter 706 | passenger car 707 | patio 708 | pay-phone 709 | pedestal 710 | pencil box 711 | pencil sharpener 712 | perfume 713 | Petri dish 714 | photocopier 715 | pick 716 | pickelhaube 717 | picket fence 718 | pickup 719 | pier 720 | piggy bank 721 | pill bottle 722 | pillow 723 | ping-pong ball 724 | pinwheel 725 | pirate 726 | pitcher 727 | plane 728 | planetarium 729 | plastic bag 730 | plate rack 731 | plow 732 | plunger 733 | Polaroid camera 734 | pole 735 | police van 736 | poncho 737 | pool table 738 | pop bottle 739 | pot 740 | potter's wheel 741 | power drill 742 | prayer rug 743 | printer 744 | prison 745 | projectile 746 | projector 747 | puck 748 | punching bag 749 | purse 750 | quill 751 | quilt 752 | racer 753 | racket 754 | radiator 755 | radio 756 | radio telescope 757 | rain barrel 758 | recreational vehicle 759 | reel 760 | reflex camera 761 | refrigerator 762 | remote control 763 | restaurant 764 | revolver 765 | rifle 766 | rocking chair 767 | rotisserie 768 | rubber eraser 769 | rugby ball 770 | rule 771 | running shoe 772 | safe 773 | safety pin 774 | saltshaker 775 | sandal 776 | sarong 777 | sax 778 | scabbard 779 | scale 780 | school bus 781 | schooner 782 | scoreboard 783 | screen 784 | screw 785 | screwdriver 786 | seat belt 787 | sewing machine 788 | shield 789 | shoe shop 790 | shoji 791 | shopping basket 792 | shopping cart 793 | shovel 794 | shower cap 795 | shower curtain 796 | ski 797 | ski mask 798 | sleeping bag 799 | slide rule 800 | sliding door 801 | slot 802 | snorkel 803 | snowmobile 804 | snowplow 805 | soap dispenser 806 | soccer ball 807 | sock 808 | solar dish 809 | sombrero 810 | soup bowl 811 | space bar 812 | space heater 813 | space shuttle 814 | spatula 815 | speedboat 816 | spider web 817 | spindle 818 | sports car 819 | spotlight 820 | stage 821 | steam locomotive 822 | steel arch bridge 823 | steel drum 824 | stethoscope 825 | stole 826 | stone wall 827 | stopwatch 828 | stove 829 | strainer 830 | streetcar 831 | stretcher 832 | studio couch 833 | stupa 834 | submarine 835 | suit 836 | sundial 837 | sunglass 838 | sunglasses 839 | sunscreen 840 | suspension bridge 841 | swab 842 | sweatshirt 843 | swimming trunks 844 | swing 845 | switch 846 | syringe 847 | table lamp 848 | tank 849 | tape player 850 | teapot 851 | teddy 852 | television 853 | tennis ball 854 | thatch 855 | theater curtain 856 | thimble 857 | thresher 858 | throne 859 | tile roof 860 | toaster 861 | tobacco shop 862 | toilet seat 863 | torch 864 | totem pole 865 | tow truck 866 | toyshop 867 | tractor 868 | trailer truck 869 | tray 870 | trench coat 871 | tricycle 872 | trimaran 873 | tripod 874 | triumphal arch 875 | trolleybus 876 | trombone 877 | tub 878 | turnstile 879 | typewriter keyboard 880 | umbrella 881 | unicycle 882 | upright 883 | vacuum 884 | vase 885 | vault 886 | velvet 887 | vending machine 888 | vestment 889 | viaduct 890 | violin 891 | volleyball 892 | waffle iron 893 | wall clock 894 | wallet 895 | wardrobe 896 | warplane 897 | washbasin 898 | washer 899 | water bottle 900 | water jug 901 | water tower 902 | whiskey jug 903 | whistle 904 | wig 905 | window screen 906 | window shade 907 | Windsor tie 908 | wine bottle 909 | wing 910 | wok 911 | wooden spoon 912 | wool 913 | worm fence 914 | wreck 915 | yawl 916 | yurt 917 | web site 918 | comic book 919 | crossword puzzle 920 | street sign 921 | traffic light 922 | book jacket 923 | menu 924 | plate 925 | guacamole 926 | consomme 927 | hot pot 928 | trifle 929 | ice cream 930 | ice lolly 931 | French loaf 932 | bagel 933 | pretzel 934 | cheeseburger 935 | hotdog 936 | mashed potato 937 | head cabbage 938 | broccoli 939 | cauliflower 940 | zucchini 941 | spaghetti squash 942 | acorn squash 943 | butternut squash 944 | cucumber 945 | artichoke 946 | bell pepper 947 | cardoon 948 | mushroom 949 | Granny Smith 950 | strawberry 951 | orange 952 | lemon 953 | fig 954 | pineapple 955 | banana 956 | jackfruit 957 | custard apple 958 | pomegranate 959 | hay 960 | carbonara 961 | chocolate sauce 962 | dough 963 | meat loaf 964 | pizza 965 | potpie 966 | burrito 967 | red wine 968 | espresso 969 | cup 970 | eggnog 971 | alp 972 | bubble 973 | cliff 974 | coral reef 975 | geyser 976 | lakeside 977 | promontory 978 | sandbar 979 | seashore 980 | valley 981 | volcano 982 | ballplayer 983 | groom 984 | scuba diver 985 | rapeseed 986 | daisy 987 | yellow lady's slipper 988 | corn 989 | acorn 990 | hip 991 | buckeye 992 | coral fungus 993 | agaric 994 | gyromitra 995 | stinkhorn 996 | earthstar 997 | hen-of-the-woods 998 | bolete 999 | ear 1000 | toilet tissue --------------------------------------------------------------------------------