├── .gitignore ├── LICENSE ├── README.md ├── checkpoint └── .gitignore ├── config.py ├── create_billards_data.py ├── figures ├── 2.png └── VIN-example.gif ├── img └── .gitignore ├── load_data.py ├── model.py ├── requirements.txt ├── train.py ├── vin.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | /vin/ 103 | .idea/ 104 | *.mat 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Karl Stelzner 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Visual-Interaction-Networks 2 | An implementation of Deepmind's visual interaction networks in Pytorch. 3 | 4 | > From just a glance, humans can make rich predictions about the future state of a wide range of physical systems. On the other hand, modern approaches from engineering, robotics, and graphics are often restricted to narrow domains and require direct measurements of the underlying states. We introduce the Visual Interaction Network, a general-purpose model for learning the dynamics of a physical system from raw visual observations. Our model consists of a perceptual front-end based on convolutional neural networks and a dynamics predictor based on interaction networks. Through joint training, the perceptual front-end learns to parse a dynamic visual scene into a set of factored latent object representations. The dynamics predictor learns to roll these states forward in time by computing their interactions and dynamics, producing a predicted physical trajectory of arbitrary length. We found that from just six input video frames the Visual Interaction Network can generate accurate future trajectories of hundreds of time steps on a wide range of physical systems. Our model can also be applied to scenes with invisible objects, inferring their future states from their effects on the visible objects, and can implicitly infer the unknown mass of objects. Our results demonstrate that the perceptual module and the object-based dynamics predictor module can induce factored latent representations that support accurate dynamical predictions. This work opens new opportunities for model-based decision-making and planning from raw sensory observations in complex physical environments. 5 | 6 | [Watters, N., Tacchetti, A., Weber, T., Pascanu, R., Battaglia, P.W., & Zoran, D. (2017). Visual Interaction Networks. CoRR, abs/1706.01433.](https://arxiv.org/abs/1706.01433) 7 |
8 | 9 | 10 |
11 | 12 | 13 | ## Architecture 14 |
15 | 16 |
17 | 18 | 19 | ### Data 20 | Run create_billards_data.py to create a dataset of bouncing billard balls, or supply your own data. 21 | 22 | ### Dependencies 23 | Dependencies can be installed using `pip install -r requirements.txt`. They consist of the following required packages 24 | ``` 25 | Python 3.6 26 | pytorch 0.4 27 | numpy 1.15 28 | scipy 1.1 29 | ``` 30 | as well as these optional packages for visualization: 31 | ``` 32 | matplotlib 3.0 33 | visdom 0.1 34 | imageio 2.4 35 | ``` 36 | 37 | ### Run 38 | - Edit configration file to your needs. 39 | - Supply the data, for instance by running `create_billards_data.py` 40 | - Run `vin.py` 41 | 42 | ### Thank you! 43 | This repository was primarily created by refactoring and fixing 44 | * https://github.com/MrGemy95/visual-interaction-networks-pytorch 45 | 46 | Indirect sources include: 47 | * https://github.com/jaesik817/visual-interaction-networks_tensorflow 48 | * Ilya Sutskever's code for the [billard data](http://www.cs.utoronto.ca/~ilya/code/2008/RTRBM.tar) 49 | 50 | 51 | -------------------------------------------------------------------------------- /checkpoint/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | class VinConfig: 2 | 3 | # Directories 4 | img_folder = "./img/" # image folder 5 | traindata = 'billards_balls_training_data.mat' 6 | testdata = 'billards_balls_testing_data.mat' 7 | checkpoint_dir = "./checkpoint/" 8 | log_dir = "./log" 9 | 10 | # Model/training config 11 | visual = True # If False, use states as input instead of images 12 | load = False # Load parameters from checkpoint file 13 | num_visible = 6 # Number of visible frames 14 | num_rollout = 8 # Number of rollout frames 15 | frame_step = 1 # Stepsize when observing frames 16 | batch_size = 100 17 | cl = 16 # state code length per object 18 | discount_factor = 0.9 # discount factor for loss from rollouts 19 | 20 | # Data config 21 | num_episodes = 1000 # The number of episodes 22 | num_frames = 100 # The number of frames per episode 23 | width = 32 24 | height = 32 25 | channels = 3 26 | num_obj = 3 # the number of object 27 | -------------------------------------------------------------------------------- /create_billards_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script comes from the RTRBM code by Ilya Sutskever from 3 | http://www.cs.utoronto.ca/~ilya/code/2008/RTRBM.tar 4 | """ 5 | 6 | from numpy import * 7 | from scipy import * 8 | import scipy.io 9 | 10 | import matplotlib 11 | 12 | matplotlib.use('Agg') 13 | import matplotlib.pyplot as plt 14 | 15 | shape_std = shape 16 | 17 | 18 | def shape(A): 19 | if isinstance(A, ndarray): 20 | return shape_std(A) 21 | else: 22 | return A.shape() 23 | 24 | 25 | size_std = size 26 | 27 | 28 | def size(A): 29 | if isinstance(A, ndarray): 30 | return size_std(A) 31 | else: 32 | return A.size() 33 | 34 | 35 | det = linalg.det 36 | 37 | 38 | def new_speeds(m1, m2, v1, v2): 39 | new_v2 = (2 * m1 * v1 + v2 * (m2 - m1)) / (m1 + m2) 40 | new_v1 = new_v2 + (v2 - v1) 41 | return new_v1, new_v2 42 | 43 | 44 | def norm(x): return sqrt((x ** 2).sum()) 45 | 46 | 47 | def sigmoid(x): return 1. / (1. + exp(-x)) 48 | 49 | 50 | SIZE = 10 51 | 52 | 53 | # size of bounding box: SIZE X SIZE. 54 | 55 | def bounce_n(T=128, n=2, r=None, m=None): 56 | if r is None: 57 | r = array([1.2] * n) 58 | if m is None: 59 | m = array([1] * n) 60 | # r is to be rather small. 61 | X = zeros((T, n, 2), dtype='float') 62 | y = zeros((T, n, 2), dtype='float') 63 | v = randn(n, 2) 64 | v = v / norm(v) * .5 65 | good_config = False 66 | while not good_config: 67 | x = 2 + rand(n, 2) * 8 68 | good_config = True 69 | for i in range(n): 70 | for z in range(2): 71 | if x[i][z] - r[i] < 0: 72 | good_config = False 73 | if x[i][z] + r[i] > SIZE: 74 | good_config = False 75 | 76 | # that's the main part. 77 | for i in range(n): 78 | for j in range(i): 79 | if norm(x[i] - x[j]) < r[i] + r[j]: 80 | good_config = False 81 | 82 | eps = .5 83 | for t in range(T): 84 | # for how long do we show small simulation 85 | 86 | v_prev = copy(v) 87 | 88 | for i in range(n): 89 | X[t, i] = x[i] 90 | y[t, i] = v[i] 91 | 92 | for mu in range(int(1 / eps)): 93 | 94 | for i in range(n): 95 | x[i] += eps * v[i] 96 | 97 | for i in range(n): 98 | for z in range(2): 99 | if x[i][z] - r[i] < 0: 100 | v[i][z] = abs(v[i][z]) # want positive 101 | if x[i][z] + r[i] > SIZE: 102 | v[i][z] = -abs(v[i][z]) # want negative 103 | 104 | for i in range(n): 105 | for j in range(i): 106 | if norm(x[i] - x[j]) < r[i] + r[j]: 107 | # the bouncing off part: 108 | w = x[i] - x[j] 109 | w = w / norm(w) 110 | 111 | v_i = dot(w.transpose(), v[i]) 112 | v_j = dot(w.transpose(), v[j]) 113 | 114 | new_v_i, new_v_j = new_speeds(m[i], m[j], v_i, v_j) 115 | 116 | v[i] += w * (new_v_i - v_i) 117 | v[j] += w * (new_v_j - v_j) 118 | 119 | return X, y 120 | 121 | 122 | def ar(x, y, z): 123 | return z / 2 + arange(x, y, z, dtype='float') 124 | 125 | 126 | def draw_image(X, res, r=None): 127 | T, n = shape(X)[0:2] 128 | if r is None: 129 | r = array([1.2] * n) 130 | 131 | A = zeros((T, res, res, 3), dtype='float') 132 | 133 | [I, J] = meshgrid(ar(0, 1, 1. / res) * SIZE, ar(0, 1, 1. / res) * SIZE) 134 | 135 | for t in range(T): 136 | for i in range(n): 137 | A[t, :, :, i] += exp(-(((I - X[t, i, 0]) ** 2 + 138 | (J - X[t, i, 1]) ** 2) / 139 | (r[i] ** 2)) ** 4) 140 | 141 | A[t][A[t] > 1] = 1 142 | return A 143 | 144 | 145 | def bounce_mat(res, n=2, T=128, r=None): 146 | if r is None: 147 | r = array([1.2] * n) 148 | x, y = bounce_n(T, n, r) 149 | A = draw_image(x, res, r) 150 | return A, y 151 | 152 | 153 | def bounce_vec(res, n=2, T=128, r=None, m=None): 154 | if r is None: 155 | r = array([1.2] * n) 156 | x, y = bounce_n(T, n, r, m) 157 | V = draw_image(x, res, r) 158 | y = concatenate((x, y), axis=2) 159 | return V.reshape(T, res, res, 3), y 160 | 161 | 162 | # make sure you have this folder 163 | logdir = './img' 164 | 165 | 166 | def show_sample(V): 167 | T = V.shape[0] 168 | for t in range(T): 169 | plt.imshow(V[t]) 170 | # Save it 171 | fname = logdir + '/' + str(t) + '.png' 172 | plt.savefig(fname) 173 | 174 | 175 | if __name__ == "__main__": 176 | res = 32 177 | T = 100 178 | N = 1000 179 | dat = empty((N, T, res, res, 3), dtype=float) 180 | dat_y = empty((N, T, 3, 4), dtype=float) 181 | for i in range(N): 182 | dat[i], dat_y[i] = bounce_vec(res=res, n=3, T=T) 183 | print('training example {} / {}'.format(i, N)) 184 | data = dict() 185 | data['X'] = dat 186 | data['y'] = dat_y 187 | scipy.io.savemat('billards_balls_training_data.mat', data) 188 | 189 | N = 200 190 | dat = empty((N, T, res, res, 3), dtype=float) 191 | dat_y = empty((N, T, 3, 4), dtype=float) 192 | for i in range(N): 193 | dat[i], dat_y[i] = bounce_vec(res=res, n=3, T=T) 194 | print('test example {} / {}'.format(i, N)) 195 | data = dict() 196 | data['X'] = dat 197 | data['y'] = dat_y 198 | scipy.io.savemat('billards_balls_testing_data.mat', data) 199 | 200 | # show one video 201 | show_sample(dat[0]) 202 | print(dat_y[0, :]) 203 | -------------------------------------------------------------------------------- /figures/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stelzner/Visual-Interaction-Networks/aec463df90396506aa6dbe2a31e13e06ac1fe784/figures/2.png -------------------------------------------------------------------------------- /figures/VIN-example.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stelzner/Visual-Interaction-Networks/aec463df90396506aa6dbe2a31e13e06ac1fe784/figures/VIN-example.gif -------------------------------------------------------------------------------- /img/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /load_data.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from scipy.io import loadmat 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def clips_from_episodes(images, labels, visible_l, rollout_l, step): 8 | """ 9 | Rearrange episodic observations into shorter clips 10 | :param images: Episodes of images of shape (n, fr, c, h, w) 11 | :param labels: Episodes of accompanying data for the images (n, fr, obj, d) 12 | :param visible_l: Number of frames in each clip 13 | :param rollout_l: Number of future frames for which labels are returned 14 | :param step: Stepsize for taking frames from the given episodes 15 | :return: A number of shorter clips (_, visible_l, c, h, w), 16 | the corresponding labels (_, visible_l, obj, d), 17 | and future labels (_, rollout_l, obj, d). 18 | """ 19 | (num_episodes, num_frames, height, width, channels) = images.shape 20 | num_obj = labels.shape[-2] 21 | 22 | clips_per_episode = num_frames - (rollout_l + visible_l) * step + 1 23 | num_clips = num_episodes * clips_per_episode 24 | 25 | clips = np.zeros((num_clips, visible_l, height, width, channels)) 26 | present_labels = np.zeros((num_clips, visible_l - 2, num_obj, 4)) 27 | future_labels = np.zeros((num_clips, rollout_l, num_obj, 4)) 28 | 29 | for i in range(num_episodes): 30 | for j in range(clips_per_episode): 31 | clip_idx = i * clips_per_episode + j 32 | 33 | end_visible = j + visible_l * step 34 | end_rollout = end_visible + rollout_l * step 35 | 36 | clips[clip_idx] = images[i, j:end_visible:step] 37 | present_labels[clip_idx] = labels[i, j + 2*step:end_visible:step] 38 | future_labels[clip_idx] = labels[i, end_visible:end_rollout:step] 39 | 40 | # shuffle 41 | perm_idx = np.random.permutation(num_clips) 42 | return clips[perm_idx], present_labels[perm_idx], future_labels[perm_idx] 43 | 44 | 45 | class VinDataset(Dataset): 46 | """Face Landmarks dataset.""" 47 | 48 | def __init__(self, config, transform=None, test=False): 49 | self.config = config 50 | self.transform = transform 51 | 52 | if test: 53 | data = loadmat(config.testdata) 54 | else: 55 | data = loadmat(config.traindata) 56 | 57 | self.total_img = data['X'][:config.num_episodes] 58 | # Transpose, as PyTorch images have shape (c, h, w) 59 | self.total_img = np.transpose(self.total_img, (0, 1, 4, 2, 3)) 60 | self.total_data = data['y'][:config.num_episodes] 61 | 62 | num_eps, num_frames = self.total_img.shape[0:2] 63 | clips_per_ep = num_frames - ((config.num_visible + 64 | config.num_rollout) * 65 | config.frame_step) + 1 66 | 67 | idx_ep, idx_fr = np.meshgrid(list(range(num_eps)), 68 | list(range(clips_per_ep))) 69 | 70 | self.idxs = np.reshape(np.stack([idx_ep, idx_fr], 2), (-1, 2)) 71 | 72 | def __len__(self): 73 | return len(self.idxs) 74 | 75 | def __getitem__(self, idx): 76 | conf = self.config 77 | step = conf.frame_step 78 | 79 | i, j = self.idxs[idx, 0], self.idxs[idx, 1] 80 | 81 | end_visible = j + conf.num_visible * step 82 | end_rollout = end_visible + conf.num_rollout * step 83 | image = self.total_img[i, j:end_visible:step] 84 | present = self.total_data[i, j + 2 * step:end_visible:step] 85 | future = self.total_data[i, end_visible:end_rollout:step] 86 | 87 | sample = {'image': torch.from_numpy(image), 88 | 'future_labels': torch.from_numpy(future), 89 | 'present_labels': torch.from_numpy(present)} 90 | 91 | return sample 92 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.autograd import Variable 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class Net(nn.Module): 9 | def __init__(self, config): 10 | super(Net, self).__init__() 11 | self.config = config 12 | self.pool = nn.MaxPool2d(2, 2) 13 | self.x_coord, self.y_coord = self.construct_coord_dims() 14 | cl = config.cl 15 | 16 | # Visual Encoder Modules 17 | self.conv1 = nn.Conv2d(config.channels * 2 + 2, 16, 3, padding=1) 18 | self.conv2 = nn.Conv2d(16, 16, 3, padding=1) 19 | self.conv3 = nn.Conv2d(16, 16, 3, padding=1) 20 | self.conv4 = nn.Conv2d(16, 32, 3, padding=1) 21 | self.conv5 = nn.Conv2d(32, 32, 3, padding=1) 22 | # shared linear layer to get pair codes of shape N_obj*cl 23 | self.fc1 = nn.Linear(32, 3 * cl) 24 | 25 | # shared MLP to encode pairs of pair codes as state codes N_obj*cl 26 | self.fc2 = nn.Linear(cl * 2, cl) 27 | self.fc3 = nn.Linear(cl, cl) 28 | # end of visual encoder 29 | 30 | # Interaction Net Core Modules 31 | # Self-dynamics MLP 32 | self.self_cores = nn.ModuleList() 33 | for i in range(3): 34 | self.self_cores.append(nn.ModuleList()) 35 | self.self_cores[i].append(nn.Linear(cl, cl).double().cuda()) 36 | self.self_cores[i].append(nn.Linear(cl, cl).double().cuda()) 37 | 38 | # Relation MLP 39 | self.rel_cores = nn.ModuleList() 40 | for i in range(3): 41 | self.rel_cores.append(nn.ModuleList()) 42 | self.rel_cores[i].append(nn.Linear(cl * 2, 2 * cl).double().cuda()) 43 | self.rel_cores[i].append(nn.Linear(2 * cl, cl).double().cuda()) 44 | self.rel_cores[i].append(nn.Linear(cl, cl).double().cuda()) 45 | 46 | # Affector MLP 47 | self.affector = nn.ModuleList() 48 | for i in range(3): 49 | self.affector.append(nn.ModuleList()) 50 | self.affector[i].append(nn.Linear(cl, cl).double().cuda()) 51 | self.affector[i].append(nn.Linear(cl, cl).double().cuda()) 52 | self.affector[i].append(nn.Linear(cl, cl).double().cuda()) 53 | 54 | # Core output MLP 55 | self.out = nn.ModuleList() 56 | for i in range(3): 57 | self.out.append(nn.ModuleList()) 58 | self.out[i].append(nn.Linear(cl + cl, cl).double().cuda()) 59 | self.out[i].append(nn.Linear(cl, cl).double().cuda()) 60 | 61 | # Aggregator MLP for aggregating core predictions 62 | self.aggregator1 = nn.Linear(cl * 3, cl) 63 | self.aggregator2 = nn.Linear(cl, cl) 64 | 65 | # decoder mapping state codes to actual states 66 | self.state_decoder = nn.Linear(cl, 4) 67 | # encoder for the non-visual case 68 | self.state_encoder = nn.Linear(4, cl) 69 | 70 | def construct_coord_dims(self): 71 | """ 72 | Build a meshgrid of x, y coordinates to be used as additional channels 73 | """ 74 | x = np.linspace(0, 1, self.config.width) 75 | y = np.linspace(0, 1, self.config.height) 76 | xv, yv = np.meshgrid(x, y) 77 | xv = np.reshape(xv, [1, 1, self.config.height, self.config.width]) 78 | yv = np.reshape(yv, [1, 1, self.config.height, self.config.width]) 79 | x_coord = Variable(torch.from_numpy(xv)).cuda() 80 | y_coord = Variable(torch.from_numpy(yv)).cuda() 81 | x_coord = x_coord.expand(self.config.batch_size * 5, -1, -1, -1) 82 | y_coord = y_coord.expand(self.config.batch_size * 5, -1, -1, -1) 83 | return x_coord, y_coord 84 | 85 | def core(self, s, core_idx): 86 | """ 87 | Applies an interaction network core 88 | :param s: A state code of shape (n, o, cl) 89 | :param core_idx: The index of the set of parameters to apply (0, 1, 2) 90 | :return: Prediction of a future state code (n, o, cl) 91 | """ 92 | objects = [s[:, i] for i in range(3)] 93 | 94 | self_sd_h1 = F.relu(self.self_cores[core_idx][0](s)) 95 | self_dynamic = self.self_cores[core_idx][1](self_sd_h1) + self_sd_h1 96 | 97 | rel_combination = [] 98 | for i in range(6): 99 | row_idx = i // 2 100 | # pick the two other objects 101 | col_idx = (row_idx + 1 + (i % 2)) % 3 102 | rel_combination.append( 103 | torch.cat([objects[row_idx], objects[col_idx]], 1)) 104 | # 6 combinations of the 3 objects, (n, 6, 2*cl) 105 | rel_combination = torch.stack(rel_combination, 1) 106 | rel_sd_h1 = F.relu(self.rel_cores[core_idx][0](rel_combination)) 107 | rel_sd_h2 = F.relu(self.rel_cores[core_idx][1](rel_sd_h1)) 108 | rel_factors = self.rel_cores[core_idx][2](rel_sd_h2) + rel_sd_h2 109 | obj1 = rel_factors[:, 0] + rel_factors[:, 1] 110 | obj2 = rel_factors[:, 2] + rel_factors[:, 3] 111 | obj3 = rel_factors[:, 4] + rel_factors[:, 5] 112 | # relational dynamics per object, (n, o, cl) 113 | rel_dynamic = torch.stack([obj1, obj2, obj3], 1) 114 | # total dynamics 115 | dynamic_pred = self_dynamic + rel_dynamic 116 | 117 | aff1 = F.relu(self.affector[core_idx][0](dynamic_pred)) 118 | aff2 = F.relu(self.affector[core_idx][1](aff1) + aff1) 119 | aff3 = self.affector[core_idx][2](aff2) 120 | 121 | aff_s = torch.cat([aff3, s], 2) 122 | out1 = F.relu(self.out[core_idx][0](aff_s)) 123 | out2 = self.out[core_idx][1](out1) + out1 124 | return out2 125 | 126 | def frames_to_states(self, frames): 127 | """ 128 | Apply visual encoder 129 | :param frames: Groups of six input frames of shape (n, 6, c, w, h) 130 | :return: State codes of shape (n, 4, o, cl) 131 | """ 132 | batch_size = self.config.batch_size 133 | cl = self.config.cl 134 | num_obj = self.config.num_obj 135 | 136 | pairs = [] 137 | for i in range(frames.shape[1] - 1): 138 | # pair consecutive frames (n, 2c, w, h) 139 | pair = torch.cat((frames[:, i], frames[:, i+1]), 1) 140 | pairs.append(pair) 141 | 142 | num_pairs = len(pairs) 143 | pairs = torch.cat(pairs, 0) 144 | # add coord channels (n * num_pairs, 2c + 2, w, h) 145 | pairs = torch.cat([pairs, self.x_coord, self.y_coord], dim=1) 146 | 147 | # apply ConvNet to pairs 148 | ve_h1 = F.relu(self.conv1(pairs)) 149 | ve_h1 = self.pool(ve_h1) 150 | ve_h2 = F.relu(self.conv2(ve_h1)) 151 | ve_h2 = self.pool(ve_h2) 152 | ve_h3 = F.relu(self.conv3(ve_h2)) 153 | ve_h3 = self.pool(ve_h3) 154 | ve_h4 = F.relu(self.conv4(ve_h3)) 155 | ve_h4 = self.pool(ve_h4) 156 | ve_h5 = F.relu(self.conv5(ve_h4)) 157 | ve_h5 = self.pool(ve_h5) 158 | 159 | # pooled to 1x1, 32 channels: (n * num_pairs, 32) 160 | encoded_pairs = torch.squeeze(ve_h5) 161 | # final pair encoding (n * num_pairs, o, cl) 162 | encoded_pairs = self.fc1(encoded_pairs) 163 | encoded_pairs = encoded_pairs.view(batch_size * num_pairs, num_obj, cl) 164 | # chunk pairs encoding, each is (n, o, cl) 165 | encoded_pairs = torch.chunk(encoded_pairs, num_pairs) 166 | 167 | triples = [] 168 | for i in range(num_pairs - 1): 169 | # pair consecutive pairs to obtain encodings for triples 170 | triple = torch.cat([encoded_pairs[i], encoded_pairs[i+1]], 2) 171 | triples.append(triple) 172 | 173 | # the triples together, i.e. (n, num_pairs - 1, o, 2 * cl) 174 | triples = torch.stack(triples, 1) 175 | # apply MLP to triples 176 | shared_h1 = F.relu(self.fc2(triples)) 177 | state_codes = self.fc3(shared_h1) 178 | return state_codes 179 | 180 | def forward(self, x, num_rollout=8, visual=True): 181 | """ 182 | Rollout a given sequence of observations using the model 183 | :param x: The given sequence of observations. 184 | If visual is True, it should be images of shape (n, 6, c, w, h), 185 | otherwise states of shape (n, 4, o, 4). 186 | :param num_rollout: The number of future states to be predicted 187 | :param visual: Boolean determining the type of input 188 | :return: rollout_states: predicted future states (n, roll_len, o, 4) 189 | present_states: predicted states at the time of 190 | the given observations, (n, 4, o, 4) 191 | """ 192 | # get encoded states 193 | if visual: 194 | state_codes = self.frames_to_states(x) 195 | else: 196 | state_codes = self.state_encoder(x) 197 | # the 4 state codes (n, o, cl) 198 | s1, s2, s3, s4 = [state_codes[:, i] for i in range(4)] 199 | rollouts = [] 200 | for i in range(num_rollout): 201 | # use cores to predict next state using delta_t = 1, 2, 4 202 | c1 = self.core(s4, 0) 203 | c2 = self.core(s3, 1) 204 | c4 = self.core(s1, 2) 205 | all_c = torch.cat([c1, c2, c4], 2) 206 | aggregator1 = F.relu(self.aggregator1(all_c)) 207 | state_prediction = self.aggregator2(aggregator1) 208 | rollouts.append(state_prediction) 209 | s1, s2, s3, s4 = s2, s3, s4, state_prediction 210 | rollouts = torch.stack(rollouts, 1) 211 | 212 | present_states = self.state_decoder(state_codes) 213 | rollout_states = self.state_decoder(rollouts) 214 | 215 | return rollout_states, present_states 216 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | python_version > 3.4 2 | torch>=0.4 3 | numpy>=1.0 4 | scipy>=1.1 5 | matplotlib>=3.0 6 | visdom>=0.1 7 | imageio>=2.4 8 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import torch.optim as optim 6 | from torch import nn 7 | from torch.utils.data import DataLoader 8 | 9 | from load_data import VinDataset 10 | from visualize import plot_positions, animate 11 | 12 | 13 | class Trainer: 14 | def __init__(self, config, net): 15 | self.net = net 16 | self.params = net.parameters() 17 | self.initial_values = {} 18 | self.config = config 19 | 20 | train_dataset = VinDataset(self.config) 21 | self.dataloader = DataLoader(train_dataset, 22 | batch_size=self.config.batch_size, 23 | shuffle=True, 24 | num_workers=4, 25 | drop_last=True) 26 | 27 | self.test_dataset = VinDataset(self.config, 28 | test=True) 29 | self.test_dataloader = DataLoader(self.test_dataset, 30 | batch_size=self.config.batch_size, 31 | shuffle=True, 32 | num_workers=4, 33 | drop_last=True) 34 | 35 | self.optimizer = optim.Adam(self.net.parameters(), lr=0.0005) 36 | if config.load: 37 | self.load() 38 | 39 | def save(self): 40 | torch.save(self.net.state_dict(), os.path.join( 41 | self.config.checkpoint_dir, "checkpoint")) 42 | print('Parameters saved') 43 | 44 | def load(self): 45 | try: 46 | self.net.load_state_dict(torch.load(os.path.join( 47 | self.config.checkpoint_dir, "checkpoint"))) 48 | print('Parameters loaded') 49 | except RuntimeError: 50 | print('Loading parameters failed, training from scratch...') 51 | 52 | def compute_loss(self, present_labels, future_labels, recons, preds): 53 | loss = nn.MSELoss() 54 | df = self.config.discount_factor 55 | pred_loss = 0.0 56 | for delta_t in range(0, self.config.num_rollout): 57 | pred_loss += (df ** (delta_t + 1)) * \ 58 | loss(preds[:, delta_t], future_labels[:, delta_t]) 59 | 60 | recon_loss = loss(recons, present_labels) 61 | total_loss = pred_loss + recon_loss 62 | 63 | return total_loss, pred_loss, recon_loss 64 | 65 | def train(self): 66 | step_counter = 0 67 | num_rollout = self.config.num_rollout 68 | for epoch in range(100): 69 | print("testing................") 70 | self.test() 71 | for i, data in enumerate(self.dataloader, 0): 72 | step_counter += 1 73 | images = data['image'] 74 | present_labels = data['present_labels'] 75 | future_labels = data['future_labels'] 76 | 77 | images = images.cuda() 78 | present_labels = present_labels.cuda() 79 | future_labels = future_labels.cuda() 80 | 81 | if self.config.visual: 82 | vin_input = images 83 | else: 84 | vin_input = present_labels 85 | 86 | self.optimizer.zero_grad() 87 | state_pred, state_recon = self.net(vin_input, 88 | num_rollout=num_rollout, 89 | visual=self.config.visual) 90 | 91 | total_loss, pred_loss, recon_loss = \ 92 | self.compute_loss(present_labels, future_labels, 93 | state_recon, state_pred) 94 | 95 | total_loss.backward() 96 | self.optimizer.step() 97 | 98 | # print loss 99 | if step_counter % 20 == 0: 100 | print('{:5d} {:5f} {:5f} {:5f}'.format(step_counter, 101 | total_loss.item(), 102 | recon_loss.item(), 103 | pred_loss.item())) 104 | # Draw example 105 | if step_counter % 200 == 0: 106 | real = torch.cat([present_labels[0], future_labels[0]]) 107 | simu = torch.cat([state_recon[0], state_pred[0]]).detach() 108 | plot_positions(real, self.config.img_folder, 'real') 109 | plot_positions(simu, self.config.img_folder, 'rollout') 110 | # Save parameters 111 | if (step_counter + 1) % 1000 == 0: 112 | self.save() 113 | 114 | print("epoch ", epoch, " Finished") 115 | print('Finished Training') 116 | 117 | def test(self): 118 | total_loss = 0.0 119 | for i, data in enumerate(self.test_dataloader, 0): 120 | images, future_labels, present_labels = \ 121 | data['image'], data['future_labels'], data['present_labels'] 122 | 123 | images = images.cuda() 124 | present_labels = present_labels.cuda() 125 | future_labels = future_labels.cuda() 126 | 127 | vin_input = images if self.config.visual else present_labels 128 | 129 | pred, recon = self.net(vin_input, 130 | num_rollout=self.config.num_rollout, 131 | visual=self.config.visual) 132 | 133 | total_loss, pred_loss, recon_loss = \ 134 | self.compute_loss(present_labels, future_labels, 135 | recon, pred) 136 | 137 | print('total test loss {:5f}'.format(total_loss.item())) 138 | 139 | # Create one long rollout and save it as an animated GIF 140 | total_images = self.test_dataset.total_img 141 | total_labels = self.test_dataset.total_data 142 | step = self.config.frame_step 143 | visible = self.config.num_visible 144 | batch_size = self.config.batch_size 145 | 146 | long_rollout_length = self.config.num_frames // step - visible 147 | 148 | if self.config.visual: 149 | vin_input = total_images[:batch_size, :visible*step:step] 150 | else: 151 | vin_input = total_labels[:batch_size, 2*step:visible*step:step] 152 | 153 | vin_input = torch.tensor(vin_input).cuda() 154 | 155 | pred, recon = self.net(vin_input, long_rollout_length, 156 | visual=self.config.visual) 157 | 158 | simu_rollout = pred[0].detach().cpu().numpy() 159 | simu_recon = recon[0].detach().cpu().numpy() 160 | simu = np.concatenate((simu_recon, simu_rollout), axis=0) 161 | 162 | # Saving 163 | print("Make GIFs") 164 | animate(total_labels[0, 2:], self.config.img_folder, 'real') 165 | animate(simu, self.config.img_folder, 'rollout') 166 | print("Done") 167 | -------------------------------------------------------------------------------- /vin.py: -------------------------------------------------------------------------------- 1 | from model import Net 2 | from config import VinConfig 3 | from train import Trainer 4 | 5 | 6 | def main(): 7 | net = Net(VinConfig) 8 | net = net.cuda() 9 | net = net.double() 10 | trainer = Trainer(VinConfig, net) 11 | trainer.train() 12 | 13 | 14 | if __name__ == '__main__': 15 | main() 16 | -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import matplotlib.pyplot as plt 3 | import visdom 4 | import numpy as np 5 | import imageio 6 | 7 | vis = visdom.Visdom() 8 | 9 | 10 | def plot_positions(xy, img_folder, prefix, save=True, size=10): 11 | if not os.path.exists(img_folder): 12 | os.makedirs(img_folder) 13 | fig_num = len(xy) 14 | mydpi = 100 15 | fig = plt.figure(figsize=(128/mydpi, 128/mydpi)) 16 | plt.xlim(0, 10) 17 | plt.ylim(0, 10) 18 | plt.xticks([]) 19 | plt.yticks([]) 20 | 21 | color = ['r', 'b', 'g', 'k', 'y', 'm', 'c'] 22 | for i in range(fig_num): 23 | for j in range(len(xy[0])): 24 | plt.scatter(xy[i, j, 1], xy[i, j, 0], 25 | c=color[j % len(color)], s=size, alpha=(i+1)/fig_num) 26 | 27 | if save: 28 | fig.savefig(img_folder+prefix+".pdf", dpi=mydpi, transparent=True) 29 | vis.matplot(fig) 30 | 31 | fig.canvas.draw() 32 | image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8') 33 | image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 34 | plt.close() 35 | return image 36 | 37 | 38 | def animate(states, img_folder, prefix): 39 | images = [] 40 | for i in range(len(states)): 41 | images.append(plot_positions(states[i:i + 1], img_folder, 42 | prefix, save=False, size=270)) 43 | imageio.mimsave(img_folder+prefix+'.gif', images, fps=24) 44 | --------------------------------------------------------------------------------