├── BIGAN.py ├── README.md ├── main.py ├── models.py ├── plot_utils.py ├── representation_plot.py └── utils.py /BIGAN.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch, time, os, pickle 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | from torch.optim import Adam 8 | from torchvision import datasets, transforms 9 | import utils 10 | import torch.optim as optim 11 | from torch.utils.data import DataLoader 12 | from torchvision import datasets, transforms 13 | from itertools import * 14 | from tensorflow.examples.tutorials.mnist import input_data 15 | import matplotlib.pyplot as plt 16 | import matplotlib.gridspec as gridspec 17 | import numpy as np 18 | import math 19 | from functools import reduce 20 | 21 | from plot_utils import save_plot_losses, save_plot_pixel_norm, save_plot_z_norm 22 | from models import Generator_FC, Discriminator_FC, Encoder_FC, Generator_CNN, Discriminator_CNN, Encoder_CNN 23 | 24 | from representation_plot import plot_representation, plot_representation2 25 | 26 | import matplotlib.mlab as mlab 27 | 28 | def log(x): 29 | return torch.log(x + 1e-8) 30 | 31 | class Mnist: 32 | def __init__(self, batch_size): 33 | MNIST_MEAN = 0.1307 34 | MNIST_STD = 0.3081 35 | 36 | dataset_transform = transforms.Compose([ 37 | transforms.ToTensor(), 38 | # transforms.Normalize((MNIST_MEAN,), (MNIST_STD,)) 39 | ]) 40 | 41 | train_dataset = datasets.MNIST('../data', train=True, download=True, transform=dataset_transform) 42 | test_dataset = datasets.MNIST('../data', train=False, download=True, transform=dataset_transform) 43 | 44 | self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 45 | self.test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True) 46 | 47 | class RobotWorld: 48 | def __init__(self, batch_size, dataset_path, gpu_mode): 49 | self.gpu_mode = gpu_mode 50 | self.batch_size = batch_size 51 | 52 | print('Loading data ... ') 53 | 54 | 55 | ########################################################### 56 | # Import training dataset 57 | path = dataset_path + '/simple_navigation_task_train.npz' 58 | training_data = np.load(path) 59 | 60 | observations, actions = training_data['observations'], training_data['actions'] 61 | rewards, episode_starts = training_data['rewards'], training_data['episode_starts'] 62 | obs_dim = reduce(lambda x,y: x*y, observations.shape[1:]) 63 | 64 | if len(observations.shape) > 2: 65 | # Channel first 66 | observations = np.transpose(observations, (0, 3, 1, 2)) 67 | # Flatten the image 68 | observations = observations.reshape((-1, obs_dim)) 69 | 70 | ########################################################### 71 | # Import evaluation dataset 72 | path = dataset_path + '/simple_navigation_task_test.npz' 73 | test_data = np.load(path) 74 | 75 | test_observations, actions = test_data['observations'], test_data['actions'] 76 | rewards, episode_starts = test_data['rewards'], test_data['episode_starts'] 77 | test_obs_dim = reduce(lambda x,y: x*y, test_observations.shape[1:]) 78 | 79 | if len(test_observations.shape) > 2: 80 | # Channel first 81 | test_observations = np.transpose(test_observations, (0, 3, 1, 2)) 82 | # Flatten the image 83 | test_observations = test_observations.reshape((-1, test_obs_dim)) 84 | 85 | ########################################################## 86 | self.observations = observations.astype(np.float32) 87 | 88 | obs_var = Variable(torch.from_numpy(observations), volatile=True) 89 | if self.gpu_mode: 90 | obs_var = obs_var.cuda() 91 | 92 | num_samples = observations.shape[0] - 1 # number of samples 93 | 94 | # indices for all time steps where the episode continues 95 | indices = np.array([i for i in range(num_samples)], dtype='int64') 96 | np.random.shuffle(indices) 97 | 98 | # split indices into minibatches 99 | self.minibatchlist = [np.array(sorted(indices[start_idx:start_idx + self.batch_size])) 100 | for start_idx in range(0, num_samples - self.batch_size + 1, self.batch_size)] 101 | 102 | ########################################################### 103 | 104 | self.test_observations = test_observations.astype(np.float32) 105 | 106 | test_obs_var = Variable(torch.from_numpy(test_observations), volatile=True) 107 | if self.gpu_mode: 108 | test_obs_var = test_obs_var.cuda() 109 | 110 | num_test_samples = test_observations.shape[0] - 1 # number of samples 111 | 112 | # indices for all time steps where the episode continues 113 | indices = np.array([i for i in range(num_test_samples)], dtype='int64') 114 | np.random.shuffle(indices) 115 | 116 | # split indices into minibatches 117 | self.test_minibatchlist = [np.array(sorted(indices[start_idx:start_idx + self.batch_size])) 118 | for start_idx in range(0, num_test_samples - self.batch_size + 1, self.batch_size)] 119 | 120 | ########################################################### 121 | 122 | self.train_loader = [ ( torch.from_numpy(self.observations[batch]).float() , 0) 123 | for it, batch in list(enumerate(self.minibatchlist)) ] 124 | self.test_loader = [ ( torch.from_numpy(self.test_observations[batch]).float() , 0) 125 | for it, batch in list(enumerate(self.test_minibatchlist)) ] 126 | 127 | def shuffle(self): 128 | # shuffle the minibatches 129 | enumerated_minibatches = list(enumerate(self.minibatchlist)) 130 | np.random.shuffle(enumerated_minibatches) 131 | 132 | enumerated_test_minibatches = list(enumerate(self.test_minibatchlist)) 133 | np.random.shuffle(enumerated_test_minibatches) 134 | 135 | self.train_loader = [ ( torch.from_numpy(self.observations[batch]).float() , it) 136 | for it, batch in enumerated_minibatches ] 137 | self.test_loader = [ ( torch.from_numpy(self.test_observations[batch]).float() , it) 138 | for it, batch in enumerated_test_minibatches ] 139 | 140 | class BIGAN(object): 141 | """ 142 | Class implementing a BIGAN network that trains from an observations dataset 143 | """ 144 | 145 | def __init__(self, args): 146 | self.epoch = args.epoch 147 | self.batch_size = args.batch_size 148 | self.save_dir = args.save_dir 149 | self.result_dir = args.result_dir 150 | self.log_dir = args.log_dir 151 | self.gpu_mode = args.gpu_mode 152 | self.learning_rate = args.lr 153 | self.beta1 = args.beta1 154 | self.beta2 = args.beta2 155 | self.slope = args.slope 156 | self.decay = args.decay 157 | self.dropout = args.dropout 158 | self.network_type = args.network_type 159 | self.dataset = args.dataset 160 | self.dataset_path = args.dataset_path 161 | 162 | # BIGAN parameters 163 | self.z_dim = args.z_dim #dimension of feature space 164 | self.h_dim = args.h_dim #dimension of the hidden layer 165 | 166 | if args.dataset == 'mnist': 167 | self.X_dim = 28*28 #dimension of data 168 | self.num_channels = 1 169 | elif args.dataset == 'robot_world': 170 | self.X_dim = 16*16*3 #dimension of data 171 | self.num_channels = 3 172 | 173 | if args.network_type == 'FC': 174 | # networks init 175 | self.G = Generator_FC(self.z_dim, self.h_dim, self.X_dim) 176 | self.D = Discriminator_FC(self.z_dim, self.h_dim, self.X_dim) 177 | self.E = Encoder_FC(self.z_dim, self.h_dim, self.X_dim) 178 | elif args.network_type == 'CNN': 179 | params = {'slope': self.slope, 'dropout':self.dropout, 'batch_size':self.batch_size, 'num_channels':self.num_channels, 'dataset':self.dataset} 180 | 181 | self.G = Generator_CNN(self.z_dim, self.h_dim, self.X_dim, params) 182 | self.D = Discriminator_CNN(self.z_dim, self.h_dim, self.X_dim, params) 183 | self.E = Encoder_CNN(self.z_dim, self.h_dim, self.X_dim, params) 184 | else: 185 | raise Exception("[!] There is no option for " + args.network_type) 186 | 187 | if self.gpu_mode: 188 | self.G.cuda() 189 | self.D.cuda() 190 | self.E.cuda() 191 | 192 | self.G_solver = optim.Adam(chain(self.E.parameters(), self.G.parameters()), lr=self.learning_rate, betas=[self.beta1,self.beta2], weight_decay=self.decay) 193 | self.D_solver = optim.Adam(self.D.parameters(), lr=self.learning_rate, betas=[self.beta1,self.beta2], weight_decay=self.decay) 194 | 195 | 196 | 197 | print('---------- Networks architecture -------------') 198 | utils.print_network(self.G) 199 | utils.print_network(self.E) 200 | utils.print_network(self.D) 201 | print('-----------------------------------------------') 202 | 203 | 204 | 205 | def D_(self, X, z): 206 | return self.D(X, z) 207 | 208 | def reset_grad(self): 209 | self.E.zero_grad() 210 | self.G.zero_grad() 211 | self.D.zero_grad() 212 | 213 | 214 | def train(self): 215 | if self.dataset == 'mnist': 216 | dataset = Mnist(self.batch_size) 217 | elif self.dataset == 'robot_world': 218 | dataset = RobotWorld(self.batch_size, self.dataset_path, self.gpu_mode) 219 | 220 | 221 | self.train_hist = {} 222 | self.train_hist['D_loss'] = [] 223 | self.train_hist['G_loss'] = [] 224 | 225 | self.eval_hist = {} 226 | self.eval_hist['D_loss'] = [] 227 | self.eval_hist['G_loss'] = [] 228 | self.eval_hist['pixel_norm'] = [] 229 | self.eval_hist['z_norm'] = [] 230 | 231 | 232 | for epoch in range(self.epoch): 233 | print("epoch ",str(epoch)) 234 | 235 | self.D.train() 236 | self.E.train() 237 | self.G.train() 238 | 239 | train_loss_G = 0 240 | train_loss_D = 0 241 | 242 | if self.dataset == "robot_world": 243 | dataset.shuffle() 244 | 245 | for batch_id, (data, target) in enumerate(dataset.train_loader): 246 | 247 | if self.gpu_mode: 248 | # sample z 249 | z = Variable(torch.rand(self.batch_size, self.z_dim)).cuda() 250 | # X is a real image from the dataset 251 | X = data 252 | X = Variable(X).cuda() 253 | else: 254 | z = Variable(torch.rand(self.batch_size, self.z_dim)) 255 | X = data 256 | X = Variable(X) 257 | 258 | # sometimes bathsize of X is not equal to actual batch_size 259 | if X.size(0) == self.batch_size: 260 | 261 | if self.network_type == 'CNN': 262 | if self.dataset == 'robot_world': 263 | X = X.view(self.batch_size,3,16,16) 264 | 265 | z_hat = self.E(X) 266 | X_hat = self.G(z) 267 | 268 | D_enc = self.D(X, z_hat) 269 | z = z.unsqueeze(2).unsqueeze(3) 270 | D_gen = self.D(X_hat, z) 271 | 272 | elif self.network_type == 'FC': 273 | X = X.view(self.batch_size, -1) 274 | z_hat = self.E(X) 275 | X_hat = self.G(z) 276 | 277 | D_enc = self.D_(X, z_hat) 278 | D_gen = self.D_(X_hat, z) 279 | 280 | 281 | D_loss = -torch.mean(log(D_enc) + log(1 - D_gen)) 282 | G_loss = -torch.mean(log(D_gen) + log(1 - D_enc)) 283 | 284 | D_loss.backward(retain_graph=True) 285 | self.D_solver.step() 286 | self.reset_grad() 287 | 288 | G_loss.backward() 289 | self.G_solver.step() 290 | self.reset_grad() 291 | 292 | train_loss_G += G_loss.data[0] 293 | train_loss_D += D_loss.data[0] 294 | 295 | if batch_id % 1000 == 0: 296 | # Print and plot every now and then 297 | samples = X_hat.data.cpu().numpy() 298 | 299 | fig = plt.figure(figsize=(8, 4)) 300 | gs = gridspec.GridSpec(4, 8) 301 | gs.update(wspace=0.05, hspace=0.05) 302 | 303 | for i, sample in enumerate(samples): 304 | if i<32: 305 | ax = plt.subplot(gs[i]) 306 | plt.axis('off') 307 | ax.set_xticklabels([]) 308 | ax.set_yticklabels([]) 309 | ax.set_aspect('equal') 310 | 311 | if self.network_type == 'FC': 312 | if self.dataset == 'mnist': 313 | sample = sample.reshape(28, 28) 314 | plt.imshow(sample, cmap='Greys_r') 315 | elif self.dataset == 'robot_world': 316 | sample = sample.reshape(16,16,3) 317 | sample = np.rot90(sample, 2) 318 | plt.imshow(sample) 319 | elif self.network_type == 'CNN': 320 | if self.dataset == 'mnist': 321 | plt.imshow(sample[0,:,:], cmap='Greys_r') 322 | elif self.dataset == 'robot_world': 323 | sample = np.clip(sample, 0, 1) 324 | sample = sample.reshape(16,16,3) 325 | sample = np.rot90(sample, 2) 326 | plt.imshow(sample) 327 | 328 | 329 | if not os.path.exists(self.result_dir + '/train/'): 330 | os.makedirs(self.result_dir + '/train/') 331 | 332 | filename = "epoch_" + str(epoch) + "_batchid_" + str(batch_id) 333 | plt.savefig(self.result_dir + '/train/{}.png'.format(filename, bbox_inches='tight')) 334 | plt.close() 335 | 336 | print("Train loss G:", train_loss_G / len(dataset.train_loader)) 337 | print("Train loss D:", train_loss_D / len(dataset.train_loader)) 338 | 339 | self.train_hist['D_loss'].append(train_loss_D / len(dataset.train_loader)) 340 | self.train_hist['G_loss'].append(train_loss_G / len(dataset.train_loader)) 341 | 342 | 343 | self.D.eval() 344 | self.E.eval() 345 | self.G.eval() 346 | test_loss_G = 0 347 | test_loss_D = 0 348 | 349 | mean_pixel_norm = 0 350 | mean_z_norm = 0 351 | norm_counter = 1 352 | 353 | for batch_id, (data, target) in enumerate(dataset.test_loader): 354 | # Sample data 355 | z = Variable(torch.rand(self.batch_size, self.z_dim)) 356 | X_data = Variable(data) 357 | 358 | if self.gpu_mode: 359 | z = z.cuda() 360 | X_data = X_data.cuda() 361 | 362 | if X_data.size(0) == self.batch_size: 363 | X = X_data 364 | if self.network_type == 'CNN': 365 | if self.dataset == 'robot_world': 366 | X = X.view(self.batch_size,3,16,16) 367 | z_hat = self.E(X) 368 | z_hat = z_hat.view(self.batch_size, -1) 369 | X_hat = self.G(z) 370 | 371 | z = z.unsqueeze(2).unsqueeze(3) 372 | 373 | D_enc = self.D(X, z_hat) 374 | D_gen = self.D(X_hat, z) 375 | 376 | elif self.network_type == 'FC': 377 | X = X.view(self.batch_size, -1) 378 | z_hat = self.E(X) 379 | X_hat = self.G(z) 380 | 381 | D_enc = self.D_(X, z_hat) 382 | D_gen = self.D_(X_hat, z) 383 | 384 | D_loss = -torch.mean(log(D_enc) + log(1 - D_gen)) 385 | G_loss = -torch.mean(log(D_gen) + log(1 - D_enc)) 386 | 387 | test_loss_G += G_loss.data[0] 388 | test_loss_D += D_loss.data[0] 389 | 390 | pixel_norm = X - self.G(z_hat) 391 | pixel_norm = pixel_norm.norm().data[0] / float(self.X_dim) 392 | mean_pixel_norm += pixel_norm 393 | 394 | 395 | z_norm = z - self.E(X_hat) 396 | z_norm = z_norm.norm().data[0] / float(self.z_dim) 397 | mean_z_norm += z_norm 398 | 399 | norm_counter += 1 400 | 401 | 402 | print("Eval loss G:", test_loss_G / norm_counter) 403 | print("Eval loss D:", test_loss_D / norm_counter) 404 | 405 | self.eval_hist['D_loss'].append(test_loss_D / norm_counter) 406 | self.eval_hist['G_loss'].append(test_loss_G / norm_counter) 407 | 408 | print("Pixel norm:", mean_pixel_norm / norm_counter) 409 | self.eval_hist['pixel_norm'].append( mean_pixel_norm / norm_counter ) 410 | 411 | with open('pixel_error_BIGAN.txt', 'a') as f: 412 | f.writelines(str(mean_pixel_norm / norm_counter) + '\n') 413 | 414 | print("z norm:", mean_z_norm / norm_counter) 415 | self.eval_hist['z_norm'].append( mean_z_norm / norm_counter ) 416 | 417 | with open('z_error_BIGAN.txt', 'a') as f: 418 | f.writelines(str(mean_z_norm / norm_counter) + '\n') 419 | 420 | ##### At the end of the epoch, save X and its reconstruction G(E(X)) 421 | samples = X.data.cpu().numpy() 422 | 423 | fig = plt.figure(figsize=(10, 2)) 424 | gs = gridspec.GridSpec(2, 10) 425 | gs.update(wspace=0.05, hspace=0.05) 426 | 427 | for i, sample in enumerate(samples): 428 | if i<10: 429 | ax = plt.subplot(gs[i]) 430 | plt.axis('off') 431 | ax.set_xticklabels([]) 432 | ax.set_yticklabels([]) 433 | ax.set_aspect('equal') 434 | if self.network_type == 'FC': 435 | if self.dataset == 'mnist': 436 | sample = sample.reshape(28, 28) 437 | plt.imshow(sample, cmap='Greys_r') 438 | elif self.dataset == 'robot_world': 439 | sample = sample.reshape(16,16,3) 440 | sample = np.rot90(sample, 2) 441 | plt.imshow(sample) 442 | elif self.network_type == 'CNN': 443 | if self.dataset == 'mnist': 444 | plt.imshow(sample[0,:,:], cmap='Greys_r') 445 | elif self.dataset == 'robot_world': 446 | sample = sample.reshape(16,16,3) 447 | sample = np.rot90(sample, 2) 448 | plt.imshow(sample) 449 | 450 | X_hat = self.G(self.E(X).view(self.batch_size, self.z_dim)) 451 | samples = X_hat.data.cpu().numpy() 452 | 453 | 454 | for i, sample in enumerate(samples): 455 | if i<10: 456 | ax = plt.subplot(gs[10+i]) 457 | plt.axis('off') 458 | ax.set_xticklabels([]) 459 | ax.set_yticklabels([]) 460 | ax.set_aspect('equal') 461 | if self.network_type == 'FC': 462 | if self.dataset == 'mnist': 463 | sample = sample.reshape(28, 28) 464 | plt.imshow(sample, cmap='Greys_r') 465 | elif self.dataset == 'robot_world': 466 | sample = sample.reshape(16,16,3) 467 | sample = np.rot90(sample, 2) 468 | plt.imshow(sample) 469 | elif self.network_type == 'CNN': 470 | if self.dataset == 'mnist': 471 | plt.imshow(sample[0,:,:], cmap='Greys_r') 472 | elif self.dataset == 'robot_world': 473 | sample = sample.reshape(16,16,3) 474 | sample = np.clip(sample, 0, 1) 475 | sample = np.rot90(sample, 2) 476 | plt.imshow(sample) 477 | 478 | if not os.path.exists(self.result_dir + '/recons/'): 479 | os.makedirs(self.result_dir + '/recons/') 480 | 481 | filename = "epoch_" + str(epoch) 482 | plt.savefig(self.result_dir + '/recons/{}.png'.format(filename), bbox_inches='tight') 483 | plt.close() 484 | 485 | if epoch % 10 == 0: 486 | self.plot_states(epoch) 487 | 488 | save_plot_losses(self.train_hist['D_loss'], self.train_hist['G_loss'], self.eval_hist['D_loss'], self.eval_hist['G_loss'], self.network_type, self.z_dim, self.epoch, self.learning_rate, self.batch_size) 489 | save_plot_pixel_norm(self.eval_hist['pixel_norm'], self.network_type, self.z_dim, self.epoch, self.learning_rate, self.batch_size) 490 | save_plot_z_norm(self.eval_hist['z_norm'], self.network_type, self.z_dim, self.epoch, self.learning_rate, self.batch_size) 491 | 492 | def save_model(self): 493 | torch.save(self.G.state_dict(), self.save_dir + "/G.pt") 494 | torch.save(self.E.state_dict(), self.save_dir + "/E.pt") 495 | torch.save(self.D.state_dict(), self.save_dir + "/D.pt") 496 | 497 | def load_model(self, args): 498 | if args.network_type == 'FC': 499 | # networks init 500 | self.G = Generator_FC(self.z_dim, self.h_dim, self.X_dim) 501 | self.D = Discriminator_FC(self.z_dim, self.h_dim, self.X_dim) 502 | self.E = Encoder_FC(self.z_dim, self.h_dim, self.X_dim) 503 | elif args.network_type == 'CNN': 504 | params = {'slope': self.slope, 'dropout':self.dropout, 'batch_size':self.batch_size, 'num_channels':self.num_channels, 'dataset':self.dataset} 505 | 506 | self.G = Generator_CNN(self.z_dim, self.h_dim, self.X_dim, params) 507 | self.D = Discriminator_CNN(self.z_dim, self.h_dim, self.X_dim, params) 508 | self.E = Encoder_CNN(self.z_dim, self.h_dim, self.X_dim, params) 509 | 510 | self.G.load_state_dict(torch.load("models/G.pt")) 511 | self.E.load_state_dict(torch.load("models/E.pt")) 512 | self.D.load_state_dict(torch.load("models/D.pt")) 513 | 514 | if self.gpu_mode: 515 | self.G.cuda() 516 | self.D.cuda() 517 | self.E.cuda() 518 | 519 | def plot_states(self, i): 520 | # plot the representation of the latent space by running through all the evaluation dataset_transform 521 | if self.dataset == 'robot_world': 522 | test_data = np.load(self.dataset_path + '/simple_navigation_task_test.npz') 523 | 524 | test_observations, actions = test_data['observations'], test_data['actions'] 525 | rewards, episode_starts = test_data['rewards'], test_data['episode_starts'] 526 | test_obs_dim = reduce(lambda x,y: x*y, test_observations.shape[1:]) 527 | 528 | if len(test_observations.shape) > 2: 529 | # Channel first 530 | test_observations = np.transpose(test_observations, (0, 3, 1, 2)) 531 | # Flatten the image 532 | test_observations = test_observations.reshape((-1, test_obs_dim)) 533 | 534 | 535 | test_observations = test_observations.astype(np.float32) 536 | 537 | obs_var = Variable(torch.from_numpy(test_observations), volatile=True) 538 | if self.gpu_mode: 539 | obs_var = obs_var.cuda() 540 | 541 | num_samples = test_observations.shape[0] - 1 542 | 543 | print("NUM SAMPLES IS " + str(num_samples)) 544 | 545 | # indices for all time steps where the episode continues 546 | indices = np.array([i for i in range(num_samples)], dtype='int64') 547 | 548 | # split indices into minibatches 549 | minibatchlist = [np.array(sorted(indices[start_idx:start_idx + self.batch_size])) 550 | for start_idx in range(0, num_samples - self.batch_size + 1, self.batch_size)] 551 | 552 | enumerated_minibatches = list(enumerate(minibatchlist)) 553 | 554 | for it, batch in enumerated_minibatches: 555 | obs = Variable(torch.from_numpy(test_observations[batch]).float()) 556 | 557 | # Sample data 558 | if self.gpu_mode: 559 | X = obs.cuda() 560 | else: 561 | X = batch 562 | 563 | if self.network_type == 'CNN': 564 | X = X.view(self.batch_size, 3, 16, 16) 565 | 566 | z_hat = self.E(X) 567 | z_hat = z_hat.view(self.batch_size, self.z_dim) 568 | 569 | 570 | if it==0: 571 | states = z_hat.data.cpu().numpy() 572 | else: 573 | states = np.vstack((states , z_hat.data.cpu().numpy() )) 574 | 575 | 576 | rewards = test_data['rewards'] 577 | rewards = rewards[:len(states)] 578 | 579 | print("LEN OF REWARDS IS : ", len(rewards)) 580 | print('LEN OF STATES IS : ', len(states)) 581 | 582 | if self.z_dim == 2: 583 | plot_representation2(states, rewards, self.network_type, self.z_dim, self.epoch, self.learning_rate, self.batch_size, i) 584 | else: 585 | plot_representation(states, rewards, self.network_type, self.z_dim, self.epoch, self.learning_rate, self.batch_size, i) 586 | plot_representation2(states, rewards, self.network_type, self.z_dim, self.epoch, self.learning_rate, self.batch_size, i) 587 | 588 | 589 | 590 | def plot_z_distribution(z, model_used, z_dim, epochs, lr, batch_size): 591 | # plotting the distribution of the components of the latent vector dimension by dimension 592 | if not os.path.exists('histograms'): 593 | os.makedirs('histograms') 594 | 595 | for i in range(z.shape[1]): 596 | fig = plt.figure() 597 | # the histogram of the data 598 | n, bins, patches = plt.hist(z[:,i], 50, normed=1, facecolor='orange', alpha=0.75) 599 | 600 | 601 | plt.xlabel('z_' + str(i)) 602 | plt.ylabel('Probability') 603 | plt.suptitle(r'Histogram of z distribution in dim ' + str(i)) 604 | params = "Network type: " + model_used + ", Dimension of latent space: " + str(z_dim) + ", epochs: " + str(epochs) + ", learning rate: " + str(lr) + ", batch size:" + str(batch_size) 605 | plt.title(params, fontsize=8) 606 | plt.grid(True) 607 | 608 | plt.savefig("histograms/histogram_z_" + str(i) + ".eps", format='eps', dpi=1000) 609 | plt.close() 610 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # bigan_SRL 2 | Testing BIGAN (Adversarial Feature Learning) for State Representation Learning 3 | 4 | 5 | This is a PyTorch implementation of a BIGAN Network described in the paper "Adversarial Feature Learning" by J. Donahue, P. Krahenbuhl, T. Darrell. 6 | 7 | This program will be tested on datasets from "Learning State Representations with Robotic Priors" (Jonschkowski & Brock, 2015), 8 | https://github.com/tu-rbo/learning-state-representations-with-robotic-priors 9 | 10 | ### Learn a state representation 11 | 12 | Usage: 13 | ``` 14 | python3 main.py [-h] [--dataset {mnist,robot_world}] 15 | [--dataset_path DATASET_PATH] [--gpu_mode GPU_MODE] 16 | [--save_dir SAVE_DIR] [--result_dir RESULT_DIR] 17 | [--log_dir LOG_DIR] [--epoch EPOCH] [--batch_size BATCH_SIZE] 18 | [--lr LR] [--beta1 BETA1] [--beta2 BETA2] [--slope SLOPE] 19 | [--decay DECAY] [--dropout DROPOUT] [--network_type {FC,CNN}] 20 | [--z_dim Z_DIM] [--h_dim H_DIM] 21 | 22 | 23 | Pytorch implementation of BIGAN 24 | 25 | optional arguments: 26 | -h, --help show this help message and exit 27 | --dataset {mnist,robot_world} 28 | The name of dataset 29 | --dataset_path DATASET_PATH 30 | --gpu_mode GPU_MODE 31 | --save_dir SAVE_DIR Directory name to save the model 32 | --result_dir RESULT_DIR 33 | Directory name to save the generated images 34 | --log_dir LOG_DIR Directory name to save training logs 35 | --epoch EPOCH The number of epochs to run 36 | --batch_size BATCH_SIZE 37 | The size of batch 38 | --lr LR Learning rate 39 | --beta1 BETA1 for adam 40 | --beta2 BETA2 for adam 41 | --slope SLOPE for leaky ReLU 42 | --decay DECAY for weight decay 43 | --dropout DROPOUT 44 | --network_type {FC,CNN} 45 | Type of network (Fully connectec or CNN) 46 | --z_dim Z_DIM The dimension of latent space Z 47 | --h_dim H_DIM The dimension of the hidden layers in case of a FC 48 | network 49 | ``` 50 | 51 | 52 | Example: 53 | ``` 54 | python3 main.py --network_type FC --dataset robot_world --epoch 50 --z_dim 2 55 | ``` 56 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Testing BIGAN (Adversarial Feature Learning) for State Representation Learning 3 | 4 | This is a PyTorch implementation of a BIGAN Network described in the paper "Adversarial Feature Learning" by J. Donahue, P. Krahenbuhl, T. Darrell. 5 | 6 | This program will be tested on datasets from "Learning State Representations with Robotic Priors" (Jonschkowski & Brock, 2015), https://github.com/tu-rbo/learning-state-representations-with-robotic-priors 7 | 8 | 9 | """ 10 | 11 | 12 | import argparse, os 13 | from BIGAN import BIGAN 14 | 15 | """parsing and configuration""" 16 | def parse_args(): 17 | desc = "Pytorch implementation of BIGAN" 18 | parser = argparse.ArgumentParser(description=desc) 19 | 20 | 21 | parser.add_argument('--dataset', type=str, default='mnist', choices=['mnist', 'robot_world'], help='The name of dataset') 22 | parser.add_argument('--dataset_path', type=str, default='/home/williamb/project/bigan_implementation/m_bigan_learner/racecar_dataset') 23 | parser.add_argument('--gpu_mode', type=bool, default=True) 24 | 25 | # logging 26 | parser.add_argument('--save_dir', type=str, default='models', help='Directory name to save the model') 27 | parser.add_argument('--result_dir', type=str, default='results', help='Directory name to save the generated images') 28 | parser.add_argument('--log_dir', type=str, default='logs', help='Directory name to save training logs') 29 | 30 | # hyperparameters 31 | parser.add_argument('--epoch', type=int, default=25, help='The number of epochs to run') 32 | parser.add_argument('--batch_size', type=int, default=64, help='The size of batch') 33 | parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate') 34 | parser.add_argument('--beta1', type=float, default=0.5, help='for adam') 35 | parser.add_argument('--beta2', type=float, default=0.999, help='for adam') 36 | parser.add_argument('--slope', type=float, default=1e-2, help='for leaky ReLU') 37 | parser.add_argument('--decay', type=float, default=2.5*1e-5, help='for weight decay') 38 | parser.add_argument('--dropout', type=float, default=0.2) 39 | 40 | # network parameters 41 | parser.add_argument('--network_type', type=str, default='FC', choices=['FC', 'CNN'], help='Type of network (Fully connectec or CNN)') 42 | parser.add_argument('--z_dim', type=int, default=50, help='The dimension of latent space Z') 43 | parser.add_argument('--h_dim', type=int, default=1024, help='The dimension of the hidden layers in case of a FC network') 44 | 45 | return check_args(parser.parse_args()) 46 | 47 | """checking arguments""" 48 | def check_args(args): 49 | # --save_dir 50 | if not os.path.exists(args.save_dir): 51 | os.makedirs(args.save_dir) 52 | 53 | # --result_dir 54 | if not os.path.exists(args.result_dir): 55 | os.makedirs(args.result_dir) 56 | 57 | # --result_dir 58 | if not os.path.exists(args.log_dir): 59 | os.makedirs(args.log_dir) 60 | 61 | # --epoch 62 | try: 63 | assert args.epoch >= 1 64 | except: 65 | print('number of epochs must be larger than or equal to one') 66 | 67 | # --batch_size 68 | try: 69 | assert args.batch_size >= 1 70 | except: 71 | print('batch size must be larger than or equal to one') 72 | 73 | return args 74 | 75 | """main""" 76 | def main(): 77 | # parse arguments 78 | args = parse_args() 79 | if args is None: 80 | exit() 81 | 82 | bigan = BIGAN(args) 83 | 84 | # ecrase anciens fichiers 85 | with open('pixel_error_BIGAN.txt', 'w') as f: 86 | f.writelines('') 87 | with open('z_error_BIGAN.txt', 'w') as f: 88 | f.writelines('') 89 | 90 | bigan.train() 91 | print(" [*] Training finished!") 92 | 93 | bigan.save_model() 94 | 95 | 96 | bigan.plot_states() 97 | 98 | if __name__ == '__main__': 99 | main() 100 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import utils 4 | 5 | class Generator_CNN(nn.Module): 6 | """ 7 | CNN to model the generator of a BIGAN 8 | Input is a vector from representation space of dimension z_dim 9 | output is a vector from image space of dimension X_dim 10 | """ 11 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 12 | # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S 13 | def __init__(self, z_dim, h_dim, X_dim, params): 14 | super(Generator_CNN, self).__init__() 15 | 16 | self.input_height = 28 17 | self.input_width = 28 18 | self.input_dim = z_dim 19 | self.output_dim = 1 20 | self.slope = params['slope'] 21 | self.dropout = params['dropout'] 22 | self.num_channels = self.dropout = params['num_channels'] 23 | self.dataset = params['dataset'] 24 | 25 | 26 | if self.dataset == 'mnist': 27 | self.inference = nn.Sequential( 28 | # input dim: z_dim x 1 x 1 29 | nn.ConvTranspose2d(z_dim, 256, 4, stride=1, bias=True), 30 | nn.BatchNorm2d(256), 31 | nn.LeakyReLU(self.slope, inplace=True), 32 | # state dim: 256 x 4 x 4 33 | nn.ConvTranspose2d(256, 128, 4, stride=2, bias=True), 34 | nn.BatchNorm2d(128), 35 | nn.LeakyReLU(self.slope, inplace=True), 36 | # state dim: 128 x 10 x 10 37 | nn.ConvTranspose2d(128, 64, 4, stride=1, bias=True), 38 | nn.BatchNorm2d(64), 39 | nn.LeakyReLU(self.slope, inplace=True), 40 | # state dim: 64 x 13 x 13 41 | nn.ConvTranspose2d(64, 32, 4, stride=2, bias=True), 42 | nn.BatchNorm2d(32), 43 | nn.LeakyReLU(self.slope, inplace=True), 44 | # state dim: 32 x 28 x 28 45 | nn.Conv2d(32, self.num_channels, 1, stride=1, bias=True), 46 | # output dim: num_channels x 28 x 28 47 | nn.Tanh() 48 | ) 49 | elif self.dataset == 'robot_world': 50 | self.inference = nn.Sequential( 51 | # input dim: z_dim x 1 x 1 52 | nn.ConvTranspose2d(z_dim, 256, 4, stride=1, bias=True), 53 | nn.BatchNorm2d(256), 54 | nn.LeakyReLU(self.slope, inplace=True), 55 | # state dim: 256 x 4 x 4 56 | nn.ConvTranspose2d(256, 128, 4, stride=2, bias=True), 57 | nn.BatchNorm2d(128), 58 | nn.LeakyReLU(self.slope, inplace=True), 59 | # state dim: 128 x 10 x 10 60 | nn.ConvTranspose2d(128, 64, 4, stride=1, bias=True), 61 | nn.BatchNorm2d(64), 62 | nn.LeakyReLU(self.slope, inplace=True), 63 | # state dim: 64 x 13 x 13 64 | nn.ConvTranspose2d(64, self.num_channels, 4, stride=1, bias=True), 65 | nn.Sigmoid() 66 | ) 67 | 68 | utils.initialize_weights(self) 69 | 70 | def forward(self, input): 71 | z = input.unsqueeze(2).unsqueeze(3) 72 | x = self.inference(z) 73 | 74 | return x 75 | 76 | class Encoder_CNN(nn.Module): 77 | """ 78 | CNN to model the encoder of a BIGAN 79 | Input is vector X from image space if dimension X_dim 80 | Output is vector z from representation space of dimension z_dim 81 | """ 82 | def __init__(self, z_dim, h_dim, X_dim, params): 83 | super(Encoder_CNN, self).__init__() 84 | 85 | self.input_height = 28 86 | self.input_width = 28 87 | self.input_dim = 1 88 | self.output_dim = z_dim 89 | 90 | self.slope = params['slope'] 91 | self.dropout = params['dropout'] 92 | self.num_channels = self.dropout = params['num_channels'] 93 | self.dataset = params['dataset'] 94 | 95 | if self.dataset == 'mnist': 96 | self.inference = nn.Sequential( 97 | # input dim: num_channels x 32 x 32 98 | nn.Conv2d(self.num_channels, 32, 3, stride=1, padding=1, bias=True), 99 | nn.BatchNorm2d(32), 100 | nn.LeakyReLU(self.slope, inplace=True), 101 | # state dim: 32 x 28 x 28 102 | nn.Conv2d(32, 64, 4, stride=2, bias=True), 103 | nn.BatchNorm2d(64), 104 | nn.LeakyReLU(self.slope, inplace=True), 105 | # state dim: 64 x 13 x 13 106 | nn.Conv2d(64, 128, 4, stride=1, bias=True), 107 | nn.BatchNorm2d(128), 108 | nn.LeakyReLU(self.slope, inplace=True), 109 | # state dim: 128 x 10 x 10 110 | nn.Conv2d(128, 256, 4, stride=2, bias=True), 111 | nn.BatchNorm2d(256), 112 | nn.LeakyReLU(self.slope, inplace=True), 113 | # state dim: 256 x 4 x 4 114 | nn.Conv2d(256, 512, 4, stride=1, bias=True), 115 | nn.BatchNorm2d(512), 116 | nn.LeakyReLU(self.slope, inplace=True), 117 | # state dim: 512 x 1 x 1 118 | nn.Conv2d(512, 512, 1, stride=1, bias=True), 119 | nn.BatchNorm2d(512), 120 | nn.LeakyReLU(self.slope, inplace=True), 121 | # state dim: 512 x 1 x 1 122 | nn.Conv2d(512, z_dim, 1, stride=1, bias=True) 123 | # output dim: opt.z_dim x 1 x 1 124 | ) 125 | elif self.dataset == 'robot_world': 126 | self.inference = nn.Sequential( 127 | # input dim: num_channels x 16 x 16 128 | nn.Conv2d(self.num_channels, 64, 4, stride=1, padding=0, bias=True), 129 | # state dim: 64 x 13 x 13 130 | nn.Conv2d(64, 128, 4, stride=1, bias=True), 131 | nn.BatchNorm2d(128), 132 | nn.LeakyReLU(self.slope, inplace=True), 133 | # state dim: 128 x 10 x 10 134 | nn.Conv2d(128, 256, 4, stride=2, bias=True), 135 | nn.BatchNorm2d(256), 136 | nn.LeakyReLU(self.slope, inplace=True), 137 | # state dim: 256 x 4 x 4 138 | nn.Conv2d(256, 512, 4, stride=1, bias=True), 139 | nn.BatchNorm2d(512), 140 | nn.LeakyReLU(self.slope, inplace=True), 141 | # state dim: 512 x 1 x 1 142 | nn.Conv2d(512, 512, 1, stride=1, bias=True), 143 | nn.BatchNorm2d(512), 144 | nn.LeakyReLU(self.slope, inplace=True), 145 | # state dim: 512 x 1 x 1 146 | nn.Conv2d(512, z_dim, 1, stride=1, bias=True) 147 | # output dim: opt.z_dim x 1 x 1 148 | ) 149 | 150 | utils.initialize_weights(self) 151 | 152 | def forward(self, input): 153 | x = self.inference(input) 154 | 155 | return x 156 | 157 | class Discriminator_CNN(nn.Module): 158 | """ 159 | CNN to model the discriminator of a BIGAN 160 | Input is tuple (X,z) of an image vector and its corresponding 161 | representation z vector. For example, if X comes from the dataset, corresponding 162 | z is Encoder(X), and if z is sampled from representation space, X is Generator(z) 163 | Output is a 1-dimensional value 164 | """ 165 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 166 | # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S 167 | def __init__(self, z_dim, h_dim, X_dim, params): 168 | super(Discriminator_CNN, self).__init__() 169 | 170 | self.z_dim = z_dim 171 | self.h_dim = h_dim 172 | self.X_dim = X_dim 173 | 174 | self.input_height = 28 175 | 176 | self.slope = params['slope'] 177 | self.dropout = params['dropout'] 178 | self.batch_size = params['batch_size'] 179 | self.num_channels = params['num_channels'] 180 | self.dataset = params['dataset'] 181 | 182 | if self.dataset == 'mnist': 183 | self.inference_x = nn.Sequential( 184 | # state dim: num_channels 28 x 28 185 | nn.Conv2d(self.num_channels, 64, 4, stride=1, bias=True), 186 | nn.BatchNorm2d(64), 187 | nn.LeakyReLU(self.slope, inplace=True), 188 | nn.Dropout2d(p=self.dropout), 189 | # state dim: 64 x 13 x 13 190 | nn.Conv2d(64, 128, 4, stride=1, bias=True), 191 | nn.BatchNorm2d(128), 192 | nn.LeakyReLU(self.slope, inplace=True), 193 | nn.Dropout2d(p=self.dropout), 194 | # state dim: 128 x 10 x 10 195 | nn.Conv2d(128, 256, 4, stride=2, bias=True), 196 | nn.BatchNorm2d(256), 197 | nn.LeakyReLU(self.slope, inplace=True), 198 | nn.Dropout2d(p=self.dropout), 199 | # state dim: 256 x 4 x 4 200 | nn.Conv2d(256, 512, 4, stride=1, bias=True), 201 | nn.BatchNorm2d(512), 202 | nn.LeakyReLU(self.slope, inplace=True), 203 | nn.Dropout2d(p=self.dropout) 204 | # output dim: 512 x 1 x 1 205 | ) 206 | elif self.dataset == 'robot_world': 207 | self.inference_x = nn.Sequential( 208 | # state dim: num_channels x 16 x 16 209 | nn.Conv2d(self.num_channels, 64, 4, stride=1, bias=True), 210 | nn.BatchNorm2d(64), 211 | nn.LeakyReLU(self.slope, inplace=True), 212 | nn.Dropout2d(p=self.dropout), 213 | # state dim: 64 x 13 x 13 214 | nn.Conv2d(64, 128, 4, stride=1, bias=True), 215 | nn.BatchNorm2d(128), 216 | nn.LeakyReLU(self.slope, inplace=True), 217 | nn.Dropout2d(p=self.dropout), 218 | # state dim: 128 x 10 x 10 219 | nn.Conv2d(128, 256, 4, stride=2, bias=True), 220 | nn.BatchNorm2d(256), 221 | nn.LeakyReLU(self.slope, inplace=True), 222 | nn.Dropout2d(p=self.dropout), 223 | # state dim: 256 x 4 x 4 224 | nn.Conv2d(256, 512, 4, stride=1, bias=True), 225 | nn.BatchNorm2d(512), 226 | nn.LeakyReLU(self.slope, inplace=True), 227 | nn.Dropout2d(p=self.dropout) 228 | # output dim: 512 x 1 x 1 229 | ) 230 | 231 | self.inference_joint = nn.Sequential( 232 | torch.nn.Linear(512 + self.z_dim, self.h_dim), 233 | nn.LeakyReLU(0.2), # torch.nn.ReLU(), 234 | torch.nn.Linear(self.h_dim, self.h_dim), 235 | nn.LeakyReLU(0.2), # torch.nn.ReLU(), 236 | torch.nn.Linear(self.h_dim, 1), 237 | torch.nn.Sigmoid() 238 | ) 239 | utils.initialize_weights(self) 240 | 241 | def forward(self, x, z): 242 | output_x = self.inference_x(x) 243 | output_x = output_x.view(self.batch_size, -1) 244 | 245 | output_z = z.view(self.batch_size, -1) 246 | 247 | output = self.inference_joint(torch.cat((output_x, output_z), 1)) 248 | return output 249 | 250 | 251 | 252 | 253 | class Generator_FC(nn.Module): 254 | """ 255 | Simple NN with one hidden layer of dimension h_dim 256 | Input is a vector from representation space of dimension z_dim 257 | output is a vector from image space of dimension X_dim 258 | """ 259 | def __init__(self, z_dim, h_dim, X_dim): 260 | super(Generator_FC, self).__init__() 261 | 262 | self.z_dim = z_dim 263 | self.h_dim = h_dim 264 | self.X_dim = X_dim 265 | 266 | self.fc = torch.nn.Sequential( 267 | torch.nn.Linear(z_dim, h_dim), 268 | torch.nn.ReLU(), 269 | torch.nn.Linear(h_dim, h_dim), 270 | torch.nn.ReLU(), 271 | torch.nn.Linear(h_dim, X_dim), 272 | torch.nn.BatchNorm1d(X_dim), 273 | torch.nn.Sigmoid() 274 | ) 275 | 276 | utils.initialize_weights(self) 277 | 278 | def forward(self, input): 279 | x = self.fc(input) 280 | return x 281 | 282 | 283 | class Discriminator_FC(nn.Module): 284 | """ 285 | Simple NN with one hidden layer of dimension h_dim 286 | Input is a tuple (X,z) of an image vector and its corresponding 287 | representation z vector. For example, if X comes from the dataset, corresponding 288 | z is Encoder(X), and if z is sampled from representation space, X is Generator(z) 289 | """ 290 | def __init__(self, z_dim, h_dim, X_dim): 291 | super(Discriminator_FC, self).__init__() 292 | 293 | self.z_dim = z_dim 294 | self.h_dim = h_dim 295 | self.X_dim = X_dim 296 | 297 | self.fc1 = torch.nn.Sequential( 298 | torch.nn.Linear(X_dim, z_dim), 299 | nn.LeakyReLU(0.2), 300 | ) 301 | 302 | self.fc = torch.nn.Sequential( 303 | torch.nn.Linear(2*z_dim, h_dim), 304 | nn.LeakyReLU(0.2), 305 | torch.nn.Linear(h_dim, h_dim), 306 | nn.LeakyReLU(0.2), 307 | torch.nn.Linear(h_dim, 1), 308 | torch.nn.Sigmoid() 309 | ) 310 | 311 | utils.initialize_weights(self) 312 | 313 | def forward(self, input_x, input_z): 314 | x = self.fc1(input_x) 315 | return self.fc(torch.cat([x, input_z], 1)) 316 | 317 | 318 | 319 | 320 | class Encoder_FC(nn.Module): 321 | """ 322 | Simple NN with one hidden layer of dimension h_dim 323 | Input is vector X from image space if dimension X_dim 324 | Output is vector z from representation space of dimension z_dim 325 | """ 326 | def __init__(self, z_dim, h_dim, X_dim): 327 | super(Encoder_FC, self).__init__() 328 | 329 | self.z_dim = z_dim 330 | self.h_dim = h_dim 331 | self.X_dim = X_dim 332 | 333 | self.fc = torch.nn.Sequential( 334 | torch.nn.Linear(X_dim, h_dim), 335 | nn.LeakyReLU(0.2), 336 | torch.nn.Linear(h_dim, h_dim), 337 | torch.nn.BatchNorm1d(h_dim), 338 | nn.LeakyReLU(0.2), 339 | torch.nn.Linear(h_dim, z_dim) 340 | ) 341 | 342 | utils.initialize_weights(self) 343 | 344 | def forward(self, input): 345 | x = self.fc(input) 346 | return x 347 | -------------------------------------------------------------------------------- /plot_utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | def save_plot_losses(train_D_loss, train_G_loss, eval_D_loss, eval_G_loss, model_used, z_dim, epochs, lr, batch_size): 5 | 6 | x = np.arange(1, len(train_D_loss) + 1) 7 | 8 | plt.figure(figsize=(8, 6)) 9 | plt.plot(x, train_G_loss, label="Train G loss", linewidth=2) 10 | plt.plot(x, eval_G_loss, label="Eval G loss", linewidth=2) 11 | 12 | 13 | plt.axes().set_xlabel('Epoch') 14 | plt.axes().set_ylabel('Loss') 15 | plt.legend(loc='upper right') 16 | plt.suptitle("Evolution of the Train and Eval losses of G") 17 | params = "Network type: " + model_used + ", Dimension of latent space: " + str(z_dim) + ", epochs: " + str(epochs) + ", learning rate: " + str(lr) + ", batch size:" + str(batch_size) 18 | plt.title(params, fontsize=8) 19 | # plt.show() 20 | plt.savefig("plot_G_losses.eps", format='eps', dpi=1000) 21 | 22 | 23 | 24 | plt.figure(figsize=(8, 6)) 25 | plt.plot(x, train_D_loss, label="Train D loss", linewidth=2) 26 | plt.plot(x, eval_D_loss, label="Eval D loss", linewidth=2) 27 | 28 | plt.axes().set_xlabel('Epoch') 29 | plt.axes().set_ylabel('Loss') 30 | plt.legend(loc='upper right') 31 | 32 | plt.suptitle("Evolution of the Train and Eval losses of D") 33 | 34 | params = "Network type: " + model_used + ", Dimension of latent space: " + str(z_dim) + ", epochs: " + str(epochs) + ", learning rate: " + str(lr) + ", batch size:" + str(batch_size) 35 | plt.title(params, fontsize=8) 36 | # plt.show() 37 | plt.savefig("plot_D_losses.eps", format='eps', dpi=1000) 38 | plt.close() 39 | 40 | 41 | def save_plot_pixel_norm(mean_pixel_norm, model_used, z_dim, epochs, lr, batch_size): 42 | x = np.arange(1, len(mean_pixel_norm) + 1) 43 | 44 | plt.figure(figsize=(8, 6)) 45 | plt.plot(x, mean_pixel_norm, label="Reconstruction error", linewidth=2) 46 | 47 | plt.axes().set_xlabel('Epoch') 48 | plt.axes().set_ylabel('Norm') 49 | plt.legend(loc='upper right') 50 | plt.suptitle("Evolution of the reconstruction error between X and G(E(X))") 51 | params = "Network type: " + model_used + ", Dimension of latent space: " + str(z_dim) + ", epochs: " + str(epochs) + ", learning rate: " + str(lr) + ", batch size:" + str(batch_size) 52 | plt.title(params, fontsize=8) 53 | # plt.show() 54 | plt.savefig("pix2pix_norm.eps", format='eps', dpi=1000) 55 | plt.close() 56 | 57 | def save_plot_z_norm(mean_z_norm, model_used, z_dim, epochs, lr, batch_size): 58 | x = np.arange(1, len(mean_z_norm) + 1) 59 | 60 | plt.figure(figsize=(8, 6)) 61 | plt.plot(x, mean_z_norm, label="Reconstruction error", linewidth=2) 62 | 63 | plt.axes().set_xlabel('Epoch') 64 | plt.axes().set_ylabel('Norm') 65 | plt.legend(loc='upper right') 66 | plt.suptitle("Evolution of the reconstruction error between z and E(G(z))") 67 | params = "Network type: " + model_used + ", Dimension of latent space: " + str(z_dim) + ", epochs: " + str(epochs) + ", learning rate: " + str(lr) + ", batch size:" + str(batch_size) 68 | plt.title(params, fontsize=8) 69 | # plt.show() 70 | plt.savefig("z_norm.eps", format='eps', dpi=1000) 71 | plt.close() 72 | -------------------------------------------------------------------------------- /representation_plot.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | 3 | import json 4 | import argparse 5 | from textwrap import fill 6 | 7 | import matplotlib.pyplot as plt 8 | import seaborn as sns 9 | import numpy as np 10 | from mpl_toolkits.mplot3d import Axes3D 11 | from sklearn.decomposition import PCA 12 | # from sklearn.manifold import TSNE 13 | # Faster implementation of t-SNE: 14 | # from MulticoreTSNE import MulticoreTSNE as TSNE 15 | 16 | # Python 2/3 compatibility 17 | try: 18 | input = raw_input 19 | except NameError: 20 | pass 21 | 22 | # Init seaborn 23 | sns.set() 24 | INTERACTIVE_PLOT = False 25 | TITLE_MAX_LENGTH = 50 26 | 27 | 28 | def updateDisplayMode(): 29 | """ 30 | Enable or disable interactive plot 31 | see: http://matplotlib.org/faq/usage_faq.html#what-is-interactive-mode 32 | """ 33 | if INTERACTIVE_PLOT: 34 | plt.ion() 35 | else: 36 | plt.ioff() 37 | 38 | 39 | def pauseOrClose(fig): 40 | """ 41 | :param fig: (matplotlib figure object) 42 | """ 43 | if INTERACTIVE_PLOT: 44 | plt.draw() 45 | plt.pause(0.0001) # Small pause to update the plot 46 | else: 47 | plt.close(fig) 48 | 49 | 50 | # def plot_tsne(states, rewards, name="T-SNE of Learned States", add_colorbar=True, path=None, 51 | # n_components=3, perplexity=100.0, learning_rate=200.0, n_iter=1000, cmap="coolwarm"): 52 | # """ 53 | # :param states: (numpy array) 54 | # :param rewards: (numpy 1D array) 55 | # :param name: (str) 56 | # :param add_colorbar: (bool) 57 | # :param path: (str) 58 | # :param n_components: (int) 59 | # :param perplexity: (float) 60 | # :param learning_rate: (float) 61 | # :param n_iter: (int) 62 | # :param cmap: (str) 63 | # """ 64 | # assert n_components in [2, 3], "You cannot apply t-SNE with n_components={}".format(n_components) 65 | # t_sne = TSNE(n_components=n_components, perplexity=perplexity, 66 | # learning_rate=learning_rate, n_iter=n_iter, verbose=1, n_jobs=4) 67 | # s_transformed = t_sne.fit_transform(states) 68 | # plot_representation(s_transformed, rewards, name, add_colorbar, path, cmap=cmap, fit_pca=False) 69 | 70 | 71 | def plot_representation(states, rewards, model_used, z_dim, epochs, lr, batch_size, i, name="Learned State Representation", 72 | add_colorbar=True, path=None, fit_pca=True, cmap='coolwarm'): 73 | """ 74 | Plot learned state representation using rewards for coloring 75 | :param states: (numpy array) 76 | :param rewards: (numpy 1D array) 77 | :param name: (str) 78 | :param add_colorbar: (bool) 79 | :param path: (str) 80 | :param fit_pca: (bool) 81 | :param cmap: (str) 82 | """ 83 | state_dim = states.shape[1] 84 | if state_dim != 1 and (fit_pca or state_dim > 3): 85 | name += " (PCA)" 86 | n_components = min(state_dim, 3) 87 | print("Fitting PCA with {} components".format(n_components)) 88 | states = PCA(n_components=n_components).fit_transform(states) 89 | 90 | if state_dim == 1: 91 | # Extend states as 2D: 92 | states_matrix = np.zeros((states.shape[0], 2)) 93 | states_matrix[:, 0] = states[:, 0] 94 | plot_2d_representation(states_matrix, rewards, model_used, z_dim, epochs, lr, batch_size) 95 | elif state_dim == 2: 96 | plot_2d_representation(states, rewards, model_used, z_dim, epochs, lr, batch_size) 97 | else: 98 | plot_3d_representation(states, rewards, model_used, z_dim, epochs, lr, batch_size, i) 99 | 100 | plt.savefig("representation_plot_3D_" +str(i) + ".eps", format='eps', dpi=1000) 101 | 102 | def plot_representation2(states, rewards, model_used, z_dim, epochs, lr, batch_size, i, name="Learned State Representation", 103 | add_colorbar=True, path=None, fit_pca=True, cmap='coolwarm'): 104 | """ 105 | Plot learned state representation using rewards for coloring 106 | :param states: (numpy array) 107 | :param rewards: (numpy 1D array) 108 | :param name: (str) 109 | :param add_colorbar: (bool) 110 | :param path: (str) 111 | :param fit_pca: (bool) 112 | :param cmap: (str) 113 | """ 114 | state_dim = states.shape[1] 115 | if state_dim != 1 and (fit_pca or state_dim > 3): 116 | name += " (PCA)" 117 | n_components = min(state_dim, 2) 118 | print("Fitting PCA with {} components".format(n_components)) 119 | states = PCA(n_components=n_components).fit_transform(states) 120 | 121 | if state_dim == 1: 122 | # Extend states as 2D: 123 | states_matrix = np.zeros((states.shape[0], 2)) 124 | states_matrix[:, 0] = states[:, 0] 125 | plot_2d_representation(states_matrix, rewards, model_used, z_dim, epochs, lr, batch_size, i) 126 | elif state_dim == 2: 127 | plot_2d_representation(states, rewards, model_used, z_dim, epochs, lr, batch_size, i) 128 | else: 129 | plot_2d_representation(states, rewards, model_used, z_dim, epochs, lr, batch_size, i) 130 | 131 | plt.savefig("representation_plot_2D_" +str(i) + ".eps", format='eps', dpi=1000) 132 | plt.close() 133 | 134 | def plot_2d_representation(states, rewards, model_used, z_dim, epochs, lr, batch_size, i, name="Learned State Representation", 135 | add_colorbar=True, path=None, cmap='coolwarm'): 136 | # updateDisplayMode() 137 | fig = plt.figure(name) 138 | plt.clf() 139 | plt.scatter(states[:, 0], states[:, 1], s=7, c=rewards, cmap=cmap, linewidths=0.1) 140 | plt.xlabel('State dimension 1') 141 | plt.ylabel('State dimension 2') 142 | 143 | # plt.suptitle(fill(name, TITLE_MAX_LENGTH)) 144 | params = "Network type: " + model_used + ", Dimension of latent space: " + str(z_dim) + ", epochs: " + str(i) + ", learning rate: " + str(lr) + ", batch size:" + str(batch_size) 145 | plt.title(params, fontsize=6) 146 | 147 | fig.tight_layout() 148 | if add_colorbar: 149 | plt.colorbar(label='Reward') 150 | if path is not None: 151 | plt.savefig(path) 152 | # pauseOrClose(fig) 153 | 154 | 155 | def plot_3d_representation(states, rewards, model_used, z_dim, epochs, lr, batch_size, i, name="Learned State Representation", 156 | add_colorbar=True, path=None, cmap='coolwarm'): 157 | updateDisplayMode() 158 | fig = plt.figure(name) 159 | plt.clf() 160 | ax = fig.add_subplot(111, projection='3d') 161 | im = ax.scatter(states[:, 0], states[:, 1], states[:, 2], 162 | s=7, c=rewards, cmap=cmap, linewidths=0.1) 163 | ax.set_xlabel('State dimension 1') 164 | ax.set_ylabel('State dimension 2') 165 | ax.set_zlabel('State dimension 3') 166 | 167 | # ax.set_title(fill(name, TITLE_MAX_LENGTH)) 168 | plt.suptitle(fill(name, TITLE_MAX_LENGTH)) 169 | params = "Network type: " + model_used + ", Dimension of latent space: " + str(z_dim) + ", epochs: " + str(i) + ", learning rate: " + str(lr) + ", batch size:" + str(batch_size) 170 | plt.title(params, fontsize=6) 171 | 172 | fig.tight_layout() 173 | if add_colorbar: 174 | fig.colorbar(im, label='Reward') 175 | if path is not None: 176 | plt.savefig(path) 177 | # pauseOrClose(fig) 178 | 179 | 180 | def plot_observations(observations, name='Observation Samples'): 181 | updateDisplayMode() 182 | fig = plt.figure(name) 183 | m, n = 8, 10 184 | for i in range(m * n): 185 | plt.subplot(m, n, i + 1) 186 | plt.imshow(observations[i].reshape(16, 16, 3), interpolation='nearest') 187 | plt.gca().invert_yaxis() 188 | plt.xticks([]) 189 | plt.yticks([]) 190 | pauseOrClose(fig) 191 | 192 | 193 | def plot_image(image, name='Observation Sample'): 194 | """ 195 | Display an image 196 | :param image: (numpy tensor) (with values in [0, 1]) 197 | :param name: (str) 198 | """ 199 | # Reorder channels 200 | if image.shape[0] == 3 and len(image.shape) == 3: 201 | # (n_channels, height, width) -> (width, height, n_channels) 202 | image = np.transpose(image, (2, 1, 0)) 203 | updateDisplayMode() 204 | fig = plt.figure(name) 205 | plt.imshow(image, interpolation='nearest') 206 | # plt.gca().invert_yaxis() 207 | plt.xticks([]) 208 | plt.yticks([]) 209 | pauseOrClose(fig) 210 | 211 | 212 | def colorPerEpisode(episode_starts): 213 | """ 214 | :param episode_starts: (numpy 1D array) 215 | :return: (numpy 1D array) 216 | """ 217 | colors = np.zeros(len(episode_starts)) 218 | color_idx = -1 219 | print(np.sum(episode_starts)) 220 | for i in range(len(episode_starts)): 221 | # New episode 222 | if episode_starts[i] == 1: 223 | color_idx += 1 224 | colors[i] = color_idx 225 | return colors 226 | 227 | 228 | def plot_against(states, rewards, title="Representation", fit_pca=False, cmap='coolwarm'): 229 | """ 230 | State dimensions are plotted one against the other (it creates a matrix of 2d representation) 231 | using rewards for coloring 232 | :param states: (numpy tensor) 233 | :param reward: (numpy array) 234 | :param title: (str) 235 | :param fit_pca: (bool) 236 | :param cmap: (str) 237 | """ 238 | n = states.shape[1] 239 | fig, ax_mat = plt.subplots(n, n, figsize=(10, 10), sharex=False, sharey=False) 240 | fig.subplots_adjust(hspace=0.0, wspace=0.0) 241 | 242 | if fit_pca: 243 | title += " (PCA)" 244 | states = PCA(n_components=n).fit_transform(states) 245 | 246 | for i in range(n): 247 | for j in range(n): 248 | x, y = states[:, i], states[:, j] 249 | ax = ax_mat[i, j] 250 | ax.scatter(x, y, c=rewards, cmap=cmap, s=5) 251 | ax.set_xlim([np.min(x), np.max(x)]) 252 | ax.set_ylim([np.min(y), np.max(y)]) 253 | 254 | # Hide ticks 255 | if i != 0 and i != n - 1: 256 | ax.xaxis.set_visible(False) 257 | if j != 0 and j != n - 1: 258 | ax.yaxis.set_visible(False) 259 | 260 | # Set up ticks only on one side for the "edge" subplots... 261 | if j == 0: 262 | ax.yaxis.set_ticks_position('left') 263 | if j == n - 1: 264 | ax.yaxis.set_ticks_position('right') 265 | if i == 0: 266 | ax.set_title("Dim {}".format(j), y=1.2) 267 | ax.xaxis.set_ticks_position('top') 268 | if i == n - 1: 269 | ax.xaxis.set_ticks_position('bottom') 270 | 271 | 272 | plt.suptitle(title, fontsize=16) 273 | plt.show() 274 | 275 | 276 | # if __name__ == '__main__': 277 | # parser = argparse.ArgumentParser(description='Plotting script for representation') 278 | # parser.add_argument('-i', '--input_file', type=str, default="", 279 | # help='Path to a npz file containing states and rewards') 280 | # parser.add_argument('--data_folder', type=str, default="", 281 | # help='Path to a dataset folder, it will plot ground truth states') 282 | # parser.add_argument('--t-sne', action='store_true', default=False, help='Use t-SNE instead of PCA') 283 | # parser.add_argument('--color-episode', action='store_true', default=False, 284 | # help='Color states per episodes instead of reward') 285 | # parser.add_argument('--plot-against', action='store_true', default=False, 286 | # help='Plot against each dimension') 287 | # args = parser.parse_args() 288 | # 289 | # cmap = "tab20" if args.color_episode else "coolwarm" 290 | # assert not (args.color_episode and args.data_folder == ""),\ 291 | # "You must specify a datafolder when using per-episode color" 292 | # # Remove `data/` from the path if needed 293 | # if "data/" in args.data_folder: 294 | # args.data_folder = args.data_folder.split('data/')[1].strip("/") 295 | # 296 | # if args.input_file != "": 297 | # print("Loading {}...".format(args.input_file)) 298 | # states_rewards = np.load(args.input_file) 299 | # rewards = states_rewards['rewards'] 300 | # if args.color_episode: 301 | # episode_starts = np.load('data/{}/preprocessed_data.npz'.format(args.data_folder))['episode_starts'] 302 | # rewards = colorPerEpisode(episode_starts)[:len(rewards)] 303 | # 304 | # if args.t_sne: 305 | # print("Using t-SNE...") 306 | # plot_tsne(states_rewards['states'], rewards, cmap=cmap) 307 | # elif args.plot_against: 308 | # print("Plotting against") 309 | # plot_against(states_rewards['states'], rewards, cmap=cmap) 310 | # else: 311 | # plot_representation(states_rewards['states'], rewards, cmap=cmap) 312 | # input('\nPress any key to exit.') 313 | # 314 | # elif args.data_folder != "": 315 | # 316 | # print("Plotting ground truth...") 317 | # training_data = np.load('data/{}/preprocessed_data.npz'.format(args.data_folder)) 318 | # ground_truth = np.load('data/{}/ground_truth.npz'.format(args.data_folder)) 319 | # true_states = ground_truth['arm_states'] 320 | # name = "Ground Truth States - {}".format(args.data_folder) 321 | # episode_starts, rewards = training_data['episode_starts'], training_data['rewards'] 322 | # button_positions = ground_truth['button_positions'] 323 | # with open('data/{}/dataset_config.json'.format(args.data_folder), 'r') as f: 324 | # relative_pos = json.load(f).get('relative_pos', False) 325 | # 326 | # # True state is the relative position to the button 327 | # if relative_pos: 328 | # button_idx = -1 329 | # for i in range(len(episode_starts)): 330 | # if episode_starts[i] == 1: 331 | # button_idx += 1 332 | # true_states[i] -= button_positions[button_idx] 333 | # 334 | # if args.color_episode: 335 | # rewards = colorPerEpisode(episode_starts) 336 | # 337 | # if args.plot_against: 338 | # plot_against(true_states, rewards, cmap=cmap) 339 | # else: 340 | # plot_representation(true_states, rewards, name, fit_pca=False, cmap=cmap) 341 | # input('\nPress any key to exit.') 342 | # 343 | # else: 344 | # print("You must specify one of --input_file or --data_folder") 345 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # import os, gzip, torch 2 | import torch.nn as nn 3 | # import numpy as np 4 | # import scipy.misc 5 | # import imageio 6 | # import matplotlib.pyplot as plt 7 | # from torchvision import datasets, transforms 8 | 9 | # def load_mnist(dataset): 10 | # data_dir = os.path.join("./data", dataset) 11 | 12 | # def extract_data(filename, num_data, head_size, data_size): 13 | # with gzip.open(filename) as bytestream: 14 | # bytestream.read(head_size) 15 | # buf = bytestream.read(data_size * num_data) 16 | # data = np.frombuffer(buf, dtype=np.uint8).astype(np.float) 17 | # return data 18 | 19 | # data = extract_data(data_dir + '/train-images-idx3-ubyte.gz', 60000, 16, 28 * 28) 20 | # trX = data.reshape((60000, 28, 28, 1)) 21 | 22 | # data = extract_data(data_dir + '/train-labels-idx1-ubyte.gz', 60000, 8, 1) 23 | # trY = data.reshape((60000)) 24 | 25 | # data = extract_data(data_dir + '/t10k-images-idx3-ubyte.gz', 10000, 16, 28 * 28) 26 | # teX = data.reshape((10000, 28, 28, 1)) 27 | 28 | # data = extract_data(data_dir + '/t10k-labels-idx1-ubyte.gz', 10000, 8, 1) 29 | # teY = data.reshape((10000)) 30 | 31 | # trY = np.asarray(trY).astype(np.int) 32 | # teY = np.asarray(teY) 33 | 34 | # X = np.concatenate((trX, teX), axis=0) 35 | # y = np.concatenate((trY, teY), axis=0).astype(np.int) 36 | 37 | # seed = 547 38 | # np.random.seed(seed) 39 | # np.random.shuffle(X) 40 | # np.random.seed(seed) 41 | # np.random.shuffle(y) 42 | 43 | # y_vec = np.zeros((len(y), 10), dtype=np.float) 44 | # for i, label in enumerate(y): 45 | # y_vec[i, y[i]] = 1 46 | 47 | # X = X.transpose(0, 3, 1, 2) / 255. 48 | # # y_vec = y_vec.transpose(0, 3, 1, 2) 49 | 50 | # X = torch.from_numpy(X).type(torch.FloatTensor) 51 | # y_vec = torch.from_numpy(y_vec).type(torch.FloatTensor) 52 | # return X, y_vec 53 | 54 | # def load_celebA(dir, transform, batch_size, shuffle): 55 | # # transform = transforms.Compose([ 56 | # # transforms.CenterCrop(160), 57 | # # transform.Scale(64) 58 | # # transforms.ToTensor(), 59 | # # transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) 60 | # # ]) 61 | 62 | # # data_dir = 'data/celebA' # this path depends on your computer 63 | # dset = datasets.ImageFolder(dir, transform) 64 | # data_loader = torch.utils.data.DataLoader(dset, batch_size, shuffle) 65 | 66 | # return data_loader 67 | 68 | 69 | def print_network(net): 70 | num_params = 0 71 | for param in net.parameters(): 72 | num_params += param.numel() 73 | print(net) 74 | print('Total number of parameters: %d' % num_params) 75 | 76 | # def save_images(images, size, image_path): 77 | # return imsave(images, size, image_path) 78 | 79 | # def imsave(images, size, path): 80 | # image = np.squeeze(merge(images, size)) 81 | # return scipy.misc.imsave(path, image) 82 | 83 | # def merge(images, size): 84 | # h, w = images.shape[1], images.shape[2] 85 | # if (images.shape[3] in (3,4)): 86 | # c = images.shape[3] 87 | # img = np.zeros((h * size[0], w * size[1], c)) 88 | # for idx, image in enumerate(images): 89 | # i = idx % size[1] 90 | # j = idx // size[1] 91 | # img[j * h:j * h + h, i * w:i * w + w, :] = image 92 | # return img 93 | # elif images.shape[3]==1: 94 | # img = np.zeros((h * size[0], w * size[1])) 95 | # for idx, image in enumerate(images): 96 | # i = idx % size[1] 97 | # j = idx // size[1] 98 | # img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0] 99 | # return img 100 | # else: 101 | # raise ValueError('in merge(images,size) images parameter ''must have dimensions: HxW or HxWx3 or HxWx4') 102 | 103 | # def generate_animation(path, num): 104 | # images = [] 105 | # for e in range(num): 106 | # img_name = path + '_epoch%03d' % (e+1) + '.png' 107 | # images.append(imageio.imread(img_name)) 108 | # imageio.mimsave(path + '_generate_animation.gif', images, fps=5) 109 | 110 | # def loss_plot(hist, path = 'Train_hist.png', model_name = ''): 111 | # x = range(len(hist['D_loss'])) 112 | 113 | # y1 = hist['D_loss'] 114 | # y2 = hist['G_loss'] 115 | 116 | # plt.plot(x, y1, label='D_loss') 117 | # plt.plot(x, y2, label='G_loss') 118 | 119 | # plt.xlabel('Iter') 120 | # plt.ylabel('Loss') 121 | 122 | # plt.legend(loc=4) 123 | # plt.grid(True) 124 | # plt.tight_layout() 125 | 126 | # path = os.path.join(path, model_name + '_loss.png') 127 | 128 | # plt.savefig(path) 129 | 130 | # plt.close() 131 | 132 | def initialize_weights(net): 133 | for m in net.modules(): 134 | if isinstance(m, nn.Conv2d): 135 | m.weight.data.normal_(0, 0.02) 136 | m.bias.data.zero_() 137 | elif isinstance(m, nn.ConvTranspose2d): 138 | m.weight.data.normal_(0, 0.02) 139 | m.bias.data.zero_() 140 | elif isinstance(m, nn.Linear): 141 | m.weight.data.normal_(0, 0.02) 142 | m.bias.data.zero_() --------------------------------------------------------------------------------