├── LICENSE ├── README.md ├── deno_score_match.py ├── fokker_plank.py ├── forward.py ├── langevin_dyn.py ├── probing_score_learning.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Ayan Das 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Anonymous code repo for ICLR Blog post track 2 | 3 | More details will be added later. -------------------------------------------------------------------------------- /deno_score_match.py: -------------------------------------------------------------------------------- 1 | from tqdm import trange 2 | 3 | import numpy as np 4 | import torch as th 5 | import matplotlib.pyplot as plt 6 | 7 | from utils import ( 8 | GDensity, 9 | GMixDensity, 10 | EmpiricalDensity 11 | ) 12 | 13 | g1 = GDensity([0., 0.], [ 14 | [1., 0.9], 15 | [0.9, 1.] 16 | ]) 17 | 18 | g2 = GDensity([0., 0.], [ 19 | [1., -0.9], 20 | [-0.9, 1.] 21 | ]) 22 | 23 | G = EmpiricalDensity(GMixDensity([g2, g1]).sample()) 24 | 25 | GRAN = 15 26 | sigma = 0.2 27 | N_ARROWS = 0 28 | 29 | fig = plt.figure(figsize=(7, 7)) 30 | ax = plt.gca() 31 | ax.set_xlim([-4, 4]); ax.set_ylim([-4, 4]) 32 | 33 | x_grid, y_grid = np.meshgrid(np.linspace(-4, 4, GRAN), np.linspace(-4, 4, GRAN)) 34 | X = np.concatenate([x_grid.reshape((GRAN**2, 1)), y_grid.reshape((GRAN**2, 1))], -1) 35 | 36 | i = 0 37 | for _ in trange(200): 38 | ax.cla() 39 | 40 | G.plot_density(ax, cmap='Reds', alpha=0.5) 41 | # ax.scatter(G.data[:, 0], G.data[:, 1], color='red', alpha=0.2, s=1) 42 | 43 | S = G.score(X) 44 | S_x_grid, S_y_grid = S[:, 0, None].reshape((GRAN, GRAN)), S[:, 1, None].reshape((GRAN, GRAN)) 45 | ax.quiver(x_grid, y_grid, S_x_grid, S_y_grid, headwidth=4, headlength=4, alpha=0.3) 46 | 47 | noise = np.random.randn(*G.data.shape) 48 | noisy_data = G.data + sigma * noise 49 | 50 | ax.set_xlim([-4, 4]); ax.set_ylim([-4, 4]) 51 | fig.savefig(f'figs/test_{i}.png') 52 | i += 1 53 | 54 | G.train_score(noisy_data, noise, sigma) -------------------------------------------------------------------------------- /fokker_plank.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | from utils import ( 7 | GDensity, 8 | GMixDensity, 9 | Uniform, 10 | EmpiricalDensity 11 | ) 12 | 13 | g1 = GDensity([0., 0.], [ 14 | [1., 0.9], 15 | [0.9, 1.] 16 | ]) 17 | 18 | g2 = GDensity([0., 0.], [ 19 | [1., -0.9], 20 | [-0.9, 1.] 21 | ]) 22 | 23 | G = GMixDensity([g2, g1]) 24 | 25 | N = GDensity([0., 0.], 26 | [ 27 | [1., 0], 28 | [0., 1.] 29 | ] 30 | ) 31 | U = Uniform([-2, 2], [-2, 2]) 32 | n = EmpiricalDensity(N.sample(N=1000)) 33 | u = EmpiricalDensity(U.sample(N=1000)) 34 | 35 | def stoch_proc_1(x, delta=0.01): 36 | return x + G.score(x) * delta + np.sqrt(2 * delta) * np.random.randn(*x.shape) 37 | 38 | ITER = 50 39 | 40 | fig, ax = plt.subplots(1, 3, figsize=(14, 4)) 41 | XRANGE = [-3, 3] 42 | YRANGE = [-3, 3] 43 | 44 | x_grid, y_grid = np.meshgrid(np.linspace(-1, 3, 10), np.linspace(-3, 1, 10)) 45 | dfdx = lambda x, y: -x / np.sqrt(x ** 2 + y ** 2) 46 | dfdy = lambda x, y: -y / np.sqrt(x ** 2 + y ** 2) 47 | x_vf = dfdx(x_grid, y_grid) 48 | y_vf = dfdy(x_grid, y_grid) 49 | 50 | p_x, p_y = 1, -2 51 | q_x, q_y = 2, 0 52 | 53 | for i in tqdm(range(ITER)): 54 | if i % (ITER // 50) == 0: 55 | ax[0].cla() 56 | ax[1].cla() 57 | ax[2].cla() 58 | n.plot_density(ax=ax[0], cmap='Blues') 59 | ax[1].quiver(x_grid, y_grid, x_vf, y_vf) 60 | ax[1].scatter([p_x, ], [p_y, ]) 61 | ax[1].scatter([q_x, ], [q_y, ], color='green') 62 | ax[1].scatter([0., ], [0., ], color='red') 63 | ax[1].text(0.3, 0.4, r'$q_{data}$', ha='center', va='center', color='red', fontsize=20) 64 | p_x, p_y = p_x + 5.e-2 * dfdx(p_x, p_y), p_y + 5.e-2 * dfdy(p_x, p_y) 65 | q_x, q_y = q_x + 5.e-2 * dfdx(q_x, q_y), q_y + 5.e-2 * dfdy(q_x, q_y) 66 | u.plot_density(ax=ax[2], cmap='Greens') 67 | ax[0].set_xlim(XRANGE); ax[0].set_ylim(YRANGE) 68 | ax[1].set_xlim([-1, 3]); ax[1].set_ylim([-3, 1]) 69 | ax[2].set_xlim(XRANGE); ax[2].set_ylim(YRANGE) 70 | ax[0].set_title(r'$p_t\ |\ p_0 = \mathcal{N}(0, I)$', fontsize=20) 71 | ax[1].set_title(r'$p$ space', fontsize=20) 72 | ax[2].set_title(r'$p_t\ |\ p_0 = \mathcal{U}(-2, 2)$', fontsize=20) 73 | ax[0].axis('off') 74 | ax[1].axis('off') 75 | ax[2].axis('off') 76 | plt.savefig(f'figs/test_{i}.png', bbox_inches='tight', pad_inches=0) 77 | 78 | n.nudge(stoch_proc_1, delta=5.e-3) 79 | u.nudge(stoch_proc_1, delta=5.e-3) 80 | -------------------------------------------------------------------------------- /forward.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | from utils import ( 7 | GDensity, 8 | GMixDensity, 9 | Uniform, 10 | EmpiricalDensity 11 | ) 12 | 13 | g1 = GDensity([0., 0.], [ 14 | [1., 0.9], 15 | [0.9, 1.] 16 | ]) 17 | 18 | g2 = GDensity([0., 0.], [ 19 | [1., -0.9], 20 | [-0.9, 1.] 21 | ]) 22 | 23 | G = GMixDensity([g2, g1]) 24 | 25 | N = GDensity([0., 0.], 26 | [ 27 | [1., 0], 28 | [0., 1.] 29 | ] 30 | ) 31 | U = Uniform([-2, 2], [-2, 2]) 32 | n = EmpiricalDensity(N.sample(N=1000)) 33 | u = EmpiricalDensity(U.sample(N=1000)) 34 | 35 | def stoch_proc_1(x, delta=0.01): 36 | return x + G.score(x) * delta + np.sqrt(2 * delta) * np.random.randn(*x.shape) 37 | 38 | ITER = 50 39 | 40 | fig, ax = plt.subplots(1, 2, figsize=(11, 4)) 41 | XRANGE = [-3, 3] 42 | YRANGE = [-3, 3] 43 | 44 | x_grid, y_grid = np.meshgrid(np.linspace(-1, 3, 10), np.linspace(-3, 1, 10)) 45 | dfdx = lambda x, y: -x / np.sqrt(x ** 2 + y ** 2) 46 | dfdy = lambda x, y: -y / np.sqrt(x ** 2 + y ** 2) 47 | x_vf = dfdx(x_grid, y_grid) 48 | y_vf = dfdy(x_grid, y_grid) 49 | 50 | p_x, p_y = 1, -2 51 | q_x, q_y = 2, 0 52 | 53 | for i in tqdm(range(ITER)): 54 | if i % (ITER // 50) == 0: 55 | ax[0].cla() 56 | ax[1].cla() 57 | n.plot_density(ax=ax[0], cmap='Reds') 58 | ax[1].quiver(x_grid, y_grid, x_vf, y_vf) 59 | ax[1].scatter([p_x, ], [p_y, ], color='red') 60 | # ax[1].scatter([q_x, ], [q_y, ], color='green') 61 | # ax[1].scatter([0., ], [0., ], color='red') 62 | ax[1].scatter([1, ], [-2, ], color='blue') 63 | ax[1].text(0.3, 0.4, r'$q_{data}$', ha='center', va='center', color='red', fontsize=20) 64 | ax[1].text(1.3, -2.4, r'$\mathcal{N}(0, I)$', ha='center', va='center', color='blue', fontsize=20) 65 | p_x, p_y = p_x + 5.e-2 * dfdx(p_x, p_y), p_y + 5.e-2 * dfdy(p_x, p_y) 66 | q_x, q_y = q_x + 5.e-2 * dfdx(q_x, q_y), q_y + 5.e-2 * dfdy(q_x, q_y) 67 | ax[0].set_xlim(XRANGE); ax[0].set_ylim(YRANGE) 68 | ax[1].set_xlim([-1, 3]); ax[1].set_ylim([-3, 1]) 69 | ax[0].set_title(r'$q_t\ |\ q_0 = q_{data}$', fontsize=20) 70 | ax[1].set_title(r'$p$ space', fontsize=20) 71 | ax[0].axis('off') 72 | ax[1].axis('off') 73 | # ax[2].axis('off') 74 | plt.savefig(f'figs/test_{ITER - i}.png', bbox_inches='tight', pad_inches=0) 75 | 76 | n.nudge(stoch_proc_1, delta=5.e-3) 77 | u.nudge(stoch_proc_1, delta=5.e-3) 78 | -------------------------------------------------------------------------------- /langevin_dyn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | from utils import ( 5 | GDensity, 6 | GMixDensity, 7 | EmpiricalDensity 8 | ) 9 | 10 | g1 = GDensity([0., 0.], [ 11 | [1., 0.9], 12 | [0.9, 1.] 13 | ]) 14 | 15 | g2 = GDensity([0., 0.], [ 16 | [1., -0.9], 17 | [-0.9, 1.] 18 | ]) 19 | 20 | G = GMixDensity([g2, g1]) 21 | 22 | N = GDensity([0., 0.], 23 | [ 24 | [1., 0], 25 | [0., 1.] 26 | ] 27 | ) 28 | n = EmpiricalDensity(N.sample(N=1000)) 29 | 30 | def stoch_proc_1(x, delta=0.01): 31 | return x + G.score(x) * delta + np.sqrt(2 * delta) * np.random.randn(*x.shape) 32 | 33 | ITER = 50 34 | 35 | fig, ax = plt.subplots(1, 3, figsize=(14, 4)) 36 | XRANGE = [-3, 3] 37 | YRANGE = [-3, 3] 38 | 39 | for i in range(ITER): 40 | if i % (ITER // 50) == 0: 41 | ax[0].cla() 42 | ax[1].cla() 43 | ax[2].cla() 44 | n.plot_density(ax=ax[1], cmap='Blues') 45 | n.plot_traj(ax=ax[0], color='black') 46 | G.plot_density(ax=ax[2], cmap='Reds') 47 | ax[0].set_xlim(XRANGE); ax[0].set_ylim(YRANGE) 48 | ax[1].set_xlim(XRANGE); ax[1].set_ylim(YRANGE) 49 | ax[2].set_xlim(XRANGE); ax[2].set_ylim(YRANGE) 50 | ax[0].set_title('$x_t$', fontsize=20) 51 | ax[1].set_title('$p_t(x)$', fontsize=20) 52 | ax[2].set_title('$q_{data}(x)$', fontsize=20) 53 | ax[0].axis('off') 54 | ax[1].axis('off') 55 | ax[2].axis('off') 56 | plt.savefig(f'figs/test_{i}.png', bbox_inches='tight', pad_inches=0) 57 | 58 | n.nudge(stoch_proc_1, delta=1.e-2) 59 | -------------------------------------------------------------------------------- /probing_score_learning.py: -------------------------------------------------------------------------------- 1 | from tqdm import trange 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | data = np.array([ 6 | [1.0, 0.1], 7 | [1.25, 0.15], 8 | [1.5, 0.2], 9 | [1.75, 0.25], 10 | [2.0, 0.3], 11 | 12 | [-1.0, 0.1], 13 | [-1.25, 0.15], 14 | [-1.5, 0.2], 15 | [-1.75, 0.25], 16 | [-2.0, 0.3], 17 | ]) 18 | 19 | data[:, 1] *= 3 20 | 21 | fig = plt.figure(figsize=(8, 8), dpi=600) 22 | ax = plt.gca() 23 | 24 | sigma = 1 25 | ball_radius = 0.4 26 | 27 | p1 = np.array([-1, 1]) 28 | p2 = np.array([1, 1]) 29 | p3 = np.array([0., 2.]) 30 | 31 | m1 = np.array([-3., 3.]) 32 | n_m1 = 1 33 | m2 = np.array([2., 2.]) 34 | n_m2 = 1 35 | m3 = np.array([-1., 2.]) 36 | n_m3 = 1 37 | 38 | 39 | for i in trange(200): 40 | ax.scatter(data[:, 0], data[:, 1], color='red', marker='o') 41 | 42 | ax.scatter(*p1, color='blue', marker='x') 43 | ax.scatter(*p2, color='magenta', marker='x') 44 | ax.scatter(*p3, color='green', marker='x') 45 | 46 | ax.text(*(p1 + 0.1), r'$\tilde{x}_1$', color='blue', fontsize=15) 47 | ax.text(*(p2 + 0.1), r'$\tilde{x}_2$', color='magenta', fontsize=15) 48 | ax.text(*(p3 + 0.1), r'$\tilde{x}_3$', color='green', fontsize=15) 49 | 50 | noisy_data = data + sigma * np.random.randn(*data.shape) 51 | 52 | ax.scatter(*m1, color='blue', marker='X') 53 | ax.scatter(*m2, color='magenta', marker='X') 54 | ax.scatter(*m3, color='green', marker='X') 55 | 56 | ax.text(*(m1 + 0.05), r'$\mathbb{E}_{x|\tilde{x}_1}[x]$', color='blue', fontsize=15) 57 | ax.text(*(m2 + 0.05), r'$\mathbb{E}_{x|\tilde{x}_2}[x]$', color='magenta', fontsize=15) 58 | ax.text(*(m3 + 0.05), r'$\mathbb{E}_{x|\tilde{x}_3}[x]$', color='green', fontsize=15) 59 | 60 | ax.arrow(*p1, *((m1 - p1) * 0.5), head_width=0.07, head_length=0.07, length_includes_head=True, linewidth=2, color='gray', alpha=0.5) 61 | ax.arrow(*p2, *((m2 - p2) * 0.5), head_width=0.07, head_length=0.07, length_includes_head=True, linewidth=2, color='gray', alpha=0.5) 62 | ax.arrow(*p3, *((m3 - p3) * 0.5), head_width=0.07, head_length=0.07, length_includes_head=True, linewidth=2, color='gray', alpha=0.5) 63 | 64 | ax.scatter(noisy_data[:, 0], noisy_data[:, 1], color='red', marker='.', alpha=0.4) 65 | for clean_datum, noisy_datum in zip(data, noisy_data): 66 | if np.linalg.norm(noisy_datum - p1) < ball_radius: 67 | ax.arrow(*noisy_datum, *(clean_datum - noisy_datum), 68 | head_width=0.05, head_length=0.05, length_includes_head=True, linewidth=1, color='blue', alpha=0.3, 69 | linestyle='--') 70 | m1 = (m1 * n_m1 + clean_datum) / (n_m1 + 1) 71 | n_m1 += 1 72 | if np.linalg.norm(noisy_datum - p2) < ball_radius: 73 | ax.arrow(*noisy_datum, *(clean_datum - noisy_datum), 74 | head_width=0.05, head_length=0.05, length_includes_head=True, linewidth=1, color='magenta', alpha=0.3, 75 | linestyle='--') 76 | m2 = (m2 * n_m2 + clean_datum) / (n_m2 + 1) 77 | n_m2 += 1 78 | if np.linalg.norm(noisy_datum - p3) < (ball_radius * 2): 79 | ax.arrow(*noisy_datum, *(clean_datum - noisy_datum), 80 | head_width=0.05, head_length=0.05, length_includes_head=True, linewidth=1, color='green', alpha=0.3, 81 | linestyle='--') 82 | m3 = (m3 * n_m3 + clean_datum) / (n_m3 + 1) 83 | n_m3 += 1 84 | 85 | ax.set_xlim([-2.5, 2.5]) 86 | ax.set_ylim([-1, 2.5]) 87 | 88 | ax.axis('off') 89 | 90 | # plt.draw() 91 | # plt.pause(0.1) 92 | 93 | plt.savefig(f'figs/test_{i}.png', bbox_inches='tight', pad_inches=0) 94 | ax.cla() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import typing as ty 2 | import torch as th 3 | import numpy as np 4 | from scipy.stats import multivariate_normal 5 | import seaborn as sns 6 | import queue 7 | 8 | def shuffle_along_axis(a, axis): 9 | idx = np.random.rand(*a.shape).argsort(axis = axis) 10 | return np.take_along_axis(a, idx, axis = axis) 11 | 12 | class GDensity: 13 | 14 | def __init__(self, mean, cov) -> None: 15 | super().__init__() 16 | self.mean, self.cov = np.array(mean), np.array(cov) 17 | self.dist = multivariate_normal(self.mean, self.cov) 18 | self.data = self.sample() 19 | 20 | def sample(self, N=1000): 21 | return self.dist.rvs(size=N) 22 | 23 | def likelihood(self, x): 24 | lpdf = self.dist.logpdf(x)[:, None] 25 | lpdf[lpdf < -20] = -20 # just for numerical stability 26 | return np.exp(lpdf) 27 | 28 | def score(self, x, normalize = False): 29 | if len(x.shape) == 1: 30 | x = x[None, :] 31 | 32 | score = - np.einsum('ab,Nb->Na', np.linalg.inv(self.cov), x - self.mean[None, :]) 33 | if normalize: 34 | score = score / np.linalg.norm(score, ord=0, axis=1, keepdims=True) 35 | return score.squeeze() 36 | 37 | def plot_density(self, ax, cmap='Reds'): 38 | EmpiricalDensity.plot_density(self, ax, cmap=cmap) 39 | 40 | class Uniform: 41 | 42 | def __init__(self, xrange, yrange) -> None: 43 | self.xlow, self.xhigh = xrange 44 | self.xrange = self.xhigh - self.xlow 45 | self.ylow, self.yhigh = yrange 46 | self.yrange = self.yhigh - self.ylow 47 | self.data = self.sample() 48 | 49 | def sample(self, N=1000): 50 | x = np.random.rand(N, 1) * self.xrange + self.xlow 51 | y = np.random.rand(N, 1) * self.yrange + self.ylow 52 | return np.concatenate([x, y], axis=-1) 53 | 54 | class GMixDensity: 55 | 56 | def __init__(self, components: ty.List[GDensity]) -> None: 57 | self.components = components 58 | self.M = len(self.components) 59 | self.data = self.sample() 60 | 61 | def sample(self, N=1000): 62 | NS = [N // self.M for _ in range(self.M)] 63 | extra = N - sum(NS) 64 | NS[-1] += extra 65 | 66 | samples = [] 67 | for m in range(self.M): 68 | s = self.components[m].sample(NS[m]) 69 | samples.append(s) 70 | 71 | return np.concatenate(samples, 0) 72 | 73 | def score(self, x, normalize = False): 74 | numerator = sum([comp.score(x, normalize=normalize) * comp.likelihood(x) for comp in self.components]) 75 | denominator = sum([comp.likelihood(x) for comp in self.components]) 76 | score = numerator / denominator 77 | return score 78 | 79 | def plot_density(self, ax, cmap='Reds'): 80 | GDensity.plot_density(self, ax, cmap=cmap) 81 | 82 | 83 | class EmpiricalDensity: 84 | 85 | N_TRAJ = 70 86 | TRAJ_LEN = 7 87 | 88 | def __init__(self, data) -> None: 89 | assert len(data.shape) == 2 and data.shape[1] == 2, "data array not properly sized" 90 | self._data = queue.Queue(EmpiricalDensity.TRAJ_LEN) 91 | self._data.put(data) 92 | 93 | self.score_model = th.nn.Sequential( 94 | th.nn.Linear(2, 12), 95 | th.nn.Tanh(), 96 | th.nn.Linear(12, 2) 97 | ).double() 98 | 99 | self.opt = th.optim.Adam(self.score_model.parameters(), lr=5e-3) 100 | 101 | @property 102 | def data(self): 103 | return self._data.queue[-1] 104 | 105 | @data.setter 106 | def data(self, d): 107 | if self._data.full(): 108 | self._data.get() 109 | 110 | self._data.put(d) 111 | 112 | def plot_density(self, ax, cmap='Reds', **kwargs): 113 | sns.kdeplot(x=self.data[:, 0], y=self.data[:, 1], ax=ax, cmap=cmap, fill=True, **kwargs) 114 | 115 | def plot_traj(self, ax, color='black'): 116 | if self._data.qsize() > 1: 117 | traj = np.stack(list(self._data.queue), 1) 118 | for i in range(EmpiricalDensity.N_TRAJ): 119 | selected_traj = traj[i] 120 | ax.plot(selected_traj[:, 0], selected_traj[:, 1], color=color, linewidth=0.5, alpha=0.2) 121 | 122 | latest = self._data.queue[-1] 123 | selected_latest = latest[:EmpiricalDensity.N_TRAJ, ...] 124 | ax.scatter(selected_latest[:, 0], selected_latest[:, 1], color=color, s=2) 125 | 126 | def nudge(self, fn, delta = 0.01): 127 | self.data = fn(self.data, delta) 128 | 129 | def score(self, x, normalize=False): 130 | with th.no_grad(): 131 | return self.score_model(th.from_numpy(x)).numpy() 132 | 133 | def train_score(self, noisy_data, target, sigma=0.1): 134 | noisy_data = th.from_numpy(noisy_data) 135 | target = th.from_numpy(target) 136 | loss = self.score_model(noisy_data) - (- target / sigma) 137 | self.opt.zero_grad() 138 | loss = (th.linalg.vector_norm(loss, ord=2, dim=-1) ** 2).mean() 139 | loss.backward() 140 | self.opt.step() --------------------------------------------------------------------------------