├── physics_engine.pyc ├── true_gravity_object_2.mp4 ├── true_gravity_object_3.mp4 ├── true_gravity_object_6.mp4 ├── figures ├── gravity_object_2.PNG ├── interaction_net.PNG └── interaction_net_2.PNG ├── modeling_gravity_object_2.mp4 ├── modeling_gravity_object_3.mp4 ├── modeling_gravity_object_6.mp4 ├── LICENSE ├── physics_engine.py ├── README.md └── interaction_network.py /physics_engine.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsikyoon/Interaction-networks_tensorflow/HEAD/physics_engine.pyc -------------------------------------------------------------------------------- /true_gravity_object_2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsikyoon/Interaction-networks_tensorflow/HEAD/true_gravity_object_2.mp4 -------------------------------------------------------------------------------- /true_gravity_object_3.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsikyoon/Interaction-networks_tensorflow/HEAD/true_gravity_object_3.mp4 -------------------------------------------------------------------------------- /true_gravity_object_6.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsikyoon/Interaction-networks_tensorflow/HEAD/true_gravity_object_6.mp4 -------------------------------------------------------------------------------- /figures/gravity_object_2.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsikyoon/Interaction-networks_tensorflow/HEAD/figures/gravity_object_2.PNG -------------------------------------------------------------------------------- /figures/interaction_net.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsikyoon/Interaction-networks_tensorflow/HEAD/figures/interaction_net.PNG -------------------------------------------------------------------------------- /figures/interaction_net_2.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsikyoon/Interaction-networks_tensorflow/HEAD/figures/interaction_net_2.PNG -------------------------------------------------------------------------------- /modeling_gravity_object_2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsikyoon/Interaction-networks_tensorflow/HEAD/modeling_gravity_object_2.mp4 -------------------------------------------------------------------------------- /modeling_gravity_object_3.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsikyoon/Interaction-networks_tensorflow/HEAD/modeling_gravity_object_3.mp4 -------------------------------------------------------------------------------- /modeling_gravity_object_6.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsikyoon/Interaction-networks_tensorflow/HEAD/modeling_gravity_object_6.mp4 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Jaesik Yoon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /physics_engine.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import sys 7 | import os 8 | import numpy as np 9 | import time 10 | from math import sin, cos, radians, pi 11 | import matplotlib 12 | matplotlib.use("Agg") 13 | import matplotlib.pyplot as plt 14 | import matplotlib.animation as manimation 15 | import cv2 16 | 17 | # 1000 time steps 18 | total_state=1000; 19 | # 5 features on the state [mass,x,y,x_vel,y_vel] 20 | fea_num=5; 21 | # G 22 | #G = 6.67428e-11; 23 | G=10**5; 24 | # time step 25 | diff_t=0.001; 26 | 27 | def init(total_state,n_body,fea_num,orbit): 28 | data=np.zeros((total_state,n_body,fea_num),dtype=float); 29 | if(orbit): 30 | data[0][0][0]=100; 31 | data[0][0][1:5]=0.0; 32 | for i in range(1,n_body): 33 | data[0][i][0]=np.random.rand()*8.98+0.02; 34 | distance=np.random.rand()*90.0+10.0; 35 | theta=np.random.rand()*360; 36 | theta_rad = pi/2 - radians(theta); 37 | data[0][i][1]=distance*cos(theta_rad); 38 | data[0][i][2]=distance*sin(theta_rad); 39 | data[0][i][3]=-1*data[0][i][2]/norm(data[0][i][1:3])*(G*data[0][0][0]/norm(data[0][i][1:3])**2)*distance/1000; 40 | data[0][i][4]=data[0][i][1]/norm(data[0][i][1:3])*(G*data[0][0][0]/norm(data[0][i][1:3])**2)*distance/1000; 41 | #data[0][i][3]=np.random.rand()*10.0-5.0; 42 | #data[0][i][4]=np.random.rand()*10.0-5.0; 43 | else: 44 | for i in range(n_body): 45 | data[0][i][0]=np.random.rand()*8.98+0.02; 46 | distance=np.random.rand()*90.0+10.0; 47 | theta=np.random.rand()*360; 48 | theta_rad = pi/2 - radians(theta); 49 | data[0][i][1]=distance*cos(theta_rad); 50 | data[0][i][2]=distance*sin(theta_rad); 51 | data[0][i][3]=np.random.rand()*6.0-3.0; 52 | data[0][i][4]=np.random.rand()*6.0-3.0; 53 | return data; 54 | 55 | def norm(x): 56 | return np.sqrt(np.sum(x**2)); 57 | 58 | def get_f(reciever,sender): 59 | diff=sender[1:3]-reciever[1:3]; 60 | distance=norm(diff); 61 | if(distance<1): 62 | distance=1; 63 | return G*reciever[0]*sender[0]/(distance**3)*diff; 64 | 65 | def calc(cur_state,n_body): 66 | next_state=np.zeros((n_body,fea_num),dtype=float); 67 | f_mat=np.zeros((n_body,n_body,2),dtype=float); 68 | f_sum=np.zeros((n_body,2),dtype=float); 69 | acc=np.zeros((n_body,2),dtype=float); 70 | for i in range(n_body): 71 | for j in range(i+1,n_body): 72 | if(j!=i): 73 | f=get_f(cur_state[i][:3],cur_state[j][:3]); 74 | f_mat[i,j]+=f; 75 | f_mat[j,i]-=f; 76 | f_sum[i]=np.sum(f_mat[i],axis=0); 77 | acc[i]=f_sum[i]/cur_state[i][0]; 78 | next_state[i][0]=cur_state[i][0]; 79 | next_state[i][3:5]=cur_state[i][3:5]+acc[i]*diff_t; 80 | next_state[i][1:3]=cur_state[i][1:3]+next_state[i][3:5]*diff_t; 81 | return next_state; 82 | 83 | def gen(n_body,orbit): 84 | # initialization on just first state 85 | data=init(total_state,n_body,fea_num,orbit); 86 | for i in range(1,total_state): 87 | data[i]=calc(data[i-1],n_body); 88 | return data; 89 | 90 | def make_video(xy,filename): 91 | os.system("rm -rf pics/*"); 92 | FFMpegWriter = manimation.writers['ffmpeg'] 93 | metadata = dict(title='Movie Test', artist='Matplotlib', 94 | comment='Movie support!') 95 | writer = FFMpegWriter(fps=15, metadata=metadata) 96 | fig = plt.figure() 97 | plt.xlim(-200, 200) 98 | plt.ylim(-200, 200) 99 | fig_num=len(xy); 100 | color=['ro','bo','go','ko','yo','mo','co']; 101 | with writer.saving(fig, filename, len(xy)): 102 | for i in range(len(xy)): 103 | for j in range(len(xy[0])): 104 | plt.plot(xy[i,j,1],xy[i,j,0],color[j%len(color)]); 105 | writer.grab_frame(); 106 | 107 | if __name__=='__main__': 108 | data=gen(3,True); 109 | xy=data[:,:,1:3]; 110 | make_video(xy,"test.mp4"); 111 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Interaction Networks 2 | ==================== 3 | 4 | Tensorflow Implementation of Interaction Networks from Google Deepmind. 5 | 6 | Implementation is on Tensorflow r1.2. 7 | 8 | https://arxiv.org/abs/1612.00222 9 | 10 | "Reasoning about objects, relations, and physics is central to human intelligence, and 11 | a key goal of artificial intelligence. Here we introduce the interaction network, a 12 | model which can reason about how objects in complex systems interact, supporting 13 | dynamical predictions, as well as inferences about the abstract properties of the 14 | system." 15 | From the paper 16 | 17 | ![alt tag](https://github.com/jaesik817/Interaction-networks_tensorflow/blob/master/figures/interaction_net.PNG) 18 | 19 | ![alt tag](https://github.com/jaesik817/Interaction-networks_tensorflow/blob/master/figures/interaction_net_2.PNG) 20 | 21 | N-body Simulation 22 | -------------------- 23 | 24 | ` 25 | python interaction_network.py 26 | ` 27 | 28 | ### Data 29 | Data are gathered from implemented physics engine (physics_engine.py), which by using given the number of objects initializes and processes random weights[0.2-9kg], distance[10-100m] and angle[0-360] same as the paper settings. 30 | Orbit system is implemented 0-th really big object in central (100kg, same to paper), and velocities of other objects are initialized as supporting stable status. 31 | Object state values of non-orbit system are randomly initialzed as above. 32 | 33 | Before traing that, the data (not label!) is initialized as bellowed (same to the paper). 34 | median value is going to 0, and 5% and 95% data are going to -1 and 1, respectively. 35 | This initialization is processed differently for mass, positions and velocities. 36 | 37 | ### Settings 38 | Almost settings are same to the paper, which is clearly written for that except bellowed. 39 | 40 | For state, there no description, and I used 5-D vector [mass,x_position,y_position,x_velocity,y_velocity]. 41 | For R_a and X, there are no description for n-body test, thus I used zero-array with D_R=1, D_X=1. 42 | For D_P, descriptions in page 6 are different, (for f_O, D_P=2, and for \phi_A, D_P=10) in my implementation, I used D_P=2 (x and y velocities). 43 | (I checked that with 1st Author of this paper, Peter W. Battaglia, D_P is used as 10 for estimating potiential energy, and 2 for estimating next state.) 44 | For b_k, descriptions in implementation is to concatenate OR_r and OR_s, however in model architecture, that is described to difference vector between them. In my implementation, I used difference vector. 45 | For input of function a, descriptions in implementation is to object matrix O, however in model architecture, that is described to just use velocities as input. In my implementation, I just used velocities as input. 46 | 47 | Except above three things, other settings are same to paper as followed. 48 | 49 | For \phi_R function, 4 hidden layers MLP with 150 nodes, ReLU activation function for hidden layers, linear ones for output layer and D_E=50 settings are used. 50 | For \phi_O function, 1 hidden layers MLP with 100 nodes, ReLU for hidden layer, linear for output layer and D_P=2 settings are used. 51 | Adam Optimizer is used with 0.001 learning rate. 52 | L2-regularization is used for matrix E and all parameters (lambda values for each regularization are 0.001). 53 | 54 | I generated 10 samples, which have 1000 frames, and 90%/10% data are used for training and validation. Max epoch is 40,000. For qualititive measure, I newly generated 1 sample and made video files for preidiction from model and true ones. 55 | 56 | I did not use Random noise when traning and balancing in batch. 57 | 58 | ### Results 59 | 60 | ![alt tag](https://github.com/jaesik817/Interaction-networks_tensorflow/blob/master/figures/gravity_object_2.PNG) 61 | 62 | The experiments are 2-object,3-object and 6-object ones. 63 | 64 | Above MSE graph is 2-object experiment one. 65 | 66 | The validation velocities MSE has been saturated to about 0.2~0.3,10 and 20,000 for 2-object, 3-object and 6-object, and video generated new data (not training ones!) is quitely good. 67 | -------------------------------------------------------------------------------- /interaction_network.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import sys 7 | import copy 8 | import tensorflow as tf 9 | from sklearn.cluster import KMeans 10 | 11 | import numpy as np 12 | import time 13 | from physics_engine import gen, make_video 14 | 15 | FLAGS = None 16 | 17 | def variable_summaries(var,idx): 18 | """Attach a lot of summaries to a Tensor (for TensorBoard visualization).""" 19 | with tf.name_scope('summaries_'+str(idx)): 20 | mean = tf.reduce_mean(var) 21 | tf.summary.scalar('mean', mean) 22 | with tf.name_scope('stddev'): 23 | stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))) 24 | tf.summary.scalar('stddev', stddev) 25 | tf.summary.scalar('max', tf.reduce_max(var)) 26 | tf.summary.scalar('min', tf.reduce_min(var)) 27 | tf.summary.histogram('histogram', var) 28 | 29 | def m(O,Rr,Rs,Ra): 30 | return tf.concat([(tf.matmul(O,Rr)-tf.matmul(O,Rs)),Ra],1); 31 | #return tf.concat([tf.matmul(O,Rr),tf.matmul(O,Rs),Ra],1); 32 | 33 | 34 | def phi_R(B): 35 | h_size=150; 36 | B_trans=tf.transpose(B,[0,2,1]); 37 | B_trans=tf.reshape(B_trans,[-1,(FLAGS.Ds+FLAGS.Dr)]); 38 | w1 = tf.Variable(tf.truncated_normal([(FLAGS.Ds+FLAGS.Dr), h_size], stddev=0.1), name="r_w1", dtype=tf.float32); 39 | b1 = tf.Variable(tf.zeros([h_size]), name="r_b1", dtype=tf.float32); 40 | h1 = tf.nn.relu(tf.matmul(B_trans, w1) + b1); 41 | w2 = tf.Variable(tf.truncated_normal([h_size, h_size], stddev=0.1), name="r_w2", dtype=tf.float32); 42 | b2 = tf.Variable(tf.zeros([h_size]), name="r_b2", dtype=tf.float32); 43 | h2 = tf.nn.relu(tf.matmul(h1, w2) + b2); 44 | w3 = tf.Variable(tf.truncated_normal([h_size, h_size], stddev=0.1), name="r_w3", dtype=tf.float32); 45 | b3 = tf.Variable(tf.zeros([h_size]), name="r_b3", dtype=tf.float32); 46 | h3 = tf.nn.relu(tf.matmul(h2, w3) + b3); 47 | w4 = tf.Variable(tf.truncated_normal([h_size, h_size], stddev=0.1), name="r_w4", dtype=tf.float32); 48 | b4 = tf.Variable(tf.zeros([h_size]), name="r_b4", dtype=tf.float32); 49 | h4 = tf.nn.relu(tf.matmul(h3, w4) + b4); 50 | w5 = tf.Variable(tf.truncated_normal([h_size, FLAGS.De], stddev=0.1), name="r_w5", dtype=tf.float32); 51 | b5 = tf.Variable(tf.zeros([FLAGS.De]), name="r_b5", dtype=tf.float32); 52 | h5 = tf.matmul(h4, w5) + b5; 53 | h5_trans=tf.reshape(h5,[-1,FLAGS.Nr,FLAGS.De]); 54 | h5_trans=tf.transpose(h5_trans,[0,2,1]); 55 | return(h5_trans); 56 | 57 | def a(O,Rr,X,E): 58 | E_bar=tf.matmul(E,tf.transpose(Rr,[0,2,1])); 59 | O_2=tf.stack(tf.unstack(O,FLAGS.Ds,1)[3:5],1); 60 | return (tf.concat([O_2,X,E_bar],1)); 61 | #return (tf.concat([O,X,E_bar],1)); 62 | 63 | def phi_O(C): 64 | h_size=100; 65 | C_trans=tf.transpose(C,[0,2,1]); 66 | C_trans=tf.reshape(C_trans,[-1,(2+FLAGS.Dx+FLAGS.De)]); 67 | w1 = tf.Variable(tf.truncated_normal([(2+FLAGS.Dx+FLAGS.De), h_size], stddev=0.1), name="o_w1", dtype=tf.float32); 68 | b1 = tf.Variable(tf.zeros([h_size]), name="o_b1", dtype=tf.float32); 69 | h1 = tf.nn.relu(tf.matmul(C_trans, w1) + b1); 70 | w2 = tf.Variable(tf.truncated_normal([h_size, FLAGS.Dp], stddev=0.1), name="o_w2", dtype=tf.float32); 71 | b2 = tf.Variable(tf.zeros([FLAGS.Dp]), name="o_b2", dtype=tf.float32); 72 | h2 = tf.matmul(h1, w2) + b2; 73 | h2_trans=tf.reshape(h2,[-1,FLAGS.No,FLAGS.Dp]); 74 | h2_trans=tf.transpose(h2_trans,[0,2,1]); 75 | return(h2_trans); 76 | 77 | def phi_A(P): 78 | h_size=25; 79 | p_bar=tf.reduce_sum(P,2); 80 | w1 = tf.Variable(tf.truncated_normal([FLAGS.Dp, h_size], stddev=0.1), name="a_w1", dtype=tf.float32); 81 | b1 = tf.Variable(tf.zeros([h_size]), name="a_b1", dtype=tf.float32); 82 | h1 = tf.nn.relu(tf.matmul(p_bar, w1) + b1); 83 | w2 = tf.Variable(tf.truncated_normal([h_size, FLAGS.Da], stddev=0.1), name="a_w2", dtype=tf.float32); 84 | b2 = tf.Variable(tf.zeros([FLAGS.Da]), name="a_b2", dtype=tf.float32); 85 | h2 = tf.matmul(h1, w2) + b2; 86 | return(h2); 87 | 88 | def train(): 89 | 90 | # Object Matrix 91 | O = tf.placeholder(tf.float32, [None,FLAGS.Ds,FLAGS.No], name="O"); 92 | # Relation Matrics R= 93 | Rr = tf.placeholder(tf.float32, [None,FLAGS.No,FLAGS.Nr], name="Rr"); 94 | Rs = tf.placeholder(tf.float32, [None,FLAGS.No,FLAGS.Nr], name="Rs"); 95 | Ra = tf.placeholder(tf.float32, [None,FLAGS.Dr,FLAGS.Nr], name="Ra"); 96 | # next velocities 97 | P_label = tf.placeholder(tf.float32, [None,FLAGS.Dp,FLAGS.No], name="P_label"); 98 | # External Effects 99 | X = tf.placeholder(tf.float32, [None,FLAGS.Dx,FLAGS.No], name="X"); 100 | 101 | # marshalling function, m(G)=B, G= 102 | B=m(O,Rr,Rs,Ra); 103 | 104 | # relational modeling phi_R(B)=E 105 | E=phi_R(B); 106 | 107 | # aggregator 108 | C=a(O,Rr,X,E); 109 | 110 | # object modeling phi_O(C)=P 111 | P=phi_O(C); 112 | 113 | # abstract modeling phi_A(P)=q 114 | #q=phi_A(P); 115 | 116 | # loss and optimizer 117 | params_list=tf.global_variables(); 118 | for i in range(len(params_list)): 119 | variable_summaries(params_list[i],i); 120 | mse=tf.reduce_mean(tf.reduce_mean(tf.square(P-P_label),[1,2])); 121 | loss = 0.001*tf.nn.l2_loss(E); 122 | for i in params_list: 123 | loss+=0.001*tf.nn.l2_loss(i); 124 | optimizer = tf.train.AdamOptimizer(0.001); 125 | trainer=optimizer.minimize(mse+loss); 126 | 127 | # tensorboard 128 | tf.summary.scalar('mse',mse); 129 | merged=tf.summary.merge_all(); 130 | writer=tf.summary.FileWriter(FLAGS.log_dir); 131 | 132 | sess=tf.InteractiveSession(); 133 | tf.global_variables_initializer().run(); 134 | 135 | # Data Generation 136 | set_num=10; 137 | #set_num=2000; 138 | total_data=np.zeros((999*set_num,FLAGS.Ds,FLAGS.No),dtype=object); 139 | total_label=np.zeros((999*set_num,FLAGS.Dp,FLAGS.No),dtype=object); 140 | for i in range(set_num): 141 | raw_data=gen(FLAGS.No,True); 142 | data=np.zeros((999,FLAGS.Ds,FLAGS.No),dtype=object); 143 | label=np.zeros((999,FLAGS.Dp,FLAGS.No),dtype=object); 144 | for j in range(1000-1): 145 | data[j]=np.transpose(raw_data[j]);label[j]=np.transpose(raw_data[j+1,:,3:5]); 146 | total_data[i*999:(i+1)*999,:]=data; 147 | total_label[i*999:(i+1)*999,:]=label; 148 | 149 | # Shuffle 150 | tr_data_num=999*(set_num-1); 151 | val_data_num=300; 152 | #tr_data_num=1000000; 153 | #val_data_num=200000; 154 | total_idx=range(len(total_data));np.random.shuffle(total_idx); 155 | mixed_data=total_data[total_idx]; 156 | mixed_label=total_label[total_idx]; 157 | # Training/Validation/Test 158 | train_data=mixed_data[:tr_data_num]; 159 | train_label=mixed_label[:tr_data_num]; 160 | val_data=mixed_data[tr_data_num:tr_data_num+val_data_num]; 161 | val_label=mixed_label[tr_data_num:tr_data_num+val_data_num]; 162 | test_data=mixed_data[tr_data_num+val_data_num:]; 163 | test_label=mixed_label[tr_data_num+val_data_num:]; 164 | 165 | # Normalization 166 | weights_list=np.sort(np.reshape(train_data[:,0,:],[1,tr_data_num*FLAGS.No])[0]); 167 | weights_median=weights_list[int(len(weights_list)*0.5)]; 168 | weights_min=weights_list[int(len(weights_list)*0.05)]; 169 | weights_max=weights_list[int(len(weights_list)*0.95)]; 170 | position_list=np.sort(np.reshape(train_data[:,1:3,:],[1,tr_data_num*FLAGS.No*2])[0]); 171 | position_median=position_list[int(len(position_list)*0.5)]; 172 | position_min=position_list[int(len(position_list)*0.05)]; 173 | position_max=position_list[int(len(position_list)*0.95)]; 174 | velocity_list=np.sort(np.reshape(train_data[:,3:5,:],[1,tr_data_num*FLAGS.No*2])[0]); 175 | velocity_median=velocity_list[int(len(velocity_list)*0.5)]; 176 | velocity_min=velocity_list[int(len(velocity_list)*0.05)]; 177 | velocity_max=velocity_list[int(len(velocity_list)*0.95)]; 178 | 179 | train_data[:,0,:]=(train_data[:,0,:]-weights_median)*(2/(weights_max-weights_min)); 180 | train_data[:,1:3,:]=(train_data[:,1:3,:]-position_median)*(2/(position_max-position_min)); 181 | train_data[:,3:5,:]=(train_data[:,3:5,:]-velocity_median)*(2/(velocity_max-velocity_min)); 182 | 183 | val_data[:,0,:]=(val_data[:,0,:]-weights_median)*(2/(weights_max-weights_min)); 184 | val_data[:,1:3,:]=(val_data[:,1:3,:]-position_median)*(2/(position_max-position_min)); 185 | val_data[:,3:5,:]=(val_data[:,3:5,:]-velocity_median)*(2/(velocity_max-velocity_min)); 186 | 187 | test_data[:,0,:]=(test_data[:,0,:]-weights_median)*(2/(weights_max-weights_min)); 188 | test_data[:,1:3,:]=(test_data[:,1:3,:]-position_median)*(2/(position_max-position_min)); 189 | test_data[:,3:5,:]=(test_data[:,3:5,:]-velocity_median)*(2/(velocity_max-velocity_min)); 190 | 191 | mini_batch_num=100; 192 | # Set Rr_data, Rs_data, Ra_data and X_data 193 | Rr_data=np.zeros((mini_batch_num,FLAGS.No,FLAGS.Nr),dtype=float); 194 | Rs_data=np.zeros((mini_batch_num,FLAGS.No,FLAGS.Nr),dtype=float); 195 | Ra_data=np.zeros((mini_batch_num,FLAGS.Dr,FLAGS.Nr),dtype=float); 196 | X_data=np.zeros((mini_batch_num,FLAGS.Dx,FLAGS.No),dtype=float); 197 | cnt=0; 198 | for i in range(FLAGS.No): 199 | for j in range(FLAGS.No): 200 | if(i!=j): 201 | Rr_data[:,i,cnt]=1.0; 202 | Rs_data[:,j,cnt]=1.0; 203 | cnt+=1; 204 | 205 | # Training 206 | max_epoches=2000*20; 207 | for i in range(max_epoches): 208 | tr_loss=0; 209 | for j in range(int(len(train_data)/mini_batch_num)): 210 | batch_data=train_data[j*mini_batch_num:(j+1)*mini_batch_num]; 211 | batch_label=train_label[j*mini_batch_num:(j+1)*mini_batch_num]; 212 | tr_loss_part,_=sess.run([mse,trainer],feed_dict={O:batch_data,Rr:Rr_data,Rs:Rs_data,Ra:Ra_data,P_label:batch_label,X:X_data}); 213 | tr_loss+=tr_loss_part; 214 | train_idx=range(len(train_data));np.random.shuffle(train_idx); 215 | train_data=train_data[train_idx]; 216 | train_label=train_label[train_idx]; 217 | val_loss=0; 218 | for j in range(int(len(val_data)/mini_batch_num)): 219 | batch_data=val_data[j*mini_batch_num:(j+1)*mini_batch_num]; 220 | batch_label=val_label[j*mini_batch_num:(j+1)*mini_batch_num]; 221 | #if(j==0): 222 | # summary,val_loss_part,estimated=sess.run([merged,mse,P],feed_dict={O:batch_data,Rr:Rr_data,Rs:Rs_data,Ra:Ra_data,P_label:batch_label,X:X_data}); 223 | # writer.add_summary(summary,i); 224 | #else: 225 | val_loss_part,estimated=sess.run([mse,P],feed_dict={O:batch_data,Rr:Rr_data,Rs:Rs_data,Ra:Ra_data,P_label:batch_label,X:X_data}); 226 | val_loss+=val_loss_part; 227 | val_idx=range(len(val_data));np.random.shuffle(val_idx); 228 | val_data=val_data[val_idx]; 229 | val_label=val_label[val_idx]; 230 | print("Epoch "+str(i+1)+" Training MSE: "+str(tr_loss/(int(len(train_data)/mini_batch_num)))+" Validation MSE: "+str(val_loss/(j+1))); 231 | 232 | # Make Video 233 | frame_len=300; 234 | raw_data=gen(FLAGS.No,True); 235 | xy_origin=copy.deepcopy(raw_data[:frame_len,:,1:3]); 236 | estimated_data=np.zeros((frame_len,FLAGS.No,FLAGS.Ds),dtype=float); 237 | raw_data[:,:,0]=(raw_data[:,:,0]-weights_median)*(2/(weights_max-weights_min)); 238 | raw_data[:,:,1:3]=(raw_data[:,:,1:3]-position_median)*(2/(position_max-position_min)); 239 | raw_data[:,:,3:5]=(raw_data[:,:,3:5]-velocity_median)*(2/(velocity_max-velocity_min)); 240 | estimated_data[0]=raw_data[0]; 241 | ts_loss=0; 242 | for i in range(1,frame_len): 243 | ts_loss_part,velocities=sess.run([mse,P],feed_dict={O:[np.transpose(raw_data[i-1])],Rr:[Rr_data[0]],Rs:[Rs_data[0]],Ra:[Ra_data[0]],X:[X_data[0]],P_label:[np.transpose(raw_data[i,:,3:5]*(velocity_max-velocity_min)/2+velocity_median)]}); 244 | velocities=velocities[0]; 245 | estimated_data[i,:,0]=estimated_data[i-1][:,0]; 246 | estimated_data[i,:,3:5]=np.transpose(velocities); 247 | estimated_data[i,:,1:3]=(estimated_data[i-1,:,1:3]*(position_max-position_min)/2+position_median)+estimated_data[i,:,3:5]*0.001; 248 | estimated_data[i,:,1:3]=(estimated_data[i,:,1:3]-position_median)*(2/(position_max-position_min)); 249 | ts_loss+=ts_loss_part; 250 | xy_estimated=estimated_data[:,:,1:3]*(position_max-position_min)/2+position_median; 251 | print("Test Loss: "+str(ts_loss)); 252 | print("Video Recording"); 253 | make_video(xy_origin,"true"+str(time.time())+".mp4"); 254 | make_video(xy_estimated,"modeling"+str(time.time())+".mp4"); 255 | print("Done"); 256 | 257 | def main(_): 258 | FLAGS.log_dir+=str(int(time.time())); 259 | if tf.gfile.Exists(FLAGS.log_dir): 260 | tf.gfile.DeleteRecursively(FLAGS.log_dir) 261 | tf.gfile.MakeDirs(FLAGS.log_dir) 262 | train() 263 | 264 | 265 | if __name__ == '__main__': 266 | parser = argparse.ArgumentParser() 267 | parser.add_argument('--log_dir', type=str, default='/tmp/interaction-network/', 268 | help='Summaries log directry') 269 | parser.add_argument('--Ds', type=int, default=5, 270 | help='The State Dimention') 271 | parser.add_argument('--No', type=int, default=6, 272 | help='The Number of Objects') 273 | parser.add_argument('--Nr', type=int, default=30, 274 | help='The Number of Relations') 275 | parser.add_argument('--Dr', type=int, default=1, 276 | help='The Relationship Dimension') 277 | parser.add_argument('--Dx', type=int, default=1, 278 | help='The External Effect Dimension') 279 | parser.add_argument('--De', type=int, default=50, 280 | help='The Effect Dimension') 281 | parser.add_argument('--Dp', type=int, default=2, 282 | help='The Object Modeling Output Dimension') 283 | parser.add_argument('--Da', type=int, default=1, 284 | help='The Abstract Modeling Output Dimension') 285 | FLAGS, unparsed = parser.parse_known_args() 286 | tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 287 | --------------------------------------------------------------------------------