├── fNRI.png ├── LICENSE ├── trajectory_plot.py ├── README.md ├── data └── generate_dataset.py ├── train_dec.py ├── train_sigmoid.py ├── utils.py ├── train_enc.py ├── modules.py └── train.py /fNRI.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ekwebb/fNRI/HEAD/fNRI.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 E K Webb 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /trajectory_plot.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import matplotlib 5 | from matplotlib.colors import ListedColormap 6 | import matplotlib.collections as mcoll 7 | 8 | 9 | def draw_lines(output,output_i,linestyle='-',alpha=1,darker=False,linewidth=2): 10 | """ 11 | http://nbviewer.ipython.org/github/dpsanders/matplotlib-examples/blob/master/colorline.ipynb 12 | http://matplotlib.org/examples/pylab_examples/multicolored_line.html 13 | """ 14 | loc = np.array(output[output_i,:,:,0:2]) 15 | loc = np.transpose( loc, [1,2,0] ) 16 | 17 | x = loc[:,0,:] 18 | y = loc[:,1,:] 19 | x_min = np.min(x) 20 | x_max = np.max(x) 21 | y_min = np.min(y) 22 | y_max = np.max(y) 23 | max_range = max( y_max-y_min, x_max-x_min ) 24 | xmin = (x_min+x_max)/2-max_range/2-0.1 25 | xmax = (x_min+x_max)/2+max_range/2+0.1 26 | ymin = (y_min+y_max)/2-max_range/2-0.1 27 | ymax = (y_min+y_max)/2+max_range/2+0.1 28 | 29 | cmaps = [ 'Purples', 'Greens', 'Blues', 'Oranges', 'Reds', 'Purples', 'Greens', 'Blues', 'Oranges', 'Reds' ] 30 | cmaps = [ matplotlib.cm.get_cmap(cmap, 512) for cmap in cmaps ] 31 | cmaps = [ ListedColormap(cmap(np.linspace(0., 0.8, 256))) for cmap in cmaps ] 32 | if darker: 33 | cmaps = [ ListedColormap(cmap(np.linspace(0.2, 0.8, 256))) for cmap in cmaps ] 34 | 35 | for i in range(loc.shape[-1]): 36 | lc = colorline(loc[:,0,i], loc[:,1,i], cmap=cmaps[i],linestyle=linestyle,alpha=alpha,linewidth=linewidth) 37 | return xmin, ymin, xmax, ymax 38 | 39 | def colorline( 40 | x, y, z=None, cmap='copper', norm=plt.Normalize(0.0, 1.0), 41 | linewidth=2, alpha=0.8, linestyle='-'): 42 | """ 43 | http://nbviewer.ipython.org/github/dpsanders/matplotlib-examples/blob/master/colorline.ipynb 44 | http://matplotlib.org/examples/pylab_examples/multicolored_line.html 45 | """ 46 | # Default colors equally spaced on [0,1]: 47 | if z is None: 48 | z = np.linspace(0.0, 1.0, len(x)) 49 | if not hasattr(z, "__iter__"): 50 | z = np.array([z]) 51 | z = np.asarray(z) 52 | segments = make_segments(x, y) 53 | 54 | lc = mcoll.LineCollection(segments, array=z, cmap=cmap, norm=norm, 55 | linewidth=linewidth, alpha=alpha, linestyle=linestyle) 56 | ax = plt.gca() 57 | ax.add_collection(lc) 58 | return lc 59 | 60 | def make_segments(x, y): 61 | points = np.array([x, y]).T.reshape(-1, 1, 2) 62 | segments = np.concatenate([points[:-1], points[1:]], axis=1) 63 | return segments 64 | 65 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Factorised Neural Relational Inference for Multi-Interaction Systems 2 | 3 | This repository contains the official PyTorch implementation of: 4 | 5 | **Factorised Neural Relational Inference for Multi-Interaction Systems.** 6 | Ezra Webb, Ben Day, Helena Andres-Terre, Pietro Lió 7 | https://arxiv.org/abs/1905.08721 8 | 9 | ![Factorised Neural Relational Inference (fNRI)](fNRI.png) 10 | 11 | **Abstract:** Many complex natural and cultural phenomena are well modelled by systems of simple interactions between particles. A number of architectures have been developed to articulate this kind of structure, both implicitly and explicitly. We consider an unsupervised explicit model, the NRI model, and make a series of representational adaptations and physically motivated changes. Most notably we factorise the inferred latent interaction graph into a multiplex graph, allowing each layer to encode for a different interaction-type. This fNRI model is smaller in size and significantly outperforms the original in both edge and trajectory prediction, establishing a new state-of-the-art. We also present a simplified variant of our model, which demonstrates the NRI's formulation as a variational auto-encoder is not necessary for good performance, and make an adaptation to the NRI's training routine, significantly improving its ability to model complex physical dynamical systems. 12 | 13 | Much of the code here is based on https://github.com/ethanfetaya/NRI (MIT licence). We would like to thank Thomas Kipf, Ethan Fetaya, Kuan-Chieh Wang, Max Welling & Richard Zemel for making the codebase for the Neural Relational Inference model (arXiv:1802.04687) publicly available. 14 | 15 | 16 | ### Requirements 17 | * Pytorch 1.0 18 | * Python 3.6 19 | 20 | ### Data generation 21 | 22 | To replicate the experiments on simulated physical data, first generate training, validation and test data by running: 23 | 24 | ``` 25 | cd data 26 | python generate_dataset.py 27 | ``` 28 | This generates the ideal springs and charges (I+C) dataset. Add the argument `--sim-type springchargefspring` to the command above to generate the ideal spring, charges and finite springs dataset. 29 | 30 | ### Run experiments 31 | 32 | From the project's root folder, run 33 | ``` 34 | python train.py 35 | ``` 36 | to train an fNRI model on the I+C dataset. To run the standard NRI model, add the `--NRI` argument. You can specify a different dataset by modifying the `sim-folder` argument: `--sim-folder springchargefspring_5` will run the model on the I+C+F particle simulation with 5 particles (if it has been generated). The number of edge types in each layer graph is specified using the `--edge-types-list` argument: for the I+C+F dataset, use `--edge-types-list 2 2 2`. 37 | 38 | To train an sfNRI model, run 39 | ``` 40 | python train_sigmoid.py 41 | ``` 42 | Here the K parameter (in this case equivalent to the number of layer-graphs) is specifed using the `--num-factors` argument: for the I+C+F dataset, use `--num-factors 3`. 43 | 44 | To train the fNRI encoder in isolation in order to replicate the (supervised) experiments, run 45 | ``` 46 | python train_enc.py 47 | ``` 48 | 49 | To train the fNRI decoder in isolation in order to replicate the (true graph) experiments, run 50 | ``` 51 | python train_dec.py 52 | ``` 53 | The arguments `--NRI` or `--sigmoid` can be added to train the NRI and sfNRI models respectively in both of these scripts. A number of other training options are documented in the respective training files. 54 | -------------------------------------------------------------------------------- /data/generate_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is based on https://github.com/ethanfetaya/NRI 3 | (MIT licence) 4 | """ 5 | from synthetic_sim import * 6 | import time 7 | import numpy as np 8 | import argparse 9 | import os 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--num-train', type=int, default=50000, 13 | help='Number of training simulations to generate.') 14 | parser.add_argument('--num-valid', type=int, default=10000, 15 | help='Number of validation simulations to generate.') 16 | parser.add_argument('--num-test', type=int, default=10000, 17 | help='Number of test simulations to generate.') 18 | parser.add_argument('--length', type=int, default=10000, 19 | help='Length of trajectory.') 20 | parser.add_argument('--length-test', type=int, default=10000, 21 | help='Length of test set trajectory.') 22 | parser.add_argument('--sample-freq', type=int, default=100, 23 | help='How often to sample the trajectory.') 24 | parser.add_argument('--n-balls', type=int, default=5, 25 | help='Number of balls in the simulation.') 26 | parser.add_argument('--seed', type=int, default=42, 27 | help='Random seed.') 28 | parser.add_argument('--savefolder', type=str, default='springcharge_5', 29 | help='name of folder to save everything in') 30 | parser.add_argument('--sim-type', type=str, default='springcharge', 31 | help='Type of simulation system') 32 | 33 | args = parser.parse_args() 34 | os.makedirs(args.savefolder) 35 | par_file = open(os.path.join(args.savefolder,'sim_args.txt'),'w') 36 | print(args, file=par_file) 37 | par_file.flush() 38 | par_file.close() 39 | 40 | if args.sim_type == 'springcharge': 41 | sim = SpringChargeSim(noise_var=0.0, n_balls=args.n_balls, box_size=5.0) 42 | 43 | elif args.sim_type == 'springchargequad': 44 | sim = SpringChargeQuadSim(noise_var=0.0, n_balls=args.n_balls, box_size=5.0) 45 | 46 | elif args.sim_type == 'springquad': 47 | sim = SpringQuadSim(noise_var=0.0, n_balls=args.n_balls, box_size=5.0) 48 | 49 | elif args.sim_type == 'springchargefspring': 50 | sim = SpringChargeFspringSim(noise_var=0.0, n_balls=args.n_balls, box_size=5.0) 51 | 52 | np.random.seed(args.seed) 53 | 54 | def generate_dataset(num_sims, length, sample_freq): 55 | loc_all = list() 56 | vel_all = list() 57 | edges_all = list() 58 | 59 | for i in range(num_sims): 60 | t = time.time() 61 | loc, vel, edges = sim.sample_trajectory(T=length, sample_freq=sample_freq) 62 | if i % 100 == 0: 63 | print("Iter: {}, Simulation time: {}".format(i, time.time() - t)) 64 | loc_all.append(loc) 65 | vel_all.append(vel) 66 | edges_all.append(edges) 67 | 68 | loc_all = np.stack(loc_all) 69 | vel_all = np.stack(vel_all) 70 | edges_all = np.stack(edges_all) 71 | 72 | return loc_all, vel_all, edges_all 73 | 74 | 75 | print("Generating {} training simulations".format(args.num_train)) 76 | loc_train, vel_train, edges_train = generate_dataset(args.num_train, args.length, args.sample_freq) 77 | 78 | np.save(os.path.join(args.savefolder,'loc_train.npy'), loc_train) 79 | np.save(os.path.join(args.savefolder,'vel_train.npy'), vel_train) 80 | np.save(os.path.join(args.savefolder,'edges_train.npy'), edges_train) 81 | 82 | print("Generating {} validation simulations".format(args.num_valid)) 83 | loc_valid, vel_valid, edges_valid = generate_dataset(args.num_valid, args.length, args.sample_freq) 84 | 85 | np.save(os.path.join(args.savefolder,'loc_valid.npy'), loc_valid) 86 | np.save(os.path.join(args.savefolder,'vel_valid.npy'), vel_valid) 87 | np.save(os.path.join(args.savefolder,'edges_valid.npy'), edges_valid) 88 | 89 | print("Generating {} test simulations".format(args.num_test)) 90 | loc_test, vel_test, edges_test= generate_dataset(args.num_test, args.length_test, args.sample_freq) 91 | 92 | np.save(os.path.join(args.savefolder,'loc_test.npy'), loc_test) 93 | np.save(os.path.join(args.savefolder,'vel_test.npy'), vel_test) 94 | np.save(os.path.join(args.savefolder,'edges_test.npy'), edges_test) 95 | -------------------------------------------------------------------------------- /train_dec.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is based on https://github.com/ethanfetaya/NRI 3 | (MIT licence) 4 | """ 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import time 9 | import argparse 10 | import pickle 11 | import os 12 | import datetime 13 | import csv 14 | import math 15 | 16 | import torch.optim as optim 17 | from torch.optim import lr_scheduler 18 | 19 | from utils import * 20 | from modules import * 21 | 22 | parser = argparse.ArgumentParser() 23 | ## arguments related to training ## 24 | parser.add_argument('--epochs', type=int, default=500, 25 | help='Number of epochs to train.') 26 | parser.add_argument('--batch-size', type=int, default=128, 27 | help='Number of samples per batch.') 28 | parser.add_argument('--lr', type=float, default=0.0005, 29 | help='Initial learning rate.') 30 | parser.add_argument('--prediction-steps', type=int, default=10, metavar='N', 31 | help='Num steps to predict before re-using teacher forcing.') 32 | parser.add_argument('--lr-decay', type=int, default=200, 33 | help='After how epochs to decay LR by a factor of gamma.') 34 | parser.add_argument('--gamma', type=float, default=0.5, 35 | help='LR decay factor.') 36 | parser.add_argument('--patience', type=int, default=500, 37 | help='Early stopping patience') 38 | parser.add_argument('--decoder-dropout', type=float, default=0.0, 39 | help='Dropout rate (1 - keep probability).') 40 | parser.add_argument('--dont-split-data', action='store_true', default=False, 41 | help='Whether to not split training and validation data into two parts') 42 | parser.add_argument('--split-enc-only', action='store_true', default=False, 43 | help='Whether to give the encoder the first half of trajectories \ 44 | and the decoder the whole of the trajectories') 45 | 46 | ## arguments related to loss function ## 47 | parser.add_argument('--var', type=float, default=5e-5, 48 | help='Output variance.') 49 | 50 | ## arguments related to weight and bias initialisation ## 51 | parser.add_argument('--seed', type=int, default=1, 52 | help='Random seed.') 53 | parser.add_argument('--decoder-init-type',type=str, default='default', 54 | help='The type of weight initialization to use in the decoder') 55 | 56 | 57 | ## arguments related to changing the model ## 58 | parser.add_argument('--NRI', action='store_true', default=False, 59 | help='Use the NRI model, rather than the fNRI model') 60 | parser.add_argument('--sigmoid', action='store_true', default=False, 61 | help='Use the sfNRI model, rather than the fNRI model') 62 | parser.add_argument('--edge-types-list', nargs='+', default=[2,2], 63 | help='The number of edge types to infer.') # takes arguments from cmd line as: --edge-types-list 2 2 64 | parser.add_argument('--decoder', type=str, default='mlp', 65 | help='Type of decoder model (mlp, rnn, or sim).') 66 | parser.add_argument('--decoder-hidden', type=int, default=256, 67 | help='Number of hidden units.') 68 | parser.add_argument('--skip-first', action='store_true', default=False, 69 | help='Skip the first edge type in each block in the decoder, i.e. it represents no-edge.') 70 | parser.add_argument('--full-graph', action='store_true', default=False, 71 | help='Use a fixed fully connected graph rather than the ground truth labels') 72 | parser.add_argument('--num-factors', type=int, default=2, 73 | help='The number of factor graphs (this is only for sfNRI model, replaces edge-types-list)') 74 | 75 | ## arguments related to the simulation data ## 76 | parser.add_argument('--sim-folder', type=str, default='springcharge_5', 77 | help='Name of the folder in the data folder to load simulation data from') 78 | parser.add_argument('--data-folder', type=str, default='data', 79 | help='Name of the data folder to load data from') 80 | parser.add_argument('--num-atoms', type=int, default=5, 81 | help='Number of atoms in simulation.') 82 | parser.add_argument('--dims', type=int, default=4, 83 | help='The number of input dimensions (position + velocity).') 84 | parser.add_argument('--timesteps', type=int, default=49, 85 | help='The number of time steps per sample.') 86 | 87 | ## Saving, loading etc. ## 88 | parser.add_argument('--no-cuda', action='store_true', default=False, 89 | help='Disables CUDA training.') 90 | parser.add_argument('--save-folder', type=str, default='logs', 91 | help='Where to save the trained model, leave empty to not save anything.') 92 | parser.add_argument('--load-folder', type=str, default='', 93 | help='Where to load the trained model if finetunning. ' + 94 | 'Leave empty to train from scratch') 95 | parser.add_argument('--test', action='store_true', default=False, 96 | help='Skip training and validation') 97 | parser.add_argument('--plot', action='store_true', default=False, 98 | help='Skip training and plot trajectories against actual') 99 | 100 | 101 | args = parser.parse_args() 102 | args.cuda = not args.no_cuda and torch.cuda.is_available() 103 | #args.factor = not args.no_factor 104 | args.edge_types_list = list(map(int, args.edge_types_list)) 105 | args.edge_types_list.sort(reverse=True) 106 | 107 | if all( (isinstance(k, int) and k >= 1) for k in args.edge_types_list): 108 | if args.NRI: 109 | edge_types = np.prod(args.edge_types_list) 110 | else: 111 | edge_types = sum(args.edge_types_list) 112 | else: 113 | raise ValueError('Could not compute the edge-types-list') 114 | 115 | np.random.seed(args.seed) 116 | torch.manual_seed(args.seed) 117 | if args.cuda: 118 | torch.cuda.manual_seed(args.seed) 119 | 120 | print(args) 121 | 122 | # Save model and meta-data. Always saves in a new sub-folder. 123 | if args.save_folder: 124 | exp_counter = 0 125 | now = datetime.datetime.now() 126 | timestamp = now.isoformat().replace(':','-')[:-7] 127 | save_folder = os.path.join(args.save_folder,'exp'+timestamp) 128 | os.makedirs(save_folder) 129 | meta_file = os.path.join(save_folder, 'metadata.pkl') 130 | decoder_file = os.path.join(save_folder, 'decoder.pt') 131 | 132 | log_file = os.path.join(save_folder, 'log.txt') 133 | log_csv_file = os.path.join(save_folder, 'log_csv.csv') 134 | log = open(log_file, 'w') 135 | log_csv = open(log_csv_file, 'w') 136 | csv_writer = csv.writer(log_csv, delimiter=',') 137 | 138 | pickle.dump({'args': args}, open(meta_file, "wb")) 139 | 140 | par_file = open(os.path.join(save_folder,'args.txt'),'w') 141 | print(args,file=par_file) 142 | par_file.flush 143 | par_file.close() 144 | 145 | else: 146 | print("WARNING: No save_folder provided!" + 147 | "Testing (within this script) will throw an error.") 148 | 149 | 150 | if args.NRI: 151 | train_loader, valid_loader, test_loader, loc_max, loc_min, vel_max, vel_min = load_data_NRI( 152 | args.batch_size, args.sim_folder, shuffle=True, 153 | data_folder=args.data_folder) 154 | else: 155 | train_loader, valid_loader, test_loader, loc_max, loc_min, vel_max, vel_min = load_data_fNRI( 156 | args.batch_size, args.sim_folder, shuffle=True, 157 | data_folder=args.data_folder) 158 | 159 | 160 | # Generate off-diagonal interaction graph 161 | off_diag = np.ones([args.num_atoms, args.num_atoms]) - np.eye(args.num_atoms) 162 | rel_rec = np.array(encode_onehot(np.where(off_diag)[1]), dtype=np.float32) 163 | rel_send = np.array(encode_onehot(np.where(off_diag)[0]), dtype=np.float32) 164 | rel_rec = torch.FloatTensor(rel_rec) 165 | rel_send = torch.FloatTensor(rel_send) 166 | 167 | if args.NRI: 168 | edge_types_list = [ edge_types ] 169 | else: 170 | edge_types_list = args.edge_types_list 171 | 172 | if args.decoder == 'mlp': 173 | if args.sigmoid: 174 | decoder = MLPDecoder_sigmoid(n_in_node=args.dims, 175 | num_factors=args.num_factors, 176 | msg_hid=args.decoder_hidden, 177 | msg_out=args.decoder_hidden, 178 | n_hid=args.decoder_hidden, 179 | do_prob=args.decoder_dropout, 180 | init_type=args.decoder_init_type) 181 | else: 182 | decoder = MLPDecoder_multi(n_in_node=args.dims, 183 | edge_types=edge_types, 184 | edge_types_list=edge_types_list, 185 | msg_hid=args.decoder_hidden, 186 | msg_out=args.decoder_hidden, 187 | n_hid=args.decoder_hidden, 188 | do_prob=args.decoder_dropout, 189 | skip_first=args.skip_first, 190 | init_type=args.decoder_init_type) 191 | 192 | elif args.decoder == 'stationary': 193 | decoder = StationaryDecoder() 194 | 195 | elif args.decoder == 'velocity': 196 | decoder = VelocityStepDecoder() 197 | 198 | if args.load_folder: 199 | print('Loading model from: '+args.load_folder) 200 | decoder_file = os.path.join(args.load_folder, 'decoder.pt') 201 | if not args.cuda: 202 | decoder.load_state_dict(torch.load(decoder_file,map_location='cpu')) 203 | else: 204 | decoder.load_state_dict(torch.load(decoder_file)) 205 | args.save_folder = False 206 | 207 | optimizer = optim.Adam(list(decoder.parameters()), lr=args.lr) 208 | scheduler = lr_scheduler.StepLR(optimizer, step_size=args.lr_decay, 209 | gamma=args.gamma) 210 | 211 | 212 | if args.cuda: 213 | decoder.cuda() 214 | rel_rec = rel_rec.cuda() 215 | rel_send = rel_send.cuda() 216 | 217 | rel_rec = Variable(rel_rec) 218 | rel_send = Variable(rel_send) 219 | 220 | 221 | def train(epoch, best_val_loss): 222 | t = time.time() 223 | nll_train = [] 224 | nll_var_train = [] 225 | mse_train = [] 226 | 227 | decoder.train() 228 | scheduler.step() 229 | if not args.plot: 230 | for batch_idx, (data, relations) in enumerate(train_loader): # relations are the ground truth interactions graphs 231 | 232 | optimizer.zero_grad() 233 | 234 | if args.full_graph: 235 | zeros = torch.zeros([data.size(0), rel_rec.size(0)]) 236 | ones = torch.ones([data.size(0), rel_rec.size(0)]) 237 | if args.NRI: 238 | stack = [ ones ] + [ zeros for _ in range(edge_types-1) ] 239 | rel_type_onehot = torch.stack(stack, -1) 240 | elif args.sigmoid: 241 | stack = [ ones for _ in range(args.num_factors) ] 242 | rel_type_onehot = torch.stack(stack, -1) 243 | else: 244 | stack = [] 245 | for i in range(len(args.edge_types_list)): 246 | stack += [ ones ] + [ zeros for _ in range(args.edge_types_list[i]-1) ] 247 | rel_type_onehot = torch.stack(stack, -1) 248 | 249 | else: 250 | if args.NRI: 251 | rel_type_onehot = torch.FloatTensor(data.size(0), rel_rec.size(0), edge_types) 252 | rel_type_onehot.zero_() 253 | rel_type_onehot.scatter_(2, relations.view(data.size(0), -1, 1), 1) 254 | elif args.sigmoid: 255 | rel_type_onehot = relations.transpose(1,2).type(torch.FloatTensor) 256 | else: 257 | rel_type_onehot = [ torch.FloatTensor(data.size(0), rel_rec.size(0), types) for types in args.edge_types_list ] 258 | rel_type_onehot = [ rel.zero_() for rel in rel_type_onehot ] 259 | rel_type_onehot = [ rel_type_onehot[i].scatter_(2, relations[:,i,:].view(data.size(0), -1, 1), 1) for i in range(len(rel_type_onehot)) ] 260 | rel_type_onehot = torch.cat( rel_type_onehot, dim=-1 ) 261 | 262 | if args.dont_split_data: 263 | data_decoder = data[:, :, :args.timesteps, :] 264 | elif args.split_enc_only: 265 | data_decoder = data 266 | else: 267 | assert (data.size(2) - args.timesteps) >= args.timesteps 268 | data_decoder = data[:, :, -args.timesteps:, :] 269 | 270 | if args.cuda: 271 | data_decoder, rel_type_onehot = data_decoder.cuda(), rel_type_onehot.cuda() 272 | data_decoder = data_decoder.contiguous() 273 | 274 | data_decoder, rel_type_onehot = Variable(data_decoder), Variable(rel_type_onehot) 275 | 276 | target = data_decoder[:, :, 1:, :] # dimensions are [batch, particle, time, state] 277 | output = decoder(data_decoder, rel_type_onehot, rel_rec, rel_send, args.prediction_steps) 278 | 279 | loss_nll = nll_gaussian(output, target, args.var) 280 | loss_nll_var = nll_gaussian_var(output, target, args.var) 281 | 282 | 283 | loss_nll.backward() 284 | optimizer.step() 285 | 286 | mse_train.append(F.mse_loss(output, target).data.item()) 287 | nll_train.append(loss_nll.data.item()) 288 | nll_var_train.append(loss_nll_var.data.item()) 289 | 290 | 291 | nll_val = [] 292 | nll_var_val = [] 293 | mse_val = [] 294 | 295 | nll_M_val = [] 296 | nll_M_var_val = [] 297 | 298 | decoder.eval() 299 | for batch_idx, (data, relations) in enumerate(valid_loader): 300 | with torch.no_grad(): 301 | 302 | if args.full_graph: 303 | zeros = torch.zeros([data.size(0), rel_rec.size(0)]) 304 | ones = torch.ones([data.size(0), rel_rec.size(0)]) 305 | if args.NRI: 306 | stack = [ ones ] + [ zeros for _ in range(edge_types-1) ] 307 | rel_type_onehot = torch.stack(stack, -1) 308 | elif args.sigmoid: 309 | stack = [ ones for _ in range(args.num_factors) ] 310 | rel_type_onehot = torch.stack(stack, -1) 311 | else: 312 | stack = [] 313 | for i in range(len(args.edge_types_list)): 314 | stack += [ ones ] + [ zeros for _ in range(args.edge_types_list[i]-1) ] 315 | rel_type_onehot = torch.stack(stack, -1) 316 | 317 | else: 318 | if args.NRI: 319 | rel_type_onehot = torch.FloatTensor(data.size(0), rel_rec.size(0), edge_types) 320 | rel_type_onehot.zero_() 321 | rel_type_onehot.scatter_(2, relations.view(data.size(0), -1, 1), 1) 322 | elif args.sigmoid: 323 | rel_type_onehot = relations.transpose(1,2).type(torch.FloatTensor) 324 | else: 325 | rel_type_onehot = [ torch.FloatTensor(data.size(0), rel_rec.size(0), types) for types in args.edge_types_list ] 326 | rel_type_onehot = [ rel.zero_() for rel in rel_type_onehot ] 327 | rel_type_onehot = [ rel_type_onehot[i].scatter_(2, relations[:,i,:].view(data.size(0), -1, 1), 1) for i in range(len(rel_type_onehot)) ] 328 | rel_type_onehot = torch.cat( rel_type_onehot, dim=-1 ) 329 | 330 | if args.dont_split_data: 331 | data_decoder = data[:, :, :args.timesteps, :] 332 | elif args.split_enc_only: 333 | data_decoder = data 334 | else: 335 | assert (data.size(2) - args.timesteps) >= args.timesteps 336 | data_decoder = data[:, :, -args.timesteps:, :] 337 | 338 | if args.cuda: 339 | data_decoder, rel_type_onehot = data_decoder.cuda(), rel_type_onehot.cuda() 340 | data_decoder = data_decoder.contiguous() 341 | 342 | data_decoder, rel_type_onehot = Variable(data_decoder), Variable(rel_type_onehot) 343 | 344 | target = data_decoder[:, :, 1:, :] # dimensions are [batch, particle, time, state] 345 | output = decoder(data_decoder, rel_type_onehot, rel_rec, rel_send, 1) 346 | 347 | if args.plot: 348 | import matplotlib.pyplot as plt 349 | output_plot = decoder(data_decoder, rel_type_onehot, rel_rec, rel_send, 49) 350 | 351 | from trajectory_plot import draw_lines 352 | for i in range(args.batch_size): 353 | fig = plt.figure(figsize=(7, 7)) 354 | ax = fig.add_axes([0, 0, 1, 1]) 355 | xmin_t, ymin_t, xmax_t, ymax_t = draw_lines( target, i, linestyle=':', alpha=0.6 ) 356 | xmin_o, ymin_o, xmax_o, ymax_o = draw_lines( output_plot.detach().numpy(), i, linestyle='-' ) 357 | 358 | ax.set_xlim([min(xmin_t, xmin_o), max(xmax_t, xmax_o)]) 359 | ax.set_ylim([min(ymin_t, ymin_o), max(ymax_t, ymax_o)]) 360 | ax.set_xticks([]) 361 | ax.set_yticks([]) 362 | plt.show() 363 | 364 | 365 | loss_nll = nll_gaussian(output, target, args.var) 366 | loss_nll_var = nll_gaussian_var(output, target, args.var) 367 | 368 | output_M = decoder(data_decoder, rel_type_onehot, rel_rec, rel_send, args.prediction_steps) 369 | loss_nll_M = nll_gaussian(output_M, target, args.var) 370 | loss_nll_M_var = nll_gaussian_var(output_M, target, args.var) 371 | 372 | mse_val.append(F.mse_loss(output_M, target).data.item()) 373 | nll_val.append(loss_nll.data.item()) 374 | nll_var_val.append(loss_nll_var.data.item()) 375 | 376 | nll_M_val.append(loss_nll_M.data.item()) 377 | nll_M_var_val.append(loss_nll_M_var.data.item()) 378 | 379 | 380 | print('Epoch: {:03d}'.format(epoch), 381 | 'time: {:.1f}s'.format(time.time() - t), 382 | 'nll_trn: {:.2f}'.format(np.mean(nll_train)), 383 | 'mse_trn: {:.10f}'.format(np.mean(mse_train)), 384 | 'nll_val: {:.2f}'.format(np.mean(nll_M_val)), 385 | 'mse_val: {:.10f}'.format(np.mean(mse_val)) 386 | ) 387 | 388 | print('Epoch: {:03d}'.format(epoch), 389 | 'time: {:.1f}s'.format(time.time() - t), 390 | 'nll_trn: {:.2f}'.format(np.mean(nll_train)), 391 | 'mse_trn: {:.10f}'.format(np.mean(mse_train)), 392 | 'nll_val: {:.2f}'.format(np.mean(nll_M_val)), 393 | 'mse_val: {:.10f}'.format(np.mean(mse_val)), 394 | file=log) 395 | 396 | if epoch == 0: 397 | labels = [ 'epoch', 'nll trn', 'mse train', 'nll var trn' ] 398 | labels += [ 'nll val', 'nll M val', 'mse val', 'nll var val', 'nll M var val' ] 399 | csv_writer.writerow( labels ) 400 | 401 | csv_writer.writerow( [epoch, np.mean(nll_train), np.mean(mse_train), np.mean(nll_var_train)] + 402 | [np.mean(nll_val), np.mean(nll_M_val), np.mean(mse_val)] + 403 | [np.mean(nll_var_val), np.mean(nll_M_var_val)] 404 | ) 405 | 406 | log.flush() 407 | if args.save_folder and np.mean(nll_M_val) < best_val_loss: 408 | torch.save(decoder.state_dict(), decoder_file) 409 | print('Best model so far, saving...') 410 | return np.mean(nll_M_val) 411 | 412 | 413 | def test(): 414 | t = time.time() 415 | nll_test = [] 416 | nll_var_test = [] 417 | mse_1_test = [] 418 | mse_10_test = [] 419 | mse_20_test = [] 420 | mse_static = [] 421 | 422 | nll_M_test = [] 423 | nll_M_var_test = [] 424 | 425 | decoder.eval() 426 | if not args.cuda: 427 | decoder.load_state_dict(torch.load(decoder_file,map_location='cpu')) 428 | else: 429 | decoder.load_state_dict(torch.load(decoder_file)) 430 | 431 | for batch_idx, (data, relations) in enumerate(test_loader): 432 | with torch.no_grad(): 433 | 434 | if args.full_graph: 435 | zeros = torch.zeros([data.size(0), rel_rec.size(0)]) 436 | ones = torch.ones([data.size(0), rel_rec.size(0)]) 437 | if args.NRI: 438 | stack = [ ones ] + [ zeros for _ in range(edge_types-1) ] 439 | rel_type_onehot = torch.stack(stack, -1) 440 | elif args.sigmoid: 441 | stack = [ ones for _ in range(args.num_factors) ] 442 | rel_type_onehot = torch.stack(stack, -1) 443 | else: 444 | stack = [] 445 | for i in range(len(args.edge_types_list)): 446 | stack += [ ones ] + [ zeros for _ in range(args.edge_types_list[i]-1) ] 447 | rel_type_onehot = torch.stack(stack, -1) 448 | 449 | else: 450 | if args.NRI: 451 | rel_type_onehot = torch.FloatTensor(data.size(0), rel_rec.size(0), edge_types) 452 | rel_type_onehot.zero_() 453 | rel_type_onehot.scatter_(2, relations.view(data.size(0), -1, 1), 1) 454 | elif args.sigmoid: 455 | rel_type_onehot = relations.transpose(1,2).type(torch.FloatTensor) 456 | else: 457 | rel_type_onehot = [ torch.FloatTensor(data.size(0), rel_rec.size(0), types) for types in args.edge_types_list ] 458 | rel_type_onehot = [ rel.zero_() for rel in rel_type_onehot ] 459 | rel_type_onehot = [ rel_type_onehot[i].scatter_(2, relations[:,i,:].view(data.size(0), -1, 1), 1) for i in range(len(rel_type_onehot)) ] 460 | rel_type_onehot = torch.cat( rel_type_onehot, dim=-1 ) 461 | 462 | data_decoder = data[:, :, -args.timesteps:, :] 463 | 464 | if args.cuda: 465 | data_decoder, rel_type_onehot = data_decoder.cuda(), rel_type_onehot.cuda() 466 | data_decoder = data_decoder.contiguous() 467 | 468 | data_decoder, rel_type_onehot = Variable(data_decoder), Variable(rel_type_onehot) 469 | 470 | target = data_decoder[:, :, 1:, :] # dimensions are [batch, particle, time, state] 471 | output = decoder(data_decoder, rel_type_onehot, rel_rec, rel_send, 1) 472 | 473 | 474 | if args.plot: 475 | import matplotlib.pyplot as plt 476 | output_plot = decoder(data_decoder, rel_type_onehot, rel_rec, rel_send, 49) 477 | from trajectory_plot import draw_lines 478 | for i in range(args.batch_size): 479 | fig = plt.figure(figsize=(7, 7)) 480 | ax = fig.add_axes([0, 0, 1, 1]) 481 | xmin_t, ymin_t, xmax_t, ymax_t = draw_lines( target, i, linestyle=':', alpha=0.6 ) 482 | xmin_o, ymin_o, xmax_o, ymax_o = draw_lines( output_plot.detach().numpy(), i, linestyle='-' ) 483 | 484 | ax.set_xlim([min(xmin_t, xmin_o), max(xmax_t, xmax_o)]) 485 | ax.set_ylim([min(ymin_t, ymin_o), max(ymax_t, ymax_o)]) 486 | ax.set_xticks([]) 487 | ax.set_yticks([]) 488 | #plt.savefig(os.path.join(args.load_folder,str(i)+'_pred_and_true_.png'), dpi=300) 489 | plt.show() 490 | 491 | 492 | loss_nll = nll_gaussian(output, target, args.var) 493 | loss_nll_var = nll_gaussian_var(output, target, args.var) 494 | 495 | output_M = decoder(data_decoder, rel_type_onehot, rel_rec, rel_send, args.prediction_steps) 496 | loss_nll_M = nll_gaussian(output_M, target, args.var) 497 | 498 | output_10 = decoder(data_decoder, rel_type_onehot, rel_rec, rel_send, 10) 499 | output_20 = decoder(data_decoder, rel_type_onehot, rel_rec, rel_send, 20) 500 | mse_1_test.append(F.mse_loss(output, target).data.item()) 501 | mse_10_test.append(F.mse_loss(output_10, target).data.item()) 502 | mse_20_test.append(F.mse_loss(output_20, target).data.item()) 503 | 504 | static = F.mse_loss(data_decoder[:, :, :-1, :], data_decoder[:, :, 1:, :]) 505 | mse_static.append(static.data.item()) 506 | 507 | nll_test.append(loss_nll.data.item()) 508 | nll_var_test.append(loss_nll_var.data.item()) 509 | nll_M_test.append(loss_nll_M.data.item()) 510 | 511 | 512 | print('--------------------------------') 513 | print('------------Testing-------------') 514 | print('--------------------------------') 515 | print('nll_test: {:.2f}'.format(np.mean(nll_test)), 516 | 'nll_M_test: {:.2f}'.format(np.mean(nll_M_test)), 517 | 'mse_1_test: {:.10f}'.format(np.mean(mse_1_test)), 518 | 'mse_10_test: {:.10f}'.format(np.mean(mse_10_test)), 519 | 'mse_20_test: {:.10f}'.format(np.mean(mse_20_test)), 520 | 'mse_static: {:.10f}'.format(np.mean(mse_static)), 521 | 'time: {:.1f}s'.format(time.time() - t)) 522 | print('--------------------------------', file=log) 523 | print('------------Testing-------------', file=log) 524 | print('--------------------------------', file=log) 525 | print('nll_test: {:.2f}'.format(np.mean(nll_test)), 526 | 'nll_M_test: {:.2f}'.format(np.mean(nll_M_test)), 527 | 'mse_1_test: {:.10f}'.format(np.mean(mse_1_test)), 528 | 'mse_10_test: {:.10f}'.format(np.mean(mse_10_test)), 529 | 'mse_20_test: {:.10f}'.format(np.mean(mse_20_test)), 530 | 'mse_static: {:.10f}'.format(np.mean(mse_static)), 531 | 'time: {:.1f}s'.format(time.time() - t), 532 | file=log) 533 | log.flush() 534 | 535 | 536 | # Train model 537 | if not args.test: 538 | t_total = time.time() 539 | best_val_loss = np.inf 540 | best_epoch = 0 541 | for epoch in range(args.epochs): 542 | val_loss = train(epoch, best_val_loss) 543 | if val_loss < best_val_loss: 544 | best_val_loss = val_loss 545 | best_epoch = epoch 546 | if epoch - best_epoch > args.patience and epoch > 99: 547 | break 548 | print("Optimization Finished!") 549 | print("Best Epoch: {:04d}".format(best_epoch)) 550 | if args.save_folder: 551 | print("Best Epoch: {:04d}".format(best_epoch), file=log) 552 | log.flush() 553 | 554 | print('Reloading best model') 555 | test() 556 | if log is not None: 557 | print(save_folder) 558 | log.close() 559 | log_csv.close() 560 | -------------------------------------------------------------------------------- /train_sigmoid.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is based on https://github.com/ethanfetaya/NRI 3 | (MIT licence) 4 | """ 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import time 9 | import argparse 10 | import pickle 11 | import os 12 | import datetime 13 | import csv 14 | import math 15 | 16 | import torch.optim as optim 17 | from torch.optim import lr_scheduler 18 | 19 | from utils import * 20 | from modules import * 21 | 22 | parser = argparse.ArgumentParser() 23 | ## arguments related to training ## 24 | parser.add_argument('--epochs', type=int, default=500, 25 | help='Number of epochs to train.') 26 | parser.add_argument('--batch-size', type=int, default=128, 27 | help='Number of samples per batch.') 28 | parser.add_argument('--lr', type=float, default=0.0005, 29 | help='Initial learning rate.') 30 | parser.add_argument('--prediction-steps', type=int, default=10, metavar='N', 31 | help='Num steps to predict before re-using teacher forcing.') 32 | parser.add_argument('--lr-decay', type=int, default=1000, 33 | help='After how epochs to decay LR by a factor of gamma.') 34 | parser.add_argument('--weight-decay', type=float, default=0.0, 35 | help='Weight decay value for L2 regularisation in Adam optimiser') 36 | parser.add_argument('--gamma', type=float, default=0.5, 37 | help='LR decay factor.') 38 | parser.add_argument('--patience', type=int, default=500, 39 | help='Early stopping patience') 40 | parser.add_argument('--encoder-dropout', type=float, default=0.0, 41 | help='Dropout rate (1 - keep probability).') 42 | parser.add_argument('--decoder-dropout', type=float, default=0.0, 43 | help='Dropout rate (1 - keep probability).') 44 | parser.add_argument('--dont-split-data', action='store_true', default=False, 45 | help='Whether to not split training and validation data into two parts') 46 | parser.add_argument('--split-enc-only', action='store_true', default=False, 47 | help='Whether to give the encoder the first half of trajectories \ 48 | and the decoder the whole of the trajectories') 49 | 50 | ## arguments related to loss function ## 51 | parser.add_argument('--var', type=float, default=5e-5, 52 | help='Output variance.') ## this is only here to rescale mse for NRI and fNRI comparison 53 | 54 | ## arguments related to weight and bias initialisation ## 55 | parser.add_argument('--seed', type=int, default=1, 56 | help='Random seed.') 57 | parser.add_argument('--encoder-init-type',type=str, default='xavier_normal', 58 | help='The type of weight initialization to use in the encoder') 59 | parser.add_argument('--decoder-init-type',type=str, default='default', 60 | help='The type of weight initialization to use in the decoder') 61 | parser.add_argument('--encoder-bias-scale',type=float, default=0.1, 62 | help='The type of weight initialization to use in the encoder') 63 | 64 | ## arguments related to changing the model ## 65 | parser.add_argument('--num-factors', type=int, default=2, 66 | help='The number of factors to use') 67 | parser.add_argument('--split-point', type=int, default=0, 68 | help='The point at which factor graphs are split up in the encoder' ) 69 | parser.add_argument('--encoder', type=str, default='mlp', 70 | help='Type of path encoder model (mlp or cnn).') 71 | parser.add_argument('--decoder', type=str, default='mlp', 72 | help='Type of decoder model (mlp, rnn, or sim).') 73 | parser.add_argument('--encoder-hidden', type=int, default=256, 74 | help='Number of hidden units.') 75 | parser.add_argument('--decoder-hidden', type=int, default=256, 76 | help='Number of hidden units.') 77 | parser.add_argument('--sigmoid-sharpness', type=float, default=1., 78 | help='Coefficient in the power of the sigmoid function') 79 | parser.add_argument('--hard', action='store_true', default=False, 80 | help='Round edges to integers while retaining non-rounded gradients during training.') 81 | 82 | ## arguments related to the simulation data ## 83 | parser.add_argument('--sim-folder', type=str, default='springcharge_5', 84 | help='Name of the folder in the data folder to load simulation data from') 85 | parser.add_argument('--data-folder', type=str, default='data', 86 | help='Name of the data folder to load data from') 87 | parser.add_argument('--num-atoms', type=int, default=5, 88 | help='Number of atoms in simulation.') 89 | parser.add_argument('--dims', type=int, default=4, 90 | help='The number of input dimensions (position + velocity).') 91 | parser.add_argument('--timesteps', type=int, default=49, 92 | help='The number of time steps per sample.') 93 | 94 | ## Saving, loading etc. ## 95 | parser.add_argument('--no-cuda', action='store_true', default=False, 96 | help='Disables CUDA training.') 97 | parser.add_argument('--save-folder', type=str, default='logs', 98 | help='Where to save the trained model, leave empty to not save anything.') 99 | parser.add_argument('--load-folder', type=str, default='', 100 | help='Where to load the trained model if finetunning. ' + 101 | 'Leave empty to train from scratch') 102 | parser.add_argument('--test', action='store_true', default=False, 103 | help='Skip training and validation') 104 | parser.add_argument('--plot', action='store_true', default=False, 105 | help='Skip training and plot trajectories against actual') 106 | 107 | 108 | args = parser.parse_args() 109 | args.cuda = not args.no_cuda and torch.cuda.is_available() 110 | 111 | print(args) 112 | 113 | np.random.seed(args.seed) 114 | torch.manual_seed(args.seed) 115 | if args.cuda: 116 | torch.cuda.manual_seed(args.seed) 117 | 118 | 119 | # Save model and meta-data. Always saves in a new sub-folder. 120 | if args.save_folder: 121 | exp_counter = 0 122 | now = datetime.datetime.now() 123 | timestamp = now.isoformat().replace(':','-')[:-7] 124 | save_folder = os.path.join(args.save_folder,'exp'+timestamp) 125 | os.makedirs(save_folder) 126 | meta_file = os.path.join(save_folder, 'metadata.pkl') 127 | encoder_file = os.path.join(save_folder, 'encoder.pt') 128 | decoder_file = os.path.join(save_folder, 'decoder.pt') 129 | 130 | log_file = os.path.join(save_folder, 'log.txt') 131 | log_csv_file = os.path.join(save_folder, 'log_csv.csv') 132 | log = open(log_file, 'w') 133 | log_csv = open(log_csv_file, 'w') 134 | csv_writer = csv.writer(log_csv, delimiter=',') 135 | 136 | pickle.dump({'args': args}, open(meta_file, "wb")) 137 | 138 | par_file = open(os.path.join(save_folder,'args.txt'),'w') 139 | print(args,file=par_file) 140 | par_file.flush 141 | par_file.close() 142 | 143 | perm_csv_file = os.path.join(save_folder, 'perm_csv.csv') 144 | perm_csv = open(perm_csv_file, 'w') 145 | perm_writer = csv.writer(perm_csv, delimiter=',') 146 | else: 147 | print("WARNING: No save_folder provided!" + 148 | "Testing (within this script) will throw an error.") 149 | 150 | train_loader, valid_loader, test_loader, loc_max, loc_min, vel_max, vel_min = load_data_fNRI( 151 | args.batch_size, args.sim_folder, 152 | shuffle=True, 153 | data_folder=args.data_folder) 154 | 155 | 156 | # Generate off-diagonal interaction graph 157 | off_diag = np.ones([args.num_atoms, args.num_atoms]) - np.eye(args.num_atoms) 158 | rel_rec = np.array(encode_onehot(np.where(off_diag)[1]), dtype=np.float32) 159 | rel_send = np.array(encode_onehot(np.where(off_diag)[0]), dtype=np.float32) 160 | rel_rec = torch.FloatTensor(rel_rec) 161 | rel_send = torch.FloatTensor(rel_send) 162 | 163 | 164 | if args.encoder == 'mlp': 165 | encoder = MLPEncoder_sigmoid(args.timesteps * args.dims, args.encoder_hidden, 166 | args.num_factors,args.encoder_dropout, 167 | split_point=args.split_point) 168 | 169 | elif args.encoder == 'random': 170 | encoder = RandomEncoder(args.num_factors, args.cuda) 171 | 172 | elif args.encoder == 'ones': 173 | encoder = OnesEncoder(args.num_factors, args.cuda) 174 | 175 | if args.decoder == 'mlp': 176 | decoder = MLPDecoder_sigmoid(n_in_node=args.dims, 177 | num_factors=args.num_factors, 178 | msg_hid=args.decoder_hidden, 179 | msg_out=args.decoder_hidden, 180 | n_hid=args.decoder_hidden, 181 | do_prob=args.decoder_dropout, 182 | init_type=args.decoder_init_type) 183 | 184 | elif args.decoder == 'stationary': 185 | decoder = StationaryDecoder() 186 | 187 | elif args.decoder == 'velocity': 188 | decoder = VelocityStepDecoder() 189 | 190 | if args.load_folder: 191 | print('Loading model from: '+args.load_folder) 192 | encoder_file = os.path.join(args.load_folder, 'encoder.pt') 193 | decoder_file = os.path.join(args.load_folder, 'decoder.pt') 194 | if not args.cuda: 195 | encoder.load_state_dict(torch.load(encoder_file,map_location='cpu')) 196 | decoder.load_state_dict(torch.load(decoder_file,map_location='cpu')) 197 | else: 198 | encoder.load_state_dict(torch.load(encoder_file)) 199 | decoder.load_state_dict(torch.load(decoder_file)) 200 | args.save_folder = False 201 | 202 | optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), 203 | lr=args.lr, weight_decay=args.weight_decay) 204 | scheduler = lr_scheduler.StepLR(optimizer, step_size=args.lr_decay, 205 | gamma=args.gamma) 206 | 207 | 208 | if args.cuda: 209 | encoder.cuda() 210 | decoder.cuda() 211 | rel_rec = rel_rec.cuda() 212 | rel_send = rel_send.cuda() 213 | 214 | rel_rec = Variable(rel_rec) 215 | rel_send = Variable(rel_send) 216 | 217 | 218 | def train(epoch, best_val_loss): 219 | t = time.time() 220 | nll_train = [] 221 | nll_var_train = [] 222 | mse_train = [] 223 | 224 | kl_train = [] 225 | kl_list_train = [] 226 | kl_var_list_train = [] 227 | 228 | acc_train = [] 229 | perm_train = [] 230 | acc_blocks_train = [] 231 | acc_var_train = [] 232 | acc_var_blocks_train = [] 233 | 234 | KLb_train = [] 235 | KLb_blocks_train = [] 236 | 237 | encoder.train() 238 | decoder.train() 239 | scheduler.step() 240 | if not args.plot: 241 | for batch_idx, (data, relations) in enumerate(train_loader): # relations are the ground truth interactions graphs 242 | if args.cuda: 243 | data, relations = data.cuda(), relations.cuda() 244 | data, relations = Variable(data), Variable(relations) 245 | 246 | if args.dont_split_data: 247 | data_encoder = data[:, :, :args.timesteps, :].contiguous() 248 | data_decoder = data[:, :, :args.timesteps, :].contiguous() 249 | elif args.split_enc_only: 250 | data_encoder = data[:, :, :args.timesteps, :].contiguous() 251 | data_decoder = data 252 | else: 253 | assert (data.size(2) - args.timesteps) >= args.timesteps 254 | data_encoder = data[:, :, :args.timesteps, :].contiguous() 255 | data_decoder = data[:, :, -args.timesteps:, :].contiguous() 256 | 257 | optimizer.zero_grad() 258 | 259 | logits = encoder(data_encoder, rel_rec, rel_send) 260 | 261 | # dim of logits, edges and prob are [batchsize, N^2-N, edgetypes] where N = no. of particles 262 | 263 | edges = my_sigmoid(logits, hard=args.hard, sharpness=args.sigmoid_sharpness) 264 | 265 | loss_kl = 0 266 | loss_kl_split = [ 0 ] 267 | loss_kl_var_split = [ 0 ] 268 | 269 | KLb_train.append( 0 ) 270 | KLb_blocks_train.append([0]) 271 | 272 | acc_perm, perm, acc_blocks, acc_var, acc_var_blocks = edge_accuracy_perm_sigmoid(edges, relations) 273 | 274 | target = data_decoder[:, :, 1:, :] # dimensions are [batch, particle, time, state] 275 | output = decoder(data_decoder, edges, rel_rec, rel_send, args.prediction_steps) 276 | 277 | loss_nll = nll_gaussian(output, target, args.var) 278 | loss_nll_var = nll_gaussian_var(output, target, args.var) 279 | 280 | loss = F.mse_loss(output, target) 281 | 282 | perm_train.append(perm) 283 | acc_train.append(acc_perm) 284 | acc_blocks_train.append(acc_blocks) 285 | acc_var_train.append(acc_var) 286 | acc_var_blocks_train.append(acc_var_blocks) 287 | 288 | loss.backward() 289 | optimizer.step() 290 | 291 | mse_train.append(loss.data.item()) 292 | nll_train.append(loss_nll.data.item()) 293 | nll_var_train.append(loss_nll_var.data.item()) 294 | 295 | 296 | nll_val = [] 297 | nll_var_val = [] 298 | mse_val = [] 299 | 300 | kl_val = [] 301 | kl_list_val = [] 302 | kl_var_list_val = [] 303 | 304 | acc_val = [] 305 | acc_blocks_val = [] 306 | acc_var_val = [] 307 | acc_var_blocks_val = [] 308 | perm_val = [] 309 | 310 | KLb_val = [] 311 | KLb_blocks_val = [] # KL between blocks list 312 | 313 | nll_M_val = [] 314 | nll_M_var_val = [] 315 | 316 | encoder.eval() 317 | decoder.eval() 318 | for batch_idx, (data, relations) in enumerate(valid_loader): 319 | with torch.no_grad(): 320 | if args.cuda: 321 | data, relations = data.cuda(), relations.cuda() 322 | 323 | if args.dont_split_data: 324 | data_encoder = data[:, :, :args.timesteps, :].contiguous() 325 | data_decoder = data[:, :, :args.timesteps, :].contiguous() 326 | elif args.split_enc_only: 327 | data_encoder = data[:, :, :args.timesteps, :].contiguous() 328 | data_decoder = data 329 | else: 330 | assert (data.size(2) - args.timesteps) >= args.timesteps 331 | data_encoder = data[:, :, :args.timesteps, :].contiguous() 332 | data_decoder = data[:, :, -args.timesteps:, :].contiguous() 333 | 334 | # dim of logits, edges are [batchsize, N^2-N, sum(edge_types_list)] where N = no. of particles 335 | logits = encoder(data_encoder, rel_rec, rel_send) 336 | 337 | edges = my_sigmoid(logits, hard=args.hard, sharpness=args.sigmoid_sharpness) 338 | 339 | loss_kl = 0 340 | loss_kl_split = [ 0 ] 341 | loss_kl_var_split = [ 0 ] 342 | 343 | KLb_train.append( 0 ) 344 | KLb_blocks_train.append([0]) 345 | 346 | acc_perm, perm, acc_blocks, acc_var, acc_var_blocks = edge_accuracy_perm_sigmoid(edges, relations) 347 | 348 | target = data_decoder[:, :, 1:, :] # dimensions are [batch, particle, time, state] 349 | output = decoder(data_decoder, edges, rel_rec, rel_send, 1) 350 | 351 | if args.plot: 352 | import matplotlib.pyplot as plt 353 | output_plot = decoder(data_decoder, edges, rel_rec, rel_send, 49) 354 | 355 | acc_batch, perm, acc_blocks_batch = edge_accuracy_perm_sigmoid_batch(edges, relations) 356 | 357 | from trajectory_plot import draw_lines 358 | for i in range(args.batch_size): 359 | fig = plt.figure(figsize=(7, 7)) 360 | ax = fig.add_axes([0, 0, 1, 1]) 361 | xmin_t, ymin_t, xmax_t, ymax_t = draw_lines( target, i, linestyle=':', alpha=0.6 ) 362 | xmin_o, ymin_o, xmax_o, ymax_o = draw_lines( output_plot.detach().numpy(), i, linestyle='-' ) 363 | 364 | ax.set_xlim([min(xmin_t, xmin_o), max(xmax_t, xmax_o)]) 365 | ax.set_ylim([min(ymin_t, ymin_o), max(ymax_t, ymax_o)]) 366 | ax.set_xticks([]) 367 | ax.set_yticks([]) 368 | block_names = [ str(j) for j in range(args.num_factors) ] 369 | acc_text = [ 'layer ' + block_names[j] + ' acc: {:02.0f}%'.format(100*acc_blocks_batch[i,j]) 370 | for j in range(acc_blocks_batch.shape[1]) ] 371 | acc_text = ', '.join(acc_text) 372 | plt.text( 0.5, 0.95, acc_text, horizontalalignment='center', transform=ax.transAxes ) 373 | plt.show() 374 | 375 | 376 | loss_nll = nll_gaussian(output, target, args.var) 377 | loss_nll_var = nll_gaussian_var(output, target, args.var) 378 | 379 | output_M = decoder(data_decoder, edges, rel_rec, rel_send, args.prediction_steps) 380 | loss_nll_M = nll_gaussian(output_M, target, args.var) 381 | loss_nll_M_var = nll_gaussian_var(output_M, target, args.var) 382 | 383 | perm_val.append(perm) 384 | acc_val.append(acc_perm) 385 | acc_blocks_val.append(acc_blocks) 386 | acc_var_val.append(acc_var) 387 | acc_var_blocks_val.append(acc_var_blocks) 388 | 389 | mse_val.append(F.mse_loss(output_M, target).data.item()) 390 | nll_val.append(loss_nll.data.item()) 391 | nll_var_val.append(loss_nll_var.data.item()) 392 | 393 | nll_M_val.append(loss_nll_M.data.item()) 394 | nll_M_var_val.append(loss_nll_M_var.data.item()) 395 | 396 | print('Epoch: {:03d}'.format(epoch), 397 | 'perm_val: '+str( np.around(np.mean(np.array(perm_val),axis=0),4 ) ), 398 | 'time: {:.1f}s'.format(time.time() - t)) 399 | print('nll_trn: {:.2f}'.format(np.mean(nll_train)), 400 | 'mse_trn: {:.10f}'.format(np.mean(mse_train)), 401 | 'acc_trn: {:.5f}'.format(np.mean(acc_train)), 402 | 'acc_b_trn: '+str( np.around(np.mean(np.array(acc_blocks_train),axis=0),4 ) ) 403 | ) 404 | print('nll_val: {:.2f}'.format(np.mean(nll_M_val)), 405 | 'mse_val: {:.10f}'.format(np.mean(mse_val)), 406 | 'acc_val: {:.5f}'.format(np.mean(acc_val)), 407 | 'acc_b_val: '+str( np.around(np.mean(np.array(acc_blocks_val),axis=0),4 ) ) 408 | ) 409 | print('Epoch: {:03d}'.format(epoch), 410 | 'perm_val: '+str( np.around(np.mean(np.array(perm_val),axis=0),4 ) ), 411 | 'time: {:.1f}s'.format(time.time() - t), 412 | file=log ) 413 | print('nll_trn: {:.2f}'.format(np.mean(nll_train)), 414 | 'mse_trn: {:.10f}'.format(np.mean(mse_train)), 415 | 'acc_trn: {:.5f}'.format(np.mean(acc_train)), 416 | 'acc_b_trn: '+str( np.around(np.mean(np.array(acc_blocks_train),axis=0),4 ) ), 417 | file=log ) 418 | print('nll_val: {:.2f}'.format(np.mean(nll_val)), 419 | 'nll_M_val: {:.2f}'.format(np.mean(nll_M_val)), 420 | 'mse_val: {:.10f}'.format(np.mean(mse_val)), 421 | 'acc_val: {:.5f}'.format(np.mean(acc_val)), 422 | 'acc_b_val: '+str( np.around(np.mean(np.array(acc_blocks_val),axis=0),4 ) ), 423 | file=log ) 424 | if epoch == 0: 425 | labels = [ 'epoch', 'nll trn', 'mse train', 'nll var trn', 'acc trn' ] 426 | labels += [ 'b'+str(i)+' acc trn' for i in range( args.num_factors ) ] 427 | labels += [ 'acc var trn'] + [ 'b'+str(i)+' acc var trn' for i in range( args.num_factors ) ] 428 | labels += [ 'nll val', 'nll M val', 'mse val', 'acc val' ] 429 | labels += [ 'b'+str(i)+' acc val' for i in range( args.num_factors ) ] 430 | labels += [ 'nll var val', 'nll M var val' ] 431 | labels += [ 'acc var val'] + [ 'b'+str(i)+' acc var val' for i in range( args.num_factors ) ] 432 | csv_writer.writerow( labels ) 433 | 434 | labels = [ 'trn '+str(i) for i in range(len(perm_train[0])) ] 435 | labels += [ 'val '+str(i) for i in range(len(perm_val[0])) ] 436 | perm_writer.writerow( labels ) 437 | 438 | csv_writer.writerow( [epoch, np.mean(nll_train), np.mean(mse_train), np.mean(nll_var_train), np.mean(acc_train)] + 439 | list(np.mean(np.array(acc_blocks_train),axis=0)) + 440 | [np.mean(acc_var_train)] + list(np.mean(np.array(acc_var_blocks_train),axis=0)) + 441 | [np.mean(nll_val), np.mean(nll_M_val), np.mean(mse_val), np.mean(acc_val) ] + 442 | list(np.mean(np.array(acc_blocks_val),axis=0)) + 443 | [np.mean(nll_var_val), np.mean(nll_M_var_val)] + 444 | [np.mean(acc_var_val)] + list(np.mean(np.array(acc_var_blocks_val),axis=0)) 445 | ) 446 | perm_writer.writerow( list(np.mean(np.array(perm_train),axis=0)) + 447 | list(np.mean(np.array(perm_val),axis=0)) 448 | ) 449 | 450 | log.flush() 451 | if args.save_folder and np.mean(nll_M_val) < best_val_loss: 452 | torch.save(encoder.state_dict(), encoder_file) 453 | torch.save(decoder.state_dict(), decoder_file) 454 | print('Best model so far, saving...') 455 | return np.mean(nll_M_val) 456 | 457 | 458 | def test(): 459 | nll_test = [] 460 | nll_var_test = [] 461 | 462 | acc_test = [] 463 | acc_blocks_test = [] 464 | acc_var_test = [] 465 | acc_var_blocks_test = [] 466 | perm_test = [] 467 | 468 | mse_1_test = [] 469 | mse_10_test = [] 470 | mse_20_test = [] 471 | 472 | nll_M_test = [] 473 | nll_M_var_test = [] 474 | 475 | encoder.eval() 476 | decoder.eval() 477 | if not args.cuda: 478 | encoder.load_state_dict(torch.load(encoder_file,map_location='cpu')) 479 | decoder.load_state_dict(torch.load(decoder_file,map_location='cpu')) 480 | else: 481 | encoder.load_state_dict(torch.load(encoder_file)) 482 | decoder.load_state_dict(torch.load(decoder_file)) 483 | 484 | for batch_idx, (data, relations) in enumerate(test_loader): 485 | with torch.no_grad(): 486 | if args.cuda: 487 | data, relations = data.cuda(), relations.cuda() 488 | 489 | assert (data.size(2) - args.timesteps) >= args.timesteps 490 | data_encoder = data[:, :, :args.timesteps, :].contiguous() 491 | data_decoder = data[:, :, -args.timesteps:, :].contiguous() 492 | 493 | # dim of logits, edges and prob are [batchsize, N^2-N, sum(edge_types_list)] where N = no. of particles 494 | logits = encoder(data_encoder, rel_rec, rel_send) 495 | edges = edges = my_sigmoid(logits, hard=args.hard, sharpness=args.sigmoid_sharpness) 496 | 497 | acc_perm, perm, acc_blocks, acc_var, acc_var_blocks = edge_accuracy_perm_sigmoid(edges, relations) 498 | 499 | target = data_decoder[:, :, 1:, :] # dimensions are [batch, particle, time, state] 500 | output = decoder(data_decoder, edges, rel_rec, rel_send, 1) 501 | 502 | if args.plot: 503 | import matplotlib.pyplot as plt 504 | output_plot = decoder(data_decoder, edges, rel_rec, rel_send, 49) 505 | 506 | output_plot_en = decoder(data_encoder, edges, rel_rec, rel_send, 49) 507 | from trajectory_plot import draw_lines 508 | 509 | acc_batch, perm, acc_blocks_batch = edge_accuracy_perm_sigmoid_batch(edges, relations) 510 | 511 | for i in range(args.batch_size): 512 | fig = plt.figure(figsize=(7, 7)) 513 | ax = fig.add_axes([0, 0, 1, 1]) 514 | xmin_t, ymin_t, xmax_t, ymax_t = draw_lines( target, i, linestyle=':', alpha=0.6 ) 515 | xmin_o, ymin_o, xmax_o, ymax_o = draw_lines( output_plot.detach().numpy(), i, linestyle='-' ) 516 | 517 | ax.set_xlim([min(xmin_t, xmin_o), max(xmax_t, xmax_o)]) 518 | ax.set_ylim([min(ymin_t, ymin_o), max(ymax_t, ymax_o)]) 519 | ax.set_xticks([]) 520 | ax.set_yticks([]) 521 | block_names = [str(j) for j in range(args.num_factors)] 522 | acc_text = [ 'layer ' + block_names[j] + ' acc: {:02.0f}%'.format(100*acc_blocks_batch[i,j]) 523 | for j in range(acc_blocks_batch.shape[1]) ] 524 | acc_text = ', '.join(acc_text) 525 | plt.text( 0.5, 0.95, acc_text, horizontalalignment='center', transform=ax.transAxes ) 526 | #plt.savefig(os.path.join(args.load_folder,str(i)+'_pred_and_true_.png'), dpi=300) 527 | plt.show() 528 | 529 | 530 | loss_nll = nll_gaussian(output, target, args.var) 531 | loss_nll_var = nll_gaussian_var(output, target, args.var) 532 | 533 | output_10 = decoder(data_decoder, edges, rel_rec, rel_send, 10) 534 | output_20 = decoder(data_decoder, edges, rel_rec, rel_send, 20) 535 | mse_1_test.append(F.mse_loss(output, target).data.item()) 536 | mse_10_test.append(F.mse_loss(output_10, target).data.item()) 537 | mse_20_test.append(F.mse_loss(output_20, target).data.item()) 538 | 539 | loss_nll_M = nll_gaussian(output_10, target, args.var) 540 | loss_nll_M_var = nll_gaussian_var(output_10, target, args.var) 541 | 542 | perm_test.append(perm) 543 | acc_test.append(acc_perm) 544 | acc_blocks_test.append(acc_blocks) 545 | acc_var_test.append(acc_var) 546 | acc_var_blocks_test.append(acc_var_blocks) 547 | 548 | nll_test.append(loss_nll.data.item()) 549 | nll_var_test.append(loss_nll_var.data.item()) 550 | nll_M_test.append(loss_nll_M.data.item()) 551 | nll_M_var_test.append(loss_nll_M_var.data.item()) 552 | 553 | 554 | print('--------------------------------') 555 | print('------------Testing-------------') 556 | print('--------------------------------') 557 | print('nll_test: {:.2f}'.format(np.mean(nll_test)), 558 | 'nll_M_test: {:.2f}'.format(np.mean(nll_M_test)), 559 | 'mse_1_test: {:.10f}'.format(np.mean(mse_1_test)), 560 | 'mse_10_test: {:.10f}'.format(np.mean(mse_10_test)), 561 | 'mse_20_test: {:.10f}'.format(np.mean(mse_20_test)), 562 | 'acc_test: {:.5f}'.format(np.mean(acc_test)), 563 | 'acc_var_test: {:.5f}'.format(np.mean(acc_var_test)), 564 | 'acc_b_test: '+str( np.around(np.mean(np.array(acc_blocks_test),axis=0),4 ) ), 565 | 'acc_var_b_test: '+str( np.around(np.mean(np.array(acc_var_blocks_test),axis=0),4 ) ) 566 | ) 567 | print('--------------------------------', file=log) 568 | print('------------Testing-------------', file=log) 569 | print('--------------------------------', file=log) 570 | print('nll_test: {:.2f}'.format(np.mean(nll_test)), 571 | 'nll_M_test: {:.2f}'.format(np.mean(nll_M_test)), 572 | 'mse_1_test: {:.10f}'.format(np.mean(mse_1_test)), 573 | 'mse_10_test: {:.10f}'.format(np.mean(mse_10_test)), 574 | 'mse_20_test: {:.10f}'.format(np.mean(mse_20_test)), 575 | 'acc_test: {:.5f}'.format(np.mean(acc_test)), 576 | 'acc_var_test: {:.5f}'.format(np.mean(acc_var_test)), 577 | 'acc_b_test: '+str( np.around(np.mean(np.array(acc_blocks_test),axis=0),4 ) ), 578 | 'acc_var_b_test: '+str( np.around(np.mean(np.array(acc_var_blocks_test),axis=0),4 ) ), 579 | file=log) 580 | log.flush() 581 | 582 | 583 | # Train model 584 | if not args.test: 585 | t_total = time.time() 586 | best_val_loss = np.inf 587 | best_epoch = 0 588 | for epoch in range(args.epochs): 589 | val_loss = train(epoch, best_val_loss) 590 | if val_loss < best_val_loss: 591 | best_val_loss = val_loss 592 | best_epoch = epoch 593 | if epoch - best_epoch > args.patience and epoch > 99: 594 | break 595 | print("Optimization Finished!") 596 | print("Best Epoch: {:04d}".format(best_epoch)) 597 | if args.save_folder: 598 | print("Best Epoch: {:04d}".format(best_epoch), file=log) 599 | log.flush() 600 | 601 | print('Reloading best model') 602 | test() 603 | if log is not None: 604 | print(save_folder) 605 | log.close() 606 | log_csv.close() 607 | perm_csv.close() 608 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is based on https://github.com/ethanfetaya/NRI 3 | (MIT licence) 4 | """ 5 | 6 | import numpy as np 7 | import torch 8 | from torch.utils.data.dataset import TensorDataset 9 | from torch.utils.data import DataLoader 10 | import torch.nn.functional as F 11 | from torch.autograd import Variable 12 | 13 | from itertools import permutations, chain 14 | from math import factorial 15 | 16 | from os import path 17 | 18 | def my_softmax(input, axis=1): 19 | trans_input = input.transpose(axis, 0).contiguous() 20 | soft_max_1d = F.softmax(trans_input, dim=0) # added dim=0 as implicit choice is deprecated, dim 0 is edgetype due to transpose 21 | return soft_max_1d.transpose(axis, 0) 22 | 23 | 24 | def binary_concrete(logits, tau=1, hard=False, eps=1e-10): 25 | y_soft = binary_concrete_sample(logits, tau=tau, eps=eps) 26 | if hard: 27 | y_hard = (y_soft > 0.5).float() 28 | y = Variable(y_hard.data - y_soft.data) + y_soft 29 | else: 30 | y = y_soft 31 | return y 32 | 33 | 34 | def binary_concrete_sample(logits, tau=1, eps=1e-10): 35 | logistic_noise = sample_logistic(logits.size(), eps=eps) 36 | if logits.is_cuda: 37 | logistic_noise = logistic_noise.cuda() 38 | y = logits + Variable(logistic_noise) 39 | return F.sigmoid(y / tau) 40 | 41 | 42 | def sample_logistic(shape, eps=1e-10): 43 | uniform = torch.rand(shape).float() 44 | return torch.log(uniform + eps) - torch.log(1 - uniform + eps) 45 | 46 | 47 | def sample_gumbel(shape, eps=1e-10): 48 | """ 49 | NOTE: Stolen from https://github.com/pytorch/pytorch/pull/3341/commits/327fcfed4c44c62b208f750058d14d4dc1b9a9d3 50 | 51 | Sample from Gumbel(0, 1) 52 | 53 | based on 54 | https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb , 55 | (MIT license) 56 | """ 57 | U = torch.rand(shape).float() 58 | return - torch.log(eps - torch.log(U + eps)) 59 | 60 | 61 | def gumbel_softmax_sample(logits, tau=1, eps=1e-10): 62 | """ 63 | NOTE: Stolen from https://github.com/pytorch/pytorch/pull/3341/commits/327fcfed4c44c62b208f750058d14d4dc1b9a9d3 64 | 65 | Draw a sample from the Gumbel-Softmax distribution 66 | 67 | based on 68 | https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb 69 | (MIT license) 70 | """ 71 | gumbel_noise = sample_gumbel(logits.size(), eps=eps) 72 | if logits.is_cuda: 73 | gumbel_noise = gumbel_noise.cuda() 74 | y = logits + Variable(gumbel_noise) 75 | return my_softmax(y / tau, axis=-1) 76 | 77 | 78 | def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10): 79 | """ 80 | NOTE: Stolen from https://github.com/pytorch/pytorch/pull/3341/commits/327fcfed4c44c62b208f750058d14d4dc1b9a9d3 81 | 82 | Sample from the Gumbel-Softmax distribution and optionally discretize. 83 | Args: 84 | logits: [batch_size, n_class] unnormalized log-probs 85 | tau: non-negative scalar temperature 86 | hard: if True, take argmax, but differentiate w.r.t. soft sample y 87 | Returns: 88 | [batch_size, n_class] sample from the Gumbel-Softmax distribution. 89 | If hard=True, then the returned sample will be one-hot, otherwise it will 90 | be a probability distribution that sums to 1 across classes 91 | 92 | Constraints: 93 | - this implementation only works on batch_size x num_features tensor for now 94 | 95 | based on 96 | https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb , 97 | (MIT license) 98 | """ 99 | y_soft = gumbel_softmax_sample(logits, tau=tau, eps=eps) 100 | if hard: 101 | shape = logits.size() 102 | _, k = y_soft.data.max(-1) 103 | # this bit is based on 104 | # https://discuss.pytorch.org/t/stop-gradients-for-st-gumbel-softmax/530/5 105 | y_hard = torch.zeros(*shape) 106 | if y_soft.is_cuda: 107 | y_hard = y_hard.cuda() 108 | y_hard = y_hard.zero_().scatter_(-1, k.view(shape[:-1] + (1,)), 1.0) 109 | # this cool bit of code achieves two things: 110 | # - makes the output value exactly one-hot (since we add then 111 | # subtract y_soft value) 112 | # - makes the gradient equal to y_soft gradient (since we strip 113 | # all other gradients) 114 | y = Variable(y_hard - y_soft.data) + y_soft 115 | else: 116 | y = y_soft 117 | return y 118 | 119 | def my_sigmoid(logits, hard=True, sharpness=1.0): 120 | 121 | edges_soft = 1/(1+torch.exp(-sharpness*logits)) 122 | if hard: 123 | edges_hard = torch.round(edges_soft) 124 | # this bit is based on 125 | # https://discuss.pytorch.org/t/stop-gradients-for-st-gumbel-softmax/530/5 126 | if edges_soft.is_cuda: 127 | edges_hard = edges_hard.cuda() 128 | # this cool bit of code achieves two things: 129 | # - makes the output value exactly one-hot (since we add then 130 | # subtract y_soft value) 131 | # - makes the gradient equal to y_soft gradient (since we strip 132 | # all other gradients) 133 | edges = Variable(edges_hard - edges_soft.data) + edges_soft 134 | else: 135 | edges = edges_soft 136 | return edges 137 | 138 | def binary_accuracy(output, labels): 139 | preds = output > 0.5 140 | correct = preds.type_as(labels).eq(labels).double() 141 | correct = correct.sum() 142 | return correct / len(labels) 143 | 144 | def edge_type_encode(edges): # this is used to gives each 'interaction strength' a unique integer = 0, 1, 2 .. 145 | unique = np.unique(edges) 146 | encode = np.zeros(edges.shape) 147 | for i in range(unique.shape[0]): 148 | encode += np.where( edges == unique[i], i, 0) 149 | return encode 150 | 151 | def loader_edges_encode(edges, num_atoms): 152 | edges = np.reshape(edges, [edges.shape[0], edges.shape[1], num_atoms ** 2]) 153 | edges = np.array(edge_type_encode(edges), dtype=np.int64) 154 | off_diag_idx = np.ravel_multi_index( 155 | np.where(np.ones((num_atoms, num_atoms)) - np.eye(num_atoms)), 156 | [num_atoms, num_atoms]) 157 | edges = edges[:,:, off_diag_idx] 158 | return edges 159 | 160 | def loader_combine_edges(edges): 161 | edge_types_list = [ int(np.max(edges[:,i,:]))+1 for i in range(edges.shape[1]) ] 162 | assert( edge_types_list == sorted(edge_types_list)[::-1] ) 163 | encoded_target = np.zeros( edges[:,0,:].shape ) 164 | base = 1 165 | for i in reversed(range(edges.shape[1])): 166 | encoded_target += base*edges[:,i,:] 167 | base *= edge_types_list[i] 168 | return encoded_target.astype('int') 169 | 170 | def load_data_NRI(batch_size=1, sim_folder='', shuffle=True, data_folder='data'): 171 | # the edges numpy arrays below are [ num_sims, N, N ] 172 | loc_train = np.load(path.join(data_folder,sim_folder,'loc_train.npy')) 173 | vel_train = np.load(path.join(data_folder,sim_folder,'vel_train.npy')) 174 | edges_train = np.load(path.join(data_folder,sim_folder,'edges_train.npy')) 175 | 176 | loc_valid = np.load(path.join(data_folder,sim_folder,'loc_valid.npy')) 177 | vel_valid = np.load(path.join(data_folder,sim_folder,'vel_valid.npy')) 178 | edges_valid = np.load(path.join(data_folder,sim_folder,'edges_valid.npy')) 179 | 180 | loc_test = np.load(path.join(data_folder,sim_folder,'loc_test.npy')) 181 | vel_test = np.load(path.join(data_folder,sim_folder,'vel_test.npy')) 182 | edges_test = np.load(path.join(data_folder,sim_folder,'edges_test.npy')) 183 | 184 | # [num_samples, num_timesteps, num_dims, num_atoms] 185 | num_atoms = loc_train.shape[3] 186 | 187 | loc_max = loc_train.max() 188 | loc_min = loc_train.min() 189 | vel_max = vel_train.max() 190 | vel_min = vel_train.min() 191 | 192 | # Normalize to [-1, 1] 193 | loc_train = (loc_train - loc_min) * 2 / (loc_max - loc_min) - 1 194 | vel_train = (vel_train - vel_min) * 2 / (vel_max - vel_min) - 1 195 | 196 | loc_valid = (loc_valid - loc_min) * 2 / (loc_max - loc_min) - 1 197 | vel_valid = (vel_valid - vel_min) * 2 / (vel_max - vel_min) - 1 198 | 199 | loc_test = (loc_test - loc_min) * 2 / (loc_max - loc_min) - 1 200 | vel_test = (vel_test - vel_min) * 2 / (vel_max - vel_min) - 1 201 | 202 | # Reshape to: [num_sims, num_atoms, num_timesteps, num_dims] 203 | loc_train = np.transpose(loc_train, [0, 3, 1, 2]) 204 | vel_train = np.transpose(vel_train, [0, 3, 1, 2]) 205 | feat_train = np.concatenate([loc_train, vel_train], axis=3) 206 | 207 | loc_valid = np.transpose(loc_valid, [0, 3, 1, 2]) 208 | vel_valid = np.transpose(vel_valid, [0, 3, 1, 2]) 209 | feat_valid = np.concatenate([loc_valid, vel_valid], axis=3) 210 | 211 | loc_test = np.transpose(loc_test, [0, 3, 1, 2]) 212 | vel_test = np.transpose(vel_test, [0, 3, 1, 2]) 213 | feat_test = np.concatenate([loc_test, vel_test], axis=3) 214 | 215 | edges_train = loader_edges_encode(edges_train, num_atoms) 216 | edges_valid = loader_edges_encode(edges_valid, num_atoms) 217 | edges_test = loader_edges_encode(edges_test, num_atoms) 218 | 219 | edges_train = loader_combine_edges(edges_train) 220 | edges_valid = loader_combine_edges(edges_valid) 221 | edges_test = loader_combine_edges(edges_test) 222 | 223 | feat_train = torch.FloatTensor(feat_train) 224 | edges_train = torch.LongTensor(edges_train) 225 | feat_valid = torch.FloatTensor(feat_valid) 226 | edges_valid = torch.LongTensor(edges_valid) 227 | feat_test = torch.FloatTensor(feat_test) 228 | edges_test = torch.LongTensor(edges_test) 229 | 230 | train_data = TensorDataset(feat_train, edges_train) 231 | valid_data = TensorDataset(feat_valid, edges_valid) 232 | test_data = TensorDataset(feat_test, edges_test) 233 | 234 | train_data_loader = DataLoader(train_data, batch_size=batch_size, shuffle=shuffle) 235 | valid_data_loader = DataLoader(valid_data, batch_size=batch_size) 236 | test_data_loader = DataLoader(test_data, batch_size=batch_size) 237 | 238 | return train_data_loader, valid_data_loader, test_data_loader, loc_max, loc_min, vel_max, vel_min 239 | 240 | 241 | def load_data_fNRI(batch_size=1, sim_folder='', shuffle=True, data_folder='data'): 242 | # the edges numpy arrays below are [ num_sims, N, N ] 243 | loc_train = np.load(path.join(data_folder,sim_folder,'loc_train.npy')) 244 | vel_train = np.load(path.join(data_folder,sim_folder,'vel_train.npy')) 245 | edges_train = np.load(path.join(data_folder,sim_folder,'edges_train.npy')) 246 | 247 | loc_valid = np.load(path.join(data_folder,sim_folder,'loc_valid.npy')) 248 | vel_valid = np.load(path.join(data_folder,sim_folder,'vel_valid.npy')) 249 | edges_valid = np.load(path.join(data_folder,sim_folder,'edges_valid.npy')) 250 | 251 | loc_test = np.load(path.join(data_folder,sim_folder,'loc_test.npy')) 252 | vel_test = np.load(path.join(data_folder,sim_folder,'vel_test.npy')) 253 | edges_test = np.load(path.join(data_folder,sim_folder,'edges_test.npy')) 254 | 255 | # [num_samples, num_timesteps, num_dims, num_atoms] 256 | num_atoms = loc_train.shape[3] 257 | 258 | loc_max = loc_train.max() 259 | loc_min = loc_train.min() 260 | vel_max = vel_train.max() 261 | vel_min = vel_train.min() 262 | 263 | # Normalize to [-1, 1] 264 | loc_train = (loc_train - loc_min) * 2 / (loc_max - loc_min) - 1 265 | vel_train = (vel_train - vel_min) * 2 / (vel_max - vel_min) - 1 266 | 267 | loc_valid = (loc_valid - loc_min) * 2 / (loc_max - loc_min) - 1 268 | vel_valid = (vel_valid - vel_min) * 2 / (vel_max - vel_min) - 1 269 | 270 | loc_test = (loc_test - loc_min) * 2 / (loc_max - loc_min) - 1 271 | vel_test = (vel_test - vel_min) * 2 / (vel_max - vel_min) - 1 272 | 273 | # Reshape to: [num_sims, num_atoms, num_timesteps, num_dims] 274 | loc_train = np.transpose(loc_train, [0, 3, 1, 2]) 275 | vel_train = np.transpose(vel_train, [0, 3, 1, 2]) 276 | feat_train = np.concatenate([loc_train, vel_train], axis=3) 277 | 278 | loc_valid = np.transpose(loc_valid, [0, 3, 1, 2]) 279 | vel_valid = np.transpose(vel_valid, [0, 3, 1, 2]) 280 | feat_valid = np.concatenate([loc_valid, vel_valid], axis=3) 281 | 282 | loc_test = np.transpose(loc_test, [0, 3, 1, 2]) 283 | vel_test = np.transpose(vel_test, [0, 3, 1, 2]) 284 | feat_test = np.concatenate([loc_test, vel_test], axis=3) 285 | 286 | edges_train = loader_edges_encode( edges_train, num_atoms ) 287 | edges_valid = loader_edges_encode( edges_valid, num_atoms ) 288 | edges_test = loader_edges_encode( edges_test, num_atoms ) 289 | 290 | edges_train = torch.LongTensor(edges_train) 291 | edges_valid = torch.LongTensor(edges_valid) 292 | edges_test = torch.LongTensor(edges_test) 293 | 294 | feat_train = torch.FloatTensor(feat_train) 295 | feat_valid = torch.FloatTensor(feat_valid) 296 | feat_test = torch.FloatTensor(feat_test) 297 | 298 | train_data = TensorDataset(feat_train, edges_train) 299 | valid_data = TensorDataset(feat_valid, edges_valid) 300 | test_data = TensorDataset(feat_test, edges_test) 301 | 302 | train_data_loader = DataLoader(train_data, batch_size=batch_size, shuffle=shuffle) 303 | valid_data_loader = DataLoader(valid_data, batch_size=batch_size) 304 | test_data_loader = DataLoader(test_data, batch_size=batch_size) 305 | 306 | return train_data_loader, valid_data_loader, test_data_loader, loc_max, loc_min, vel_max, vel_min 307 | 308 | 309 | def to_2d_idx(idx, num_cols): 310 | idx = np.array(idx, dtype=np.int64) 311 | y_idx = np.array(np.floor(idx / float(num_cols)), dtype=np.int64) 312 | x_idx = idx % num_cols 313 | return x_idx, y_idx 314 | 315 | 316 | def encode_onehot(labels): 317 | classes = set(labels) 318 | classes_dict = {c: np.identity(len(classes))[i, :] for i, c in 319 | enumerate(classes)} 320 | labels_onehot = np.array(list(map(classes_dict.get, labels)), 321 | dtype=np.int32) 322 | return labels_onehot 323 | 324 | 325 | def get_triu_indices(num_nodes): 326 | """Linear triu (upper triangular) indices.""" 327 | ones = torch.ones(num_nodes, num_nodes) 328 | eye = torch.eye(num_nodes, num_nodes) 329 | triu_indices = (ones.triu() - eye).nonzero().t() 330 | triu_indices = triu_indices[0] * num_nodes + triu_indices[1] 331 | return triu_indices 332 | 333 | 334 | def get_tril_indices(num_nodes): 335 | """Linear tril (lower triangular) indices.""" 336 | ones = torch.ones(num_nodes, num_nodes) 337 | eye = torch.eye(num_nodes, num_nodes) 338 | tril_indices = (ones.tril() - eye).nonzero().t() 339 | tril_indices = tril_indices[0] * num_nodes + tril_indices[1] 340 | return tril_indices 341 | 342 | 343 | def get_offdiag_indices(num_nodes): 344 | """Linear off-diagonal indices.""" 345 | ones = torch.ones(num_nodes, num_nodes) 346 | eye = torch.eye(num_nodes, num_nodes) 347 | offdiag_indices = (ones - eye).nonzero().t() 348 | offdiag_indices = offdiag_indices[0] * num_nodes + offdiag_indices[1] 349 | return offdiag_indices 350 | 351 | 352 | def get_triu_offdiag_indices(num_nodes): 353 | """Linear triu (upper) indices w.r.t. vector of off-diagonal elements.""" 354 | triu_idx = torch.zeros(num_nodes * num_nodes) 355 | triu_idx[get_triu_indices(num_nodes)] = 1. 356 | triu_idx = triu_idx[get_offdiag_indices(num_nodes)] 357 | return triu_idx.nonzero() 358 | 359 | 360 | def get_tril_offdiag_indices(num_nodes): 361 | """Linear tril (lower) indices w.r.t. vector of off-diagonal elements.""" 362 | tril_idx = torch.zeros(num_nodes * num_nodes) 363 | tril_idx[get_tril_indices(num_nodes)] = 1. 364 | tril_idx = tril_idx[get_offdiag_indices(num_nodes)] 365 | return tril_idx.nonzero() 366 | 367 | 368 | def get_minimum_distance(data): 369 | data = data[:, :, :, :2].transpose(1, 2) 370 | data_norm = (data ** 2).sum(-1, keepdim=True) 371 | dist = data_norm + \ 372 | data_norm.transpose(2, 3) - \ 373 | 2 * torch.matmul(data, data.transpose(2, 3)) 374 | min_dist, _ = dist.min(1) 375 | return min_dist.view(min_dist.size(0), -1) 376 | 377 | 378 | def get_buckets(dist, num_buckets): 379 | dist = dist.cpu().data.numpy() 380 | 381 | min_dist = np.min(dist) 382 | max_dist = np.max(dist) 383 | bucket_size = (max_dist - min_dist) / num_buckets 384 | thresholds = bucket_size * np.arange(num_buckets) 385 | 386 | bucket_idx = [] 387 | for i in range(num_buckets): 388 | if i < num_buckets - 1: 389 | idx = np.where(np.all(np.vstack((dist > thresholds[i], 390 | dist <= thresholds[i + 1])), 0))[0] 391 | else: 392 | idx = np.where(dist > thresholds[i])[0] 393 | bucket_idx.append(idx) 394 | 395 | return bucket_idx, thresholds 396 | 397 | 398 | def get_correct_per_bucket(bucket_idx, pred, target): 399 | pred = pred.cpu().numpy()[:, 0] 400 | target = target.cpu().data.numpy() 401 | 402 | correct_per_bucket = [] 403 | for i in range(len(bucket_idx)): 404 | preds_bucket = pred[bucket_idx[i]] 405 | target_bucket = target[bucket_idx[i]] 406 | correct_bucket = np.sum(preds_bucket == target_bucket) 407 | correct_per_bucket.append(correct_bucket) 408 | 409 | return correct_per_bucket 410 | 411 | 412 | def get_correct_per_bucket_(bucket_idx, pred, target): 413 | pred = pred.cpu().numpy() 414 | target = target.cpu().data.numpy() 415 | 416 | correct_per_bucket = [] 417 | for i in range(len(bucket_idx)): 418 | preds_bucket = pred[bucket_idx[i]] 419 | target_bucket = target[bucket_idx[i]] 420 | correct_bucket = np.sum(preds_bucket == target_bucket) 421 | correct_per_bucket.append(correct_bucket) 422 | 423 | return correct_per_bucket 424 | 425 | 426 | def kl_categorical(preds, log_prior, num_atoms, eps=1e-16): 427 | kl_div = preds * (torch.log(preds + eps) - log_prior) 428 | return kl_div.sum() / (num_atoms * preds.size(0)) # normalisation here is (batch * num atoms) 429 | 430 | 431 | def kl_categorical_uniform(preds, num_atoms, num_edge_types, add_const=False, 432 | eps=1e-16): 433 | kl_div = preds * torch.log(preds + eps) 434 | if add_const: 435 | const = np.log(num_edge_types) 436 | kl_div += const 437 | return kl_div.sum() / (num_atoms * preds.size(0)) 438 | 439 | def kl_categorical_uniform_var(preds, num_atoms, num_edge_types, add_const=False, 440 | eps=1e-16): 441 | kl_div = preds * torch.log(preds + eps) 442 | if add_const: 443 | const = np.log(num_edge_types) 444 | kl_div += const 445 | return (kl_div.sum(dim=1) / num_atoms).var() 446 | 447 | 448 | def nll_gaussian(preds, target, variance, add_const=False): 449 | neg_log_p = ((preds - target) ** 2 / (2 * variance)) 450 | if add_const: 451 | const = 0.5 * np.log(2 * np.pi * variance) 452 | neg_log_p += const 453 | return neg_log_p.sum() / (target.size(0) * target.size(1)) # normalisation here is (batch * num atoms) 454 | 455 | def nll_gaussian_var(preds, target, variance, add_const=False): 456 | # returns the variance over the batch of the reconstruction loss 457 | neg_log_p = ((preds - target) ** 2 / (2 * variance)) 458 | if add_const: 459 | const = 0.5 * np.log(2 * np.pi * variance) 460 | neg_log_p += const 461 | return (neg_log_p.sum(dim=1)/target.size(1)).var() 462 | 463 | 464 | 465 | def true_flip(x, dim): 466 | indices = [slice(None)] * x.dim() 467 | indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, 468 | dtype=torch.long, device=x.device) 469 | return x[tuple(indices)] 470 | 471 | def KL_between_blocks(prob_list, num_atoms, eps=1e-16): 472 | # Return a list of the mutual information between every block pair 473 | KL_list = [] 474 | for i in range(len(prob_list)): 475 | for j in range(len(prob_list)): 476 | if i != j: 477 | KL = prob_list[i] *( torch.log(prob_list[i] + eps) - torch.log(prob_list[j] + eps) ) 478 | KL_list.append( KL.sum() / (num_atoms * prob_list[i].size(0)) ) 479 | KL = prob_list[i] *( torch.log(prob_list[i] + eps) - torch.log( true_flip(prob_list[j],-1) + eps) ) 480 | KL_list.append( KL.sum() / (num_atoms * prob_list[i].size(0)) ) 481 | return KL_list 482 | 483 | 484 | def decode_target( target, num_edge_types_list ): 485 | target_list = [] 486 | base = np.prod(num_edge_types_list) 487 | for i in range(len(num_edge_types_list)): 488 | base /= num_edge_types_list[i] 489 | target_list.append( target//base ) 490 | target = target % base 491 | return target_list 492 | 493 | def encode_target_list( target_list, edge_types_list ): 494 | encoded_target = np.zeros( target_list[0].shape ) 495 | base = 1 496 | for i in reversed(range(len(target_list))): 497 | encoded_target += base*np.array(target_list[i]) 498 | base *= edge_types_list[i] 499 | return encoded_target.astype('int') 500 | 501 | def edge_accuracy_perm_NRI_batch(preds, target, num_edge_types_list): 502 | # permutation edge accuracy calculator for the standard NRI model 503 | # return the maximum accuracy of the batch over the permutations of the edge labels 504 | # also returns a one-hot encoding of the number which represents this permutation 505 | # also returns the accuracies for the individual factor graphs 506 | 507 | _, preds = preds.max(-1) # returns index of max in each z_ij to reduce dim by 1 508 | 509 | num_edge_types = np.prod(num_edge_types_list) 510 | preds = np.eye(num_edge_types)[np.array(preds.cpu())] # this is nice way to turn integers into one-hot vectors 511 | target = np.array(target.cpu()) 512 | 513 | perms = [p for p in permutations(range(num_edge_types))] # list of edge type permutations 514 | # in the below, for each permutation of edge-types, permute preds, then take argmax to go from one-hot to integers 515 | # then compare to target, compute accuracy 516 | acc = np.array([np.mean(np.equal(target, np.argmax(preds[:,:,p], axis=-1),dtype=object)) for p in perms]) 517 | max_acc, idx = np.amax(acc), np.argmax(acc) 518 | preds_deperm = np.argmax(preds[:,:,perms[idx]], axis=-1) 519 | 520 | target_list = decode_target( target, num_edge_types_list ) 521 | preds_deperm_list = decode_target( preds_deperm, num_edge_types_list ) 522 | 523 | blocks_acc = [ np.mean(np.equal(target_list[i], preds_deperm_list[i], dtype=object),axis=-1) 524 | for i in range(len(target_list)) ] 525 | acc = np.mean(np.equal(target, preds_deperm ,dtype=object), axis=-1) 526 | blocks_acc = np.swapaxes(np.array(blocks_acc),0,1) 527 | 528 | idx_onehot = np.eye(len(perms))[np.array(idx)] 529 | return acc, idx_onehot, blocks_acc 530 | 531 | def edge_accuracy_perm_NRI(preds, targets, num_edge_types_list): 532 | acc_batch, perm_code_onehot, acc_blocks_batch = edge_accuracy_perm_NRI_batch(preds, targets, num_edge_types_list) 533 | 534 | acc = np.mean(acc_batch) 535 | acc_var = np.var(acc_batch) 536 | acc_blocks = np.mean(acc_blocks_batch, axis=0) 537 | acc_var_blocks = np.var(acc_blocks_batch, axis=0) 538 | 539 | return acc, perm_code_onehot, acc_blocks, acc_var, acc_var_blocks 540 | 541 | 542 | def edge_accuracy_perm_fNRI_batch(preds_list, targets, num_edge_types_list): 543 | # permutation edge accuracy calculator for the fNRI model 544 | # return the maximum accuracy of the batch over the permutations of the edge labels 545 | # also returns a one-hot encoding of the number which represents this permutation 546 | # also returns the accuracies for the individual factor graphs 547 | 548 | target_list = [ targets[:,i,:].cpu() for i in range(targets.shape[1])] 549 | preds_list = [ pred.max(-1)[1].cpu() for pred in preds_list] 550 | preds = encode_target_list(preds_list, num_edge_types_list) 551 | target = encode_target_list(target_list, num_edge_types_list) 552 | 553 | target_list = [ np.array(t.cpu()).astype('int') for t in target_list ] 554 | 555 | num_edge_types = np.prod(num_edge_types_list) 556 | preds = np.eye(num_edge_types)[preds] # this is nice way to turn integers into one-hot vectors 557 | 558 | perms = [p for p in permutations(range(num_edge_types))] # list of edge type permutations 559 | 560 | # in the below, for each permutation of edge-types, permute preds, then take argmax to go from one-hot to integers 561 | # then compare to target to compute accuracy 562 | acc = np.array([np.mean(np.equal(target, np.argmax(preds[:,:,p], axis=-1),dtype=object)) for p in perms]) 563 | max_acc, idx = np.amax(acc), np.argmax(acc) 564 | 565 | preds_deperm = np.argmax(preds[:,:,perms[idx]], axis=-1) 566 | preds_deperm_list = decode_target( preds_deperm, num_edge_types_list ) 567 | 568 | blocks_acc = [ np.mean(np.equal(target_list[i], preds_deperm_list[i], dtype=object),axis=-1) 569 | for i in range(len(target_list)) ] 570 | acc = np.mean(np.equal(target, preds_deperm ,dtype=object), axis=-1) 571 | blocks_acc = np.swapaxes(np.array(blocks_acc),0,1) 572 | 573 | idx_onehot = np.array([0])#np.eye(len(perms))[np.array(idx)] 574 | 575 | return acc, idx_onehot, blocks_acc 576 | 577 | def edge_accuracy_perm_fNRI_batch_skipfirst(preds_list, targets, num_factors): 578 | # permutation edge accuracy calculator for the fNRI model when using skip-first argument 579 | # and all factor graphs have two edge types 580 | # return the maximum accuracy of the batch over the permutations of the edge labels 581 | # also returns a one-hot encoding of the number which represents this permutation 582 | # also returns the accuracies for the individual factor graphs 583 | 584 | targets = np.swapaxes(np.array(targets.cpu()),1,2) 585 | preds = torch.cat( [ torch.unsqueeze(pred.max(-1)[1],-1) for pred in preds_list], -1 ) 586 | preds = np.array(preds.cpu()) 587 | perms = [p for p in permutations(range(num_factors))] 588 | 589 | acc = np.array([np.mean( np.sum(np.equal(targets, preds[:,:,p],dtype=object),axis=-1)==num_factors ) for p in perms]) 590 | max_acc, idx = np.amax(acc), np.argmax(acc) 591 | 592 | preds_deperm = preds[:,:,perms[idx]] 593 | blocks_acc = np.mean(np.equal(targets, preds_deperm, dtype=object),axis=1) 594 | acc = np.mean( np.sum(np.equal(targets, preds_deperm,dtype=object),axis=-1)==num_factors, axis=-1) 595 | 596 | idx_onehot = np.eye(len(perms))[np.array(idx)] 597 | 598 | return acc, idx_onehot, blocks_acc 599 | 600 | 601 | def edge_accuracy_perm_fNRI(preds_list, targets, num_edge_types_list, skip_first=False): 602 | 603 | if skip_first and all(e == 2 for e in num_edge_types_list): 604 | acc_batch, perm_code_onehot, acc_blocks_batch = edge_accuracy_perm_fNRI_batch_skipfirst(preds_list, targets, len(num_edge_types_list)) 605 | else: 606 | acc_batch, perm_code_onehot, acc_blocks_batch = edge_accuracy_perm_fNRI_batch(preds_list, targets, num_edge_types_list) 607 | 608 | acc = np.mean(acc_batch) 609 | acc_var = np.var(acc_batch) 610 | acc_blocks = np.mean(acc_blocks_batch, axis=0) 611 | acc_var_blocks = np.var(acc_blocks_batch, axis=0) 612 | 613 | return acc, perm_code_onehot, acc_blocks, acc_var, acc_var_blocks 614 | 615 | def edge_accuracy_perm_sigmoid_batch(preds, targets): 616 | # permutation edge accuracy calculator for the sigmoid model 617 | # return the maximum accuracy of the batch over the permutations of the edge labels 618 | # also returns a one-hot encoding of the number which represents this permutation 619 | # also returns the accuracies for the individual factor graph_list 620 | 621 | targets = np.swapaxes(np.array(targets.cpu()),1,2) 622 | preds = np.array(preds.cpu().detach()) 623 | preds = np.rint(preds).astype('int') 624 | num_factors = targets.shape[-1] 625 | perms = [p for p in permutations(range(num_factors))] # list of edge type permutations 626 | 627 | # in the below, for each permutation of edge-types, permute preds, then take argmax to go from one-hot to integers 628 | # then compare to target to compute accuracy 629 | acc = np.array([np.mean( np.sum(np.equal(targets, preds[:,:,p],dtype=object),axis=-1)==num_factors ) for p in perms]) 630 | max_acc, idx = np.amax(acc), np.argmax(acc) 631 | 632 | preds_deperm = preds[:,:,perms[idx]] 633 | blocks_acc = np.mean(np.equal(targets, preds_deperm, dtype=object),axis=1) 634 | acc = np.mean( np.sum(np.equal(targets, preds_deperm,dtype=object),axis=-1)==num_factors, axis=-1) 635 | 636 | idx_onehot = np.eye(len(perms))[np.array(idx)] 637 | return acc, idx_onehot, blocks_acc 638 | 639 | 640 | def edge_accuracy_perm_sigmoid(preds, targets): 641 | acc_batch, perm_code_onehot, acc_blocks_batch= edge_accuracy_perm_sigmoid_batch(preds, targets) 642 | 643 | acc = np.mean(acc_batch) 644 | acc_var = np.var(acc_batch) 645 | acc_blocks = np.mean(acc_blocks_batch, axis=0) 646 | acc_var_blocks = np.var(acc_blocks_batch, axis=0) 647 | 648 | return acc, perm_code_onehot, acc_blocks, acc_var, acc_var_blocks 649 | -------------------------------------------------------------------------------- /train_enc.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is based on https://github.com/ethanfetaya/NRI 3 | (MIT licence) 4 | """ 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import time 9 | import argparse 10 | import pickle 11 | import os 12 | import datetime 13 | import csv 14 | import math 15 | 16 | import torch.optim as optim 17 | from torch.optim import lr_scheduler 18 | 19 | from utils import * 20 | from modules import * 21 | 22 | parser = argparse.ArgumentParser() 23 | ## arguments related to training ## 24 | parser.add_argument('--epochs', type=int, default=500, 25 | help='Number of epochs to train.') 26 | parser.add_argument('--batch-size', type=int, default=128, 27 | help='Number of samples per batch.') 28 | parser.add_argument('--lr', type=float, default=0.0005, 29 | help='Initial learning rate.') 30 | parser.add_argument('--lr-decay', type=int, default=200, 31 | help='After how epochs to decay LR by a factor of gamma.') 32 | parser.add_argument('--gamma', type=float, default=0.5, 33 | help='LR decay factor.') 34 | parser.add_argument('--patience', type=int, default=500, 35 | help='Early stopping patience') 36 | parser.add_argument('--encoder-dropout', type=float, default=0.5, 37 | help='Dropout rate (1 - keep probability).') 38 | parser.add_argument('--weight-decay', type=float, default=0.0, 39 | help='Weight decay value for L2 regularisation in Adam optimiser') 40 | 41 | ## arguments related to weight and bias initialisation ## 42 | parser.add_argument('--seed', type=int, default=1, 43 | help='Random seed.') 44 | parser.add_argument('--encoder-init-type',type=str, default='xavier_normal', 45 | help='The type of weight initialization to use in the encoder') 46 | parser.add_argument('--encoder-bias-scale',type=float, default=0.1, 47 | help='The type of weight initialization to use in the encoder') 48 | 49 | ## arguments related to changing the model ## 50 | parser.add_argument('--NRI', action='store_true', default=False, 51 | help='Use the NRI model, rather than the fNRI model') 52 | parser.add_argument('--edge-types-list', nargs='+', default=[2,2], 53 | help='The number of edge types to infer.') # takes arguments from cmd line as: --edge-types-list 2 2 54 | parser.add_argument('--split-point', type=int, default=0, 55 | help='The point at which factor graphs are split up in the encoder' ) 56 | parser.add_argument('--encoder', type=str, default='mlp', 57 | help='Type of path encoder model (mlp or cnn).') 58 | parser.add_argument('--encoder-hidden', type=int, default=256, 59 | help='Number of hidden units.') 60 | parser.add_argument('--prior', action='store_true', default=False, 61 | help='Whether to use sparsity prior.') 62 | 63 | parser.add_argument('--sigmoid', action='store_true', default=False, 64 | help='Use the sfNRI model, rather than the fNRI model') 65 | parser.add_argument('--num-factors', type=int, default=2, 66 | help='The number of factor graphs (this is only for sigmoid)') 67 | parser.add_argument('--sigmoid-sharpness', type=float, default=1., 68 | help='Coefficient in the power of the sigmoid function') 69 | 70 | ## arguments related to the simulation data ## 71 | parser.add_argument('--sim-folder', type=str, default='springcharge_5', 72 | help='Name of the folder in the data folder to load simulation data from') 73 | parser.add_argument('--data-folder', type=str, default='data', 74 | help='Name of the data folder to load data from') 75 | parser.add_argument('--num-atoms', type=int, default=5, 76 | help='Number of atoms in simulation.') 77 | parser.add_argument('--dims', type=int, default=4, 78 | help='The number of input dimensions (position + velocity).') 79 | parser.add_argument('--timesteps', type=int, default=49, 80 | help='The number of time steps per sample.') 81 | 82 | ## Saving, loading etc. ## 83 | parser.add_argument('--no-cuda', action='store_true', default=False, 84 | help='Disables CUDA training.') 85 | parser.add_argument('--save-folder', type=str, default='logs', 86 | help='Where to save the trained model, leave empty to not save anything.') 87 | parser.add_argument('--load-folder', type=str, default='', 88 | help='Where to load the trained model if finetunning. ' + 89 | 'Leave empty to train from scratch') 90 | parser.add_argument('--test', action='store_true', default=False, 91 | help='Skip training and validation') 92 | parser.add_argument('--no-edge-acc', action='store_true', default=False, 93 | help='Skip training and plot accuracy distributions') 94 | 95 | 96 | args = parser.parse_args() 97 | args.cuda = not args.no_cuda and torch.cuda.is_available() 98 | #args.factor = not args.no_factor 99 | args.edge_types_list = list(map(int, args.edge_types_list)) 100 | args.edge_types_list.sort(reverse=True) 101 | 102 | if all( (isinstance(k, int) and k >= 1) for k in args.edge_types_list): 103 | if args.NRI: 104 | edge_types = np.prod(args.edge_types_list) 105 | else: 106 | edge_types = sum(args.edge_types_list) 107 | else: 108 | raise ValueError('Could not compute the edge-types-list') 109 | 110 | if args.NRI: 111 | print('Using NRI model') 112 | if args.split_point != 0: 113 | args.split_point = 0 114 | print(args) 115 | 116 | if args.prior: 117 | prior = [ [0.9, 0.1] , [0.9, 0.1] ] # TODO: hard coded for now 118 | if not all( prior[i].size == edge_types_list[i] for i in range(len(args.edge_types_list))): 119 | raise ValueError('Prior is incompatable with the edge types list') 120 | print("Using prior: "+str(prior)) 121 | log_prior = [] 122 | for i in range(len(args.edge_types_list)): 123 | prior_i = np.array(prior[i]) 124 | log_prior_i = torch.FloatTensor(np.log(prior)) 125 | log_prior_i = torch.unsqueeze(log_prior_i, 0) 126 | log_prior_i = torch.unsqueeze(log_prior_i, 0) 127 | log_prior_i = Variable(log_prior_i) 128 | log_prior.append(log_prior_i) 129 | if args.cuda: 130 | log_prior = log_prior.cuda() 131 | 132 | np.random.seed(args.seed) 133 | torch.manual_seed(args.seed) 134 | if args.cuda: 135 | torch.cuda.manual_seed(args.seed) 136 | 137 | 138 | # Save model and meta-data. Always saves in a new sub-folder. 139 | if args.save_folder: 140 | exp_counter = 0 141 | now = datetime.datetime.now() 142 | timestamp = now.isoformat().replace(':','-')[:-7] 143 | save_folder = os.path.join(args.save_folder,'exp'+timestamp) 144 | os.makedirs(save_folder) 145 | meta_file = os.path.join(save_folder, 'metadata.pkl') 146 | encoder_file = os.path.join(save_folder, 'encoder.pt') 147 | 148 | log_file = os.path.join(save_folder, 'log.txt') 149 | log_csv_file = os.path.join(save_folder, 'log_csv.csv') 150 | log = open(log_file, 'w') 151 | log_csv = open(log_csv_file, 'w') 152 | csv_writer = csv.writer(log_csv, delimiter=',') 153 | 154 | pickle.dump({'args': args}, open(meta_file, "wb")) 155 | par_file = open(os.path.join(save_folder,'args.txt'),'w') 156 | print(args,file=par_file) 157 | par_file.flush 158 | par_file.close() 159 | 160 | else: 161 | print("WARNING: No save_folder provided!" + 162 | "Testing (within this script) will throw an error.") 163 | 164 | 165 | if args.NRI: 166 | train_loader, valid_loader, test_loader, loc_max, loc_min, vel_max, vel_min = load_data_NRI( 167 | args.batch_size, args.sim_folder, shuffle=True, 168 | data_folder=args.data_folder) 169 | else: 170 | train_loader, valid_loader, test_loader, loc_max, loc_min, vel_max, vel_min = load_data_fNRI( 171 | args.batch_size, args.sim_folder, shuffle=True, 172 | data_folder=args.data_folder) 173 | 174 | 175 | # Generate off-diagonal interaction graph 176 | off_diag = np.ones([args.num_atoms, args.num_atoms]) - np.eye(args.num_atoms) 177 | rel_rec = np.array(encode_onehot(np.where(off_diag)[1]), dtype=np.float32) 178 | rel_send = np.array(encode_onehot(np.where(off_diag)[0]), dtype=np.float32) 179 | rel_rec = torch.FloatTensor(rel_rec) 180 | rel_send = torch.FloatTensor(rel_send) 181 | 182 | if args.NRI: 183 | edge_types_list = [ edge_types ] 184 | else: 185 | edge_types_list = args.edge_types_list 186 | 187 | if args.encoder == 'mlp': 188 | if args.sigmoid: 189 | encoder = MLPEncoder_sigmoid(args.timesteps * args.dims, args.encoder_hidden, 190 | args.num_factors,args.encoder_dropout, 191 | split_point=args.split_point) 192 | else: 193 | encoder = MLPEncoder_multi(args.timesteps * args.dims, args.encoder_hidden, 194 | edge_types_list, args.encoder_dropout, 195 | split_point=args.split_point, 196 | init_type=args.encoder_init_type, 197 | bias_init=args.encoder_bias_scale) 198 | 199 | elif args.encoder == 'cnn': 200 | encoder = CNNEncoder_multi(args.dims, args.encoder_hidden, 201 | edge_types_list, 202 | args.encoder_dropout, 203 | split_point=args.split_point, 204 | init_type=args.encoder_init_type) 205 | 206 | 207 | 208 | if args.load_folder: 209 | print('Loading model from: '+args.load_folder) 210 | encoder_file = os.path.join(args.load_folder, 'encoder.pt') 211 | if not args.cuda: 212 | encoder.load_state_dict(torch.load(encoder_file,map_location='cpu')) 213 | else: 214 | encoder.load_state_dict(torch.load(encoder_file)) 215 | args.save_folder = False 216 | 217 | optimizer = optim.Adam(list(encoder.parameters()), 218 | lr=args.lr, weight_decay=args.weight_decay) 219 | scheduler = lr_scheduler.StepLR(optimizer, step_size=args.lr_decay, 220 | gamma=args.gamma) 221 | 222 | 223 | if args.cuda: 224 | encoder.cuda() 225 | rel_rec = rel_rec.cuda() 226 | rel_send = rel_send.cuda() 227 | 228 | rel_rec = Variable(rel_rec) 229 | rel_send = Variable(rel_send) 230 | 231 | 232 | def train(epoch, best_val_loss): 233 | t = time.time() 234 | 235 | kl_train = [] 236 | kl_list_train = [] 237 | kl_var_list_train = [] 238 | 239 | acc_train = [] 240 | acc_blocks_train = [] 241 | acc_var_train = [] 242 | acc_var_blocks_train = [] 243 | 244 | KLb_train = [] 245 | KLb_blocks_train = [] 246 | 247 | ce_train = [] 248 | 249 | encoder.train() 250 | scheduler.step() 251 | 252 | for batch_idx, (data, relations) in enumerate(train_loader): # relations are the ground truth interactions graphs 253 | if args.cuda: 254 | data, relations = data.cuda(), relations.cuda() 255 | data, relations = Variable(data), Variable(relations) 256 | 257 | data_encoder = data[:, :, :args.timesteps, :].contiguous() 258 | 259 | optimizer.zero_grad() 260 | 261 | logits = encoder(data_encoder, rel_rec, rel_send) 262 | 263 | if args.NRI: 264 | prob = my_softmax(logits, -1) 265 | 266 | loss_kl = kl_categorical_uniform(prob, args.num_atoms, edge_types) 267 | loss_kl_split = [ loss_kl ] 268 | loss_kl_var_split = [ kl_categorical_uniform_var(prob, args.num_atoms, edge_types) ] 269 | kl_train.append(loss_kl.data.item()) 270 | kl_list_train.append([kl.data.item() for kl in loss_kl_split]) 271 | kl_var_list_train.append([kl_var.data.item() for kl_var in loss_kl_var_split]) 272 | 273 | KLb_train.append( 0 ) 274 | KLb_blocks_train.append([0]) 275 | 276 | preds = np.array(logits.max(-1)[1].cpu()) 277 | targets = np.array(relations.cpu()) 278 | preds_list = decode_target( preds, args.edge_types_list ) 279 | target_list = decode_target( targets, args.edge_types_list ) 280 | 281 | acc = np.mean(np.equal(targets, preds,dtype=object)) 282 | acc_blocks = np.array([ np.mean(np.equal(target_list[i], preds_list[i], dtype=object)) 283 | for i in range(len(target_list)) ]) 284 | acc_var = np.var(np.mean(np.equal(targets, preds,dtype=object),axis=-1)) 285 | acc_var_blocks = np.array([ np.var(np.mean(np.equal(target_list[i], preds_list[i], dtype=object),axis=-1)) 286 | for i in range(len(target_list)) ]) 287 | 288 | logits = logits.view(-1, edge_types) 289 | relations = relations.view(-1) 290 | 291 | loss = F.cross_entropy(logits, relations) 292 | 293 | elif args.sigmoid: 294 | edges = 1/(1+torch.exp(-args.sigmoid_sharpness*logits)) 295 | 296 | targets = np.swapaxes(np.array(relations.cpu()),1,2) 297 | preds = np.array(edges.cpu().detach()) 298 | preds = np.rint(preds).astype('int') 299 | 300 | acc = np.mean( np.sum(np.equal(targets, preds,dtype=object),axis=-1)==args.num_factors ) 301 | acc_blocks = np.mean(np.equal(targets, preds, dtype=object),axis=(0,1)) 302 | acc_var = np.var(np.mean( np.sum(np.equal(targets, preds,dtype=object), axis=-1)==args.num_factors, axis=1)) 303 | acc_var_blocks = np.var(np.mean(np.equal(targets, preds, dtype=object), axis=1), axis=0) 304 | 305 | edges = edges.view(-1) 306 | relations = relations.transpose(1,2).type(torch.FloatTensor).contiguous().view(-1) 307 | if args.cuda: 308 | relations = relations.cuda() 309 | loss = F.binary_cross_entropy( edges, relations ) 310 | 311 | kl_train.append(0) 312 | kl_list_train.append([0]) 313 | kl_var_list_train.append([0]) 314 | KLb_train.append( 0 ) 315 | KLb_blocks_train.append( [0] ) 316 | 317 | else: 318 | # dim of logits, edges and prob are [batchsize, N^2-N, sum(edge_types_list)] where N = no. of particles 319 | logits_split = torch.split(logits, args.edge_types_list, dim=-1) 320 | 321 | prob_split = [my_softmax(logits_i, -1) for logits_i in logits_split ] 322 | 323 | if args.prior: 324 | loss_kl_split = [kl_categorical(prob_split[type_idx], log_prior[type_idx], args.num_atoms) 325 | for type_idx in range(len(args.edge_types_list)) ] 326 | loss_kl = sum(loss_kl_split) 327 | else: 328 | loss_kl_split = [ kl_categorical_uniform(prob_split[type_idx], args.num_atoms, 329 | args.edge_types_list[type_idx]) 330 | for type_idx in range(len(args.edge_types_list)) ] 331 | loss_kl = sum(loss_kl_split) 332 | 333 | loss_kl_var_split = [ kl_categorical_uniform_var(prob_split[type_idx], args.num_atoms, 334 | args.edge_types_list[type_idx]) 335 | for type_idx in range(len(args.edge_types_list)) ] 336 | 337 | kl_train.append(loss_kl.data.item()) 338 | kl_list_train.append([kl.data.item() for kl in loss_kl_split]) 339 | kl_var_list_train.append([kl_var.data.item() for kl_var in loss_kl_var_split]) 340 | KLb_blocks = KL_between_blocks(prob_split, args.num_atoms) 341 | KLb_train.append(sum(KLb_blocks).data.item()) 342 | KLb_blocks_train.append([KL.data.item() for KL in KLb_blocks]) 343 | 344 | targets = np.swapaxes(np.array(relations.cpu()),1,2) 345 | preds = torch.cat( [ torch.unsqueeze(pred.max(-1)[1],-1) for pred in logits_split], -1 ) 346 | preds = np.array(preds.cpu()) 347 | 348 | acc = np.mean( np.sum(np.equal(targets, preds,dtype=object),axis=-1)==len(args.edge_types_list) ) 349 | acc_blocks = np.mean(np.equal(targets, preds, dtype=object),axis=(0,1)) 350 | acc_var = np.var(np.mean(np.sum(np.equal(targets, preds,dtype=object),axis=-1)==len(args.edge_types_list), axis=-1)) 351 | acc_var_blocks = np.var(np.mean(np.equal(targets, preds, dtype=object), axis=1),axis=0) 352 | 353 | loss = 0 354 | for i in range(len(args.edge_types_list)): 355 | logits_i = logits_split[i].view(-1, args.edge_types_list[i]) 356 | relations_i = relations[:,i,:].contiguous().view(-1) 357 | loss += F.cross_entropy(logits_i, relations_i) 358 | 359 | 360 | loss.backward() 361 | optimizer.step() 362 | 363 | acc_train.append(acc) 364 | acc_blocks_train.append(acc_blocks) 365 | acc_var_train.append(acc_var) 366 | acc_var_blocks_train.append(acc_var_blocks) 367 | 368 | ce_train.append(loss.data.item()) 369 | 370 | 371 | 372 | kl_val = [] 373 | kl_list_val = [] 374 | kl_var_list_val = [] 375 | 376 | acc_val = [] 377 | acc_blocks_val = [] 378 | acc_var_val = [] 379 | acc_var_blocks_val = [] 380 | 381 | KLb_val = [] 382 | KLb_blocks_val = [] # KL between blocks list 383 | 384 | ce_val = [] 385 | 386 | encoder.eval() 387 | for batch_idx, (data, relations) in enumerate(valid_loader): 388 | with torch.no_grad(): 389 | if args.cuda: 390 | data, relations = data.cuda(), relations.cuda() 391 | 392 | data_encoder = data[:, :, :args.timesteps, :].contiguous() 393 | 394 | logits = encoder(data_encoder, rel_rec, rel_send) 395 | 396 | if args.NRI: 397 | prob = my_softmax(logits, -1) 398 | 399 | loss_kl = kl_categorical_uniform(prob, args.num_atoms, edge_types) 400 | loss_kl_split = [ loss_kl ] 401 | loss_kl_var_split = [ kl_categorical_uniform_var(prob, args.num_atoms, edge_types) ] 402 | kl_val.append(loss_kl.data.item()) 403 | kl_list_val.append([kl.data.item() for kl in loss_kl_split]) 404 | kl_var_list_val.append([kl_var.data.item() for kl_var in loss_kl_var_split]) 405 | 406 | KLb_val.append( 0 ) 407 | KLb_blocks_val.append([0]) 408 | 409 | preds = np.array(logits.max(-1)[1].cpu()) 410 | targets = np.array(relations.cpu()) 411 | preds_list = decode_target( preds, args.edge_types_list ) 412 | target_list = decode_target( targets, args.edge_types_list ) 413 | 414 | acc = np.mean(np.equal(targets, preds,dtype=object)) 415 | acc_blocks = np.array([ np.mean(np.equal(target_list[i], preds_list[i], dtype=object)) 416 | for i in range(len(target_list)) ]) 417 | acc_var = np.var(np.mean(np.equal(targets, preds,dtype=object),axis=-1)) 418 | acc_var_blocks = np.array([ np.var(np.mean(np.equal(target_list[i], preds_list[i], dtype=object),axis=-1)) 419 | for i in range(len(target_list)) ]) 420 | 421 | logits = logits.view(-1, edge_types) 422 | relations = relations.view(-1) 423 | 424 | loss = F.cross_entropy(logits, relations) 425 | 426 | elif args.sigmoid: 427 | edges = 1/(1+torch.exp(-args.sigmoid_sharpness*logits)) 428 | 429 | targets = np.swapaxes(np.array(relations.cpu()),1,2) 430 | preds = np.array(edges.cpu().detach()) 431 | preds = np.rint(preds).astype('int') 432 | 433 | acc = np.mean( np.sum(np.equal(targets, preds,dtype=object),axis=-1)==args.num_factors ) 434 | acc_blocks = np.mean(np.equal(targets, preds, dtype=object),axis=(0,1)) 435 | acc_var = np.var(np.mean( np.sum(np.equal(targets, preds,dtype=object), axis=-1)==args.num_factors, axis=1)) 436 | acc_var_blocks = np.var(np.mean(np.equal(targets, preds, dtype=object), axis=1), axis=0) 437 | 438 | edges = edges.view(-1) 439 | relations = relations.transpose(1,2).type(torch.FloatTensor).contiguous().view(-1) 440 | if args.cuda: 441 | relations = relations.cuda() 442 | loss = F.binary_cross_entropy( edges, relations ) 443 | 444 | kl_val.append(0) 445 | kl_list_val.append([0]) 446 | kl_var_list_val.append([0]) 447 | KLb_val.append( 0 ) 448 | KLb_blocks_val.append( [0] ) 449 | 450 | else: 451 | logits_split = torch.split(logits, args.edge_types_list, dim=-1) 452 | prob_split = [my_softmax(logits_i, -1) for logits_i in logits_split ] 453 | 454 | if args.prior: 455 | loss_kl_split = [kl_categorical(prob_split[type_idx], log_prior[type_idx], args.num_atoms) 456 | for type_idx in range(len(args.edge_types_list)) ] 457 | loss_kl = sum(loss_kl_split) 458 | else: 459 | loss_kl_split = [ kl_categorical_uniform(prob_split[type_idx], args.num_atoms, 460 | args.edge_types_list[type_idx]) 461 | for type_idx in range(len(args.edge_types_list)) ] 462 | loss_kl = sum(loss_kl_split) 463 | 464 | loss_kl_var_split = [ kl_categorical_uniform_var(prob_split[type_idx], args.num_atoms, 465 | args.edge_types_list[type_idx]) 466 | for type_idx in range(len(args.edge_types_list)) ] 467 | 468 | kl_val.append(loss_kl.data.item()) 469 | kl_list_val.append([kl.data.item() for kl in loss_kl_split]) 470 | kl_var_list_val.append([kl_var.data.item() for kl_var in loss_kl_var_split]) 471 | 472 | targets = np.swapaxes(np.array(relations.cpu()),1,2) 473 | preds = torch.cat( [ torch.unsqueeze(pred.max(-1)[1],-1) for pred in logits_split], -1 ) 474 | preds = np.array(preds.cpu()) 475 | 476 | acc = np.mean( np.sum(np.equal(targets, preds,dtype=object),axis=-1)==len(args.edge_types_list) ) 477 | acc_blocks = np.mean(np.equal(targets, preds, dtype=object),axis=(0,1)) 478 | acc_var = np.var(np.mean(np.sum(np.equal(targets, preds,dtype=object),axis=-1)==len(args.edge_types_list), axis=-1)) 479 | acc_var_blocks = np.var(np.mean(np.equal(targets, preds, dtype=object), axis=1),axis=0) 480 | 481 | loss = 0 482 | for i in range(len(args.edge_types_list)): 483 | logits_i = logits_split[i].view(-1, args.edge_types_list[i]) 484 | relations_i = relations[:,i,:].contiguous().view(-1) 485 | loss += F.cross_entropy(logits_i, relations_i) 486 | 487 | KLb_blocks = KL_between_blocks(prob_split, args.num_atoms) 488 | KLb_val.append(sum(KLb_blocks).data.item()) 489 | KLb_blocks_val.append([KL.data.item() for KL in KLb_blocks]) 490 | 491 | 492 | acc_val.append(acc) 493 | acc_blocks_val.append(acc_blocks) 494 | acc_var_val.append(acc_var) 495 | acc_var_blocks_val.append(acc_var_blocks) 496 | 497 | ce_val.append(loss.data.item()) 498 | 499 | 500 | print('Epoch: {:03d}'.format(epoch), 501 | 'time: {:.1f}s'.format(time.time() - t)) 502 | print('ce_trn: {:.5f}'.format(np.mean(ce_train)), 503 | 'kl_trn: {:.5f}'.format(np.mean(kl_train)), 504 | 'acc_trn: {:.5f}'.format(np.mean(acc_train)), 505 | 'KLb_trn: {:.5f}'.format(np.mean(KLb_train)), 506 | 'acc_b_trn: '+str( np.around(np.mean(np.array(acc_blocks_train),axis=0),4 ) ), 507 | 'kl_trn: '+str( np.around(np.mean(np.array(kl_list_train),axis=0),4 ) ) 508 | ) 509 | print('ce_val: {:.5f}'.format(np.mean(ce_val)), 510 | 'kl_val: {:.5f}'.format(np.mean(kl_val)), 511 | 'acc_val: {:.5f}'.format(np.mean(acc_val)), 512 | 'KLb_val: {:.5f}'.format(np.mean(KLb_val)), 513 | 'acc_b_val: '+str( np.around(np.mean(np.array(acc_blocks_val),axis=0),4 ) ), 514 | 'kl_val: '+str( np.around(np.mean(np.array(kl_list_val),axis=0),4 ) ), 515 | ) 516 | print('Epoch: {:04d}'.format(epoch), 517 | 'time: {:.4f}s'.format(time.time() - t), 518 | file=log) 519 | print('ce_trn: {:.5f}'.format(np.mean(ce_train)), 520 | 'kl_trn: {:.5f}'.format(np.mean(kl_train)), 521 | 'acc_trn: {:.5f}'.format(np.mean(acc_train)), 522 | 'KLb_trn: {:.5f}'.format(np.mean(KLb_train)), 523 | 'acc_b_trn: '+str( np.around(np.mean(np.array(acc_blocks_train),axis=0),4 ) ), 524 | 'kl_trn: '+str( np.around(np.mean(np.array(kl_list_train),axis=0),4 ) ), 525 | file=log ) 526 | print('ce_val: {:.5f}'.format(np.mean(ce_val)), 527 | 'kl_val: {:.5f}'.format(np.mean(kl_val)), 528 | 'acc_val: {:.5f}'.format(np.mean(acc_val)), 529 | 'KLb_val: {:.5f}'.format(np.mean(KLb_val)), 530 | 'acc_b_val: '+str( np.around(np.mean(np.array(acc_blocks_val),axis=0),4 ) ), 531 | 'kl_val: '+str( np.around(np.mean(np.array(kl_list_val),axis=0),4 ) ), 532 | file=log) 533 | if epoch == 0: 534 | labels = [ 'epoch', 'ce trn', 'kl trn', 'KLb trn', 'acc trn' ] 535 | labels += [ 'b'+str(i)+ ' acc trn' for i in range( len(args.edge_types_list) ) ] 536 | labels += [ 'b'+str(i)+ ' kl trn' for i in range( len(kl_list_train[0]) ) ] 537 | labels += [ 'b'+str(i)+' kl var trn' for i in range( len(kl_list_train[0]) ) ] 538 | labels += [ 'acc var trn'] + [ 'b'+str(i)+' acc var trn' for i in range( len(args.edge_types_list) ) ] 539 | labels += [ 'ce val', 'kl val', 'KLb val', 'acc val' ] 540 | labels += [ 'b'+str(i)+ ' acc val' for i in range( len(args.edge_types_list) ) ] 541 | labels += [ 'b'+str(i)+ ' kl val' for i in range( len(kl_list_val[0]) ) ] 542 | labels += [ 'b'+str(i)+' kl var val' for i in range( len(kl_list_val[0]) ) ] 543 | labels += [ 'acc var val'] + [ 'b'+str(i)+' acc var val' for i in range( len(args.edge_types_list) ) ] 544 | csv_writer.writerow( labels ) 545 | 546 | 547 | csv_writer.writerow( [epoch, np.mean(ce_train), np.mean(kl_train), np.mean(KLb_train), np.mean(acc_train)] + 548 | list(np.mean(np.array(acc_blocks_train),axis=0)) + 549 | list(np.mean(np.array(kl_list_train),axis=0)) + 550 | list(np.mean(np.array(kl_var_list_train),axis=0)) + 551 | [np.mean(acc_var_train)] + list(np.mean(np.array(acc_var_blocks_train),axis=0)) + 552 | [ np.mean(ce_val), np.mean(kl_val), np.mean(KLb_val), np.mean(acc_val) ] + 553 | list(np.mean(np.array(acc_blocks_val ),axis=0)) + 554 | list(np.mean(np.array(kl_list_val),axis=0)) + 555 | list(np.mean(np.array(kl_var_list_val),axis=0)) + 556 | [np.mean(acc_var_val)] + list(np.mean(np.array(acc_var_blocks_val),axis=0)) 557 | ) 558 | 559 | log.flush() 560 | if args.save_folder and np.mean(acc_val) > best_val_loss: 561 | torch.save(encoder.state_dict(), encoder_file) 562 | print('Best model so far, saving...') 563 | return np.mean(acc_val) 564 | 565 | 566 | def test(): 567 | t = time.time() 568 | 569 | ce_test = [] 570 | 571 | kl_test = [] 572 | kl_list_test = [] 573 | kl_var_list_test = [] 574 | 575 | acc_test = [] 576 | acc_blocks_test = [] 577 | acc_var_test = [] 578 | acc_var_blocks_test = [] 579 | 580 | KLb_test = [] 581 | KLb_blocks_test = [] # KL between blocks list 582 | 583 | encoder.eval() 584 | if not args.cuda: 585 | encoder.load_state_dict(torch.load(encoder_file,map_location='cpu')) 586 | else: 587 | encoder.load_state_dict(torch.load(encoder_file)) 588 | 589 | for batch_idx, (data, relations) in enumerate(test_loader): 590 | with torch.no_grad(): 591 | if args.cuda: 592 | data, relations = data.cuda(), relations.cuda() 593 | 594 | data_encoder = data[:, :, :args.timesteps, :].contiguous() 595 | 596 | logits = encoder(data_encoder, rel_rec, rel_send) 597 | 598 | if args.NRI: 599 | prob = my_softmax(logits, -1) 600 | 601 | loss_kl = kl_categorical_uniform(prob, args.num_atoms, edge_types) 602 | loss_kl_split = [ loss_kl ] 603 | loss_kl_var_split = [ kl_categorical_uniform_var(prob, args.num_atoms, edge_types) ] 604 | kl_test.append(loss_kl.data.item()) 605 | kl_list_test.append([kl.data.item() for kl in loss_kl_split]) 606 | kl_var_list_test.append([kl_var.data.item() for kl_var in loss_kl_var_split]) 607 | 608 | KLb_test.append( 0 ) 609 | KLb_blocks_test.append([0]) 610 | 611 | preds = np.array(logits.max(-1)[1].cpu()) 612 | targets = np.array(relations.cpu()) 613 | preds_list = decode_target( preds, args.edge_types_list ) 614 | target_list = decode_target( targets, args.edge_types_list ) 615 | 616 | acc = np.mean(np.equal(targets, preds,dtype=object)) 617 | acc_blocks = np.array([ np.mean(np.equal(target_list[i], preds_list[i], dtype=object)) 618 | for i in range(len(target_list)) ]) 619 | acc_var = np.var(np.mean(np.equal(targets, preds,dtype=object),axis=-1)) 620 | acc_var_blocks = np.array([ np.var(np.mean(np.equal(target_list[i], preds_list[i], dtype=object),axis=-1)) 621 | for i in range(len(target_list)) ]) 622 | 623 | logits = logits.view(-1, edge_types) 624 | relations = relations.view(-1) 625 | 626 | loss = F.cross_entropy(logits, relations) 627 | 628 | elif args.sigmoid: 629 | edges = 1/(1+torch.exp(-args.sigmoid_sharpness*logits)) 630 | 631 | targets = np.swapaxes(np.array(relations.cpu()),1,2) 632 | preds = np.array(edges.cpu().detach()) 633 | preds = np.rint(preds).astype('int') 634 | 635 | acc = np.mean( np.sum(np.equal(targets, preds,dtype=object),axis=-1)==args.num_factors ) 636 | acc_blocks = np.mean(np.equal(targets, preds, dtype=object),axis=(0,1)) 637 | acc_var = np.var(np.mean( np.sum(np.equal(targets, preds,dtype=object), axis=-1)==args.num_factors, axis=1)) 638 | acc_var_blocks = np.var(np.mean(np.equal(targets, preds, dtype=object), axis=1), axis=0) 639 | 640 | edges = edges.view(-1) 641 | relations = relations.transpose(1,2).type(torch.FloatTensor).contiguous().view(-1) 642 | if args.cuda: 643 | relations = relations.cuda() 644 | loss = F.binary_cross_entropy( edges, relations ) 645 | 646 | kl_test.append(0) 647 | kl_list_test.append([0]) 648 | kl_var_list_test.append([0]) 649 | KLb_test.append( 0 ) 650 | KLb_blocks_test.append( [0] ) 651 | 652 | else: 653 | logits_split = torch.split(logits, args.edge_types_list, dim=-1) 654 | 655 | prob_split = [my_softmax(logits_i, -1) for logits_i in logits_split ] 656 | 657 | if args.prior: 658 | loss_kl_split = [kl_categorical(prob_split[type_idx], log_prior[type_idx], args.num_atoms) 659 | for type_idx in range(len(args.edge_types_list)) ] 660 | loss_kl = sum(loss_kl_split) 661 | else: 662 | loss_kl_split = [ kl_categorical_uniform(prob_split[type_idx], args.num_atoms, 663 | args.edge_types_list[type_idx]) 664 | for type_idx in range(len(args.edge_types_list)) ] 665 | loss_kl = sum(loss_kl_split) 666 | 667 | loss_kl_var_split = [ kl_categorical_uniform_var(prob_split[type_idx], args.num_atoms, 668 | args.edge_types_list[type_idx]) 669 | for type_idx in range(len(args.edge_types_list)) ] 670 | 671 | kl_test.append(loss_kl.data.item()) 672 | kl_list_test.append([kl.data.item() for kl in loss_kl_split]) 673 | kl_var_list_test.append([kl_var.data.item() for kl_var in loss_kl_var_split]) 674 | 675 | targets = np.swapaxes(np.array(relations.cpu()),1,2) 676 | preds = torch.cat( [ torch.unsqueeze(pred.max(-1)[1],-1) for pred in logits_split], -1 ) 677 | preds = np.array(preds.cpu()) 678 | 679 | acc = np.mean( np.sum(np.equal(targets, preds,dtype=object),axis=-1)==len(args.edge_types_list) ) 680 | acc_blocks = np.mean(np.equal(targets, preds, dtype=object),axis=(0,1)) 681 | acc_var = np.var(np.mean(np.sum(np.equal(targets, preds,dtype=object),axis=-1)==len(args.edge_types_list), axis=-1)) 682 | acc_var_blocks = np.var(np.mean(np.equal(targets, preds, dtype=object), axis=1),axis=0) 683 | 684 | loss = 0 685 | for i in range(len(args.edge_types_list)): 686 | logits_i = logits_split[i].view(-1, args.edge_types_list[i]) 687 | relations_i = relations[:,i,:].contiguous().view(-1) 688 | loss += F.cross_entropy(logits_i, relations_i) 689 | 690 | KLb_blocks = KL_between_blocks(prob_split, args.num_atoms) 691 | KLb_test.append(sum(KLb_blocks).data.item()) 692 | KLb_blocks_test.append([KL.data.item() for KL in KLb_blocks]) 693 | 694 | ce_test.append(loss.data.item()) 695 | acc_test.append(acc) 696 | acc_blocks_test.append(acc_blocks) 697 | acc_var_test.append(acc_var) 698 | acc_var_blocks_test.append(acc_var_blocks) 699 | 700 | 701 | print('--------------------------------') 702 | print('------------Testing-------------') 703 | print('--------------------------------') 704 | print('ce_test: {:.2f}'.format(np.mean(ce_test)), 705 | 'kl_test: {:.5f}'.format(np.mean(kl_test)), 706 | 'acc_test: {:.5f}'.format(np.mean(acc_test)), 707 | 'acc_var_test: {:.5f}'.format(np.mean(acc_var_test)), 708 | 'KLb_test: {:.5f}'.format(np.mean(KLb_test)), 709 | 'time: {:.1f}s'.format(time.time() - t)) 710 | print('acc_b_test: '+str( np.around(np.mean(np.array(acc_blocks_test),axis=0),4 ) ), 711 | 'acc_var_b_test: '+str( np.around(np.mean(np.array(acc_var_blocks_test),axis=0),4 ) ), 712 | 'kl_test: '+str( np.around(np.mean(np.array(kl_list_test),axis=0),4 ) )) 713 | if args.save_folder: 714 | print('--------------------------------', file=log) 715 | print('------------Testing-------------', file=log) 716 | print('--------------------------------', file=log) 717 | print('ce_test: {:.2f}'.format(np.mean(ce_test)), 718 | 'kl_test: {:.5f}'.format(np.mean(kl_test)), 719 | 'acc_test: {:.5f}'.format(np.mean(acc_test)), 720 | 'acc_var_test: {:.5f}'.format(np.mean(acc_var_test)), 721 | 'KLb_test: {:.5f}'.format(np.mean(KLb_test)), 722 | 'time: {:.1f}s'.format(time.time() - t), 723 | file=log) 724 | print('acc_b_test: '+str( np.around(np.mean(np.array(acc_blocks_test),axis=0),4 ) ), 725 | 'acc_var_b_test: '+str( np.around(np.mean(np.array(acc_var_blocks_test),axis=0),4 ) ), 726 | 'kl_test: '+str( np.around(np.mean(np.array(kl_list_test),axis=0),4 ) ), 727 | file=log) 728 | log.flush() 729 | 730 | 731 | # Train model 732 | if not args.test: 733 | t_total = time.time() 734 | best_val_loss = 0 735 | best_epoch = 0 736 | for epoch in range(args.epochs): 737 | val_loss = train(epoch, best_val_loss) 738 | if val_loss > best_val_loss: 739 | best_val_loss = val_loss 740 | best_epoch = epoch 741 | if epoch - best_epoch > args.patience and epoch > 99: 742 | break 743 | print("Optimization Finished!") 744 | print("Best Epoch: {:04d}".format(best_epoch)) 745 | if args.save_folder: 746 | print("Best Epoch: {:04d}".format(best_epoch), file=log) 747 | log.flush() 748 | 749 | print('Reloading best model') 750 | test() 751 | if log is not None: 752 | print(save_folder) 753 | log.close() 754 | log_csv.close() 755 | -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is based on https://github.com/ethanfetaya/NRI 3 | (MIT licence) 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import math 10 | 11 | from torch.autograd import Variable 12 | from utils import my_softmax, get_offdiag_indices, gumbel_softmax 13 | 14 | _EPS = 1e-10 15 | 16 | 17 | class MLP(nn.Module): 18 | """Two-layer fully-connected ELU net with batch norm.""" 19 | 20 | def __init__(self, n_in, n_hid, n_out, do_prob=0.): 21 | super(MLP, self).__init__() 22 | self.fc1 = nn.Linear(n_in, n_hid) 23 | self.fc2 = nn.Linear(n_hid, n_out) 24 | self.bn = nn.BatchNorm1d(n_out) 25 | self.dropout_prob = do_prob 26 | 27 | self.init_weights() 28 | 29 | def init_weights(self): 30 | for m in self.modules(): 31 | if isinstance(m, nn.Linear): 32 | nn.init.xavier_normal_(m.weight.data) 33 | m.bias.data.fill_(0.1) 34 | elif isinstance(m, nn.BatchNorm1d): 35 | m.weight.data.fill_(1) 36 | m.bias.data.zero_() 37 | 38 | def batch_norm(self, inputs): 39 | x = inputs.view(inputs.size(0) * inputs.size(1), -1) 40 | x = self.bn(x) 41 | return x.view(inputs.size(0), inputs.size(1), -1) 42 | 43 | def forward(self, inputs): 44 | # Input shape: [num_sims, num_things, num_features] 45 | x = F.elu(self.fc1(inputs)) 46 | x = F.dropout(x, self.dropout_prob, training=self.training) 47 | x = F.elu(self.fc2(x)) 48 | return self.batch_norm(x) 49 | 50 | 51 | class CNN(nn.Module): 52 | def __init__(self, n_in, n_hid, n_out, do_prob=0.): 53 | super(CNN, self).__init__() 54 | self.pool = nn.MaxPool1d(kernel_size=2, stride=None, padding=0, 55 | dilation=1, return_indices=False, 56 | ceil_mode=False) 57 | 58 | self.conv1 = nn.Conv1d(n_in, n_hid, kernel_size=5, stride=1, padding=0) 59 | self.bn1 = nn.BatchNorm1d(n_hid) 60 | self.conv2 = nn.Conv1d(n_hid, n_hid, kernel_size=5, stride=1, padding=0) 61 | self.bn2 = nn.BatchNorm1d(n_hid) 62 | self.conv_predict = nn.Conv1d(n_hid, n_out, kernel_size=1) 63 | self.conv_attention = nn.Conv1d(n_hid, 1, kernel_size=1) 64 | self.dropout_prob = do_prob 65 | 66 | self.init_weights() 67 | 68 | def init_weights(self): 69 | for m in self.modules(): 70 | if isinstance(m, nn.Conv1d): 71 | n = m.kernel_size[0] * m.out_channels 72 | m.weight.data.normal_(0, math.sqrt(2. / n)) 73 | m.bias.data.fill_(0.1) 74 | elif isinstance(m, nn.BatchNorm1d): 75 | m.weight.data.fill_(1) 76 | m.bias.data.zero_() 77 | 78 | def forward(self, inputs): 79 | # Input shape: [num_sims * num_edges, num_dims, num_timesteps] 80 | 81 | x = F.relu(self.conv1(inputs)) 82 | x = self.bn1(x) 83 | x = F.dropout(x, self.dropout_prob, training=self.training) 84 | x = self.pool(x) 85 | x = F.relu(self.conv2(x)) 86 | x = self.bn2(x) 87 | pred = self.conv_predict(x) 88 | attention = my_softmax(self.conv_attention(x), axis=2) 89 | 90 | edge_prob = (pred * attention).mean(dim=2) 91 | return edge_prob 92 | 93 | 94 | class MLPEncoder(nn.Module): 95 | def __init__(self, n_in, n_hid, n_out, do_prob=0., factor=True): 96 | super(MLPEncoder, self).__init__() 97 | 98 | self.factor = factor 99 | # n_hid = num edge types 100 | self.mlp1 = MLP(n_in, n_hid, n_hid, do_prob) 101 | self.mlp2 = MLP(n_hid * 2, n_hid, n_hid, do_prob) 102 | self.mlp3 = MLP(n_hid, n_hid, n_hid, do_prob) 103 | if self.factor: 104 | self.mlp4 = MLP(n_hid * 3, n_hid, n_hid, do_prob) 105 | print("Using factor graph MLP encoder.") 106 | else: 107 | self.mlp4 = MLP(n_hid * 2, n_hid, n_hid, do_prob) 108 | print("Using MLP encoder.") 109 | self.fc_out = nn.Linear(n_hid, n_out) 110 | self.init_weights() 111 | 112 | def init_weights(self): 113 | for m in self.modules(): 114 | if isinstance(m, nn.Linear): 115 | nn.init.xavier_normal_(m.weight.data) 116 | m.bias.data.fill_(0.1) 117 | 118 | def edge2node(self, x, rel_rec, rel_send): 119 | # NOTE: Assumes that we have the same graph across all samples. 120 | incoming = torch.matmul(rel_rec.t(), x) 121 | return incoming / incoming.size(1) 122 | 123 | def node2edge(self, x, rel_rec, rel_send): 124 | # NOTE: Assumes that we have the same graph across all samples. 125 | receivers = torch.matmul(rel_rec, x) 126 | senders = torch.matmul(rel_send, x) 127 | edges = torch.cat([receivers, senders], dim=2) 128 | return edges 129 | 130 | def forward(self, inputs, rel_rec, rel_send): 131 | # Input shape: [num_sims, num_atoms, num_timesteps, num_dims] 132 | x = inputs.view(inputs.size(0), inputs.size(1), -1) 133 | # New shape: [num_sims, num_atoms, num_timesteps*num_dims] 134 | 135 | x = self.mlp1(x) # 2-layer ELU net per node 136 | 137 | x = self.node2edge(x, rel_rec, rel_send) 138 | x = self.mlp2(x) 139 | x_skip = x 140 | 141 | if self.factor: 142 | x = self.edge2node(x, rel_rec, rel_send) 143 | x = self.mlp3(x) 144 | x = self.node2edge(x, rel_rec, rel_send) 145 | x = torch.cat((x, x_skip), dim=2) # Skip connection 146 | x = self.mlp4(x) 147 | else: 148 | x = self.mlp3(x) 149 | x = torch.cat((x, x_skip), dim=2) # Skip connection 150 | x = self.mlp4(x) 151 | 152 | return self.fc_out(x) 153 | 154 | class MLPEncoder_multi(nn.Module): 155 | def __init__(self, n_in, n_hid, edge_types_list, do_prob=0., split_point=1, 156 | init_type='xavier_normal', bias_init=0.0): 157 | super(MLPEncoder_multi, self).__init__() 158 | 159 | self.edge_types_list = edge_types_list 160 | self.mlp1 = MLP(n_in, n_hid, n_hid, do_prob) 161 | #print(self.mlp1.fc1.weight[0][0:5]) 162 | self.mlp2 = MLP(n_hid * 2, n_hid, n_hid, do_prob) 163 | 164 | self.init_type = init_type 165 | if self.init_type not in [ 'xavier_normal', 'orthogonal', 'sparse' ]: 166 | raise ValueError('This initialization type has not been coded') 167 | #print('Using '+self.init_type+' for encoder weight initialization') 168 | self.bias_init = bias_init 169 | 170 | self.split_point = split_point 171 | if split_point == 0: 172 | self.mlp3 = MLP(n_hid, n_hid, n_hid, do_prob) 173 | self.mlp4 = MLP(n_hid * 3, n_hid, n_hid, do_prob) 174 | self.fc_out = nn.ModuleList([nn.Linear(n_hid, sum(edge_types_list))]) 175 | elif split_point == 1: 176 | self.mlp3 = MLP(n_hid, n_hid, n_hid, do_prob) 177 | self.mlp4 = nn.ModuleList([MLP(n_hid * 3, n_hid, n_hid, do_prob) for _ in edge_types_list]) 178 | self.fc_out = nn.ModuleList([nn.Linear(n_hid, K) for K in edge_types_list]) 179 | elif split_point == 2: 180 | self.mlp3 = nn.ModuleList([MLP(n_hid, n_hid, n_hid, do_prob) for _ in edge_types_list]) 181 | self.mlp4 = nn.ModuleList([MLP(n_hid * 3, n_hid, n_hid, do_prob) for _ in edge_types_list]) 182 | self.fc_out = nn.ModuleList([nn.Linear(n_hid, K) for K in edge_types_list]) 183 | else: 184 | raise ValueError('Split point is not valid, must be 0, 1, or 2') 185 | 186 | self.init_weights() 187 | 188 | 189 | def init_weights(self): 190 | for m in self.modules(): 191 | if isinstance(m, nn.Linear): 192 | if self.init_type == 'orthogonal': 193 | nn.init.orthogonal_(m.weight.data) 194 | elif self.init_type == 'xavier_normal': 195 | nn.init.xavier_normal_(m.weight.data) 196 | elif self.init_type == 'sparse': 197 | nn.init.sparse_(m.weight.data, sparsity=0.1) 198 | 199 | if not math.isclose(self.bias_init, 0, rel_tol=1e-9): 200 | m.bias.data.fill_(self.bias_init) 201 | 202 | def edge2node(self, x, rel_rec, rel_send): 203 | # NOTE: Assumes that we have the same graph across all samples. 204 | incoming = torch.matmul(rel_rec.t(), x) 205 | return incoming / incoming.size(1) 206 | 207 | def node2edge(self, x, rel_rec, rel_send): 208 | # NOTE: Assumes that we have the same graph across all samples. 209 | receivers = torch.matmul(rel_rec, x) 210 | senders = torch.matmul(rel_send, x) 211 | edges = torch.cat([receivers, senders], dim=2) 212 | return edges 213 | 214 | def forward(self, inputs, rel_rec, rel_send): 215 | # Input shape: [num_sims, num_atoms, num_timesteps, num_dims] 216 | x = inputs.view(inputs.size(0), inputs.size(1), -1) 217 | # New shape: [num_sims, num_atoms, num_timesteps*num_dims] 218 | 219 | x = self.mlp1(x) # 2-layer ELU net per node 220 | 221 | x = self.node2edge(x, rel_rec, rel_send) 222 | x = self.mlp2(x) 223 | x_skip = x 224 | 225 | x = self.edge2node(x, rel_rec, rel_send) 226 | if self.split_point == 0: 227 | x = self.mlp3(x) 228 | x = self.node2edge(x, rel_rec, rel_send) 229 | x = torch.cat((x, x_skip), dim=2) # Skip connection 230 | x = self.mlp4(x) 231 | return self.fc_out[0](x) 232 | elif self.split_point == 1: 233 | x = self.mlp3(x) 234 | x = self.node2edge(x, rel_rec, rel_send) 235 | x = torch.cat((x, x_skip), dim=2) # Skip connection 236 | y_list = [] 237 | for i in range(len(self.edge_types_list)): 238 | y = self.mlp4[i](x) 239 | y_list.append( self.fc_out[i](y) ) 240 | return torch.cat(y_list,dim=-1) 241 | elif self.split_point == 2: 242 | y_list = [] 243 | for i in range(len(self.edge_types_list)): 244 | y = self.mlp3[i](x) 245 | y = self.node2edge(y, rel_rec, rel_send) 246 | y = torch.cat((y, x_skip), dim=2) # Skip connection 247 | y = self.mlp4[i](y) 248 | y_list.append( self.fc_out[i](y) ) 249 | return torch.cat(y_list,dim=-1) 250 | 251 | class MLPEncoder_sigmoid(nn.Module): 252 | def __init__(self, n_in, n_hid, num_factors, do_prob=0., split_point=1): 253 | super(MLPEncoder_sigmoid, self).__init__() 254 | 255 | self.num_factors = num_factors 256 | self.mlp1 = MLP(n_in, n_hid, n_hid, do_prob) 257 | self.mlp2 = MLP(n_hid * 2, n_hid, n_hid, do_prob) 258 | self.mlp3 = MLP(n_hid, n_hid, n_hid, do_prob) 259 | 260 | self.split_point = split_point 261 | if split_point == 0: 262 | self.mlp4 = MLP(n_hid * 3, n_hid, n_hid, do_prob) 263 | self.fc_out = nn.Linear(n_hid, num_factors) 264 | elif split_point == 1: 265 | self.mlp4 = nn.ModuleList([MLP(n_hid * 3, n_hid, n_hid, do_prob) for _ in range(num_factors)]) 266 | self.fc_out = nn.ModuleList([nn.Linear(n_hid, 1) for i in range(num_factors)]) 267 | elif split_point == 2: 268 | self.mlp3 = nn.ModuleList([MLP(n_hid, n_hid, n_hid, do_prob) for _ in range(num_factors)]) 269 | self.mlp4 = nn.ModuleList([MLP(n_hid * 3, n_hid, n_hid, do_prob) for _ in range(num_factors)]) 270 | self.fc_out = nn.ModuleList([nn.Linear(n_hid, 1) for i in range(num_factors)]) 271 | else: 272 | raise ValueError('Split point is not valid, must be 0, 1, or 2') 273 | 274 | self.init_weights() 275 | 276 | def init_weights(self): 277 | for m in self.modules(): 278 | if isinstance(m, nn.Linear): 279 | nn.init.xavier_normal_(m.weight.data) 280 | m.bias.data.fill_(0.1) 281 | 282 | def edge2node(self, x, rel_rec, rel_send): 283 | # NOTE: Assumes that we have the same graph across all samples. 284 | incoming = torch.matmul(rel_rec.t(), x) 285 | return incoming / incoming.size(1) 286 | 287 | def node2edge(self, x, rel_rec, rel_send): 288 | # NOTE: Assumes that we have the same graph across all samples. 289 | receivers = torch.matmul(rel_rec, x) 290 | senders = torch.matmul(rel_send, x) 291 | edges = torch.cat([receivers, senders], dim=2) 292 | return edges 293 | 294 | def forward(self, inputs, rel_rec, rel_send): 295 | # Input shape: [num_sims, num_atoms, num_timesteps, num_dims] 296 | x = inputs.view(inputs.size(0), inputs.size(1), -1) 297 | # New shape: [num_sims, num_atoms, num_timesteps*num_dims] 298 | 299 | x = self.mlp1(x) # 2-layer ELU net per node 300 | 301 | x = self.node2edge(x, rel_rec, rel_send) 302 | x = self.mlp2(x) 303 | x_skip = x 304 | 305 | x = self.edge2node(x, rel_rec, rel_send) 306 | if self.split_point == 0: 307 | x = self.mlp3(x) 308 | x = self.node2edge(x, rel_rec, rel_send) 309 | x = torch.cat((x, x_skip), dim=2) # Skip connection 310 | x = self.mlp4(x) 311 | return self.fc_out(x) 312 | elif self.split_point == 1: 313 | x = self.mlp3(x) 314 | x = self.node2edge(x, rel_rec, rel_send) 315 | x = torch.cat((x, x_skip), dim=2) # Skip connection 316 | y_list = [] 317 | for i in range(self.num_factors): 318 | y = self.mlp4[i](x) 319 | y_list.append( self.fc_out[i](y) ) 320 | return torch.cat(y_list,dim=-1) 321 | elif self.split_point == 2: 322 | y_list = [] 323 | for i in range(self.num_factors): 324 | y = self.mlp3[i](x) 325 | y = self.node2edge(y, rel_rec, rel_send) 326 | y = torch.cat((y, x_skip), dim=2) # Skip connection 327 | y = self.mlp4[i](y) 328 | y_list.append( self.fc_out[i](y) ) 329 | return torch.cat(y_list,dim=-1) 330 | 331 | class RandomEncoder(nn.Module): 332 | """MLP decoder module.""" 333 | 334 | def __init__(self, edge_types_list, cuda_on): 335 | super(RandomEncoder, self).__init__() 336 | 337 | self.edge_types_list = edge_types_list 338 | self.cuda_on = cuda_on 339 | print('Using a random encoder.') 340 | 341 | def forward(self, inputs, rel_rec, rel_send): 342 | n = inputs.shape[1] 343 | output = Variable(torch.randn(inputs.shape[0],n**2-n,sum(self.edge_types_list))) 344 | if self.cuda_on: 345 | output = output.cuda() 346 | return output 347 | 348 | class OnesEncoder(nn.Module): 349 | """MLP decoder module.""" 350 | 351 | def __init__(self, edge_types_list, cuda_on): 352 | super(OnesEncoder, self).__init__() 353 | 354 | self.edge_types_list = edge_types_list 355 | self.cuda_on = cuda_on 356 | print('Using a "ones" encoder.') 357 | 358 | def forward(self, inputs, rel_rec, rel_send): 359 | n = inputs.shape[1] 360 | output = Variable(torch.ones(inputs.shape[0],n**2-n,sum(self.edge_types_list))) 361 | if self.cuda_on: 362 | output = output.cuda() 363 | return output 364 | 365 | 366 | class CNNEncoder(nn.Module): 367 | def __init__(self, n_in, n_hid, n_out, do_prob=0., factor=True): 368 | super(CNNEncoder, self).__init__() 369 | self.dropout_prob = do_prob 370 | 371 | self.factor = factor 372 | 373 | self.cnn = CNN(n_in * 2, n_hid, n_hid, do_prob) 374 | self.mlp1 = MLP(n_hid, n_hid, n_hid, do_prob) 375 | self.mlp2 = MLP(n_hid, n_hid, n_hid, do_prob) 376 | self.mlp3 = MLP(n_hid * 3, n_hid, n_hid, do_prob) 377 | self.fc_out = nn.Linear(n_hid, n_out) 378 | 379 | if self.factor: 380 | print("Using factor graph CNN encoder.") 381 | else: 382 | print("Using CNN encoder.") 383 | 384 | self.init_weights() 385 | 386 | def init_weights(self): 387 | for m in self.modules(): 388 | if isinstance(m, nn.Linear): 389 | nn.init.xavier_normal(m.weight.data) 390 | m.bias.data.fill_(0.1) 391 | 392 | def node2edge_temporal(self, inputs, rel_rec, rel_send): 393 | # NOTE: Assumes that we have the same graph across all samples. 394 | 395 | x = inputs.view(inputs.size(0), inputs.size(1), -1) 396 | 397 | receivers = torch.matmul(rel_rec, x) 398 | receivers = receivers.view(inputs.size(0) * receivers.size(1), 399 | inputs.size(2), inputs.size(3)) 400 | receivers = receivers.transpose(2, 1) 401 | 402 | senders = torch.matmul(rel_send, x) 403 | senders = senders.view(inputs.size(0) * senders.size(1), 404 | inputs.size(2), 405 | inputs.size(3)) 406 | senders = senders.transpose(2, 1) 407 | 408 | # receivers and senders have shape: 409 | # [num_sims * num_edges, num_dims, num_timesteps] 410 | edges = torch.cat([receivers, senders], dim=1) 411 | return edges 412 | 413 | def edge2node(self, x, rel_rec, rel_send): 414 | # NOTE: Assumes that we have the same graph across all samples. 415 | incoming = torch.matmul(rel_rec.t(), x) 416 | return incoming / incoming.size(1) 417 | 418 | def node2edge(self, x, rel_rec, rel_send): 419 | # NOTE: Assumes that we have the same graph across all samples. 420 | receivers = torch.matmul(rel_rec, x) 421 | senders = torch.matmul(rel_send, x) 422 | edges = torch.cat([receivers, senders], dim=2) 423 | return edges 424 | 425 | def forward(self, inputs, rel_rec, rel_send): 426 | 427 | # Input has shape: [num_sims, num_atoms, num_timesteps, num_dims] 428 | edges = self.node2edge_temporal(inputs, rel_rec, rel_send) 429 | x = self.cnn(edges) 430 | x = x.view(inputs.size(0), (inputs.size(1) - 1) * inputs.size(1), -1) 431 | x = self.mlp1(x) 432 | x_skip = x 433 | 434 | if self.factor: 435 | x = self.edge2node(x, rel_rec, rel_send) 436 | x = self.mlp2(x) 437 | 438 | x = self.node2edge(x, rel_rec, rel_send) 439 | x = torch.cat((x, x_skip), dim=2) # Skip connection 440 | x = self.mlp3(x) 441 | 442 | return self.fc_out(x) 443 | 444 | class CNNEncoder_multi(nn.Module): 445 | def __init__(self, n_in, n_hid, edge_types_list, do_prob=0., split_point=0, init_type='xavier_normal'): 446 | super(CNNEncoder_multi, self).__init__() 447 | self.dropout_prob = do_prob 448 | 449 | self.edge_types_list = edge_types_list 450 | 451 | self.init_type = init_type 452 | if self.init_type not in [ 'xavier_normal', 'orthogonal' ]: 453 | raise ValueError('This initialization type has not been coded') 454 | print('Using '+self.init_type+' for encoder weight initialization') 455 | 456 | self.cnn = CNN(n_in * 2, n_hid, n_hid, do_prob) 457 | self.mlp1 = MLP(n_hid, n_hid, n_hid, do_prob) 458 | 459 | self.split_point = split_point 460 | if split_point == 0: 461 | self.mlp2 = MLP(n_hid, n_hid, n_hid, do_prob) 462 | self.mlp3 = MLP(n_hid * 3, n_hid, n_hid, do_prob) 463 | self.fc_out = nn.Linear(n_hid, sum(edge_types_list)) 464 | elif split_point == 1: 465 | self.mlp2 = MLP(n_hid, n_hid, n_hid, do_prob) 466 | self.mlp3 = nn.ModuleList([MLP(n_hid * 3, n_hid, n_hid, do_prob) for _ in edge_types_list]) 467 | self.fc_out = nn.ModuleList([nn.Linear(n_hid, K) for K in edge_types_list]) 468 | elif split_point == 2: 469 | self.mlp2 = nn.ModuleList([MLP(n_hid, n_hid, n_hid, do_prob) for _ in edge_types_list]) 470 | self.mlp3 = nn.ModuleList([MLP(n_hid * 3, n_hid, n_hid, do_prob) for _ in edge_types_list]) 471 | self.fc_out = nn.ModuleList([nn.Linear(n_hid, K) for K in edge_types_list]) 472 | else: 473 | raise ValueError('Split point is not valid, must be 0, 1, or 2') 474 | 475 | self.init_weights() 476 | 477 | def init_weights(self): 478 | for m in self.modules(): 479 | if isinstance(m, nn.Linear): 480 | if self.init_type == 'orthogonal': 481 | nn.init.orthogonal_(m.weight.data) 482 | elif self.init_type == 'xavier_normal': 483 | nn.init.xavier_normal_(m.weight.data) 484 | m.bias.data.fill_(0.1) 485 | 486 | def node2edge_temporal(self, inputs, rel_rec, rel_send): 487 | # NOTE: Assumes that we have the same graph across all samples. 488 | 489 | x = inputs.view(inputs.size(0), inputs.size(1), -1) 490 | 491 | receivers = torch.matmul(rel_rec, x) 492 | receivers = receivers.view(inputs.size(0) * receivers.size(1), 493 | inputs.size(2), inputs.size(3)) 494 | receivers = receivers.transpose(2, 1) 495 | 496 | senders = torch.matmul(rel_send, x) 497 | senders = senders.view(inputs.size(0) * senders.size(1), 498 | inputs.size(2), 499 | inputs.size(3)) 500 | senders = senders.transpose(2, 1) 501 | 502 | # receivers and senders have shape: 503 | # [num_sims * num_edges, num_dims, num_timesteps] 504 | edges = torch.cat([receivers, senders], dim=1) 505 | return edges 506 | 507 | def edge2node(self, x, rel_rec, rel_send): 508 | # NOTE: Assumes that we have the same graph across all samples. 509 | incoming = torch.matmul(rel_rec.t(), x) 510 | return incoming / incoming.size(1) 511 | 512 | def node2edge(self, x, rel_rec, rel_send): 513 | # NOTE: Assumes that we have the same graph across all samples. 514 | receivers = torch.matmul(rel_rec, x) 515 | senders = torch.matmul(rel_send, x) 516 | edges = torch.cat([receivers, senders], dim=2) 517 | return edges 518 | 519 | def forward(self, inputs, rel_rec, rel_send): 520 | 521 | # Input has shape: [num_sims, num_atoms, num_timesteps, num_dims] 522 | edges = self.node2edge_temporal(inputs, rel_rec, rel_send) 523 | x = self.cnn(edges) 524 | x = x.view(inputs.size(0), (inputs.size(1) - 1) * inputs.size(1), -1) 525 | x = self.mlp1(x) 526 | x_skip = x 527 | x = self.edge2node(x, rel_rec, rel_send) 528 | 529 | if self.split_point == 0: 530 | x = self.mlp2(x) 531 | x = self.node2edge(x, rel_rec, rel_send) 532 | x = torch.cat((x, x_skip), dim=2) # Skip connection 533 | x = self.mlp3(x) 534 | return self.fc_out(x) 535 | elif self.split_point == 1: 536 | x = self.mlp2(x) 537 | x = self.node2edge(x, rel_rec, rel_send) 538 | x = torch.cat((x, x_skip), dim=2) # Skip connection 539 | y_list = [] 540 | for i in range(len(self.edge_types_list)): 541 | y = self.mlp3[i](x) 542 | y_list.append( self.fc_out[i](y) ) 543 | return torch.cat(y_list,dim=-1) 544 | elif self.split_point == 2: 545 | y_list = [] 546 | for i in range(len(self.edge_types_list)): 547 | y = self.mlp2[i](x) 548 | y = self.node2edge(y, rel_rec, rel_send) 549 | y = torch.cat((y, x_skip), dim=2) # Skip connection 550 | y = self.mlp3[i](y) 551 | y_list.append( self.fc_out[i](y) ) 552 | return torch.cat(y_list,dim=-1) 553 | 554 | 555 | class MLPDecoder(nn.Module): 556 | """MLP decoder module.""" 557 | 558 | def __init__(self, n_in_node, edge_types, msg_hid, msg_out, n_hid, 559 | do_prob=0., skip_first=False): 560 | super(MLPDecoder, self).__init__() 561 | self.msg_fc1 = nn.ModuleList( 562 | [nn.Linear(2 * n_in_node, msg_hid) for _ in range(edge_types)]) 563 | self.msg_fc2 = nn.ModuleList( 564 | [nn.Linear(msg_hid, msg_out) for _ in range(edge_types)]) 565 | self.msg_out_shape = msg_out 566 | self.skip_first_edge_type = skip_first 567 | 568 | self.out_fc1 = nn.Linear(n_in_node + msg_out, n_hid) 569 | self.out_fc2 = nn.Linear(n_hid, n_hid) 570 | self.out_fc3 = nn.Linear(n_hid, n_in_node) 571 | 572 | print('Using learned interaction net decoder.') 573 | 574 | self.dropout_prob = do_prob 575 | 576 | def single_step_forward(self, single_timestep_inputs, rel_rec, rel_send, 577 | single_timestep_rel_type): 578 | 579 | # single_timestep_inputs has shape 580 | # [batch_size, num_timesteps, num_atoms, num_dims] 581 | 582 | # single_timestep_rel_type has shape: 583 | # [batch_size, num_timesteps, num_atoms*(num_atoms-1), num_edge_types] 584 | 585 | # Node2edge 586 | receivers = torch.matmul(rel_rec, single_timestep_inputs) 587 | senders = torch.matmul(rel_send, single_timestep_inputs) 588 | pre_msg = torch.cat([receivers, senders], dim=-1) 589 | 590 | all_msgs = Variable(torch.zeros(pre_msg.size(0), pre_msg.size(1), 591 | pre_msg.size(2), self.msg_out_shape)) 592 | if single_timestep_inputs.is_cuda: 593 | all_msgs = all_msgs.cuda() 594 | 595 | if self.skip_first_edge_type: 596 | start_idx = 1 597 | else: 598 | start_idx = 0 599 | 600 | # Run separate MLP for every edge type 601 | # NOTE: To exlude one edge type, simply offset range by 1 602 | for i in range(start_idx, len(self.msg_fc2)): 603 | msg = F.relu(self.msg_fc1[i](pre_msg)) 604 | msg = F.dropout(msg, p=self.dropout_prob) 605 | msg = F.relu(self.msg_fc2[i](msg)) 606 | msg = msg * single_timestep_rel_type[:, :, :, i:i + 1] 607 | all_msgs += msg 608 | 609 | # Aggregate all msgs to receiver 610 | agg_msgs = all_msgs.transpose(-2, -1).matmul(rel_rec).transpose(-2, -1) 611 | agg_msgs = agg_msgs.contiguous() 612 | 613 | # Skip connection 614 | aug_inputs = torch.cat([single_timestep_inputs, agg_msgs], dim=-1) 615 | 616 | # Output MLP 617 | pred = F.dropout(F.relu(self.out_fc1(aug_inputs)), p=self.dropout_prob) 618 | pred = F.dropout(F.relu(self.out_fc2(pred)), p=self.dropout_prob) 619 | pred = self.out_fc3(pred) 620 | 621 | # Predict position/velocity difference 622 | return single_timestep_inputs + pred 623 | 624 | def forward(self, inputs, rel_type, rel_rec, rel_send, pred_steps=1): 625 | # NOTE: Assumes that we have the same graph across all samples. 626 | 627 | inputs = inputs.transpose(1, 2).contiguous() 628 | 629 | sizes = [rel_type.size(0), inputs.size(1), rel_type.size(1), 630 | rel_type.size(2)] 631 | rel_type = rel_type.unsqueeze(1).expand(sizes) 632 | 633 | time_steps = inputs.size(1) 634 | assert (pred_steps <= time_steps) 635 | preds = [] 636 | 637 | # Only take n-th timesteps as starting points (n: pred_steps) 638 | last_pred = inputs[:, 0::pred_steps, :, :] 639 | curr_rel_type = rel_type[:, 0::pred_steps, :, :] 640 | # NOTE: Assumes rel_type is constant (i.e. same across all time steps). 641 | 642 | # Run n prediction steps 643 | for step in range(0, pred_steps): 644 | last_pred = self.single_step_forward(last_pred, rel_rec, rel_send, 645 | curr_rel_type) 646 | preds.append(last_pred) 647 | 648 | sizes = [preds[0].size(0), preds[0].size(1) * pred_steps, 649 | preds[0].size(2), preds[0].size(3)] 650 | 651 | output = Variable(torch.zeros(sizes)) 652 | if inputs.is_cuda: 653 | output = output.cuda() 654 | 655 | # Re-assemble correct timeline 656 | for i in range(len(preds)): 657 | output[:, i::pred_steps, :, :] = preds[i] 658 | 659 | pred_all = output[:, :(inputs.size(1) - 1), :, :] 660 | 661 | return pred_all.transpose(1, 2).contiguous() 662 | 663 | 664 | class MLPDecoder_multi(nn.Module): 665 | """MLP decoder module.""" 666 | 667 | def __init__(self, n_in_node, edge_types, edge_types_list, msg_hid, msg_out, n_hid, 668 | do_prob=0., skip_first=False, init_type='default'): 669 | super(MLPDecoder_multi, self).__init__() 670 | self.msg_fc1 = nn.ModuleList( 671 | [nn.Linear(2 * n_in_node, msg_hid) for _ in range(edge_types)]) 672 | self.msg_fc2 = nn.ModuleList( 673 | [nn.Linear(msg_hid, msg_out) for _ in range(edge_types)]) 674 | self.msg_out_shape = msg_out 675 | self.skip_first = skip_first 676 | self.edge_types = edge_types 677 | self.edge_types_list = edge_types_list 678 | 679 | self.out_fc1 = nn.Linear(n_in_node + msg_out, n_hid) 680 | self.out_fc2 = nn.Linear(n_hid, n_hid) 681 | self.out_fc3 = nn.Linear(n_hid, n_in_node) 682 | 683 | print('Using learned interaction net decoder.') 684 | 685 | self.dropout_prob = do_prob 686 | 687 | self.init_type = init_type 688 | if self.init_type not in [ 'xavier_normal', 'orthogonal', 'default' ]: 689 | raise ValueError('This initialization type has not been coded') 690 | #print('Using '+self.init_type+' for decoder weight initialization') 691 | 692 | if self.init_type != 'default': 693 | self.init_weights() 694 | 695 | def init_weights(self): 696 | for m in self.modules(): 697 | if isinstance(m, nn.Linear): 698 | if self.init_type == 'orthogonal': 699 | nn.init.orthogonal_(m.weight.data,gain=0.000001) 700 | elif self.init_type == 'xavier_normal': 701 | nn.init.xavier_normal_(m.weight.data,gain=0.000001) 702 | #m.bias.data.fill_(0.1) 703 | 704 | def single_step_forward(self, single_timestep_inputs, rel_rec, rel_send, 705 | single_timestep_rel_type): 706 | 707 | # single_timestep_inputs has shape 708 | # [batch_size, num_timesteps, num_atoms, num_dims] 709 | 710 | # single_timestep_rel_type has shape: 711 | # [batch_size, num_timesteps, num_atoms*(num_atoms-1), num_edge_types] 712 | 713 | # Node2edge 714 | receivers = torch.matmul(rel_rec, single_timestep_inputs) 715 | senders = torch.matmul(rel_send, single_timestep_inputs) 716 | pre_msg = torch.cat([receivers, senders], dim=-1) 717 | 718 | all_msgs = Variable(torch.zeros(pre_msg.size(0), pre_msg.size(1), 719 | pre_msg.size(2), self.msg_out_shape)) 720 | if single_timestep_inputs.is_cuda: 721 | all_msgs = all_msgs.cuda() 722 | 723 | # non_null_idxs = list of indexs of edge types which as non null (i.e. edges over which messages can be passed) 724 | non_null_idxs = list(range(self.edge_types)) 725 | if self.skip_first: 726 | # if skip_first is True, the first edge type in each factor block is null 727 | edge = 0 728 | for k in self.edge_types_list: 729 | non_null_idxs.remove(edge) 730 | edge += k 731 | 732 | # Run separate MLP for every edge type 733 | # NOTE: To exlude one edge type, simply offset range by 1 734 | for i in non_null_idxs: 735 | msg = F.relu(self.msg_fc1[i](pre_msg)) 736 | msg = F.dropout(msg, p=self.dropout_prob) 737 | msg = F.relu(self.msg_fc2[i](msg)) 738 | msg = msg * single_timestep_rel_type[:, :, :, i:i + 1] 739 | all_msgs += msg 740 | 741 | # Aggregate all msgs to receiver 742 | agg_msgs = all_msgs.transpose(-2, -1).matmul(rel_rec).transpose(-2, -1) 743 | agg_msgs = agg_msgs.contiguous() 744 | 745 | # Skip connection 746 | aug_inputs = torch.cat([single_timestep_inputs, agg_msgs], dim=-1) 747 | 748 | # Output MLP 749 | pred = F.dropout(F.relu(self.out_fc1(aug_inputs)), p=self.dropout_prob) 750 | pred = F.dropout(F.relu(self.out_fc2(pred)), p=self.dropout_prob) 751 | pred = self.out_fc3(pred) 752 | 753 | # Predict position/velocity difference 754 | return single_timestep_inputs + pred 755 | 756 | def forward(self, inputs, rel_type, rel_rec, rel_send, pred_steps=1): 757 | # NOTE: Assumes that we have the same graph across all samples. 758 | 759 | inputs = inputs.transpose(1, 2).contiguous() 760 | 761 | sizes = [rel_type.size(0), inputs.size(1), rel_type.size(1), 762 | rel_type.size(2)] 763 | rel_type = rel_type.unsqueeze(1).expand(sizes) 764 | 765 | time_steps = inputs.size(1) 766 | assert (pred_steps <= time_steps) 767 | preds = [] 768 | 769 | # Only take n-th timesteps as starting points (n: pred_steps) 770 | last_pred = inputs[:, 0::pred_steps, :, :] 771 | curr_rel_type = rel_type[:, 0::pred_steps, :, :] 772 | # NOTE: Assumes rel_type is constant (i.e. same across all time steps). 773 | 774 | # Run n prediction steps 775 | for step in range(0, pred_steps): 776 | last_pred = self.single_step_forward(last_pred, rel_rec, rel_send, 777 | curr_rel_type) 778 | preds.append(last_pred) 779 | 780 | sizes = [preds[0].size(0), preds[0].size(1) * pred_steps, 781 | preds[0].size(2), preds[0].size(3)] 782 | 783 | output = Variable(torch.zeros(sizes)) 784 | if inputs.is_cuda: 785 | output = output.cuda() 786 | 787 | # Re-assemble correct timeline 788 | for i in range(len(preds)): 789 | output[:, i::pred_steps, :, :] = preds[i] 790 | 791 | pred_all = output[:, :(inputs.size(1) - 1), :, :] 792 | 793 | return pred_all.transpose(1, 2).contiguous() 794 | 795 | class MLPDecoder_sigmoid(nn.Module): 796 | """MLP decoder module.""" 797 | 798 | def __init__(self, n_in_node, num_factors, msg_hid, msg_out, n_hid, 799 | do_prob=0., skip_first=False, init_type='default'): 800 | super(MLPDecoder_sigmoid, self).__init__() 801 | self.msg_fc1 = nn.ModuleList( 802 | [nn.Linear(2 * n_in_node, msg_hid) for _ in range(num_factors)]) 803 | self.msg_fc2 = nn.ModuleList( 804 | [nn.Linear(msg_hid, msg_out) for _ in range(num_factors)]) 805 | self.msg_out_shape = msg_out 806 | self.num_factors = num_factors 807 | 808 | self.out_fc1 = nn.Linear(n_in_node + msg_out, n_hid) 809 | self.out_fc2 = nn.Linear(n_hid, n_hid) 810 | self.out_fc3 = nn.Linear(n_hid, n_in_node) 811 | 812 | print('Using learned interaction net decoder.') 813 | 814 | self.dropout_prob = do_prob 815 | 816 | self.init_type = init_type 817 | if self.init_type not in [ 'xavier_normal', 'orthogonal', 'default' ]: 818 | raise ValueError('This initialization type has not been coded') 819 | #print('Using '+self.init_type+' for decoder weight initialization') 820 | 821 | if self.init_type != 'default': 822 | self.init_weights() 823 | 824 | def init_weights(self): 825 | for m in self.modules(): 826 | if isinstance(m, nn.Linear): 827 | if self.init_type == 'orthogonal': 828 | nn.init.orthogonal_(m.weight.data,gain=0.000001) 829 | elif self.init_type == 'xavier_normal': 830 | nn.init.xavier_normal_(m.weight.data,gain=0.000001) 831 | #m.bias.data.fill_(0.1) 832 | 833 | def single_step_forward(self, single_timestep_inputs, rel_rec, rel_send, 834 | single_timestep_rel_type): 835 | 836 | # single_timestep_inputs has shape 837 | # [batch_size, num_timesteps, num_atoms, num_dims] 838 | 839 | # single_timestep_rel_type has shape: 840 | # [batch_size, num_timesteps, num_atoms*(num_atoms-1), num_edge_types] 841 | 842 | # Node2edge 843 | receivers = torch.matmul(rel_rec, single_timestep_inputs) 844 | senders = torch.matmul(rel_send, single_timestep_inputs) 845 | pre_msg = torch.cat([receivers, senders], dim=-1) 846 | 847 | all_msgs = Variable(torch.zeros(pre_msg.size(0), pre_msg.size(1), 848 | pre_msg.size(2), self.msg_out_shape)) 849 | if single_timestep_inputs.is_cuda: 850 | all_msgs = all_msgs.cuda() 851 | 852 | 853 | # Run separate MLP for every edge type 854 | # NOTE: To exlude one edge type, simply offset range by 1 855 | for i in range(self.num_factors): 856 | msg = F.relu(self.msg_fc1[i](pre_msg)) 857 | msg = F.dropout(msg, p=self.dropout_prob) 858 | msg = F.relu(self.msg_fc2[i](msg)) 859 | msg = msg * single_timestep_rel_type[:, :, :, i:i + 1] 860 | all_msgs += msg 861 | 862 | # Aggregate all msgs to receiver 863 | agg_msgs = all_msgs.transpose(-2, -1).matmul(rel_rec).transpose(-2, -1) 864 | agg_msgs = agg_msgs.contiguous() 865 | 866 | # Skip connection 867 | aug_inputs = torch.cat([single_timestep_inputs, agg_msgs], dim=-1) 868 | 869 | # Output MLP 870 | pred = F.dropout(F.relu(self.out_fc1(aug_inputs)), p=self.dropout_prob) 871 | pred = F.dropout(F.relu(self.out_fc2(pred)), p=self.dropout_prob) 872 | pred = self.out_fc3(pred) 873 | 874 | # Predict position/velocity difference 875 | return single_timestep_inputs + pred 876 | 877 | def forward(self, inputs, rel_type, rel_rec, rel_send, pred_steps=1): 878 | # NOTE: Assumes that we have the same graph across all samples. 879 | 880 | inputs = inputs.transpose(1, 2).contiguous() 881 | 882 | sizes = [rel_type.size(0), inputs.size(1), rel_type.size(1), 883 | rel_type.size(2)] 884 | rel_type = rel_type.unsqueeze(1).expand(sizes) 885 | 886 | time_steps = inputs.size(1) 887 | assert (pred_steps <= time_steps) 888 | preds = [] 889 | 890 | # Only take n-th timesteps as starting points (n: pred_steps) 891 | last_pred = inputs[:, 0::pred_steps, :, :] 892 | curr_rel_type = rel_type[:, 0::pred_steps, :, :] 893 | # NOTE: Assumes rel_type is constant (i.e. same across all time steps). 894 | 895 | # Run n prediction steps 896 | for step in range(0, pred_steps): 897 | last_pred = self.single_step_forward(last_pred, rel_rec, rel_send, 898 | curr_rel_type) 899 | preds.append(last_pred) 900 | 901 | sizes = [preds[0].size(0), preds[0].size(1) * pred_steps, 902 | preds[0].size(2), preds[0].size(3)] 903 | 904 | output = Variable(torch.zeros(sizes)) 905 | if inputs.is_cuda: 906 | output = output.cuda() 907 | 908 | # Re-assemble correct timeline 909 | for i in range(len(preds)): 910 | output[:, i::pred_steps, :, :] = preds[i] 911 | 912 | pred_all = output[:, :(inputs.size(1) - 1), :, :] 913 | 914 | return pred_all.transpose(1, 2).contiguous() 915 | 916 | 917 | class StationaryDecoder(nn.Module): 918 | """MLP decoder module.""" 919 | 920 | def __init__(self): 921 | super(StationaryDecoder, self).__init__() 922 | 923 | print('Using stationary decoder.') 924 | 925 | def forward(self, inputs, rel_type, rel_rec, rel_send, pred_steps=1): 926 | 927 | inputs = inputs.transpose(1, 2).contiguous() 928 | 929 | time_steps = inputs.size(1) 930 | assert (pred_steps <= time_steps) 931 | preds = [] 932 | 933 | # Only take n-th timesteps as starting points (n: pred_steps) 934 | last_pred = inputs[:, 0::pred_steps, :, :] 935 | 936 | # Run n prediction steps 937 | for step in range(0, pred_steps): 938 | preds.append(last_pred) 939 | 940 | sizes = [preds[0].size(0), preds[0].size(1) * pred_steps, 941 | preds[0].size(2), preds[0].size(3)] 942 | 943 | output = Variable(torch.zeros(sizes)) 944 | if inputs.is_cuda: 945 | output = output.cuda() 946 | 947 | # Re-assemble correct timeline 948 | for i in range(len(preds)): 949 | output[:, i::pred_steps, :, :] = preds[i] 950 | 951 | pred_all = output[:, :(inputs.size(1) - 1), :, :] 952 | 953 | return pred_all.transpose(1, 2).contiguous() 954 | 955 | 956 | class VelocityStepDecoder(nn.Module): 957 | """MLP decoder module.""" 958 | 959 | def __init__(self, delta_T=0.1): 960 | super(VelocityStepDecoder, self).__init__() 961 | self.delta_T = delta_T 962 | 963 | print('Using velocity step decoder.') 964 | 965 | def forward(self, inputs, rel_type, rel_rec, rel_send, pred_steps=1): 966 | 967 | # input dimensions ofinputs are [batch, particle, time, state] 968 | 969 | inputs = inputs.transpose(1, 2).contiguous() 970 | 971 | time_steps = inputs.size(1) 972 | assert (pred_steps <= time_steps) 973 | preds = [] 974 | 975 | # Only take n-th timesteps as starting points (n: pred_steps) 976 | last_pred = inputs[:, 0::pred_steps, :, :] 977 | 978 | # Run n prediction steps 979 | for step in range(0, pred_steps): 980 | last_pred[:, :, :, 0:2] = last_pred[:, :, :, 0:2] + self.delta_T*last_pred[:, :, :, 2:] 981 | preds.append(last_pred) 982 | 983 | sizes = [preds[0].size(0), preds[0].size(1) * pred_steps, 984 | preds[0].size(2), preds[0].size(3)] 985 | 986 | output = Variable(torch.zeros(sizes)) 987 | if inputs.is_cuda: 988 | output = output.cuda() 989 | 990 | # Re-assemble correct timeline 991 | for i in range(len(preds)): 992 | output[:, i::pred_steps, :, :] = preds[i] 993 | 994 | pred_all = output[:, :(inputs.size(1) - 1), :, :] 995 | 996 | return pred_all.transpose(1, 2).contiguous() 997 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is based on https://github.com/ethanfetaya/NRI 3 | (MIT licence) 4 | """ 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import time 9 | import argparse 10 | import pickle 11 | import os 12 | import datetime 13 | import csv 14 | import math 15 | 16 | import torch.optim as optim 17 | from torch.optim import lr_scheduler 18 | 19 | from utils import * 20 | from modules import * 21 | 22 | parser = argparse.ArgumentParser() 23 | ## arguments related to training ## 24 | parser.add_argument('--epochs', type=int, default=500, 25 | help='Number of epochs to train.') 26 | parser.add_argument('--batch-size', type=int, default=128, 27 | help='Number of samples per batch.') 28 | parser.add_argument('--lr', type=float, default=0.0005, 29 | help='Initial learning rate.') 30 | parser.add_argument('--prediction-steps', type=int, default=10, metavar='N', 31 | help='Num steps to predict before re-using teacher forcing.') 32 | parser.add_argument('--lr-decay', type=int, default=200, 33 | help='After how epochs to decay LR by a factor of gamma.') 34 | parser.add_argument('--gamma', type=float, default=0.5, 35 | help='LR decay factor.') 36 | parser.add_argument('--patience', type=int, default=500, 37 | help='Early stopping patience') 38 | parser.add_argument('--encoder-dropout', type=float, default=0.0, 39 | help='Dropout rate (1 - keep probability).') 40 | parser.add_argument('--decoder-dropout', type=float, default=0.0, 41 | help='Dropout rate (1 - keep probability).') 42 | parser.add_argument('--dont-split-data', action='store_true', default=False, 43 | help='Whether to not split training and validation data into two parts') 44 | parser.add_argument('--split-enc-only', action='store_true', default=False, 45 | help='Whether to give the encoder the first half of trajectories \ 46 | and the decoder the whole of the trajectories') 47 | 48 | ## arguments related to loss function ## 49 | parser.add_argument('--var', type=float, default=5e-5, 50 | help='Output variance.') 51 | parser.add_argument('--beta', type=float, default=1.0, 52 | help='KL-divergence beta factor' ) 53 | parser.add_argument('--mse-loss', action='store_true', default=False, 54 | help='Use the MSE as the loss') 55 | 56 | ## arguments related to weight and bias initialisation ## 57 | parser.add_argument('--seed', type=int, default=1, 58 | help='Random seed.') 59 | parser.add_argument('--encoder-init-type',type=str, default='xavier_normal', 60 | help='The type of weight initialization to use in the encoder') 61 | parser.add_argument('--decoder-init-type',type=str, default='default', 62 | help='The type of weight initialization to use in the decoder') 63 | parser.add_argument('--encoder-bias-scale',type=float, default=0.1, 64 | help='The type of weight initialization to use in the encoder') 65 | 66 | ## arguments related to changing the model ## 67 | parser.add_argument('--NRI', action='store_true', default=False, 68 | help='Use the NRI model, rather than the fNRI model') 69 | parser.add_argument('--edge-types-list', nargs='+', default=[2,2], 70 | help='The number of edge types to infer.') # takes arguments from cmd line as: --edge-types-list 2 2 71 | parser.add_argument('--split-point', type=int, default=0, 72 | help='The point at which factor graphs are split up in the encoder' ) 73 | parser.add_argument('--encoder', type=str, default='mlp', 74 | help='Type of path encoder model (mlp or cnn).') 75 | parser.add_argument('--decoder', type=str, default='mlp', 76 | help='Type of decoder model (mlp, rnn, or sim).') 77 | parser.add_argument('--encoder-hidden', type=int, default=256, 78 | help='Number of hidden units.') 79 | parser.add_argument('--decoder-hidden', type=int, default=256, 80 | help='Number of hidden units.') 81 | parser.add_argument('--temp', type=float, default=0.5, 82 | help='Temperature for Gumbel softmax.') 83 | parser.add_argument('--skip-first', action='store_true', default=False, 84 | help='Skip the first edge type in each block in the decoder, i.e. it represents no-edge.') 85 | parser.add_argument('--hard', action='store_true', default=False, 86 | help='Uses discrete samples in training forward pass.') 87 | parser.add_argument('--soft-valid', action='store_true', default=False, 88 | help='Dont use hard in validation') 89 | parser.add_argument('--prior', action='store_true', default=False, 90 | help='Whether to use sparsity prior.') 91 | 92 | ## arguments related to the simulation data ## 93 | parser.add_argument('--sim-folder', type=str, default='springcharge_5', 94 | help='Name of the folder in the data folder to load simulation data from') 95 | parser.add_argument('--data-folder', type=str, default='data', 96 | help='Name of the data folder to load data from') 97 | parser.add_argument('--num-atoms', type=int, default=5, 98 | help='Number of atoms in simulation.') 99 | parser.add_argument('--dims', type=int, default=4, 100 | help='The number of input dimensions (position + velocity).') 101 | parser.add_argument('--timesteps', type=int, default=49, 102 | help='The number of time steps per sample.') 103 | 104 | ## Saving, loading etc. ## 105 | parser.add_argument('--no-cuda', action='store_true', default=False, 106 | help='Disables CUDA training.') 107 | parser.add_argument('--save-folder', type=str, default='logs', 108 | help='Where to save the trained model, leave empty to not save anything.') 109 | parser.add_argument('--load-folder', type=str, default='', 110 | help='Where to load the trained model if finetunning. ' + 111 | 'Leave empty to train from scratch') 112 | parser.add_argument('--test', action='store_true', default=False, 113 | help='Skip training and validation') 114 | parser.add_argument('--plot', action='store_true', default=False, 115 | help='Skip training and plot trajectories against actual') 116 | parser.add_argument('--no-edge-acc', action='store_true', default=False, 117 | help='Skip training and plot accuracy distributions') 118 | 119 | 120 | args = parser.parse_args() 121 | args.cuda = not args.no_cuda and torch.cuda.is_available() 122 | args.edge_types_list = list(map(int, args.edge_types_list)) 123 | args.edge_types_list.sort(reverse=True) 124 | 125 | if all( (isinstance(k, int) and k >= 1) for k in args.edge_types_list): 126 | if args.NRI: 127 | edge_types = np.prod(args.edge_types_list) 128 | else: 129 | edge_types = sum(args.edge_types_list) 130 | else: 131 | raise ValueError('Could not compute the edge-types-list') 132 | 133 | if args.NRI: 134 | print('Using NRI model') 135 | if args.split_point != 0: 136 | args.split_point = 0 137 | print(args) 138 | 139 | if args.prior: 140 | prior = [ [0.9, 0.1] , [0.9, 0.1] ] # TODO: hard coded for now 141 | if not all( prior[i].size == edge_types_list[i] for i in range(len(args.edge_types_list))): 142 | raise ValueError('Prior is incompatable with the edge types list') 143 | print("Using prior: "+str(prior)) 144 | log_prior = [] 145 | for i in range(len(args.edge_types_list)): 146 | prior_i = np.array(prior[i]) 147 | log_prior_i = torch.FloatTensor(np.log(prior)) 148 | log_prior_i = torch.unsqueeze(log_prior_i, 0) 149 | log_prior_i = torch.unsqueeze(log_prior_i, 0) 150 | log_prior_i = Variable(log_prior_i) 151 | log_prior.append(log_prior_i) 152 | if args.cuda: 153 | log_prior = log_prior.cuda() 154 | 155 | np.random.seed(args.seed) 156 | torch.manual_seed(args.seed) 157 | if args.cuda: 158 | torch.cuda.manual_seed(args.seed) 159 | 160 | 161 | # Save model and meta-data. Always saves in a new sub-folder. 162 | if args.save_folder: 163 | exp_counter = 0 164 | now = datetime.datetime.now() 165 | timestamp = now.isoformat().replace(':','-')[:-7] 166 | save_folder = os.path.join(args.save_folder,'exp'+timestamp) 167 | os.makedirs(save_folder) 168 | meta_file = os.path.join(save_folder, 'metadata.pkl') 169 | encoder_file = os.path.join(save_folder, 'encoder.pt') 170 | decoder_file = os.path.join(save_folder, 'decoder.pt') 171 | 172 | log_file = os.path.join(save_folder, 'log.txt') 173 | log_csv_file = os.path.join(save_folder, 'log_csv.csv') 174 | log = open(log_file, 'w') 175 | log_csv = open(log_csv_file, 'w') 176 | csv_writer = csv.writer(log_csv, delimiter=',') 177 | 178 | pickle.dump({'args': args}, open(meta_file, "wb")) 179 | par_file = open(os.path.join(save_folder,'args.txt'),'w') 180 | print(args,file=par_file) 181 | par_file.flush 182 | par_file.close() 183 | 184 | perm_csv_file = os.path.join(save_folder, 'perm_csv.csv') 185 | perm_csv = open(perm_csv_file, 'w') 186 | perm_writer = csv.writer(perm_csv, delimiter=',') 187 | else: 188 | print("WARNING: No save_folder provided!" + 189 | "Testing (within this script) will throw an error.") 190 | 191 | 192 | if args.NRI: 193 | train_loader, valid_loader, test_loader, loc_max, loc_min, vel_max, vel_min = load_data_NRI( 194 | args.batch_size, args.sim_folder, shuffle=True, 195 | data_folder=args.data_folder) 196 | else: 197 | train_loader, valid_loader, test_loader, loc_max, loc_min, vel_max, vel_min = load_data_fNRI( 198 | args.batch_size, args.sim_folder, shuffle=True, 199 | data_folder=args.data_folder) 200 | 201 | 202 | # Generate off-diagonal interaction graph 203 | off_diag = np.ones([args.num_atoms, args.num_atoms]) - np.eye(args.num_atoms) 204 | rel_rec = np.array(encode_onehot(np.where(off_diag)[1]), dtype=np.float32) 205 | rel_send = np.array(encode_onehot(np.where(off_diag)[0]), dtype=np.float32) 206 | rel_rec = torch.FloatTensor(rel_rec) 207 | rel_send = torch.FloatTensor(rel_send) 208 | 209 | if args.NRI: 210 | edge_types_list = [ edge_types ] 211 | else: 212 | edge_types_list = args.edge_types_list 213 | 214 | if args.encoder == 'mlp': 215 | encoder = MLPEncoder_multi(args.timesteps * args.dims, args.encoder_hidden, 216 | edge_types_list, args.encoder_dropout, 217 | split_point=args.split_point, 218 | init_type=args.encoder_init_type, 219 | bias_init=args.encoder_bias_scale) 220 | 221 | elif args.encoder == 'cnn': 222 | encoder = CNNEncoder_multi(args.dims, args.encoder_hidden, 223 | edge_types_list, 224 | args.encoder_dropout, 225 | split_point=args.split_point, 226 | init_type=args.encoder_init_type) 227 | 228 | elif args.encoder == 'random': 229 | encoder = RandomEncoder(args.edge_types_list, args.cuda) 230 | 231 | elif args.encoder == 'ones': 232 | encoder = OnesEncoder(args.edge_types_list, args.cuda) 233 | 234 | if args.decoder == 'mlp': 235 | decoder = MLPDecoder_multi(n_in_node=args.dims, 236 | edge_types=edge_types, 237 | edge_types_list=edge_types_list, 238 | msg_hid=args.decoder_hidden, 239 | msg_out=args.decoder_hidden, 240 | n_hid=args.decoder_hidden, 241 | do_prob=args.decoder_dropout, 242 | skip_first=args.skip_first, 243 | init_type=args.decoder_init_type) 244 | 245 | elif args.decoder == 'stationary': 246 | decoder = StationaryDecoder() 247 | 248 | elif args.decoder == 'velocity': 249 | decoder = VelocityStepDecoder() 250 | 251 | if args.load_folder: 252 | print('Loading model from: '+args.load_folder) 253 | encoder_file = os.path.join(args.load_folder, 'encoder.pt') 254 | decoder_file = os.path.join(args.load_folder, 'decoder.pt') 255 | if not args.cuda: 256 | encoder.load_state_dict(torch.load(encoder_file,map_location='cpu')) 257 | decoder.load_state_dict(torch.load(decoder_file,map_location='cpu')) 258 | else: 259 | encoder.load_state_dict(torch.load(encoder_file)) 260 | decoder.load_state_dict(torch.load(decoder_file)) 261 | args.save_folder = False 262 | 263 | optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), 264 | lr=args.lr) 265 | 266 | scheduler = lr_scheduler.StepLR(optimizer, step_size=args.lr_decay, 267 | gamma=args.gamma) 268 | 269 | 270 | if args.cuda: 271 | encoder.cuda() 272 | decoder.cuda() 273 | rel_rec = rel_rec.cuda() 274 | rel_send = rel_send.cuda() 275 | 276 | rel_rec = Variable(rel_rec) 277 | rel_send = Variable(rel_send) 278 | 279 | 280 | def train(epoch, best_val_loss): 281 | t = time.time() 282 | nll_train = [] 283 | nll_var_train = [] 284 | mse_train = [] 285 | 286 | kl_train = [] 287 | kl_list_train = [] 288 | kl_var_list_train = [] 289 | 290 | acc_train = [] 291 | acc_var_train = [] 292 | perm_train = [] 293 | acc_var_blocks_train = [] 294 | acc_blocks_train = [] 295 | 296 | KLb_train = [] 297 | KLb_blocks_train = [] 298 | 299 | encoder.train() 300 | decoder.train() 301 | scheduler.step() 302 | if not args.plot: 303 | for batch_idx, (data, relations) in enumerate(train_loader): # relations are the ground truth interactions graphs 304 | if args.cuda: 305 | data, relations = data.cuda(), relations.cuda() 306 | data, relations = Variable(data), Variable(relations) 307 | 308 | if args.dont_split_data: 309 | data_encoder = data[:, :, :args.timesteps, :].contiguous() 310 | data_decoder = data[:, :, :args.timesteps, :].contiguous() 311 | elif args.split_enc_only: 312 | data_encoder = data[:, :, :args.timesteps, :].contiguous() 313 | data_decoder = data 314 | else: 315 | assert (data.size(2) - args.timesteps) >= args.timesteps 316 | data_encoder = data[:, :, :args.timesteps, :].contiguous() 317 | data_decoder = data[:, :, -args.timesteps:, :].contiguous() 318 | 319 | optimizer.zero_grad() 320 | 321 | logits = encoder(data_encoder, rel_rec, rel_send) 322 | 323 | if args.NRI: 324 | # dim of logits, edges and prob are [batchsize, N^2-N, edgetypes] where N = no. of particles 325 | edges = gumbel_softmax(logits, tau=args.temp, hard=args.hard) 326 | prob = my_softmax(logits, -1) 327 | 328 | loss_kl = kl_categorical_uniform(prob, args.num_atoms, edge_types) 329 | loss_kl_split = [ loss_kl ] 330 | loss_kl_var_split = [ kl_categorical_uniform_var(prob, args.num_atoms, edge_types) ] 331 | 332 | KLb_train.append( 0 ) 333 | KLb_blocks_train.append([0]) 334 | 335 | if args.no_edge_acc: 336 | acc_perm, perm, acc_blocks, acc_var, acc_var_blocks = 0, np.array([0]), np.zeros(len(args.edge_types_list)), 0, np.zeros(len(args.edge_types_list)) 337 | else: 338 | acc_perm, perm, acc_blocks, acc_var, acc_var_blocks = edge_accuracy_perm_NRI(logits, relations, args.edge_types_list) 339 | 340 | else: 341 | # dim of logits, edges and prob are [batchsize, N^2-N, sum(edge_types_list)] where N = no. of particles 342 | logits_split = torch.split(logits, args.edge_types_list, dim=-1) 343 | edges_split = tuple([gumbel_softmax(logits_i, tau=args.temp, hard=args.hard) 344 | for logits_i in logits_split ]) 345 | edges = torch.cat(edges_split, dim=-1) 346 | prob_split = [my_softmax(logits_i, -1) for logits_i in logits_split ] 347 | 348 | if args.prior: 349 | loss_kl_split = [kl_categorical(prob_split[type_idx], log_prior[type_idx], args.num_atoms) 350 | for type_idx in range(len(args.edge_types_list)) ] 351 | loss_kl = sum(loss_kl_split) 352 | else: 353 | loss_kl_split = [ kl_categorical_uniform(prob_split[type_idx], args.num_atoms, 354 | args.edge_types_list[type_idx]) 355 | for type_idx in range(len(args.edge_types_list)) ] 356 | loss_kl = sum(loss_kl_split) 357 | 358 | loss_kl_var_split = [ kl_categorical_uniform_var(prob_split[type_idx], args.num_atoms, 359 | args.edge_types_list[type_idx]) 360 | for type_idx in range(len(args.edge_types_list)) ] 361 | 362 | if args.no_edge_acc: 363 | acc_perm, perm, acc_blocks, acc_var, acc_var_blocks = 0, np.array([0]), np.zeros(len(args.edge_types_list)), 0, np.zeros(len(args.edge_types_list)) 364 | else: 365 | acc_perm, perm, acc_blocks, acc_var, acc_var_blocks = edge_accuracy_perm_fNRI(logits_split, relations, 366 | args.edge_types_list, args.skip_first) 367 | 368 | KLb_blocks = KL_between_blocks(prob_split, args.num_atoms) 369 | KLb_train.append(sum(KLb_blocks).data.item()) 370 | KLb_blocks_train.append([KL.data.item() for KL in KLb_blocks]) 371 | 372 | target = data_decoder[:, :, 1:, :] # dimensions are [batch, particle, time, state] 373 | output = decoder(data_decoder, edges, rel_rec, rel_send, args.prediction_steps) 374 | 375 | loss_nll = nll_gaussian(output, target, args.var) 376 | loss_nll_var = nll_gaussian_var(output, target, args.var) 377 | 378 | 379 | if args.mse_loss: 380 | loss = F.mse_loss(output, target) 381 | else: 382 | loss = loss_nll 383 | if not math.isclose(args.beta, 0, rel_tol=1e-6): 384 | loss += args.beta*loss_kl 385 | 386 | perm_train.append(perm) 387 | acc_train.append(acc_perm) 388 | acc_blocks_train.append(acc_blocks) 389 | acc_var_train.append(acc_var) 390 | acc_var_blocks_train.append(acc_var_blocks) 391 | 392 | loss.backward() 393 | optimizer.step() 394 | 395 | mse_train.append(F.mse_loss(output, target).data.item()) 396 | nll_train.append(loss_nll.data.item()) 397 | kl_train.append(loss_kl.data.item()) 398 | kl_list_train.append([kl.data.item() for kl in loss_kl_split]) 399 | 400 | nll_var_train.append(loss_nll_var.data.item()) 401 | kl_var_list_train.append([kl_var.data.item() for kl_var in loss_kl_var_split]) 402 | 403 | 404 | nll_val = [] 405 | nll_var_val = [] 406 | mse_val = [] 407 | 408 | kl_val = [] 409 | kl_list_val = [] 410 | kl_var_list_val = [] 411 | 412 | acc_val = [] 413 | acc_var_val = [] 414 | acc_blocks_val = [] 415 | acc_var_blocks_val = [] 416 | perm_val = [] 417 | 418 | KLb_val = [] 419 | KLb_blocks_val = [] # KL between blocks list 420 | 421 | nll_M_val = [] 422 | nll_M_var_val = [] 423 | 424 | encoder.eval() 425 | decoder.eval() 426 | for batch_idx, (data, relations) in enumerate(valid_loader): 427 | with torch.no_grad(): 428 | if args.cuda: 429 | data, relations = data.cuda(), relations.cuda() 430 | 431 | if args.dont_split_data: 432 | data_encoder = data[:, :, :args.timesteps, :].contiguous() 433 | data_decoder = data[:, :, :args.timesteps, :].contiguous() 434 | elif args.split_enc_only: 435 | data_encoder = data[:, :, :args.timesteps, :].contiguous() 436 | data_decoder = data 437 | else: 438 | assert (data.size(2) - args.timesteps) >= args.timesteps 439 | data_encoder = data[:, :, :args.timesteps, :].contiguous() 440 | data_decoder = data[:, :, -args.timesteps:, :].contiguous() 441 | 442 | # dim of logits, edges and prob are [batchsize, N^2-N, sum(edge_types_list)] where N = no. of particles 443 | logits = encoder(data_encoder, rel_rec, rel_send) 444 | 445 | if args.NRI: 446 | # dim of logits, edges and prob are [batchsize, N^2-N, edgetypes] where N = no. of particles 447 | edges = gumbel_softmax(logits, tau=args.temp, hard=args.hard) # uses concrete distribution (for hard=False) to sample edge types 448 | prob = my_softmax(logits, -1) # my_softmax returns the softmax over the edgetype dim 449 | 450 | loss_kl = kl_categorical_uniform(prob, args.num_atoms, edge_types) 451 | loss_kl_split = [ loss_kl ] 452 | loss_kl_var_split = [ kl_categorical_uniform_var(prob, args.num_atoms, edge_types) ] 453 | 454 | KLb_val.append( 0 ) 455 | KLb_blocks_val.append([0]) 456 | 457 | if args.no_edge_acc: 458 | acc_perm, perm, acc_blocks, acc_var, acc_var_blocks = 0, np.array([0]), np.zeros(len(args.edge_types_list)), 0, np.zeros(len(args.edge_types_list)) 459 | else: 460 | acc_perm, perm, acc_blocks, acc_var, acc_var_blocks = edge_accuracy_perm_NRI(logits, relations, args.edge_types_list) 461 | 462 | else: 463 | # dim of logits, edges and prob are [batchsize, N^2-N, sum(edge_types_list)] where N = no. of particles 464 | logits_split = torch.split(logits, args.edge_types_list, dim=-1) 465 | edges_split = tuple([gumbel_softmax(logits_i, tau=args.temp, hard=args.hard) 466 | for logits_i in logits_split ]) 467 | edges = torch.cat(edges_split, dim=-1) 468 | prob_split = [my_softmax(logits_i, -1) for logits_i in logits_split ] 469 | 470 | if args.prior: 471 | loss_kl_split = [kl_categorical(prob_split[type_idx], log_prior[type_idx], args.num_atoms) 472 | for type_idx in range(len(args.edge_types_list)) ] 473 | loss_kl = sum(loss_kl_split) 474 | else: 475 | loss_kl_split = [ kl_categorical_uniform(prob_split[type_idx], args.num_atoms, 476 | args.edge_types_list[type_idx]) 477 | for type_idx in range(len(args.edge_types_list)) ] 478 | loss_kl = sum(loss_kl_split) 479 | 480 | loss_kl_var_split = [ kl_categorical_uniform_var(prob_split[type_idx], args.num_atoms, 481 | args.edge_types_list[type_idx]) 482 | for type_idx in range(len(args.edge_types_list)) ] 483 | 484 | if args.no_edge_acc: 485 | acc_perm, perm, acc_blocks, acc_var, acc_var_blocks = 0, np.array([0]), np.zeros(len(args.edge_types_list)), 0, np.zeros(len(args.edge_types_list)) 486 | else: 487 | acc_perm, perm, acc_blocks, acc_var, acc_var_blocks = edge_accuracy_perm_fNRI(logits_split, relations, 488 | args.edge_types_list, args.skip_first) 489 | 490 | KLb_blocks = KL_between_blocks(prob_split, args.num_atoms) 491 | KLb_val.append(sum(KLb_blocks).data.item()) 492 | KLb_blocks_val.append([KL.data.item() for KL in KLb_blocks]) 493 | 494 | target = data_decoder[:, :, 1:, :] # dimensions are [batch, particle, time, state] 495 | output = decoder(data_decoder, edges, rel_rec, rel_send, 1) 496 | 497 | if args.plot: 498 | import matplotlib.pyplot as plt 499 | output_plot = decoder(data_decoder, edges, rel_rec, rel_send, 49) 500 | 501 | if args.NRI: 502 | acc_batch, perm, acc_blocks_batch = edge_accuracy_perm_NRI_batch(logits, relations, 503 | args.edge_types_list) 504 | else: 505 | acc_batch, perm, acc_blocks_batch = edge_accuracy_perm_fNRI_batch(logits_split, relations, 506 | args.edge_types_list) 507 | 508 | from trajectory_plot import draw_lines 509 | for i in range(args.batch_size): 510 | fig = plt.figure(figsize=(7, 7)) 511 | ax = fig.add_axes([0, 0, 1, 1]) 512 | xmin_t, ymin_t, xmax_t, ymax_t = draw_lines( target, i, linestyle=':', alpha=0.6 ) 513 | xmin_o, ymin_o, xmax_o, ymax_o = draw_lines( output_plot.detach().numpy(), i, linestyle='-' ) 514 | 515 | ax.set_xlim([min(xmin_t, xmin_o), max(xmax_t, xmax_o)]) 516 | ax.set_ylim([min(ymin_t, ymin_o), max(ymax_t, ymax_o)]) 517 | ax.set_xticks([]) 518 | ax.set_yticks([]) 519 | block_names = [ 'layer ' + str(j) for j in range(len(args.edge_types_list)) ] 520 | #block_names = [ 'springs', 'charges' ] 521 | acc_text = [ block_names[j] + ' acc: {:02.0f}%'.format(100*acc_blocks_batch[i,j]) 522 | for j in range(acc_blocks_batch.shape[1]) ] 523 | acc_text = ', '.join(acc_text) 524 | plt.text( 0.5, 0.95, acc_text, horizontalalignment='center', transform=ax.transAxes ) 525 | #plt.savefig(os.path.join(args.load_folder,str(i)+'_pred_and_true.png'), dpi=300) 526 | plt.show() 527 | 528 | loss_nll = nll_gaussian(output, target, args.var) 529 | loss_nll_var = nll_gaussian_var(output, target, args.var) 530 | 531 | output_M = decoder(data_decoder, edges, rel_rec, rel_send, args.prediction_steps) 532 | loss_nll_M = nll_gaussian(output_M, target, args.var) 533 | loss_nll_M_var = nll_gaussian_var(output_M, target, args.var) 534 | 535 | perm_val.append(perm) 536 | acc_val.append(acc_perm) 537 | acc_blocks_val.append(acc_blocks) 538 | acc_var_val.append(acc_var) 539 | acc_var_blocks_val.append(acc_var_blocks) 540 | 541 | mse_val.append(F.mse_loss(output_M, target).data.item()) 542 | nll_val.append(loss_nll.data.item()) 543 | nll_var_val.append(loss_nll_var.data.item()) 544 | 545 | kl_val.append(loss_kl.data.item()) 546 | kl_list_val.append([kl_loss.data.item() for kl_loss in loss_kl_split]) 547 | kl_var_list_val.append([kl_var.data.item() for kl_var in loss_kl_var_split]) 548 | 549 | nll_M_val.append(loss_nll_M.data.item()) 550 | nll_M_var_val.append(loss_nll_M_var.data.item()) 551 | 552 | print('Epoch: {:03d}'.format(epoch), 553 | 'perm_val: '+str( np.around(np.mean(np.array(perm_val),axis=0),4 ) ), 554 | 'time: {:.1f}s'.format(time.time() - t)) 555 | print('nll_trn: {:.2f}'.format(np.mean(nll_train)), 556 | 'kl_trn: {:.5f}'.format(np.mean(kl_train)), 557 | 'mse_trn: {:.10f}'.format(np.mean(mse_train)), 558 | 'acc_trn: {:.5f}'.format(np.mean(acc_train)), 559 | 'KLb_trn: {:.5f}'.format(np.mean(KLb_train)) 560 | ) 561 | print('acc_b_trn: '+str( np.around(np.mean(np.array(acc_blocks_train),axis=0),4 ) ), 562 | 'kl_trn: '+str( np.around(np.mean(np.array(kl_list_train),axis=0),4 ) ) 563 | ) 564 | print('nll_val: {:.2f}'.format(np.mean(nll_M_val)), 565 | 'kl_val: {:.5f}'.format(np.mean(kl_val)), 566 | 'mse_val: {:.10f}'.format(np.mean(mse_val)), 567 | 'acc_val: {:.5f}'.format(np.mean(acc_val)), 568 | 'KLb_val: {:.5f}'.format(np.mean(KLb_val)) 569 | ) 570 | print('acc_b_val: '+str( np.around(np.mean(np.array(acc_blocks_val),axis=0),4 ) ), 571 | 'kl_val: '+str( np.around(np.mean(np.array(kl_list_val),axis=0),4 ) ) 572 | ) 573 | print('Epoch: {:04d}'.format(epoch), 574 | 'perm_val: '+str( np.around(np.mean(np.array(perm_val),axis=0),4 ) ), 575 | 'time: {:.4f}s'.format(time.time() - t), 576 | file=log) 577 | print('nll_trn: {:.5f}'.format(np.mean(nll_train)), 578 | 'kl_trn: {:.5f}'.format(np.mean(kl_train)), 579 | 'mse_trn: {:.10f}'.format(np.mean(mse_train)), 580 | 'acc_trn: {:.5f}'.format(np.mean(acc_train)), 581 | 'KLb_trn: {:.5f}'.format(np.mean(KLb_train)), 582 | 'acc_b_trn: '+str( np.around(np.mean(np.array(acc_blocks_train),axis=0),4 ) ), 583 | 'kl_trn: '+str( np.around(np.mean(np.array(kl_list_train),axis=0),4 ) ), 584 | file=log ) 585 | print('nll_val: {:.5f}'.format(np.mean(nll_M_val)), 586 | 'kl_val: {:.5f}'.format(np.mean(kl_val)), 587 | 'mse_val: {:.10f}'.format(np.mean(mse_val)), 588 | 'acc_val: {:.5f}'.format(np.mean(acc_val)), 589 | 'KLb_val: {:.5f}'.format(np.mean(KLb_val)), 590 | 'acc_b_val: '+str( np.around(np.mean(np.array(acc_blocks_val),axis=0),4 ) ), 591 | 'kl_val: '+str( np.around(np.mean(np.array(kl_list_val),axis=0),4 ) ), 592 | file=log) 593 | if epoch == 0: 594 | labels = [ 'epoch', 'nll trn', 'kl trn', 'mse train', 'KLb trn', 'acc trn'] 595 | labels += [ 'b'+str(i)+ ' acc trn' for i in range( len(args.edge_types_list) ) ] + [ 'nll var trn' ] 596 | labels += [ 'b'+str(i)+ ' kl trn' for i in range( len(kl_list_train[0]) ) ] 597 | labels += [ 'b'+str(i)+' kl var trn' for i in range( len(kl_list_train[0]) ) ] 598 | labels += [ 'acc var trn'] + [ 'b'+str(i)+' acc var trn' for i in range( len(args.edge_types_list) ) ] 599 | labels += [ 'nll val', 'nll_M_val', 'kl val', 'mse val', 'KLb val', 'acc val' ] 600 | labels += [ 'b'+str(i)+ ' acc val' for i in range( len(args.edge_types_list) ) ] 601 | labels += [ 'nll var val', 'nll_M var val' ] 602 | labels += [ 'b'+str(i)+ ' kl val' for i in range( len(kl_list_val[0]) ) ] 603 | labels += [ 'b'+str(i)+' kl var val' for i in range( len(kl_list_val[0]) ) ] 604 | labels += [ 'acc var val'] + [ 'b'+str(i)+' acc var val' for i in range( len(args.edge_types_list) ) ] 605 | csv_writer.writerow( labels ) 606 | 607 | labels = [ 'trn '+str(i) for i in range(len(perm_train[0])) ] 608 | labels += [ 'val '+str(i) for i in range(len(perm_val[0])) ] 609 | perm_writer.writerow( labels ) 610 | 611 | csv_writer.writerow( [epoch, np.mean(nll_train), np.mean(kl_train), 612 | np.mean(mse_train), np.mean(KLb_train), np.mean(acc_train)] + 613 | list(np.mean(np.array(acc_blocks_train),axis=0)) + 614 | [np.mean(nll_var_train)] + 615 | list(np.mean(np.array(kl_list_train),axis=0)) + 616 | list(np.mean(np.array(kl_var_list_train),axis=0)) + 617 | #list(np.mean(np.array(KLb_blocks_train),axis=0)) + 618 | [np.mean(acc_var_train)] + list(np.mean(np.array(acc_var_blocks_train),axis=0)) + 619 | [np.mean(nll_val), np.mean(nll_M_val), np.mean(kl_val), np.mean(mse_val), 620 | np.mean(KLb_val), np.mean(acc_val) ] + 621 | list(np.mean(np.array(acc_blocks_val ),axis=0)) + 622 | [np.mean(nll_var_val), np.mean(nll_M_var_val)] + 623 | list(np.mean(np.array(kl_list_val),axis=0)) + 624 | list(np.mean(np.array(kl_var_list_val),axis=0)) + 625 | #list(np.mean(np.array(KLb_blocks_val),axis=0)) 626 | [np.mean(acc_var_val)] + list(np.mean(np.array(acc_var_blocks_val),axis=0)) 627 | ) 628 | perm_writer.writerow( list(np.mean(np.array(perm_train),axis=0)) + 629 | list(np.mean(np.array(perm_val),axis=0)) 630 | ) 631 | 632 | log.flush() 633 | if args.save_folder and np.mean(nll_M_val) < best_val_loss: 634 | torch.save(encoder.state_dict(), encoder_file) 635 | torch.save(decoder.state_dict(), decoder_file) 636 | print('Best model so far, saving...') 637 | return np.mean(nll_M_val) 638 | 639 | 640 | def test(): 641 | t = time.time() 642 | nll_test = [] 643 | nll_var_test = [] 644 | 645 | mse_1_test = [] 646 | mse_10_test = [] 647 | mse_20_test = [] 648 | 649 | kl_test = [] 650 | kl_list_test = [] 651 | kl_var_list_test = [] 652 | 653 | acc_test = [] 654 | acc_var_test = [] 655 | acc_blocks_test = [] 656 | acc_var_blocks_test = [] 657 | perm_test = [] 658 | 659 | KLb_test = [] 660 | KLb_blocks_test = [] # KL between blocks list 661 | 662 | nll_M_test = [] 663 | nll_M_var_test = [] 664 | 665 | encoder.eval() 666 | decoder.eval() 667 | if not args.cuda: 668 | encoder.load_state_dict(torch.load(encoder_file,map_location='cpu')) 669 | decoder.load_state_dict(torch.load(decoder_file,map_location='cpu')) 670 | else: 671 | encoder.load_state_dict(torch.load(encoder_file)) 672 | decoder.load_state_dict(torch.load(decoder_file)) 673 | 674 | for batch_idx, (data, relations) in enumerate(test_loader): 675 | with torch.no_grad(): 676 | if args.cuda: 677 | data, relations = data.cuda(), relations.cuda() 678 | 679 | assert (data.size(2) - args.timesteps) >= args.timesteps 680 | data_encoder = data[:, :, :args.timesteps, :].contiguous() 681 | data_decoder = data[:, :, -args.timesteps:, :].contiguous() 682 | 683 | # dim of logits, edges and prob are [batchsize, N^2-N, sum(edge_types_list)] where N = no. of particles 684 | logits = encoder(data_encoder, rel_rec, rel_send) 685 | 686 | if args.NRI: 687 | edges = gumbel_softmax(logits, tau=args.temp, hard=args.hard) 688 | prob = my_softmax(logits, -1) 689 | 690 | loss_kl = kl_categorical_uniform(prob, args.num_atoms, edge_types) 691 | loss_kl_split = [ loss_kl ] 692 | loss_kl_var_split = [ kl_categorical_uniform_var(prob, args.num_atoms, edge_types) ] 693 | 694 | KLb_test.append( 0 ) 695 | KLb_blocks_test.append([0]) 696 | 697 | acc_perm, perm, acc_blocks, acc_var, acc_var_blocks = edge_accuracy_perm_NRI(logits, relations, args.edge_types_list) 698 | 699 | else: 700 | logits_split = torch.split(logits, args.edge_types_list, dim=-1) 701 | edges_split = tuple([gumbel_softmax(logits_i, tau=args.temp, hard=args.hard) for logits_i in logits_split ]) 702 | edges = torch.cat(edges_split, dim=-1) 703 | prob_split = [my_softmax(logits_i, -1) for logits_i in logits_split ] 704 | 705 | if args.prior: 706 | loss_kl_split = [kl_categorical(prob_split[type_idx], log_prior[type_idx], 707 | args.num_atoms) for type_idx in range(len(args.edge_types_list)) ] 708 | loss_kl = sum(loss_kl_split) 709 | else: 710 | loss_kl_split = [ kl_categorical_uniform(prob_split[type_idx], args.num_atoms, 711 | args.edge_types_list[type_idx]) 712 | for type_idx in range(len(args.edge_types_list)) ] 713 | loss_kl = sum(loss_kl_split) 714 | 715 | loss_kl_var_split = [ kl_categorical_uniform_var(prob_split[type_idx], args.num_atoms, 716 | args.edge_types_list[type_idx]) 717 | for type_idx in range(len(args.edge_types_list)) ] 718 | 719 | acc_perm, perm, acc_blocks, acc_var, acc_var_blocks = edge_accuracy_perm_fNRI(logits_split, relations, 720 | args.edge_types_list, args.skip_first) 721 | 722 | KLb_blocks = KL_between_blocks(prob_split, args.num_atoms) 723 | KLb_test.append(sum(KLb_blocks).data.item()) 724 | KLb_blocks_test.append([KL.data.item() for KL in KLb_blocks]) 725 | 726 | target = data_decoder[:, :, 1:, :] # dimensions are [batch, particle, time, state] 727 | output = decoder(data_decoder, edges, rel_rec, rel_send, 1) 728 | 729 | 730 | if args.plot: 731 | import matplotlib.pyplot as plt 732 | output_plot = decoder(data_decoder, edges, rel_rec, rel_send, 49) 733 | 734 | output_plot_en = decoder(data_encoder, edges, rel_rec, rel_send, 49) 735 | from trajectory_plot import draw_lines 736 | 737 | if args.NRI: 738 | acc_batch, perm, acc_blocks_batch = edge_accuracy_perm_NRI_batch(logits, relations, 739 | args.edge_types_list) 740 | else: 741 | acc_batch, perm, acc_blocks_batch = edge_accuracy_perm_fNRI_batch(logits_split, relations, 742 | args.edge_types_list) 743 | 744 | for i in range(args.batch_size): 745 | fig = plt.figure(figsize=(7, 7)) 746 | ax = fig.add_axes([0, 0, 1, 1]) 747 | xmin_t, ymin_t, xmax_t, ymax_t = draw_lines( target, i, linestyle=':', alpha=0.6 ) 748 | xmin_o, ymin_o, xmax_o, ymax_o = draw_lines( output_plot.detach().numpy(), i, linestyle='-' ) 749 | 750 | ax.set_xlim([min(xmin_t, xmin_o), max(xmax_t, xmax_o)]) 751 | ax.set_ylim([min(ymin_t, ymin_o), max(ymax_t, ymax_o)]) 752 | ax.set_xticks([]) 753 | ax.set_yticks([]) 754 | block_names = [ str(j) for j in range(len(args.edge_types_list)) ] 755 | acc_text = [ 'layer ' + block_names[j] + ' acc: {:02.0f}%'.format(100*acc_blocks_batch[i,j]) 756 | for j in range(acc_blocks_batch.shape[1]) ] 757 | acc_text = ', '.join(acc_text) 758 | plt.text( 0.5, 0.95, acc_text, horizontalalignment='center', transform=ax.transAxes ) 759 | plt.show() 760 | 761 | loss_nll = nll_gaussian(output, target, args.var) # compute the reconstruction loss. nll_gaussian is from utils.py 762 | loss_nll_var = nll_gaussian_var(output, target, args.var) 763 | 764 | output_M = decoder(data_decoder, edges, rel_rec, rel_send, args.prediction_steps) 765 | loss_nll_M = nll_gaussian(output_M, target, args.var) 766 | loss_nll_M_var = nll_gaussian_var(output_M, target, args.var) 767 | 768 | perm_test.append(perm) 769 | acc_test.append(acc_perm) 770 | acc_blocks_test.append(acc_blocks) 771 | acc_var_test.append(acc_var) 772 | acc_var_blocks_test.append(acc_var_blocks) 773 | 774 | output_10 = decoder(data_decoder, edges, rel_rec, rel_send, 10) 775 | output_20 = decoder(data_decoder, edges, rel_rec, rel_send, 20) 776 | mse_1_test.append(F.mse_loss(output, target).data.item()) 777 | mse_10_test.append(F.mse_loss(output_10, target).data.item()) 778 | mse_20_test.append(F.mse_loss(output_20, target).data.item()) 779 | 780 | nll_test.append(loss_nll.data.item()) 781 | kl_test.append(loss_kl.data.item()) 782 | kl_list_test.append([kl_loss.data.item() for kl_loss in loss_kl_split]) 783 | 784 | nll_var_test.append(loss_nll_var.data.item()) 785 | kl_var_list_test.append([kl_var.data.item() for kl_var in loss_kl_var_split]) 786 | 787 | nll_M_test.append(loss_nll_M.data.item()) 788 | nll_M_var_test.append(loss_nll_M_var.data.item()) 789 | 790 | 791 | print('--------------------------------') 792 | print('------------Testing-------------') 793 | print('--------------------------------') 794 | print('nll_test: {:.2f}'.format(np.mean(nll_test)), 795 | 'nll_M_test: {:.2f}'.format(np.mean(nll_M_test)), 796 | 'kl_test: {:.5f}'.format(np.mean(kl_test)), 797 | 'mse_1_test: {:.10f}'.format(np.mean(mse_1_test)), 798 | 'mse_10_test: {:.10f}'.format(np.mean(mse_10_test)), 799 | 'mse_20_test: {:.10f}'.format(np.mean(mse_20_test)), 800 | 'acc_test: {:.5f}'.format(np.mean(acc_test)), 801 | 'acc_var_test: {:.5f}'.format(np.mean(acc_var_test)), 802 | 'KLb_test: {:.5f}'.format(np.mean(KLb_test)), 803 | 'time: {:.1f}s'.format(time.time() - t)) 804 | print('acc_b_test: '+str( np.around(np.mean(np.array(acc_blocks_test),axis=0),4 ) ), 805 | 'acc_var_b: '+str( np.around(np.mean(np.array(acc_var_blocks_test),axis=0),4 ) ), 806 | 'kl_test: '+str( np.around(np.mean(np.array(kl_list_test),axis=0),4 ) ) 807 | ) 808 | if args.save_folder: 809 | print('--------------------------------', file=log) 810 | print('------------Testing-------------', file=log) 811 | print('--------------------------------', file=log) 812 | print('nll_test: {:.2f}'.format(np.mean(nll_test)), 813 | 'nll_M_test: {:.2f}'.format(np.mean(nll_M_test)), 814 | 'kl_test: {:.5f}'.format(np.mean(kl_test)), 815 | 'mse_1_test: {:.10f}'.format(np.mean(mse_1_test)), 816 | 'mse_10_test: {:.10f}'.format(np.mean(mse_10_test)), 817 | 'mse_20_test: {:.10f}'.format(np.mean(mse_20_test)), 818 | 'acc_test: {:.5f}'.format(np.mean(acc_test)), 819 | 'acc_var_test: {:.5f}'.format(np.mean(acc_var_test)), 820 | 'KLb_test: {:.5f}'.format(np.mean(KLb_test)), 821 | 'time: {:.1f}s'.format(time.time() - t), 822 | file=log) 823 | print('acc_b_test: '+str( np.around(np.mean(np.array(acc_blocks_test),axis=0),4 ) ), 824 | 'acc_var_b_test: '+str( np.around(np.mean(np.array(acc_var_blocks_test),axis=0),4 ) ), 825 | 'kl_test: '+str( np.around(np.mean(np.array(kl_list_test),axis=0),4 ) ), 826 | file=log) 827 | log.flush() 828 | 829 | 830 | # Train model 831 | if not args.test: 832 | t_total = time.time() 833 | best_val_loss = np.inf 834 | best_epoch = 0 835 | for epoch in range(args.epochs): 836 | val_loss = train(epoch, best_val_loss) 837 | if val_loss < best_val_loss: 838 | best_val_loss = val_loss 839 | best_epoch = epoch 840 | if epoch - best_epoch > args.patience and epoch > 99: 841 | break 842 | print("Optimization Finished!") 843 | print("Best Epoch: {:04d}".format(best_epoch)) 844 | if args.save_folder: 845 | print("Best Epoch: {:04d}".format(best_epoch), file=log) 846 | log.flush() 847 | 848 | test() 849 | if log is not None: 850 | print(save_folder) 851 | log.close() 852 | log_csv.close() 853 | perm_csv.close() --------------------------------------------------------------------------------