├── README.md ├── diffusion_lm ├── README.md ├── catsample.py ├── configs │ ├── config.yaml │ ├── ft.yaml │ └── model │ │ ├── medium.yaml │ │ └── small.yaml ├── eval.py ├── finetune_cond_data.py ├── gfn.py ├── graph_lib.py ├── load_model.py ├── losses.py ├── model │ ├── __init__.py │ ├── ema.py │ ├── fused_add_dropout_scale.py │ ├── rotary.py │ ├── transformer.py │ └── utils.py ├── noise_lib.py ├── requirements.txt ├── results.ipynb ├── run_sample.py ├── run_sample_cond.py ├── run_train.py ├── sampling.py ├── train.py └── utils.py ├── inverse_diffusion ├── README.md ├── requirements.txt ├── samples │ └── cls_finetuning_samples.png └── src │ ├── create_data.py │ ├── finetune_posterior.py │ ├── models │ ├── __init__.py │ ├── classifiers.py │ ├── denoisers.py │ ├── langevin.py │ └── samplers.py │ ├── train_classifier.py │ ├── train_prior.py │ ├── utils │ ├── __init__.py │ ├── args.py │ ├── data_loaders.py │ ├── diffusers │ │ ├── __init__.py │ │ ├── pipelines │ │ │ ├── .DS_Store │ │ │ ├── ddim_gfn │ │ │ │ ├── __init__.py │ │ │ │ └── pipeline_ddim_gfn.py │ │ │ ├── ddpm_dp │ │ │ │ ├── __init__.py │ │ │ │ └── pipeline_ddpm_dp.py │ │ │ └── ddpm_gfn │ │ │ │ ├── __init__.py │ │ │ │ └── pipeline_ddpm.py │ │ └── schedulers │ │ │ ├── scheduling_ddim_gfn.py │ │ │ ├── scheduling_ddpm_dp.py │ │ │ ├── scheduling_ddpm_gfn.py │ │ │ ├── scheduling_edm_euler_gfn.py │ │ │ └── scheduling_sde_ve_gfn.py │ ├── diffusion.py │ ├── fid_evaluation.py │ ├── gfn_diffusion.py │ ├── pytorch_utils.py │ ├── simple_io.py │ └── visualization.py │ └── visualize_runs.py ├── offline_RL ├── IQL_PyTorch │ ├── LICENSE │ ├── README.md │ ├── main.py │ ├── requirements.txt │ ├── results.py │ └── src │ │ ├── iql.py │ │ ├── policy.py │ │ ├── util.py │ │ └── value_functions.py ├── README.md ├── bc_models │ └── .gitkeep ├── model.py ├── q_models │ └── .gitkeep ├── qflow_offline.py └── train_bc.py ├── rtb_diffusion ├── README.md ├── __init__.py ├── classifer_guidance_posterior.py ├── energies │ ├── __init__.py │ ├── base_set.py │ ├── posterior_2Dgmm.py │ └── twenty_five_gmm.py ├── finetune_posterior.py ├── gflownet_losses.py ├── models │ ├── __init__.py │ ├── architectures.py │ └── gfn.py ├── plot_utils.py ├── pretrain_prior.py ├── pretrained │ └── prior.pt └── utils.py └── text_to_image ├── README.md ├── dataset └── drawbench │ └── data_meta.json ├── dpok_utils.py ├── environment.yaml ├── eval_div.py ├── eval_div_clip.py ├── eval_div_dino.py ├── eval_rew.py ├── img └── rabbit.gif ├── install_image_reward.sh ├── pipeline_stable_diffusion_extended.py ├── reward_model.py ├── scheduling_ddim_extended.py ├── test.py ├── train_gfn.py ├── train_online_pg.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Amortizing intractable inference in diffusion models for vision, language, and control 2 | 3 | This repository contains code for learning posteriors with diffusion priors and arbitrary constraints using _relative trajectory balance_ (RTB) introduced in 4 | 5 | **Amortizing intractable inference in diffusion models for vision, language, and control** 6 | 7 | Siddarth Venkatraman*, Moksh Jain*, Luca Scimeca*, Minsu Kim*, Marcin Sendera*, Mohsin Hasan, Luke Rowe, Sarthak Mittal, Pablo Lemos, Emmanuel Bengio, Alexandre Adam, Jarrid Rector-Brooks, Yoshua Bengio, Glen Berseth, Nikolay Malkin 8 | 9 | [arXiv](https://arxiv.org/abs/) 10 | 11 | 12 | The code and documentation for running the experiments is structured in subdirectories corresponding to each experiment. 13 | 14 | - Class-conditional posterior sampling from unconditional diffusion priors (§3.1) in `inverse_diffusion/` 15 | - Fine-tuning a text-to-image diffusion model (§3.2) in `text_to_image/` 16 | - Text infilling with discrete diffusion language models (§3.3) in `diffusion_lm/` 17 | - KL-constrained policy search in offline reinforcement learning (§3.4) in `offline_RL/` 18 | - Learning posterior of diffusion model sampling a mixture of 25 Gaussians (§1) in `rtb_diffusion/` 19 | -------------------------------------------------------------------------------- /diffusion_lm/README.md: -------------------------------------------------------------------------------- 1 | Code is based on [`louaaron/Score-Entropy-Discrete-Diffusion`](https://github.com/louaaron/Score-Entropy-Discrete-Diffusion). 2 | 3 | To train the reward model, please follow instructions at [GFNOrg/gfn-lm-tuning](https://github.com/GFNOrg/gfn-lm-tuning/tree/main/infill_subj_arithmetic). 4 | 5 | The entry point to the code is `finetune_cond_data.py`. The following command with the right path to the reward model should run the experiments in the paper 6 | 7 | ```bash 8 | python finetune_cond_data.py reward_type=story sampling_len=15 likelihood_model= wandb_mode=online loss_type=vargrad save_dir="" 9 | ``` 10 | 11 | For evaluation use the `eval.py` script. 12 | ```bash 13 | python eval.py --load_checkpoint_path --save_path .pkl.gz 14 | ``` 15 | 16 | -------------------------------------------------------------------------------- /diffusion_lm/catsample.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def gumbel_softmax(categorical_probs, hard=False, eps=1e-9): 6 | logits = categorical_probs.clamp(min=1e-9).log() 7 | return F.gumbel_softmax(logits, hard=hard) 8 | 9 | 10 | def sample_categorical(categorical_probs, method="hard"): 11 | if method == "hard": 12 | gumbel_norm = 1e-10 - (torch.rand_like(categorical_probs) + 1e-10).log() 13 | return (categorical_probs / gumbel_norm).argmax(dim=-1) 14 | else: 15 | raise ValueError(f"Method {method} for sampling categorical variables is not valid.") 16 | -------------------------------------------------------------------------------- /diffusion_lm/configs/config.yaml: -------------------------------------------------------------------------------- 1 | # This config is taken from https://github.com/louaaron/Score-Entropy-Discrete-Diffusion. 2 | defaults: 3 | - _self_ 4 | - model: small 5 | - override hydra/launcher: submitit_slurm 6 | 7 | ngpus: 2 8 | tokens: 50257 9 | 10 | training: 11 | batch_size: 512 12 | accum: 1 13 | n_iters: 1300001 14 | snapshot_freq: 50000 15 | log_freq: 50 16 | eval_freq: 100 17 | snapshot_freq_for_preemption: 10000 18 | weight: standard 19 | snapshot_sampling: True 20 | ema: 0.9999 21 | 22 | data: 23 | train: openwebtext 24 | valid: wikitext103 25 | cache_dir: 26 | 27 | graph: 28 | type: absorb 29 | file: 30 | report_all: False 31 | 32 | noise: 33 | type: loglinear 34 | sigma_min: 1e-4 35 | sigma_max: 20 36 | 37 | sampling: 38 | predictor: euler 39 | steps: 128 40 | noise_removal: True 41 | 42 | eval: 43 | batch_size: 512 44 | perplexity: True 45 | perplexity_batch_size: 32 46 | 47 | optim: 48 | weight_decay: 0 49 | optimizer: AdamW 50 | lr: 3e-4 51 | beta1: 0.9 52 | beta2: 0.999 53 | eps: 1e-8 54 | warmup: 2500 55 | grad_clip: 1. 56 | 57 | 58 | hydra: 59 | run: 60 | dir: 61 | sweep: 62 | dir: 63 | subdir: ${hydra.job.num} 64 | launcher: 65 | max_num_timeout: 100000 66 | # timeout_min: 10079 67 | partition: g40x 68 | account: stanford 69 | mem_gb: 96 70 | cpus_per_task: 40 71 | gpus_per_node: ${ngpus} 72 | constraint: null 73 | -------------------------------------------------------------------------------- /diffusion_lm/configs/ft.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 16 2 | model_name: louaaron/sedd-small 3 | sampling_len: 32 4 | learning_rate: 1e-5 5 | learning_rate_Z: 1e-3 6 | temperature: 1.0 7 | grad_accumulation: 8 8 | num_steps: 1500 9 | eps: 1e-5 10 | detach_freq: 0 11 | back_and_forth: False 12 | device: "cuda" 13 | likelihood_model: 14 | reward_batch_size: 16 15 | reward_type: dummy 16 | seed: 42 17 | seq_len: 32 18 | wandb_mode: disabled 19 | cutoff: 0.1 20 | loss_type: "tb" 21 | buffer_size: 1000 22 | sim_tolerance: 0.25 23 | prioritization: reward 24 | backward: False 25 | warmup_steps: 20 26 | reward_invtemp_end: 1.2 27 | reward_invtemp_start: 0.9 28 | reward_temp_sched_steps: 5000 29 | save_dir: "." 30 | exp_name: "test" -------------------------------------------------------------------------------- /diffusion_lm/configs/model/medium.yaml: -------------------------------------------------------------------------------- 1 | name: medium 2 | type: ddit 3 | hidden_size: 1024 4 | cond_dim: 128 5 | length: 1024 6 | n_blocks: 24 7 | n_heads: 16 8 | scale_by_sigma: True 9 | dropout: 0.1 -------------------------------------------------------------------------------- /diffusion_lm/configs/model/small.yaml: -------------------------------------------------------------------------------- 1 | name: small 2 | type: ddit 3 | hidden_size: 768 4 | cond_dim: 128 5 | length: 1024 6 | n_blocks: 12 7 | n_heads: 12 8 | scale_by_sigma: True 9 | dropout: 0.1 -------------------------------------------------------------------------------- /diffusion_lm/graph_lib.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.cuda.amp import custom_fwd, custom_bwd 7 | 8 | 9 | from catsample import sample_categorical 10 | 11 | def get_graph(config, device): 12 | if config.graph.type == "uniform": 13 | return Uniform(config.tokens) 14 | elif config.graph.type == "absorb": 15 | return Absorbing(config.tokens) 16 | else: 17 | raise ValueError(f"Graph {config.graph.type} not valid") 18 | 19 | 20 | def unsqueeze_as(x, y, back=True): 21 | if back: 22 | return x.view(*x.shape, *((1,) * (len(y.shape) - len(x.shape)))) 23 | else: 24 | return x.view(*((1,) * (len(y.shape) - len(x.shape))), *x.shape) 25 | 26 | 27 | class Graph(abc.ABC): 28 | 29 | @property 30 | def dim(self): 31 | pass 32 | 33 | @property 34 | def absorb(self): 35 | """ 36 | Whether input {dim - 1} is an absorbing state (used for denoising to always remove the mask). 37 | """ 38 | pass 39 | 40 | 41 | @abc.abstractmethod 42 | def rate(self, i): 43 | """ 44 | Computes the i-th column of the rate matrix Q, where i is [B_1, ..., B_n]. 45 | 46 | This is intended to compute the "forward" rate of p(X_t | X_0 = i). 47 | """ 48 | pass 49 | 50 | 51 | @abc.abstractmethod 52 | def transp_rate(self, i): 53 | """ 54 | Computes the i-th row of the rate matrix Q. 55 | 56 | Can be used to compute the reverse rate. 57 | """ 58 | pass 59 | 60 | 61 | @abc.abstractmethod 62 | def transition(self, i, sigma): 63 | """ 64 | Computes the i-th column of the transition matrix e^{sigma Q}. 65 | """ 66 | pass 67 | 68 | 69 | def sample_transition(self, i, sigma): 70 | """ 71 | Samples the transition vector. 72 | """ 73 | transition_vector = self.transition(i, sigma) 74 | return sample_categorical(transition_vector, method="hard") 75 | 76 | 77 | def reverse_rate(self, i, score): 78 | """ 79 | Constructs the reverse rate. Which is score * transp_rate 80 | """ 81 | normalized_rate = self.transp_rate(i) * score 82 | 83 | normalized_rate.scatter_(-1, i[..., None], torch.zeros_like(normalized_rate)) 84 | normalized_rate.scatter_(-1, i[..., None], -normalized_rate.sum(dim=-1, keepdim=True)) 85 | return normalized_rate 86 | 87 | def sample_rate(self, i, rate): 88 | return sample_categorical(F.one_hot(i, num_classes=self.dim).to(rate) + rate) 89 | 90 | 91 | @abc.abstractmethod 92 | def staggered_score(self, score, dsigma): 93 | """ 94 | Computes p_{sigma - dsigma}(z) / p_{sigma}(x), which is approximated with 95 | e^{-{dsigma} E} score 96 | """ 97 | pass 98 | 99 | 100 | @abc.abstractmethod 101 | def sample_limit(self, *batch_dims): 102 | """ 103 | Sample the limiting distribution. Returns the probability vector as well. 104 | """ 105 | pass 106 | 107 | 108 | @abc.abstractmethod 109 | def score_entropy(self, score, sigma, x, x0): 110 | """ 111 | Computes the score entropy function (with requisite constant normalization) 112 | """ 113 | pass 114 | 115 | 116 | class Uniform(Graph): 117 | """ 118 | Everything goes to everything else. Normalized down by dimension to avoid blowup. 119 | """ 120 | def __init__(self, dim): 121 | self._dim = dim 122 | 123 | @property 124 | def dim(self): 125 | return self._dim 126 | 127 | @property 128 | def absorb(self): 129 | return False 130 | 131 | def rate(self, i): 132 | edge = torch.ones(*i.shape, self.dim, device=i.device) / self.dim 133 | edge = edge.scatter(-1, i[..., None], - (self.dim - 1) / self.dim) 134 | return edge 135 | 136 | def transp_rate(self, i): 137 | return self.rate(i) 138 | 139 | def transition(self, i, sigma): 140 | trans = torch.ones(*i.shape, self.dim, device=i.device) * (1 - (-sigma[..., None]).exp()) / self.dim 141 | trans = trans.scatter(-1, i[..., None], torch.zeros_like(trans)) 142 | trans = trans.scatter(-1, i[..., None], 1 - trans.sum(dim=-1, keepdim=True)) 143 | return trans 144 | 145 | def transp_transition(self, i, sigma): 146 | return self.transition(i, sigma) 147 | 148 | def sample_transition(self, i, sigma): 149 | move_chance = 1 - (-sigma).exp() 150 | move_indices = torch.rand(*i.shape, device=i.device) < move_chance 151 | i_pert = torch.where(move_indices, torch.randint_like(i, self.dim), i) 152 | return i_pert 153 | 154 | def staggered_score(self, score, dsigma): 155 | dim = score.shape[-1] 156 | epow = (-dsigma).exp()[..., None] 157 | return ((epow - 1) / (dim * epow)) * score.sum(dim=-1, keepdim=True) + score / epow 158 | 159 | def sample_limit(self, *batch_dims): 160 | return torch.randint(0, self.dim, batch_dims) 161 | 162 | def score_entropy(self, score, sigma, x, x0): 163 | esigm1 = torch.where( 164 | sigma < 0.5, 165 | torch.expm1(sigma), 166 | torch.exp(sigma) - 1 167 | ) 168 | ratio = 1 - self.dim / (esigm1 + self.dim) 169 | 170 | # negative term 171 | neg_term = score.mean(dim=-1) - torch.gather(score, -1, x[..., None]).squeeze(-1) / self.dim 172 | # no move means scaling by the uniform ratio. move means alter only one ratio away from 1 173 | neg_term = torch.where( 174 | x == x0, 175 | ratio * neg_term, 176 | torch.gather(score, -1, x0[..., None]).squeeze(-1) / esigm1 + neg_term 177 | ) 178 | 179 | # constant factor 180 | const = torch.where( 181 | x == x0, 182 | (self.dim - 1) / self.dim * ratio * (ratio.log() - 1), 183 | ((-ratio.log() - 1) / ratio - (self.dim - 2)) / self.dim 184 | ) 185 | 186 | #positive term 187 | sexp = score.exp() 188 | pos_term = sexp.mean(dim=-1) - torch.gather(sexp, -1, x[..., None]).squeeze(-1) / self.dim 189 | return pos_term - neg_term + const 190 | 191 | 192 | class Absorbing(Graph): 193 | def __init__(self, dim): 194 | super().__init__() 195 | self._dim = dim 196 | 197 | @property 198 | def dim(self): 199 | return self._dim + 1 200 | 201 | @property 202 | def absorb(self): 203 | return True 204 | 205 | def rate(self, i): 206 | # edge = - F.one_hot(i, num_classes=self.dim) 207 | # edge.scatter_add_(-1, i[..., None], torch.ones_like(edge[..., :1])) 208 | return F.one_hot((self.dim - 1) * torch.ones_like(i), num_classes=self.dim) - F.one_hot(i, num_classes=self.dim) 209 | 210 | def transp_rate(self, i): 211 | edge = -F.one_hot(i, num_classes=self.dim) 212 | edge[i == self.dim - 1] += 1 213 | return edge 214 | 215 | def transition(self, i, sigma): 216 | pass 217 | 218 | def transp_transition(self, i, sigma): 219 | sigma = unsqueeze_as(sigma, i[..., None]) 220 | edge = (-sigma).exp() * F.one_hot(i, num_classes=self.dim) 221 | edge += torch.where( 222 | i == self.dim - 1, 223 | 1 - (-sigma).squeeze(-1).exp(), 224 | 0 225 | )[..., None] 226 | return edge 227 | 228 | def sample_transition(self, i, sigma): 229 | move_chance = 1 - (-sigma).exp() 230 | move_indices = torch.rand(*i.shape, device=i.device) < move_chance 231 | i_pert = torch.where(move_indices, self.dim - 1, i) 232 | return i_pert 233 | 234 | def staggered_score(self, score, dsigma): 235 | out_score = score.clone() # yeah yeah whatever we should probably do this 236 | extra_const = (1 - (dsigma).exp()) * score.sum(dim=-1) 237 | # import pdb; pdb.set_trace(); 238 | out_score = out_score * dsigma.exp()[:, None] 239 | # mask = torch.zeros_like(score) 240 | # mask[..., -1] = 1 241 | # # score.masked_fill_(mask.bool(), extra_const) 242 | # score = score + mask * extra_const 243 | out_score[..., -1] = out_score[..., -1] + extra_const 244 | return score 245 | 246 | def sample_limit(self, *batch_dims): 247 | return (self.dim - 1) * torch.ones(*batch_dims, dtype=torch.int64).requires_grad_(False) 248 | 249 | def score_entropy(self, score, sigma, x, x0): 250 | rel_ind = x == self.dim - 1 251 | esigm1 = torch.where( 252 | sigma < 0.5, 253 | torch.expm1(sigma), 254 | torch.exp(sigma) - 1 255 | ) 256 | 257 | ratio = 1 / esigm1.expand_as(x)[rel_ind] 258 | other_ind = x0[rel_ind] 259 | 260 | # negative_term 261 | neg_term = ratio * torch.gather(score[rel_ind], -1, other_ind[..., None]).squeeze(-1) 262 | 263 | #positive term 264 | pos_term = score[rel_ind][:, :-1].exp().sum(dim=-1) 265 | 266 | # constant term 267 | const = ratio * (ratio.log() - 1) 268 | 269 | entropy = torch.zeros(*x.shape, device=x.device) 270 | entropy[rel_ind] += pos_term - neg_term + const 271 | return entropy 272 | -------------------------------------------------------------------------------- /diffusion_lm/load_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from model import SEDD 4 | import utils 5 | from model.ema import ExponentialMovingAverage 6 | import graph_lib 7 | import noise_lib 8 | 9 | from omegaconf import OmegaConf 10 | 11 | def load_model_hf(dir, device): 12 | score_model = SEDD.from_pretrained(dir).to(device) 13 | graph = graph_lib.get_graph(score_model.config, device) 14 | noise = noise_lib.get_noise(score_model.config).to(device) 15 | return score_model, graph, noise 16 | 17 | 18 | def load_model_local(root_dir, device): 19 | cfg = utils.load_hydra_config_from_run(root_dir) 20 | graph = graph_lib.get_graph(cfg, device) 21 | noise = noise_lib.get_noise(cfg).to(device) 22 | score_model = SEDD(cfg).to(device) 23 | ema = ExponentialMovingAverage(score_model.parameters(), decay=cfg.training.ema) 24 | 25 | ckpt_dir = os.path.join(root_dir, "checkpoints-meta", "checkpoint.pth") 26 | loaded_state = torch.load(ckpt_dir, map_location=device) 27 | 28 | score_model.load_state_dict(loaded_state['model']) 29 | ema.load_state_dict(loaded_state['ema']) 30 | 31 | ema.store(score_model.parameters()) 32 | ema.copy_to(score_model.parameters()) 33 | return score_model, graph, noise 34 | 35 | 36 | def load_model(root_dir, device): 37 | try: 38 | return load_model_hf(root_dir, device) 39 | except: 40 | return load_model_local(root_dir, device) -------------------------------------------------------------------------------- /diffusion_lm/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import graph_lib 6 | from model import utils as mutils 7 | 8 | 9 | def get_loss_fn(noise, graph, train, sampling_eps=1e-3, lv=False): 10 | 11 | def loss_fn(model, batch, cond=None, t=None, perturbed_batch=None): 12 | """ 13 | Batch shape: [B, L] int. D given from graph 14 | """ 15 | 16 | if t is None: 17 | if lv: 18 | raise NotImplementedError("Yeah I gotta do this later") 19 | else: 20 | t = (1 - sampling_eps) * torch.rand(batch.shape[0], device=batch.device) + sampling_eps 21 | 22 | sigma, dsigma = noise(t) 23 | 24 | if perturbed_batch is None: 25 | perturbed_batch = graph.sample_transition(batch, sigma[:, None]) 26 | 27 | log_score_fn = mutils.get_score_fn(model, train=train, sampling=False) 28 | log_score = log_score_fn(perturbed_batch, sigma) 29 | loss = graph.score_entropy(log_score, sigma[:, None], perturbed_batch, batch) 30 | 31 | loss = (dsigma[:, None] * loss).sum(dim=-1) 32 | 33 | return loss 34 | 35 | return loss_fn 36 | 37 | 38 | def get_optimizer(config, params): 39 | if config.optim.optimizer == 'Adam': 40 | optimizer = optim.Adam(params, lr=config.optim.lr, betas=(config.optim.beta1, config.optim.beta2), eps=config.optim.eps, 41 | weight_decay=config.optim.weight_decay) 42 | elif config.optim.optimizer == 'AdamW': 43 | optimizer = optim.AdamW(params, lr=config.optim.lr, betas=(config.optim.beta1, config.optim.beta2), eps=config.optim.eps, 44 | weight_decay=config.optim.weight_decay) 45 | else: 46 | raise NotImplementedError( 47 | f'Optimizer {config.optim.optimizer} not supported yet!') 48 | 49 | return optimizer 50 | 51 | 52 | def optimization_manager(config): 53 | """Returns an optimize_fn based on `config`.""" 54 | 55 | def optimize_fn(optimizer, 56 | scaler, 57 | params, 58 | step, 59 | lr=config.optim.lr, 60 | warmup=config.optim.warmup, 61 | grad_clip=config.optim.grad_clip): 62 | """Optimizes with warmup and gradient clipping (disabled if negative).""" 63 | scaler.unscale_(optimizer) 64 | 65 | if warmup > 0: 66 | for g in optimizer.param_groups: 67 | g['lr'] = lr * np.minimum(step / warmup, 1.0) 68 | if grad_clip >= 0: 69 | torch.nn.utils.clip_grad_norm_(params, max_norm=grad_clip) 70 | 71 | scaler.step(optimizer) 72 | scaler.update() 73 | 74 | return optimize_fn 75 | 76 | 77 | def get_step_fn(noise, graph, train, optimize_fn, accum): 78 | loss_fn = get_loss_fn(noise, graph, train) 79 | 80 | accum_iter = 0 81 | total_loss = 0 82 | 83 | def step_fn(state, batch, cond=None): 84 | nonlocal accum_iter 85 | nonlocal total_loss 86 | 87 | model = state['model'] 88 | 89 | if train: 90 | optimizer = state['optimizer'] 91 | scaler = state['scaler'] 92 | loss = loss_fn(model, batch, cond=cond).mean() / accum 93 | 94 | scaler.scale(loss).backward() 95 | 96 | accum_iter += 1 97 | total_loss += loss.detach() 98 | if accum_iter == accum: 99 | accum_iter = 0 100 | 101 | state['step'] += 1 102 | optimize_fn(optimizer, scaler, model.parameters(), step=state['step']) 103 | state['ema'].update(model.parameters()) 104 | optimizer.zero_grad() 105 | 106 | loss = total_loss 107 | total_loss = 0 108 | else: 109 | with torch.no_grad(): 110 | ema = state['ema'] 111 | ema.store(model.parameters()) 112 | ema.copy_to(model.parameters()) 113 | loss = loss_fn(model, batch, cond=cond).mean() 114 | ema.restore(model.parameters()) 115 | 116 | return loss 117 | 118 | return step_fn -------------------------------------------------------------------------------- /diffusion_lm/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer import SEDD -------------------------------------------------------------------------------- /diffusion_lm/model/ema.py: -------------------------------------------------------------------------------- 1 | # Modified from https://raw.githubusercontent.com/fadel/pytorch_ema/master/torch_ema/ema.py 2 | 3 | from __future__ import division 4 | from __future__ import unicode_literals 5 | 6 | import torch 7 | 8 | 9 | # Partially based on: https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py 10 | class ExponentialMovingAverage: 11 | """ 12 | Maintains (exponential) moving average of a set of parameters. 13 | """ 14 | 15 | def __init__(self, parameters, decay, use_num_updates=True): 16 | """ 17 | Args: 18 | parameters: Iterable of `torch.nn.Parameter`; usually the result of 19 | `model.parameters()`. 20 | decay: The exponential decay. 21 | use_num_updates: Whether to use number of updates when computing 22 | averages. 23 | """ 24 | if decay < 0.0 or decay > 1.0: 25 | raise ValueError('Decay must be between 0 and 1') 26 | self.decay = decay 27 | self.num_updates = 0 if use_num_updates else None 28 | self.shadow_params = [p.clone().detach() 29 | for p in parameters if p.requires_grad] 30 | self.collected_params = [] 31 | 32 | def update(self, parameters): 33 | """ 34 | Update currently maintained parameters. 35 | 36 | Call this every time the parameters are updated, such as the result of 37 | the `optimizer.step()` call. 38 | 39 | Args: 40 | parameters: Iterable of `torch.nn.Parameter`; usually the same set of 41 | parameters used to initialize this object. 42 | """ 43 | decay = self.decay 44 | if self.num_updates is not None: 45 | self.num_updates += 1 46 | decay = min(decay, (1 + self.num_updates) / 47 | (10 + self.num_updates)) 48 | one_minus_decay = 1.0 - decay 49 | with torch.no_grad(): 50 | parameters = [p for p in parameters if p.requires_grad] 51 | for s_param, param in zip(self.shadow_params, parameters): 52 | s_param.sub_(one_minus_decay * (s_param - param)) 53 | 54 | 55 | def copy_to(self, parameters): 56 | """ 57 | Copy current parameters into given collection of parameters. 58 | 59 | Args: 60 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 61 | updated with the stored moving averages. 62 | """ 63 | parameters = [p for p in parameters if p.requires_grad] 64 | for s_param, param in zip(self.shadow_params, parameters): 65 | if param.requires_grad: 66 | param.data.copy_(s_param.data) 67 | 68 | def store(self, parameters): 69 | """ 70 | Save the current parameters for restoring later. 71 | 72 | Args: 73 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 74 | temporarily stored. 75 | """ 76 | self.collected_params = [param.clone() for param in parameters] 77 | 78 | def restore(self, parameters): 79 | """ 80 | Restore the parameters stored with the `store` method. 81 | Useful to validate the model with EMA parameters without affecting the 82 | original optimization process. Store the parameters before the 83 | `copy_to` method. After validation (or model saving), use this to 84 | restore the former parameters. 85 | 86 | Args: 87 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 88 | updated with the stored parameters. 89 | """ 90 | for c_param, param in zip(self.collected_params, parameters): 91 | param.data.copy_(c_param.data) 92 | 93 | def state_dict(self): 94 | return dict(decay=self.decay, num_updates=self.num_updates, 95 | shadow_params=self.shadow_params) 96 | 97 | def load_state_dict(self, state_dict): 98 | self.decay = state_dict['decay'] 99 | self.num_updates = state_dict['num_updates'] 100 | self.shadow_params = state_dict['shadow_params'] -------------------------------------------------------------------------------- /diffusion_lm/model/fused_add_dropout_scale.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from typing import Optional 4 | from torch import Tensor 5 | 6 | # flags required to enable jit fusion kernels 7 | torch._C._jit_set_profiling_mode(False) 8 | torch._C._jit_set_profiling_executor(False) 9 | torch._C._jit_override_can_fuse_on_cpu(True) 10 | torch._C._jit_override_can_fuse_on_gpu(True) 11 | 12 | 13 | def bias_dropout_add_scale( 14 | x: Tensor, bias: Optional[Tensor], scale: Tensor, residual: Optional[Tensor], prob: float, training: bool 15 | ) -> Tensor: 16 | if bias is not None: 17 | out = scale * F.dropout(x + bias, p=prob, training=training) 18 | else: 19 | out = scale * F.dropout(x, p=prob, training=training) 20 | 21 | if residual is not None: 22 | out = residual + out 23 | return out 24 | 25 | 26 | def get_bias_dropout_add_scale(training): 27 | def _bias_dropout_add(x, bias, scale, residual, prob): 28 | return bias_dropout_add_scale(x, bias, scale, residual, prob, training) 29 | 30 | return _bias_dropout_add 31 | 32 | 33 | def modulate(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor: 34 | return x * (1 + scale) + shift 35 | 36 | 37 | @torch.jit.script 38 | def bias_dropout_add_scale_fused_train( 39 | x: Tensor, bias: Optional[Tensor], scale: Tensor, residual: Optional[Tensor], prob: float 40 | ) -> Tensor: 41 | return bias_dropout_add_scale(x, bias, scale, residual, prob, True) 42 | 43 | 44 | @torch.jit.script 45 | def bias_dropout_add_scale_fused_inference( 46 | x: Tensor, bias: Optional[Tensor], scale: Tensor, residual: Optional[Tensor], prob: float 47 | ) -> Tensor: 48 | return bias_dropout_add_scale(x, bias, scale, residual, prob, False) 49 | 50 | @torch.jit.script 51 | def modulate_fused(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor: 52 | return modulate(x, shift, scale) -------------------------------------------------------------------------------- /diffusion_lm/model/rotary.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Rotary(torch.nn.Module): 6 | def __init__(self, dim, base=10_000): 7 | super().__init__() 8 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) 9 | self.register_buffer("inv_freq", inv_freq) 10 | self.seq_len_cached = None 11 | self.cos_cached = None 12 | self.sin_cached = None 13 | 14 | def forward(self, x, seq_dim=1): 15 | seq_len = x.shape[seq_dim] 16 | if seq_len != self.seq_len_cached: 17 | self.seq_len_cached = seq_len 18 | t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) 19 | freqs = torch.einsum("i,j->ij", t, self.inv_freq.clone()) 20 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 21 | # dims are: batch, seq_len, qkv, head, dim 22 | self.cos_cached = emb.cos()[None, :, None, None, :].repeat(1,1,3,1,1) 23 | self.sin_cached = emb.sin()[None, :, None, None, :].repeat(1,1,3,1,1) 24 | # This makes the transformation on v an identity. 25 | self.cos_cached[:,:,2,:,:].fill_(1.) 26 | self.sin_cached[:,:,2,:,:].fill_(0.) 27 | 28 | return self.cos_cached, self.sin_cached 29 | 30 | 31 | def rotate_half(x): 32 | x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] 33 | return torch.cat( 34 | (-x2, x1), dim=-1 35 | ) 36 | 37 | 38 | @torch.jit.script 39 | def _apply_rotary_pos_emb_torchscript(qkv, cos, sin): 40 | return (qkv * cos) + (rotate_half(qkv) * sin) 41 | 42 | 43 | def apply_rotary_pos_emb(qkv, cos, sin): 44 | try: 45 | import flash_attn.layers.rotary 46 | cos = cos[0,:,0,0,:cos.shape[-1]//2] 47 | sin = sin[0,:,0,0,:sin.shape[-1]//2] 48 | return flash_attn.layers.rotary.apply_rotary_emb_qkv_( 49 | qkv, cos, sin 50 | ) 51 | except: 52 | return _apply_rotary_pos_emb_torchscript(qkv, cos, sin) -------------------------------------------------------------------------------- /diffusion_lm/model/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def get_model_fn(model, train=False): 6 | """Create a function to give the output of the score-based model. 7 | 8 | Args: 9 | model: The score model. 10 | train: `True` for training and `False` for evaluation. 11 | mlm: If the input model is a mlm and models the base probability 12 | 13 | Returns: 14 | A model function. 15 | """ 16 | 17 | def model_fn(x, sigma): 18 | """Compute the output of the score-based model. 19 | 20 | Args: 21 | x: A mini-batch of input data. 22 | labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently 23 | for different models. 24 | 25 | Returns: 26 | A tuple of (model output, new mutable states) 27 | """ 28 | if train: 29 | model.train() 30 | else: 31 | model.eval() 32 | 33 | # otherwise output the raw values (we handle mlm training in losses.py) 34 | return model(x, sigma) 35 | 36 | return model_fn 37 | 38 | 39 | def get_score_fn(model, train=False, sampling=False): 40 | if sampling: 41 | assert not train, "Must sample in eval mode" 42 | model_fn = get_model_fn(model, train=train) 43 | 44 | with torch.cuda.amp.autocast(dtype=torch.bfloat16): 45 | def score_fn(x, sigma): 46 | sigma = sigma.reshape(-1) 47 | score = model_fn(x, sigma) 48 | 49 | if sampling: 50 | # when sampling return true score (not log used for training) 51 | return score.exp() 52 | 53 | return score 54 | 55 | return score_fn 56 | -------------------------------------------------------------------------------- /diffusion_lm/noise_lib.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | 7 | def get_noise(config): 8 | if config.noise.type == "geometric": 9 | return GeometricNoise(config.noise.sigma_min, config.noise.sigma_max) 10 | elif config.noise.type == "loglinear": 11 | return LogLinearNoise() 12 | else: 13 | raise ValueError(f"{config.noise.type} is not a valid noise") 14 | 15 | 16 | class Noise(abc.ABC, nn.Module): 17 | """ 18 | Baseline forward method to get the total + rate of noise at a timestep 19 | """ 20 | def forward(self, t): 21 | return self.total_noise(t), self.rate_noise(t) 22 | 23 | """ 24 | Assume time goes from 0 to 1 25 | """ 26 | @abc.abstractmethod 27 | def rate_noise(self, t): 28 | """ 29 | Rate of change of noise ie g(t) 30 | """ 31 | pass 32 | 33 | @abc.abstractmethod 34 | def total_noise(self, t): 35 | """ 36 | Total noise ie \int_0^t g(t) dt + g(0) 37 | """ 38 | pass 39 | 40 | 41 | class GeometricNoise(Noise, nn.Module): 42 | def __init__(self, sigma_min=1e-3, sigma_max=1, learnable=False): 43 | super().__init__() 44 | self.sigmas = 1.0 * torch.tensor([sigma_min, sigma_max]) 45 | if learnable: 46 | self.sigmas = nn.Parameter(self.sigmas) 47 | self.empty = nn.Parameter(torch.tensor(0.0)) 48 | 49 | def rate_noise(self, t): 50 | return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t * (self.sigmas[1].log() - self.sigmas[0].log()) 51 | 52 | def total_noise(self, t): 53 | return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t 54 | 55 | 56 | class LogLinearNoise(Noise, nn.Module): 57 | """ 58 | Log Linear noise schedule built so that 1 - 1/e^(n(t)) interpolates between 0 and ~1 59 | when t goes from 0 to 1. Used for absorbing 60 | 61 | Total noise is -log(1 - (1 - eps) * t), so the sigma will be (1 - eps) * t 62 | """ 63 | def __init__(self, eps=1e-3): 64 | super().__init__() 65 | self.eps = eps 66 | self.empty = nn.Parameter(torch.tensor(0.0)) 67 | 68 | def rate_noise(self, t): 69 | return (1 - self.eps) / (1 - (1 - self.eps) * t) 70 | 71 | def total_noise(self, t): 72 | return -torch.log1p(-(1 - self.eps) * t) 73 | 74 | -------------------------------------------------------------------------------- /diffusion_lm/requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.27.2 2 | aiohttp==3.9.4 3 | aiosignal==1.3.1 4 | antlr4-python3-runtime==4.9.3 5 | appdirs==1.4.4 6 | async-timeout==4.0.3 7 | attrs==23.2.0 8 | beartype==0.14.1 9 | better-abc==0.0.3 10 | certifi==2024.7.4 11 | charset-normalizer==2.1.1 12 | click==8.1.7 13 | cloudpickle==3.0.0 14 | cmake==3.25.0 15 | datasets==2.17.1 16 | dill==0.3.8 17 | docker-pycreds==0.4.0 18 | einops==0.7.0 19 | fancy-einsum==0.0.3 20 | filelock==3.9.0 21 | frozenlist==1.4.1 22 | fsspec==2023.10.0 23 | gitdb==4.0.11 24 | gitpython==3.1.42 25 | huggingface-hub==0.21.1 26 | hydra-core==1.3.2 27 | hydra-submitit-launcher==1.2.0 28 | idna==3.7 29 | jaxtyping==0.2.25 30 | jinja2==3.1.4 31 | lit==15.0.7 32 | markdown-it-py==3.0.0 33 | markupsafe==2.1.3 34 | mdurl==0.1.2 35 | mpmath==1.3.0 36 | multidict==6.0.5 37 | multiprocess==0.70.16 38 | networkx==3.2.1 39 | ninja==1.11.1.1 40 | numpy==1.24.1 41 | omegaconf==2.3.0 42 | packaging==23.2 43 | pandas==2.2.1 44 | pillow==10.3.0 45 | protobuf==4.25.3 46 | psutil==5.9.8 47 | pyarrow==15.0.0 48 | pyarrow-hotfix==0.6 49 | pygments==2.17.2 50 | python-dateutil==2.8.2 51 | pytz==2024.1 52 | pyyaml==6.0.1 53 | regex==2023.12.25 54 | requests==2.32.2 55 | rich==13.7.0 56 | safetensors==0.4.2 57 | sentry-sdk==1.40.6 58 | setproctitle==1.3.3 59 | six==1.16.0 60 | smmap==5.0.1 61 | submitit==1.5.1 62 | sympy==1.12 63 | tokenizers==0.15.2 64 | torch==2.2.0 65 | tqdm==4.66.3 66 | transformer-lens==1.14.0 67 | transformers==4.38.1 68 | triton==2.0.0 69 | typeguard==2.13.3 70 | typing-extensions==4.8.0 71 | tzdata==2024.1 72 | urllib3==1.26.19 73 | wandb==0.16.3 74 | xxhash==3.4.1 75 | yarl==1.9.4 -------------------------------------------------------------------------------- /diffusion_lm/run_sample.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | 4 | from load_model import load_model 5 | from transformers import GPT2TokenizerFast, GPT2LMHeadModel 6 | import torch.nn.functional as F 7 | import sampling 8 | 9 | 10 | def main(): 11 | parser = argparse.ArgumentParser(description="Generate some samples") 12 | parser.add_argument("--model_path", default="louaaron/sedd-small", type=str) 13 | parser.add_argument("--dataset", default="wikitext103", type=str) 14 | parser.add_argument("--batch_size", type=int, default=8) 15 | parser.add_argument("--steps", type=int, default=512) 16 | args = parser.parse_args() 17 | 18 | 19 | device = torch.device('cuda') 20 | model, graph, noise = load_model(args.model_path, device) 21 | tokenizer = GPT2TokenizerFast.from_pretrained('gpt2') 22 | eval_model = GPT2LMHeadModel.from_pretrained("gpt2-large").to(device) 23 | sampling_fn = sampling.get_pc_sampler( 24 | graph, noise, (args.batch_size, 128), 'analytic', args.steps, device=device 25 | ) 26 | samples = sampling_fn(model) 27 | with torch.no_grad(): 28 | eval_out = eval_model(samples, labels=samples) 29 | logits = eval_out.logits.transpose(-1, -2) 30 | pplx = F.cross_entropy(logits[..., :-1], samples[..., 1:], reduction="none").exp().mean() 31 | text_samples = tokenizer.batch_decode(samples) 32 | 33 | for i in text_samples: 34 | print(i) 35 | print("=================================================") 36 | 37 | print(f"Perplexity: {pplx.item()}") 38 | 39 | if __name__=="__main__": 40 | main() -------------------------------------------------------------------------------- /diffusion_lm/run_sample_cond.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | 4 | from load_model import load_model 5 | from transformers import GPT2TokenizerFast, GPT2LMHeadModel 6 | import sampling 7 | import torch.nn.functional as F 8 | 9 | def main(): 10 | parser = argparse.ArgumentParser(description="Generate some samples") 11 | parser.add_argument("--model_path", default="louaaron/sedd-medium", type=str) 12 | parser.add_argument("--dataset", default="wikitext103", type=str) 13 | parser.add_argument("--batch_size", type=int, default=1) 14 | parser.add_argument("--steps", type=int, default=512) 15 | parser.add_argument("--prefix", type=str, default="Hi, my name is") 16 | parser.add_argument("--suffix", type=str, default=" and that's why I'm late.") 17 | args = parser.parse_args() 18 | 19 | tokenizer = GPT2TokenizerFast.from_pretrained('gpt2') 20 | 21 | prefix_ids = tokenizer(args.prefix).input_ids 22 | suffix_ids = tokenizer(args.suffix).input_ids 23 | input_ids = prefix_ids + suffix_ids 24 | input_locs = list(range(len(prefix_ids))) + list(range(1024-len(suffix_ids), 1024)) 25 | 26 | # more generaly commands can be defined with something like below: 27 | # input_ids = [0, 1, 512, 8080, 50256, 20000] 28 | # input_locs = [5, 6, 19, 20, 1000, 10001] 29 | 30 | 31 | input_ids = torch.tensor(input_ids, device="cuda")[None].repeat(args.batch_size, 1) 32 | 33 | def proj_fun(x): 34 | x[:, input_locs] = input_ids 35 | return x 36 | 37 | device = torch.device('cuda') 38 | model, graph, noise = load_model(args.model_path, device) 39 | 40 | eval_model = GPT2LMHeadModel.from_pretrained("gpt2-large").to(device) 41 | sampling_fn = sampling.get_pc_sampler( 42 | graph, noise, (args.batch_size, 128), 'analytic', args.steps, device=device, proj_fun=proj_fun 43 | ) 44 | 45 | samples = proj_fun(sampling_fn(model)) 46 | with torch.no_grad(): 47 | eval_out = eval_model(samples, labels=samples) 48 | logits = eval_out.logits.transpose(-1, -2) 49 | pplx = F.cross_entropy(logits[..., :-1], samples[..., 1:], reduction="none").exp().mean() 50 | text_samples = tokenizer.batch_decode(samples) 51 | text_samples = tokenizer.batch_decode(samples) 52 | for i in text_samples: 53 | print(i) 54 | print("=================================================") 55 | print(f"Perplexity: {pplx.item()}") 56 | 57 | if __name__=="__main__": 58 | main() -------------------------------------------------------------------------------- /diffusion_lm/run_train.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import os.path 4 | import gc 5 | from itertools import chain 6 | 7 | import numpy as np 8 | import torch 9 | import torch.distributed as dist 10 | from torch.nn.parallel import DistributedDataParallel as DDP 11 | import torch.nn.functional as F 12 | 13 | import data 14 | import losses 15 | import sampling 16 | import graph_lib 17 | import noise_lib 18 | import utils 19 | from model import SEDD 20 | from model.ema import ExponentialMovingAverage 21 | from transformers import GPT2TokenizerFast, GPT2LMHeadModel 22 | 23 | 24 | torch.backends.cudnn.benchmark = True 25 | # torch.autograd.set_detect_anomaly(True) 26 | 27 | 28 | def setup(rank, world_size, port): 29 | os.environ["MASTER_ADDR"] = "localhost" 30 | os.environ["MASTER_PORT"] = str(port) 31 | 32 | # initialize the process group 33 | dist.init_process_group( 34 | "nccl", rank=rank, world_size=world_size, timeout=datetime.timedelta(minutes=30) 35 | ) 36 | 37 | 38 | def cleanup(): 39 | dist.destroy_process_group() 40 | 41 | 42 | def run_multiprocess(rank, world_size, cfg, port): 43 | try: 44 | setup(rank, world_size, port) 45 | _run(rank, world_size, cfg) 46 | finally: 47 | cleanup() 48 | 49 | 50 | def _run(rank, world_size, cfg): 51 | torch.cuda.set_device(rank) 52 | work_dir = cfg.work_dir 53 | 54 | # Create directories for experimental logs 55 | sample_dir = os.path.join(work_dir, "samples") 56 | checkpoint_dir = os.path.join(work_dir, "checkpoints") 57 | checkpoint_meta_dir = os.path.join(work_dir, "checkpoints-meta", "checkpoint.pth") 58 | if rank == 0: 59 | utils.makedirs(sample_dir) 60 | utils.makedirs(checkpoint_dir) 61 | utils.makedirs(os.path.dirname(checkpoint_meta_dir)) 62 | 63 | # logging 64 | if rank == 0: 65 | logger = utils.get_logger(os.path.join(work_dir, "logs")) 66 | def mprint(msg): 67 | if rank == 0: 68 | logger.info(msg) 69 | 70 | mprint(work_dir) 71 | mprint(cfg) 72 | device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu") 73 | if device.type == "cuda": 74 | mprint("Found {} CUDA devices.".format(torch.cuda.device_count())) 75 | for i in range(torch.cuda.device_count()): 76 | props = torch.cuda.get_device_properties(i) 77 | mprint( 78 | "{} \t Memory: {:.2f}GB".format( 79 | props.name, props.total_memory / (1024 ** 3) 80 | ) 81 | ) 82 | else: 83 | mprint("WARNING: Using device {}".format(device)) 84 | mprint(f"Found {os.cpu_count()} total number of CPUs.") 85 | 86 | # build token graph 87 | graph = graph_lib.get_graph(cfg, device) 88 | 89 | # build score model 90 | score_model = SEDD(cfg).to(device) 91 | score_model = DDP(score_model, device_ids=[rank], static_graph=True, find_unused_parameters=True) 92 | 93 | num_parameters = sum(p.numel() for p in score_model.parameters()) 94 | mprint(f"Number of parameters in the model: {num_parameters}") 95 | 96 | ema = ExponentialMovingAverage( 97 | score_model.parameters(), decay=cfg.training.ema) 98 | mprint(score_model) 99 | mprint(f"EMA: {ema}") 100 | 101 | # build noise 102 | noise = noise_lib.get_noise(cfg).to(device) 103 | noise = DDP(noise, device_ids=[rank], static_graph=True) 104 | sampling_eps = 1e-5 105 | 106 | 107 | # build optimization state 108 | optimizer = losses.get_optimizer(cfg, chain(score_model.parameters(), noise.parameters())) 109 | mprint(f"Optimizer: {optimizer}") 110 | scaler = torch.cuda.amp.GradScaler() 111 | mprint(f"Scaler: {scaler}") 112 | state = dict(optimizer=optimizer, scaler=scaler, model=score_model, noise=noise, ema=ema, step=0) 113 | 114 | 115 | # load in state 116 | state = utils.restore_checkpoint(checkpoint_meta_dir, state, device) 117 | initial_step = int(state['step']) 118 | 119 | 120 | # load in tokenizer 121 | tokenizer = GPT2TokenizerFast.from_pretrained('gpt2') 122 | 123 | # Build data iterators 124 | train_ds, eval_ds = data.get_dataloaders(cfg) 125 | 126 | # mprint(f"Length of datasets: {len(train_ds)}, {len(eval_ds)}") 127 | 128 | train_iter = iter(train_ds) 129 | eval_iter = iter(eval_ds) 130 | 131 | # Build one-step training and evaluation functions 132 | optimize_fn = losses.optimization_manager(cfg) 133 | train_step_fn = losses.get_step_fn(noise, graph, True, optimize_fn, cfg.training.accum) 134 | eval_step_fn = losses.get_step_fn(noise, graph, False, optimize_fn, cfg.training.accum) 135 | 136 | 137 | if cfg.training.snapshot_sampling: 138 | sampling_shape = (cfg.training.batch_size // (cfg.ngpus * cfg.training.accum), cfg.model.length) 139 | sampling_fn = sampling.get_sampling_fn(cfg, graph, noise, sampling_shape, sampling_eps, device) 140 | 141 | num_train_steps = cfg.training.n_iters 142 | mprint(f"Starting training loop at step {initial_step}.") 143 | 144 | 145 | while state['step'] < num_train_steps + 1: 146 | step = state['step'] 147 | 148 | 149 | if cfg.data.train != "text8": 150 | batch = next(train_iter)['input_ids'].to(device) 151 | else: 152 | batch = next(train_iter).to(device) 153 | loss = train_step_fn(state, batch) 154 | 155 | # flag to see if there was movement ie a full batch got computed 156 | if step != state['step']: 157 | if step % cfg.training.log_freq == 0: 158 | dist.all_reduce(loss) 159 | loss /= world_size 160 | 161 | mprint("step: %d, training_loss: %.5e" % (step, loss.item())) 162 | 163 | if step % cfg.training.snapshot_freq_for_preemption == 0 and rank == 0: 164 | utils.save_checkpoint(checkpoint_meta_dir, state) 165 | 166 | if step % cfg.training.eval_freq == 0: 167 | if cfg.data.valid != "text8": 168 | eval_batch = next(eval_iter)['input_ids'].to(device) 169 | else: 170 | eval_batch = next(train_iter).to(device) 171 | eval_loss = eval_step_fn(state, eval_batch) 172 | 173 | dist.all_reduce(eval_loss) 174 | eval_loss /= world_size 175 | 176 | mprint("step: %d, evaluation_loss: %.5e" % (step, eval_loss.item())) 177 | 178 | if step > 0 and step % cfg.training.snapshot_freq == 0 or step == num_train_steps: 179 | # Save the checkpoint. 180 | save_step = step // cfg.training.snapshot_freq 181 | if rank == 0: 182 | utils.save_checkpoint(os.path.join( 183 | checkpoint_dir, f'checkpoint_{save_step}.pth'), state) 184 | 185 | # Generate and save samples 186 | if cfg.training.snapshot_sampling: 187 | mprint(f"Generating text at step: {step}") 188 | 189 | this_sample_dir = os.path.join(sample_dir, "iter_{}".format(step)) 190 | utils.makedirs(this_sample_dir) 191 | 192 | ema.store(score_model.parameters()) 193 | ema.copy_to(score_model.parameters()) 194 | sample = sampling_fn(score_model) 195 | ema.restore(score_model.parameters()) 196 | 197 | sentences = tokenizer.batch_decode(sample) 198 | 199 | file_name = os.path.join(this_sample_dir, f"sample_{rank}.txt") 200 | with open(file_name, 'w') as file: 201 | for sentence in sentences: 202 | file.write(sentence + "\n") 203 | file.write("============================================================================================\n") 204 | 205 | if cfg.eval.perplexity: 206 | with torch.no_grad(): 207 | eval_model = GPT2LMHeadModel.from_pretrained("gpt2-large").to(device).eval() 208 | batches = sample.shape[0] // cfg.eval.perplexity_batch_size 209 | total_perplexity = 0 210 | for i in range(batches): 211 | s = sample[i * cfg.eval.perplexity_batch_size:(i + 1) * cfg.eval.perplexity_batch_size] 212 | loss, logits = eval_model(s, labels=s)[:2] 213 | logits = logits.transpose(-1, -2) 214 | perplexity = F.cross_entropy(logits[..., :-1], s[..., 1:], reduction="none").mean(dim=-1).exp().mean() 215 | total_perplexity += perplexity 216 | total_perplexity /= batches 217 | dist.all_reduce(total_perplexity) 218 | total_perplexity /= world_size 219 | mprint(f"Generative Perplexity at step: {step}. Perplexity: {total_perplexity:.3f}.") 220 | 221 | del eval_model, logits, loss 222 | 223 | dist.barrier() 224 | -------------------------------------------------------------------------------- /diffusion_lm/sampling.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import torch 3 | import torch.nn.functional as F 4 | from catsample import sample_categorical 5 | 6 | from model import utils as mutils 7 | 8 | _PREDICTORS = {} 9 | 10 | 11 | def register_predictor(cls=None, *, name=None): 12 | """A decorator for registering predictor classes.""" 13 | 14 | def _register(cls): 15 | if name is None: 16 | local_name = cls.__name__ 17 | else: 18 | local_name = name 19 | if local_name in _PREDICTORS: 20 | raise ValueError( 21 | f'Already registered model with name: {local_name}') 22 | _PREDICTORS[local_name] = cls 23 | return cls 24 | 25 | if cls is None: 26 | return _register 27 | else: 28 | return _register(cls) 29 | 30 | 31 | def get_predictor(name): 32 | return _PREDICTORS[name] 33 | 34 | 35 | 36 | class Predictor(abc.ABC): 37 | """The abstract class for a predictor algorithm.""" 38 | 39 | def __init__(self, graph, noise): 40 | super().__init__() 41 | self.graph = graph 42 | self.noise = noise 43 | 44 | @abc.abstractmethod 45 | def update_fn(self, score_fn, x, t, step_size): 46 | """One update of the predictor. 47 | 48 | Args: 49 | score_fn: score function 50 | x: A PyTorch tensor representing the current state 51 | t: A Pytorch tensor representing the current time step. 52 | 53 | Returns: 54 | x: A PyTorch tensor of the next state. 55 | """ 56 | pass 57 | 58 | 59 | @register_predictor(name="euler") 60 | class EulerPredictor(Predictor): 61 | def update_fn(self, score_fn, x, t, step_size): 62 | sigma, dsigma = self.noise(t) 63 | score = score_fn(x, sigma) 64 | 65 | rev_rate = step_size * dsigma[..., None] * self.graph.reverse_rate(x, score) 66 | x = self.graph.sample_rate(x, rev_rate) 67 | return x 68 | 69 | @register_predictor(name="none") 70 | class NonePredictor(Predictor): 71 | def update_fn(self, score_fn, x, t, step_size): 72 | return x 73 | 74 | 75 | @register_predictor(name="analytic") 76 | class AnalyticPredictor(Predictor): 77 | def update_fn(self, score_fn, x, t, step_size): 78 | curr_sigma = self.noise(t)[0] 79 | next_sigma = self.noise(t - step_size)[0] 80 | dsigma = curr_sigma - next_sigma 81 | 82 | score = score_fn(x, curr_sigma) 83 | 84 | stag_score = self.graph.staggered_score(score, dsigma) 85 | probs = stag_score * self.graph.transp_transition(x, dsigma) 86 | return sample_categorical(probs) 87 | 88 | 89 | class Denoiser: 90 | def __init__(self, graph, noise): 91 | self.graph = graph 92 | self.noise = noise 93 | 94 | def update_fn(self, score_fn, x, t): 95 | sigma = self.noise(t)[0] 96 | 97 | score = score_fn(x, sigma) 98 | stag_score = self.graph.staggered_score(score, sigma) 99 | probs = stag_score * self.graph.transp_transition(x, sigma) 100 | # truncate probabilities 101 | if self.graph.absorb: 102 | probs = probs[..., :-1] 103 | 104 | #return probs.argmax(dim=-1) 105 | return sample_categorical(probs) 106 | 107 | 108 | def get_sampling_fn(config, graph, noise, batch_dims, eps, device): 109 | 110 | sampling_fn = get_pc_sampler(graph=graph, 111 | noise=noise, 112 | batch_dims=batch_dims, 113 | predictor=config.sampling.predictor, 114 | steps=config.sampling.steps, 115 | denoise=config.sampling.noise_removal, 116 | eps=eps, 117 | device=device) 118 | 119 | return sampling_fn 120 | 121 | 122 | def get_pc_sampler(graph, noise, batch_dims, predictor, steps, denoise=True, eps=1e-5, device=torch.device('cpu'), proj_fun=lambda x: x): 123 | predictor = get_predictor(predictor)(graph, noise) 124 | projector = proj_fun 125 | denoiser = Denoiser(graph, noise) 126 | 127 | @torch.no_grad() 128 | def pc_sampler(model): 129 | sampling_score_fn = mutils.get_score_fn(model, train=False, sampling=True) 130 | x = graph.sample_limit(*batch_dims).to(device) 131 | timesteps = torch.linspace(1, eps, steps + 1, device=device) 132 | dt = (1 - eps) / steps 133 | 134 | for i in range(steps): 135 | t = timesteps[i] * torch.ones(x.shape[0], 1, device=device) 136 | x = projector(x) 137 | x = predictor.update_fn(sampling_score_fn, x, t, dt) 138 | 139 | 140 | if denoise: 141 | # denoising step 142 | x = projector(x) 143 | t = timesteps[-1] * torch.ones(x.shape[0], 1, device=device) 144 | x = denoiser.update_fn(sampling_score_fn, x, t) 145 | 146 | return x 147 | 148 | return pc_sampler 149 | 150 | -------------------------------------------------------------------------------- /diffusion_lm/train.py: -------------------------------------------------------------------------------- 1 | """Training and evaluation""" 2 | 3 | import hydra 4 | import os 5 | import numpy as np 6 | import run_train 7 | import utils 8 | import torch.multiprocessing as mp 9 | from hydra.core.hydra_config import HydraConfig 10 | from hydra.types import RunMode 11 | from omegaconf import OmegaConf, open_dict 12 | 13 | 14 | @hydra.main(version_base=None, config_path="configs", config_name="config") 15 | def main(cfg): 16 | ngpus = cfg.ngpus 17 | if "load_dir" in cfg: 18 | hydra_cfg_path = os.path.join(cfg.load_dir, ".hydra/hydra.yaml") 19 | hydra_cfg = OmegaConf.load(hydra_cfg_path).hydra 20 | 21 | cfg = utils.load_hydra_config_from_run(cfg.load_dir) 22 | 23 | work_dir = cfg.work_dir 24 | utils.makedirs(work_dir) 25 | else: 26 | hydra_cfg = HydraConfig.get() 27 | work_dir = hydra_cfg.run.dir if hydra_cfg.mode == RunMode.RUN else os.path.join(hydra_cfg.sweep.dir, hydra_cfg.sweep.subdir) 28 | utils.makedirs(work_dir) 29 | 30 | with open_dict(cfg): 31 | cfg.ngpus = ngpus 32 | cfg.work_dir = work_dir 33 | cfg.wandb_name = os.path.basename(os.path.normpath(work_dir)) 34 | 35 | # Run the training pipeline 36 | port = int(np.random.randint(10000, 20000)) 37 | logger = utils.get_logger(os.path.join(work_dir, "logs")) 38 | 39 | hydra_cfg = HydraConfig.get() 40 | if hydra_cfg.mode != RunMode.RUN: 41 | logger.info(f"Run id: {hydra_cfg.job.id}") 42 | 43 | try: 44 | mp.set_start_method("forkserver") 45 | mp.spawn(run_train.run_multiprocess, args=(ngpus, cfg, port), nprocs=ngpus, join=True) 46 | except Exception as e: 47 | logger.critical(e, exc_info=True) 48 | 49 | 50 | if __name__ == "__main__": 51 | main() -------------------------------------------------------------------------------- /diffusion_lm/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import os 4 | import logging 5 | from omegaconf import OmegaConf, open_dict 6 | 7 | 8 | def load_hydra_config_from_run(load_dir): 9 | cfg_path = os.path.join(load_dir, ".hydra/config.yaml") 10 | cfg = OmegaConf.load(cfg_path) 11 | return cfg 12 | 13 | 14 | def makedirs(dirname): 15 | os.makedirs(dirname, exist_ok=True) 16 | 17 | 18 | def get_logger(logpath, package_files=[], displaying=True, saving=True, debug=False): 19 | logger = logging.getLogger() 20 | if debug: 21 | level = logging.DEBUG 22 | else: 23 | level = logging.INFO 24 | 25 | if (logger.hasHandlers()): 26 | logger.handlers.clear() 27 | 28 | logger.setLevel(level) 29 | formatter = logging.Formatter('%(asctime)s - %(message)s') 30 | if saving: 31 | info_file_handler = logging.FileHandler(logpath, mode="a") 32 | info_file_handler.setLevel(level) 33 | info_file_handler.setFormatter(formatter) 34 | logger.addHandler(info_file_handler) 35 | if displaying: 36 | console_handler = logging.StreamHandler() 37 | console_handler.setLevel(level) 38 | console_handler.setFormatter(formatter) 39 | logger.addHandler(console_handler) 40 | 41 | for f in package_files: 42 | logger.info(f) 43 | with open(f, "r") as package_f: 44 | logger.info(package_f.read()) 45 | 46 | return logger 47 | 48 | 49 | def restore_checkpoint(ckpt_dir, state, device): 50 | if not os.path.exists(ckpt_dir): 51 | makedirs(os.path.dirname(ckpt_dir)) 52 | logging.warning(f"No checkpoint found at {ckpt_dir}. Returned the same state as input") 53 | return state 54 | else: 55 | loaded_state = torch.load(ckpt_dir, map_location=device) 56 | state['optimizer'].load_state_dict(loaded_state['optimizer']) 57 | state['model'].module.load_state_dict(loaded_state['model'], strict=False) 58 | state['ema'].load_state_dict(loaded_state['ema']) 59 | state['step'] = loaded_state['step'] 60 | return state 61 | 62 | 63 | def save_checkpoint(ckpt_dir, state): 64 | saved_state = { 65 | 'optimizer': state['optimizer'].state_dict(), 66 | 'model': state['model'].module.state_dict(), 67 | 'ema': state['ema'].state_dict(), 68 | 'step': state['step'] 69 | } 70 | torch.save(saved_state, ckpt_dir) -------------------------------------------------------------------------------- /inverse_diffusion/README.md: -------------------------------------------------------------------------------- 1 | # Amortizing intractable inference in diffusion models for vision, language, and control 2 | Code to finetune a Diffusion posterior from an unconditional diffusion prior. The code base was written to be compatible with the [Diffusers](https://huggingface.co/docs/diffusers/en/index) library. 3 | 4 | ![alt text](samples/cls_finetuning_samples.png "Prior Posterior Samples") 5 | 6 | 7 | ## File Structure 8 | 9 | - `src/`: Contains all the source code. 10 | - `src/models`: Code for denoising and diffusion models. 11 | - `src/models/pretrained/`: Should contain pretrained weights. 12 | - `src/utils/`: Custom utility libraries. 13 | - Code to run experiments. 14 | - `results/`: Stores local results (can be redirected). 15 | 16 | 17 | ### Training a Prior Model Locally 18 | Train a prior model locally using `train_prior.py`. The results will be in saved in `./results/` folder unless the `--save_folder` argument is set. 19 | 20 | ### Fine-Tuning a Pretrained Model 21 | 22 | To finetune you must run the `finetune_posterior.py` file. Remember to substitute `PATH-TO-DATA` and `PATH-TO-RESULTS`. The supported datasets and `mnist` and `cifar-10`. For each, the prior model to use has already been hardcoded in our library, and will be automatically retrieved from the hugging face database. Should you want to try other prior models this can be done. Note, for best results, use prior models trained with large variance scheduling. 23 | ##### Example Command 24 | ```bash 25 | python ../finetune_posterior.py --data_path PATH-TO-DATA --save_folder PATH-TO-RESULTS \ 26 | --load_path ./../models/pretrained --dataset mnist --lr 6e-4 \ 27 | --sampling_length 100 --batch_size 32 --accumulate_gradient_every 1 \ 28 | --epochs 5000 --finetune_class 7 --compute_fid True --checkpointing True \ 29 | --push_to_hf False --method gfn --exp_name example_run 30 | ``` 31 | 32 | # Paper 33 | 34 | If you have found any of our work useful for your own project, please cite our paper: 35 | ```bash 36 | citation 37 | ``` -------------------------------------------------------------------------------- /inverse_diffusion/requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.27.2 2 | apex==0.9.10dev 3 | deeplake==3.9.4 4 | denoising_diffusion_pytorch==1.10.12 5 | einops==0.8.0 6 | ema_pytorch==0.4.2 7 | fld==1.0.0 8 | gdown==5.2.0 9 | huggingface_hub==0.22.2 10 | matplotlib==3.8.0 11 | numpy==1.26.4 12 | packaging==24.0 13 | peft==0.10.0 14 | Pillow==10.3.0 15 | pytorch_fid==0.3.0 16 | scheduling_utils==0.2.3 17 | torch==2.2.0 18 | torchvision==0.16.2 19 | tqdm==4.66.1 20 | wandb==0.16.6 21 | -------------------------------------------------------------------------------- /inverse_diffusion/samples/cls_finetuning_samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GFNOrg/diffusion-finetuning/21bfa222b606abd64873994f424a477efc29707d/inverse_diffusion/samples/cls_finetuning_samples.png -------------------------------------------------------------------------------- /inverse_diffusion/src/create_data.py: -------------------------------------------------------------------------------- 1 | from utils.args import fetch_args 2 | from utils.simple_io import * 3 | from PIL import Image 4 | 5 | import torchvision 6 | import gdown 7 | import deeplake 8 | 9 | # get arguments for the run 10 | args, state = fetch_args() 11 | 12 | if 'celeb' in args.dataset.lower(): 13 | dt_id = 'celeb-a' 14 | else: 15 | dt_id = args.dataset.lower() 16 | 17 | 18 | if args.dataset.lower() in ['mnist', 'celeba', 'cifar10']: 19 | splits = ['train', 'test'] 20 | for split in splits: 21 | # use deep lake 22 | deeplake.dataset(f"hub://activeloop/{dt_id}-{split}", access_method='local') 23 | 24 | elif 'utkface' in args.dataset.lower(): 25 | def setup_utkface_dataset(args): 26 | # Define Google Drive links 27 | # utkface_link = 'https://drive.google.com/file/d/1W-vm-rgSDsPA015wQQ9vWzquR_KvgBwe' 28 | # crop1_link = 'https://drive.google.com/file/d/19GNs2OPm0zvkR99nFlXNo1aQTB22HFCj' 29 | 30 | # Define destination paths 31 | utkface_zip_path = f'{args.data_path}/UTKFace.zip' 32 | crop1_zip_path = f'{args.data_path}/crop1.zip' 33 | 34 | # a file 35 | utkface_id = "1W-vm-rgSDsPA015wQQ9vWzquR_KvgBwe" 36 | crop1_id = "19GNs2OPm0zvkR99nFlXNo1aQTB22HFCj" 37 | gdown.download(id=utkface_id, output=utkface_zip_path) 38 | gdown.download(id=crop1_id, output=crop1_zip_path) 39 | 40 | # Extract ZIP files 41 | extract_zip(utkface_zip_path, f'{args.data_path}') 42 | extract_zip(crop1_zip_path, f'{args.data_path}') 43 | 44 | 45 | # Assuming fetch_args() provides necessary paths 46 | args, state = fetch_args() 47 | setup_utkface_dataset(args) 48 | 49 | #example run > python create_data.py --dataset utkface --data_path $SCRATCH/data/ 50 | -------------------------------------------------------------------------------- /inverse_diffusion/src/finetune_posterior.py: -------------------------------------------------------------------------------- 1 | from models.classifiers import CNN, ResNet 2 | from models.langevin import LangevinModel 3 | from utils.args import fetch_args 4 | from utils.gfn_diffusion import load_PPDGFN_from_diffusers,load_PPDGFN_from_diffusers, GFNFinetuneTrainer, ReinforceFinetuneTrainer 5 | from utils.pytorch_utils import seed_experiment, train_classifier, print_gpu_memory 6 | from utils.gfn_diffusion import load_PPDGFN_from_diffusers, GFNFinetuneTrainer, \ 7 | load_DDPM_from_diffusers, ReinforceFinetuneTrainer, BaselineGuidance 8 | from utils.pytorch_utils import seed_experiment, train_classifier, get_train_classifier 9 | from utils.simple_io import * 10 | 11 | import torch as T 12 | import numpy as np 13 | import torch.nn as nn 14 | 15 | 16 | # get arguments for the run 17 | args, state = fetch_args(exp_prepend='train_posterior') 18 | print(f"Running experiment on '{args.device}'") 19 | 20 | # -------- OVERRIDE ARGUMENTS (if you have to) -------- 21 | # args.dataset = 'mnist' 22 | # args.t_scale = 1 23 | # args.batch_size = 128 24 | # args.lr = 1e-5 25 | # args.lr_logZ = 5e-2 26 | args.learn_var = True 27 | 28 | logtwopi = np.log(2 * 3.14159265358979) 29 | 30 | seed_experiment(args.seed) 31 | 32 | # ------------ Load pretrained PosteriorPriorGFN from pretrained diffusion model --------- 33 | 34 | if args.method in['gfn', "reinforce"]: 35 | sampler = load_PPDGFN_from_diffusers(args) # note, this function may change some of the args in place 36 | else: 37 | # todo baseline case 38 | sampler = load_DDPM_from_diffusers(args) 39 | print(args) 40 | 41 | # -------- Train a reward model (if not already pretrained) --------- 42 | classifier = get_train_classifier(args, scheduler=sampler.get_scheduler()) 43 | 44 | # ---- add langevin if gfn and user specified (here because we need a classifier first) ---- 45 | if args.method == 'gfn' and args.langevin: 46 | 47 | log_reward = classifier if args.langevin else lambda x: T.zeros((x.shape[0],), device=args.device) 48 | 49 | problem_dim = int(args.channels * (args.image_size ** 2)) 50 | lgv_model = LangevinModel(problem_dim, args.lgv_t_dim, args.lgv_hidden_dim, 1, 51 | args.lgv_num_layers, args.lgv_zero_init) 52 | 53 | sampler.add_langevin(log_reward=log_reward, 54 | lgv_model=lgv_model, 55 | lgv_clip=args.lgv_clip, 56 | lgv_clipping=args.lgv_clipping) 57 | else: 58 | sampler.add_classifier(classifier=classifier) # add classifier for classifier guidance 59 | 60 | # ------------ Set learning params --------- 61 | if args.method == 'gfn': 62 | params = [param for param in sampler.posterior_node.get_parameters() if param.requires_grad] 63 | opt = T.optim.Adam([{'params': params, 64 | 'lr': args.lr}, 65 | {'params': [sampler.logZ], 66 | 'lr': args.lr_logZ, 67 | 'weight_decay':args.z_weight_decay}]) 68 | 69 | Trainer = GFNFinetuneTrainer 70 | else: 71 | opt = None 72 | Trainer = BaselineGuidance 73 | 74 | 75 | # ------------ Train posterior --------- 76 | if args.method in ['gfn', 'dp', 'lgd_mc']: 77 | trainer = Trainer( 78 | sampler=sampler, 79 | classifier=classifier, 80 | optimizer=opt, 81 | finetune_class=args.finetune_class, 82 | save_folder=args.save_folder, 83 | config=args 84 | ) 85 | elif args.method == 'reinforce': 86 | trainer = ReinforceFinetuneTrainer( 87 | sampler=sampler, 88 | classifier=classifier, 89 | optimizer=opt, 90 | finetune_class=args.finetune_class, 91 | save_folder=args.save_folder, 92 | config=args 93 | ) 94 | else: 95 | raise ValueError(f"Method '{args.method}' not recognized.") 96 | 97 | trainer.run( 98 | finetune_class=args.finetune_class, 99 | epochs=args.epochs, 100 | back_and_forth=args.back_and_forth, 101 | ) 102 | -------------------------------------------------------------------------------- /inverse_diffusion/src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GFNOrg/diffusion-finetuning/21bfa222b606abd64873994f424a477efc29707d/inverse_diffusion/src/models/__init__.py -------------------------------------------------------------------------------- /inverse_diffusion/src/models/classifiers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torchvision.models as models 4 | 5 | 6 | class CNN(nn.Module): 7 | def __init__(self, input_size, channels, num_classes): 8 | super(CNN, self).__init__() 9 | self.input_shape = (channels, input_size, input_size) 10 | 11 | self.conv1 = nn.Sequential( 12 | nn.Conv2d( 13 | in_channels=channels, 14 | out_channels=16, 15 | kernel_size=5, 16 | stride=1, 17 | padding=2, 18 | ), 19 | nn.ReLU(), 20 | nn.MaxPool2d(kernel_size=2), 21 | ) 22 | self.conv2 = nn.Sequential( 23 | nn.Conv2d(16, 32, 5, 1, 2), 24 | nn.ReLU(), 25 | nn.MaxPool2d(2), 26 | ) 27 | self.embed = nn.Sequential(self.conv1, self.conv2) 28 | 29 | x = self.embed(torch.zeros((1,) + tuple(self.input_shape))) 30 | input_dim = x.view(x.size(0), -1).shape[-1] 31 | self.out = nn.Linear(input_dim, num_classes) 32 | 33 | def forward(self, x): 34 | x = self.conv1(x) 35 | x = self.conv2(x) 36 | x = x.view(x.size(0), -1) 37 | output = self.out(x) 38 | return output 39 | 40 | 41 | class ResNet(nn.Module): 42 | 43 | name = "ResNet" 44 | 45 | def __init__(self, input_size, channels, num_classes, depth=18, finetune=False): 46 | super(ResNet, self).__init__() 47 | self.input_shape = (channels, input_size, input_size) 48 | self.depth = depth 49 | self.num_classes = num_classes 50 | 51 | # create the embedding (sub) model 52 | self.resnet = self._init_resnet(finetune) 53 | 54 | if finetune: 55 | for name, param in self.resnet.named_parameters(): 56 | param.requires_grad = False 57 | 58 | # we dynamically create readout layer dimensions 59 | x = self.resnet(torch.zeros((2,) + tuple(self.input_shape))) 60 | 61 | # make sure to set the input dimension, so the paent class knows how to build the models 62 | input_dim = x.view(x.size(0), -1).shape[-1] 63 | 64 | # creates and initializes all models and associated optimizers within ensemble 65 | self.out_layer = self._create_readout_layer(input_dim, num_classes) 66 | 67 | def _init_resnet(self, finetune=False): 68 | 69 | # Dynamically select the ResNet architecture 70 | resnet_class = getattr(models, f'resnet{self.depth}') 71 | 72 | # Initialize a pre-trained model 73 | resnet_model = resnet_class(weights=models.ResNet18_Weights.DEFAULT if finetune else None) 74 | # resnet_model = resnet_class(weights=models.ResNet18_Weights.DEFAULT) 75 | 76 | # Modify the first convolutional layer (for different input channels) 77 | if self.input_shape[0] != 3: 78 | resnet_model.conv1 = torch.nn.Conv2d(self.input_shape[0], 64, 79 | kernel_size=(7, 7), 80 | stride=(2, 2), 81 | padding=(0, 3), 82 | bias=False) 83 | 84 | modules = list(resnet_model.children())[:-1] # Exclude the last fc layer 85 | resnet_model = nn.Sequential(*modules) 86 | 87 | return resnet_model 88 | 89 | def _create_readout_layer(self, input_dimension, output_dimension): 90 | return nn.Sequential(nn.Flatten(), 91 | nn.Linear(input_dimension, output_dimension)) 92 | 93 | def forward(self, x): 94 | x = self.resnet(x) 95 | return self.out_layer(x) 96 | -------------------------------------------------------------------------------- /inverse_diffusion/src/models/langevin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class LangevinModel(nn.Module): 6 | def __init__(self, problem_dim: int, t_dim: int, hidden_dim: int = 256, out_dim: int = 1, num_layers: int = 3, 7 | zero_init: bool = False): 8 | super(LangevinModel, self).__init__() 9 | 10 | pe = torch.linspace(start=0.1, end=100, steps=t_dim)[None] 11 | 12 | self.timestep_phase = nn.Parameter(torch.randn(t_dim)[None]) 13 | 14 | self.lgv_model = nn.Sequential( 15 | nn.Linear(problem_dim + (2 * t_dim), hidden_dim), 16 | *[ 17 | nn.Sequential( 18 | nn.GELU(), 19 | nn.Linear(hidden_dim, hidden_dim), 20 | ) 21 | for _ in range(num_layers - 1) 22 | ], 23 | nn.GELU(), 24 | nn.Linear(hidden_dim, out_dim) 25 | ) 26 | 27 | self.register_buffer('pe', pe) 28 | 29 | if zero_init: 30 | self.lgv_model[-1].weight.data.fill_(0.0) 31 | self.lgv_model[-1].bias.data.fill_(0.01) 32 | 33 | def forward(self, x, t): 34 | bs, _, _, _ = x.shape 35 | t_sin = ((t * self.pe) + self.timestep_phase).sin() 36 | t_cos = ((t * self.pe) + self.timestep_phase).cos() 37 | t_emb = torch.cat([t_sin, t_cos], dim=-1) 38 | t_emb = t_emb.repeat(bs, 1) 39 | x = torch.flatten(x, start_dim=1) 40 | scaling_factor = self.lgv_model(torch.cat([x, t_emb], dim=-1)).reshape(bs, 1, 1, 1) 41 | return scaling_factor -------------------------------------------------------------------------------- /inverse_diffusion/src/train_classifier.py: -------------------------------------------------------------------------------- 1 | from models.classifiers import CNN, ResNet 2 | from utils.args import fetch_args 3 | from utils.data_loaders import get_dataset, cycle 4 | from utils.pytorch_utils import seed_experiment, train_classifier, NoContext 5 | from utils.simple_io import * 6 | from torch.cuda.amp import autocast 7 | 8 | import torch as T 9 | import numpy as np 10 | import gc 11 | 12 | # get arguments for the run 13 | args, state = fetch_args(exp_prepend='') 14 | print(f"Running experiment on '{args.device}'") 15 | 16 | seed_experiment(args.seed) 17 | 18 | if 'cnn' in args.classifier_model.lower(): 19 | classifier = CNN( 20 | input_size=args.image_size, 21 | channels=args.channels, 22 | num_classes=args.num_classes 23 | ) 24 | else: 25 | classifier = ResNet( 26 | input_size=args.image_size, 27 | channels=args.channels, 28 | num_classes=args.num_classes, 29 | depth=args.classifier_depth, 30 | finetune=args.classifier_pretrained 31 | ) 32 | 33 | pretrained_classifer = get_filenames( 34 | args.load_path, 35 | contains=[args.dataset, "classifier", f"_{args.multi_class_index}_"], 36 | ends_with='.pth' 37 | ) 38 | 39 | # if len(pretrained_classifer) == 0: 40 | print(f"No classifier '{args.model}' found for '{args.dataset}' data, \nTraining classifier...") 41 | trained_params, best_epoch, classifier = train_classifier( 42 | model=classifier, 43 | batch_size=128, 44 | data_path=args.data_path, 45 | dataset=args.dataset, 46 | channels=args.channels, 47 | image_size=args.image_size, 48 | x_tensor=args.x_tensor, 49 | y_tensor=args.y_tensor, 50 | multi_class_index=args.multi_class_index, 51 | workers=0 if args is None or 'linux' not in args.system else 4, 52 | use_cuda=args.use_cuda 53 | ) 54 | T.save(trained_params, f'{args.load_path}/{args.dataset}_classifier_{args.multi_class_index}_{best_epoch}.pth') 55 | # else: 56 | # print(f"A classifier already exists for {args.dataset}, classes: {args.multi_class_index}: {args.finetune_class}!") -------------------------------------------------------------------------------- /inverse_diffusion/src/train_prior.py: -------------------------------------------------------------------------------- 1 | from diffusers import DDPMScheduler, DDIMScheduler, DDPMPipeline 2 | from diffusers.models.unet_2d import UNet2DModel 3 | from models.denoisers import ScoreNet, Unet 4 | from utils.args import fetch_args 5 | from utils.data_loaders import get_dataset 6 | from utils.diffusers.schedulers.scheduling_ddpm_gfn import DDPMGFNScheduler 7 | from utils.pytorch_utils import * 8 | from utils.diffusion import GaussianDiffusion, DiffTrainer, DiffuserTrainer, TrainingConfig 9 | from diffusers.optimization import get_cosine_schedule_with_warmup 10 | 11 | import torch as T 12 | import numpy as np 13 | 14 | # get arguments for the run 15 | args, state = fetch_args(exp_prepend="train_prior") 16 | 17 | # -------- OVERRIDE ARGUMENTS (if you have to) -------- 18 | # args.epochs = 30000 19 | # 20 | # args.algo = 'mle' 21 | # args.t_scale = 1 22 | # args.batch_size = 128 23 | # args.lr = 1e-3 24 | # args.lr_logZ = 1e-1 25 | 26 | logtwopi = np.log(2 * 3.14159265358979) 27 | # ----------------------------------------------------- 28 | 29 | print(args) 30 | logger = Logger(args) 31 | logger.save_args() 32 | 33 | seed_experiment(args.seed) 34 | 35 | # Get dataset for training -- default args.dataset is 'mnist' 36 | 37 | data_dict = get_dataset( 38 | dataset=args.dataset, 39 | batch_size=args.batch_size, 40 | data_path=args.data_path, 41 | channels=args.channels, 42 | image_size=args.image_size, 43 | multi_class_index=args.multi_class_index, 44 | x_tensor=args.x_tensor, 45 | y_tensor=args.y_tensor, 46 | splits=args.splits, 47 | workers=0 if args is None or 'linux' not in args.system else 4 48 | ) 49 | train_dataset = data_dict['train_data'] 50 | train_loader = data_dict['train_loader'] 51 | x_dim = (args.image_size, args.image_size, args.channels) 52 | 53 | model = UNet2DModel( 54 | sample_size=args.image_size, # the target image resolution 55 | in_channels=args.channels, # the number of input channels, 3 for RGB images 56 | out_channels=args.channels, # the number of output channels 57 | layers_per_block=2, # how many ResNet layers to use per UNet block 58 | block_out_channels=args.block_out_channels 59 | ) 60 | 61 | print(f'Total params: \nFwd policy model: {(sum(p.numel() for p in model.parameters()) / 1e6):.2f}M ') 62 | 63 | 64 | noise_scheduler = DDPMGFNScheduler( 65 | num_train_timesteps=args.traj_length, 66 | beta_end=0.02, 67 | beta_schedule="linear", 68 | beta_start=0.0001, 69 | clip_sample=True, 70 | variance_type='fixed_large' 71 | ) 72 | 73 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr) 74 | lr_scheduler = get_cosine_schedule_with_warmup( 75 | optimizer=optimizer, 76 | num_warmup_steps=500, 77 | num_training_steps=args.epochs, 78 | ) 79 | 80 | trainer = DiffuserTrainer( 81 | config=TrainingConfig(args), 82 | model=model, 83 | noise_scheduler=noise_scheduler, 84 | optimizer=optimizer, 85 | train_dataloader=data_dict['train_loader'], 86 | lr_scheduler=lr_scheduler 87 | ) 88 | 89 | trainer.train() 90 | 91 | -------------------------------------------------------------------------------- /inverse_diffusion/src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GFNOrg/diffusion-finetuning/21bfa222b606abd64873994f424a477efc29707d/inverse_diffusion/src/utils/__init__.py -------------------------------------------------------------------------------- /inverse_diffusion/src/utils/diffusers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GFNOrg/diffusion-finetuning/21bfa222b606abd64873994f424a477efc29707d/inverse_diffusion/src/utils/diffusers/__init__.py -------------------------------------------------------------------------------- /inverse_diffusion/src/utils/diffusers/pipelines/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GFNOrg/diffusion-finetuning/21bfa222b606abd64873994f424a477efc29707d/inverse_diffusion/src/utils/diffusers/pipelines/.DS_Store -------------------------------------------------------------------------------- /inverse_diffusion/src/utils/diffusers/pipelines/ddim_gfn/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from diffusers.utils import DIFFUSERS_SLOW_IMPORT, _LazyModule 4 | 5 | 6 | _import_structure = {"pipeline_ddim_gfn": ["DDIMGFNPipeline"]} 7 | 8 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 9 | from .pipeline_ddim_gfn import DDIMPipeline 10 | else: 11 | import sys 12 | 13 | sys.modules[__name__] = _LazyModule( 14 | __name__, 15 | globals()["__file__"], 16 | _import_structure, 17 | ) 18 | -------------------------------------------------------------------------------- /inverse_diffusion/src/utils/diffusers/pipelines/ddim_gfn/pipeline_ddim_gfn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | from typing import List, Optional, Tuple, Union 16 | 17 | import torch 18 | 19 | from utils.diffusers.schedulers.scheduling_ddim_gfn import DDIMGFNScheduler 20 | from diffusers.utils.torch_utils import randn_tensor 21 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput 22 | 23 | 24 | class DDIMGFNPipeline(DiffusionPipeline): 25 | r""" 26 | Pipeline for image generation. 27 | 28 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods 29 | implemented for all pipelines (downloading, saving, running on a particular device, etc.). 30 | 31 | Parameters: 32 | unet ([`UNet2DModel`]): 33 | A `UNet2DModel` to denoise the encoded image latents. 34 | scheduler ([`SchedulerMixin`]): 35 | A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of 36 | [`DDPMScheduler`], or [`DDIMScheduler`], or [`DDIMGFNScheduler`]. 37 | """ 38 | 39 | model_cpu_offload_seq = "unet" 40 | 41 | def __init__(self, unet, scheduler): 42 | super().__init__() 43 | 44 | # make sure scheduler can always be converted to DDIM 45 | scheduler = DDIMGFNScheduler.from_config(scheduler.config) 46 | 47 | self.register_modules(unet=unet, scheduler=scheduler) 48 | 49 | @torch.no_grad() 50 | def __call__( 51 | self, 52 | batch_size: int = 1, 53 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 54 | eta: float = 0.0, 55 | num_inference_steps: int = 50, 56 | use_clipped_model_output: Optional[bool] = None, 57 | output_type: Optional[str] = "pil", 58 | return_dict: bool = True, 59 | ) -> Union[ImagePipelineOutput, Tuple]: 60 | r""" 61 | The call function to the pipeline for generation. 62 | 63 | Args: 64 | batch_size (`int`, *optional*, defaults to 1): 65 | The number of images to generate. 66 | generator (`torch.Generator`, *optional*): 67 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make 68 | generation deterministic. 69 | eta (`float`, *optional*, defaults to 0.0): 70 | Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies 71 | to the [`~schedulers.DDIMGFNScheduler`], and is ignored in other schedulers. A value of `0` corresponds to 72 | DDIM and `1` corresponds to DDPM. 73 | num_inference_steps (`int`, *optional*, defaults to 50): 74 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 75 | expense of slower inference. 76 | use_clipped_model_output (`bool`, *optional*, defaults to `None`): 77 | If `True` or `False`, see documentation for [`DDIMGFNScheduler.step`]. If `None`, nothing is passed 78 | downstream to the scheduler (use `None` for schedulers which don't support this argument). 79 | output_type (`str`, *optional*, defaults to `"pil"`): 80 | The output format of the generated image. Choose between `PIL.Image` or `np.array`. 81 | return_dict (`bool`, *optional*, defaults to `True`): 82 | Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. 83 | 84 | Example: 85 | 86 | ```py 87 | >>> from utils.diffusers.pipelines import DDIMGFNPipeline 88 | >>> import PIL.Image 89 | >>> import numpy as np 90 | 91 | >>> # load model and scheduler 92 | >>> pipe = DDIMGFNPipeline.from_pretrained("fusing/ddim_gfn-lsun-bedroom") 93 | 94 | >>> # run pipeline in inference (sample random noise and denoise) 95 | >>> image = pipe(eta=0.0, num_inference_steps=50) 96 | 97 | >>> # process image to PIL 98 | >>> image_processed = image.cpu().permute(0, 2, 3, 1) 99 | >>> image_processed = (image_processed + 1.0) * 127.5 100 | >>> image_processed = image_processed.numpy().astype(np.uint8) 101 | >>> image_pil = PIL.Image.fromarray(image_processed[0]) 102 | 103 | >>> # save image 104 | >>> image_pil.save("test.png") 105 | ``` 106 | 107 | Returns: 108 | [`~pipelines.ImagePipelineOutput`] or `tuple`: 109 | If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is 110 | returned where the first element is a list with the generated images 111 | """ 112 | 113 | # Sample gaussian noise to begin loop 114 | if isinstance(self.unet.config.sample_size, int): 115 | image_shape = ( 116 | batch_size, 117 | self.unet.config.in_channels, 118 | self.unet.config.sample_size, 119 | self.unet.config.sample_size, 120 | ) 121 | else: 122 | image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size) 123 | 124 | if isinstance(generator, list) and len(generator) != batch_size: 125 | raise ValueError( 126 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 127 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 128 | ) 129 | 130 | image = randn_tensor(image_shape, generator=generator, device=self._execution_device, dtype=self.unet.dtype) 131 | 132 | # set step values 133 | self.scheduler.set_timesteps(num_inference_steps) 134 | 135 | for t in self.progress_bar(self.scheduler.timesteps): 136 | # 1. predict noise model_output 137 | model_output = self.unet(image, t).sample 138 | 139 | # 2. predict previous mean of image x_t-1 and add variance depending on eta 140 | # eta corresponds to η in paper and should be between [0, 1] 141 | # do x_t -> x_t-1 142 | image = self.scheduler.step( 143 | model_output, t, image, eta=eta, use_clipped_model_output=use_clipped_model_output, generator=generator 144 | ).prev_sample 145 | 146 | image = (image / 2 + 0.5).clamp(0, 1) 147 | image = image.cpu().permute(0, 2, 3, 1).numpy() 148 | if output_type == "pil": 149 | image = self.numpy_to_pil(image) 150 | 151 | if not return_dict: 152 | return (image,) 153 | 154 | return ImagePipelineOutput(images=image) 155 | 156 | def sample(self, *args, **kwargs): 157 | res = self(*args, convert=False, **kwargs)[0] 158 | if isinstance(res, ImagePipelineOutput): 159 | return (res.images * 2) - 1 160 | else: 161 | return (res * 2) - 1 162 | 163 | def eval(self): 164 | self.unet.eval() 165 | 166 | def train(self): 167 | self.unet.train() -------------------------------------------------------------------------------- /inverse_diffusion/src/utils/diffusers/pipelines/ddpm_dp/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from diffusers.utils import ( 4 | DIFFUSERS_SLOW_IMPORT, 5 | _LazyModule, 6 | ) 7 | 8 | 9 | _import_structure = {"pipeline_ddpm": ["DDPMPipeline"]} 10 | 11 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 12 | from utils.diffusers.pipelines import DDPMGFNPipeline 13 | 14 | else: 15 | import sys 16 | 17 | sys.modules[__name__] = _LazyModule( 18 | __name__, 19 | globals()["__file__"], 20 | _import_structure, 21 | ) 22 | -------------------------------------------------------------------------------- /inverse_diffusion/src/utils/diffusers/pipelines/ddpm_dp/pipeline_ddpm_dp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | 20 | from utils.diffusers.schedulers.scheduling_ddpm_dp import DDPMDPScheduler 21 | from diffusers.utils.torch_utils import randn_tensor 22 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput 23 | 24 | 25 | class DDPMDPPipeline(DiffusionPipeline): 26 | r""" 27 | Pipeline for image generation. 28 | 29 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods 30 | implemented for all pipelines (downloading, saving, running on a particular device, etc.). 31 | 32 | Parameters: 33 | unet ([`UNet2DModel`]): 34 | A `UNet2DModel` to denoise the encoded image latents. 35 | scheduler ([`SchedulerMixin`]): 36 | A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of 37 | [`DDPMScheduler`], or [`DDIMScheduler`]. 38 | """ 39 | 40 | model_cpu_offload_seq = "unet" 41 | 42 | def __init__(self, unet, scheduler): 43 | super().__init__() 44 | # make sure scheduler can always be converted to DDIM 45 | scheduler = DDPMDPScheduler.from_config(scheduler.config) 46 | self.register_modules(unet=unet, scheduler=scheduler) 47 | 48 | @torch.no_grad() 49 | def __call__( 50 | self, 51 | batch_size: int = 1, 52 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 53 | num_inference_steps: int = 1000, 54 | output_type: Optional[str] = "pil", 55 | return_dict: bool = True, 56 | ) -> Union[ImagePipelineOutput, Tuple]: 57 | r""" 58 | The call function to the pipeline for generation. 59 | 60 | Args: 61 | batch_size (`int`, *optional*, defaults to 1): 62 | The number of images to generate. 63 | generator (`torch.Generator`, *optional*): 64 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make 65 | generation deterministic. 66 | num_inference_steps (`int`, *optional*, defaults to 1000): 67 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 68 | expense of slower inference. 69 | output_type (`str`, *optional*, defaults to `"pil"`): 70 | The output format of the generated image. Choose between `PIL.Image` or `np.array`. 71 | return_dict (`bool`, *optional*, defaults to `True`): 72 | Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. 73 | 74 | Example: 75 | 76 | ```py 77 | >>> from diffusers import DDPMDPPipeline 78 | 79 | >>> # load model and scheduler 80 | >>> pipe = DDPMDPPipeline.from_pretrained("google/ddpm-cat-256") 81 | 82 | >>> # run pipeline in inference (sample random noise and denoise) 83 | >>> image = pipe().images[0] 84 | 85 | >>> # save image 86 | >>> image.save("ddpm_generated_image.png") 87 | ``` 88 | 89 | Returns: 90 | [`~pipelines.ImagePipelineOutput`] or `tuple`: 91 | If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is 92 | returned where the first element is a list with the generated images 93 | """ 94 | # Sample gaussian noise to begin loop 95 | if isinstance(self.unet.config.sample_size, int): 96 | image_shape = ( 97 | batch_size, 98 | self.unet.config.in_channels, 99 | self.unet.config.sample_size, 100 | self.unet.config.sample_size, 101 | ) 102 | else: 103 | image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size) 104 | 105 | if self.device.type == "mps": 106 | # randn does not work reproducibly on mps 107 | image = randn_tensor(image_shape, generator=generator) 108 | image = image.to(self.device) 109 | else: 110 | image = randn_tensor(image_shape, generator=generator, device=self.device) 111 | 112 | # set step values 113 | self.scheduler.set_timesteps(num_inference_steps) 114 | 115 | for t in self.progress_bar(self.scheduler.timesteps): 116 | # 1. predict noise model_output 117 | model_output = self.unet(image, t).sample 118 | 119 | # 2. compute previous image: x_t -> x_t-1 120 | image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample 121 | 122 | image = (image / 2 + 0.5).clamp(0, 1) 123 | image = image.cpu().permute(0, 2, 3, 1).numpy() 124 | if output_type == "pil": 125 | image = self.numpy_to_pil(image) 126 | 127 | if not return_dict: 128 | return (image,) 129 | 130 | return ImagePipelineOutput(images=image) 131 | 132 | def sample(self, *args, **kwargs): 133 | res = self(*args, output_type='tensor', **kwargs)[0] 134 | if isinstance(res, ImagePipelineOutput): 135 | return res.images 136 | else: 137 | return res 138 | 139 | def eval(self): 140 | self.unet.eval() 141 | 142 | def train(self): 143 | self.unet.train() 144 | 145 | def get_parameters(self): 146 | return self.unet.parameters() 147 | -------------------------------------------------------------------------------- /inverse_diffusion/src/utils/diffusers/pipelines/ddpm_gfn/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from diffusers.utils import ( 4 | DIFFUSERS_SLOW_IMPORT, 5 | _LazyModule, 6 | ) 7 | 8 | 9 | _import_structure = {"pipeline_ddpm": ["DDPMPipeline"]} 10 | 11 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 12 | from utils.diffusers.pipelines import DDPMGFNPipeline 13 | 14 | else: 15 | import sys 16 | 17 | sys.modules[__name__] = _LazyModule( 18 | __name__, 19 | globals()["__file__"], 20 | _import_structure, 21 | ) 22 | -------------------------------------------------------------------------------- /inverse_diffusion/src/utils/diffusers/pipelines/ddpm_gfn/pipeline_ddpm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | 20 | from utils.diffusers.schedulers.scheduling_ddpm_gfn import DDPMGFNScheduler 21 | from diffusers.utils.torch_utils import randn_tensor 22 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput 23 | 24 | 25 | class DDPMGFNPipeline(DiffusionPipeline): 26 | r""" 27 | Pipeline for image generation. 28 | 29 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods 30 | implemented for all pipelines (downloading, saving, running on a particular device, etc.). 31 | 32 | Parameters: 33 | unet ([`UNet2DModel`]): 34 | A `UNet2DModel` to denoise the encoded image latents. 35 | scheduler ([`SchedulerMixin`]): 36 | A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of 37 | [`DDPMScheduler`], or [`DDIMScheduler`]. 38 | """ 39 | 40 | model_cpu_offload_seq = "unet" 41 | 42 | def __init__(self, unet, scheduler): 43 | super().__init__() 44 | # make sure scheduler can always be converted to DDIM 45 | scheduler = DDPMGFNScheduler.from_config(scheduler.config) 46 | self.register_modules(unet=unet, scheduler=scheduler) 47 | 48 | @torch.no_grad() 49 | def __call__( 50 | self, 51 | batch_size: int = 1, 52 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 53 | num_inference_steps: int = 1000, 54 | output_type: Optional[str] = "pil", 55 | return_dict: bool = True, 56 | ) -> Union[ImagePipelineOutput, Tuple]: 57 | r""" 58 | The call function to the pipeline for generation. 59 | 60 | Args: 61 | batch_size (`int`, *optional*, defaults to 1): 62 | The number of images to generate. 63 | generator (`torch.Generator`, *optional*): 64 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make 65 | generation deterministic. 66 | num_inference_steps (`int`, *optional*, defaults to 1000): 67 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 68 | expense of slower inference. 69 | output_type (`str`, *optional*, defaults to `"pil"`): 70 | The output format of the generated image. Choose between `PIL.Image` or `np.array`. 71 | return_dict (`bool`, *optional*, defaults to `True`): 72 | Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. 73 | 74 | Example: 75 | 76 | ```py 77 | >>> from diffusers import DDPMGFNPipeline 78 | 79 | >>> # load model and scheduler 80 | >>> pipe = DDPMGFNPipeline.from_pretrained("google/ddpm-cat-256") 81 | 82 | >>> # run pipeline in inference (sample random noise and denoise) 83 | >>> image = pipe().images[0] 84 | 85 | >>> # save image 86 | >>> image.save("ddpm_generated_image.png") 87 | ``` 88 | 89 | Returns: 90 | [`~pipelines.ImagePipelineOutput`] or `tuple`: 91 | If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is 92 | returned where the first element is a list with the generated images 93 | """ 94 | # Sample gaussian noise to begin loop 95 | if isinstance(self.unet.config.sample_size, int): 96 | image_shape = ( 97 | batch_size, 98 | self.unet.config.in_channels, 99 | self.unet.config.sample_size, 100 | self.unet.config.sample_size, 101 | ) 102 | else: 103 | image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size) 104 | 105 | if self.device.type == "mps": 106 | # randn does not work reproducibly on mps 107 | image = randn_tensor(image_shape, generator=generator) 108 | image = image.to(self.device) 109 | else: 110 | image = randn_tensor(image_shape, generator=generator, device=self.device) 111 | 112 | # set step values 113 | self.scheduler.set_timesteps(num_inference_steps) 114 | 115 | for t in self.progress_bar(self.scheduler.timesteps): 116 | # 1. predict noise model_output 117 | model_output = self.unet(image, t).sample 118 | 119 | # 2. compute previous image: x_t -> x_t-1 120 | image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample 121 | 122 | image = (image / 2 + 0.5).clamp(0, 1) 123 | image = image.cpu().permute(0, 2, 3, 1).numpy() 124 | if output_type == "pil": 125 | image = self.numpy_to_pil(image) 126 | 127 | if not return_dict: 128 | return (image,) 129 | 130 | return ImagePipelineOutput(images=image) 131 | 132 | def sample(self, *args, **kwargs): 133 | res = self(*args, output_type='tensor', **kwargs)[0] 134 | if isinstance(res, ImagePipelineOutput): 135 | return res.images 136 | else: 137 | return res 138 | 139 | def eval(self): 140 | self.unet.eval() 141 | 142 | def train(self): 143 | self.unet.train() -------------------------------------------------------------------------------- /inverse_diffusion/src/utils/fid_evaluation.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from itertools import combinations 4 | 5 | import numpy as np 6 | import torch 7 | from einops import rearrange, repeat 8 | from pytorch_fid.fid_score import calculate_frechet_distance 9 | from pytorch_fid.inception import InceptionV3 10 | from torch import cosine_similarity 11 | from torch.nn.functional import adaptive_avg_pool2d 12 | from tqdm.auto import tqdm 13 | from fld.metrics.FLD import FLD 14 | from fld.metrics.FID import FID 15 | from fld.metrics.AuthPct import AuthPct 16 | from fld.metrics.CTTest import CTTest 17 | from fld.metrics.KID import KID 18 | from fld.metrics.PrecisionRecall import PrecisionRecall 19 | 20 | 21 | class NoContext: 22 | def __enter__(self): 23 | pass 24 | 25 | def __exit__(self, exc_type, exc_val, exc_tb): 26 | pass 27 | 28 | 29 | def num_to_groups(num, divisor): 30 | groups = num // divisor 31 | remainder = num % divisor 32 | arr = [divisor] * groups 33 | if remainder > 0: 34 | arr.append(remainder) 35 | return arr 36 | 37 | class DIVERSITY: 38 | def compute_metric( 39 | self, 40 | train_feat, 41 | test_feat, 42 | gen_feat, 43 | samples=5000, 44 | pairs=20000 45 | ): 46 | """ 47 | Computes the average cosine similarity between all pairs of elements in a tensor. 48 | 49 | Args: 50 | features: A PyTorch tensor of shape (BS, N). 51 | max_pairs: (Optional) Maximum number of pairs to consider. Defaults to None (all pairs). 52 | 53 | Returns: 54 | A float representing the average cosine similarity. 55 | """ 56 | batch_size, num_features = train_feat.shape 57 | 58 | samples = min(samples, batch_size) 59 | 60 | # Ensure max_pairs doesn't exceed total possible pairs 61 | sampled_indices = np.random.choice(list(range(batch_size)), samples, replace=False) 62 | idx = np.array(list(combinations(sampled_indices, 2))) 63 | pairs = min(len(idx), pairs) 64 | pairs_idx = np.random.choice(list(range(len(idx))), pairs, replace=False) 65 | idx = idx[pairs_idx] 66 | 67 | # Return the average cosine similarity 68 | return cosine_similarity(train_feat[idx[:, 0]], train_feat[idx[:, 1]], dim=1).mean() 69 | 70 | 71 | class SCOREEvaluation: 72 | def __init__( 73 | self, 74 | batch_size, 75 | dl, 76 | sampler, 77 | channels=3, 78 | dl_test=None, 79 | accelerator=None, 80 | stats_dir="./results", 81 | device="cuda", 82 | num_fid_samples=50000, 83 | normalize_input=True, 84 | inception_block_idx=2048, 85 | ): 86 | self.batch_size = batch_size 87 | self.n_samples = num_fid_samples 88 | self.device = device 89 | self.channels = channels 90 | self.dl = dl 91 | self.dl_test = dl_test 92 | self.sampler = sampler 93 | self.stats_dir = stats_dir 94 | self.print_fn = print if accelerator is None else accelerator.print 95 | assert inception_block_idx in InceptionV3.BLOCK_INDEX_BY_DIM 96 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[inception_block_idx] 97 | if accelerator is not None: 98 | self.inception_v3 = accelerator.prepare(InceptionV3([block_idx], normalize_input=normalize_input)).to(device) 99 | else: 100 | self.inception_v3 = InceptionV3([block_idx], normalize_input=normalize_input).to(device) 101 | self.dataset_stats_loaded = False 102 | self.accelerator = accelerator 103 | accelerator.prepare(self.inception_v3) 104 | 105 | print(f"using {inception_block_idx} block for fid computation") 106 | 107 | def calculate_inception_features(self, samples): 108 | if self.channels == 1: 109 | samples = repeat(samples, "b 1 ... -> b c ...", c=3) 110 | 111 | self.inception_v3.eval() 112 | features = self.inception_v3(samples)[0] 113 | 114 | if features.size(2) != 1 or features.size(3) != 1: 115 | features = adaptive_avg_pool2d(features, output_size=(1, 1)) 116 | features = rearrange(features, "... 1 1 -> ...") 117 | return features 118 | 119 | @torch.inference_mode() 120 | def load_or_precalc_dataset_stats(self): 121 | path = os.path.join(self.stats_dir, "dataset_stats") 122 | try: 123 | ckpt = np.load(path + ".npz") 124 | self.train_real_features, self.test_real_features = torch.FloatTensor(ckpt["train_real_features"]), torch.FloatTensor(ckpt["test_real_features"]) 125 | self.print_fn("Dataset stats loaded from disk.") 126 | ckpt.close() 127 | except OSError: 128 | num_batches = int(math.ceil(self.n_samples / self.batch_size)) 129 | loaders = {'train': self.dl, 'test': self.dl_test} 130 | features = {"train_real_features": None, "test_real_features":None} 131 | for split in loaders.keys(): 132 | if loaders[split] is None: 133 | continue 134 | stacked_real_features = [] 135 | self.print_fn( 136 | f"Stacking Inception features for {split}:{self.n_samples} samples from the real dataset." 137 | ) 138 | for _ in tqdm(range(num_batches)): 139 | try: 140 | real_samples = next(loaders[split]) 141 | except StopIteration: 142 | break 143 | if isinstance(real_samples, dict): 144 | real_samples = real_samples['images'] 145 | real_samples = real_samples.to(self.device) 146 | real_features = self.calculate_inception_features(real_samples) 147 | stacked_real_features.append(real_features) 148 | features[split + "_real_features"] = torch.cat(stacked_real_features, dim=0).cpu() 149 | 150 | self.train_real_features = features["train_real_features"] 151 | self.test_real_features = features["test_real_features"] 152 | np.savez_compressed(path, 153 | train_real_features=self.train_real_features.numpy(), 154 | test_real_features=self.test_real_features.numpy()) 155 | self.print_fn(f"Dataset stats cached to {path}.npz for future use.") 156 | self.dataset_stats_loaded = True 157 | print("generated features for real dataset") 158 | 159 | def fid_score(self, grad=False): 160 | 161 | context = torch.no_grad() if not grad else NoContext() 162 | with context: 163 | if not self.dataset_stats_loaded: 164 | self.load_or_precalc_dataset_stats() 165 | self.sampler.eval() 166 | batches = num_to_groups(self.n_samples, self.batch_size) 167 | stacked_fake_features = [] 168 | self.print_fn( 169 | f"Stacking Inception features for {self.n_samples} generated samples." 170 | ) 171 | for batch in tqdm(batches): 172 | fake_samples = self.sampler.sample(batch_size=batch) 173 | fake_features = self.calculate_inception_features(fake_samples) 174 | stacked_fake_features.append(fake_features) 175 | print("stacking features") 176 | stacked_fake_features = torch.cat(stacked_fake_features, dim=0).cpu() 177 | print("features stacked") 178 | 179 | 180 | res = {} 181 | 182 | scores_fs = { 183 | 'COS-SIMILARITY': DIVERSITY(), 184 | 'FID': FID(), 185 | 'FLD': FLD(eval_feat="train"), 186 | # 'AuthPct': AuthPct(), 187 | # 'CTTest': CTTest(), 188 | 'KID': KID(ref_feat='train'), 189 | # 'Precision': PrecisionRecall(mode='Precision'), 190 | # 'Recall': PrecisionRecall(mode='Recall'), 191 | } 192 | for s_name, fn in scores_fs.items(): 193 | try: 194 | print(f"computing {s_name}...") 195 | res[s_name] = fn.compute_metric(stacked_fake_features, self.test_real_features, self.train_real_features) 196 | print(f"computed {s_name}!") 197 | except Exception as e: 198 | print(f"\nWARNING: score '{s_name}' could not be computed. \nException:\n{e.__class__.__name__}:{e}\n\n") 199 | # raise e 200 | 201 | return res 202 | 203 | -------------------------------------------------------------------------------- /inverse_diffusion/src/utils/visualization.py: -------------------------------------------------------------------------------- 1 | import math 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | def plot_samples(samples, sample_logs=None, title='', show=False, save=False, filename=''): 6 | num_samples = len(samples) 7 | grid_size = math.ceil(math.sqrt(num_samples)) 8 | 9 | fig, axs = plt.subplots(nrows=grid_size, ncols=grid_size, figsize=(10, 10)) 10 | axs = axs.flatten() 11 | 12 | for i in range(grid_size * grid_size): 13 | axs[i].axis('off') 14 | if i < num_samples: 15 | axs[i].imshow(samples[i], vmin=-1, vmax=1, cmap='gray') 16 | if sample_logs is not None: 17 | axs[i].set_title(sample_logs[i], fontsize=26) 18 | 19 | fig.suptitle(title, fontsize=32) 20 | 21 | if save and filename: 22 | plt.savefig(filename) 23 | 24 | if show: 25 | plt.show() 26 | 27 | if save or show: 28 | plt.close(fig) # Close the figure after saving 29 | 30 | return fig, axs # Return the figure and axes for further use 31 | 32 | 33 | def compare_samples(samples1, samples2, sample1_title, sample2_title, sample1_logs=None, sample2_logs=None, dpi=300, 34 | show=False, save=False, filename=''): 35 | num_samples1 = len(samples1) 36 | num_samples2 = len(samples2) 37 | grid_size1 = math.ceil(math.sqrt(num_samples1)) 38 | grid_size2 = math.ceil(math.sqrt(num_samples2)) 39 | 40 | # Calculate the total grid size needed 41 | total_cols = grid_size1 + grid_size2 + 1 # +1 for space between sets 42 | total_rows = max(grid_size1, grid_size2) 43 | 44 | fig, axs = plt.subplots(nrows=total_rows, ncols=total_cols, figsize=(40, 20)) 45 | 46 | # Hide all axes initially 47 | for ax in axs.flat: 48 | ax.axis('off') 49 | 50 | # Function to display samples in the grid 51 | def display_samples(samples, axs, start_row, start_col, num_samples, grid_size, sample_logs=None): 52 | for i in range(num_samples): 53 | row_idx = (i // grid_size) + start_row 54 | col_idx = (i % grid_size) + start_col 55 | ax = axs[row_idx, col_idx] 56 | ax.imshow((samples[i]+1)/2, cmap='gray') 57 | ax.axis('off') # Only turn on the axis for images 58 | if sample_logs is not None: 59 | ax.set_title(sample_logs[i], fontsize=24) 60 | 61 | # Display first set of samples 62 | display_samples(samples1, axs, 0, 0, num_samples1, grid_size1, sample1_logs) 63 | 64 | # Display second set of samples, offset by the first set plus an additional column for spacing 65 | start_col_for_samples2 = grid_size1 + 1 # +1 for space 66 | display_samples(samples2, axs, 0, start_col_for_samples2, num_samples2, grid_size2, sample2_logs) 67 | 68 | # Set super title for the entire figure 69 | fig.suptitle(f"{sample1_title} vs {sample2_title}", fontsize=32) 70 | 71 | if save and filename: 72 | plt.savefig(filename, dpi=dpi) 73 | 74 | if show: 75 | plt.show() 76 | 77 | plt.close(fig) # Close the figure to clean up 78 | 79 | 80 | def smooth(scalars, weight): # Weight between 0 and 1 81 | last = scalars[0] # First value in the plot (first timestep) 82 | smoothed = list() 83 | for point in scalars: 84 | smoothed_val = last * weight + (1 - weight) * point # Calculate smoothed value 85 | smoothed.append(smoothed_val) # Save it 86 | last = smoothed_val # Anchor the last smoothed value 87 | 88 | return smoothed 89 | 90 | 91 | def plot_exp_logs(exp_logs, exp_args, smoothing=.9, path='', filter='', show=True, save=False): 92 | # Extract unique keys from the experiments (assuming all experiments have the same keys) 93 | keys = next(iter(exp_logs.values())).keys() 94 | num_keys = len(keys) 95 | 96 | # Determine the layout for subplots (as square as possible) 97 | rows = cols = math.ceil(math.sqrt(num_keys)) 98 | 99 | fig, axs = plt.subplots(rows, cols, figsize=(15, 15)) 100 | if not isinstance(axs, plt.Axes): 101 | axs = axs.flatten() # Flatten the array for easy indexing 102 | else: 103 | axs = [axs] 104 | 105 | handles, labels = [], [] 106 | 107 | # Loop over each key and plot the corresponding data for each experiment 108 | for idx, key in enumerate(keys): 109 | for exp_name, exp_data in exp_logs.items(): 110 | if filter in exp_name: 111 | if key in exp_data.keys() and len(exp_data[key]) > 0: 112 | line, = axs[idx].plot(exp_data[key], alpha=.1) 113 | line, = axs[idx].plot(smooth(exp_data[key], smoothing), color=line.get_color()) 114 | axs[idx].tick_params(axis='both', which='major', labelsize=16) 115 | if idx == 0: # Only add to legend once 116 | handles.append(line) 117 | # labels.append(f"lr: {exp_args[exp_name]['lr']}, " 118 | # f"lr logZ: {exp_args[exp_name]['lr_logZ']}, " 119 | # f"ct: {exp_args[exp_name]['learning_cutoff']}") 120 | 121 | labels.append(exp_name) 122 | axs[idx].set_title(key, fontsize=22) 123 | 124 | # axs[idx].legend(fontsize=14) 125 | fig.legend(handles, labels, 126 | loc='upper center', 127 | bbox_to_anchor=(0.5, 0.3), 128 | ncol=2, 129 | fontsize=18) 130 | 131 | # Hide any unused subplots 132 | for i in range(num_keys, len(axs)): 133 | axs[i].axis('off') 134 | 135 | plt.tight_layout() 136 | if save: 137 | plt.savefig(path+"/experiment_run.png") 138 | if show: 139 | plt.show() 140 | 141 | 142 | def plot_separate_exp_logs(exp_logs, exp_args, smoothing=.9, path='', filter='', show=True, save=False): 143 | # Extract unique keys from the experiments (assuming all experiments have the same keys) 144 | keys = next(iter(exp_logs.values())).keys() 145 | num_keys = len(keys) 146 | 147 | # Determine the layout for subplots (as square as possible) 148 | cols = 5 149 | rows = num_keys // cols + (1 if num_keys % cols != 0 else 0) 150 | 151 | fig, axs = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4)) 152 | if not isinstance(axs, plt.Axes): 153 | axs = axs.flatten() # Flatten the array for easy indexing 154 | else: 155 | axs = [axs] 156 | 157 | # Loop over each key and plot the corresponding data for each experiment 158 | for exp_name, exp_data in exp_logs.items(): 159 | for idx, key in enumerate(keys): 160 | if filter in exp_name: 161 | if key in exp_data.keys() and len(exp_data[key]) > 0: 162 | line, = axs[idx].plot(exp_data[key], alpha=.1) 163 | axs[idx].plot(smooth(exp_data[key], smoothing), color=line.get_color()) 164 | axs[idx].tick_params(axis='both', which='major', labelsize=16) 165 | 166 | axs[idx].set_title(key, fontsize=22) 167 | 168 | # Hide any unused subplots 169 | for i in range(num_keys, len(axs)): 170 | axs[i].axis('off') 171 | 172 | plt.suptitle(exp_name, fontsize=32) 173 | plt.tight_layout() 174 | if save: 175 | plt.savefig(path+"/experiment_run.png") 176 | if show: 177 | plt.show() 178 | 179 | fig, axs = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4)) 180 | if not isinstance(axs, plt.Axes): 181 | axs = axs.flatten() # Flatten the array for easy indexing 182 | else: 183 | axs = [axs] -------------------------------------------------------------------------------- /inverse_diffusion/src/visualize_runs.py: -------------------------------------------------------------------------------- 1 | from utils.gfn_diffusion import diffusion_resample 2 | from utils.visualization import * 3 | from utils.simple_io import * 4 | from utils.args import fetch_args 5 | 6 | args, state = fetch_args(experiment_run=False) 7 | 8 | exp_names = [expn for expn in get_filenames(path=args.save_folder) if '.DS' not in expn] 9 | exp_paths = [f"{args.save_folder}/{expn}" for expn in exp_names] 10 | 11 | exp_args = {exp_name: DictObj(load_dict_from_file(f"{exp_path}/run_args.json")) for exp_name, exp_path in zip(exp_names, exp_paths) if file_exists(f"{exp_path}/run_args.json")} 12 | exp_logs = {exp_name: load_dict_from_file(f"{exp_path}/run_logs.json") for exp_name, exp_path in zip(exp_names, exp_paths) if file_exists(f"{exp_path}/run_logs.json")} 13 | 14 | # plot_exp_logs(exp_logs, exp_args) 15 | # plot_separate_exp_logs(exp_logs, exp_args) 16 | diffusion_resample( 17 | exp_args, 18 | exp_paths=[exp_path for exp_path in exp_paths if file_exists(f"{exp_path}/run_args.json")], 19 | batch_size=args.plot_batch_size, 20 | device=args.device 21 | ) 22 | 23 | 24 | -------------------------------------------------------------------------------- /offline_RL/IQL_PyTorch/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Garrett Thomas 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /offline_RL/IQL_PyTorch/README.md: -------------------------------------------------------------------------------- 1 | # Codebase adapted from public github repo https://github.com/gwthomas/IQL-PyTorch -------------------------------------------------------------------------------- /offline_RL/IQL_PyTorch/main.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import gym 4 | import d4rl 5 | import numpy as np 6 | import torch 7 | from tqdm import trange 8 | from distutils.util import strtobool 9 | import time 10 | from torch.utils.tensorboard import SummaryWriter 11 | 12 | from src.iql import ImplicitQLearning 13 | from src.policy import GaussianPolicy, DeterministicPolicy 14 | from src.value_functions import TwinQ, ValueFunction 15 | from src.util import return_range, set_seed, Log, sample_batch, torchify, evaluate_policy 16 | 17 | 18 | def get_env_and_dataset(log, env_name, max_episode_steps): 19 | env = gym.make(env_name) 20 | dataset = d4rl.qlearning_dataset(env) 21 | 22 | if any(s in env_name for s in ('halfcheetah', 'hopper', 'walker2d')): 23 | min_ret, max_ret = return_range(dataset, max_episode_steps) 24 | log(f'Dataset returns have range [{min_ret}, {max_ret}]') 25 | dataset['rewards'] /= (max_ret - min_ret) 26 | dataset['rewards'] *= max_episode_steps 27 | elif 'antmaze' in env_name: 28 | dataset['rewards'] -= 1. 29 | 30 | for k, v in dataset.items(): 31 | dataset[k] = torchify(v) 32 | 33 | return env, dataset 34 | 35 | 36 | def main(args): 37 | run_name = f"{args.env_name}__{args.seed}__{int(time.time())}" 38 | if args.track: 39 | import wandb 40 | wandb.init( 41 | project=args.wandb_project_name, 42 | entity=args.wandb_entity, 43 | sync_tensorboard=True, 44 | config=vars(args), 45 | name=run_name, 46 | monitor_gym=True, 47 | save_code=True, 48 | ) 49 | writer = SummaryWriter(f"runs/{run_name}") 50 | writer.add_text( 51 | "hyperparameters", 52 | "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), 53 | ) 54 | 55 | torch.set_num_threads(1) 56 | log = Log(Path(args.log_dir)/args.env_name, vars(args)) 57 | log(f'Log dir: {log.dir}') 58 | 59 | env, dataset = get_env_and_dataset(log, args.env_name, args.max_episode_steps) 60 | obs_dim = dataset['observations'].shape[1] 61 | act_dim = dataset['actions'].shape[1] # this assume continuous actions 62 | set_seed(args.seed, env=env) 63 | 64 | if args.deterministic_policy: 65 | policy = DeterministicPolicy(obs_dim, act_dim, hidden_dim=args.hidden_dim, n_hidden=args.n_hidden) 66 | else: 67 | policy = GaussianPolicy(obs_dim, act_dim, hidden_dim=args.hidden_dim, n_hidden=args.n_hidden) 68 | def eval_policy(): 69 | eval_returns = np.array([evaluate_policy(env, policy, args.max_episode_steps) \ 70 | for _ in range(args.n_eval_episodes)]) 71 | if 'antmaze' not in args.env_name: 72 | normalized_returns = d4rl.get_normalized_score(args.env_name, eval_returns) * 100.0 73 | else: 74 | normalized_returns = eval_returns 75 | log.row({ 76 | 'return mean': eval_returns.mean(), 77 | 'return std': eval_returns.std(), 78 | 'normalized return mean': normalized_returns.mean(), 79 | 'normalized return std': normalized_returns.std(), 80 | }) 81 | return normalized_returns.mean() 82 | 83 | iql = ImplicitQLearning( 84 | qf=TwinQ(obs_dim, act_dim, hidden_dim=args.hidden_dim, n_hidden=args.n_hidden), 85 | vf=ValueFunction(obs_dim, hidden_dim=args.hidden_dim, n_hidden=args.n_hidden), 86 | policy=policy, 87 | optimizer_factory=lambda params: torch.optim.Adam(params, lr=args.learning_rate), 88 | max_steps=args.n_steps, 89 | tau=args.tau, 90 | beta=args.beta, 91 | alpha=args.alpha, 92 | discount=args.discount 93 | ) 94 | 95 | best_return = -100000 96 | for step in trange(args.n_steps): 97 | iql.update(**sample_batch(dataset, args.batch_size)) 98 | if (step+1) % args.eval_period == 0: 99 | avg_return = eval_policy() 100 | torch.save(iql.qf.state_dict(), '../q_models/' + args.env_name + '_qf.pth') 101 | if args.track: 102 | wandb.log({"avg_reward": avg_return}) 103 | if avg_return > best_return: 104 | best_return = avg_return 105 | torch.save(iql.qf.state_dict(), '../q_models/' + args.env_name + '_qf_best.pth') 106 | 107 | torch.save(iql.qf.state_dict(), '../q_models/' + args.env_name + '_qf_final.pth') 108 | log.close() 109 | 110 | 111 | if __name__ == '__main__': 112 | from argparse import ArgumentParser 113 | parser = ArgumentParser() 114 | parser.add_argument('--env-name', required=True) 115 | parser.add_argument('--log-dir', default='logs') 116 | parser.add_argument('--seed', type=int, default=0) 117 | parser.add_argument('--discount', type=float, default=0.99) 118 | parser.add_argument('--hidden-dim', type=int, default=256) 119 | parser.add_argument('--n-hidden', type=int, default=2) 120 | parser.add_argument('--n-steps', type=int, default=(10**7)) 121 | parser.add_argument('--batch-size', type=int, default=256) 122 | parser.add_argument('--learning-rate', type=float, default=3e-4) 123 | parser.add_argument('--alpha', type=float, default=0.005) 124 | parser.add_argument('--tau', type=float, default=0.7) 125 | parser.add_argument('--beta', type=float, default=3.0) 126 | parser.add_argument('--deterministic-policy', action='store_true') 127 | parser.add_argument('--eval-period', type=int, default=5000) 128 | parser.add_argument('--n-eval-episodes', type=int, default=10) 129 | parser.add_argument('--max-episode-steps', type=int, default=1000) 130 | parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, 131 | help="if toggled, this experiment will be tracked with Weights and Biases") 132 | parser.add_argument("--wandb-project-name", type=str, default="project-name", 133 | help="the wandb's project name") 134 | parser.add_argument("--wandb-entity", type=str, default='entity-name', 135 | help="the entity (team) of wandb's project") 136 | main(parser.parse_args()) -------------------------------------------------------------------------------- /offline_RL/IQL_PyTorch/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | pandas 4 | tqdm 5 | gym[mujoco] >= 0.18.0 6 | torch>=1.7.0 7 | git+https://github.com/rail-berkeley/d4rl@master#egg=d4rl -------------------------------------------------------------------------------- /offline_RL/IQL_PyTorch/results.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | 7 | LOCOMOTION_ENVS = { 8 | 'halfcheetah-medium-v2': 47.4, 9 | 'hopper-medium-v2': 66.3, 10 | 'walker2d-medium-v2': 78.3, 11 | 'halfcheetah-medium-replay-v2': 44.2, 12 | 'hopper-medium-replay-v2': 94.7, 13 | 'walker2d-medium-replay-v2': 73.9, 14 | 'halfcheetah-medium-expert-v2': 86.7, 15 | 'hopper-medium-expert-v2': 91.5, 16 | 'walker2d-medium-expert-v2': 109.6 17 | } 18 | 19 | ANTMAZE_ENVS = { 20 | 'antmaze-umaze-v0': 87.5, 21 | 'antmaze-umaze-diverse-v0': 62.2, 22 | 'antmaze-medium-play-v0': 71.2, 23 | 'antmaze-medium-diverse-v0': 70.0, 24 | 'antmaze-large-play-v0': 39.6, 25 | 'antmaze-large-diverse-v0': 47.5 26 | } 27 | 28 | KITCHEN_ENVS = { 29 | 'kitchen-complete-v0': 62.5, 30 | 'kitchen-partial-v0': 46.3, 31 | 'kitchen-mixed-v0': 51.0 32 | } 33 | 34 | ADROIT_ENVS = { 35 | 'pen-human-v0': 71.5, 36 | 'hammer-human-v0': 1.4, 37 | 'door-human-v0': 4.3, 38 | 'relocate-human-v0': 0.1, 39 | 'pen-cloned-v0': 37.3, 40 | 'hammer-cloned-v0': 2.1, 41 | 'door-cloned-v0': 1.6, 42 | 'relocate-cloned-v0': -0.2 43 | } 44 | 45 | ENV_COLLECTIONS = { 46 | 'locomotion-all': LOCOMOTION_ENVS, 47 | 'antmaze-all': ANTMAZE_ENVS, 48 | 'kitchen-all': KITCHEN_ENVS, 49 | 'adroit-all': ADROIT_ENVS 50 | } 51 | 52 | 53 | def main(args): 54 | dir = Path(args.dir) 55 | assert dir.is_dir(), f'{dir} is not a directory' 56 | print('| Environment | This implementation | Official implementation |\n' 57 | '| ----------- | ------------------- | ----------------------- |') 58 | envs = ENV_COLLECTIONS[args.envs] 59 | for env, ref_score in envs.items(): 60 | env_dir = dir/env 61 | assert env_dir.is_dir(), f'{env_dir} is not a directory' 62 | run_dirs = [d for d in env_dir.iterdir() if d.is_dir()] 63 | final_perfs = [] 64 | for run_dir in run_dirs: 65 | data = pd.read_csv(run_dir/'progress.csv') 66 | normalized_returns = data['normalized return mean'].to_numpy() 67 | final_perfs.append(normalized_returns[-args.last_k:]) 68 | print(f'| {env} | {np.mean(final_perfs):.1f} +/- {np.std(final_perfs):.1f} | {ref_score:.1f} |') 69 | 70 | 71 | if __name__ == '__main__': 72 | from argparse import ArgumentParser 73 | parser = ArgumentParser() 74 | parser.add_argument('-d', '--dir', required=True) 75 | parser.add_argument('-e', '--envs', required=True) 76 | parser.add_argument('-k', '--last-k', type=int, default=10) # average over last k evals 77 | main(parser.parse_args()) -------------------------------------------------------------------------------- /offline_RL/IQL_PyTorch/src/iql.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.optim.lr_scheduler import CosineAnnealingLR 7 | 8 | from .util import DEFAULT_DEVICE, compute_batched, update_exponential_moving_average 9 | 10 | 11 | EXP_ADV_MAX = 100. 12 | 13 | 14 | def asymmetric_l2_loss(u, tau): 15 | return torch.mean(torch.abs(tau - (u < 0).float()) * u**2) 16 | 17 | 18 | class ImplicitQLearning(nn.Module): 19 | def __init__(self, qf, vf, policy, optimizer_factory, max_steps, 20 | tau, beta, discount=0.99, alpha=0.005): 21 | super().__init__() 22 | self.qf = qf.to(DEFAULT_DEVICE) 23 | self.q_target = copy.deepcopy(qf).requires_grad_(False).to(DEFAULT_DEVICE) 24 | self.vf = vf.to(DEFAULT_DEVICE) 25 | self.policy = policy.to(DEFAULT_DEVICE) 26 | self.v_optimizer = optimizer_factory(self.vf.parameters()) 27 | self.q_optimizer = optimizer_factory(self.qf.parameters()) 28 | self.policy_optimizer = optimizer_factory(self.policy.parameters()) 29 | self.policy_lr_schedule = CosineAnnealingLR(self.policy_optimizer, max_steps) 30 | self.tau = tau 31 | self.beta = beta 32 | self.discount = discount 33 | self.alpha = alpha 34 | 35 | def update(self, observations, actions, next_observations, rewards, terminals): 36 | with torch.no_grad(): 37 | target_q = self.q_target(observations, actions) 38 | next_v = self.vf(next_observations) 39 | 40 | # v, next_v = compute_batched(self.vf, [observations, next_observations]) 41 | 42 | # Update value function 43 | v = self.vf(observations) 44 | adv = target_q - v 45 | v_loss = asymmetric_l2_loss(adv, self.tau) 46 | self.v_optimizer.zero_grad(set_to_none=True) 47 | v_loss.backward() 48 | self.v_optimizer.step() 49 | 50 | # Update Q function 51 | targets = rewards + (1. - terminals.float()) * self.discount * next_v.detach() 52 | qs = self.qf.both(observations, actions) 53 | q_loss = sum(F.mse_loss(q, targets) for q in qs) / len(qs) 54 | self.q_optimizer.zero_grad(set_to_none=True) 55 | q_loss.backward() 56 | self.q_optimizer.step() 57 | 58 | # Update target Q network 59 | update_exponential_moving_average(self.q_target, self.qf, self.alpha) 60 | 61 | # Update policy 62 | exp_adv = torch.exp(self.beta * adv.detach()).clamp(max=EXP_ADV_MAX) 63 | policy_out = self.policy(observations) 64 | if isinstance(policy_out, torch.distributions.Distribution): 65 | bc_losses = -policy_out.log_prob(actions) 66 | elif torch.is_tensor(policy_out): 67 | assert policy_out.shape == actions.shape 68 | bc_losses = torch.sum((policy_out - actions)**2, dim=1) 69 | else: 70 | raise NotImplementedError 71 | policy_loss = torch.mean(exp_adv * bc_losses) 72 | self.policy_optimizer.zero_grad(set_to_none=True) 73 | policy_loss.backward() 74 | self.policy_optimizer.step() 75 | self.policy_lr_schedule.step() -------------------------------------------------------------------------------- /offline_RL/IQL_PyTorch/src/policy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.distributions import MultivariateNormal 4 | 5 | from .util import mlp 6 | 7 | 8 | LOG_STD_MIN = -5.0 9 | LOG_STD_MAX = 2.0 10 | 11 | 12 | class GaussianPolicy(nn.Module): 13 | def __init__(self, obs_dim, act_dim, hidden_dim=256, n_hidden=2): 14 | super().__init__() 15 | self.net = mlp([obs_dim, *([hidden_dim] * n_hidden), act_dim]) 16 | self.log_std = nn.Parameter(torch.zeros(act_dim, dtype=torch.float32)) 17 | 18 | def forward(self, obs): 19 | mean = self.net(obs) 20 | std = torch.exp(self.log_std.clamp(LOG_STD_MIN, LOG_STD_MAX)) 21 | scale_tril = torch.diag(std) 22 | return MultivariateNormal(mean, scale_tril=scale_tril) 23 | # if mean.ndim > 1: 24 | # batch_size = len(obs) 25 | # return MultivariateNormal(mean, scale_tril=scale_tril.repeat(batch_size, 1, 1)) 26 | # else: 27 | # return MultivariateNormal(mean, scale_tril=scale_tril) 28 | 29 | def act(self, obs, deterministic=False, enable_grad=False): 30 | with torch.set_grad_enabled(enable_grad): 31 | dist = self(obs) 32 | return dist.mean if deterministic else dist.sample() 33 | 34 | 35 | class DeterministicPolicy(nn.Module): 36 | def __init__(self, obs_dim, act_dim, hidden_dim=256, n_hidden=2): 37 | super().__init__() 38 | self.net = mlp([obs_dim, *([hidden_dim] * n_hidden), act_dim], 39 | output_activation=nn.Tanh) 40 | 41 | def forward(self, obs): 42 | return self.net(obs) 43 | 44 | def act(self, obs, deterministic=False, enable_grad=False): 45 | with torch.set_grad_enabled(enable_grad): 46 | return self(obs) -------------------------------------------------------------------------------- /offline_RL/IQL_PyTorch/src/util.py: -------------------------------------------------------------------------------- 1 | import csv 2 | from datetime import datetime 3 | import json 4 | from pathlib import Path 5 | import random 6 | import string 7 | import sys 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | 13 | 14 | DEFAULT_DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 15 | 16 | 17 | class Squeeze(nn.Module): 18 | def __init__(self, dim=None): 19 | super().__init__() 20 | self.dim = dim 21 | 22 | def forward(self, x): 23 | return x.squeeze(dim=self.dim) 24 | 25 | 26 | def mlp(dims, activation=nn.ReLU, output_activation=None, squeeze_output=False): 27 | n_dims = len(dims) 28 | assert n_dims >= 2, 'MLP requires at least two dims (input and output)' 29 | 30 | layers = [] 31 | for i in range(n_dims - 2): 32 | layers.append(nn.Linear(dims[i], dims[i+1])) 33 | layers.append(activation()) 34 | layers.append(nn.Linear(dims[-2], dims[-1])) 35 | if output_activation is not None: 36 | layers.append(output_activation()) 37 | if squeeze_output: 38 | assert dims[-1] == 1 39 | layers.append(Squeeze(-1)) 40 | net = nn.Sequential(*layers) 41 | net.to(dtype=torch.float32) 42 | return net 43 | 44 | 45 | def compute_batched(f, xs): 46 | return f(torch.cat(xs, dim=0)).split([len(x) for x in xs]) 47 | 48 | 49 | def update_exponential_moving_average(target, source, alpha): 50 | for target_param, source_param in zip(target.parameters(), source.parameters()): 51 | target_param.data.mul_(1. - alpha).add_(source_param.data, alpha=alpha) 52 | 53 | 54 | def torchify(x): 55 | x = torch.from_numpy(x) 56 | if x.dtype is torch.float64: 57 | x = x.float() 58 | x = x.to(device=DEFAULT_DEVICE) 59 | return x 60 | 61 | 62 | 63 | def return_range(dataset, max_episode_steps): 64 | returns, lengths = [], [] 65 | ep_ret, ep_len = 0., 0 66 | for r, d in zip(dataset['rewards'], dataset['terminals']): 67 | ep_ret += float(r) 68 | ep_len += 1 69 | if d or ep_len == max_episode_steps: 70 | returns.append(ep_ret) 71 | lengths.append(ep_len) 72 | ep_ret, ep_len = 0., 0 73 | # returns.append(ep_ret) # incomplete trajectory 74 | lengths.append(ep_len) # but still keep track of number of steps 75 | assert sum(lengths) == len(dataset['rewards']) 76 | return min(returns), max(returns) 77 | 78 | 79 | # dataset is a dict, values of which are tensors of same first dimension 80 | def sample_batch(dataset, batch_size): 81 | k = list(dataset.keys())[0] 82 | n, device = len(dataset[k]), dataset[k].device 83 | for v in dataset.values(): 84 | assert len(v) == n, 'Dataset values must have same length' 85 | indices = torch.randint(low=0, high=n, size=(batch_size,), device=device) 86 | return {k: v[indices] for k, v in dataset.items()} 87 | 88 | 89 | def evaluate_policy(env, policy, max_episode_steps, deterministic=True): 90 | obs = env.reset() 91 | total_reward = 0. 92 | for _ in range(max_episode_steps): 93 | with torch.no_grad(): 94 | action = policy.act(torchify(obs), deterministic=deterministic).cpu().numpy() 95 | next_obs, reward, done, info = env.step(action) 96 | total_reward += reward 97 | if done: 98 | break 99 | else: 100 | obs = next_obs 101 | return total_reward 102 | 103 | 104 | def set_seed(seed, env=None): 105 | torch.manual_seed(seed) 106 | if torch.cuda.is_available(): 107 | torch.cuda.manual_seed_all(seed) 108 | np.random.seed(seed) 109 | random.seed(seed) 110 | if env is not None: 111 | env.seed(seed) 112 | 113 | 114 | def _gen_dir_name(): 115 | now_str = datetime.now().strftime('%m-%d-%y_%H.%M.%S') 116 | rand_str = ''.join(random.choices(string.ascii_lowercase, k=4)) 117 | return f'{now_str}_{rand_str}' 118 | 119 | class Log: 120 | def __init__(self, root_log_dir, cfg_dict, 121 | txt_filename='log.txt', 122 | csv_filename='progress.csv', 123 | cfg_filename='config.json', 124 | flush=True): 125 | self.dir = Path(root_log_dir)/_gen_dir_name() 126 | self.dir.mkdir(parents=True) 127 | self.txt_file = open(self.dir/txt_filename, 'w') 128 | self.csv_file = None 129 | (self.dir/cfg_filename).write_text(json.dumps(cfg_dict)) 130 | self.txt_filename = txt_filename 131 | self.csv_filename = csv_filename 132 | self.cfg_filename = cfg_filename 133 | self.flush = flush 134 | 135 | def write(self, message, end='\n'): 136 | now_str = datetime.now().strftime('%H:%M:%S') 137 | message = f'[{now_str}] ' + message 138 | for f in [sys.stdout, self.txt_file]: 139 | print(message, end=end, file=f, flush=self.flush) 140 | 141 | def __call__(self, *args, **kwargs): 142 | self.write(*args, **kwargs) 143 | 144 | def row(self, dict): 145 | if self.csv_file is None: 146 | self.csv_file = open(self.dir/self.csv_filename, 'w', newline='') 147 | self.csv_writer = csv.DictWriter(self.csv_file, list(dict.keys())) 148 | self.csv_writer.writeheader() 149 | 150 | self(str(dict)) 151 | self.csv_writer.writerow(dict) 152 | if self.flush: 153 | self.csv_file.flush() 154 | 155 | def close(self): 156 | self.txt_file.close() 157 | if self.csv_file is not None: 158 | self.csv_file.close() -------------------------------------------------------------------------------- /offline_RL/IQL_PyTorch/src/value_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .util import mlp 4 | 5 | 6 | class TwinQ(nn.Module): 7 | def __init__(self, state_dim, action_dim, hidden_dim=256, n_hidden=2): 8 | super().__init__() 9 | dims = [state_dim + action_dim, *([hidden_dim] * n_hidden), 1] 10 | self.q1 = mlp(dims, squeeze_output=True) 11 | self.q2 = mlp(dims, squeeze_output=True) 12 | 13 | def both(self, state, action): 14 | sa = torch.cat([state, action], 1) 15 | return self.q1(sa), self.q2(sa) 16 | 17 | def forward(self, state, action): 18 | return torch.min(*self.both(state, action)) 19 | 20 | def log_reward(self, s, a_arctanh, alpha=1.0): 21 | a = torch.tanh(a_arctanh) 22 | q_sa = self(s, a) 23 | r = q_sa + alpha*torch.log((1 - (a)**2) + 1e-7).sum(1) 24 | return r 25 | 26 | def score(self, s, a, alpha=1.0): 27 | a = a.detach() 28 | a.requires_grad_(True) 29 | r = self.log_reward(s, a, alpha=alpha) 30 | # get gradient wrt r_sa 31 | score = torch.clamp(torch.autograd.grad(r.sum(), a)[0], -100, 100) 32 | return score.detach() 33 | 34 | 35 | class ValueFunction(nn.Module): 36 | def __init__(self, state_dim, hidden_dim=256, n_hidden=2): 37 | super().__init__() 38 | dims = [state_dim, *([hidden_dim] * n_hidden), 1] 39 | self.v = mlp(dims, squeeze_output=True) 40 | 41 | def forward(self, state): 42 | return self.v(state) -------------------------------------------------------------------------------- /offline_RL/README.md: -------------------------------------------------------------------------------- 1 | # QFlow Offline 2 | 3 | IQL codebase adapted from public github repo: https://github.com/gwthomas/IQL-PyTorch 4 | 5 | 6 | ### Install Mujoco 7 | 8 | Install Mujoco, following the instructions here: https://github.com/openai/mujoco-py?tab=readme-ov-file#install-mujoco 9 | 10 | ### Train BC prior 11 | 12 | ``` 13 | python train_bc.py --env-id 'halfcheetah-medium-replay-v2' --diffusion-steps 75 --n-epochs 1500 [--track] 14 | ``` 15 | 16 | ### Train Q function with IQL 17 | 18 | ``` 19 | cd 'IQL_PyTorch' 20 | python3 main.py --env-name 'halfcheetah-medium-replay-v2' [--track] 21 | ``` 22 | 23 | ### Policy extraction with RTB 24 | For reproducing results in paper, use alpha values from Table G.2 in the paper. 25 | 26 | ``` 27 | python qflow_offline.py --env-id 'halfcheetah-medium-replay' --alpha 0.05 --diffusion-steps 75 --batch-size 64 --num-eval 10 [--track] 28 | ``` -------------------------------------------------------------------------------- /offline_RL/bc_models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GFNOrg/diffusion-finetuning/21bfa222b606abd64873994f424a477efc29707d/offline_RL/bc_models/.gitkeep -------------------------------------------------------------------------------- /offline_RL/q_models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GFNOrg/diffusion-finetuning/21bfa222b606abd64873994f424a477efc29707d/offline_RL/q_models/.gitkeep -------------------------------------------------------------------------------- /offline_RL/qflow_offline.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import gym 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import d4rl 7 | import random 8 | import argparse 9 | from distutils.util import strtobool 10 | import os 11 | import time 12 | from model import DiffusionModel, QFlow 13 | from IQL_PyTorch.src.value_functions import TwinQ 14 | from torch.utils.data import Dataset, DataLoader 15 | from torch.utils.tensorboard import SummaryWriter 16 | import stable_baselines3 as sb3 17 | from stable_baselines3.common.vec_env import SubprocVecEnv 18 | 19 | def make_env(env_id, seed, rank, run_name, args): 20 | def thunk(): 21 | env = gym.make(env_id) 22 | env = gym.wrappers.RecordEpisodeStatistics(env) 23 | env.action_space.seed(seed + rank) 24 | return env 25 | return thunk 26 | 27 | class D4RLDataset(Dataset): 28 | def __init__(self, data): 29 | self.states = data['observations'] 30 | self.actions = data['actions'] 31 | 32 | def __len__(self): 33 | return len(self.states) 34 | 35 | def __getitem__(self, idx): 36 | states = self.states[idx] 37 | actions = self.actions[idx] 38 | return states, actions 39 | 40 | def parse_args(): 41 | # fmt: off 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), 44 | help="the name of this experiment") 45 | parser.add_argument("--seed", type=int, default=1, 46 | help="seed of the experiment") 47 | parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, 48 | help="if toggled, `torch.backends.cudnn.deterministic=False`") 49 | parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, 50 | help="if toggled, cuda will be enabled by default") 51 | parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, 52 | help="if toggled, this experiment will be tracked with Weights and Biases") 53 | parser.add_argument("--wandb-project-name", type=str, default="project-name", 54 | help="the wandb's project name") 55 | parser.add_argument("--wandb-entity", type=str, default='entity-name', 56 | help="the entity (team) of wandb's project") 57 | 58 | parser.add_argument("--env-id", type=str, default="hopper-medium-expert-v2", 59 | help="the id of the environment") 60 | parser.add_argument("--diffusion-steps", type=int, default=75) 61 | parser.add_argument("--batch-size", type=int, default=64) 62 | parser.add_argument("--lr", type=float, default=5e-4) 63 | parser.add_argument("--schedule", type=str, default='linear') 64 | parser.add_argument("--n-epochs", type=int, default=1000) 65 | parser.add_argument("--sample-freq", type=int, default=1) 66 | parser.add_argument("--predict", type=str, default='epsilon') 67 | parser.add_argument("--policy-net", type=str, default='mlp') 68 | parser.add_argument("--num-eval", type=int, default=10) 69 | parser.add_argument("--alpha", type=float, default=1.0) 70 | parser.add_argument("--extra", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, 71 | help="Extra sampling steps") 72 | 73 | args = parser.parse_args() 74 | return args 75 | 76 | if __name__ == '__main__': 77 | args = parse_args() 78 | run_name = f"{args.env_id}__{args.exp_name}__{args.alpha}__{args.seed}__{int(time.time())}" 79 | filename = args.env_id+"_"+args.exp_name 80 | if args.track: 81 | import wandb 82 | 83 | wandb.init( 84 | project=args.wandb_project_name, 85 | entity=args.wandb_entity, 86 | sync_tensorboard=True, 87 | config=vars(args), 88 | name=run_name, 89 | monitor_gym=True, 90 | save_code=True, 91 | ) 92 | writer = SummaryWriter(f"runs/{run_name}") 93 | writer.add_text( 94 | "hyperparameters", 95 | "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), 96 | ) 97 | 98 | # TRY NOT TO MODIFY: seeding 99 | random.seed(args.seed) 100 | np.random.seed(args.seed) 101 | torch.manual_seed(args.seed) 102 | torch.backends.cudnn.deterministic = args.torch_deterministic 103 | 104 | device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") 105 | 106 | env = gym.make(args.env_id) 107 | dataset = env.get_dataset() 108 | #if args.predict == 'epsilon': 109 | dataset['actions'] = np.arctanh(np.clip(dataset['actions'],-0.99,0.99)) 110 | data = D4RLDataset(dataset) 111 | dataloader = DataLoader(data, batch_size=args.batch_size, shuffle=True) 112 | 113 | bc_model = DiffusionModel(state_dim=env.observation_space.shape[0], action_dim=env.action_space.shape[0], diffusion_steps=args.diffusion_steps, predict=args.predict, policy_net=args.policy_net).to(device) 114 | bc_model.load_state_dict(torch.load('bc_models/'+args.env_id+'_'+'train_bc.pth')) 115 | q = TwinQ(env.observation_space.shape[0], env.action_space.shape[0]).to(device) 116 | q.load_state_dict(torch.load('q_models/'+args.env_id+'_qf.pth')) 117 | 118 | qflow = QFlow(state_dim=env.observation_space.shape[0], action_dim=env.action_space.shape[0], diffusion_steps=args.diffusion_steps, predict=args.predict, q_net=q, bc_net=bc_model, alpha=args.alpha).to(device) 119 | optimizer = torch.optim.Adam(list(qflow.qflow.out_model.parameters()) + list(qflow.qflow.means_scaling_model.parameters()) + list(qflow.qflow.x_model.parameters()), lr=args.lr) 120 | save_path = f'/home/mila/l/luke.rowe/qflowoffline/qflow_models/qflow_{args.env_id}_{args.alpha}_{args.seed}.pt' 121 | 122 | if os.path.exists(save_path): 123 | state = torch.load(save_path, map_location='cuda:0') 124 | qflow.load_state_dict(state['state_dict']) 125 | optimizer.load_state_dict(state['optimizer']) 126 | global_step = state['global_step'] 127 | current_epoch = state['epoch'] + 1 128 | else: 129 | global_step = 0 130 | current_epoch = 0 131 | 132 | for epoch in range(current_epoch, args.n_epochs): 133 | for states, actions in dataloader: 134 | if global_step % args.sample_freq == 0: 135 | optimizer.zero_grad() 136 | states = states.to(device) 137 | loss, logZSample = qflow.compute_loss(states) 138 | loss.backward() 139 | optimizer.step() 140 | sample_loss = loss.item() 141 | 142 | states = states.to(device) 143 | actions = actions.to(device) 144 | loss, logC = qflow.compute_loss_with_sample(states, actions) 145 | optimizer.zero_grad() 146 | loss.backward() 147 | optimizer.step() 148 | batch_loss = loss.item() 149 | 150 | if global_step%5 == 0 and args.track: 151 | writer.add_scalar("loss/sample_loss", sample_loss, global_step) 152 | writer.add_scalar("loss/batch_loss", batch_loss, global_step) 153 | writer.add_scalar("loss/logZSample", logZSample, global_step) 154 | writer.add_scalar("loss/logC", logC, global_step) 155 | with torch.no_grad(): 156 | if ((global_step)%5000) == 0: 157 | avg_reward = 0.0 158 | envs = SubprocVecEnv([make_env(args.env_id, 159 | args.seed, 160 | i, 161 | run_name, 162 | args) for i in range(args.num_eval)]) 163 | 164 | s = envs.reset() 165 | steps = 0 166 | done = [False for i in range(args.num_eval)] 167 | while False in done: 168 | steps+=1 169 | s_tensor = torch.tensor(s).float().to(device) 170 | a, _, _ = qflow.sample(s_tensor, extra=args.extra) 171 | a = torch.tanh(torch.tensor(a)).detach().cpu().numpy() 172 | s, r, terminations, infos = envs.step(a) 173 | 174 | for i in range(args.num_eval): 175 | if terminations[i] and not done[i]: 176 | done[i] = True 177 | avg_reward += float(infos[i]['episode']['r']) 178 | 179 | avg_reward /= args.num_eval 180 | if not 'antmaze' in args.env_id: 181 | env = gym.make(args.env_id) 182 | avg_reward = env.get_normalized_score(avg_reward)*100 183 | print('AVG REWARD:', avg_reward) 184 | writer.add_scalar("eval/avg_reward", avg_reward, global_step) 185 | 186 | global_step += 1 187 | 188 | # save model 189 | state = { 190 | 'epoch': epoch, 191 | 'state_dict': qflow.state_dict(), 192 | 'optimizer': optimizer.state_dict(), 193 | 'global_step': global_step 194 | } 195 | torch.save(state, save_path) -------------------------------------------------------------------------------- /offline_RL/train_bc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import gym 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import d4rl 7 | import random 8 | import argparse 9 | from distutils.util import strtobool 10 | import os 11 | import time 12 | from model import DiffusionModel 13 | from torch.utils.data import Dataset, DataLoader 14 | from torch.utils.tensorboard import SummaryWriter 15 | 16 | class D4RLDataset(Dataset): 17 | def __init__(self, data): 18 | self.states = data['observations'] 19 | self.actions = data['actions'] 20 | 21 | def __len__(self): 22 | return len(self.states) 23 | 24 | def __getitem__(self, idx): 25 | states = self.states[idx] 26 | actions = self.actions[idx] 27 | return states, actions 28 | 29 | def parse_args(): 30 | # fmt: off 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), 33 | help="the name of this experiment") 34 | parser.add_argument("--seed", type=int, default=1, 35 | help="seed of the experiment") 36 | parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, 37 | help="if toggled, `torch.backends.cudnn.deterministic=False`") 38 | parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, 39 | help="if toggled, cuda will be enabled by default") 40 | parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, 41 | help="if toggled, this experiment will be tracked with Weights and Biases") 42 | parser.add_argument("--wandb-project-name", type=str, default="project-name", 43 | help="the wandb's project name") 44 | parser.add_argument("--wandb-entity", type=str, default='entity-name', 45 | help="the entity (team) of wandb's project") 46 | 47 | parser.add_argument("--env-id", type=str, default="hopper-medium-expert", 48 | help="the id of the environment") 49 | parser.add_argument("--diffusion-steps", type=int, default=75) 50 | parser.add_argument("--batch-size", type=float, default=512) 51 | parser.add_argument("--lr", type=float, default=5e-4) 52 | parser.add_argument("--schedule", type=str, default='linear') 53 | parser.add_argument("--n-epochs", type=int, default=100000) 54 | parser.add_argument("--predict", type=str, default='epsilon') 55 | parser.add_argument("--policy-net", type=str, default='mlp') 56 | parser.add_argument("--num-eval", type=int, default=10) 57 | 58 | args = parser.parse_args() 59 | return args 60 | 61 | if __name__ == '__main__': 62 | args = parse_args() 63 | run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" 64 | filename = args.env_id+"_"+args.exp_name 65 | if args.track: 66 | import wandb 67 | 68 | wandb.init( 69 | project=args.wandb_project_name, 70 | entity=args.wandb_entity, 71 | sync_tensorboard=True, 72 | config=vars(args), 73 | name=run_name, 74 | monitor_gym=True, 75 | save_code=True, 76 | ) 77 | writer = SummaryWriter(f"runs/{run_name}") 78 | writer.add_text( 79 | "hyperparameters", 80 | "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), 81 | ) 82 | 83 | # TRY NOT TO MODIFY: seeding 84 | random.seed(args.seed) 85 | np.random.seed(args.seed) 86 | torch.manual_seed(args.seed) 87 | torch.backends.cudnn.deterministic = args.torch_deterministic 88 | 89 | device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") 90 | 91 | env = gym.make(args.env_id) 92 | dataset = env.get_dataset() 93 | #if args.predict == 'epsilon': 94 | dataset['actions'] = np.arctanh(np.clip(dataset['actions'],-0.99,0.99)) 95 | data = D4RLDataset(dataset) 96 | dataloader = DataLoader(data, batch_size=args.batch_size, shuffle=True) 97 | 98 | model = DiffusionModel(state_dim=env.observation_space.shape[0], action_dim=env.action_space.shape[0], diffusion_steps=args.diffusion_steps, predict=args.predict, policy_net=args.policy_net).to(device) 99 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 100 | 101 | best_eval = -10000 102 | for epoch in range(args.n_epochs): 103 | loss_epoch = 0.0 104 | for states, actions in dataloader: 105 | states = states.float().to(device) 106 | actions = actions.float().to(device) 107 | optimizer.zero_grad() 108 | loss = model.compute_loss(states, actions).mean() 109 | loss.backward() 110 | optimizer.step() 111 | loss_epoch += loss.item() 112 | #print(loss.item(), epoch) 113 | 114 | if epoch % 15 == 0: 115 | torch.save(model.state_dict(), "bc_models/"+filename+".pth") 116 | with torch.no_grad(): 117 | if epoch % 15 == 0: 118 | avg_reward = 0.0 119 | for i in range(args.num_eval): 120 | s = env.reset() 121 | done = False 122 | while not done: 123 | s_tensor = torch.tensor(s).float().to(device).unsqueeze(0) 124 | a = model.sample(s_tensor).detach().cpu().numpy() 125 | a = torch.tanh(torch.tensor(a)).detach().cpu().numpy()[0] 126 | s, r, done, _ = env.step(a) 127 | avg_reward += r 128 | avg_reward /= args.num_eval 129 | avg_reward = env.get_normalized_score(avg_reward)*100 130 | wandb.log({"loss": loss_epoch/len(dataloader), "avg_reward": avg_reward}) 131 | if avg_reward > best_eval: 132 | best_eval = avg_reward 133 | torch.save(model.state_dict(), "bc_models/"+filename+"_best.pth") 134 | else: 135 | wandb.log({"loss": loss_epoch/len(dataloader)}) -------------------------------------------------------------------------------- /rtb_diffusion/README.md: -------------------------------------------------------------------------------- 1 | # Relative Trajectory Balance for Posterior Diffusion Sampler 2 | 3 | This repository builds upon the original source code developed in: 4 | 5 | [On Diffusion Models for Amortized Inference: Benchmarking and Improving Stochastic Control and Sampling](https://arxiv.org/abs/2402.05098) 6 | 7 | ## Overview 8 | 9 | This repository provides a posterior diffusion sampler where the prior model is a 2D Gaussian mixture model with 25 modes (25gmm), and the posterior is a Gaussian mixture model with 9 modes, with the modes containing reweighted densities. This setup can be extended to various prior-reward combinations by incorporating custom energy functions into the `energy/` directory. 10 | 11 | ## Dependencies 12 | 13 | Ensure you have the following libraries installed: 14 | 15 | - `torch` 16 | - `einops` 17 | - `pot` 18 | - `matplotlib` 19 | 20 | ## Getting Started 21 | 22 | #### RTB finetuning: 23 | 24 | To fine-tune the posterior using Relative Trajectory Balance (RTB), run: 25 | 26 | ``` 27 | python finetune_posterior.py --method rtb --name "save_rtb_finetune" 28 | ``` 29 | 30 | #### RL finetuning: 31 | 32 | For reinforcement learning (RL) based fine-tuning, run: 33 | 34 | 35 | ``` 36 | python finetune_posterior.py --method rl --kl_weight 0.01 --name "save_rl_finetune" 37 | ``` 38 | 39 | #### Classifier Guidance: 40 | 41 | To use classifier guidance for fine-tuning the posterior, run: 42 | 43 | 44 | ``` 45 | python classifer_guidance_posterior.py 46 | ``` 47 | 48 | #### prior pretraning #### 49 | 50 | To pretrain the prior model, run: 51 | 52 | 53 | 54 | ``` 55 | python pretrain_prior.py 56 | ``` 57 | -------------------------------------------------------------------------------- /rtb_diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GFNOrg/diffusion-finetuning/21bfa222b606abd64873994f424a477efc29707d/rtb_diffusion/__init__.py -------------------------------------------------------------------------------- /rtb_diffusion/classifer_guidance_posterior.py: -------------------------------------------------------------------------------- 1 | from plot_utils import * 2 | import argparse 3 | import torch 4 | import os 5 | 6 | from utils import set_seed, fig_to_image 7 | from models import GFN 8 | from gflownet_losses import * 9 | from energies import * 10 | import copy 11 | import matplotlib.pyplot as plt 12 | from tqdm import trange 13 | 14 | 15 | parser = argparse.ArgumentParser(description='classifier_guidance_posterior') 16 | 17 | 18 | parser.add_argument('--seed', type=int, default=12345) 19 | 20 | parser.add_argument('--name', type=str, default='classifier_guidance') 21 | 22 | 23 | args = parser.parse_args() 24 | 25 | set_seed(args.seed) 26 | if 'SLURM_PROCID' in os.environ: 27 | args.seed += int(os.environ["SLURM_PROCID"]) 28 | 29 | plot_data_size = 2000 30 | 31 | 32 | 33 | args.zero_init = True 34 | 35 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 36 | 37 | def get_energy(): 38 | prior = TwentyFiveGaussianMixture(device=device) 39 | energy = Posterior2DGaussianMixture(device=device) 40 | return energy, prior 41 | 42 | 43 | 44 | 45 | def inference(): 46 | energy, prior_energy = get_energy() 47 | name = args.name 48 | 49 | gfn_model = GFN(2, 64, 64, 64, 64, 50 | trajectory_length=100, clipping=True, lgv_clip=1e2, gfn_clip=1e4, 51 | langevin=False, learned_variance=False, 52 | partial_energy=False, log_var_range=4., 53 | pb_scale_range=0.1, 54 | t_scale=5.0, langevin_scaling_per_dimension=False, 55 | conditional_flow_model=False, learn_pb=False, 56 | pis_architectures=True, lgv_layers=3, 57 | joint_layers=2, zero_init=True, device=device).to(device) 58 | 59 | start_epoch = 0 60 | 61 | 62 | checkpoint_path = 'pretrained/prior.pt' 63 | 64 | checkpoint = torch.load(checkpoint_path) 65 | 66 | if 'model_state_dict' in checkpoint: 67 | gfn_model.load_state_dict(checkpoint['model_state_dict']) 68 | else: 69 | gfn_model.load_state_dict(checkpoint) 70 | 71 | start_epoch = 0 72 | 73 | prior = copy.deepcopy(gfn_model) 74 | prior.eval() 75 | 76 | 77 | initial_state = torch.zeros(plot_data_size, energy.data_ndim).to(device) 78 | states, _, _, _ = prior.get_trajectory_fwd_classifier_guidance(initial_state, None, log_r= prior_energy.log_reward, 79 | log_classifier = energy.log_reward, guid_stren = 1.0) 80 | samples = states[:, -1] 81 | gt_samples = energy.sample(plot_data_size) 82 | 83 | fig_contour, ax_contour = get_figure(bounds=(-13., 13.)) 84 | fig_kde, ax_kde = get_figure(bounds=(-13., 13.)) 85 | fig_kde_overlay, ax_kde_overlay = get_figure(bounds=(-13., 13.)) 86 | 87 | plot_contours(energy.log_reward, ax=ax_contour, bounds=(-13., 13.), n_contour_levels=150, device=device) 88 | plot_kde(gt_samples, ax=ax_kde_overlay, bounds=(-13., 13.)) 89 | plot_kde(samples, ax=ax_kde, bounds=(-13., 13.)) 90 | plot_samples(samples, ax=ax_contour, bounds=(-13., 13.)) 91 | plot_samples(samples, ax=ax_kde_overlay, bounds=(-13., 13.)) 92 | 93 | fig_contour.savefig(f'output/{name}_contour.png', bbox_inches='tight') 94 | fig_kde_overlay.savefig(f'output/{name}_kde_overlay.png', bbox_inches='tight') 95 | fig_kde.savefig(f'output/{name}_kde.png', bbox_inches='tight') 96 | 97 | 98 | 99 | 100 | if __name__ == '__main__': 101 | inference() 102 | -------------------------------------------------------------------------------- /rtb_diffusion/energies/__init__.py: -------------------------------------------------------------------------------- 1 | from .twenty_five_gmm import TwentyFiveGaussianMixture 2 | from .posterior_2Dgmm import Posterior2DGaussianMixture -------------------------------------------------------------------------------- /rtb_diffusion/energies/base_set.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import torch 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | 6 | 7 | def nll_unit_gaussian(data, sigma=1.0): 8 | data = data.view(data.shape[0], -1) 9 | loss = 0.5 * np.log(2 * np.pi) + np.log(sigma) + 0.5 * data * data / (sigma ** 2) 10 | return torch.sum(torch.flatten(loss, start_dim=1), -1) 11 | 12 | 13 | class BaseSet(abc.ABC, Dataset): 14 | def __init__(self, len_data=-2333): 15 | self.num_sample = len_data 16 | self.data = None 17 | self.data_ndim = None 18 | self._gt_ksd = None 19 | 20 | def gt_logz(self): 21 | raise NotImplementedError 22 | 23 | @abc.abstractmethod 24 | def energy(self, x): 25 | return 26 | 27 | def unnorm_pdf(self, x): 28 | return torch.exp(-self.energy(x)) 29 | 30 | # hmt stands for hamiltonian 31 | def hmt_energy(self, x): 32 | dim = x.shape[-1] 33 | x, v = torch.split(x, dim // 2, dim=-1) 34 | neg_log_p_x = self.sample_energy_fn(x) 35 | neg_log_p_v = nll_unit_gaussian(v) 36 | return neg_log_p_x + neg_log_p_v 37 | 38 | @property 39 | def ndim(self): 40 | return self.data_ndim 41 | 42 | def sample(self, batch_size): 43 | del batch_size 44 | raise NotImplementedError 45 | 46 | def score(self, x): 47 | with torch.no_grad(): 48 | copy_x = x.detach().clone() 49 | copy_x.requires_grad = True 50 | with torch.enable_grad(): 51 | self.energy(copy_x).sum().backward() 52 | lgv_data = copy_x.grad.data 53 | return lgv_data 54 | 55 | def log_reward(self, x): 56 | return -self.energy(x) 57 | 58 | def hmt_score(self, x): 59 | with torch.no_grad(): 60 | copy_x = x.detach().clone() 61 | copy_x.requires_grad = True 62 | with torch.enable_grad(): 63 | self.hmt_energy(copy_x).sum().backward() 64 | lgv_data = copy_x.grad.data 65 | return lgv_data 66 | -------------------------------------------------------------------------------- /rtb_diffusion/energies/posterior_2Dgmm.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | import torch 4 | import torch.distributions as D 5 | from torch.distributions.mixture_same_family import MixtureSameFamily 6 | 7 | from .base_set import BaseSet 8 | 9 | from . import twenty_five_gmm 10 | 11 | 12 | class Posterior2DGaussianMixture(BaseSet): 13 | def __init__(self, device, scale=0.5477222, dim=2): 14 | super().__init__() 15 | self.device = device 16 | self.data = torch.tensor([0.0]) 17 | self.data_ndim = 2 18 | 19 | self.prior = twenty_five_gmm.TwentyFiveGaussianMixture(device, dim=2) 20 | 21 | mean_ls = [ 22 | [-10., -5.], [-5., -10.], [-5., 0.], 23 | [10., -5.], [0., 0.], [0., 5.], 24 | [5., -5.], [5., 0.], [5., 10.], 25 | ] 26 | 27 | nmode = len(mean_ls) 28 | mean = torch.stack([torch.tensor(xy) for xy in mean_ls]) 29 | comp = D.Independent(D.Normal(mean.to(self.device), torch.ones_like(mean).to(self.device) * scale), 1) 30 | 31 | probs = torch.Tensor([4, 10, 4, 5, 10, 5, 4, 15, 4]).to(self.device) 32 | probs = probs / probs.sum() 33 | mix = D.Categorical(probs=probs) 34 | self.gmm = MixtureSameFamily(mix, comp) 35 | self.data_ndim = dim 36 | 37 | def gt_logz(self): 38 | return 0. 39 | 40 | def energy(self, x): 41 | en = -(self.gmm.log_prob(x).flatten()) - self.prior.energy(x) #- self.prior.gmm.log_prob(x).flatten()) 42 | #print("x shape: ", x.shape) 43 | #print("en shape: ", en.shape) 44 | #exit() 45 | return en 46 | 47 | def sample(self, batch_size): 48 | return self.gmm.sample((batch_size,)) 49 | 50 | def viz_pdf(self, fsave="ou-density.png"): 51 | x = torch.linspace(-15, 15, 100).to(self.device) 52 | y = torch.linspace(-15, 15, 100).to(self.device) 53 | X, Y = torch.meshgrid(x, y) 54 | x = torch.stack([X.flatten(), Y.flatten()], dim=1) # ? 55 | 56 | density = self.unnorm_pdf(x) 57 | return x, density 58 | 59 | def __getitem__(self, idx): 60 | del idx 61 | return self.data[0] 62 | -------------------------------------------------------------------------------- /rtb_diffusion/energies/twenty_five_gmm.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | import torch 4 | import torch.distributions as D 5 | from torch.distributions.mixture_same_family import MixtureSameFamily 6 | 7 | from .base_set import BaseSet 8 | 9 | 10 | class TwentyFiveGaussianMixture(BaseSet): 11 | def __init__(self, device, dim=2): 12 | super().__init__() 13 | self.data = torch.tensor([0.0]) 14 | self.device = device 15 | 16 | modes = torch.Tensor([(a, b) for a in [-10, -5, 0, 5, 10] for b in [-10, -5, 0, 5, 10]]).to(self.device) 17 | 18 | nmode = 25 19 | self.nmode = nmode 20 | 21 | self.data_ndim = dim 22 | 23 | self.gmm = [D.MultivariateNormal(loc=mode.to(self.device), 24 | covariance_matrix=0.3 * torch.eye(self.data_ndim, device=self.device)) 25 | for mode in modes] 26 | 27 | def gt_logz(self): 28 | return 0. 29 | 30 | def energy(self, x): 31 | log_prob = torch.logsumexp(torch.stack([mvn.log_prob(x) for mvn in self.gmm]), dim=0, 32 | keepdim=False) - torch.log(torch.tensor(self.nmode, device=self.device)) 33 | return -log_prob 34 | 35 | def sample(self, batch_size): 36 | samples = torch.cat([mvn.sample((batch_size // self.nmode,)) for mvn in self.gmm], dim=0).to(self.device) 37 | return samples 38 | 39 | def viz_pdf(self, fsave="25gmm-density.png"): 40 | x = torch.linspace(-15, 15, 100).to(self.device) 41 | y = torch.linspace(-15, 15, 100).to(self.device) 42 | X, Y = torch.meshgrid(x, y) 43 | x = torch.stack([X.flatten(), Y.flatten()], dim=1) # ? 44 | 45 | density = self.unnorm_pdf(x) 46 | return x, density 47 | 48 | def __getitem__(self, idx): 49 | del idx 50 | return self.data[0] 51 | -------------------------------------------------------------------------------- /rtb_diffusion/finetune_posterior.py: -------------------------------------------------------------------------------- 1 | from plot_utils import * 2 | import argparse 3 | import torch 4 | import os 5 | 6 | from utils import set_seed, fig_to_image, get_gfn_optimizer, \ 7 | get_gfn_backward_loss, get_exploration_std, get_finetuning_loss 8 | from models import GFN 9 | from gflownet_losses import * 10 | from energies import * 11 | import copy 12 | import matplotlib.pyplot as plt 13 | from tqdm import trange 14 | 15 | WANDB = True 16 | 17 | if WANDB: 18 | import wandb 19 | 20 | parser = argparse.ArgumentParser(description='finetuning posterior') 21 | parser.add_argument('--lr_policy', type=float, default=1e-3) 22 | parser.add_argument('--batch_size', type=int, default=500) 23 | parser.add_argument('--epochs', type=int, default=5000) 24 | parser.add_argument('--seed', type=int, default=12345) 25 | parser.add_argument('--kl_weight', type=float, default=1.) 26 | parser.add_argument('--name', type=str, default='rtb_finetuning') 27 | parser.add_argument('--method', type=str, default='rtb') 28 | 29 | args = parser.parse_args() 30 | 31 | set_seed(args.seed) 32 | if 'SLURM_PROCID' in os.environ: 33 | args.seed += int(os.environ["SLURM_PROCID"]) 34 | 35 | plot_data_size = 2000 36 | 37 | 38 | 39 | args.zero_init = True 40 | 41 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 42 | 43 | def get_energy(): 44 | prior = TwentyFiveGaussianMixture(device=device) 45 | energy = Posterior2DGaussianMixture(device=device) 46 | return energy, prior 47 | 48 | 49 | def plot_step(energy, gfn_model, name): 50 | 51 | 52 | batch_size = plot_data_size 53 | samples = gfn_model.sample(batch_size, energy.log_reward) 54 | gt_samples = energy.sample(batch_size) 55 | 56 | fig_contour, ax_contour = get_figure(bounds=(-13., 13.)) 57 | fig_kde, ax_kde = get_figure(bounds=(-13., 13.)) 58 | fig_kde_overlay, ax_kde_overlay = get_figure(bounds=(-13., 13.)) 59 | 60 | plot_contours(energy.log_reward, ax=ax_contour, bounds=(-13., 13.), n_contour_levels=150, device=device) 61 | plot_kde(gt_samples, ax=ax_kde_overlay, bounds=(-13., 13.)) 62 | plot_kde(samples, ax=ax_kde, bounds=(-13., 13.)) 63 | plot_samples(samples, ax=ax_contour, bounds=(-13., 13.)) 64 | plot_samples(samples, ax=ax_kde_overlay, bounds=(-13., 13.)) 65 | 66 | fig_contour.savefig(f'output/{name}_contour.png', bbox_inches='tight') 67 | fig_kde_overlay.savefig(f'output/{name}_kde_overlay.png', bbox_inches='tight') 68 | fig_kde.savefig(f'output/{name}_kde.png', bbox_inches='tight') 69 | try: 70 | return {"visualization/contour": wandb.Image(fig_to_image(fig_contour)), 71 | "visualization/kde_overlay": wandb.Image(fig_to_image(fig_kde_overlay)), 72 | "visualization/kde": wandb.Image(fig_to_image(fig_kde))} 73 | except: 74 | return {} 75 | 76 | 77 | def train_step(energy, prior, gfn_model, gfn_optimizer, it, method, exploratory, exploration_factor, exploration_wd, beta = 1.0, kl_weight=1.0): 78 | gfn_model.zero_grad() 79 | exploration_std = get_exploration_std(it, exploratory, exploration_factor, exploration_wd) 80 | if method == 'rtb': 81 | 82 | loss, kl_div = fwd_train_step(energy, prior, gfn_model, exploration_std, method, beta = beta) 83 | loss.backward() 84 | gfn_optimizer.step() 85 | return loss.item(), kl_div.item() 86 | elif method == 'rl': 87 | exploration_std = get_exploration_std(it, False, 0.0, False) 88 | 89 | rl_loss, kl_loss, kl_div = fwd_train_step(energy, prior, gfn_model, exploration_std, method) 90 | loss = (rl_loss + kl_weight * kl_loss).mean() 91 | loss.backward() 92 | gfn_optimizer.step() 93 | return loss.item(), rl_loss.mean().item(), kl_loss.mean().item(), kl_div.item() 94 | 95 | def fwd_train_step(energy, prior, gfn_model, exploration_std, method, return_exp=False, beta=1.0): 96 | init_state = torch.zeros(args.batch_size, energy.data_ndim).to(device) 97 | 98 | if method == 'rtb': 99 | return get_finetuning_loss('rtb', init_state, prior, gfn_model, energy.log_reward, beta = beta, exploration_std=exploration_std, return_exp=return_exp) 100 | else: 101 | return get_finetuning_loss('rl', init_state, prior, gfn_model, energy.log_reward, beta = beta, exploration_std=exploration_std, return_exp=return_exp) 102 | 103 | 104 | 105 | 106 | 107 | def train(): 108 | energy, prior_energy = get_energy() 109 | name = args.name 110 | 111 | gfn_model = GFN(2, 64, 64, 64, 64, 112 | trajectory_length=100, clipping=True, lgv_clip=1e2, gfn_clip=1e4, 113 | langevin=False, learned_variance=False, 114 | partial_energy=False, log_var_range=4., 115 | pb_scale_range=0.1, 116 | t_scale=5.0, langevin_scaling_per_dimension=False, 117 | conditional_flow_model=False, learn_pb=False, 118 | pis_architectures=True, lgv_layers=3, 119 | joint_layers=2, zero_init=True, device=device).to(device) 120 | 121 | 122 | gfn_optimizer = get_gfn_optimizer(gfn_model, args.lr_policy, 0.1, args.lr_policy, False, 123 | False, False, False) 124 | 125 | start_epoch = 0 126 | 127 | 128 | checkpoint_path = 'pretrained/prior.pt' 129 | 130 | checkpoint = torch.load(checkpoint_path) 131 | 132 | if 'model_state_dict' in checkpoint: 133 | gfn_model.load_state_dict(checkpoint['model_state_dict']) 134 | gfn_optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 135 | else: 136 | gfn_model.load_state_dict(checkpoint) 137 | 138 | start_epoch = 0 139 | 140 | prior = copy.deepcopy(gfn_model) 141 | prior.eval() 142 | method = args.method 143 | 144 | metrics = dict() 145 | 146 | 147 | gfn_model.train() 148 | for i in trange(start_epoch, args.epochs + 1): 149 | if method == 'rtb': 150 | # off-policy: with exploration noise (True) 151 | loss, kl_div = train_step(energy, prior, gfn_model, gfn_optimizer, i, method, True, 0.5, True, beta=1.0) 152 | 153 | metrics['train/loss'] = loss 154 | metrics['train/kl_div'] = kl_div 155 | else: 156 | # on-policy: no exploratinon noise (False) 157 | loss, rl_loss, kl_loss, kl_div = train_step(energy, prior, gfn_model, gfn_optimizer, i, method, False, 0.0, False, beta=1.0, kl_weight=args.kl_weight) 158 | 159 | metrics['train/loss'] = loss 160 | metrics['train/rl_loss'] = rl_loss 161 | metrics['train/kl_loss'] = kl_loss 162 | metrics['train/kl_div'] = kl_div 163 | 164 | if i % 100 == 0: 165 | 166 | images = plot_step(energy, gfn_model, name) 167 | metrics.update(images) 168 | plt.close('all') 169 | 170 | # you may put logger here 171 | ######################### 172 | # wandb.log(metrics) 173 | 174 | 175 | ######################### 176 | print(metrics) 177 | 178 | if i % 1000 == 0: 179 | torch.save({ 180 | 'epoch': i, 181 | 'model_state_dict': gfn_model.state_dict(), 182 | 'optimizer_state_dict': gfn_optimizer.state_dict(), 183 | }, f'output/{name}.pt') 184 | 185 | images = plot_step(energy, gfn_model, name) 186 | metrics.update(images) 187 | plt.close('all') 188 | 189 | torch.save({ 190 | 'epoch': args.epochs, 191 | 'model_state_dict': gfn_model.state_dict(), 192 | 'optimizer_state_dict': gfn_optimizer.state_dict(), 193 | }, f'output/{name}.pt') 194 | 195 | 196 | 197 | if __name__ == '__main__': 198 | train() 199 | -------------------------------------------------------------------------------- /rtb_diffusion/gflownet_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.distributions import Normal 3 | 4 | 5 | def fwd_rl(initial_state, prior, gfn, log_reward_fn, exploration_std=None, return_exp = False): 6 | 7 | states, log_p_posterior, log_pbs, log_fs = gfn.get_trajectory_fwd(initial_state, exploration_std, log_reward_fn) 8 | log_p_prior = prior.get_trajectory_fwd_off(states.detach(), log_reward_fn).detach() 9 | 10 | 11 | with torch.no_grad(): 12 | log_r = log_reward_fn(states[:, -1]).detach() 13 | reward = log_r.exp() 14 | adv = reward - reward.mean() 15 | 16 | kl_loss = (log_p_posterior.sum(-1) - log_p_prior.sum(-1))**2 17 | kl_div = (log_p_posterior.sum(-1) - log_p_prior.sum(-1)).mean() 18 | 19 | reinforce_loss = -adv * log_p_posterior.sum(-1) 20 | 21 | return reinforce_loss, kl_loss, kl_div 22 | 23 | def fwd_rtb(initial_state, prior, gfn, log_reward_fn, exploration_std=None, return_exp = False, beta = 1.0): 24 | states, log_pfs, log_pbs, log_fs = gfn.get_trajectory_fwd(initial_state, exploration_std, log_reward_fn) 25 | log_p_prior = prior.get_trajectory_fwd_off(states.detach(), log_reward_fn).detach() 26 | with torch.no_grad(): 27 | log_r = log_reward_fn(states[:, -1]).detach() 28 | kl_div = (log_pfs.sum(-1) - log_p_prior.sum(-1)).mean() 29 | 30 | loss = 0.5 * ((log_pfs.sum(-1) + log_fs[:, 0] - log_p_prior.sum(-1) - beta * log_r) ** 2) 31 | if return_exp: 32 | return loss.mean(), states, log_pfs, log_pbs, log_r, kl_div 33 | return loss.mean(), kl_div 34 | 35 | 36 | def bwd_mle(samples, gfn, log_reward_fn, exploration_std=None): 37 | states, log_pfs, log_pbs, log_fs = gfn.get_trajectory_bwd(samples, exploration_std, log_reward_fn) 38 | loss = -log_pfs.sum(-1) 39 | return loss.mean() 40 | 41 | 42 | -------------------------------------------------------------------------------- /rtb_diffusion/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .gfn import * 2 | -------------------------------------------------------------------------------- /rtb_diffusion/plot_utils.py: -------------------------------------------------------------------------------- 1 | import seaborn as sns 2 | import torch 3 | import matplotlib.pyplot as plt 4 | import itertools 5 | import numpy as np 6 | from einops import rearrange 7 | 8 | 9 | def get_figure(bounds=(-10., 10.)): 10 | fig, ax = plt.subplots(1, figsize=(16, 16)) 11 | ax.axis('off') 12 | ax.set_autoscale_on(False) 13 | ax.set_xlim([bounds[0], bounds[1]]) 14 | ax.set_ylim([bounds[0], bounds[1]]) 15 | return fig, ax 16 | 17 | 18 | def plot_contours(log_prob, ax=None, bounds=(-10., 10.), grid_width_n_points=200, n_contour_levels=50, 19 | log_prob_min=-1000., device=torch.device('cuda')): 20 | """Plot contours of a log_prob_func that is defined on 2D""" 21 | if ax is None: 22 | fig, ax = plt.subplots(1) 23 | x_points_dim1 = torch.linspace(bounds[0], bounds[1], grid_width_n_points) 24 | x_points_dim2 = x_points_dim1 25 | x_points = torch.tensor(list(itertools.product(x_points_dim1, x_points_dim2))) 26 | log_p_x = log_prob(x_points.to(device)).detach().cpu() 27 | log_p_x = torch.clamp_min(log_p_x, log_prob_min) 28 | log_p_x = log_p_x.reshape((grid_width_n_points, grid_width_n_points)) 29 | x_points_dim1 = x_points[:, 0].reshape((grid_width_n_points, grid_width_n_points)).numpy() 30 | x_points_dim2 = x_points[:, 1].reshape((grid_width_n_points, grid_width_n_points)).numpy() 31 | if n_contour_levels: 32 | ax.contour(x_points_dim1, x_points_dim2, log_p_x, levels=n_contour_levels) 33 | else: 34 | ax.contour(x_points_dim1, x_points_dim2, log_p_x) 35 | 36 | 37 | def plot_samples(samples, ax=None, bounds=(-10., 10.), alpha=0.5): 38 | if ax is None: 39 | fig, ax = plt.subplots(1) 40 | samples = torch.clamp(samples, bounds[0], bounds[1]) 41 | samples = samples.cpu().detach() 42 | ax.scatter(samples[:, 0], samples[:, 1], alpha=alpha, marker="o", s=10) 43 | 44 | 45 | def plot_kde(samples, ax=None, bounds=(-10., 10.)): 46 | if ax is None: 47 | fig, ax = plt.subplots(1) 48 | samples = samples.cpu().detach() 49 | sns.kdeplot(x=samples[:, 0], y=samples[:, 1], cmap="Blues", fill=True, ax=ax, clip=bounds) 50 | 51 | 52 | def viz_many_well(mw_energy, samples=None, num_samples=5000): 53 | if samples is None: 54 | samples = mw_energy.sample(num_samples) 55 | 56 | x13 = samples[:, 0:3:2].detach().cpu() 57 | fig_samples_x13, ax_samples_x13 = viz_sample2d(x13, "samples", f"distx13.png", lim=3) 58 | fig_kde_x13, ax_kde_x13 = viz_kde2d(x13, "kde", f"kdex13.png", lim=3) 59 | 60 | lim = 3 61 | alpha = 0.8 62 | n_contour_levels = 20 63 | 64 | def logp_func(x_2d): 65 | x = torch.zeros((x_2d.shape[0], mw_energy.data_ndim)).to(mw_energy.device) 66 | x[:, 0] = x_2d[:, 0] 67 | x[:, 2] = x_2d[:, 1] 68 | return -mw_energy.energy(x).detach().cpu() 69 | 70 | x13 = samples[:, 0:3:2] 71 | contour_img_path = f"contourx13.png" 72 | fig_contour_x13, ax_contour_x13 = viz_contour_sample2d(x13, contour_img_path, logp_func, lim=lim, alpha=alpha, 73 | n_contour_levels=n_contour_levels) 74 | 75 | x23 = samples[:, 1:3].detach().cpu() 76 | fig_samples_x23, ax_samples_x23 = viz_sample2d(x23, "samples", f"distx23.png", lim=3) 77 | fig_kde_x23, ax_kde_x23 = viz_kde2d(x23, "kde", f"kdex23.png", lim=3) 78 | 79 | def logp_func(x_2d): 80 | x = torch.zeros((x_2d.shape[0], mw_energy.data_ndim)).to(mw_energy.device) 81 | x[:, 1] = x_2d[:, 0] 82 | x[:, 2] = x_2d[:, 1] 83 | return -mw_energy.energy(x).detach().cpu() 84 | 85 | x23 = samples[:, 1:3] 86 | contour_img_path2 = f"contourx23.png" 87 | fig_contour_x23, ax_contour_x23 = viz_contour_sample2d(x23, contour_img_path2, logp_func, lim=lim, alpha=alpha, 88 | n_contour_levels=n_contour_levels) 89 | 90 | return fig_samples_x13, ax_samples_x13, fig_kde_x13, ax_kde_x13, fig_contour_x13, ax_contour_x13, fig_samples_x23, ax_samples_x23, fig_kde_x23, ax_kde_x23, fig_contour_x23, ax_contour_x23 91 | 92 | 93 | def traj_plot1d(traj_len, samples, xlabel, ylabel, title="", fsave="img.png"): 94 | samples = rearrange(samples, "t b d -> b t d").cpu() 95 | inds = np.linspace(0, samples.shape[1], traj_len, endpoint=False, dtype=int) 96 | samples = samples[:, inds] 97 | plt.figure() 98 | for i, sample in enumerate(samples): 99 | plt.plot(np.arange(traj_len), sample.flatten(), marker="x", label=f"sample {i}") 100 | plt.title(title) 101 | plt.xlabel(xlabel) 102 | plt.ylabel(ylabel) 103 | plt.savefig(fsave) 104 | plt.close() 105 | 106 | 107 | ########### 2D plot 108 | def viz_sample2d(points, title, fsave, lim=7.0, sample_num=50000): 109 | fig, ax = plt.subplots(1, 1, figsize=(7, 7)) 110 | if title is not None: 111 | ax.set_title(title) 112 | ax.plot( 113 | points[:sample_num, 0], 114 | points[:sample_num, 1], 115 | linewidth=0, 116 | marker=".", 117 | markersize=1, 118 | ) 119 | ax.set_xlim(-lim, lim) 120 | ax.set_ylim(-lim, lim) 121 | return fig, ax 122 | 123 | 124 | def viz_kde2d(points, title, fname, lim=7.0, sample_num=2000): 125 | fig, ax = plt.subplots(1, 1, figsize=(7, 7), dpi=200) 126 | if title is not None: 127 | ax.set_title(title) 128 | sns.kdeplot( 129 | x=points[:sample_num, 0], y=points[:sample_num, 1], 130 | cmap="coolwarm", fill=True, ax=ax 131 | ) 132 | ax.set_xlim(-lim, lim) 133 | ax.set_ylim(-lim, lim) 134 | return fig, ax 135 | 136 | 137 | def viz_coutour_with_ax(ax, log_prob_func, lim=3.0, n_contour_levels=None): 138 | grid_width_n_points = 100 139 | log_prob_min = -1000.0 140 | x_points_dim1 = torch.linspace(-lim, lim, grid_width_n_points) 141 | x_points_dim2 = x_points_dim1 142 | x_points = torch.tensor(list(itertools.product(x_points_dim1, x_points_dim2))) 143 | log_p_x = log_prob_func(x_points).detach().cpu() 144 | log_p_x = torch.clamp_min(log_p_x, log_prob_min) 145 | log_p_x = log_p_x.reshape((grid_width_n_points, grid_width_n_points)) 146 | x_points_dim1 = x_points[:, 0].reshape((grid_width_n_points, grid_width_n_points)).numpy() 147 | x_points_dim2 = x_points[:, 1].reshape((grid_width_n_points, grid_width_n_points)).numpy() 148 | if n_contour_levels: 149 | ax.contour(x_points_dim1, x_points_dim2, log_p_x, levels=n_contour_levels) 150 | else: 151 | ax.contour(x_points_dim1, x_points_dim2, log_p_x) 152 | 153 | 154 | def viz_contour_sample2d(points, fname, log_prob_func, 155 | lim=3.0, alpha=0.7, n_contour_levels=None): 156 | fig, ax = plt.subplots(1, 1, figsize=(7, 7)) 157 | 158 | viz_coutour_with_ax(ax, log_prob_func, lim=lim, n_contour_levels=n_contour_levels) 159 | 160 | samples = torch.clamp(points, -lim, lim) 161 | samples = samples.cpu().detach() 162 | ax.plot(samples[:, 0], samples[:, 1], 163 | linewidth=0, marker=".", markersize=1.5, alpha=alpha) 164 | 165 | return fig, ax 166 | -------------------------------------------------------------------------------- /rtb_diffusion/pretrain_prior.py: -------------------------------------------------------------------------------- 1 | from plot_utils import * 2 | import argparse 3 | import torch 4 | import os 5 | 6 | from utils import set_seed, fig_to_image, get_gfn_optimizer, get_gfn_backward_loss, get_exploration_std 7 | from models import GFN 8 | from gflownet_losses import * 9 | from energies import * 10 | 11 | 12 | import matplotlib.pyplot as plt 13 | from tqdm import trange 14 | 15 | 16 | parser = argparse.ArgumentParser(description='pretrain_prior') 17 | parser.add_argument('--lr_policy', type=float, default=1e-3) 18 | parser.add_argument('--lr_flow', type=float, default=1e-1) 19 | parser.add_argument('--lr_back', type=float, default=1e-3) 20 | parser.add_argument('--hidden_dim', type=int, default=64) 21 | parser.add_argument('--s_emb_dim', type=int, default=64) 22 | parser.add_argument('--t_emb_dim', type=int, default=64) 23 | parser.add_argument('--harmonics_dim', type=int, default=64) 24 | parser.add_argument('--batch_size', type=int, default=500) 25 | parser.add_argument('--epochs', type=int, default=10000) 26 | parser.add_argument('--T', type=int, default=100) 27 | parser.add_argument('--t_scale', type=float, default=5.) 28 | parser.add_argument('--log_var_range', type=float, default=4.) 29 | 30 | 31 | parser.add_argument('--exploratory', action='store_true', default=False) 32 | parser.add_argument('--langevin', action='store_true', default=False) 33 | parser.add_argument('--langevin_scaling_per_dimension', action='store_true', default=False) 34 | parser.add_argument('--conditional_flow_model', action='store_true', default=False) 35 | parser.add_argument('--learn_pb', action='store_true', default=False) 36 | parser.add_argument('--pb_scale_range', type=float, default=0.1) 37 | parser.add_argument('--learned_variance', action='store_true', default=False) 38 | parser.add_argument('--partial_energy', action='store_true', default=False) 39 | parser.add_argument('--exploration_factor', type=float, default=0.1) 40 | parser.add_argument('--exploration_wd', action='store_true', default=False) 41 | parser.add_argument('--clipping', action='store_true', default=False) 42 | parser.add_argument('--lgv_clip', type=float, default=1e2) 43 | parser.add_argument('--gfn_clip', type=float, default=1e4) 44 | parser.add_argument('--zero_init', action='store_true', default=True) 45 | parser.add_argument('--pis_architectures', action='store_true', default=True) 46 | parser.add_argument('--lgv_layers', type=int, default=3) 47 | parser.add_argument('--joint_layers', type=int, default=2) 48 | parser.add_argument('--seed', type=int, default=12345) 49 | parser.add_argument('--weight_decay', type=float, default=1e-7) 50 | parser.add_argument('--use_weight_decay', action='store_true', default=False) 51 | 52 | parser.add_argument('--name', type=str, default='prior_25gmm', help='Name of the run') 53 | args = parser.parse_args() 54 | 55 | set_seed(args.seed) 56 | if 'SLURM_PROCID' in os.environ: 57 | args.seed += int(os.environ["SLURM_PROCID"]) 58 | 59 | plot_data_size = 2000 60 | 61 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 62 | 63 | def get_energy(): 64 | 65 | energy = TwentyFiveGaussianMixture(device=device) 66 | 67 | return energy 68 | 69 | 70 | def plot_step(energy, gfn_model, name): 71 | batch_size = plot_data_size 72 | samples = gfn_model.sample(batch_size, energy.log_reward) 73 | gt_samples = energy.sample(batch_size) 74 | 75 | fig_contour, ax_contour = get_figure(bounds=(-13., 13.)) 76 | fig_kde, ax_kde = get_figure(bounds=(-13., 13.)) 77 | fig_kde_overlay, ax_kde_overlay = get_figure(bounds=(-13., 13.)) 78 | 79 | plot_contours(energy.log_reward, ax=ax_contour, bounds=(-13., 13.), n_contour_levels=150, device=device) 80 | plot_kde(gt_samples, ax=ax_kde_overlay, bounds=(-13., 13.)) 81 | plot_kde(samples, ax=ax_kde, bounds=(-13., 13.)) 82 | plot_samples(samples, ax=ax_contour, bounds=(-13., 13.)) 83 | plot_samples(samples, ax=ax_kde_overlay, bounds=(-13., 13.)) 84 | 85 | fig_contour.savefig(f'{name}contour.pdf', bbox_inches='tight') 86 | fig_kde_overlay.savefig(f'{name}kde_overlay.pdf', bbox_inches='tight') 87 | fig_kde.savefig(f'{name}kde.pdf', bbox_inches='tight') 88 | 89 | try: 90 | return {"visualization/contour": wandb.Image(fig_to_image(fig_contour)), 91 | "visualization/kde_overlay": wandb.Image(fig_to_image(fig_kde_overlay)), 92 | "visualization/kde": wandb.Image(fig_to_image(fig_kde))} 93 | except: 94 | return {} 95 | 96 | 97 | 98 | 99 | def train_step(energy, gfn_model, gfn_optimizer, it, exploratory, exploration_factor, exploration_wd): 100 | gfn_model.zero_grad() 101 | 102 | exploration_std = get_exploration_std(it, exploratory, exploration_factor, exploration_wd) 103 | 104 | # True samples 105 | samples = energy.sample(args.batch_size).to(device) 106 | 107 | # MLE training 108 | loss = get_gfn_backward_loss('mle', samples, gfn_model, energy.log_reward, 109 | exploration_std=exploration_std) 110 | 111 | 112 | loss.backward() 113 | gfn_optimizer.step() 114 | return loss.item() 115 | 116 | 117 | def train(): 118 | 119 | name = f'pretrained/{args.name}.pt' 120 | 121 | energy = get_energy() 122 | 123 | 124 | config = args.__dict__ 125 | config["Experiment"] = "{args.energy}" 126 | 127 | 128 | gfn_model = GFN(energy.data_ndim, args.s_emb_dim, args.hidden_dim, args.harmonics_dim, args.t_emb_dim, 129 | trajectory_length=args.T, clipping=args.clipping, lgv_clip=args.lgv_clip, gfn_clip=args.gfn_clip, 130 | langevin=args.langevin, learned_variance=args.learned_variance, 131 | partial_energy=args.partial_energy, log_var_range=args.log_var_range, 132 | pb_scale_range=args.pb_scale_range, 133 | t_scale=args.t_scale, langevin_scaling_per_dimension=args.langevin_scaling_per_dimension, 134 | conditional_flow_model=args.conditional_flow_model, learn_pb=args.learn_pb, 135 | pis_architectures=args.pis_architectures, lgv_layers=args.lgv_layers, 136 | joint_layers=args.joint_layers, zero_init=args.zero_init, device=device).to(device) 137 | 138 | 139 | gfn_optimizer = get_gfn_optimizer(gfn_model, args.lr_policy, args.lr_flow, args.lr_back, args.learn_pb, 140 | args.conditional_flow_model, args.use_weight_decay, args.weight_decay) 141 | 142 | start_epoch = 0 143 | 144 | 145 | print(gfn_model) 146 | metrics = dict() 147 | 148 | 149 | gfn_model.train() 150 | for i in trange(start_epoch, args.epochs + 1): 151 | metrics['train/loss'] = train_step(energy, gfn_model, gfn_optimizer, i, args.exploratory, args.exploration_factor, args.exploration_wd) 152 | if i % 100 == 0: 153 | 154 | images = plot_step(energy, gfn_model, name) 155 | metrics.update(images) 156 | plt.close('all') 157 | 158 | #wandb.log(metrics, step=i) 159 | 160 | if i % 1000 == 0: 161 | torch.save({ 162 | 'epoch': i, 163 | 'model_state_dict': gfn_model.state_dict(), 164 | 'optimizer_state_dict': gfn_optimizer.state_dict(), 165 | }, name) 166 | images = plot_step(energy, gfn_model, name) 167 | metrics.update(images) 168 | plt.close('all') 169 | torch.save({ 170 | 'epoch': args.epochs, 171 | 'model_state_dict': gfn_model.state_dict(), 172 | 'optimizer_state_dict': gfn_optimizer.state_dict(), 173 | }, name) 174 | 175 | 176 | 177 | 178 | if __name__ == '__main__': 179 | train() 180 | -------------------------------------------------------------------------------- /rtb_diffusion/pretrained/prior.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GFNOrg/diffusion-finetuning/21bfa222b606abd64873994f424a477efc29707d/rtb_diffusion/pretrained/prior.pt -------------------------------------------------------------------------------- /rtb_diffusion/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import math 4 | import PIL 5 | 6 | from gflownet_losses import * 7 | 8 | 9 | def set_seed(seed): 10 | torch.manual_seed(seed) 11 | torch.cuda.manual_seed(seed) 12 | torch.cuda.manual_seed_all(seed) 13 | random.seed(seed) 14 | torch.backends.cudnn.deterministic = True 15 | torch.backends.cudnn.benchmark = False 16 | np.random.seed(seed) 17 | 18 | 19 | def logmeanexp(x, dim=0): 20 | return x.logsumexp(dim) - math.log(x.shape[dim]) 21 | 22 | 23 | def dcp(tensor): 24 | return tensor.detach().cpu() 25 | 26 | 27 | def gaussian_params(tensor): 28 | mean, logvar = torch.chunk(tensor, 2, dim=-1) 29 | return mean, logvar 30 | 31 | 32 | def fig_to_image(fig): 33 | fig.canvas.draw() 34 | 35 | return PIL.Image.frombytes( 36 | "RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb() 37 | ) 38 | 39 | 40 | def get_gfn_optimizer(gfn_model, lr_policy, lr_flow, lr_back, back_model=False, conditional_flow_model=False, use_weight_decay=False, weight_decay=1e-7): 41 | param_groups = [ {'params': gfn_model.t_model.parameters()}, 42 | {'params': gfn_model.s_model.parameters()}, 43 | {'params': gfn_model.joint_model.parameters()}, 44 | {'params': gfn_model.langevin_scaling_model.parameters()} ] 45 | if conditional_flow_model: 46 | param_groups += [ {'params': gfn_model.flow_model.parameters(), 'lr': lr_flow} ] 47 | else: 48 | param_groups += [ {'params': [gfn_model.flow_model], 'lr': lr_flow} ] 49 | 50 | if back_model: 51 | param_groups += [ {'params': gfn_model.back_model.parameters(), 'lr': lr_back} ] 52 | 53 | if use_weight_decay: 54 | gfn_optimizer = torch.optim.Adam(param_groups, lr_policy, weight_decay=weight_decay) 55 | else: 56 | gfn_optimizer = torch.optim.Adam(param_groups, lr_policy) 57 | return gfn_optimizer 58 | 59 | 60 | 61 | 62 | 63 | def get_finetuning_loss(mode, init_state, prior, gfn_model, log_reward, beta = 1.0, exploration_std=None, return_exp=False): 64 | if mode == 'rl': 65 | rl_loss, kl_loss, kl_div = fwd_rl(init_state, prior, gfn_model, log_reward, exploration_std) 66 | return rl_loss, kl_loss, kl_div 67 | elif mode == 'rtb': 68 | return fwd_rtb(init_state, prior, gfn_model, log_reward, exploration_std, beta = beta, return_exp=return_exp) 69 | 70 | return None 71 | 72 | def get_gfn_backward_loss(mode, samples, gfn_model, log_reward, exploration_std=None): 73 | 74 | loss = bwd_mle(samples, gfn_model, log_reward, exploration_std) 75 | return loss 76 | 77 | 78 | def get_exploration_std(iter, exploratory, exploration_factor=0.1, exploration_wd=False): 79 | if exploratory is False: 80 | return None 81 | if exploration_wd: 82 | if iter < 500: 83 | exploration_std = exploration_factor 84 | else: 85 | exploration_std = exploration_factor * max(0, 1. - iter / 4500.) 86 | else: 87 | exploration_std = exploration_factor 88 | expl = lambda x: exploration_std 89 | return expl 90 | 91 | 92 | 93 | -------------------------------------------------------------------------------- /text_to_image/README.md: -------------------------------------------------------------------------------- 1 | # Text-to-image with Stable Diffusion 2 | Codebase adapted from the official [DPOK: Reinforcement Learning for Fine-tuning Text-to-Image Diffusion Models](https://github.com/google-research/google-research/tree/master/dpok). 3 | See the [DPOK paper](https://arxiv.org/abs/2305.16381) for reference. 4 | 5 | ![](./img/rabbit.gif) 6 | 7 | ## Env Installation 8 | 9 | We recommend the Anaconda version Anaconda3-2022.10. 10 | 11 | ```bash 12 | wget -P /tmp https://repo.anaconda.com/archive/Anaconda3-2022.10-Linux-x86_64.sh 13 | bash /tmp/Anaconda3-2022.10-Linux-x86_64.sh 14 | conda init bash 15 | source ~/.bashrc 16 | ``` 17 | 18 | Create a conda environment and install required modules. 19 | 20 | ```bash 21 | conda env create -f environment.yaml 22 | conda activate dpok 23 | ``` 24 | 25 | Install ImageReward module. 26 | 27 | ```bash 28 | bash install_image_reward.sh 29 | ``` 30 | 31 | If training gets stuck, try uninstall torch_xla and reinstall accelerate: 32 | 33 | ```bash 34 | pip uninstall torch_xla 35 | pip uninstall accelerate 36 | pip install accelerate==0.17.0 37 | ``` 38 | 39 | ## DPOK Training 40 | 41 | ``` 42 | accelerate launch train_online_pg.py --p_batch_size 8 --reward_weight 10 --kl_weight 0.01 --learning_rate 1e-5 --single_flag 1 --single_prompt "A green colored rabbit." --gradient_accumulation_steps 12 --clip_norm 0.1 --g_batch_size 10 --multi_gpu 0 --v_flag 1 --output_dir 'green_rabbit' [--report_to "wandb"] 43 | ``` 44 | 45 | Explanation of the arguments: 46 | - `p_batch_size`: batch size for policy training. Batch size 4 can be used for a single A100 GPU. If multiple GPUs are used, we need to share the entire unet which is more memory-consuming. In this case, we need to use a smaller batch size for policy training (2 for each A100). 47 | - `reward_weight`: weight for the reward term in the policy loss 48 | - `kl_weight`: weight for the KL term in the policy loss 49 | - `learning_rate`: learning rate for the policy network 50 | - `single_flag`: whether to train on a single prompt 51 | - `single_prompt`: the single prompt to train on 52 | - `gradient_accumulation_steps`: number of gradient accumulation steps 53 | - `clip_norm`: gradient clipping norm 54 | - `g_batch_size`: batch size for generation. Batch size 12 can be used for a single A100 GPU. 55 | - `multi_gpu`: whether to use multiple GPUs. 56 | - `v_flag`: whether to use value learning. 57 | 58 | The LoRA weights and tensorboard logs will be saved under `./online_model/img_reward_0/pre_train/single_prompt/prompt_name`. 59 | 60 | ## RTB Training 61 | ``` 62 | accelerate launch train_gfn.py --p_batch_size 8 --reward_weight 1.0 --learning_rate 1e-5 --single_flag 1 --single_prompt "A green colored rabbit." --gradient_accumulation_steps 12 --clip_norm 0.1 --g_batch_size 10 --multi_gpu 0 --v_flag 0 --output_dir 'green_rabbit' [--report_to "wandb"] 63 | ``` 64 | 65 | ## Test 66 | Put the lora weights under `./test`, and run `test.py --prompt "A green colored rabbit."`. One can modify the test prompt and model path as needed in the arguments. 67 | -------------------------------------------------------------------------------- /text_to_image/dpok_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Helper functions.""" 17 | 18 | from ImageReward import ImageReward 19 | import torch 20 | 21 | 22 | def image_reward_get_reward( 23 | model, pil_image, prompt, weight_dtype 24 | ): 25 | """Gets rewards using ImageReward model.""" 26 | image = ( 27 | model.preprocess(pil_image).unsqueeze(0).to(weight_dtype).to(model.device) 28 | ) 29 | image_embeds = model.blip.visual_encoder(image) 30 | image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( 31 | model.device 32 | ) 33 | 34 | text_input = model.blip.tokenizer( 35 | prompt, 36 | padding="max_length", 37 | truncation=True, 38 | max_length=35, 39 | return_tensors="pt", 40 | ).to(model.device) 41 | text_output = model.blip.text_encoder( 42 | text_input.input_ids, 43 | attention_mask=text_input.attention_mask, 44 | encoder_hidden_states=image_embeds, 45 | encoder_attention_mask=image_atts, 46 | return_dict=True, 47 | ) 48 | txt_features = text_output.last_hidden_state[:, 0, :] 49 | rewards = model.mlp(txt_features) 50 | rewards = (rewards - model.mean) / model.std 51 | return rewards, txt_features 52 | -------------------------------------------------------------------------------- /text_to_image/environment.yaml: -------------------------------------------------------------------------------- 1 | name: dpok 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.7 7 | - pip=20.3 8 | - cudatoolkit=11.3 9 | - pytorch=1.12.1 10 | - torchvision=0.13.1 11 | - numpy=1.21 12 | - pip: 13 | - accelerate==0.17.0 14 | - diffusers==0.14.0 15 | - opencv-python==4.1.2.30 16 | - imageio==2.9.0 17 | - imageio-ffmpeg==0.4.2 18 | - einops==0.3.0 19 | - tensorboard 20 | - transformers 21 | - kornia 22 | - fairscale 23 | - importlib-metadata 24 | - timm 25 | - datasets 26 | - huggingface_hub==0.14.1 27 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip 28 | -------------------------------------------------------------------------------- /text_to_image/eval_div.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | from PIL import Image 4 | import timm 5 | import glob 6 | from torch.nn.functional import cosine_similarity 7 | from argparse import ArgumentParser 8 | 9 | def main(args): 10 | # Load the pre-trained DINO model 11 | model = torch.hub.load('facebookresearch/dino:main', 'dino_vits16') 12 | model.eval() 13 | 14 | # Assume CUDA is available and use GPU for computation 15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | model.to(device) 17 | 18 | # Transformations for the input images 19 | transform = transforms.Compose([ 20 | transforms.Resize((224, 224)), 21 | transforms.ToTensor(), 22 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 23 | ]) 24 | 25 | # Load images 26 | image_paths = glob.glob(args.img_dir+"*.png") # Adjust the path and extension 27 | images = [transform(Image.open(x).convert('RGB')).unsqueeze(0).to(device) for x in image_paths] 28 | 29 | # Extract features 30 | with torch.no_grad(): 31 | features = [model(x).squeeze(0) for x in images] 32 | 33 | # Compute pairwise cosine similarity 34 | n = len(features) 35 | similarity_matrix = torch.zeros((n, n), device=device) 36 | for i in range(n): 37 | for j in range(i + 1, n): 38 | similarity = cosine_similarity(features[i].unsqueeze(0), features[j].unsqueeze(0)) 39 | similarity_matrix[i][j] = similarity 40 | similarity_matrix[j][i] = similarity # since cosine similarity is symmetric 41 | 42 | # Calculate average pairwise cosine similarity (excluding self-similarity) 43 | upper_tri_indices = torch.triu_indices(row=n, col=n, offset=1) # Offset 1 to exclude diagonal 44 | average_similarity = torch.mean(similarity_matrix[upper_tri_indices[0], upper_tri_indices[1]]) 45 | 46 | # Print the average cosine similarity 47 | print("Average Pairwise Cosine Similarity:", average_similarity.item()) 48 | 49 | if __name__ == "__main__": 50 | parser = ArgumentParser() 51 | parser.add_argument("--img_dir", default="img_post/") 52 | main(parser.parse_args()) -------------------------------------------------------------------------------- /text_to_image/eval_div_clip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | from PIL import Image 4 | import glob 5 | from torch.nn.functional import cosine_similarity 6 | from argparse import ArgumentParser 7 | from transformers import CLIPModel, CLIPProcessor # pylint: disable=g-multiple-import 8 | from transformers import CLIPTokenizer # pylint: disable=g-multiple-import 9 | 10 | def main(args): 11 | # Load the pre-trained CLIP model 12 | model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") 13 | reward_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") 14 | reward_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") 15 | 16 | # Assume CUDA is available and use GPU for computation 17 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 18 | model.to(device) 19 | 20 | # Load images 21 | image_paths = glob.glob(args.img_dir + "*.png") # Adjust the path and extension 22 | images = [Image.open(x).convert('RGB') for x in image_paths] 23 | 24 | # Process images 25 | inputs = reward_processor(images=images, return_tensors="pt") 26 | pixels = inputs.pixel_values.to(device) 27 | 28 | with torch.no_grad(): 29 | features = model.get_image_features(pixels) 30 | 31 | # Compute pairwise cosine similarity 32 | n = features.size(0) 33 | similarity_matrix = torch.zeros((n, n), device=device) 34 | for i in range(n): 35 | for j in range(i + 1, n): 36 | similarity = cosine_similarity(features[i].unsqueeze(0), features[j].unsqueeze(0)) 37 | similarity_matrix[i][j] = similarity 38 | similarity_matrix[j][i] = similarity # since cosine similarity is symmetric 39 | 40 | # Calculate average pairwise cosine similarity (excluding self-similarity) 41 | upper_tri_indices = torch.triu_indices(row=n, col=n, offset=1) # Offset 1 to exclude diagonal 42 | average_similarity = torch.mean(similarity_matrix[upper_tri_indices[0], upper_tri_indices[1]]) 43 | 44 | # Print the average cosine similarity 45 | print("Average Pairwise Cosine Similarity:", average_similarity.item()) 46 | 47 | if __name__ == "__main__": 48 | parser = ArgumentParser() 49 | parser.add_argument("--img_dir", default="output_img/green_rabbit/kl_0.01/") 50 | args = parser.parse_args() 51 | main(args) -------------------------------------------------------------------------------- /text_to_image/eval_div_dino.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | from PIL import Image 4 | import timm 5 | import glob 6 | from torch.nn.functional import cosine_similarity 7 | from argparse import ArgumentParser 8 | 9 | def main(args): 10 | # Load the pre-trained DINO model 11 | model = torch.hub.load('facebookresearch/dino:main', 'dino_vits16') 12 | model.eval() 13 | 14 | # Assume CUDA is available and use GPU for computation 15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | model.to(device) 17 | 18 | # Transformations for the input images 19 | transform = transforms.Compose([ 20 | transforms.Resize((224, 224)), 21 | transforms.ToTensor(), 22 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 23 | ]) 24 | 25 | # Load images 26 | image_paths = glob.glob(args.img_dir+"*.png") # Adjust the path and extension 27 | images = [transform(Image.open(x).convert('RGB')).unsqueeze(0).to(device) for x in image_paths] 28 | 29 | # Extract features 30 | with torch.no_grad(): 31 | features = [model(x).squeeze(0) for x in images] 32 | 33 | # Compute pairwise cosine similarity 34 | n = len(features) 35 | similarity_matrix = torch.zeros((n, n), device=device) 36 | for i in range(n): 37 | for j in range(i + 1, n): 38 | similarity = cosine_similarity(features[i].unsqueeze(0), features[j].unsqueeze(0)) 39 | similarity_matrix[i][j] = similarity 40 | similarity_matrix[j][i] = similarity # since cosine similarity is symmetric 41 | 42 | # Calculate average pairwise cosine similarity (excluding self-similarity) 43 | upper_tri_indices = torch.triu_indices(row=n, col=n, offset=1) # Offset 1 to exclude diagonal 44 | average_similarity = torch.mean(similarity_matrix[upper_tri_indices[0], upper_tri_indices[1]]) 45 | 46 | # Print the average cosine similarity 47 | print("Average Pairwise Cosine Similarity:", average_similarity.item()) 48 | 49 | if __name__ == "__main__": 50 | parser = ArgumentParser() 51 | parser.add_argument("--img_dir", default="output_img/green_rabbit/kl_0.01/") 52 | main(parser.parse_args()) -------------------------------------------------------------------------------- /text_to_image/eval_rew.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | from PIL import Image 5 | import ImageReward as imagereward 6 | from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer 7 | from argparse import ArgumentParser 8 | import utils 9 | 10 | def evaluate_image_reward(image_folder, prompt, output_file="image_rewards.json"): 11 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 12 | 13 | # Load the ImageReward model 14 | image_reward = imagereward.load("ImageReward-v1.0") 15 | image_reward.requires_grad_(False) 16 | image_reward.to(device, dtype=torch.float16) 17 | 18 | rewards = 0.0 19 | n = 0 20 | 21 | for image_file in os.listdir(image_folder): 22 | if image_file.endswith(('.png', '.jpg', '.jpeg')): 23 | image_path = os.path.join(image_folder, image_file) 24 | image = Image.open(image_path).convert("RGB") 25 | 26 | # Get image reward 27 | blip_reward, _ = utils.image_reward_get_reward(image_reward, image, prompt, torch.float16)#image_reward(image, prompt, dtype=torch.float16) 28 | 29 | # Store the rewards in a dictionary 30 | rewards += blip_reward.item() 31 | n += 1 32 | 33 | print("Avg reward: ", rewards/n) 34 | 35 | if __name__ == "__main__": 36 | parser = ArgumentParser() 37 | parser.add_argument("--img_dir") 38 | parser.add_argument("--prompt") 39 | args = parser.parse_args() 40 | image_folder = args.img_dir 41 | prompt = args.prompt 42 | evaluate_image_reward(image_folder, prompt) 43 | -------------------------------------------------------------------------------- /text_to_image/img/rabbit.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GFNOrg/diffusion-finetuning/21bfa222b606abd64873994f424a477efc29707d/text_to_image/img/rabbit.gif -------------------------------------------------------------------------------- /text_to_image/install_image_reward.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Google Research Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | #!/usr/bin/bash 16 | ROOT=".tmp" 17 | 18 | mkdir $ROOT 19 | cd $ROOT 20 | git clone https://github.com/THUDM/ImageReward.git 21 | echo "from .utils import *" > ImageReward/ImageReward/__init__.py 22 | cp -r "ImageReward/ImageReward" ../ 23 | 24 | rm -rf $ROOT 25 | -------------------------------------------------------------------------------- /text_to_image/reward_model.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Reward model.""" 17 | 18 | import numpy as np 19 | import torch 20 | from torch import nn 21 | import torch.nn.functional as F 22 | 23 | 24 | def gen_net(in_size=1024, out_size=1, h_size=512, n_layers=3, activation='sig'): 25 | """Gets NN.""" 26 | net = [] 27 | for _ in range(n_layers): 28 | net.append(nn.Linear(in_size, h_size)) 29 | net.append(nn.LeakyReLU()) 30 | in_size = h_size 31 | net.append(nn.Linear(in_size, out_size)) 32 | if activation == 'tanh': 33 | net.append(nn.Tanh()) 34 | elif activation == 'sig': 35 | net.append(nn.Sigmoid()) 36 | elif activation == 'relu': 37 | net.append(nn.ReLU()) 38 | return net 39 | 40 | 41 | class RewardModel(nn.Module): 42 | """Reward model.""" 43 | 44 | def __init__(self, in_size=1024, h_size=512, n_layers=3, activation='sig'): 45 | super().__init__() 46 | 47 | self.in_size = in_size 48 | self.h_size = h_size 49 | self.n_layers = n_layers 50 | self.activation = activation 51 | 52 | self.model = nn.Sequential( 53 | *gen_net( 54 | in_size=in_size, 55 | out_size=1, 56 | h_size=h_size, 57 | n_layers=n_layers, 58 | activation=activation, 59 | ) 60 | ) 61 | 62 | def forward( 63 | self, text, img 64 | ): 65 | input_f = torch.cat([text, img], axis=-1) 66 | score = self.model(input_f) 67 | 68 | return score 69 | 70 | 71 | class RewardModelAdapter(nn.Module): 72 | """Reward model adapter.""" 73 | 74 | def __init__( 75 | self, 76 | vis_in_size, 77 | txt_in_size, 78 | out_size, 79 | h_size=512, 80 | n_layers=3, 81 | activation='no', 82 | ): 83 | super().__init__() 84 | 85 | self.vision_model = nn.Sequential( 86 | *gen_net( 87 | in_size=vis_in_size, 88 | out_size=out_size, 89 | h_size=h_size, 90 | n_layers=n_layers, 91 | activation=activation, 92 | ) 93 | ) 94 | self.text_model = nn.Sequential( 95 | *gen_net( 96 | in_size=txt_in_size, 97 | out_size=out_size, 98 | h_size=h_size, 99 | n_layers=n_layers, 100 | activation=activation, 101 | ) 102 | ) 103 | 104 | def forward( 105 | self, text, img 106 | ): 107 | vis_f = self.vision_model(img) 108 | txt_f = self.text_model(text) 109 | return vis_f, txt_f 110 | 111 | 112 | def leaky_relu(p=0.2): 113 | return nn.LeakyReLU(p, inplace=True) 114 | 115 | 116 | class ConditionalLinear(nn.Module): 117 | """Conditional linear.""" 118 | 119 | def __init__(self, num_in, num_out, n_steps): 120 | super(ConditionalLinear, self).__init__() 121 | self.num_out = num_out 122 | self.lin = nn.Linear(num_in, num_out) 123 | self.embed = nn.Embedding(n_steps, num_out) 124 | self.embed.weight.data.uniform_() 125 | torch.nn.init.xavier_normal_(self.lin.weight) 126 | 127 | def forward(self, x, y): 128 | out = self.lin(x) 129 | gamma = self.embed(y) 130 | out = gamma.view(-1, self.num_out) * out 131 | return out 132 | 133 | 134 | class Value(nn.Module): 135 | """Value.""" 136 | 137 | def __init__(self, num_steps, img_shape): 138 | super(Value, self).__init__() 139 | self.lin1 = ConditionalLinear(int(np.prod(img_shape)), 4096, num_steps) 140 | self.lin2 = ConditionalLinear(4096, 1024, num_steps) 141 | self.lin3 = ConditionalLinear(1024, 256, num_steps) 142 | self.lin4 = nn.Linear(256, 1) 143 | torch.nn.init.xavier_normal_(self.lin4.weight) 144 | 145 | def forward(self, img, t): 146 | x = img.view(img.shape[0], -1) 147 | x = F.relu(self.lin1(x, t)) 148 | x = F.relu(self.lin2(x, t)) 149 | x = F.relu(self.lin3(x, t)) 150 | return self.lin4(x) 151 | 152 | 153 | class ValueMulti(nn.Module): 154 | """ValueMulti.""" 155 | 156 | def __init__(self, num_steps, img_shape): 157 | super(ValueMulti, self).__init__() 158 | self.lin1 = ConditionalLinear(int(np.prod(img_shape)) + 768, 256, num_steps) 159 | self.lin2 = ConditionalLinear(256, 256, num_steps) 160 | self.lin3 = ConditionalLinear(256, 256, num_steps) 161 | self.lin4 = nn.Linear(256, 1) 162 | torch.nn.init.xavier_normal_(self.lin4.weight) 163 | 164 | def forward(self, img, txt_emb, t): 165 | x = img.view(img.shape[0], -1) 166 | x = torch.cat([x, txt_emb], dim=1) 167 | # x = torch.cat([x, txt_emb], dim=1) 168 | x = F.relu(self.lin1(x, t)) 169 | x = F.relu(self.lin2(x, t)) 170 | x = F.relu(self.lin3(x, t)) 171 | return self.lin4(x) 172 | 173 | 174 | class TimeEmbedding(nn.Module): 175 | """Time embedding.""" 176 | 177 | def __init__(self, max_time, embed_dim): 178 | super(TimeEmbedding, self).__init__() 179 | self.max_time = max_time 180 | self.embed_dim = embed_dim 181 | self.embedding = nn.Embedding(max_time, embed_dim) 182 | 183 | def forward(self, time): 184 | # time is of shape [batch_size, 1] 185 | time_embed = self.embedding(time) 186 | time_embed = time_embed.view(-1, self.embed_dim) 187 | return time_embed 188 | 189 | 190 | class TDCNN(nn.Module): 191 | """TDCNN.""" 192 | 193 | def __init__(self, time_dim, max_time, embed_dim, num_classes): 194 | super(TDCNN, self).__init__() 195 | self.time_dim = time_dim 196 | self.max_time = max_time 197 | self.embed_dim = embed_dim 198 | self.time_embed = TimeEmbedding(max_time, embed_dim) 199 | self.conv1 = nn.Conv1d(in_channels=1, out_channels=16, kernel_size=3) 200 | self.conv2 = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3) 201 | self.fc1 = nn.Linear(32 * ((time_dim - 4) // 2) ** 2 + embed_dim, 64) 202 | self.fc2 = nn.Linear(64, num_classes) 203 | 204 | def forward(self, x, time): 205 | # x is of shape [batch_size, 1, time_dim] 206 | # time is of shape [batch_size, 1] 207 | x = F.relu(self.conv1(x)) 208 | x = F.max_pool1d(x, 2) 209 | x = F.relu(self.conv2(x)) 210 | x = F.max_pool1d(x, 2) 211 | x = x.view(-1, 32 * ((self.time_dim - 4) // 2) ** 2) 212 | time_embed = self.time_embed(time) 213 | x = torch.cat((x, time_embed), dim=1) 214 | x = F.relu(self.fc1(x)) 215 | x = self.fc2(x) 216 | return x 217 | -------------------------------------------------------------------------------- /text_to_image/test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Generate an image given a prompt using a trained model.""" 17 | 18 | from argparse import ArgumentParser # pylint: disable=g-importing-member 19 | import random 20 | from diffusers import DDIMScheduler # pylint: disable=g-importing-member 21 | from diffusers import StableDiffusionPipeline # pylint: disable=g-importing-member 22 | import numpy as np 23 | import torch 24 | 25 | 26 | def main(args): 27 | torch.manual_seed(args.seed) 28 | torch.cuda.manual_seed(args.seed) 29 | torch.backends.cudnn.deterministic = True 30 | torch.backends.cudnn.benchmark = False 31 | np.random.seed(args.seed) 32 | random.seed(args.seed) 33 | model_path = args.model_path 34 | pipe = StableDiffusionPipeline.from_pretrained( 35 | "runwayml/stable-diffusion-v1-5" 36 | ) 37 | pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) 38 | pipe.unet.load_attn_procs(model_path) 39 | pipe.to("cuda") 40 | prompt_list = [args.prompt for i in range(15)] 41 | for i in range(15): 42 | images = pipe(prompt=prompt_list[i], eta=1.0).images 43 | image = images[0] 44 | 45 | image.save("img_post/image{}.png".format(i)) 46 | 47 | 48 | if __name__ == "__main__": 49 | parser = ArgumentParser() 50 | parser.add_argument("--prompt", default="A green colored rabbit.") 51 | parser.add_argument("--model-path", default="./test") 52 | parser.add_argument("--seed", default=0, type=int) 53 | main(parser.parse_args()) 54 | -------------------------------------------------------------------------------- /text_to_image/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Helper functions.""" 17 | 18 | from ImageReward import ImageReward 19 | import torch 20 | 21 | 22 | def image_reward_get_reward( 23 | model, pil_image, prompt, weight_dtype 24 | ): 25 | """Gets rewards using ImageReward model.""" 26 | image = ( 27 | model.preprocess(pil_image).unsqueeze(0).to(weight_dtype).to(model.device) 28 | ) 29 | image_embeds = model.blip.visual_encoder(image) 30 | image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( 31 | model.device 32 | ) 33 | 34 | text_input = model.blip.tokenizer( 35 | prompt, 36 | padding="max_length", 37 | truncation=True, 38 | max_length=35, 39 | return_tensors="pt", 40 | ).to(model.device) 41 | text_output = model.blip.text_encoder( 42 | text_input.input_ids, 43 | attention_mask=text_input.attention_mask, 44 | encoder_hidden_states=image_embeds, 45 | encoder_attention_mask=image_atts, 46 | return_dict=True, 47 | ) 48 | txt_features = text_output.last_hidden_state[:, 0, :] 49 | rewards = model.mlp(txt_features) 50 | rewards = (rewards - model.mean) / model.std 51 | return rewards, txt_features 52 | --------------------------------------------------------------------------------