├── .gitignore ├── LICENSE ├── README.md ├── config.py ├── data.py ├── data ├── data_Rope │ ├── demo │ │ ├── 0.param │ │ ├── 0.rollout.h5 │ │ ├── 1.param │ │ ├── 1.rollout.h5 │ │ ├── 100.rollout.h5 │ │ ├── 101.rollout.h5 │ │ ├── 102.rollout.h5 │ │ ├── 103.rollout.h5 │ │ ├── 104.rollout.h5 │ │ ├── 105.rollout.h5 │ │ ├── 106.rollout.h5 │ │ ├── 107.rollout.h5 │ │ ├── 108.rollout.h5 │ │ ├── 125.rollout.h5 │ │ ├── 126.rollout.h5 │ │ ├── 127.rollout.h5 │ │ ├── 128.rollout.h5 │ │ ├── 129.rollout.h5 │ │ ├── 130.rollout.h5 │ │ ├── 131.rollout.h5 │ │ ├── 132.rollout.h5 │ │ ├── 133.rollout.h5 │ │ ├── 150.rollout.h5 │ │ ├── 151.rollout.h5 │ │ ├── 152.rollout.h5 │ │ ├── 153.rollout.h5 │ │ ├── 154.rollout.h5 │ │ ├── 155.rollout.h5 │ │ ├── 156.rollout.h5 │ │ ├── 157.rollout.h5 │ │ ├── 158.rollout.h5 │ │ ├── 175.rollout.h5 │ │ ├── 176.rollout.h5 │ │ ├── 177.rollout.h5 │ │ ├── 178.rollout.h5 │ │ ├── 179.rollout.h5 │ │ ├── 180.rollout.h5 │ │ ├── 181.rollout.h5 │ │ ├── 182.rollout.h5 │ │ ├── 183.rollout.h5 │ │ ├── 2.param │ │ ├── 2.rollout.h5 │ │ ├── 25.rollout.h5 │ │ ├── 26.rollout.h5 │ │ ├── 27.rollout.h5 │ │ ├── 28.rollout.h5 │ │ ├── 29.rollout.h5 │ │ ├── 3.param │ │ ├── 3.rollout.h5 │ │ ├── 30.rollout.h5 │ │ ├── 31.rollout.h5 │ │ ├── 32.rollout.h5 │ │ ├── 33.rollout.h5 │ │ ├── 4.param │ │ ├── 4.rollout.h5 │ │ ├── 5.param │ │ ├── 5.rollout.h5 │ │ ├── 50.rollout.h5 │ │ ├── 51.rollout.h5 │ │ ├── 52.rollout.h5 │ │ ├── 53.rollout.h5 │ │ ├── 54.rollout.h5 │ │ ├── 55.rollout.h5 │ │ ├── 56.rollout.h5 │ │ ├── 57.rollout.h5 │ │ ├── 58.rollout.h5 │ │ ├── 6.param │ │ ├── 6.rollout.h5 │ │ ├── 7.param │ │ ├── 7.rollout.h5 │ │ ├── 75.rollout.h5 │ │ ├── 76.rollout.h5 │ │ ├── 77.rollout.h5 │ │ ├── 78.rollout.h5 │ │ ├── 79.rollout.h5 │ │ ├── 8.rollout.h5 │ │ ├── 80.rollout.h5 │ │ ├── 81.rollout.h5 │ │ ├── 82.rollout.h5 │ │ └── 83.rollout.h5 │ └── stat_demo.h5 ├── data_Soft │ ├── demo │ │ ├── 0.param │ │ ├── 0.rollout.h5 │ │ ├── 1.param │ │ ├── 1.rollout.h5 │ │ ├── 100.rollout.h5 │ │ ├── 101.rollout.h5 │ │ ├── 102.rollout.h5 │ │ ├── 103.rollout.h5 │ │ ├── 104.rollout.h5 │ │ ├── 105.rollout.h5 │ │ ├── 106.rollout.h5 │ │ ├── 107.rollout.h5 │ │ ├── 108.rollout.h5 │ │ ├── 125.rollout.h5 │ │ ├── 126.rollout.h5 │ │ ├── 127.rollout.h5 │ │ ├── 128.rollout.h5 │ │ ├── 129.rollout.h5 │ │ ├── 130.rollout.h5 │ │ ├── 131.rollout.h5 │ │ ├── 132.rollout.h5 │ │ ├── 133.rollout.h5 │ │ ├── 150.rollout.h5 │ │ ├── 151.rollout.h5 │ │ ├── 152.rollout.h5 │ │ ├── 153.rollout.h5 │ │ ├── 154.rollout.h5 │ │ ├── 155.rollout.h5 │ │ ├── 156.rollout.h5 │ │ ├── 157.rollout.h5 │ │ ├── 158.rollout.h5 │ │ ├── 175.rollout.h5 │ │ ├── 176.rollout.h5 │ │ ├── 177.rollout.h5 │ │ ├── 178.rollout.h5 │ │ ├── 179.rollout.h5 │ │ ├── 180.rollout.h5 │ │ ├── 181.rollout.h5 │ │ ├── 182.rollout.h5 │ │ ├── 183.rollout.h5 │ │ ├── 2.param │ │ ├── 2.rollout.h5 │ │ ├── 25.rollout.h5 │ │ ├── 26.rollout.h5 │ │ ├── 27.rollout.h5 │ │ ├── 28.rollout.h5 │ │ ├── 29.rollout.h5 │ │ ├── 3.param │ │ ├── 3.rollout.h5 │ │ ├── 30.rollout.h5 │ │ ├── 31.rollout.h5 │ │ ├── 32.rollout.h5 │ │ ├── 33.rollout.h5 │ │ ├── 4.param │ │ ├── 4.rollout.h5 │ │ ├── 5.param │ │ ├── 5.rollout.h5 │ │ ├── 50.rollout.h5 │ │ ├── 51.rollout.h5 │ │ ├── 52.rollout.h5 │ │ ├── 53.rollout.h5 │ │ ├── 54.rollout.h5 │ │ ├── 55.rollout.h5 │ │ ├── 56.rollout.h5 │ │ ├── 57.rollout.h5 │ │ ├── 58.rollout.h5 │ │ ├── 6.param │ │ ├── 6.rollout.h5 │ │ ├── 7.param │ │ ├── 7.rollout.h5 │ │ ├── 75.rollout.h5 │ │ ├── 76.rollout.h5 │ │ ├── 77.rollout.h5 │ │ ├── 78.rollout.h5 │ │ ├── 79.rollout.h5 │ │ ├── 8.rollout.h5 │ │ ├── 80.rollout.h5 │ │ ├── 81.rollout.h5 │ │ ├── 82.rollout.h5 │ │ └── 83.rollout.h5 │ └── stat_demo.h5 └── data_Swim │ ├── demo │ ├── 0.param │ ├── 0.rollout.h5 │ ├── 1.param │ ├── 1.rollout.h5 │ ├── 100.rollout.h5 │ ├── 101.rollout.h5 │ ├── 102.rollout.h5 │ ├── 103.rollout.h5 │ ├── 104.rollout.h5 │ ├── 105.rollout.h5 │ ├── 106.rollout.h5 │ ├── 107.rollout.h5 │ ├── 108.rollout.h5 │ ├── 125.rollout.h5 │ ├── 126.rollout.h5 │ ├── 127.rollout.h5 │ ├── 128.rollout.h5 │ ├── 129.rollout.h5 │ ├── 130.rollout.h5 │ ├── 131.rollout.h5 │ ├── 132.rollout.h5 │ ├── 133.rollout.h5 │ ├── 150.rollout.h5 │ ├── 151.rollout.h5 │ ├── 152.rollout.h5 │ ├── 153.rollout.h5 │ ├── 154.rollout.h5 │ ├── 155.rollout.h5 │ ├── 156.rollout.h5 │ ├── 157.rollout.h5 │ ├── 158.rollout.h5 │ ├── 175.rollout.h5 │ ├── 176.rollout.h5 │ ├── 177.rollout.h5 │ ├── 178.rollout.h5 │ ├── 179.rollout.h5 │ ├── 180.rollout.h5 │ ├── 181.rollout.h5 │ ├── 182.rollout.h5 │ ├── 183.rollout.h5 │ ├── 2.param │ ├── 2.rollout.h5 │ ├── 25.rollout.h5 │ ├── 26.rollout.h5 │ ├── 27.rollout.h5 │ ├── 28.rollout.h5 │ ├── 29.rollout.h5 │ ├── 3.param │ ├── 3.rollout.h5 │ ├── 30.rollout.h5 │ ├── 31.rollout.h5 │ ├── 32.rollout.h5 │ ├── 33.rollout.h5 │ ├── 4.param │ ├── 4.rollout.h5 │ ├── 5.param │ ├── 5.rollout.h5 │ ├── 50.rollout.h5 │ ├── 51.rollout.h5 │ ├── 52.rollout.h5 │ ├── 53.rollout.h5 │ ├── 54.rollout.h5 │ ├── 55.rollout.h5 │ ├── 56.rollout.h5 │ ├── 57.rollout.h5 │ ├── 58.rollout.h5 │ ├── 6.param │ ├── 6.rollout.h5 │ ├── 7.param │ ├── 7.rollout.h5 │ ├── 75.rollout.h5 │ ├── 76.rollout.h5 │ ├── 77.rollout.h5 │ ├── 78.rollout.h5 │ ├── 79.rollout.h5 │ ├── 8.rollout.h5 │ ├── 80.rollout.h5 │ ├── 81.rollout.h5 │ ├── 82.rollout.h5 │ └── 83.rollout.h5 │ └── stat_demo.h5 ├── dump_Rope └── train_Rope_CKO_demo │ └── net_best.pth ├── dump_Soft └── train_Soft_CKO_demo │ └── net_best.pth ├── dump_Swim └── train_Swim_CKO_demo │ └── net_best.pth ├── eval.py ├── figures ├── rope.gif ├── soft.gif └── swim.gif ├── models ├── CompositionalKoopmanOperators.py ├── KoopmanBaselineModel.py └── __init__.py ├── physics_engine.py ├── preprocess_data.py ├── requirements.txt ├── scripts ├── eval_Rope.sh ├── eval_Soft.sh ├── eval_Swim.sh ├── mpc_Rope.sh ├── mpc_Soft.sh ├── mpc_Swim.sh ├── train_Rope.sh ├── train_Soft.sh └── train_Swim.sh ├── shoot.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__ 3 | *.zip 4 | .idea 5 | *.swp 6 | *.swoarchived 7 | models/__pycache__ 8 | data/data_Rope/valid 9 | data/data_Rope/train 10 | data/data_Rope/stat.h5 11 | data/data_Soft/valid 12 | data/data_Soft/train 13 | data/data_Soft/stat.h5 14 | data/data_Swim/valid 15 | data/data_Swim/train 16 | data/data_Swim/stat.h5 17 | dump_Rope/eval* 18 | dump_Rope/shoot* 19 | dump_Rope/train* 20 | dump_Soft/eval* 21 | dump_Soft/shoot* 22 | dump_Soft/train* 23 | dump_Swim/eval* 24 | dump_Swim/shoot* 25 | dump_Swim/train* 26 | 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Yunzhu Li, Hao He, Jiajun Wu, Dina Katabi, Antonio Torralba 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning Compositional Koopman Operators for Model-Based Control 2 | 3 | Yunzhu Li*, Hao He*, Jiajun Wu, Dina Katabi, Antonio Torralba 4 | 5 | (* indicates equal contributions) 6 | 7 | **ICLR 2020 (Spotlight)** 8 | [[website]](http://koopman.csail.mit.edu/) [[openreview]](https://openreview.net/forum?id=H1ldzA4tPr) [[video]](https://youtu.be/MnXo_hjh1Q4) 9 | 10 | Demo 11 | ------------- 12 | #### Rope Manipulation 13 | In this task, we manipulate a rope by applying forces to the top of it. The target shape is shown as red dots. 14 | Note that in different cases, the number of rope masses, the gravity constant and the spring stiffness are different. 15 | ![](figures/rope.gif) 16 | 17 | #### Soft Robot - Swing 18 | Let's consider controlling a soft robot composed of boxes with different materials such as rigid tissue (gray color) and soft tissue (light blue color). 19 | Some of the boxes have actuators which can perform contracting (red color) or expanding (green color). 20 | This task is to swing it to a target shape shown as red grids. 21 | ![](figures/soft.gif) 22 | 23 | #### Soft Robot - Swim 24 | This task is to make the soft robot swim forward to a target shape shown as red grids. 25 | ![](figures/swim.gif) 26 | 27 | 28 | Installation 29 | ------------- 30 | This codebase is tested with Ubuntu 16.04 LTS, Python 3.6+, PyTorch 1.2.0+, and CUDA 10.0. 31 | Dependencies can be found in `requirements.txt`. 32 | 33 | Play with Pre-trained Models 34 | ------------- 35 | We provide pretrained models in the three environments as shown in the demo. The model parameteres are stored in the `dump_{env}/train_{env}_CKO_demo/net_best.pth`. 36 | 37 | We provide the following scripts to run simulation with the pretained model. The result will be stored in `dump_{env}/eval_{env}_CKO_demo/`. 38 | 39 | bash scripts/eval_Rope.sh 40 | bash scripts/eval_Soft.sh 41 | bash scripts/eval_Swim.sh 42 | 43 | 44 | We provide the following scripts to perform model-based control with the pretained model. The result will be stored in `dump_{env}/shoot_{env}_CKO_demo/`. 45 | 46 | bash scripts/mpc_Rope.sh 47 | bash scripts/mpc_Soft.sh 48 | bash scripts/mpc_Swim.sh 49 | 50 | For Mac users, you will be able to get the results in the form of image frames, but it is possible that the generated video is empty due to the incorrect combination of the file extension and the fourcc when using the OpenCV video writer (see this [[issue]](https://github.com/YunzhuLi/CompositionalKoopmanOperators/issues/1)). 51 | 52 | 53 | 54 | Train Your Own Models 55 | ----- 56 | 57 | ### Training 58 | We also provide the scripts to train compositional koopman operators from scratch. 59 | **Note that if you are running the script for the first time**, it will start by generating training and validation data in parallel using `num_workers` threads. 60 | To aviod unnecessary data generation, you need to set `--gen_data 0` in the scripts, if the data has already been generated. 61 | 62 | bash scripts/train_Rope.sh 63 | bash scripts/train_Soft.sh 64 | bash scripts/train_Swim.sh 65 | 66 | ### Evaluation and Model-Based Control 67 | You can simply use the scripts introduced in the previous section,`eval_{env}.sh` and `shoot_{env}.sh`, to evaluate the model you trained. 68 | Do not the forget to change the argument `--eval_set` and `--shoot_set` to `valid`. It indicates evaluate your model on the validation data you generated instead of the demo data. 69 | 70 | Compare to the Koopman Baseline Method 71 | --- 72 | We also provide an implementation of the baseline Koopman model using polynomial Koopman basis. 73 | Just adding the argument `--baseline` to the above eval/mpc scripts, we can generate the simulation/control results for the Koopman baseline model. 74 | For example, you can evaluate the Koopman baseline on the demo data we provided. The result will be stored in `dump_{env}/eval_{env}_KoopmanBaseline_demo/` and `dump_{env}/shoot_{env}_KoopmanBaseline_demo/`. 75 | 76 | 77 | Citing Our Paper 78 | ----------------- 79 | 80 | If you find this codebase useful in your research, please consider citing: 81 | 82 | @inproceedings{ 83 | li2020learning, 84 | title={Learning Compositional Koopman Operators for Model-Based Control}, 85 | author={Yunzhu Li and Hao He and Jiajun Wu and Dina Katabi and Antonio Torralba}, 86 | booktitle={International Conference on Learning Representations}, 87 | year={2020}, 88 | url={https://openreview.net/forum?id=H1ldzA4tPr} 89 | } 90 | 91 | For any questions, please contact Yunzhu Li (liyunzhu@mit.edu) and Hao He (haohe@mit.edu). 92 | 93 | Related Work 94 | --------------- 95 | Propagation Networks for Model-Based Control Under Partial Observation [(website)](http://propnet.csail.mit.edu/) [(code)](https://github.com/YunzhuLi/PropNet) 96 | 97 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser() 4 | ''' 5 | General 6 | ''' 7 | parser.add_argument('--env', default='', required=True, help='Rope | Soft | Swim') 8 | parser.add_argument('--dt', type=float, default=1. / 50.) 9 | 10 | ''' 11 | Compositional Koopman Operator model 12 | ''' 13 | parser.add_argument('--pstep', type=int, default=2, help='number of propagation steps in GNN model') 14 | parser.add_argument('--nf_relation', type=int, default=120, help='length of relation encoding') 15 | parser.add_argument('--nf_particle', type=int, default=100, help='length of object encoding') 16 | parser.add_argument('--nf_effect', type=int, default=100, help='length of effect encoding') 17 | parser.add_argument('--g_dim', type=int, default=32, help='dimention of latent linear dynamics') 18 | parser.add_argument('--fit_type', default='structured', 19 | help="what is the structure of AB matrix in koopman: structured | unstructured | diagonal") 20 | # input dimensions 21 | parser.add_argument('--attr_dim', type=int, default=0) 22 | parser.add_argument('--state_dim', type=int, default=0) 23 | parser.add_argument('--action_dim', type=int, default=0) 24 | parser.add_argument('--relation_dim', type=int, default=0) 25 | 26 | ''' 27 | Koopman baseline with polynomial Koopman basis 28 | ''' 29 | parser.add_argument('--baseline', default=False, action='store_true') 30 | parser.add_argument('--baseline_order', type=int, default=3, help='order of polynomial basis') 31 | 32 | ''' 33 | data 34 | ''' 35 | parser.add_argument('--dataf', default='data') 36 | parser.add_argument('--regular_data', type=int, default=0, help='generate regular shape of soft robot (used in Swim env)') 37 | parser.add_argument('--num_workers', type=int, default=10) 38 | parser.add_argument('--gen_data', type=int, default=0, help="whether to generate new data") 39 | parser.add_argument('--gen_stat', type=int, default=1, help="whether to generate statistic for the data") 40 | parser.add_argument('--group_size', type=int, default=25, help='# of episodes sharing the same physical parameters') 41 | 42 | ''' 43 | train 44 | ''' 45 | parser.add_argument('--outf', default='train') 46 | parser.add_argument('--lr', type=float, default=1e-4) 47 | parser.add_argument('--batch_size', type=int, default=8) 48 | parser.add_argument('--grad_clip', type=float, default=5.0) 49 | parser.add_argument('--n_epoch', type=int, default=1000) 50 | parser.add_argument('--beta1', type=float, default=0.9) 51 | parser.add_argument('--log_per_iter', type=int, default=100, help="print log every x iterations") 52 | parser.add_argument('--ckp_per_iter', type=int, default=1000, help="save checkpoint every x iterations") 53 | parser.add_argument('--resume_epoch', type=int, default=-1, help="resume epoch of previous trained checkpoint") 54 | parser.add_argument('--resume_iter', type=int, default=-1, help="resume iteration of previous trained checkpoint") 55 | parser.add_argument('--lambda_loss_metric', type=float, default=0.3) 56 | parser.add_argument('--len_seq', type=int, default=64, help='length of every episodes used in training') 57 | 58 | ''' 59 | system identification 60 | ''' 61 | parser.add_argument('--I_factor', type=float, default=10, help='l2 regularization factor of least-square fitting') 62 | parser.add_argument('--fit_num', type=int, default=8, help='number of episodes used for system identification') 63 | 64 | ''' 65 | eval 66 | ''' 67 | parser.add_argument('--eval', type=int, default=0) 68 | parser.add_argument('--evalf', default='eval') 69 | parser.add_argument('--eval_type', default='koopman', help='rollout|valid|koopman') 70 | parser.add_argument('--eval_epoch', type=int, default=-1) 71 | parser.add_argument('--eval_iter', type=int, default=-1) 72 | parser.add_argument('--eval_set', default='valid', help='train|valid|demo') 73 | 74 | ''' 75 | shoot 76 | ''' 77 | parser.add_argument('--shootf', default='shoot') 78 | parser.add_argument('--optim_iter_init', type=int, default=100) 79 | parser.add_argument('--optim_iter', type=int, default=10) 80 | parser.add_argument('--optim_type', default='qp', help="qp|lqr") 81 | parser.add_argument('--feedback', type=int, default=1, help="optimize the control signals every x steps") 82 | parser.add_argument('--shoot_set', default='valid', help='train|valid|demo') 83 | parser.add_argument('--roll_start', type=int, default=0) 84 | parser.add_argument('--roll_step', type=int, default=40) 85 | parser.add_argument('--shoot_epoch', type=int, default=-1) 86 | parser.add_argument('--shoot_iter', type=int, default=-1) 87 | 88 | 89 | 90 | 91 | def gen_args(): 92 | args = parser.parse_args() 93 | assert args.batch_size == args.fit_num 94 | if args.env == 'Rope': 95 | args.data_names = ['attrs', 'states', 'actions'] 96 | 97 | args.n_rollout = 10000 98 | args.train_valid_ratio = 0.9 99 | 100 | args.time_step = 101 101 | # one hot to indicate root/children 102 | args.attr_dim = 2 103 | # state [x, y, xdot, ydot] 104 | args.state_dim = 4 105 | # action [x] 106 | args.action_dim = 1 107 | # relation [spring, ghost spring] 108 | args.relation_dim = 8 109 | 110 | args.param_dim = 5 111 | 112 | args.n_splits = 5 113 | args.num_obj_range = [*range(5, 5 + 5)] 114 | args.extra_num_obj_range = [10, 11, 12, 13, 14] 115 | 116 | args.act_scale = 2. 117 | 118 | elif args.env == 'Soft': 119 | args.data_names = ['attrs', 'states', 'actions'] 120 | 121 | args.n_rollout = 50000 122 | args.train_valid_ratio = 0.9 123 | 124 | args.time_step = 101 125 | # one hot to indicate actuated / soft / rigid / fixed 126 | args.attr_dim = 4 127 | # state [x, y] * 4 + [xdot, ydot] * 4 128 | args.state_dim = 16 129 | # action 1-dim scalar of extending or contracting 130 | args.action_dim = 1 131 | # relation: #relations types = #spaical position types * #box types 132 | args.relation_dim = 9 * 4 133 | 134 | args.param_dim = 4 135 | args.n_splits = 5 136 | args.num_obj_range = [*range(5, 5 + 5)] 137 | args.extra_num_obj_range = [10, 11, 12, 13, 14] 138 | 139 | args.act_scale = 650. 140 | 141 | elif args.env == 'Swim': 142 | args.data_names = ['attrs', 'states', 'actions'] 143 | 144 | args.n_rollout = 50000 145 | args.train_valid_ratio = 0.9 146 | 147 | args.time_step = 101 148 | # one hot to indicate actuated / soft / rigid 149 | args.attr_dim = 3 150 | # state [x, y] * 4 + [xdot, ydot] * 4 151 | args.state_dim = 16 152 | # action 1-dim scalar of extending or contracting 153 | args.action_dim = 1 154 | # relation: #relations types = #spaical position types * #box types 155 | args.relation_dim = 9 * 3 156 | 157 | args.param_dim = 4 158 | args.n_splits = 5 159 | args.num_obj_range = [*range(5, 5 + 5)] 160 | args.extra_num_obj_range = [10, 11, 12, 13, 14] 161 | 162 | args.act_scale = 500. 163 | 164 | else: 165 | raise AssertionError("Unsupported env") 166 | 167 | assert args.n_rollout % (args.group_size * args.n_splits * args.batch_size) == 0 168 | 169 | args.demo = args.eval_set == 'demo' or args.shoot_set == 'demo' 170 | data_root = 'data' 171 | args.dataf = data_root + '/' + args.dataf + '_' + args.env 172 | 173 | dump_prefix = 'dump_{}/'.format(args.env) 174 | args.outf = dump_prefix + args.outf 175 | args.evalf = dump_prefix + args.evalf 176 | args.shootf = dump_prefix + args.shootf 177 | args.tmpf = dump_prefix + 'tmp' 178 | args.outf = args.outf + '_' + args.env 179 | args.stat_path = args.dataf + '/' + ('stat.h5' if not args.demo else 'stat_demo.h5') 180 | 181 | if not args.baseline: 182 | # compositional koopman operators 183 | args.outf += '_CKO_pstep_' + str(args.pstep) 184 | args.outf += '_lenseq_' + str(args.len_seq) 185 | args.outf += '_gdim_' + str(args.g_dim) 186 | args.outf += '_bs_' + str(args.batch_size) 187 | args.outf += '_' + str(args.fit_type) 188 | 189 | args.evalf += '_CKO_pstep_' + str(args.pstep) 190 | args.evalf += '_lenseq_' + str(args.len_seq) 191 | args.evalf += '_gdim_' + str(args.g_dim) 192 | args.evalf += '_fitnum_' + str(args.fit_num) 193 | args.evalf += '_' + str(args.fit_type) 194 | args.evalf += '_' + str(args.eval_set) 195 | if args.eval_epoch > -1: 196 | args.evalf += '_epoch_' + str(args.eval_epoch) 197 | args.evalf += '_iter_' + str(args.eval_iter) 198 | else: 199 | args.evalf += '_epoch_best' 200 | 201 | args.shootf += '_CKO_pstep_' + str(args.pstep) 202 | args.shootf += '_lenseq_' + str(args.len_seq) 203 | args.shootf += '_gdim_' + str(args.g_dim) 204 | args.shootf += '_fitnum_' + str(args.fit_num) 205 | args.shootf += '_' + args.fit_type 206 | args.shootf += '_' + args.optim_type 207 | args.shootf += '_roll_' + str(args.roll_step) 208 | if args.shoot_epoch > -1: 209 | args.shootf += '_epoch_' + str(args.shoot_epoch) 210 | args.shootf += '_iter_' + str(args.shoot_iter) 211 | else: 212 | args.shootf += '_epoch_best' 213 | 214 | args.shootf += '_feedback_' + str(args.feedback) 215 | args.shootf += '_' + str(args.shoot_set) 216 | 217 | # for demo 218 | if args.demo: 219 | args.outf = dump_prefix + f'train_{args.env}_CKO_demo' 220 | args.evalf = dump_prefix + f'eval_{args.env}_CKO_demo' 221 | args.shootf = dump_prefix + f'shoot_{args.env}_CKO_demo' 222 | 223 | else: 224 | 225 | args.evalf += '_KoopmanBaseline' 226 | args.evalf += '_fitnum_' + str(args.fit_num) 227 | args.evalf += '_' + str(args.fit_type) 228 | args.evalf += '_I_' + str(args.I_factor) 229 | args.evalf += '_order_' + str(args.baseline_order) 230 | args.evalf += '_' + str(args.eval_set) 231 | 232 | args.shootf += '_KoopmanBaseline' 233 | args.shootf += '_fitnum_' + str(args.fit_num) 234 | args.shootf += '_' + args.fit_type 235 | args.shootf += '_I_' + str(args.I_factor) 236 | args.shootf += '_order_' + str(args.baseline_order) 237 | args.shootf += '_roll_' + str(args.roll_step) 238 | args.shootf += '_feedback_' + str(args.feedback) 239 | 240 | # for demo 241 | if args.demo: 242 | args.outf = dump_prefix + f'train_{args.env}_KoopmanBaseline_demo' 243 | args.evalf = dump_prefix + f'eval_{args.env}_KoopmanBaseline_demo' 244 | args.shootf = dump_prefix + f'shoot_{args.env}_KoopmanBaseline_demo' 245 | 246 | return args 247 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import multiprocessing as mp 2 | import os 3 | import time 4 | 5 | import numpy as np 6 | import torch 7 | from progressbar import ProgressBar 8 | from torch.autograd import Variable 9 | from torch.utils.data import Dataset 10 | from socket import gethostname 11 | 12 | from physics_engine import RopeEngine, SoftEngine, SwimEngine 13 | from physics_engine import sample_init_p_flight 14 | 15 | from utils import rand_float, rand_int, calc_dis 16 | from utils import init_stat, combine_stat, load_data, store_data 17 | 18 | 19 | # ====================================================================================================================== 20 | def normalize(data, stat, var=False): 21 | for i in range(len(stat)): 22 | stat[i][stat[i][:, 1] == 0, 1] = 1.0 23 | if var: 24 | for i in range(len(stat)): 25 | s = Variable(torch.FloatTensor(stat[i]).to(data[i].device)) 26 | data[i] = (data[i] - s[:, 0]) / s[:, 1] 27 | else: 28 | for i in range(len(stat)): 29 | data[i] = (data[i] - stat[i][:, 0]) / stat[i][:, 1] 30 | return data 31 | 32 | 33 | def denormalize(data, stat, var=False): 34 | if var: 35 | for i in range(len(stat)): 36 | s = Variable(torch.FloatTensor(stat[i])).to(data[i].device) 37 | data[i] = data[i] * s[:, 1] + s[:, 0] 38 | else: 39 | for i in range(len(stat)): 40 | data[i] = data[i] * stat[i][:, 1] + stat[i][:, 0] 41 | return data 42 | 43 | 44 | # ====================================================================================================================== 45 | 46 | def prepare_input(data, stat, args, param=None, var=False): 47 | if args.env == 'Rope': 48 | data = normalize(data, stat, var) 49 | attrs, states, actions = data 50 | 51 | # print('attrs', attrs.shape, np.mean(attrs), np.std(attrs)) 52 | # print('states', states.shape, np.mean(states), np.std(states)) 53 | # print('acts', acts.shape, np.mean(actions), np.std(actions)) 54 | 55 | N = len(attrs) 56 | 57 | # print('N', N) 58 | 59 | rel_attrs = np.zeros((N, N, args.relation_dim)) 60 | 61 | '''relation #0 self: root <- root''' 62 | rel_attrs[0, 0, 0] = 1 63 | 64 | '''relation #1 spring: root <- child''' 65 | rel_attrs[0, 1, 1] = 1 66 | 67 | '''relation #2 spring: child <- root''' 68 | rel_attrs[1, 0, 2] = 1 69 | 70 | '''relation #3 spring bihop: root <- child''' 71 | rel_attrs[0, 2, 3] = 1 72 | 73 | '''relation #4 spring bihop: child <- root''' 74 | rel_attrs[2, 0, 4] = 1 75 | 76 | '''relation #5 spring: child <- child''' 77 | for i in range(1, N - 1): 78 | rel_attrs[i, i + 1, 5] = rel_attrs[i + 1, i, 5] = 1 79 | 80 | '''relation #6 spring bihop: child <- child''' 81 | for i in range(1, N - 2): 82 | rel_attrs[i, i + 2, 6] = rel_attrs[i + 2, i, 6] = 1 83 | 84 | '''relation #7 self: child <- child''' 85 | np.fill_diagonal(rel_attrs[1:, 1:, 7], 1) 86 | 87 | assert (rel_attrs.sum(2) <= 1).all() 88 | 89 | # check the number of each edge type 90 | rel_type_sum = np.sum(rel_attrs, axis=(0, 1)) 91 | assert rel_type_sum[0] == 1 92 | assert rel_type_sum[1] == 1 93 | assert rel_type_sum[2] == 1 94 | assert rel_type_sum[3] == 1 95 | assert rel_type_sum[4] == 1 96 | assert rel_type_sum[5] == (N - 2) * 2 97 | assert rel_type_sum[6] == (N - 3) * 2 98 | assert rel_type_sum[7] == N - 1 99 | 100 | elif args.env in ['Soft', 'Swim']: 101 | init_p = param[3] 102 | data = normalize(data, stat, var) 103 | attrs, states, actions = data 104 | 105 | # print('attrs', attrs.shape, np.mean(attrs), np.std(attrs)) 106 | # print('states', states.shape, np.mean(states), np.std(states)) 107 | # print('acts', actions.shape, np.mean(actions), np.std(actions)) 108 | 109 | N = len(attrs) 110 | 111 | # print('N', N) 112 | rel_attrs = np.zeros((N, N, args.relation_dim)) 113 | 114 | num_spacial_rel_type = 9 115 | num_box_type = 3 if args.env == 'Swim' else 4 116 | 117 | for i in range(N): 118 | # normalized attributes 119 | type_i = np.where(attrs[i] > 0)[0][0] 120 | type_id = type_i 121 | # type_id = type_i * num_box_type + type_i 122 | rel_attrs[i, i, type_id * num_spacial_rel_type + 0] = 1 # self 123 | 124 | for j in range(N): 125 | if i == j: 126 | continue 127 | 128 | delta = init_p[i, :2] - init_p[j, :2] 129 | 130 | assert (np.abs(delta) > 0).any() 131 | 132 | if (np.abs(delta) > 1).any(): 133 | # no contact 134 | continue 135 | 136 | """ 137 | get i and j type 138 | Soft: [0: soft actuator, 1: soft, 2: rigid, 3: fixed] 139 | Swim: [0: soft actuator, 1: soft, 2: rigid] 140 | """ 141 | # normalized attributes 142 | type_i = np.where(attrs[i] > 0)[0][0] 143 | type_id = type_i 144 | 145 | if np.sum(np.abs(delta)) == 1: 146 | # contact at a corner 147 | if delta[0] == 1: 148 | rel_attrs[i, j, 1 + type_id * num_spacial_rel_type] = 1 149 | elif delta[0] == -1: 150 | rel_attrs[i, j, 2 + type_id * num_spacial_rel_type] = 1 151 | elif delta[1] == 1: 152 | rel_attrs[i, j, 3 + type_id * num_spacial_rel_type] = 1 153 | elif delta[1] == -1: 154 | rel_attrs[i, j, 4 + type_id * num_spacial_rel_type] = 1 155 | 156 | elif np.sum(np.abs(delta)) == 2: 157 | # contact at a side 158 | if delta[0] == 1 and delta[1] == 1: 159 | rel_attrs[i, j, 5 + type_id * num_spacial_rel_type] = 1 160 | elif delta[0] == 1 and delta[1] == -1: 161 | rel_attrs[i, j, 6 + type_id * num_spacial_rel_type] = 1 162 | elif delta[0] == -1 and delta[1] == 1: 163 | rel_attrs[i, j, 7 + type_id * num_spacial_rel_type] = 1 164 | elif delta[0] == -1 and delta[1] == -1: 165 | rel_attrs[i, j, 8 + type_id * num_spacial_rel_type] = 1 166 | else: 167 | raise AssertionError( 168 | "Unknown contact pattern %d %d" % (delta[0], delta[1])) 169 | else: 170 | raise AssertionError( 171 | "Unknown contact pattern %d %d" % (delta[0], delta[1])) 172 | 173 | else: 174 | raise AssertionError("unsupported env") 175 | 176 | return attrs, states, actions, rel_attrs 177 | 178 | 179 | def gen_Rope(info): 180 | thread_idx, data_dir, data_names = info['thread_idx'], info['data_dir'], info['data_names'] 181 | n_rollout, time_step = info['n_rollout'], info['time_step'] 182 | dt, video, args, phase = info['dt'], info['video'], info['args'], info['phase'] 183 | 184 | np.random.seed(round(time.time() * 1000 + thread_idx) % 2 ** 32) 185 | 186 | attr_dim = args.attr_dim # root, child 187 | state_dim = args.state_dim # x, y, xdot, ydot 188 | action_dim = args.action_dim 189 | param_dim = args.param_dim # n_ball, init_x, k, damping, gravity 190 | 191 | act_scale = 2. 192 | ret_scale = 1. 193 | 194 | # attr, state, action 195 | stats = [init_stat(attr_dim), init_stat(state_dim), init_stat(action_dim)] 196 | 197 | engine = RopeEngine(dt, state_dim, action_dim, param_dim) 198 | 199 | group_size = args.group_size 200 | sub_dataset_size = n_rollout * args.num_workers // args.n_splits 201 | print('group size', group_size, 'sub_dataset_size', sub_dataset_size) 202 | assert n_rollout % group_size == 0 203 | assert args.n_rollout % args.n_splits == 0 204 | 205 | bar = ProgressBar() 206 | for i in bar(range(n_rollout)): 207 | rollout_idx = thread_idx * n_rollout + i 208 | group_idx = rollout_idx // group_size 209 | sub_idx = rollout_idx // sub_dataset_size 210 | 211 | num_obj_range = args.num_obj_range if phase in {'train', 'valid'} else args.extra_num_obj_range 212 | num_obj = num_obj_range[sub_idx] 213 | 214 | rollout_dir = os.path.join(data_dir, str(rollout_idx)) 215 | 216 | param_file = os.path.join(data_dir, str(group_idx) + '.param') 217 | 218 | os.system('mkdir -p ' + rollout_dir) 219 | 220 | if rollout_idx % group_size == 0: 221 | engine.init(param=(num_obj, None, None, None, None)) 222 | torch.save(engine.get_param(), param_file) 223 | else: 224 | while not os.path.isfile(param_file): 225 | time.sleep(0.5) 226 | param = torch.load(param_file) 227 | engine.init(param=param) 228 | 229 | for j in range(time_step): 230 | states_ctl = engine.get_state()[0] 231 | act_t = np.zeros((engine.num_obj, action_dim)) 232 | act_t[0, 0] = (np.random.rand() * 2 - 1.) * act_scale - states_ctl[0] * ret_scale 233 | 234 | engine.set_action(action=act_t) 235 | 236 | states = engine.get_state() 237 | actions = engine.get_action() 238 | 239 | n_obj = engine.num_obj 240 | 241 | pos = states[:, :2].copy() 242 | vec = states[:, 2:].copy() 243 | 244 | '''reset velocity''' 245 | if j > 0: 246 | vec = (pos - states_all[j - 1, :, :2]) / dt 247 | 248 | if j == 0: 249 | attrs_all = np.zeros((time_step, n_obj, attr_dim)) 250 | states_all = np.zeros((time_step, n_obj, state_dim)) 251 | actions_all = np.zeros((time_step, n_obj, action_dim)) 252 | 253 | '''attrs: [1, 0] => root; [0, 1] => child''' 254 | assert attr_dim == 2 255 | attrs = np.zeros((n_obj, attr_dim)) 256 | # category: the first ball is fixed 257 | attrs[0, 0] = 1 258 | attrs[1:, 1] = 1 259 | 260 | assert np.sum(attrs[:, 0]) == 1 261 | assert np.sum(attrs[:, 1]) == engine.num_obj - 1 262 | 263 | attrs_all[j] = attrs 264 | states_all[j, :, :2] = pos 265 | states_all[j, :, 2:] = vec 266 | actions_all[j] = actions 267 | 268 | data = [attrs, states_all[j], actions_all[j]] 269 | 270 | store_data(data_names, data, os.path.join(rollout_dir, str(j) + '.h5')) 271 | 272 | engine.step() 273 | 274 | datas = [attrs_all.astype(np.float64), states_all.astype(np.float64), actions_all.astype(np.float64)] 275 | 276 | for j in range(len(stats)): 277 | stat = init_stat(stats[j].shape[0]) 278 | stat[:, 0] = np.mean(datas[j], axis=(0, 1))[:] 279 | stat[:, 1] = np.std(datas[j], axis=(0, 1))[:] 280 | stat[:, 2] = datas[j].shape[0] 281 | stats[j] = combine_stat(stats[j], stat) 282 | 283 | return stats 284 | 285 | 286 | def gen_Soft(info): 287 | thread_idx, data_dir, data_names = info['thread_idx'], info['data_dir'], info['data_names'] 288 | n_rollout, time_step = info['n_rollout'], info['time_step'] 289 | dt, video, args, phase = info['dt'], info['video'], info['args'], info['phase'] 290 | 291 | np.random.seed(round(time.time() * 1000 + thread_idx) % 2 ** 32) 292 | 293 | attr_dim = args.attr_dim # attrs: actuated/soft/rigid/fixed 294 | 295 | state_dim = args.state_dim # x, y, xdot, ydot 296 | action_dim = args.action_dim 297 | param_dim = args.param_dim # n_box, k, damping, init_p 298 | 299 | act_scale = 650. 300 | act_delta = 200. 301 | 302 | # attr, state, action 303 | stats = [init_stat(attr_dim), init_stat(state_dim), init_stat(action_dim)] 304 | 305 | engine = SoftEngine(dt, state_dim, action_dim, param_dim) 306 | 307 | group_size = args.group_size 308 | sub_dataset_size = n_rollout * args.num_workers // args.n_splits 309 | print('group size', group_size, 'sub_dataset_size', sub_dataset_size) 310 | assert n_rollout % group_size == 0 311 | assert args.n_rollout % args.n_splits == 0 312 | 313 | bar = ProgressBar() 314 | for i in bar(range(n_rollout)): 315 | rollout_idx = thread_idx * n_rollout + i 316 | group_idx = rollout_idx // group_size 317 | sub_idx = rollout_idx // sub_dataset_size 318 | 319 | num_obj_range = args.num_obj_range if phase in {'train', 'valid'} else args.extra_num_obj_range 320 | num_obj = num_obj_range[sub_idx] 321 | 322 | rollout_dir = os.path.join(data_dir, str(rollout_idx)) 323 | param_file = os.path.join(data_dir, str(group_idx) + '.param') 324 | os.system('mkdir -p ' + rollout_dir) 325 | 326 | if rollout_idx % group_size == 0: 327 | engine.init(param=(num_obj, None, None, None)) 328 | torch.save(engine.get_param(), param_file) 329 | else: 330 | while not os.path.isfile(param_file): 331 | time.sleep(0.5) 332 | param = torch.load(param_file) 333 | engine.init(param=param) 334 | 335 | # act_t_param = np.zeros((engine.n_box, 1)) 336 | 337 | for j in range(time_step): 338 | box_type = engine.init_p[:, 2] 339 | act_t = np.zeros((engine.n_box, action_dim)) 340 | 341 | for k in range(engine.n_box): 342 | if box_type[k] == 0: 343 | ''' 344 | # if this is an actuated box 345 | if j == 0: 346 | act_t_param[k] = np.array([rand_float(0., 1.)]) 347 | 348 | if act_t_param[k] < 0.5: 349 | # using random action 350 | act_t[k] = rand_float(-act_scale, act_scale) 351 | 352 | else: 353 | ''' 354 | # using smooth action 355 | if j == 0: 356 | act_t[k] = rand_float(-act_delta, act_delta) 357 | else: 358 | act_t[k] = actions_all[j - 1, k] + rand_float(-act_delta, act_delta) 359 | act_t[k] = np.clip(act_t[k], -act_scale, act_scale) 360 | 361 | engine.set_action(act_t) 362 | 363 | states = engine.get_state() 364 | actions = engine.get_action() 365 | 366 | pos = states[:, :8].copy() 367 | vec = states[:, 8:].copy() 368 | 369 | '''reset velocity''' 370 | if j > 0: 371 | vec = (pos - states_all[j - 1, :, :8]) / dt 372 | 373 | if j == 0: 374 | attrs_all = np.zeros((time_step, num_obj, attr_dim)) 375 | states_all = np.zeros((time_step, num_obj, state_dim)) 376 | actions_all = np.zeros((time_step, num_obj, action_dim)) 377 | 378 | '''attrs: actuated/soft/rigid/fixed''' 379 | assert attr_dim == 4 380 | attrs = np.zeros((num_obj, attr_dim)) 381 | 382 | for k in range(engine.n_box): 383 | attrs[k, int(engine.init_p[k, 2])] = 1 384 | 385 | assert np.sum(attrs[:, 0]) == np.sum(engine.init_p[:, 2] == 0) 386 | assert np.sum(attrs[:, 1]) == np.sum(engine.init_p[:, 2] == 1) 387 | assert np.sum(attrs[:, 2]) == np.sum(engine.init_p[:, 2] == 2) 388 | assert np.sum(attrs[:, 3]) == np.sum(engine.init_p[:, 2] == 3) 389 | assert (np.sum(attrs, 1) == 1).all() 390 | 391 | attrs_all[j] = attrs 392 | states_all[j, :, :8] = pos 393 | states_all[j, :, 8:] = vec 394 | actions_all[j] = actions 395 | 396 | data = [attrs, states_all[j], actions_all[j]] 397 | 398 | store_data(data_names, data, os.path.join(rollout_dir, str(j) + '.h5')) 399 | 400 | engine.step() 401 | 402 | datas = [attrs_all.astype(np.float64), states_all.astype(np.float64), actions_all.astype(np.float64)] 403 | 404 | for j in range(len(stats)): 405 | stat = init_stat(stats[j].shape[0]) 406 | stat[:, 0] = np.mean(datas[j], axis=(0, 1))[:] 407 | stat[:, 1] = np.std(datas[j], axis=(0, 1))[:] 408 | stat[:, 2] = datas[j].shape[0] 409 | stats[j] = combine_stat(stats[j], stat) 410 | 411 | return stats 412 | 413 | 414 | def gen_Swim(info): 415 | thread_idx, data_dir, data_names = info['thread_idx'], info['data_dir'], info['data_names'] 416 | n_rollout, time_step = info['n_rollout'], info['time_step'] 417 | dt, video, args, phase = info['dt'], info['video'], info['args'], info['phase'] 418 | 419 | np.random.seed(round(time.time() * 1000 + thread_idx) % 2 ** 32) 420 | 421 | attr_dim = args.attr_dim # actuated, soft, rigid 422 | state_dim = args.state_dim # x, y, xdot, ydot 423 | action_dim = args.action_dim 424 | param_dim = args.param_dim # n_box, k, damping, init_p 425 | 426 | act_scale = 500. 427 | act_delta = 250. 428 | 429 | # attr, state, action 430 | stats = [init_stat(attr_dim), init_stat(state_dim), init_stat(action_dim)] 431 | 432 | engine = SwimEngine(dt, state_dim, action_dim, param_dim) 433 | 434 | group_size = args.group_size 435 | sub_dataset_size = n_rollout * args.num_workers // args.n_splits 436 | print('group size', group_size, 'sub_dataset_size', sub_dataset_size) 437 | assert n_rollout % group_size == 0 438 | assert args.n_rollout % args.n_splits == 0 439 | 440 | bar = ProgressBar() 441 | for i in bar(range(n_rollout)): 442 | rollout_idx = thread_idx * n_rollout + i 443 | group_idx = rollout_idx // group_size 444 | sub_idx = rollout_idx // sub_dataset_size 445 | 446 | num_obj_range = args.num_obj_range if phase in {'train', 'valid'} else args.extra_num_obj_range 447 | num_obj = num_obj_range[sub_idx] 448 | 449 | rollout_dir = os.path.join(data_dir, str(rollout_idx)) 450 | param_file = os.path.join(data_dir, str(group_idx) + '.param') 451 | os.system('mkdir -p ' + rollout_dir) 452 | 453 | if rollout_idx % group_size == 0: 454 | init_p = None if not args.regular_data else sample_init_p_flight(n_box=num_obj, aug=True, train=phase=='train') 455 | engine.init(param=(num_obj, None, None, init_p)) 456 | torch.save(engine.get_param(), param_file) 457 | else: 458 | while not os.path.isfile(param_file): 459 | time.sleep(0.5) 460 | param = torch.load(param_file) 461 | engine.init(param=param) 462 | 463 | act_t_param = np.zeros((engine.n_box, 3)) 464 | 465 | for j in range(time_step): 466 | box_type = engine.init_p[:, 2] 467 | act_t = np.zeros((engine.n_box, action_dim)) 468 | 469 | for k in range(engine.n_box): 470 | if box_type[k] == 0: 471 | # if this is an actuated box 472 | if j == 0: 473 | act_t_param[k] = np.array([rand_float(0., 1.), rand_float(1., 2.5), rand_float(0, np.pi * 2)]) 474 | 475 | if act_t_param[k, 0] < 0.3: 476 | # using smooth action 477 | if j == 0: 478 | act_t[k] = rand_float(-act_delta, act_delta) 479 | else: 480 | lo = max(actions_all[j - 1, k] - act_delta, - act_scale - 20) 481 | hi = min(actions_all[j - 1, k] + act_delta, act_scale + 20) 482 | act_t[k] = rand_float(lo, hi) 483 | act_t[k] = np.clip(act_t[k], -act_scale, act_scale) 484 | 485 | elif act_t_param[k, 0] < 0.6: 486 | # using random action 487 | act_t[k] = rand_float(-act_scale, act_scale) 488 | 489 | else: 490 | # using sin action 491 | act_t[k] = np.sin(j / act_t_param[k, 1] + act_t_param[k, 2]) * \ 492 | rand_float(act_scale / 2., act_scale) 493 | 494 | engine.set_action(act_t) 495 | 496 | states = engine.get_state() 497 | actions = engine.get_action() 498 | 499 | pos = states[:, :8].copy() 500 | vec = states[:, 8:].copy() 501 | 502 | '''reset velocity''' 503 | if j > 0: 504 | vec = (pos - states_all[j - 1, :, :8]) / dt 505 | 506 | if j == 0: 507 | attrs_all = np.zeros((time_step, num_obj, attr_dim)) 508 | states_all = np.zeros((time_step, num_obj, state_dim)) 509 | actions_all = np.zeros((time_step, num_obj, action_dim)) 510 | 511 | '''attrs: actuated/soft/rigid''' 512 | assert attr_dim == 3 513 | attrs = np.zeros((num_obj, attr_dim)) 514 | 515 | for k in range(engine.n_box): 516 | attrs[k, int(engine.init_p[k, 2])] = 1 517 | 518 | assert np.sum(attrs[:, 0]) == np.sum(engine.init_p[:, 2] == 0) 519 | assert np.sum(attrs[:, 1]) == np.sum(engine.init_p[:, 2] == 1) 520 | assert np.sum(attrs[:, 2]) == np.sum(engine.init_p[:, 2] == 2) 521 | 522 | attrs_all[j] = attrs 523 | states_all[j, :, :8] = pos 524 | states_all[j, :, 8:] = vec 525 | actions_all[j] = actions 526 | 527 | data = [attrs, states_all[j], actions_all[j]] 528 | 529 | store_data(data_names, data, os.path.join(rollout_dir, str(j) + '.h5')) 530 | 531 | engine.step() 532 | 533 | datas = [attrs_all.astype(np.float64), states_all.astype(np.float64), actions_all.astype(np.float64)] 534 | 535 | for j in range(len(stats)): 536 | stat = init_stat(stats[j].shape[0]) 537 | stat[:, 0] = np.mean(datas[j], axis=(0, 1))[:] 538 | stat[:, 1] = np.std(datas[j], axis=(0, 1))[:] 539 | stat[:, 2] = datas[j].shape[0] 540 | stats[j] = combine_stat(stats[j], stat) 541 | 542 | return stats 543 | 544 | 545 | 546 | class PhysicsDataset(Dataset): 547 | 548 | def __init__(self, args, phase): 549 | self.args = args 550 | self.phase = phase 551 | self.data_dir = os.path.join(self.args.dataf, phase) 552 | if gethostname().startswith('netmit') and phase == 'extra': 553 | self.data_dir = self.args.dataf + '_' + phase 554 | 555 | self.stat_path = os.path.join(self.args.dataf, 'stat.h5') 556 | self.stat = None 557 | 558 | os.system('mkdir -p ' + self.data_dir) 559 | 560 | if args.env in ['Rope', 'Soft', 'Swim']: 561 | self.data_names = ['attrs', 'states', 'actions'] 562 | else: 563 | raise AssertionError("Unknown env") 564 | 565 | ratio = self.args.train_valid_ratio 566 | if phase == 'train': 567 | self.n_rollout = int(self.args.n_rollout * ratio) 568 | elif phase in {'valid', 'extra'}: 569 | self.n_rollout = self.args.n_rollout - int(self.args.n_rollout * ratio) 570 | else: 571 | raise AssertionError("Unknown phase") 572 | 573 | self.T = self.args.len_seq 574 | 575 | def load_data(self): 576 | self.stat = load_data(self.data_names, self.stat_path) 577 | 578 | def gen_data(self): 579 | # if the data hasn't been generated, generate the data 580 | n_rollout, time_step, dt = self.n_rollout, self.args.time_step, self.args.dt 581 | assert n_rollout % self.args.num_workers == 0 582 | 583 | print("Generating data ... n_rollout=%d, time_step=%d" % (n_rollout, time_step)) 584 | 585 | infos = [] 586 | for i in range(self.args.num_workers): 587 | info = {'thread_idx': i, 588 | 'data_dir': self.data_dir, 589 | 'data_names': self.data_names, 590 | 'n_rollout': n_rollout // self.args.num_workers, 591 | 'time_step': time_step, 592 | 'dt': dt, 593 | 'video': False, 594 | 'phase': self.phase, 595 | 'args': self.args} 596 | 597 | infos.append(info) 598 | 599 | cores = self.args.num_workers 600 | pool = mp.Pool(processes=cores) 601 | 602 | env = self.args.env 603 | 604 | if env == 'Rope': 605 | data = pool.map(gen_Rope, infos) 606 | elif env == 'Soft': 607 | data = pool.map(gen_Soft, infos) 608 | elif env == 'Swim': 609 | data = pool.map(gen_Swim, infos) 610 | else: 611 | raise AssertionError("Unknown env") 612 | 613 | print("Training data generated, warpping up stats ...") 614 | 615 | if self.phase == 'train': 616 | # states [x, y, angle, xdot, ydot, angledot], action [x, xdot] 617 | if env in ['Rope', 'Soft', 'Swim']: 618 | self.stat = [init_stat(self.args.attr_dim), 619 | init_stat(self.args.state_dim), 620 | init_stat(self.args.action_dim)] 621 | 622 | for i in range(len(data)): 623 | for j in range(len(self.stat)): 624 | self.stat[j] = combine_stat(self.stat[j], data[i][j]) 625 | 626 | if self.args.gen_stat: 627 | print("Storing stat to %s" % self.stat_path) 628 | store_data(self.data_names, self.stat, self.stat_path) 629 | else: 630 | print("stat will be discarded") 631 | else: 632 | print("Loading stat from %s ..." % self.stat_path) 633 | 634 | if env in ['Rope', 'Soft', 'Swim']: 635 | self.stat = load_data(self.data_names, self.stat_path) 636 | 637 | def __len__(self): 638 | return self.n_rollout * (self.args.time_step - self.T) 639 | 640 | def __getitem__(self, idx): 641 | idx_rollout = idx // (self.args.time_step - self.T) 642 | idx_timestep = idx % (self.args.time_step - self.T) 643 | 644 | # prepare input data 645 | seq_data = None 646 | for t in range(self.T + 1): 647 | data_path = os.path.join(self.data_dir, str(idx_rollout), str(idx_timestep + t) + '.h5') 648 | data = load_data(self.data_names, data_path) 649 | data = prepare_input(data, self.stat, self.args) 650 | if seq_data is None: 651 | seq_data = [[d] for d in data] 652 | else: 653 | for i, d in enumerate(data): 654 | seq_data[i].append(d) 655 | seq_data = [np.array(d).astype(np.float32) for d in seq_data] 656 | 657 | return seq_data 658 | 659 | 660 | if __name__ == '__main__': 661 | from easydict import EasyDict 662 | 663 | args = EasyDict() 664 | args.dataf = 'data' 665 | 666 | args.train_valid_ratio = 0.9 667 | args.num_workers = 10 668 | args.len_seq = 64 669 | 670 | # args.env = 'Rope' 671 | args.env = 'Soft' 672 | args.dataf = 'data/' + args.dataf + '_' + args.env 673 | 674 | if args.env == 'Rope': 675 | args.dt = 1.0 / 50 676 | args.n_rollout = 1000 677 | args.time_step = 101 678 | 679 | args.attr_dim = 2 680 | args.state_dim = 4 681 | args.action_dim = 1 682 | 683 | args.relation_dim = 8 684 | 685 | args.param_dim = 5 686 | args.n_splits = 5 687 | 688 | elif args.env == 'Soft': 689 | args.dt = 1.0 / 50 690 | args.n_rollout = 1000 691 | args.time_step = 101 692 | 693 | args.attr_dim = 3 # actuated, soft tissue, rigid tissue 694 | args.state_dim = 4 695 | args.action_dim = 1 696 | 697 | args.relation_dim = 9 698 | 699 | args.param_dim = 4 700 | args.n_splits = 10 701 | 702 | dataset = PhysicsDataset(args, phase='train') 703 | dataset.gen_data() 704 | -------------------------------------------------------------------------------- /data/data_Rope/demo/0.param: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/0.param -------------------------------------------------------------------------------- /data/data_Rope/demo/0.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/0.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/1.param: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/1.param -------------------------------------------------------------------------------- /data/data_Rope/demo/1.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/1.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/100.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/100.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/101.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/101.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/102.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/102.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/103.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/103.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/104.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/104.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/105.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/105.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/106.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/106.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/107.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/107.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/108.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/108.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/125.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/125.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/126.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/126.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/127.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/127.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/128.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/128.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/129.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/129.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/130.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/130.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/131.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/131.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/132.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/132.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/133.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/133.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/150.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/150.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/151.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/151.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/152.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/152.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/153.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/153.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/154.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/154.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/155.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/155.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/156.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/156.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/157.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/157.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/158.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/158.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/175.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/175.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/176.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/176.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/177.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/177.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/178.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/178.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/179.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/179.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/180.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/180.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/181.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/181.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/182.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/182.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/183.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/183.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/2.param: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/2.param -------------------------------------------------------------------------------- /data/data_Rope/demo/2.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/2.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/25.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/25.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/26.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/26.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/27.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/27.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/28.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/28.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/29.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/29.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/3.param: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/3.param -------------------------------------------------------------------------------- /data/data_Rope/demo/3.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/3.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/30.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/30.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/31.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/31.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/32.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/32.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/33.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/33.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/4.param: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/4.param -------------------------------------------------------------------------------- /data/data_Rope/demo/4.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/4.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/5.param: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/5.param -------------------------------------------------------------------------------- /data/data_Rope/demo/5.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/5.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/50.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/50.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/51.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/51.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/52.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/52.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/53.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/53.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/54.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/54.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/55.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/55.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/56.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/56.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/57.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/57.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/58.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/58.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/6.param: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/6.param -------------------------------------------------------------------------------- /data/data_Rope/demo/6.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/6.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/7.param: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/7.param -------------------------------------------------------------------------------- /data/data_Rope/demo/7.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/7.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/75.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/75.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/76.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/76.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/77.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/77.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/78.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/78.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/79.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/79.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/8.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/8.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/80.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/80.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/81.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/81.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/82.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/82.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/demo/83.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/demo/83.rollout.h5 -------------------------------------------------------------------------------- /data/data_Rope/stat_demo.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Rope/stat_demo.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/0.param: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/0.param -------------------------------------------------------------------------------- /data/data_Soft/demo/0.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/0.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/1.param: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/1.param -------------------------------------------------------------------------------- /data/data_Soft/demo/1.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/1.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/100.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/100.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/101.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/101.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/102.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/102.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/103.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/103.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/104.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/104.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/105.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/105.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/106.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/106.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/107.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/107.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/108.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/108.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/125.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/125.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/126.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/126.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/127.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/127.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/128.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/128.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/129.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/129.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/130.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/130.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/131.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/131.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/132.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/132.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/133.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/133.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/150.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/150.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/151.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/151.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/152.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/152.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/153.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/153.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/154.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/154.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/155.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/155.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/156.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/156.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/157.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/157.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/158.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/158.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/175.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/175.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/176.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/176.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/177.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/177.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/178.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/178.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/179.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/179.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/180.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/180.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/181.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/181.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/182.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/182.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/183.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/183.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/2.param: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/2.param -------------------------------------------------------------------------------- /data/data_Soft/demo/2.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/2.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/25.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/25.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/26.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/26.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/27.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/27.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/28.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/28.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/29.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/29.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/3.param: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/3.param -------------------------------------------------------------------------------- /data/data_Soft/demo/3.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/3.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/30.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/30.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/31.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/31.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/32.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/32.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/33.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/33.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/4.param: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/4.param -------------------------------------------------------------------------------- /data/data_Soft/demo/4.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/4.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/5.param: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/5.param -------------------------------------------------------------------------------- /data/data_Soft/demo/5.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/5.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/50.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/50.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/51.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/51.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/52.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/52.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/53.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/53.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/54.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/54.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/55.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/55.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/56.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/56.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/57.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/57.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/58.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/58.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/6.param: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/6.param -------------------------------------------------------------------------------- /data/data_Soft/demo/6.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/6.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/7.param: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/7.param -------------------------------------------------------------------------------- /data/data_Soft/demo/7.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/7.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/75.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/75.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/76.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/76.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/77.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/77.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/78.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/78.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/79.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/79.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/8.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/8.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/80.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/80.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/81.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/81.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/82.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/82.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/demo/83.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/demo/83.rollout.h5 -------------------------------------------------------------------------------- /data/data_Soft/stat_demo.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Soft/stat_demo.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/0.param: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/0.param -------------------------------------------------------------------------------- /data/data_Swim/demo/0.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/0.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/1.param: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/1.param -------------------------------------------------------------------------------- /data/data_Swim/demo/1.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/1.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/100.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/100.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/101.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/101.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/102.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/102.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/103.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/103.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/104.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/104.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/105.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/105.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/106.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/106.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/107.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/107.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/108.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/108.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/125.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/125.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/126.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/126.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/127.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/127.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/128.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/128.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/129.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/129.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/130.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/130.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/131.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/131.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/132.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/132.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/133.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/133.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/150.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/150.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/151.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/151.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/152.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/152.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/153.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/153.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/154.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/154.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/155.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/155.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/156.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/156.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/157.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/157.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/158.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/158.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/175.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/175.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/176.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/176.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/177.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/177.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/178.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/178.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/179.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/179.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/180.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/180.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/181.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/181.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/182.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/182.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/183.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/183.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/2.param: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/2.param -------------------------------------------------------------------------------- /data/data_Swim/demo/2.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/2.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/25.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/25.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/26.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/26.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/27.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/27.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/28.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/28.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/29.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/29.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/3.param: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/3.param -------------------------------------------------------------------------------- /data/data_Swim/demo/3.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/3.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/30.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/30.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/31.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/31.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/32.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/32.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/33.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/33.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/4.param: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/4.param -------------------------------------------------------------------------------- /data/data_Swim/demo/4.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/4.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/5.param: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/5.param -------------------------------------------------------------------------------- /data/data_Swim/demo/5.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/5.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/50.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/50.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/51.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/51.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/52.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/52.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/53.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/53.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/54.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/54.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/55.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/55.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/56.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/56.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/57.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/57.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/58.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/58.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/6.param: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/6.param -------------------------------------------------------------------------------- /data/data_Swim/demo/6.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/6.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/7.param: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/7.param -------------------------------------------------------------------------------- /data/data_Swim/demo/7.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/7.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/75.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/75.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/76.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/76.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/77.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/77.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/78.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/78.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/79.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/79.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/8.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/8.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/80.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/80.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/81.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/81.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/82.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/82.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/demo/83.rollout.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/demo/83.rollout.h5 -------------------------------------------------------------------------------- /data/data_Swim/stat_demo.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/data/data_Swim/stat_demo.h5 -------------------------------------------------------------------------------- /dump_Rope/train_Rope_CKO_demo/net_best.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/dump_Rope/train_Rope_CKO_demo/net_best.pth -------------------------------------------------------------------------------- /dump_Soft/train_Soft_CKO_demo/net_best.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/dump_Soft/train_Soft_CKO_demo/net_best.pth -------------------------------------------------------------------------------- /dump_Swim/train_Swim_CKO_demo/net_best.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/dump_Swim/train_Swim_CKO_demo/net_best.pth -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from config import gen_args 4 | from data import normalize, denormalize 5 | from models.CompositionalKoopmanOperators import CompositionalKoopmanOperators 6 | from models.KoopmanBaselineModel import KoopmanBaseline 7 | from physics_engine import SoftEngine, RopeEngine, SwimEngine 8 | from utils import * 9 | from utils import to_var, to_np, Tee 10 | from progressbar import ProgressBar 11 | import time 12 | 13 | args = gen_args() 14 | print_args(args) 15 | ''' 16 | args.fit_num is # of trajectories used for SysID 17 | ''' 18 | assert args.group_size - 1 >= args.fit_num 19 | 20 | data_names = ['attrs', 'states', 'actions'] 21 | prepared_names = ['attrs', 'states', 'actions', 'rel_attrs'] 22 | 23 | data_dir = os.path.join(args.dataf, args.eval_set) 24 | 25 | print(f"Load stored dataset statistics from {args.stat_path}!") 26 | stat = load_data(data_names, args.stat_path) 27 | 28 | if args.env == 'Rope': 29 | engine = RopeEngine(args.dt, args.state_dim, args.action_dim, args.param_dim) 30 | elif args.env == 'Soft': 31 | engine = SoftEngine(args.dt, args.state_dim, args.action_dim, args.param_dim) 32 | elif args.env == 'Swim': 33 | engine = SwimEngine(args.dt, args.state_dim, args.action_dim, args.param_dim) 34 | else: 35 | assert False 36 | 37 | 38 | os.system('mkdir -p ' + args.evalf) 39 | log_path = os.path.join(args.evalf, 'log.txt') 40 | tee = Tee(log_path, 'w') 41 | 42 | ''' 43 | model 44 | ''' 45 | # build model 46 | use_gpu = torch.cuda.is_available() 47 | if not args.baseline: 48 | """ Koopman model""" 49 | model = CompositionalKoopmanOperators(args, residual=False, use_gpu=use_gpu) 50 | 51 | # load pretrained checkpoint 52 | if args.eval_epoch == -1: 53 | model_path = os.path.join(args.outf, 'net_best.pth') 54 | else: 55 | model_path = os.path.join(args.outf, 'net_epoch_%d_iter_%d.pth' % (args.eval_epoch, args.eval_iter)) 56 | print("Loading saved checkpoint from %s" % model_path) 57 | device = torch.device('cuda:0') if use_gpu else torch.device('cpu') 58 | model.load_state_dict(torch.load(model_path,map_location=device)) 59 | model.eval() 60 | if use_gpu: model.cuda() 61 | 62 | else: 63 | """ Koopman Baselinese """ 64 | model = KoopmanBaseline(args) 65 | 66 | ''' 67 | eval 68 | ''' 69 | 70 | 71 | def get_more_trajectories(roll_idx): 72 | group_idx = roll_idx // args.group_size 73 | offset = group_idx * args.group_size 74 | 75 | all_seq = [[], [], [], []] 76 | 77 | for i in range(1, args.fit_num + 1): 78 | new_idx = (roll_idx + i - offset) % args.group_size + offset 79 | seq_data = load_data(prepared_names, os.path.join(data_dir, str(new_idx) + '.rollout.h5')) 80 | for j in range(4): 81 | all_seq[j].append(seq_data[j]) 82 | 83 | all_seq = [np.array(all_seq[j], dtype=np.float32) for j in range(4)] 84 | return all_seq 85 | 86 | def eval(idx_rollout, video=True): 87 | print(f'\n=== Forward Simulation on Example {roll_idx} ===') 88 | 89 | seq_data = load_data(prepared_names, os.path.join(data_dir, str(idx_rollout) + '.rollout.h5')) 90 | attrs, states, actions, rel_attrs = [to_var(d.copy(), use_gpu=use_gpu) for d in seq_data] 91 | 92 | seq_data = denormalize(seq_data, stat) 93 | attrs_gt, states_gt, action_gt = seq_data[:3] 94 | 95 | param_file = os.path.join(data_dir, str(idx_rollout // args.group_size) + '.param') 96 | param = torch.load(param_file) 97 | engine.init(param) 98 | 99 | ''' 100 | fit data 101 | ''' 102 | fit_data = get_more_trajectories(roll_idx) 103 | fit_data = [to_var(d, use_gpu=use_gpu) for d in fit_data] 104 | bs = args.fit_num 105 | 106 | ''' T x N x D (denormalized)''' 107 | states_pred = states_gt.copy() 108 | states_pred[1:] = 0 109 | 110 | ''' T x N x D (normalized)''' 111 | s_pred = states.clone() 112 | 113 | ''' 114 | reconstruct loss 115 | ''' 116 | attrs_flat = get_flat(fit_data[0]) 117 | states_flat = get_flat(fit_data[1]) 118 | actions_flat = get_flat(fit_data[2]) 119 | rel_attrs_flat = get_flat(fit_data[3]) 120 | 121 | g = model.to_g(attrs_flat, states_flat, rel_attrs_flat, args.pstep) 122 | g = g.view(torch.Size([bs, args.time_step]) + g.size()[1:]) 123 | 124 | G_tilde = g[:, :-1] 125 | H_tilde = g[:, 1:] 126 | U_tilde = fit_data[2][:, :-1] 127 | 128 | G_tilde = get_flat(G_tilde, keep_dim=True) 129 | H_tilde = get_flat(H_tilde, keep_dim=True) 130 | U_tilde = get_flat(U_tilde, keep_dim=True) 131 | 132 | _t = time.time() 133 | A, B, fit_err = model.system_identify( 134 | G=G_tilde, H=H_tilde, U=U_tilde, rel_attrs=fit_data[3][:1, 0], I_factor=args.I_factor) 135 | _t = time.time() - _t 136 | 137 | ''' 138 | predict 139 | ''' 140 | 141 | g = model.to_g(attrs, states, rel_attrs, args.pstep) 142 | 143 | pred_g = None 144 | for step in range(0, args.time_step - 1): 145 | # prepare input data 146 | 147 | if step == 0: 148 | current_s = states[step:step + 1] 149 | current_g = g[step:step + 1] 150 | states_pred[step] = states_gt[step] 151 | else: 152 | '''current state''' 153 | if args.eval_type == 'valid': 154 | current_s = states[step:step + 1] 155 | elif args.eval_type == 'rollout': 156 | current_s = s_pred[step:step + 1] 157 | 158 | '''current g''' 159 | if args.eval_type in {'valid', 'rollout'}: 160 | current_g = model.to_g(attrs[step:step + 1], current_s, rel_attrs[step:step + 1], args.pstep) 161 | elif args.eval_type == 'koopman': 162 | current_g = pred_g 163 | 164 | '''next g''' 165 | pred_g = model.step(g=current_g, u=actions[step:step + 1], rel_attrs=rel_attrs[step:step + 1]) 166 | 167 | '''decode s''' 168 | pred_s = model.to_s(attrs=attrs[step:step + 1], gcodes=pred_g, 169 | rel_attrs=rel_attrs[step:step + 1], pstep=args.pstep) 170 | 171 | pred_s_np_denorm = denormalize([to_np(pred_s)], [stat[1]])[0] 172 | 173 | states_pred[step + 1:step + 2] = pred_s_np_denorm 174 | d = args.state_dim // 2 175 | states_pred[step + 1:step + 2, :, :d] = states_pred[step:step + 1, :, :d] + \ 176 | args.dt * states_pred[step + 1:step + 2, :, d:] 177 | 178 | s_pred_next = normalize([states_pred[step + 1:step + 2]], [stat[1]])[0] 179 | s_pred[step + 1:step + 2] = to_var(s_pred_next, use_gpu=use_gpu) 180 | 181 | if video: 182 | engine.render(states_pred, seq_data[2], param, act_scale=args.act_scale, video=True, image=True, 183 | path=os.path.join(args.evalf, str(idx_rollout) + '.pred'), 184 | states_gt=states_gt) 185 | 186 | if __name__ == '__main__': 187 | 188 | num_train = int(args.n_rollout * args.train_valid_ratio) 189 | num_valid = args.n_rollout - num_train 190 | 191 | ls_rollout_idx = np.arange(0, num_valid, num_valid // args.n_splits) 192 | 193 | if args.demo: 194 | ls_rollout_idx = np.arange(8) * 25 195 | 196 | for roll_idx in ls_rollout_idx: 197 | eval(roll_idx) 198 | -------------------------------------------------------------------------------- /figures/rope.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/figures/rope.gif -------------------------------------------------------------------------------- /figures/soft.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/figures/soft.gif -------------------------------------------------------------------------------- /figures/swim.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/figures/swim.gif -------------------------------------------------------------------------------- /models/CompositionalKoopmanOperators.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | 8 | from data import denormalize, normalize 9 | from utils import load_data 10 | 11 | 12 | class RelationEncoder(nn.Module): 13 | def __init__(self, input_size, hidden_size, output_size): 14 | super(RelationEncoder, self).__init__() 15 | 16 | self.model = nn.Sequential( 17 | nn.Linear(input_size, hidden_size), 18 | nn.ReLU(), 19 | nn.Linear(hidden_size, output_size), 20 | nn.ReLU() 21 | ) 22 | 23 | def forward(self, x): 24 | """ 25 | Args: 26 | x: [n_relations, input_size] 27 | Returns: 28 | [n_relations, output_size] 29 | """ 30 | return self.model(x) 31 | 32 | 33 | class ParticleEncoder(nn.Module): 34 | def __init__(self, input_size, hidden_size, output_size): 35 | super(ParticleEncoder, self).__init__() 36 | 37 | self.model = nn.Sequential( 38 | nn.Linear(input_size, hidden_size), 39 | nn.ReLU(), 40 | nn.Linear(hidden_size, output_size), 41 | nn.ReLU() 42 | ) 43 | 44 | def forward(self, x): 45 | """ 46 | Args: 47 | x: [n_particles, input_size] 48 | Returns: 49 | [n_particles, output_size] 50 | """ 51 | return self.model(x) 52 | 53 | 54 | class Propagator(nn.Module): 55 | def __init__(self, input_size, output_size, residual=False): 56 | super(Propagator, self).__init__() 57 | 58 | self.residual = residual 59 | 60 | self.linear = nn.Linear(input_size, output_size) 61 | self.relu = nn.ReLU() 62 | 63 | def forward(self, x, res=None): 64 | """ 65 | Args: 66 | x: [n_relations/n_particles, input_size] 67 | Returns: 68 | [n_relations/n_particles, output_size] 69 | """ 70 | if self.residual: 71 | x = self.relu(self.linear(x) + res) 72 | else: 73 | x = self.relu(self.linear(x)) 74 | 75 | return x 76 | 77 | 78 | class ParticlePredictor(nn.Module): 79 | def __init__(self, input_size, hidden_size, output_size): 80 | super(ParticlePredictor, self).__init__() 81 | 82 | self.linear_0 = nn.Linear(input_size, hidden_size) 83 | self.linear_1 = nn.Linear(hidden_size, output_size) 84 | self.relu = nn.ReLU() 85 | 86 | def forward(self, x): 87 | """ 88 | Args: 89 | x: [n_particles, input_size] 90 | Returns: 91 | [n_particles, output_size] 92 | """ 93 | x = self.relu(self.linear_0(x)) 94 | 95 | return self.linear_1(x) 96 | 97 | 98 | class l2_norm_layer(nn.Module): 99 | def __init__(self): 100 | super(l2_norm_layer, self).__init__() 101 | 102 | def forward(self, x): 103 | """ 104 | :param x: B x D 105 | :return: 106 | """ 107 | norm_x = torch.sqrt((x ** 2).sum(1) + 1e-10) 108 | return x / norm_x[:, None] 109 | 110 | 111 | class PropagationNetwork(nn.Module): 112 | def __init__(self, args, input_particle_dim=None, input_relation_dim=None, output_dim=None, action=True, tanh=False, 113 | residual=False, use_gpu=False): 114 | 115 | super(PropagationNetwork, self).__init__() 116 | 117 | self.args = args 118 | self.action = action 119 | 120 | if input_particle_dim is None: 121 | input_particle_dim = args.attr_dim + args.state_dim 122 | input_particle_dim += args.action_dim if action else 0 123 | 124 | if input_relation_dim is None: 125 | input_relation_dim = args.relation_dim + args.state_dim 126 | 127 | if output_dim is None: 128 | output_dim = args.state_dim 129 | 130 | nf_particle = args.nf_particle 131 | nf_relation = args.nf_relation 132 | nf_effect = args.nf_effect 133 | 134 | self.nf_effect = args.nf_effect 135 | 136 | self.use_gpu = use_gpu 137 | self.residual = residual 138 | 139 | # (1) state 140 | self.obj_encoder = ParticleEncoder(input_particle_dim, nf_particle, nf_effect) 141 | 142 | # (1) state receiver (2) state_diff 143 | self.relation_encoder = RelationEncoder(input_relation_dim, nf_relation, nf_relation) 144 | 145 | # (1) relation encode (2) sender effect (3) receiver effect 146 | self.relation_propagator = Propagator(nf_relation + 2 * nf_effect, nf_effect) 147 | 148 | # (1) particle encode (2) particle effect 149 | self.particle_propagator = Propagator(2 * nf_effect, nf_effect, self.residual) 150 | 151 | # rigid predictor 152 | # (1) particle encode (2) set particle effect 153 | 154 | self.particle_predictor = ParticlePredictor(nf_effect, nf_effect, output_dim) 155 | 156 | if tanh: 157 | self.particle_predictor = nn.Sequential( 158 | self.particle_predictor, nn.Tanh() 159 | ) 160 | 161 | def forward(self, attrs, states, actions, rel_attrs, pstep): 162 | """ 163 | :param attrs: B x N x attr_dim 164 | :param states: B x N x state_dim 165 | :param actions: B x N x action_dim 166 | :param rel_attrs: B x N x N x relation_dim 167 | :param pstep: 1 or 2 168 | :return: 169 | """ 170 | B, N = attrs.size(0), attrs.size(1) 171 | '''encode node''' 172 | obj_input_list = [attrs, states] 173 | if self.action: 174 | obj_input_list += [actions] 175 | 176 | tmp = torch.cat(obj_input_list, 2) 177 | obj_encode = self.obj_encoder(tmp.reshape(tmp.size(0) * tmp.size(1), tmp.size(2))).reshape(B, N, -1) 178 | 179 | '''encode edge''' 180 | rel_states = states[:, :, None, :] - states[:, None, :, :] 181 | receiver_attr = attrs[:, :, None, :].repeat(1, 1, N, 1) 182 | sender_attr = attrs[:, None, :, :].repeat(1, N, 1, 1) 183 | tmp = torch.cat([rel_attrs, rel_states, receiver_attr, sender_attr], 3) 184 | rel_encode = self.relation_encoder(tmp.reshape(B * N * N, -1)).reshape(B, N, N, -1) 185 | 186 | for i in range(pstep): 187 | '''calculate relation effect''' 188 | 189 | receiver_code = obj_encode[:, :, None, :].repeat(1, 1, N, 1) 190 | sender_code = obj_encode[:, None, :, :].repeat(1, N, 1, 1) 191 | tmp = torch.cat([rel_encode, receiver_code, sender_code], 3) 192 | rel_effect = self.relation_propagator(tmp.reshape(B * N * N, -1)).reshape(B, N, N, -1) 193 | 194 | '''aggregating relation effect''' 195 | 196 | rel_agg_effect = rel_effect.sum(2) 197 | 198 | '''calc particle effect''' 199 | tmp = torch.cat([obj_encode, rel_agg_effect], 2) 200 | obj_encode = self.particle_propagator(tmp.reshape(B * N, -1)).reshape(B, N, -1) 201 | 202 | obj_prediction = self.particle_predictor(obj_encode.reshape(B * N, -1)).reshape(B, N, -1) 203 | return obj_prediction 204 | 205 | 206 | # ====================================================================================================================== 207 | class CompositionalKoopmanOperators(nn.Module, ABC): 208 | def __init__(self, args, residual=False, use_gpu=False): 209 | super(CompositionalKoopmanOperators, self).__init__() 210 | 211 | self.args = args 212 | 213 | self.stat = load_data(['attrs', 'states', 'actions'], args.stat_path) 214 | 215 | g_dim = args.g_dim 216 | 217 | self.nf_effect = args.nf_effect 218 | 219 | self.use_gpu = use_gpu 220 | self.residual = residual 221 | 222 | ''' state ''' 223 | # we should not include action in state encoder 224 | input_particle_dim = args.attr_dim + args.state_dim 225 | input_relation_dim = args.state_dim + args.relation_dim + args.attr_dim * 2 226 | 227 | # print('state_encoder', 'node', input_particle_dim, 'edge', input_relation_dim) 228 | 229 | self.state_encoder = PropagationNetwork( 230 | args, input_particle_dim=input_particle_dim, input_relation_dim=input_relation_dim, 231 | output_dim=g_dim, action=False, tanh=True, # use tanh to enforce the shape of the code space 232 | residual=residual, use_gpu=use_gpu) 233 | 234 | # the state for decoding phase is replaced with code of g_dim 235 | input_particle_dim = args.attr_dim + args.g_dim 236 | input_relation_dim = args.g_dim + args.relation_dim + args.attr_dim * 2 237 | 238 | # print('state_decoder', 'node', input_particle_dim, 'edge', input_relation_dim) 239 | 240 | self.state_decoder = PropagationNetwork( 241 | args, input_particle_dim=input_particle_dim, input_relation_dim=input_relation_dim, 242 | output_dim=args.state_dim, action=False, tanh=False, 243 | residual=residual, use_gpu=use_gpu) 244 | 245 | ''' dynamical system coefficient: A and B ''' 246 | self.A = None 247 | self.B = None 248 | if args.fit_type == 'structured': 249 | self.system_identify = self.fit 250 | self.simulate = self.rollout 251 | self.step = self.linear_forward 252 | if args.fit_type == 'unstructured': 253 | self.system_identify = self.fit_unstructured 254 | self.simulate = self.rollout_unstructured 255 | self.step = self.linear_forward_unstructured 256 | elif args.fit_type == 'diagonal': 257 | self.system_identify = self.fit_diagonal 258 | self.simulate = self.rollout_diagonal 259 | self.step = self.linear_forward_diagonal 260 | 261 | def to_s(self, attrs, gcodes, rel_attrs, pstep): 262 | """ state decoder """ 263 | 264 | if self.args.env in ['Soft', 'Swim']: 265 | states = self.state_decoder(attrs=attrs, states=gcodes, actions=None, rel_attrs=rel_attrs, pstep=pstep) 266 | return regularize_state_Soft(states, rel_attrs, self.stat) 267 | 268 | return self.state_decoder(attrs=attrs, states=gcodes, actions=None, rel_attrs=rel_attrs, pstep=pstep) 269 | 270 | def to_g(self, attrs, states, rel_attrs, pstep): 271 | """ state encoder """ 272 | return self.state_encoder(attrs=attrs, states=states, actions=None, rel_attrs=rel_attrs, pstep=pstep) 273 | 274 | @staticmethod 275 | def get_aug(G, rel_attrs): 276 | """ 277 | :param G: B x T x N x D 278 | :param rel_attrs: B x N x N x R 279 | :return: 280 | """ 281 | B, T, N, D = G.size() 282 | R = rel_attrs.size(-1) 283 | 284 | sumG_list = [] 285 | for i in range(R): 286 | ''' B x T x N x N ''' 287 | adj = rel_attrs[:, :, :, i][:, None, :, :].repeat(1, T, 1, 1) 288 | sumG = torch.bmm( 289 | adj.reshape(B * T, N, N), 290 | G.reshape(B * T, N, D) 291 | ).reshape(B, T, N, D) 292 | sumG_list.append(sumG) 293 | 294 | augG = torch.cat(sumG_list, 3) 295 | 296 | return augG 297 | 298 | # structured A 299 | 300 | def fit(self, G, H, U, rel_attrs, I_factor): 301 | """ 302 | :param G: B x T x N x D 303 | :param H: B x T x N x D 304 | :param U: B x T x N x a_dim 305 | :param rel_attrs: B x N x N x R (relation_dim) rel_attrs[i,j] ==> receiver i, sender j 306 | :param I_factor: scalor 307 | :return: 308 | A: B x R D x D 309 | B: B x R a_dim x D 310 | s.t. 311 | H = augG @ A + augU @ B 312 | """ 313 | 314 | ''' B x R: sqrt(# of appearance of block matrices of the same type)''' 315 | rel_weights = torch.sqrt(rel_attrs.sum(1).sum(1)) 316 | rel_weights = torch.clamp(rel_weights, min=1e-8) 317 | 318 | bs, T, N, D = G.size() 319 | R = rel_attrs.size(-1) 320 | a_dim = U.size(3) 321 | 322 | ''' B x T x N x R D ''' 323 | augG = self.get_aug(G, rel_attrs) 324 | ''' B x T x N x R a_dim''' 325 | augU = self.get_aug(U, rel_attrs) 326 | 327 | augG_reweight = augG.reshape(bs, T, N, R, D) / rel_weights[:, None, None, :, None] 328 | augU_reweight = augU.reshape(bs, T, N, R, a_dim) / rel_weights[:, None, None, :, None] 329 | 330 | ''' B x TN x R(D + a_dim)''' 331 | GU_reweight = torch.cat([augG_reweight.reshape(bs, T * N, R * D), 332 | augU_reweight.reshape(bs, T * N, R * a_dim)], 2) 333 | 334 | '''B x (R * D + R * a_dim) x D''' 335 | AB_reweight = torch.bmm( 336 | self.batch_pinv(GU_reweight, I_factor), 337 | H.reshape(bs, T * N, D) 338 | ) 339 | self.A = AB_reweight[:, :R * D].reshape(bs, R, D, D) / rel_weights[:, :, None, None] 340 | self.B = AB_reweight[:, R * D:].reshape(bs, R, a_dim, D) / rel_weights[:, :, None, None] 341 | 342 | self.A = self.A.reshape(bs, R * D, D) 343 | self.B = self.B.reshape(bs, R * a_dim, D) 344 | 345 | fit_err = H.reshape(bs, T * N, D) - torch.bmm(GU_reweight, AB_reweight) 346 | fit_err = torch.sqrt((fit_err ** 2).mean()) 347 | 348 | return self.A, self.B, fit_err 349 | 350 | def linear_forward(self, g, u, rel_attrs): 351 | """ 352 | :param g: B x N x D 353 | :param u: B x N x a_dim 354 | :param rel_attrs: B x N x N x R 355 | :return: 356 | """ 357 | ''' B x N x R D ''' 358 | aug_g = self.get_aug(G=g[:, None, :, :], rel_attrs=rel_attrs)[:, 0] 359 | ''' B x N x R a_dim''' 360 | aug_u = self.get_aug(G=u[:, None, :, :], rel_attrs=rel_attrs)[:, 0] 361 | 362 | new_g = torch.bmm(aug_g, self.A) + torch.bmm(aug_u, self.B) 363 | return new_g 364 | 365 | def rollout(self, g, u_seq, T, rel_attrs): 366 | """ 367 | :param g: B x N x D 368 | :param u_seq: B x T x N x a_dim 369 | :param rel_attrs: B x N x N x R 370 | :param T: 371 | :return: 372 | """ 373 | g_list = [] 374 | for t in range(T): 375 | g = self.linear_forward(g, u_seq[:, t], rel_attrs) 376 | g_list.append(g[:, None, :, :]) 377 | return torch.cat(g_list, 1) 378 | 379 | # unstructured large A 380 | 381 | def fit_unstructured(self, G, H, U, I_factor, rel_attrs=None): 382 | """ 383 | :param G: B x T x N x D 384 | :param H: B x T x N x D 385 | :param U: B x T x N x a_dim 386 | :param I_factor: scalor 387 | :return: A, B 388 | s.t. 389 | H = catG @ A + catU @ B 390 | """ 391 | bs, T, N, D = G.size() 392 | G = G.reshape(bs, T, -1) 393 | H = H.reshape(bs, T, -1) 394 | U = U.reshape(bs, T, -1) 395 | 396 | G_U = torch.cat([G, U], 2) 397 | A_B = torch.bmm( 398 | self.batch_pinv(G_U, I_factor), 399 | H 400 | ) 401 | self.A = A_B[:, :N * D] 402 | self.B = A_B[:, N * D:] 403 | 404 | fit_err = H - torch.bmm(G_U, A_B) 405 | fit_err = torch.sqrt((fit_err ** 2).mean()) 406 | 407 | return self.A, self.B, fit_err 408 | 409 | def linear_forward_unstructured(self, g, u, rel_attrs=None): 410 | B, N, D = g.size() 411 | a_dim = u.size(-1) 412 | g = g.reshape(B, 1, N * D) 413 | u = u.reshape(B, 1, N * a_dim) 414 | new_g = torch.bmm(g, self.A) + torch.bmm(u, self.B) 415 | return new_g.reshape(B, N, D) 416 | 417 | def rollout_unstructured(self, g, u_seq, T, rel_attrs=None): 418 | g_list = [] 419 | for t in range(T): 420 | g = self.linear_forward_unstructured(g, u_seq[:, t]) 421 | g_list.append(g[:, None, :, :]) 422 | return torch.cat(g_list, 1) 423 | 424 | # shared small A 425 | 426 | def fit_diagonal(self, G, H, U, I_factor, rel_attrs=None): 427 | bs, T, N, D = G.size() 428 | a_dim = U.size(3) 429 | 430 | G_U = torch.cat([G, U], 3) 431 | 432 | '''B x (D + a_dim) x D''' 433 | A_B = torch.bmm( 434 | self.batch_pinv(G_U.reshape(bs, T * N, D + a_dim), I_factor), 435 | H.reshape(bs, T * N, D) 436 | ) 437 | self.A = A_B[:, :D] 438 | self.B = A_B[:, D:] 439 | 440 | fit_err = H.reshape(bs, T * N, D) - torch.bmm(G_U.reshape(bs, T * N, D + a_dim), A_B) 441 | fit_err = torch.sqrt((fit_err ** 2).mean()) 442 | 443 | return self.A, self.B, fit_err 444 | 445 | def linear_forward_diagonal(self, g, u, rel_attrs=None): 446 | new_g = torch.bmm(g, self.A) + torch.bmm(u, self.B) 447 | return new_g 448 | 449 | def rollout_diagonal(self, g, u_seq, T, rel_attrs=None): 450 | g_list = [] 451 | for t in range(T): 452 | g = self.linear_forward_diagonal(g, u_seq[:, t]) 453 | g_list.append(g[:, None, :, :]) 454 | return torch.cat(g_list, 1) 455 | 456 | @staticmethod 457 | def batch_pinv(x, I_factor): 458 | 459 | """ 460 | :param x: B x N x D (N > D) 461 | :param I_factor: 462 | :return: 463 | """ 464 | 465 | B, N, D = x.size() 466 | 467 | if N < D: 468 | x = torch.transpose(x, 1, 2) 469 | N, D = D, N 470 | trans = True 471 | else: 472 | trans = False 473 | 474 | x_t = torch.transpose(x, 1, 2) 475 | 476 | use_gpu = torch.cuda.is_available() 477 | I = torch.eye(D)[None, :, :].repeat(B, 1, 1) 478 | if use_gpu: 479 | I = I.cuda() 480 | 481 | x_pinv = torch.bmm( 482 | torch.inverse(torch.bmm(x_t, x) + I_factor * I), 483 | x_t 484 | ) 485 | 486 | if trans: 487 | x_pinv = torch.transpose(x_pinv, 1, 2) 488 | 489 | return x_pinv 490 | 491 | 492 | def regularize_state_Soft(states, rel_attrs, stat): 493 | """ 494 | :param states: B x N x state_dim 495 | :param rel_attrs: B x N x N x relation_dim 496 | :param stat: [xxx] 497 | :return new states: B x N x state_dim 498 | """ 499 | states_denorm = denormalize([states], [stat[1]], var=True)[0] 500 | states_denorm_acc = denormalize([states.clone()], [stat[1]], var=True)[0] 501 | 502 | rel_attrs = rel_attrs[0] 503 | 504 | rel_attrs_np = rel_attrs.detach().cpu().numpy() 505 | 506 | def get_rel_id(x): 507 | return np.where(x > 0)[0][0] 508 | 509 | B, N, state_dim = states.size() 510 | count = Variable(torch.FloatTensor(np.zeros((1, N, 1, 8))).to(states.device)) 511 | 512 | for i in range(N): 513 | for j in range(N): 514 | 515 | if i == j: 516 | assert get_rel_id(rel_attrs_np[i, j]) % 9 == 0 # rel_attrs[i, j, 0] == 1 517 | count[:, i, :, :] += 1 518 | continue 519 | 520 | assert torch.sum(rel_attrs[i, j]) <= 1 521 | 522 | if torch.sum(rel_attrs[i, j]) == 0: 523 | continue 524 | 525 | if get_rel_id(rel_attrs_np[i, j]) % 9 == 1: # rel_attrs[i, j, 1] == 1: 526 | assert get_rel_id(rel_attrs_np[j, i]) % 9 == 2 # rel_attrs[j, i, 2] == 1 527 | x0 = 1; 528 | y0 = 3 529 | x1 = 0; 530 | y1 = 2 531 | idx = 1 532 | elif get_rel_id(rel_attrs_np[i, j]) % 9 == 2: # rel_attrs[i, j, 2] == 1: 533 | assert get_rel_id(rel_attrs_np[j, i]) % 9 == 1 # rel_attrs[j, i, 1] == 1 534 | x0 = 3; 535 | y0 = 1 536 | x1 = 2; 537 | y1 = 0 538 | idx = 2 539 | elif get_rel_id(rel_attrs_np[i, j]) % 9 == 3: # rel_attrs[i, j, 3] == 1: 540 | assert get_rel_id(rel_attrs_np[j, i]) % 9 == 4 # rel_attrs[j, i, 4] == 1 541 | x0 = 0; 542 | y0 = 1 543 | x1 = 2; 544 | y1 = 3 545 | idx = 3 546 | elif get_rel_id(rel_attrs_np[i, j]) % 9 == 4: # rel_attrs[i, j, 4] == 1: 547 | assert get_rel_id(rel_attrs_np[j, i]) % 9 == 3 # rel_attrs[j, i, 3] == 1 548 | x0 = 1; 549 | y0 = 0 550 | x1 = 3; 551 | y1 = 2 552 | idx = 4 553 | elif get_rel_id(rel_attrs_np[i, j]) % 9 == 5: # rel_attrs[i, j, 5] == 1: 554 | assert get_rel_id(rel_attrs_np[j, i]) % 9 == 8 # rel_attrs[j, i, 8] == 1 555 | x = 0; 556 | y = 3 557 | idx = 5 558 | elif get_rel_id(rel_attrs_np[i, j]) % 9 == 8: # rel_attrs[i, j, 8] == 1: 559 | assert get_rel_id(rel_attrs_np[j, i]) % 9 == 5 # rel_attrs[j, i, 5] == 1 560 | x = 3; 561 | y = 0 562 | idx = 8 563 | elif get_rel_id(rel_attrs_np[i, j]) % 9 == 6: # rel_attrs[i, j, 6] == 1: 564 | assert get_rel_id(rel_attrs_np[j, i]) % 9 == 7 # rel_attrs[j, i, 7] == 1 565 | x = 1; 566 | y = 2 567 | idx = 6 568 | elif get_rel_id(rel_attrs_np[i, j]) % 9 == 7: # rel_attrs[i, j, 7] == 1: 569 | assert get_rel_id(rel_attrs_np[j, i]) % 9 == 6 # rel_attrs[j, i, 6] == 1 570 | x = 2; 571 | y = 1 572 | idx = 7 573 | else: 574 | AssertionError("Unknown rel_attr %f" % rel_attrs[i, j]) 575 | 576 | if idx < 5: 577 | # if connect by two points 578 | x0 *= 2; 579 | y0 *= 2 580 | x1 *= 2; 581 | y1 *= 2 582 | count[:, i, :, x0:x0 + 2] += 1 583 | count[:, i, :, x1:x1 + 2] += 1 584 | states_denorm_acc[:, i, x0:x0 + 2] += states_denorm[:, j, y0:y0 + 2] 585 | states_denorm_acc[:, i, x0 + 8:x0 + 10] += states_denorm[:, j, y0 + 8:y0 + 10] 586 | states_denorm_acc[:, i, x1:x1 + 2] += states_denorm[:, j, y1:y1 + 2] 587 | states_denorm_acc[:, i, x1 + 8:x1 + 10] += states_denorm[:, j, y1 + 8:y1 + 10] 588 | 589 | else: 590 | # if connected by a corner 591 | x *= 2; 592 | y *= 2 593 | count[:, i, :, x:x + 2] += 1 594 | states_denorm_acc[:, i, x:x + 2] += states_denorm[:, j, y:y + 2] 595 | states_denorm_acc[:, i, x + 8:x + 10] += states_denorm[:, j, y + 8:y + 10] 596 | 597 | states_denorm = states_denorm_acc.view(B, N, 2, state_dim // 2) / count 598 | states_denorm = states_denorm.view(B, N, state_dim) 599 | 600 | return normalize([states_denorm], [stat[1]], var=True)[0] -------------------------------------------------------------------------------- /models/KoopmanBaselineModel.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | from torch.autograd import Variable 6 | 7 | from data import denormalize, normalize 8 | from utils import load_data 9 | 10 | 11 | class KoopmanBaseline(object): 12 | def __init__(self, args): 13 | self.args = args 14 | self.stat_path = os.path.join(self.args.dataf, 'stat.h5' if not args.demo else 'stat_demo.h5') 15 | self.stat = load_data(['attrs', 'states', 'actions'], self.stat_path) 16 | self.A = None 17 | self.B = None 18 | if args.fit_type == 'structured': 19 | self.system_identify = self.fit 20 | self.simulate = self.rollout 21 | self.step = self.linear_forward 22 | 23 | def to_s(self, attrs, gcodes, rel_attrs, pstep=None): 24 | """ 25 | :param gcodes: B x N x G 26 | :return states: B x N x D 27 | """ 28 | states = gcodes[:, :, :self.args.state_dim] 29 | if self.args.env in ['Soft', 'Swim']: 30 | states = self.regularize_state_Soft(states, rel_attrs, self.stat) 31 | return states 32 | 33 | def to_g(self, attrs, states, rel_attrs, pstep=None): 34 | """ 35 | :param states: B x N x D 36 | :return gcodes: B x N x G 37 | """ 38 | B, N, D = states.size() 39 | if self.args.env in ['Soft', 'Swim']: 40 | base = (states[:, :, :4] + states[:, :, 4:8] + states[:, :, 8:12] + states[:, :, 12:16]) / 4 41 | else: 42 | base = states 43 | 44 | one = states[:, :, :1].clone() 45 | one[:] = 1 46 | g_list = [states, one] 47 | 48 | for order in range(2, self.args.baseline_order + 1): 49 | # for j in range(1, order): 50 | # poly = ((base[:, :, :, None] ** j) * (base[:, :, None, :] ** (order - j))).reshape(B, N, -1) 51 | # g_list.append(poly) 52 | poly = ((base[:, :, :, None] ** (order - 1)) * (base[:, :, None, :] ** 1)).reshape(B, N, -1) / (3 ** order) 53 | g_list.append(poly) 54 | 55 | # base_square = (base[:,:,:,None] * base[:,:,None,:]).reshape(B,N,-1) 56 | # base_cube = ((base[:,:,:,None] ** 2) * base[:,:,None,:]).reshape(B,N,-1) 57 | 58 | # bavg = base.mean(1)[:,None,:].repeat(1,N,1) 59 | # base_avg = base * bavg 60 | 61 | # gcodes = torch.cat([states, base_square, base_cube], 2) 62 | gcodes = torch.cat(g_list, 2) 63 | return gcodes 64 | 65 | @staticmethod 66 | def get_aug(G, rel_attrs): 67 | """ 68 | :param G: B x T x N x D 69 | :param rel_attrs: B x N x N x R 70 | :return augG: B x T x N x R D 71 | """ 72 | B, T, N, D = G.size() 73 | R = rel_attrs.size(-1) 74 | 75 | sumG_list = [] 76 | for i in range(R): 77 | ''' B x T x N x N ''' 78 | adj = rel_attrs[:, :, :, i][:, None, :, :].repeat(1, T, 1, 1) 79 | sumG = torch.bmm( 80 | adj.reshape(B * T, N, N), 81 | G.reshape(B * T, N, D) 82 | ).reshape(B, T, N, D) 83 | sumG_list.append(sumG) 84 | 85 | augG = torch.cat(sumG_list, 3) 86 | 87 | return augG 88 | 89 | def fit(self, G, H, U, rel_attrs, I_factor): 90 | """ 91 | :param G: B x T x N x D 92 | :param H: B x T x N x D 93 | :param U: B x T x N x a_dim 94 | :param rel_attrs: B x N x N x R (relation_dim) rel_attrs[i,j] ==> receiver i, sender j 95 | :param I_factor: scalor 96 | :return: 97 | A: B x R D x D 98 | B: B x R a_dim x D 99 | s.t. 100 | H = augG @ A + augU @ B 101 | """ 102 | 103 | ''' B x R: sqrt(# of appearance of block matrices of the same type)''' 104 | rel_weights = torch.sqrt(rel_attrs.sum(1).sum(1)) 105 | rel_weights = torch.clamp(rel_weights, min=1e-8) 106 | 107 | bs, T, N, D = G.size() 108 | R = rel_attrs.size(-1) 109 | a_dim = U.size(3) 110 | 111 | ''' B x T x N x R D ''' 112 | augG = self.get_aug(G, rel_attrs) 113 | ''' B x T x N x R a_dim''' 114 | augU = self.get_aug(U, rel_attrs) 115 | 116 | augG_reweight = augG.reshape(bs, T, N, R, D) / rel_weights[:, None, None, :, None] 117 | augU_reweight = augU.reshape(bs, T, N, R, a_dim) / rel_weights[:, None, None, :, None] 118 | 119 | ''' B x TN x R(D + a_dim)''' 120 | GU_reweight = torch.cat([augG_reweight.reshape(bs, T * N, R * D), 121 | augU_reweight.reshape(bs, T * N, R * a_dim)], 2) 122 | 123 | '''B x (R * D + R * a_dim) x D''' 124 | AB_reweight = torch.bmm( 125 | self.batch_pinv(GU_reweight, I_factor), 126 | H.reshape(bs, T * N, D) 127 | ) 128 | self.A = AB_reweight[:, :R * D].reshape(bs, R, D, D) / rel_weights[:, :, None, None] 129 | self.B = AB_reweight[:, R * D:].reshape(bs, R, a_dim, D) / rel_weights[:, :, None, None] 130 | 131 | self.A = self.A.reshape(bs, R * D, D) 132 | self.B = self.B.reshape(bs, R * a_dim, D) 133 | 134 | fit_err = H.reshape(bs, T * N, D) - torch.bmm(GU_reweight, AB_reweight) 135 | fit_err = torch.sqrt((fit_err ** 2).mean()) 136 | 137 | return self.A, self.B, fit_err 138 | 139 | def linear_forward(self, g, u, rel_attrs): 140 | """ 141 | :param g: B x N x D 142 | :param u: B x N x a_dim 143 | :param rel_attrs: B x N x N x R 144 | :return B x N x D 145 | """ 146 | ''' B x N x R D ''' 147 | aug_g = self.get_aug(G=g[:, None, :, :], rel_attrs=rel_attrs)[:, 0] 148 | ''' B x N x R a_dim''' 149 | aug_u = self.get_aug(G=u[:, None, :, :], rel_attrs=rel_attrs)[:, 0] 150 | 151 | new_g = torch.bmm(aug_g, self.A) + torch.bmm(aug_u, self.B) 152 | return new_g 153 | 154 | def rollout(self, g, u_seq, T, rel_attrs): 155 | """ 156 | :param g: B x N x D 157 | :param u_seq: B x T x N x a_dim 158 | :param rel_attrs: B x N x N x R 159 | :param T: 160 | :return: 161 | """ 162 | g_list = [] 163 | for t in range(T): 164 | g = self.linear_forward(g, u_seq[:, t], rel_attrs) 165 | g_list.append(g[:, None, :, :]) 166 | return torch.cat(g_list, 1) 167 | 168 | @staticmethod 169 | def batch_pinv(x, I_factor): 170 | 171 | """ 172 | :param x: B x N x D (N > D) 173 | :param I_factor: 174 | :return: 175 | """ 176 | 177 | B, N, D = x.size() 178 | 179 | if N < D: 180 | x = torch.transpose(x, 1, 2) 181 | N, D = D, N 182 | trans = True 183 | else: 184 | trans = False 185 | 186 | x_t = torch.transpose(x, 1, 2) 187 | 188 | I = torch.eye(D)[None, :, :].repeat(B, 1, 1) 189 | use_gpu = torch.cuda.is_available() 190 | if use_gpu: I = I.cuda() 191 | 192 | x_pinv = torch.bmm( 193 | torch.inverse(torch.bmm(x_t, x) + I_factor * I), 194 | x_t 195 | ) 196 | 197 | if trans: 198 | x_pinv = torch.transpose(x_pinv, 1, 2) 199 | 200 | return x_pinv 201 | 202 | @staticmethod 203 | def regularize_state_Soft(states, rel_attrs, stat): 204 | """ 205 | :param states: B x N x state_dim 206 | :param rel_attrs: B x N x N x relation_dim 207 | :param stat: [xxx] 208 | :return new states: B x N x state_dim 209 | """ 210 | states_denorm = denormalize([states], [stat[1]], var=True)[0] 211 | states_denorm_acc = denormalize([states.clone()], [stat[1]], var=True)[0] 212 | 213 | rel_attrs = rel_attrs[0] 214 | 215 | rel_attrs_np = rel_attrs.detach().cpu().numpy() 216 | 217 | def get_rel_id(x): 218 | return np.where(x > 0)[0][0] 219 | 220 | B, N, state_dim = states.size() 221 | count = Variable(torch.FloatTensor(np.zeros((1, N, 1, 8))).to(states.device)) 222 | 223 | for i in range(N): 224 | for j in range(N): 225 | 226 | if i == j: 227 | assert get_rel_id(rel_attrs_np[i, j]) % 9 == 0 # rel_attrs[i, j, 0] == 1 228 | count[:, i, :, :] += 1 229 | continue 230 | 231 | assert torch.sum(rel_attrs[i, j]) <= 1 232 | 233 | if torch.sum(rel_attrs[i, j]) == 0: 234 | continue 235 | 236 | if get_rel_id(rel_attrs_np[i, j]) % 9 == 1: # rel_attrs[i, j, 1] == 1: 237 | assert get_rel_id(rel_attrs_np[j, i]) % 9 == 2 # rel_attrs[j, i, 2] == 1 238 | x0 = 1; 239 | y0 = 3 240 | x1 = 0; 241 | y1 = 2 242 | idx = 1 243 | elif get_rel_id(rel_attrs_np[i, j]) % 9 == 2: # rel_attrs[i, j, 2] == 1: 244 | assert get_rel_id(rel_attrs_np[j, i]) % 9 == 1 # rel_attrs[j, i, 1] == 1 245 | x0 = 3; 246 | y0 = 1 247 | x1 = 2; 248 | y1 = 0 249 | idx = 2 250 | elif get_rel_id(rel_attrs_np[i, j]) % 9 == 3: # rel_attrs[i, j, 3] == 1: 251 | assert get_rel_id(rel_attrs_np[j, i]) % 9 == 4 # rel_attrs[j, i, 4] == 1 252 | x0 = 0; 253 | y0 = 1 254 | x1 = 2; 255 | y1 = 3 256 | idx = 3 257 | elif get_rel_id(rel_attrs_np[i, j]) % 9 == 4: # rel_attrs[i, j, 4] == 1: 258 | assert get_rel_id(rel_attrs_np[j, i]) % 9 == 3 # rel_attrs[j, i, 3] == 1 259 | x0 = 1; 260 | y0 = 0 261 | x1 = 3; 262 | y1 = 2 263 | idx = 4 264 | elif get_rel_id(rel_attrs_np[i, j]) % 9 == 5: # rel_attrs[i, j, 5] == 1: 265 | assert get_rel_id(rel_attrs_np[j, i]) % 9 == 8 # rel_attrs[j, i, 8] == 1 266 | x = 0; 267 | y = 3 268 | idx = 5 269 | elif get_rel_id(rel_attrs_np[i, j]) % 9 == 8: # rel_attrs[i, j, 8] == 1: 270 | assert get_rel_id(rel_attrs_np[j, i]) % 9 == 5 # rel_attrs[j, i, 5] == 1 271 | x = 3; 272 | y = 0 273 | idx = 8 274 | elif get_rel_id(rel_attrs_np[i, j]) % 9 == 6: # rel_attrs[i, j, 6] == 1: 275 | assert get_rel_id(rel_attrs_np[j, i]) % 9 == 7 # rel_attrs[j, i, 7] == 1 276 | x = 1; 277 | y = 2 278 | idx = 6 279 | elif get_rel_id(rel_attrs_np[i, j]) % 9 == 7: # rel_attrs[i, j, 7] == 1: 280 | assert get_rel_id(rel_attrs_np[j, i]) % 9 == 6 # rel_attrs[j, i, 6] == 1 281 | x = 2; 282 | y = 1 283 | idx = 7 284 | else: 285 | AssertionError("Unknown rel_attr %f" % rel_attrs[i, j]) 286 | 287 | if idx < 5: 288 | # if connect by two points 289 | x0 *= 2; 290 | y0 *= 2 291 | x1 *= 2; 292 | y1 *= 2 293 | count[:, i, :, x0:x0 + 2] += 1 294 | count[:, i, :, x1:x1 + 2] += 1 295 | states_denorm_acc[:, i, x0:x0 + 2] += states_denorm[:, j, y0:y0 + 2] 296 | states_denorm_acc[:, i, x0 + 8:x0 + 10] += states_denorm[:, j, y0 + 8:y0 + 10] 297 | states_denorm_acc[:, i, x1:x1 + 2] += states_denorm[:, j, y1:y1 + 2] 298 | states_denorm_acc[:, i, x1 + 8:x1 + 10] += states_denorm[:, j, y1 + 8:y1 + 10] 299 | 300 | else: 301 | # if connected by a corner 302 | x *= 2; 303 | y *= 2 304 | count[:, i, :, x:x + 2] += 1 305 | states_denorm_acc[:, i, x:x + 2] += states_denorm[:, j, y:y + 2] 306 | states_denorm_acc[:, i, x + 8:x + 10] += states_denorm[:, j, y + 8:y + 10] 307 | 308 | states_denorm = states_denorm_acc.view(B, N, 2, state_dim // 2) / count 309 | states_denorm = states_denorm.view(B, N, state_dim) 310 | 311 | return normalize([states_denorm], [stat[1]], var=True)[0] 312 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunzhuLi/CompositionalKoopmanOperators/116057b11192bb2fbea2b9af411cddcee354dae8/models/__init__.py -------------------------------------------------------------------------------- /preprocess_data.py: -------------------------------------------------------------------------------- 1 | from config import gen_args 2 | import os 3 | from utils import * 4 | from data import prepare_input 5 | from progressbar import ProgressBar 6 | import multiprocessing as mp 7 | from socket import gethostname 8 | 9 | args = gen_args() 10 | 11 | data_names = ['attrs', 'states', 'actions'] 12 | prepared_names = ['attrs', 'states', 'actions', 'rel_attrs'] 13 | 14 | stat_path = os.path.join(args.dataf, 'stat.h5') 15 | stat = load_data(data_names, stat_path) 16 | 17 | 18 | def prepare_seq(info): 19 | phase, rollout_idx = info 20 | data_dir = os.path.join(args.dataf, phase) 21 | if phase == 'extra' and gethostname().startswith('netmit'): 22 | data_dir = args.dataf + '_' + phase 23 | 24 | # get param 25 | if args.env == 'Rope': 26 | param = None 27 | elif args.env in ['Soft', 'Swim']: 28 | param_file = os.path.join(data_dir, str(rollout_idx // args.group_size) + '.param') 29 | param = torch.load(param_file) 30 | else: 31 | assert False 32 | 33 | # prepare input data 34 | seq_data = None 35 | for t in range(args.time_step): 36 | data_path = os.path.join(data_dir, str(rollout_idx), str(t) + '.h5') 37 | data = load_data(data_names, data_path) 38 | data = prepare_input(data, stat, args, param=param) 39 | if seq_data is None: 40 | seq_data = [[d] for d in data] 41 | else: 42 | for i, d in enumerate(data): 43 | seq_data[i].append(d) 44 | seq_data = [np.array(d).astype(np.float32) for d in seq_data] 45 | 46 | assert len(seq_data) == len(prepared_names) 47 | 48 | store_data(prepared_names, seq_data, os.path.join(data_dir, str(rollout_idx) + '.rollout.h5')) 49 | 50 | 51 | def sub_thread(info): 52 | n_workers, idx, n_rollout, phase = info 53 | bar = ProgressBar() 54 | n = n_rollout // n_workers 55 | for i in bar(range(n)): 56 | prepare_seq(info=(phase, n * idx + i)) 57 | 58 | 59 | n_workers = 10 60 | pool = mp.Pool(processes=n_workers) 61 | 62 | num_train = int(args.n_rollout * args.train_valid_ratio) 63 | num_valid = args.n_rollout - num_train 64 | 65 | infos = [(n_workers, idx, num_train, 'train') for idx in range(n_workers)] 66 | pool.map(sub_thread, infos) 67 | 68 | infos = [(n_workers, idx, num_valid, 'valid') for idx in range(n_workers)] 69 | pool.map(sub_thread, infos) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | numpy 3 | scipy 4 | matplotlib 5 | progressbar2 6 | opencv-python 7 | pymunk==6.0.0 8 | cvxpy 9 | -------------------------------------------------------------------------------- /scripts/eval_Rope.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | CUDA_VISIBLE_DEVICES=0 python eval.py \ 3 | --env Rope \ 4 | --pstep 2 \ 5 | --g_dim 32 \ 6 | --len_seq 64 \ 7 | --I_factor 10 \ 8 | --fit_type structured \ 9 | --fit_num 8 \ 10 | --eval_set demo \ 11 | -------------------------------------------------------------------------------- /scripts/eval_Soft.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | CUDA_VISIBLE_DEVICES=0 python eval.py \ 3 | --env Soft \ 4 | --pstep 2 \ 5 | --g_dim 32 \ 6 | --len_seq 64 \ 7 | --I_factor 10 \ 8 | --fit_type structured \ 9 | --fit_num 8 \ 10 | --eval_set demo \ 11 | -------------------------------------------------------------------------------- /scripts/eval_Swim.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | CUDA_VISIBLE_DEVICES=0 python eval.py \ 3 | --env Swim \ 4 | --pstep 2 \ 5 | --g_dim 32 \ 6 | --len_seq 64 \ 7 | --I_factor 10 \ 8 | --fit_type structured \ 9 | --fit_num 8 \ 10 | --eval_set demo \ 11 | -------------------------------------------------------------------------------- /scripts/mpc_Rope.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | CUDA_VISIBLE_DEVICES=0 python shoot.py \ 3 | --env Rope \ 4 | --pstep 2 \ 5 | --g_dim 32 \ 6 | --len_seq 64 \ 7 | --I_factor 10 \ 8 | --fit_type structured \ 9 | --optim_type qp \ 10 | --fit_num 8 \ 11 | --roll_step 40 \ 12 | --roll_start 0 \ 13 | --feedback 40 \ 14 | --shoot_set demo \ 15 | -------------------------------------------------------------------------------- /scripts/mpc_Soft.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | CUDA_VISIBLE_DEVICES=0 python shoot.py \ 3 | --env Soft \ 4 | --pstep 2 \ 5 | --g_dim 32 \ 6 | --len_seq 64 \ 7 | --I_factor 10 \ 8 | --fit_type structured \ 9 | --optim_type qp \ 10 | --fit_num 8 \ 11 | --roll_step 64 \ 12 | --roll_start 0 \ 13 | --feedback 32 \ 14 | --shoot_set demo \ 15 | -------------------------------------------------------------------------------- /scripts/mpc_Swim.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | CUDA_VISIBLE_DEVICES=0 python shoot.py \ 3 | --env Swim \ 4 | --pstep 2 \ 5 | --g_dim 32 \ 6 | --len_seq 64 \ 7 | --I_factor 10 \ 8 | --fit_type structured \ 9 | --optim_type qp \ 10 | --fit_num 8 \ 11 | --roll_step 64 \ 12 | --roll_start 0 \ 13 | --feedback 32 \ 14 | --shoot_set demo \ 15 | -------------------------------------------------------------------------------- /scripts/train_Rope.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | CUDA_VISIBLE_DEVICES=0 python train.py \ 3 | --env Rope \ 4 | --len_seq 64 \ 5 | --I_factor 10 \ 6 | --batch_size 8 \ 7 | --lr 1e-4 \ 8 | --g_dim 32 \ 9 | --pstep 2 \ 10 | --fit_type structured \ 11 | --gen_data 1 \ 12 | -------------------------------------------------------------------------------- /scripts/train_Soft.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | CUDA_VISIBLE_DEVICES=0 python train.py \ 3 | --env Soft \ 4 | --len_seq 64 \ 5 | --I_factor 10 \ 6 | --batch_size 8 \ 7 | --lr 1e-4 \ 8 | --g_dim 32 \ 9 | --pstep 2 \ 10 | --fit_type structured \ 11 | --gen_data 1 \ 12 | -------------------------------------------------------------------------------- /scripts/train_Swim.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | CUDA_VISIBLE_DEVICES=0 python train.py \ 3 | --env Swim \ 4 | --len_seq 64 \ 5 | --I_factor 10 \ 6 | --batch_size 8 \ 7 | --lr 1e-4 \ 8 | --g_dim 32 \ 9 | --pstep 2 \ 10 | --fit_type structured \ 11 | --log_per_iter 100 \ 12 | --regular_data 1 \ 13 | --gen_data 1 \ 14 | -------------------------------------------------------------------------------- /shoot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cvxpy as cp 4 | from cvxpy import quad_form 5 | 6 | import torch 7 | import torch.optim as optim 8 | import torch.nn.functional as F 9 | 10 | from physics_engine import RopeEngine, SoftEngine, SwimEngine 11 | from data import load_data, normalize, denormalize 12 | from models.CompositionalKoopmanOperators import CompositionalKoopmanOperators, regularize_state_Soft 13 | from models.KoopmanBaselineModel import KoopmanBaseline 14 | from utils import to_var, to_np, Tee, norm, get_flat, print_args 15 | 16 | from progressbar import ProgressBar 17 | 18 | from config import gen_args 19 | from socket import gethostname 20 | 21 | args = gen_args() 22 | 23 | os.system("mkdir -p " + args.shootf) 24 | 25 | log_path = os.path.join(args.shootf, 'log.txt') 26 | tee = Tee(log_path, 'w') 27 | 28 | print_args(args) 29 | 30 | print(f"Load stored dataset statistics from {args.stat_path}!") 31 | stat = load_data(args.data_names, args.stat_path) 32 | 33 | data_names = ['attrs', 'states', 'actions'] 34 | prepared_names = ['attrs', 'states', 'actions', 'rel_attrs'] 35 | data_dir = os.path.join(args.dataf, args.shoot_set) 36 | 37 | if args.shoot_set == 'extra' and gethostname().startswith('netmit'): 38 | data_dir = args.dataf + '_' + args.shoot_set 39 | 40 | ''' 41 | model 42 | ''' 43 | # build model 44 | use_gpu = torch.cuda.is_available() 45 | if not args.baseline: 46 | """ Koopman model""" 47 | model = CompositionalKoopmanOperators(args, residual=False, use_gpu=use_gpu) 48 | 49 | # load pretrained checkpoint 50 | if args.shoot_epoch == -1: 51 | model_path = os.path.join(args.outf, 'net_best.pth') 52 | else: 53 | model_path = os.path.join(args.outf, 'net_epoch_%d_iter_%d.pth' % (args.shoot_epoch, args.shoot_iter)) 54 | 55 | print("Loading saved ckp from %s" % model_path) 56 | model.load_state_dict(torch.load(model_path, map_location=torch.device('cuda:0' if use_gpu else 'cpu'))) 57 | model.eval() 58 | if use_gpu: model.cuda() 59 | 60 | else: 61 | """ Koopman Baselinese """ 62 | model = KoopmanBaseline(args) 63 | 64 | ''' 65 | shoot 66 | ''' 67 | 68 | if args.env == 'Rope': 69 | engine = RopeEngine(args.dt, args.state_dim, args.action_dim, args.param_dim) 70 | elif args.env == 'Soft': 71 | engine = SoftEngine(args.dt, args.state_dim, args.action_dim, args.param_dim) 72 | elif args.env == 'Swim': 73 | engine = SwimEngine(args.dt, args.state_dim, args.action_dim, args.param_dim) 74 | else: 75 | assert False 76 | 77 | 78 | def get_more_trajectories(roll_idx): 79 | group_idx = roll_idx // args.group_size 80 | offset = group_idx * args.group_size 81 | 82 | all_seq = [[], [], [], []] 83 | 84 | for i in range(1, args.fit_num + 1): 85 | new_idx = (roll_idx + i - offset) % args.group_size + offset 86 | seq_data = load_data(prepared_names, os.path.join(data_dir, str(new_idx) + '.rollout.h5')) 87 | for j in range(4): 88 | all_seq[j].append(seq_data[j]) 89 | 90 | all_seq = [np.array(all_seq[j], dtype=np.float32) for j in range(4)] 91 | return all_seq 92 | 93 | def mpc_qp(g_cur, g_goal, time_cur, T, rel_attrs, A_t, B_t, Q, R, node_attrs=None, 94 | actions=None, gt_info=None): 95 | """ 96 | Model Predictive Control + Quadratic Programming 97 | :param rel_attrs: N x N x relation_dim 98 | :param node_attrs: N x attributes_dim 99 | :return action sequence u: T - 1 x N x action_dim 100 | """ 101 | 102 | n_obj = engine.num_obj 103 | constraints = [] 104 | 105 | if not args.baseline: 106 | D = args.g_dim 107 | else: 108 | D = g_goal.shape[-1] 109 | 110 | if args.fit_type == 'structured': 111 | dim_a = args.action_dim 112 | g = cp.Variable((T * n_obj, D)) 113 | u = cp.Variable(((T - 1) * n_obj, args.action_dim)) 114 | augG = cp.Variable(((T - 1) * n_obj, D * args.relation_dim)) 115 | augU = cp.Variable(((T - 1) * n_obj, args.action_dim * args.relation_dim)) 116 | 117 | for t in range(T - 1): 118 | st_idx = t * n_obj 119 | ed_idx = (t + 1) * n_obj 120 | for r in range(args.relation_dim): 121 | constraints.append(augG[st_idx:ed_idx, r * D: (r + 1) * D] == 122 | rel_attrs[:, :, r] @ g[st_idx:ed_idx]) 123 | for r in range(args.relation_dim): 124 | constraints.append(augU[st_idx:ed_idx, r * dim_a: (r + 1) * dim_a] == 125 | rel_attrs[:, :, r] @ u[st_idx:ed_idx]) 126 | 127 | cost = 0 128 | 129 | for idx in range(n_obj): 130 | # constrain the initial g 131 | constraints.append(g[idx] == g_cur[idx]) 132 | 133 | for t in range(1, T): 134 | cur_idx = t * n_obj + idx 135 | prv_idx = (t - 1) * n_obj + idx 136 | 137 | zero_normed = -stat[2][:, 0] / stat[2][:, 1] 138 | act_scale_max_normed = (args.act_scale - stat[2][:, 0]) / stat[2][:, 1] 139 | act_scale_min_normed = (- args.act_scale - stat[2][:, 0]) / stat[2][:, 1] 140 | constraints.append(u[prv_idx] >= act_scale_min_normed) 141 | constraints.append(u[prv_idx] <= act_scale_max_normed) 142 | 143 | if args.env == 'Rope': 144 | if idx == 0: 145 | # first mass: action_y = 0 (no action_y now) 146 | pass 147 | else: 148 | # other mass: action_x = action_y = 0 149 | constraints.append(u[prv_idx][:] == zero_normed) 150 | 151 | elif args.env in ['Soft', 'Swim']: 152 | if node_attrs[idx, 0] < 1e-6: 153 | # if there is no actuation 154 | constraints.append(u[prv_idx][:] == zero_normed) 155 | else: 156 | pass 157 | 158 | constraints.append(g[cur_idx] == A_t @ augG[prv_idx] + B_t @ augU[prv_idx]) 159 | # penalize large actions 160 | cost += quad_form(u[prv_idx] - zero_normed, R) 161 | cost += quad_form(g[(T - 1) * n_obj + idx] - g_goal[idx], Q) 162 | 163 | elif args.fit_type == 'unstructured': 164 | 165 | zero_normed = -stat[2][:, 0] / stat[2][:, 1] 166 | 167 | g = cp.Variable((T, n_obj * args.g_dim)) 168 | u = cp.Variable((T - 1, n_obj * args.action_dim)) 169 | 170 | cost = 0 171 | 172 | constraints.append(g[0] == g_cur.ravel()) 173 | 174 | for t in range(1, T): 175 | 176 | act_scale_normed = (args.act_scale - stat[2][:, 0]) / stat[2][:, 1] 177 | act_scale_normed = np.repeat(act_scale_normed, n_obj, 0) 178 | constraints.append(u[t - 1] >= - act_scale_normed) 179 | constraints.append(u[t - 1] <= act_scale_normed) 180 | 181 | if args.env == 'Rope': 182 | # set action on balls to zeros expect the first one 183 | for idx in range(1, n_obj): 184 | constraints.append(u[t - 1][idx] == zero_normed) 185 | 186 | elif args.env in ['Soft', 'Swim']: 187 | for idx in range(0, n_obj): 188 | if node_attrs[idx, 0] < 1e-6: 189 | constraints.append(u[t - 1][idx * args.action_dim: (idx + 1) * args.action_dim] == zero_normed) 190 | 191 | constraints.append(g[t] == A_t @ g[t - 1] + B_t @ u[t - 1]) 192 | 193 | for i in range(n_obj): 194 | cost += quad_form(u[t - 1][i * args.action_dim:(i + 1) * args.action_dim] - zero_normed, R) 195 | 196 | for i in range(n_obj): 197 | cost += quad_form(g[T - 1][i * args.g_dim:(i + 1) * args.g_dim] - g_goal[i], Q) 198 | 199 | elif args.fit_type == 'diagonal': 200 | 201 | zero_normed = -stat[2][:, 0] / stat[2][:, 1] 202 | 203 | g = cp.Variable((T, n_obj * args.g_dim)) 204 | u = cp.Variable((T - 1, n_obj * args.action_dim)) 205 | 206 | cost = 0 207 | constraints.append(g[0] == g_cur.ravel()) 208 | 209 | for t in range(1, T): 210 | act_scale_normed = (args.act_scale - stat[2][:, 0]) / stat[2][:, 1] 211 | act_scale_normed = np.repeat(act_scale_normed, n_obj, 0) 212 | constraints.append(u[t - 1] >= - act_scale_normed) 213 | constraints.append(u[t - 1] <= act_scale_normed) 214 | if args.env == 'Rope': 215 | # set action on balls to zeros expect the first one 216 | for idx in range(1, n_obj): 217 | constraints.append(u[t - 1][idx] == zero_normed) 218 | elif args.env in ['Soft', 'Swim']: 219 | for idx in range(0, n_obj): 220 | if node_attrs[idx, 0] < 1e-6: 221 | constraints.append(u[t - 1][idx * args.action_dim: (idx + 1) * args.action_dim] == zero_normed) 222 | 223 | for i in range(n_obj): 224 | t1 = A_t @ g[t - 1][i * args.g_dim:(i + 1) * args.g_dim] 225 | t2 = B_t @ u[t - 1][i * args.action_dim:(i + 1) * args.action_dim] 226 | if args.env == 'Rope': 227 | t2 = t2[:, 0] 228 | constraints.append(g[t][i * args.g_dim:(i + 1) * args.g_dim] == t1 + t2) 229 | cost += quad_form(u[t - 1][i * args.action_dim:(i + 1) * args.action_dim] - zero_normed, R) 230 | for i in range(n_obj): 231 | cost += quad_form(g[T - 1][i * args.g_dim:(i + 1) * args.g_dim] - g_goal[i], Q) 232 | 233 | objective = cp.Minimize(cost) 234 | prob = cp.Problem(objective, constraints) 235 | result = prob.solve() 236 | 237 | u_val = u.value 238 | g_val = g.value 239 | u = u_val.reshape(T - 1, n_obj, args.action_dim) 240 | 241 | u = denormalize([u], [stat[2]])[0] 242 | g = g_val.reshape(T, n_obj, D) 243 | 244 | return u 245 | 246 | 247 | def shoot_mpc_qp(roll_idx): 248 | print(f'\n=== Model Based Control on Example {roll_idx} ===') 249 | 250 | ''' 251 | load data 252 | ''' 253 | seq_data = load_data(prepared_names, os.path.join(data_dir, str(roll_idx) + '.rollout.h5')) 254 | attrs, states, actions, rel_attrs = [to_var(d.copy(), use_gpu=use_gpu) for d in seq_data] 255 | 256 | seq_data = denormalize(seq_data, stat) 257 | attrs_gt, states_gt, actions_gt = seq_data[:3] 258 | 259 | ''' 260 | setup engine 261 | ''' 262 | param_file = os.path.join(data_dir, str(roll_idx // args.group_size) + '.param') 263 | param = torch.load(param_file) 264 | engine.init(param) 265 | n_obj = engine.num_obj 266 | 267 | ''' 268 | fit koopman 269 | ''' 270 | print('===> system identification!') 271 | fit_data = get_more_trajectories(roll_idx) 272 | fit_data = [to_var(d, use_gpu=use_gpu) for d in fit_data] 273 | bs = args.fit_num 274 | 275 | attrs_flat = get_flat(fit_data[0]) 276 | states_flat = get_flat(fit_data[1]) 277 | actions_flat = get_flat(fit_data[2]) 278 | rel_attrs_flat = get_flat(fit_data[3]) 279 | 280 | g = model.to_g(attrs_flat, states_flat, rel_attrs_flat, args.pstep) 281 | g = g.view(torch.Size([bs, args.time_step]) + g.size()[1:]) 282 | 283 | G_tilde = g[:, :-1] 284 | H_tilde = g[:, 1:] 285 | U_left = fit_data[2][:, :-1] 286 | 287 | G_tilde = get_flat(G_tilde, keep_dim=True) 288 | H_tilde = get_flat(H_tilde, keep_dim=True) 289 | U_left = get_flat(U_left, keep_dim=True) 290 | 291 | A, B, fit_err = model.system_identify(G=G_tilde, H=H_tilde, U=U_left, 292 | rel_attrs=fit_data[3][:1, 0], I_factor=args.I_factor) 293 | 294 | ''' 295 | shooting 296 | ''' 297 | print('===> model based control start!') 298 | # current can not set engine to a middle state 299 | assert args.roll_start == 0 300 | 301 | start_step = args.roll_start 302 | g_start_v = model.to_g(attrs=attrs[start_step:start_step + 1], states=states[start_step:start_step + 1], 303 | rel_attrs=rel_attrs[start_step:start_step + 1], pstep=args.pstep) 304 | g_start = to_np(g_start_v[0]) 305 | 306 | if args.env == 'Rope': 307 | goal_step = args.roll_step + args.roll_start 308 | elif args.env == 'Soft': 309 | goal_step = args.roll_step + args.roll_start 310 | elif args.env == 'Swim': 311 | goal_step = args.roll_step + args.roll_start 312 | 313 | g_goal_v = model.to_g(attrs=attrs[goal_step:goal_step + 1], states=states[goal_step:goal_step + 1], 314 | rel_attrs=rel_attrs[goal_step:goal_step + 1], pstep=args.pstep) 315 | g_goal = to_np(g_goal_v[0]) 316 | 317 | states_start = states_gt[start_step] 318 | states_goal = states_gt[goal_step] 319 | states_roll = np.zeros((args.roll_step + 1, n_obj, args.state_dim)) 320 | states_roll[0] = states_start 321 | 322 | control = np.zeros((args.roll_step + 1, n_obj, args.action_dim)) 323 | # control_v = to_var(control, use_gpu, requires_grad=True) 324 | bar = ProgressBar() 325 | for step in bar(range(args.roll_step)): 326 | states_input = normalize([states_roll[step:step + 1]], [stat[1]])[0] 327 | states_input_v = to_var(states_input, use_gpu=use_gpu) 328 | g_cur_v = model.to_g(attrs=attrs[:1], states=states_input_v, 329 | rel_attrs=rel_attrs[:1], pstep=args.pstep) 330 | g_cur = to_np(g_cur_v[0]) 331 | 332 | ''' 333 | setup parameters 334 | ''' 335 | T = args.roll_step - step + 1 336 | 337 | A_v, B_v = model.A, model.B 338 | A_t = to_np(A_v[0]).T 339 | B_t = to_np(B_v[0]).T 340 | 341 | if not args.baseline: 342 | Q = np.eye(args.g_dim) 343 | else: 344 | Q = np.eye(g_goal.shape[-1]) 345 | 346 | if args.env == 'Rope': 347 | R_factor = 0.01 348 | elif args.env == 'Soft': 349 | R_factor = 0.001 350 | elif args.env == 'Swim': 351 | R_factor = 0.0001 352 | else: 353 | assert False 354 | 355 | R = np.eye(args.action_dim) * R_factor 356 | 357 | ''' 358 | generate action 359 | ''' 360 | rel_attrs_np = to_np(rel_attrs)[0] 361 | assert args.optim_type == 'qp' 362 | if step % args.feedback == 0: 363 | node_attrs = attrs_gt[0] if args.env in ['Soft', 'Swim'] else None 364 | u = mpc_qp(g_cur, g_goal, step, T, rel_attrs_np, A_t, B_t, Q, R, node_attrs=node_attrs, 365 | actions=to_np(actions[step:]), 366 | gt_info=[param, states_gt[goal_step:goal_step + 1], attrs[step:step + T], 367 | rel_attrs[step:step + T]]) 368 | else: 369 | u = u[1:] 370 | pass 371 | 372 | ''' 373 | execute action 374 | ''' 375 | engine.set_action(u[0]) # execute the first action 376 | control[step] = engine.get_action() 377 | engine.step() 378 | states_roll[step + 1] = engine.get_state() 379 | 380 | ''' 381 | render 382 | ''' 383 | engine.render(states_roll, control, param, act_scale=args.act_scale, video=True, image=True, 384 | path=os.path.join(args.shootf, str(roll_idx) + '.shoot'), 385 | states_gt=np.tile(states_gt[goal_step:goal_step + 1], (args.roll_step + 1, 1, 1)), 386 | count_down=True, gt_border=True) 387 | 388 | states_result = states_roll[args.roll_step] 389 | 390 | states_goal_normalized = normalize([states_goal], [stat[1]])[0] 391 | states_result_normalized = normalize([states_result], [stat[1]])[0] 392 | 393 | return norm(states_goal - states_result), (states_goal, states_result, states_goal_normalized, states_result_normalized) 394 | 395 | 396 | if __name__ == '__main__': 397 | os.system('mkdir -p ' + args.shootf) 398 | num_train = int(args.n_rollout * args.train_valid_ratio) 399 | num_valid = args.n_rollout - num_train 400 | ls_rollout_idx = np.arange(0, num_valid, num_valid // args.group_size // 5) 401 | 402 | if args.demo: 403 | ls_rollout_idx = np.arange(8) * 25 404 | 405 | for roll_idx in ls_rollout_idx: 406 | shoot_mpc_qp(roll_idx) 407 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | from progressbar import ProgressBar 10 | from torch.optim.lr_scheduler import ReduceLROnPlateau 11 | from torch.utils.data import Dataset, DataLoader 12 | 13 | from config import gen_args 14 | from data import PhysicsDataset 15 | from data import load_data 16 | from models.CompositionalKoopmanOperators import CompositionalKoopmanOperators 17 | from utils import count_parameters, Tee, AverageMeter, rand_int, mix_iters, get_flat, print_args 18 | 19 | args = gen_args() 20 | 21 | os.system('mkdir -p ' + args.outf) 22 | os.system('mkdir -p ' + args.dataf) 23 | tee = Tee(os.path.join(args.outf, 'train.log'), 'w') 24 | print_args(args) 25 | 26 | # generate data 27 | datasets = {phase: PhysicsDataset(args, phase) for phase in ['train', 'valid']} 28 | for phase in ['train', 'valid']: 29 | if args.gen_data: 30 | datasets[phase].gen_data() 31 | else: 32 | datasets[phase].load_data() 33 | 34 | if args.gen_data: 35 | print("Preprocessing data ...") 36 | os.system('python preprocess_data.py --env ' + args.env) 37 | 38 | args.stat = datasets['train'].stat 39 | 40 | 41 | class ShuffledDataset(Dataset): 42 | def __init__(self, 43 | mother_dataset, 44 | idx, 45 | batch_size): 46 | self.samples_per_rollout = args.time_step - args.len_seq 47 | self.mother = mother_dataset 48 | self.n_rollout = mother_dataset.n_rollout // args.n_splits 49 | self.idx = idx 50 | self.prepared_names = ['attrs', 'states', 'actions', 'rel_attrs'] 51 | self.batch_size = batch_size 52 | 53 | self.build_table() 54 | 55 | def __len__(self): 56 | return self.n_rollout * self.samples_per_rollout 57 | 58 | def build_table(self): 59 | assert self.n_rollout % args.group_size == 0 60 | bs = self.batch_size 61 | num_groups = self.n_rollout // args.group_size 62 | 63 | sample_list = [[] for _ in range(num_groups)] 64 | for i in range(self.n_rollout): 65 | for j in range(self.samples_per_rollout): 66 | gidx = i // args.group_size 67 | sample_list[gidx].append((i, j)) 68 | 69 | '''shuffle sample list''' 70 | for i in range(num_groups): 71 | l = sample_list[i] 72 | random.shuffle(l) 73 | 74 | '''padding samples in the same group such that the size can be divied by the batch size''' 75 | for i in range(num_groups): 76 | if len(sample_list[i]) % bs > 0: 77 | sample_list[i] += sample_list[i][:bs - len(sample_list[i]) % bs] 78 | 79 | '''create batches''' 80 | batch_list = [] 81 | for i in range(num_groups): 82 | l = sample_list[i] 83 | for j in range(len(l) // bs): 84 | batch_list.append(l[j * bs:j * bs + bs]) 85 | 86 | '''merge the batch list to a total sample list''' 87 | random.shuffle(batch_list) 88 | total_list = [] 89 | for batch in batch_list: 90 | total_list += batch 91 | self.sample_table = total_list 92 | 93 | def __getitem__(self, idx): 94 | # print('dataset', self.idx, 'sample', idx) 95 | idx_rollout = self.sample_table[idx][0] + self.n_rollout * self.idx 96 | idx_timestep = self.sample_table[idx][1] 97 | 98 | # prepare input data 99 | seq_data = load_data(self.prepared_names, os.path.join(self.mother.data_dir, str(idx_rollout) + '.rollout.h5')) 100 | seq_data = [d[idx_timestep:idx_timestep + args.len_seq + 1] for d in seq_data] 101 | 102 | # prepare fit data 103 | fit_idx = rand_int(0, args.group_size - 1) # new traj idx in group 104 | fit_idx = fit_idx + idx_rollout // args.group_size * args.group_size # new traj idx in global 105 | fit_data = load_data(self.prepared_names, os.path.join(self.mother.data_dir, str(fit_idx) + '.rollout.h5')) 106 | 107 | return seq_data, fit_data 108 | 109 | 110 | class SubPreparedDataset(Dataset): 111 | 112 | def __init__(self, 113 | mother_dataset, 114 | idx, ): 115 | self.samples_per_rollout = args.time_step - args.len_seq 116 | self.mother = mother_dataset 117 | self.n_rollout = mother_dataset.n_rollout // args.n_splits 118 | self.idx = idx 119 | self.prepared_names = ['attrs', 'states', 'actions', 'rel_attrs'] 120 | 121 | def __len__(self): 122 | return self.n_rollout * self.samples_per_rollout 123 | 124 | def __getitem__(self, idx): 125 | idx_rollout = idx // self.samples_per_rollout + self.n_rollout * self.idx 126 | idx_timestep = idx % self.samples_per_rollout 127 | 128 | # prepare input data 129 | seq_data = load_data(self.prepared_names, os.path.join(self.mother.data_dir, str(idx_rollout) + '.rollout.h5')) 130 | seq_data = [d[idx_timestep:idx_timestep + args.len_seq + 1] for d in seq_data] 131 | 132 | # prepare fit data 133 | fit_idx = rand_int(0, args.group_size - 1) # new traj idx in group 134 | fit_idx = fit_idx + idx_rollout // args.group_size * args.group_size # new traj idx in global 135 | fit_data = load_data(self.prepared_names, os.path.join(self.mother.data_dir, str(fit_idx) + '.rollout.h5')) 136 | 137 | return seq_data, fit_data 138 | 139 | 140 | def split_dataset(ds): 141 | assert ds.n_rollout % args.group_size == 0 142 | assert ds.n_rollout % args.n_splits == 0 143 | sub_datasets = [ShuffledDataset(mother_dataset=ds, idx=i, batch_size=args.batch_size) for i in range(args.n_splits)] 144 | return sub_datasets 145 | 146 | 147 | use_gpu = torch.cuda.is_available() 148 | 149 | """ 150 | various number of objects, need mixing datasets 151 | """ 152 | 153 | dataloaders = {} 154 | data_n_batches = {} 155 | loaders = {} 156 | for phase in ['train', 'valid']: 157 | loaders[phase] = [DataLoader( 158 | dataset=dataset, batch_size=args.batch_size, 159 | shuffle=False, 160 | num_workers=args.num_workers, ) 161 | for dataset in split_dataset(datasets[phase])] 162 | 163 | dataloaders[phase] = lambda: mix_iters(iters=[iter(loader) for loader in loaders[phase]]) 164 | 165 | num_batches = sum(len(loader) for loader in loaders[phase]) 166 | data_n_batches[phase] = num_batches 167 | 168 | # Compositional Koopman Operator 169 | model = CompositionalKoopmanOperators(args, residual=False, use_gpu=use_gpu) 170 | 171 | # print model #params 172 | print("model #params: %d" % count_parameters(model)) 173 | 174 | # if resume from a pretrained checkpoint 175 | if args.resume_epoch >= 0: 176 | model_path = os.path.join(args.outf, 'net_epoch_%d_iter_%d.pth' % (args.resume_epoch, args.resume_iter)) 177 | print("Loading saved ckp from %s" % model_path) 178 | model.load_state_dict(torch.load(model_path)) 179 | 180 | # criterion 181 | criterionMSE = nn.MSELoss() 182 | 183 | # optimizer 184 | params = model.parameters() 185 | optimizer = optim.Adam(params, lr=args.lr, betas=(args.beta1, 0.999)) 186 | scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.6, patience=2, verbose=True) 187 | 188 | if use_gpu: 189 | model = model.cuda() 190 | criterionMSE = criterionMSE.cuda() 191 | 192 | st_epoch = args.resume_epoch if args.resume_epoch > 0 else 0 193 | best_valid_loss = np.inf 194 | 195 | log_fout = open(os.path.join(args.outf, 'log_st_epoch_%d.txt' % st_epoch), 'w') 196 | 197 | for epoch in range(st_epoch, args.n_epoch): 198 | 199 | phases = ['train', 'valid'] if args.eval == 0 else ['valid'] 200 | 201 | for phase in phases: 202 | model.train(phase == 'train') 203 | meter_loss = AverageMeter() 204 | meter_loss_metric = AverageMeter() 205 | meter_loss_ae = AverageMeter() 206 | meter_loss_pred = AverageMeter() 207 | meter_fit_error = AverageMeter() 208 | meter_dist_g = AverageMeter() 209 | meter_dist_s = AverageMeter() 210 | 211 | bar = ProgressBar(max_value=data_n_batches[phase]) 212 | 213 | loader = dataloaders[phase]() 214 | 215 | for i, (seq_data, fit_data) in bar(enumerate(loader)): 216 | 217 | attrs, states, actions, rel_attrs = seq_data 218 | attrs_2, states_2, actions_2, rel_attrs_2 = fit_data 219 | # print('attrs', attrs.shape) bs x len_seq x num_obj x attr_dim 220 | # print('states', states.shape) bs x len_seq x num_obj x state_dim 221 | # print('actions', actions.shape) bs x len_seq x num_obj x action_dim 222 | # print('rel_attrs', rel_attrs.shape) bs x len_seq x num_obj x num_obj x rel_dim 223 | 224 | if use_gpu: 225 | attrs_2, states_2, actions_2, rel_attrs_2 = [x.cuda() for x in fit_data] 226 | fit_data = [attrs_2, states_2, actions_2, rel_attrs_2] 227 | 228 | with torch.set_grad_enabled(phase == 'train'): 229 | if use_gpu: 230 | attrs, states, actions, rel_attrs = [x.cuda() for x in seq_data] 231 | data = [attrs, states, actions, rel_attrs] 232 | 233 | T = args.len_seq 234 | bs = len(attrs) 235 | 236 | """ 237 | flatten fit data 238 | """ 239 | attrs_flat = get_flat(attrs_2) 240 | states_flat = get_flat(states_2) 241 | actions_flat = get_flat(actions_2) 242 | rel_attrs_flat = get_flat(rel_attrs_2) 243 | 244 | g = model.to_g(attrs_flat, states_flat, rel_attrs_flat, args.pstep) 245 | g = g.view(torch.Size([bs, args.time_step]) + g.size()[1:]) 246 | 247 | """ 248 | fit A with fit data 249 | !!! need to force that rel_attrs in one group to be the same !!! 250 | """ 251 | G_tilde = g[:, :-1] 252 | H_tilde = g[:, 1:] 253 | U_left = actions_2[:, :-1] 254 | 255 | G_tilde = get_flat(G_tilde, keep_dim=True) 256 | H_tilde = get_flat(H_tilde, keep_dim=True) 257 | U_left = get_flat(U_left, keep_dim=True) 258 | 259 | A, B, fit_err = model.system_identify(G=G_tilde, H=H_tilde, U=U_left, 260 | rel_attrs=rel_attrs[:1, 0], I_factor=args.I_factor) 261 | 262 | model.A = model.A.repeat(bs, 1, 1) 263 | model.B = model.B.repeat(bs, 1, 1) 264 | 265 | meter_fit_error.update(fit_err.item(), bs) 266 | 267 | """ 268 | forward on sequential data 269 | """ 270 | 271 | attrs_flat = get_flat(attrs) 272 | states_flat = get_flat(states) 273 | actions_flat = get_flat(actions) 274 | rel_attrs_flat = get_flat(rel_attrs) 275 | 276 | g = model.to_g(attrs_flat, states_flat, rel_attrs_flat, args.pstep) 277 | 278 | permu = np.random.permutation(bs * (T + 1)) 279 | split_0 = permu[:bs * (T + 1) // 2] 280 | split_1 = permu[bs * (T + 1) // 2:] 281 | 282 | dist_g = torch.mean((g[split_0] - g[split_1]) ** 2, dim=(1, 2)) 283 | dist_s = torch.mean((states_flat[split_0] - states_flat[split_1]) ** 2, dim=(1, 2)) 284 | scaling_factor = 10 285 | loss_metric = torch.abs(dist_g * scaling_factor - dist_s).mean() 286 | 287 | g = g.view(torch.Size([bs, T + 1]) + g.size()[1:]) 288 | 289 | """ 290 | rollout 0 -> 1 : T + 1 291 | """ 292 | U_for_pred = actions[:, : T] 293 | G_for_pred = model.simulate(T=T, g=g[:, 0], u_seq=U_for_pred, rel_attrs=rel_attrs[:, 0]) 294 | 295 | ''' rollout time: T // 2 + 1, T ''' 296 | data_for_ae = [x[:, :T + 1] for x in data] 297 | data_for_pred = [x[:, 1:T + 1] for x in data] 298 | 299 | # decode state for auto-encoding 300 | 301 | ''' BT x N x 4 ''' 302 | attrs_for_ae_flat = get_flat(data_for_ae[0]) 303 | rel_attrs_for_ae_flat = get_flat(data_for_ae[3]) 304 | decode_s_for_ae = model.to_s(attrs=attrs_for_ae_flat, gcodes=get_flat(g[:, :T + 1]), 305 | rel_attrs=rel_attrs_for_ae_flat, pstep=args.pstep) 306 | 307 | # decode state for prediction 308 | 309 | ''' BT x N x 4 ''' 310 | attrs_for_pred_flat = get_flat(data_for_pred[0]) 311 | rel_attrs_for_pred_flat = get_flat(data_for_pred[3]) 312 | decode_s_for_pred = model.to_s(attrs=attrs_for_pred_flat, gcodes=get_flat(G_for_pred), 313 | rel_attrs=rel_attrs_for_pred_flat, pstep=args.pstep) 314 | 315 | loss_auto_encode = F.l1_loss( 316 | decode_s_for_ae, states[:, :T + 1].reshape(decode_s_for_ae.shape)) 317 | loss_prediction = F.l1_loss( 318 | decode_s_for_pred, states[:, 1:].reshape(decode_s_for_pred.shape)) 319 | 320 | loss = loss_auto_encode + loss_prediction + loss_metric * args.lambda_loss_metric 321 | 322 | meter_loss_metric.update(loss_metric.item(), bs) 323 | meter_loss_ae.update(loss_auto_encode.item(), bs) 324 | meter_loss_pred.update(loss_prediction.item(), bs) 325 | 326 | meter_dist_g.update(dist_g.mean().item(), bs) 327 | meter_dist_s.update(dist_s.mean().item(), bs) 328 | 329 | '''prediction loss''' 330 | meter_loss.update(loss.item(), bs) 331 | 332 | if phase == 'train': 333 | optimizer.zero_grad() 334 | loss.backward() 335 | nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) 336 | optimizer.step() 337 | 338 | if i % args.log_per_iter == 0: 339 | log = '%s [%d/%d][%d/%d] Loss: %.6f (%.6f), sysid_error: %.6f (%.6f), loss_ae: %.6f (%.6f), loss_pred: %.6f (%.6f), ' \ 340 | 'loss_metric: %.6f (%.6f)' % ( 341 | phase, epoch, args.n_epoch, i, data_n_batches[phase], 342 | loss.item(), meter_loss.avg, 343 | fit_err.item(), meter_fit_error.avg, 344 | loss_auto_encode.item(), meter_loss_ae.avg, 345 | loss_prediction.item(), meter_loss_pred.avg, 346 | loss_metric.item(), meter_loss_metric.avg, 347 | ) 348 | 349 | print() 350 | print(log) 351 | log_fout.write(log + '\n') 352 | log_fout.flush() 353 | 354 | if phase == 'train' and i % args.ckp_per_iter == 0: 355 | torch.save(model.state_dict(), '%s/net_epoch_%d_iter_%d.pth' % (args.outf, epoch, i)) 356 | 357 | log = '%s [%d/%d] Loss: %.4f, Best valid: %.4f' % (phase, epoch, args.n_epoch, meter_loss.avg, best_valid_loss) 358 | print(log) 359 | log_fout.write(log + '\n') 360 | log_fout.flush() 361 | 362 | if phase == 'valid' and not args.eval: 363 | scheduler.step(meter_loss.avg) 364 | if meter_loss.avg < best_valid_loss: 365 | best_valid_loss = meter_loss.avg 366 | torch.save(model.state_dict(), '%s/net_best.pth' % (args.outf)) 367 | 368 | log_fout.close() 369 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import h5py 4 | import numpy as np 5 | import torch 6 | from torch.autograd import Variable 7 | 8 | def print_args(args): 9 | print("===== Experiment Configuration =====") 10 | options = vars(args) 11 | for key, value in options.items(): 12 | print(f'{key}: {value}') 13 | print("====================================") 14 | 15 | def rand_float(lo, hi): 16 | return np.random.rand() * (hi - lo) + lo 17 | 18 | 19 | def rand_int(lo, hi): 20 | return np.random.randint(lo, hi) 21 | 22 | 23 | def calc_dis(a, b): 24 | return np.sqrt((a[0] - b[0]) ** 2 + (a[1] - b[1]) ** 2) 25 | 26 | 27 | def norm(x, p=2): 28 | return np.power(np.sum(x ** p), 1. / p) 29 | 30 | 31 | def store_data(data_names, data, path): 32 | hf = h5py.File(path, 'w') 33 | for i in range(len(data_names)): 34 | hf.create_dataset(data_names[i], data=data[i]) 35 | hf.close() 36 | 37 | 38 | def load_data(data_names, path): 39 | hf = h5py.File(path, 'r') 40 | data = [] 41 | for i in range(len(data_names)): 42 | d = np.array(hf.get(data_names[i])) 43 | data.append(d) 44 | hf.close() 45 | return data 46 | 47 | 48 | def combine_stat(stat_0, stat_1): 49 | mean_0, std_0, n_0 = stat_0[:, 0], stat_0[:, 1], stat_0[:, 2] 50 | mean_1, std_1, n_1 = stat_1[:, 0], stat_1[:, 1], stat_1[:, 2] 51 | 52 | mean = (mean_0 * n_0 + mean_1 * n_1) / (n_0 + n_1) 53 | std = np.sqrt( 54 | (std_0 ** 2 * n_0 + std_1 ** 2 * n_1 + (mean_0 - mean) ** 2 * n_0 + (mean_1 - mean) ** 2 * n_1) / (n_0 + n_1)) 55 | n = n_0 + n_1 56 | 57 | return np.stack([mean, std, n], axis=-1) 58 | 59 | 60 | def init_stat(dim): 61 | # mean, std, count 62 | return np.zeros((dim, 3)) 63 | 64 | 65 | def var_norm(x): 66 | return torch.sqrt((x ** 2).sum()).item() 67 | 68 | 69 | def count_parameters(model): 70 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 71 | 72 | 73 | def get_flat(x, keep_dim=False): 74 | if keep_dim: 75 | return x.reshape(torch.Size([1, x.size(0) * x.size(1)]) + x.size()[2:]) 76 | return x.reshape(torch.Size([x.size(0) * x.size(1)]) + x.size()[2:]) 77 | 78 | 79 | def to_var(tensor, use_gpu, requires_grad=False): 80 | if use_gpu: 81 | return Variable(torch.FloatTensor(tensor).cuda(), requires_grad=requires_grad) 82 | else: 83 | return Variable(torch.FloatTensor(tensor), requires_grad=requires_grad) 84 | 85 | 86 | def to_np(x): 87 | return x.detach().cpu().numpy() 88 | 89 | 90 | def mix_iters(iters): 91 | table = [] 92 | for i, iter in enumerate(iters): 93 | table += [i] * len(iter) 94 | np.random.shuffle(table) 95 | for i in table: 96 | yield iters[i].next() 97 | 98 | 99 | class Tee(object): 100 | def __init__(self, name, mode): 101 | self.file = open(name, mode) 102 | self.stdout = sys.stdout 103 | sys.stdout = self 104 | 105 | def __del__(self): 106 | sys.stdout = self.stdout 107 | self.file.close() 108 | 109 | def write(self, data): 110 | self.file.write(data) 111 | self.stdout.write(data) 112 | 113 | def flush(self): 114 | self.file.flush() 115 | 116 | def close(self): 117 | self.__del__() 118 | 119 | 120 | class AverageMeter(object): 121 | def __init__(self): 122 | self.val = 0 123 | self.avg = 0 124 | self.sum = 0 125 | self.count = 0 126 | 127 | def reset(self): 128 | self.val = 0 129 | self.avg = 0 130 | self.sum = 0 131 | self.count = 0 132 | 133 | def update(self, val, n=1): 134 | self.val = val 135 | self.sum += val * n 136 | self.count += n 137 | self.avg = self.sum / self.count 138 | --------------------------------------------------------------------------------