├── assets ├── poster.png └── overview.png ├── motion ├── dataset │ ├── motion.pkl │ ├── split.pkl │ ├── motion_run.pkl │ └── split_run.pkl ├── preprocess.py ├── dataset.py └── amc_parser.py ├── start_mocap.sh ├── start_simulation.sh ├── start_md.sh ├── start_eval_mdanalysis.sh ├── config.json ├── LICENSE ├── .gitignore ├── simulation ├── datagen │ ├── run.sh │ ├── generate_dataset_complex.py │ ├── physical_objects.py │ └── system.py └── dataset.py ├── README.md ├── mdanalysis ├── preprocess.py └── dataset.py ├── utils.py ├── eval_mdanalysis.py ├── eval_mocap.py ├── main_simulation.py ├── eval_simulation.py ├── main_mocap.py ├── main_mdanalysis.py └── model ├── eghn.py └── basic.py /assets/poster.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hanjq17/EGHN/HEAD/assets/poster.png -------------------------------------------------------------------------------- /assets/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hanjq17/EGHN/HEAD/assets/overview.png -------------------------------------------------------------------------------- /motion/dataset/motion.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hanjq17/EGHN/HEAD/motion/dataset/motion.pkl -------------------------------------------------------------------------------- /motion/dataset/split.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hanjq17/EGHN/HEAD/motion/dataset/split.pkl -------------------------------------------------------------------------------- /motion/dataset/motion_run.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hanjq17/EGHN/HEAD/motion/dataset/motion_run.pkl -------------------------------------------------------------------------------- /motion/dataset/split_run.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hanjq17/EGHN/HEAD/motion/dataset/split_run.pkl -------------------------------------------------------------------------------- /start_mocap.sh: -------------------------------------------------------------------------------- 1 | log_dir=YOUR_LOG_DIR 2 | mkdir -p $log_dir 3 | python3 -u main_mocap.py --outf $log_dir 2>&1 | tee $log_dir/out.log 4 | 5 | echo "Success" 6 | echo "END" -------------------------------------------------------------------------------- /start_simulation.sh: -------------------------------------------------------------------------------- 1 | log_dir=YOUR_LOG_DIR 2 | mkdir -p $log_dir 3 | python3 -u main_simulation.py --outf $log_dir 2>&1 | tee $log_dir/out.log 4 | 5 | echo "Success" 6 | -------------------------------------------------------------------------------- /start_md.sh: -------------------------------------------------------------------------------- 1 | log_dir=YOUR_LOG_DIR 2 | mkdir -p $log_dir 3 | python3 -u main_mdanalysis.py --outf $log_dir 2>&1 | tee $log_dir/out.log 4 | 5 | echo "Success" 6 | echo "END" -------------------------------------------------------------------------------- /start_eval_mdanalysis.sh: -------------------------------------------------------------------------------- 1 | log_dir=YOUR_LOG_DIR 2 | mkdir -p $log_dir 3 | python3 -u eval_mdanalysis.py --outf $log_dir --model_dir ${MODEL_PATH} 2>&1 | tee $log_dir/out.log 4 | echo "Success" 5 | -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "md", 3 | "backbone": true, 4 | "batch_size": 8, 5 | "epochs": 15000, 6 | "no_cuda": false, 7 | "seed": 42, 8 | "lr": 0.00005, 9 | "nf": 128, 10 | "model": "hier", 11 | "n_layers": 4, 12 | "data_dir": "YOUR_DATA_DIR", 13 | "weight_decay": 0.0001, 14 | "dropout": 0.5, 15 | "lambda_link": 2, 16 | "interaction_layer": 4, 17 | "pooling_layer": 4, 18 | "decoder_layer": 2, 19 | "n_cluster": 20, 20 | "flat": false, 21 | "n_workers": 20, 22 | "load_cached": true, 23 | "test_rot": false, 24 | "test_trans": false, 25 | "delta_frame": 15 26 | } -------------------------------------------------------------------------------- /motion/preprocess.py: -------------------------------------------------------------------------------- 1 | import amc_parser as amc 2 | from glob import glob 3 | import numpy as np 4 | import pickle as pkl 5 | 6 | 7 | # BASE_DIR = 'data' 8 | # asf_name = '35.asf' 9 | BASE_DIR = '.' 10 | asf_name = '09.asf' 11 | 12 | joints = amc.parse_asf(BASE_DIR + '/' + asf_name) 13 | joints['root'].get_name_to_idx() 14 | edges = joints['root'].output_edges() 15 | print('All edges:', len(edges)) 16 | print(edges) 17 | 18 | all_X = [] 19 | 20 | for amc_name in glob(BASE_DIR + '/*.amc'): 21 | motions = amc.parse_amc(amc_name) 22 | if amc_name.split('.')[-2].split('_')[-1] == '10': 23 | print(amc_name, ' is the special case!!!') 24 | motions = motions[6:] 25 | T = len(motions) 26 | print('Frame:', T) 27 | XX = [] 28 | for i in range(T): 29 | joints['root'].set_motion(motions[i]) 30 | X = joints['root'].output_coord() 31 | XX.append(X) 32 | XX = np.array(XX) 33 | print(XX.shape) 34 | all_X.append(XX) 35 | 36 | with open('motion.pkl', 'wb') as f: 37 | pkl.dump((edges, all_X), f) 38 | 39 | print('Saved to motion.pkl') 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 AlexHan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | */.DS_Store 7 | .idea/ 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | -------------------------------------------------------------------------------- /simulation/datagen/run.sh: -------------------------------------------------------------------------------- 1 | #python -u generate_dataset.py --num-train 5000 --seed 43 --n_isolated 5 --n_stick 0 --n_hinge 0 --n_workers 50 2 | # 3 | #python -u generate_dataset.py --num-train 5000 --seed 43 --n_isolated 1 --n_stick 2 --n_hinge 0 --n_workers 50 4 | # 5 | #python -u generate_dataset.py --num-train 5000 --seed 43 --n_isolated 2 --n_stick 0 --n_hinge 1 --n_workers 50 6 | # 7 | #python -u generate_dataset.py --num-train 5000 --seed 43 --n_isolated 10 --n_stick 0 --n_hinge 0 --n_workers 50 8 | # 9 | #python -u generate_dataset.py --num-train 5000 --seed 43 --n_isolated 4 --n_stick 3 --n_hinge 0 --n_workers 50 10 | # 11 | #python -u generate_dataset.py --num-train 5000 --seed 43 --n_isolated 2 --n_stick 4 --n_hinge 0 --n_workers 50 12 | # 13 | #python -u generate_dataset.py --num-train 5000 --seed 43 --n_isolated 0 --n_stick 5 --n_hinge 0 --n_workers 50 14 | # 15 | #python -u generate_dataset.py --num-train 5000 --seed 43 --n_isolated 7 --n_stick 0 --n_hinge 1 --n_workers 50 16 | # 17 | #python -u generate_dataset.py --num-train 5000 --seed 43 --n_isolated 4 --n_stick 0 --n_hinge 2 --n_workers 50 18 | # 19 | #python -u generate_dataset.py --num-train 5000 --seed 43 --n_isolated 1 --n_stick 0 --n_hinge 3 --n_workers 50 20 | # 21 | # 22 | #python -u generate_dataset.py --num-train 5000 --seed 43 --n_isolated 20 --n_stick 0 --n_hinge 0 --n_workers 50 23 | # 24 | #python -u generate_dataset.py --num-train 5000 --seed 43 --n_isolated 10 --n_stick 5 --n_hinge 0 --n_workers 50 25 | # 26 | #python -u generate_dataset.py --num-train 5000 --seed 43 --n_isolated 8 --n_stick 6 --n_hinge 0 --n_workers 50 27 | # 28 | #python -u generate_dataset.py --num-train 5000 --seed 43 --n_isolated 4 --n_stick 8 --n_hinge 0 --n_workers 50 29 | # 30 | #python -u generate_dataset.py --num-train 5000 --seed 43 --n_isolated 0 --n_stick 10 --n_hinge 0 --n_workers 50 31 | # 32 | #python -u generate_dataset.py --num-train 5000 --seed 43 --n_isolated 14 --n_stick 0 --n_hinge 2 --n_workers 50 33 | # 34 | #python -u generate_dataset.py --num-train 5000 --seed 43 --n_isolated 8 --n_stick 0 --n_hinge 4 --n_workers 50 35 | # 36 | #python -u generate_dataset.py --num-train 5000 --seed 43 --n_isolated 2 --n_stick 0 --n_hinge 6 --n_workers 50 37 | # 38 | # 39 | #python -u generate_dataset.py --num-train 5000 --seed 43 --n_isolated 3 --n_stick 2 --n_hinge 1 --n_workers 50 40 | # 41 | #python -u generate_dataset.py --num-train 5000 --seed 43 --n_isolated 5 --n_stick 3 --n_hinge 3 --n_workers 50 42 | # 43 | # 44 | # 45 | 46 | # Small 47 | python -u generate_dataset_complex.py --num-train 2000 --seed 43 --n_complex 5 --average_complex_size 3 --n_workers 50 48 | 49 | python -u generate_dataset_complex.py --num-train 2000 --seed 43 --n_complex 8 --average_complex_size 5 --n_workers 50 50 | 51 | python -u generate_dataset_complex.py --num-train 2000 --seed 43 --n_complex 10 --average_complex_size 10 --n_workers 50 52 | 53 | # Median 54 | python -u generate_dataset_complex.py --num-train 2000 --seed 43 --n_complex 5 --average_complex_size 10 --n_workers 50 55 | 56 | python -u generate_dataset_complex.py --num-train 2000 --seed 43 --n_complex 5 --average_complex_size 10 --n_workers 50 57 | 58 | # V100 96 59 | # A100 192 60 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Equivariant Graph Hierarchy-Based Neural Networks (NeurIPS 2022) 2 | 3 | Jiaqi Han, Wenbing Huang, Tingyang Xu, Yu Rong 4 | 5 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://github.com/hanjq17/EGHN/blob/main/LICENSE) 6 | 7 | [**[Paper]**](https://arxiv.org/pdf/2202.10643.pdf) [**[Poster]**](assets/poster.png) 8 | 9 | Equivariant Graph Hierarchy-Based Neural Networks (EGHNs) are novel graph networks that incorporate automatic hierarchical modeling into equivariant GNNs. The model performs promisingly on various types of complex physical/biochemical systems (e.g., proteins dynamics) by achieving lower simulation error while producing visually interpretable cluster assignments as well. Please refer to our paper for more details. 10 | 11 | ![Overview](assets/overview.png "Overview") 12 | 13 | ## Dependencies 14 | 15 | ``` 16 | python==3.8.10 17 | torch==1.8.0 18 | torch-geometric==2.0.1 19 | scikit-learn==0.24.2 20 | networkx==2.5.1 21 | ``` 22 | You may also need `mdanalysis` if you want to process the protein MD data. 23 | 24 | 25 | ## Data Preparation 26 | 27 | **1. Simulation dataset** 28 | 29 | Under `simulation/datagen` path, run the following command: 30 | 31 | ```python 32 | python -u generate_dataset.py --num-train 5000 --seed 43 --n_complex 5 --average_complex_size 3 --system_types 5 33 | ``` 34 | 35 | where `n_complex` is the number of complexes $M$, `average_complex_size` is the size of each complex in expectation, and `system_types` indicate the total number of system types. 36 | 37 | **2. Motion capture dataset** 38 | 39 | We provide our pre-processed dataset as well as the splits in `motion/dataset` folder, which can also be found in the repo of [GMN](https://github.com/hanjq17/GMN). 40 | 41 | **3. Protein MD** 42 | 43 | We provide the data preprocessing code in `mdanalysis/preprocess.py`. One can simply run 44 | 45 | ```python 46 | python mdanalysis/preprocess.py 47 | ``` 48 | 49 | after setting the correct data path specified as the variable `tmp_path` in `preprocess.py`. 50 | 51 | 52 | ## Model Training 53 | 54 | **1. Simulation dataset** 55 | 56 | ```bash 57 | sh start_simulation.sh 58 | ``` 59 | 60 | **2. Motion capture** 61 | 62 | ```bash 63 | sh start_mocap.sh 64 | ``` 65 | 66 | **3. Protein MD** 67 | 68 | ```bash 69 | sh start_md.sh 70 | ``` 71 | 72 | 73 | ## Evaluation 74 | 75 | For Simulation and Motion datasets, the evaluation (validation and testing) is conducted along with training. For protein MD, we extra offer an evaluation script: 76 | 77 | **Protein MD** 78 | 79 | ```bash 80 | sh start_eval_mdanalysis.sh 81 | ``` 82 | 83 | ## Citation 84 | Please consider citing our work if you find it useful: 85 | ``` 86 | @inproceedings{ 87 | han2022equivariant, 88 | title={Equivariant Graph Hierarchy-Based Neural Networks}, 89 | author={Jiaqi Han and Wenbing Huang and Tingyang Xu and Yu Rong}, 90 | booktitle={Advances in Neural Information Processing Systems}, 91 | year={2022}, 92 | url={https://openreview.net/forum?id=ywxtmG1nU_6} 93 | } 94 | ``` 95 | 96 | ## Contact 97 | 98 | If you have any question, welcome to contact me at: 99 | 100 | Jiaqi Han: alexhan99max@gmail.com 101 | -------------------------------------------------------------------------------- /mdanalysis/preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | from joblib import Parallel, delayed 7 | from scipy.sparse import coo_matrix 8 | 9 | from MDAnalysisData import datasets 10 | import MDAnalysis 11 | from MDAnalysis.analysis import distances 12 | 13 | 14 | def compute_ele(ts, index, cutoff): 15 | edge = coo_matrix(distances.contact_matrix(ts.positions[index], cutoff=cutoff, returntype="sparse")) 16 | edge.setdiag(False) 17 | edge.eliminate_zeros() 18 | edge_global = [torch.tensor(edge.row, dtype=torch.long), torch.tensor(edge.col, dtype=torch.long)] 19 | global_edge_attr = torch.norm(torch.tensor(ts.positions[index[edge.row], :] - ts.positions[index[edge.col], :]), 20 | p=2, dim=1) 21 | return edge_global, global_edge_attr 22 | 23 | 24 | # delta_frame = 50 25 | backbone = True 26 | cut_off = 8 27 | # train_valid_test_ratio = [0.6, 0.2, 0.2] 28 | is_save = True 29 | 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--dir', type=str, default='mdanalysis/dataset/') 32 | parser.add_argument('--top_file', type=str, default=None, 33 | help="topology file name 'in' the directory") 34 | parser.add_argument('--traj_file', type=str, default=None, 35 | help="trajectory file name 'in' the directory") 36 | args = parser.parse_args() 37 | 38 | tmp_dir = args.dir 39 | 40 | if args.top_file is not None and args.traj_file is not None: 41 | top_path = os.path.join(args.dir, args.top_file) 42 | traj_path = os.path.join(args.dir, args.traj_file) 43 | data = MDAnalysis.Universe(top_path, traj_path) 44 | else: 45 | print("Warning: No topology or trajectory file given. Using default adk dataset.") 46 | adk = datasets.fetch_adk_equilibrium(data_home=tmp_dir) 47 | data = MDAnalysis.Universe(adk.topology, adk.trajectory) 48 | if backbone: 49 | ag = data.select_atoms('backbone') 50 | else: 51 | ag = data.atoms 52 | 53 | 54 | charges = torch.tensor(data.atoms[ag.ix].charges) 55 | bonds = np.stack([bond.indices for bond in data.bonds if bond.indices[0] in ag.ix and bond.indices[1] in ag.ix]) 56 | map_dict = {v:k for k,v in enumerate(ag.ix)} 57 | bonds = np.vectorize(map_dict.get)(bonds) 58 | edges = [torch.tensor(bonds[:, 0], dtype=torch.long), 59 | torch.tensor(bonds[:, 1], dtype=torch.long)] 60 | 61 | edge_attr = torch.tensor([bond.length() for bond in data.bonds 62 | if bond.indices[0] in ag.ix and bond.indices[1] in ag.ix]) 63 | 64 | loc = [] 65 | vel = [] 66 | 67 | for i in tqdm(range(len(data.trajectory) - 1)): 68 | loc.append(torch.tensor(data.trajectory[i].positions[ag.ix])) 69 | vel.append(torch.tensor(data.trajectory[i + 1].positions[ag.ix] - data.trajectory[i].positions[ag.ix])) 70 | 71 | if backbone: 72 | save_path = os.path.join(tmp_dir, 'adk_backbone_processed', 'adk.pkl') 73 | os.makedirs(os.path.join(tmp_dir, 'adk_backbone_processed'), exist_ok=True) 74 | else: 75 | save_path = os.path.join(tmp_dir, 'adk_processed', 'adk.pkl') 76 | os.makedirs(os.path.join(tmp_dir, 'adk_processed'), exist_ok=True) 77 | if is_save: 78 | torch.save((edges, edge_attr, charges, len(data.trajectory) - 1), save_path) 79 | 80 | edges_global, edges_global_attr = zip(*Parallel(n_jobs=-1)(delayed(lambda a: compute_ele(a, ag.ix, cut_off))(_) 81 | for _ in tqdm(data.trajectory))) 82 | edges_global = edges_global[:-1] 83 | edges_global_attr = edges_global_attr[:-1] 84 | 85 | 86 | if backbone: 87 | save_path = os.path.join(tmp_dir, 'adk_backbone_processed') 88 | else: 89 | save_path = os.path.join(tmp_dir, 'adk_processed') 90 | 91 | if is_save: 92 | for i in tqdm(range(len(loc))): 93 | try: 94 | torch.save((loc[i], vel[i], edges_global[i], edges_global_attr[i]), 95 | os.path.join(save_path, f'adk_{i}.pkl')) 96 | except RuntimeError: 97 | print(i) 98 | 99 | -------------------------------------------------------------------------------- /motion/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import pickle as pkl 4 | import os 5 | 6 | 7 | class MotionDataset(): 8 | """ 9 | Motion Dataset 10 | 11 | """ 12 | 13 | def __init__(self, partition, max_samples, delta_frame, data_dir, case='walk'): 14 | if case == 'walk': 15 | with open(os.path.join(data_dir, 'motion.pkl'), 'rb') as f: 16 | edges, X = pkl.load(f) 17 | elif case == 'run': 18 | with open(os.path.join(data_dir, 'motion_run.pkl'), 'rb') as f: 19 | edges, X = pkl.load(f) 20 | else: 21 | raise RuntimeError('Unknown case') 22 | 23 | V = [] 24 | for i in range(len(X)): 25 | V.append(X[i][1:] - X[i][:-1]) 26 | X[i] = X[i][:-1] 27 | 28 | 29 | N = X[0].shape[1] 30 | 31 | if case == 'walk': 32 | train_case_id = [20, 1, 17, 13, 14, 9, 4, 2, 7, 5, 16] 33 | val_case_id = [3, 8, 11, 12, 15, 18] 34 | test_case_id = [6, 19, 21, 0, 22, 10] 35 | split_dir = os.path.join(data_dir, 'split.pkl') 36 | elif case == 'run': 37 | train_case_id = [1, 2, 5, 6, 10] 38 | val_case_id = [0, 4, 9] 39 | test_case_id = [3, 7, 8] 40 | split_dir = os.path.join(data_dir, 'split_run.pkl') 41 | else: 42 | raise RuntimeError('Unknown case') 43 | 44 | self.partition = partition 45 | 46 | try: 47 | with open(split_dir, 'rb') as f: 48 | print('Got Split!') 49 | split = pkl.load(f) 50 | except: 51 | np.random.seed(100) 52 | 53 | # sample 100 for each case 54 | if case == 'walk': 55 | itv = 300 56 | elif case == 'run': 57 | itv = 90 58 | else: 59 | raise RuntimeError('Unknown case') 60 | train_mapping = {} 61 | for i in train_case_id: 62 | # cur_x = X[i][:itv] 63 | sampled = np.random.choice(np.arange(itv), size=80 if case == 'run' else 100, replace=False) 64 | train_mapping[i] = sampled 65 | val_mapping = {} 66 | for i in val_case_id: 67 | # cur_x = X[i][:itv] 68 | sampled = np.random.choice(np.arange(itv), size=80 if case == 'run' else 100, replace=False) 69 | val_mapping[i] = sampled 70 | test_mapping = {} 71 | for i in test_case_id: 72 | # cur_x = X[i][:itv] 73 | sampled = np.random.choice(np.arange(itv), size=80 if case == 'run' else 100, replace=False) 74 | test_mapping[i] = sampled 75 | 76 | with open(split_dir, 'wb') as f: 77 | pkl.dump((train_mapping, val_mapping, test_mapping), f) 78 | 79 | print('Generate and save split!') 80 | split = (train_mapping, val_mapping, test_mapping) 81 | 82 | if partition == 'train': 83 | mapping = split[0] 84 | elif partition == 'val': 85 | mapping = split[1] 86 | elif partition == 'test': 87 | mapping = split[2] 88 | else: 89 | raise NotImplementedError() 90 | 91 | each_len = max_samples // len(mapping) 92 | 93 | x_0, v_0, x_t, v_t = [], [], [], [] 94 | for i in mapping: 95 | st = mapping[i][:each_len] 96 | cur_x_0 = X[i][st] 97 | cur_v_0 = V[i][st] 98 | cur_x_t = X[i][st + delta_frame] 99 | cur_v_t = V[i][st + delta_frame] 100 | x_0.append(cur_x_0) 101 | v_0.append(cur_v_0) 102 | x_t.append(cur_x_t) 103 | v_t.append(cur_v_t) 104 | x_0 = np.concatenate(x_0, axis=0) 105 | v_0 = np.concatenate(v_0, axis=0) 106 | x_t = np.concatenate(x_t, axis=0) 107 | v_t = np.concatenate(v_t, axis=0) 108 | 109 | print('Got {:d} samples!'.format(x_0.shape[0])) 110 | 111 | self.n_node = N 112 | 113 | atom_edges = torch.zeros(N, N).int() 114 | for edge in edges: 115 | atom_edges[edge[0], edge[1]] = 1 116 | atom_edges[edge[1], edge[0]] = 1 117 | 118 | atom_edges2 = atom_edges @ atom_edges 119 | self.atom_edge = atom_edges 120 | self.atom_edge2 = atom_edges2 121 | edge_attr = [] 122 | # Initialize edges and edge_attributes 123 | rows, cols = [], [] 124 | for i in range(N): 125 | for j in range(N): 126 | if i != j: 127 | if self.atom_edge[i][j]: 128 | rows.append(i) 129 | cols.append(j) 130 | edge_attr.append([1]) 131 | elif self.atom_edge2[i][j]: 132 | rows.append(i) 133 | cols.append(j) 134 | edge_attr.append([2]) 135 | else: 136 | pass # TODO: Do we need to add the rest of edges here? 137 | 138 | edges = [rows, cols] # edges for equivariant message passing 139 | edge_attr = torch.Tensor(np.array(edge_attr)) # [edge, 3] 140 | self.edge_attr = edge_attr # [edge, 3] 141 | self.edges = torch.LongTensor(np.array(edges)) # [2, edge] 142 | 143 | self.x_0, self.v_0, self.x_t, self.v_t = torch.Tensor(x_0), torch.Tensor(v_0), torch.Tensor(x_t), torch.Tensor( 144 | v_t) 145 | mole_idx = np.ones(N) 146 | self.mole_idx = torch.Tensor(mole_idx) # the node feature 147 | 148 | def __getitem__(self, i): 149 | edges = self.edges 150 | edge_attr = self.edge_attr 151 | local_edge_mask = edge_attr[..., -1] == 1 152 | local_edges = edges[..., local_edge_mask] 153 | local_edge_attr = edge_attr[local_edge_mask] 154 | 155 | # add z to node feature 156 | node_fea = self.x_0[i][..., 1].unsqueeze(-1) / 10 157 | 158 | return self.x_0[i], self.v_0[i], edges, edge_attr, local_edges, local_edge_attr, \ 159 | node_fea, self.x_t[i], self.v_t[i] 160 | 161 | def __len__(self): 162 | return len(self.x_0) 163 | 164 | 165 | if __name__ == '__main__': 166 | data = MotionDataset(partition='train', max_samples=500, delta_frame=30, data_dir='') -------------------------------------------------------------------------------- /simulation/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | import pickle as pkl 5 | from functools import reduce 6 | 7 | 8 | class SimulationDataset(): 9 | """ 10 | NBodyDataset 11 | 12 | """ 13 | def __init__(self, partition='train', max_samples=1e8, 14 | data_dir='', n_complex=5, average_complex_size=3, system_types=5): 15 | self.partition = partition 16 | self.data_dir = data_dir 17 | self.n_complex = n_complex 18 | self.average_complex_size = average_complex_size 19 | self.system_types = system_types 20 | 21 | if self.partition == 'val': 22 | self.suffix = 'valid' 23 | else: 24 | self.suffix = self.partition 25 | 26 | self.suffix += '_charged{:d}_{:d}_{:d}'.format(n_complex, average_complex_size, system_types) 27 | # self.suffix += '_charged0_0_0_3' 28 | 29 | self.max_samples = int(max_samples) 30 | self.loc, self.vel, self.charges, self.edges, self.cfg = self.load() 31 | # self.data, self.edges, self.cfg = self.load() 32 | 33 | def load(self): 34 | # loc = np.load(self.data_dir + '/' + 'loc_' + self.suffix + '.npy') # [N_SAMPLE, N_FRAME, N_NODE, 3] 35 | # vel = np.load(self.data_dir + '/' + 'vel_' + self.suffix + '.npy') 36 | # charges = np.load(self.data_dir + '/' + 'charges_' + self.suffix + '.npy') 37 | # edges = np.load(self.data_dir + '/' + 'edges_' + self.suffix + '.npy') 38 | 39 | with open(self.data_dir + '/' + 'loc_' + self.suffix + '.pkl', 'rb') as f: # [N_SAMPLE, N_FRAME, N_NODE, 3] 40 | loc = pkl.load(f) 41 | with open(self.data_dir + '/' + 'vel_' + self.suffix + '.pkl', 'rb') as f: 42 | vel = pkl.load(f) 43 | with open(self.data_dir + '/' + 'charges_' + self.suffix + '.pkl', 'rb') as f: 44 | charges = pkl.load(f) 45 | with open(self.data_dir + '/' + 'edges_' + self.suffix + '.pkl', 'rb') as f: 46 | edges = pkl.load(f) 47 | with open(self.data_dir + '/' + 'cfg_' + self.suffix + '.pkl', 'rb') as f: 48 | cfg = pkl.load(f) 49 | 50 | loc = loc[:self.max_samples] 51 | vel = vel[:self.max_samples] 52 | charges = charges[: self.max_samples] 53 | edges = edges[: self.max_samples] 54 | 55 | return loc, vel, charges, edges, cfg 56 | 57 | # loc, vel, edge_attr, edges, charges = self.preprocess(loc, vel, edges, charges) 58 | # return (loc, vel, edge_attr, charges), edges, cfg 59 | 60 | # def preprocess(self, loc, vel, edges, charges): 61 | # loc, vel = torch.Tensor(loc), torch.Tensor(vel) # remove transpose this time 62 | # n_nodes = loc.size(2) 63 | # loc = loc[:self.max_samples, :, :, :] # limit number of samples 64 | # vel = vel[:self.max_samples, :, :, :] # speed when starting the trajectory 65 | # charges = charges[: self.max_samples] 66 | # # edges: charge_i * charge_j (edge_attr) 67 | # edges = edges[: self.max_samples, ...] # add here for better consistency 68 | # edge_attr = [] 69 | # 70 | # # Initialize edges and edge_attributes 71 | # rows, cols = [], [] 72 | # for i in range(n_nodes): 73 | # for j in range(n_nodes): 74 | # if i != j: # remove self loop 75 | # edge_attr.append(edges[:, i, j]) 76 | # rows.append(i) 77 | # cols.append(j) 78 | # edges = [rows, cols] 79 | # 80 | # # swap n_nodes <--> batch_size and add nf dimension 81 | # edge_attr = torch.Tensor(edge_attr).transpose(0, 1).unsqueeze(2) # [B, N*(N-1), 1] 82 | # 83 | # return torch.Tensor(loc), torch.Tensor(vel), torch.Tensor(edge_attr), edges, torch.Tensor(charges) 84 | # 85 | # def set_max_samples(self, max_samples): 86 | # self.max_samples = int(max_samples) 87 | # self.data, self.edges, self.cfg = self.load() 88 | 89 | def __getitem__(self, i): 90 | loc, vel, charges, edges, cfg = self.loc[i], self.vel[i], self.charges[i], self.edges[i], self.cfg[i] 91 | 92 | # frame_0, frame_T = 30, 40 93 | frame_0, frame_T = 10, 25 94 | 95 | edge_attr = [] 96 | n_nodes = loc.shape[1] 97 | # Initialize edges and edge_attributes for interaction forces 98 | rows, cols = [], [] 99 | for i in range(n_nodes): 100 | for j in range(n_nodes): 101 | if i != j: # remove self loop 102 | edge_attr.append(edges[i, j]) 103 | rows.append(i) 104 | cols.append(j) 105 | edges = [rows, cols] 106 | edge_attr = torch.Tensor(edge_attr).unsqueeze(-1) # [N*(N-1), 1] 107 | 108 | assert 'Stick' not in cfg and 'Hinge' not in cfg # Currently, only want to support isolated and complex bodies 109 | # add edge indicator for Complex 110 | stick_ind = torch.zeros_like(edge_attr)[..., -1].unsqueeze(-1) 111 | if 'Complex' in cfg: 112 | configs = cfg['Complex'] 113 | for comp in configs: 114 | # add fully connected graph over the complex body 115 | for _i in range(len(comp)): 116 | for _j in range(len(comp)): 117 | if _i != _j: 118 | idi, idj = comp[_i], comp[_j] 119 | n_node = loc.shape[1] 120 | edge_idx = idi * (n_node - 1) + idj 121 | if idj > idi: 122 | edge_idx -= 1 123 | assert edges[0][edge_idx] == idi and edges[1][edge_idx] == idj 124 | stick_ind[edge_idx] = 1 125 | edge_attr = torch.cat((edge_attr, stick_ind), dim=-1) 126 | 127 | edges = torch.from_numpy(np.array(edges)) 128 | local_edge_mask = edge_attr[..., -1] == 1 129 | 130 | return torch.Tensor(loc[frame_0]), torch.Tensor(vel[frame_0]), edges, edge_attr, local_edge_mask,\ 131 | torch.Tensor(charges), torch.Tensor(loc[frame_T]), torch.Tensor(vel[frame_T]) 132 | 133 | def __len__(self): 134 | return len(self.loc) 135 | 136 | # def get_edges(self, batch_size, n_nodes): 137 | # edges = [torch.LongTensor(self.edges[0]), torch.LongTensor(self.edges[1])] 138 | # if batch_size == 1: 139 | # return edges 140 | # elif batch_size > 1: 141 | # offset = torch.arange(batch_size) * n_nodes 142 | # row = edges[0].unsqueeze(0).repeat(batch_size, 1) 143 | # row = row + offset.unsqueeze(-1).expand_as(row) 144 | # col = edges[1].unsqueeze(0).repeat(batch_size, 1) 145 | # col = col + offset.unsqueeze(-1).expand_as(col) 146 | # edges = [row.reshape(-1), col.reshape(-1)] 147 | # return edges 148 | # 149 | # @staticmethod 150 | # def get_cfg(batch_size, n_nodes, cfg): 151 | # offset = torch.arange(batch_size) * n_nodes 152 | # for type in cfg: 153 | # index = cfg[type] # [B, n_type, node_per_type] 154 | # cfg[type] = (index + offset.unsqueeze(-1).unsqueeze(-1).expand_as(index)).reshape(-1, index.shape[-1]) 155 | # if type == 'Isolated': 156 | # cfg[type] = cfg[type].squeeze(-1) 157 | # return cfg 158 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch import nn 4 | from motion.dataset import MotionDataset 5 | 6 | 7 | def collector(batch): 8 | """ 9 | Rebatch the input and padding zeros for loc, vel, loc_end, vel_end. 10 | Add the additional mask (B*N, 1) at the last. 11 | :param batch: 12 | :return: the re_batched list. 13 | """ 14 | re_batch = [[] for _ in range(len(batch[0]))] 15 | for b in batch: 16 | [re_batch[i].append(d) for i, d in enumerate(b)] 17 | 18 | loc, vel, edge_attr, charges, loc_end, vel_end = re_batch[:6] 19 | res = [] 20 | padding = [True, True, False, False, True, True, False, False, False] 21 | for b, p in zip(re_batch[:-1], padding[:len(re_batch) -1]): 22 | res.append(do_padding(b, padding=p)) 23 | mask = generate_mask(loc) 24 | res.append(re_batch[-1]) 25 | res.append(mask) 26 | return res 27 | 28 | 29 | def collector_simulation(batch): 30 | """ 31 | Rebatch the input and padding zeros for loc, vel, loc_end, vel_end. 32 | Add the additional mask (B*N, 1) at the last. 33 | :param batch: 34 | :return: the re_batched list. 35 | """ 36 | re_batch = [[] for _ in range(len(batch[0]))] 37 | for b in batch: 38 | [re_batch[i].append(d) for i, d in enumerate(b)] 39 | 40 | assert len(re_batch) == 8 41 | loc, vel, edges, edge_attr, local_edge_mask, charges, loc_end, vel_end = re_batch 42 | max_size = max([x.size(0) for x in loc]) 43 | node_nums = torch.tensor([x.size(0) for x in loc]) 44 | mask = generate_mask(loc) 45 | loc = _padding(loc, max_size) 46 | vel = _padding(vel, max_size) 47 | edges = _pack_edges(edges, max_size) 48 | edge_attr = torch.cat(edge_attr, dim=0) 49 | local_edge_mask = torch.cat(local_edge_mask, dim=0) 50 | charges = _padding(charges, max_size) 51 | loc_end = _padding(loc_end, max_size) 52 | vel_end = _padding(vel_end, max_size) 53 | return loc, vel, edges, edge_attr, local_edge_mask, charges, loc_end, vel_end, mask, node_nums, max_size 54 | 55 | 56 | def _padding(tensor_list, max_size): 57 | res = [torch.cat([r, torch.zeros([max_size - r.size(0), r.size(1)])]) for r in tensor_list] 58 | res = torch.cat(res, dim=0) 59 | return res 60 | 61 | 62 | def _pack_edges(edge_list, n_node): 63 | for idx, edge in enumerate(edge_list): 64 | edge[0] += idx * n_node 65 | edge[1] += idx * n_node 66 | return torch.cat(edge_list, dim=1) # [2, BM] 67 | 68 | 69 | def do_padding(tensor_list, padding=True): 70 | """ 71 | Pad the input tensor_list ad 72 | :param tensor_list: list(B, tensor[N, *]) 73 | :return: padded tensor [B*max_N, *] 74 | """ 75 | if padding: 76 | max_size = max([x.size(0) for x in tensor_list]) 77 | res = [torch.cat([r, torch.zeros([max_size - r.size(0), r.size(1)])]) for r in tensor_list] 78 | else: 79 | res = tensor_list 80 | res = torch.cat(res, dim=0) 81 | return res 82 | 83 | 84 | def generate_mask(tensor_list): 85 | max_size = max([x.size(0) for x in tensor_list]) 86 | res = [torch.cat([torch.ones([r.size(0)]), torch.zeros([max_size - r.size(0)])]) for r in tensor_list] 87 | res = torch.cat(res, dim=0) 88 | return res 89 | 90 | 91 | def test_do_padding(): 92 | tensor_list = [torch.ones([2, 3]), torch.zeros([4, 3])] 93 | res = do_padding(tensor_list) 94 | 95 | # tensor([[1., 1., 1.], 96 | # [1., 1., 1.], 97 | # [0., 0., 0.], 98 | # [0., 0., 0.], 99 | # [0., 0., 0.], 100 | # [0., 0., 0.], 101 | # [0., 0., 0.], 102 | # [0., 0., 0.]]) 103 | 104 | 105 | def test_generate_mask(): 106 | tensor_list = [torch.rand([2, 3]), torch.rand([4, 3])] 107 | res = generate_mask(tensor_list) 108 | print(res) 109 | 110 | 111 | def test_collector(): 112 | data_train = MotionDataset(partition='train', max_samples=100, delta_frame=30, data_dir='motion/dataset') 113 | loader_train = torch.utils.data.DataLoader(data_train, batch_size=2, shuffle=True, drop_last=True, 114 | num_workers=1, collate_fn=collector) 115 | for batch_idx, data in enumerate(loader_train): 116 | print(data) 117 | 118 | 119 | class MaskMSELoss(nn.Module): 120 | def __init__(self): 121 | super(MaskMSELoss, self).__init__() 122 | self.loss = nn.MSELoss(reduction="none") 123 | 124 | def forward(self, pred, target, mask, grouped_size=None): 125 | """ 126 | 127 | :param pred: [N, d] 128 | :param target: [N, d] 129 | :param mask: [N, 1] 130 | :param grouped_size: [B, K], B * K = N 131 | :return: 132 | """ 133 | assert grouped_size is None or (mask.size(0) % grouped_size.size(0) == 0) 134 | loss = self.loss(pred, target) 135 | # Looks strange, do I miss something? 136 | loss = (loss.T * mask).T 137 | if grouped_size is not None: 138 | loss = loss.reshape([grouped_size.size(0), -1, pred.size(-1)]) 139 | # average loss by grouped_size on dim=1 140 | loss = torch.sum(loss, dim=1) / grouped_size.unsqueeze(dim=1) 141 | loss = torch.mean(loss) 142 | else: 143 | loss = torch.sum(loss) / (torch.sum(mask) * loss.size(-1)) 144 | return loss 145 | 146 | 147 | def test_MaskMSELoss(): 148 | input = torch.rand([6, 2]) 149 | target = torch.rand([6, 2]) 150 | mask = torch.tensor([1, 0, 1, 0, 1, 1]) 151 | grouped_size = torch.tensor([1, 1, 2]) 152 | loss = MaskMSELoss() 153 | print(loss(input, target, mask, grouped_size)) 154 | print(loss(input, target, mask)) 155 | 156 | 157 | class EarlyStopping: 158 | """Early stops the training if validation loss doesn't improve after a given patience.""" 159 | def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print): 160 | """ 161 | Args: 162 | patience (int): How long to wait after last time validation loss improved. 163 | Default: 7 164 | verbose (bool): If True, prints a message for each validation loss improvement. 165 | Default: False 166 | delta (float): Minimum change in the monitored quantity to qualify as an improvement. 167 | Default: 0 168 | path (str): Path for the checkpoint to be saved to. 169 | Default: 'checkpoint.pt' 170 | trace_func (function): trace print function. 171 | Default: print 172 | """ 173 | self.patience = patience 174 | self.verbose = verbose 175 | self.counter = 0 176 | self.best_score = None 177 | self.early_stop = False 178 | self.val_loss_min = np.Inf 179 | self.delta = delta 180 | self.path = path 181 | self.trace_func = trace_func 182 | def __call__(self, val_loss, model, master_worker=True): 183 | 184 | score = -val_loss 185 | 186 | if self.best_score is None: 187 | self.best_score = score 188 | self.save_checkpoint(val_loss, model, master_worker) 189 | elif score < self.best_score + self.delta: 190 | self.counter += 1 191 | self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}') 192 | if self.counter >= self.patience: 193 | self.early_stop = True 194 | else: 195 | self.best_score = score 196 | self.save_checkpoint(val_loss, model, master_worker) 197 | self.counter = 0 198 | 199 | def save_checkpoint(self, val_loss, model, master_worker=True): 200 | '''Saves model when validation loss decrease.''' 201 | if not master_worker: 202 | return 203 | if self.verbose and master_worker: 204 | self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') 205 | torch.save(model.state_dict(), self.path) 206 | self.val_loss_min = val_loss -------------------------------------------------------------------------------- /motion/amc_parser.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from transforms3d.euler import euler2mat 4 | from mpl_toolkits.mplot3d import Axes3D 5 | 6 | 7 | class Joint: 8 | def __init__(self, name, direction, length, axis, dof, limits): 9 | """ 10 | Definition of basic joint. The joint also contains the information of the 11 | bone between it's parent joint and itself. Refer 12 | [here](https://research.cs.wisc.edu/graphics/Courses/cs-838-1999/Jeff/ASF-AMC.html) 13 | for detailed description for asf files. 14 | Parameter 15 | --------- 16 | name: Name of the joint defined in the asf file. There should always be one 17 | root joint. String. 18 | direction: Default direction of the joint(bone). The motions are all defined 19 | based on this default pose. 20 | length: Length of the bone. 21 | axis: Axis of rotation for the bone. 22 | dof: Degree of freedom. Specifies the number of motion channels and in what 23 | order they appear in the AMC file. 24 | limits: Limits on each of the channels in the dof specification 25 | """ 26 | self.name = name 27 | self.direction = np.reshape(direction, [3, 1]) 28 | self.length = length 29 | axis = np.deg2rad(axis) 30 | self.C = euler2mat(*axis) 31 | self.Cinv = np.linalg.inv(self.C) 32 | self.limits = np.zeros([3, 2]) 33 | for lm, nm in zip(limits, dof): 34 | if nm == 'rx': 35 | self.limits[0] = lm 36 | elif nm == 'ry': 37 | self.limits[1] = lm 38 | else: 39 | self.limits[2] = lm 40 | self.parent = None 41 | self.children = [] 42 | self.coordinate = None 43 | self.matrix = None 44 | 45 | def set_motion(self, motion): 46 | if self.name == 'root': 47 | self.coordinate = np.reshape(np.array(motion['root'][:3]), [3, 1]) 48 | rotation = np.deg2rad(motion['root'][3:]) 49 | self.matrix = self.C.dot(euler2mat(*rotation)).dot(self.Cinv) 50 | else: 51 | idx = 0 52 | rotation = np.zeros(3) 53 | for axis, lm in enumerate(self.limits): 54 | if not np.array_equal(lm, np.zeros(2)): 55 | rotation[axis] = motion[self.name][idx] 56 | idx += 1 57 | rotation = np.deg2rad(rotation) 58 | self.matrix = self.parent.matrix.dot(self.C).dot(euler2mat(*rotation)).dot(self.Cinv) 59 | self.coordinate = self.parent.coordinate + self.length * self.matrix.dot(self.direction) 60 | for child in self.children: 61 | child.set_motion(motion) 62 | 63 | def get_name_to_idx(self): 64 | joints = self.to_dict() 65 | name_to_idx = {} 66 | for idx, joint in enumerate(joints.values()): 67 | assert joint.name not in name_to_idx 68 | name_to_idx[joint.name] = idx 69 | self.name_to_idx = name_to_idx 70 | 71 | def output_edges(self): 72 | joints = self.to_dict() 73 | name_to_idx = self.name_to_idx 74 | edges = [] 75 | for idx, joint in enumerate(joints.values()): 76 | child = joint 77 | if child.parent is not None: 78 | parent = child.parent 79 | edges.append([name_to_idx[child.name], name_to_idx[parent.name]]) 80 | return edges 81 | 82 | def output_coord(self): 83 | N = len(self.name_to_idx) 84 | X = np.zeros((N, 3)) 85 | joints = self.to_dict() 86 | name_to_idx = self.name_to_idx 87 | for idx, joint in enumerate(joints.values()): 88 | X[name_to_idx[joint.name]] = joint.coordinate.reshape(-1) 89 | return X 90 | 91 | def draw(self): 92 | joints = self.to_dict() 93 | fig = plt.figure() 94 | ax = Axes3D(fig) 95 | 96 | ax.set_xlim3d(-50, 10) 97 | ax.set_ylim3d(-20, 40) 98 | ax.set_zlim3d(-20, 40) 99 | 100 | xs, ys, zs = [], [], [] 101 | for joint in joints.values(): 102 | xs.append(joint.coordinate[0, 0]) 103 | ys.append(joint.coordinate[1, 0]) 104 | zs.append(joint.coordinate[2, 0]) 105 | plt.plot(zs, xs, ys, 'b.') 106 | 107 | for joint in joints.values(): 108 | child = joint 109 | if child.parent is not None: 110 | parent = child.parent 111 | xs = [child.coordinate[0, 0], parent.coordinate[0, 0]] 112 | ys = [child.coordinate[1, 0], parent.coordinate[1, 0]] 113 | zs = [child.coordinate[2, 0], parent.coordinate[2, 0]] 114 | plt.plot(zs, xs, ys, 'r') 115 | plt.show() 116 | 117 | def to_dict(self): 118 | ret = {self.name: self} 119 | for child in self.children: 120 | ret.update(child.to_dict()) 121 | return ret 122 | 123 | def pretty_print(self): 124 | print('===================================') 125 | print('joint: %s' % self.name) 126 | print('direction:') 127 | print(self.direction) 128 | print('limits:', self.limits) 129 | print('parent:', self.parent) 130 | print('children:', self.children) 131 | 132 | 133 | def read_line(stream, idx): 134 | if idx >= len(stream): 135 | return None, idx 136 | line = stream[idx].strip().split() 137 | idx += 1 138 | return line, idx 139 | 140 | 141 | def parse_asf(file_path): 142 | '''read joint data only''' 143 | with open(file_path) as f: 144 | content = f.read().splitlines() 145 | 146 | for idx, line in enumerate(content): 147 | # meta infomation is ignored 148 | if line == ':bonedata': 149 | content = content[idx + 1:] 150 | break 151 | 152 | # read joints 153 | joints = {'root': Joint('root', np.zeros(3), 0, np.zeros(3), [], [])} 154 | idx = 0 155 | while True: 156 | # the order of each section is hard-coded 157 | 158 | line, idx = read_line(content, idx) 159 | 160 | if line[0] == ':hierarchy': 161 | break 162 | 163 | assert line[0] == 'begin' 164 | 165 | line, idx = read_line(content, idx) 166 | assert line[0] == 'id' 167 | 168 | line, idx = read_line(content, idx) 169 | assert line[0] == 'name' 170 | name = line[1] 171 | 172 | line, idx = read_line(content, idx) 173 | assert line[0] == 'direction' 174 | direction = np.array([float(axis) for axis in line[1:]]) 175 | 176 | # skip length 177 | line, idx = read_line(content, idx) 178 | assert line[0] == 'length' 179 | length = float(line[1]) 180 | 181 | line, idx = read_line(content, idx) 182 | assert line[0] == 'axis' 183 | assert line[4] == 'XYZ' 184 | 185 | axis = np.array([float(axis) for axis in line[1:-1]]) 186 | 187 | dof = [] 188 | limits = [] 189 | 190 | line, idx = read_line(content, idx) 191 | if line[0] == 'dof': 192 | dof = line[1:] 193 | for i in range(len(dof)): 194 | line, idx = read_line(content, idx) 195 | if i == 0: 196 | assert line[0] == 'limits' 197 | line = line[1:] 198 | assert len(line) == 2 199 | mini = float(line[0][1:]) 200 | maxi = float(line[1][:-1]) 201 | limits.append((mini, maxi)) 202 | 203 | line, idx = read_line(content, idx) 204 | 205 | assert line[0] == 'end' 206 | joints[name] = Joint( 207 | name, 208 | direction, 209 | length, 210 | axis, 211 | dof, 212 | limits 213 | ) 214 | 215 | # read hierarchy 216 | assert line[0] == ':hierarchy' 217 | 218 | line, idx = read_line(content, idx) 219 | 220 | assert line[0] == 'begin' 221 | 222 | while True: 223 | line, idx = read_line(content, idx) 224 | if line[0] == 'end': 225 | break 226 | assert len(line) >= 2 227 | for joint_name in line[1:]: 228 | joints[line[0]].children.append(joints[joint_name]) 229 | for nm in line[1:]: 230 | joints[nm].parent = joints[line[0]] 231 | 232 | return joints 233 | 234 | 235 | def parse_amc(file_path): 236 | with open(file_path) as f: 237 | content = f.read().splitlines() 238 | 239 | for idx, line in enumerate(content): 240 | if line == ':DEGREES': 241 | content = content[idx + 1:] 242 | break 243 | 244 | frames = [] 245 | idx = 0 246 | line, idx = read_line(content, idx) 247 | assert line[0].isnumeric(), line 248 | EOF = False 249 | while not EOF: 250 | joint_degree = {} 251 | while True: 252 | line, idx = read_line(content, idx) 253 | if line is None: 254 | EOF = True 255 | break 256 | if line[0].isnumeric(): 257 | break 258 | joint_degree[line[0]] = [float(deg) for deg in line[1:]] 259 | frames.append(joint_degree) 260 | return frames 261 | -------------------------------------------------------------------------------- /simulation/datagen/generate_dataset_complex.py: -------------------------------------------------------------------------------- 1 | import pickle as pkl 2 | import numpy as np 3 | import argparse 4 | import json 5 | from system import System 6 | from tqdm import tqdm 7 | import os 8 | from joblib import Parallel, delayed 9 | 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--simulation', type=str, default='charged', 13 | help='What simulation to generate.') 14 | parser.add_argument('--path', type=str, default='data', 15 | help='Path to save.') 16 | parser.add_argument('--num_train', type=int, default=2000, 17 | help='Number of training simulations to generate.') 18 | parser.add_argument('--num_valid', type=int, default=2000, 19 | help='Number of validation simulations to generate.') 20 | parser.add_argument('--num_test', type=int, default=2000, 21 | help='Number of test simulations to generate.') 22 | parser.add_argument('--length', type=int, default=5000, 23 | help='Length of trajectory.') 24 | parser.add_argument('--length_test', type=int, default=5000, 25 | help='Length of test set trajectory.') 26 | parser.add_argument('--sample-freq', type=int, default=100, 27 | help='How often to sample the trajectory.') 28 | parser.add_argument('--n_complex', type=int, default=5, 29 | help='Number of complex in the simulation.') 30 | parser.add_argument('--average_complex_size', type=int, default=3, 31 | help='The expected size of the complex bodies in a system.') 32 | parser.add_argument('--system_types', type=int, default=5, 33 | help="The total number of system types.") 34 | parser.add_argument('--seed', type=int, default=42, 35 | help='Random seed.') 36 | parser.add_argument('--suffix', type=str, default="", 37 | help='add a suffix to the name') 38 | parser.add_argument('--n_workers', type=int, default=1, 39 | help="Number of workers") 40 | parser.add_argument('--box_size', type=float, default=None, 41 | help="The size of the box.") 42 | parser.add_argument("--config_by_file", default=False, action="store_true", ) 43 | 44 | args = parser.parse_args() 45 | 46 | if args.config_by_file: 47 | job_param_path = './job_param.json' 48 | with open(job_param_path, 'r') as f: 49 | hyper_params = json.load(f) 50 | args.num_train = hyper_params["num_train"] 51 | args.num_valid = hyper_params["num_valid"] 52 | args.num_test = hyper_params["num_test"] 53 | args.path = hyper_params["path"] 54 | args.seed = hyper_params["seed"] 55 | args.n_complex = hyper_params["n_complex"] 56 | args.average_complex_size = hyper_params["average_complex_size"] 57 | args.system_types = hyper_params["system_types"] 58 | args.n_workers = hyper_params["n_workers"] 59 | 60 | suffix = '_charged' 61 | 62 | suffix += str(args.n_complex) + '_' + str(args.average_complex_size) + '_' + str(args.system_types) + args.suffix 63 | np.random.seed(args.seed) 64 | 65 | print(suffix) 66 | 67 | 68 | def para_comp(length, sample_freq, all_sizes): 69 | while True: 70 | X, V = [], [] 71 | system_tyes = len(all_sizes) 72 | chosen = np.random.randint(0, system_tyes) 73 | sizes = all_sizes[chosen] 74 | assert len(sizes) == args.n_complex 75 | system = System(n_isolated=0, n_stick=0, n_hinge=0, 76 | n_complex=args.n_complex, complex_sizes=list(sizes), box_size=args.box_size) 77 | 78 | for t in range(length): 79 | system.simulate_one_step() 80 | if t % sample_freq == 0: 81 | X.append(system.X.copy()) 82 | V.append(system.V.copy()) 83 | # X[t // sample_freq] = system.X.copy() 84 | # V[t // sample_freq] = system.V.copy() 85 | system.check() 86 | # if system.is_valid() # currently do not apply constraint 87 | if system.is_valid(): 88 | cfg = system.configuration() 89 | X = np.array(X) 90 | V = np.array(V) 91 | return cfg, X, V, system.edges, system.charges 92 | else: 93 | print('Velocity too large, retry') 94 | 95 | 96 | def generate_dataset(num_sims, length, sample_freq, all_sizes): 97 | results = Parallel(n_jobs=args.n_workers)(delayed(para_comp)(length, sample_freq, all_sizes) for i in tqdm(range(num_sims))) 98 | cfg_all, loc_all, vel_all, edges_all, charges_all = zip(*results) 99 | # print(f'total trials: {cnt:d}, samples: {len(loc_all):d}', cnt) 100 | 101 | return loc_all, vel_all, edges_all, charges_all, cfg_all 102 | 103 | 104 | if __name__ == "__main__": 105 | if not os.path.exists(args.path): 106 | os.mkdir(args.path) 107 | 108 | system_types = args.system_types 109 | all_sizes = [] 110 | # rand_sizes = np.random.randint(1, args.average_complex_size * 2, args.n_complex) 111 | targets = np.arange(2, args.average_complex_size * 2 - 1) 112 | weights = targets[::-1] + 4 113 | try: 114 | weights[0] = weights[1] 115 | weights[2] = weights[1] 116 | except: 117 | pass 118 | weights = weights / weights.sum() 119 | for _ in range(system_types): 120 | rand_sizes = np.random.choice(targets, size=args.n_complex, replace=False, p=weights) 121 | all_sizes.append(rand_sizes) 122 | 123 | print('All sizes:') 124 | print(all_sizes) 125 | 126 | # exit(0) 127 | # rand_sizes = [2, 3, 4] 128 | # rand_sizes = [2, 2, 3, 3] 129 | # rand_sizes = [3, 4, 5, 6, 7] 130 | 131 | 132 | print("Generating {} training simulations".format(args.num_train)) 133 | loc_train, vel_train, edges_train, charges_train, cfg_train = generate_dataset(args.num_train, 134 | args.length, 135 | args.sample_freq, 136 | all_sizes) 137 | # np.save(os.path.join(args.path, 'loc_train' + suffix + '.npy'), loc_train) 138 | # np.save(os.path.join(args.path, 'vel_train' + suffix + '.npy'), vel_train) 139 | # np.save(os.path.join(args.path, 'edges_train' + suffix + '.npy'), edges_train) 140 | # np.save(os.path.join(args.path, 'charges_train' + suffix + '.npy'), charges_train) 141 | with open(os.path.join(args.path, 'loc_train' + suffix + '.pkl'), 'wb') as f: 142 | pkl.dump(loc_train, f) 143 | with open(os.path.join(args.path, 'vel_train' + suffix + '.pkl'), 'wb') as f: 144 | pkl.dump(vel_train, f) 145 | with open(os.path.join(args.path, 'edges_train' + suffix + '.pkl'), 'wb') as f: 146 | pkl.dump(edges_train, f) 147 | with open(os.path.join(args.path, 'charges_train' + suffix + '.pkl'), 'wb') as f: 148 | pkl.dump(charges_train, f) 149 | with open(os.path.join(args.path, 'cfg_train' + suffix + '.pkl'), 'wb') as f: 150 | pkl.dump(cfg_train, f) 151 | 152 | 153 | print("Generating {} validation simulations".format(args.num_valid)) 154 | loc_valid, vel_valid, edges_valid, charges_valid, cfg_valid = generate_dataset(args.num_valid, 155 | args.length, 156 | args.sample_freq, 157 | all_sizes) 158 | # np.save(os.path.join(args.path, 'loc_valid' + suffix + '.npy'), loc_valid) 159 | # np.save(os.path.join(args.path, 'vel_valid' + suffix + '.npy'), vel_valid) 160 | # np.save(os.path.join(args.path, 'edges_valid' + suffix + '.npy'), edges_valid) 161 | # np.save(os.path.join(args.path, 'charges_valid' + suffix + '.npy'), charges_valid) 162 | with open(os.path.join(args.path, 'loc_valid' + suffix + '.pkl'), 'wb') as f: 163 | pkl.dump(loc_valid, f) 164 | with open(os.path.join(args.path, 'vel_valid' + suffix + '.pkl'), 'wb') as f: 165 | pkl.dump(vel_valid, f) 166 | with open(os.path.join(args.path, 'edges_valid' + suffix + '.pkl'), 'wb') as f: 167 | pkl.dump(edges_valid, f) 168 | with open(os.path.join(args.path, 'charges_valid' + suffix + '.pkl'), 'wb') as f: 169 | pkl.dump(charges_valid, f) 170 | with open(os.path.join(args.path, 'cfg_valid' + suffix + '.pkl'), 'wb') as f: 171 | pkl.dump(cfg_valid, f) 172 | 173 | print("Generating {} test simulations".format(args.num_test)) 174 | loc_test, vel_test, edges_test, charges_test, cfg_test = generate_dataset(args.num_test, 175 | args.length_test, 176 | args.sample_freq, 177 | all_sizes) 178 | # np.save(os.path.join(args.path, 'loc_test' + suffix + '.npy'), loc_test) 179 | # np.save(os.path.join(args.path, 'vel_test' + suffix + '.npy'), vel_test) 180 | # np.save(os.path.join(args.path, 'edges_test' + suffix + '.npy'), edges_test) 181 | # np.save(os.path.join(args.path, 'charges_test' + suffix + '.npy'), charges_test) 182 | with open(os.path.join(args.path, 'loc_test' + suffix + '.pkl'), 'wb') as f: 183 | pkl.dump(loc_test, f) 184 | with open(os.path.join(args.path, 'vel_test' + suffix + '.pkl'), 'wb') as f: 185 | pkl.dump(vel_test, f) 186 | with open(os.path.join(args.path, 'edges_test' + suffix + '.pkl'), 'wb') as f: 187 | pkl.dump(edges_test, f) 188 | with open(os.path.join(args.path, 'charges_test' + suffix + '.pkl'), 'wb') as f: 189 | pkl.dump(charges_test, f) 190 | with open(os.path.join(args.path, 'cfg_test' + suffix + '.pkl'), 'wb') as f: 191 | pkl.dump(cfg_test, f) 192 | print('Finished!') 193 | 194 | # python -u generate_dataset_complex.py --num-train 2000 --seed 43 --n_complex 5 --average_complex_size 3 --n_workers 50 195 | # python -u generate_dataset_complex.py --num-train 2000 --seed 43 --n_complex 10 --average_complex_size 5 --n_workers 50 196 | # python -u generate_dataset_complex.py --num-train 2000 --seed 43 --n_complex 15 --average_complex_size 8 --n_workers 50 -------------------------------------------------------------------------------- /mdanalysis/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | from scipy.sparse import coo_matrix 5 | import torch 6 | from pytorch3d import transforms 7 | from torch.utils.data import Dataset 8 | 9 | from MDAnalysisData import datasets 10 | import MDAnalysis as mda 11 | from MDAnalysis import transformations 12 | from MDAnalysis.analysis import distances 13 | 14 | 15 | class MDAnalysisDataset(Dataset): 16 | """ 17 | NBodyDataset 18 | 19 | """ 20 | 21 | def __init__(self, dataset_name, partition='train', tmp_dir=None, delta_frame=1, train_valid_test_ratio=None, 22 | test_rot=False, test_trans=False, load_cached=False, cut_off=6, backbone=False): 23 | super().__init__() 24 | self.delta_frame = delta_frame 25 | self.dataset = dataset_name 26 | self.partition = partition 27 | self.load_cached = load_cached 28 | self.test_rot = test_rot 29 | self.test_trans = test_trans 30 | self.cut_off = cut_off 31 | self.backbone = backbone 32 | if load_cached: 33 | print(f'Loading {dataset_name} from cached data for {partition}...') 34 | if backbone: 35 | tmp_dir = os.path.join(tmp_dir, 'adk_backbone_processed') 36 | else: 37 | tmp_dir = os.path.join(tmp_dir, 'adk_processed') 38 | self.tmp_dir = tmp_dir 39 | if train_valid_test_ratio is None: 40 | train_valid_test_ratio = [0.6, 0.2, 0.2] 41 | assert sum(train_valid_test_ratio) <= 1 42 | 43 | if load_cached: 44 | edges, self.edge_attr, self.charges, self.n_frames = torch.load(os.path.join(tmp_dir, 45 | f'{dataset_name}.pkl')) 46 | self.edges = torch.stack(edges, dim=0) 47 | self.train_valid_test = [int(train_valid_test_ratio[0] * (self.n_frames - delta_frame)), 48 | int(sum(train_valid_test_ratio[:2]) * (self.n_frames - delta_frame))] 49 | return 50 | 51 | assert not self.backbone, NotImplementedError("Use load_cached for backbone case.") 52 | if dataset_name.lower() == 'adk': 53 | adk = datasets.fetch_adk_equilibrium(data_home=tmp_dir) 54 | self.data = mda.Universe(adk.topology, adk.trajectory) 55 | else: 56 | raise NotImplementedError(f'{dataset_name} is not available in MDAnalysisData.') 57 | 58 | # Local Graph information 59 | try: 60 | self.charges = torch.tensor(self.data.atoms.charges) 61 | except OSError: 62 | print(f'Charge error') 63 | try: 64 | self.edges = torch.stack([torch.tensor(self.data.bonds.indices[:, 0], dtype=torch.long), 65 | torch.tensor(self.data.bonds.indices[:, 1], dtype=torch.long)], dim=0) 66 | except OSError: 67 | print(f'edges error') 68 | try: 69 | self.edge_attr = torch.tensor([bond.length() for bond in self.data.bonds]) 70 | except OSError: 71 | print(f'edge_attr error') 72 | 73 | self.train_valid_test = [int(train_valid_test_ratio[0] * (len(self.data.trajectory) - delta_frame)), 74 | int(sum(train_valid_test_ratio[:2]) * (len(self.data.trajectory) - delta_frame))] 75 | 76 | def __getitem__(self, i): 77 | 78 | charges, edges, edge_attr = self.charges, self.edges, self.edge_attr 79 | if len(charges.size()) == 1: 80 | charges = charges.unsqueeze(-1) 81 | if len(edge_attr.size()) == 1: 82 | edge_attr = edge_attr.unsqueeze(-1) 83 | 84 | if self.partition == "valid": 85 | i = i + self.train_valid_test[0] 86 | elif self.partition == "test": 87 | i = i + self.train_valid_test[1] 88 | 89 | # Frames 90 | frame_0, frame_t = i, i + self.delta_frame 91 | 92 | if self.load_cached: 93 | loc_0, vel_0, edge_global, edge_global_attr = torch.load(os.path.join(self.tmp_dir, 94 | f'{self.dataset}_{frame_0}.pkl')) 95 | edge_global = torch.stack(edge_global, dim=0) 96 | if len(edge_global_attr.size()) == 1: 97 | edge_global_attr = edge_global_attr.unsqueeze(-1) 98 | 99 | loc_t, vel_t, _, _ = torch.load(os.path.join(self.tmp_dir, 100 | f'{self.dataset}_{frame_t}.pkl')) 101 | if self.test_rot and self.partition == 'test': 102 | rot = transforms.random_rotation() 103 | loc_0 = torch.tensor(np.matmul(loc_0.detach().numpy(), rot.detach().numpy())) 104 | vel_0 = torch.tensor(np.matmul(vel_0.detach().numpy(), rot.detach().numpy())) 105 | loc_t = torch.tensor(np.matmul(loc_t.detach().numpy(), rot.detach().numpy())) 106 | vel_t = torch.tensor(np.matmul(vel_t.detach().numpy(), rot.detach().numpy())) 107 | if self.test_trans and self.partition == 'test': 108 | dimension = loc_t.max(dim=0)[0] - loc_t.min(dim=0)[0] 109 | trans = torch.randn(3) * dimension / 2 110 | loc_0 += trans 111 | loc_t += trans 112 | return loc_0, vel_0, edge_global, edge_global_attr, edges, edge_attr, charges, loc_t, vel_t 113 | 114 | assert not self.backbone, NotImplementedError("Use load_cached for backbone case.") 115 | 116 | ts_0, ts_t, d, angle, trans = None, None, [0, 0, 1], 0, [0, 0, 0] 117 | # Initial frame 118 | retry_0 = 0 119 | while retry_0 < 10: 120 | try: 121 | ts_0 = self.data.trajectory[frame_0].copy() 122 | if not ts_0.has_velocities: 123 | ts_0.velocities = self.data.trajectory[frame_0 + 1].positions - \ 124 | self.data.trajectory[frame_0].positions 125 | retry_0 = 11 126 | except OSError: 127 | print(f'Reading error at {frame_0}') 128 | retry_0 += 1 129 | assert retry_0 != 10, OSError(f'Falied to read positions by 10 times') 130 | 131 | # Rotations and Translations 132 | if self.test_rot and self.partition == "test": 133 | d = np.random.randn(3) 134 | d = d / np.linalg.norm(d) 135 | angle = random.randint(0, 360) 136 | ts_0 = transformations.rotate.rotateby(angle, direction=d, ag=self.data.atoms)(ts_0) 137 | if self.test_trans and self.partition == 'test': 138 | trans = np.random.randn(3) * ts_0.dimensions[:3] / 2 139 | ts_0 = transformations.translate(trans)(ts_0) 140 | loc_0 = torch.tensor(ts_0.positions) 141 | vel_0 = torch.tensor(ts_0.velocities) 142 | 143 | # Global edges 144 | edge_global = coo_matrix(distances.contact_matrix(loc_0.detach().numpy(), 145 | cutoff=self.cut_off, returntype="sparse")) 146 | edge_global.setdiag(False) 147 | edge_global.eliminate_zeros() 148 | edge_global = torch.stack([torch.tensor(edge_global.row, dtype=torch.long), 149 | torch.tensor(edge_global.col, dtype=torch.long)], dim=0) 150 | edge_global_attr = torch.norm(loc_0[edge_global[0], :] - loc_0[edge_global[1], :], p=2, dim=1).unsqueeze(-1) 151 | 152 | # Final frames 153 | retry_t = 0 154 | while retry_t < 10: 155 | try: 156 | ts_t = self.data.trajectory[frame_t].copy() 157 | if not ts_t.has_velocities: 158 | ts_t.velocities = self.data.trajectory[frame_t + 1].positions - \ 159 | self.data.trajectory[frame_t].positions 160 | retry_t = 11 161 | except OSError: 162 | print(f'Reading error at {frame_t} t') 163 | retry_t += 1 164 | assert retry_t != 10, OSError(f'Falied to read velocity by 10 times') 165 | 166 | # Rotations and Translations 167 | if self.test_rot and self.partition == "test": 168 | ts_t = transformations.rotate.rotateby(angle, direction=d, ag=self.data.atoms)(ts_t) 169 | if self.test_trans and self.partition == 'test': 170 | ts_t = transformations.translate(trans)(ts_t) 171 | loc_t = torch.tensor(ts_t.positions) 172 | vel_t = torch.tensor(ts_t.velocities) 173 | 174 | return loc_0, vel_0, edge_global, edge_global_attr, edges, edge_attr, charges, loc_t, vel_t 175 | 176 | def __len__(self): 177 | if self.load_cached: 178 | total_len = max(0, self.n_frames - self.delta_frame) 179 | else: 180 | total_len = max(0, len(self.data.trajectory) - 1 - self.delta_frame) 181 | if self.partition == 'train': 182 | return min(total_len, self.train_valid_test[0]) 183 | if self.partition == 'valid': 184 | return max(0, min(total_len, self.train_valid_test[1]) - self.train_valid_test[0]) 185 | if self.partition == 'test': 186 | return max(0, total_len - self.train_valid_test[1]) 187 | 188 | @staticmethod 189 | def get_cfg(batch_size, n_nodes, cfg): 190 | offset = torch.arange(batch_size) * n_nodes 191 | for type in cfg: 192 | index = cfg[type] # [B, n_type, node_per_type] 193 | cfg[type] = (index + offset.unsqueeze(-1).unsqueeze(-1).expand_as(index)).reshape(-1, index.shape[-1]) 194 | if type == 'Isolated': 195 | cfg[type] = cfg[type].squeeze(-1) 196 | return cfg 197 | 198 | 199 | def collate_mda(data): 200 | loc_0, vel_0, edge_global, edge_global_attr, edges, edge_attr, charges, loc_t, vel_t = zip(*data) 201 | 202 | # edges 203 | offset = torch.cumsum(torch.tensor([0] + [_.size(0) for _ in loc_0], dtype=torch.long), dim=0) 204 | edge_global = torch.cat(list(map(lambda _: _[0] + _[1], zip(edge_global, offset))), dim=-1) 205 | edges = torch.cat(list(map(lambda _: _[0] + _[1], zip(edges, offset))), dim=-1) 206 | edge_global_attr = torch.cat(edge_global_attr, dim=0).type(torch.float) 207 | edge_attr = torch.cat(edge_attr, dim=0).type(torch.float) 208 | 209 | loc_0 = torch.stack(loc_0, dim=0).type(torch.float) 210 | vel_0 = torch.stack(vel_0, dim=0).view(-1, vel_0[0].size(-1)).type(torch.float) 211 | loc_t = torch.stack(loc_t, dim=0).view(-1, loc_t[0].size(-1)).type(torch.float) 212 | vel_t = torch.stack(vel_t, dim=0).view(-1, vel_t[0].size(-1)).type(torch.float) 213 | charges = torch.stack(charges, dim=0).view(-1, charges[0].size(-1)).type(torch.float) 214 | 215 | return loc_0, vel_0, edge_global, edge_global_attr, edges, edge_attr, charges, loc_t, vel_t 216 | -------------------------------------------------------------------------------- /eval_mdanalysis.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from argparse import Namespace 3 | import torch 4 | import torch.utils.data 5 | from tqdm import tqdm 6 | from mdanalysis.dataset import MDAnalysisDataset, collate_mda 7 | from model.eghn import EGHN 8 | import os 9 | from torch import nn, optim 10 | import json 11 | 12 | import random 13 | import numpy as np 14 | 15 | parser = argparse.ArgumentParser(description='VAE MNIST Example') 16 | parser.add_argument('--exp_name', type=str, default='exp_1', metavar='N', help='experiment_name') 17 | parser.add_argument('--batch_size', type=int, default=100, metavar='N', 18 | help='input batch size for training (default: 128)') 19 | parser.add_argument('--epochs', type=int, default=10000, metavar='N', 20 | help='number of epochs to train (default: 10)') 21 | parser.add_argument('--no-cuda', action='store_true', default=False, 22 | help='enables CUDA training') 23 | parser.add_argument('--seed', type=int, default=1, metavar='S', 24 | help='random seed (default: 1)') 25 | parser.add_argument('--log_interval', type=int, default=1, metavar='N', 26 | help='how many batches to wait before logging training status') 27 | parser.add_argument('--test_interval', type=int, default=5, metavar='N', 28 | help='how many epochs to wait before logging test') 29 | parser.add_argument('--outf', type=str, default='exp_results', metavar='N', 30 | help='folder to output the json log file') 31 | parser.add_argument('--lr', type=float, default=5e-4, metavar='N', 32 | help='learning rate') 33 | parser.add_argument('--nf', type=int, default=64, metavar='N', 34 | help='hidden dim') 35 | parser.add_argument('--model', type=str, default='hier', metavar='N') 36 | parser.add_argument('--attention', type=int, default=0, metavar='N', 37 | help='attention in the ae model') 38 | parser.add_argument('--n_layers', type=int, default=4, metavar='N', 39 | help='number of layers for the autoencoder') 40 | parser.add_argument('--max_training_samples', type=int, default=3000, metavar='N', 41 | help='maximum amount of training samples') 42 | parser.add_argument('--dataset', type=str, default="nbody_small", metavar='N', 43 | help='nbody_small, nbody') 44 | parser.add_argument('--weight_decay', type=float, default=1e-12, metavar='N', 45 | help='timing experiment') 46 | parser.add_argument('--delta_frame', type=int, default=50, 47 | help='Number of frames delta.') 48 | parser.add_argument('--data_dir', type=str, default='spatial_graph/md17', 49 | help='Data directory.') 50 | parser.add_argument('--model_dir', type=str, default='spatial_graph/md17', 51 | help='Data directory.') 52 | parser.add_argument('--dropout', type=float, default=0.5, 53 | help='Dropout rate (1 - keep probability).') 54 | parser.add_argument("--config_by_file", default=False, action="store_true", ) 55 | 56 | parser.add_argument('--lambda_link', type=float, default=1, 57 | help='The weight of the linkage loss.') 58 | parser.add_argument('--n_cluster', type=int, default=3, 59 | help='The number of clusters.') 60 | parser.add_argument('--flat', action='store_true', default=False, 61 | help='flat MLP') 62 | parser.add_argument('--interaction_layer', type=int, default=3, 63 | help='The number of interaction layers per block.') 64 | parser.add_argument('--pooling_layer', type=int, default=3, 65 | help='The number of pooling layers in EGPN.') 66 | parser.add_argument('--decoder_layer', type=int, default=1, 67 | help='The number of decoder layers.') 68 | 69 | parser.add_argument("--n_workers", '-n', type=int, default=8, help="Number of workers.") 70 | parser.add_argument("--load_cached", action="store_true", help="Load cached dataset.") 71 | parser.add_argument("--test_rot", action="store_true", help="Rotate the test") 72 | parser.add_argument("--test_trans", action="store_true", help="Translate the test") 73 | parser.add_argument("--top_k", type=int, default=None, help="Translate the test") 74 | 75 | 76 | time_exp_dic = {'time': 0, 'counter': 0} 77 | 78 | args = parser.parse_args() 79 | 80 | if args.config_by_file: 81 | job_param_path = './job_param.json' 82 | with open(job_param_path, 'r') as f: 83 | hyper_params = json.load(f) 84 | # Only update existing keys 85 | args = vars(args) 86 | args.update((k, v) for k, v in hyper_params.items() if k in args) 87 | args = Namespace(**args) 88 | 89 | # Place the checkpoint file here 90 | ckpt_file = os.path.join(args.model_dir, args.exp_name, 'saved_model.pth') 91 | 92 | args.cuda = not args.no_cuda and torch.cuda.is_available() 93 | 94 | 95 | device = torch.device("cuda" if args.cuda else "cpu") 96 | loss_mse = nn.MSELoss() 97 | loss_all = nn.MSELoss(reduction='none') 98 | 99 | print(args) 100 | 101 | 102 | def main(): 103 | # fix seed 104 | seed = args.seed 105 | random.seed(seed) 106 | np.random.seed(seed) 107 | torch.manual_seed(seed) 108 | torch.cuda.manual_seed(seed) 109 | 110 | dataset_train = MDAnalysisDataset('adk', partition='train', tmp_dir=args.data_dir, 111 | delta_frame=args.delta_frame, load_cached=args.load_cached) 112 | loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, drop_last=False, 113 | num_workers=args.n_workers, collate_fn=collate_mda) 114 | 115 | dataset_test = MDAnalysisDataset('adk', partition='test', tmp_dir=args.data_dir, 116 | delta_frame=args.delta_frame, load_cached=args.load_cached, 117 | test_rot=args.test_rot, test_trans=args.test_trans) 118 | loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=args.batch_size, shuffle=False, drop_last=False, 119 | num_workers=args.n_workers, collate_fn=collate_mda) 120 | 121 | if args.model == 'hier': 122 | model = EGHN(in_node_nf=2, in_edge_nf=2, hidden_nf=args.nf, device=device, 123 | n_cluster=args.n_cluster, flat=args.flat, layer_per_block=args.interaction_layer, 124 | layer_pooling=args.pooling_layer, activation=nn.SiLU(), 125 | layer_decoder=args.decoder_layer) 126 | model.load_state_dict(torch.load(ckpt_file)) 127 | print('loaded from ', ckpt_file) 128 | else: 129 | raise Exception("Wrong model specified") 130 | 131 | print(model) 132 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 133 | 134 | model.eval() 135 | save_name = os.path.join(args.outf, f'{args.exp_name}_eval_train.pkl') 136 | train_loss = train(model, optimizer, 0, loader_train, backprop=False, save_name=save_name) 137 | save_name = os.path.join(args.outf, f'{args.exp_name}_eval_test.pkl') 138 | test_loss = train(model, optimizer, 0, loader_test, backprop=False, save_name=save_name) 139 | exit(0) 140 | 141 | return best_train_loss, best_val_loss, best_test_loss, best_epoch 142 | 143 | 144 | def train(model, optimizer, epoch, loader, backprop=True, save_name=None): 145 | all_loc, all_loc_pred, all_loc_end, all_loss = None, None, None, None 146 | all_pooling_plan = None 147 | if backprop: 148 | model.train() 149 | else: 150 | model.eval() 151 | 152 | res = {'epoch': epoch, 'loss': 0, 'counter': 0} 153 | 154 | for batch_idx, data in enumerate(tqdm(loader)): 155 | batch_size, n_nodes, _ = data[0].size() 156 | data = [d.to(device) for d in data] 157 | # data = [d.view(-1, d.size(2)) for d in data] # construct mini-batch graphs 158 | loc, vel, edges, edge_attr, local_edge_index, local_edge_fea, Z, loc_end, vel_end = data 159 | # convert into graph minibatch 160 | loc = loc.view(-1, loc.size(2)) 161 | 162 | if all_loc is None: 163 | all_loc = loc.detach().cpu() 164 | else: 165 | all_loc = torch.cat((all_loc, loc.detach().cpu()), dim=0) 166 | 167 | optimizer.zero_grad() 168 | 169 | if args.model == 'hier': 170 | nodes = torch.sqrt(torch.sum(vel ** 2, dim=1)).unsqueeze(1).detach() 171 | nodes = torch.cat((nodes, Z / Z.max()), dim=-1) 172 | rows, cols = edges 173 | loc_dist = torch.sum((loc[rows] - loc[cols])**2, 1).unsqueeze(1) # relative distances among locations 174 | edge_attr = torch.cat([edge_attr, loc_dist], 1).detach() # concatenate all edge properties 175 | loc_dist1 = torch.sum((loc[local_edge_index[0]] - loc[local_edge_index[1]])**2, 1).unsqueeze(1) 176 | local_edge_fea = torch.cat([local_edge_fea, loc_dist1], 1).detach() # concatenate all edge properties 177 | loc_pred, vel_pred, _ = model(loc, nodes, edges, edge_attr, local_edge_index, local_edge_fea, 178 | n_node=n_nodes, v=vel, node_mask=None, node_nums=None) 179 | else: 180 | raise Exception("Wrong model") 181 | 182 | if all_loc_pred is None: 183 | all_loc_pred = loc_pred.detach().cpu() 184 | else: 185 | all_loc_pred = torch.cat((all_loc_pred, loc_pred.detach().cpu()), dim=0) 186 | 187 | if all_loc_end is None: 188 | all_loc_end = loc_end.detach().cpu() 189 | else: 190 | all_loc_end = torch.cat((all_loc_end, loc_end.detach().cpu()), dim=0) 191 | 192 | cur_pooling_plan = model.current_pooling_plan 193 | if all_pooling_plan is None: 194 | all_pooling_plan = cur_pooling_plan.detach().cpu() 195 | else: 196 | all_pooling_plan = torch.cat((all_pooling_plan, cur_pooling_plan.detach().cpu()), dim=0) 197 | 198 | loss = loss_mse(loc_pred, loc_end) 199 | 200 | if all_loss is None: 201 | all_loss = loss_all(loc_pred, loc_end).sum(dim=1).detach().cpu() 202 | else: 203 | all_loss = torch.cat((all_loss, loss_all(loc_pred, loc_end).sum(dim=1).detach().cpu()), dim=0) 204 | 205 | if backprop: 206 | loss.backward() 207 | optimizer.step() 208 | pass 209 | res['loss'] += loss.item()*batch_size 210 | res['counter'] += batch_size 211 | import pickle as pkl 212 | with open(save_name, 'wb') as f: 213 | pkl.dump((all_loc.numpy(), 214 | all_loc_end.numpy(), 215 | all_loc_pred.numpy(), 216 | all_pooling_plan.numpy(), 217 | all_loss.numpy() 218 | ), f) 219 | print('Saved to ', save_name) 220 | 221 | if not backprop: 222 | prefix = "==> " 223 | else: 224 | prefix = "" 225 | print('%s epoch %d avg loss: %.5f' 226 | % (prefix+loader.dataset.partition, epoch, res['loss'] / res['counter'])) 227 | 228 | return res['loss'] / res['counter'] 229 | 230 | 231 | if __name__ == "__main__": 232 | best_train_loss, best_val_loss, best_test_loss, best_epoch = main() 233 | print("best_train = %.6f" % best_train_loss) 234 | print("best_val = %.6f" % best_val_loss) 235 | print("best_test = %.6f" % best_test_loss) 236 | print("best_epoch = %d" % best_epoch) 237 | print("best_train = %.6f, best_val = %.6f, best_test = %.6f, best_epoch = %d" % (best_train_loss, best_val_loss, best_test_loss, best_epoch)) 238 | 239 | 240 | 241 | 242 | 243 | -------------------------------------------------------------------------------- /eval_mocap.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.utils.data 4 | from motion.dataset import MotionDataset 5 | from model.eghn import EGHN 6 | import os 7 | from torch import nn, optim 8 | import json 9 | 10 | import random 11 | import numpy as np 12 | 13 | parser = argparse.ArgumentParser(description='VAE MNIST Example') 14 | parser.add_argument('--exp_name', type=str, default='exp_1', metavar='N', help='experiment_name') 15 | parser.add_argument('--batch_size', type=int, default=100, metavar='N', 16 | help='input batch size for training (default: 128)') 17 | parser.add_argument('--epochs', type=int, default=10000, metavar='N', 18 | help='number of epochs to train (default: 10)') 19 | parser.add_argument('--no-cuda', action='store_true', default=False, 20 | help='enables CUDA training') 21 | parser.add_argument('--seed', type=int, default=1, metavar='S', 22 | help='random seed (default: 1)') 23 | parser.add_argument('--log_interval', type=int, default=1, metavar='N', 24 | help='how many batches to wait before logging training status') 25 | parser.add_argument('--test_interval', type=int, default=5, metavar='N', 26 | help='how many epochs to wait before logging test') 27 | parser.add_argument('--outf', type=str, default='exp_results', metavar='N', 28 | help='folder to output the json log file') 29 | parser.add_argument('--lr', type=float, default=5e-4, metavar='N', 30 | help='learning rate') 31 | parser.add_argument('--nf', type=int, default=64, metavar='N', 32 | help='hidden dim') 33 | parser.add_argument('--model', type=str, default='hier', metavar='N') 34 | parser.add_argument('--attention', type=int, default=0, metavar='N', 35 | help='attention in the ae model') 36 | parser.add_argument('--n_layers', type=int, default=4, metavar='N', 37 | help='number of layers for the autoencoder') 38 | parser.add_argument('--max_training_samples', type=int, default=3000, metavar='N', 39 | help='maximum amount of training samples') 40 | parser.add_argument('--dataset', type=str, default="nbody_small", metavar='N', 41 | help='nbody_small, nbody') 42 | parser.add_argument('--weight_decay', type=float, default=1e-12, metavar='N', 43 | help='timing experiment') 44 | parser.add_argument('--delta_frame', type=int, default=50, 45 | help='Number of frames delta.') 46 | parser.add_argument('--data_dir', type=str, default='spatial_graph/md17', 47 | help='Data directory.') 48 | parser.add_argument('--dropout', type=float, default=0.5, 49 | help='Dropout rate (1 - keep probability).') 50 | parser.add_argument("--config_by_file", default=False, action="store_true", ) 51 | 52 | parser.add_argument('--lambda_link', type=float, default=1, 53 | help='The weight of the linkage loss.') 54 | parser.add_argument('--n_cluster', type=int, default=3, 55 | help='The number of clusters.') 56 | parser.add_argument('--flat', action='store_true', default=False, 57 | help='flat MLP') 58 | parser.add_argument('--interaction_layer', type=int, default=3, 59 | help='The number of interaction layers per block.') 60 | parser.add_argument('--pooling_layer', type=int, default=3, 61 | help='The number of pooling layers in EGPN.') 62 | parser.add_argument('--decoder_layer', type=int, default=1, 63 | help='The number of decoder layers.') 64 | 65 | parser.add_argument('--case', type=str, default='walk', 66 | help='The case, walk or run.') 67 | 68 | time_exp_dic = {'time': 0, 'counter': 0} 69 | 70 | args = parser.parse_args() 71 | # Place the checkpoint file here 72 | ckpt_file = args.outf + '/' + args.exp_name + '/' + 'saved_model.pth' 73 | 74 | if args.config_by_file: 75 | job_param_path = './job_param.json' 76 | with open(job_param_path, 'r') as f: 77 | hyper_params = json.load(f) 78 | args.exp_name = hyper_params["exp_name"] 79 | args.batch_size = hyper_params["batch_size"] 80 | args.epochs = hyper_params["epochs"] 81 | args.seed = hyper_params["seed"] 82 | args.lr = hyper_params["lr"] 83 | args.nf = hyper_params["nf"] 84 | args.model = hyper_params["model"] 85 | args.n_layers = hyper_params["n_layers"] 86 | args.max_training_samples = hyper_params["max_training_samples"] 87 | # Do not necessary in practice. 88 | args.data_dir = hyper_params["data_dir"] 89 | args.weight_decay = hyper_params["weight_decay"] 90 | 91 | args.dropout = hyper_params["dropout"] 92 | args.lambda_link = hyper_params["lambda_link"] 93 | args.n_cluster = hyper_params["n_cluster"] 94 | args.flat = hyper_params["flat"] 95 | args.interaction_layer = hyper_params["interaction_layer"] 96 | args.pooling_layer = hyper_params["pooling_layer"] 97 | args.decoder_layer = hyper_params["decoder_layer"] 98 | 99 | args.case = hyper_params["case"] 100 | 101 | args.cuda = not args.no_cuda and torch.cuda.is_available() 102 | 103 | 104 | device = torch.device("cuda" if args.cuda else "cpu") 105 | loss_mse = nn.MSELoss() 106 | 107 | print(args) 108 | 109 | 110 | def main(): 111 | # fix seed 112 | seed = args.seed 113 | random.seed(seed) 114 | np.random.seed(seed) 115 | torch.manual_seed(seed) 116 | torch.cuda.manual_seed(seed) 117 | 118 | dataset_train = MotionDataset(partition='train', max_samples=args.max_training_samples, data_dir=args.data_dir, 119 | delta_frame=args.delta_frame, case=args.case) 120 | loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, drop_last=False, 121 | num_workers=8) 122 | 123 | dataset_val = MotionDataset(partition='val', max_samples=600, data_dir=args.data_dir, 124 | delta_frame=args.delta_frame, case=args.case) 125 | loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, drop_last=False, 126 | num_workers=8) 127 | 128 | dataset_test = MotionDataset(partition='test', max_samples=600, data_dir=args.data_dir, 129 | delta_frame=args.delta_frame, case=args.case) 130 | loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=args.batch_size, shuffle=False, drop_last=False, 131 | num_workers=8) 132 | 133 | if args.model == 'hier': 134 | model = EGHN(in_node_nf=2, in_edge_nf=2, hidden_nf=args.nf, device=device, 135 | n_cluster=args.n_cluster, flat=args.flat, layer_per_block=args.interaction_layer, 136 | layer_pooling=args.pooling_layer, activation=nn.SiLU(), 137 | layer_decoder=args.decoder_layer) 138 | model.load_state_dict(torch.load(ckpt_file)) 139 | print('loaded from ', ckpt_file) 140 | else: 141 | raise Exception("Wrong model specified") 142 | 143 | print(model) 144 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 145 | 146 | model.eval() 147 | save_name = args.outf + '/' + args.exp_name + '/' + 'eval_train.pkl' 148 | train_loss = train(model, optimizer, 0, loader_train, backprop=False, save_name=save_name) 149 | save_name = args.outf + '/' + args.exp_name + '/' + 'eval_test.pkl' 150 | test_loss = train(model, optimizer, 0, loader_test, backprop=False, save_name=save_name) 151 | exit(0) 152 | 153 | return best_train_loss, best_val_loss, best_test_loss, best_epoch 154 | 155 | 156 | def train(model, optimizer, epoch, loader, backprop=True, save_name=None): 157 | all_loc, all_loc_pred, all_loc_end = None, None, None 158 | all_pooling_plan = None 159 | if backprop: 160 | model.train() 161 | else: 162 | model.eval() 163 | 164 | res = {'epoch': epoch, 'loss': 0, 'counter': 0} 165 | 166 | for batch_idx, data in enumerate(loader): 167 | batch_size, n_nodes, _ = data[0].size() 168 | data = [d.to(device) for d in data] 169 | # data = [d.view(-1, d.size(2)) for d in data] # construct mini-batch graphs 170 | loc, vel, edges, edge_attr, local_edges, local_edge_fea, Z, loc_end, vel_end = data 171 | # convert into graph minibatch 172 | loc = loc.view(-1, loc.size(2)) 173 | vel = vel.view(-1, vel.size(2)) 174 | offset = (torch.arange(batch_size) * n_nodes).unsqueeze(-1).unsqueeze(-1).to(edges.device) 175 | edges = torch.cat(list(edges + offset), dim=-1) # [2, BM] 176 | edge_attr = torch.cat(list(edge_attr), dim=0) # [BM, ] 177 | local_edge_index = torch.cat(list(local_edges + offset), dim=-1) # [2, BM] 178 | local_edge_fea = torch.cat(list(local_edge_fea), dim=0) # [BM, ] 179 | # local_edge_mask = torch.cat(list(local_edge_mask), dim=0) # [BM, ] 180 | Z = Z.view(-1, Z.size(2)) 181 | loc_end = loc_end.view(-1, loc_end.size(2)) 182 | vel_end = vel_end.view(-1, vel_end.size(2)) 183 | 184 | if all_loc is None: 185 | all_loc = loc 186 | else: 187 | all_loc = torch.cat((all_loc, loc), dim=0) 188 | 189 | optimizer.zero_grad() 190 | 191 | if args.model == 'hier': 192 | nodes = torch.sqrt(torch.sum(vel ** 2, dim=1)).unsqueeze(1).detach() 193 | nodes = torch.cat((nodes, Z / Z.max()), dim=-1) 194 | rows, cols = edges 195 | loc_dist = torch.sum((loc[rows] - loc[cols])**2, 1).unsqueeze(1) # relative distances among locations 196 | edge_attr = torch.cat([edge_attr, loc_dist], 1).detach() # concatenate all edge properties 197 | loc_dist1 = torch.sum((loc[local_edge_index[0]] - loc[local_edge_index[1]])**2, 1).unsqueeze(1) 198 | local_edge_fea = torch.cat([local_edge_fea, loc_dist1], 1).detach() # concatenate all edge properties 199 | loc_pred, vel_pred, _ = model(loc, nodes, edges, edge_attr, local_edge_index, local_edge_fea, 200 | n_node=n_nodes, v=vel, node_mask=None, node_nums=None) 201 | else: 202 | raise Exception("Wrong model") 203 | 204 | if all_loc_pred is None: 205 | all_loc_pred = loc_pred 206 | else: 207 | all_loc_pred = torch.cat((all_loc_pred, loc_pred), dim=0) 208 | 209 | if all_loc_end is None: 210 | all_loc_end = loc_end 211 | else: 212 | all_loc_end = torch.cat((all_loc_end, loc_end), dim=0) 213 | 214 | cur_pooling_plan = model.current_pooling_plan 215 | if all_pooling_plan is None: 216 | all_pooling_plan = cur_pooling_plan 217 | else: 218 | all_pooling_plan = torch.cat((all_pooling_plan, cur_pooling_plan), dim=0) 219 | 220 | loss = loss_mse(loc_pred, loc_end) 221 | 222 | if backprop: 223 | loss.backward() 224 | optimizer.step() 225 | pass 226 | res['loss'] += loss.item()*batch_size 227 | res['counter'] += batch_size 228 | import pickle as pkl 229 | with open(save_name, 'wb') as f: 230 | pkl.dump((all_loc.detach().cpu().numpy(), 231 | all_loc_end.detach().cpu().numpy(), 232 | all_loc_pred.detach().cpu().numpy(), 233 | all_pooling_plan.detach().cpu().numpy() 234 | ), f) 235 | print('Saved to ', save_name) 236 | 237 | if not backprop: 238 | prefix = "==> " 239 | else: 240 | prefix = "" 241 | print('%s epoch %d avg loss: %.5f' 242 | % (prefix+loader.dataset.partition, epoch, res['loss'] / res['counter'])) 243 | 244 | return res['loss'] / res['counter'] 245 | 246 | 247 | if __name__ == "__main__": 248 | best_train_loss, best_val_loss, best_test_loss, best_epoch = main() 249 | print("best_train = %.6f" % best_train_loss) 250 | print("best_val = %.6f" % best_val_loss) 251 | print("best_test = %.6f" % best_test_loss) 252 | print("best_epoch = %d" % best_epoch) 253 | print("best_train = %.6f, best_val = %.6f, best_test = %.6f, best_epoch = %d" % (best_train_loss, best_val_loss, best_test_loss, best_epoch)) 254 | 255 | 256 | 257 | 258 | 259 | -------------------------------------------------------------------------------- /main_simulation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.utils.data 4 | from simulation.dataset import SimulationDataset 5 | from model.eghn import EGHN 6 | from utils import collector_simulation as collector, MaskMSELoss, EarlyStopping 7 | import os 8 | from torch import nn, optim 9 | import json 10 | 11 | import random 12 | import numpy as np 13 | 14 | parser = argparse.ArgumentParser(description='VAE MNIST Example') 15 | parser.add_argument('--exp_name', type=str, default='exp_1', metavar='N', help='experiment_name') 16 | parser.add_argument('--batch_size', type=int, default=100, metavar='N', 17 | help='input batch size for training (default: 128)') 18 | parser.add_argument('--epochs', type=int, default=10000, metavar='N', 19 | help='number of epochs to train (default: 10)') 20 | parser.add_argument('--no-cuda', action='store_true', default=False, 21 | help='enables CUDA training') 22 | parser.add_argument('--seed', type=int, default=1, metavar='S', 23 | help='random seed (default: 1)') 24 | parser.add_argument('--log_interval', type=int, default=1, metavar='N', 25 | help='how many batches to wait before logging training status') 26 | parser.add_argument('--test_interval', type=int, default=5, metavar='N', 27 | help='how many epochs to wait before logging test') 28 | parser.add_argument('--outf', type=str, default='exp_results', metavar='N', 29 | help='folder to output the json log file') 30 | parser.add_argument('--lr', type=float, default=5e-4, metavar='N', 31 | help='learning rate') 32 | parser.add_argument('--nf', type=int, default=64, metavar='N', 33 | help='hidden dim') 34 | parser.add_argument('--model', type=str, default='hier', metavar='N') 35 | parser.add_argument('--n_layers', type=int, default=4, metavar='N', 36 | help='number of layers for the autoencoder') 37 | parser.add_argument('--max_training_samples', type=int, default=1000, metavar='N', 38 | help='maximum amount of training samples') 39 | parser.add_argument('--weight_decay', type=float, default=1e-12, metavar='N', 40 | help='timing experiment') 41 | parser.add_argument('--data_dir', type=str, default='spatial_graph/md17', 42 | help='Data directory.') 43 | parser.add_argument('--dropout', type=float, default=0.5, 44 | help='Dropout rate (1 - keep probability).') 45 | parser.add_argument("--config_by_file", default=False, action="store_true", ) 46 | 47 | parser.add_argument('--n_complex', type=int, default=5, 48 | help='Number of complex bodies.') 49 | parser.add_argument('--average_complex_size', type=int, default=3, 50 | help='Average size of complex bodies.') 51 | parser.add_argument('--system_types', type=int, default=5, 52 | help="The total number of system types.") 53 | 54 | parser.add_argument('--lambda_link', type=float, default=1, 55 | help='The weight of the linkage loss.') 56 | parser.add_argument('--n_cluster', type=int, default=3, 57 | help='The number of clusters.') 58 | parser.add_argument('--flat', action='store_true', default=False, 59 | help='flat MLP') 60 | parser.add_argument('--interaction_layer', type=int, default=3, 61 | help='The number of interaction layers per block.') 62 | parser.add_argument('--pooling_layer', type=int, default=3, 63 | help='The number of pooling layers in EGPN.') 64 | parser.add_argument('--decoder_layer', type=int, default=1, 65 | help='The number of decoder layers.') 66 | parser.add_argument('--norm', action='store_true', default=False, 67 | help='Use norm in EGNN') 68 | 69 | time_exp_dic = {'time': 0, 'counter': 0} 70 | 71 | 72 | args = parser.parse_args() 73 | if args.config_by_file: 74 | job_param_path = './job_param.json' 75 | with open(job_param_path, 'r') as f: 76 | hyper_params = json.load(f) 77 | args.exp_name = hyper_params["exp_name"] 78 | args.batch_size = hyper_params["batch_size"] 79 | args.epochs = hyper_params["epochs"] 80 | #args.no_cuda = hyper_params["no_cuda"] 81 | args.seed = hyper_params["seed"] 82 | args.lr = hyper_params["lr"] 83 | args.nf = hyper_params["nf"] 84 | args.model = hyper_params["model"] 85 | args.n_layers = hyper_params["n_layers"] 86 | args.max_training_samples = hyper_params["max_training_samples"] 87 | # Do not necessary in practice. 88 | args.data_dir = hyper_params["data_dir"] 89 | args.weight_decay = hyper_params["weight_decay"] 90 | args.dropout = hyper_params["dropout"] 91 | args.n_complex = hyper_params["n_complex"] 92 | args.average_complex_size = hyper_params["average_complex_size"] 93 | args.system_types = hyper_params["system_types"] 94 | args.lambda_link = hyper_params["lambda_link"] 95 | args.n_cluster = hyper_params["n_cluster"] 96 | args.flat = hyper_params["flat"] 97 | args.interaction_layer = hyper_params["interaction_layer"] 98 | args.pooling_layer = hyper_params["pooling_layer"] 99 | args.decoder_layer = hyper_params["decoder_layer"] 100 | args.norm = hyper_params["norm"] 101 | 102 | args.cuda = not args.no_cuda and torch.cuda.is_available() 103 | 104 | 105 | device = torch.device("cuda" if args.cuda else "cpu") 106 | loss_mse = MaskMSELoss() 107 | 108 | print(args) 109 | try: 110 | os.makedirs(args.outf) 111 | except OSError: 112 | pass 113 | 114 | try: 115 | os.makedirs(args.outf + "/" + args.exp_name) 116 | except OSError: 117 | pass 118 | 119 | # torch.autograd.set_detect_anomaly(True) 120 | 121 | def main(): 122 | # fix seed 123 | seed = args.seed 124 | random.seed(seed) 125 | np.random.seed(seed) 126 | torch.manual_seed(seed) 127 | torch.cuda.manual_seed(seed) 128 | 129 | n_complex, average_complex_size, system_types = args.n_complex, args.average_complex_size, args.system_types 130 | 131 | dataset_train = SimulationDataset(partition='train', max_samples=args.max_training_samples, n_complex=n_complex, 132 | average_complex_size=average_complex_size, system_types=system_types, 133 | data_dir=args.data_dir) 134 | loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, drop_last=True, 135 | num_workers=8, collate_fn=collector) 136 | 137 | dataset_val = SimulationDataset(partition='val', n_complex=n_complex, 138 | average_complex_size=average_complex_size, system_types=system_types, 139 | data_dir=args.data_dir) 140 | loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=args.batch_size, shuffle=True, drop_last=False, 141 | num_workers=8, collate_fn=collector) 142 | 143 | dataset_test = SimulationDataset(partition='test', n_complex=n_complex, 144 | average_complex_size=average_complex_size, system_types=system_types, 145 | data_dir=args.data_dir) 146 | loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=args.batch_size, shuffle=True, drop_last=False, 147 | num_workers=8, collate_fn=collector) 148 | 149 | if args.model == 'hier': 150 | model = EGHN(in_node_nf=1, in_edge_nf=2 + 1, hidden_nf=args.nf, device=device, 151 | n_cluster=args.n_cluster, flat=args.flat, layer_per_block=args.interaction_layer, 152 | layer_pooling=args.pooling_layer, activation=nn.SiLU(), norm=args.norm, 153 | layer_decoder=args.decoder_layer) 154 | else: 155 | raise NotImplementedError('Unknown model:', args.model) 156 | 157 | print(model) 158 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 159 | # 250 epoch no improvement. We will stop. 160 | model_save_path = args.outf + '/' + args.exp_name + '/' + 'saved_model.pth' 161 | early_stopping = EarlyStopping(patience=50, verbose=True, path=model_save_path) 162 | 163 | results = {'eval epoch': [], 'val loss': [], 'test loss': [], 'train loss': []} 164 | best_val_loss = 1e8 165 | best_test_loss = 1e8 166 | best_epoch = 0 167 | best_train_loss = 1e8 168 | for epoch in range(0, args.epochs): 169 | train_loss = train(model, optimizer, epoch, loader_train) 170 | results['train loss'].append(train_loss) 171 | if epoch % args.test_interval == 0: 172 | val_loss = train(model, optimizer, epoch, loader_val, backprop=False) 173 | test_loss = train(model, optimizer, epoch, loader_test, backprop=False) 174 | 175 | results['eval epoch'].append(epoch) 176 | results['val loss'].append(val_loss) 177 | results['test loss'].append(test_loss) 178 | if val_loss < best_val_loss: 179 | best_val_loss = val_loss 180 | best_test_loss = test_loss 181 | best_train_loss = train_loss 182 | best_epoch = epoch 183 | # Save model is move to early stopping. 184 | print("*** Best Val Loss: %.5f \t Best Test Loss: %.5f \t Best epoch %d" 185 | % (best_val_loss, best_test_loss, best_epoch)) 186 | early_stopping(val_loss, model) 187 | if early_stopping.early_stop: 188 | print("Early Stopping.") 189 | break 190 | 191 | json_object = json.dumps(results, indent=4) 192 | with open(args.outf + "/" + args.exp_name + "/loss.json", "w") as outfile: 193 | outfile.write(json_object) 194 | return best_train_loss, best_val_loss, best_test_loss, best_epoch 195 | 196 | 197 | def train(model, optimizer, epoch, loader, backprop=True): 198 | if backprop: 199 | model.train() 200 | else: 201 | model.eval() 202 | 203 | res = {'epoch': epoch, 'loss': 0, 'counter': 0, 'lp_loss': 0} 204 | 205 | for batch_idx, data in enumerate(loader): 206 | data = [d.to(device) for d in data[:-1]] + [data[-1]] 207 | loc, vel, edges, edge_attr, local_edge_mask, charges, loc_end, vel_end, mask, node_nums, n_nodes = data 208 | batch_size = loc.shape[0] // n_nodes 209 | 210 | optimizer.zero_grad() 211 | 212 | if args.model == 'hier': 213 | nodes = torch.sqrt(torch.sum(vel ** 2, dim=1)).unsqueeze(1).detach() 214 | rows, cols = edges 215 | loc_dist = torch.sum((loc[rows] - loc[cols])**2, 1).unsqueeze(1) # relative distances among locations 216 | edge_attr = torch.cat([edge_attr, loc_dist], 1).detach() # concatenate all edge properties 217 | local_edge_index, local_edge_fea = [edges[0][local_edge_mask], edges[1][local_edge_mask]], edge_attr[ 218 | local_edge_mask] 219 | loc_pred, vel_pred, _ = model(loc, nodes, edges, edge_attr, local_edge_index, local_edge_fea, 220 | n_node=n_nodes, v=vel, node_mask=mask, node_nums=node_nums) 221 | else: 222 | raise Exception("Wrong model") 223 | 224 | loss = loss_mse(loc_pred, loc_end, mask) 225 | # loss = loss_mse(loc_pred, loc_end) 226 | if args.model == 'hier': 227 | lp_loss = model.cut_loss 228 | res['lp_loss'] += lp_loss.item() * batch_size 229 | 230 | if backprop: 231 | # link prediction loss 232 | if args.model == 'hier': 233 | _lambda = args.lambda_link 234 | (loss + _lambda * lp_loss).backward() 235 | else: 236 | loss.backward() 237 | optimizer.step() 238 | res['loss'] += loss.item() * batch_size 239 | res['counter'] += batch_size 240 | 241 | # check the current pooling distribution 242 | if args.model == 'hier': 243 | model.inspect_pooling_plan() 244 | 245 | if not backprop: 246 | prefix = "==> " 247 | else: 248 | prefix = "" 249 | print('%s epoch %d avg loss: %.5f avg lploss: %.5f' 250 | % (prefix+loader.dataset.partition, epoch, res['loss'] / res['counter'], res['lp_loss'] / res['counter'])) 251 | 252 | return res['loss'] / res['counter'] 253 | 254 | 255 | if __name__ == "__main__": 256 | best_train_loss, best_val_loss, best_test_loss, best_epoch = main() 257 | print("best_train = %.6f" % best_train_loss) 258 | print("best_val = %.6f" % best_val_loss) 259 | print("best_test = %.6f" % best_test_loss) 260 | print("best_epoch = %d" % best_epoch) 261 | print("best_train = %.6f, best_val = %.6f, best_test = %.6f, best_epoch = %d" 262 | % (best_train_loss, best_val_loss, best_test_loss, best_epoch)) 263 | 264 | -------------------------------------------------------------------------------- /eval_simulation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.utils.data 4 | from motion.dataset import MotionDataset 5 | from simulation.dataset import SimulationDataset 6 | from model.eghn import EGHN 7 | from utils import collector_simulation as collector, MaskMSELoss 8 | from tqdm import tqdm 9 | import os 10 | from torch import nn, optim 11 | import json 12 | 13 | import random 14 | import numpy as np 15 | 16 | parser = argparse.ArgumentParser(description='VAE MNIST Example') 17 | parser.add_argument('--exp_name', type=str, default='exp_1', metavar='N', help='experiment_name') 18 | parser.add_argument('--batch_size', type=int, default=100, metavar='N', 19 | help='input batch size for training (default: 128)') 20 | parser.add_argument('--epochs', type=int, default=10000, metavar='N', 21 | help='number of epochs to train (default: 10)') 22 | parser.add_argument('--no-cuda', action='store_true', default=False, 23 | help='enables CUDA training') 24 | parser.add_argument('--seed', type=int, default=1, metavar='S', 25 | help='random seed (default: 1)') 26 | parser.add_argument('--log_interval', type=int, default=1, metavar='N', 27 | help='how many batches to wait before logging training status') 28 | parser.add_argument('--test_interval', type=int, default=5, metavar='N', 29 | help='how many epochs to wait before logging test') 30 | parser.add_argument('--outf', type=str, default='exp_results', metavar='N', 31 | help='folder to output the json log file') 32 | parser.add_argument('--lr', type=float, default=5e-4, metavar='N', 33 | help='learning rate') 34 | parser.add_argument('--nf', type=int, default=64, metavar='N', 35 | help='hidden dim') 36 | parser.add_argument('--model', type=str, default='hier', metavar='N') 37 | parser.add_argument('--n_layers', type=int, default=4, metavar='N', 38 | help='number of layers for the autoencoder') 39 | parser.add_argument('--max_training_samples', type=int, default=3000, metavar='N', 40 | help='maximum amount of training samples') 41 | parser.add_argument('--weight_decay', type=float, default=1e-12, metavar='N', 42 | help='timing experiment') 43 | parser.add_argument('--data_dir', type=str, default='spatial_graph/md17', 44 | help='Data directory.') 45 | parser.add_argument('--dropout', type=float, default=0.5, 46 | help='Dropout rate (1 - keep probability).') 47 | parser.add_argument("--config_by_file", default=False, action="store_true", ) 48 | 49 | parser.add_argument('--n_complex', type=int, default=5, 50 | help='Number of complex bodies.') 51 | parser.add_argument('--average_complex_size', type=int, default=3, 52 | help='Average size of complex bodies.') 53 | parser.add_argument('--system_types', type=int, default=5, 54 | help="The total number of system types.") 55 | 56 | parser.add_argument('--lambda_link', type=float, default=1, 57 | help='The weight of the linkage loss.') 58 | parser.add_argument('--n_cluster', type=int, default=3, 59 | help='The number of clusters.') 60 | parser.add_argument('--flat', action='store_true', default=False, 61 | help='flat MLP') 62 | parser.add_argument('--interaction_layer', type=int, default=3, 63 | help='The number of interaction layers per block.') 64 | parser.add_argument('--pooling_layer', type=int, default=3, 65 | help='The number of pooling layers in EGPN.') 66 | parser.add_argument('--decoder_layer', type=int, default=1, 67 | help='The number of decoder layers.') 68 | parser.add_argument('--norm', action='store_true', default=False, 69 | help='Use norm in EGNN') 70 | 71 | time_exp_dic = {'time': 0, 'counter': 0} 72 | 73 | args = parser.parse_args() 74 | # Place the checkpoint file here 75 | ckpt_file = args.outf + '/' + args.exp_name + '/' + 'saved_model.pth' 76 | 77 | if args.config_by_file: 78 | job_param_path = './job_param.json' 79 | with open(job_param_path, 'r') as f: 80 | hyper_params = json.load(f) 81 | args.exp_name = hyper_params["exp_name"] 82 | args.batch_size = hyper_params["batch_size"] 83 | args.epochs = hyper_params["epochs"] 84 | #args.no_cuda = hyper_params["no_cuda"] 85 | args.seed = hyper_params["seed"] 86 | args.lr = hyper_params["lr"] 87 | args.nf = hyper_params["nf"] 88 | args.model = hyper_params["model"] 89 | args.n_layers = hyper_params["n_layers"] 90 | args.max_training_samples = hyper_params["max_training_samples"] 91 | # Do not necessary in practice. 92 | args.data_dir = hyper_params["data_dir"] 93 | args.weight_decay = hyper_params["weight_decay"] 94 | 95 | args.dropout = hyper_params["dropout"] 96 | args.n_complex = hyper_params["n_complex"] 97 | args.average_complex_size = hyper_params["average_complex_size"] 98 | args.system_types = hyper_params["system_types"] 99 | 100 | args.lambda_link = hyper_params["lambda_link"] 101 | args.n_cluster = hyper_params["n_cluster"] 102 | args.flat = hyper_params["flat"] 103 | args.interaction_layer = hyper_params["interaction_layer"] 104 | args.pooling_layer = hyper_params["pooling_layer"] 105 | args.decoder_layer = hyper_params["decoder_layer"] 106 | args.norm = hyper_params["norm"] 107 | 108 | args.cuda = not args.no_cuda and torch.cuda.is_available() 109 | 110 | 111 | device = torch.device("cuda" if args.cuda else "cpu") 112 | loss_mse = MaskMSELoss() 113 | 114 | print(args) 115 | # torch.autograd.set_detect_anomaly(True) 116 | 117 | def main(): 118 | # fix seed 119 | seed = args.seed 120 | random.seed(seed) 121 | np.random.seed(seed) 122 | torch.manual_seed(seed) 123 | torch.cuda.manual_seed(seed) 124 | 125 | n_complex, average_complex_size, system_types = args.n_complex, args.average_complex_size, args.system_types 126 | 127 | dataset_train = SimulationDataset(partition='train', max_samples=args.max_training_samples, n_complex=n_complex, 128 | average_complex_size=average_complex_size, system_types=system_types, 129 | data_dir=args.data_dir) 130 | loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, drop_last=False, 131 | num_workers=8, collate_fn=collector) 132 | 133 | dataset_val = SimulationDataset(partition='val', n_complex=n_complex, 134 | average_complex_size=average_complex_size, system_types=system_types, 135 | data_dir=args.data_dir) 136 | loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=args.batch_size, shuffle=True, drop_last=False, 137 | num_workers=8, collate_fn=collector) 138 | 139 | dataset_test = SimulationDataset(partition='test', n_complex=n_complex, 140 | average_complex_size=average_complex_size, system_types=system_types, 141 | data_dir=args.data_dir) 142 | loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=args.batch_size, shuffle=True, drop_last=False, 143 | num_workers=8, collate_fn=collector) 144 | 145 | if args.model == 'hier': 146 | model = EGHN(in_node_nf=1, in_edge_nf=2 + 1, hidden_nf=args.nf, device=device, 147 | n_cluster=args.n_cluster, flat=args.flat, layer_per_block=args.interaction_layer, 148 | layer_pooling=args.pooling_layer, activation=nn.SiLU(), norm=args.norm, 149 | layer_decoder=args.decoder_layer) 150 | model.load_state_dict(torch.load(ckpt_file)) 151 | print('loaded from ', ckpt_file) 152 | else: 153 | raise NotImplementedError('Unknown model:', args.model) 154 | 155 | print(model) 156 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 157 | 158 | model.eval() 159 | save_name = args.outf + '/' + args.exp_name + '/' + 'eval_train.pkl' 160 | train_loss = train(model, optimizer, 0, loader_train, backprop=False, save_name=save_name) 161 | save_name = args.outf + '/' + args.exp_name + '/' + 'eval_test.pkl' 162 | test_loss = train(model, optimizer, 0, loader_test, backprop=False, save_name=save_name) 163 | exit(0) 164 | 165 | return best_train_loss, best_val_loss, best_test_loss, best_epoch 166 | 167 | 168 | def train(model, optimizer, epoch, loader, backprop=True, save_name=None): 169 | all_loc, all_loc_pred, all_loc_end = [], [], [] 170 | all_mask = [] 171 | all_local_edges = [] 172 | all_pooling_plan = [] 173 | 174 | if backprop: 175 | model.train() 176 | else: 177 | model.eval() 178 | 179 | res = {'epoch': epoch, 'loss': 0, 'counter': 0, 'lp_loss': 0} 180 | 181 | for batch_idx, data in tqdm(enumerate(loader)): 182 | data = [d.to(device) for d in data[:-1]] + [data[-1]] 183 | loc, vel, edges, edge_attr, local_edge_mask, charges, loc_end, vel_end, mask, node_nums, n_nodes = data 184 | batch_size = loc.shape[0] // n_nodes 185 | 186 | all_loc.extend(list(loc.reshape(batch_size, n_nodes, -1).detach().cpu().numpy())) 187 | all_loc_end.extend(list(loc_end.reshape(batch_size, n_nodes, -1).detach().cpu().numpy())) 188 | all_mask.extend(list(mask.reshape(batch_size, n_nodes, -1).bool().detach().cpu().numpy())) 189 | 190 | local_edges = [edges[0][local_edge_mask], edges[1][local_edge_mask]] 191 | new_local_edges = [[] for _ in range(batch_size)] 192 | for i in range(len(local_edges[0])): 193 | cur_row, cur_col = local_edges[0][i], local_edges[1][i] 194 | idx = cur_row // n_nodes 195 | new_local_edges[idx].append((cur_row - idx * n_nodes, cur_col - idx * n_nodes)) 196 | all_local_edges.extend(new_local_edges) 197 | 198 | optimizer.zero_grad() 199 | 200 | if args.model == 'hier': 201 | nodes = torch.sqrt(torch.sum(vel ** 2, dim=1)).unsqueeze(1).detach() 202 | rows, cols = edges 203 | loc_dist = torch.sum((loc[rows] - loc[cols])**2, 1).unsqueeze(1) # relative distances among locations 204 | edge_attr = torch.cat([edge_attr, loc_dist], 1).detach() # concatenate all edge properties 205 | local_edge_index, local_edge_fea = [edges[0][local_edge_mask], edges[1][local_edge_mask]], edge_attr[ 206 | local_edge_mask] 207 | loc_pred, vel_pred, _ = model(loc, nodes, edges, edge_attr, local_edge_index, local_edge_fea, 208 | n_node=n_nodes, v=vel, node_mask=mask, node_nums=node_nums) 209 | else: 210 | raise Exception("Wrong model") 211 | 212 | all_loc_pred.extend(list(loc_pred.reshape(batch_size, n_nodes, -1).detach().cpu().numpy())) 213 | 214 | cur_pooling_plan = model.current_pooling_plan 215 | all_pooling_plan.extend(list(cur_pooling_plan.reshape(batch_size, n_nodes, -1).detach().cpu().numpy())) 216 | 217 | loss = loss_mse(loc_pred, loc_end, mask) 218 | if args.model == 'hier': 219 | lp_loss = model.cut_loss 220 | # lp_loss = model.link_prediction_loss 221 | res['lp_loss'] += lp_loss.item() * batch_size 222 | 223 | if backprop: 224 | loss.backward() 225 | optimizer.step() 226 | res['loss'] += loss.item()*batch_size 227 | res['counter'] += batch_size 228 | 229 | import pickle as pkl 230 | with open(save_name, 'wb') as f: 231 | pkl.dump((all_loc, all_loc_end, all_loc_pred, all_pooling_plan, all_local_edges, all_mask), f) 232 | # pkl.dump((all_loc.detach().cpu().numpy(), 233 | # all_loc_end.detach().cpu().numpy(), 234 | # all_loc_pred.detach().cpu().numpy(), 235 | # all_pooling_plan.detach().cpu().numpy(), 236 | # all_cfg 237 | # ), f) 238 | print('Saved to ', save_name) 239 | 240 | if not backprop: 241 | prefix = "==> " 242 | else: 243 | prefix = "" 244 | print('%s epoch %d avg loss: %.5f' 245 | % (prefix+loader.dataset.partition, epoch, res['loss'] / res['counter'],)) 246 | 247 | return res['loss'] / res['counter'] 248 | 249 | 250 | if __name__ == "__main__": 251 | best_train_loss, best_val_loss, best_test_loss, best_epoch = main() 252 | print("best_train = %.6f" % best_train_loss) 253 | print("best_val = %.6f" % best_val_loss) 254 | print("best_test = %.6f" % best_test_loss) 255 | print("best_apoch = %d" % best_epoch) 256 | print("best_train = %.6f, best_val = %.6f, best_test = %.6f, best_apoch = %d" 257 | % (best_train_loss, best_val_loss, best_test_loss, best_epoch)) 258 | 259 | -------------------------------------------------------------------------------- /main_mocap.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.utils.data 4 | from motion.dataset import MotionDataset 5 | from model.eghn import EGHN 6 | import os 7 | from torch import nn, optim 8 | import json 9 | 10 | import random 11 | import numpy as np 12 | 13 | from utils import EarlyStopping 14 | 15 | parser = argparse.ArgumentParser(description='VAE MNIST Example') 16 | parser.add_argument('--exp_name', type=str, default='exp_1', metavar='N', help='experiment_name') 17 | parser.add_argument('--batch_size', type=int, default=100, metavar='N', 18 | help='input batch size for training (default: 128)') 19 | parser.add_argument('--epochs', type=int, default=10000, metavar='N', 20 | help='number of epochs to train (default: 10)') 21 | parser.add_argument('--no-cuda', action='store_true', default=False, 22 | help='enables CUDA training') 23 | parser.add_argument('--seed', type=int, default=1, metavar='S', 24 | help='random seed (default: 1)') 25 | parser.add_argument('--log_interval', type=int, default=1, metavar='N', 26 | help='how many batches to wait before logging training status') 27 | parser.add_argument('--test_interval', type=int, default=5, metavar='N', 28 | help='how many epochs to wait before logging test') 29 | parser.add_argument('--outf', type=str, default='exp_results', metavar='N', 30 | help='folder to output the json log file') 31 | parser.add_argument('--lr', type=float, default=5e-4, metavar='N', 32 | help='learning rate') 33 | parser.add_argument('--nf', type=int, default=64, metavar='N', 34 | help='hidden dim') 35 | parser.add_argument('--model', type=str, default='hier', metavar='N') 36 | parser.add_argument('--n_layers', type=int, default=4, metavar='N', 37 | help='number of layers for the autoencoder') 38 | parser.add_argument('--max_training_samples', type=int, default=3000, metavar='N', 39 | help='maximum amount of training samples') 40 | parser.add_argument('--weight_decay', type=float, default=1e-12, metavar='N', 41 | help='timing experiment') 42 | parser.add_argument('--delta_frame', type=int, default=30, 43 | help='Number of frames delta.') 44 | parser.add_argument('--data_dir', type=str, default='spatial_graph/md17', 45 | help='Data directory.') 46 | parser.add_argument('--dropout', type=float, default=0.5, 47 | help='Dropout rate (1 - keep probability).') 48 | parser.add_argument("--config_by_file", default=None, nargs="?", const='', type=str, ) 49 | 50 | parser.add_argument('--lambda_link', type=float, default=1, 51 | help='The weight of the linkage loss.') 52 | parser.add_argument('--n_cluster', type=int, default=3, 53 | help='The number of clusters.') 54 | parser.add_argument('--flat', action='store_true', default=False, 55 | help='flat MLP') 56 | parser.add_argument('--interaction_layer', type=int, default=3, 57 | help='The number of interaction layers per block.') 58 | parser.add_argument('--pooling_layer', type=int, default=3, 59 | help='The number of pooling layers in EGPN.') 60 | parser.add_argument('--decoder_layer', type=int, default=1, 61 | help='The number of decoder layers.') 62 | 63 | parser.add_argument('--case', type=str, default='walk', 64 | help='The case, walk or run.') 65 | 66 | 67 | time_exp_dic = {'time': 0, 'counter': 0} 68 | 69 | 70 | args = parser.parse_args() 71 | if args.config_by_file is not None: 72 | if len(args.config_by_file) == 0: 73 | job_param_path = './job_param.json' 74 | else: 75 | job_param_path = args.config_by_file 76 | with open(job_param_path, 'r') as f: 77 | hyper_params = json.load(f) 78 | args.exp_name = hyper_params["exp_name"] 79 | args.batch_size = hyper_params["batch_size"] 80 | args.epochs = hyper_params["epochs"] 81 | args.seed = hyper_params["seed"] 82 | args.lr = hyper_params["lr"] 83 | args.nf = hyper_params["nf"] 84 | args.model = hyper_params["model"] 85 | args.n_layers = hyper_params["n_layers"] 86 | args.max_training_samples = hyper_params["max_training_samples"] 87 | # Do not necessary in practice. 88 | args.data_dir = hyper_params["data_dir"] 89 | args.weight_decay = hyper_params["weight_decay"] 90 | 91 | args.dropout = hyper_params["dropout"] 92 | args.lambda_link = hyper_params["lambda_link"] 93 | args.n_cluster = hyper_params["n_cluster"] 94 | args.flat = hyper_params["flat"] 95 | args.interaction_layer = hyper_params["interaction_layer"] 96 | args.pooling_layer = hyper_params["pooling_layer"] 97 | args.decoder_layer = hyper_params["decoder_layer"] 98 | 99 | args.case = hyper_params["case"] 100 | 101 | args.cuda = not args.no_cuda and torch.cuda.is_available() 102 | 103 | device = torch.device("cuda" if args.cuda else "cpu") 104 | loss_mse = nn.MSELoss() 105 | 106 | print(args) 107 | try: 108 | os.makedirs(args.outf) 109 | except OSError: 110 | pass 111 | 112 | try: 113 | os.makedirs(args.outf + "/" + args.exp_name) 114 | except OSError: 115 | pass 116 | 117 | # torch.autograd.set_detect_anomaly(True) 118 | 119 | def main(): 120 | # fix seed 121 | seed = args.seed 122 | random.seed(seed) 123 | np.random.seed(seed) 124 | torch.manual_seed(seed) 125 | torch.cuda.manual_seed(seed) 126 | 127 | dataset_train = MotionDataset(partition='train', max_samples=args.max_training_samples, data_dir=args.data_dir, 128 | delta_frame=args.delta_frame, case=args.case) 129 | loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, drop_last=True, 130 | num_workers=8) 131 | 132 | dataset_val = MotionDataset(partition='val', max_samples=600, data_dir=args.data_dir, 133 | delta_frame=args.delta_frame, case=args.case) 134 | loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, drop_last=False, 135 | num_workers=8) 136 | 137 | dataset_test = MotionDataset(partition='test', max_samples=600, data_dir=args.data_dir, 138 | delta_frame=args.delta_frame, case=args.case) 139 | loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=args.batch_size, shuffle=False, drop_last=False, 140 | num_workers=8) 141 | 142 | if args.model == 'hier': 143 | model = EGHN(in_node_nf=2, in_edge_nf=2, hidden_nf=args.nf, device=device, 144 | n_cluster=args.n_cluster, flat=args.flat, layer_per_block=args.interaction_layer, 145 | layer_pooling=args.pooling_layer, activation=nn.SiLU(), 146 | layer_decoder=args.decoder_layer) 147 | else: 148 | raise NotImplementedError('Unknown model:', args.model) 149 | 150 | print(model) 151 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 152 | model_save_path = os.path.join(args.outf, args.exp_name, 'saved_model.pth') 153 | early_stopping = EarlyStopping(patience=50, verbose=True, path=model_save_path) 154 | 155 | results = {'eval epoch': [], 'val loss': [], 'test loss': [], 'train loss': []} 156 | best_val_loss = 1e8 157 | best_test_loss = 1e8 158 | best_epoch = 0 159 | best_train_loss = 1e8 160 | bast_lp_loss = 1e8 161 | for epoch in range(0, args.epochs): 162 | train_loss, lp_loss = train(model, optimizer, epoch, loader_train) 163 | results['train loss'].append(train_loss) 164 | if epoch % args.test_interval == 0: 165 | val_loss, _ = train(model, optimizer, epoch, loader_val, backprop=False) 166 | test_loss, _ = train(model, optimizer, epoch, loader_test, backprop=False) 167 | 168 | results['eval epoch'].append(epoch) 169 | results['val loss'].append(val_loss) 170 | results['test loss'].append(test_loss) 171 | if val_loss < best_val_loss: 172 | best_val_loss = val_loss 173 | best_test_loss = test_loss 174 | best_train_loss = train_loss 175 | best_epoch = epoch 176 | best_lp_loss = lp_loss 177 | # Save model is move to early stopping. 178 | print("*** Best Val Loss: %.5f \t Best Test Loss: %.5f \t Best epoch %d" 179 | % (best_val_loss, best_test_loss, best_epoch)) 180 | early_stopping(val_loss, model) 181 | if early_stopping.early_stop: 182 | print("Early Stopping.") 183 | break 184 | 185 | json_object = json.dumps(results, indent=4) 186 | with open(args.outf + "/" + args.exp_name + "/loss.json", "w") as outfile: 187 | outfile.write(json_object) 188 | return best_train_loss, best_val_loss, best_test_loss, best_epoch, best_lp_loss 189 | 190 | 191 | def train(model, optimizer, epoch, loader, backprop=True): 192 | if backprop: 193 | model.train() 194 | else: 195 | model.eval() 196 | 197 | res = {'epoch': epoch, 'loss': 0, 'counter': 0, 'lp_loss': 0} 198 | 199 | for batch_idx, data in enumerate(loader): 200 | batch_size, n_nodes, _ = data[0].size() 201 | data = [d.to(device) for d in data] 202 | # data = [d.view(-1, d.size(2)) for d in data] # construct mini-batch graphs 203 | loc, vel, edges, edge_attr, local_edges, local_edge_fea, Z, loc_end, vel_end = data 204 | # convert into graph minibatch 205 | loc = loc.view(-1, loc.size(2)) 206 | vel = vel.view(-1, vel.size(2)) 207 | offset = (torch.arange(batch_size) * n_nodes).unsqueeze(-1).unsqueeze(-1).to(edges.device) 208 | edges = torch.cat(list(edges + offset), dim=-1) # [2, BM] 209 | edge_attr = torch.cat(list(edge_attr), dim=0) # [BM, ] 210 | local_edge_index = torch.cat(list(local_edges + offset), dim=-1) # [2, BM] 211 | local_edge_fea = torch.cat(list(local_edge_fea), dim=0) # [BM, ] 212 | # local_edge_mask = torch.cat(list(local_edge_mask), dim=0) # [BM, ] 213 | Z = Z.view(-1, Z.size(2)) 214 | loc_end = loc_end.view(-1, loc_end.size(2)) 215 | vel_end = vel_end.view(-1, vel_end.size(2)) 216 | 217 | optimizer.zero_grad() 218 | 219 | if args.model == 'hier': 220 | nodes = torch.sqrt(torch.sum(vel ** 2, dim=1)).unsqueeze(1).detach() 221 | nodes = torch.cat((nodes, Z / Z.max()), dim=-1) 222 | rows, cols = edges 223 | loc_dist = torch.sum((loc[rows] - loc[cols])**2, 1).unsqueeze(1) # relative distances among locations 224 | edge_attr = torch.cat([edge_attr, loc_dist], 1).detach() # concatenate all edge properties 225 | loc_dist1 = torch.sum((loc[local_edge_index[0]] - loc[local_edge_index[1]])**2, 1).unsqueeze(1) 226 | local_edge_fea = torch.cat([local_edge_fea, loc_dist1], 1).detach() # concatenate all edge properties 227 | # loc_pred, vel_pred, _ = model(loc, nodes, edges, n_node=n_nodes, edge_fea=edge_attr, v=vel) 228 | # local_edge_index, local_edge_fea = [edges[0][local_edge_mask], edges[1][local_edge_mask]], edge_attr[ 229 | # local_edge_mask] 230 | loc_pred, vel_pred, _ = model(loc, nodes, edges, edge_attr, local_edge_index, local_edge_fea, 231 | n_node=n_nodes, v=vel, node_mask=None, node_nums=None) 232 | else: 233 | raise Exception("Wrong model") 234 | 235 | loss = loss_mse(loc_pred, loc_end) 236 | 237 | if args.model == 'hier': 238 | lp_loss = model.cut_loss 239 | res['lp_loss'] += lp_loss.item() * batch_size 240 | 241 | if backprop: 242 | # link prediction loss 243 | if args.model == 'hier': 244 | _lambda = args.lambda_link 245 | (loss + _lambda * lp_loss).backward() 246 | else: 247 | loss.backward() 248 | optimizer.step() 249 | res['loss'] += loss.item()*batch_size 250 | res['counter'] += batch_size 251 | 252 | # check the current pooling distribution 253 | if args.model == 'hier': 254 | model.inspect_pooling_plan() 255 | 256 | if not backprop: 257 | prefix = "==> " 258 | else: 259 | prefix = "" 260 | print('%s epoch %d avg loss: %.5f avg lploss: %.5f' 261 | % (prefix+loader.dataset.partition, epoch, res['loss'] / res['counter'], res['lp_loss'] / res['counter'])) 262 | 263 | return res['loss'] / res['counter'], res['lp_loss'] / res['counter'] 264 | 265 | 266 | if __name__ == "__main__": 267 | best_train_loss, best_val_loss, best_test_loss, best_epoch, best_lp_loss = main() 268 | print("best_train = %.6f" % best_train_loss) 269 | print("best_lp = %.6f" % best_lp_loss) 270 | print("best_val = %.6f" % best_val_loss) 271 | print("best_test = %.6f" % best_test_loss) 272 | print("best_epoch = %d" % best_epoch) 273 | print("best_train = %.6f, best_lp = %.6f, best_val = %.6f, best_test = %.6f, best_epoch = %d" 274 | % (best_train_loss, best_lp_loss, best_val_loss, best_test_loss, best_epoch)) 275 | 276 | -------------------------------------------------------------------------------- /simulation/datagen/physical_objects.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | eps = 1e-6 4 | 5 | 6 | def projection(va, vb): 7 | return np.dot(va, vb.T) / np.dot(vb, vb.T) * vb 8 | 9 | 10 | def get_rotation_matrix(theta, d): 11 | x, y, z = d[0], d[1], d[2] 12 | M = np.zeros((3, 3)) 13 | cos, sin = np.cos(theta), np.sin(theta) 14 | M[0][0] = cos + (1 - cos) * x * x 15 | M[0][1] = (1 - cos) * x * y - sin * z 16 | M[0][2] = (1 - cos) * x * z + sin * y 17 | M[1][0] = (1 - cos) * x * y + sin * z 18 | M[1][1] = cos + (1 - cos) * y * y 19 | M[1][2] = (1 - cos) * y * z - sin * x 20 | M[2][0] = (1 - cos) * x * z - sin * y 21 | M[2][1] = (1 - cos) * y * z + sin * x 22 | M[2][2] = cos + (1 - cos) * z * z 23 | return M 24 | 25 | 26 | class PhysicalObject: 27 | def __init__(self, n_balls, node_idx, charge, type): 28 | self.n_balls, self.node_idx, self.type = n_balls, node_idx, type 29 | self.charge = charge 30 | assert len(node_idx) == n_balls == len(charge) 31 | 32 | def initialize(self, X, V): 33 | raise NotImplementedError() 34 | 35 | def update(self, X, V, F, delta_t): 36 | raise NotImplementedError() 37 | 38 | def check(self, X, V): 39 | raise NotImplementedError() 40 | 41 | 42 | class Isolated(PhysicalObject): 43 | def __init__(self, n_balls, node_idx, charge, type): 44 | super().__init__(n_balls, node_idx, charge, type) 45 | 46 | def initialize(self, X, V): 47 | return X, V 48 | 49 | def update(self, X, V, F, delta_t): 50 | x, v, f = X[self.node_idx[0]], V[self.node_idx[0]], F[self.node_idx[0]] 51 | a = f / 1. 52 | 53 | # a[a > 100] = 100 54 | # a[a < - 100] = - 100 55 | 56 | v = v + a * delta_t 57 | x = x + v * delta_t 58 | X[self.node_idx[0]] = x 59 | V[self.node_idx[0]] = v 60 | return X, V 61 | 62 | def check(self, X, V): 63 | return True 64 | 65 | 66 | class Stick(PhysicalObject): 67 | def __init__(self, n_balls, node_idx, charge, type): 68 | super().__init__(n_balls, node_idx, charge, type) 69 | self.xc, self.vc, self.wc = None, None, None 70 | self.length = None 71 | 72 | def initialize(self, X, V): 73 | # check and adjust the initial conditions 74 | x, v = X[self.node_idx], V[self.node_idx] 75 | x0, x1 = x[0], x[1] 76 | v0, v1 = v[0], v[1] 77 | m0, m1 = 1., 1. 78 | # the velocity along the stick should be the same for two nodes (0, 1) 79 | d = x1 - x0 80 | v0_pro, v1_pro = projection(v0, d), projection(v1, d) 81 | v0_vert, v1_vert = v0 - v0_pro, v1 - v1_pro 82 | average_v_pro = (v0_pro + v1_pro) / 2 83 | v0, v1 = v0_vert + average_v_pro, v1_vert + average_v_pro 84 | 85 | xc = (m0 * x0 + m1 * x1) / (m0 + m1) 86 | vc = (m0 * v0 + m1 * v1) / (m0 + m1) 87 | relative_v0, relative_v1 = v0 - vc, v1 - vc 88 | r0, r1 = x0 - xc, x1 - xc 89 | w0, w1 = np.cross(r0, relative_v0) / np.dot(r0, r0.T), np.cross(r1, relative_v1) / np.dot(r1, r1.T) 90 | assert np.sum(np.abs(w0 - w1)) < 1e-5 91 | # book-keeping 92 | self.xc, self.vc, self.wc = xc, vc, w0 93 | self.length = np.sqrt(np.sum(d ** 2)) 94 | X[self.node_idx[0]], X[self.node_idx[1]] = x0, x1 95 | V[self.node_idx[0]], V[self.node_idx[1]] = v0, v1 96 | 97 | return X, V 98 | 99 | def update(self, X, V, F, delta_t): 100 | x, v, f = X[self.node_idx], V[self.node_idx], F[self.node_idx] 101 | x0, x1 = x[0], x[1] 102 | v0, v1 = v[0], v[1] 103 | f0, f1 = f[0], f[1] 104 | m0, m1 = 1., 1. 105 | xc, vc, wc = self.xc, self.vc, self.wc 106 | r0, r1 = x0 - xc, x1 - xc 107 | ac = (f0 + f1) / (m0 + m1) 108 | 109 | # ac[ac > 100] = 100 110 | # ac[ac < - 100] = - 100 111 | 112 | # update vc, xc 113 | vc = vc + ac * delta_t 114 | xc = xc + vc * delta_t 115 | 116 | # update wc 117 | J = m0 * np.dot(r0, r0.T) + m1 * np.dot(r1, r1.T) 118 | M = np.cross(r0, f0) + np.cross(r1, f1) 119 | beta = M / J 120 | 121 | # beta[beta > 100] = 100 122 | # beta[beta < - 100] = - 100 123 | 124 | wc = wc + beta * delta_t 125 | 126 | wc_norm = np.sqrt(np.dot(wc, wc.T)) 127 | theta = wc_norm * delta_t 128 | 129 | M = get_rotation_matrix(theta, wc / wc_norm) 130 | _r0 = np.matmul(M, r0.T).T 131 | _r1 = np.matmul(M, r1.T).T 132 | 133 | # update x and v 134 | x0, x1 = xc + _r0, xc + _r1 135 | v0, v1 = vc + np.cross(wc, _r0), vc + np.cross(wc, _r1) # here, we use the updated r (instead of original r) 136 | 137 | # book-keeping 138 | self.xc, self.vc, self.wc = xc, vc, wc 139 | X[self.node_idx[0]], X[self.node_idx[1]] = x0, x1 140 | V[self.node_idx[0]], V[self.node_idx[1]] = v0, v1 141 | 142 | return X, V 143 | 144 | def check(self, X, V): 145 | x, v = X[self.node_idx], V[self.node_idx] 146 | x0, x1 = x[0], x[1] 147 | v0, v1 = v[0], v[1] 148 | 149 | d = x1 - x0 150 | v0_pro, v1_pro = projection(v0, d), projection(v1, d) 151 | 152 | assert np.sum(np.abs(v0_pro - v1_pro)) < eps 153 | length = np.sqrt(np.sum(d ** 2)) 154 | assert np.abs(length - self.length) < eps 155 | 156 | 157 | class Hinge(PhysicalObject): 158 | def __init__(self, n_balls, node_idx, charge, type): 159 | super().__init__(n_balls, node_idx, charge, type) 160 | self.w1, self.w2 = None, None 161 | self.length1, self.length2 = None, None 162 | 163 | def initialize(self, X, V): 164 | # check and adjust the initial conditions 165 | x, v = X[self.node_idx], V[self.node_idx] 166 | x0, x1, x2 = x[0], x[1], x[2] 167 | v0, v1, v2 = v[0], v[1], v[2] 168 | # the velocity along the two beams should be the same for nodes (0, 1) and (0, 2), respectively 169 | d1, d2 = x1 - x0, x2 - x0 170 | v0_pro1, v0_pro2 = projection(v0, d1), projection(v0, d2) 171 | v1_pro, v2_pro = projection(v1, d1), projection(v2, d2) 172 | v1_vert, v2_vert = v1 - v1_pro, v2 - v2_pro 173 | v1, v2 = v0_pro1 + v1_vert, v0_pro2 + v2_vert 174 | 175 | r1, r2 = x1 - x0, x2 - x0 176 | v01, v02 = v1 - v0, v2 - v0 177 | w1, w2 = np.cross(r1, v01) / np.dot(r1, r1.T), np.cross(r2, v02) / np.dot(r2, r2.T) 178 | 179 | # book-keeping 180 | self.w1, self.w2 = w1, w2 181 | X[self.node_idx[0]], X[self.node_idx[1]], X[self.node_idx[2]] = x0, x1, x2 182 | V[self.node_idx[0]], V[self.node_idx[1]], V[self.node_idx[2]] = v0, v1, v2 183 | 184 | self.length1, self.length2 = np.sqrt(np.sum(d1 ** 2)), np.sqrt(np.sum(d2 ** 2)) 185 | 186 | return X, V 187 | 188 | def update(self, X, V, F, delta_t): 189 | x, v, f = X[self.node_idx], V[self.node_idx], F[self.node_idx] 190 | x0, x1, x2 = x[0], x[1], x[2] 191 | v0, v1, v2 = v[0], v[1], v[2] 192 | f0, f1, f2 = f[0], f[1], f[2] 193 | 194 | _f = f0 + f1 + f2 195 | r01, r02 = x1 - x0, x2 - x0 196 | v01, v02 = v1 - v0, v2 - v0 197 | w1, w2 = self.w1, self.w2 198 | e01, e02 = r01 / np.sqrt(np.dot(r01, r01.T)), r02 / np.sqrt(np.dot(r02, r02.T)) 199 | e01, e02 = e01.reshape(1, -1), e02.reshape(1, -1) 200 | A = np.eye(3) + np.matmul(e01.T, e01) + np.matmul(e02.T, e02) 201 | a = _f / 1. - np.cross(w1, v01) - np.cross(w2, v02) 202 | a = a - np.matmul((np.eye(3) - np.matmul(e01.T, e01)), f1 / 1.) - np.matmul((np.eye(3) - np.matmul(e02.T, e02)), 203 | f2 / 1.) 204 | a0 = np.matmul(np.linalg.inv(A), a) 205 | 206 | # update x0, v0 207 | v0 = v0 + a0 * delta_t 208 | x0 = x0 + v0 * delta_t 209 | 210 | # update w1, w2 211 | beta1 = np.cross(r01, f1 - 1. * a0) / (1. * np.dot(r01, r01.T)) 212 | beta2 = np.cross(r02, f2 - 1. * a0) / (1. * np.dot(r02, r02.T)) 213 | w1 = w1 + beta1 * delta_t 214 | w2 = w2 + beta2 * delta_t 215 | 216 | # update x, v 217 | w1_norm = np.sqrt(np.dot(w1, w1.T)) 218 | theta = w1_norm * delta_t 219 | M = get_rotation_matrix(theta, w1 / w1_norm) 220 | _r01 = np.matmul(M, r01.T).T 221 | x1 = x0 + _r01 222 | 223 | w2_norm = np.sqrt(np.dot(w2, w2.T)) 224 | theta = w2_norm * delta_t 225 | M = get_rotation_matrix(theta, w2 / w2_norm) 226 | _r02 = np.matmul(M, r02.T).T 227 | x2 = x0 + _r02 228 | 229 | v1, v2 = v0 + np.cross(w1, _r01), v0 + np.cross(w2, _r02) 230 | 231 | # book-keeping 232 | self.w1, self.w2 = w1, w2 233 | X[self.node_idx[0]], X[self.node_idx[1]], X[self.node_idx[2]] = x0, x1, x2 234 | V[self.node_idx[0]], V[self.node_idx[1]], V[self.node_idx[2]] = v0, v1, v2 235 | 236 | return X, V 237 | 238 | def check(self, X, V): 239 | x, v = X[self.node_idx], V[self.node_idx] 240 | x0, x1, x2 = x[0], x[1], x[2] 241 | v0, v1, v2 = v[0], v[1], v[2] 242 | d1 = x1 - x0 243 | d2 = x2 - x0 244 | 245 | length1, length2 = np.sqrt(np.sum(d1 ** 2)), np.sqrt(np.sum(d2 ** 2)) 246 | assert np.abs(length1 - self.length1) < eps 247 | assert np.abs(length2 - self.length2) < eps 248 | 249 | v0_pro1, v0_pro2 = projection(v0, d1), projection(v0, d2) 250 | v1_pro, v2_pro = projection(v1, d1), projection(v2, d2) 251 | assert np.sum(np.abs(v0_pro1 - v1_pro)) < eps 252 | assert np.sum(np.abs(v0_pro2 - v2_pro)) < eps 253 | 254 | 255 | class Complex(PhysicalObject): 256 | def __init__(self, n_balls, node_idx, charge, type): 257 | super().__init__(n_balls, node_idx, charge, type) 258 | self.xc, self.vc, self.wc = None, None, None 259 | self.length = None 260 | 261 | def initialize(self, X, V, X_c, rr): 262 | x, v = X[self.node_idx], V[self.node_idx] # [N, 3] 263 | if x.shape[0] == 1: 264 | # isolated particles 265 | return X, V 266 | 267 | # adjust the positions 268 | # xc = x[np.random.choice(np.arange(x.shape[0]), size=1, replace=False)] # [1, 3] 269 | rand_r = np.random.randn(x.shape[0], x.shape[1]) # [N, 3] 270 | rand_d = (np.random.rand(x.shape[0], 1) / 2 + 0.6) * rr # [N, 1] 271 | rand_r = rand_r / np.sqrt(np.sum(rand_r ** 2, axis=-1).reshape(-1, 1)) * rand_d # [N, 3] 272 | x = X_c + rand_r 273 | 274 | m = np.ones(x.shape[0]).reshape(-1, 1) # [N,] 275 | xc = np.sum(x * m, axis=0) / np.sum(m) # [3,] 276 | rc = x - xc # [N, 3] 277 | vc = np.sum(v * m, axis=0) / np.sum(m) # [3,] 278 | relative_v = v - vc # [N, 3] 279 | w = np.cross(rc, relative_v) / np.sum(rc ** 2, axis=-1).reshape(-1, 1) # [N, 3] 280 | # pooling over w 281 | wc = np.mean(w, axis=0) # [3,] 282 | self.xc, self.vc, self.wc = xc, vc, wc # [3,] 283 | ds = np.sqrt(np.sum(rc ** 2, axis=-1)) # [N,] 284 | self.length = ds 285 | new_v = np.cross(wc, rc) # [N, 3] 286 | new_x = x # [N, 3] 287 | X[self.node_idx] = new_x 288 | V[self.node_idx] = new_v 289 | 290 | return X, V 291 | 292 | def update(self, X, V, F, delta_t): 293 | x, v, f = X[self.node_idx], V[self.node_idx], F[self.node_idx] 294 | if x.shape[0] == 1: 295 | # isolated particles 296 | a = f / 1.0 297 | v = v + a * delta_t 298 | x = x + v * delta_t 299 | X[self.node_idx] = x 300 | V[self.node_idx] = v 301 | return X, V 302 | m = np.ones(x.shape[0]) # [N,] 303 | xc, vc, wc = self.xc, self.vc, self.wc # [3,] 304 | r = x - xc # [N, 3] 305 | # I did some modification here 306 | ac = np.mean(f, axis=0) / np.sum(m) # [3,] 307 | # ac = np.mean(f, axis=0) 308 | vc = vc + ac * delta_t 309 | xc = xc + vc * delta_t 310 | 311 | # I did some modification here 312 | # J = np.sum(np.sum(r ** 2, axis=-1) * m, axis=0) # [3,] 313 | temp = np.repeat(wc.T.reshape(1, -1), repeats=r.shape[0], axis=0) 314 | r_pro = (r * temp).sum(axis=-1) / (wc ** 2).sum() 315 | r_pro = r_pro.reshape(-1, 1) * np.repeat(wc.reshape(1, -1), repeats=r.shape[0], axis=0) 316 | r_d = r - r_pro 317 | J = np.sum(np.mean(r_d ** 2, axis=-1), axis=0) 318 | M = np.sum(np.cross(r, f), axis=0) # [3,] 319 | beta = M / J # [3,] 320 | 321 | wc = wc + beta * delta_t 322 | wc_norm = np.sqrt(np.dot(wc, wc.T)) 323 | theta = wc_norm * delta_t 324 | 325 | M = get_rotation_matrix(theta, wc / wc_norm) # [3, 3] 326 | _r = np.matmul(M, r.T).T # [N, 3] 327 | new_x = xc + _r # [N, 3] 328 | new_v = vc + np.cross(wc.reshape(1, -1), _r) # [N, 3] 329 | 330 | self.xc, self.vc, self.wc = xc, vc, wc 331 | X[self.node_idx] = new_x 332 | V[self.node_idx] = new_v 333 | 334 | return X, V 335 | 336 | def check(self, X, V): 337 | x, v = X[self.node_idx], V[self.node_idx] 338 | if x.shape[0] == 1: 339 | # isolated particles 340 | return 341 | m = np.ones(x.shape[0]).reshape(-1, 1) # [N,] 342 | xc = np.sum(x * m, axis=0) / np.sum(m) # [3,] 343 | rc = x - xc # [N, 3] 344 | ds = np.sqrt(np.sum(rc ** 2, axis=-1)) # [N,] 345 | assert np.sum(np.abs(ds - self.length)) < eps 346 | 347 | 348 | 349 | 350 | 351 | 352 | 353 | 354 | -------------------------------------------------------------------------------- /main_mdanalysis.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from argparse import Namespace 3 | import torch 4 | from torch.utils.data import DistributedSampler 5 | import torch.utils.data 6 | from mdanalysis.dataset import MDAnalysisDataset, collate_mda 7 | from model.eghn import EGHN 8 | import os 9 | from torch import nn, optim 10 | import json 11 | import time 12 | import random 13 | import numpy as np 14 | 15 | from utils import EarlyStopping 16 | 17 | parser = argparse.ArgumentParser(description='VAE MNIST Example') 18 | parser.add_argument('--exp_name', type=str, default='exp_1', metavar='N', help='experiment_name') 19 | parser.add_argument('--batch_size', type=int, default=100, metavar='N', 20 | help='input batch size for training (default: 128)') 21 | parser.add_argument('--epochs', type=int, default=10000, metavar='N', 22 | help='number of epochs to train (default: 10)') 23 | parser.add_argument('--no-cuda', action='store_true', default=False, 24 | help='enables CUDA training') 25 | parser.add_argument('--seed', type=int, default=42, metavar='S', 26 | help='random seed (default: 1)') 27 | parser.add_argument('--log_interval', type=int, default=1, metavar='N', 28 | help='how many batches to wait before logging training status') 29 | parser.add_argument('--test_interval', type=int, default=5, metavar='N', 30 | help='how many epochs to wait before logging test') 31 | parser.add_argument('--outf', type=str, default='exp_results', metavar='N', 32 | help='folder to output the json log file') 33 | parser.add_argument('--lr', type=float, default=5e-4, metavar='N', 34 | help='learning rate') 35 | parser.add_argument('--nf', type=int, default=64, metavar='N', 36 | help='hidden dim') 37 | parser.add_argument('--model', type=str, default='hier', metavar='N') 38 | parser.add_argument('--attention', type=int, default=0, metavar='N', 39 | help='attention in the ae model') 40 | parser.add_argument('--n_layers', type=int, default=4, metavar='N', 41 | help='number of layers for the autoencoder') 42 | parser.add_argument('--max_training_samples', type=int, default=3000, metavar='N', 43 | help='maximum amount of training samples') 44 | parser.add_argument('--dataset', type=str, default="nbody_small", metavar='N', 45 | help='nbody_small, nbody') 46 | parser.add_argument('--weight_decay', type=float, default=1e-12, metavar='N', 47 | help='timing experiment') 48 | parser.add_argument('--delta_frame', type=int, default=50, 49 | help='Number of frames delta.') 50 | parser.add_argument('--data_dir', type=str, default='YOUR_DATA_DIR', 51 | help='Data directory.') 52 | parser.add_argument('--dropout', type=float, default=0.5, 53 | help='Dropout rate (1 - keep probability).') 54 | parser.add_argument("--backbone", action="store_true", 55 | help="Load backbone data of protein") 56 | 57 | parser.add_argument('--lambda_link', type=float, default=1, help='The weight of the linkage loss.') 58 | parser.add_argument('--interaction_layer', type=int, default=3, help='The number of interaction layers per block.') 59 | parser.add_argument('--pooling_layer', type=int, default=3, help='The number of pooling layers in EGPN.') 60 | parser.add_argument('--decoder_layer', type=int, default=1, help='The number of decoder layers.') 61 | parser.add_argument('--n_cluster', type=int, default=20, help='The number of clusters.') 62 | parser.add_argument('--flat', action='store_true', default=False, help='flat MLP') 63 | parser.add_argument("--config_by_file", default=None, nargs='?', const='') 64 | parser.add_argument("--n_workers", '-n', type=int, default=8, help="Number of workers.") 65 | parser.add_argument("--load_cached", action="store_true", help="Load cached dataset.") 66 | parser.add_argument("--test_rot", action="store_true", help="Rotate the test") 67 | parser.add_argument("--test_trans", action="store_true", help="Translate the test") 68 | parser.add_argument("--enable_multi_gpus", action="store_true", help="Multi GPUs") 69 | 70 | time_exp_dic = {'time': 0, 'counter': 0} 71 | 72 | 73 | args = parser.parse_args() 74 | if args.config_by_file is not None: 75 | if len(args.config_by_file) == 0: 76 | job_param_path = './job_param.json' 77 | else: 78 | job_param_path = args.config_by_file 79 | with open(job_param_path, 'r') as f: 80 | hyper_params = json.load(f) 81 | # Only update existing keys 82 | args = vars(args) 83 | args.update((k, v) for k, v in hyper_params.items() if k in args) 84 | args = Namespace(**args) 85 | 86 | args.cuda = not args.no_cuda and torch.cuda.is_available() 87 | 88 | device = torch.device("cuda" if args.cuda else "cpu") 89 | master_worker = True 90 | 91 | loss_mse = nn.MSELoss() 92 | 93 | print(args) 94 | try: 95 | os.makedirs(args.outf) 96 | except OSError: 97 | pass 98 | 99 | try: 100 | os.makedirs(args.outf + "/" + args.exp_name) 101 | except OSError: 102 | pass 103 | 104 | # torch.autograd.set_detect_anomaly(True) 105 | 106 | def main(): 107 | # fix seed 108 | seed = args.seed 109 | random.seed(seed) 110 | np.random.seed(seed) 111 | torch.manual_seed(seed) 112 | torch.cuda.manual_seed(seed) 113 | 114 | dataset_train = MDAnalysisDataset('adk', partition='train', tmp_dir=args.data_dir, 115 | delta_frame=args.delta_frame, load_cached=args.load_cached, 116 | backbone=args.backbone) 117 | sampler = None 118 | shuffle = True 119 | loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=args.batch_size, 120 | shuffle=shuffle, sampler=sampler, drop_last=True, 121 | num_workers=args.n_workers, collate_fn=collate_mda) 122 | 123 | dataset_val = MDAnalysisDataset('adk', partition='valid', tmp_dir=args.data_dir, 124 | delta_frame=args.delta_frame, load_cached=args.load_cached, 125 | backbone=args.backbone) 126 | loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, 127 | drop_last=False, num_workers=args.n_workers, collate_fn=collate_mda) 128 | 129 | # Val and test do not need sampler. 130 | dataset_test = MDAnalysisDataset('adk', partition='test', tmp_dir=args.data_dir, 131 | delta_frame=args.delta_frame, load_cached=args.load_cached, 132 | test_rot=False, test_trans=False, 133 | backbone=args.backbone) 134 | loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=args.batch_size, 135 | shuffle=False, drop_last=False, 136 | num_workers=args.n_workers, collate_fn=collate_mda) 137 | 138 | dataset_test_hard = MDAnalysisDataset('adk', partition='test', tmp_dir=args.data_dir, 139 | delta_frame=args.delta_frame, load_cached=args.load_cached, 140 | test_rot=True, test_trans=True) 141 | loader_test_hard = torch.utils.data.DataLoader(dataset_test_hard, batch_size=args.batch_size, 142 | shuffle=False, drop_last=False, 143 | num_workers=args.n_workers, collate_fn=collate_mda) 144 | 145 | if args.load_cached and master_worker: 146 | print("Data loading finished.") 147 | 148 | if args.model == 'hier': 149 | model = EGHN(in_node_nf=2, in_edge_nf=2, hidden_nf=args.nf, device=device, 150 | n_cluster=args.n_cluster, flat=args.flat, layer_per_block=args.interaction_layer, 151 | layer_pooling=args.pooling_layer, activation=nn.SiLU(), 152 | layer_decoder=args.decoder_layer) 153 | else: 154 | raise NotImplementedError('Unknown model:', args.model) 155 | 156 | if master_worker: 157 | print(model) 158 | 159 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 160 | 161 | model_save_path = os.path.join(args.outf, args.exp_name, 'saved_model.pth') 162 | early_stopping = EarlyStopping(patience=50, verbose=True, path=model_save_path) 163 | 164 | 165 | results = {'eval epoch': [], 'val loss': [], 'test loss': [], 'train loss': [], 'test loss hard':[]} 166 | best_val_loss = 1e8 167 | best_test_loss = 1e8 168 | best_test_loss_hard = 1e8 169 | best_epoch = 0 170 | best_train_loss = 1e8 171 | best_lp_loss = 1e8 172 | for epoch in range(0, args.epochs): 173 | train_loss, lp_loss = train(model, optimizer, epoch, loader_train) 174 | results['train loss'].append(train_loss) 175 | if epoch % args.test_interval == 0: 176 | # every worker need evaluate this part! 177 | val_loss, _ = train(model, optimizer, epoch, loader_val, backprop=False) 178 | test_loss, _ = train(model, optimizer, epoch, loader_test, backprop=False) 179 | # test_loss_hard, _ = train(model, optimizer, epoch, loader_test_hard, backprop=False) 180 | test_loss_hard = 0 181 | 182 | results['eval epoch'].append(epoch) 183 | results['val loss'].append(val_loss) 184 | results['test loss'].append(test_loss) 185 | results['test loss hard'].append(test_loss_hard) 186 | if val_loss < best_val_loss: 187 | best_val_loss = val_loss 188 | best_test_loss = test_loss 189 | best_train_loss = train_loss 190 | best_test_loss_hard = test_loss_hard 191 | best_epoch = epoch 192 | best_lp_loss = lp_loss 193 | # Save model is move to early stopping. 194 | if master_worker: 195 | print("*** Best Val Loss: %.5f \t Best Test Loss: %.5f \t Best Hard Test Loss: %.5f \t Best epoch %d" 196 | % (best_val_loss, best_test_loss, best_test_loss_hard, best_epoch)) 197 | # only master worker will store the model. 198 | early_stopping(val_loss, model, master_worker) 199 | if early_stopping.early_stop: 200 | # This state is consistent for all workers. 201 | print("Early Stopping.") 202 | break 203 | 204 | if master_worker: 205 | json_object = json.dumps(results, indent=4) 206 | with open(args.outf + "/" + args.exp_name + "/loss.json", "w") as outfile: 207 | outfile.write(json_object) 208 | return best_train_loss, best_lp_loss, best_val_loss, best_test_loss, best_test_loss_hard, best_epoch 209 | 210 | 211 | def train(model, optimizer, epoch, loader, backprop=True): 212 | s = time.time() 213 | if backprop: 214 | model.train() 215 | else: 216 | model.eval() 217 | 218 | res = {'epoch': epoch, 'loss': 0, 'counter': 0, 'lp_loss': 0} 219 | 220 | #tqdm_loader = tqdm(loader, desc=f'Epoch {epoch}') 221 | for batch_idx, data in enumerate(loader): 222 | batch_size, n_nodes, _ = data[0].size() 223 | data = [d.to(device) for d in data] 224 | # data = [d.view(-1, d.size(2)) for d in data] # construct mini-batch graphs 225 | loc, vel, edges, edge_attr, local_edge_index, local_edge_fea, Z, loc_end, vel_end = data 226 | # convert into graph minibatch 227 | loc = loc.view(-1, loc.size(2)) 228 | # vel = vel.view(-1, vel.size(2)) 229 | # offset = (torch.arange(batch_size) * n_nodes).unsqueeze(-1).unsqueeze(-1).to(edges.device) 230 | # edges = torch.cat(list(edges + offset), dim=-1) # [2, BM] 231 | # edge_attr = torch.cat(list(edge_attr), dim=0) # [BM, ] 232 | # local_edge_index = torch.cat(list(local_edges + offset), dim=-1) # [2, BM] 233 | # local_edge_fea = torch.cat(list(local_edge_fea), dim=0) # [BM, ] 234 | # # local_edge_mask = torch.cat(list(local_edge_mask), dim=0) # [BM, ] 235 | # Z = Z.view(-1, Z.size(2)) 236 | # loc_end = loc_end.view(-1, loc_end.size(2)) 237 | # vel_end = vel_end.view(-1, vel_end.size(2)) 238 | 239 | optimizer.zero_grad() 240 | 241 | if args.model == 'hier': 242 | nodes = torch.sqrt(torch.sum(vel ** 2, dim=1)).unsqueeze(1).detach() 243 | nodes = torch.cat((nodes, Z / Z.max()), dim=-1) 244 | rows, cols = edges 245 | loc_dist = torch.sum((loc[rows] - loc[cols])**2, 1).unsqueeze(1) # relative distances among locations 246 | edge_attr = torch.cat([edge_attr, loc_dist], 1).detach() # concatenate all edge properties 247 | loc_dist1 = torch.sum((loc[local_edge_index[0]] - loc[local_edge_index[1]])**2, 1).unsqueeze(1) 248 | local_edge_fea = torch.cat([local_edge_fea, loc_dist1], 1).detach() # concatenate all edge properties 249 | loc_pred, vel_pred, _ = model(loc, nodes, edges, edge_attr, local_edge_index, local_edge_fea, 250 | n_node=n_nodes, v=vel, node_mask=None, node_nums=None) 251 | else: 252 | raise Exception("Wrong model") 253 | 254 | loss = loss_mse(loc_pred, loc_end) 255 | 256 | if args.model == 'hier': 257 | lp_loss = model.cut_loss 258 | res['lp_loss'] += lp_loss.item() * batch_size 259 | 260 | if backprop: 261 | # link prediction loss 262 | if args.model == 'hier': 263 | _lambda = args.lambda_link 264 | (loss + _lambda * lp_loss).backward() 265 | else: 266 | loss.backward() 267 | optimizer.step() 268 | res['loss'] += loss.item()*batch_size 269 | res['counter'] += batch_size 270 | 271 | # check the current pooling distribution 272 | if args.model == 'hier': 273 | model.inspect_pooling_plan() 274 | 275 | if not backprop: 276 | prefix = "==> " 277 | time_prefix = "val time" 278 | else: 279 | prefix = "" 280 | time_prefix = "traning time" 281 | print('%s epoch %d avg loss: %.5f avg lploss: %.5f, %s: %.5f' 282 | % (prefix+loader.dataset.partition, epoch, res['loss'] / res['counter'], res['lp_loss'] / res['counter'], time_prefix, time.time() - s)) 283 | 284 | return res['loss'] / res['counter'], res['lp_loss'] / res['counter'] 285 | 286 | 287 | if __name__ == "__main__": 288 | best_train_loss, best_lp_loss, best_val_loss, best_test_loss, best_test_loss_hard, best_epoch = main() 289 | if master_worker: 290 | print("best_train = %.6f" % best_train_loss) 291 | print("best_lp = %.6f" % best_lp_loss) 292 | print("best_val = %.6f" % best_val_loss) 293 | print("best_test = %.6f" % best_test_loss) 294 | print("best_test_hard = %.6f" % best_test_loss) 295 | print("best_epoch = %d" % best_epoch) 296 | print("best_train = %.6f, best_lp = %.6f, best_val = %.6f, best_test = %.6f, best_test_hard = %.6f, best_epoch = %d" 297 | % (best_train_loss, best_lp_loss, best_val_loss, best_test_loss, best_test_loss_hard, best_epoch)) 298 | 299 | -------------------------------------------------------------------------------- /simulation/datagen/system.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from physical_objects import Isolated, Stick, Hinge, Complex 3 | from tqdm import tqdm 4 | import matplotlib.pyplot as plt 5 | from mpl_toolkits.mplot3d import Axes3D 6 | import warnings 7 | warnings.filterwarnings("ignore") 8 | 9 | 10 | class System: 11 | def __init__(self, n_isolated, n_stick, n_hinge, n_complex=0, complex_sizes=None, delta_t=0.001, 12 | box_size=None, loc_std=1., vel_norm=0.5, 13 | interaction_strength=1., charge_types=None, 14 | ): 15 | self.n_isolated, n_stick, n_hinge = n_isolated, n_stick, n_hinge 16 | self.n_complex = n_complex 17 | self.complex_sizes = complex_sizes 18 | self.delta_t = delta_t 19 | self._max_F = 0.1 / self.delta_t # tentative setting 20 | self.box_size = box_size 21 | self.vel_norm = vel_norm 22 | self.interaction_strength = interaction_strength 23 | self.dim = 3 24 | 25 | if self.n_complex > 0: 26 | if complex_sizes is None: 27 | raise NotImplementedError('Really want complex bodies with random sizes?') 28 | rand_size = complex_sizes 29 | assert len(complex_sizes) == n_complex 30 | offset = np.sum(rand_size) 31 | else: 32 | rand_size = None 33 | offset = 0 34 | 35 | self.n_balls = n_isolated * 1 + n_stick * 2 + n_hinge * 3 36 | 37 | if self.n_complex > 0: 38 | self.n_balls += offset 39 | 40 | n = self.n_balls 41 | # self.loc_std = loc_std * (float(self.n_balls) / 5.) ** (1 / 3) 42 | self.loc_std = loc_std * (float(self.n_balls) / 5.) ** (1 / 3) + 0.5 43 | 44 | if charge_types is None: 45 | charge_types = [1.0, -1.0] 46 | self.charge_types = charge_types 47 | 48 | self.diag_mask = np.ones((n, n), dtype=bool) 49 | np.fill_diagonal(self.diag_mask, 0) 50 | 51 | charges = np.random.choice(self.charge_types, size=(self.n_balls, 1)) 52 | self.charges = charges 53 | edges = charges.dot(charges.transpose()) 54 | self.edges = edges 55 | 56 | assert self.n_isolated == 0 57 | # Hyper-parameters 58 | eps = 1.6 # the minimum distance between centers 59 | max_try = 20 60 | X_c = np.zeros((self.n_complex, self.dim)) 61 | for i in range(X_c.shape[0]): 62 | # sample 63 | counter = 0 64 | while True: 65 | xx = 2 * (np.random.rand(self.dim) - 0.5) * self.loc_std 66 | ok = True 67 | for j in range(i): 68 | d = np.sqrt(np.sum((X_c[j] - xx) ** 2, axis=-1)) 69 | if d < eps: 70 | ok = False 71 | break 72 | if ok: 73 | X_c[i] = xx 74 | break 75 | else: 76 | counter += 1 77 | if counter >= max_try: 78 | # have to accept 79 | print('max try, have to accept') 80 | X_c[i] = xx 81 | break 82 | min_d = 1e10 83 | for i in range(X_c.shape[0]): 84 | for j in range(i + 1, X_c.shape[0]): 85 | dd = np.sqrt(np.sum((X_c[i] - X_c[j]) ** 2, axis=-1)) 86 | min_d = min(min_d, dd) 87 | min_d = min_d / 1.6 88 | # print(min_d) 89 | # Initialize location and velocity 90 | # X = np.random.randn(n, self.dim) * self.loc_std # N(0, loc_std) 91 | X = np.random.randn(n, self.dim) * self.loc_std # N(0, loc_std) 92 | V = np.random.randn(n, self.dim) # N(0, 1) 93 | v_norm = np.sqrt((V ** 2).sum(axis=-1)).reshape(-1, 1) 94 | V = V / v_norm * self.vel_norm 95 | 96 | # initialize physical objects 97 | self.physical_objects = [] 98 | # node_idx = 0 99 | selected = [] 100 | # for _ in range(n_isolated): 101 | # rest = [idx for idx in range(self.n_balls) if idx not in selected] 102 | # node_idx = list(np.random.choice(rest, size=1, replace=False)) 103 | # current_obj = Isolated(n_balls=1, node_idx=node_idx, 104 | # charge=[charges[node_idx[0]]], type='Isolated') 105 | # selected.extend(node_idx) 106 | # 107 | # # current_obj = Isolated(n_balls=1, node_idx=[node_idx], 108 | # # charge=[charges[node_idx]], type='Isolated') 109 | # self.physical_objects.append(current_obj) 110 | # # node_idx += 1 111 | # 112 | # for _ in range(n_stick): 113 | # rest = [idx for idx in range(self.n_balls) if idx not in selected] 114 | # node_idx = list(np.random.choice(rest, size=2, replace=False)) 115 | # current_obj = Stick(n_balls=2, node_idx=node_idx, 116 | # charge=[charges[node_idx[0]], charges[node_idx[1]]], type='Stick') 117 | # selected.extend(node_idx) 118 | # 119 | # # current_obj = Stick(n_balls=2, node_idx=[node_idx, node_idx + 1], 120 | # # charge=[charges[node_idx], charges[node_idx + 1]], type='Stick') 121 | # self.physical_objects.append(current_obj) 122 | # # node_idx += 2 123 | # 124 | # for _ in range(n_hinge): 125 | # rest = [idx for idx in range(self.n_balls) if idx not in selected] 126 | # node_idx = list(np.random.choice(rest, size=3, replace=False)) 127 | # current_obj = Hinge(n_balls=3, node_idx=node_idx, 128 | # charge=[charges[node_idx[0]], charges[node_idx[1]], charges[node_idx[2]]], type='Hinge') 129 | # selected.extend(node_idx) 130 | # 131 | # # current_obj = Hinge(n_balls=3, node_idx=[node_idx, node_idx + 1, node_idx + 2], 132 | # # charge=[charges[node_idx], charges[node_idx + 1], charges[node_idx + 2]], type='Hinge') 133 | # self.physical_objects.append(current_obj) 134 | # # node_idx += 3 135 | 136 | for _ in range(n_complex): 137 | size = rand_size[_] 138 | rest = [idx for idx in range(self.n_balls) if idx not in selected] 139 | node_idx = list(np.random.choice(rest, size=size, replace=False)) 140 | current_obj = Complex(n_balls=size, node_idx=node_idx, charge=[charges[node_idx[i]] for i in range(size)], 141 | type='Complex') 142 | selected.extend(node_idx) 143 | self.physical_objects.append(current_obj) 144 | 145 | assert n == self.n_balls == len(selected) == len(np.unique(selected)) 146 | 147 | assert len(self.physical_objects) == X_c.shape[0] 148 | # check and adjust initial conditions 149 | for idx, obj in enumerate(self.physical_objects): 150 | # X, V = obj.initialize(X, V) 151 | X, V = obj.initialize(X, V, X_c=X_c[idx], rr=min_d) 152 | 153 | # book-keeping x and v 154 | self.X, self.V = X, V 155 | 156 | @staticmethod 157 | def _l2(A, B): 158 | A_norm = (A ** 2).sum(axis=1).reshape(A.shape[0], 1) 159 | B_norm = (B ** 2).sum(axis=1).reshape(1, B.shape[0]) 160 | dist = A_norm + B_norm - 2 * A.dot(B.transpose()) 161 | return dist 162 | 163 | def compute_F(self, X, V): 164 | n = self.n_balls 165 | with np.errstate(divide='ignore'): 166 | # half step leapfrog 167 | l2_dist_power3 = np.power( 168 | self._l2(X, X), 3. / 2.) 169 | 170 | # size of forces up to a 1/|r| factor 171 | # since I later multiply by an unnormalized r vector 172 | forces_size = self.interaction_strength * self.edges / l2_dist_power3 173 | np.fill_diagonal(forces_size, 0) # self forces are zero (fixes division by zero) 174 | assert (np.abs(forces_size[self.diag_mask]).min() > 1e-10) 175 | 176 | # here for minor precision issue with respect to the original script 177 | _X = X.T 178 | F = (forces_size.reshape(1, n, n) * 179 | np.concatenate(( 180 | np.subtract.outer(_X[0, :], 181 | _X[0, :]).reshape(1, n, n), 182 | np.subtract.outer(_X[1, :], 183 | _X[1, :]).reshape(1, n, n), 184 | np.subtract.outer(_X[2, :], 185 | _X[2, :]).reshape(1, n, n)))).sum(axis=-1) 186 | F = F.T 187 | 188 | # adjust F scale 189 | F[F > self._max_F] = self._max_F 190 | F[F < -self._max_F] = -self._max_F 191 | 192 | return F 193 | 194 | def simulate_one_step(self): 195 | X, V = self.X, self.V 196 | F = self.compute_F(X, V) 197 | for obj in self.physical_objects: 198 | X, V = obj.update(X, V, F, self.delta_t) 199 | self.X, self.V = X, V 200 | return X, V 201 | 202 | def check(self): 203 | for obj in self.physical_objects: 204 | obj.check(self.X, self.V) 205 | 206 | def is_valid(self): 207 | if self.box_size: 208 | return np.all(self.X <= self.box_size) and np.all(self.X >= - self.box_size) 209 | else: 210 | if (self.V > 5).any(): 211 | return False 212 | return True # no box size 213 | 214 | def configuration(self): 215 | cfg = {} 216 | for obj in self.physical_objects: 217 | _type = obj.type 218 | _node_idx = obj.node_idx 219 | if _type in cfg: 220 | cfg[_type].append(_node_idx) 221 | else: 222 | cfg[_type] = [_node_idx] 223 | return cfg 224 | 225 | def print(self): 226 | print('X:') 227 | print(self.X) 228 | print('V:') 229 | print(self.V) 230 | 231 | 232 | def visualize(): 233 | np.random.seed(89) 234 | sizes = [5, 5, 5, 5, 5, 5, 5, 5, 5, 5] 235 | system = System(n_isolated=0, n_complex=len(sizes), complex_sizes=sizes, n_hinge=0, n_stick=0) 236 | steps = 5001 237 | ret = [] 238 | for step in tqdm(range(steps)): 239 | system.simulate_one_step() 240 | ret.append((system.X.copy(), system.V.copy())) 241 | system.check() 242 | print(len(ret)) 243 | cfg = system.configuration() 244 | print(cfg) 245 | # exit(0) 246 | edges = [] 247 | # construct edges 248 | for obj_type in cfg: 249 | if obj_type == 'Isolated': 250 | pass 251 | elif obj_type == 'Complex': 252 | nodes = cfg[obj_type] 253 | # put together 254 | _nodes = [] 255 | for n in nodes: 256 | _nodes += n 257 | nodes = _nodes 258 | 259 | # print(nodes) 260 | split = sizes 261 | st = 0 262 | for sp in split: 263 | cur_nodes = nodes[st: st + sp] 264 | # fully connected 265 | for i in range(len(cur_nodes)): 266 | for j in range(len(cur_nodes)): 267 | if i != j: 268 | edges.append([cur_nodes[i], cur_nodes[j]]) 269 | st = st + sp 270 | else: 271 | raise NotImplementedError('Unknown object type:', obj_type) 272 | print(edges) 273 | # exit(0) 274 | x_start = ret[1100][0] 275 | print(x_start) 276 | print(x_start.shape) 277 | # exit(0) 278 | x_center = np.mean(x_start[..., 0]) 279 | y_center = np.mean(x_start[..., 1]) 280 | z_center = np.mean(x_start[..., 2]) 281 | 282 | fig = plt.figure() 283 | ax = Axes3D(fig) 284 | figure_3D_size = 8 285 | 286 | ax.set_xlim3d(z_center - figure_3D_size / 2, z_center + figure_3D_size / 2) 287 | ax.set_ylim3d(x_center - figure_3D_size / 2, x_center + figure_3D_size / 2) 288 | ax.set_zlim3d(y_center - figure_3D_size / 2, y_center + figure_3D_size / 2) 289 | 290 | N_NODE = sum(sizes) 291 | # plot the start position 292 | xs, ys, zs = [], [], [] 293 | for i in range(N_NODE): 294 | xs.append(x_start[i][0]) 295 | ys.append(x_start[i][1]) 296 | zs.append(x_start[i][2]) 297 | plt.plot(zs, xs, ys, 'b.') 298 | for edge in edges: 299 | xs = [x_start[edge[0]][0], x_start[edge[1]][0]] 300 | ys = [x_start[edge[0]][1], x_start[edge[1]][1]] 301 | zs = [x_start[edge[0]][2], x_start[edge[1]][2]] 302 | plt.plot(zs, xs, ys, 'b') 303 | # plot the end position 304 | xs, ys, zs = [], [], [] 305 | x_end = ret[2600][0] 306 | print(x_end) 307 | for i in range(N_NODE): 308 | xs.append(x_end[i][0]) 309 | ys.append(x_end[i][1]) 310 | zs.append(x_end[i][2]) 311 | plt.plot(zs, xs, ys, 'r.') 312 | for edge in edges: 313 | xs = [x_end[edge[0]][0], x_end[edge[1]][0]] 314 | ys = [x_end[edge[0]][1], x_end[edge[1]][1]] 315 | zs = [x_end[edge[0]][2], x_end[edge[1]][2]] 316 | plt.plot(zs, xs, ys, 'red') 317 | plt.show() 318 | 319 | print(np.sum((x_start[0] - x_start[1]) ** 2)) 320 | print(np.sum((x_end[0] - x_end[1]) ** 2)) 321 | 322 | 323 | def test(): 324 | np.random.seed(10) 325 | # system = System(n_isolated=4, n_stick=0, n_hinge=2) 326 | # n_balls = 20 327 | 328 | # system = System(n_isolated=10, n_stick=5, n_hinge=0) 329 | # np.random.seed(10) 330 | # system.X = np.random.rand(20, 3) 331 | # system.V = np.random.rand(20, 3) 332 | # charges = np.random.choice([1, -1], size=20).reshape(-1, 1) 333 | # system.edges = charges.dot(charges.transpose()) 334 | # system.charges = charges 335 | # for obj in system.physical_objects: 336 | # system.X, system.V = obj.initialize(system.X, system.V) 337 | 338 | # system.X = np.array([ 339 | # [-0.4400, 1.8563, -0.8407], 340 | # [ 1.8749, 0.8352, 0.7475], 341 | # [ 1.8236, -0.2337, 0.8648], 342 | # [0.2393, 0.3604, 0.9857], 343 | # [-0.9331, -1.2261, 2.5555], 344 | # ]) 345 | # system.V = np.array([ 346 | # [-0.1481, 0.1102, -0.1308], 347 | # [ 0.6803, 0.1725, -0.2684], 348 | # [ 0.2594, 0.2284, 0.0572], 349 | # [1.0518, -0.3935, 0.7333], 350 | # [-0.3815, 0.0661, 0.1284], 351 | # ]) 352 | # charges = np.array([-1, 1, 1, 1, 1]).reshape(-1, 1) 353 | # system.edges = charges.dot(charges.transpose()) 354 | # system.charges = charges 355 | # 356 | # system.print() 357 | # 358 | # for obj in system.physical_objects: 359 | # system.X, system.V = obj.initialize(system.X, system.V) 360 | # system.check() 361 | 362 | system = System(n_isolated=3, n_stick=5, n_hinge=0, n_complex=2, complex_sizes=[3, 4]) 363 | 364 | system.print() 365 | steps = 5001 366 | ret = [] 367 | for step in tqdm(range(steps)): 368 | system.simulate_one_step() 369 | ret.append((system.X.copy(), system.V.copy())) 370 | # system.check() 371 | system.print() 372 | # print(system.X.dtype) 373 | # print(system.charges) 374 | # print(system.is_valid()) 375 | return ret 376 | 377 | 378 | if __name__ == '__main__': 379 | # for i, F in enumerate() 380 | # ret = test() 381 | visualize() 382 | 383 | 384 | -------------------------------------------------------------------------------- /model/eghn.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | import torch.nn.functional as F 4 | from torch_sparse import spmm 5 | from model.basic import EGNN, EquivariantScalarNet, BaseMLP, aggregate, EGMN 6 | 7 | 8 | class EquivariantEdgeScalarNet(nn.Module): 9 | def __init__(self, n_vector_input, hidden_dim, activation, n_scalar_input=0, norm=True, flat=False): 10 | """ 11 | The universal O(n) equivariant network using scalars. 12 | :param n_input: The total number of input vectors. 13 | :param hidden_dim: The hidden dim of the network. 14 | :param activation: The activation function. 15 | """ 16 | super(EquivariantEdgeScalarNet, self).__init__() 17 | self.input_dim = n_vector_input * n_vector_input + n_scalar_input 18 | self.hidden_dim = hidden_dim 19 | self.output_dim = hidden_dim 20 | # self.output_dim = n_vector_input 21 | self.activation = activation 22 | self.norm = norm 23 | self.in_scalar_net = BaseMLP(self.input_dim, self.hidden_dim, self.hidden_dim, self.activation, last_act=True, 24 | flat=flat) 25 | self.out_vector_net = BaseMLP(self.hidden_dim, self.hidden_dim, n_vector_input * n_vector_input, 26 | self.activation, flat=flat) 27 | 28 | def forward(self, vectors_i, vectors_j, scalars=None): 29 | """ 30 | :param vectors: torch.Tensor with shape [N, 3, K] or a list of torch.Tensor 31 | :param scalars: torch.Tensor with shape [N, L] (Optional) 32 | :return: A vector that is equivariant to the O(n) transformations of input vectors with shape [N, 3] 33 | """ 34 | Z_i, Z_j = vectors_i, vectors_j # [N, 3, K] 35 | K = Z_i.shape[-1] 36 | Z_j_T = Z_j.transpose(-1, -2) # [N, K, 3] 37 | scalar = torch.einsum('bij,bjk->bik', Z_j_T, Z_i) # [N, K, K] 38 | scalar = scalar.reshape(-1, K * K) # [N, KK] 39 | if self.norm: 40 | scalar = F.normalize(scalar, p=2, dim=-1) # [N, KK] 41 | if scalars is not None: 42 | scalar = torch.cat((scalar, scalars), dim=-1) # [N, KK + L] 43 | scalar = self.in_scalar_net(scalar) # [N, H] 44 | vec_scalar = self.out_vector_net(scalar) # [N, KK] 45 | vec_scalar = vec_scalar.reshape(-1, Z_j.shape[-1], Z_i.shape[-1]) # [N, K, K] 46 | vector = torch.einsum('bij,bjk->bik', Z_j, vec_scalar) # [N, 3, K] 47 | return vector, scalar 48 | 49 | 50 | class PoolingLayer(nn.Module): 51 | def __init__(self, in_edge_nf, hidden_nf, n_vector_input, activation=nn.SiLU(), flat=False): 52 | super(PoolingLayer, self).__init__() 53 | self.edge_message_net = EquivariantEdgeScalarNet(n_vector_input=n_vector_input, hidden_dim=hidden_nf, 54 | activation=activation, n_scalar_input=2 * hidden_nf + in_edge_nf, 55 | norm=True, flat=flat) 56 | self.node_net = BaseMLP(input_dim=hidden_nf + hidden_nf, hidden_dim=hidden_nf, output_dim=hidden_nf, 57 | activation=activation, flat=flat) 58 | 59 | def forward(self, vectors, h, edge_index, edge_fea): 60 | """ 61 | :param vectors: the node vectors with shape: [BN, 3, V] where V is the number of vectors 62 | :param h: the scalar node feature with shape: [BN, K] 63 | :param edge_index: the edge index with shape [2, BM] 64 | :param edge_fea: the edge feature with shape: [BM, T] 65 | :return: the updated node vectors [BN, 3, V] and node scalar feature [BN, K] 66 | """ 67 | row, col = edge_index 68 | hij = torch.cat((h[row], h[col], edge_fea), dim=-1) # [BM, 2K+T] 69 | vectors_i, vectors_j = vectors[row], vectors[col] # [BM, 3, V] 70 | vectors_out, message = self.edge_message_net(vectors_i=vectors_i, vectors_j=vectors_j, scalars=hij) # [BM, 3, V] 71 | DIM, V = vectors_out.shape[-2], vectors_out.shape[-1] 72 | vectors_out = vectors_out.reshape(-1, DIM * V) # [BM, 3V] 73 | vectors_out = aggregate(message=vectors_out, row_index=row, n_node=h.shape[0], aggr='mean') # [BN, 3V] 74 | vectors_out = vectors_out.reshape(-1, DIM, V) # [BN, 3, V] 75 | vectors_out = vectors + vectors_out # [BN, 3, V] 76 | tot_message = aggregate(message=message, row_index=row, n_node=h.shape[0], aggr='sum') # [BN, K] 77 | node_message = torch.cat((h, tot_message), dim=-1) # [BN, K+K] 78 | h = self.node_net(node_message) + h # [BN, K] 79 | return vectors_out, h 80 | 81 | 82 | class PoolingNet(nn.Module): 83 | def __init__(self, n_layers, in_edge_nf, n_vector_input, 84 | hidden_nf, output_nf, activation=nn.SiLU(), device='cpu', flat=False): 85 | super(PoolingNet, self).__init__() 86 | self.layers = nn.ModuleList() 87 | self.n_layers = n_layers 88 | for i in range(self.n_layers): 89 | layer = PoolingLayer(in_edge_nf, hidden_nf, n_vector_input=n_vector_input, activation=activation, flat=flat) 90 | self.layers.append(layer) 91 | self.pooling = nn.Sequential( 92 | nn.Linear(hidden_nf, 8 * hidden_nf), 93 | nn.Tanh(), 94 | nn.Linear(8 * hidden_nf, output_nf) 95 | ) 96 | self.to(device) 97 | 98 | def forward(self, vectors, h, edge_index, edge_fea): 99 | if type(vectors) == list: 100 | vectors = torch.stack(vectors, dim=-1) # [BN, 3, V] 101 | for i in range(self.n_layers): 102 | vectors, h = self.layers[i](vectors, h, edge_index, edge_fea) 103 | pooling = self.pooling(h) 104 | return pooling # [BN, P] 105 | 106 | 107 | class EGHN(nn.Module): 108 | def __init__(self, in_node_nf, in_edge_nf, hidden_nf, n_cluster, layer_per_block=3, layer_pooling=3, layer_decoder=1, 109 | flat=False, activation=nn.SiLU(), device='cpu', norm=False, with_v=True): 110 | super(EGHN, self).__init__() 111 | node_hidden_dim = hidden_nf 112 | # input feature mapping 113 | self.embedding = nn.Linear(in_node_nf, hidden_nf) 114 | self.current_pooling_plan = None 115 | self.n_cluster = n_cluster # 4 for simulation and 5 for mocap 116 | self.n_layer_per_block = layer_per_block 117 | self.n_layer_pooling = layer_pooling 118 | self.n_layer_decoder = layer_decoder 119 | self.flat = flat 120 | self.with_v = with_v 121 | # low-level force net 122 | self.low_force_net = EGNN(n_layers=self.n_layer_per_block, 123 | in_node_nf=hidden_nf, in_edge_nf=in_edge_nf, hidden_nf=hidden_nf, 124 | activation=activation, device=device, with_v=with_v, flat=flat, norm=norm) 125 | self.low_pooling = PoolingNet(n_vector_input=3, hidden_nf=hidden_nf, output_nf=self.n_cluster, 126 | activation=activation, in_edge_nf=in_edge_nf, n_layers=self.n_layer_pooling, flat=flat) 127 | self.high_force_net = EGNN(n_layers=self.n_layer_per_block, 128 | in_node_nf=hidden_nf, in_edge_nf=1, hidden_nf=hidden_nf, 129 | activation=activation, device=device, with_v=with_v, flat=flat) 130 | _n_vector_input = 4 if self.with_v else 3 131 | if self.n_layer_decoder == 1: 132 | self.kinematics_net = EquivariantScalarNet(n_vector_input=_n_vector_input, 133 | hidden_dim=hidden_nf, 134 | activation=activation, 135 | n_scalar_input=node_hidden_dim + node_hidden_dim, 136 | norm=True, 137 | flat=flat) 138 | else: 139 | self.kinematics_net = EGMN(n_vector_input=_n_vector_input, hidden_dim=hidden_nf, activation=activation, 140 | n_scalar_input=node_hidden_dim + node_hidden_dim, norm=True, flat=flat, 141 | n_layers=self.n_layer_decoder) 142 | 143 | self.to(device) 144 | 145 | def forward(self, x, h, edge_index, edge_fea, local_edge_index, local_edge_fea, n_node, v=None, node_mask=None, node_nums=None): 146 | """ 147 | :param x: input positions [B * N, 3] 148 | :param h: input node feature [B * N, R] 149 | :param edge_index: edge index of the graph [2, B * M] 150 | :param edge_fea: input edge feature [B* M, T] 151 | :param local_edge_index: the edges used in pooling network [B * M'] 152 | :param local_edge_fea: the feature of local edges [B * M', T] 153 | :param n_node: number of nodes per graph [1, ] 154 | :param v: input velocities [B * N, 3] (Optional) 155 | :param node_mask: the node mask when number of nodes are different in graphs [B * N, ] (Optional) 156 | :param node_nums: the real number of nodes in each graph 157 | :return: 158 | """ 159 | h = self.embedding(h) # [R, K] 160 | row, col = edge_index 161 | 162 | ''' low level force ''' 163 | new_x, new_v, h = self.low_force_net(x, h, edge_index, edge_fea, v=v) # [BN, 3] 164 | nf = new_x - x # [BN, 3] 165 | 166 | ''' pooling network ''' 167 | if node_nums is None: 168 | x_mean = torch.mean(x.reshape(-1, n_node, x.shape[-1]), dim=1, keepdim=True).expand(-1, n_node, -1).reshape( 169 | -1, x.shape[-1]) 170 | else: 171 | pooled_mean = (torch.sum(x.reshape(-1, n_node, x.shape[-1]), dim=1).T/node_nums).T.unsqueeze(dim=1) #[B,1,3] 172 | x_mean = pooled_mean.expand(-1, n_node, -1).reshape(-1, x.shape[-1]) 173 | 174 | pooling_fea = self.low_pooling(vectors=[x - x_mean, nf, v], h=h, 175 | edge_index=local_edge_index, edge_fea=local_edge_fea) # [BN, P] 176 | 177 | hard_pooling = pooling_fea.argmax(dim=-1) 178 | hard_pooling = F.one_hot(hard_pooling, num_classes=self.n_cluster).float() 179 | pooling = F.softmax(pooling_fea, dim=1) 180 | self.current_pooling_plan = hard_pooling # record the pooling plan 181 | 182 | ''' derive high-level information (be careful with graph mini-batch) ''' 183 | s = pooling.reshape(-1, n_node, pooling.shape[-1]) # [B, N, P] 184 | 185 | sT = s.transpose(-2, -1) # [B, P, N] 186 | p_index = torch.ones_like(nf)[..., 0] # [BN, ] 187 | if node_mask is not None: 188 | p_index = p_index * node_mask 189 | p_index = p_index.reshape(-1, n_node, 1) # [B, N, 1] 190 | count = torch.einsum('bij,bjk->bik', sT, p_index).clamp_min(1e-5) # [B, P, 1] 191 | _x, _h, _nf = x.reshape(-1, n_node, x.shape[-1]), h.reshape(-1, n_node, h.shape[-1]), nf.reshape(-1, n_node, nf.shape[-1]) 192 | # [B, N, 3], [B, N, K], [B, N, 3] 193 | X, H, NF = torch.einsum('bij,bjk->bik', sT, _x), torch.einsum('bij,bjk->bik', sT, _h), torch.einsum('bij,bjk->bik', sT, _nf) 194 | if v is not None: 195 | _v = v.reshape(-1, n_node, v.shape[-1]) 196 | V = torch.einsum('bij,bjk->bik', sT, _v) 197 | V = V / count 198 | V = V.reshape(-1, V.shape[-1]) 199 | else: 200 | V = None 201 | X, H, NF = X / count, H / count, NF / count # [B, P, 3], [B, P, K], [B, P, 3] 202 | X, H, NF = X.reshape(-1, X.shape[-1]), H.reshape(-1, H.shape[-1]), NF.reshape(-1, NF.shape[-1]) # [BP, 3] 203 | 204 | a = spmm(torch.stack((local_edge_index[0], local_edge_index[1]), dim=0), 205 | torch.ones_like(local_edge_index[0]), x.shape[0], x.shape[0], pooling) # [BN, P] 206 | a = a.reshape(-1, n_node, a.shape[-1]) # [B, N, P] 207 | A = torch.einsum('bij,bjk->bik', sT, a) # [B, P, P] 208 | self.cut_loss = self.get_cut_loss(A) 209 | aa = spmm(torch.stack((row, col), dim=0), torch.ones_like(row), x.shape[0], x.shape[0], pooling) # [BN, P] 210 | aa = aa.reshape(-1, n_node, aa.shape[-1]) # [B, N, P] 211 | AA = torch.einsum('bij,bjk->bik', sT, aa) # [B, P, P] 212 | 213 | # construct high-level edges 214 | h_row, h_col, h_edge_fea, h_edge_mask = self.construct_edges(AA, AA.shape[-1]) # [BPP] 215 | ''' high-level message passing ''' 216 | h_new_x, h_new_v, h_new_h = self.high_force_net(X, H, (h_row, h_col), h_edge_fea.unsqueeze(-1), v=V) 217 | h_nf = h_new_x - X 218 | 219 | ''' high-level kinematics update ''' 220 | _X = X + h_nf # [BP, 3] 221 | _V = h_new_v # [BP, 3] 222 | _H = h_new_h # [BP, K] 223 | 224 | ''' low-level kinematics update ''' 225 | l_nf = h_nf.reshape(-1, AA.shape[1], h_nf.shape[-1]) # [B, P, 3] 226 | l_nf = torch.einsum('bij,bjk->bik', s, l_nf).reshape(-1, l_nf.shape[-1]) # [BN, 3] 227 | l_X = X.reshape(-1, AA.shape[1], X.shape[-1]) # [B, P, 3] 228 | l_X = torch.einsum('bij,bjk->bik', s, l_X).reshape(-1, l_X.shape[-1]) # [BN, 3] 229 | if v is not None: 230 | l_V = V.reshape(-1, AA.shape[1], V.shape[-1]) # [B, P, 3] 231 | l_V = torch.einsum('bij,bjk->bik', s, l_V).reshape(-1, l_V.shape[-1]) # [BN, 3] 232 | vectors = [l_nf, x - l_X, v - l_V, nf] 233 | else: 234 | vectors = [l_nf, x - l_X, nf] 235 | l_H = _H.reshape(-1, AA.shape[1], _H.shape[-1]) # [B, P, K] 236 | l_H = torch.einsum('bij,bjk->bik', s, l_H).reshape(-1, l_H.shape[-1]) # [BN, K] 237 | 238 | l_kinematics, h_out = self.kinematics_net(vectors=vectors, 239 | scalars=torch.cat((h, l_H), dim=-1)) # [BN, 3] 240 | _l_X = _X.reshape(-1, AA.shape[1], _X.shape[-1]) # [B, P, 3] 241 | _l_X = torch.einsum('bij,bjk->bik', s, _l_X).reshape(-1, _l_X.shape[-1]) # [BN, 3] 242 | x_out = _l_X + l_kinematics # [BN, 3] 243 | 244 | return (x_out, v, h_out) if v is not None else (x_out, h_out) 245 | 246 | def inspect_pooling_plan(self): 247 | plan = self.current_pooling_plan # [BN, P] 248 | if plan is None: 249 | print('No pooling plan!') 250 | return 251 | dist = torch.sum(plan, dim=0) # [P,] 252 | # print(dist) 253 | dist = F.normalize(dist, p=1, dim=0) # [P,] 254 | print('Pooling plan:', dist.detach().cpu().numpy()) 255 | return 256 | 257 | def get_cut_loss(self, A): 258 | A = F.normalize(A, p=2, dim=2) 259 | return torch.norm(A - torch.eye(A.shape[-1]).to(A.device), p="fro", dim=[1, 2]).mean() 260 | 261 | @staticmethod 262 | def construct_edges(A, n_node): 263 | h_edge_fea = A.reshape(-1) # [BPP] 264 | h_row = torch.arange(A.shape[1]).unsqueeze(-1).expand(-1, A.shape[1]).reshape(-1).to(A.device) 265 | h_col = torch.arange(A.shape[1]).unsqueeze(0).expand(A.shape[1], -1).reshape(-1).to(A.device) 266 | h_row = h_row.unsqueeze(0).expand(A.shape[0], -1) 267 | h_col = h_col.unsqueeze(0).expand(A.shape[0], -1) 268 | offset = (torch.arange(A.shape[0]) * n_node).unsqueeze(-1).to(A.device) 269 | h_row, h_col = (h_row + offset).reshape(-1), (h_col + offset).reshape(-1) # [BPP] 270 | h_edge_mask = torch.ones_like(h_row) # [BPP] 271 | h_edge_mask[torch.arange(A.shape[1]) * (A.shape[1] + 1)] = 0 272 | return h_row, h_col, h_edge_fea, h_edge_mask 273 | 274 | 275 | -------------------------------------------------------------------------------- /model/basic.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | def aggregate(message, row_index, n_node, aggr='sum', mask=None): 7 | """ 8 | The aggregation function (aggregate edge messages towards nodes) 9 | :param message: The edge message with shape [M, K] 10 | :param row_index: The row index of edges with shape [M] 11 | :param n_node: The number of nodes, N 12 | :param aggr: aggregation type, sum or mean 13 | :param mask: the edge mask (used in mean aggregation for counting degree) 14 | :return: The aggreagated node-wise information with shape [N, K] 15 | """ 16 | result_shape = (n_node, message.shape[1]) 17 | result = message.new_full(result_shape, 0) # [N, K] 18 | row_index = row_index.unsqueeze(-1).expand(-1, message.shape[1]) # [M, K] 19 | result.scatter_add_(0, row_index, message) # [N, K] 20 | if aggr == 'sum': 21 | pass 22 | elif aggr == 'mean': 23 | count = message.new_full(result_shape, 0) 24 | ones = torch.ones_like(message) 25 | if mask is not None: 26 | ones = ones * mask.unsqueeze(-1) 27 | count.scatter_add_(0, row_index, ones) 28 | result = result / count.clamp(min=1) 29 | else: 30 | raise NotImplementedError('Unknown aggregation method:', aggr) 31 | return result # [N, K] 32 | 33 | 34 | class BaseMLP(nn.Module): 35 | def __init__(self, input_dim, hidden_dim, output_dim, activation, residual=False, last_act=False, flat=False): 36 | super(BaseMLP, self).__init__() 37 | self.residual = residual 38 | if flat: 39 | activation = nn.Tanh() 40 | hidden_dim = 4 * hidden_dim 41 | if residual: 42 | assert output_dim == input_dim 43 | if last_act: 44 | self.mlp = nn.Sequential( 45 | nn.Linear(input_dim, hidden_dim), 46 | activation, 47 | nn.Linear(hidden_dim, output_dim), 48 | activation 49 | ) 50 | else: 51 | self.mlp = nn.Sequential( 52 | nn.Linear(input_dim, hidden_dim), 53 | activation, 54 | nn.Linear(hidden_dim, output_dim) 55 | ) 56 | 57 | def forward(self, x): 58 | return self.mlp(x) if not self.residual else self.mlp(x) + x 59 | 60 | 61 | class EquivariantScalarNet(nn.Module): 62 | def __init__(self, n_vector_input, hidden_dim, activation, n_scalar_input=0, norm=True, flat=True): 63 | """ 64 | The universal O(n) equivariant network using scalars. 65 | :param n_input: The total number of input vectors. 66 | :param hidden_dim: The hidden dim of the network. 67 | :param activation: The activation function. 68 | """ 69 | super(EquivariantScalarNet, self).__init__() 70 | self.input_dim = n_vector_input * n_vector_input + n_scalar_input 71 | self.hidden_dim = hidden_dim 72 | self.output_dim = hidden_dim 73 | # self.output_dim = n_vector_input 74 | self.activation = activation 75 | self.norm = norm 76 | self.in_scalar_net = BaseMLP(self.input_dim, self.hidden_dim, self.hidden_dim, self.activation, last_act=True, 77 | flat=flat) 78 | self.out_vector_net = BaseMLP(self.hidden_dim, self.hidden_dim, n_vector_input, self.activation, flat=flat) 79 | self.out_scalar_net = BaseMLP(self.hidden_dim, self.hidden_dim, self.output_dim, self.activation, flat=flat) 80 | 81 | def forward(self, vectors, scalars=None): 82 | """ 83 | :param vectors: torch.Tensor with shape [N, 3, K] or a list of torch.Tensor 84 | :param scalars: torch.Tensor with shape [N, L] (Optional) 85 | :return: A vector that is equivariant to the O(n) transformations of input vectors with shape [N, 3] 86 | """ 87 | if type(vectors) == list: 88 | Z = torch.stack(vectors, dim=-1) # [N, 3, K] 89 | else: 90 | Z = vectors 91 | K = Z.shape[-1] 92 | Z_T = Z.transpose(-1, -2) # [N, K, 3] 93 | scalar = torch.einsum('bij,bjk->bik', Z_T, Z) # [N, K, K] 94 | scalar = scalar.reshape(-1, K * K) # [N, KK] 95 | if self.norm: 96 | scalar = F.normalize(scalar, p=2, dim=-1) # [N, KK] 97 | if scalars is not None: 98 | scalar = torch.cat((scalar, scalars), dim=-1) # [N, KK + L] 99 | scalar = self.in_scalar_net(scalar) # [N, K] 100 | vec_scalar = self.out_vector_net(scalar) # [N, K] 101 | vector = torch.einsum('bij,bj->bi', Z, vec_scalar) # [N, 3] 102 | scalar = self.out_scalar_net(scalar) # [N, H] 103 | 104 | return vector, scalar 105 | 106 | 107 | class InvariantScalarNet(nn.Module): 108 | def __init__(self, n_vector_input, hidden_dim, output_dim, activation, n_scalar_input=0, norm=True, last_act=False, 109 | flat=False): 110 | """ 111 | The universal O(n) invariant network using scalars. 112 | :param n_vector_input: The total number of input vectors. 113 | :param hidden_dim: The hidden dim of the network. 114 | :param activation: The activation function. 115 | """ 116 | super(InvariantScalarNet, self).__init__() 117 | self.input_dim = n_vector_input * n_vector_input + n_scalar_input 118 | self.hidden_dim = hidden_dim 119 | self.output_dim = output_dim 120 | self.activation = activation 121 | self.norm = norm 122 | self.scalar_net = BaseMLP(self.input_dim, self.hidden_dim, self.output_dim, self.activation, last_act=last_act, 123 | flat=flat) 124 | 125 | def forward(self, vectors, scalars=None): 126 | """ 127 | :param vectors: torch.Tensor with shape [N, 3, K] or a list of torch.Tensor with shape [N, 3] 128 | :param scalars: torch.Tensor with shape [N, L] (Optional) 129 | :return: A scalar that is invariant to the O(n) transformations of input vectors with shape [N, K] 130 | """ 131 | if type(vectors) == list: 132 | Z = torch.stack(vectors, dim=-1) # [N, 3, K] 133 | else: 134 | Z = vectors 135 | K = Z.shape[-1] 136 | Z_T = Z.transpose(-1, -2) # [N, K, 3] 137 | scalar = torch.einsum('bij,bjk->bik', Z_T, Z) # [N, K, K] 138 | scalar = scalar.reshape(-1, K * K) # [N, KK] 139 | if self.norm: 140 | scalar = F.normalize(scalar, p=2, dim=-1) # [N, KK] 141 | if scalars is not None: 142 | scalar = torch.cat((scalar, scalars), dim=-1) # [N, KK + L] 143 | scalar = self.scalar_net(scalar) # [N, K] 144 | return scalar 145 | 146 | 147 | class EGNN_Layer(nn.Module): 148 | def __init__(self, in_edge_nf, hidden_nf, activation=nn.SiLU(), with_v=False, flat=False, norm=False): 149 | super(EGNN_Layer, self).__init__() 150 | self.with_v = with_v 151 | self.edge_message_net = InvariantScalarNet(n_vector_input=1, hidden_dim=hidden_nf, output_dim=hidden_nf, 152 | activation=activation, n_scalar_input=2 * hidden_nf + in_edge_nf, 153 | norm=norm, last_act=True, flat=flat) 154 | self.coord_net = BaseMLP(input_dim=hidden_nf, hidden_dim=hidden_nf, output_dim=1, activation=activation, 155 | flat=flat) 156 | self.node_net = BaseMLP(input_dim=hidden_nf + hidden_nf, hidden_dim=hidden_nf, output_dim=hidden_nf, 157 | activation=activation, flat=flat) 158 | if self.with_v: 159 | self.node_v_net = BaseMLP(input_dim=hidden_nf, hidden_dim=hidden_nf, output_dim=1, activation=activation, 160 | flat=flat) 161 | else: 162 | self.node_v_net = None 163 | 164 | def forward(self, x, h, edge_index, edge_fea, v=None): 165 | row, col = edge_index 166 | rij = x[row] - x[col] # [BM, 3] 167 | hij = torch.cat((h[row], h[col], edge_fea), dim=-1) # [BM, 2K+T] 168 | message = self.edge_message_net(vectors=[rij], scalars=hij) # [BM, 3] 169 | coord_message = self.coord_net(message) # [BM, 1] 170 | f = (x[row] - x[col]) * coord_message # [BM, 3] 171 | tot_f = aggregate(message=f, row_index=row, n_node=x.shape[0], aggr='mean') # [BN, 3] 172 | tot_f = torch.clamp(tot_f, min=-100, max=100) 173 | 174 | if v is not None: 175 | x = x + self.node_v_net(h) * v + tot_f 176 | else: 177 | x = x + tot_f # [BN, 3] 178 | 179 | tot_message = aggregate(message=message, row_index=row, n_node=x.shape[0], aggr='sum') # [BN, K] 180 | node_message = torch.cat((h, tot_message), dim=-1) # [BN, K+K] 181 | h = self.node_net(node_message) # [BN, K] 182 | return x, v, h 183 | 184 | 185 | class EGNN(nn.Module): 186 | def __init__(self, n_layers, in_node_nf, in_edge_nf, hidden_nf, activation=nn.SiLU(), device='cpu', with_v=False, 187 | flat=False, norm=False): 188 | super(EGNN, self).__init__() 189 | self.layers = nn.ModuleList() 190 | self.n_layers = n_layers 191 | self.with_v = with_v 192 | # input feature mapping 193 | self.embedding = nn.Linear(in_node_nf, hidden_nf) 194 | for i in range(self.n_layers): 195 | layer = EGNN_Layer(in_edge_nf, hidden_nf, activation=activation, with_v=with_v, flat=flat, norm=norm) 196 | self.layers.append(layer) 197 | self.to(device) 198 | 199 | def forward(self, x, h, edge_index, edge_fea, v=None): 200 | h = self.embedding(h) 201 | for i in range(self.n_layers): 202 | x, v, h = self.layers[i](x, h, edge_index, edge_fea, v=v) 203 | return (x, v, h) if v is not None else (x, h) 204 | 205 | 206 | class EGMN(nn.Module): 207 | def __init__(self, n_layers, n_vector_input, hidden_dim, n_scalar_input, activation=nn.SiLU(), norm=False, flat=False): 208 | super(EGMN, self).__init__() 209 | self.layers = nn.ModuleList() 210 | self.n_layers = n_layers 211 | for i in range(self.n_layers): 212 | cur_layer = EquivariantScalarNet(n_vector_input=n_vector_input + i, hidden_dim=hidden_dim, 213 | activation=activation, n_scalar_input=n_scalar_input if i == 0 else hidden_dim, 214 | norm=norm, flat=flat) 215 | self.layers.append(cur_layer) 216 | 217 | def forward(self, vectors, scalars): 218 | cur_vectors = vectors 219 | for i in range(self.n_layers): 220 | vector, scalars = self.layers[i](cur_vectors, scalars) 221 | cur_vectors.append(vector) 222 | return cur_vectors[-1], scalars 223 | 224 | 225 | class GNN_Layer(nn.Module): 226 | def __init__(self, in_edge_nf, hidden_nf, activation=nn.SiLU(), with_v=False, flat=False): 227 | super(GNN_Layer, self).__init__() 228 | self.with_v = with_v 229 | self.edge_message_net = BaseMLP(input_dim=in_edge_nf + 2 * hidden_nf, hidden_dim=hidden_nf, output_dim=hidden_nf, 230 | activation=activation, flat=flat) 231 | self.node_net = BaseMLP(input_dim=hidden_nf + hidden_nf, hidden_dim=hidden_nf, output_dim=hidden_nf, 232 | activation=activation, flat=flat) 233 | 234 | def forward(self, h, edge_index, edge_fea): 235 | row, col = edge_index 236 | hij = torch.cat((h[row], h[col], edge_fea), dim=-1) # [BM, 2K+T] 237 | message = self.edge_message_net(hij) # [BM, K] 238 | agg = aggregate(message=message, row_index=row, n_node=h.shape[0], aggr='mean') # [BN, K] 239 | h = h + self.node_net(torch.cat((agg, h), dim=-1)) 240 | return h 241 | 242 | 243 | class GNN(nn.Module): 244 | def __init__(self, n_layers, in_node_nf, in_edge_nf, hidden_nf, activation=nn.SiLU(), device='cpu', flat=False): 245 | super(GNN, self).__init__() 246 | self.layers = nn.ModuleList() 247 | self.n_layers = n_layers 248 | # input feature mapping 249 | self.embedding = nn.Linear(in_node_nf, hidden_nf) 250 | for i in range(self.n_layers): 251 | layer = GNN_Layer(in_edge_nf, hidden_nf, activation=activation, flat=flat) 252 | self.layers.append(layer) 253 | self.decoder = nn.Sequential( 254 | nn.Linear(hidden_nf, hidden_nf), 255 | activation, 256 | nn.Linear(hidden_nf, 3) 257 | ) 258 | self.to(device) 259 | 260 | def forward(self, h, edge_index, edge_fea): 261 | h = self.embedding(h) 262 | for i in range(self.n_layers): 263 | h = self.layers[i](h, edge_index, edge_fea) 264 | h = self.decoder(h) 265 | return h 266 | 267 | 268 | class Linear_dynamics(nn.Module): 269 | def __init__(self, device='cpu'): 270 | super(Linear_dynamics, self).__init__() 271 | self.time = nn.Parameter(torch.ones(1)) 272 | self.device = device 273 | self.to(self.device) 274 | 275 | def forward(self, x, v): 276 | return x + v * self.time 277 | 278 | 279 | class RF_vel(nn.Module): 280 | def __init__(self, hidden_nf, edge_attr_nf=0, device='cpu', act_fn=nn.SiLU(), n_layers=4): 281 | super(RF_vel, self).__init__() 282 | self.hidden_nf = hidden_nf 283 | self.device = device 284 | self.n_layers = n_layers 285 | for i in range(0, n_layers): 286 | self.add_module("gcl_%d" % i, GCL_rf_vel(nf=hidden_nf, edge_attr_nf=edge_attr_nf, act_fn=act_fn)) 287 | self.to(self.device) 288 | 289 | def forward(self, vel_norm, x, edges, vel, edge_attr): 290 | for i in range(0, self.n_layers): 291 | x, _ = self._modules["gcl_%d" % i](x, vel_norm, vel, edges, edge_attr) 292 | return x 293 | 294 | 295 | class GCL_rf_vel(nn.Module): 296 | def __init__(self, nf=64, edge_attr_nf=0, act_fn=nn.LeakyReLU(0.2), coords_weight=1.0): 297 | super(GCL_rf_vel, self).__init__() 298 | self.coords_weight = coords_weight 299 | self.coord_mlp_vel = nn.Sequential( 300 | nn.Linear(1, nf), 301 | act_fn, 302 | nn.Linear(nf, 1)) 303 | 304 | layer = nn.Linear(nf, 1, bias=False) 305 | torch.nn.init.xavier_uniform_(layer.weight, gain=0.001) 306 | self.phi = nn.Sequential(nn.Linear(1 + edge_attr_nf, nf), 307 | act_fn, 308 | layer, 309 | nn.Tanh()) 310 | 311 | def forward(self, x, vel_norm, vel, edge_index, edge_attr=None): 312 | row, col = edge_index 313 | edge_m = self.edge_model(x[row], x[col], edge_attr) 314 | x = self.node_model(x, edge_index, edge_m) 315 | x += vel * self.coord_mlp_vel(vel_norm) 316 | return x, edge_attr 317 | 318 | def edge_model(self, source, target, edge_attr): 319 | x_diff = source - target 320 | radial = torch.sqrt(torch.sum(x_diff ** 2, dim=1)).unsqueeze(1) 321 | e_input = torch.cat([radial, edge_attr], dim=1) 322 | e_out = self.phi(e_input) 323 | m_ij = x_diff * e_out 324 | return m_ij 325 | 326 | def node_model(self, x, edge_index, edge_m): 327 | row, col = edge_index 328 | agg = unsorted_segment_mean(edge_m, row, num_segments=x.size(0)) 329 | x_out = x + agg * self.coords_weight 330 | return x_out 331 | 332 | 333 | def unsorted_segment_mean(data, segment_ids, num_segments): 334 | result_shape = (num_segments, data.size(1)) 335 | segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1)) 336 | result = data.new_full(result_shape, 0) # Init empty result tensor. 337 | count = data.new_full(result_shape, 0) 338 | result.scatter_add_(0, segment_ids, data) 339 | count.scatter_add_(0, segment_ids, torch.ones_like(data)) 340 | return result / count.clamp(min=1) 341 | 342 | 343 | class FullMLP(nn.ModuleList): 344 | def __init__(self, in_node_nf, hidden_nf, n_layers, activation=nn.SiLU(), flat=False, device='cpu'): 345 | super(FullMLP, self).__init__() 346 | self.layers = nn.ModuleList() 347 | self.embedding = nn.Linear(in_node_nf, hidden_nf) 348 | for i in range(n_layers): 349 | self.layers.append(BaseMLP(hidden_nf, hidden_nf, hidden_nf, activation, 350 | residual=True, last_act=True, flat=flat)) 351 | self.output = nn.Linear(hidden_nf, 3) 352 | self.to(device) 353 | 354 | def forward(self, x): 355 | x = self.embedding(x) 356 | for i in range(len(self.layers)): 357 | x = self.layers[i](x) 358 | return self.output(x) 359 | --------------------------------------------------------------------------------