├── .gitignore ├── README.md ├── configs └── hydrogen.py ├── core ├── hmodel.py ├── losses.py └── samplers.py ├── launch.py ├── models ├── ema.py └── mlp.py ├── notebooks ├── annealed_Langevin.ipynb ├── gifs │ ├── am_celeba_diffusion.gif │ ├── am_celeba_inpaint.gif │ ├── am_celeba_superres.gif │ ├── am_celeba_torus.gif │ ├── am_cifar_color.gif │ ├── am_cifar_diffusion.gif │ ├── am_cifar_superres.gif │ ├── am_cifar_torus.gif │ ├── am_results.gif │ ├── dynamics_densities.gif │ ├── sm_results.gif │ └── ssm_results.gif ├── loss_plots.ipynb ├── losses_se.png ├── mmd_se.png └── visualize.ipynb └── utils ├── eval_utils.py ├── plot_utils.py └── train_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Action Matching for Schrödinger Equation Simulation 2 | 3 | We demonstrate that Action Matching can learn a wide range of stochastic dynamics by applying it to the dynamics of a quantum system evolving according to the Schrödinger equation. The Schrödinger equation describes the evolution of many quantum systems, and in particular, it describes the physics of molecular systems. Here, for the ground truth dynamics, we take the dynamics of an excited state of the hydrogen atom, which is described by the following equation 4 | 5 | $i\frac{\partial}{\partial t}\psi(x,t) = -\frac{1}{|x|}\psi(x,t) -\frac{1}{2}\nabla^2\psi(x,t).$ 6 | 7 | The function $\psi(x,t): \mathbb{R}^3\times \mathbb{R} \to \mathbb{C}$ is called a wavefunction and it completely describes the state of the quantum system. 8 | In particular, it defines the distribution of the coordinates $x$ by defining its density as $q_t(x) := |\psi(x,t)|^2$, which dynamics is defined by the dynamics of $\psi(x,t)$. 9 | Below we demonstrate the ground truth dynamics where we project the density $q_t(x)$ on three different planes. 10 | 11 | drawing 12 | 13 | In what follows we illustrate the histograms for the learned dynamics. Since the original distribution is in $\mathbb{R}^3$ we project the samples onto three different planes and draw 2d-histograms. 14 | The top row for every model corresponds to the ground truth dynamics (the training data), and the bottom rows corresspond to the learned models. 15 | 16 | #### Action Matching (AM) results visualization 17 | drawing 18 | 19 | #### Score Matching (SM) results visualization 20 | drawing 21 | 22 | #### Sliced Score Matching (SSM) results visualization 23 | drawing 24 | -------------------------------------------------------------------------------- /configs/hydrogen.py: -------------------------------------------------------------------------------- 1 | from ml_collections import config_dict 2 | import torch 3 | import numpy as np 4 | 5 | def get_config(): 6 | config = config_dict.ConfigDict() 7 | 8 | config.data = config_dict.ConfigDict() 9 | config.data.T = 14_000 10 | config.data.n_steps = 1_000 11 | config.data.batch_size = 5_000 12 | config.data.n = torch.tensor([3,2]) 13 | config.data.l = torch.tensor([2,1]) 14 | config.data.m = torch.tensor([-1,0]) 15 | config.data.c = torch.tensor([1.0+0.0j, 1.0+0.0j]) 16 | config.data.name = 'hydrogen' 17 | 18 | config.model = config_dict.ConfigDict() 19 | config.model.method = 'am' 20 | config.model.n_hid = 256 21 | config.model.savepath = config.model.method + '_' + config.data.name 22 | config.model.checkpoints = [] 23 | 24 | config.train = config_dict.ConfigDict() 25 | config.train.batch_size = 1_000 26 | config.train.lr = 1e-4 27 | config.train.warmup = 5_000 28 | config.train.grad_clip = 1.0 29 | config.train.betas = (0.9, 0.999) 30 | config.train.wd = 0.0 31 | config.train.n_iter = 100_000 32 | config.train.regen_every = 5_000 33 | config.train.save_every = 10_000 34 | config.train.eval_every = 5_000 35 | config.train.current_step = 0 36 | 37 | config.eval = config_dict.ConfigDict() 38 | config.eval.ema = 0.9999 39 | 40 | return config 41 | -------------------------------------------------------------------------------- /core/hmodel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from scipy.special import genlaguerre 5 | from scipy.special import lpmv 6 | from scipy.special import factorial2 7 | 8 | 9 | hbar = 1.0 # reduced Planck constant 10 | m_e = 1.0 # electron mass 11 | eps0 = 1.0 12 | e = 1.0 13 | a0 = 4*np.pi*eps0*hbar**2/m_e/e**2 # Bohr radius 14 | EPS = 1e-2 15 | 16 | def asslaguerre_torch(n, alpha, x): 17 | if n == 0: 18 | return torch.ones_like(x) 19 | elif n == 1: 20 | return 1.0 + alpha - x 21 | else: 22 | output = (2*n-1.0+alpha-x)*asslaguerre_torch(n-1,alpha,x) 23 | output = output-(n-1+alpha)*asslaguerre_torch(n-2,alpha,x) 24 | output = output/n 25 | return output 26 | 27 | def asslegendre_torch(m, l, x): 28 | if m < 0: 29 | m = np.abs(m) 30 | return (-1)**m*np.math.factorial(l-m)/np.math.factorial(l+m)*asslegendre_torch(m, l, x) 31 | if m == l: 32 | return (-1)**l*factorial2(2*l-1)*torch.pow(1.0-x**2, torch.tensor([l]).to(x.device)/2.0) 33 | elif m > l: 34 | return torch.zeros_like(x) 35 | else: 36 | output = (2*l-1)*x*asslegendre_torch(m, l-1, x) 37 | output = output - (l+m-1)*asslegendre_torch(m, l-2, x) 38 | output = output/(l-m) 39 | return output 40 | 41 | class EigenState: 42 | def __init__(self, n, l, m): 43 | self.n, self.l, self.m = n, l, m 44 | self.E = -hbar**2/(2.0*m_e*a0**2)/self.n**2 45 | self.L2 = hbar**2*self.l*(self.l+1) 46 | self.Lz = hbar**2*self.m 47 | 48 | def radial(self, r): 49 | n, l, m = self.n, self.l, self.m 50 | # output = torch.exp(-r/n/a0)*(2*r/(n*a0))**l 51 | output = torch.exp(-r/n/a0 + l*torch.log(2*r) - l*np.log(n*a0)) 52 | output = output*np.sqrt((2.0/n/a0)**3*np.math.factorial(n-l-1)/np.math.factorial(n+l)/2.0/n) 53 | output = output*asslaguerre_torch(n-l-1, 2*l+1, 2.0*r/(n*a0)) 54 | return output 55 | 56 | def angular(self, theta, phi): 57 | n, l, m = self.n, self.l, self.m 58 | output = asslegendre_torch(m, l, torch.cos(theta)) 59 | output = output*(-1)**m*np.sqrt((2.0*l+1.0)*np.math.factorial(l-m)/np.math.factorial(l+m)/4.0/np.pi) 60 | output = output*torch.exp(1.0j*m*phi) 61 | return output 62 | 63 | def _radial(self, r): 64 | n, l, m = self.n, self.l, self.m 65 | output = np.sqrt((2.0/n/a0)**3*np.math.factorial(n-l-1)/np.math.factorial(n+l)/2.0/n) 66 | output = output*asslaguerre_torch(n-l-1, 2*l+1, 2.0*r/(n*a0)) 67 | return output 68 | 69 | def _radial_log(self, r): 70 | n, l, m = self.n, self.l, self.m 71 | return -r/n/a0 + l*torch.log(2*r) - l*np.log(n*a0) 72 | 73 | class WaveFunction: 74 | def __init__(self, n, l, m, c0, device): 75 | assert (n < 0).sum() == 0 76 | assert (l < 0).sum() == (l >= n).sum() == 0 77 | assert (m > l).sum() == (m < -l).sum() == 0 78 | self.n, self.l, self.m, self.c0 = n, l, m, c0 79 | self.c0 = self.c0.to(device) 80 | self.c0 = self.c0/torch.sqrt(torch.sum(self.c0.abs()**2)) 81 | self.states = list(EigenState(qnum[0], qnum[1], qnum[2]) for qnum in zip(n,l,m)) 82 | self.dim = 3 83 | self.device = device 84 | 85 | def evolve_to(self, t): 86 | E = torch.tensor(list(psi.E for psi in self.states)).to(self.device) 87 | return WaveFunction(self.n, self.l, self.m, torch.exp(-1j*E*t/hbar)*self.c0, self.device) 88 | 89 | def avgH(self): 90 | E = torch.tensor(list(psi.E for psi in self.states)) 91 | return torch.sum(self.c0.abs()**2*E) 92 | 93 | def avgL2(self): 94 | L2 = torch.tensor(list(psi.L2 for psi in self.states)) 95 | return torch.sum(self.c0.abs()**2*L2) 96 | 97 | def avgLz(self): 98 | Lz = torch.tensor(list(psi.Lz for psi in self.states)) 99 | return torch.sum(self.c0.abs()**2*Lz) 100 | 101 | def at(self, x): 102 | r = torch.sqrt(x[:,0]**2 + x[:,1]**2 + x[:,2]**2+EPS).flatten() 103 | theta = torch.atan2(torch.sqrt(x[:,0]**2 + x[:,1]**2+EPS),x[:,2]).flatten() 104 | x_coord = torch.sign(x[:,0])*(torch.abs(x[:,0])+EPS) 105 | phi = torch.atan2(x[:,1],x_coord).flatten() 106 | return self.at_polar(r, theta, phi) 107 | 108 | def at_polar(self, r, theta, phi): 109 | assert r.shape == theta.shape == phi.shape 110 | output = 1j*torch.zeros_like(r) 111 | for i in range(len(self.states)): 112 | psi_i = self.states[i] 113 | output += self.c0[i]*psi_i.radial(r)*psi_i.angular(theta, phi) 114 | return output 115 | 116 | def log_prob(self, x): 117 | # x.shape = [num_points, 3] 118 | r = torch.sqrt(x[:,0]**2 + x[:,1]**2 + x[:,2]**2+EPS).flatten() 119 | z = torch.sign(x[:,2])*(torch.abs(x[:,2])+EPS) 120 | theta = torch.atan2(torch.sqrt(x[:,0]**2 + x[:,1]**2+EPS),z).flatten() 121 | x_coord = torch.sign(x[:,0])*(torch.abs(x[:,0])+EPS) 122 | phi = torch.atan2(x[:,1],x_coord).flatten() 123 | 124 | radial_log = torch.stack([psi._radial_log(r) for psi in self.states]) 125 | angular = torch.stack([psi._radial(r)*psi.angular(theta, phi) for psi in self.states]) 126 | coords = self.c0.view([-1,1]) 127 | max_log, _ = torch.max(radial_log, dim=0) 128 | psi = (torch.exp(radial_log-max_log)*angular*coords).sum(0) 129 | output = 2*torch.log(psi.abs()) + 2*max_log 130 | return output 131 | 132 | 133 | class BohmianDynamics: 134 | def __init__(self, wave_function, samples): 135 | self.psi = wave_function 136 | self.samples = samples 137 | 138 | def propagate(self, dt): 139 | samples = self.samples 140 | samples.requires_grad = True 141 | v = torch.autograd.grad(self.psi.at(samples).angle().sum(), samples)[0] 142 | samples.data += dt*v 143 | samples.requires_grad = False 144 | self.samples = samples 145 | self.psi = self.psi.evolve_to(dt) 146 | -------------------------------------------------------------------------------- /core/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import math 4 | 5 | 6 | class AMLoss: 7 | def __init__(self, s, q_t, config, device): 8 | self.u0 = 0.5 9 | self.batch_size = config.train.batch_size 10 | self.s = s 11 | self.q_t = q_t 12 | self.device = device 13 | 14 | def sample_t(self, n): 15 | u = (self.u0 + np.sqrt(2)*np.arange(n)) % 1 16 | return torch.tensor(u).view(-1,1).to(self.device) 17 | 18 | def eval_loss(self): 19 | q_t, s = self.q_t, self.s 20 | bs = self.batch_size 21 | 22 | t = self.sample_t(bs) 23 | x_t, t = q_t(t) 24 | x_t.requires_grad, t.requires_grad = True, True 25 | s_t = s(t, x_t) 26 | assert (2 == s_t.dim()) 27 | dsdt, dsdx = torch.autograd.grad(s_t.sum(), [t, x_t], create_graph=True, retain_graph=True) 28 | x_t.requires_grad, t.requires_grad = False, False 29 | 30 | loss = 0.5*(dsdx**2).sum(1, keepdim=True) + dsdt.sum(1, keepdim=True) 31 | loss = loss.squeeze() 32 | 33 | t_0 = torch.zeros([bs, 1], device=self.device) 34 | x_0, _ = q_t(t_0) 35 | loss = loss + s(t_0,x_0).squeeze() 36 | t_1 = torch.ones([bs, 1], device=self.device) 37 | x_1, _ = q_t(t_1) 38 | loss = loss - s(t_1,x_1).squeeze() 39 | return loss.mean() 40 | 41 | 42 | class SMLoss: 43 | def __init__(self, s, q_t, config, device, sliced=True): 44 | self.u0 = 0.5 45 | self.batch_size = config.train.batch_size 46 | self.s = s 47 | self.q_t = q_t 48 | self.device = device 49 | self.div = div 50 | if sliced: 51 | self.div = divH 52 | 53 | def sample_t(self, n): 54 | u = (self.u0 + np.sqrt(2)*np.arange(n)) % 1 55 | return torch.tensor(u).view(-1,1).to(self.device) 56 | 57 | def eval_loss(self): 58 | q_t, s = self.q_t, self.s 59 | bs = self.batch_size 60 | 61 | t = self.sample_t(bs) 62 | x_t, t = q_t(t) 63 | dsdx = self.div(s, t, x_t, create_graph=True) 64 | 65 | loss = 0.5*(s(t, x_t)**2).sum(1, keepdim=True) + dsdx 66 | loss = loss.squeeze() 67 | return loss.mean() 68 | 69 | def div(v, t, x, create_graph=False): 70 | f = lambda x: v(t, x).sum(0) 71 | J = torch.autograd.functional.jacobian(f, x, create_graph=create_graph).swapaxes(0,1) 72 | return J.diagonal(dim1=1,dim2=2).sum(1, keepdim=True) 73 | 74 | def divH(v, t, x, create_graph=False): 75 | eps = torch.randint_like(x, low=0, high=2).float() * 2 - 1. 76 | x.requires_grad = True 77 | dxdt = v(t, x) 78 | div = (eps*torch.autograd.grad(dxdt, x, grad_outputs=eps, create_graph=create_graph)[0]).sum(1) 79 | x.requires_grad = False 80 | return div 81 | -------------------------------------------------------------------------------- /core/samplers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class RWMH: 6 | def __init__(self, target, sigma=5.0): 7 | self.device = target.device 8 | self.target = target 9 | self.sigma = sigma 10 | 11 | @torch.no_grad() 12 | def log_prob(self, x): 13 | return self.target.log_prob(x).view([-1,1]) 14 | 15 | def sample_n(self, x_0, n): 16 | # x_0.shape = [batch_size, dim] 17 | x = x_0.clone() 18 | log_p = self.log_prob(x) 19 | ar = torch.zeros_like(log_p) 20 | for i in range(n): 21 | _x = x + self.sigma*torch.empty_like(x).normal_() 22 | _log_p = self.log_prob(_x) 23 | accept_mask = (_log_p - log_p > torch.log(torch.zeros_like(log_p).uniform_())).float() 24 | x = accept_mask*_x + (1-accept_mask)*x 25 | log_p = accept_mask*_log_p + (1-accept_mask)*log_p 26 | ar += accept_mask 27 | ar = ar/n 28 | return x, ar 29 | -------------------------------------------------------------------------------- /launch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import argparse 4 | import json 5 | 6 | import torch 7 | import numpy as np 8 | import wandb 9 | 10 | from torch import nn 11 | from tqdm.auto import tqdm, trange 12 | 13 | from models.mlp import * 14 | from models import ema 15 | from utils.train_utils import * 16 | 17 | from core.hmodel import * 18 | from core.losses import * 19 | 20 | 21 | def main(args): 22 | device = torch.device('cuda') 23 | 24 | model, data_gen, loss, config = prepare_hydrogen(device) 25 | config.model.savepath = os.path.join(args.checkpoint_dir, config.model.savepath) 26 | config.train.wandbid = wandb.util.generate_id() 27 | 28 | wandb.login() 29 | wandb.init(id=config.train.wandbid, 30 | project=config.data.name, 31 | resume="allow", 32 | config=json.loads(config.to_json_best_effort())) 33 | os.environ["WANDB_RESUME"] = "allow" 34 | os.environ["WANDB_RUN_ID"] = config.train.wandbid 35 | 36 | optim = torch.optim.Adam(model.parameters(), lr=config.train.lr, betas=config.train.betas, 37 | eps=1e-8, weight_decay=config.train.wd) 38 | ema_ = ema.ExponentialMovingAverage(model.parameters(), decay=config.eval.ema) 39 | train(model, ema_, loss, data_gen, optim, config, device) 40 | 41 | if __name__ == "__main__": 42 | parser = argparse.ArgumentParser( 43 | description='' 44 | ) 45 | 46 | parser.add_argument( 47 | '--checkpoint_dir', 48 | type=str, 49 | help='path to save and look for the checkpoint file', 50 | default=os.getcwd() 51 | ) 52 | 53 | main(parser.parse_args()) 54 | -------------------------------------------------------------------------------- /models/ema.py: -------------------------------------------------------------------------------- 1 | # Modified from https://raw.githubusercontent.com/fadel/pytorch_ema/master/torch_ema/ema.py 2 | 3 | from __future__ import division 4 | from __future__ import unicode_literals 5 | 6 | import torch 7 | 8 | 9 | # Partially based on: https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py 10 | class ExponentialMovingAverage: 11 | """ 12 | Maintains (exponential) moving average of a set of parameters. 13 | """ 14 | 15 | def __init__(self, parameters, decay, use_num_updates=True): 16 | """ 17 | Args: 18 | parameters: Iterable of `torch.nn.Parameter`; usually the result of 19 | `model.parameters()`. 20 | decay: The exponential decay. 21 | use_num_updates: Whether to use number of updates when computing 22 | averages. 23 | """ 24 | if decay < 0.0 or decay > 1.0: 25 | raise ValueError('Decay must be between 0 and 1') 26 | self.decay = decay 27 | self.num_updates = 0 if use_num_updates else None 28 | self.shadow_params = [p.clone().detach() 29 | for p in parameters if p.requires_grad] 30 | self.collected_params = [] 31 | 32 | def update(self, parameters): 33 | """ 34 | Update currently maintained parameters. 35 | 36 | Call this every time the parameters are updated, such as the result of 37 | the `optimizer.step()` call. 38 | 39 | Args: 40 | parameters: Iterable of `torch.nn.Parameter`; usually the same set of 41 | parameters used to initialize this object. 42 | """ 43 | decay = self.decay 44 | if self.num_updates is not None: 45 | self.num_updates += 1 46 | decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates)) 47 | one_minus_decay = 1.0 - decay 48 | with torch.no_grad(): 49 | parameters = [p for p in parameters if p.requires_grad] 50 | for s_param, param in zip(self.shadow_params, parameters): 51 | s_param.sub_(one_minus_decay * (s_param - param)) 52 | 53 | def copy_to(self, parameters): 54 | """ 55 | Copy current parameters into given collection of parameters. 56 | 57 | Args: 58 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 59 | updated with the stored moving averages. 60 | """ 61 | parameters = [p for p in parameters if p.requires_grad] 62 | for s_param, param in zip(self.shadow_params, parameters): 63 | if param.requires_grad: 64 | param.data.copy_(s_param.data) 65 | 66 | def store(self, parameters): 67 | """ 68 | Save the current parameters for restoring later. 69 | 70 | Args: 71 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 72 | temporarily stored. 73 | """ 74 | self.collected_params = [param.clone() for param in parameters] 75 | 76 | def restore(self, parameters): 77 | """ 78 | Restore the parameters stored with the `store` method. 79 | Useful to validate the model with EMA parameters without affecting the 80 | original optimization process. Store the parameters before the 81 | `copy_to` method. After validation (or model saving), use this to 82 | restore the former parameters. 83 | 84 | Args: 85 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 86 | updated with the stored parameters. 87 | """ 88 | for c_param, param in zip(self.collected_params, parameters): 89 | param.data.copy_(c_param.data) 90 | 91 | def state_dict(self): 92 | return dict(decay=self.decay, num_updates=self.num_updates, 93 | shadow_params=self.shadow_params) 94 | 95 | def load_state_dict(self, state_dict): 96 | self.decay = state_dict['decay'] 97 | self.num_updates = state_dict['num_updates'] 98 | self.shadow_params = state_dict['shadow_params'] 99 | -------------------------------------------------------------------------------- /models/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | class ActionNet(nn.Module): 6 | def __init__(self, net): 7 | super(ActionNet, self).__init__() 8 | self.net = net 9 | 10 | def forward(self, t, x): 11 | h = self.net(t, x) 12 | out = 0.5*((x-h)**2).sum(dim=1, keepdim=True) 13 | return out 14 | 15 | def propagate(self, t, x, dt): 16 | x.requires_grad = True 17 | v = torch.autograd.grad(self(t,x).sum(), x)[0] 18 | x.data += dt*v 19 | x.requires_grad = False 20 | return x 21 | 22 | 23 | class ScoreNet(nn.Module): 24 | def __init__(self, net): 25 | super(ScoreNet, self).__init__() 26 | self.net = net 27 | 28 | def forward(self, t, x): 29 | return 1e-2*self.net(t, x) 30 | 31 | def propagate(self, t, x, dt, eps=15.0, n_steps=5): 32 | for _ in range(n_steps): 33 | x.data += 0.5*eps*self(t+dt, x) + math.sqrt(eps)*torch.randn_like(x) 34 | return x 35 | 36 | 37 | class MLP(nn.Module): 38 | def __init__(self, n_dims=4, n_out=3, n_hid=512, layer=nn.Linear, relu=False): 39 | super(MLP, self).__init__() 40 | self._built = False 41 | self.net = nn.Sequential( 42 | layer(n_dims, n_hid), 43 | nn.LeakyReLU(.2) if relu else nn.SiLU(n_hid), 44 | layer(n_hid, n_hid), 45 | nn.LeakyReLU(.2) if relu else nn.SiLU(n_hid), 46 | layer(n_hid, n_hid), 47 | nn.LeakyReLU(.2) if relu else nn.SiLU(n_hid), 48 | layer(n_hid, n_hid), 49 | nn.LeakyReLU(.2) if relu else nn.SiLU(n_hid), 50 | layer(n_hid, n_out) 51 | ) 52 | 53 | def forward(self, t, x): 54 | x = x.view(x.size(0), -1) 55 | t = t.view(t.size(0), 1) 56 | h = torch.hstack([t,x]) 57 | h = self.net(h) 58 | return h 59 | -------------------------------------------------------------------------------- /notebooks/annealed_Langevin.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "de3d2fa2", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%matplotlib inline" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 2, 16 | "id": "aefa42d3", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import torch\n", 21 | "import numpy as np\n", 22 | "\n", 23 | "import sys\n", 24 | "sys.path.append('../')\n", 25 | "from core.hmodel import *\n", 26 | "from core.losses import *\n", 27 | "from models.mlp import MLP\n", 28 | "from models import ema\n", 29 | "from utils.train_utils import *\n", 30 | "from utils.plot_utils import *\n", 31 | "from utils.eval_utils import *\n", 32 | "\n", 33 | "from ml_collections import config_dict\n", 34 | "from tqdm.auto import tqdm, trange" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 3, 40 | "id": "ca4cf6cd", 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "def propagate(t,x,dt,eps=10.0,n_steps=5):\n", 45 | " x_t, gen_t = data_gen.q_t(t + dt, replace=False)\n", 46 | " gen_t = gen_t*data_gen.T\n", 47 | " psi_t = data_gen.psi.evolve_to(gen_t[0].to(device))\n", 48 | " for _ in range(n_steps):\n", 49 | " x.requires_grad = True\n", 50 | " nabla_logp = torch.autograd.grad(psi_t.log_prob(x.to(device)).sum(), x)[0]\n", 51 | " x.requires_grad = False\n", 52 | " x.data += 0.5*eps*nabla_logp + math.sqrt(eps)*torch.randn_like(x)\n", 53 | " return x\n", 54 | "\n", 55 | "def evaluate(data_gen, eps, n_steps, device, config):\n", 56 | " N = config.data.n_steps//4\n", 57 | " x = data_gen.samples[0].to(device)\n", 58 | " dt = 1./config.data.n_steps\n", 59 | " t = torch.zeros([x.shape[0],1], device=device)\n", 60 | " n_evals = 5\n", 61 | " eval_every = N//n_evals\n", 62 | " avg_mmd = 0.0\n", 63 | " mmd = MMDStatistic(config.data.batch_size, config.data.batch_size)\n", 64 | " for i in range(N):\n", 65 | " x = propagate(t, x, dt, eps, n_steps)\n", 66 | " t.data += dt\n", 67 | " if ((i+1) % eval_every) == 0:\n", 68 | " x_t, gen_t = data_gen.q_t(t, replace=False)\n", 69 | " gen_t = gen_t*data_gen.T\n", 70 | " cur_mmd = mmd(x, x_t, 1e-4*torch.ones(x.shape[1], device=device))\n", 71 | " avg_mmd += cur_mmd.abs().cpu().numpy()/n_evals\n", 72 | " return avg_mmd\n", 73 | "\n", 74 | "def plot_frame(x, x_gt, kde=False):\n", 75 | " fig, ax = plt.subplots(2,3, figsize=(16,10))\n", 76 | " if kde:\n", 77 | " plot_samples_kde(x_gt, axes=ax[0])\n", 78 | " plot_samples_kde(x, axes=ax[1])\n", 79 | " else:\n", 80 | " plot_samples(x_gt, bins=40, axes=ax[0])\n", 81 | " plot_samples(x, bins=40, axes=ax[1])\n", 82 | " for j in range(3):\n", 83 | " ax[0,j].set_title('training data', fontsize=15)\n", 84 | " ax[1,j].set_title(f'samples from model', fontsize=15)\n", 85 | " plt.draw()\n", 86 | " fig.tight_layout()\n", 87 | " return fig" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 4, 93 | "id": "06063f4c", 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "device = torch.device('cuda')\n", 98 | "model, data_gen, loss, config = prepare_hydrogen(device)" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 5, 104 | "id": "4da7ce14", 105 | "metadata": {}, 106 | "outputs": [ 107 | { 108 | "data": { 109 | "application/vnd.jupyter.widget-view+json": { 110 | "model_id": "48e9394110c64be0b06f735ce320ac69", 111 | "version_major": 2, 112 | "version_minor": 0 113 | }, 114 | "text/plain": [ 115 | " 0%| | 0/30 [00:00" 141 | ] 142 | }, 143 | "metadata": {}, 144 | "output_type": "display_data" 145 | } 146 | ], 147 | "source": [ 148 | "plt.plot(eps_space, mmd_plot)\n", 149 | "plt.grid()" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 7, 155 | "id": "d9a8f717", 156 | "metadata": {}, 157 | "outputs": [ 158 | { 159 | "data": { 160 | "application/vnd.jupyter.widget-view+json": { 161 | "model_id": "285c745f5e8548c69cda83d59df6354a", 162 | "version_major": 2, 163 | "version_minor": 0 164 | }, 165 | "text/plain": [ 166 | " 0%| | 0/10 [00:00)}" 108 | ] 109 | }, 110 | "execution_count": 7, 111 | "metadata": {}, 112 | "output_type": "execute_result" 113 | } 114 | ], 115 | "source": [ 116 | "evaluate(model, ema_, data_gen, device, config)" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 8, 122 | "id": "b7b7b1de", 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "def plot_frame(x, x_gt, kde=False):\n", 127 | " fig, ax = plt.subplots(2,3, figsize=(16,10))\n", 128 | " if kde:\n", 129 | " plot_samples_kde(x_gt, axes=ax[0])\n", 130 | " plot_samples_kde(x, axes=ax[1])\n", 131 | " else:\n", 132 | " plot_samples(x_gt, bins=40, axes=ax[0])\n", 133 | " plot_samples(x, bins=40, axes=ax[1])\n", 134 | " for j in range(3):\n", 135 | " ax[0,j].set_title('training data', fontsize=15)\n", 136 | " ax[1,j].set_title(f'samples from {config.model.method}', fontsize=15)\n", 137 | " plt.draw()\n", 138 | " fig.tight_layout()\n", 139 | " return fig" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 9, 145 | "id": "e0aebc6d", 146 | "metadata": { 147 | "scrolled": false 148 | }, 149 | "outputs": [ 150 | { 151 | "data": { 152 | "application/vnd.jupyter.widget-view+json": { 153 | "model_id": "289a3829da92490e823a28a3fa536cce", 154 | "version_major": 2, 155 | "version_minor": 0 156 | }, 157 | "text/plain": [ 158 | " 0%| | 0/1000 [00:00" 169 | ] 170 | }, 171 | "metadata": {}, 172 | "output_type": "display_data" 173 | } 174 | ], 175 | "source": [ 176 | "x = data_gen.samples[0].to(device)\n", 177 | "t = torch.zeros([len(x),1]).to(device)\n", 178 | "dt = 1.0/config.data.n_steps\n", 179 | "for i in trange(config.data.n_steps):\n", 180 | " x = model.propagate(t, x, dt)\n", 181 | " t += dt\n", 182 | "\n", 183 | "plot_frame(x, data_gen.samples[-1]).show()" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": 10, 189 | "id": "356b669e", 190 | "metadata": {}, 191 | "outputs": [], 192 | "source": [ 193 | "!mkdir gifs/ssm_results" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 11, 199 | "id": "d54fa282", 200 | "metadata": {}, 201 | "outputs": [ 202 | { 203 | "data": { 204 | "application/vnd.jupyter.widget-view+json": { 205 | "model_id": "e892eb913d934395a54f1d0224881d67", 206 | "version_major": 2, 207 | "version_minor": 0 208 | }, 209 | "text/plain": [ 210 | " 0%| | 0/1000 [00:00 0: 82 | for g in optim.param_groups: 83 | g['lr'] = config.train.lr * np.minimum(current_step / config.train.warmup, 1.0) 84 | if config.train.grad_clip >= 0: 85 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.train.grad_clip) 86 | optim.step() 87 | ema.update(model.parameters()) 88 | 89 | if (current_step % 50) == 0: 90 | logging_dict = { 91 | 'loss_' + config.model.method: loss_total.detach().cpu() 92 | } 93 | wandb.log(logging_dict, step=current_step) 94 | 95 | if ((current_step % config.train.regen_every) == 0) and current_step > 0: 96 | data_gen.gen_data() 97 | if ((current_step % config.train.eval_every) == 0): 98 | metric_dict = evaluate(model, ema, data_gen, device, config) 99 | wandb.log(metric_dict, step=current_step) 100 | if ((current_step % config.train.save_every) == 0): 101 | save(model, ema, optim, loss, config) 102 | config.train.current_step = current_step 103 | save(model, ema, optim, loss, config) 104 | metric_dict = evaluate(model, ema, data_gen, device, config) 105 | wandb.log(metric_dict, step=current_step) 106 | 107 | def save(model, ema, optim, loss, config): 108 | checkpoint_name = config.model.savepath + '_%d.cpt' % config.train.current_step 109 | config.model.checkpoints.append(checkpoint_name) 110 | torch.save({'model': model.state_dict(), 111 | 'ema': ema.state_dict(), 112 | 'optim': optim.state_dict()}, checkpoint_name) 113 | torch.save(config, config.model.savepath + '.config') 114 | 115 | def evaluate(model, ema, data_gen, device, config): 116 | ema.store(model.parameters()) 117 | ema.copy_to(model.parameters()) 118 | model.eval() 119 | ######## evaluation ######## 120 | metric_dict = {} 121 | x = data_gen.samples[0].to(device) 122 | dt = 1./config.data.n_steps 123 | t = torch.zeros([x.shape[0],1], device=device) 124 | n_evals = 10 125 | eval_every = config.data.n_steps//n_evals 126 | metric_dict['avg_mmd'] = 0.0 127 | metric_dict['score_loss'] = 0.0 128 | mmd = MMDStatistic(config.data.batch_size, config.data.batch_size) 129 | for i in range(config.data.n_steps): 130 | x = model.propagate(t, x, dt) 131 | t.data += dt 132 | if ((i+1) % eval_every) == 0: 133 | x_t, gen_t = data_gen.q_t(t, replace=False) 134 | gen_t = gen_t*data_gen.T 135 | cur_mmd = mmd(x, x_t, 1e-4*torch.ones(x.shape[1], device=device)) 136 | metric_dict['avg_mmd'] += cur_mmd.abs().cpu().numpy()/n_evals 137 | if config.model.method in {'sm', 'ssm'}: 138 | psi_t = data_gen.psi.evolve_to(gen_t[0].to(device)) 139 | x_t.requires_grad = True 140 | nabla_logp = torch.autograd.grad(psi_t.log_prob(x_t.to(device)).sum(), x_t)[0] 141 | x_t.requires_grad = False 142 | loss = 0.5*((nabla_logp - model(t, x_t))**2).sum(1) 143 | metric_dict['score_loss'] += loss.mean()/n_evals 144 | ############################ 145 | ema.restore(model.parameters()) 146 | return metric_dict 147 | --------------------------------------------------------------------------------