├── __init__.py ├── NRI-MD.png ├── visual.py ├── LICENSE ├── .gitignore ├── README.md ├── postanalysis_path.py ├── postanalysis_visual.py ├── convert_dataset.py ├── main.py ├── utils.py └── modules.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /NRI-MD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juexinwang/NRI-MD/HEAD/NRI-MD.png -------------------------------------------------------------------------------- /visual.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | a = np.load('056_500probs_test.npy') 3 | b = a[:, :, 1] 4 | c = a[:, :, 2] 5 | d = a[:, :, 3] 6 | probs = b+c+d 7 | probs = np.reshape(probs, (56,5852)) 8 | edges_train= probs/56 9 | 10 | results=np.zeros((5852)) 11 | for i in range(56): 12 | results=results+edges_train[i,:] 13 | 14 | index=results<(0.5) 15 | results[index]= 0 16 | 17 | 18 | genes=77 19 | edges_results=np.zeros((genes,genes)) 20 | count=0 21 | for i in range(genes): 22 | for j in range(genes): 23 | if not i==j: 24 | edges_results[i,j]=results[count] 25 | count+=1 26 | else: 27 | edges_results[i,j]=0 28 | 29 | import matplotlib.pyplot as plt 30 | 31 | import seaborn as sns 32 | a = edges_results 33 | ax = sns.heatmap(a, linewidth=0.5, cmap="Blues", vmax=1.0, vmin=0.0) 34 | plt.savefig('probs.png', dpi=600) 35 | plt.show() 36 | 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 juexinwang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | *.npy 104 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural relational inference to learn allosteric long-range interactions in proteins from molecular dynamics simulations 2 | 3 | **Abstract:** Protein allostery is a biological process facilitated by spatially long-range intra-protein communication, whereby ligand binding or amino acid mutation at a distant site affects the active site remotely. Molecular dynamics (MD) simulation provides a powerful computational approach to probe the allostery effect. However, current MD simulations cannot reach the time scales of whole allostery processes. The advent of deep learning made it possible to evaluate both spatially short and long-range communications for understanding allostery. For this purpose, we applied a neural relational inference (NRI) model based on a graph neural network (GNN), which adopts an encoder-decoder architecture to simultaneously infer latent interactions to probe protein allosteric processes as dynamic networks of interacting residues. From the MD trajectories, this model successfully learned the long-range interactions and pathways that can mediate the allosteric communications between the two distant binding sites in the Pin1, SOD1, and MEK1 systems. 4 | 5 | ![Neural Relational Inference (NRI)](NRI-MD.png) 6 | 7 | ### Requirements 8 | * Pytorch 1.2 9 | * Python 3.7 10 | * networkx 2.5 (Optional, only used in post analysis) 11 | 12 | ### Prepare Molecular Simulation Trajectories 13 | 14 | Place your molecular simulation trajectories in data/pdb folder. We have ca_1.pdb for the tutorial usage. 15 | 16 | ``` 17 | python convert_dataset.py 18 | ``` 19 | 20 | This step will seperate train/validation/test dataset, and generate .npy files in data folder. 21 | 22 | ### Run experiments 23 | 24 | From the project's root folder, simply run 25 | ``` 26 | python main.py 27 | ``` 28 | ### Post Analysis 29 | Visualize the inferred the interaction between residues and domains 30 | ``` 31 | python postanalysis_visual.py 32 | ``` 33 | 34 | Find the shortest path between residues 35 | ``` 36 | python postanalysis_path.py 37 | ``` 38 | 39 | ### Data availability 40 | All the MD trajectories used in the study can be downloaded from https://doi.org/10.5281/zenodo.5941385 41 | 42 | ### Cite 43 | If you make use of this code in your own work, please cite our paper: 44 | 45 | **Neural relational inference to learn long-range allosteric interactions in proteins from molecular dynamics simulations.** 46 | Jingxuan Zhu, Juexin Wang, Weiwei Han, Dong Xu, 47 | Nature communications 13, no. 1 (2022): 1-16. 48 | https://www.nature.com/articles/s41467-022-29331-3 49 | 50 | ### Reference 51 | We thank the official implementation of neural relational inference at 52 | https://github.com/ethanfetaya/NRI 53 | 54 | -------------------------------------------------------------------------------- /postanalysis_path.py: -------------------------------------------------------------------------------- 1 | import seaborn as sns 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import pandas as pd 5 | import os 6 | import argparse 7 | import networkx as nx 8 | 9 | parser = argparse.ArgumentParser( 10 | 'Find shortest paths along domains in residues.') 11 | parser.add_argument('--num-residues', type=int, default=77, 12 | help='Number of residues of the PDB.') 13 | parser.add_argument('--windowsize', type=int, default=56, 14 | help='window size') 15 | parser.add_argument('--dist-threshold', type=int, default=12, 16 | help='threshold for shortest distance') 17 | parser.add_argument('--filename', type=str, default='logs/out_probs_train.npy', 18 | help='File name of the probs file.') 19 | parser.add_argument('--source-node', type=int, default=46, 20 | help='source residue of the PDB') 21 | parser.add_argument('--outputfilename', type=str, default='logs/source46.txt', 22 | help='File of shortest path from source to targets') 23 | args = parser.parse_args() 24 | 25 | 26 | # According to the distribution of learned edges between residues, we calculated the shortest path 27 | # from mutation site to the residues in the active loop. 28 | 29 | def getEdgeResults(threshold=False): 30 | a = np.load(args.filename) 31 | b = a[:, :, 1] 32 | c = a[:, :, 2] 33 | d = a[:, :, 3] 34 | 35 | # There are four types of edges, eliminate the first type as the non-edge 36 | probs = b+c+d 37 | # For default residue number 77, residueR2 = 77*(77-1)=5852 38 | residueR2 = args.num_residues*(args.num_residues-1) 39 | probs = np.reshape(probs, (args.windowsize, residueR2)) 40 | 41 | # Calculate the occurence of edges 42 | edges_train = probs/args.windowsize 43 | 44 | results = np.zeros((residueR2)) 45 | for i in range(args.windowsize): 46 | results = results+edges_train[i, :] 47 | 48 | if threshold: 49 | # threshold, default 0.6 50 | index = results < (args.threshold) 51 | results[index] = 0 52 | 53 | # Calculate prob for figures 54 | edges_results = np.zeros((args.num_residues, args.num_residues)) 55 | count = 0 56 | for i in range(args.num_residues): 57 | for j in range(args.num_residues): 58 | if not i == j: 59 | edges_results[i, j] = results[count] 60 | count += 1 61 | else: 62 | edges_results[i, j] = 0 63 | 64 | return edges_results 65 | 66 | 67 | def dist_cal(AA1, AA2): 68 | """ 69 | calculate CA-CA distance 70 | """ 71 | dist = np.sqrt(np.square( 72 | AA1['x'] - AA2['x']) + np.square(AA1['y'] - AA2['y']) + np.square(AA1['z'] - AA2['z'])) 73 | return dist 74 | 75 | 76 | # Load distribution of learned edges 77 | edges_results = getEdgeResults() 78 | 79 | dists_matrix = list() 80 | for i in range(1, 2): 81 | tmp_matrix = np.zeros((args.num_residues, args.num_residues)) 82 | df = pd.read_csv('data/pdb/ca_%d.pdb' % i, sep='\s+', names=[ 83 | 'ATOM', 'n1', 'CA', 'AA', 'Chain', 'n2', 'x', 'y', 'z', 'n3', 'n4', 'C']) # temparory DataFrame 84 | tdf = df.iloc[1:(args.num_residues+1)] 85 | for ind_1 in range(args.num_residues-1): 86 | for ind_2 in range(ind_1+1, args.num_residues): 87 | AA1 = tdf.iloc[ind_1] 88 | AA2 = tdf.iloc[ind_2] 89 | tmp_dist = dist_cal(AA1, AA2) 90 | tmp_matrix[ind_1, ind_2] = tmp_dist 91 | tmp_matrix[ind_2, ind_1] = tmp_dist 92 | dists_matrix.append(tmp_matrix) 93 | dists_matrix = np.array(dists_matrix) 94 | dists_mean = np.zeros((args.num_residues, args.num_residues)) 95 | for ind_1 in range(args.num_residues-1): 96 | for ind_2 in range(ind_1+1, args.num_residues): 97 | dists_mean[ind_1, ind_2] = dists_matrix[:, ind_1, ind_2].mean() 98 | dists_mean[ind_2, ind_1] = dists_matrix[:, ind_2, ind_1].mean() 99 | np.save('logs/dists_mean_12.npy', dists_mean) 100 | # if the distance is larer than 12, then ignore 101 | filtered_edges = np.where(dists_mean > args.dist_threshold, 1, edges_results) 102 | np.save('logs/filtered_edges_12.npy', filtered_edges) 103 | 104 | edges = np.load('logs/filtered_edges_12.npy') 105 | # The network is directed 106 | edges_list = list() 107 | # Default: i->j 108 | for i in range(args.num_residues): 109 | for j in range(args.num_residues): 110 | if i != j: 111 | edges_list.append((i, j, {'weight': filtered_edges[j, i]})) 112 | MDG = nx.MultiDiGraph() 113 | MDG.add_edges_from(edges_list) 114 | 115 | source_node = args.source_node # set source node 116 | target_nodes = [61, 62, 63, 64, 65, 66, 67, 68, 69, 70] # set target nodes 117 | out_file = args.outputfilename 118 | 119 | path_dict = dict() 120 | for tn in target_nodes: 121 | path_dict[tn] = [] 122 | path_nodes = nx.dijkstra_path(MDG, source_node, tn) 123 | path_length_list = [] # save the length of shorest path 124 | path_length_list.append(nx.dijkstra_path_length(MDG, source_node, tn)) 125 | path_dict[tn].append('shortest_path : ' + '->'.join(list(map(str, path_nodes))) + 126 | ' : ' + str(nx.dijkstra_path_length(MDG, source_node, tn))) 127 | if len(path_nodes) > 2: 128 | for ipn in range(1, len(path_nodes)): 129 | tmp_MDG = MDG.copy() 130 | tmp_MDG.remove_edge(path_nodes[ipn-1], path_nodes[ipn]) 131 | tmp_path_nodes = nx.dijkstra_path(tmp_MDG, source_node, tn) 132 | path_length_list.append( 133 | nx.dijkstra_path_length(tmp_MDG, source_node, tn)) 134 | path_dict[tn].append('remove(%d->%d) : ' % (path_nodes[ipn-1], path_nodes[ipn]) + '->'.join( 135 | list(map(str, tmp_path_nodes))) + ' : ' + str(nx.dijkstra_path_length(tmp_MDG, source_node, tn))) 136 | # According to the shortest path length from small to large 137 | path_dict[tn] = np.array(path_dict[tn])[np.argsort(path_length_list)] 138 | with open(out_file, 'w') as f: 139 | for k, v in path_dict.items(): 140 | f.write('target node : %d\n' % k) 141 | for tmp_path in v: 142 | f.write('\t\t' + tmp_path + '\n') 143 | -------------------------------------------------------------------------------- /postanalysis_visual.py: -------------------------------------------------------------------------------- 1 | import seaborn as sns 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import pandas as pd 5 | import os 6 | import argparse 7 | 8 | parser = argparse.ArgumentParser( 9 | 'Visualize the distribution of learned edges between residues.') 10 | parser.add_argument('--num-residues', type=int, default=77, 11 | help='Number of residues of the PDB.') 12 | parser.add_argument('--windowsize', type=int, default=56, 13 | help='window size') 14 | parser.add_argument('--threshold', type=float, default=0.6, 15 | help='threshold for plotting') 16 | parser.add_argument('--dist-threshold', type=int, default=12, 17 | help='threshold for shortest distance') 18 | parser.add_argument('--filename', type=str, default='logs/out_probs_train.npy', 19 | help='File name of the probs file.') 20 | args = parser.parse_args() 21 | 22 | 23 | def getEdgeResults(threshold=False): 24 | a = np.load(args.filename) 25 | b = a[:, :, 1] 26 | c = a[:, :, 2] 27 | d = a[:, :, 3] 28 | 29 | # There are four types of edges, eliminate the first type as the non-edge 30 | probs = b+c+d 31 | # For default residue number 77, residueR2 = 77*(77-1)=5852 32 | residueR2 = args.num_residues*(args.num_residues-1) 33 | probs = np.reshape(probs, (args.windowsize, residueR2)) 34 | 35 | # Calculate the occurence of edges 36 | edges_train = probs/args.windowsize 37 | 38 | results = np.zeros((residueR2)) 39 | for i in range(args.windowsize): 40 | results = results+edges_train[i, :] 41 | 42 | if threshold: 43 | # threshold, default 0.6 44 | index = results < (args.threshold) 45 | results[index] = 0 46 | 47 | # Calculate prob for figures 48 | edges_results = np.zeros((args.num_residues, args.num_residues)) 49 | count = 0 50 | for i in range(args.num_residues): 51 | for j in range(args.num_residues): 52 | if not i == j: 53 | edges_results[i, j] = results[count] 54 | count += 1 55 | else: 56 | edges_results[i, j] = 0 57 | 58 | return edges_results 59 | 60 | 61 | def getDomainEdges(edges_results, domainName): 62 | 63 | if domainName == 'b1': 64 | startLoc = 0 65 | endLoc = 25 66 | elif domainName == 'diml': 67 | startLoc = 25 68 | endLoc = 29 69 | elif domainName == 'disl': 70 | startLoc = 29 71 | endLoc = 32 72 | elif domainName == 'zl': 73 | startLoc = 32 74 | endLoc = 43 75 | elif domainName == 'b2': 76 | startLoc = 43 77 | endLoc = 62 78 | elif domainName == 'el': 79 | startLoc = 62 80 | endLoc = 72 81 | elif domainName == 'b3': 82 | startLoc = 72 83 | endLoc = 77 84 | 85 | edges_results_b1 = edges_results[:25, startLoc:endLoc] 86 | edges_results_diml = edges_results[25:29, startLoc:endLoc] 87 | edges_results_disl = edges_results[29:32, startLoc:endLoc] 88 | edges_results_zl = edges_results[32:43, startLoc:endLoc] 89 | edges_results_b2 = edges_results[43:62, startLoc:endLoc] 90 | edges_results_el = edges_results[62:72, startLoc:endLoc] 91 | edges_results_b3 = edges_results[72:-1, startLoc:endLoc] 92 | 93 | edge_num_b1 = edges_results_b1.sum(axis=0) 94 | edge_num_diml = edges_results_diml.sum(axis=0) 95 | edge_num_disl = edges_results_disl.sum(axis=0) 96 | edge_num_zl = edges_results_zl.sum(axis=0) 97 | edge_num_b2 = edges_results_b2.sum(axis=0) 98 | edge_num_el = edges_results_el.sum(axis=0) 99 | edge_num_b3 = edges_results_b3.sum(axis=0) 100 | 101 | if domainName == 'b1': 102 | edge_average_b1 = 0 103 | else: 104 | edge_average_b1 = edge_num_b1.sum(axis=0)/(25*(endLoc-startLoc)) 105 | if domainName == 'diml': 106 | edge_average_diml = 0 107 | else: 108 | edge_average_diml = edge_num_diml.sum(axis=0)/(4*(endLoc-startLoc)) 109 | if domainName == 'disl': 110 | edge_average_disl = 0 111 | else: 112 | edge_average_disl = edge_num_disl.sum(axis=0)/(3*(endLoc-startLoc)) 113 | if domainName == 'zl': 114 | edge_average_zl = 0 115 | else: 116 | edge_average_zl = edge_num_zl.sum(axis=0)/(11*(endLoc-startLoc)) 117 | if domainName == 'b2': 118 | edge_average_b2 = 0 119 | else: 120 | edge_average_b2 = edge_num_b2.sum(axis=0)/(19*(endLoc-startLoc)) 121 | if domainName == 'el': 122 | edge_average_el = 0 123 | else: 124 | edge_average_el = edge_num_el.sum(axis=0)/(10*(endLoc-startLoc)) 125 | if domainName == 'b3': 126 | edge_average_b3 = 0 127 | else: 128 | edge_average_b3 = edge_num_b3.sum(axis=0)/(6*(endLoc-startLoc)) 129 | 130 | edges_to_all = np.hstack((edge_average_b1, edge_average_diml, edge_average_disl, 131 | edge_average_zl, edge_average_b2, edge_average_el, edge_average_b3)) 132 | return edges_to_all 133 | 134 | 135 | # Load distribution of learned edges 136 | edges_results_visual = getEdgeResults(threshold=True) 137 | # Step 1: Visualize results 138 | ax = sns.heatmap(edges_results_visual, linewidth=0.5, 139 | cmap="Blues", vmax=1.0, vmin=0.0) 140 | plt.savefig('logs/probs.png', dpi=600) 141 | # plt.show() 142 | plt.close() 143 | 144 | # Step 2: Get domain specific results 145 | # According to the distribution of learned edges between residues, we integrated adjacent residues as blocks for a more straightforward observation of the interactions. 146 | # For example, the residues in SOD1 structure are divided into seven domains (β1, diml, disl, zl, β2, el, β3). 147 | 148 | edges_results = getEdgeResults(threshold=False) 149 | # SOD1 specific: 150 | b1 = getDomainEdges(edges_results, 'b1') 151 | diml = getDomainEdges(edges_results, 'diml') 152 | disl = getDomainEdges(edges_results, 'disl') 153 | zl = getDomainEdges(edges_results, 'zl') 154 | b2 = getDomainEdges(edges_results, 'b2') 155 | el = getDomainEdges(edges_results, 'el') 156 | b3 = getDomainEdges(edges_results, 'b3') 157 | edges_results = np.vstack((b1, diml, disl, zl, b2, el, b3)) 158 | # print(edges_results) 159 | edges_results_T = edges_results.T 160 | index = edges_results_T < (args.threshold) 161 | edges_results_T[index] = 0 162 | 163 | # Visualize 164 | ax = sns.heatmap(edges_results_T, linewidth=1, 165 | cmap="Blues", vmax=1.0, vmin=0.0) 166 | ax.set_ylim([7, 0]) 167 | plt.savefig('logs/edges_domain.png', dpi=600) 168 | # plt.show() 169 | plt.close() 170 | -------------------------------------------------------------------------------- /convert_dataset.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import argparse 4 | from copy import deepcopy 5 | from scipy import interpolate 6 | 7 | parser = argparse.ArgumentParser('Preprocessing: Generate training/validation/testing features from pdb') 8 | parser.add_argument('--MDfolder', type=str, default="data/pdb/", 9 | help='folder of pdb MD') 10 | parser.add_argument('--pdb-start', type=int, default="1", 11 | help='select pdb file window from start, e.g. in tutorial it is ca_1.pdb') 12 | parser.add_argument('--pdb-end', type=int, default="56", 13 | help='select pdb file window to end') 14 | parser.add_argument('--num-residues', type=int, default=77, 15 | help='Number of residues of the MD pdb') 16 | parser.add_argument('--feature-size', type=int, default=6, 17 | help='The number of features used in study( position (X,Y,Z) + velocity (X,Y,Z) ).') 18 | parser.add_argument('--train-interval', type=int, default=60, 19 | help='intervals in trajectory in training') 20 | parser.add_argument('--validate-interval', type=int, default=60, 21 | help='intervals in trajectory in validate') 22 | parser.add_argument('--test-interval', type=int, default=100, 23 | help='intervals in trajectory in test') 24 | args = parser.parse_args() 25 | 26 | 27 | def read_feature_file(filename, feature_size=1, gene_size=10, timestep_size=21): 28 | """ 29 | Read single expriments of all time points 30 | """ 31 | feature = np.zeros((timestep_size, feature_size, gene_size)) 32 | 33 | time_count = -1 34 | with open(filename) as f: 35 | lines = f.readlines() 36 | for line in lines: 37 | line = line.strip() 38 | if(time_count >= 0 and time_count < timestep_size): 39 | words = line.split() 40 | data_count = 0 41 | for word in words: 42 | feature[time_count, 0, data_count] = word 43 | data_count += 1 44 | time_count += 1 45 | f.close() 46 | # Use interpole 47 | feature = timepoint_sim(feature, 4) 48 | return feature 49 | 50 | 51 | def read_feature_Residue_file(filename): 52 | resdict = {} 53 | count = 0 54 | with open(filename) as f: 55 | lines = f.readlines() 56 | for line in lines: 57 | line = line.strip() 58 | words = line.split(",") 59 | if count > 0: 60 | feature = np.zeros((len(words)-1)) 61 | for i in range(len(words)-1): 62 | feature[i] = words[i+1] 63 | resdict[words[0]] = feature 64 | count = count+1 65 | return resdict 66 | 67 | 68 | def read_feature_MD_file(filename, timestep_size, feature_size, num_residues, interval): 69 | """ 70 | Read single expriments of all time points 71 | """ 72 | feature = np.zeros((timestep_size, feature_size, num_residues)) 73 | 74 | flag = False 75 | nflag = False 76 | modelNum = 0 77 | with open(filename) as f: 78 | lines = f.readlines() 79 | for line in lines: 80 | line = line.strip() 81 | words = line.split() 82 | if(line.startswith("MODEL")): 83 | modelNum = int(words[1]) 84 | if (modelNum % interval == 1): 85 | flag = True 86 | if (modelNum % interval == 2): 87 | nflag = True 88 | elif(line.startswith("ATOM") and words[2] == "CA" and flag): 89 | numStep = int(modelNum/interval) 90 | feature[numStep, 0, int(words[5])-1] = float(words[6]) 91 | feature[numStep, 1, int(words[5])-1] = float(words[7]) 92 | feature[numStep, 2, int(words[5])-1] = float(words[8]) 93 | elif(line.startswith("ATOM") and words[2] == "CA" and nflag): 94 | numStep = int(modelNum/interval) 95 | feature[numStep, 3, int( 96 | words[5])-1] = float(words[6])-feature[numStep, 0, int(words[5])-1] 97 | feature[numStep, 4, int( 98 | words[5])-1] = float(words[7])-feature[numStep, 1, int(words[5])-1] 99 | feature[numStep, 5, int( 100 | words[5])-1] = float(words[8])-feature[numStep, 2, int(words[5])-1] 101 | elif(line.startswith("ENDMDL") and flag): 102 | flag = False 103 | elif(line.startswith("ENDMDL") and nflag): 104 | nflag = False 105 | f.close() 106 | return feature 107 | 108 | 109 | def read_feature_MD_file_slidingwindow(filename, timestep_size, feature_size, num_residues, interval, window_choose, aa_start, aa_end): 110 | # read single expriments of all time points 111 | feature = np.zeros((timestep_size, feature_size, num_residues)) 112 | 113 | flag = False 114 | nflag = False 115 | modelNum = 0 116 | with open(filename) as f: 117 | lines = f.readlines() 118 | for line in lines: 119 | line = line.strip() 120 | words = line.split() 121 | if(line.startswith("MODEL")): 122 | modelNum = int(words[1]) 123 | if (modelNum % interval == window_choose): 124 | flag = True 125 | if (modelNum % interval == (window_choose+1)): 126 | nflag = True 127 | elif(line.startswith("ATOM") and words[2] == "CA" and int(words[4]) >= aa_start and int(words[4]) <= aa_end and flag): 128 | numStep = int(modelNum/interval) 129 | feature[numStep, 0, int(words[4])-aa_start] = float(words[5]) 130 | feature[numStep, 1, int(words[4])-aa_start] = float(words[6]) 131 | feature[numStep, 2, int(words[4])-aa_start] = float(words[7]) 132 | elif(line.startswith("ATOM") and words[2] == "CA" and int(words[4]) >= aa_start and int(words[4]) <= aa_end and nflag): 133 | numStep = int(modelNum/interval) 134 | feature[numStep, 3, int( 135 | words[4])-aa_start] = float(words[5])-feature[numStep, 0, int(words[4])-aa_start] 136 | feature[numStep, 4, int( 137 | words[4])-aa_start] = float(words[6])-feature[numStep, 1, int(words[4])-aa_start] 138 | feature[numStep, 5, int( 139 | words[4])-aa_start] = float(words[7])-feature[numStep, 2, int(words[4])-aa_start] 140 | elif(line.startswith("ENDMDL") and flag): 141 | flag = False 142 | elif(line.startswith("ENDMDL") and nflag): 143 | nflag = False 144 | f.close() 145 | # print(feature.shape) 146 | return feature 147 | 148 | 149 | def read_feature_MD_file_resi(filename, resDict, feature_size, num_residues, timestep_size, interval): 150 | # read single expriments of all time points 151 | feature = np.zeros((timestep_size, feature_size, num_residues)) 152 | 153 | flag = False 154 | nflag = False 155 | modelNum = 0 156 | with open(filename) as f: 157 | lines = f.readlines() 158 | for line in lines: 159 | line = line.strip() 160 | words = line.split() 161 | if(line.startswith("MODEL")): 162 | modelNum = int(words[1]) 163 | if (modelNum % interval == 1): 164 | flag = True 165 | if (modelNum % interval == 2): 166 | nflag = True 167 | elif(line.startswith("ATOM") and words[2] == "CA" and flag): 168 | numStep = int(modelNum/interval) 169 | feature[numStep, 0, int(words[4])-1] = float(words[5]) 170 | feature[numStep, 1, int(words[4])-1] = float(words[6]) 171 | feature[numStep, 2, int(words[4])-1] = float(words[7]) 172 | featureResi = resDict[words[3]] 173 | for i in range(6, 6+featureResi.shape[0]): 174 | feature[numStep, i, int(words[4])-1] = featureResi[i-6] 175 | 176 | elif(line.startswith("ATOM") and words[2] == "CA" and nflag): 177 | numStep = int(modelNum/interval) 178 | feature[numStep, 3, int( 179 | words[4])-1] = float(words[5])-feature[numStep, 0, int(words[4])-1] 180 | feature[numStep, 4, int( 181 | words[4])-1] = float(words[6])-feature[numStep, 1, int(words[4])-1] 182 | feature[numStep, 5, int( 183 | words[4])-1] = float(words[7])-feature[numStep, 2, int(words[4])-1] 184 | elif(line.startswith("ENDMDL") and flag): 185 | flag = False 186 | elif(line.startswith("ENDMDL") and nflag): 187 | nflag = False 188 | f.close() 189 | return feature 190 | 191 | 192 | def read_edge_file(filename, gene_size): 193 | edges = np.zeros((gene_size, gene_size)) 194 | count = 0 195 | with open(filename) as f: 196 | lines = f.readlines() 197 | for line in lines: 198 | line = line.strip() 199 | words = line.split() 200 | data_count = 0 201 | for word in words: 202 | edges[count, data_count] = word 203 | data_count += 1 204 | count += 1 205 | f.close() 206 | return edges 207 | 208 | 209 | def convert_dataset(feature_filename, edge_filename, experiment_size=5): 210 | features = list() 211 | 212 | edges = np.zeros((experiment_size, experiment_size)) 213 | 214 | for i in range(1, experiment_size+1): 215 | features.append(read_feature_file(feature_filename+"_"+str(i)+".txt")) 216 | 217 | count = 0 218 | with open(edge_filename) as f: 219 | lines = f.readlines() 220 | for line in lines: 221 | line = line.strip() 222 | words = line.split() 223 | data_count = 0 224 | for word in words: 225 | edges[count, data_count] = word 226 | data_count += 1 227 | count += 1 228 | f.close() 229 | 230 | features = np.stack(features, axis=0) 231 | edges = np.tile(edges, (features.shape[0], 1)).reshape( 232 | features.shape[0], features.shape[3], features.shape[3]) 233 | return features, edges 234 | 235 | 236 | def convert_dataset_sim(feature_filename, edge_filename, experiment_size=5, gene_size=5, sim_size=50000): 237 | features = list() 238 | 239 | edges = np.zeros((gene_size, gene_size)) 240 | 241 | for i in range(1, experiment_size+1): 242 | features.append(read_feature_file( 243 | feature_filename+"_"+str(i)+".txt"), gene_size=5) 244 | 245 | count = 0 246 | with open(edge_filename) as f: 247 | lines = f.readlines() 248 | for line in lines: 249 | line = line.strip() 250 | words = line.split() 251 | data_count = 0 252 | for word in words: 253 | edges[count, data_count] = word 254 | data_count += 1 255 | count += 1 256 | f.close() 257 | 258 | features = np.stack(features, axis=0) 259 | 260 | features_out = np.zeros( 261 | (sim_size, features.shape[1], features.shape[2], features.shape[3])) 262 | edges_out = np.zeros((sim_size, gene_size, gene_size)) 263 | 264 | for i in range(sim_size): 265 | index = np.random.permutation(np.arange(experiment_size)) 266 | num = np.random.permutation(np.arange(experiment_size))[0] 267 | features_out[i, :, :, :] = features[num, :, :, :][:, :, index] 268 | edges_out[i, :, :] = edges[index, :][:, index] 269 | 270 | # Add noise 271 | features_out = features_out + \ 272 | np.random.randn( 273 | sim_size, features.shape[1], features.shape[2], features.shape[3]) 274 | return features_out, edges_out 275 | 276 | 277 | def convert_dataset_md(feature_filename, startIndex, experiment_size, timestep_size, feature_size, num_residues, interval): 278 | features = list() 279 | edges = list() 280 | 281 | for i in range(startIndex, experiment_size+1): 282 | print("Start: "+str(i)+"th PDB") 283 | features.append(read_feature_MD_file(feature_filename+"smd"+str(i) + 284 | ".pdb", timestep_size, feature_size, num_residues, interval)) 285 | edges.append(np.zeros((num_residues, num_residues))) 286 | 287 | features = np.stack(features, axis=0) 288 | edges = np.stack(edges, axis=0) 289 | 290 | return features, edges 291 | 292 | 293 | def convert_dataset_md_single(MDfolder, startIndex, experiment_size, timestep_size, feature_size, num_residues, interval, pdb_start, pdb_end, aa_start, aa_end): 294 | """ 295 | Convert in single md file in single skeleton 296 | """ 297 | features = list() 298 | edges = list() 299 | 300 | for i in range(startIndex, experiment_size+1): 301 | print("Start: "+str(i)+"th PDB") 302 | for j in range(pdb_start, pdb_end+1): 303 | # print(str(i)+" "+str(j)) 304 | features.append(read_feature_MD_file_slidingwindow(MDfolder+"ca_"+str( 305 | i)+".pdb", timestep_size, feature_size, num_residues, interval, j, aa_start, aa_end)) 306 | edges.append(np.zeros((num_residues, num_residues))) 307 | print("***") 308 | print(len(features)) 309 | print("###") 310 | features = np.stack(features, axis=0) 311 | edges = np.stack(edges, axis=0) 312 | 313 | return features, edges 314 | 315 | 316 | def timepoint_sim(feature, fold): 317 | # hard code now,fold=4 318 | # feature_shape: [timestep, feature_size, gene] 319 | step = 1/fold 320 | timestep = feature.shape[0] 321 | genes = feature.shape[2] 322 | x = np.arange(timestep) 323 | xnew = np.arange(0, (timestep-1)+step, step) 324 | feature_out = np.zeros((xnew.shape[0], 1, genes)) 325 | for gene in range(genes): 326 | y = feature[:, 0, gene] 327 | tck = interpolate.splrep(x, y, s=0) 328 | ynew = interpolate.splev(xnew, tck, der=0) 329 | feature_out[:, 0, gene] = ynew 330 | return feature_out 331 | 332 | 333 | MDfolder = args.MDfolder 334 | feature_size = args.feature_size 335 | num_residues = args.num_residues 336 | pdb_start = args.pdb_start 337 | pdb_end = args.pdb_end 338 | train_interval = args.train_interval 339 | validate_interval = args.validate_interval 340 | test_interval = args.test_interval 341 | 342 | # Generate training/validating/testing 343 | print("Generate Train") 344 | features, edges = convert_dataset_md_single(MDfolder, startIndex=1, experiment_size=1, timestep_size=50, 345 | feature_size=feature_size, num_residues=num_residues, interval=train_interval, pdb_start=pdb_start, pdb_end=pdb_end, aa_start=1, aa_end=num_residues) 346 | 347 | np.save('data/features.npy', features) 348 | np.save('data/edges.npy', edges) 349 | 350 | 351 | print("Generate Valid") 352 | features_valid, edges_valid = convert_dataset_md_single(MDfolder, startIndex=1, experiment_size=1, timestep_size=50, 353 | feature_size=feature_size, num_residues=num_residues, interval=validate_interval, pdb_start=pdb_start, pdb_end=pdb_end, aa_start=1, aa_end=num_residues) 354 | 355 | np.save('data/features_valid.npy', features_valid) 356 | np.save('data/edges_valid.npy', edges_valid) 357 | 358 | 359 | print("Generate Test") 360 | features_test, edges_test = convert_dataset_md_single(MDfolder, startIndex=1, experiment_size=1, timestep_size=50, 361 | feature_size=feature_size, num_residues=num_residues, interval=test_interval, pdb_start=pdb_start, pdb_end=pdb_end, aa_start=1, aa_end=num_residues) 362 | np.save('data/features_test.npy', features_test) 363 | np.save('data/edges_test.npy', edges_test) 364 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import time 2 | import argparse 3 | import pickle 4 | import os 5 | import datetime 6 | import torch 7 | import torch.optim as optim 8 | from torch.optim import lr_scheduler 9 | from utils import * 10 | from modules import * 11 | 12 | parser = argparse.ArgumentParser( 13 | 'Neral relational inference for molecular dynamics simulations') 14 | parser.add_argument('--num-residues', type=int, default=77, 15 | help='Number of residues of the PDB.') 16 | parser.add_argument('--save-folder', type=str, default='logs', 17 | help='Where to save the trained model, leave empty to not save anything.') 18 | parser.add_argument('--load-folder', type=str, default='', 19 | help='Where to load the trained model if finetunning. ' + 20 | 'Leave empty to train from scratch') 21 | parser.add_argument('--edge-types', type=int, default=4, 22 | help='The number of edge types to infer.') 23 | parser.add_argument('--dims', type=int, default=6, 24 | help='The number of input dimensions used in study( position (X,Y,Z) + velocity (X,Y,Z) ). ') 25 | parser.add_argument('--timesteps', type=int, default=50, 26 | help='The number of time steps per sample. Actually is 50') 27 | parser.add_argument('--prediction-steps', type=int, default=1, metavar='N', 28 | help='Num steps to predict before re-using teacher forcing.') 29 | parser.add_argument('--no-cuda', action='store_true', default=False, 30 | help='Disables CUDA training.') 31 | parser.add_argument('--seed', type=int, default=42, help='Random seed.') 32 | parser.add_argument('--epochs', type=int, default=500, 33 | help='Number of epochs to train.') 34 | parser.add_argument('--batch-size', type=int, default=1, 35 | help='Number of samples per batch.') 36 | parser.add_argument('--lr', type=float, default=0.0005, 37 | help='Initial learning rate.') 38 | parser.add_argument('--encoder-hidden', type=int, default=256, 39 | help='Number of hidden units in encoder.') 40 | parser.add_argument('--decoder-hidden', type=int, default=256, 41 | help='Number of hidden units in decoder.') 42 | parser.add_argument('--temp', type=float, default=0.5, 43 | help='Temperature for Gumbel softmax.') 44 | parser.add_argument('--encoder', type=str, default='mlp', 45 | help='Type of path encoder model (mlp or cnn).') 46 | parser.add_argument('--decoder', type=str, default='rnn', 47 | help='Type of decoder model (mlp, rnn, or sim).') 48 | parser.add_argument('--no-factor', action='store_true', default=False, 49 | help='Disables factor graph model.') 50 | parser.add_argument('--encoder-dropout', type=float, default=0.0, 51 | help='Dropout rate (1 - keep probability) in encoder.') 52 | parser.add_argument('--decoder-dropout', type=float, default=0.0, 53 | help='Dropout rate (1 - keep probability) in decoder.') 54 | parser.add_argument('--lr-decay', type=int, default=200, 55 | help='After how epochs to decay LR by a factor of gamma.') 56 | parser.add_argument('--gamma', type=float, default=0.5, 57 | help='LR decay factor.') 58 | parser.add_argument('--skip-first', action='store_true', default=True, 59 | help='Skip first edge type in decoder, i.e. it represents no-edge.') 60 | parser.add_argument('--var', type=float, default=5e-5, 61 | help='Output variance.') 62 | parser.add_argument('--hard', action='store_true', default=True, 63 | help='Uses discrete samples in training forward pass.') 64 | parser.add_argument('--prior', action='store_true', default=True, 65 | help='Whether to use sparsity prior.') 66 | parser.add_argument('--dynamic-graph', action='store_true', default=True, 67 | help='Whether test with dynamically re-computed graph.') 68 | parser.add_argument('--number-expstart', type=int, default=0, 69 | help='start number of experiments.') 70 | parser.add_argument('--number-exp', type=int, default=56, 71 | help='number of experiments.') 72 | 73 | args = parser.parse_args() 74 | args.cuda = not args.no_cuda and torch.cuda.is_available() 75 | args.factor = not args.no_factor 76 | # print all arguments 77 | print(args) 78 | 79 | np.random.seed(args.seed) 80 | torch.manual_seed(args.seed) 81 | if args.cuda: 82 | torch.cuda.manual_seed(args.seed) 83 | 84 | if args.dynamic_graph: 85 | print("Testing with dynamically re-computed graph.") 86 | 87 | # Save model and meta-data. Always saves in a new sub-folder. 88 | if args.save_folder: 89 | exp_counter = 0 90 | now = datetime.datetime.now() 91 | timestamp = now.isoformat() 92 | save_folder = args.save_folder+'/' 93 | if not os.path.isdir(save_folder): 94 | os.mkdir(save_folder) 95 | meta_file = os.path.join(save_folder, 'metadata.pkl') 96 | encoder_file = os.path.join(save_folder, 'encoder.pt') 97 | decoder_file = os.path.join(save_folder, 'decoder.pt') 98 | 99 | log_file = os.path.join(save_folder, 'log.txt') 100 | log = open(log_file, 'w') 101 | 102 | pickle.dump({'args': args}, open(meta_file, "wb")) 103 | else: 104 | print("WARNING: No save_folder provided!" + 105 | "Testing (within this script) will throw an error.") 106 | 107 | # load data 108 | train_loader, valid_loader, test_loader, loc_max, loc_min, vel_max, vel_min = load_dataset_train_valid_test( 109 | args.batch_size, args.number_exp, args.number_expstart, args.dims) 110 | 111 | 112 | # Generate off-diagonal interaction graph 113 | off_diag = np.ones([args.num_residues, args.num_residues] 114 | ) - np.eye(args.num_residues) 115 | 116 | rel_rec = np.array(encode_onehot(np.where(off_diag)[1]), dtype=np.float32) 117 | rel_send = np.array(encode_onehot(np.where(off_diag)[0]), dtype=np.float32) 118 | rel_rec = torch.FloatTensor(rel_rec) 119 | rel_send = torch.FloatTensor(rel_send) 120 | 121 | if args.encoder == 'mlp': 122 | encoder = MLPEncoder(args.timesteps * args.dims, args.encoder_hidden, 123 | args.edge_types, 124 | args.encoder_dropout, args.factor) 125 | elif args.encoder == 'cnn': 126 | encoder = CNNEncoder(args.dims, args.encoder_hidden, 127 | args.edge_types, 128 | args.encoder_dropout, args.factor) 129 | 130 | if args.decoder == 'mlp': 131 | decoder = MLPDecoder(n_in_node=args.dims, 132 | edge_types=args.edge_types, 133 | msg_hid=args.decoder_hidden, 134 | msg_out=args.decoder_hidden, 135 | n_hid=args.decoder_hidden, 136 | do_prob=args.decoder_dropout, 137 | skip_first=args.skip_first) 138 | elif args.decoder == 'rnn': 139 | decoder = RNNDecoder(n_in_node=args.dims, 140 | edge_types=args.edge_types, 141 | n_hid=args.decoder_hidden, 142 | do_prob=args.decoder_dropout, 143 | skip_first=args.skip_first) 144 | elif args.decoder == 'sim': 145 | decoder = SimulationDecoder( 146 | loc_max, loc_min, vel_max, vel_min, args.suffix) 147 | 148 | if args.load_folder: 149 | encoder_file = os.path.join(args.load_folder, 'encoder.pt') 150 | encoder.load_state_dict(torch.load(encoder_file)) 151 | decoder_file = os.path.join(args.load_folder, 'decoder.pt') 152 | decoder.load_state_dict(torch.load(decoder_file)) 153 | 154 | args.save_folder = False 155 | 156 | optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), 157 | lr=args.lr) 158 | scheduler = lr_scheduler.StepLR(optimizer, step_size=args.lr_decay, 159 | gamma=args.gamma) 160 | 161 | # Linear indices of an upper triangular mx, used for acc calculation 162 | triu_indices = get_triu_offdiag_indices(args.num_residues) 163 | tril_indices = get_tril_offdiag_indices(args.num_residues) 164 | 165 | if args.prior: 166 | prior = np.array([0.91, 0.03, 0.03, 0.03]) # TODO: hard coded for now 167 | print("Using prior") 168 | print(prior) 169 | log_prior = torch.FloatTensor(np.log(prior)) 170 | log_prior = torch.unsqueeze(log_prior, 0) 171 | log_prior = torch.unsqueeze(log_prior, 0) 172 | log_prior = Variable(log_prior) 173 | 174 | if args.cuda: 175 | log_prior = log_prior.cuda() 176 | 177 | if args.cuda: 178 | encoder.cuda() 179 | decoder.cuda() 180 | rel_rec = rel_rec.cuda() 181 | rel_send = rel_send.cuda() 182 | triu_indices = triu_indices.cuda() 183 | tril_indices = tril_indices.cuda() 184 | 185 | rel_rec = Variable(rel_rec) 186 | rel_send = Variable(rel_send) 187 | 188 | 189 | def train(epoch, best_val_loss): 190 | t = time.time() 191 | nll_train = [] 192 | acc_train = [] 193 | kl_train = [] 194 | mse_train = [] 195 | edges_train = [] 196 | probs_train = [] 197 | 198 | encoder.train() 199 | decoder.train() 200 | 201 | for batch_idx, (data, relations) in enumerate(train_loader): 202 | 203 | if args.cuda: 204 | data, relations = data.cuda(), relations.cuda() 205 | data, relations = Variable(data), Variable(relations) 206 | 207 | optimizer.zero_grad() 208 | 209 | logits = encoder(data, rel_rec, rel_send) 210 | edges = gumbel_softmax(logits, tau=args.temp, hard=args.hard) 211 | prob = my_softmax(logits, -1) 212 | 213 | if args.decoder == 'rnn': 214 | output = decoder(data, edges, rel_rec, rel_send, 50, 215 | burn_in=True, 216 | burn_in_steps=args.timesteps - args.prediction_steps) 217 | else: 218 | output = decoder(data, edges, rel_rec, rel_send, 219 | args.prediction_steps) 220 | 221 | target = data[:, :, 1:, :] 222 | 223 | loss_nll = nll_gaussian(output, target, args.var) 224 | 225 | if args.prior: 226 | loss_kl = kl_categorical(prob, log_prior, args.num_residues) 227 | else: 228 | loss_kl = kl_categorical_uniform(prob, args.num_residues, 229 | args.edge_types) 230 | 231 | loss = loss_nll + loss_kl 232 | 233 | acc = edge_accuracy(logits, relations) 234 | acc_train.append(acc) 235 | 236 | loss.backward() 237 | optimizer.step() 238 | 239 | mse_train.append(F.mse_loss(output, target).item()) 240 | nll_train.append(loss_nll.item()) 241 | kl_train.append(loss_kl.item()) 242 | _, edges_t = edges.max(-1) 243 | edges_train.append(edges_t.data.cpu().numpy()) 244 | probs_train.append(prob.data.cpu().numpy()) 245 | 246 | scheduler.step() 247 | nll_val = [] 248 | acc_val = [] 249 | kl_val = [] 250 | mse_val = [] 251 | 252 | encoder.eval() 253 | decoder.eval() 254 | for batch_idx, (data, relations) in enumerate(valid_loader): 255 | if args.cuda: 256 | data, relations = data.cuda(), relations.cuda() 257 | with torch.no_grad(): 258 | 259 | logits = encoder(data, rel_rec, rel_send) 260 | edges = gumbel_softmax(logits, tau=args.temp, hard=True) 261 | prob = my_softmax(logits, -1) 262 | 263 | # validation output uses teacher forcing 264 | output = decoder(data, edges, rel_rec, rel_send, 1) 265 | 266 | target = data[:, :, 1:, :] 267 | loss_nll = nll_gaussian(output, target, args.var) 268 | loss_kl = kl_categorical_uniform( 269 | prob, args.num_residues, args.edge_types) 270 | 271 | acc = edge_accuracy(logits, relations) 272 | acc_val.append(acc) 273 | 274 | mse_val.append(F.mse_loss(output, target).item()) 275 | nll_val.append(loss_nll.item()) 276 | kl_val.append(loss_kl.item()) 277 | 278 | print('Epoch: {:04d}'.format(epoch), 279 | 'nll_train: {:.10f}'.format(np.mean(np.array(nll_train))), 280 | 'kl_train: {:.10f}'.format(np.mean(np.array(kl_train))), 281 | 'mse_train: {:.10f}'.format(np.mean(np.array(mse_train))), 282 | 'acc_train: {:.10f}'.format(np.mean(np.array(acc_train))), 283 | 'nll_val: {:.10f}'.format(np.mean(np.array(nll_val))), 284 | 'kl_val: {:.10f}'.format(np.mean(np.array(kl_val))), 285 | 'mse_val: {:.10f}'.format(np.mean(np.array(mse_val))), 286 | 'acc_val: {:.10f}'.format(np.mean(np.array(acc_val))), 287 | 'time: {:.4f}s'.format(time.time() - t)) 288 | edges_train = np.concatenate(edges_train) 289 | probs_train = np.concatenate(probs_train) 290 | if args.save_folder and np.mean(np.array(nll_val)) < best_val_loss: 291 | torch.save(encoder.state_dict(), encoder_file) 292 | torch.save(decoder.state_dict(), decoder_file) 293 | print('Best model so far, saving...') 294 | print('Epoch: {:04d}'.format(epoch), 295 | 'nll_train: {:.10f}'.format(np.mean(np.array(nll_train))), 296 | 'kl_train: {:.10f}'.format(np.mean(np.array(kl_train))), 297 | 'mse_train: {:.10f}'.format(np.mean(np.array(mse_train))), 298 | 'acc_train: {:.10f}'.format(np.mean(np.array(acc_train))), 299 | 'nll_val: {:.10f}'.format(np.mean(np.array(nll_val))), 300 | 'kl_val: {:.10f}'.format(np.mean(np.array(kl_val))), 301 | 'mse_val: {:.10f}'.format(np.mean(np.array(mse_val))), 302 | 'acc_val: {:.10f}'.format(np.mean(np.array(acc_val))), 303 | 'time: {:.4f}s'.format(time.time() - t), file=log) 304 | log.flush() 305 | 306 | return encoder, decoder, edges_train, probs_train, np.mean(np.array(nll_val)) 307 | 308 | 309 | def test(): 310 | acc_test = [] 311 | nll_test = [] 312 | kl_test = [] 313 | mse_test = [] 314 | edges_test = [] 315 | probs_test = [] 316 | tot_mse = 0 317 | counter = 0 318 | 319 | encoder.eval() 320 | decoder.eval() 321 | encoder.load_state_dict(torch.load(encoder_file)) 322 | decoder.load_state_dict(torch.load(decoder_file)) 323 | 324 | for batch_idx, (data, relations) in enumerate(test_loader): 325 | if args.cuda: 326 | data, relations = data.cuda(), relations.cuda() 327 | 328 | with torch.no_grad(): 329 | # assert (data.size(2) - args.timesteps) >= args.timesteps 330 | assert (data.size(2)) >= args.timesteps 331 | 332 | data_encoder = data[:, :, :args.timesteps, :].contiguous() 333 | data_decoder = data[:, :, -args.timesteps:, :].contiguous() 334 | 335 | logits = encoder(data_encoder, rel_rec, rel_send) 336 | edges = gumbel_softmax(logits, tau=args.temp, hard=True) 337 | 338 | prob = my_softmax(logits, -1) 339 | 340 | output = decoder(data_decoder, edges, rel_rec, rel_send, 1) 341 | 342 | target = data_decoder[:, :, 1:, :] 343 | loss_nll = nll_gaussian(output, target, args.var) 344 | loss_kl = kl_categorical_uniform( 345 | prob, args.num_residues, args.edge_types) 346 | 347 | acc = edge_accuracy(logits, relations) 348 | acc_test.append(acc) 349 | 350 | mse_test.append(F.mse_loss(output, target).item()) 351 | nll_test.append(loss_nll.item()) 352 | kl_test.append(loss_kl.item()) 353 | _, edges_t = edges.max(-1) 354 | edges_test.append(edges_t.data.cpu().numpy()) 355 | probs_test.append(prob.data.cpu().numpy()) 356 | 357 | # For plotting purposes 358 | if args.decoder == 'rnn': 359 | if args.dynamic_graph: 360 | output = decoder(data, edges, rel_rec, rel_send, 50, 361 | burn_in=False, burn_in_steps=args.timesteps, 362 | dynamic_graph=True, encoder=encoder, 363 | temp=args.temp) 364 | else: 365 | output = decoder(data, edges, rel_rec, rel_send, 50, 366 | burn_in=True, burn_in_steps=args.timesteps) 367 | 368 | target = data[:, :, 1:, :] 369 | 370 | else: 371 | data_plot = data[:, :, 0:0 + 21, 372 | :].contiguous() 373 | output = decoder(data_plot, edges, rel_rec, rel_send, 20) 374 | target = data_plot[:, :, 1:, :] 375 | 376 | mse = ((target - output) ** 2).mean(dim=0).mean(dim=0).mean(dim=-1) 377 | tot_mse += mse.data.cpu().numpy() 378 | counter += 1 379 | 380 | mean_mse = tot_mse / counter 381 | mse_str = '[' 382 | for mse_step in mean_mse[:-1]: 383 | mse_str += " {:.12f} ,".format(mse_step) 384 | mse_str += " {:.12f} ".format(mean_mse[-1]) 385 | mse_str += ']' 386 | 387 | print('--------------------------------') 388 | print('--------Testing-----------------') 389 | print('--------------------------------') 390 | print('nll_test: {:.10f}'.format(np.mean(nll_test)), 391 | 'kl_test: {:.10f}'.format(np.mean(kl_test)), 392 | 'mse_test: {:.10f}'.format(np.mean(mse_test)), 393 | 'acc_test: {:.10f}'.format(np.mean(acc_test))) 394 | print('MSE: {}'.format(mse_str)) 395 | edges_test = np.concatenate(edges_test) 396 | probs_test = np.concatenate(probs_test) 397 | 398 | if args.save_folder: 399 | print('--------------------------------', file=log) 400 | print('--------Testing-----------------', file=log) 401 | print('--------------------------------', file=log) 402 | print('nll_test: {:.10f}'.format(np.mean(nll_test)), 403 | 'kl_test: {:.10f}'.format(np.mean(kl_test)), 404 | 'mse_test: {:.10f}'.format(np.mean(mse_test)), 405 | 'acc_test: {:.10f}'.format(np.mean(acc_test)), 406 | file=log) 407 | print('MSE: {}'.format(mse_str), file=log) 408 | log.flush() 409 | return edges_test, probs_test 410 | 411 | 412 | # Train model 413 | print("Start Training...") 414 | t_total = time.time() 415 | best_val_loss = np.inf 416 | best_epoch = 0 417 | for epoch in range(args.epochs): 418 | encoder, decoder, edges_train, probs_train, val_loss = train( 419 | epoch, best_val_loss) 420 | # print('Epoch '+str(epoch)+' with val loss:'+str(val_loss)) 421 | if val_loss < best_val_loss: 422 | best_val_loss = val_loss 423 | best_epoch = epoch 424 | np.save(str(args.save_folder)+'/out_edges_train.npy', edges_train) 425 | np.save(str(args.save_folder)+'/out_probs_train.npy', probs_train) 426 | print("Optimization Finished!") 427 | print("Best Epoch: {:04d}".format(best_epoch)) 428 | if args.save_folder: 429 | print("Best Epoch: {:04d}".format(best_epoch), file=log) 430 | log.flush() 431 | 432 | # Test 433 | edges_test, probs_test = test() 434 | if log is not None: 435 | print(save_folder) 436 | log.close() 437 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data.dataset import TensorDataset 4 | from torch.utils.data import DataLoader 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | 8 | 9 | def my_softmax(input, axis=1): 10 | trans_input = input.transpose(axis, 0).contiguous() 11 | soft_max_1d = F.softmax(trans_input, dim=0) 12 | return soft_max_1d.transpose(axis, 0) 13 | 14 | 15 | def binary_concrete(logits, tau=1, hard=False, eps=1e-10): 16 | y_soft = binary_concrete_sample(logits, tau=tau, eps=eps) 17 | if hard: 18 | y_hard = (y_soft > 0.5).float() 19 | y = Variable(y_hard.data - y_soft.data) + y_soft 20 | else: 21 | y = y_soft 22 | return y 23 | 24 | 25 | def binary_concrete_sample(logits, tau=1, eps=1e-10): 26 | logistic_noise = sample_logistic(logits.size(), eps=eps) 27 | if logits.is_cuda: 28 | logistic_noise = logistic_noise.cuda() 29 | y = logits + Variable(logistic_noise) 30 | return F.sigmoid(y / tau) 31 | 32 | 33 | def sample_logistic(shape, eps=1e-10): 34 | uniform = torch.rand(shape).float() 35 | return torch.log(uniform + eps) - torch.log(1 - uniform + eps) 36 | 37 | 38 | def sample_gumbel(shape, eps=1e-10): 39 | """ 40 | NOTE: Stolen from https://github.com/pytorch/pytorch/pull/3341/commits/327fcfed4c44c62b208f750058d14d4dc1b9a9d3 41 | 42 | Sample from Gumbel(0, 1) 43 | 44 | based on 45 | https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb , 46 | (MIT license) 47 | """ 48 | U = torch.rand(shape).float() 49 | return - torch.log(eps - torch.log(U + eps)) 50 | 51 | 52 | def gumbel_softmax_sample(logits, tau=1, eps=1e-10): 53 | """ 54 | NOTE: Stolen from https://github.com/pytorch/pytorch/pull/3341/commits/327fcfed4c44c62b208f750058d14d4dc1b9a9d3 55 | 56 | Draw a sample from the Gumbel-Softmax distribution 57 | 58 | based on 59 | https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb 60 | (MIT license) 61 | """ 62 | gumbel_noise = sample_gumbel(logits.size(), eps=eps) 63 | if logits.is_cuda: 64 | gumbel_noise = gumbel_noise.cuda() 65 | y = logits + Variable(gumbel_noise) 66 | return my_softmax(y / tau, axis=-1) 67 | 68 | 69 | def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10): 70 | """ 71 | NOTE: Stolen from https://github.com/pytorch/pytorch/pull/3341/commits/327fcfed4c44c62b208f750058d14d4dc1b9a9d3 72 | 73 | Sample from the Gumbel-Softmax distribution and optionally discretize. 74 | Args: 75 | logits: [batch_size, n_class] unnormalized log-probs 76 | tau: non-negative scalar temperature 77 | hard: if True, take argmax, but differentiate w.r.t. soft sample y 78 | Returns: 79 | [batch_size, n_class] sample from the Gumbel-Softmax distribution. 80 | If hard=True, then the returned sample will be one-hot, otherwise it will 81 | be a probability distribution that sums to 1 across classes 82 | 83 | Constraints: 84 | - this implementation only works on batch_size x num_features tensor for now 85 | 86 | based on 87 | https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb , 88 | (MIT license) 89 | """ 90 | y_soft = gumbel_softmax_sample(logits, tau=tau, eps=eps) 91 | if hard: 92 | shape = logits.size() 93 | _, k = y_soft.data.max(-1) 94 | # this bit is based on 95 | # https://discuss.pytorch.org/t/stop-gradients-for-st-gumbel-softmax/530/5 96 | y_hard = torch.zeros(*shape) 97 | if y_soft.is_cuda: 98 | y_hard = y_hard.cuda() 99 | y_hard = y_hard.zero_().scatter_(-1, k.view(shape[:-1] + (1,)), 1.0) 100 | # this cool bit of code achieves two things: 101 | # - makes the output value exactly one-hot (since we add then 102 | # subtract y_soft value) 103 | # - makes the gradient equal to y_soft gradient (since we strip 104 | # all other gradients) 105 | y = Variable(y_hard - y_soft.data) + y_soft 106 | else: 107 | y = y_soft 108 | return y 109 | 110 | 111 | def binary_accuracy(output, labels): 112 | preds = output > 0.5 113 | correct = preds.type_as(labels).eq(labels).double() 114 | correct = correct.sum() 115 | return correct / len(labels) 116 | 117 | 118 | def load_dataset_train(batch_size=1, suffix='', number_exp=1, dims=6): 119 | feat_train = np.load('data/features.npy') 120 | edges_train = np.load('data/edges.npy') 121 | 122 | assert(feat_train.shape[0] >= number_exp) 123 | assert(feat_train.shape[2] >= dims) 124 | feat_train = feat_train[0:number_exp, :, 0:dims, :] 125 | edges_train = edges_train[0:number_exp, :, :] 126 | 127 | # [num_samples, num_timesteps, num_dims, num_genes] 128 | num_genes = feat_train.shape[3] 129 | 130 | feat_max = feat_train.max() 131 | feat_min = feat_train.min() 132 | 133 | # Normalize to [-1, 1] 134 | feat_train = (feat_train - feat_min) * 2 / (feat_max - feat_min) - 1 135 | 136 | # Reshape to: [num_samples, num_genes, num_timesteps, num_dims] 137 | feat_train = np.transpose(feat_train, [0, 3, 1, 2]) 138 | edges_train = np.reshape(edges_train, [-1, num_genes ** 2]) 139 | edges_train = np.array((edges_train + 1) / 2, dtype=np.int64) 140 | 141 | # Exclude self edges 142 | off_diag_idx = np.ravel_multi_index( 143 | np.where(np.ones((num_genes, num_genes)) - np.eye(num_genes)), 144 | [num_genes, num_genes]) 145 | edges_train = edges_train[:, off_diag_idx] 146 | 147 | feat_train = torch.FloatTensor(feat_train) 148 | edges_train = torch.LongTensor(edges_train) 149 | 150 | train_data = TensorDataset(feat_train, edges_train) 151 | 152 | train_data_loader = DataLoader(train_data, batch_size=batch_size) 153 | 154 | return train_data_loader, feat_max, feat_min 155 | 156 | 157 | def load_dataset_train_valid_test(batch_size=1, number_exp=1, number_expstart=0, dims=6): 158 | feat_train = np.load('data/features.npy') 159 | edges_train = np.load('data/edges.npy') 160 | feat_valid = np.load('data/features_valid.npy') 161 | edges_valid = np.load('data/edges_valid.npy') 162 | feat_test = np.load('data/features_test.npy') 163 | edges_test = np.load('data/edges_test.npy') 164 | 165 | assert (feat_train.shape[0] >= number_exp) 166 | assert (feat_train.shape[2] >= dims) 167 | feat_train = feat_train[number_expstart:number_exp, :, 0:dims, :] 168 | edges_train = edges_train[number_expstart:number_exp, :, :] 169 | feat_valid = feat_valid[number_expstart:number_exp, :, 0:dims, :] 170 | edges_valid = edges_valid[number_expstart:number_exp, :, :] 171 | feat_test = feat_test[:, :, 0:dims, :] 172 | 173 | # [num_samples, num_timesteps, num_dims, num_genes] 174 | num_genes = feat_train.shape[3] 175 | 176 | loc_max = feat_train[:, :, 0:3, :].max() 177 | loc_min = feat_train[:, :, 0:3, :].min() 178 | vel_max = feat_train[:, :, 3:6, :].max() 179 | vel_min = feat_train[:, :, 3:6, :].min() 180 | 181 | # Normalize to [-1, 1] 182 | loc_train = (feat_train[:, :, 0:3, :] - loc_min) * \ 183 | 2 / (loc_max - loc_min) - 1 184 | vel_train = (feat_train[:, :, 3:6, :] - vel_min) * \ 185 | 2 / (vel_max - vel_min) - 1 186 | 187 | loc_valid = (feat_valid[:, :, 0:3, :] - loc_min) * \ 188 | 2 / (loc_max - loc_min) - 1 189 | vel_valid = (feat_valid[:, :, 3:6, :] - vel_min) * \ 190 | 2 / (vel_max - vel_min) - 1 191 | 192 | loc_test = (feat_test[:, :, 0:3, :] - loc_min) * \ 193 | 2 / (loc_max - loc_min) - 1 194 | vel_test = (feat_test[:, :, 3:6, :] - vel_min) * \ 195 | 2 / (vel_max - vel_min) - 1 196 | 197 | feat_train = np.concatenate((loc_train, vel_train), axis=2) 198 | feat_valid = np.concatenate((loc_valid, vel_valid), axis=2) 199 | feat_test = np.concatenate((loc_test, vel_test), axis=2) 200 | 201 | # Reshape to: [num_samples, num_genes, num_timesteps, num_dims] 202 | feat_train = np.transpose(feat_train, [0, 3, 1, 2]) 203 | edges_train = np.reshape(edges_train, [-1, num_genes ** 2]) 204 | edges_train = np.array((edges_train + 1) / 2, dtype=np.int64) 205 | 206 | feat_valid = np.transpose(feat_valid, [0, 3, 1, 2]) 207 | edges_valid = np.reshape(edges_valid, [-1, num_genes ** 2]) 208 | edges_valid = np.array((edges_valid + 1) / 2, dtype=np.int64) 209 | 210 | feat_test = np.transpose(feat_test, [0, 3, 1, 2]) 211 | edges_test = np.reshape(edges_test, [-1, num_genes ** 2]) 212 | edges_test = np.array((edges_test + 1) / 2, dtype=np.int64) 213 | 214 | # Exclude self edges 215 | off_diag_idx = np.ravel_multi_index( 216 | np.where(np.ones((num_genes, num_genes)) - np.eye(num_genes)), 217 | [num_genes, num_genes]) 218 | edges_train = edges_train[:, off_diag_idx] 219 | edges_valid = edges_valid[:, off_diag_idx] 220 | edges_test = edges_test[:, off_diag_idx] 221 | 222 | feat_train = torch.FloatTensor(feat_train) 223 | edges_train = torch.LongTensor(edges_train) 224 | feat_valid = torch.FloatTensor(feat_valid) 225 | edges_valid = torch.LongTensor(edges_valid) 226 | feat_test = torch.FloatTensor(feat_test) 227 | edges_test = torch.LongTensor(edges_test) 228 | 229 | train_data = TensorDataset(feat_train, edges_train) 230 | valid_data = TensorDataset(feat_valid, edges_valid) 231 | test_data = TensorDataset(feat_test, edges_test) 232 | 233 | train_data_loader = DataLoader(train_data, batch_size=batch_size) 234 | valid_data_loader = DataLoader(valid_data, batch_size=batch_size) 235 | test_data_loader = DataLoader(test_data, batch_size=batch_size) 236 | 237 | return train_data_loader, valid_data_loader, test_data_loader, loc_max, loc_min, vel_max, vel_min 238 | 239 | 240 | def load_dataset_train_test(batch_size=1, number_exp=1, number_expstart=0, dims=6): 241 | feat_train = np.load('data/features.npy') 242 | edges_train = np.load('data/edges.npy') 243 | feat_test = np.load('data/features_test.npy') 244 | edges_test = np.load('data/edges_test.npy') 245 | 246 | assert(feat_train.shape[0] >= number_exp) 247 | assert(feat_train.shape[2] >= dims) 248 | feat_train = feat_train[number_expstart:number_exp, :, 0:dims, :] 249 | edges_train = edges_train[number_expstart:number_exp, :, :] 250 | feat_test = feat_test[:, :, 0:dims, :] 251 | 252 | # [num_samples, num_timesteps, num_dims, num_genes] 253 | num_genes = feat_train.shape[3] 254 | 255 | loc_max = feat_train[:, :, 0:3, :].max() 256 | loc_min = feat_train[:, :, 0:3, :].min() 257 | vel_max = feat_train[:, :, 3:6, :].max() 258 | vel_min = feat_train[:, :, 3:6, :].min() 259 | 260 | # Normalize to [-1, 1] 261 | loc_train = (feat_train[:, :, 0:3, :] - loc_min) * \ 262 | 2 / (loc_max - loc_min) - 1 263 | vel_train = (feat_train[:, :, 3:6, :] - vel_min) * \ 264 | 2 / (vel_max - vel_min) - 1 265 | 266 | loc_test = (feat_test[:, :, 0:3, :] - loc_min) * \ 267 | 2 / (loc_max - loc_min) - 1 268 | vel_test = (feat_test[:, :, 3:6, :] - vel_min) * \ 269 | 2 / (vel_max - vel_min) - 1 270 | 271 | feat_train = np.concatenate((loc_train, vel_train), axis=2) 272 | feat_test = np.concatenate((loc_test, vel_test), axis=2) 273 | 274 | # Reshape to: [num_samples, num_genes, num_timesteps, num_dims] 275 | feat_train = np.transpose(feat_train, [0, 3, 1, 2]) 276 | edges_train = np.reshape(edges_train, [-1, num_genes ** 2]) 277 | edges_train = np.array((edges_train + 1) / 2, dtype=np.int64) 278 | 279 | feat_test = np.transpose(feat_test, [0, 3, 1, 2]) 280 | edges_test = np.reshape(edges_test, [-1, num_genes ** 2]) 281 | edges_test = np.array((edges_test + 1) / 2, dtype=np.int64) 282 | 283 | # Exclude self edges 284 | off_diag_idx = np.ravel_multi_index( 285 | np.where(np.ones((num_genes, num_genes)) - np.eye(num_genes)), 286 | [num_genes, num_genes]) 287 | edges_train = edges_train[:, off_diag_idx] 288 | edges_test = edges_test[:, off_diag_idx] 289 | 290 | feat_train = torch.FloatTensor(feat_train) 291 | edges_train = torch.LongTensor(edges_train) 292 | feat_test = torch.FloatTensor(feat_test) 293 | edges_test = torch.LongTensor(edges_test) 294 | 295 | train_data = TensorDataset(feat_train, edges_train) 296 | test_data = TensorDataset(feat_test, edges_test) 297 | 298 | train_data_loader = DataLoader(train_data, batch_size=batch_size) 299 | test_data_loader = DataLoader(test_data, batch_size=batch_size) 300 | 301 | return train_data_loader, test_data_loader, loc_max, loc_min, vel_max, vel_min 302 | 303 | 304 | def load_data(batch_size=1, suffix=''): 305 | loc_train = np.load('data/loc_train' + suffix + '.npy') 306 | vel_train = np.load('data/vel_train' + suffix + '.npy') 307 | edges_train = np.load('data/edges_train' + suffix + '.npy') 308 | 309 | loc_valid = np.load('data/loc_valid' + suffix + '.npy') 310 | vel_valid = np.load('data/vel_valid' + suffix + '.npy') 311 | edges_valid = np.load('data/edges_valid' + suffix + '.npy') 312 | 313 | loc_test = np.load('data/loc_test' + suffix + '.npy') 314 | vel_test = np.load('data/vel_test' + suffix + '.npy') 315 | edges_test = np.load('data/edges_test' + suffix + '.npy') 316 | 317 | # [num_samples, num_timesteps, num_dims, num_atoms] 318 | num_atoms = loc_train.shape[3] 319 | 320 | loc_max = loc_train.max() 321 | loc_min = loc_train.min() 322 | vel_max = vel_train.max() 323 | vel_min = vel_train.min() 324 | 325 | # Normalize to [-1, 1] 326 | loc_train = (loc_train - loc_min) * 2 / (loc_max - loc_min) - 1 327 | vel_train = (vel_train - vel_min) * 2 / (vel_max - vel_min) - 1 328 | 329 | loc_valid = (loc_valid - loc_min) * 2 / (loc_max - loc_min) - 1 330 | vel_valid = (vel_valid - vel_min) * 2 / (vel_max - vel_min) - 1 331 | 332 | loc_test = (loc_test - loc_min) * 2 / (loc_max - loc_min) - 1 333 | vel_test = (vel_test - vel_min) * 2 / (vel_max - vel_min) - 1 334 | 335 | # Reshape to: [num_sims, num_atoms, num_timesteps, num_dims] 336 | loc_train = np.transpose(loc_train, [0, 3, 1, 2]) 337 | vel_train = np.transpose(vel_train, [0, 3, 1, 2]) 338 | feat_train = np.concatenate([loc_train, vel_train], axis=3) 339 | edges_train = np.reshape(edges_train, [-1, num_atoms ** 2]) 340 | edges_train = np.array((edges_train + 1) / 2, dtype=np.int64) 341 | 342 | loc_valid = np.transpose(loc_valid, [0, 3, 1, 2]) 343 | vel_valid = np.transpose(vel_valid, [0, 3, 1, 2]) 344 | feat_valid = np.concatenate([loc_valid, vel_valid], axis=3) 345 | edges_valid = np.reshape(edges_valid, [-1, num_atoms ** 2]) 346 | edges_valid = np.array((edges_valid + 1) / 2, dtype=np.int64) 347 | 348 | loc_test = np.transpose(loc_test, [0, 3, 1, 2]) 349 | vel_test = np.transpose(vel_test, [0, 3, 1, 2]) 350 | feat_test = np.concatenate([loc_test, vel_test], axis=3) 351 | edges_test = np.reshape(edges_test, [-1, num_atoms ** 2]) 352 | edges_test = np.array((edges_test + 1) / 2, dtype=np.int64) 353 | 354 | feat_train = torch.FloatTensor(feat_train) 355 | edges_train = torch.LongTensor(edges_train) 356 | feat_valid = torch.FloatTensor(feat_valid) 357 | edges_valid = torch.LongTensor(edges_valid) 358 | feat_test = torch.FloatTensor(feat_test) 359 | edges_test = torch.LongTensor(edges_test) 360 | 361 | # Exclude self edges 362 | off_diag_idx = np.ravel_multi_index( 363 | np.where(np.ones((num_atoms, num_atoms)) - np.eye(num_atoms)), 364 | [num_atoms, num_atoms]) 365 | edges_train = edges_train[:, off_diag_idx] 366 | edges_valid = edges_valid[:, off_diag_idx] 367 | edges_test = edges_test[:, off_diag_idx] 368 | 369 | train_data = TensorDataset(feat_train, edges_train) 370 | valid_data = TensorDataset(feat_valid, edges_valid) 371 | test_data = TensorDataset(feat_test, edges_test) 372 | 373 | train_data_loader = DataLoader(train_data, batch_size=batch_size) 374 | valid_data_loader = DataLoader(valid_data, batch_size=batch_size) 375 | test_data_loader = DataLoader(test_data, batch_size=batch_size) 376 | 377 | return train_data_loader, valid_data_loader, test_data_loader, loc_max, loc_min, vel_max, vel_min 378 | 379 | 380 | def load_kuramoto_data(batch_size=1, suffix=''): 381 | feat_train = np.load('data/feat_train' + suffix + '.npy') 382 | edges_train = np.load('data/edges_train' + suffix + '.npy') 383 | feat_valid = np.load('data/feat_valid' + suffix + '.npy') 384 | edges_valid = np.load('data/edges_valid' + suffix + '.npy') 385 | feat_test = np.load('data/feat_test' + suffix + '.npy') 386 | edges_test = np.load('data/edges_test' + suffix + '.npy') 387 | 388 | # [num_sims, num_atoms, num_timesteps, num_dims] 389 | num_atoms = feat_train.shape[1] 390 | 391 | # Normalize each feature dim. individually 392 | feat_max = feat_train.max(0).max(0).max(0) 393 | feat_min = feat_train.min(0).min(0).min(0) 394 | 395 | feat_max = np.expand_dims(np.expand_dims( 396 | np.expand_dims(feat_max, 0), 0), 0) 397 | feat_min = np.expand_dims(np.expand_dims( 398 | np.expand_dims(feat_min, 0), 0), 0) 399 | 400 | # Normalize to [-1, 1] 401 | feat_train = (feat_train - feat_min) * 2 / (feat_max - feat_min) - 1 402 | feat_valid = (feat_valid - feat_min) * 2 / (feat_max - feat_min) - 1 403 | feat_test = (feat_test - feat_min) * 2 / (feat_max - feat_min) - 1 404 | 405 | # Reshape to: [num_sims, num_atoms, num_timesteps, num_dims] 406 | edges_train = np.reshape(edges_train, [-1, num_atoms ** 2]) 407 | edges_valid = np.reshape(edges_valid, [-1, num_atoms ** 2]) 408 | edges_test = np.reshape(edges_test, [-1, num_atoms ** 2]) 409 | 410 | feat_train = torch.FloatTensor(feat_train) 411 | edges_train = torch.LongTensor(edges_train) 412 | feat_valid = torch.FloatTensor(feat_valid) 413 | edges_valid = torch.LongTensor(edges_valid) 414 | feat_test = torch.FloatTensor(feat_test) 415 | edges_test = torch.LongTensor(edges_test) 416 | 417 | # Exclude self edges 418 | off_diag_idx = np.ravel_multi_index( 419 | np.where(np.ones((num_atoms, num_atoms)) - np.eye(num_atoms)), 420 | [num_atoms, num_atoms]) 421 | edges_train = edges_train[:, off_diag_idx] 422 | edges_valid = edges_valid[:, off_diag_idx] 423 | edges_test = edges_test[:, off_diag_idx] 424 | 425 | train_data = TensorDataset(feat_train, edges_train) 426 | valid_data = TensorDataset(feat_valid, edges_valid) 427 | test_data = TensorDataset(feat_test, edges_test) 428 | 429 | train_data_loader = DataLoader(train_data, batch_size=batch_size) 430 | valid_data_loader = DataLoader(valid_data, batch_size=batch_size) 431 | test_data_loader = DataLoader(test_data, batch_size=batch_size) 432 | 433 | return train_data_loader, valid_data_loader, test_data_loader 434 | 435 | 436 | def to_2d_idx(idx, num_cols): 437 | idx = np.array(idx, dtype=np.int64) 438 | y_idx = np.array(np.floor(idx / float(num_cols)), dtype=np.int64) 439 | x_idx = idx % num_cols 440 | return x_idx, y_idx 441 | 442 | 443 | def encode_onehot(labels): 444 | classes = set(labels) 445 | classes_dict = {c: np.identity(len(classes))[i, :] for i, c in 446 | enumerate(classes)} 447 | labels_onehot = np.array(list(map(classes_dict.get, labels)), 448 | dtype=np.int32) 449 | return labels_onehot 450 | 451 | 452 | def get_triu_indices(num_nodes): 453 | """Linear triu (upper triangular) indices.""" 454 | ones = torch.ones(num_nodes, num_nodes) 455 | eye = torch.eye(num_nodes, num_nodes) 456 | triu_indices = (ones.triu() - eye).nonzero().t() 457 | triu_indices = triu_indices[0] * num_nodes + triu_indices[1] 458 | return triu_indices 459 | 460 | 461 | def get_tril_indices(num_nodes): 462 | """Linear tril (lower triangular) indices.""" 463 | ones = torch.ones(num_nodes, num_nodes) 464 | eye = torch.eye(num_nodes, num_nodes) 465 | tril_indices = (ones.tril() - eye).nonzero().t() 466 | tril_indices = tril_indices[0] * num_nodes + tril_indices[1] 467 | return tril_indices 468 | 469 | 470 | def get_offdiag_indices(num_nodes): 471 | """Linear off-diagonal indices.""" 472 | ones = torch.ones(num_nodes, num_nodes) 473 | eye = torch.eye(num_nodes, num_nodes) 474 | offdiag_indices = (ones - eye).nonzero().t() 475 | offdiag_indices = offdiag_indices[0] * num_nodes + offdiag_indices[1] 476 | return offdiag_indices 477 | 478 | 479 | def get_triu_offdiag_indices(num_nodes): 480 | """Linear triu (upper) indices w.r.t. vector of off-diagonal elements.""" 481 | triu_idx = torch.zeros(num_nodes * num_nodes) 482 | triu_idx[get_triu_indices(num_nodes)] = 1. 483 | triu_idx = triu_idx[get_offdiag_indices(num_nodes)] 484 | return triu_idx.nonzero() 485 | 486 | 487 | def get_tril_offdiag_indices(num_nodes): 488 | """Linear tril (lower) indices w.r.t. vector of off-diagonal elements.""" 489 | tril_idx = torch.zeros(num_nodes * num_nodes) 490 | tril_idx[get_tril_indices(num_nodes)] = 1. 491 | tril_idx = tril_idx[get_offdiag_indices(num_nodes)] 492 | return tril_idx.nonzero() 493 | 494 | 495 | def get_minimum_distance(data): 496 | data = data[:, :, :, :2].transpose(1, 2) 497 | data_norm = (data ** 2).sum(-1, keepdim=True) 498 | dist = data_norm + \ 499 | data_norm.transpose(2, 3) - \ 500 | 2 * torch.matmul(data, data.transpose(2, 3)) 501 | min_dist, _ = dist.min(1) 502 | return min_dist.view(min_dist.size(0), -1) 503 | 504 | 505 | def get_buckets(dist, num_buckets): 506 | dist = dist.cpu().data.numpy() 507 | 508 | min_dist = np.min(dist) 509 | max_dist = np.max(dist) 510 | bucket_size = (max_dist - min_dist) / num_buckets 511 | thresholds = bucket_size * np.arange(num_buckets) 512 | 513 | bucket_idx = [] 514 | for i in range(num_buckets): 515 | if i < num_buckets - 1: 516 | idx = np.where(np.all(np.vstack((dist > thresholds[i], 517 | dist <= thresholds[i + 1])), 0))[0] 518 | else: 519 | idx = np.where(dist > thresholds[i])[0] 520 | bucket_idx.append(idx) 521 | 522 | return bucket_idx, thresholds 523 | 524 | 525 | def get_correct_per_bucket(bucket_idx, pred, target): 526 | pred = pred.cpu().numpy()[:, 0] 527 | target = target.cpu().data.numpy() 528 | 529 | correct_per_bucket = [] 530 | for i in range(len(bucket_idx)): 531 | preds_bucket = pred[bucket_idx[i]] 532 | target_bucket = target[bucket_idx[i]] 533 | correct_bucket = np.sum(preds_bucket == target_bucket) 534 | correct_per_bucket.append(correct_bucket) 535 | 536 | return correct_per_bucket 537 | 538 | 539 | def get_correct_per_bucket_(bucket_idx, pred, target): 540 | pred = pred.cpu().numpy() 541 | target = target.cpu().data.numpy() 542 | 543 | correct_per_bucket = [] 544 | for i in range(len(bucket_idx)): 545 | preds_bucket = pred[bucket_idx[i]] 546 | target_bucket = target[bucket_idx[i]] 547 | correct_bucket = np.sum(preds_bucket == target_bucket) 548 | correct_per_bucket.append(correct_bucket) 549 | 550 | return correct_per_bucket 551 | 552 | 553 | def kl_categorical(preds, log_prior, num_atoms, eps=1e-16): 554 | kl_div = preds * (torch.log(preds + eps) - log_prior) 555 | return kl_div.sum() / (num_atoms * preds.size(0)) 556 | 557 | 558 | def kl_categorical_uniform(preds, num_atoms, num_edge_types, add_const=False, 559 | eps=1e-16): 560 | kl_div = preds * torch.log(preds + eps) 561 | if add_const: 562 | const = np.log(num_edge_types) 563 | kl_div += const 564 | return kl_div.sum() / (num_atoms * preds.size(0)) 565 | 566 | 567 | def nll_gaussian(preds, target, variance, add_const=False): 568 | neg_log_p = ((preds - target) ** 2 / (2 * variance)) 569 | if add_const: 570 | const = 0.5 * np.log(2 * np.pi * variance) 571 | neg_log_p += const 572 | return neg_log_p.sum() / (target.size(0) * target.size(1)) 573 | 574 | 575 | def edge_accuracy(preds, target): 576 | _, preds = preds.max(-1) 577 | correct = preds.float().data.eq( 578 | target.float().data.view_as(preds)).cpu().sum() 579 | return np.float(correct) / (target.size(0) * target.size(1)) 580 | -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | from torch.autograd import Variable 7 | from utils import my_softmax, get_offdiag_indices, gumbel_softmax 8 | 9 | _EPS = 1e-10 10 | 11 | 12 | class MLP(nn.Module): 13 | """Two-layer fully-connected ELU net with batch norm.""" 14 | 15 | def __init__(self, n_in, n_hid, n_out, do_prob=0.): 16 | super(MLP, self).__init__() 17 | self.fc1 = nn.Linear(n_in, n_hid) 18 | self.fc2 = nn.Linear(n_hid, n_out) 19 | self.bn = nn.BatchNorm1d(n_out) 20 | self.dropout_prob = do_prob 21 | 22 | self.init_weights() 23 | 24 | def init_weights(self): 25 | for m in self.modules(): 26 | if isinstance(m, nn.Linear): 27 | nn.init.xavier_normal_(m.weight.data) 28 | m.bias.data.fill_(0.1) 29 | elif isinstance(m, nn.BatchNorm1d): 30 | m.weight.data.fill_(1) 31 | m.bias.data.zero_() 32 | 33 | def batch_norm(self, inputs): 34 | x = inputs.view(inputs.size(0) * inputs.size(1), -1) 35 | x = self.bn(x) 36 | return x.view(inputs.size(0), inputs.size(1), -1) 37 | 38 | def forward(self, inputs): 39 | # Input shape: [num_sims, num_things, num_features] 40 | x = F.elu(self.fc1(inputs)) 41 | x = F.dropout(x, self.dropout_prob, training=self.training) 42 | x = F.elu(self.fc2(x)) 43 | return self.batch_norm(x) 44 | 45 | 46 | class CNN(nn.Module): 47 | def __init__(self, n_in, n_hid, n_out, do_prob=0.): 48 | super(CNN, self).__init__() 49 | self.pool = nn.MaxPool1d(kernel_size=2, stride=None, padding=0, 50 | dilation=1, return_indices=False, 51 | ceil_mode=False) 52 | 53 | self.conv1 = nn.Conv1d(n_in, n_hid, kernel_size=5, stride=1, padding=0) 54 | self.bn1 = nn.BatchNorm1d(n_hid) 55 | self.conv2 = nn.Conv1d( 56 | n_hid, n_hid, kernel_size=5, stride=1, padding=0) 57 | self.bn2 = nn.BatchNorm1d(n_hid) 58 | self.conv_predict = nn.Conv1d(n_hid, n_out, kernel_size=1) 59 | self.conv_attention = nn.Conv1d(n_hid, 1, kernel_size=1) 60 | self.dropout_prob = do_prob 61 | 62 | self.init_weights() 63 | 64 | def init_weights(self): 65 | for m in self.modules(): 66 | if isinstance(m, nn.Conv1d): 67 | n = m.kernel_size[0] * m.out_channels 68 | m.weight.data.normal_(0, math.sqrt(2. / n)) 69 | m.bias.data.fill_(0.1) 70 | elif isinstance(m, nn.BatchNorm1d): 71 | m.weight.data.fill_(1) 72 | m.bias.data.zero_() 73 | 74 | def forward(self, inputs): 75 | # Input shape: [num_sims * num_edges, num_dims, num_timesteps] 76 | 77 | x = F.relu(self.conv1(inputs)) 78 | x = self.bn1(x) 79 | x = F.dropout(x, self.dropout_prob, training=self.training) 80 | x = self.pool(x) 81 | x = F.relu(self.conv2(x)) 82 | x = self.bn2(x) 83 | pred = self.conv_predict(x) 84 | attention = my_softmax(self.conv_attention(x), axis=2) 85 | 86 | edge_prob = (pred * attention).mean(dim=2) 87 | return edge_prob 88 | 89 | 90 | class MLPEncoder(nn.Module): 91 | def __init__(self, n_in, n_hid, n_out, do_prob=0., factor=True): 92 | super(MLPEncoder, self).__init__() 93 | 94 | self.factor = factor 95 | 96 | self.mlp1 = MLP(n_in, n_hid, n_hid, do_prob) 97 | self.mlp2 = MLP(n_hid * 2, n_hid, n_hid, do_prob) 98 | self.mlp3 = MLP(n_hid, n_hid, n_hid, do_prob) 99 | if self.factor: 100 | self.mlp4 = MLP(n_hid * 3, n_hid, n_hid, do_prob) 101 | print("Using factor graph MLP encoder.") 102 | else: 103 | self.mlp4 = MLP(n_hid * 2, n_hid, n_hid, do_prob) 104 | print("Using MLP encoder.") 105 | self.fc_out = nn.Linear(n_hid, n_out) 106 | self.init_weights() 107 | 108 | def init_weights(self): 109 | for m in self.modules(): 110 | if isinstance(m, nn.Linear): 111 | nn.init.xavier_normal_(m.weight.data) 112 | m.bias.data.fill_(0.1) 113 | 114 | def edge2node(self, x, rel_rec, rel_send): 115 | # NOTE: Assumes that we have the same graph across all samples. 116 | incoming = torch.matmul(rel_rec.t(), x) 117 | return incoming / incoming.size(1) 118 | 119 | def node2edge(self, x, rel_rec, rel_send): 120 | # NOTE: Assumes that we have the same graph across all samples. 121 | receivers = torch.matmul(rel_rec, x) 122 | senders = torch.matmul(rel_send, x) 123 | edges = torch.cat([receivers, senders], dim=2) 124 | return edges 125 | 126 | def forward(self, inputs, rel_rec, rel_send): 127 | # Input shape: [num_sims, num_atoms, num_timesteps, num_dims] 128 | x = inputs.view(inputs.size(0), inputs.size(1), -1) 129 | # New shape: [num_sims, num_atoms, num_timesteps*num_dims] 130 | 131 | x = self.mlp1(x) # 2-layer ELU net per node 132 | 133 | x = self.node2edge(x, rel_rec, rel_send) 134 | x = self.mlp2(x) 135 | x_skip = x 136 | 137 | if self.factor: 138 | x = self.edge2node(x, rel_rec, rel_send) 139 | x = self.mlp3(x) 140 | x = self.node2edge(x, rel_rec, rel_send) 141 | x = torch.cat((x, x_skip), dim=2) # Skip connection 142 | x = self.mlp4(x) 143 | else: 144 | x = self.mlp3(x) 145 | x = torch.cat((x, x_skip), dim=2) # Skip connection 146 | x = self.mlp4(x) 147 | 148 | return self.fc_out(x) 149 | 150 | 151 | class CNNEncoder(nn.Module): 152 | def __init__(self, n_in, n_hid, n_out, do_prob=0., factor=True): 153 | super(CNNEncoder, self).__init__() 154 | self.dropout_prob = do_prob 155 | 156 | self.factor = factor 157 | 158 | self.cnn = CNN(n_in * 2, n_hid, n_hid, do_prob) 159 | self.mlp1 = MLP(n_hid, n_hid, n_hid, do_prob) 160 | self.mlp2 = MLP(n_hid, n_hid, n_hid, do_prob) 161 | self.mlp3 = MLP(n_hid * 3, n_hid, n_hid, do_prob) 162 | self.fc_out = nn.Linear(n_hid, n_out) 163 | 164 | if self.factor: 165 | print("Using factor graph CNN encoder.") 166 | else: 167 | print("Using CNN encoder.") 168 | 169 | self.init_weights() 170 | 171 | def init_weights(self): 172 | for m in self.modules(): 173 | if isinstance(m, nn.Linear): 174 | nn.init.xavier_normal_(m.weight.data) 175 | m.bias.data.fill_(0.1) 176 | 177 | def node2edge_temporal(self, inputs, rel_rec, rel_send): 178 | # NOTE: Assumes that we have the same graph across all samples. 179 | 180 | x = inputs.view(inputs.size(0), inputs.size(1), -1) 181 | 182 | receivers = torch.matmul(rel_rec, x) 183 | receivers = receivers.view(inputs.size(0) * receivers.size(1), 184 | inputs.size(2), inputs.size(3)) 185 | receivers = receivers.transpose(2, 1) 186 | 187 | senders = torch.matmul(rel_send, x) 188 | senders = senders.view(inputs.size(0) * senders.size(1), 189 | inputs.size(2), 190 | inputs.size(3)) 191 | senders = senders.transpose(2, 1) 192 | 193 | # receivers and senders have shape: 194 | # [num_sims * num_edges, num_dims, num_timesteps] 195 | edges = torch.cat([receivers, senders], dim=1) 196 | return edges 197 | 198 | def edge2node(self, x, rel_rec, rel_send): 199 | # NOTE: Assumes that we have the same graph across all samples. 200 | incoming = torch.matmul(rel_rec.t(), x) 201 | return incoming / incoming.size(1) 202 | 203 | def node2edge(self, x, rel_rec, rel_send): 204 | # NOTE: Assumes that we have the same graph across all samples. 205 | receivers = torch.matmul(rel_rec, x) 206 | senders = torch.matmul(rel_send, x) 207 | edges = torch.cat([receivers, senders], dim=2) 208 | return edges 209 | 210 | def forward(self, inputs, rel_rec, rel_send): 211 | 212 | # Input has shape: [num_sims, num_atoms, num_timesteps, num_dims] 213 | edges = self.node2edge_temporal(inputs, rel_rec, rel_send) 214 | x = self.cnn(edges) 215 | x = x.view(inputs.size(0), (inputs.size(1) - 1) * inputs.size(1), -1) 216 | x = self.mlp1(x) 217 | x_skip = x 218 | 219 | if self.factor: 220 | x = self.edge2node(x, rel_rec, rel_send) 221 | x = self.mlp2(x) 222 | 223 | x = self.node2edge(x, rel_rec, rel_send) 224 | x = torch.cat((x, x_skip), dim=2) # Skip connection 225 | x = self.mlp3(x) 226 | 227 | return self.fc_out(x) 228 | 229 | 230 | class SimulationDecoder(nn.Module): 231 | """Simulation-based decoder.""" 232 | 233 | def __init__(self, loc_max, loc_min, vel_max, vel_min, suffix): 234 | super(SimulationDecoder, self).__init__() 235 | 236 | self.loc_max = loc_max 237 | self.loc_min = loc_min 238 | self.vel_max = vel_max 239 | self.vel_min = vel_min 240 | 241 | self.interaction_type = suffix 242 | 243 | if '_springs' in self.interaction_type: 244 | print('Using spring simulation decoder.') 245 | self.interaction_strength = .1 246 | self.sample_freq = 1 247 | self._delta_T = 0.1 248 | self.box_size = 5. 249 | elif '_charged' in self.interaction_type: 250 | print('Using charged particle simulation decoder.') 251 | self.interaction_strength = 1. 252 | self.sample_freq = 100 253 | self._delta_T = 0.001 254 | self.box_size = 5. 255 | elif '_charged_short' in self.interaction_type: 256 | print('Using charged particle simulation decoder.') 257 | self.interaction_strength = .1 258 | self.sample_freq = 10 259 | self._delta_T = 0.001 260 | self.box_size = 1. 261 | else: 262 | print("Simulation type could not be inferred from suffix.") 263 | 264 | self.out = None 265 | 266 | # NOTE: For exact reproduction, choose sample_freq=100, delta_T=0.001 267 | 268 | self._max_F = 0.1 / self._delta_T 269 | 270 | def unnormalize(self, loc, vel): 271 | loc = 0.5 * (loc + 1) * (self.loc_max - self.loc_min) + self.loc_min 272 | vel = 0.5 * (vel + 1) * (self.vel_max - self.vel_min) + self.vel_min 273 | return loc, vel 274 | 275 | def renormalize(self, loc, vel): 276 | loc = 2 * (loc - self.loc_min) / (self.loc_max - self.loc_min) - 1 277 | vel = 2 * (vel - self.vel_min) / (self.vel_max - self.vel_min) - 1 278 | return loc, vel 279 | 280 | def clamp(self, loc, vel): 281 | over = loc > self.box_size 282 | loc[over] = 2 * self.box_size - loc[over] 283 | vel[over] = -torch.abs(vel[over]) 284 | 285 | under = loc < -self.box_size 286 | loc[under] = -2 * self.box_size - loc[under] 287 | vel[under] = torch.abs(vel[under]) 288 | 289 | return loc, vel 290 | 291 | def set_diag_to_zero(self, x): 292 | """Hack to set diagonal of a tensor to zero.""" 293 | mask = torch.diag(torch.ones(x.size(1))).unsqueeze(0).expand_as(x) 294 | inverse_mask = torch.ones(x.size(1), x.size(1)) - mask 295 | if x.is_cuda: 296 | inverse_mask = inverse_mask.cuda() 297 | inverse_mask = Variable(inverse_mask) 298 | return inverse_mask * x 299 | 300 | def set_diag_to_one(self, x): 301 | """Hack to set diagonal of a tensor to one.""" 302 | mask = torch.diag(torch.ones(x.size(1))).unsqueeze(0).expand_as(x) 303 | inverse_mask = torch.ones(x.size(1), x.size(1)) - mask 304 | if x.is_cuda: 305 | mask, inverse_mask = mask.cuda(), inverse_mask.cuda() 306 | mask, inverse_mask = Variable(mask), Variable(inverse_mask) 307 | return mask + inverse_mask * x 308 | 309 | def pairwise_sq_dist(self, x): 310 | xx = torch.bmm(x, x.transpose(1, 2)) 311 | rx = (x ** 2).sum(2).unsqueeze(-1).expand_as(xx) 312 | return torch.abs(rx.transpose(1, 2) + rx - 2 * xx) 313 | 314 | def forward(self, inputs, relations, rel_rec, rel_send, pred_steps=1): 315 | # Input has shape: [num_sims, num_things, num_timesteps, num_dims] 316 | # Relation mx shape: [num_sims, num_things*num_things] 317 | 318 | # Only keep single dimension of softmax output 319 | relations = relations[:, :, 1] 320 | 321 | loc = inputs[:, :, :-1, :2].contiguous() 322 | vel = inputs[:, :, :-1, 2:].contiguous() 323 | 324 | # Broadcasting/shape tricks for parallel processing of time steps 325 | loc = loc.permute(0, 2, 1, 3).contiguous() 326 | vel = vel.permute(0, 2, 1, 3).contiguous() 327 | loc = loc.view(inputs.size(0) * (inputs.size(2) - 1), 328 | inputs.size(1), 2) 329 | vel = vel.view(inputs.size(0) * (inputs.size(2) - 1), 330 | inputs.size(1), 2) 331 | 332 | loc, vel = self.unnormalize(loc, vel) 333 | 334 | offdiag_indices = get_offdiag_indices(inputs.size(1)) 335 | edges = Variable(torch.zeros(relations.size(0), inputs.size(1) * 336 | inputs.size(1))) 337 | if inputs.is_cuda: 338 | edges = edges.cuda() 339 | offdiag_indices = offdiag_indices.cuda() 340 | 341 | edges[:, offdiag_indices] = relations.float() 342 | 343 | edges = edges.view(relations.size(0), inputs.size(1), 344 | inputs.size(1)) 345 | 346 | self.out = [] 347 | 348 | for _ in range(0, self.sample_freq): 349 | x = loc[:, :, 0].unsqueeze(-1) 350 | y = loc[:, :, 1].unsqueeze(-1) 351 | 352 | xx = x.expand(x.size(0), x.size(1), x.size(1)) 353 | yy = y.expand(y.size(0), y.size(1), y.size(1)) 354 | dist_x = xx - xx.transpose(1, 2) 355 | dist_y = yy - yy.transpose(1, 2) 356 | 357 | if '_springs' in self.interaction_type: 358 | forces_size = -self.interaction_strength * edges 359 | pair_dist = torch.cat( 360 | (dist_x.unsqueeze(-1), dist_y.unsqueeze(-1)), 361 | -1) 362 | 363 | # Tricks for parallel processing of time steps 364 | pair_dist = pair_dist.view(inputs.size(0), (inputs.size(2) - 1), 365 | inputs.size(1), inputs.size(1), 2) 366 | forces = ( 367 | forces_size.unsqueeze(-1).unsqueeze(1) * pair_dist).sum( 368 | 3) 369 | else: # charged particle sim 370 | e = (-1) * (edges * 2 - 1) 371 | forces_size = -self.interaction_strength * e 372 | 373 | l2_dist_power3 = torch.pow(self.pairwise_sq_dist(loc), 3. / 2.) 374 | l2_dist_power3 = self.set_diag_to_one(l2_dist_power3) 375 | 376 | l2_dist_power3 = l2_dist_power3.view(inputs.size(0), 377 | (inputs.size(2) - 1), 378 | inputs.size(1), 379 | inputs.size(1)) 380 | forces_size = forces_size.unsqueeze( 381 | 1) / (l2_dist_power3 + _EPS) 382 | 383 | pair_dist = torch.cat( 384 | (dist_x.unsqueeze(-1), dist_y.unsqueeze(-1)), 385 | -1) 386 | pair_dist = pair_dist.view(inputs.size(0), (inputs.size(2) - 1), 387 | inputs.size(1), inputs.size(1), 2) 388 | forces = (forces_size.unsqueeze(-1) * pair_dist).sum(3) 389 | 390 | forces = forces.view(inputs.size(0) * (inputs.size(2) - 1), 391 | inputs.size(1), 2) 392 | 393 | if '_charged' in self.interaction_type: # charged particle sim 394 | # Clip forces 395 | forces[forces > self._max_F] = self._max_F 396 | forces[forces < -self._max_F] = -self._max_F 397 | 398 | # Leapfrog integration step 399 | vel = vel + self._delta_T * forces 400 | loc = loc + self._delta_T * vel 401 | 402 | # Handle box boundaries 403 | loc, vel = self.clamp(loc, vel) 404 | 405 | loc, vel = self.renormalize(loc, vel) 406 | 407 | loc = loc.view(inputs.size(0), (inputs.size(2) - 1), inputs.size(1), 2) 408 | vel = vel.view(inputs.size(0), (inputs.size(2) - 1), inputs.size(1), 2) 409 | 410 | loc = loc.permute(0, 2, 1, 3) 411 | vel = vel.permute(0, 2, 1, 3) 412 | 413 | out = torch.cat((loc, vel), dim=-1) 414 | # Output has shape: [num_sims, num_things, num_timesteps-1, num_dims] 415 | 416 | return out 417 | 418 | 419 | class MLPDecoder(nn.Module): 420 | """MLP decoder module.""" 421 | 422 | def __init__(self, n_in_node, edge_types, msg_hid, msg_out, n_hid, 423 | do_prob=0., skip_first=False): 424 | super(MLPDecoder, self).__init__() 425 | self.msg_fc1 = nn.ModuleList( 426 | [nn.Linear(2 * n_in_node, msg_hid) for _ in range(edge_types)]) 427 | self.msg_fc2 = nn.ModuleList( 428 | [nn.Linear(msg_hid, msg_out) for _ in range(edge_types)]) 429 | self.msg_out_shape = msg_out 430 | self.skip_first_edge_type = skip_first 431 | 432 | self.out_fc1 = nn.Linear(n_in_node + msg_out, n_hid) 433 | self.out_fc2 = nn.Linear(n_hid, n_hid) 434 | self.out_fc3 = nn.Linear(n_hid, n_in_node) 435 | 436 | print('Using learned interaction net decoder.') 437 | 438 | self.dropout_prob = do_prob 439 | 440 | def single_step_forward(self, single_timestep_inputs, rel_rec, rel_send, 441 | single_timestep_rel_type): 442 | 443 | # single_timestep_inputs has shape 444 | # [batch_size, num_timesteps, num_atoms, num_dims] 445 | 446 | # single_timestep_rel_type has shape: 447 | # [batch_size, num_timesteps, num_atoms*(num_atoms-1), num_edge_types] 448 | 449 | # Node2edge 450 | receivers = torch.matmul(rel_rec, single_timestep_inputs) 451 | senders = torch.matmul(rel_send, single_timestep_inputs) 452 | pre_msg = torch.cat([receivers, senders], dim=-1) 453 | 454 | all_msgs = Variable(torch.zeros(pre_msg.size(0), pre_msg.size(1), 455 | pre_msg.size(2), self.msg_out_shape)) 456 | if single_timestep_inputs.is_cuda: 457 | all_msgs = all_msgs.cuda() 458 | 459 | if self.skip_first_edge_type: 460 | start_idx = 1 461 | else: 462 | start_idx = 0 463 | 464 | # Run separate MLP for every edge type 465 | # NOTE: To exlude one edge type, simply offset range by 1 466 | for i in range(start_idx, len(self.msg_fc2)): 467 | msg = F.relu(self.msg_fc1[i](pre_msg)) 468 | msg = F.dropout(msg, p=self.dropout_prob) 469 | msg = F.relu(self.msg_fc2[i](msg)) 470 | msg = msg * single_timestep_rel_type[:, :, :, i:i + 1] 471 | all_msgs += msg 472 | 473 | # Aggregate all msgs to receiver 474 | agg_msgs = all_msgs.transpose(-2, -1).matmul(rel_rec).transpose(-2, -1) 475 | agg_msgs = agg_msgs.contiguous() 476 | 477 | # Skip connection 478 | aug_inputs = torch.cat([single_timestep_inputs, agg_msgs], dim=-1) 479 | 480 | # Output MLP 481 | pred = F.dropout(F.relu(self.out_fc1(aug_inputs)), p=self.dropout_prob) 482 | pred = F.dropout(F.relu(self.out_fc2(pred)), p=self.dropout_prob) 483 | pred = self.out_fc3(pred) 484 | 485 | # Predict position/velocity difference 486 | return single_timestep_inputs + pred 487 | 488 | def forward(self, inputs, rel_type, rel_rec, rel_send, pred_steps=1): 489 | # NOTE: Assumes that we have the same graph across all samples. 490 | 491 | inputs = inputs.transpose(1, 2).contiguous() 492 | 493 | sizes = [rel_type.size(0), inputs.size(1), rel_type.size(1), 494 | rel_type.size(2)] 495 | rel_type = rel_type.unsqueeze(1).expand(sizes) 496 | 497 | time_steps = inputs.size(1) 498 | assert (pred_steps <= time_steps) 499 | preds = [] 500 | 501 | # Only take n-th timesteps as starting points (n: pred_steps) 502 | last_pred = inputs[:, 0::pred_steps, :, :] 503 | curr_rel_type = rel_type[:, 0::pred_steps, :, :] 504 | # NOTE: Assumes rel_type is constant (i.e. same across all time steps). 505 | 506 | # Run n prediction steps 507 | for step in range(0, pred_steps): 508 | last_pred = self.single_step_forward(last_pred, rel_rec, rel_send, 509 | curr_rel_type) 510 | preds.append(last_pred) 511 | 512 | sizes = [preds[0].size(0), preds[0].size(1) * pred_steps, 513 | preds[0].size(2), preds[0].size(3)] 514 | 515 | output = Variable(torch.zeros(sizes)) 516 | if inputs.is_cuda: 517 | output = output.cuda() 518 | 519 | # Re-assemble correct timeline 520 | for i in range(len(preds)): 521 | output[:, i::pred_steps, :, :] = preds[i] 522 | 523 | pred_all = output[:, :(inputs.size(1) - 1), :, :] 524 | 525 | return pred_all.transpose(1, 2).contiguous() 526 | 527 | 528 | class RNNDecoder(nn.Module): 529 | """Recurrent decoder module.""" 530 | 531 | def __init__(self, n_in_node, edge_types, n_hid, 532 | do_prob=0., skip_first=False): 533 | super(RNNDecoder, self).__init__() 534 | self.msg_fc1 = nn.ModuleList( 535 | [nn.Linear(2 * n_hid, n_hid) for _ in range(edge_types)]) 536 | self.msg_fc2 = nn.ModuleList( 537 | [nn.Linear(n_hid, n_hid) for _ in range(edge_types)]) 538 | self.msg_out_shape = n_hid 539 | self.skip_first_edge_type = skip_first 540 | 541 | self.hidden_r = nn.Linear(n_hid, n_hid, bias=False) 542 | self.hidden_i = nn.Linear(n_hid, n_hid, bias=False) 543 | self.hidden_h = nn.Linear(n_hid, n_hid, bias=False) 544 | 545 | self.input_r = nn.Linear(n_in_node, n_hid, bias=True) 546 | self.input_i = nn.Linear(n_in_node, n_hid, bias=True) 547 | self.input_n = nn.Linear(n_in_node, n_hid, bias=True) 548 | 549 | self.out_fc1 = nn.Linear(n_hid, n_hid) 550 | self.out_fc2 = nn.Linear(n_hid, n_hid) 551 | self.out_fc3 = nn.Linear(n_hid, n_in_node) 552 | 553 | print('Using learned recurrent interaction net decoder.') 554 | 555 | self.dropout_prob = do_prob 556 | 557 | def single_step_forward(self, inputs, rel_rec, rel_send, 558 | rel_type, hidden): 559 | 560 | # node2edge 561 | receivers = torch.matmul(rel_rec, hidden) 562 | senders = torch.matmul(rel_send, hidden) 563 | pre_msg = torch.cat([receivers, senders], dim=-1) 564 | 565 | all_msgs = Variable(torch.zeros(pre_msg.size(0), pre_msg.size(1), 566 | self.msg_out_shape)) 567 | if inputs.is_cuda: 568 | all_msgs = all_msgs.cuda() 569 | 570 | if self.skip_first_edge_type: 571 | start_idx = 1 572 | norm = float(len(self.msg_fc2)) - 1. 573 | else: 574 | start_idx = 0 575 | norm = float(len(self.msg_fc2)) 576 | 577 | # Run separate MLP for every edge type 578 | # NOTE: To exlude one edge type, simply offset range by 1 579 | for i in range(start_idx, len(self.msg_fc2)): 580 | msg = torch.tanh(self.msg_fc1[i](pre_msg)) 581 | msg = F.dropout(msg, p=self.dropout_prob) 582 | msg = torch.tanh(self.msg_fc2[i](msg)) 583 | msg = msg * rel_type[:, :, i:i + 1] 584 | all_msgs += msg / norm 585 | 586 | agg_msgs = all_msgs.transpose(-2, -1).matmul(rel_rec).transpose(-2, 587 | -1) 588 | agg_msgs = agg_msgs.contiguous() / inputs.size(2) # Average 589 | 590 | # GRU-style gated aggregation 591 | r = torch.sigmoid(self.input_r(inputs) + self.hidden_r(agg_msgs)) 592 | i = torch.sigmoid(self.input_i(inputs) + self.hidden_i(agg_msgs)) 593 | n = torch.tanh(self.input_n(inputs) + r * self.hidden_h(agg_msgs)) 594 | hidden = (1 - i) * n + i * hidden 595 | 596 | # Output MLP 597 | pred = F.dropout(F.relu(self.out_fc1(hidden)), p=self.dropout_prob) 598 | pred = F.dropout(F.relu(self.out_fc2(pred)), p=self.dropout_prob) 599 | pred = self.out_fc3(pred) 600 | 601 | # Predict position/velocity difference 602 | pred = inputs + pred 603 | 604 | return pred, hidden 605 | 606 | def forward(self, data, rel_type, rel_rec, rel_send, pred_steps=1, 607 | burn_in=False, burn_in_steps=1, dynamic_graph=False, 608 | encoder=None, temp=None): 609 | 610 | inputs = data.transpose(1, 2).contiguous() 611 | 612 | time_steps = inputs.size(1) 613 | 614 | # inputs has shape 615 | # [batch_size, num_timesteps, num_atoms, num_dims] 616 | 617 | # rel_type has shape: 618 | # [batch_size, num_atoms*(num_atoms-1), num_edge_types] 619 | 620 | hidden = Variable( 621 | torch.zeros(inputs.size(0), inputs.size(2), self.msg_out_shape)) 622 | if inputs.is_cuda: 623 | hidden = hidden.cuda() 624 | 625 | pred_all = [] 626 | 627 | for step in range(0, inputs.size(1) - 1): 628 | 629 | if burn_in: 630 | if step <= burn_in_steps: 631 | ins = inputs[:, step, :, :] 632 | else: 633 | ins = pred_all[step - 1] 634 | else: 635 | assert (pred_steps <= time_steps) 636 | # Use ground truth trajectory input vs. last prediction 637 | if not step % pred_steps: 638 | ins = inputs[:, step, :, :] 639 | else: 640 | ins = pred_all[step - 1] 641 | 642 | if dynamic_graph and step >= burn_in_steps: 643 | # NOTE: Assumes burn_in_steps = args.timesteps 644 | logits = encoder( 645 | data[:, :, step - burn_in_steps:step, :].contiguous(), 646 | rel_rec, rel_send) 647 | rel_type = gumbel_softmax(logits, tau=temp, hard=True) 648 | 649 | pred, hidden = self.single_step_forward(ins, rel_rec, rel_send, 650 | rel_type, hidden) 651 | pred_all.append(pred) 652 | 653 | preds = torch.stack(pred_all, dim=1) 654 | 655 | return preds.transpose(1, 2).contiguous() 656 | --------------------------------------------------------------------------------