├── README.md ├── Report.pdf ├── config.py ├── datasets └── CelebA.zip ├── main.py ├── model.py ├── roam.py ├── sample.py ├── saves └── gtm_sm_state_dict.pth ├── show_results.py ├── train.py ├── utils ├── __init__.py └── torch_utils.py └── videos └── image_navigation ├── image_navigation1.mp4 └── image_navigation2.mp4 /README.md: -------------------------------------------------------------------------------- 1 | # Generative-Temporal-Models-with-Spatial-Memory 2 | 3 | This repo contains code accomplishing the paper, [Generative Temporal Models with Spatial Memory 4 | for Partially Observed Environments (Marco Fraccaro et al.)](https://arxiv.org/abs/1804.09401). It includes code for running the image navigation experiment. 5 | 6 | 7 | ### Dependencies 8 | This code requires the following: 9 | * python 3.\* 10 | * pytorch v 0.4.0 11 | * pyflann v 1.6.14 12 | * some other frequent used packages, numpy, matplotlib etc. 13 | 14 | 15 | ### Data 16 | This experiment is traind by the Large-scale CelebFaces Attributes (CelebA) Dataset, downloaded from (http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html). Moreover, in this experiment, the preprocess about face detector is required. So I save the preprocess result in `./datasets/CelebA.zip`. Readers can unzip this into `./datasets/` for further experiment. 17 | 18 | 19 | ### Usage 20 | Instructions of python files. 21 | - `config.py` 22 | Use for setting the parameters of the model, such as the batch_size, total epochs, log interval and so on. 23 | - `main.py` 24 | It is **the main function** that uses to train our GTM-SM model. It calls for the functions -- `train` and `test` in `train.py` to train our model and feedback the reconstructon error from validation set. 25 | - `train.py` 26 | Use for implementation of the `train` and `test` function. 27 | - `roam.py` 28 | Use for genetating the trajectory of the 8 x 8 crop over a 32 x 32 image. 29 | - `show_results.py` 30 | Use for generating the result as the `./videos/image_navigation` shows. 31 | - `sample.py` 32 | Use for generating image navigation experiment videos. It can be directly called to producing the corresponding result. 33 | - `/utils/torch_utils.py` 34 | Provide some useful functions. 35 | 36 | I have saved the trained parameters in `./saves/`, and the file `sample.py` would reload these parameters to do the image navigation, so that it can reproduce the simulation result. 37 | 38 | If readers would like to check the result, you can directly run the file `sample.py`. For convenience, it is pleasant to have seen it first, so I recode the relevant videos in `videos/image_navigation/`. 39 | 40 | 41 | ### Contact 42 | To ask questions or report issues, please open an issue on the [issues tracker](https://github.com/chenxy99/Generative-Temporal-Models-with-Spatial-Memory/issues). 43 | -------------------------------------------------------------------------------- /Report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxy99/Generative-Temporal-Models-with-Spatial-Memory/4b59d33c887d0b492406d7eabbec85261a8eb8bd/Report.pdf -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | 4 | parser = argparse.ArgumentParser(description='GTM-SM Example') 5 | parser.add_argument('--batch-size', type=int, default=16, metavar='N', 6 | help='input batch size for training (default: 16)') 7 | parser.add_argument('--epochs', type=int, default=25, metavar='N', 8 | help='number of epochs to train (default: 100)') 9 | parser.add_argument('--no-cuda', action='store_true', default=False, 10 | help='enables CUDA training') 11 | parser.add_argument('--seed', type=int, default=2018, metavar='S', 12 | help='random seed (default: 1)') 13 | parser.add_argument('--log-interval', type=int, default=50, metavar='N', 14 | help='how many batches to wait before logging training status (default: 10)') 15 | parser.add_argument('--save-interval', type=int, default=1, metavar='N', 16 | help='how many epochs to wait before saving model status (default: 1)') 17 | parser.add_argument('--gradient-clip', type=int, default=10, metavar='N', 18 | help='the maximum norm of the gradient will be used (default: 10)') 19 | args = parser.parse_args() 20 | args.cuda = not args.no_cuda and torch.cuda.is_available() 21 | 22 | torch.manual_seed(args.seed) 23 | if args.cuda: 24 | torch.cuda.manual_seed_all(args.seed) 25 | 26 | device = torch.device("cuda" if args.cuda else "cpu") 27 | 28 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} -------------------------------------------------------------------------------- /datasets/CelebA.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxy99/Generative-Temporal-Models-with-Spatial-Memory/4b59d33c887d0b492406d7eabbec85261a8eb8bd/datasets/CelebA.zip -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torchvision 5 | import argparse 6 | import torch.nn.functional as F 7 | import torchvision.transforms as T 8 | import torch.optim as optim 9 | from torch.utils.data import DataLoader 10 | from torch.utils.data import sampler 11 | import torchvision.datasets as dset 12 | 13 | import os 14 | import time 15 | import numpy as np 16 | import matplotlib.pyplot as plt 17 | import matplotlib.gridspec as gridspec 18 | 19 | from multiprocessing import Process 20 | 21 | from utils.torch_utils import initNetParams, ChunkSampler, show_images, device_agnostic_selection 22 | from model import GTM_SM 23 | from config import * 24 | from show_results import show_experiment_information 25 | from train import train, test 26 | 27 | plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots 28 | plt.rcParams['image.interpolation'] = 'nearest' 29 | plt.rcParams['image.cmap'] = 'gray' 30 | 31 | def main(): 32 | data_transform = T.Compose([ 33 | T.Resize((32, 32)), 34 | T.ToTensor(), 35 | ]) 36 | training_dataset = dset.ImageFolder(root='./datasets/CelebA/training', transform=data_transform) 37 | loader_train = DataLoader(training_dataset, batch_size=args.batch_size, shuffle=True, **kwargs) 38 | 39 | val_dataset = dset.ImageFolder(root='./datasets/CelebA/val', transform=data_transform) 40 | loader_val = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, **kwargs) 41 | 42 | GTM_SM_model = GTM_SM(batch_size=args.batch_size, total_dim=256 + 32).to(device=device) 43 | initNetParams(GTM_SM_model) 44 | 45 | optimizer = optim.Adam(GTM_SM_model.parameters(), lr=1e-3) 46 | 47 | for epoch in range(1, args.epochs + 1): 48 | # training + testing 49 | train(epoch, GTM_SM_model, optimizer, loader_train) 50 | test(epoch, GTM_SM_model, loader_val) 51 | # saving model 52 | if (epoch - 1) % args.save_interval == 0: 53 | fn = 'saves/gtm_sm_state_dict_' + str(epoch) + '.pth' 54 | torch.save(GTM_SM_model.state_dict(), fn) 55 | print('Saved model to ' + fn) 56 | 57 | if __name__ == "__main__": 58 | main() -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import torchvision 5 | import torch.nn.functional as F 6 | import torchvision.transforms as T 7 | import torch.optim as optim 8 | 9 | import time 10 | import numpy as np 11 | import pyflann 12 | import matplotlib.pyplot as plt 13 | import matplotlib.gridspec as gridspec 14 | 15 | from utils.torch_utils import initNetParams, ChunkSampler, show_images, device_agnostic_selection 16 | from config import * 17 | from roam import random_walk, random_walk_wo_wall 18 | """implementation of the Generative Temporal Models 19 | with Spatial Memory (GTM-SM) from https://arxiv.org/abs/1804.09401 20 | """ 21 | 22 | class Preprocess_img(nn.Module): 23 | def forward(self, x): 24 | return x * 2 -1 25 | 26 | 27 | class Deprocess_img(nn.Module): 28 | def forward(self, x): 29 | return (x + 1) / 2 30 | 31 | 32 | class Flatten(nn.Module): 33 | def forward(self, x): 34 | N, C, H, W = x.size() # read in N, C, H, W 35 | return x.contiguous().view(N, -1) 36 | 37 | 38 | class Exponent(nn.Module): 39 | def forward(self, x): 40 | return torch.exp(x) 41 | 42 | 43 | class Unflatten(nn.Module): 44 | """ 45 | An Unflatten module receives an input of shape (N, C*H*W) and reshapes it 46 | to produce an output of shape (N, C, H, W). 47 | """ 48 | 49 | def __init__(self, N=-1, C=3, H=8, W=8): 50 | super(Unflatten, self).__init__() 51 | self.N = N 52 | self.C = C 53 | self.H = H 54 | self.W = W 55 | 56 | def forward(self, x): 57 | return x.view(self.N, self.C, self.H, self.W) 58 | 59 | class GTM_SM(nn.Module): 60 | def __init__(self, x_dim=8, a_dim=5, s_dim=2, z_dim=16, observe_dim=256, total_dim=288, \ 61 | r_std=0.001, k_nearest_neighbour=5, delta=0.0001, kl_samples=1000, batch_size=1, \ 62 | lambda_for_mat_orth=1000, lambda_for_mat_mag=1000, lambda_for_sigmoid = 10000, \ 63 | training_wo_wall = True, training_sigmoid = False): 64 | super(GTM_SM, self).__init__() 65 | 66 | self.x_dim = x_dim 67 | self.a_dim = a_dim 68 | self.s_dim = s_dim 69 | self.z_dim = z_dim 70 | self.observe_dim = observe_dim 71 | self.z_dim = z_dim 72 | self.total_dim = total_dim 73 | self.r_std = r_std 74 | self.k_nearest_neighbour = k_nearest_neighbour 75 | self.delta = delta 76 | self.kl_samples = kl_samples 77 | self.batch_size = batch_size 78 | self.lambda_for_mat_orth = lambda_for_mat_orth 79 | self.lambda_for_mat_mag = lambda_for_mat_mag 80 | self.lambda_for_sigmoid = lambda_for_sigmoid 81 | self.training_wo_wall = training_wo_wall 82 | self.training_sigmoid = training_sigmoid 83 | 84 | self.flanns = pyflann.FLANN() 85 | 86 | # feature-extracting transformations 87 | # encoder 88 | # for zt 89 | self.enc_zt = nn.Sequential( 90 | Preprocess_img(), 91 | nn.Conv2d(3, 64, kernel_size=2, stride=2), 92 | nn.LeakyReLU(0.01), 93 | nn.Conv2d(64, 16, kernel_size=2, stride=2), 94 | nn.LeakyReLU(0.01), 95 | Flatten() 96 | ) 97 | 98 | self.enc_zt_mean = nn.Sequential( 99 | nn.Linear(64, z_dim)) 100 | 101 | self.enc_zt_std = nn.Sequential( 102 | nn.Linear(64, z_dim), 103 | Exponent()) 104 | 105 | # for st 106 | self.enc_st_matrix = nn.Sequential( 107 | nn.Linear(a_dim, s_dim, bias=False)) 108 | 109 | self.enc_st_sigmoid = nn.Sequential( 110 | nn.Linear(s_dim, 10), 111 | nn.ReLU(), 112 | nn.Linear(10, 5), 113 | nn.ReLU(), 114 | nn.Linear(5, 1), 115 | nn.Sigmoid()) 116 | 117 | # decoder 118 | self.dec = nn.Sequential( 119 | nn.Linear(z_dim, 64), 120 | nn.ReLU(), 121 | Unflatten(-1, 16, 2, 2), 122 | nn.ConvTranspose2d(in_channels=16, out_channels=64, kernel_size=2, stride=2), 123 | nn.ReLU(), 124 | nn.ConvTranspose2d(in_channels=64, out_channels=3, kernel_size=2, stride=2), 125 | nn.Tanh(), 126 | Deprocess_img()) 127 | 128 | def forward(self, x): 129 | if not self.training: 130 | origin_total_dim = self.total_dim 131 | self.total_dim = 512 132 | if len(x.shape) == 3: 133 | x = x.unsqueeze(0) 134 | 135 | ''' 136 | action_one_hot_value tensor (self.batch_size, self.a_dim, self.total_dim) 137 | position np (self.batch_size, self.s_dim, self.total_dim) 138 | action_selection np (self.batch_size, self.total_dim) 139 | st_observation_list list (self.observe_dim)(self.batch_size, self.s_dim) 140 | st_prediction_list list (self.total_dim - self.observe_dim)(self.batch_size, self.s_dim) 141 | zt_mean_observation_list list (self.observe_dim)(self.batch_size, self.z_dim) 142 | zt_std_observation_list list (self.observe_dim)(self.batch_size, self.z_dim) 143 | zt_mean_prediction_list list (self.total_dim - self.observe_dim)(self.batch_size, self.z_dim) 144 | zt_std_prediction_list list (self.total_dim - self.observe_dim)(self.batch_size, self.z_dim) 145 | xt_prediction_list list (self.total_dim - self.observe_dim)(self.batch_size, self.x_dim) 146 | xt_ground_true_list list (self.total_dim - self.observe_dim)(self.batch_size, self.x_dim) 147 | 148 | after construct them, we will use torch.cat to eliminate the list object 149 | 150 | st_observation_tensor tensor (self.observe_dim, self.batch_size, self.s_dim) 151 | st_prediction_tensor tensor (self.total_dim - self.observe_dim, self.batch_size, self.s_dim) 152 | zt_mean_observation_tensor tensor (self.observe_dim, self.batch_size, self.z_dim) 153 | zt_std_observation_tensor tensor (self.observe_dim, self.batch_size, self.z_dim) 154 | zt_mean_prediction_tensor tensor (self.total_dim - self.observe_dim, self.batch_size, self.z_dim) 155 | zt_std_prediction_tensor tensor (self.total_dim - self.observe_dim, self.batch_size, self.z_dim) 156 | xt_prediction_tensor tensor (self.total_dim - self.observe_dim, self.batch_size, self.x_dim) 157 | xt_ground_true_tensor tensor (self.total_dim - self.observe_dim, self.batch_size, self.x_dim) 158 | 159 | ''' 160 | if self.training_wo_wall: 161 | action_one_hot_value, position, action_selection = random_walk_wo_wall(self) 162 | else: 163 | action_one_hot_value, position, action_selection = random_walk(self) 164 | 165 | st_observation_list = [] 166 | st_prediction_list = [] 167 | zt_mean_observation_list = [] 168 | zt_std_observation_list = [] 169 | zt_mean_prediction_list = [] 170 | zt_std_prediction_list = [] 171 | xt_prediction_list = [] 172 | 173 | kld_loss = 0 174 | nll_loss = 0 175 | 176 | # observation phase: construct st 177 | for t in range(self.observe_dim): 178 | if t == 0: 179 | st_observation_t = torch.zeros(self.batch_size, self.s_dim, device=device)#torch.rand(self.batch_size, self.s_dim, device=device) - 1 180 | else: 181 | replacement = self.enc_st_matrix(action_one_hot_value[:, :, t - 1]) 182 | if not self.training_sigmoid: 183 | st_observation_t = st_observation_list[t - 1] + replacement + \ 184 | torch.randn((self.batch_size, self.s_dim), device=device) * self.r_std 185 | else: 186 | st_observation_t = st_observation_list[t - 1] + replacement * \ 187 | self.enc_st_sigmoid(st_observation_list[t - 1] + replacement) + \ 188 | torch.randn((self.batch_size, self.s_dim), device=device) * self.r_std 189 | st_observation_list.append(st_observation_t) 190 | st_observation_tensor = torch.cat(st_observation_list, 0).view(self.observe_dim, self.batch_size, self.s_dim) 191 | 192 | 193 | # prediction phase: construct st 194 | for t in range(self.total_dim - self.observe_dim): 195 | if t == 0: 196 | replacement = self.enc_st_matrix(action_one_hot_value[:, :, t + self.observe_dim - 1]) 197 | if not self.training_sigmoid: 198 | st_prediction_t = st_observation_list[-1] + replacement + \ 199 | torch.randn((self.batch_size, self.s_dim), device=device) * self.r_std 200 | else: 201 | st_prediction_t = st_observation_list[-1] + replacement * \ 202 | self.enc_st_sigmoid(st_observation_list[-1] + replacement) + \ 203 | torch.randn((self.batch_size, self.s_dim), device=device) * self.r_std 204 | 205 | else: 206 | replacement = self.enc_st_matrix(action_one_hot_value[:, :, t + self.observe_dim - 1]) 207 | if not self.training_sigmoid: 208 | st_prediction_t = st_prediction_list[t - 1] + replacement + \ 209 | torch.randn((self.batch_size, self.s_dim), device=device) * self.r_std 210 | else: 211 | st_prediction_t = st_prediction_list[t - 1] + replacement * \ 212 | self.enc_st_sigmoid(st_prediction_list[t - 1] + replacement) + \ 213 | torch.randn((self.batch_size, self.s_dim), device=device) * self.r_std 214 | 215 | st_prediction_list.append(st_prediction_t) 216 | st_prediction_tensor = torch.cat(st_prediction_list, 0).view(self.total_dim - self.observe_dim, self.batch_size, 217 | self.s_dim) 218 | 219 | # observation phase: construct zt from xt 220 | for t in range(self.observe_dim): 221 | index_mask = torch.zeros((self.batch_size, 3, 32, 32), device=device) 222 | for index_sample in range(self.batch_size): 223 | position_h_t = position[index_sample, 0, t] 224 | position_w_t = position[index_sample, 1, t] 225 | index_mask[index_sample, :, 3 * position_h_t:3 * position_h_t + 8, 226 | 3 * position_w_t:3 * position_w_t + 8] = 1 227 | index_mask_bool = index_mask.ge(0.5) 228 | x_feed = torch.masked_select(x, index_mask_bool).view(-1, 3, 8, 8) 229 | zt_observation_t = self.enc_zt(x_feed) 230 | zt_mean_observation_t = self.enc_zt_mean(zt_observation_t) 231 | zt_std_observation_t = self.enc_zt_std(zt_observation_t) 232 | zt_mean_observation_list.append(zt_mean_observation_t) 233 | zt_std_observation_list.append(zt_std_observation_t) 234 | zt_mean_observation_tensor = torch.cat(zt_mean_observation_list, 0).view(self.observe_dim, self.batch_size, 235 | self.z_dim) 236 | zt_std_observation_tensor = torch.cat(zt_std_observation_list, 0).view(self.observe_dim, self.batch_size, 237 | self.z_dim) 238 | 239 | if self.training: 240 | # prediction phase: construct zt from xt 241 | for t in range(self.total_dim - self.observe_dim): 242 | index_mask = torch.zeros((self.batch_size, 3, 32, 32), device=device) 243 | for index_sample in range(self.batch_size): 244 | position_h_t = position[index_sample, 0, t + self.observe_dim] 245 | position_w_t = position[index_sample, 1, t + self.observe_dim] 246 | index_mask[index_sample, :, 3 * position_h_t:3 * position_h_t + 8, 247 | 3 * position_w_t:3 * position_w_t + 8] = 1 248 | index_mask_bool = index_mask.ge(0.5) 249 | x_feed = torch.masked_select(x, index_mask_bool).view(-1, 3, 8, 8) 250 | zt_prediction_t = self.enc_zt(x_feed) 251 | zt_mean_prediction_t = self.enc_zt_mean(zt_prediction_t) 252 | zt_std_prediction_t = self.enc_zt_std(zt_prediction_t) 253 | zt_mean_prediction_list.append(zt_mean_prediction_t) 254 | zt_std_prediction_list.append(zt_std_prediction_t) 255 | zt_mean_prediction_tensor = torch.cat(zt_mean_prediction_list, 0).view(self.total_dim - self.observe_dim, 256 | self.batch_size, self.z_dim) 257 | zt_std_prediction_tensor = torch.cat(zt_std_prediction_list, 0).view(self.total_dim - self.observe_dim, 258 | self.batch_size, self.z_dim) 259 | 260 | # reparameterized_sample to calculate the reconstruct error 261 | for t in range(self.total_dim - self.observe_dim): 262 | zt_prediction_sample = self._reparameterized_sample(zt_mean_prediction_list[t], 263 | zt_std_prediction_list[t]) 264 | index_mask = torch.zeros((self.batch_size, 3, 32, 32), device=device) 265 | for index_sample in range(self.batch_size): 266 | position_h_t = position[index_sample, 0, t + self.observe_dim] 267 | position_w_t = position[index_sample, 1, t + self.observe_dim] 268 | index_mask[index_sample, :, 3 * position_h_t:3 * position_h_t + 8, 269 | 3 * position_w_t:3 * position_w_t + 8] = 1 270 | index_mask_bool = index_mask.ge(0.5) 271 | x_ground_true_t = torch.masked_select(x, index_mask_bool).view(-1, 3, 8, 8) 272 | x_resconstruct_t = self.dec(zt_prediction_sample) 273 | nll_loss += self._nll_gauss(x_resconstruct_t, x_ground_true_t) 274 | xt_prediction_list.append(x_resconstruct_t) 275 | 276 | # construct kd tree 277 | st_observation_memory = st_observation_tensor.cpu().detach().numpy() 278 | st_prediction_memory = st_prediction_tensor.cpu().detach().numpy() 279 | 280 | results = [] 281 | for index_sample in range(self.batch_size): 282 | param = self.flanns.build_index(st_observation_memory[:, index_sample, :], algorithm='kdtree', 283 | trees=4) 284 | result, _ = self.flanns.nn_index(st_prediction_memory[:, index_sample, :], 285 | self.k_nearest_neighbour, checks=param["checks"]) 286 | results.append(result) 287 | 288 | if self.training: 289 | # calculate the kld 290 | for index_sample in range(self.batch_size): 291 | knn_index = results[index_sample] 292 | knn_index_vec = np.reshape(knn_index, (self.k_nearest_neighbour * (self.total_dim - self.observe_dim))) 293 | knn_st_memory = (st_observation_tensor[knn_index_vec, index_sample]).reshape((self.total_dim - self.observe_dim), \ 294 | self.k_nearest_neighbour, -1) 295 | dk2 = ((knn_st_memory.transpose(0, 1) - st_prediction_tensor[:, index_sample, :]) ** 2).sum(2).transpose(0, 1) 296 | wk = 1 / (dk2 + self.delta) 297 | normalized_wk = (wk.t() / torch.sum(wk, 1)).t() 298 | log_normalized_wk = torch.log(normalized_wk) 299 | zt_sampling = self._reparameterized_sample_cluster(zt_mean_prediction_tensor[:, index_sample], 300 | zt_std_prediction_tensor[:, index_sample]) 301 | log_q_phi = - 0.5 * self.z_dim * torch.log(torch.tensor(2 * 3.1415926535, device = device)) - \ 302 | 0.5 * self.z_dim - torch.log(zt_std_prediction_tensor[:, index_sample]).sum(1) 303 | zt_mean_knn_tensor = zt_mean_observation_tensor[knn_index_vec, index_sample].reshape( 304 | (self.total_dim - self.observe_dim), self.k_nearest_neighbour, -1) 305 | zt_std_knn_tensor = zt_std_observation_tensor[knn_index_vec, index_sample].reshape( 306 | (self.total_dim - self.observe_dim), self.k_nearest_neighbour, -1) 307 | 308 | log_p_theta_element = self._log_gaussian_element_pdf(zt_sampling, zt_mean_knn_tensor, zt_std_knn_tensor) + \ 309 | log_normalized_wk 310 | (log_p_theta_element_max, _) = torch.max(log_p_theta_element, 2) 311 | log_p_theta_element_nimus_max = (log_p_theta_element.transpose(1, 2).transpose(0, 1) - log_p_theta_element_max) 312 | p_theta_nimus_max = torch.exp(log_p_theta_element_nimus_max).sum(0) 313 | kld_loss += torch.mean(log_q_phi - torch.mean(log_p_theta_element_max + torch.log(p_theta_nimus_max), 0)) 314 | else: 315 | xt_prediction_tensor = torch.zeros(self.total_dim - self.observe_dim, self.batch_size, 3, 8, 8, 316 | device=device) 317 | for index_sample in range(self.batch_size): 318 | knn_index = results[index_sample] 319 | knn_index_vec = np.reshape(knn_index, (self.k_nearest_neighbour * (self.total_dim - self.observe_dim))) 320 | knn_st_memory = (st_observation_tensor[knn_index_vec, index_sample]).reshape( 321 | (self.total_dim - self.observe_dim), \ 322 | self.k_nearest_neighbour, -1) 323 | dk2 = ((knn_st_memory.transpose(0, 1) - st_prediction_tensor[:, index_sample, :]) ** 2).sum( 324 | 2).transpose(0, 1) 325 | wk = 1 / (dk2 + self.delta) 326 | normalized_wk = (wk.t() / torch.sum(wk, 1)).t() 327 | cumsum_normalized_wk = torch.cumsum(normalized_wk, dim=1) 328 | rand_sample_value = torch.rand((self.total_dim - self.observe_dim, 1), device=device) 329 | bool_index_list = cumsum_normalized_wk + torch.tensor(1e-7).to(device=device) <= rand_sample_value 330 | knn_sample_index = bool_index_list.sum(1) 331 | zt_sampling = self._reparameterized_sample( 332 | zt_mean_observation_tensor[knn_index[range(self.total_dim - self.observe_dim), knn_sample_index], index_sample], 333 | zt_std_observation_tensor[knn_index[range(self.total_dim - self.observe_dim), knn_sample_index], index_sample]) 334 | xt_prediction_tensor[:, index_sample] = self.dec(zt_sampling) 335 | 336 | # calculate the reconstruct error 337 | for t in range(self.total_dim - self.observe_dim): 338 | index_mask = torch.zeros((self.batch_size, 3, 32, 32), device=device) 339 | for index_sample in range(self.batch_size): 340 | position_h_t = position[index_sample, 0, t + self.observe_dim] 341 | position_w_t = position[index_sample, 1, t + self.observe_dim] 342 | index_mask[index_sample, :, 3 * position_h_t:3 * position_h_t + 8, 343 | 3 * position_w_t:3 * position_w_t + 8] = 1 344 | index_mask_bool = index_mask.ge(0.5) 345 | x_ground_true_t = torch.masked_select(x, index_mask_bool).view(-1, 3, 8, 8) 346 | nll_loss += self._nll_gauss(xt_prediction_tensor[t], x_ground_true_t) 347 | 348 | 349 | # reparameterized_sample to calculate the reconstruct error 350 | for t in range(self.total_dim - self.observe_dim): 351 | xt_prediction_list.append(xt_prediction_tensor[t]) 352 | 353 | if not self.training: 354 | self.total_dim = origin_total_dim 355 | 356 | matrix_loss = self._matrix_loss() 357 | 358 | return kld_loss, nll_loss, matrix_loss, st_observation_list, st_prediction_list, xt_prediction_list, position 359 | 360 | def _log_gaussian_pdf(self, zt, zt_mean, zt_std): 361 | constant_value = torch.tensor(2 * 3.1415926535, device = device) 362 | log_exp_term = - torch.sum((((zt - zt_mean) ** 2) / (zt_std ** 2) / 2.0), 2) 363 | log_other_term = - (self.z_dim / 2.0) * torch.log(constant_value) - torch.sum(torch.log(zt_std), 1) 364 | return log_exp_term + log_other_term 365 | 366 | def _log_gaussian_element_pdf(self, zt, zt_mean, zt_std): 367 | constant_value = torch.tensor(2 * 3.1415926535, device = device) 368 | zt_repeat = zt.unsqueeze(2).repeat(1, 1, self.k_nearest_neighbour, 1) 369 | log_exp_term = - torch.sum((((zt_repeat - zt_mean) ** 2) / (zt_std ** 2) / 2.0), 3) 370 | log_other_term = - (self.z_dim / 2.0) * torch.log(constant_value) - torch.sum(torch.log(zt_std), 2) 371 | return log_exp_term + log_other_term 372 | 373 | def reset_parameters(self, stdv=1e-1): 374 | for weight in self.parameters(): 375 | weight.data.normal_(0, stdv) 376 | 377 | def _init_weights(self, stdv): 378 | for weight in self.parameters(): 379 | weight.data.normal_(0, stdv) 380 | 381 | def _reparameterized_sample(self, mean, std): 382 | """using std to sample""" 383 | eps = torch.randn_like(std, device = device) 384 | return eps.mul(std).add(mean) 385 | 386 | 387 | def _reparameterized_sample_cluster(self, mean, std): 388 | """using std to sample""" 389 | eps = torch.randn((self.kl_samples, self.total_dim - self.observe_dim, self.z_dim), device=device) 390 | return eps.mul(std).add(mean) 391 | 392 | 393 | def _kld_gauss(self, mean_1, std_1, mean_2, std_2): 394 | """Using std to compute KLD""" 395 | kld_element = (2 * torch.log(std_2) - 2 * torch.log(std_1) + 396 | (std_1.pow(2) + (mean_1 - mean_2).pow(2)) / 397 | std_2.pow(2) - 1) 398 | return 0.5 * torch.sum(kld_element) 399 | 400 | def _nll_bernoulli(self, theta, x): 401 | return - torch.sum(x * torch.log(theta) + (1 - x) * torch.log(1 - theta)) 402 | 403 | def _nll_gauss(self, x, mean): 404 | # n, _ = x.size() 405 | return torch.sum((x - mean) ** 2) 406 | 407 | def _matrix_loss(self): 408 | # n, _ = x.size() 409 | for param in self.enc_st_matrix.parameters(): 410 | matrix_loss = self.lambda_for_mat_orth * torch.sum(torch.sum(param[:, 0:1] * param[:, 2:3], 0) ** 2) 411 | for index in range(4): 412 | matrix_loss += self.lambda_for_mat_mag * (torch.norm(param[:, index]) - 1.5/8) ** 2 413 | return matrix_loss 414 | 415 | def _enc_st_sigmoid_forward(self, X_train): 416 | Y_predict = self.enc_st_sigmoid(X_train) 417 | return Y_predict -------------------------------------------------------------------------------- /roam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from config import * 4 | 5 | 6 | def random_walk_wo_wall(model): 7 | # construct position and action 8 | action_one_hot_value_numpy = np.zeros((model.batch_size, model.a_dim, model.total_dim - 1), np.float32) 9 | position = np.zeros((model.batch_size, model.s_dim, model.total_dim), np.int32) 10 | action_selection = np.zeros((model.batch_size, model.total_dim - 1), np.int32) 11 | for index_sample in range(model.batch_size): 12 | new_continue_action_flag = True 13 | for t in range(model.total_dim): 14 | if t == 0: 15 | #position[index_sample, :, t] = np.random.randint(0, 9, size=(2)) 16 | position[index_sample, :, t] = np.ones(2) * 4 17 | else: 18 | if new_continue_action_flag: 19 | new_continue_action_flag = False 20 | need_to_stop = False 21 | 22 | while 1: 23 | action_random_selection = np.random.randint(0, 4, size=(1)) 24 | if not (action_random_selection == 0 and position[index_sample, 1, t - 1] == 8): 25 | if not (action_random_selection == 1 and position[index_sample, 1, t - 1] == 0): 26 | if not (action_random_selection == 2 and position[index_sample, 0, t - 1] == 0): 27 | if not (action_random_selection == 3 and position[index_sample, 0, t - 1] == 8): 28 | break 29 | 30 | #action_random_selection = np.random.randint(0, 5, size=(1)) 31 | action_duriation = np.random.poisson(2, 1) 32 | 33 | if action_duriation > 0: 34 | if not need_to_stop: 35 | if action_random_selection == 0: 36 | if position[index_sample, 1, t - 1] == 8: 37 | need_to_stop = True 38 | position[index_sample, :, t] = position[index_sample, :, t - 1] 39 | action_selection[index_sample, t - 1] = 4 40 | else: 41 | position[index_sample, :, t] = position[index_sample, :, t - 1] + np.array([0, 1]) 42 | action_selection[index_sample, t - 1] = action_random_selection 43 | elif action_random_selection == 1: 44 | if position[index_sample, 1, t - 1] == 0: 45 | need_to_stop = True 46 | position[index_sample, :, t] = position[index_sample, :, t - 1] 47 | action_selection[index_sample, t - 1] = 4 48 | else: 49 | position[index_sample, :, t] = position[index_sample, :, t - 1] + np.array([0, -1]) 50 | action_selection[index_sample, t - 1] = action_random_selection 51 | elif action_random_selection == 2: 52 | if position[index_sample, 0, t - 1] == 0: 53 | need_to_stop = True 54 | position[index_sample, :, t] = position[index_sample, :, t - 1] 55 | action_selection[index_sample, t - 1] = 4 56 | else: 57 | position[index_sample, :, t] = position[index_sample, :, t - 1] + np.array([-1, 0]) 58 | action_selection[index_sample, t - 1] = action_random_selection 59 | elif action_random_selection == 3: 60 | if position[index_sample, 0, t - 1] == 8: 61 | need_to_stop = True 62 | position[index_sample, :, t] = position[index_sample, :, t - 1] 63 | action_selection[index_sample, t - 1] = 4 64 | else: 65 | position[index_sample, :, t] = position[index_sample, :, t - 1] + np.array([1, 0]) 66 | action_selection[index_sample, t - 1] = action_random_selection 67 | else: 68 | position[index_sample, :, t] = position[index_sample, :, t - 1] 69 | action_selection[index_sample, t - 1] = 4 70 | action_duriation -= 1 71 | else: 72 | action_selection[index_sample, t - 1] = 4 73 | position[index_sample, :, t] = position[index_sample, :, t - 1] 74 | if action_duriation <= 0: 75 | new_continue_action_flag = True 76 | 77 | 78 | for index_sample in range(model.batch_size): 79 | action_one_hot_value_numpy[ 80 | index_sample, action_selection[index_sample], np.array(range(model.total_dim - 1))] = 1 81 | 82 | action_one_hot_value = torch.from_numpy(action_one_hot_value_numpy).to(device=device) 83 | 84 | return action_one_hot_value, position, action_selection 85 | 86 | 87 | def random_walk(model): 88 | # construct position and action 89 | action_one_hot_value_numpy = np.zeros((model.batch_size, model.a_dim, model.total_dim - 1), np.float32) 90 | position = np.zeros((model.batch_size, model.s_dim, model.total_dim), np.int32) 91 | action_selection = np.zeros((model.batch_size, model.total_dim - 1), np.int32) 92 | for index_sample in range(model.batch_size): 93 | new_continue_action_flag = True 94 | for t in range(model.total_dim): 95 | if t == 0: 96 | #position[index_sample, :, t] = np.random.randint(0, 9, size=(2)) 97 | position[index_sample, :, t] = np.ones(2) * 4 98 | else: 99 | if new_continue_action_flag: 100 | new_continue_action_flag = False 101 | need_to_stop = False 102 | 103 | while 1: 104 | action_random_selection = np.random.randint(0, 4, size=(1)) 105 | if not (action_random_selection == 0 and position[index_sample, 1, t - 1] == 8): 106 | if not (action_random_selection == 1 and position[index_sample, 1, t - 1] == 0): 107 | if not (action_random_selection == 2 and position[index_sample, 0, t - 1] == 0): 108 | if not (action_random_selection == 3 and position[index_sample, 0, t - 1] == 8): 109 | break 110 | 111 | #action_random_selection = np.random.randint(0, 5, size=(1)) 112 | action_duriation = np.random.poisson(2, 1) 113 | 114 | if action_duriation > 0: 115 | if not need_to_stop: 116 | if action_random_selection == 0: 117 | if position[index_sample, 1, t - 1] == 8: 118 | need_to_stop = True 119 | position[index_sample, :, t] = position[index_sample, :, t - 1] 120 | else: 121 | position[index_sample, :, t] = position[index_sample, :, t - 1] + np.array([0, 1]) 122 | elif action_random_selection == 1: 123 | if position[index_sample, 1, t - 1] == 0: 124 | need_to_stop = True 125 | position[index_sample, :, t] = position[index_sample, :, t - 1] 126 | else: 127 | position[index_sample, :, t] = position[index_sample, :, t - 1] + np.array([0, -1]) 128 | elif action_random_selection == 2: 129 | if position[index_sample, 0, t - 1] == 0: 130 | need_to_stop = True 131 | position[index_sample, :, t] = position[index_sample, :, t - 1] 132 | else: 133 | position[index_sample, :, t] = position[index_sample, :, t - 1] + np.array([-1, 0]) 134 | elif action_random_selection == 3: 135 | if position[index_sample, 0, t - 1] == 8: 136 | need_to_stop = True 137 | position[index_sample, :, t] = position[index_sample, :, t - 1] 138 | else: 139 | position[index_sample, :, t] = position[index_sample, :, t - 1] + np.array([1, 0]) 140 | else: 141 | position[index_sample, :, t] = position[index_sample, :, t - 1] 142 | action_duriation -= 1 143 | action_selection[index_sample, t - 1] = action_random_selection 144 | else: 145 | action_selection[index_sample, t - 1] = 4 146 | position[index_sample, :, t] = position[index_sample, :, t - 1] 147 | if action_duriation <= 0: 148 | new_continue_action_flag = True 149 | 150 | for index_sample in range(model.batch_size): 151 | action_one_hot_value_numpy[ 152 | index_sample, action_selection[index_sample], np.array(range(model.total_dim - 1))] = 1 153 | 154 | action_one_hot_value = torch.from_numpy(action_one_hot_value_numpy).to(device=device) 155 | 156 | return action_one_hot_value, position, action_selection 157 | 158 | 159 | def sample_position(model): 160 | for params in model.enc_st_matrix.parameters(): 161 | enc_st_matrix_params = params 162 | 163 | right_vector = enc_st_matrix_params[:, 0].detach().cpu().numpy() 164 | up_vector = enc_st_matrix_params[:, 2].detach().cpu().numpy() 165 | 166 | project_matrix = np.hstack((right_vector.reshape((2, 1)), up_vector.reshape((2, 1)))) 167 | 168 | x, y = np.meshgrid(np.arange(-8.,9., dtype=np.float32), np.arange(-8.,9., dtype=np.float32)) 169 | pos_index = np.vstack((x.reshape((1, -1)), y.reshape((1, -1)))) 170 | 171 | X_train = np.dot(project_matrix, pos_index).T 172 | Y_train = np.zeros((X_train.shape[0], 1), np.float32) 173 | true_index = np.logical_and(pos_index[0] <= 4.5, pos_index[0] >= -4.5) 174 | true_index = np.logical_and(true_index, pos_index[1] <= 4.5) 175 | true_index = np.logical_and(true_index, pos_index[1] >= -4.5) 176 | Y_train[true_index,:] = 1 177 | 178 | X_train = torch.from_numpy(X_train).to(device=device) 179 | Y_train = torch.from_numpy(Y_train).to(device=device) 180 | 181 | return X_train, Y_train 182 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torchvision 5 | import torch.nn.functional as F 6 | import torchvision.transforms as T 7 | import torch.optim as optim 8 | from torch.utils.data import DataLoader 9 | from torch.utils.data import sampler 10 | import torchvision.datasets as dset 11 | 12 | import numpy as np 13 | import matplotlib.pyplot as plt 14 | import matplotlib.gridspec as gridspec 15 | 16 | from utils.torch_utils import initNetParams, ChunkSampler, show_images, device_agnostic_selection 17 | from model import GTM_SM 18 | from config import * 19 | from show_results import show_experiment_information 20 | 21 | plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots 22 | plt.rcParams['image.interpolation'] = 'nearest' 23 | plt.rcParams['image.cmap'] = 'gray' 24 | 25 | #load data 26 | data_transform = T.Compose([ 27 | T.Resize((32, 32)), 28 | T.ToTensor(), 29 | ]) 30 | testing_dataset = dset.ImageFolder(root='./datasets/CelebA/testing', 31 | transform=data_transform) 32 | loader_val = DataLoader(testing_dataset, batch_size=args.batch_size, shuffle=True) 33 | 34 | path_to_load = 'saves/GTM_SM_state_dict.pth' 35 | if torch.cuda.is_available(): 36 | state_dict = torch.load(path_to_load) 37 | else: 38 | state_dict = torch.load(path_to_load, map_location=lambda storage, loc: storage) 39 | GTM_SM_model = GTM_SM(batch_size = args.batch_size) 40 | GTM_SM_model.load_state_dict(state_dict) 41 | GTM_SM_model.to(device=device) 42 | 43 | def sample(): 44 | GTM_SM_model.eval() 45 | with torch.no_grad(): 46 | for batch_idx, (data, _) in enumerate(loader_val): 47 | 48 | #transforming data 49 | training_data = data.to(device=device) 50 | #forward 51 | kld_loss, nll_loss, matrix_loss, st_observation_list, st_prediction_list, xt_prediction_list, position = GTM_SM_model(training_data) 52 | 53 | show_experiment_information(GTM_SM_model, data, st_observation_list, st_prediction_list, xt_prediction_list, position) 54 | 55 | sample() -------------------------------------------------------------------------------- /saves/gtm_sm_state_dict.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxy99/Generative-Temporal-Models-with-Spatial-Memory/4b59d33c887d0b492406d7eabbec85261a8eb8bd/saves/gtm_sm_state_dict.pth -------------------------------------------------------------------------------- /show_results.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import matplotlib.gridspec as gridspec 4 | 5 | def show_experiment_information(model, x, st_observation_list, st_prediction_list, xt_prediction_list, position): 6 | if not model.training: 7 | origin_total_dim = model.total_dim 8 | model.total_dim = 512 9 | 10 | if len(x.shape) == 3: 11 | x = x.unsqueeze(0) 12 | sample_id = np.random.randint(0, model.batch_size, size=(1)) 13 | sample_imgs = x[sample_id] 14 | 15 | st_observation_sample = np.zeros((model.observe_dim, model.s_dim)) 16 | for t in range(model.observe_dim): 17 | st_observation_sample[t] = st_observation_list[t][sample_id].cpu().detach().numpy() 18 | 19 | st_prediction_sample = np.zeros((model.total_dim - model.observe_dim, model.s_dim)) 20 | for t in range(model.total_dim - model.observe_dim): 21 | st_prediction_sample[t] = st_prediction_list[t][sample_id].cpu().detach().numpy() 22 | 23 | st_2_max = np.maximum(np.max(st_observation_sample[:, 0]), np.max(st_prediction_sample[:, 0])) 24 | st_2_min = np.minimum(np.min(st_observation_sample[:, 0]), np.min(st_prediction_sample[:, 0])) 25 | st_1_max = np.maximum(np.max(st_observation_sample[:, 1]), np.max(st_prediction_sample[:, 1])) 26 | st_1_min = np.minimum(np.min(st_observation_sample[:, 1]), np.min(st_prediction_sample[:, 1])) 27 | axis_st_1_max = st_1_max + (st_1_max - st_1_min) / 10.0 28 | axis_st_1_min = st_1_min - (st_1_max - st_1_min) / 10.0 29 | axis_st_2_max = st_2_max + (st_2_max - st_2_min) / 10.0 30 | axis_st_2_min = st_2_min - (st_2_max - st_2_min) / 10.0 31 | 32 | fig = plt.figure() 33 | # interaction mode 34 | plt.ion() 35 | 36 | # observation phase 37 | for t in range(0, model.observe_dim): 38 | position_h_t = np.asscalar(position[sample_id, 0, t]) 39 | position_w_t = np.asscalar(position[sample_id, 1, t]) 40 | sample_imgs_t = np.copy(sample_imgs.cpu().detach().numpy()) 41 | observed_img = np.copy(sample_imgs[:, :, 3 * position_h_t: 3 * position_h_t + 8, 42 | 3 * position_w_t: 3 * position_w_t + 8].cpu().detach().numpy()) 43 | 44 | sample_imgs_t[0, 0, 3 * position_h_t: 3 * position_h_t + 8, 3 * position_w_t] = 1.0 45 | sample_imgs_t[0, 0, 3 * position_h_t: 3 * position_h_t + 8, 3 * position_w_t + 8 - 1] = 1.0 46 | sample_imgs_t[0, 0, 3 * position_h_t, 3 * position_w_t: 3 * position_w_t + 8] = 1.0 47 | sample_imgs_t[0, 0, 3 * position_h_t + 8 - 1, 3 * position_w_t: 3 * position_w_t + 8] = 1.0 48 | sample_imgs_t[0, 1, 3 * position_h_t: 3 * position_h_t + 8, 3 * position_w_t] = 0.0 49 | sample_imgs_t[0, 1, 3 * position_h_t: 3 * position_h_t + 8, 3 * position_w_t + 8 - 1] = 0.0 50 | sample_imgs_t[0, 1, 3 * position_h_t, 3 * position_w_t: 3 * position_w_t + 8] = 0.0 51 | sample_imgs_t[0, 1, 3 * position_h_t + 8 - 1, 3 * position_w_t: 3 * position_w_t + 8] = 0.0 52 | sample_imgs_t[0, 2, 3 * position_h_t: 3 * position_h_t + 8, 3 * position_w_t] = 0.0 53 | sample_imgs_t[0, 2, 3 * position_h_t: 3 * position_h_t + 8, 3 * position_w_t + 8 - 1] = 0.0 54 | sample_imgs_t[0, 2, 3 * position_h_t, 3 * position_w_t: 3 * position_w_t + 8] = 0.0 55 | sample_imgs_t[0, 2, 3 * position_h_t + 8 - 1, 3 * position_w_t: 3 * position_w_t + 8] = 0.0 56 | 57 | fig.clf() 58 | 59 | plt.suptitle('t = ' + str(t) + '\n' + 'OBSERVATION PHASE', fontsize=25) 60 | 61 | gs = gridspec.GridSpec(20, 20) 62 | 63 | # subfigure 1 64 | ax1 = plt.subplot(gs[1:10, 1:10]) 65 | ax1.set_xticklabels([]) 66 | ax1.set_yticklabels([]) 67 | ax1.set_aspect('equal') 68 | plt.axis('off') 69 | plt.imshow(sample_imgs_t.reshape([3, 32, 32]).transpose((1, 2, 0))) 70 | 71 | # subfigure 2 72 | ax2 = plt.subplot(gs[4:8, 11:15]) 73 | ax2.set_xticklabels([]) 74 | ax2.set_yticklabels([]) 75 | ax2.set_aspect('equal') 76 | ax2.set_title('Observation') 77 | plt.axis('off') 78 | plt.imshow(observed_img.reshape([3, 8, 8]).transpose((1, 2, 0))) 79 | 80 | # subfigure 3 81 | 82 | # subfigure 4 83 | ax4 = plt.subplot(gs[11:20, 1:10]) 84 | ax4.set_xlabel('x') 85 | ax4.set_ylabel('y') 86 | ax4.set_title('True states') 87 | ax4.set_aspect('equal') 88 | plt.axis([-1, 9, -1, 9]) 89 | plt.gca().invert_yaxis() 90 | plt.plot(position[sample_id, 1, 0: t + 1].T, position[sample_id, 0, 0: t + 1].T, color='k', 91 | linestyle='solid', marker='o') 92 | plt.plot(position[sample_id, 1, t], position[sample_id, 0, t], 'bs') 93 | 94 | # subfigure 5 95 | ax5 = plt.subplot(gs[11:20, 11:20]) 96 | ax5.set_xlabel('$s_1$') 97 | ax5.set_ylabel('$s_2$') 98 | ax5.set_title('Inferred states') 99 | plt.axis([axis_st_1_min, axis_st_1_max, axis_st_2_min, axis_st_2_max]) 100 | plt.gca().invert_yaxis() 101 | plt.plot(st_observation_sample[0: t + 1, 1], st_observation_sample[0: t + 1, 0], color='k', 102 | linestyle='solid', marker='o') 103 | plt.plot(st_observation_sample[t, 1], st_observation_sample[t, 0], 'bs') 104 | 105 | plt.pause(0.01) 106 | 107 | # predition phase 108 | for t in range(model.total_dim - model.observe_dim): 109 | position_h_t = np.asscalar(position[sample_id, 0, t + model.observe_dim]) 110 | position_w_t = np.asscalar(position[sample_id, 1, t + model.observe_dim]) 111 | sample_imgs_t = np.copy(sample_imgs.cpu().detach().numpy()) 112 | observed_img = np.copy(sample_imgs[:, :, 3 * position_h_t: 3 * position_h_t + 8, 113 | 3 * position_w_t: 3 * position_w_t + 8].cpu().detach().numpy()) 114 | predict_img = xt_prediction_list[t][np.asscalar(sample_id)].cpu().detach().numpy() 115 | 116 | sample_imgs_t[0, 0, 3 * position_h_t: 3 * position_h_t + 8, 3 * position_w_t] = 1.0 117 | sample_imgs_t[0, 0, 3 * position_h_t: 3 * position_h_t + 8, 3 * position_w_t + 8 - 1] = 1.0 118 | sample_imgs_t[0, 0, 3 * position_h_t, 3 * position_w_t: 3 * position_w_t + 8] = 1.0 119 | sample_imgs_t[0, 0, 3 * position_h_t + 8 - 1, 3 * position_w_t: 3 * position_w_t + 8] = 1.0 120 | sample_imgs_t[0, 1, 3 * position_h_t: 3 * position_h_t + 8, 3 * position_w_t] = 0.0 121 | sample_imgs_t[0, 1, 3 * position_h_t: 3 * position_h_t + 8, 3 * position_w_t + 8 - 1] = 0.0 122 | sample_imgs_t[0, 1, 3 * position_h_t, 3 * position_w_t: 3 * position_w_t + 8] = 0.0 123 | sample_imgs_t[0, 1, 3 * position_h_t + 8 - 1, 3 * position_w_t: 3 * position_w_t + 8] = 0.0 124 | sample_imgs_t[0, 2, 3 * position_h_t: 3 * position_h_t + 8, 3 * position_w_t] = 0.0 125 | sample_imgs_t[0, 2, 3 * position_h_t: 3 * position_h_t + 8, 3 * position_w_t + 8 - 1] = 0.0 126 | sample_imgs_t[0, 2, 3 * position_h_t, 3 * position_w_t: 3 * position_w_t + 8] = 0.0 127 | sample_imgs_t[0, 2, 3 * position_h_t + 8 - 1, 3 * position_w_t: 3 * position_w_t + 8] = 0.0 128 | 129 | fig.clf() 130 | 131 | plt.suptitle('t = ' + str(t + model.observe_dim) + '\n' + 'PREDICTION PHASE', fontsize=25) 132 | 133 | gs = gridspec.GridSpec(20, 20) 134 | ''' 135 | ax1 = plt.subplot(gs[0:4, 0:4]) 136 | ax2 = plt.subplot(gs[1:3, 4:6]) 137 | ax3 = plt.subplot(gs[1:3, 6:8]) 138 | ax4 = plt.subplot(gs[4:8, 0:4]) 139 | ax5 = plt.subplot(gs[4:8, 4:8]) 140 | ''' 141 | # subfigure 1 142 | ax1 = plt.subplot(gs[1:10, 1:10]) 143 | ax1.set_xticklabels([]) 144 | ax1.set_yticklabels([]) 145 | ax1.set_aspect('equal') 146 | plt.axis('off') 147 | plt.imshow(sample_imgs_t.reshape([3, 32, 32]).transpose((1, 2, 0))) 148 | 149 | # subfigure 2 150 | ax2 = plt.subplot(gs[4:8, 11:15]) 151 | ax2.set_xticklabels([]) 152 | ax2.set_yticklabels([]) 153 | ax2.set_aspect('equal') 154 | ax2.set_title('Observation') 155 | plt.axis('off') 156 | plt.imshow(observed_img.reshape([3, 8, 8]).transpose((1, 2, 0))) 157 | 158 | # subfigure 3 159 | ax3 = plt.subplot(gs[4:8, 16:20]) 160 | ax3.set_xticklabels([]) 161 | ax3.set_yticklabels([]) 162 | ax3.set_aspect('equal') 163 | ax3.set_title('Prediction') 164 | plt.axis('off') 165 | plt.imshow(predict_img.reshape([3, 8, 8]).transpose((1, 2, 0))) 166 | 167 | # subfigure 4 168 | ax4 = plt.subplot(gs[11:20, 1:10]) 169 | ax4.set_xlabel('x') 170 | ax4.set_ylabel('y') 171 | ax4.set_title('True states') 172 | ax4.set_aspect('equal') 173 | plt.axis([-1, 9, -1, 9]) 174 | plt.gca().invert_yaxis() 175 | plt.plot(position[sample_id, 1, 0: model.observe_dim + 1].T, 176 | position[sample_id, 0, 0: model.observe_dim + 1].T, color='k', linestyle='solid', marker='o') 177 | plt.plot(position[sample_id, 1, t + model.observe_dim], position[sample_id, 0, t + model.observe_dim], 'bs') 178 | 179 | # subfigure 5 180 | ax5 = plt.subplot(gs[11:20, 11:20]) 181 | ax5.set_xlabel('$s_1$') 182 | ax5.set_ylabel('$s_2$') 183 | ax5.set_title('Inferred states') 184 | plt.axis([axis_st_1_min, axis_st_1_max, axis_st_2_min, axis_st_2_max]) 185 | plt.gca().invert_yaxis() 186 | plt.plot(st_observation_sample[:, 1], st_observation_sample[:, 0], color='k', linestyle='solid', marker='o') 187 | plt.plot(st_prediction_sample[t, 1], st_prediction_sample[t, 0], 'bs') 188 | 189 | plt.pause(0.01) 190 | 191 | # show figure 192 | plt.show() 193 | 194 | # close figure 195 | plt.close(fig) 196 | 197 | # close interaction mode 198 | plt.ioff() 199 | 200 | if not model.training: 201 | model.total_dim = origin_total_dim -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torchvision 5 | import argparse 6 | import torch.nn.functional as F 7 | import torchvision.transforms as T 8 | import torch.optim as optim 9 | from torch.utils.data import DataLoader 10 | from torch.utils.data import sampler 11 | import torchvision.datasets as dset 12 | 13 | import os 14 | import time 15 | import numpy as np 16 | import matplotlib.pyplot as plt 17 | import matplotlib.gridspec as gridspec 18 | 19 | from multiprocessing import Process 20 | 21 | from utils.torch_utils import initNetParams, ChunkSampler, show_images, device_agnostic_selection, show_heatmap 22 | from model import GTM_SM 23 | from config import * 24 | from show_results import show_experiment_information 25 | from roam import sample_position 26 | 27 | def train(epoch, model, optimizer, loader_train): 28 | model.train() 29 | BCE_loss = nn.BCELoss() 30 | 31 | train_loss = 0 32 | bce_loss = 0 33 | wo_wall_end_epoch = 10 34 | wo_wall_and_penalty_end_epoch = 15 35 | 36 | for batch_idx, (data, _) in enumerate(loader_train): 37 | 38 | if epoch < wo_wall_end_epoch: 39 | if epoch == 1 and batch_idx == 0: 40 | optimizer.__init__( 41 | [{'params': model.enc_zt.parameters()}, 42 | {'params': model.enc_zt_mean.parameters()}, 43 | {'params': model.enc_zt_std.parameters()}, 44 | {'params': model.enc_st_matrix.parameters()}, 45 | {'params': model.dec.parameters()}, 46 | {'params': model.enc_st_sigmoid.parameters(), 'lr': 1e-2}], 47 | lr=1e-3) 48 | 49 | # transforming data 50 | training_data = data.to(device=device) 51 | 52 | # forward + backward + optimize 53 | optimizer.zero_grad() 54 | kld_loss, nll_loss, matrix_loss, st_observation_list, st_prediction_list, xt_prediction_list, position = model.forward( 55 | training_data) 56 | 57 | train_loss += (nll_loss + kld_loss).item() 58 | loss_to_optimize = (nll_loss + kld_loss) / args.batch_size 59 | loss_to_optimize.backward() 60 | 61 | elif epoch >= wo_wall_end_epoch and epoch < wo_wall_and_penalty_end_epoch: 62 | if epoch == wo_wall_end_epoch and batch_idx == 0: 63 | optimizer.__init__( [{'params': model.enc_zt.parameters()}, 64 | {'params': model.enc_zt_mean.parameters()}, 65 | {'params': model.enc_zt_std.parameters()}, 66 | {'params': model.enc_st_matrix.parameters()}, 67 | {'params': model.dec.parameters()}, 68 | {'params': model.enc_st_sigmoid.parameters(), 'lr': 1e-2}], 69 | lr=1e-3) 70 | 71 | # transforming data 72 | training_data = data.to(device=device) 73 | 74 | # forward + backward + optimize 75 | optimizer.zero_grad() 76 | kld_loss, nll_loss, matrix_loss, st_observation_list, st_prediction_list, xt_prediction_list, position = model.forward( 77 | training_data) 78 | 79 | train_loss += (nll_loss + kld_loss).item() 80 | loss_to_optimize = (nll_loss + kld_loss) / args.batch_size + matrix_loss 81 | loss_to_optimize.backward() 82 | 83 | else: 84 | if epoch == wo_wall_and_penalty_end_epoch and batch_idx == 0: 85 | model.training_wo_wall = False 86 | model.training_sigmoid = True 87 | optimizer.__init__( 88 | [{'params': model.enc_zt.parameters()}, 89 | {'params': model.enc_zt_mean.parameters()}, 90 | {'params': model.enc_zt_std.parameters()}, 91 | {'params': model.enc_st_matrix.parameters()}, 92 | {'params': model.dec.parameters()}, 93 | {'params': model.enc_st_sigmoid.parameters(), 'lr': 1e-2}], 94 | lr=5e-4) 95 | 96 | # transforming data 97 | training_data = data.to(device=device) 98 | 99 | # forward + backward + optimize 100 | optimizer.zero_grad() 101 | kld_loss, nll_loss, matrix_loss, st_observation_list, st_prediction_list, xt_prediction_list, position = model.forward( 102 | training_data) 103 | X_train, Y_train = sample_position(model) 104 | bce_loss = BCE_loss(model._enc_st_sigmoid_forward(X_train), Y_train) 105 | 106 | train_loss += (nll_loss + kld_loss).item() 107 | loss_to_optimize = (nll_loss + kld_loss) / args.batch_size + matrix_loss + model.lambda_for_sigmoid * bce_loss 108 | loss_to_optimize.backward() 109 | 110 | # grad norm clipping, only in pytorch version >= 1.10 111 | #nn.utils.clip_grad_norm_(GTM_SM_model.parameters(), args.gradient_clip) 112 | 113 | optimizer.step() 114 | 115 | # printing 116 | if epoch < wo_wall_and_penalty_end_epoch : 117 | if batch_idx % args.log_interval == 0: 118 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\t KLD Loss: {:.6f} \t NLL Loss: {:.6f} \t MATRIX Lose: {:.6f}'.format( 119 | epoch, batch_idx * len(data), len(loader_train.dataset), 120 | 100. * batch_idx * len(data) / len(loader_train.dataset), 121 | kld_loss.item() / len(data), 122 | nll_loss.item() / len(data), 123 | matrix_loss.item())) 124 | else: 125 | if batch_idx % args.log_interval == 0: 126 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\t KLD Loss: {:.6f} \t NLL Loss: {:.6f} \t MATRIX Lose: {:.6f} \t Sigmoid Lose: {:.6f}'.format( 127 | epoch, batch_idx * len(data), len(loader_train.dataset), 128 | 100. * batch_idx * len(data) / len(loader_train.dataset), 129 | kld_loss.item() / len(data), 130 | nll_loss.item() / len(data), 131 | matrix_loss.item(), 132 | bce_loss.item())) 133 | 134 | print('====> Epoch: {} Average loss: {:.4f}'.format( 135 | epoch, train_loss / len(loader_train.dataset))) 136 | 137 | 138 | def test(epoch, model, loader_val): 139 | model.eval() 140 | test_loss = 0 141 | with torch.no_grad(): 142 | for i, (data, _) in enumerate(loader_val): 143 | data = data.to(device=device) 144 | kld_loss, nll_loss, matrix_loss, st_observation_list, st_prediction_list, xt_prediction_list, position = model.forward( 145 | data) 146 | test_loss += nll_loss 147 | 148 | test_loss /= len(loader_val.dataset) 149 | print('====> Test set loss: {:.4f}'.format(test_loss)) 150 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxy99/Generative-Temporal-Models-with-Spatial-Memory/4b59d33c887d0b492406d7eabbec85261a8eb8bd/utils/__init__.py -------------------------------------------------------------------------------- /utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | from torch.utils.data import sampler 5 | import argparse 6 | 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | import matplotlib.gridspec as gridspec 10 | 11 | def initNetParams(net): 12 | '''Init net parameters.''' 13 | for m in net.modules(): 14 | if isinstance(m, nn.Conv2d): 15 | init.xavier_uniform_(m.weight) 16 | if m.bias is not None: 17 | init.constant_(m.bias, 0) 18 | elif isinstance(m, nn.BatchNorm1d): 19 | init.constant_(m.weight, 1) 20 | init.constant_(m.bias, 0) 21 | elif isinstance(m, nn.BatchNorm2d): 22 | init.constant_(m.weight, 1) 23 | init.constant_(m.bias, 0) 24 | elif isinstance(m, nn.Linear): 25 | #init.normal_(m.weight, std=1e-1) 26 | init.xavier_uniform_(m.weight) 27 | if m.bias is not None: 28 | init.constant_(m.bias, 0) 29 | # init.normal_(m.bias, std=1e-1) 30 | 31 | 32 | class ChunkSampler(sampler.Sampler): 33 | """Samples elements sequentially from some offset. 34 | Arguments: 35 | num_samples: # of desired datapoints 36 | start: offset where we should start selecting from 37 | """ 38 | 39 | def __init__(self, num_samples, start=0): 40 | self.num_samples = num_samples 41 | self.start = start 42 | 43 | def __iter__(self): 44 | return iter(range(self.start, self.start + self.num_samples)) 45 | 46 | def __len__(self): 47 | return self.num_samples 48 | 49 | 50 | def show_images(images): 51 | if len(images.shape) == 3: 52 | images = np.expand_dims(images, axis=0) 53 | print(images.shape) 54 | images = np.reshape(images.cpu().detach().numpy(), [images.shape[0], 3, -1]) # images reshape to (batch_size, D) 55 | sqrtn = int(np.ceil(np.sqrt(images.shape[0]))) 56 | sqrtimg = int(np.ceil(np.sqrt(images.shape[2]))) 57 | 58 | fig = plt.figure(figsize=(sqrtn, sqrtn)) 59 | gs = gridspec.GridSpec(sqrtn, sqrtn) 60 | gs.update(wspace=0.05, hspace=0.05) 61 | 62 | for i, img in enumerate(images): 63 | ax = plt.subplot(gs[i]) 64 | plt.axis('off') 65 | ax.set_xticklabels([]) 66 | ax.set_yticklabels([]) 67 | ax.set_aspect('equal') 68 | plt.imshow(img.reshape([3, sqrtimg, sqrtimg]).transpose((1, 2, 0))) 69 | 70 | # show figure 71 | plt.show() 72 | return 73 | 74 | 75 | def device_agnostic_selection(): 76 | parser = argparse.ArgumentParser(description='PyTorch Example') 77 | parser.add_argument('--disable-cuda', action='store_true', 78 | help='Disable CUDA') 79 | args = parser.parse_args() 80 | args.device = None 81 | if not args.disable_cuda and torch.cuda.is_available(): 82 | args.device = torch.device('cuda') 83 | else: 84 | args.device = torch.device('cpu') 85 | return args.device 86 | 87 | 88 | def show_heatmap(x, y, z): 89 | ''' 90 | :param x: the np.array, such as np.linspace(0, 10, 10) 91 | :param y: the np.array, such as np.linspace(0, 10, 10) 92 | :param z: the 2D np.array, generated by the function of x and y 93 | :return: None 94 | ''' 95 | plt.close() 96 | fig, ax = plt.subplots() 97 | # interaction mode 98 | plt.ion() 99 | im = ax.pcolormesh(x, y, z) 100 | fig.colorbar(im) 101 | 102 | ax.axis('tight') 103 | plt.show() 104 | plt.pause(5) 105 | # close figure 106 | #plt.close(fig) 107 | # close interaction mode 108 | plt.ioff() -------------------------------------------------------------------------------- /videos/image_navigation/image_navigation1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxy99/Generative-Temporal-Models-with-Spatial-Memory/4b59d33c887d0b492406d7eabbec85261a8eb8bd/videos/image_navigation/image_navigation1.mp4 -------------------------------------------------------------------------------- /videos/image_navigation/image_navigation2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxy99/Generative-Temporal-Models-with-Spatial-Memory/4b59d33c887d0b492406d7eabbec85261a8eb8bd/videos/image_navigation/image_navigation2.mp4 --------------------------------------------------------------------------------