├── .gitignore ├── LICENSE ├── README.md ├── deepebm └── ebm.py ├── gflownet.py ├── network.py ├── synthetic ├── synthetic_data.py ├── synthetic_utils.py └── train.py └── utils_data.py /.gitignore: -------------------------------------------------------------------------------- 1 | **data/ 2 | *__pycache__* 3 | *mig* 4 | **ttt*/ 5 | **/*mig*/ 6 | **log/** 7 | **/log/** 8 | *.txt 9 | *.idea/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Narsil-Dinghuai Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Energy-based GFlowNets 2 | 3 | Code for our ICML 2022 paper [Generative Flow Networks for Discrete Probabilistic Modeling](https://arxiv.org/abs/2202.01361) 4 | by [Dinghuai Zhang](https://zdhnarsil.github.io/), [Nikolay Malkin](https://malkin1729.github.io/), [Zhen Liu](http://itszhen.com/), 5 | [Alexandra Volokhova](https://alexandravolokhova.github.io/), Aaron Courville, 6 | [Yoshua Bengio](https://yoshuabengio.org/). 7 | 8 | 9 | ### Example 10 | 11 | Synthetic tasks 12 | 13 | ``` 14 | python -m synthetic.train --data checkerboard --lr 1e-3 --type tblb --hid_layer 3 --hid 256 --print_every 100 --glr 1e-3 --zlr 1 --rand_coef 0 --back_ratio 0.5 --lin_k 1 --warmup_k 1e5 --with_mh 1 15 | ``` 16 | 17 | Discrete image modeling 18 | 19 | ```angular2html 20 | python -m deepebm.ebm --model mlp-256 --lr 1e-4 --type tblb --hid_layer 3 --hid 256 --glr 1e-3 --zlr 1 --rand_coef 0 --back_ratio 0.5 --lin_k 1 --warmup_k 5e4 --with_mh 1 --print_every 100 --mc_num 5 21 | ``` 22 | 23 | -------------------------------------------------------------------------------- /deepebm/ebm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torchvision 5 | import os, sys 6 | import copy 7 | import time 8 | import random 9 | import ipdb 10 | from tqdm import tqdm 11 | import argparse 12 | import network 13 | 14 | sys.path.append("/home/zhangdh/EB_GFN") 15 | from gflownet import get_GFlowNet 16 | import utils_data 17 | 18 | 19 | def makedirs(path): 20 | if not os.path.exists(path): 21 | print('creating dir: {}'.format(path)) 22 | os.makedirs(path) 23 | else: 24 | print(path, "already exist!") 25 | 26 | class EBM(nn.Module): 27 | def __init__(self, net, mean=None): 28 | super().__init__() 29 | self.net = net 30 | if mean is None: 31 | self.mean = None 32 | else: 33 | self.mean = nn.Parameter(mean, requires_grad=False) 34 | self.base_dist = torch.distributions.Bernoulli(probs=self.mean) 35 | 36 | def forward(self, x): 37 | if self.mean is None: 38 | bd = 0. 39 | else: 40 | bd = self.base_dist.log_prob(x).sum(-1) 41 | 42 | logp = self.net(x).squeeze() 43 | return logp + bd 44 | 45 | 46 | if __name__ == "__main__": 47 | parser = argparse.ArgumentParser() 48 | parser.add_argument("--device", "--d", default=0, type=int) 49 | # data 50 | parser.add_argument('--save_dir', type=str, default="./") 51 | parser.add_argument('--data', type=str, default='dmnist') 52 | parser.add_argument("--down_sample", "--ds", default=0, type=int, choices=[0, 1]) 53 | parser.add_argument('--ckpt_path', type=str, default=None) 54 | # models 55 | parser.add_argument('--model', type=str, default='mlp-256') 56 | parser.add_argument('--base_dist', "--bd", type=int, default=1, choices=[0, 1]) 57 | parser.add_argument('--gradnorm', "--gn", type=float, default=0.0) 58 | parser.add_argument('--l2', type=float, default=0.0) 59 | parser.add_argument('--n_iters', "--ni", type=lambda x: int(float(x)), default=5e4) 60 | parser.add_argument('--batch_size', "--bs", type=int, default=100) 61 | parser.add_argument('--test_batch_size', type=int, default=100) 62 | parser.add_argument('--print_every', "--pe", type=int, default=100) 63 | parser.add_argument('--viz_every', "--ve", type=int, default=2000) 64 | parser.add_argument('--eval_every', type=int, default=2000) 65 | parser.add_argument('--lr', type=float, default=.0001) 66 | parser.add_argument("--ebm_every", "--ee", type=int, default=1, help="EBM training frequency") 67 | 68 | # for GFN 69 | parser.add_argument("--type", type=str) 70 | parser.add_argument("--hid", type=int, default=256) 71 | parser.add_argument("--hid_layers", "--hl", type=int, default=5) 72 | parser.add_argument("--leaky", type=int, default=1, choices=[0, 1]) 73 | parser.add_argument("--gfn_bn", "--gbn", type=int, default=0, choices=[0, 1]) 74 | parser.add_argument("--init_zero", "--iz", type=int, default=0, choices=[0, 1]) 75 | parser.add_argument("--gmodel", "--gm", type=str, default="mlp") 76 | parser.add_argument("--train_steps", "--ts", type=int, default=1) 77 | parser.add_argument("--l1loss", "--l1l", type=int, default=0, choices=[0, 1], help="use soft l1 loss instead of l2") 78 | 79 | parser.add_argument("--with_mh", "--wm", type=int, default=0, choices=[0, 1]) 80 | parser.add_argument("--rand_k", "--rk", type=int, default=0, choices=[0, 1]) 81 | parser.add_argument("--lin_k", "--lk", type=int, default=0, choices=[0, 1]) 82 | parser.add_argument("--warmup_k", "--wk", type=lambda x: int(float(x)), default=0, help="need to use w/ lin_k") 83 | parser.add_argument("--K", type=int, default=-1, help="for gfn back forth negative sample generation") 84 | 85 | parser.add_argument("--rand_coef", "--rc", type=float, default=0, help="for tb") 86 | parser.add_argument("--back_ratio", "--br", type=float, default=0.) 87 | parser.add_argument("--clip", type=float, default=-1., help="for gfn's linf gradient clipping") 88 | parser.add_argument("--temp", type=float, default=1) 89 | parser.add_argument("--opt", type=str, default="adam", choices=["adam", "sgd"]) 90 | parser.add_argument("--glr", type=float, default=1e-3) 91 | parser.add_argument("--zlr", type=float, default=1e-1) 92 | parser.add_argument("--momentum", "--mom", type=float, default=0.0) 93 | parser.add_argument("--gfn_weight_decay", "--gwd", type=float, default=0.0) 94 | parser.add_argument('--mc_num', "--mcn", type=int, default=5) 95 | args = parser.parse_args() 96 | 97 | os.environ['CUDA_VISIBLE_DEVICES'] = "{:}".format(args.device) 98 | device = torch.device("cpu") if args.device < 0 else torch.device("cuda") 99 | 100 | args.device = device 101 | args.save_dir = os.path.join(args.save_dir, "test") 102 | makedirs(args.save_dir) 103 | 104 | print("Device:" + str(device)) 105 | print("Args:" + str(args)) 106 | 107 | before_load = time.time() 108 | train_loader, val_loader, test_loader, args = utils_data.load_dataset(args) 109 | plot = lambda p, x: torchvision.utils.save_image(x.view(x.size(0), args.input_size[0], 110 | args.input_size[1], args.input_size[2]), p, normalize=True, nrow=int(x.size(0) ** .5)) 111 | print(f"It takes {time.time() - before_load:.3f}s to load {args.data} dataset.") 112 | 113 | def preprocess(data): 114 | if args.dynamic_binarization: 115 | return torch.bernoulli(data) 116 | else: 117 | return data 118 | 119 | if args.down_sample: 120 | assert args.model.startswith("mlp-") 121 | 122 | if args.model.startswith("mlp-"): 123 | nint = int(args.model.split('-')[1]) 124 | net = network.mlp_ebm(np.prod(args.input_size), nint) 125 | elif args.model.startswith("cnn-"): 126 | nint = int(args.model.split('-')[1]) 127 | net = network.MNISTConvNet(nint) 128 | elif args.model.startswith("resnet-"): 129 | nint = int(args.model.split('-')[1]) 130 | net = network.ResNetEBM(nint) 131 | else: 132 | raise ValueError("invalid model definition") 133 | 134 | init_batch = [] 135 | for x, _ in train_loader: 136 | init_batch.append(preprocess(x)) 137 | init_batch = torch.cat(init_batch, 0) 138 | eps = 1e-2 139 | init_mean = init_batch.mean(0) * (1. - 2 * eps) + eps 140 | 141 | if args.base_dist: 142 | model = EBM(net, init_mean) 143 | else: 144 | model = EBM(net) 145 | 146 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 147 | 148 | xdim = np.prod(args.input_size) 149 | assert args.gmodel == "mlp" 150 | gfn = get_GFlowNet(args.type, xdim, args, device) 151 | model.to(device) 152 | print("model: {:}".format(model)) 153 | 154 | itr = 0 155 | while itr < args.n_iters: 156 | for x in train_loader: 157 | st = time.time() 158 | x = preprocess(x[0].to(device)) # -> (bs, 784) 159 | 160 | if args.gradnorm > 0: 161 | x.requires_grad_() 162 | 163 | update_success_rate = -1. 164 | assert "tb" in args.type 165 | train_loss, train_logZ = gfn.train(args.batch_size, scorer=lambda inp: model(inp).detach(), 166 | silent=itr % args.print_every != 0, data=x, back_ratio=args.back_ratio) 167 | 168 | if args.rand_k or args.lin_k or (args.K > 0): 169 | if args.rand_k: 170 | K = random.randrange(xdim) + 1 171 | elif args.lin_k: 172 | K = min(xdim, int(xdim * float(itr + 1) / args.warmup_k)) 173 | K = max(K, 1) 174 | elif args.K > 0: 175 | K = args.K 176 | else: 177 | raise ValueError 178 | 179 | gfn.model.eval() 180 | x_fake, delta_logp_traj = gfn.backforth_sample(x, K) 181 | 182 | delta_logp_traj = delta_logp_traj.detach() 183 | if args.with_mh: 184 | # MH step, calculate log p(x') - log p(x) 185 | lp_update = model(x_fake).squeeze() - model(x).squeeze() 186 | update_dist = torch.distributions.Bernoulli(logits=lp_update + delta_logp_traj) 187 | updates = update_dist.sample() 188 | x_fake = x_fake * updates[:, None] + x * (1. - updates[:, None]) 189 | update_success_rate = updates.mean().item() 190 | 191 | else: 192 | x_fake = gfn.sample(args.batch_size) 193 | 194 | if itr % args.ebm_every == 0: 195 | st = time.time() - st 196 | 197 | model.train() 198 | logp_real = model(x).squeeze() 199 | if args.gradnorm > 0: 200 | grad_ld = torch.autograd.grad(logp_real.sum(), x, 201 | create_graph=True)[0].flatten(start_dim=1).norm(2, 1) 202 | grad_reg = (grad_ld ** 2. / 2.).mean() 203 | else: 204 | grad_reg = torch.tensor(0.).to(device) 205 | 206 | logp_fake = model(x_fake).squeeze() 207 | obj = logp_real.mean() - logp_fake.mean() 208 | l2_reg = (logp_real ** 2.).mean() + (logp_fake ** 2.).mean() 209 | loss = -obj + grad_reg * args.gradnorm + args.l2 * l2_reg 210 | 211 | optimizer.zero_grad() 212 | loss.backward() 213 | optimizer.step() 214 | 215 | if itr % args.print_every == 0 or itr == args.n_iters - 1: 216 | print("({:5d}) | ({:.3f}s/iter) |log p(real)={:.2e}, " 217 | "log p(fake)={:.2e}, diff={:.2e}, grad_reg={:.2e}, l2_reg={:.2e} update_rate={:.1f}".format(itr, st, 218 | logp_real.mean().item(), logp_fake.mean().item(), obj.item(), grad_reg.item(), l2_reg.item(), update_success_rate)) 219 | 220 | if (itr + 1) % args.eval_every == 0: 221 | model.eval() 222 | print("GFN TEST") 223 | gfn.model.eval() 224 | gfn_test_ll = gfn.evaluate(test_loader, preprocess, args.mc_num) 225 | print("GFN Test log-likelihood ({}) with {} samples: {}".format(itr, args.mc_num, gfn_test_ll.item())) 226 | 227 | model.cpu() 228 | d = {} 229 | d['model'] = model.state_dict() 230 | d['optimizer'] = optimizer.state_dict() 231 | gfn_ckpt = {"model": gfn.model.state_dict(), "optimizer": gfn.optimizer.state_dict(),} 232 | gfn_ckpt["logZ"] = gfn.logZ.detach().cpu() 233 | torch.save(d, "{}/ckpt.pt".format(args.save_dir)) 234 | torch.save(gfn_ckpt, "{}/gfn_ckpt.pt".format(args.save_dir)) 235 | 236 | model.to(device) 237 | 238 | itr += 1 239 | if itr > args.n_iters: 240 | print("Training finished!") 241 | quit(0) 242 | -------------------------------------------------------------------------------- /gflownet.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import time 3 | import copy 4 | import random 5 | import ipdb 6 | from tqdm import tqdm 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch.distributions as dists 13 | import torchvision 14 | 15 | from network import make_mlp 16 | 17 | 18 | def get_GFlowNet(type, xdim, args, device, net=None): 19 | if type == "tbrf": 20 | return GFlowNet_Randf_TB(xdim=xdim, args=args, device=device, net=net) 21 | elif type == "tblb": 22 | return GFlowNet_LearnedPb_TB(xdim=xdim, args=args, device=device, net=net) 23 | else: 24 | raise NotImplementedError 25 | 26 | 27 | class GFlowNet_Randf_TB: 28 | # binary data, train w/ long DB loss 29 | def __init__(self, xdim, args, device, net=None): 30 | self.xdim = xdim 31 | self._hops = 0. 32 | # (bs, data_dim) -> (bs, data_dim) 33 | if net is None: 34 | self.model = make_mlp([xdim] + [args.hid] * args.hid_layers + 35 | [3 * xdim], act=(nn.LeakyReLU() if args.leaky else nn.ReLU()), with_bn=args.gfn_bn) 36 | else: 37 | self.model = net 38 | self.model.to(device) 39 | 40 | self.logZ = nn.Parameter(torch.tensor(0.)) 41 | self.logZ.to(device) 42 | 43 | self.device = device 44 | 45 | self.exp_temp = args.temp 46 | self.rand_coef = args.rand_coef # involving exploration 47 | self.init_zero = args.init_zero 48 | self.clip = args.clip 49 | self.l1loss = args.l1loss 50 | 51 | self.replay = None 52 | self.tau = args.tau if hasattr(args, "tau") else -1 53 | 54 | self.train_steps = args.train_steps 55 | 56 | param_list = [{'params': self.model.parameters(), 'lr': args.glr}, 57 | {'params': self.logZ, 'lr': args.zlr}] 58 | if args.opt == "adam": 59 | self.optimizer = torch.optim.Adam(param_list, weight_decay=args.gfn_weight_decay) 60 | elif args.opt == "sgd": 61 | self.optimizer = torch.optim.SGD(param_list, momentum=args.momentum, weight_decay=args.gfn_weight_decay) 62 | 63 | def backforth_sample(self, x, K, rand_coef=0.): 64 | assert K > 0 65 | batch_size = x.size(0) 66 | 67 | # "backward" 68 | logp_xprime2x = torch.zeros(batch_size).to(self.device) 69 | for step in range(K + 1): 70 | del_val_logits = self.model(x)[:, :2 * self.xdim] 71 | 72 | if step > 0: 73 | del_val_logits = del_val_logits.reshape(-1, self.xdim, 2) 74 | log_del_val_prob = del_val_logits.gather(1, del_locs.unsqueeze(2).repeat(1, 1, 2)).squeeze().log_softmax(1) 75 | logp_xprime2x = logp_xprime2x + log_del_val_prob.gather(1, deleted_val).squeeze(1) 76 | 77 | if step < K: 78 | if self.init_zero: 79 | # mask = (x == 0).float() 80 | mask = (x.abs() < 1e-8).float() 81 | else: 82 | mask = (x < -0.5).float() 83 | del_locs = (0 - 1e9 * mask).softmax(1).multinomial(1) # row sum not need to be 1 84 | deleted_val = x.gather(1, del_locs).long() 85 | del_values = torch.ones(batch_size, 1).to(self.device) * (0 if self.init_zero else -1) 86 | x = x.scatter(1, del_locs, del_values) 87 | 88 | # forward 89 | logp_x2xprime = torch.zeros(batch_size).to(self.device) 90 | for step in range(K): 91 | logits = self.model(x) 92 | add_logits = logits[:, :2 * self.xdim] 93 | 94 | # those have been edited 95 | if self.init_zero: 96 | mask = (x != 0).float() 97 | else: 98 | mask = (x > -0.5).float() 99 | add_prob = (1 - mask) / (1e-9 + (1 - mask).sum(1)).unsqueeze(1) 100 | add_locs = add_prob.multinomial(1) 101 | add_val_logits = add_logits.reshape(-1, self.xdim, 2) 102 | add_val_prob = add_val_logits.gather(1, add_locs.unsqueeze(2).repeat(1, 1, 2)).squeeze().softmax(1) 103 | add_values = add_val_prob.multinomial(1) 104 | if rand_coef > 0: 105 | updates = torch.bernoulli(rand_coef * torch.ones(x.shape[0])).int().to(x.device) 106 | add_values = (1 - add_values) * updates[:, None] + add_values * (1 - updates[:, None]) 107 | 108 | logp_x2xprime = logp_x2xprime + add_val_prob.log().gather(1, add_values).squeeze(1) # (bs, 1) -> (bs,) 109 | 110 | if self.init_zero: 111 | add_values = 2 * add_values - 1 112 | 113 | x = x.scatter(1, add_locs, add_values.float()) 114 | 115 | return x, logp_xprime2x - logp_x2xprime # leave MH step to out loop code 116 | 117 | def sample(self, batch_size): 118 | self.model.eval() 119 | if self.init_zero: 120 | x = torch.zeros((batch_size, self.xdim)).to(self.device) 121 | else: 122 | x = -1 * torch.ones((batch_size, self.xdim)).to(self.device) 123 | 124 | for step in range(self.xdim + 1): 125 | if step < self.xdim: 126 | logits = self.model(x) 127 | add_logits, _ = logits[:, :2 * self.xdim], logits[:, 2 * self.xdim:] 128 | 129 | if self.init_zero: 130 | mask = (x != 0).float() 131 | else: 132 | mask = (x > -0.5).float() 133 | add_prob = (1 - mask) / (1e-9 + (1 - mask).sum(1)).unsqueeze(1) 134 | add_locs = add_prob.multinomial(1) # row sum not need to be 1 135 | 136 | add_val_logits = add_logits.reshape(-1, self.xdim, 2) 137 | add_val_prob = add_val_logits.gather(1, add_locs.unsqueeze(2).repeat(1, 1, 2)).squeeze().softmax(1) 138 | add_values = add_val_prob.multinomial(1) 139 | 140 | if self.init_zero: 141 | add_values = 2 * add_values - 1 142 | 143 | x = x.scatter(1, add_locs, add_values.float()) 144 | return x 145 | 146 | def cal_logp(self, data, num: int): 147 | logp_ls = [] 148 | for _ in range(num): 149 | _, _, _, mle_loss, = tb_mle_randf_loss(lambda inp: torch.tensor(0.).to(self.device), 150 | self, data.shape[0], back_ratio=1, data=data) 151 | logpj = - mle_loss.detach().cpu() - torch.tensor(num).log() 152 | logp_ls.append(logpj.reshape(logpj.shape[0], -1)) 153 | 154 | batch_logp = torch.logsumexp(torch.cat(logp_ls, dim=1), dim=1) # (bs,) 155 | return batch_logp.mean() 156 | 157 | def evaluate(self, loader, preprocess, num, use_tqdm=False): 158 | logps = [] 159 | if use_tqdm: 160 | pbar = tqdm(loader) 161 | else: 162 | pbar = loader 163 | 164 | if hasattr(pbar, "set_description"): 165 | pbar.set_description("Calculating likelihood") 166 | self.model.eval() 167 | for x, _ in pbar: 168 | x = preprocess(x.to(self.device)) 169 | logp = self.cal_logp(x, num) 170 | logps.append(logp.reshape(-1)) 171 | if hasattr(pbar, "set_postfix"): 172 | pbar.set_postfix({"logp": f"{torch.cat(logps).mean().item():.2f}"}) 173 | 174 | return torch.cat(logps).mean() 175 | 176 | def train(self, batch_size, scorer, silent=False, data=None, back_ratio=0.,): #mle_coef=0., kl_coef=0., kl2_coef=0., pdb=False): 177 | # scorer: x -> logp 178 | if silent: 179 | pbar = range(self.train_steps) 180 | else: 181 | pbar = tqdm(range(self.train_steps)) 182 | curr_lr = self.optimizer.param_groups[0]['lr'] 183 | pbar.set_description(f"Lr={curr_lr:.1e}") 184 | 185 | train_loss = [] 186 | train_mle_loss = [] 187 | train_logZ = [] 188 | # train_kl_loss = [] 189 | self.model.train() 190 | self.model.zero_grad() 191 | torch.cuda.empty_cache() 192 | 193 | for _ in pbar: 194 | gfn_loss, forth_loss, back_loss, mle_loss = \ 195 | tb_mle_randf_loss(scorer, self, batch_size, back_ratio=back_ratio, data=data) 196 | gfn_loss, forth_loss, back_loss, mle_loss = \ 197 | gfn_loss.mean(), forth_loss.mean(), back_loss.mean(), mle_loss.mean() 198 | 199 | loss = gfn_loss 200 | 201 | self.optimizer.zero_grad() 202 | loss.backward() 203 | if self.clip > 0: 204 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.clip, norm_type="inf") 205 | self.optimizer.step() 206 | 207 | train_loss.append(gfn_loss.item()) 208 | train_mle_loss.append(mle_loss.item()) 209 | train_logZ.append(self.logZ.item()) 210 | 211 | if not silent: 212 | pbar.set_postfix({"MLE": "{:.2e}".format(mle_loss.item()), 213 | "GFN": "{:.2e}".format(gfn_loss.item()), 214 | "Forth": "{:.2e}".format(forth_loss.item()), 215 | "Back": "{:.2e}".format(back_loss.item()), 216 | "LogZ": "{:.2e}".format(self.logZ.item()), 217 | }) 218 | 219 | return np.mean(train_loss), np.mean(train_logZ) 220 | 221 | 222 | def tb_mle_randf_loss(ebm_model, gfn, batch_size, back_ratio=0., data=None): 223 | if back_ratio < 1.: 224 | if gfn.init_zero: 225 | x = torch.zeros((batch_size, gfn.xdim)).to(gfn.device) 226 | else: 227 | x = -1 * torch.ones((batch_size, gfn.xdim)).to(gfn.device) 228 | 229 | log_pf = 0. 230 | for step in range(gfn.xdim + 1): 231 | logits = gfn.model(x) 232 | add_logits, _ = logits[:, :2 * gfn.xdim], logits[:, 2 * gfn.xdim:] 233 | 234 | if step < gfn.xdim: 235 | # mask those that have been edited 236 | if gfn.init_zero: 237 | mask = (x != 0).float() 238 | else: 239 | mask = (x > -0.5).float() 240 | add_prob = (1 - mask) / (1e-9 + (1 - mask).sum(1)).unsqueeze(1) 241 | add_locs = add_prob.multinomial(1) 242 | 243 | add_val_logits = add_logits.reshape(-1, gfn.xdim, 2) 244 | add_val_prob = add_val_logits.gather(1, add_locs.unsqueeze(2).repeat(1, 1, 2)).squeeze().softmax(1) 245 | add_values = add_val_prob.multinomial(1) 246 | 247 | if gfn.rand_coef > 0: 248 | # updates = torch.distributions.Bernoulli(probs=gfn.rand_coef).sample(sample_shape=torch.Size([x.shape[0]])) 249 | updates = torch.bernoulli(gfn.rand_coef * torch.ones(x.shape[0])).int().to(x.device) 250 | add_values = (1 - add_values) * updates[:, None] + add_values * (1 - updates[:, None]) 251 | 252 | log_pf = log_pf + add_val_prob.log().gather(1, add_values).squeeze(1) # (bs, 1) -> (bs,) 253 | 254 | if gfn.init_zero: 255 | add_values = 2 * add_values - 1 256 | 257 | x = x.scatter(1, add_locs, add_values.float()) 258 | 259 | assert torch.all(x != 0) if gfn.init_zero else torch.all(x >= 0) 260 | 261 | score_value = ebm_model(x) 262 | if gfn.l1loss: 263 | forth_loss = F.smooth_l1_loss(gfn.logZ + log_pf - score_value, torch.zeros_like(score_value)) 264 | else: 265 | forth_loss = (gfn.logZ + log_pf - score_value) ** 2 266 | else: 267 | forth_loss = torch.tensor(0.).to(gfn.device) 268 | 269 | # traj is from given data back to s0, sample w/ unif back prob 270 | mle_loss = torch.tensor(0.).to(gfn.device) 271 | if back_ratio <= 0.: 272 | back_loss = torch.tensor(0.).to(gfn.device) 273 | else: 274 | assert data is not None 275 | x = data 276 | batch_size = x.size(0) 277 | back_loss = torch.zeros(batch_size).to(gfn.device) 278 | 279 | for step in range(gfn.xdim + 1): 280 | logits = gfn.model(x) 281 | del_val_logits, _ = logits[:, :2 * gfn.xdim], logits[:, 2 * gfn.xdim:] 282 | 283 | if step > 0: 284 | del_val_logits = del_val_logits.reshape(-1, gfn.xdim, 2) 285 | log_del_val_prob = del_val_logits.gather(1, del_locs.unsqueeze(2).repeat(1, 1, 2)).squeeze().log_softmax(1) 286 | mle_loss = mle_loss + log_del_val_prob.gather(1, deleted_val).squeeze(1) 287 | 288 | if step < gfn.xdim: 289 | if gfn.init_zero: 290 | mask = (x.abs() < 1e-8).float() 291 | else: 292 | mask = (x < -0.5).float() 293 | del_locs = (0 - 1e9 * mask).softmax(1).multinomial(1) # row sum not need to be 1 294 | deleted_val = x.gather(1, del_locs).long() 295 | del_values = torch.ones(batch_size, 1).to(gfn.device) * (0 if gfn.init_zero else -1) 296 | x = x.scatter(1, del_locs, del_values) 297 | 298 | # if back_ratio > 0.: 299 | if gfn.l1loss: 300 | back_loss = F.smooth_l1_loss(gfn.logZ + mle_loss - ebm_model(data).detach(), torch.zeros_like(mle_loss)) 301 | else: 302 | back_loss = (gfn.logZ + mle_loss - ebm_model(data).detach()) ** 2 303 | 304 | gfn_loss = (1 - back_ratio) * forth_loss + back_ratio * back_loss 305 | 306 | return gfn_loss, forth_loss, back_loss, mle_loss 307 | 308 | 309 | class GFlowNet_LearnedPb_TB: 310 | def __init__(self, xdim, args, device, net=None): 311 | self.xdim = xdim 312 | self._hops = 0. 313 | # (bs, data_dim) -> (bs, data_dim) 314 | if net is None: 315 | self.model = make_mlp([xdim] + [args.hid] * args.hid_layers + 316 | [3 * xdim], act=(nn.LeakyReLU() if args.leaky else nn.ReLU()), with_bn=args.gfn_bn) 317 | else: 318 | self.model = net 319 | self.model.to(device) 320 | 321 | self.logZ = nn.Parameter(torch.tensor(0.)) 322 | self.logZ.to(device) 323 | self.device = device 324 | 325 | self.exp_temp = args.temp 326 | self.rand_coef = args.rand_coef # involving exploration 327 | self.init_zero = args.init_zero 328 | self.clip = args.clip 329 | self.l1loss = args.l1loss 330 | 331 | self.replay = None 332 | self.tau = args.tau if hasattr(args, "tau") else -1 333 | 334 | self.train_steps = args.train_steps 335 | param_list = [{'params': self.model.parameters(), 'lr': args.glr}, 336 | {'params': self.logZ, 'lr': args.zlr}] 337 | if args.opt == "adam": 338 | self.optimizer = torch.optim.Adam(param_list) 339 | elif args.opt == "sgd": 340 | self.optimizer = torch.optim.SGD(param_list, momentum=args.momentum) 341 | 342 | def backforth_sample(self, x, K): 343 | assert K > 0 344 | batch_size = x.size(0) 345 | 346 | logp_xprime2x = torch.zeros(batch_size).to(self.device) 347 | logp_x2xprime = torch.zeros(batch_size).to(self.device) 348 | 349 | # "backward" 350 | for step in range(K + 1): 351 | logits = self.model(x) 352 | add_logits, del_logits = logits[:, :2 * self.xdim], logits[:, 2 * self.xdim:] 353 | 354 | if step > 0: 355 | if self.init_zero: 356 | mask = (x != 0).unsqueeze(2).repeat(1, 1, 2).reshape(batch_size, 2 * self.xdim).float() 357 | else: 358 | mask = (x > -0.5).unsqueeze(2).repeat(1, 1, 2).reshape(batch_size, 2 * self.xdim).float() 359 | add_sample = del_locs * 2 + (deleted_values == 1).long() # whether it's init_zero, this holds true 360 | logp_xprime2x = logp_xprime2x + (add_logits - 1e9 * mask).float().log_softmax(1).gather(1,add_sample).squeeze(1) 361 | 362 | if step < K: 363 | if self.init_zero: 364 | mask = (x.abs() < 1e-8).float() 365 | else: 366 | mask = (x < -0.5).float() 367 | del_logits = (del_logits - 1e9 * mask).float() 368 | del_locs = del_logits.softmax(1).multinomial(1) # row sum not need to be 1 369 | del_values = torch.ones(batch_size, 1).to(self.device) * (0 if self.init_zero else -1) 370 | deleted_values = x.gather(1, del_locs) 371 | logp_x2xprime = logp_x2xprime + del_logits.float().log_softmax(1).gather(1, del_locs).squeeze(1) 372 | x = x.scatter(1, del_locs, del_values) 373 | 374 | # forward 375 | for step in range(K + 1): 376 | logits = self.model(x) 377 | add_logits, del_logits = logits[:, :2 * self.xdim], logits[:, 2 * self.xdim:] 378 | 379 | if step > 0: 380 | if self.init_zero: 381 | mask = (x.abs() < 1e-8).float() 382 | else: 383 | mask = (x < 0).float() 384 | logp_xprime2x = logp_xprime2x + (del_logits - 1e9 * mask).log_softmax(1).gather(1, add_locs).squeeze(1) 385 | 386 | if step < K: 387 | # those have been edited 388 | if self.init_zero: 389 | mask = (x != 0).unsqueeze(2).repeat(1, 1, 2).reshape(batch_size, 2 * self.xdim).float() 390 | else: 391 | mask = (x > -0.5).unsqueeze(2).repeat(1, 1, 2).reshape(batch_size, 2 * self.xdim).float() 392 | add_logits = (add_logits - 1e9 * mask).float() 393 | add_prob = add_logits.softmax(1) 394 | 395 | # haven't used rand coef here 396 | add_sample = add_prob.multinomial(1) # row sum not need to be 1 397 | if self.init_zero: 398 | add_locs, add_values = add_sample // 2, 2 * (add_sample % 2) - 1 399 | else: 400 | add_locs, add_values = add_sample // 2, add_sample % 2 401 | 402 | logp_x2xprime = logp_x2xprime + add_logits.log_softmax(1).gather(1, add_sample).squeeze(1) 403 | x = x.scatter(1, add_locs, add_values.float()) 404 | 405 | return x, logp_xprime2x - logp_x2xprime # leave MH step to out loop code 406 | 407 | def sample(self, batch_size): 408 | self.model.eval() 409 | if self.init_zero: 410 | x = torch.zeros((batch_size, self.xdim)).to(self.device) 411 | else: 412 | x = -1 * torch.ones((batch_size, self.xdim)).to(self.device) 413 | 414 | for step in range(self.xdim + 1): 415 | logits = self.model(x) 416 | add_logits, del_logits = logits[:, :2 * self.xdim], logits[:, 2 * self.xdim:] 417 | 418 | # those have been edited 419 | if self.init_zero: 420 | mask = (x != 0).unsqueeze(2).repeat(1, 1, 2).reshape(batch_size, 2 * self.xdim).float() 421 | else: 422 | mask = (x > -0.5).unsqueeze(2).repeat(1, 1, 2).reshape(batch_size, 2 * self.xdim).float() 423 | add_prob = (add_logits - 1e9 * mask).float().softmax(1) 424 | 425 | if step < self.xdim: 426 | # add_prob = add_prob ** (1 / self.exp_temp) 427 | add_sample = add_prob.multinomial(1) # row sum not need to be 1 428 | if self.init_zero: 429 | add_locs, add_values = add_sample // 2, 2 * (add_sample % 2) - 1 430 | else: 431 | add_locs, add_values = add_sample // 2, add_sample % 2 432 | 433 | x = x.scatter(1, add_locs, add_values.float()) 434 | return x 435 | 436 | def cal_logp(self, data, num: int): 437 | logp_ls = [] 438 | for _ in range(num): 439 | _, _, _, mle_loss, data_log_pb = tb_mle_learnedpb_loss(lambda inp: torch.tensor(0.).to(self.device), self, data.shape[0], back_ratio=1, data=data) 440 | logpj = - mle_loss.detach().cpu() - data_log_pb.detach().cpu() 441 | logp_ls.append(logpj.reshape(logpj.shape[0], -1)) 442 | batch_logp = torch.logsumexp(torch.cat(logp_ls, dim=1), dim=1) # (bs,) 443 | 444 | return batch_logp.mean() - torch.tensor(num).log() 445 | 446 | def evaluate(self, loader, preprocess, num, use_tqdm=False): 447 | logps = [] 448 | if use_tqdm: 449 | pbar = tqdm(loader) 450 | else: 451 | pbar = loader 452 | 453 | if hasattr(pbar, "set_description"): 454 | pbar.set_description("Calculating likelihood") 455 | self.model.eval() 456 | for x, _ in pbar: 457 | x = preprocess(x.to(self.device)) 458 | logp = self.cal_logp(x, num) 459 | logps.append(logp.reshape(-1)) 460 | if hasattr(pbar, "set_postfix"): 461 | pbar.set_postfix({"logp": f"{torch.cat(logps).mean().item():.2f}"}) 462 | 463 | return torch.cat(logps).mean() 464 | 465 | def train(self, batch_size, scorer, silent=False, 466 | data=None, back_ratio=0.): 467 | if silent: 468 | pbar = range(self.train_steps) 469 | else: 470 | pbar = tqdm(range(self.train_steps)) 471 | curr_lr = self.optimizer.param_groups[0]['lr'] 472 | pbar.set_description(f"Alg: GFN LongDB Training, Lr={curr_lr:.1e}") 473 | 474 | train_loss = [] 475 | train_mle_loss = [] 476 | train_logZ = [] 477 | self.model.train() 478 | self.model.zero_grad() 479 | torch.cuda.empty_cache() 480 | 481 | for _ in pbar: 482 | gfn_loss, forth_loss, back_loss, mle_loss, data_log_pb = \ 483 | tb_mle_learnedpb_loss(scorer, self, batch_size, back_ratio=back_ratio, data=data) 484 | gfn_loss, forth_loss, back_loss, mle_loss, data_log_pb = \ 485 | gfn_loss.mean(), forth_loss.mean(), back_loss.mean(), mle_loss.mean(), data_log_pb.mean() 486 | 487 | loss = gfn_loss 488 | 489 | self.optimizer.zero_grad() 490 | loss.backward() 491 | if self.clip > 0: 492 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.clip, norm_type="inf") 493 | self.optimizer.step() 494 | 495 | train_loss.append(gfn_loss.item()) 496 | train_mle_loss.append(mle_loss.item()) 497 | train_logZ.append(self.logZ.item()) 498 | 499 | if not silent: 500 | pbar.set_postfix({"MLE": "{:.2e}".format(mle_loss.item()), 501 | "GFN": "{:.2e}".format(gfn_loss.item()), 502 | "Forth": "{:.2e}".format(forth_loss.item()), 503 | "Back": "{:.2e}".format(back_loss.item()), 504 | "LogZ": "{:.2e}".format(self.logZ.item()), 505 | }) 506 | 507 | return np.mean(train_loss), np.mean(train_logZ) 508 | 509 | 510 | def tb_mle_learnedpb_loss(ebm_model, gfn, batch_size, back_ratio=0., data=None): 511 | # traj is from s0 -> sf, sample by current gfn policy 512 | if back_ratio < 1.: 513 | if gfn.init_zero: 514 | x = torch.zeros((batch_size, gfn.xdim)).to(gfn.device) 515 | else: 516 | # -1 denotes "have not been edited" 517 | x = -1 * torch.ones((batch_size, gfn.xdim)).to(gfn.device) 518 | 519 | # forth_loss = 0. 520 | log_pb = 0. 521 | log_pf = 0. 522 | for step in range(gfn.xdim + 1): 523 | logits = gfn.model(x) 524 | add_logits, del_logits = logits[:, :2 * gfn.xdim], logits[:, 2 * gfn.xdim:] 525 | 526 | if step > 0: 527 | if gfn.init_zero: 528 | mask = (x.abs() < 1e-8).float() 529 | else: 530 | mask = (x < 0).float() 531 | log_pb = log_pb + (del_logits - 1e9 * mask).log_softmax(1).gather(1, add_locs).squeeze(1) 532 | # log_pb = log_pb + torch.tensor(1 / step).log().to(gfn.device) 533 | 534 | if step < gfn.xdim: 535 | # mask those that have been edited 536 | if gfn.init_zero: 537 | mask = (x != 0).unsqueeze(2).repeat(1, 1, 2).reshape(batch_size, 2 * gfn.xdim).float() 538 | else: 539 | mask = (x > -0.5).unsqueeze(2).repeat(1, 1, 2).reshape(batch_size, 2 * gfn.xdim).float() 540 | 541 | add_logits = (add_logits - 1e9 * mask).float() 542 | add_prob = add_logits.softmax(1) 543 | 544 | add_prob = add_prob ** (1 / gfn.exp_temp) 545 | add_prob = add_prob / (1e-9 + add_prob.sum(1, keepdim=True)) 546 | add_prob = (1 - gfn.rand_coef) * add_prob + \ 547 | gfn.rand_coef * (1 - mask) / (1e-9 + (1 - mask).sum(1)).unsqueeze(1) 548 | 549 | add_sample = add_prob.multinomial(1) 550 | if gfn.init_zero: 551 | add_locs, add_values = add_sample // 2, 2 * (add_sample % 2) - 1 552 | else: 553 | add_locs, add_values = add_sample // 2, add_sample % 2 554 | # P_F 555 | log_pf = log_pf + add_logits.log_softmax(1).gather(1, add_sample).squeeze(1) 556 | # update x 557 | x = x.scatter(1, add_locs, add_values.float()) 558 | 559 | assert torch.all(x != 0) if gfn.init_zero else torch.all(x >= 0) 560 | 561 | score_value = ebm_model(x) 562 | if gfn.l1loss: 563 | forth_loss = F.smooth_l1_loss(gfn.logZ + log_pf - log_pb - score_value, torch.zeros_like(score_value)) 564 | else: 565 | forth_loss = ((gfn.logZ + log_pf - log_pb - score_value) ** 2) 566 | else: 567 | forth_loss = torch.tensor(0.).to(gfn.device) 568 | 569 | mle_loss = torch.tensor(0.).to(gfn.device) # log_pf 570 | if back_ratio <= 0.: 571 | data_log_pb = torch.tensor(0.).to(gfn.device) 572 | back_loss = torch.tensor(0.).to(gfn.device) 573 | else: 574 | assert data is not None 575 | x = data 576 | batch_size = x.size(0) 577 | data_log_pb = torch.zeros(batch_size).to(gfn.device) 578 | 579 | for step in range(gfn.xdim + 1): 580 | logits = gfn.model(x) 581 | add_logits, del_logits = logits[:, :2 * gfn.xdim], logits[:, 2 * gfn.xdim:] 582 | 583 | if step > 0: 584 | if gfn.init_zero: 585 | mask = (x != 0).unsqueeze(2).repeat(1, 1, 2).reshape(batch_size, 2 * gfn.xdim).float() 586 | else: 587 | mask = (x > -0.5).unsqueeze(2).repeat(1, 1, 2).reshape(batch_size, 2 * gfn.xdim).float() 588 | 589 | add_sample = del_locs * 2 + (deleted_values == 1).long() # whether it's init_zero, this holds true 590 | add_logits = (add_logits - 1e9 * mask).float() 591 | mle_loss = mle_loss + add_logits.log_softmax(1).gather(1, add_sample).squeeze(1) 592 | 593 | if step < gfn.xdim: 594 | if gfn.init_zero: 595 | # mask = (x == 0).float() 596 | mask = (x.abs() < 1e-8).float() 597 | else: 598 | mask = (x < -0.5).float() 599 | del_logits = (del_logits - 1e9 * mask).float() 600 | del_prob = del_logits.softmax(1) 601 | del_prob = (1 - gfn.rand_coef) * del_prob + gfn.rand_coef * (1 - mask) / (1e-9 + (1 - mask).sum(1)).unsqueeze(1) 602 | del_locs = del_prob.multinomial(1) # row sum not need to be 1 603 | deleted_values = x.gather(1, del_locs) 604 | data_log_pb = data_log_pb + del_logits.log_softmax(1).gather(1, del_locs).squeeze(1) 605 | 606 | del_values = torch.ones(batch_size, 1).to(gfn.device) * (0 if gfn.init_zero else -1) 607 | x = x.scatter(1, del_locs, del_values) 608 | 609 | if gfn.l1loss: 610 | back_loss = F.smooth_l1_loss(gfn.logZ + mle_loss - data_log_pb - ebm_model(data).detach(), torch.zeros_like(mle_loss)) 611 | else: 612 | back_loss = ((gfn.logZ + mle_loss - data_log_pb - ebm_model(data).detach()) ** 2) 613 | 614 | gfn_loss = (1 - back_ratio) * forth_loss + back_ratio * back_loss 615 | mle_loss = - mle_loss 616 | 617 | return gfn_loss, forth_loss, back_loss, mle_loss, data_log_pb 618 | 619 | 620 | 621 | 622 | 623 | 624 | 625 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import ipdb 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class Swish(nn.Module): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def forward(self, x): 13 | return x * torch.sigmoid(x) 14 | 15 | 16 | def make_mlp(l, act=nn.LeakyReLU(), tail=[], with_bn=False): 17 | """makes an MLP with no top layer activation""" 18 | net = nn.Sequential(*(sum( 19 | [[nn.Linear(i, o)] + (([nn.BatchNorm1d(o), act] if with_bn else [act]) if n < len(l) - 2 else []) 20 | for n, (i, o) in enumerate(zip(l, l[1:]))], [] 21 | ) + tail)) 22 | return net 23 | 24 | 25 | def mlp_ebm(nin, nint=256, nout=1): 26 | return nn.Sequential( 27 | nn.Linear(nin, nint), 28 | Swish(), 29 | nn.Linear(nint, nint), 30 | Swish(), 31 | nn.Linear(nint, nint), 32 | Swish(), 33 | nn.Linear(nint, nout), 34 | ) -------------------------------------------------------------------------------- /synthetic/synthetic_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sklearn 3 | import sklearn.datasets 4 | 5 | 6 | class ToyDataset(object): 7 | def __init__(self, dim, data_file=None, static_data=None): 8 | if data_file is not None: 9 | self.static_data = np.load(data_file) 10 | elif static_data is not None: 11 | self.static_data = static_data 12 | else: 13 | self.static_data = None 14 | self.dim = dim 15 | 16 | def gen_batch(self, batch_size): 17 | raise NotImplementedError 18 | 19 | def data_gen(self, batch_size, auto_reset): 20 | if self.static_data is not None: 21 | num_obs = self.static_data.shape[0] 22 | while True: 23 | for pos in range(0, num_obs, batch_size): 24 | if pos + batch_size > num_obs: # the last mini-batch has fewer samples 25 | if auto_reset: # no need to use this last mini-batch 26 | break 27 | else: 28 | num_samples = num_obs - pos 29 | else: 30 | num_samples = batch_size 31 | yield self.static_data[pos : pos + num_samples, :] 32 | if not auto_reset: 33 | break 34 | np.random.shuffle(self.static_data) 35 | else: 36 | while True: 37 | yield self.gen_batch(batch_size) 38 | 39 | 40 | class OnlineToyDataset(ToyDataset): 41 | def __init__(self, data_name, discrete_dim=16): 42 | super(OnlineToyDataset, self).__init__(2) 43 | assert discrete_dim % 2 == 0 44 | self.data_name = data_name 45 | self.rng = np.random.RandomState() 46 | 47 | rng = np.random.RandomState(1) 48 | samples = inf_train_gen(self.data_name, rng, 5000) 49 | self.f_scale = np.max(np.abs(samples)) + 1 # for normalization 50 | self.int_scale = 2 ** (discrete_dim / 2 - 1) / (self.f_scale + 1) 51 | print('f_scale,', self.f_scale, 'int_scale,', self.int_scale) 52 | 53 | def gen_batch(self, batch_size): 54 | return inf_train_gen(self.data_name, self.rng, batch_size) 55 | 56 | def gen_batch_with_seed(self, batch_size, seed): 57 | rng = np.random.RandomState(seed) 58 | return inf_train_gen(self.data_name, rng, batch_size) 59 | 60 | 61 | # Dataset iterator 62 | def inf_train_gen(data, rng=None, batch_size=200): 63 | if rng is None: 64 | rng = np.random.RandomState() 65 | 66 | if data == "swissroll": 67 | data = sklearn.datasets.make_swiss_roll(n_samples=batch_size, noise=1.0, random_state=rng)[0] 68 | data = data.astype("float32")[:, [0, 2]] 69 | data /= 5 70 | return data 71 | 72 | elif data == "circles": 73 | data = sklearn.datasets.make_circles(n_samples=batch_size, factor=.5, noise=0.08, random_state=rng)[0] 74 | data = data.astype("float32") 75 | data *= 3 76 | return data 77 | 78 | elif data == "moons": 79 | data = sklearn.datasets.make_moons(n_samples=batch_size, noise=0.1, random_state=rng)[0] 80 | data = data.astype("float32") 81 | data = data * 2 + np.array([-1, -0.2]) 82 | return data 83 | 84 | elif data == "8gaussians": 85 | scale = 4. 86 | centers = [(1, 0), (-1, 0), (0, 1), (0, -1), (1. / np.sqrt(2), 1. / np.sqrt(2)), 87 | (1. / np.sqrt(2), -1. / np.sqrt(2)), (-1. / np.sqrt(2), 88 | 1. / np.sqrt(2)), (-1. / np.sqrt(2), -1. / np.sqrt(2))] 89 | centers = [(scale * x, scale * y) for x, y in centers] 90 | 91 | dataset = [] 92 | for i in range(batch_size): 93 | point = rng.randn(2) * 0.5 94 | idx = rng.randint(8) 95 | center = centers[idx] 96 | point[0] += center[0] 97 | point[1] += center[1] 98 | dataset.append(point) 99 | dataset = np.array(dataset, dtype="float32") 100 | dataset /= 1.414 101 | return dataset 102 | 103 | elif data == "pinwheel": 104 | radial_std = 0.3 105 | tangential_std = 0.1 106 | num_classes = 5 107 | num_per_class = batch_size // 5 108 | rate = 0.25 109 | rads = np.linspace(0, 2 * np.pi, num_classes, endpoint=False) 110 | 111 | features = rng.randn(num_classes*num_per_class, 2) \ 112 | * np.array([radial_std, tangential_std]) 113 | features[:, 0] += 1. 114 | labels = np.repeat(np.arange(num_classes), num_per_class) 115 | 116 | angles = rads[labels] + rate * np.exp(features[:, 0]) 117 | rotations = np.stack([np.cos(angles), -np.sin(angles), np.sin(angles), np.cos(angles)]) 118 | rotations = np.reshape(rotations.T, (-1, 2, 2)) 119 | 120 | return 2 * rng.permutation(np.einsum("ti,tij->tj", features, rotations)) 121 | 122 | elif data == "2spirals": 123 | # n = np.sqrt(np.random.rand(batch_size // 2, 1)) * 540 * (2 * np.pi) / 360 124 | # d1x = -np.cos(n) * n + np.random.rand(batch_size // 2, 1) * 0.5 125 | # d1y = np.sin(n) * n + np.random.rand(batch_size // 2, 1) * 0.5 126 | # x = np.vstack((np.hstack((d1x, d1y)), np.hstack((-d1x, -d1y)))) / 3 127 | # x += np.random.randn(*x.shape) * 0.1 128 | 129 | n = np.sqrt(rng.rand(batch_size // 2, 1)) * 540 * (2 * np.pi) / 360 130 | d1x = -np.cos(n) * n + rng.rand(batch_size // 2, 1) * 0.5 131 | d1y = np.sin(n) * n + rng.rand(batch_size // 2, 1) * 0.5 132 | x = np.vstack((np.hstack((d1x, d1y)), np.hstack((-d1x, -d1y)))) / 3 133 | x += rng.randn(*x.shape) * 0.1 134 | return x 135 | 136 | elif data == "checkerboard": 137 | # x1 = np.random.rand(batch_size) * 4 - 2 138 | # x2_ = np.random.rand(batch_size) - np.random.randint(0, 2, batch_size) * 2 139 | # x2 = x2_ + (np.floor(x1) % 2) 140 | 141 | x1 = rng.rand(batch_size) * 4 - 2 142 | x2_ = rng.rand(batch_size) - rng.randint(0, 2, batch_size) * 2 143 | x2 = x2_ + (np.floor(x1) % 2) 144 | return np.concatenate([x1[:, None], x2[:, None]], 1) * 2 145 | 146 | elif data == "line": 147 | x = rng.rand(batch_size) * 5 - 2.5 148 | y = x 149 | return np.stack((x, y), 1) 150 | elif data == "cos": 151 | x = rng.rand(batch_size) * 5 - 2.5 152 | y = np.sin(x) * 2.5 153 | return np.stack((x, y), 1) 154 | else: 155 | raise NotImplementedError -------------------------------------------------------------------------------- /synthetic/synthetic_utils.py: -------------------------------------------------------------------------------- 1 | import torch as T 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | import tqdm 7 | import random 8 | import sys, os 9 | from matplotlib import pyplot as plt 10 | from sympy.combinatorics.graycode import GrayCode 11 | import time 12 | import ipdb 13 | 14 | 15 | 16 | def get_true_samples(db, size, bm, int_salce, discrete_dim, seed=None): 17 | if seed is None: 18 | samples = float2bin(db.gen_batch(size), bm, int_salce, discrete_dim) 19 | else: 20 | samples = float2bin(db.gen_batch_with_seed(size, seed), bm, int_salce, discrete_dim) 21 | return torch.from_numpy(samples).float() 22 | 23 | def get_ebm_samples(score_func, size, inv_bm, int_scale, discrete_dim, device, gibbs_sampler=None, gibbs_steps=20): 24 | unif_dist = torch.distributions.Bernoulli(probs=0.5) 25 | ebm_samples = unif_dist.sample((size, discrete_dim)).to(device) 26 | ebm_samp_float = [] 27 | for ind in range(gibbs_steps * discrete_dim): # takes about 1s 28 | ebm_samples = gibbs_sampler.step(ebm_samples, score_func) 29 | ebm_samp_float.append(bin2float(ebm_samples.data.cpu().numpy().astype(int), inv_bm, int_scale, discrete_dim)) 30 | ebm_samp_float = np.concatenate(ebm_samp_float, axis=0) 31 | return ebm_samples, ebm_samp_float 32 | 33 | def estimate_ll(score_func, samples, n_partition=None, rand_samples=None): 34 | with torch.no_grad(): 35 | if rand_samples is None: 36 | rand_samples = torch.randint(2, (n_partition, samples.shape[1])).float().to(samples.device) 37 | n_partition = rand_samples.shape[0] 38 | f_z_list = [] 39 | for i in range(0, n_partition, samples.shape[0]): # 从0数到n_partition,每一份是samples.shape[0]大小 40 | f_z = score_func(rand_samples[i:i+samples.shape[0]]).view(-1, 1) 41 | f_z_list.append(f_z) 42 | f_z = torch.cat(f_z_list, dim=0) 43 | f_z = f_z - samples.shape[1] * np.log(0.5) - np.log(n_partition) # log(1/2)是unif的概率,importance sample的时候在分母 44 | 45 | # log_part = logsumexp(f_z) 46 | log_part = f_z.logsumexp(0) 47 | f_sample = score_func(samples) 48 | ll = f_sample - log_part 49 | 50 | return torch.mean(ll).item() 51 | 52 | 53 | def exp_hamming_sim(x, y, bd): 54 | x = x.unsqueeze(1) 55 | y = y.unsqueeze(0) 56 | d = T.sum(T.abs(x - y), dim=-1) 57 | return T.exp(-bd * d) 58 | 59 | 60 | def exp_hamming_mmd(x, y, bandwidth=0.1): 61 | x = x.float() 62 | y = y.float() 63 | 64 | with T.no_grad(): 65 | kxx = exp_hamming_sim(x, x, bd=bandwidth) 66 | idx = T.arange(0, x.shape[0], out=T.LongTensor()) 67 | kxx[idx, idx] = 0.0 68 | kxx = T.sum(kxx) / x.shape[0] / (x.shape[0] - 1) 69 | 70 | kyy = exp_hamming_sim(y, y, bd=bandwidth) 71 | idx = T.arange(0, y.shape[0], out=T.LongTensor()) 72 | kyy[idx, idx] = 0.0 73 | kyy = T.sum(kyy) / y.shape[0] / (y.shape[0] - 1) 74 | 75 | kxy = T.sum(exp_hamming_sim(x, y, bd=bandwidth)) / x.shape[0] / y.shape[0] 76 | 77 | mmd = kxx + kyy - 2 * kxy 78 | return mmd 79 | 80 | 81 | def hamming_sim(x, y): 82 | x = x.unsqueeze(1) 83 | y = y.unsqueeze(0) 84 | d = torch.sum(torch.abs(x - y), dim=-1) 85 | return x.shape[-1] - d 86 | 87 | def hamming_mmd(x, y): 88 | x = x.float() 89 | y = y.float() 90 | with torch.no_grad(): 91 | kxx = hamming_sim(x, x) 92 | idx = torch.arange(0, x.shape[0], out=torch.LongTensor()) 93 | kxx[idx, idx] = 0.0 94 | kxx = torch.sum(kxx) / x.shape[0] / (x.shape[0] - 1) 95 | 96 | kyy = hamming_sim(y, y) 97 | idx = torch.arange(0, y.shape[0], out=torch.LongTensor()) 98 | kyy[idx, idx] = 0.0 99 | kyy = torch.sum(kyy) / y.shape[0] / (y.shape[0] - 1) 100 | kxy = torch.sum(hamming_sim(x, y)) / x.shape[0] / y.shape[0] 101 | mmd = kxx + kyy - 2 * kxy 102 | return mmd 103 | 104 | 105 | def linear_mmd(x, y): 106 | x = x.float() 107 | y = y.float() 108 | with torch.no_grad(): 109 | kxx = torch.mm(x, x.transpose(0, 1)) 110 | idx = torch.arange(0, x.shape[0], out=torch.LongTensor()) 111 | kxx = kxx * (1 - torch.eye(x.shape[0]).to(x.device)) 112 | kxx = torch.sum(kxx) / x.shape[0] / (x.shape[0] - 1) 113 | 114 | kyy = torch.mm(y, y.transpose(0, 1)) 115 | idx = torch.arange(0, y.shape[0], out=torch.LongTensor()) 116 | kyy[idx, idx] = 0.0 117 | kyy = torch.sum(kyy) / y.shape[0] / (y.shape[0] - 1) 118 | kxy = torch.sum(torch.mm(y, x.transpose(0, 1))) / x.shape[0] / y.shape[0] 119 | mmd = kxx + kyy - 2 * kxy 120 | return mmd 121 | 122 | 123 | from torch.autograd import Variable, Function 124 | def get_gamma(X, bandwidth): 125 | with torch.no_grad(): 126 | x_norm = torch.sum(X ** 2, dim=1, keepdim=True) 127 | x_t = torch.transpose(X, 0, 1) 128 | x_norm_t = x_norm.view(1, -1) 129 | t = x_norm + x_norm_t - 2.0 * torch.matmul(X, x_t) 130 | dist2 = F.relu(Variable(t)).detach().data 131 | 132 | d = dist2.cpu().numpy() 133 | d = d[np.isfinite(d)] 134 | d = d[d > 0] 135 | median_dist2 = float(np.median(d)) 136 | gamma = 0.5 / median_dist2 / bandwidth 137 | return gamma 138 | 139 | def pairwise_distances(x, y=None): 140 | ''' 141 | Input: x is a Nxd matrix 142 | y is an optional Mxd matirx 143 | Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:] 144 | if y is not given then use 'y=x'. 145 | i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2 146 | ''' 147 | x_norm = (x**2).sum(1).view(-1, 1) 148 | if y is not None: 149 | y_t = torch.transpose(y, 0, 1) 150 | y_norm = (y**2).sum(1).view(1, -1) 151 | else: 152 | y_t = torch.transpose(x, 0, 1) 153 | y_norm = x_norm.view(1, -1) 154 | 155 | dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t) 156 | # Ensure diagonal is zero if x=y 157 | # if y is None: 158 | # dist = dist - torch.diag(dist.diag) 159 | return torch.clamp(dist, 0.0, np.inf) 160 | 161 | def get_kernel_mat(x, landmarks, gamma): 162 | d = pairwise_distances(x, landmarks) 163 | k = torch.exp(d * -gamma) 164 | k = k.view(x.shape[0], -1) 165 | return k 166 | 167 | def MMD(x, y, bandwidth=1.0): 168 | y = y.detach() 169 | gamma = get_gamma(x.detach(), bandwidth) 170 | kxx = get_kernel_mat(x, x, gamma) 171 | idx = torch.arange(0, x.shape[0], out=torch.LongTensor()) 172 | kxx = kxx * (1 - torch.eye(x.shape[0]).to(x.device)) 173 | kxx = torch.sum(kxx) / x.shape[0] / (x.shape[0] - 1) 174 | 175 | kyy = get_kernel_mat(y, y, gamma) 176 | idx = torch.arange(0, y.shape[0], out=torch.LongTensor()) 177 | kyy[idx, idx] = 0.0 178 | kyy = torch.sum(kyy) / y.shape[0] / (y.shape[0] - 1) 179 | kxy = torch.sum(get_kernel_mat(y, x, gamma)) / x.shape[0] / y.shape[0] 180 | mmd = kxx + kyy - 2 * kxy 181 | return mmd 182 | 183 | 184 | 185 | def get_binmap(discrete_dim, binmode): 186 | b = discrete_dim // 2 - 1 187 | all_bins = [] 188 | for i in range(1 << b): 189 | bx = np.binary_repr(i, width=discrete_dim // 2 - 1) 190 | all_bins.append('0' + bx) 191 | all_bins.append('1' + bx) 192 | vals = all_bins[:] 193 | if binmode == 'rand': 194 | print('remapping binary repr with random permute') 195 | random.shuffle(vals) 196 | elif binmode == 'gray': 197 | print('remapping binary repr with gray code') 198 | a = GrayCode(b) 199 | vals = [] 200 | for x in a.generate_gray(): 201 | vals.append('0' + x) 202 | vals.append('1' + x) 203 | else: 204 | assert binmode == 'normal' 205 | bm = {} 206 | inv_bm = {} 207 | for i, key in enumerate(all_bins): 208 | bm[key] = vals[i] 209 | inv_bm[vals[i]] = key 210 | return bm, inv_bm 211 | 212 | 213 | def compress(x, discrete_dim): 214 | bx = np.binary_repr(int(abs(x)), width=discrete_dim // 2 - 1) 215 | bx = '0' + bx if x >= 0 else '1' + bx 216 | return bx 217 | 218 | 219 | def recover(bx): 220 | x = int(bx[1:], 2) 221 | return x if bx[0] == '0' else -x 222 | 223 | 224 | def float2bin(samples, bm, int_scale, discrete_dim): 225 | bin_list = [] 226 | for i in range(samples.shape[0]): 227 | x, y = samples[i] * int_scale 228 | bx, by = compress(x, discrete_dim), compress(y, discrete_dim) 229 | bx, by = bm[bx], bm[by] 230 | bin_list.append(np.array(list(bx + by), dtype=int)) 231 | return np.array(bin_list) 232 | 233 | 234 | def bin2float(samples, inv_bm, int_scale, discrete_dim): 235 | floats = [] 236 | for i in range(samples.shape[0]): 237 | s = '' 238 | for j in range(samples.shape[1]): 239 | s += str(samples[i, j]) 240 | x, y = s[:discrete_dim // 2], s[discrete_dim // 2:] 241 | x, y = inv_bm[x], inv_bm[y] 242 | x, y = recover(x), recover(y) 243 | x /= int_scale 244 | y /= int_scale 245 | floats.append((x, y)) 246 | return np.array(floats) 247 | 248 | 249 | def plot_heat(score_func, bm, size, device, int_scale, discrete_dim, out_file=None): 250 | w = 100 251 | x = np.linspace(-size, size, w) 252 | y = np.linspace(-size, size, w) 253 | xx, yy = np.meshgrid(x, y) 254 | xx = np.reshape(xx, [-1, 1]) 255 | yy = np.reshape(yy, [-1, 1]) 256 | heat_samples = float2bin(np.concatenate((xx, yy), axis=-1), bm, int_scale, discrete_dim) 257 | heat_samples = torch.from_numpy(heat_samples).to(device).float() 258 | heat_score = F.softmax(score_func(heat_samples).view(1, -1), dim=-1) 259 | a = heat_score.view(w, w).data.cpu().numpy() 260 | a = np.flip(a, axis=0) 261 | print("energy max and min:", a.max(), a.min()) 262 | plt.imshow(a) 263 | plt.axis('equal') 264 | plt.axis('off') 265 | # if out_file is None: 266 | # out_file = os.path.join(save_dir, 'heat.pdf') 267 | plt.savefig(out_file, bbox_inches='tight') 268 | plt.close() 269 | 270 | 271 | def plot_samples(samples, out_name, lim=None, axis=True): 272 | plt.scatter(samples[:, 0], samples[:, 1], marker='.') 273 | plt.axis('equal') 274 | if lim is not None: 275 | plt.xlim(-lim, lim) 276 | plt.ylim(-lim, lim) 277 | if not axis: 278 | plt.axis('off') 279 | plt.savefig(out_name, bbox_inches='tight') 280 | plt.close() 281 | 282 | 283 | ############# Model Architecture 284 | 285 | class EnergyModel(T.nn.Module): 286 | 287 | def __init__(self, s, mid_size): 288 | super(EnergyModel, self).__init__() 289 | 290 | self.m = T.nn.Sequential(T.nn.Linear(s, mid_size), 291 | T.nn.ELU(), 292 | T.nn.Linear(mid_size, mid_size), 293 | T.nn.ELU(), 294 | T.nn.Linear(mid_size, mid_size), 295 | T.nn.ELU(), 296 | T.nn.Linear(mid_size, 1)) 297 | 298 | def forward(self, x): 299 | x = x.view((x.shape[0], -1)) 300 | x = self.m(x) 301 | 302 | return x[:, -1] 303 | -------------------------------------------------------------------------------- /synthetic/train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torchvision 5 | import os, sys 6 | 7 | import time 8 | import random 9 | import ipdb, pdb 10 | from tqdm import tqdm 11 | import argparse 12 | 13 | sys.path.append("/home/zhangdh/EB_GFN") 14 | from gflownet import get_GFlowNet 15 | 16 | sys.path.append("/home/zhangdh/EB_GFN/synthetic") 17 | from synthetic_utils import plot_heat, plot_samples,\ 18 | float2bin, bin2float, get_binmap, get_true_samples, get_ebm_samples, EnergyModel 19 | from synthetic_data import inf_train_gen, OnlineToyDataset 20 | 21 | 22 | def makedirs(path): 23 | if not os.path.exists(path): 24 | print('creating dir: {}'.format(path)) 25 | os.makedirs(path) 26 | else: 27 | print(path, "already exist!") 28 | 29 | 30 | unif_dist = torch.distributions.Bernoulli(probs=0.5) 31 | if __name__ == "__main__": 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument("--device", "--d", default=0, type=int) 34 | # data 35 | parser.add_argument('--save_dir', type=str, default="./") 36 | parser.add_argument('--data', type=str, default='circles') # 2spirals 8gaussians pinwheel circles moons swissroll checkerboard 37 | 38 | # training 39 | parser.add_argument('--n_iters', "--ni", type=lambda x: int(float(x)), default=1e5) 40 | parser.add_argument('--batch_size', "--bs", type=int, default=128) 41 | parser.add_argument('--print_every', "--pe", type=int, default=100) 42 | parser.add_argument('--eval_every', type=int, default=2000) 43 | parser.add_argument('--lr', type=float, default=.0001) 44 | parser.add_argument("--ebm_every", "--ee", type=int, default=1, help="EBM training frequency") 45 | 46 | # for GFN 47 | parser.add_argument("--type", type=str) 48 | parser.add_argument("--hid", type=int, default=512) 49 | parser.add_argument("--hid_layers", "--hl", type=int, default=3) 50 | parser.add_argument("--leaky", type=int, default=1, choices=[0, 1]) 51 | parser.add_argument("--gfn_bn", "--gbn", type=int, default=0, choices=[0, 1]) 52 | parser.add_argument("--init_zero", "--iz", type=int, default=0, choices=[0, 1], ) 53 | parser.add_argument("--gmodel", "--gm", type=str,default="mlp") 54 | parser.add_argument("--train_steps", "--ts", type=int, default=1) 55 | parser.add_argument("--l1loss", "--l1l", type=int, default=0, choices=[0, 1], help="use soft l1 loss instead of l2") 56 | 57 | parser.add_argument("--with_mh", "--wm", type=int, default=0, choices=[0, 1]) 58 | parser.add_argument("--rand_k", "--rk", type=int, default=0, choices=[0, 1]) 59 | parser.add_argument("--lin_k", "--lk", type=int, default=0, choices=[0, 1]) 60 | parser.add_argument("--warmup_k", "--wk", type=lambda x: int(float(x)), default=0, help="need to use w/ lin_k") 61 | parser.add_argument("--K", type=int, default=-1, help="for gfn back forth negative sample generation") 62 | 63 | parser.add_argument("--rand_coef", "--rc", type=float, default=0, help="for tb") 64 | parser.add_argument("--back_ratio", "--br", type=float, default=0.) 65 | parser.add_argument("--clip", type=float, default=-1., help="for gfn's linf gradient clipping") 66 | parser.add_argument("--temp", type=float, default=1) 67 | parser.add_argument("--opt", type=str, default="adam", choices=["adam", "sgd"]) 68 | parser.add_argument("--glr", type=float, default=1e-3) 69 | parser.add_argument("--zlr", type=float, default=1) 70 | parser.add_argument("--momentum", "--mom", type=float, default=0.0) 71 | parser.add_argument("--gfn_weight_decay", "--gwd", type=float, default=0.0) 72 | args = parser.parse_args() 73 | 74 | os.environ['CUDA_VISIBLE_DEVICES'] = "{:}".format(args.device) 75 | device = torch.device("cpu") if args.device < 0 else torch.device("cuda") 76 | 77 | args.save_dir = os.path.join(args.save_dir, "test") 78 | makedirs(args.save_dir) 79 | 80 | print("Device:" + str(device)) 81 | print("Args:" + str(args)) 82 | 83 | ############## Data 84 | discrete_dim = 32 85 | bm, inv_bm = get_binmap(discrete_dim, 'gray') 86 | 87 | db = OnlineToyDataset(args.data, discrete_dim) 88 | if not hasattr(args, "int_scale"): 89 | int_scale = db.int_scale 90 | else: 91 | int_scale = args.int_scale 92 | if not hasattr(args, "plot_size"): 93 | plot_size = db.f_scale 94 | else: 95 | db.f_scale = args.plot_size 96 | plot_size = args.plot_size 97 | # plot_size = 4.1 98 | 99 | batch_size = args.batch_size 100 | multiples = {'pinwheel': 5, '2spirals': 2} 101 | batch_size = batch_size - batch_size % multiples.get(args.data, 1) 102 | 103 | ############## EBM model 104 | energy_model = EnergyModel(discrete_dim, 256).to(device) 105 | optimizer = torch.optim.Adam(energy_model.parameters(), lr=args.lr) 106 | 107 | ############## GFN 108 | xdim = discrete_dim 109 | assert args.gmodel == "mlp" 110 | gfn = get_GFlowNet(args.type, xdim, args, device) 111 | 112 | energy_model.to(device) 113 | print("model: {:}".format(energy_model)) 114 | 115 | itr = 0 116 | best_val_ll = -np.inf 117 | best_itr = -1 118 | lr = args.lr 119 | while itr < args.n_iters: 120 | st = time.time() 121 | 122 | x = get_true_samples(db, batch_size, bm, int_scale, discrete_dim).to(device) 123 | 124 | update_success_rate = -1. 125 | gfn.model.train() 126 | train_loss, train_logZ = gfn.train(batch_size, 127 | scorer=lambda inp: energy_model(inp).detach(), silent=itr % args.print_every != 0, data=x, 128 | back_ratio=args.back_ratio) 129 | 130 | if args.rand_k or args.lin_k or (args.K > 0): 131 | if args.rand_k: 132 | K = random.randrange(xdim) + 1 133 | elif args.lin_k: 134 | K = min(xdim, int(xdim * float(itr + 1) / args.warmup_k)) 135 | K = max(K, 1) 136 | elif args.K > 0: 137 | K = args.K 138 | else: 139 | raise ValueError 140 | 141 | gfn.model.eval() 142 | x_fake, delta_logp_traj = gfn.backforth_sample(x, K) 143 | 144 | delta_logp_traj = delta_logp_traj.detach() 145 | if args.with_mh: 146 | # MH step, calculate log p(x') - log p(x) 147 | lp_update = energy_model(x_fake).squeeze() - energy_model(x).squeeze() 148 | update_dist = torch.distributions.Bernoulli(logits=lp_update + delta_logp_traj) 149 | updates = update_dist.sample() 150 | x_fake = x_fake * updates[:, None] + x * (1. - updates[:, None]) 151 | update_success_rate = updates.mean().item() 152 | 153 | else: 154 | x_fake = gfn.sample(batch_size) 155 | 156 | 157 | if itr % args.ebm_every == 0: 158 | st = time.time() - st 159 | 160 | energy_model.train() 161 | logp_real = energy_model(x).squeeze() 162 | 163 | logp_fake = energy_model(x_fake).squeeze() 164 | obj = logp_real.mean() - logp_fake.mean() 165 | l2_reg = (logp_real ** 2.).mean() + (logp_fake ** 2.).mean() 166 | loss = -obj 167 | 168 | optimizer.zero_grad() 169 | loss.backward() 170 | optimizer.step() 171 | 172 | if itr % args.print_every == 0 or itr == args.n_iters - 1: 173 | print("({:5d}) | ({:.3f}s/iter) cur lr= {:.2e} |log p(real)={:.2e}, " 174 | "log p(fake)={:.2e}, diff={:.2e}, update_rate={:.1f}".format( 175 | itr, st, lr, logp_real.mean().item(), logp_fake.mean().item(), obj.item(), update_success_rate)) 176 | 177 | 178 | if (itr + 1) % args.eval_every == 0: 179 | # heat map of energy 180 | plot_heat(energy_model, bm, plot_size, device, int_scale, discrete_dim, 181 | out_file=os.path.join(args.save_dir, f'heat_{itr}.pdf')) 182 | 183 | # samples of gfn 184 | gfn_samples = gfn.sample(4000).detach() 185 | gfn_samp_float = bin2float(gfn_samples.data.cpu().numpy().astype(int), inv_bm, int_scale, discrete_dim) 186 | plot_samples(gfn_samp_float, os.path.join(args.save_dir, f'gfn_samples_{itr}.pdf'), lim=plot_size) 187 | 188 | # GFN LL 189 | gfn.model.eval() 190 | logps = [] 191 | pbar = tqdm(range(10)) 192 | pbar.set_description("GFN Calculating likelihood") 193 | for _ in pbar: 194 | pos_samples_bs = get_true_samples(db, 1000, bm, int_scale, discrete_dim).to(device) 195 | logp = gfn.cal_logp(pos_samples_bs, 20) 196 | logps.append(logp.reshape(-1)) 197 | pbar.set_postfix({"logp": f"{torch.cat(logps).mean().item():.2f}"}) 198 | gfn_test_ll = torch.cat(logps).mean() 199 | 200 | print(f"Test NLL ({itr}): GFN: {-gfn_test_ll.item():.3f}") 201 | 202 | 203 | itr += 1 204 | if itr > args.n_iters: 205 | quit(0) 206 | -------------------------------------------------------------------------------- /utils_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data_utils 3 | import torchvision 4 | 5 | import numpy as np 6 | import ipdb 7 | import os 8 | 9 | 10 | # ====================================================================================================================== 11 | def load_dynamic_mnist(args, **kwargs): 12 | # set args 13 | if args.down_sample: 14 | args.input_size = [1, 14, 14] 15 | else: 16 | args.input_size = [1, 28, 28] 17 | args.input_type = 'binary' 18 | args.dynamic_binarization = True 19 | 20 | # start processing 21 | from torchvision import datasets, transforms 22 | train_loader = torch.utils.data.DataLoader(datasets.MNIST('/home/zhangdh/data', train=True, download=True, 23 | transform=transforms.Compose([transforms.ToTensor()])), batch_size=args.batch_size, shuffle=True) 24 | 25 | test_loader = torch.utils.data.DataLoader(datasets.MNIST('/home/zhangdh/data', train=False, 26 | transform=transforms.Compose([transforms.ToTensor()])), batch_size=args.batch_size, shuffle=True) 27 | 28 | # preparing data 29 | x_train = train_loader.dataset.train_data.float().numpy() / 255. 30 | x_test = test_loader.dataset.test_data.float().numpy() / 255. 31 | if args.down_sample: 32 | x_train = x_train[:, ::2, ::2] 33 | x_test = x_test[:, ::2, ::2] 34 | x_train = np.reshape(x_train, (x_train.shape[0], x_train.shape[1] * x_train.shape[2])) 35 | y_train = np.array(train_loader.dataset.train_labels.float().numpy(), dtype=int) 36 | x_test = np.reshape(x_test, (x_test.shape[0], x_test.shape[1] * x_test.shape[2])) 37 | y_test = np.array(test_loader.dataset.test_labels.float().numpy(), dtype=int) 38 | 39 | # validation set 40 | x_val = x_train[50000:60000] 41 | y_val = np.array(y_train[50000:60000], dtype=int) 42 | x_train = x_train[0:50000] 43 | y_train = np.array(y_train[0:50000], dtype=int) 44 | 45 | # binarize 46 | if args.dynamic_binarization: 47 | args.input_type = 'binary' 48 | np.random.seed(777) 49 | x_val = np.random.binomial(1, x_val) 50 | x_test = np.random.binomial(1, x_test) 51 | else: 52 | args.input_type = 'gray' 53 | 54 | train = data_utils.TensorDataset(torch.from_numpy(x_train), torch.from_numpy(y_train)) 55 | train_loader = data_utils.DataLoader(train, batch_size=args.batch_size, shuffle=True, **kwargs) 56 | 57 | validation = data_utils.TensorDataset(torch.from_numpy(x_val).float(), torch.from_numpy(y_val)) 58 | val_loader = data_utils.DataLoader(validation, batch_size=args.test_batch_size, shuffle=False, **kwargs) 59 | 60 | test = data_utils.TensorDataset(torch.from_numpy(x_test).float(), torch.from_numpy(y_test)) 61 | test_loader = data_utils.DataLoader(test, batch_size=args.test_batch_size, shuffle=False, **kwargs) 62 | 63 | return train_loader, val_loader, test_loader, args 64 | 65 | 66 | 67 | # ====================================================================================================================== 68 | def load_dataset(args, **kwargs): 69 | assert args.data in ['dynamic_mnist', "dmnist"] 70 | train_loader, val_loader, test_loader, args = load_dynamic_mnist(args, **kwargs) 71 | 72 | return train_loader, val_loader, test_loader, args 73 | --------------------------------------------------------------------------------