├── .gitignore
├── LICENSE
├── README.md
├── checkpoint
└── .gitignore
├── config.py
├── create_billards_data.py
├── figures
├── 2.png
└── VIN-example.gif
├── img
└── .gitignore
├── load_data.py
├── model.py
├── requirements.txt
├── train.py
├── vin.py
└── visualize.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 | /vin/
103 | .idea/
104 | *.mat
105 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Karl Stelzner
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Visual-Interaction-Networks
2 | An implementation of Deepmind's visual interaction networks in Pytorch.
3 |
4 | > From just a glance, humans can make rich predictions about the future state of a wide range of physical systems. On the other hand, modern approaches from engineering, robotics, and graphics are often restricted to narrow domains and require direct measurements of the underlying states. We introduce the Visual Interaction Network, a general-purpose model for learning the dynamics of a physical system from raw visual observations. Our model consists of a perceptual front-end based on convolutional neural networks and a dynamics predictor based on interaction networks. Through joint training, the perceptual front-end learns to parse a dynamic visual scene into a set of factored latent object representations. The dynamics predictor learns to roll these states forward in time by computing their interactions and dynamics, producing a predicted physical trajectory of arbitrary length. We found that from just six input video frames the Visual Interaction Network can generate accurate future trajectories of hundreds of time steps on a wide range of physical systems. Our model can also be applied to scenes with invisible objects, inferring their future states from their effects on the visible objects, and can implicitly infer the unknown mass of objects. Our results demonstrate that the perceptual module and the object-based dynamics predictor module can induce factored latent representations that support accurate dynamical predictions. This work opens new opportunities for model-based decision-making and planning from raw sensory observations in complex physical environments.
5 |
6 | [Watters, N., Tacchetti, A., Weber, T., Pascanu, R., Battaglia, P.W., & Zoran, D. (2017). Visual Interaction Networks. CoRR, abs/1706.01433.](https://arxiv.org/abs/1706.01433)
7 |
8 |
9 |

10 |
11 |
12 |
13 | ## Architecture
14 |
15 |

16 |
17 |
18 |
19 | ### Data
20 | Run create_billards_data.py to create a dataset of bouncing billard balls, or supply your own data.
21 |
22 | ### Dependencies
23 | Dependencies can be installed using `pip install -r requirements.txt`. They consist of the following required packages
24 | ```
25 | Python 3.6
26 | pytorch 0.4
27 | numpy 1.15
28 | scipy 1.1
29 | ```
30 | as well as these optional packages for visualization:
31 | ```
32 | matplotlib 3.0
33 | visdom 0.1
34 | imageio 2.4
35 | ```
36 |
37 | ### Run
38 | - Edit configration file to your needs.
39 | - Supply the data, for instance by running `create_billards_data.py`
40 | - Run `vin.py`
41 |
42 | ### Thank you!
43 | This repository was primarily created by refactoring and fixing
44 | * https://github.com/MrGemy95/visual-interaction-networks-pytorch
45 |
46 | Indirect sources include:
47 | * https://github.com/jaesik817/visual-interaction-networks_tensorflow
48 | * Ilya Sutskever's code for the [billard data](http://www.cs.utoronto.ca/~ilya/code/2008/RTRBM.tar)
49 |
50 |
51 |
--------------------------------------------------------------------------------
/checkpoint/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore everything in this directory
2 | *
3 | # Except this file
4 | !.gitignore
5 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | class VinConfig:
2 |
3 | # Directories
4 | img_folder = "./img/" # image folder
5 | traindata = 'billards_balls_training_data.mat'
6 | testdata = 'billards_balls_testing_data.mat'
7 | checkpoint_dir = "./checkpoint/"
8 | log_dir = "./log"
9 |
10 | # Model/training config
11 | visual = True # If False, use states as input instead of images
12 | load = False # Load parameters from checkpoint file
13 | num_visible = 6 # Number of visible frames
14 | num_rollout = 8 # Number of rollout frames
15 | frame_step = 1 # Stepsize when observing frames
16 | batch_size = 100
17 | cl = 16 # state code length per object
18 | discount_factor = 0.9 # discount factor for loss from rollouts
19 |
20 | # Data config
21 | num_episodes = 1000 # The number of episodes
22 | num_frames = 100 # The number of frames per episode
23 | width = 32
24 | height = 32
25 | channels = 3
26 | num_obj = 3 # the number of object
27 |
--------------------------------------------------------------------------------
/create_billards_data.py:
--------------------------------------------------------------------------------
1 | """
2 | This script comes from the RTRBM code by Ilya Sutskever from
3 | http://www.cs.utoronto.ca/~ilya/code/2008/RTRBM.tar
4 | """
5 |
6 | from numpy import *
7 | from scipy import *
8 | import scipy.io
9 |
10 | import matplotlib
11 |
12 | matplotlib.use('Agg')
13 | import matplotlib.pyplot as plt
14 |
15 | shape_std = shape
16 |
17 |
18 | def shape(A):
19 | if isinstance(A, ndarray):
20 | return shape_std(A)
21 | else:
22 | return A.shape()
23 |
24 |
25 | size_std = size
26 |
27 |
28 | def size(A):
29 | if isinstance(A, ndarray):
30 | return size_std(A)
31 | else:
32 | return A.size()
33 |
34 |
35 | det = linalg.det
36 |
37 |
38 | def new_speeds(m1, m2, v1, v2):
39 | new_v2 = (2 * m1 * v1 + v2 * (m2 - m1)) / (m1 + m2)
40 | new_v1 = new_v2 + (v2 - v1)
41 | return new_v1, new_v2
42 |
43 |
44 | def norm(x): return sqrt((x ** 2).sum())
45 |
46 |
47 | def sigmoid(x): return 1. / (1. + exp(-x))
48 |
49 |
50 | SIZE = 10
51 |
52 |
53 | # size of bounding box: SIZE X SIZE.
54 |
55 | def bounce_n(T=128, n=2, r=None, m=None):
56 | if r is None:
57 | r = array([1.2] * n)
58 | if m is None:
59 | m = array([1] * n)
60 | # r is to be rather small.
61 | X = zeros((T, n, 2), dtype='float')
62 | y = zeros((T, n, 2), dtype='float')
63 | v = randn(n, 2)
64 | v = v / norm(v) * .5
65 | good_config = False
66 | while not good_config:
67 | x = 2 + rand(n, 2) * 8
68 | good_config = True
69 | for i in range(n):
70 | for z in range(2):
71 | if x[i][z] - r[i] < 0:
72 | good_config = False
73 | if x[i][z] + r[i] > SIZE:
74 | good_config = False
75 |
76 | # that's the main part.
77 | for i in range(n):
78 | for j in range(i):
79 | if norm(x[i] - x[j]) < r[i] + r[j]:
80 | good_config = False
81 |
82 | eps = .5
83 | for t in range(T):
84 | # for how long do we show small simulation
85 |
86 | v_prev = copy(v)
87 |
88 | for i in range(n):
89 | X[t, i] = x[i]
90 | y[t, i] = v[i]
91 |
92 | for mu in range(int(1 / eps)):
93 |
94 | for i in range(n):
95 | x[i] += eps * v[i]
96 |
97 | for i in range(n):
98 | for z in range(2):
99 | if x[i][z] - r[i] < 0:
100 | v[i][z] = abs(v[i][z]) # want positive
101 | if x[i][z] + r[i] > SIZE:
102 | v[i][z] = -abs(v[i][z]) # want negative
103 |
104 | for i in range(n):
105 | for j in range(i):
106 | if norm(x[i] - x[j]) < r[i] + r[j]:
107 | # the bouncing off part:
108 | w = x[i] - x[j]
109 | w = w / norm(w)
110 |
111 | v_i = dot(w.transpose(), v[i])
112 | v_j = dot(w.transpose(), v[j])
113 |
114 | new_v_i, new_v_j = new_speeds(m[i], m[j], v_i, v_j)
115 |
116 | v[i] += w * (new_v_i - v_i)
117 | v[j] += w * (new_v_j - v_j)
118 |
119 | return X, y
120 |
121 |
122 | def ar(x, y, z):
123 | return z / 2 + arange(x, y, z, dtype='float')
124 |
125 |
126 | def draw_image(X, res, r=None):
127 | T, n = shape(X)[0:2]
128 | if r is None:
129 | r = array([1.2] * n)
130 |
131 | A = zeros((T, res, res, 3), dtype='float')
132 |
133 | [I, J] = meshgrid(ar(0, 1, 1. / res) * SIZE, ar(0, 1, 1. / res) * SIZE)
134 |
135 | for t in range(T):
136 | for i in range(n):
137 | A[t, :, :, i] += exp(-(((I - X[t, i, 0]) ** 2 +
138 | (J - X[t, i, 1]) ** 2) /
139 | (r[i] ** 2)) ** 4)
140 |
141 | A[t][A[t] > 1] = 1
142 | return A
143 |
144 |
145 | def bounce_mat(res, n=2, T=128, r=None):
146 | if r is None:
147 | r = array([1.2] * n)
148 | x, y = bounce_n(T, n, r)
149 | A = draw_image(x, res, r)
150 | return A, y
151 |
152 |
153 | def bounce_vec(res, n=2, T=128, r=None, m=None):
154 | if r is None:
155 | r = array([1.2] * n)
156 | x, y = bounce_n(T, n, r, m)
157 | V = draw_image(x, res, r)
158 | y = concatenate((x, y), axis=2)
159 | return V.reshape(T, res, res, 3), y
160 |
161 |
162 | # make sure you have this folder
163 | logdir = './img'
164 |
165 |
166 | def show_sample(V):
167 | T = V.shape[0]
168 | for t in range(T):
169 | plt.imshow(V[t])
170 | # Save it
171 | fname = logdir + '/' + str(t) + '.png'
172 | plt.savefig(fname)
173 |
174 |
175 | if __name__ == "__main__":
176 | res = 32
177 | T = 100
178 | N = 1000
179 | dat = empty((N, T, res, res, 3), dtype=float)
180 | dat_y = empty((N, T, 3, 4), dtype=float)
181 | for i in range(N):
182 | dat[i], dat_y[i] = bounce_vec(res=res, n=3, T=T)
183 | print('training example {} / {}'.format(i, N))
184 | data = dict()
185 | data['X'] = dat
186 | data['y'] = dat_y
187 | scipy.io.savemat('billards_balls_training_data.mat', data)
188 |
189 | N = 200
190 | dat = empty((N, T, res, res, 3), dtype=float)
191 | dat_y = empty((N, T, 3, 4), dtype=float)
192 | for i in range(N):
193 | dat[i], dat_y[i] = bounce_vec(res=res, n=3, T=T)
194 | print('test example {} / {}'.format(i, N))
195 | data = dict()
196 | data['X'] = dat
197 | data['y'] = dat_y
198 | scipy.io.savemat('billards_balls_testing_data.mat', data)
199 |
200 | # show one video
201 | show_sample(dat[0])
202 | print(dat_y[0, :])
203 |
--------------------------------------------------------------------------------
/figures/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stelzner/Visual-Interaction-Networks/aec463df90396506aa6dbe2a31e13e06ac1fe784/figures/2.png
--------------------------------------------------------------------------------
/figures/VIN-example.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stelzner/Visual-Interaction-Networks/aec463df90396506aa6dbe2a31e13e06ac1fe784/figures/VIN-example.gif
--------------------------------------------------------------------------------
/img/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore everything in this directory
2 | *
3 | # Except this file
4 | !.gitignore
5 |
--------------------------------------------------------------------------------
/load_data.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | from scipy.io import loadmat
3 | import numpy as np
4 | import torch
5 |
6 |
7 | def clips_from_episodes(images, labels, visible_l, rollout_l, step):
8 | """
9 | Rearrange episodic observations into shorter clips
10 | :param images: Episodes of images of shape (n, fr, c, h, w)
11 | :param labels: Episodes of accompanying data for the images (n, fr, obj, d)
12 | :param visible_l: Number of frames in each clip
13 | :param rollout_l: Number of future frames for which labels are returned
14 | :param step: Stepsize for taking frames from the given episodes
15 | :return: A number of shorter clips (_, visible_l, c, h, w),
16 | the corresponding labels (_, visible_l, obj, d),
17 | and future labels (_, rollout_l, obj, d).
18 | """
19 | (num_episodes, num_frames, height, width, channels) = images.shape
20 | num_obj = labels.shape[-2]
21 |
22 | clips_per_episode = num_frames - (rollout_l + visible_l) * step + 1
23 | num_clips = num_episodes * clips_per_episode
24 |
25 | clips = np.zeros((num_clips, visible_l, height, width, channels))
26 | present_labels = np.zeros((num_clips, visible_l - 2, num_obj, 4))
27 | future_labels = np.zeros((num_clips, rollout_l, num_obj, 4))
28 |
29 | for i in range(num_episodes):
30 | for j in range(clips_per_episode):
31 | clip_idx = i * clips_per_episode + j
32 |
33 | end_visible = j + visible_l * step
34 | end_rollout = end_visible + rollout_l * step
35 |
36 | clips[clip_idx] = images[i, j:end_visible:step]
37 | present_labels[clip_idx] = labels[i, j + 2*step:end_visible:step]
38 | future_labels[clip_idx] = labels[i, end_visible:end_rollout:step]
39 |
40 | # shuffle
41 | perm_idx = np.random.permutation(num_clips)
42 | return clips[perm_idx], present_labels[perm_idx], future_labels[perm_idx]
43 |
44 |
45 | class VinDataset(Dataset):
46 | """Face Landmarks dataset."""
47 |
48 | def __init__(self, config, transform=None, test=False):
49 | self.config = config
50 | self.transform = transform
51 |
52 | if test:
53 | data = loadmat(config.testdata)
54 | else:
55 | data = loadmat(config.traindata)
56 |
57 | self.total_img = data['X'][:config.num_episodes]
58 | # Transpose, as PyTorch images have shape (c, h, w)
59 | self.total_img = np.transpose(self.total_img, (0, 1, 4, 2, 3))
60 | self.total_data = data['y'][:config.num_episodes]
61 |
62 | num_eps, num_frames = self.total_img.shape[0:2]
63 | clips_per_ep = num_frames - ((config.num_visible +
64 | config.num_rollout) *
65 | config.frame_step) + 1
66 |
67 | idx_ep, idx_fr = np.meshgrid(list(range(num_eps)),
68 | list(range(clips_per_ep)))
69 |
70 | self.idxs = np.reshape(np.stack([idx_ep, idx_fr], 2), (-1, 2))
71 |
72 | def __len__(self):
73 | return len(self.idxs)
74 |
75 | def __getitem__(self, idx):
76 | conf = self.config
77 | step = conf.frame_step
78 |
79 | i, j = self.idxs[idx, 0], self.idxs[idx, 1]
80 |
81 | end_visible = j + conf.num_visible * step
82 | end_rollout = end_visible + conf.num_rollout * step
83 | image = self.total_img[i, j:end_visible:step]
84 | present = self.total_data[i, j + 2 * step:end_visible:step]
85 | future = self.total_data[i, end_visible:end_rollout:step]
86 |
87 | sample = {'image': torch.from_numpy(image),
88 | 'future_labels': torch.from_numpy(future),
89 | 'present_labels': torch.from_numpy(present)}
90 |
91 | return sample
92 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.autograd import Variable
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class Net(nn.Module):
9 | def __init__(self, config):
10 | super(Net, self).__init__()
11 | self.config = config
12 | self.pool = nn.MaxPool2d(2, 2)
13 | self.x_coord, self.y_coord = self.construct_coord_dims()
14 | cl = config.cl
15 |
16 | # Visual Encoder Modules
17 | self.conv1 = nn.Conv2d(config.channels * 2 + 2, 16, 3, padding=1)
18 | self.conv2 = nn.Conv2d(16, 16, 3, padding=1)
19 | self.conv3 = nn.Conv2d(16, 16, 3, padding=1)
20 | self.conv4 = nn.Conv2d(16, 32, 3, padding=1)
21 | self.conv5 = nn.Conv2d(32, 32, 3, padding=1)
22 | # shared linear layer to get pair codes of shape N_obj*cl
23 | self.fc1 = nn.Linear(32, 3 * cl)
24 |
25 | # shared MLP to encode pairs of pair codes as state codes N_obj*cl
26 | self.fc2 = nn.Linear(cl * 2, cl)
27 | self.fc3 = nn.Linear(cl, cl)
28 | # end of visual encoder
29 |
30 | # Interaction Net Core Modules
31 | # Self-dynamics MLP
32 | self.self_cores = nn.ModuleList()
33 | for i in range(3):
34 | self.self_cores.append(nn.ModuleList())
35 | self.self_cores[i].append(nn.Linear(cl, cl).double().cuda())
36 | self.self_cores[i].append(nn.Linear(cl, cl).double().cuda())
37 |
38 | # Relation MLP
39 | self.rel_cores = nn.ModuleList()
40 | for i in range(3):
41 | self.rel_cores.append(nn.ModuleList())
42 | self.rel_cores[i].append(nn.Linear(cl * 2, 2 * cl).double().cuda())
43 | self.rel_cores[i].append(nn.Linear(2 * cl, cl).double().cuda())
44 | self.rel_cores[i].append(nn.Linear(cl, cl).double().cuda())
45 |
46 | # Affector MLP
47 | self.affector = nn.ModuleList()
48 | for i in range(3):
49 | self.affector.append(nn.ModuleList())
50 | self.affector[i].append(nn.Linear(cl, cl).double().cuda())
51 | self.affector[i].append(nn.Linear(cl, cl).double().cuda())
52 | self.affector[i].append(nn.Linear(cl, cl).double().cuda())
53 |
54 | # Core output MLP
55 | self.out = nn.ModuleList()
56 | for i in range(3):
57 | self.out.append(nn.ModuleList())
58 | self.out[i].append(nn.Linear(cl + cl, cl).double().cuda())
59 | self.out[i].append(nn.Linear(cl, cl).double().cuda())
60 |
61 | # Aggregator MLP for aggregating core predictions
62 | self.aggregator1 = nn.Linear(cl * 3, cl)
63 | self.aggregator2 = nn.Linear(cl, cl)
64 |
65 | # decoder mapping state codes to actual states
66 | self.state_decoder = nn.Linear(cl, 4)
67 | # encoder for the non-visual case
68 | self.state_encoder = nn.Linear(4, cl)
69 |
70 | def construct_coord_dims(self):
71 | """
72 | Build a meshgrid of x, y coordinates to be used as additional channels
73 | """
74 | x = np.linspace(0, 1, self.config.width)
75 | y = np.linspace(0, 1, self.config.height)
76 | xv, yv = np.meshgrid(x, y)
77 | xv = np.reshape(xv, [1, 1, self.config.height, self.config.width])
78 | yv = np.reshape(yv, [1, 1, self.config.height, self.config.width])
79 | x_coord = Variable(torch.from_numpy(xv)).cuda()
80 | y_coord = Variable(torch.from_numpy(yv)).cuda()
81 | x_coord = x_coord.expand(self.config.batch_size * 5, -1, -1, -1)
82 | y_coord = y_coord.expand(self.config.batch_size * 5, -1, -1, -1)
83 | return x_coord, y_coord
84 |
85 | def core(self, s, core_idx):
86 | """
87 | Applies an interaction network core
88 | :param s: A state code of shape (n, o, cl)
89 | :param core_idx: The index of the set of parameters to apply (0, 1, 2)
90 | :return: Prediction of a future state code (n, o, cl)
91 | """
92 | objects = [s[:, i] for i in range(3)]
93 |
94 | self_sd_h1 = F.relu(self.self_cores[core_idx][0](s))
95 | self_dynamic = self.self_cores[core_idx][1](self_sd_h1) + self_sd_h1
96 |
97 | rel_combination = []
98 | for i in range(6):
99 | row_idx = i // 2
100 | # pick the two other objects
101 | col_idx = (row_idx + 1 + (i % 2)) % 3
102 | rel_combination.append(
103 | torch.cat([objects[row_idx], objects[col_idx]], 1))
104 | # 6 combinations of the 3 objects, (n, 6, 2*cl)
105 | rel_combination = torch.stack(rel_combination, 1)
106 | rel_sd_h1 = F.relu(self.rel_cores[core_idx][0](rel_combination))
107 | rel_sd_h2 = F.relu(self.rel_cores[core_idx][1](rel_sd_h1))
108 | rel_factors = self.rel_cores[core_idx][2](rel_sd_h2) + rel_sd_h2
109 | obj1 = rel_factors[:, 0] + rel_factors[:, 1]
110 | obj2 = rel_factors[:, 2] + rel_factors[:, 3]
111 | obj3 = rel_factors[:, 4] + rel_factors[:, 5]
112 | # relational dynamics per object, (n, o, cl)
113 | rel_dynamic = torch.stack([obj1, obj2, obj3], 1)
114 | # total dynamics
115 | dynamic_pred = self_dynamic + rel_dynamic
116 |
117 | aff1 = F.relu(self.affector[core_idx][0](dynamic_pred))
118 | aff2 = F.relu(self.affector[core_idx][1](aff1) + aff1)
119 | aff3 = self.affector[core_idx][2](aff2)
120 |
121 | aff_s = torch.cat([aff3, s], 2)
122 | out1 = F.relu(self.out[core_idx][0](aff_s))
123 | out2 = self.out[core_idx][1](out1) + out1
124 | return out2
125 |
126 | def frames_to_states(self, frames):
127 | """
128 | Apply visual encoder
129 | :param frames: Groups of six input frames of shape (n, 6, c, w, h)
130 | :return: State codes of shape (n, 4, o, cl)
131 | """
132 | batch_size = self.config.batch_size
133 | cl = self.config.cl
134 | num_obj = self.config.num_obj
135 |
136 | pairs = []
137 | for i in range(frames.shape[1] - 1):
138 | # pair consecutive frames (n, 2c, w, h)
139 | pair = torch.cat((frames[:, i], frames[:, i+1]), 1)
140 | pairs.append(pair)
141 |
142 | num_pairs = len(pairs)
143 | pairs = torch.cat(pairs, 0)
144 | # add coord channels (n * num_pairs, 2c + 2, w, h)
145 | pairs = torch.cat([pairs, self.x_coord, self.y_coord], dim=1)
146 |
147 | # apply ConvNet to pairs
148 | ve_h1 = F.relu(self.conv1(pairs))
149 | ve_h1 = self.pool(ve_h1)
150 | ve_h2 = F.relu(self.conv2(ve_h1))
151 | ve_h2 = self.pool(ve_h2)
152 | ve_h3 = F.relu(self.conv3(ve_h2))
153 | ve_h3 = self.pool(ve_h3)
154 | ve_h4 = F.relu(self.conv4(ve_h3))
155 | ve_h4 = self.pool(ve_h4)
156 | ve_h5 = F.relu(self.conv5(ve_h4))
157 | ve_h5 = self.pool(ve_h5)
158 |
159 | # pooled to 1x1, 32 channels: (n * num_pairs, 32)
160 | encoded_pairs = torch.squeeze(ve_h5)
161 | # final pair encoding (n * num_pairs, o, cl)
162 | encoded_pairs = self.fc1(encoded_pairs)
163 | encoded_pairs = encoded_pairs.view(batch_size * num_pairs, num_obj, cl)
164 | # chunk pairs encoding, each is (n, o, cl)
165 | encoded_pairs = torch.chunk(encoded_pairs, num_pairs)
166 |
167 | triples = []
168 | for i in range(num_pairs - 1):
169 | # pair consecutive pairs to obtain encodings for triples
170 | triple = torch.cat([encoded_pairs[i], encoded_pairs[i+1]], 2)
171 | triples.append(triple)
172 |
173 | # the triples together, i.e. (n, num_pairs - 1, o, 2 * cl)
174 | triples = torch.stack(triples, 1)
175 | # apply MLP to triples
176 | shared_h1 = F.relu(self.fc2(triples))
177 | state_codes = self.fc3(shared_h1)
178 | return state_codes
179 |
180 | def forward(self, x, num_rollout=8, visual=True):
181 | """
182 | Rollout a given sequence of observations using the model
183 | :param x: The given sequence of observations.
184 | If visual is True, it should be images of shape (n, 6, c, w, h),
185 | otherwise states of shape (n, 4, o, 4).
186 | :param num_rollout: The number of future states to be predicted
187 | :param visual: Boolean determining the type of input
188 | :return: rollout_states: predicted future states (n, roll_len, o, 4)
189 | present_states: predicted states at the time of
190 | the given observations, (n, 4, o, 4)
191 | """
192 | # get encoded states
193 | if visual:
194 | state_codes = self.frames_to_states(x)
195 | else:
196 | state_codes = self.state_encoder(x)
197 | # the 4 state codes (n, o, cl)
198 | s1, s2, s3, s4 = [state_codes[:, i] for i in range(4)]
199 | rollouts = []
200 | for i in range(num_rollout):
201 | # use cores to predict next state using delta_t = 1, 2, 4
202 | c1 = self.core(s4, 0)
203 | c2 = self.core(s3, 1)
204 | c4 = self.core(s1, 2)
205 | all_c = torch.cat([c1, c2, c4], 2)
206 | aggregator1 = F.relu(self.aggregator1(all_c))
207 | state_prediction = self.aggregator2(aggregator1)
208 | rollouts.append(state_prediction)
209 | s1, s2, s3, s4 = s2, s3, s4, state_prediction
210 | rollouts = torch.stack(rollouts, 1)
211 |
212 | present_states = self.state_decoder(state_codes)
213 | rollout_states = self.state_decoder(rollouts)
214 |
215 | return rollout_states, present_states
216 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | python_version > 3.4
2 | torch>=0.4
3 | numpy>=1.0
4 | scipy>=1.1
5 | matplotlib>=3.0
6 | visdom>=0.1
7 | imageio>=2.4
8 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import torch
5 | import torch.optim as optim
6 | from torch import nn
7 | from torch.utils.data import DataLoader
8 |
9 | from load_data import VinDataset
10 | from visualize import plot_positions, animate
11 |
12 |
13 | class Trainer:
14 | def __init__(self, config, net):
15 | self.net = net
16 | self.params = net.parameters()
17 | self.initial_values = {}
18 | self.config = config
19 |
20 | train_dataset = VinDataset(self.config)
21 | self.dataloader = DataLoader(train_dataset,
22 | batch_size=self.config.batch_size,
23 | shuffle=True,
24 | num_workers=4,
25 | drop_last=True)
26 |
27 | self.test_dataset = VinDataset(self.config,
28 | test=True)
29 | self.test_dataloader = DataLoader(self.test_dataset,
30 | batch_size=self.config.batch_size,
31 | shuffle=True,
32 | num_workers=4,
33 | drop_last=True)
34 |
35 | self.optimizer = optim.Adam(self.net.parameters(), lr=0.0005)
36 | if config.load:
37 | self.load()
38 |
39 | def save(self):
40 | torch.save(self.net.state_dict(), os.path.join(
41 | self.config.checkpoint_dir, "checkpoint"))
42 | print('Parameters saved')
43 |
44 | def load(self):
45 | try:
46 | self.net.load_state_dict(torch.load(os.path.join(
47 | self.config.checkpoint_dir, "checkpoint")))
48 | print('Parameters loaded')
49 | except RuntimeError:
50 | print('Loading parameters failed, training from scratch...')
51 |
52 | def compute_loss(self, present_labels, future_labels, recons, preds):
53 | loss = nn.MSELoss()
54 | df = self.config.discount_factor
55 | pred_loss = 0.0
56 | for delta_t in range(0, self.config.num_rollout):
57 | pred_loss += (df ** (delta_t + 1)) * \
58 | loss(preds[:, delta_t], future_labels[:, delta_t])
59 |
60 | recon_loss = loss(recons, present_labels)
61 | total_loss = pred_loss + recon_loss
62 |
63 | return total_loss, pred_loss, recon_loss
64 |
65 | def train(self):
66 | step_counter = 0
67 | num_rollout = self.config.num_rollout
68 | for epoch in range(100):
69 | print("testing................")
70 | self.test()
71 | for i, data in enumerate(self.dataloader, 0):
72 | step_counter += 1
73 | images = data['image']
74 | present_labels = data['present_labels']
75 | future_labels = data['future_labels']
76 |
77 | images = images.cuda()
78 | present_labels = present_labels.cuda()
79 | future_labels = future_labels.cuda()
80 |
81 | if self.config.visual:
82 | vin_input = images
83 | else:
84 | vin_input = present_labels
85 |
86 | self.optimizer.zero_grad()
87 | state_pred, state_recon = self.net(vin_input,
88 | num_rollout=num_rollout,
89 | visual=self.config.visual)
90 |
91 | total_loss, pred_loss, recon_loss = \
92 | self.compute_loss(present_labels, future_labels,
93 | state_recon, state_pred)
94 |
95 | total_loss.backward()
96 | self.optimizer.step()
97 |
98 | # print loss
99 | if step_counter % 20 == 0:
100 | print('{:5d} {:5f} {:5f} {:5f}'.format(step_counter,
101 | total_loss.item(),
102 | recon_loss.item(),
103 | pred_loss.item()))
104 | # Draw example
105 | if step_counter % 200 == 0:
106 | real = torch.cat([present_labels[0], future_labels[0]])
107 | simu = torch.cat([state_recon[0], state_pred[0]]).detach()
108 | plot_positions(real, self.config.img_folder, 'real')
109 | plot_positions(simu, self.config.img_folder, 'rollout')
110 | # Save parameters
111 | if (step_counter + 1) % 1000 == 0:
112 | self.save()
113 |
114 | print("epoch ", epoch, " Finished")
115 | print('Finished Training')
116 |
117 | def test(self):
118 | total_loss = 0.0
119 | for i, data in enumerate(self.test_dataloader, 0):
120 | images, future_labels, present_labels = \
121 | data['image'], data['future_labels'], data['present_labels']
122 |
123 | images = images.cuda()
124 | present_labels = present_labels.cuda()
125 | future_labels = future_labels.cuda()
126 |
127 | vin_input = images if self.config.visual else present_labels
128 |
129 | pred, recon = self.net(vin_input,
130 | num_rollout=self.config.num_rollout,
131 | visual=self.config.visual)
132 |
133 | total_loss, pred_loss, recon_loss = \
134 | self.compute_loss(present_labels, future_labels,
135 | recon, pred)
136 |
137 | print('total test loss {:5f}'.format(total_loss.item()))
138 |
139 | # Create one long rollout and save it as an animated GIF
140 | total_images = self.test_dataset.total_img
141 | total_labels = self.test_dataset.total_data
142 | step = self.config.frame_step
143 | visible = self.config.num_visible
144 | batch_size = self.config.batch_size
145 |
146 | long_rollout_length = self.config.num_frames // step - visible
147 |
148 | if self.config.visual:
149 | vin_input = total_images[:batch_size, :visible*step:step]
150 | else:
151 | vin_input = total_labels[:batch_size, 2*step:visible*step:step]
152 |
153 | vin_input = torch.tensor(vin_input).cuda()
154 |
155 | pred, recon = self.net(vin_input, long_rollout_length,
156 | visual=self.config.visual)
157 |
158 | simu_rollout = pred[0].detach().cpu().numpy()
159 | simu_recon = recon[0].detach().cpu().numpy()
160 | simu = np.concatenate((simu_recon, simu_rollout), axis=0)
161 |
162 | # Saving
163 | print("Make GIFs")
164 | animate(total_labels[0, 2:], self.config.img_folder, 'real')
165 | animate(simu, self.config.img_folder, 'rollout')
166 | print("Done")
167 |
--------------------------------------------------------------------------------
/vin.py:
--------------------------------------------------------------------------------
1 | from model import Net
2 | from config import VinConfig
3 | from train import Trainer
4 |
5 |
6 | def main():
7 | net = Net(VinConfig)
8 | net = net.cuda()
9 | net = net.double()
10 | trainer = Trainer(VinConfig, net)
11 | trainer.train()
12 |
13 |
14 | if __name__ == '__main__':
15 | main()
16 |
--------------------------------------------------------------------------------
/visualize.py:
--------------------------------------------------------------------------------
1 | import os
2 | import matplotlib.pyplot as plt
3 | import visdom
4 | import numpy as np
5 | import imageio
6 |
7 | vis = visdom.Visdom()
8 |
9 |
10 | def plot_positions(xy, img_folder, prefix, save=True, size=10):
11 | if not os.path.exists(img_folder):
12 | os.makedirs(img_folder)
13 | fig_num = len(xy)
14 | mydpi = 100
15 | fig = plt.figure(figsize=(128/mydpi, 128/mydpi))
16 | plt.xlim(0, 10)
17 | plt.ylim(0, 10)
18 | plt.xticks([])
19 | plt.yticks([])
20 |
21 | color = ['r', 'b', 'g', 'k', 'y', 'm', 'c']
22 | for i in range(fig_num):
23 | for j in range(len(xy[0])):
24 | plt.scatter(xy[i, j, 1], xy[i, j, 0],
25 | c=color[j % len(color)], s=size, alpha=(i+1)/fig_num)
26 |
27 | if save:
28 | fig.savefig(img_folder+prefix+".pdf", dpi=mydpi, transparent=True)
29 | vis.matplot(fig)
30 |
31 | fig.canvas.draw()
32 | image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
33 | image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
34 | plt.close()
35 | return image
36 |
37 |
38 | def animate(states, img_folder, prefix):
39 | images = []
40 | for i in range(len(states)):
41 | images.append(plot_positions(states[i:i + 1], img_folder,
42 | prefix, save=False, size=270))
43 | imageio.mimsave(img_folder+prefix+'.gif', images, fps=24)
44 |
--------------------------------------------------------------------------------