├── README.md ├── blog ├── 1 zkOBQ9Izq28yXCANTmdKtA.webp ├── ddpo.ipynb ├── gasbI.png ├── image-1.png ├── image-2.png ├── image-3.png └── image.png ├── main.py └── requirements.txt /README.md: -------------------------------------------------------------------------------- 1 | # RLHF for Diffusion Models 2 | 3 | This is an implementation of [Training Diffusion Models with Reinforcement Learning](https://arxiv.org/abs/2305.13301). This is meant as an educational codebase, with lots of comments explaining the code 4 | and only basic features. It currently only implements [LAION aesthetic classifier](https://github.com/LAION-AI/aesthetic-predictor) as a reward function, but more examples will be added soon. 5 | 6 | *Tutorial blog post coming soon* 7 | 8 | _This codebase is just for educational purposes, another codebase for scalable training is being developed [here](https://github.com/CarperAI/DRLX)._ 9 | 10 | ## Installation 11 | ``` 12 | git clone https://github.com/tmabraham/ddpo-pytorch.git 13 | cd ddpo-pytorch 14 | pip install -r requirements.txt 15 | ``` 16 | 17 | ## Usage 18 | 19 | It's as simple as running: 20 | ``` 21 | python main.py 22 | ``` 23 | 24 | To save memory (you'll likely need it), use the arguments `--enable_attention_slicing`, `--enable_xformers_memory_efficient_attention`, and `--enable_grad_checkpointing`. 25 | 26 | ## Results 27 | 28 | Original samples: 29 | ![image](https://github.com/tmabraham/ddpo-pytorch/assets/37097934/6a9489a2-9cfb-4e21-84c5-eaa2694acbd4) 30 | 31 | After training for 50 epochs: 32 | ![image](https://github.com/tmabraham/ddpo-pytorch/assets/37097934/a82ce5ce-2e29-4adf-b06c-601295be288d) 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /blog/1 zkOBQ9Izq28yXCANTmdKtA.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tmabraham/ddpo-pytorch/a8ca13f95b2fb052438f2d461f0ffb541eec208b/blog/1 zkOBQ9Izq28yXCANTmdKtA.webp -------------------------------------------------------------------------------- /blog/gasbI.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tmabraham/ddpo-pytorch/a8ca13f95b2fb052438f2d461f0ffb541eec208b/blog/gasbI.png -------------------------------------------------------------------------------- /blog/image-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tmabraham/ddpo-pytorch/a8ca13f95b2fb052438f2d461f0ffb541eec208b/blog/image-1.png -------------------------------------------------------------------------------- /blog/image-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tmabraham/ddpo-pytorch/a8ca13f95b2fb052438f2d461f0ffb541eec208b/blog/image-2.png -------------------------------------------------------------------------------- /blog/image-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tmabraham/ddpo-pytorch/a8ca13f95b2fb052438f2d461f0ffb541eec208b/blog/image-3.png -------------------------------------------------------------------------------- /blog/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tmabraham/ddpo-pytorch/a8ca13f95b2fb052438f2d461f0ffb541eec208b/blog/image.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import requests 4 | from pathlib import Path 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import clip # pip install git+https://github.com/openai/CLIP.git 8 | import torch 9 | import random 10 | import math 11 | import wandb 12 | from torch import nn 13 | from diffusers import StableDiffusionPipeline, DDIMScheduler 14 | from PIL import Image 15 | from fastprogress import progress_bar, master_bar 16 | from collections import deque 17 | 18 | 19 | # tf32, performance optimization 20 | torch.backends.cuda.matmul.allow_tf32 = True 21 | 22 | def parse_args(): 23 | args = argparse.ArgumentParser() 24 | args.add_argument("--model", type=str, help="model name", default="CompVis/stable-diffusion-v1-4") 25 | args.add_argument("--enable_attention_slicing", action="store_true") 26 | args.add_argument("--enable_xformers_memory_efficient_attention", action="store_true") 27 | args.add_argument("--enable_grad_checkpointing", action="store_true") 28 | args.add_argument("--num_samples_per_epoch", type=int, default=128) 29 | args.add_argument("--num_epochs", type=int, default=50) 30 | args.add_argument("--num_inner_epochs", type=int, default=1) 31 | args.add_argument("--num_timesteps", type=int, default=50) 32 | args.add_argument("--batch_size", type=int, default=4) 33 | args.add_argument("--sample_batch_size", type=int, default=32) 34 | args.add_argument("--img_size", type=int, default=512) 35 | args.add_argument("--lr", type=float, default=5e-6) 36 | args.add_argument("--weight_decay", type=float, default=1e-4) 37 | args.add_argument("--clip_advantages", type=float, default=10.0) 38 | args.add_argument("--clip_ratio", type=float, default=1e-4) 39 | args.add_argument("--cfg", type=float, default=5.0) 40 | args.add_argument("--buffer_size", type=int, default=32) 41 | args.add_argument("--min_count", type=int, default=16) 42 | args.add_argument("--wandb_project", type=str, default="DDPO") 43 | args.add_argument("--gpu", type=int, default=0) 44 | args.add_argument("--output_dir", type=str, default="ddpo_model") 45 | return args.parse_args() 46 | 47 | 48 | 49 | class MLP(nn.Module): 50 | def __init__(self, input_size, xcol='emb', ycol='avg_rating'): 51 | super().__init__() 52 | self.input_size = input_size 53 | self.xcol = xcol 54 | self.ycol = ycol 55 | self.layers = nn.Sequential( 56 | nn.Linear(self.input_size, 1024), 57 | #nn.ReLU(), 58 | nn.Dropout(0.2), 59 | nn.Linear(1024, 128), 60 | #nn.ReLU(), 61 | nn.Dropout(0.2), 62 | nn.Linear(128, 64), 63 | #nn.ReLU(), 64 | nn.Dropout(0.1), 65 | 66 | nn.Linear(64, 16), 67 | #nn.ReLU(), 68 | 69 | nn.Linear(16, 1) 70 | ) 71 | 72 | def forward(self, x): 73 | return self.layers(x) 74 | 75 | def load_aesthetic_model_weights(cache="."): 76 | weights_fname = "sac+logos+ava1-l14-linearMSE.pth" 77 | loadpath = os.path.join(cache, weights_fname) 78 | 79 | if not os.path.exists(loadpath): 80 | url = ( 81 | "https://github.com/christophschuhmann/" 82 | f"improved-aesthetic-predictor/blob/main/{weights_fname}?raw=true" 83 | ) 84 | r = requests.get(url) 85 | 86 | with open(loadpath, "wb") as f: 87 | f.write(r.content) 88 | 89 | weights = torch.load(loadpath, map_location=torch.device("cpu")) 90 | return weights 91 | 92 | def aesthetic_model_normalize(a, axis=-1, order=2): 93 | l2 = np.atleast_1d(np.linalg.norm(a, order, axis)) 94 | l2[l2 == 0] = 1 95 | return a / np.expand_dims(l2, axis) 96 | 97 | 98 | def imagenet_animal_prompts(): 99 | animal = random.choice(imagenet_classes[:397]) 100 | prompts = f'{animal}' 101 | return prompts 102 | 103 | class PromptDataset(torch.utils.data.Dataset): 104 | def __init__(self, prompt_fn, num): 105 | super().__init__() 106 | self.prompt_fn = prompt_fn 107 | self.num = num 108 | 109 | def __len__(self): return self.num 110 | def __getitem__(self, x): return self.prompt_fn() 111 | 112 | @torch.no_grad() 113 | def decoding_fn(latents,pipe): 114 | images = pipe.vae.decode(1 / 0.18215 * latents.cuda()).sample 115 | images = (images / 2 + 0.5).clamp(0, 1) 116 | images = images.detach().cpu().permute(0, 2, 3, 1).numpy() 117 | images = (images * 255).round().astype("uint8") 118 | return images 119 | 120 | def aesthetic_scoring(imgs, preprocess, clip_model, aesthetic_model_normalize, aesthetic_model): 121 | imgs = torch.stack([preprocess(Image.fromarray(img)).cuda() for img in imgs]) 122 | with torch.no_grad(): image_features = clip_model.encode_image(imgs) 123 | im_emb_arr = aesthetic_model_normalize(image_features.cpu().detach().numpy()) 124 | prediction = aesthetic_model(torch.from_numpy(im_emb_arr).float().cuda()) 125 | return prediction 126 | 127 | class PerPromptStatTracker: 128 | def __init__(self, buffer_size, min_count): 129 | self.buffer_size = buffer_size 130 | self.min_count = min_count 131 | self.stats = {} 132 | 133 | def update(self, prompts, rewards): 134 | unique = np.unique(prompts) 135 | advantages = np.empty_like(rewards) 136 | for prompt in unique: 137 | prompt_rewards = rewards[prompts == prompt] 138 | if prompt not in self.stats: 139 | self.stats[prompt] = deque(maxlen=self.buffer_size) 140 | self.stats[prompt].extend(prompt_rewards) 141 | 142 | if len(self.stats[prompt]) < self.min_count: 143 | mean = np.mean(rewards) 144 | std = np.std(rewards) + 1e-6 145 | else: 146 | mean = np.mean(self.stats[prompt]) 147 | std = np.std(self.stats[prompt]) + 1e-6 148 | advantages[prompts == prompt] = (prompt_rewards - mean) / std 149 | 150 | return advantages 151 | 152 | def calculate_log_probs(prev_sample, prev_sample_mean, std_dev_t): 153 | std_dev_t = torch.clip(std_dev_t, 1e-6) 154 | log_probs = -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * std_dev_t ** 2) - torch.log(std_dev_t) - math.log(math.sqrt(2 * math.pi)) 155 | return log_probs 156 | 157 | @torch.no_grad() 158 | def sd_sample(prompts, pipe, height, width, guidance_scale, num_inference_steps, eta, device): 159 | scheduler = pipe.scheduler 160 | unet = pipe.unet 161 | text_embeddings = pipe._encode_prompt(prompts,device, 1, do_classifier_free_guidance=guidance_scale > 1.0) 162 | 163 | scheduler.set_timesteps(num_inference_steps, device=device) 164 | latents = torch.randn((len(prompts), unet.in_channels, height//8, width//8)).to(device) 165 | 166 | all_step_preds, log_probs = [latents], [] 167 | 168 | 169 | for i, t in enumerate(progress_bar(scheduler.timesteps)): 170 | input = torch.cat([latents] * 2) 171 | input = scheduler.scale_model_input(input, t) 172 | 173 | # predict the noise residual 174 | pred = unet(input, t, encoder_hidden_states=text_embeddings).sample 175 | 176 | # perform guidance 177 | pred_uncond, pred_text = pred.chunk(2) 178 | pred = pred_uncond + guidance_scale * (pred_text - pred_uncond) 179 | 180 | # compute the "previous" noisy sample mean and variance, and get log probs 181 | scheduler_output = scheduler.step(pred, t, latents, eta, variance_noise=0) 182 | t_1 = t - scheduler.config.num_train_timesteps // num_inference_steps 183 | variance = scheduler._get_variance(t, t_1) 184 | std_dev_t = eta * variance ** (0.5) 185 | prev_sample_mean = scheduler_output.prev_sample # this is the mean and not full sample since variance is 0 186 | prev_sample = prev_sample_mean + torch.randn_like(prev_sample_mean) * std_dev_t # get full sample by adding noise 187 | log_probs.append(calculate_log_probs(prev_sample, prev_sample_mean, std_dev_t).mean(dim=tuple(range(1, prev_sample_mean.ndim)))) 188 | 189 | all_step_preds.append(prev_sample) 190 | latents = prev_sample 191 | 192 | return latents, torch.stack(all_step_preds), torch.stack(log_probs) 193 | 194 | def compute_loss(x_t, original_log_probs, advantages, clip_advantages, clip_ratio, prompts, pipe, num_inference_steps, guidance_scale, eta, device): 195 | scheduler = pipe.scheduler 196 | unet = pipe.unet 197 | text_embeddings = pipe._encode_prompt(prompts,device, 1, do_classifier_free_guidance=guidance_scale > 1.0).detach() 198 | scheduler.set_timesteps(num_inference_steps, device=device) 199 | loss_value = 0. 200 | for i, t in enumerate(progress_bar(scheduler.timesteps)): 201 | clipped_advantages = torch.clip(advantages, -clip_advantages, clip_advantages).detach() 202 | 203 | input = torch.cat([x_t[i].detach()] * 2) 204 | input = scheduler.scale_model_input(input, t) 205 | 206 | # predict the noise residual 207 | pred = unet(input, t, encoder_hidden_states=text_embeddings).sample 208 | 209 | # perform guidance 210 | pred_uncond, pred_text = pred.chunk(2) 211 | pred = pred_uncond + guidance_scale * (pred_text - pred_uncond) 212 | 213 | # compute the "previous" noisy sample mean and variance, and get log probs 214 | scheduler_output = scheduler.step(pred, t, x_t[i].detach(), eta, variance_noise=0) 215 | t_1 = t - scheduler.config.num_train_timesteps // num_inference_steps 216 | variance = scheduler._get_variance(t, t_1) 217 | std_dev_t = eta * variance ** (0.5) 218 | prev_sample_mean = scheduler_output.prev_sample 219 | current_log_probs = calculate_log_probs(x_t[i+1].detach(), prev_sample_mean, std_dev_t).mean(dim=tuple(range(1, prev_sample_mean.ndim))) 220 | 221 | # calculate loss 222 | 223 | ratio = torch.exp(current_log_probs - original_log_probs[i].detach()) # this is the ratio of the new policy to the old policy 224 | unclipped_loss = -clipped_advantages * ratio # this is the surrogate loss 225 | clipped_loss = -clipped_advantages * torch.clip(ratio, 1. - clip_ratio, 1. + clip_ratio) # this is the surrogate loss, but with artificially clipped ratios 226 | loss = torch.max(unclipped_loss, clipped_loss).mean() # we take the max of the clipped and unclipped surrogate losses, and take the mean over the batch 227 | loss.backward() 228 | 229 | loss_value += loss.item() 230 | return loss_value 231 | 232 | 233 | 234 | 235 | if __name__ == '__main__': 236 | args = parse_args() 237 | 238 | # set the gpu 239 | torch.cuda.set_device(args.gpu) 240 | 241 | wandb.init( 242 | # set the wandb project where this run will be logged 243 | project=args.wandb_project, 244 | 245 | # track hyperparameters and run metadata 246 | config={ 247 | "num_samples_per_epoch": args.num_samples_per_epoch, 248 | "num_epochs": args.num_epochs, 249 | "num_inner_epochs": args.num_inner_epochs, 250 | "num_timesteps": args.num_timesteps, 251 | "batch_size": args.batch_size, 252 | "lr": args.lr 253 | } 254 | ) 255 | 256 | # setup diffusion model 257 | pipe = StableDiffusionPipeline.from_pretrained(args.model).to("cuda") 258 | if args.enable_attention_slicing: pipe.enable_attention_slicing() 259 | if args.enable_xformers_memory_efficient_attention: pipe.enable_xformers_memory_efficient_attention() 260 | pipe.text_encoder.requires_grad_(False) 261 | pipe.vae.requires_grad_(False) 262 | 263 | # only tested and works with DDIM for now 264 | pipe.scheduler = DDIMScheduler( 265 | num_train_timesteps=pipe.scheduler.num_train_timesteps, 266 | beta_start=pipe.scheduler.beta_start, 267 | beta_end=pipe.scheduler.beta_end, 268 | beta_schedule=pipe.scheduler.beta_schedule, 269 | trained_betas=pipe.scheduler.trained_betas, 270 | clip_sample=pipe.scheduler.clip_sample, 271 | set_alpha_to_one=pipe.scheduler.set_alpha_to_one, 272 | steps_offset=pipe.scheduler.steps_offset, 273 | prediction_type=pipe.scheduler.prediction_type 274 | ) 275 | 276 | # setup reward model 277 | clip_model, preprocess = clip.load("ViT-L/14", device="cuda") 278 | aesthetic_model = MLP(768) 279 | aesthetic_model.load_state_dict(load_aesthetic_model_weights()) 280 | aesthetic_model.cuda() 281 | 282 | # download url to file 283 | r = requests.get("https://raw.githubusercontent.com/formigone/tf-imagenet/master/LOC_synset_mapping.txt") 284 | with open("LOC_synset_mapping.txt", "wb") as f: f.write(r.content) 285 | synsets = {k:v for k,v in [o.split(',')[0].split(' ', maxsplit=1) for o in Path('LOC_synset_mapping.txt').read_text().splitlines()]} 286 | imagenet_classes = list(synsets.values()) 287 | 288 | # group all reward function stuff 289 | def reward_fn(imgs, device): 290 | clip_model.to(device) 291 | aesthetic_model.to(device) 292 | rewards = aesthetic_scoring(imgs, preprocess, clip_model, aesthetic_model_normalize, aesthetic_model) 293 | clip_model.to('cpu') 294 | aesthetic_model.to('cpu') 295 | return rewards 296 | 297 | # a function to sample from the model and calculate rewards 298 | def sample_and_calculate_rewards(prompts, pipe, image_size, cfg, num_timesteps, decoding_fn, reward_fn, device): 299 | preds, all_step_preds, log_probs = sd_sample(prompts, pipe, image_size, image_size, cfg, num_timesteps, 1, device) 300 | imgs = decoding_fn(preds,pipe) 301 | rewards = reward_fn(imgs, device) 302 | return imgs, rewards, all_step_preds, log_probs 303 | 304 | 305 | train_set = PromptDataset(imagenet_animal_prompts, args.num_samples_per_epoch) 306 | train_dl = torch.utils.data.DataLoader(train_set, batch_size=args.sample_batch_size, shuffle=True, num_workers=0) 307 | 308 | sample_prompts = next(iter(train_dl)) # sample a batch of prompts to use for visualization 309 | 310 | if args.enable_grad_checkpointing: pipe.unet.enable_gradient_checkpointing() # more performance optimization 311 | 312 | optimizer = torch.optim.AdamW(pipe.unet.parameters(), lr=args.lr, weight_decay=args.weight_decay) 313 | per_prompt_stat_tracker = PerPromptStatTracker(args.buffer_size, args.min_count) 314 | 315 | mean_rewards = [] 316 | for epoch in master_bar(range(args.num_epochs)): 317 | print(f'Epoch {epoch}') 318 | all_step_preds, log_probs, advantages, all_prompts, all_rewards = [], [], [], [], [] 319 | 320 | # sampling `num_samples_per_epoch` images and calculating rewards 321 | for i, prompts in enumerate(progress_bar(train_dl)): 322 | batch_imgs, rewards, batch_all_step_preds, batch_log_probs = sample_and_calculate_rewards(prompts, pipe, args.img_size, args.cfg, args.num_timesteps, decoding_fn, reward_fn, 'cuda') 323 | batch_advantages = torch.from_numpy(per_prompt_stat_tracker.update(np.array(prompts), rewards.squeeze().cpu().detach().numpy())).float().to('cuda') 324 | all_step_preds.append(batch_all_step_preds) 325 | log_probs.append(batch_log_probs) 326 | advantages.append(batch_advantages) 327 | all_prompts += prompts 328 | all_rewards.append(rewards) 329 | 330 | all_step_preds = torch.cat(all_step_preds, dim=1) 331 | log_probs = torch.cat(log_probs, dim=1) 332 | advantages = torch.cat(advantages) 333 | all_rewards = torch.cat(all_rewards) 334 | 335 | mean_rewards.append(all_rewards.mean().item()) 336 | 337 | wandb.log({"mean_reward": mean_rewards[-1]}) 338 | wandb.log({"reward_hist": wandb.Histogram(all_rewards.detach().cpu().numpy())}) 339 | wandb.log({"img batch": [wandb.Image(Image.fromarray(img), caption=prompt) for img, prompt in zip(batch_imgs, prompts)]}) 340 | 341 | 342 | # sample some images with the consistent prompt for visualization 343 | sample_imgs, sample_rewards, _, _ = sample_and_calculate_rewards(sample_prompts, pipe, args.img_size, args.cfg, args.num_timesteps, decoding_fn, reward_fn, 'cuda') 344 | wandb.log({"sample img batch": [wandb.Image(Image.fromarray(img), caption=prompt + f', {reward.item()}') for img, prompt, reward in zip(sample_imgs, sample_prompts, sample_rewards)]}) 345 | 346 | # inner loop 347 | for inner_epoch in progress_bar(range(args.num_inner_epochs)): 348 | print(f'Inner epoch {inner_epoch}') 349 | 350 | # chunk them into batches 351 | all_step_preds_chunked = torch.chunk(all_step_preds, args.num_samples_per_epoch // args.batch_size, dim=1) 352 | log_probs_chunked = torch.chunk(log_probs, args.num_samples_per_epoch // args.batch_size, dim=1) 353 | advantages_chunked = torch.chunk(advantages, args.num_samples_per_epoch // args.batch_size, dim=0) 354 | 355 | # chunk the prompts (list of strings) into batches 356 | all_prompts_chunked = [all_prompts[i:i + args.batch_size] for i in range(0, len(all_prompts), args.batch_size)] 357 | 358 | for i in progress_bar(range(len(all_step_preds_chunked))): 359 | optimizer.zero_grad() 360 | 361 | loss = compute_loss(all_step_preds_chunked[i], log_probs_chunked[i], 362 | advantages_chunked[i], args.clip_advantages, args.clip_ratio, all_prompts_chunked[i], pipe, args.num_timesteps, args.cfg, 1, 'cuda' 363 | ) # loss.backward happens inside 364 | 365 | torch.nn.utils.clip_grad_norm_(pipe.unet.parameters(), 1.0) # gradient clipping 366 | optimizer.step() 367 | wandb.log({"loss": loss, "epoch": epoch, "inner_epoch": inner_epoch, "batch": i}) 368 | 369 | 370 | # end of training evaluation 371 | all_rewards = [] 372 | for i, prompts in enumerate(progress_bar(train_dl)): 373 | batch_imgs, rewards, _, _ = sample_and_calculate_rewards(prompts, pipe, args.img_size, args.cfg, args.num_timesteps, decoding_fn, reward_fn, 'cuda') 374 | all_rewards.append(rewards) 375 | 376 | all_rewards = torch.cat(all_rewards) 377 | mean_rewards.append(all_rewards.mean().item()) 378 | wandb.log({"reward_hist": wandb.Histogram(all_rewards.detach().cpu().numpy())}) 379 | wandb.log({"mean_reward": mean_rewards[-1]}) 380 | wandb.log({"random img batch": [wandb.Image(Image.fromarray(img), caption=prompt) for img, prompt in zip(batch_imgs, prompts)]}) 381 | 382 | # sample some images with the consistent prompt for visualization 383 | sample_imgs, sample_rewards, _, _ = sample_and_calculate_rewards(sample_prompts, pipe, args.img_size, args.cfg, args.num_timesteps, decoding_fn, reward_fn, 'cuda') 384 | wandb.log({"sample img batch": [wandb.Image(Image.fromarray(img), caption=prompt + f', {reward}') for img, prompt, reward in zip(sample_imgs, sample_prompts, sample_rewards)]}) 385 | 386 | # save the model 387 | pipe.save_pretrained(args.output_dir) 388 | 389 | wandb.finish() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | diffusers 4 | transformers 5 | accelerate 6 | xformers 7 | wandb 8 | fastprogress 9 | matplotlib 10 | git+https://github.com/openai/CLIP.git --------------------------------------------------------------------------------