├── Gifs ├── a.gif ├── b.gif ├── c.gif └── old │ ├── a_old.gif │ ├── b_old.gif │ └── c_old.gif ├── Learning to Draw samples with application to amortized MLE for generative adversarial learning.pdf ├── README.md ├── .gitignore └── Code ├── Stein_GAN_torch.py ├── main_Stein_GAN.py ├── networks.py └── utilities.py /Gifs/a.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mokeddembillel/Amortized-SVGD-GAN/HEAD/Gifs/a.gif -------------------------------------------------------------------------------- /Gifs/b.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mokeddembillel/Amortized-SVGD-GAN/HEAD/Gifs/b.gif -------------------------------------------------------------------------------- /Gifs/c.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mokeddembillel/Amortized-SVGD-GAN/HEAD/Gifs/c.gif -------------------------------------------------------------------------------- /Gifs/old/a_old.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mokeddembillel/Amortized-SVGD-GAN/HEAD/Gifs/old/a_old.gif -------------------------------------------------------------------------------- /Gifs/old/b_old.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mokeddembillel/Amortized-SVGD-GAN/HEAD/Gifs/old/b_old.gif -------------------------------------------------------------------------------- /Gifs/old/c_old.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mokeddembillel/Amortized-SVGD-GAN/HEAD/Gifs/old/c_old.gif -------------------------------------------------------------------------------- /Learning to Draw samples with application to amortized MLE for generative adversarial learning.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mokeddembillel/Amortized-SVGD-GAN/HEAD/Learning to Draw samples with application to amortized MLE for generative adversarial learning.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Implementation of Amortized Stien Variational Gradient Descent 2 | 3 | Based on the paper : [Learning to draw samples: with application to amortized maximum likelihood estimator for generative adversarial learning](https://arxiv.org/abs/1611.01722) 4 | 5 | ## here are some examples : 6 | 7 | 8 | ![Gaussain.gif](Gifs/b.gif) 9 | 10 | 11 | ![Gaussain2.gif](Gifs/a.gif) 12 | 13 | 14 | ![Gaussain3.gif](Gifs/c.gif) 15 | 16 | ### If Anyone is interested in contributing or collaborating, I would be very glad for that. Thank you 17 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | .ipynb_checkpoints/Stein_GAN-checkpoint.ipynb 3 | *.png 4 | archive/Stein_GAN works 3.py 5 | archive/Stein.ipynb 6 | archive/stein.py 7 | archive/Stein_GAN - worked once.py 8 | archive/Stein_GAN 2 gen.py 9 | archive/Stein_GAN E10.py 10 | archive/Stein_GAN works 1.py 11 | archive/Stein_GAN Works 2.py 12 | archive/Stein_GAN.ipynb 13 | archive/Stein_GAN_O_Edited.py 14 | archive/Stein_GAN_original.py 15 | .ipynb_checkpoints/Stein-checkpoint.ipynb 16 | archive/Stein_GAN_completely_working.py 17 | Stein_GAN_Conditional - Copy.py 18 | plot_heatmap.py 19 | Stein_GAN_Conditional.py 20 | *.ipynb 21 | *.pyc 22 | Stein_GAN copy.py 23 | -------------------------------------------------------------------------------- /Code/Stein_GAN_torch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import seaborn as sns 3 | sns.set() 4 | 5 | import torch as T 6 | import torch.autograd as autograd 7 | 8 | 9 | 10 | 11 | 12 | print(T.__version__) 13 | device = T.device('cuda' if T.cuda.is_available() else 'cpu') 14 | 15 | 16 | 17 | def rbf_kernel(X, Y, h_min=1e-3): 18 | 19 | XX = X.matmul(X.t()) 20 | XY = X.matmul(Y.t()) 21 | YY = Y.matmul(Y.t()) 22 | 23 | dnorm2 = XX.diag().unsqueeze(1) + YY.diag().unsqueeze(0) - 2 * XY 24 | 25 | # Apply the median heuristic (PyTorch does not give true median) 26 | np_dnorm2 = dnorm2.detach().cpu().numpy() 27 | h = np.median(np_dnorm2) / (2 * np.log(X.size(0) + 1)) 28 | sigma = np.sqrt(h).item() 29 | 30 | gamma = 1.0 / (1e-8 + 2 * sigma ** 2) 31 | K_XY = (-gamma * dnorm2).exp() 32 | 33 | grad_K = -autograd.grad(K_XY.mean(), X)[0] 34 | 35 | return K_XY, grad_K 36 | 37 | 38 | def learn_G(P, g_net, d_net, batch_size = 10): 39 | # Draw zeta random samples 40 | zeta = T.FloatTensor(batch_size, 2).uniform_(0, 1) 41 | # Forward the noise through the network 42 | f_x = g_net.forward(zeta) 43 | ### Equation 7 (Compute Delta xi) 44 | # Get the energy using the discriminator 45 | score = d_net.forward(f_x) 46 | 47 | # Get the Gradients of the energy with respect to x and y 48 | grad_score = autograd.grad(-score.sum(), f_x)[0].squeeze(-1) 49 | #grad_score = autograd.grad(P.log_prob(f_x).sum(), f_x)[0].squeeze(-1) 50 | 51 | # Compute the similarity using the RBF kernel 52 | kappa, grad_kappa = rbf_kernel(f_x, f_x) 53 | 54 | # Compute the SVGD 55 | svgd = (T.matmul(kappa.squeeze(-1), grad_score) + grad_kappa) / f_x.size(0) 56 | 57 | # Update the network 58 | g_net.optimizer.zero_grad() 59 | autograd.backward(-f_x, grad_tensors=svgd) 60 | g_net.optimizer.step() 61 | 62 | 63 | 64 | 65 | def learn_D(g_net,d_net, x_obs, batch_size = 10): 66 | # Draw zeta random samples 67 | zeta = T.FloatTensor(batch_size, 2).uniform_(0, 1) 68 | # Forward the noise through the network 69 | x_obs.requires_grad_(True) 70 | f_x = g_net.forward(zeta).requires_grad_(True) 71 | # Get the energy of the observed data using the discriminator 72 | data_score = d_net.forward(x_obs) 73 | # Get the energy of the generated data using the discriminator 74 | gen_score = d_net.forward(f_x) 75 | print("Data Score : ", data_score.mean().detach().numpy(), "\nGen Score : ", gen_score.mean().detach().numpy()) 76 | 77 | #Calculate the GP loss 78 | grad_r = autograd.grad(data_score.sum(), x_obs, 79 | allow_unused=True, 80 | create_graph=True, 81 | retain_graph=True)[0] 82 | grad_f = autograd.grad(gen_score.sum(), f_x, 83 | allow_unused=True, 84 | create_graph=True, 85 | retain_graph=True)[0] 86 | 87 | loss_gp = T.mean(grad_r.norm(dim=1,p=2)**2) + T.mean(grad_f.norm(dim=1,p=2)**2) 88 | 89 | loss = data_score.mean() - gen_score.mean() + 10 * loss_gp 90 | 91 | #loss = data_score.mean() - gen_score.mean() 92 | #print("Loss : ", loss.detach().numpy()) 93 | 94 | # Update the network 95 | d_net.optimizer.zero_grad() 96 | autograd.backward(loss) 97 | d_net.optimizer.step() 98 | 99 | return data_score.detach().numpy(), gen_score.detach().numpy(), loss.detach().numpy() 100 | -------------------------------------------------------------------------------- /Code/main_Stein_GAN.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import seaborn as sns 4 | sns.set() 5 | import torch as T 6 | 7 | from utilities import generate_dist_1, generate_dist_2, generate_dist_3 8 | from networks import G, D 9 | from Stein_GAN_torch import learn_G, learn_D 10 | 11 | print(T.__version__) 12 | device = T.device('cuda' if T.cuda.is_available() else 'cpu') 13 | 14 | ITER_NUM = int(6e4) 15 | BATCH_SIZE = 64 16 | IMAGE_SHOW = 5e+2 17 | 18 | #x, y, mog2 = generate_dist_1() 19 | x, y, mog2 = generate_dist_2() 20 | #x, y, mog2 = generate_dist_3() 21 | 22 | 23 | plt.rcParams["figure.figsize"] = (20,10) 24 | fig, ax = plt.subplots() 25 | plt.xlim(-30,30) 26 | plt.ylim(-20,20) 27 | zeta = T.FloatTensor(3000, 2).normal_(0, 3) 28 | ax.scatter(x, y,s=1,color='blue') 29 | ax.scatter(zeta[:, 0].numpy(), zeta[:, 1].numpy(),s=1, color='red') 30 | ax.set_title('Iter:'+str(0)+' alpha:') 31 | plt.show() 32 | 33 | def train(alpha=1.0): 34 | 35 | data_score, gen_score, loss = [], [], [] 36 | for i in range(ITER_NUM): 37 | #('Iteration :', i) 38 | # sample minibatch 39 | index = np.random.choice(range(len(x)), size=BATCH_SIZE, replace=False) 40 | mini_x = x[index] 41 | mini_y = y[index] 42 | x_obs = [] 43 | for j in range(len(mini_x)): 44 | x_obs.append([mini_x[j], mini_y[j]]) 45 | x_obs = T.from_numpy(np.array(x_obs, dtype=np.float32)).float() 46 | if i%1 == 0: 47 | 48 | # learn discriminator 49 | a, b, c = learn_D(g_net, d_net, x_obs, batch_size=BATCH_SIZE) 50 | data_score.append(a[0]) 51 | gen_score.append(b[0]) 52 | loss.append(c) 53 | #print(a) 54 | 55 | if i%20 == 0: 56 | 57 | # eval_svgd 58 | learn_G(mog2, g_net, d_net, batch_size=BATCH_SIZE) 59 | 60 | 61 | #print(i) 62 | if (i+1)%IMAGE_SHOW == 0: 63 | plt.rcParams["figure.figsize"] = (20,10) 64 | fig, ax = plt.subplots() 65 | plt.xlim(-30,30) 66 | plt.ylim(-20,20) 67 | 68 | zeta = T.FloatTensor(1000, 2).normal_(0, 20) 69 | 70 | ax.scatter(x, y,s=1,color='blue') 71 | 72 | predict = g_net.forward(zeta.cpu()).detach().cpu().squeeze(-1) 73 | 74 | 75 | ax.scatter(predict[:, 0].numpy(), predict[:, 1].numpy(),s=1, color='red') 76 | 77 | 78 | ax.set_title('Iter:'+str(i+1)+' alpha:'+str(alpha)) 79 | plt.show() 80 | 81 | def moving_average(a, n=10) : 82 | ret = np.cumsum(a, dtype=float) 83 | ret[n:] = ret[n:] - ret[:-n] 84 | return ret[n - 1:] / n 85 | 86 | data_score = moving_average(data_score, 5) 87 | gen_score = moving_average(gen_score, 5) 88 | loss = moving_average(loss, 5) 89 | 90 | 91 | 92 | plt.plot(data_score) 93 | plt.ylabel('Data score') 94 | plt.show() 95 | 96 | plt.plot(gen_score) 97 | plt.ylabel('Gen score') 98 | plt.show() 99 | 100 | plt.plot(loss) 101 | plt.ylabel('Loss') 102 | plt.show() 103 | 104 | g_net = G().cpu() 105 | d_net = D().cpu() 106 | 107 | train(1.) 108 | 109 | x = np.arange(-20, 20, 0.5) 110 | y = np.zeros(len(x)) 111 | particles = [] 112 | for i in range(len(x)): 113 | particles.append([x[i], y[i]]) 114 | particles = T.tensor(np.array(particles, dtype=np.float32)) 115 | 116 | energy = d_net.forward(particles).detach().numpy() 117 | #energy = mog2.log_prob(particles).detach().numpy() 118 | #P.log_prob(f_x) 119 | 120 | #print(energy) 121 | 122 | plt.plot(energy) 123 | plt.ylabel('energy') 124 | plt.show() 125 | -------------------------------------------------------------------------------- /Code/networks.py: -------------------------------------------------------------------------------- 1 | import seaborn as sns 2 | sns.set() 3 | 4 | import torch as T 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | 10 | class G(nn.Module): 11 | def __init__(self,lr=1e-4, input_dim=2): 12 | super().__init__() 13 | 14 | # Initialize Input dimentions 15 | self.fc1_dim = 2 16 | self.fc2_dim = 128 17 | self.fc3_dim = 50 18 | self.fc4_dim = 50 19 | self.fc5_dim = 2 20 | 21 | 22 | 23 | self.fc1 = nn.Linear(self.fc1_dim, self.fc2_dim) 24 | self.fc2 = nn.Linear(self.fc2_dim, self.fc2_dim) 25 | self.fc3 = nn.Linear(self.fc2_dim, self.fc2_dim) 26 | self.fc4 = nn.Linear(self.fc2_dim, self.fc2_dim) 27 | self.fc5 = nn.Linear(self.fc2_dim, self.fc5_dim) 28 | 29 | 30 | self.bn1 = nn.BatchNorm1d(self.fc2_dim) 31 | self.bn2 = nn.BatchNorm1d(self.fc2_dim) 32 | self.bn3 = nn.BatchNorm1d(self.fc2_dim) 33 | self.bn4 = nn.BatchNorm1d(self.fc2_dim) 34 | 35 | 36 | # Initialize layers weights 37 | self.fc1.weight.data.normal_(0,1) 38 | self.fc2.weight.data.normal_(0,1) 39 | self.fc3.weight.data.normal_(0,1) 40 | self.fc4.weight.data.normal_(0,1) 41 | self.fc5.weight.data.normal_(0,1) 42 | 43 | 44 | # Initialize layers biases 45 | nn.init.constant_(self.fc1.bias, 0) 46 | nn.init.constant_(self.fc2.bias, 0) 47 | nn.init.constant_(self.fc3.bias, 0) 48 | nn.init.constant_(self.fc4.bias, 0) 49 | nn.init.constant_(self.fc5.bias, 0) 50 | 51 | 52 | self.lr = lr 53 | # Define Optimizer 54 | self.optimizer = T.optim.Adam(self.parameters(), lr = self.lr, betas=(0.0, 0.99)) 55 | 56 | # Set Device 57 | self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu') 58 | self.to(self.device) 59 | 60 | def forward(self, input, num_particle=5): 61 | X = self.fc1(input) 62 | X = F.tanh(self.bn1(X)) 63 | X = self.fc2(X) 64 | X = F.tanh(self.bn2(X)) 65 | X = self.fc3(X) 66 | X = F.tanh(self.bn3(X)) 67 | X = self.fc4(X) 68 | X = F.tanh(self.bn4(X)) 69 | X = self.fc5(X) 70 | return X 71 | 72 | 73 | 74 | 75 | # In[ ]: 76 | 77 | 78 | class D(nn.Module): 79 | def __init__(self,lr=1e-4, input_dim=2): 80 | super().__init__() 81 | 82 | # Initialize Input dimentions 83 | self.fc1_dim = 2 84 | self.fc2_dim = 128 85 | self.fc3_dim = 1 86 | 87 | #self.fc1 = nn.Linear(self.fc1_dim, self.fc3_dim) 88 | 89 | # Define the NN layers 90 | self.fc1 = nn.Linear(self.fc1_dim, self.fc2_dim) 91 | self.fc2 = nn.Linear(self.fc2_dim, self.fc2_dim) 92 | self.fc3 = nn.Linear(self.fc2_dim, self.fc3_dim) 93 | 94 | # self.bn1 = nn.BatchNorm1d(self.fc2_dim) 95 | # self.bn2 = nn.BatchNorm1d(self.fc2_dim) 96 | 97 | # Initialize layers weights 98 | self.fc1.weight.data.normal_(0,0.2) 99 | self.fc2.weight.data.normal_(0,0.2) 100 | self.fc3.weight.data.normal_(0,0.2) 101 | 102 | # # Initialize layers biases 103 | nn.init.constant_(self.fc1.bias, 0.0) 104 | nn.init.constant_(self.fc2.bias, 0.0) 105 | nn.init.constant_(self.fc3.bias, 0.0) 106 | 107 | self.lr = lr 108 | # Define Optimizer 109 | self.optimizer = T.optim.Adam(self.parameters(), lr = self.lr, betas=(0.0, 0.99)) 110 | 111 | 112 | # Set Device 113 | self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu') 114 | self.to(self.device) 115 | 116 | def forward(self, input): 117 | 118 | X = F.relu(self.fc1(input)) 119 | X = F.relu(self.fc2(X)) 120 | X = self.fc3(X) 121 | return -T.log(F.sigmoid(X)+.05) 122 | 123 | -------------------------------------------------------------------------------- /Code/utilities.py: -------------------------------------------------------------------------------- 1 | 2 | import torch as T 3 | 4 | import altair as alt 5 | import numpy as np 6 | 7 | import seaborn as sns 8 | import pandas as pd 9 | sns.set() 10 | import torch.optim as optim 11 | 12 | 13 | device = T.device('cuda' if T.cuda.is_available() else 'cpu') 14 | 15 | 16 | class MoG(T.distributions.Distribution): 17 | def __init__(self, loc, covariance_matrix): 18 | self.num_components = loc.size(0) 19 | self.loc = loc 20 | self.covariance_matrix = covariance_matrix 21 | 22 | self.dists = [ 23 | T.distributions.MultivariateNormal(mu, covariance_matrix=sigma) 24 | for mu, sigma in zip(loc, covariance_matrix) 25 | ] 26 | 27 | super(MoG, self).__init__(T.Size([]), T.Size([loc.size(-1)])) 28 | 29 | @property 30 | def arg_constraints(self): 31 | return self.dists[0].arg_constraints 32 | 33 | @property 34 | def support(self): 35 | return self.dists[0].support 36 | 37 | @property 38 | def has_rsample(self): 39 | return False 40 | 41 | def log_prob(self, value): 42 | return T.cat( 43 | [p.log_prob(value).unsqueeze(-1) for p in self.dists], dim=-1).logsumexp(dim=-1) 44 | 45 | def enumerate_support(self): 46 | return self.dists[0].enumerate_support() 47 | 48 | class MoG2(MoG): 49 | def __init__(self, device=None): 50 | loc = T.Tensor([[-10.0, 0.0], [10.0, 0.0]]).to(device) 51 | cov = T.Tensor([0.5, 0.5]).diag().unsqueeze(0).repeat(2, 1, 1).to(device) 52 | #print(cov.T) 53 | super(MoG2, self).__init__(loc, cov) 54 | 55 | 56 | 57 | class RBF(T.nn.Module): 58 | def __init__(self, sigma=None): 59 | super(RBF, self).__init__() 60 | 61 | self.sigma = sigma 62 | 63 | def forward(self, X, Y): 64 | XX = X.matmul(X.t()) 65 | XY = X.matmul(Y.t()) 66 | YY = Y.matmul(Y.t()) 67 | 68 | dnorm2 = -2 * XY + XX.diag().unsqueeze(1) + YY.diag().unsqueeze(0) 69 | 70 | # Apply the median heuristic (PyTorch does not give true median) 71 | if self.sigma is None: 72 | np_dnorm2 = dnorm2.detach().cpu().numpy() 73 | h = np.median(np_dnorm2) / (2 * np.log(X.size(0) + 1)) 74 | sigma = np.sqrt(h).item() 75 | else: 76 | sigma = self.sigma 77 | 78 | gamma = 1.0 / (1e-8 + 2 * sigma ** 2) 79 | K_XY = (-gamma * dnorm2).exp() 80 | 81 | return K_XY 82 | 83 | # Let us initialize a reusable instance right away. 84 | K = RBF() 85 | 86 | class SVGD: 87 | def __init__(self, P, K, optimizer): 88 | self.P = P 89 | self.K = K 90 | self.optim = optimizer 91 | 92 | def phi(self, X): 93 | X = X.detach().requires_grad_(True) 94 | 95 | log_prob = self.P.log_prob(X) 96 | score_func = T.autograd.grad(log_prob.sum(), X)[0] 97 | 98 | K_XX = self.K(X, X.detach()) 99 | grad_K = -T.autograd.grad(K_XX.sum(), X)[0] 100 | 101 | phi = (K_XX.detach().matmul(score_func) + grad_K) / X.size(0) 102 | 103 | return phi 104 | 105 | def step(self, X): 106 | self.optim.zero_grad() 107 | X.grad = -self.phi(X) 108 | self.optim.step() 109 | 110 | def get_density_chart(P, d=7.0, step=0.1): 111 | xv, yv = T.meshgrid([ 112 | T.arange(-d, d, step), 113 | T.arange(-d, d, step) 114 | ]) 115 | pos_xy = T.cat((xv.unsqueeze(-1), yv.unsqueeze(-1)), dim=-1) 116 | p_xy = P.log_prob(pos_xy.to(device)).exp().unsqueeze(-1).cpu() 117 | 118 | df = T.cat([pos_xy, p_xy], dim=-1).numpy() 119 | df = pd.DataFrame({ 120 | 'x': df[:, :, 0].ravel(), 121 | 'y': df[:, :, 1].ravel(), 122 | 'p': df[:, :, 2].ravel(), 123 | }) 124 | 125 | chart = alt.Chart(df).mark_point().encode( 126 | x='x:Q', 127 | y='y:Q', 128 | color=alt.Color('p:Q', scale=alt.Scale(scheme='viridis')), 129 | tooltip=['x','y','p'] 130 | ) 131 | 132 | return chart 133 | 134 | 135 | def get_particles_chart(X): 136 | df = pd.DataFrame({ 137 | 'x': X[:, 0], 138 | 'y': X[:, 1], 139 | }) 140 | 141 | chart = alt.Chart(df).mark_circle(color='red').encode( 142 | x='x:Q', 143 | y='y:Q' 144 | ) 145 | 146 | return chart 147 | 148 | def generate_dist_1(): 149 | f = lambda x : x**2 150 | 151 | x = np.linspace(-3,3,300) 152 | #y = f(x) + np.random.uniform(low=-1.0,high=1.0,size=x.shape)* 4* np.cos(x/2+.5) 153 | y = f(x) + np.random.normal(0,1,size=x.shape) + np.random.normal(1,1,size=x.shape) 154 | 155 | y_ = -y 156 | x = np.concatenate((x,x), axis=0) 157 | y = np.concatenate((y,y_),axis =0) 158 | return x, y, None 159 | 160 | def generate_dist_2(): 161 | mog2 = MoG2(device=device) 162 | 163 | n = 300 164 | X_init = (10 * T.randn(n, *mog2.event_shape)).to(device) 165 | 166 | #mog2_chart = get_density_chart(mog2, d=7.0, step=0.1) 167 | 168 | X = X_init.clone() 169 | svgd = SVGD(mog2, K, T.optim.Adam([X], lr=1e-1)) 170 | for _ in range(1500): 171 | svgd.step(X) 172 | 173 | x = X[:, 0].reshape(-1, 1) 174 | y = X[:, 1].reshape(-1, 1) 175 | return x, y, mog2 176 | 177 | def generate_dist_3(): 178 | gauss = T.distributions.MultivariateNormal(T.Tensor([10, 10]).to(device), 179 | covariance_matrix=5 * T.Tensor([[0.3, 0],[0, 0.3]]).to(device)) 180 | n = 300 181 | X_init = (3 * T.randn(n, *gauss.event_shape)).to(device) 182 | X = X_init.clone() 183 | svgd = SVGD(gauss, K, optim.Adam([X], lr=1e-1)) 184 | for _ in range(1000): 185 | svgd.step(X) 186 | x = X[:, 0].reshape(-1, 1) 187 | y = X[:, 1].reshape(-1, 1) 188 | return x, y, None 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | --------------------------------------------------------------------------------