├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── control_stats.py ├── data ├── sample_planar.py └── sample_pole.py ├── datasets.py ├── ilqr.py ├── ilqr_config ├── balance.json ├── cartpole.json ├── plane.json ├── swing.json └── threepole.json ├── ilqr_utils.py ├── latent_map_pendulum.py ├── latent_map_planar.py ├── losses.py ├── map_scale_stat.py ├── mdp ├── cartpole_mdp.py ├── common.py ├── pendulum_mdp.py ├── plane_obstacles_mdp.py ├── pole_base.py └── three_pole_mdp.py ├── networks.py ├── pc3.yml ├── pc3_model.py ├── retrain_dynamics.py ├── sample_results ├── balance.gif ├── cartpole.gif ├── pc3_5.png ├── pc3_6.png ├── pc3_7.png ├── pc3_model.png ├── pc3_planar_1.png ├── pc3_planar_2.png ├── pc3_planar_5.png ├── pcc_pendulum_1.png ├── pcc_pendulum_2.png ├── pcc_pendulum_4.png ├── pcc_planar_1.png ├── pcc_planar_10.png ├── pcc_planar_2.png ├── planar.gif ├── swing.gif ├── table_maps.png ├── table_result.png └── threepole.gif └── train_pc3.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | __pycache__ 3 | .vscode 4 | iwae_result/ 5 | *.pyc 6 | result 7 | logs 8 | planar 9 | pendulum 10 | cartpole 11 | threepole 12 | reacher 13 | .ipynb_checkpoints/ 14 | run.sh 15 | iLQR_result/ 16 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 VinAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Predictive Coding for Locally-Linear Control 2 | 3 | This is a pytorch implementation of the paper "[Predictive Coding for Locally-Linear Control](http://proceedings.mlr.press/v119/shu20a.html)". We propose PC3 - an information-theoretic representation learning framework for optimal control from high-dimensional observations. Experiments show that our proposed method outperforms the existing reconstruction-based approaches significantly. 4 | 5 | ![pc3 model](sample_results/pc3_model.png) 6 | 7 | Details of the model architecture and experimental results can be found in our [following paper](http://proceedings.mlr.press/v119/shu20a.html): 8 | ``` 9 | @InProceedings{pmlr-v119-shu20a, 10 | title = {Predictive Coding for Locally-Linear Control}, 11 | author = {Shu, Rui and Nguyen, Tung and Chow, Yinlam and Pham, Tuan and Than, Khoat and Ghavamzadeh, Mohammad and Ermon, Stefano and Bui, Hung}, 12 | booktitle = {Proceedings of the 37th International Conference on Machine Learning}, 13 | year = {2020}, 14 | volume = {119}, 15 | series = {Proceedings of Machine Learning Research}, 16 | publisher = {PMLR}, 17 | } 18 | ``` 19 | **Please CITE** our paper whenever this repository is used to help produce published results or incorporated into other software. 20 | 21 | ## Installing 22 | 23 | First, clone the repository: 24 | 25 | ``` 26 | https://github.com/VinAIResearch/PC3-pytorch.git 27 | ``` 28 | 29 | Then install the dependencies as listed in `pc3.yml` and activate the environment: 30 | 31 | ``` 32 | conda env create -f pc3.yml 33 | conda activate pc3 34 | ``` 35 | 36 | ## Training 37 | 38 | The code currently supports training for `planar`, `pendulum`, `cartpole` and `3-link` environment. Run `train_pc3.py` with your own settings. For example: 39 | 40 | ``` 41 | python train_pc3.py \ 42 | --env=planar \ 43 | --armotized=False \ 44 | --log_dir=planar_1 \ 45 | --seed=1 \ 46 | --data_size=5000 \ 47 | --noise=0 \ 48 | --batch_size=256 \ 49 | --latent_noise=0.1 \ 50 | --lam_nce=1.0 \ 51 | --lam_c=1.0 \ 52 | --lam_cur=7.0 \ 53 | --norm_coeff=0.1 \ 54 | --lr=0.0005 \ 55 | --decay=0.001 \ 56 | --num_iter=2000 \ 57 | --iter_save=1000 \ 58 | --save_map=False 59 | ``` 60 | 61 | First, data is sampled according to the given data size and noise level, then the PC3 model will be trained using the specified settings. 62 | 63 | If the argument `save_map` is set to True, the latent map will be drawn every 10 epochs (for planar only), then the gif file will be saved at the same directory as the trained model. 64 | 65 | You can also visualize the training process by running ``tensorboard --logdir={path_to_log_dir}``, where ``path_to_log_dir`` has the form ``logs/{env}/{log_dir}``. The trained model will be saved at ``result/{env}/{log_dir}``. 66 | 67 | ### Latent maps visualization 68 | 69 | You can visualize the latent map for planar and pendulum, to do that simply run: 70 | 71 | ``` 72 | python latent_map_planar.py --log_path={log_to_trained_model} --epoch={epoch} 73 | or 74 | python latent_map_pendulum.py --log_path={log_to_trained_model} --epoch={epoch} 75 | ``` 76 | 77 | ## Data visualization 78 | 79 | You can generate training images for visualization purpose by simply running: 80 | 81 | ``` 82 | cd data 83 | python sample_{env_name}_data.py --sample_size={sample_size} --noise={noise} 84 | ``` 85 | 86 | Currently the code supports simulating 4 environments: `planar`, `pendulum`, `cartpole` and `3-link`. 87 | 88 | The raw data (images) is saved in data/{env_name}/raw\_{noise}\_noise 89 | 90 | ## Running iLQR on latent space 91 | 92 | The configuration file for running iLQR for each task is in ``ilqr_config`` folder, you can modify with your own settings. Run: 93 | 94 | ``` 95 | python ilqr.py --task={task} --setting_path={setting_path} --noise={noise} --epoch={epoch} 96 | ``` 97 | 98 | where ``task`` is in ``{plane, swing, balance, cartpole, 3-link}``, `setting_path` is the path to the model of your 10 trained models (e.g., result/pendulum/). 99 | 100 | The code will run iLQR for all trained models for that specific task and compute some statistics. The result is saved in ``iLQR_result``. 101 | 102 | ## Result 103 | ### Quantitative result 104 | 105 | We compare PC3 with two state-of-the-art LCE baselines: PCC ([Levine et al., 2020](https://openreview.net/pdf?id=BJxG_0EtDS)) and SOLAR ([Zhang et al., 106 | 2019](http://proceedings.mlr.press/v97/zhang19m/zhang19m.pdf)). Specifically, we report the percentage of time spent in the goal region in the underlying system. 107 | 108 | ![result table](sample_results/table_result.png) 109 | 110 | Below are videos showing learned policy in 5 tasks. 111 | 112 | ![planar trajectory](sample_results/planar.gif) 113 | 114 | ![swing trajectory](sample_results/swing.gif) 115 | 116 | ![balance trajectory](sample_results/balance.gif) 117 | 118 | ![cartpole trajectory](sample_results/cartpole.gif) 119 | 120 | ![3-link](sample_results/threepole.gif) 121 | 122 | ### Qualitative result 123 | 124 | We also compare the quality of learned latent maps between PCC and PC3 in planar and pendulum. 125 | 126 | ![maps table](sample_results/table_maps.png) 127 | 128 | -------------------------------------------------------------------------------- /control_stats.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import numpy as np 5 | 6 | 7 | def main(args): 8 | path = args.path 9 | all_models = [os.path.join(path, dI) for dI in os.listdir(path) if os.path.isdir(os.path.join(path, dI))] 10 | all_results = [] 11 | for model in all_models: 12 | with open(model + "/result.txt", "r") as f: 13 | content = f.readlines()[:-1] 14 | content = [x.strip() for x in content] 15 | result_subtasks = [float(x[x.find(":") + 1 :].strip()) for x in content] 16 | all_results += result_subtasks 17 | all_results = np.array(all_results) * 100 18 | mean = all_results.mean() 19 | std_of_means = all_results.std() / np.sqrt(len(all_results)) 20 | print("Mean: " + str(mean)) 21 | print("Std of means: " + str(std_of_means)) 22 | 23 | 24 | if __name__ == "__main__": 25 | parser = argparse.ArgumentParser(description="compute control statistics") 26 | 27 | parser.add_argument("--path", required=True, type=str, help="path to ilqr result") 28 | args = parser.parse_args() 29 | 30 | main(args) 31 | -------------------------------------------------------------------------------- /data/sample_planar.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from datetime import datetime 5 | from os import path 6 | from pathlib import Path 7 | 8 | import numpy as np 9 | from mdp.plane_obstacles_mdp import PlanarObstaclesMDP 10 | from PIL import Image 11 | from tqdm import trange 12 | 13 | 14 | np.random.seed(1) 15 | root_path = str(Path(os.path.dirname(os.path.abspath(__file__))).parent) 16 | 17 | 18 | def get_all_pos(mdp): 19 | start = 0 20 | end = mdp.width 21 | state_list = [] 22 | for x in range(start, end): 23 | for y in range(start, end): 24 | state = np.array([x, y]) 25 | if mdp.is_valid_state(state): 26 | state_list.append(state) 27 | return state_list 28 | 29 | 30 | def sample(sample_size=5000, noise=0.0): 31 | """ 32 | return [(s, u, s_next)] 33 | """ 34 | mdp = PlanarObstaclesMDP(noise=noise) 35 | 36 | # place holder for data 37 | x_data = np.zeros((sample_size, mdp.width, mdp.height), dtype="float32") 38 | u_data = np.zeros((sample_size, mdp.action_dim), dtype="float32") 39 | x_next_data = np.zeros((sample_size, mdp.width, mdp.height), dtype="float32") 40 | state_data = np.zeros((sample_size, 2), dtype="float32") 41 | state_next_data = np.zeros((sample_size, 2), dtype="float32") 42 | 43 | # get all possible states (discretized on integer grid) 44 | state_list = get_all_pos(mdp) 45 | 46 | for j in trange(sample_size, desc="Sampling remaining data"): 47 | state_data[j] = state_list[j % len(state_list)] 48 | # state_data[j] = mdp.sample_random_state() 49 | x_data[j] = mdp.render(state_data[j]) 50 | u_data[j] = mdp.sample_valid_random_action(state_data[j]) 51 | state_next_data[j] = mdp.transition_function(state_data[j], u_data[j]) 52 | x_next_data[j] = mdp.render(state_next_data[j]) 53 | return x_data, u_data, x_next_data, state_data, state_next_data 54 | 55 | 56 | def write_to_file(noise, sample_size): 57 | """ 58 | write [(x, u, x_next)] to output dir 59 | """ 60 | output_dir = root_path + "/data/planar/raw_{:d}_{:.0f}".format(sample_size, noise) 61 | if not path.exists(output_dir): 62 | os.makedirs(output_dir) 63 | 64 | x_data, u_data, x_next_data, state_data, state_next_data = sample(sample_size, noise) 65 | 66 | samples = [] 67 | 68 | for i, _ in enumerate(x_data): 69 | before_file = "before-{:05d}.png".format(i) 70 | Image.fromarray(x_data[i] * 255.0).convert("L").save(path.join(output_dir, before_file)) 71 | 72 | after_file = "after-{:05d}.png".format(i) 73 | Image.fromarray(x_next_data[i] * 255.0).convert("L").save(path.join(output_dir, after_file)) 74 | 75 | initial_state = state_data[i] 76 | after_state = state_next_data[i] 77 | u = u_data[i] 78 | 79 | samples.append( 80 | { 81 | "before_state": initial_state.tolist(), 82 | "after_state": after_state.tolist(), 83 | "before": before_file, 84 | "after": after_file, 85 | "control": u.tolist(), 86 | } 87 | ) 88 | 89 | with open(path.join(output_dir, "data.json"), "wt") as outfile: 90 | json.dump( 91 | { 92 | "metadata": {"num_samples": sample_size, "time_created": str(datetime.now()), "version": 1}, 93 | "samples": samples, 94 | }, 95 | outfile, 96 | indent=2, 97 | ) 98 | 99 | 100 | def main(args): 101 | sample_size = args.sample_size 102 | noise = args.noise 103 | write_to_file(noise, sample_size) 104 | 105 | 106 | if __name__ == "__main__": 107 | parser = argparse.ArgumentParser(description="sample planar data") 108 | 109 | parser.add_argument("--sample_size", required=True, type=int, help="the number of samples") 110 | parser.add_argument("--noise", default=0, type=int, help="level of noise") 111 | 112 | args = parser.parse_args() 113 | 114 | main(args) 115 | -------------------------------------------------------------------------------- /data/sample_pole.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import argparse 4 | import json 5 | import os 6 | import os.path as path 7 | from datetime import datetime 8 | from pathlib import Path 9 | 10 | import numpy as np 11 | from mdp.cartpole_mdp import CartPoleMDP 12 | from mdp.pendulum_mdp import PendulumMDP 13 | from mdp.three_pole_mdp import ThreePoleMDP 14 | from PIL import Image 15 | from tqdm import trange 16 | 17 | 18 | widths = {"pendulum": 48, "cartpole": 80, "threepole": 80} 19 | heights = {"pendulum": 48, "cartpole": 80, "threepole": 80} 20 | state_dims = {"pendulum": 2, "cartpole": 4, "threepole": 6} 21 | frequencies = {"pendulum": 50, "cartpole": 50, "threepole": 50} 22 | mdps = {"pendulum": PendulumMDP, "cartpole": CartPoleMDP, "threepole": ThreePoleMDP} 23 | 24 | root_path = str(Path(os.path.dirname(os.path.abspath(__file__))).parent) 25 | 26 | 27 | def sample(env_name, sample_size, noise): 28 | """ 29 | return [(x, u, x_next, s, s_next)] 30 | """ 31 | width, height, frequency = widths[env_name], heights[env_name], frequencies[env_name] 32 | s_dim = state_dims[env_name] 33 | mdp = mdps[env_name](width=width, height=height, frequency=frequency, noise=noise) 34 | 35 | # Data buffers to fill. 36 | x_data = np.zeros((sample_size, width, height, 2), dtype="float32") 37 | u_data = np.zeros((sample_size, mdp.action_dim), dtype="float32") 38 | x_next_data = np.zeros((sample_size, width, height, 2), dtype="float32") 39 | state_data = np.zeros((sample_size, s_dim, 2), dtype="float32") 40 | state_next_data = np.zeros((sample_size, s_dim, 2), dtype="float32") 41 | 42 | # Generate interaction tuples (random states and actions). 43 | for sample in trange(sample_size, desc="Sampling " + env_name + " data"): 44 | s0 = mdp.sample_random_state() 45 | x0 = mdp.render(s0) 46 | a0 = mdp.sample_random_action() 47 | s1 = mdp.transition_function(s0, a0) 48 | 49 | x1 = mdp.render(s1) 50 | a1 = mdp.sample_random_action() 51 | s2 = mdp.transition_function(s1, a1) 52 | x2 = mdp.render(s2) 53 | # Store interaction tuple. 54 | # Current state (w/ history). 55 | x_data[sample, :, :, 0] = x0[:, :, 0] 56 | x_data[sample, :, :, 1] = x1[:, :, 0] 57 | state_data[sample, :, 0] = s0 58 | state_data[sample, :, 1] = s1 59 | # Action. 60 | u_data[sample] = a1 61 | # Next state (w/ history). 62 | x_next_data[sample, :, :, 0] = x1[:, :, 0] 63 | x_next_data[sample, :, :, 1] = x2[:, :, 0] 64 | state_next_data[sample, :, 0] = s1 65 | state_next_data[sample, :, 1] = s2 66 | 67 | return x_data, u_data, x_next_data, state_data, state_next_data 68 | 69 | 70 | def write_to_file(env_name, sample_size, noise): 71 | """ 72 | write [(x, u, x_next)] to output dir 73 | """ 74 | output_dir = root_path + "/data/" + env_name + "/raw_{:d}_{:.0f}".format(sample_size, noise) 75 | if not path.exists(output_dir): 76 | os.makedirs(output_dir) 77 | 78 | samples = [] 79 | data = sample(env_name=env_name, sample_size=sample_size, noise=noise) 80 | x_data, u_data, x_next_data, state_data, state_next_data = data 81 | 82 | for i in range(x_data.shape[0]): 83 | x_1 = x_data[i, :, :, 0] 84 | x_2 = x_data[i, :, :, 1] 85 | before = np.hstack((x_1, x_2)) 86 | before_file = "before-{:05d}.png".format(i) 87 | Image.fromarray(before * 255.0).convert("L").save(path.join(output_dir, before_file)) 88 | 89 | after_file = "after-{:05d}.png".format(i) 90 | x_next_1 = x_next_data[i, :, :, 0] 91 | x_next_2 = x_next_data[i, :, :, 1] 92 | after = np.hstack((x_next_1, x_next_2)) 93 | Image.fromarray(after * 255.0).convert("L").save(path.join(output_dir, after_file)) 94 | 95 | initial_state = state_data[i] 96 | after_state = state_next_data[i] 97 | 98 | samples.append( 99 | { 100 | "before_state": initial_state.tolist(), 101 | "after_state": after_state.tolist(), 102 | "before": before_file, 103 | "after": after_file, 104 | "control": u_data[i].tolist(), 105 | } 106 | ) 107 | 108 | with open(path.join(output_dir, "data.json"), "wt") as outfile: 109 | json.dump( 110 | { 111 | "metadata": {"num_samples": x_data.shape[0], "time_created": str(datetime.now()), "version": 1}, 112 | "samples": samples, 113 | }, 114 | outfile, 115 | indent=2, 116 | ) 117 | 118 | 119 | def main(args): 120 | sample_size = args.sample_size 121 | noise = args.noise 122 | env_name = args.env 123 | assert env_name in ["pendulum", "cartpole", "threepole"] 124 | write_to_file(env_name, sample_size, noise) 125 | 126 | 127 | if __name__ == "__main__": 128 | parser = argparse.ArgumentParser(description="sample pendulum data") 129 | 130 | parser.add_argument("--sample_size", required=True, type=int, help="the number of samples") 131 | parser.add_argument("--noise", default=0, type=int, help="level of noise") 132 | parser.add_argument("--env", required=True, type=str, help="pendulum or cartpole or threepole") 133 | 134 | args = parser.parse_args() 135 | 136 | main(args) 137 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import path 3 | 4 | import numpy as np 5 | import torch 6 | from data import sample_planar, sample_pole 7 | from torch.utils.data import Dataset 8 | 9 | 10 | torch.set_default_dtype(torch.float64) 11 | 12 | 13 | class BaseDataset(Dataset): 14 | def __init__(self, data_path, sample_size, noise): 15 | self.sample_size = sample_size 16 | self.noise = noise 17 | self.data_path = data_path 18 | if not os.path.exists(self.data_path): 19 | os.makedirs(self.data_path) 20 | self._process() 21 | self.data_x, self.data_u, self.data_x_next = torch.load( 22 | self.data_path + "{:d}_{:.0f}.pt".format(self.sample_size, self.noise) 23 | ) 24 | 25 | def __len__(self): 26 | return len(self.data_x) 27 | 28 | def __getitem__(self, index): 29 | return self.data_x[index], self.data_u[index], self.data_x_next[index] 30 | 31 | def _process_image(self, img): 32 | pass 33 | 34 | def check_exists(self): 35 | return path.exists(self.data_path + "{:d}_{:.0f}.pt".format(self.sample_size, self.noise)) 36 | 37 | def _process(self): 38 | pass 39 | 40 | 41 | class PlanarDataset(BaseDataset): 42 | width = 40 43 | height = 40 44 | action_dim = 2 45 | 46 | def __init__(self, sample_size, noise): 47 | data_path = "data/planar/" 48 | super(PlanarDataset, self).__init__(data_path, sample_size, noise) 49 | 50 | def _process_image(self, img): 51 | return torch.from_numpy(img.flatten()).unsqueeze(0) 52 | 53 | def _process(self): 54 | if self.check_exists(): 55 | return 56 | else: 57 | ( 58 | x_numpy_data, 59 | u_numpy_data, 60 | x_next_numpy_data, 61 | state_numpy_data, 62 | state_next_numpy_data, 63 | ) = sample_planar.sample(sample_size=self.sample_size, noise=self.noise) 64 | data_len = len(x_numpy_data) 65 | 66 | # place holder for data 67 | data_x = torch.zeros(data_len, self.width * self.height) 68 | data_u = torch.zeros(data_len, self.action_dim) 69 | data_x_next = torch.zeros(data_len, self.width * self.height) 70 | 71 | for i in range(data_len): 72 | data_x[i] = self._process_image(x_numpy_data[i]) 73 | data_u[i] = torch.from_numpy(u_numpy_data[i]) 74 | data_x_next[i] = self._process_image(x_next_numpy_data[i]) 75 | 76 | data_set = (data_x, data_u, data_x_next) 77 | 78 | with open(self.data_path + "{:d}_{:.0f}.pt".format(self.sample_size, self.noise), "wb") as f: 79 | torch.save(data_set, f) 80 | 81 | 82 | class PendulumDataset(BaseDataset): 83 | width = 48 84 | height = 48 * 2 85 | action_dim = 1 86 | 87 | def __init__(self, sample_size, noise): 88 | data_path = "data/pendulum/" 89 | super(PendulumDataset, self).__init__(data_path, sample_size, noise) 90 | 91 | def _process_image(self, img): 92 | x = np.vstack((img[:, :, 0], img[:, :, 1])).flatten() 93 | return torch.from_numpy(x).unsqueeze(0) 94 | 95 | def _process(self): 96 | if self.check_exists(): 97 | return 98 | else: 99 | ( 100 | x_numpy_data, 101 | u_numpy_data, 102 | x_next_numpy_data, 103 | state_numpy_data, 104 | state_next_numpy_data, 105 | ) = sample_pole.sample(env_name="pendulum", sample_size=self.sample_size, noise=self.noise) 106 | data_len = len(x_numpy_data) 107 | 108 | # place holder for data 109 | data_x = torch.zeros(data_len, self.width * self.height) 110 | data_u = torch.zeros(data_len, self.action_dim) 111 | data_x_next = torch.zeros(data_len, self.width * self.height) 112 | 113 | for i in range(data_len): 114 | data_x[i] = self._process_image(x_numpy_data[i]) 115 | data_u[i] = torch.from_numpy(u_numpy_data[i]) 116 | data_x_next[i] = self._process_image(x_next_numpy_data[i]) 117 | 118 | data_set = (data_x, data_u, data_x_next) 119 | 120 | with open(self.data_path + "{:d}_{:.0f}.pt".format(self.sample_size, self.noise), "wb") as f: 121 | torch.save(data_set, f) 122 | 123 | 124 | class CartPoleDataset(BaseDataset): 125 | width = 80 126 | height = 80 * 2 127 | action_dim = 1 128 | 129 | def __init__(self, sample_size, noise): 130 | data_path = "data/cartpole/" 131 | super(CartPoleDataset, self).__init__(data_path, sample_size, noise) 132 | 133 | def _process_image(self, img): 134 | x = torch.zeros(size=(2, self.width, self.width)) 135 | x[0, :, :] = torch.from_numpy(img[:, :, 0]) 136 | x[1, :, :] = torch.from_numpy(img[:, :, 1]) 137 | return x.unsqueeze(0) 138 | 139 | def _process(self): 140 | if self.check_exists(): 141 | return 142 | else: 143 | ( 144 | x_numpy_data, 145 | u_numpy_data, 146 | x_next_numpy_data, 147 | state_numpy_data, 148 | state_next_numpy_data, 149 | ) = sample_pole.sample(env_name="cartpole", sample_size=self.sample_size, noise=self.noise) 150 | data_len = len(x_numpy_data) 151 | 152 | # place holder for data 153 | data_x = torch.zeros(data_len, 2, self.width, self.width) 154 | data_u = torch.zeros(data_len, self.action_dim) 155 | data_x_next = torch.zeros(data_len, 2, self.width, self.width) 156 | 157 | for i in range(data_len): 158 | data_x[i] = self._process_image(x_numpy_data[i]) 159 | data_u[i] = torch.from_numpy(u_numpy_data[i]) 160 | data_x_next[i] = self._process_image(x_next_numpy_data[i]) 161 | 162 | data_set = (data_x, data_u, data_x_next) 163 | 164 | with open(self.data_path + "{:d}_{:.0f}.pt".format(self.sample_size, self.noise), "wb") as f: 165 | torch.save(data_set, f) 166 | 167 | 168 | class ThreePoleDataset(BaseDataset): 169 | width = 80 170 | height = 80 * 2 171 | action_dim = 3 172 | 173 | def __init__(self, sample_size, noise): 174 | data_path = "data/threepole/" 175 | super(ThreePoleDataset, self).__init__(data_path, sample_size, noise) 176 | 177 | def _process_image(self, img): 178 | x = torch.zeros(size=(2, self.width, self.width)) 179 | x[0, :, :] = torch.from_numpy(img[:, :, 0]) 180 | x[1, :, :] = torch.from_numpy(img[:, :, 1]) 181 | return x.unsqueeze(0) 182 | 183 | def _process(self): 184 | if self.check_exists(): 185 | return 186 | else: 187 | ( 188 | x_numpy_data, 189 | u_numpy_data, 190 | x_next_numpy_data, 191 | state_numpy_data, 192 | state_next_numpy_data, 193 | ) = sample_pole.sample(env_name="threepole", sample_size=self.sample_size, noise=self.noise) 194 | data_len = len(x_numpy_data) 195 | 196 | # place holder for data 197 | data_x = torch.zeros(data_len, 2, self.width, self.width) 198 | data_u = torch.zeros(data_len, self.action_dim) 199 | data_x_next = torch.zeros(data_len, 2, self.width, self.width) 200 | 201 | for i in range(data_len): 202 | data_x[i] = self._process_image(x_numpy_data[i]) 203 | data_u[i] = torch.from_numpy(u_numpy_data[i]) 204 | data_x_next[i] = self._process_image(x_next_numpy_data[i]) 205 | 206 | data_set = (data_x, data_u, data_x_next) 207 | 208 | with open(self.data_path + "{:d}_{:.0f}.pt".format(self.sample_size, self.noise), "wb") as f: 209 | torch.save(data_set, f) 210 | -------------------------------------------------------------------------------- /ilqr.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import random 5 | 6 | import numpy as np 7 | import torch 8 | from ilqr_utils import ( 9 | backward, 10 | compute_latent_traj, 11 | forward, 12 | get_x_data, 13 | latent_cost, 14 | random_actions_trajs, 15 | refresh_actions_trajs, 16 | save_traj, 17 | seq_jacobian, 18 | update_horizon_start, 19 | ) 20 | from mdp.cartpole_mdp import CartPoleMDP 21 | from mdp.pendulum_mdp import PendulumMDP 22 | from mdp.plane_obstacles_mdp import PlanarObstaclesMDP 23 | from mdp.three_pole_mdp import ThreePoleMDP 24 | from pc3_model import PC3 25 | 26 | 27 | seed = 2020 28 | random.seed(seed) 29 | os.environ["PYTHONHASHSEED"] = str(seed) 30 | np.random.seed(seed) 31 | torch.manual_seed(seed) 32 | torch.cuda.manual_seed(seed) 33 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 34 | torch.backends.cudnn.benchmark = False 35 | torch.backends.cudnn.deterministic = True 36 | torch.set_default_dtype(torch.float64) 37 | 38 | config_path = { 39 | "plane": "ilqr_config/plane.json", 40 | "swing": "ilqr_config/swing.json", 41 | "balance": "ilqr_config/balance.json", 42 | "cartpole": "ilqr_config/cartpole.json", 43 | "threepole": "ilqr_config/threepole.json", 44 | } 45 | env_data_dim = { 46 | "planar": (1600, 2, 2), 47 | "pendulum": ((2, 48, 48), 3, 1), 48 | "cartpole": ((2, 80, 80), 8, 1), 49 | "threepole": ((2, 80, 80), 8, 3), 50 | } 51 | 52 | 53 | def main(args): 54 | task_name = args.task 55 | assert task_name in ["planar", "balance", "swing", "cartpole", "threepole"] 56 | env_name = "pendulum" if task_name in ["balance", "swing"] else task_name 57 | 58 | setting_path = args.setting_path 59 | setting = os.path.basename(os.path.normpath(setting_path)) 60 | noise = args.noise 61 | epoch = args.epoch 62 | x_dim, z_dim, u_dim = env_data_dim[env_name] 63 | 64 | # non-convolution encoder 65 | if env_name in ["planar", "pendulum"]: 66 | x_dim = np.prod(x_dim) 67 | 68 | ilqr_result_path = "iLQR_result/" + "_".join([task_name, str(setting), str(noise), str(epoch)]) 69 | if not os.path.exists(ilqr_result_path): 70 | os.makedirs(ilqr_result_path) 71 | with open(ilqr_result_path + "/settings", "w") as f: 72 | json.dump(args.__dict__, f, indent=2) 73 | 74 | # each trained model will perform 10 random tasks 75 | all_task_configs = [] 76 | for task_counter in range(10): 77 | # config for this task 78 | with open(config_path[task_name]) as f: 79 | config = json.load(f) 80 | 81 | # sample random start and goal state 82 | s_start_min, s_start_max = config["start_min"], config["start_max"] 83 | config["s_start"] = np.random.uniform(low=s_start_min, high=s_start_max) 84 | s_goal = config["goal"][np.random.choice(len(config["goal"]))] 85 | config["s_goal"] = np.array(s_goal) 86 | 87 | all_task_configs.append(config) 88 | 89 | # the folder where all trained models are saved 90 | log_folders = [ 91 | os.path.join(setting_path, dI) 92 | for dI in os.listdir(setting_path) 93 | if os.path.isdir(os.path.join(setting_path, dI)) 94 | ] 95 | log_folders.sort() 96 | 97 | # statistics on all trained models 98 | avg_model_percent = 0.0 99 | best_model_percent = 0.0 100 | for log in log_folders: 101 | with open(log + "/settings", "r") as f: 102 | settings = json.load(f) 103 | armotized = settings["armotized"] 104 | 105 | log_base = os.path.basename(os.path.normpath(log)) 106 | model_path = ilqr_result_path + "/" + log_base 107 | if not os.path.exists(model_path): 108 | os.makedirs(model_path) 109 | print("iLQR for " + log_base) 110 | 111 | # load the trained model 112 | model = PC3(armotized, x_dim, z_dim, u_dim, env_name) 113 | model.load_state_dict(torch.load(log + "/model_" + str(epoch), map_location="cpu")) 114 | model.eval() 115 | dynamics = model.dynamics 116 | encoder = model.encoder 117 | 118 | # run the task with 10 different start and goal states for a particular model 119 | avg_percent = 0.0 120 | for task_counter, config in enumerate(all_task_configs): 121 | 122 | print("Performing task %d: " % (task_counter) + str(config["task"])) 123 | 124 | # environment specification 125 | horizon = config["horizon_prob"] 126 | plan_len = config["plan_len"] 127 | 128 | # ilqr specification 129 | R_z = config["q_weight"] * np.eye(z_dim) 130 | R_u = config["r_weight"] * np.eye(u_dim) 131 | num_uniform = config["uniform_trajs"] 132 | num_extreme = config["extreme_trajs"] 133 | ilqr_iters = config["ilqr_iters"] 134 | inv_regulator_init = config["pinv_init"] 135 | inv_regulator_multi = config["pinv_mult"] 136 | inv_regulator_max = config["pinv_max"] 137 | alpha_init = config["alpha_init"] 138 | alpha_mult = config["alpha_mult"] 139 | alpha_min = config["alpha_min"] 140 | 141 | s_start = config["s_start"] 142 | s_goal = config["s_goal"] 143 | 144 | # mdp 145 | if env_name == "planar": 146 | mdp = PlanarObstaclesMDP(goal=s_goal, goal_thres=config["distance_thresh"], noise=noise) 147 | elif env_name == "pendulum": 148 | mdp = PendulumMDP(frequency=config["frequency"], noise=noise, torque=config["torque"]) 149 | elif env_name == "cartpole": 150 | mdp = CartPoleMDP(frequency=config["frequency"], noise=noise) 151 | elif env_name == "threepole": 152 | mdp = ThreePoleMDP(frequency=config["frequency"], noise=noise, torque=config["torque"]) 153 | # get z_start and z_goal 154 | x_start = get_x_data(mdp, s_start, config) 155 | x_goal = get_x_data(mdp, s_goal, config) 156 | with torch.no_grad(): 157 | z_start = encoder(x_start) 158 | z_goal = encoder(x_goal) 159 | z_start = z_start.squeeze().numpy() 160 | z_goal = z_goal.squeeze().numpy() 161 | 162 | # initialize actions trajectories 163 | all_actions_trajs = random_actions_trajs(mdp, num_uniform, num_extreme, plan_len) 164 | 165 | # perform reciding horizon iLQR 166 | s_start_horizon = np.copy(s_start) # s_start and z_start is changed at each horizon 167 | z_start_horizon = np.copy(z_start) 168 | obs_traj = [mdp.render(s_start).squeeze()] 169 | goal_counter = 0.0 170 | for plan_iter in range(1, horizon + 1): 171 | latent_cost_list = [None] * len(all_actions_trajs) 172 | # iterate over all trajectories 173 | for traj_id in range(len(all_actions_trajs)): 174 | # initialize the inverse regulator 175 | inv_regulator = inv_regulator_init 176 | for iter in range(1, ilqr_iters + 1): 177 | u_seq = all_actions_trajs[traj_id] 178 | z_seq = compute_latent_traj(z_start_horizon, u_seq, dynamics) 179 | # compute the linearization matrices 180 | A_seq, B_seq = seq_jacobian(dynamics, z_seq, u_seq) 181 | # run backward 182 | k_small, K_big = backward(R_z, R_u, z_seq, u_seq, z_goal, A_seq, B_seq, inv_regulator) 183 | current_cost = latent_cost(R_z, R_u, z_seq, z_goal, u_seq) 184 | # forward using line search 185 | alpha = alpha_init 186 | accept = False # if any alpha is accepted 187 | while alpha > alpha_min: 188 | z_seq_cand, u_seq_cand = forward( 189 | z_seq, all_actions_trajs[traj_id], k_small, K_big, dynamics, alpha 190 | ) 191 | cost_cand = latent_cost(R_z, R_u, z_seq_cand, z_goal, u_seq_cand) 192 | if cost_cand < current_cost: # accept the trajectory candidate 193 | accept = True 194 | all_actions_trajs[traj_id] = u_seq_cand 195 | latent_cost_list[traj_id] = cost_cand 196 | break 197 | else: 198 | alpha *= alpha_mult 199 | if accept: 200 | inv_regulator = inv_regulator_init 201 | else: 202 | inv_regulator *= inv_regulator_multi 203 | if inv_regulator > inv_regulator_max: 204 | break 205 | 206 | for i in range(len(latent_cost_list)): 207 | if latent_cost_list[i] is None: 208 | latent_cost_list[i] = np.inf 209 | traj_opt_id = np.argmin(latent_cost_list) 210 | action_chosen = all_actions_trajs[traj_opt_id][0] 211 | s_start_horizon, z_start_horizon = update_horizon_start( 212 | mdp, s_start_horizon, action_chosen, encoder, config 213 | ) 214 | 215 | obs_traj.append(mdp.render(s_start_horizon).squeeze()) 216 | goal_counter += mdp.reward_function(s_start_horizon) 217 | 218 | all_actions_trajs = refresh_actions_trajs( 219 | all_actions_trajs, 220 | traj_opt_id, 221 | mdp, 222 | np.min([plan_len, horizon - plan_iter]), 223 | num_uniform, 224 | num_extreme, 225 | ) 226 | 227 | # compute the percentage close to goal 228 | success_rate = goal_counter / horizon 229 | print("Success rate: %.2f" % (success_rate)) 230 | percent = success_rate 231 | avg_percent += success_rate 232 | with open(model_path + "/result.txt", "a+") as f: 233 | f.write(config["task"] + ": " + str(percent) + "\n") 234 | 235 | # save trajectory as gif file 236 | gif_path = model_path + "/task_{:01d}.gif".format(task_counter + 1) 237 | save_traj(obs_traj, mdp.render(s_goal).squeeze(), gif_path, config["task"]) 238 | 239 | avg_percent = avg_percent / 10 240 | print("Average success rate: " + str(avg_percent)) 241 | print("====================================") 242 | avg_model_percent += avg_percent 243 | if avg_percent > best_model_percent: 244 | best_model = log_base 245 | best_model_percent = avg_percent 246 | with open(model_path + "/result.txt", "a+") as f: 247 | f.write("Average percentage: " + str(avg_percent)) 248 | 249 | avg_model_percent = avg_model_percent / len(log_folders) 250 | with open(ilqr_result_path + "/result.txt", "w") as f: 251 | f.write("Average percentage of all models: " + str(avg_model_percent) + "\n") 252 | f.write("Best model: " + best_model + ", best percentage: " + str(best_model_percent)) 253 | 254 | 255 | if __name__ == "__main__": 256 | parser = argparse.ArgumentParser(description="run iLQR") 257 | parser.add_argument("--task", required=True, type=str, help="task to perform") 258 | parser.add_argument("--setting_path", required=True, type=str, help="path to load trained models") 259 | parser.add_argument("--noise", type=float, default=0.0, help="noise level for mdp") 260 | parser.add_argument("--epoch", type=int, default=2000, help="number of epochs to load model") 261 | args = parser.parse_args() 262 | 263 | main(args) 264 | -------------------------------------------------------------------------------- /ilqr_config/balance.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "balance", 3 | 4 | "pinv_init": 1e-5, 5 | "pinv_mult": 2.0, 6 | "pinv_max": 10.0, 7 | "alpha_init": 1.0, 8 | "alpha_mult": 0.5, 9 | "alpha_min": 1e-4, 10 | 11 | "start_min": [0, 0], 12 | "start_max": [0, 0], 13 | "goal": [[0, 0]], 14 | 15 | "q_weight": 50, 16 | "r_weight": 1, 17 | 18 | "frequency": 50, 19 | "noise": 0.0, 20 | "torque": 0.65, 21 | 22 | "ilqr_iters": 4, 23 | "horizon_prob": 100, 24 | "plan_len": 10, 25 | "uniform_trajs": 3, 26 | "extreme_trajs": 3, 27 | 28 | "obs_shape": [2, 48, 48], 29 | "action_dim": 1, 30 | "latent_dim": 3 31 | } -------------------------------------------------------------------------------- /ilqr_config/cartpole.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "cartpole", 3 | 4 | "pinv_init": 1e-5, 5 | "pinv_mult": 2.0, 6 | "pinv_max": 10.0, 7 | "alpha_init": 1.0, 8 | "alpha_mult": 0.5, 9 | "alpha_min": 1e-4, 10 | 11 | "start_min": [-0.39269908169872414, 0, 0, 0], 12 | "start_max": [0.39269908169872414, 0, 0, 0], 13 | "goal": [[0, 0, 0, 0]], 14 | 15 | "q_weight": 1000, 16 | "r_weight": 1, 17 | 18 | "frequency": 50, 19 | "noise": 0.0, 20 | 21 | "ilqr_iters": 4, 22 | "horizon_prob": 50, 23 | "plan_len": 5, 24 | "uniform_trajs": 3, 25 | "extreme_trajs": 3, 26 | 27 | "obs_shape": [2, 80, 80], 28 | "action_dim": 1, 29 | "latent_dim": 8 30 | } -------------------------------------------------------------------------------- /ilqr_config/plane.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "plane", 3 | 4 | "pinv_init": 1e-5, 5 | "pinv_mult": 2.0, 6 | "pinv_max": 10.0, 7 | "alpha_init": 1.0, 8 | "alpha_mult": 0.5, 9 | "alpha_min": 1e-4, 10 | 11 | "start_min": [3, 3], 12 | "start_max": [7, 7], 13 | "goal": [[37, 37], [37, 3], [3, 37]], 14 | "distance_thresh": 2, 15 | 16 | "q_weight": 10, 17 | "r_weight": 1, 18 | 19 | "noise": 0.0, 20 | 21 | "ilqr_iters": 4, 22 | "horizon_prob": 40, 23 | "plan_len": 10, 24 | "uniform_trajs": 3, 25 | "extreme_trajs": 3, 26 | 27 | "obs_shape": [40, 40], 28 | "action_dim": 2, 29 | "latent_dim": 2 30 | } 31 | -------------------------------------------------------------------------------- /ilqr_config/swing.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "swing", 3 | 4 | "pinv_init": 1e-5, 5 | "pinv_mult": 2.0, 6 | "pinv_max": 10.0, 7 | "alpha_init": 1.0, 8 | "alpha_mult": 0.5, 9 | "alpha_min": 1e-4, 10 | 11 | "start_min": [2.827433388230814, 0], 12 | "start_max": [2.9845130209103035, 0], 13 | "goal": [[0, 0]], 14 | 15 | "q_weight": 50, 16 | "r_weight": 1, 17 | 18 | "frequency": 50, 19 | "noise": 0.0, 20 | "torque": 0.65, 21 | 22 | "ilqr_iters": 4, 23 | "horizon_prob": 100, 24 | "plan_len": 10, 25 | "uniform_trajs": 3, 26 | "extreme_trajs": 3, 27 | 28 | "obs_shape": [2, 48, 48], 29 | "action_dim": 1, 30 | "latent_dim": 3 31 | } -------------------------------------------------------------------------------- /ilqr_config/threepole.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "threepole", 3 | 4 | "pinv_init": 1e-5, 5 | "pinv_mult": 2.0, 6 | "pinv_max": 10.0, 7 | "alpha_init": 1.0, 8 | "alpha_mult": 0.5, 9 | "alpha_min": 1e-4, 10 | 11 | "start_min": [2.9845130209103035, 0, 2.0420352248333655, 0, 1.0367255756846319, 0], 12 | "start_max": [2.9845130209103035, 0, 2.0420352248333655, 0, 1.0367255756846319, 0], 13 | "goal": [[0, 0, 0, 0, 0, 0]], 14 | 15 | "q_weight": 50, 16 | "r_weight": 1, 17 | 18 | "frequency": 50, 19 | "noise": 0.0, 20 | "torque": 1.0, 21 | 22 | "ilqr_iters": 4, 23 | "horizon_prob": 200, 24 | "plan_len": 20, 25 | "uniform_trajs": 3, 26 | "extreme_trajs": 3, 27 | 28 | "obs_shape": [2, 80, 80], 29 | "action_dim": 3, 30 | "latent_dim": 8 31 | } -------------------------------------------------------------------------------- /ilqr_utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import torch 4 | from matplotlib.animation import FuncAnimation, writers 5 | 6 | 7 | np.random.seed(0) 8 | 9 | 10 | def cost_dz(R_z, z, z_goal): 11 | # compute the first-order deravative of latent cost w.r.t z 12 | z_diff = np.expand_dims(z - z_goal, axis=-1) 13 | return np.squeeze(2 * np.matmul(R_z, z_diff)) 14 | 15 | 16 | def cost_du(R_u, u): 17 | # compute the first-order deravative of latent cost w.r.t u 18 | return np.atleast_1d(np.squeeze(2 * np.matmul(R_u, np.expand_dims(u, axis=-1)))) 19 | 20 | 21 | def cost_dzz(R_z): 22 | # compute the second-order deravative of latent cost w.r.t z 23 | return 2 * R_z 24 | 25 | 26 | def cost_duu(R_u): 27 | # compute the second-order deravative of latent cost w.r.t u 28 | return 2 * R_u 29 | 30 | 31 | def cost_duz(z, u): 32 | # compute the second-order deravative of latent cost w.r.t uz 33 | return np.zeros((u.shape[-1], z.shape[-1])) 34 | 35 | 36 | def latent_cost(R_z, R_u, z_seq, z_goal, u_seq): 37 | z_diff = np.expand_dims(z_seq - z_goal, axis=-1) 38 | cost_z = np.squeeze(np.matmul(np.matmul(z_diff.transpose((0, 2, 1)), R_z), z_diff)) 39 | u_seq_reshaped = np.expand_dims(u_seq, axis=-1) 40 | cost_u = np.squeeze(np.matmul(np.matmul(u_seq_reshaped.transpose((0, 2, 1)), R_u), u_seq_reshaped)) 41 | return np.sum(cost_z) + np.sum(cost_u) 42 | 43 | 44 | def one_step_back(R_z, R_u, z, u, z_goal, A, B, V_prime_next_z, V_prime_next_zz, mu_inv_regulator): 45 | """ 46 | V_prime_next_z: first order derivative of the value function at time step t+1 47 | V_prime_next_zz: second order derivative of the value function at time tep t+1 48 | A: derivative of F(z, u) w.r.t z at z_bar_t, u_bar_t 49 | B: derivative of F(z, u) w.r.t u at z_bar_t, u_bar_t 50 | """ 51 | # compute Q_z, Q_u, Q_zz, Q_uu, Q_uz using cost function, A, B and V 52 | Q_z = cost_dz(R_z, z, z_goal) + np.matmul(A.transpose(), V_prime_next_z) 53 | Q_u = cost_du(R_u, u) + np.matmul(B.transpose(), V_prime_next_z) 54 | Q_zz = cost_dzz(R_z) + np.matmul(np.matmul(A.transpose(), V_prime_next_zz), A) 55 | Q_uz = cost_duz(z, u) + np.matmul(np.matmul(B.transpose(), V_prime_next_zz), A) 56 | Q_uu = cost_duu(R_u) + np.matmul(np.matmul(B.transpose(), V_prime_next_zz), B) 57 | 58 | # compute k and K matrix, add regularization to Q_uu 59 | Q_uu_regularized = Q_uu + mu_inv_regulator * np.eye(Q_uu.shape[0]) 60 | Q_uu_in = np.linalg.inv(Q_uu_regularized) 61 | k = -np.matmul(Q_uu_in, Q_u) 62 | K = -np.matmul(Q_uu_in, Q_uz) 63 | 64 | # compute V_z and V_zz using k and K 65 | V_prime_z = Q_z + np.matmul(Q_uz.transpose(), k) 66 | V_prime_zz = Q_zz + np.matmul(Q_uz.transpose(), K) 67 | return k, K, V_prime_z, V_prime_zz 68 | 69 | 70 | def backward(R_z, R_u, z_seq, u_seq, z_goal, A_seq, B_seq, inv_regulator): 71 | """ 72 | do the backward pass 73 | return a sequence of k and K matrices 74 | """ 75 | # first and second order derivative of the value function at the last time step 76 | V_prime_next_z = cost_dz(R_z, z_seq[-1], z_goal) 77 | V_prime_next_zz = cost_dzz(R_z) 78 | k, K = [], [] 79 | act_seq_len = len(u_seq) 80 | for t in reversed(range(act_seq_len)): 81 | k_t, K_t, V_prime_z, V_prime_zz = one_step_back( 82 | R_z, R_u, z_seq[t], u_seq[t], z_goal, A_seq[t], B_seq[t], V_prime_next_z, V_prime_next_zz, inv_regulator 83 | ) 84 | k.insert(0, k_t) 85 | K.insert(0, K_t) 86 | V_prime_next_z, V_prime_next_zz = V_prime_z, V_prime_zz 87 | return k, K 88 | 89 | 90 | def forward(z_seq, u_seq, k, K, dynamics, alpha): 91 | """ 92 | update the trajectory, given k and K 93 | """ 94 | z_seq_new = [] 95 | z_seq_new.append(z_seq[0]) 96 | u_seq_new = [] 97 | for i in range(0, len(u_seq)): 98 | u_new = u_seq[i] + alpha * k[i] + np.matmul(K[i], z_seq_new[i] - z_seq[i]) 99 | u_seq_new.append(u_new) 100 | with torch.no_grad(): 101 | z_new = dynamics(torch.from_numpy(z_seq_new[i]).unsqueeze(0), torch.from_numpy(u_new).unsqueeze(0))[0].mean 102 | z_seq_new.append(z_new.squeeze().numpy()) 103 | return np.array(z_seq_new), np.array(u_seq_new) 104 | 105 | 106 | def get_x_data(mdp, state, config): 107 | image_data = mdp.render(state).squeeze() 108 | x_dim = config["obs_shape"] 109 | if config["task"] == "plane": 110 | x_dim = np.prod(x_dim) 111 | x_data = torch.from_numpy(image_data).double().view(x_dim).unsqueeze(0) 112 | elif config["task"] in ["swing", "balance"]: 113 | x_dim = np.prod(x_dim) 114 | x_data = np.vstack((image_data, image_data)) 115 | x_data = torch.from_numpy(x_data).double().view(x_dim).unsqueeze(0) 116 | elif config["task"] in ["cartpole", "threepole"]: 117 | x_data = torch.zeros(size=(2, 80, 80)) 118 | x_data[0, :, :] = torch.from_numpy(image_data) 119 | x_data[1, :, :] = torch.from_numpy(image_data) 120 | x_data = x_data.unsqueeze(0) 121 | return x_data 122 | 123 | 124 | def update_horizon_start(mdp, s, u, encoder, config): 125 | s_next = mdp.transition_function(s, u) 126 | if config["task"] == "plane": 127 | x_next = get_x_data(mdp, s_next, config) 128 | elif config["task"] in ["swing", "balance"]: 129 | obs = mdp.render(s).squeeze() 130 | obs_next = mdp.render(s_next).squeeze() 131 | obs_stacked = np.vstack((obs, obs_next)) 132 | x_dim = np.prod(config["obs_shape"]) 133 | x_next = torch.from_numpy(obs_stacked).view(x_dim).unsqueeze(0).double() 134 | elif config["task"] in ["cartpole", "threepole"]: 135 | obs = mdp.render(s).squeeze() 136 | obs_next = mdp.render(s_next).squeeze() 137 | x_next = torch.zeros(size=config["obs_shape"]) 138 | x_next[0, :, :] = torch.from_numpy(obs) 139 | x_next[1, :, :] = torch.from_numpy(obs_next) 140 | x_next = x_next.unsqueeze(0) 141 | with torch.no_grad(): 142 | z_next = encoder(x_next) 143 | return s_next, z_next.squeeze().numpy() 144 | 145 | 146 | def random_uniform_actions(mdp, plan_len): 147 | # create a trajectory of random actions 148 | random_actions = [] 149 | for i in range(plan_len): 150 | action = mdp.sample_random_action() 151 | random_actions.append(action) 152 | return np.array(random_actions) 153 | 154 | 155 | def random_extreme_actions(mdp, plan_len): 156 | # create a trajectory of extreme actions 157 | extreme_actions = [] 158 | for i in range(plan_len): 159 | action = mdp.sample_extreme_action() 160 | extreme_actions.append(action) 161 | return np.array(extreme_actions) 162 | 163 | 164 | def random_actions_trajs(mdp, num_uniform, num_extreme, plan_len): 165 | actions_trajs = [] 166 | for i in range(num_uniform): 167 | actions_trajs.append(random_uniform_actions(mdp, plan_len)) 168 | for j in range(num_extreme): 169 | actions_trajs.append(random_extreme_actions(mdp, plan_len)) 170 | return actions_trajs 171 | 172 | 173 | def refresh_actions_trajs(actions_trajs, traj_opt_id, mdp, length, num_uniform, num_extreme): 174 | for traj_id in range(len(actions_trajs)): 175 | if traj_id == traj_opt_id: 176 | actions_trajs[traj_id] = actions_trajs[traj_id][1:] 177 | if len(actions_trajs[traj_id]) < length: 178 | # Duplicate last action. 179 | actions_trajs[traj_id] = np.append( 180 | actions_trajs[traj_id], actions_trajs[traj_id][-1].reshape(1, -1), axis=0 181 | ) 182 | continue 183 | if traj_id < num_uniform: 184 | actions_trajs[traj_id] = random_uniform_actions(mdp, length) 185 | else: 186 | actions_trajs[traj_id] = random_extreme_actions(mdp, length) 187 | return actions_trajs 188 | 189 | 190 | def update_seq_act(z_seq, z_start, u_seq, k, K, dynamics): 191 | """ 192 | update the trajectory, given k and K 193 | """ 194 | z_new = z_start 195 | u_seq_new = [] 196 | for i in range(0, len(u_seq)): 197 | u_new = u_seq[i] + k[i] + np.matmul(K[i], (z_new - z_seq[i])) 198 | with torch.no_grad(): 199 | z_new = dynamics(torch.from_numpy(z_new).view(1, -1), torch.from_numpy(u_new).view(1, -1))[0].mean 200 | z_new = z_new.squeeze().numpy() 201 | u_seq_new.append(u_new) 202 | return np.array(u_seq_new) 203 | 204 | 205 | def compute_latent_traj(z_start, u_seq, dynamics): 206 | plan_len = len(u_seq) 207 | z_seq = [z_start] 208 | for i in range(plan_len): 209 | z = torch.from_numpy(z_seq[i]).view(1, -1).double() 210 | u = torch.from_numpy(u_seq[i]).view(1, -1).double() 211 | with torch.no_grad(): 212 | z_next = dynamics(z, u)[0].mean 213 | z_seq.append(z_next.squeeze().numpy()) 214 | return z_seq 215 | 216 | 217 | def jacobian(dynamics, z, u): 218 | """ 219 | compute the jacobian of F(z,u) w.r.t z, u 220 | """ 221 | z_dim = z.shape[0] 222 | u_dim = u.shape[0] 223 | z_tensor = torch.from_numpy(z).view(1, -1).double() 224 | u_tensor = torch.from_numpy(u).view(1, -1).double() 225 | if dynamics.armotized: 226 | _, A, B = dynamics(z_tensor, u_tensor) 227 | return A.squeeze().view(z_dim, z_dim).numpy(), B.squeeze().view(z_dim, u_dim).numpy() 228 | z_tensor, u_tensor = z_tensor.squeeze().repeat(z_dim, 1), u_tensor.squeeze().repeat(z_dim, 1) 229 | z_tensor = z_tensor.detach().requires_grad_(True) 230 | u_tensor = u_tensor.detach().requires_grad_(True) 231 | z_next = dynamics(z_tensor, u_tensor)[0].mean 232 | grad_inp = torch.eye(z_dim) 233 | A, B = torch.autograd.grad(z_next, [z_tensor, u_tensor], [grad_inp, grad_inp]) 234 | return A.numpy(), B.numpy() 235 | 236 | 237 | def seq_jacobian(dynamics, z_seq, u_seq): 238 | """ 239 | compute the jacobian w.r.t each pair in the trajectory 240 | """ 241 | A_seq, B_seq = [], [] 242 | horizon = len(u_seq) 243 | for i in range(horizon): 244 | z, u = z_seq[i], u_seq[i] 245 | A, B = jacobian(dynamics, z, u) 246 | A_seq.append(A) 247 | B_seq.append(B) 248 | return A_seq, B_seq 249 | 250 | 251 | def save_traj(images, image_goal, gif_path, task): 252 | # save trajectory as gif file 253 | fig, aa = plt.subplots(1, 2) 254 | m1 = aa[0].matshow(images[0], cmap=plt.cm.gray, vmin=0.0, vmax=1.0) 255 | aa[0].set_title("Time step 0") 256 | aa[0].set_yticklabels([]) 257 | aa[0].set_xticklabels([]) 258 | m2 = aa[1].matshow(image_goal, cmap=plt.cm.gray, vmin=0.0, vmax=1.0) 259 | aa[1].set_title("goal") 260 | aa[1].set_yticklabels([]) 261 | aa[1].set_xticklabels([]) 262 | fig.tight_layout() 263 | 264 | def updatemat2(t): 265 | m1.set_data(images[t]) 266 | aa[0].set_title("Time step " + str(t)) 267 | m2.set_data(image_goal) 268 | return m1, m2 269 | 270 | frames = len(images) 271 | if task in ["plane", "cartpole"]: 272 | fps = 2 273 | else: 274 | fps = 20 275 | 276 | anim = FuncAnimation(fig, updatemat2, frames=frames, interval=200, blit=True, repeat=True) 277 | Writer = writers["imagemagick"] # animation.writers.avail 278 | writer = Writer(fps=fps, metadata=dict(artist="Me"), bitrate=1800) 279 | 280 | anim.save(gif_path, writer=writer) 281 | 282 | plt.clf() 283 | plt.cla() 284 | -------------------------------------------------------------------------------- /latent_map_pendulum.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import torch 7 | from colour import Color 8 | from mdp.pendulum_mdp import PendulumMDP 9 | from pc3_model import PC3 10 | from torchvision.transforms import ToTensor 11 | 12 | 13 | red = Color("red") 14 | blue = Color("blue") 15 | num_angles = 200 16 | num_each_angle = 20 17 | 18 | np.random.seed(0) 19 | torch.manual_seed(0) 20 | 21 | 22 | def map_angle_color(num_angles, mdp): 23 | colors = list(red.range_to(blue, num_angles)) 24 | colors_rgb = [color.rgb for color in colors] 25 | all_angles = np.linspace(start=mdp.angle_range[0], stop=mdp.angle_range[1], num=num_angles) 26 | angle_color_map = dict(zip(all_angles, colors_rgb)) 27 | return angle_color_map, colors_rgb 28 | 29 | 30 | def assign_latent_color(model, angel, mdp): 31 | # the same angle corresponds to multiple states -> multiple latent vectors 32 | # map an angle to multiple latent vectors corresponding to that angle 33 | angle_vels = np.linspace( 34 | start=mdp.angular_velocity_range[0], stop=mdp.angular_velocity_range[1], num=num_each_angle 35 | ) 36 | all_z_for_angle = [] 37 | for i in range(num_each_angle): 38 | ang_velocity = angle_vels[i] 39 | s = np.array([angel, ang_velocity]) 40 | x = mdp.render(s).squeeze() 41 | # take a random action 42 | u = mdp.sample_random_action() 43 | s_next = mdp.transition_function(s, u) 44 | x_next = mdp.render(s_next).squeeze() 45 | # reverse order: the state we want to represent is x not x_next 46 | x_with_history = np.vstack((x_next, x)) 47 | x_with_history = ToTensor()(x_with_history).double() 48 | with torch.no_grad(): 49 | z = model.encode(x_with_history.view(-1, x_with_history.shape[-1] * x_with_history.shape[-2])) 50 | all_z_for_angle.append(z.detach().squeeze().numpy()) 51 | return all_z_for_angle 52 | 53 | 54 | def show_latent_map(model, mdp): 55 | angle_color_map, colors_rgb = map_angle_color(num_angles, mdp) 56 | colors_list = [] 57 | for color in colors_rgb: 58 | for i in range(num_each_angle): 59 | colors_list.append(list(color)) 60 | all_z = [] 61 | 62 | for angle in angle_color_map: 63 | all_z_for_angle = assign_latent_color(model, angle, mdp) 64 | all_z += all_z_for_angle 65 | all_z = np.array(all_z) 66 | 67 | z_min = np.min(all_z, axis=0) 68 | z_max = np.max(all_z, axis=0) 69 | all_z = 2 * (all_z - z_min) / (z_max - z_min) - 1.0 70 | all_z = all_z * 35 71 | 72 | ax = plt.axes(projection="3d") 73 | ax.set_xlim([-60, 60]) 74 | ax.set_ylim([-60, 60]) 75 | ax.set_zlim([-60, 60]) 76 | xdata = all_z[:, 0] 77 | ydata = all_z[:, 1] 78 | zdata = all_z[:, 2] 79 | 80 | ax.scatter(xdata, ydata, zdata, c=colors_list, marker="o", s=10) 81 | plt.show() 82 | 83 | 84 | def main(args): 85 | log_path = args.log_path 86 | epoch = args.epoch 87 | 88 | mdp = PendulumMDP() 89 | 90 | # load the specified model 91 | with open(log_path + "/settings", "r") as f: 92 | settings = json.load(f) 93 | armotized = settings["armotized"] 94 | model = PC3(armotized=armotized, x_dim=4608, z_dim=3, u_dim=1, env="pendulum") 95 | model.load_state_dict(torch.load(log_path + "/model_" + str(epoch), map_location="cpu")) 96 | model.eval() 97 | 98 | show_latent_map(model, mdp) 99 | 100 | 101 | if __name__ == "__main__": 102 | parser = argparse.ArgumentParser(description="train pcc model") 103 | 104 | parser.add_argument("--log_path", required=True, type=str, help="path to trained model") 105 | parser.add_argument("--epoch", required=True, type=int, help="load model corresponding to this epoch") 106 | args = parser.parse_args() 107 | 108 | main(args) 109 | -------------------------------------------------------------------------------- /latent_map_planar.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import torch 7 | from colour import Color 8 | from mdp.plane_obstacles_mdp import PlanarObstaclesMDP 9 | from pc3_model import PC3 10 | from PIL import Image, ImageDraw 11 | 12 | 13 | blue = Color("blue") 14 | colors = list(blue.range_to(Color("red"), 40)) 15 | colors_rgb = [color.rgb for color in colors] 16 | start, end = 0, 40 17 | width, height = 40, 40 18 | 19 | 20 | # states corresponding to obstacles' positions 21 | def get_invalid_state(mdp): 22 | invalid_pos = [] 23 | for x in range(start, end): 24 | for y in range(start, end): 25 | s = [x, y] 26 | if not mdp.is_valid_state(np.array(s)): 27 | invalid_pos.append(s) 28 | return invalid_pos 29 | 30 | 31 | def color_gradient(): 32 | img = Image.new("RGB", (width, height), "#FFFFFF") 33 | draw = ImageDraw.Draw(img) 34 | 35 | for i, color in zip(range(start, end), colors_rgb): 36 | r1, g1, b1 = color[0] * 255.0, color[1] * 255.0, color[2] * 255.0 37 | draw.line((i, start, i, end), fill=(int(r1), int(g1), int(b1))) 38 | 39 | return img 40 | 41 | 42 | def get_true_map(mdp): 43 | invalid_pos = get_invalid_state(mdp) 44 | color_gradient_img = color_gradient() 45 | img_scaled = Image.new("RGB", (width * 10, height * 10), "#FFFFFF") 46 | draw = ImageDraw.Draw(img_scaled) 47 | for y in range(start, end): 48 | for x in range(start, end): 49 | if [y, x] in invalid_pos: 50 | continue 51 | else: 52 | x_scaled, y_scaled = x * 10, y * 10 53 | draw.ellipse( 54 | (x_scaled - 2, y_scaled - 2, x_scaled + 2, y_scaled + 2), fill=color_gradient_img.getpixel((x, y)) 55 | ) 56 | img_arr_scaled = np.array(img_scaled) / 255.0 57 | return img_arr_scaled 58 | 59 | 60 | def draw_latent_map(model, mdp): 61 | invalid_pos = get_invalid_state(mdp) 62 | img = color_gradient() 63 | # compute latent z 64 | all_z = [] 65 | for x in range(start, end): 66 | for y in range(start, end): 67 | s = np.array([x, y]) 68 | if [x, y] in invalid_pos: 69 | all_z.append(np.zeros(2)) 70 | else: 71 | with torch.no_grad(): 72 | obs = torch.Tensor(mdp.render(s)).unsqueeze(0).view(-1, 1600).double() 73 | if next(model.parameters()).is_cuda: 74 | obs = obs.cuda() 75 | mu = model.encode(obs) 76 | z = mu.squeeze().cpu().numpy() 77 | all_z.append(np.copy(z)) 78 | all_z = np.array(all_z) 79 | 80 | # normalize and scale to plot 81 | z_min = np.min(all_z, axis=0) 82 | z_max = np.max(all_z, axis=0) 83 | all_z = (all_z - z_min) / (z_max - z_min) 84 | all_z = all_z * 350 85 | 86 | # plot 87 | latent_map = {} 88 | i = 0 89 | for x in range(start, end): 90 | for y in range(start, end): 91 | latent_map[(x, y)] = all_z[i] 92 | i += 1 93 | 94 | img_latent = Image.new("RGB", (mdp.width * 10, mdp.height * 10), "#FFFFFF") 95 | draw = ImageDraw.Draw(img_latent) 96 | for k in latent_map: 97 | x, y = k 98 | if [x, y] in invalid_pos: 99 | continue 100 | else: 101 | x_scaled, y_scaled = latent_map[k][1], latent_map[k][0] 102 | draw.ellipse((x_scaled - 2, y_scaled - 2, x_scaled + 2, y_scaled + 2), fill=img.getpixel((y, x))) 103 | return img_latent 104 | 105 | 106 | def show_latent_map(model, mdp): 107 | true_map = get_true_map(mdp) 108 | latent_map = draw_latent_map(model, mdp) 109 | latent_map = np.array(latent_map) / 255.0 110 | 111 | f, axarr = plt.subplots(1, 2, figsize=(15, 15)) 112 | axarr[0].imshow(true_map) 113 | axarr[1].imshow(latent_map) 114 | plt.show() 115 | 116 | 117 | def main(args): 118 | log_path = args.log_path 119 | epoch = args.epoch 120 | 121 | mdp = PlanarObstaclesMDP() 122 | 123 | # load the specified model 124 | with open(log_path + "/settings", "r") as f: 125 | settings = json.load(f) 126 | armotized = settings["armotized"] 127 | model = PC3(armotized, 1600, 2, 2, "planar") 128 | model.load_state_dict(torch.load(log_path + "/model_" + str(epoch), map_location="cpu")) 129 | model.eval() 130 | 131 | show_latent_map(model, mdp) 132 | 133 | 134 | if __name__ == "__main__": 135 | parser = argparse.ArgumentParser(description="train pcc model") 136 | 137 | parser.add_argument("--log_path", required=True, type=str, help="path to trained model") 138 | parser.add_argument("--epoch", required=True, type=int, help="load model corresponding to this epoch") 139 | args = parser.parse_args() 140 | 141 | main(args) 142 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from networks import MultivariateNormalDiag 3 | 4 | 5 | torch.set_default_dtype(torch.float64) 6 | 7 | 8 | def repeat_dist(dist, bacth_size, z_dim): 9 | # [dist1, dist2, dist3] -> [dist1, dist2, dist3, dist1, dist2, dist3, dist1, dist2, dist3] 10 | mean, sttdev = dist.mean, dist.stddev 11 | mean = mean.repeat(bacth_size, 1) 12 | sttdev = sttdev.repeat(bacth_size, 1) 13 | return MultivariateNormalDiag(mean, sttdev) 14 | 15 | 16 | def nce_past(z_next_trans_dist, z_next_enc): 17 | """ 18 | z_next_trans_dist: p(.|z, u) 19 | z_next_enc: samples from p(.|x') 20 | """ 21 | batch_size, z_dim = z_next_enc.size(0), z_next_enc.size(1) 22 | 23 | z_next_trans_dist_rep = repeat_dist(z_next_trans_dist, batch_size, z_dim) 24 | z_next_enc_rep = z_next_enc.repeat(1, batch_size).view(-1, z_dim) 25 | 26 | # scores[i, j] = p(z'_i | z_j, u_j) 27 | scores = z_next_trans_dist_rep.log_prob(z_next_enc_rep).view(batch_size, batch_size) 28 | with torch.no_grad(): 29 | normalize = torch.max(scores, dim=-1)[0].view(-1, 1) 30 | scores = scores - normalize 31 | scores = torch.exp(scores) 32 | 33 | # I_NCE 34 | positive_samples = scores.diag() 35 | avg_negative_samples = torch.mean(scores, dim=-1) 36 | return -torch.mean(torch.log(positive_samples / avg_negative_samples + 1e-8)) 37 | 38 | 39 | def curvature(model, z, u, delta, armotized): 40 | z_alias = z.detach().requires_grad_(True) 41 | u_alias = u.detach().requires_grad_(True) 42 | eps_z = torch.normal(mean=torch.zeros_like(z), std=torch.empty_like(z).fill_(delta)) 43 | eps_u = torch.normal(mean=torch.zeros_like(u), std=torch.empty_like(u).fill_(delta)) 44 | 45 | z_bar = z_alias + eps_z 46 | u_bar = u_alias + eps_u 47 | 48 | f_z_bar, A, B = model.transition(z_bar, u_bar) 49 | f_z_bar = f_z_bar.mean 50 | f_z, A, B = model.transition(z_alias, u_alias) 51 | f_z = f_z.mean 52 | 53 | z_dim, u_dim = z.size(1), u.size(1) 54 | if not armotized: 55 | _, B = get_jacobian(model.dynamics, z_alias, u_alias) 56 | (grad_z,) = torch.autograd.grad(f_z, z_alias, grad_outputs=eps_z, retain_graph=True, create_graph=True) 57 | grad_u = torch.bmm(B, eps_u.view(-1, u_dim, 1)).squeeze() 58 | else: 59 | A = A.view(-1, z_dim, z_dim) 60 | B = B.view(-1, z_dim, u_dim) 61 | eps_z = eps_z.view(-1, z_dim, 1) 62 | eps_u = eps_u.view(-1, u_dim, 1) 63 | grad_z = torch.bmm(A, eps_z.view(-1, u_dim, 1)).squeeze() 64 | grad_u = torch.bmm(B, eps_u.view(-1, u_dim, 1)).squeeze() 65 | 66 | taylor_error = f_z_bar - (grad_z + grad_u) - f_z 67 | cur_loss = torch.mean(torch.sum(taylor_error.pow(2), dim=1)) 68 | return cur_loss 69 | 70 | 71 | def get_jacobian(dynamics, batched_z, batched_u): 72 | """ 73 | compute the jacobian of F(z,u) w.r.t z, u 74 | """ 75 | batch_size = batched_z.size(0) 76 | z_dim = batched_z.size(-1) 77 | 78 | z, u = batched_z.unsqueeze(1), batched_u.unsqueeze(1) # batch_size, 1, input_dim 79 | z, u = z.repeat(1, z_dim, 1), u.repeat(1, z_dim, 1) # batch_size, output_dim, input_dim 80 | z_next = dynamics(z, u)[0].mean 81 | grad_inp = torch.eye(z_dim).reshape(1, z_dim, z_dim).repeat(batch_size, 1, 1).cuda() 82 | all_A, all_B = torch.autograd.grad(z_next, [z, u], [grad_inp, grad_inp], create_graph=True, retain_graph=True) 83 | return all_A, all_B 84 | -------------------------------------------------------------------------------- /map_scale_stat.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import numpy as np 6 | import torch 7 | from datasets import CartPoleDataset, PendulumDataset, PlanarDataset, ThreePoleDataset 8 | from pc3_model import PC3 9 | from torch.utils.data import DataLoader 10 | 11 | 12 | env_data_dim = { 13 | "planar": (1600, 2, 2), 14 | "pendulum": ((2, 48, 48), 3, 1), 15 | "cartpole": ((2, 80, 80), 8, 1), 16 | "threepole": ((2, 80, 80), 8, 3), 17 | } 18 | datasets = { 19 | "planar": PlanarDataset, 20 | "pendulum": PendulumDataset, 21 | "cartpole": CartPoleDataset, 22 | "threepole": ThreePoleDataset, 23 | } 24 | 25 | 26 | def calc_scale(model, env_name, sample_size=5000, noise=0): 27 | dataset = datasets[env_name] 28 | dataset = dataset(sample_size=sample_size, noise=noise) 29 | data_loader = DataLoader(dataset, batch_size=100, shuffle=False, drop_last=False, num_workers=1) 30 | 31 | avg_norm_2 = 0.0 32 | avg_dynamics_norm_2 = 0.0 33 | for x, u, _ in data_loader: 34 | with torch.no_grad(): 35 | z = model.encode(x) 36 | z_next = model.transition(z, u)[0].mean 37 | avg_norm_2 += torch.mean(torch.sum(z.pow(2), dim=1)) 38 | avg_dynamics_norm_2 += torch.mean(torch.sum(z_next.pow(2), dim=1)) 39 | return avg_norm_2 / len(data_loader), avg_dynamics_norm_2 / len(data_loader) 40 | 41 | 42 | def main(args): 43 | env_name = args.env 44 | assert env_name in ["planar", "pendulum", "cartpole", "threepole"] 45 | setting_path = args.setting_path 46 | epoch = args.epoch 47 | 48 | x_dim, z_dim, u_dim = env_data_dim[env_name] 49 | if env_name in ["planar", "pendulum"]: 50 | x_dim = np.prod(x_dim) 51 | 52 | all_avg_norm_2 = [] 53 | all_avg_dyn_norm_2 = [] 54 | log_folders = [ 55 | os.path.join(setting_path, dI) 56 | for dI in os.listdir(setting_path) 57 | if os.path.isdir(os.path.join(setting_path, dI)) 58 | ] 59 | for log in log_folders: 60 | with open(log + "/settings", "r") as f: 61 | settings = json.load(f) 62 | armotized = settings["armotized"] 63 | 64 | # load the trained model 65 | model = PC3(armotized, x_dim, z_dim, u_dim, env_name) 66 | model.load_state_dict(torch.load(log + "/model_" + str(epoch), map_location="cpu")) 67 | model.eval() 68 | 69 | avg_norm_2, avg_dyn_norm_2 = calc_scale(model, env_name) 70 | all_avg_norm_2.append(avg_norm_2) 71 | all_avg_dyn_norm_2.append(avg_dyn_norm_2) 72 | 73 | # compute mean and variance 74 | all_avg_norm_2 = np.array(all_avg_norm_2) 75 | mean_norm_2 = np.mean(all_avg_norm_2) 76 | var_norm_2 = np.var(all_avg_norm_2) 77 | 78 | all_avg_dyn_norm_2 = np.array(all_avg_dyn_norm_2) 79 | mean_dyn_norm_2 = np.mean(all_avg_dyn_norm_2) 80 | var_dyn_norm_2 = np.var(all_avg_dyn_norm_2) 81 | 82 | print("Mean of average norm 2: " + str(mean_norm_2)) 83 | print("Variance of average norm 2: " + str(var_norm_2)) 84 | print("Mean of average dynamics norm 2: " + str(mean_dyn_norm_2)) 85 | print("Variance of average dynamics norm 2: " + str(var_dyn_norm_2)) 86 | 87 | 88 | if __name__ == "__main__": 89 | parser = argparse.ArgumentParser(description="compute latent map scale statistics") 90 | 91 | parser.add_argument("--env", required=True, type=str, help="environment to compute statistics") 92 | parser.add_argument("--setting_path", required=True, type=str, help="path to 10 trained models of a setting") 93 | parser.add_argument("--epoch", required=True, type=int, help="load model corresponding to this epoch") 94 | args = parser.parse_args() 95 | 96 | main(args) 97 | -------------------------------------------------------------------------------- /mdp/cartpole_mdp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from mdp.common import StateIndex, wrap 3 | from mdp.pole_base import PoleBase 4 | from PIL import Image, ImageDraw 5 | from scipy.integrate import solve_ivp 6 | 7 | 8 | class CartPoleMDP(PoleBase): 9 | goal_range = [-np.pi / 10, np.pi / 10] 10 | 11 | # state range 12 | angle_range = [-np.pi, np.pi] 13 | angular_velocity_range = [-2 * np.pi, 2 * np.pi] 14 | position_range = [-2.4, 2.4] 15 | velocity_range = [-6.0, 6.0] 16 | 17 | # sampling range 18 | angle_samp_range = 2 * goal_range 19 | 20 | # action range 21 | action_dim = 1 22 | action_range = np.array([-10.0, 10.0]) 23 | 24 | def __init__(self, width=80, height=80, frequency=50, noise=0.0, render_width=6): 25 | """ 26 | Args: 27 | width: width of the rendered image. 28 | height: height of the rendered image. 29 | frequency: the simulator frequency, i.e., the number of steps in 1 second. 30 | noise: noise level 31 | render_width: width of the pole in the rendered image. 32 | """ 33 | self.width = width 34 | self.height = height 35 | self.time_interval = 1 / frequency 36 | self.noise = noise 37 | 38 | self.render_width = render_width 39 | self.render_length = height * 0.65 40 | self.cart_render_size = (width / 10.0, height / 20.0) 41 | 42 | super(CartPoleMDP, self).__init__() 43 | 44 | def take_step(self, s, u): 45 | # clip the action 46 | u = np.clip(u, self.action_range[0], self.action_range[1]) 47 | 48 | # concatenate s and u to pass through ds_dt 49 | s_aug = np.append(s, u) 50 | 51 | # solve the differientable equation to compute next state 52 | s_next = solve_ivp(self.ds_dt, (0.0, self.time_interval), s_aug).y[0:4, -1] # last index is the action applied 53 | 54 | # project state to the valid space. 55 | s_next[StateIndex.THETA] = wrap(s_next[StateIndex.THETA], self.angle_range[0], self.angle_range[1]) 56 | s_next[StateIndex.THETA_DOT] = np.clip( 57 | s_next[StateIndex.THETA_DOT], self.angular_velocity_range[0], self.angular_velocity_range[1] 58 | ) 59 | s_next[StateIndex.X_DOT] = np.clip(s_next[StateIndex.X_DOT], self.velocity_range[0], self.velocity_range[1]) 60 | 61 | return s_next 62 | 63 | def ds_dt(self, t, s_augmented): 64 | mass_combined = self.cart_mass + self.pend_mass 65 | 66 | theta = s_augmented[StateIndex.THETA] 67 | theta_dot = s_augmented[StateIndex.THETA_DOT] 68 | x_dot = s_augmented[StateIndex.X_DOT] 69 | force = s_augmented[StateIndex.CARTPOLE_ACTION] 70 | 71 | sin_theta = np.sin(theta) 72 | cos_theta = np.cos(theta) 73 | calc_help = force + self.pend_mass * self.length * theta_dot ** 2 * sin_theta 74 | 75 | # derivative of theta_dot 76 | theta_double_dot_num = mass_combined * self.earth_gravity * sin_theta - cos_theta * calc_help 77 | theta_double_dot_denum = 4.0 / 3 * mass_combined * self.length - self.pend_mass * self.length * cos_theta ** 2 78 | theta_double_dot = theta_double_dot_num / theta_double_dot_denum 79 | 80 | # derivative of x_dot 81 | x_double_dot_num = calc_help - self.pend_mass * self.length * theta_double_dot * cos_theta 82 | x_double_dot = x_double_dot_num / mass_combined 83 | 84 | return np.array([theta_dot, theta_double_dot, x_dot, x_double_dot, 0.0]) 85 | 86 | def render(self, s): 87 | # black background. 88 | im = Image.new("L", (self.width, self.height)) 89 | draw = ImageDraw.Draw(im) 90 | 91 | draw.rectangle((0, 0, self.width, self.height), fill=0) 92 | 93 | # cart location. 94 | x_center_image = im.size[0] / 2.0 95 | y_center_cart = im.size[1] - 2 * self.cart_render_size[1] - 2 96 | x_center_cart = x_center_image + (s[StateIndex.X] / self.position_range[1]) * ( 97 | self.width / 2.0 - 1.0 * self.cart_render_size[0] 98 | ) 99 | 100 | # pole location. 101 | x_pole_end = x_center_cart + np.sin([s[StateIndex.THETA]]) * self.render_length 102 | y_pole_end = y_center_cart - np.cos([s[StateIndex.THETA]]) * self.render_length 103 | 104 | # draw cart. 105 | draw.rectangle( 106 | ( 107 | x_center_cart - self.cart_render_size[0], 108 | y_center_cart - self.cart_render_size[1], 109 | x_center_cart + self.cart_render_size[0], 110 | y_center_cart + self.cart_render_size[1], 111 | ), 112 | fill=255, 113 | ) 114 | 115 | # draw pole. 116 | draw.line((x_center_cart, y_center_cart, x_pole_end, y_pole_end), width=self.render_width, fill=255) 117 | 118 | return np.expand_dims(np.asarray(im) / 255.0, axis=-1) 119 | 120 | def is_fail(self, s): 121 | """check if the current state is failed""" 122 | angle = s[StateIndex.THETA] 123 | position = s[StateIndex.X] 124 | return not ( 125 | (self.goal_range[0] < angle < self.goal_range[1]) 126 | and (self.position_range[0] < position < self.position_range[1]) 127 | ) 128 | 129 | def sample_random_state(self): 130 | angle = np.random.uniform(self.angle_samp_range[0], self.angle_samp_range[1]) 131 | angle_rate = np.random.uniform(self.angular_velocity_range[0], self.angular_velocity_range[1]) 132 | pos = np.random.uniform(self.position_range[0], self.position_range[1]) 133 | vel = np.random.uniform(self.velocity_range[0], self.velocity_range[1]) 134 | true_state = np.array([angle, angle_rate, pos, vel]) 135 | return true_state 136 | -------------------------------------------------------------------------------- /mdp/common.py: -------------------------------------------------------------------------------- 1 | """Common functions for the MDP package.""" 2 | 3 | from __future__ import absolute_import, division, print_function 4 | 5 | 6 | def wrap(x, low, high): 7 | """Wraps data between low and high boundaries.""" 8 | diff = high - low 9 | while x > high: 10 | x -= diff 11 | while x < low: 12 | x += diff 13 | return x 14 | 15 | 16 | class StateIndex(object): 17 | """Flexible way to index states in the CartPole Domain. 18 | 19 | This class enumerates the different indices used when indexing the state. 20 | e.g. ``s[StateIndex.THETA]`` is guaranteed to return the angle state. 21 | """ 22 | 23 | THETA, THETA_DOT = 0, 1 24 | X, X_DOT = 2, 3 25 | PEND_ACTION = 2 26 | CARTPOLE_ACTION = 4 27 | 28 | THETA_1, THETA_2, THETA_3 = 0, 2, 4 29 | THETA_1_DOT, THETA_2_DOT, THETA_3_DOT = 1, 3, 5 30 | TORQUE_3_1, TORQUE_3_2, TORQUE_3_3 = 6, 7, 8 31 | -------------------------------------------------------------------------------- /mdp/pendulum_mdp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from mdp.common import StateIndex, wrap 3 | from mdp.pole_base import PoleBase 4 | from PIL import Image, ImageDraw 5 | from scipy.integrate import solve_ivp 6 | 7 | 8 | class PendulumMDP(PoleBase): 9 | # goal range 10 | goal_range = [-np.pi / 6, np.pi / 6] 11 | 12 | # state range 13 | angle_range = [-np.pi, np.pi] 14 | angular_velocity_range = [-3 * np.pi, 3 * np.pi] 15 | 16 | action_dim = 1 17 | 18 | def __init__(self, width=48, height=48, frequency=50, noise=0.0, torque=1.0, render_width=4): 19 | """ 20 | Args: 21 | width: width of the rendered image. 22 | height: height of the rendered image. 23 | frequency: the simulator frequency, i.e., the number of steps in 1 second. 24 | noise: noise level 25 | torque: the maximal torque which can be applied 26 | render_width: width of the pendulum in the rendered image. 27 | """ 28 | self.width = width 29 | self.height = height 30 | self.time_interval = 1.0 / frequency 31 | self.noise = noise 32 | self.action_range = np.array([-torque, torque]) 33 | 34 | self.render_width = render_width 35 | self.render_length = (width / 2) - 2 36 | 37 | super(PendulumMDP, self).__init__() 38 | 39 | def take_step(self, s, u): 40 | # clip the action 41 | u = np.clip(u, self.action_range[0], self.action_range[1]) 42 | 43 | # concatenate s and u to pass through ds_dt 44 | s_aug = np.append(s, u) 45 | 46 | # solve the differientable equation to compute next state 47 | s_next = solve_ivp(self.ds_dt, (0.0, self.time_interval), s_aug).y[0:2, -1] # last index is the action applied 48 | 49 | # project state to the valid range 50 | s_next[StateIndex.THETA] = wrap(s_next[StateIndex.THETA], self.angle_range[0], self.angle_range[1]) 51 | s_next[StateIndex.THETA_DOT] = np.clip( 52 | s_next[StateIndex.THETA_DOT], self.angular_velocity_range[0], self.angular_velocity_range[1] 53 | ) 54 | 55 | return s_next 56 | 57 | def ds_dt(self, t, s_augmented): 58 | theta = s_augmented[StateIndex.THETA] 59 | theta_dot = s_augmented[StateIndex.THETA_DOT] 60 | torque = s_augmented[StateIndex.PEND_ACTION] 61 | 62 | # theta is w.r.t the upside vertical position, which is = pi - theta in tedrake's note 63 | sine = np.sin(np.pi - theta) 64 | theta_prime_num = self.pend_mass * self.earth_gravity * self.length * sine - torque 65 | theta_prime_denum = 1.0 / 3.0 * self.pend_mass * self.length ** 2 # the moment of inertia 66 | theta_double_dot = theta_prime_num / theta_prime_denum 67 | 68 | return np.array([theta_dot, theta_double_dot, 0.0]) 69 | 70 | def render(self, s): 71 | im = Image.new("L", (self.width, self.height)) 72 | draw = ImageDraw.Draw(im) 73 | # black background 74 | draw.rectangle((0, 0, self.width, self.height), fill=0) 75 | 76 | # pendulum location. 77 | x_center = im.size[0] / 2.0 78 | y_center = im.size[1] / 2.0 79 | x_end = x_center + np.sin(s[0]) * self.render_length 80 | y_end = y_center - np.cos(s[0]) * self.render_length 81 | 82 | # white pendulum 83 | draw.line((x_center, y_center, x_end, y_end), width=self.render_width, fill=255) 84 | 85 | return np.expand_dims(np.asarray(im) / 255.0, axis=-1) 86 | 87 | def is_fail(self, s): 88 | return False 89 | 90 | def sample_random_state(self): 91 | angle = np.random.uniform(self.angle_range[0], self.angle_range[1]) 92 | angle_rate = np.random.uniform(self.angular_velocity_range[0], self.angular_velocity_range[1]) 93 | true_state = np.array([angle, angle_rate]) 94 | return true_state 95 | -------------------------------------------------------------------------------- /mdp/plane_obstacles_mdp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image, ImageDraw 3 | 4 | 5 | class PlanarObstaclesMDP(object): 6 | width = 40 7 | height = 40 8 | 9 | obstacles = np.array([[20.5, 5.5], [20.5, 13.5], [20.5, 27.5], [20.5, 35.5], [10.5, 20.5], [30.5, 20.5]]) 10 | obstacles_r = 2.5 # radius of the obstacles when rendering 11 | half_agent_size = 1.5 # robot half-width 12 | 13 | position_range = np.array([half_agent_size, width - half_agent_size]) 14 | 15 | action_dim = 2 16 | 17 | def __init__(self, rw_rendered=1, max_step=3, goal=[37, 37], goal_thres=2, noise=0): 18 | self.rw_rendered = rw_rendered 19 | self.max_step = max_step 20 | self.action_range = np.array([-max_step, max_step]) 21 | self.goal = goal 22 | self.goal_thres = goal_thres 23 | self.noise = noise 24 | super(PlanarObstaclesMDP, self).__init__() 25 | 26 | def is_valid_state(self, s): 27 | # check if the agent runs out of map 28 | if np.any(s < self.position_range[0]) or np.any(s > self.position_range[1]): 29 | return False 30 | 31 | # check if the agent crosses any obstacle (the obstacle is inside the agent) 32 | top, bot = s[0] - self.half_agent_size, s[0] + self.half_agent_size 33 | left, right = s[1] - self.half_agent_size, s[1] + self.half_agent_size 34 | for obs in self.obstacles: 35 | if top <= obs[0] <= bot and left <= obs[1] <= right: 36 | return False 37 | return True 38 | 39 | def take_step(self, s, u, anneal_ratio=0.9): # compute the next state given the current state and action 40 | u = np.clip(u, self.action_range[0], self.action_range[1]) 41 | 42 | s_next = np.clip(s + u, self.position_range[0], self.position_range[1]) 43 | if not self.is_valid_state(s_next): 44 | return s 45 | return s_next 46 | 47 | def transition_function(self, s, u): # compute next state and add noise 48 | s_next = self.take_step(s, u) 49 | # sample noise until get a valid next state 50 | sample_noise = self.noise * np.random.randn(*s_next.shape) 51 | return np.clip(s_next + sample_noise, self.position_range[0], self.position_range[1]) 52 | 53 | def render(self, s): 54 | top, bottom, left, right = self.get_pixel_location(s) 55 | x = self.generate_env() 56 | x[top:bottom, left:right] = 1.0 # robot is white on black background 57 | return x 58 | 59 | def get_pixel_location(self, s): 60 | # return the location of agent when rendered 61 | center_x, center_y = int(round(s[0])), int(round(s[1])) 62 | top = center_x - self.rw_rendered 63 | bottom = center_x + self.rw_rendered 64 | left = center_y - self.rw_rendered 65 | right = center_y + self.rw_rendered 66 | return top, bottom, left, right 67 | 68 | def generate_env(self): 69 | """ 70 | return the image with 6 obstacles 71 | """ 72 | img_arr = np.zeros(shape=(self.width, self.height)) 73 | 74 | img_env = Image.fromarray(img_arr) 75 | draw = ImageDraw.Draw(img_env) 76 | for y, x in self.obstacles: 77 | draw.ellipse( 78 | ( 79 | int(x) - int(self.obstacles_r), 80 | int(y) - int(self.obstacles_r), 81 | int(x) + int(self.obstacles_r), 82 | int(y) + int(self.obstacles_r), 83 | ), 84 | fill=255, 85 | ) 86 | img_env = img_env.convert("L") 87 | 88 | img_arr = np.array(img_env) / 255.0 89 | return img_arr 90 | 91 | def is_goal(self, s): 92 | return np.sqrt(np.sum((s - self.goal) ** 2)) <= self.goal_thres 93 | 94 | def is_fail(self, s): 95 | return False 96 | 97 | def reward_function(self, s): 98 | if self.is_goal(s): 99 | reward = 1 100 | else: 101 | reward = 0 102 | return reward 103 | 104 | def sample_random_state(self): 105 | while True: 106 | s = np.random.uniform(self.half_agent_size, self.width - self.half_agent_size, size=2) 107 | if self.is_valid_state(s): 108 | return s 109 | 110 | def is_low_error(self, u, epsilon=0.1): 111 | rounded_u = np.round(u) 112 | diff = np.abs(u - rounded_u) 113 | return np.all(diff <= epsilon) 114 | 115 | def is_valid_action(self, s, u): 116 | return self.is_low_error(u) and self.is_valid_state(s + u) 117 | 118 | def sample_valid_random_action(self, s): 119 | while True: 120 | u = np.random.uniform(self.action_range[0], self.action_range[1], size=self.action_dim) 121 | if self.is_valid_action(s, u): 122 | return u 123 | 124 | def sample_random_action(self): 125 | return np.random.uniform(self.action_range[0], self.action_range[1], size=self.action_dim) 126 | 127 | def sample_extreme_action(self): 128 | x_direction = np.random.choice([self.action_range[0], self.action_range[1]]) 129 | y_direction = np.random.choice([self.action_range[0], self.action_range[1]]) 130 | return np.array([x_direction, y_direction]) 131 | -------------------------------------------------------------------------------- /mdp/pole_base.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | from mdp.common import StateIndex 6 | from numpy import pi 7 | 8 | 9 | root_path = str(Path(os.path.dirname(os.path.abspath(__file__))).parent) 10 | os.sys.path.append(root_path) 11 | 12 | 13 | class PoleBase(object): 14 | """ 15 | The base class to store common attributes for pendulum and cartpole 16 | Basically, the MDP works as follows: 17 | - Define a function ds_dt, which computes the derivatives of state w.r.t time step t (from Tedrake's textbook) 18 | - The time interval between 2 consecutive time steps is determined by frequency, time_interval = 1./frequency 19 | - Use solve_ivp package to solve the differiential equation and compute the next state 20 | """ 21 | 22 | # environment specifications 23 | earth_gravity = 9.81 24 | pend_mass = 0.1 25 | cart_mass = 1.0 26 | length = 0.5 27 | # reward if close to goal 28 | goal_reward = 1 29 | 30 | def __init__(self): 31 | assert np.all( 32 | 2 * pi / self.time_interval > np.abs(self.angular_velocity_range) 33 | ), """ 34 | WARNING: Your step size is too small or the angular rate limit is too large. 35 | This could lead to a situation in which the pole is at the same state in 2 36 | consecutive time step (the pole finishes a round). 37 | """ 38 | 39 | def take_step(self, s, u): # compute the next state given the current state and action 40 | pass 41 | 42 | def ds_dt(self, t, s): # derivative of s w.r.t t 43 | pass 44 | 45 | def transition_function(self, s, u): # compute next state and add noise 46 | s_next = self.take_step(s, u) 47 | # add noise 48 | s_next += self.noise * np.random.randn(*s_next.shape) 49 | return s_next 50 | 51 | def render(self, s): 52 | pass 53 | 54 | def is_goal(self, s): 55 | """Check if the pendulum is in goal region""" 56 | angle = s[StateIndex.THETA] 57 | return self.goal_range[0] <= angle <= self.goal_range[1] 58 | 59 | def is_fail(self, s): 60 | pass 61 | 62 | def reward_function(self, s): 63 | """Reward function.""" 64 | return int(self.is_goal(s)) * self.goal_reward 65 | 66 | def sample_random_state(self): 67 | pass 68 | 69 | def sample_random_action(self): 70 | """Sample a random action from action range.""" 71 | return np.array([np.random.uniform(self.action_range[0], self.action_range[1])]) 72 | 73 | def sample_extreme_action(self): 74 | """Sample a random extreme action from action range.""" 75 | return np.array([np.random.choice([self.action_range[0], self.action_range[1]])]) 76 | -------------------------------------------------------------------------------- /mdp/three_pole_mdp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from mdp.common import StateIndex, wrap 3 | from mdp.pole_base import PoleBase 4 | from PIL import Image, ImageDraw 5 | from scipy.integrate import solve_ivp 6 | 7 | 8 | class ThreePoleMDP(PoleBase): 9 | # goal_range 10 | goal_range = [-np.pi / 6, np.pi / 6] 11 | 12 | # state range 13 | angle_1_range = [-np.pi, np.pi] 14 | angle_2_range = [-2.0 * np.pi / 3, 2.0 * np.pi / 3] 15 | angle_3_range = [-np.pi / 3.0, np.pi / 3.0] 16 | angular_velocity_range = [-0.5 * np.pi, 0.5 * np.pi] 17 | 18 | mass_pend_1 = 0.1 19 | mass_pend_2 = 0.1 20 | mass_pend_3 = 0.1 21 | length_1 = 0.5 22 | length_2 = 0.5 23 | length_3 = 0.5 24 | 25 | action_dim = 3 26 | 27 | def __init__(self, width=80, height=80, frequency=50, noise=0.0, torque=1.0, line_width=5): 28 | """ 29 | Args: 30 | width: width of the rendered image. 31 | height: height of the rendered image. 32 | frequency: the simulator frequency, i.e., the number of steps in 1 second. 33 | noise: noise level 34 | torque: the maximal torque which can be applied 35 | render_width: width of the line in the rendered image. 36 | """ 37 | self.width = width 38 | self.height = height 39 | self.time_interval = 1.0 / frequency 40 | self.noise = noise 41 | self.action_range = np.array([-torque, torque]) 42 | 43 | self.line_width = line_width 44 | self.visual_length = (width / 6) - 2 45 | 46 | super(ThreePoleMDP, self).__init__() 47 | 48 | def take_step(self, s, a): 49 | """Computes the next state from the current state and the action.""" 50 | torque_1_action = a[0] 51 | # Clip the action to valid values. 52 | torque_1_action = np.clip(torque_1_action, self.action_range[0], self.action_range[1]) 53 | 54 | torque_2_action = a[1] 55 | # Clip the action to valid values. 56 | torque_2_action = np.clip(torque_2_action, self.action_range[0], self.action_range[1]) 57 | 58 | torque_3_action = a[2] 59 | # Clip the action to valid values. 60 | torque_3_action = np.clip(torque_3_action, self.action_range[0], self.action_range[1]) 61 | 62 | # Add the action to the state so it can be passed to _dsdt. 63 | s_aug = np.append(s, np.array([torque_1_action, torque_2_action, torque_3_action])) 64 | 65 | # Compute next state. 66 | # Type of integration and integration step. 67 | dt_in = self.time_interval 68 | if ( 69 | self.goal_range[0] < s[StateIndex.THETA_1] < self.goal_range[1] 70 | and self.goal_range[0] < s[StateIndex.THETA_1] + s[StateIndex.THETA_2] < self.goal_range[1] 71 | and self.goal_range[0] 72 | < s[StateIndex.THETA_1] + s[StateIndex.THETA_2] + s[StateIndex.THETA_3] 73 | < self.goal_range[1] 74 | ): 75 | dt_in = self.time_interval 76 | else: 77 | dt_in = self.time_interval * 2.5 78 | 79 | ns = solve_ivp(self.ds_dt, (0.0, dt_in), s_aug).y[0:6, -1] 80 | 81 | # Project variables to valid space. 82 | theta_1 = wrap(ns[StateIndex.THETA_1], self.angle_1_range[0], self.angle_1_range[1]) 83 | ns[StateIndex.THETA_1] = np.clip(theta_1, self.angle_1_range[0], self.angle_1_range[1]) 84 | ns[StateIndex.THETA_1_DOT] = np.clip( 85 | ns[StateIndex.THETA_1_DOT], self.angular_velocity_range[0], self.angular_velocity_range[1] 86 | ) 87 | 88 | theta_2 = wrap(ns[StateIndex.THETA_2], self.angle_2_range[0], self.angle_2_range[1]) 89 | ns[StateIndex.THETA_2] = np.clip(theta_2, self.angle_2_range[0], self.angle_2_range[1]) 90 | ns[StateIndex.THETA_2_DOT] = np.clip( 91 | ns[StateIndex.THETA_2_DOT], self.angular_velocity_range[0], self.angular_velocity_range[1] 92 | ) 93 | 94 | theta_3 = wrap(ns[StateIndex.THETA_3], self.angle_3_range[0], self.angle_3_range[1]) 95 | ns[StateIndex.THETA_3] = np.clip(theta_3, self.angle_3_range[0], self.angle_3_range[1]) 96 | ns[StateIndex.THETA_3_DOT] = np.clip( 97 | ns[StateIndex.THETA_3_DOT], self.angular_velocity_range[0], self.angular_velocity_range[1] 98 | ) 99 | 100 | return ns 101 | 102 | def ds_dt(self, t, s_augmented): 103 | """Calculates derivatives at a given state.""" 104 | # Unused. 105 | del t 106 | 107 | # Extracting current state and action. 108 | theta_1 = s_augmented[StateIndex.THETA_1] 109 | theta_1_dot = s_augmented[StateIndex.THETA_1_DOT] 110 | theta_2 = s_augmented[StateIndex.THETA_2] 111 | theta_2_dot = s_augmented[StateIndex.THETA_2_DOT] 112 | theta_3 = s_augmented[StateIndex.THETA_3] 113 | theta_3_dot = s_augmented[StateIndex.THETA_3_DOT] 114 | 115 | theta_dot = np.array([theta_1_dot, theta_2_dot, theta_3_dot]) 116 | 117 | torque_1 = s_augmented[StateIndex.TORQUE_3_1] 118 | torque_2 = s_augmented[StateIndex.TORQUE_3_2] 119 | torque_3 = s_augmented[StateIndex.TORQUE_3_3] 120 | torque = np.array([torque_1, torque_2, torque_3]) 121 | 122 | # Useful mid-calculation. 123 | # NOTE: the angle here is clock-wise 124 | # which is -\theta from tedrake's reference 125 | 126 | sine_1 = np.sin(-theta_1) 127 | sine_2 = np.sin(-theta_2) 128 | sine_3 = np.sin(-theta_3) 129 | sine_2_3 = np.sin(-(theta_2 + theta_3)) 130 | 131 | # cosine_1 = np.cos(np.pi - theta_1) 132 | cosine_2 = np.cos(-theta_2) 133 | cosine_3 = np.cos(-theta_3) 134 | cosine_2_3 = np.cos(-(theta_2 + theta_3)) 135 | 136 | sine_1_2 = np.sin(-(theta_1 + theta_2)) 137 | # cosine_1_2 = np.cos(np.pi - (theta_1 + theta_2)) 138 | sine_1_2_3 = np.sin(-(theta_1 + theta_2 + theta_3)) 139 | 140 | i_1 = 1.0 / 3.0 * self.mass_pend_1 * self.length_1 ** 2 141 | i_2 = 1.0 / 3.0 * self.mass_pend_2 * self.length_2 ** 2 142 | i_3 = 1.0 / 3.0 * self.mass_pend_3 * self.length_3 ** 2 143 | 144 | length_c1 = self.length_1 / 2.0 145 | length_c2 = self.length_2 / 2.0 146 | length_c3 = self.length_3 / 2.0 147 | 148 | # point mass version, not a rod, so no inertia and no center-of-mass 149 | alpha_1 = i_1 + (self.mass_pend_2 + self.mass_pend_3) * self.length_1 ** 2 150 | alpha_2 = i_2 + self.mass_pend_3 * self.length_2 ** 2 151 | alpha_3 = (self.mass_pend_2 * length_c2 + self.mass_pend_3 * self.length_2) * self.length_1 152 | alpha_4 = i_3 153 | alpha_5 = self.mass_pend_3 * self.length_1 * length_c3 154 | alpha_6 = self.mass_pend_3 * self.length_2 * length_c3 155 | 156 | h_11 = alpha_1 + alpha_2 + alpha_4 + 2 * alpha_5 * cosine_2_3 + 2 * alpha_3 * cosine_2 + 2 * alpha_6 * cosine_3 157 | h_12 = alpha_2 + alpha_4 + alpha_3 * cosine_2 + alpha_5 * cosine_2_3 + 2 * alpha_6 * cosine_3 158 | h_13 = alpha_4 + alpha_5 * cosine_2_3 + alpha_6 * cosine_3 159 | 160 | h_21 = h_12 161 | h_22 = alpha_2 + alpha_4 + 2 * alpha_6 * cosine_3 162 | h_23 = alpha_4 + alpha_6 * cosine_3 163 | 164 | h_31 = h_13 165 | h_32 = h_23 166 | h_33 = alpha_4 167 | h_mat = np.array([[h_11, h_12, h_13], [h_21, h_22, h_23], [h_31, h_32, h_33]]) 168 | 169 | beta_1 = ( 170 | self.mass_pend_1 * length_c1 + self.mass_pend_2 * self.length_1 + self.mass_pend_3 * self.length_1 171 | ) * self.earth_gravity 172 | beta_2 = (self.mass_pend_2 * length_c2 + self.mass_pend_3 * self.length_2) * self.earth_gravity 173 | beta_3 = self.mass_pend_3 * self.earth_gravity * length_c3 174 | 175 | c_11 = ( 176 | alpha_5 * (theta_2_dot + theta_3_dot) * sine_2_3 177 | + alpha_3 * theta_2_dot * sine_2 178 | + alpha_6 * theta_3_dot * sine_3 179 | ) 180 | c_12 = ( 181 | alpha_5 * (theta_1_dot + theta_2_dot + theta_3_dot) * sine_2_3 182 | + alpha_3 * (theta_1_dot + theta_2_dot) * sine_2 183 | + alpha_6 * theta_3_dot * sine_3 184 | ) 185 | c_13 = (theta_1_dot + theta_2_dot + theta_3_dot) * (alpha_5 * sine_2_3 + alpha_6 * sine_3) 186 | c_21 = -alpha_5 * theta_1_dot * sine_2_3 - alpha_3 * theta_1_dot * sine_2 + alpha_6 * theta_3_dot * sine_3 187 | c_22 = alpha_6 * theta_3_dot * sine_3 188 | c_23 = alpha_6 * (theta_1_dot + theta_2_dot + theta_3_dot) * sine_3 189 | c_31 = -alpha_5 * theta_1_dot * sine_2_3 - alpha_6 * (theta_1_dot + theta_2_dot) * sine_3 190 | c_32 = -alpha_6 * (theta_1_dot + theta_2_dot) * sine_3 191 | c_33 = 0.0 192 | c_mat = np.array([[c_11, c_12, c_13], [c_21, c_22, c_23], [c_31, c_32, c_33]]) 193 | 194 | g_1 = -beta_1 * sine_1 - beta_2 * sine_1_2 - beta_3 * sine_1_2_3 195 | g_2 = -beta_2 * sine_1_2 - beta_3 * sine_1_2_3 196 | g_3 = -beta_3 * sine_1_2_3 197 | g_mat = np.array([g_1, g_2, g_3]) 198 | 199 | b_mat = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) 200 | 201 | theta_double_dot = -np.linalg.pinv(h_mat + 1e-6 * np.eye(len(h_mat))).dot( 202 | b_mat.dot(torque) - c_mat.dot(-theta_dot) - g_mat 203 | ) 204 | 205 | theta_1_double_dot = theta_double_dot[0] 206 | theta_2_double_dot = theta_double_dot[1] 207 | theta_3_double_dot = theta_double_dot[2] 208 | 209 | # Derivatives. 210 | return np.array( 211 | [ 212 | theta_1_dot, 213 | theta_1_double_dot, 214 | theta_2_dot, 215 | theta_2_double_dot, 216 | theta_3_dot, 217 | theta_3_double_dot, 218 | 0.0, 219 | 0.0, 220 | 0.0, 221 | ] 222 | ) 223 | 224 | def render(self, s): 225 | im = Image.new("L", (self.width, self.height)) 226 | draw = ImageDraw.Draw(im) 227 | # Draw background. 228 | draw.rectangle((0, 0, self.width, self.height), fill=0) 229 | 230 | # Pole 1 location. 231 | xstart_1 = im.size[0] / 2.0 232 | ystart_1 = im.size[1] / 2.0 233 | xend_1 = xstart_1 + np.sin(s[0]) * self.visual_length 234 | yend_1 = ystart_1 - np.cos(s[0]) * self.visual_length 235 | 236 | # Draw pole 1. 237 | draw.line((xstart_1, ystart_1, xend_1, yend_1), width=self.line_width, fill=255) 238 | 239 | # Pole 2 location. 240 | xstart_2 = xend_1 241 | ystart_2 = yend_1 242 | xend_2 = xstart_2 + np.sin(s[0] + s[2]) * self.visual_length 243 | yend_2 = ystart_2 - np.cos(s[0] + s[2]) * self.visual_length 244 | 245 | # Draw pole 2. 246 | draw.line((xstart_2, ystart_2, xend_2, yend_2), width=self.line_width, fill=255) 247 | 248 | # Pole 2 location. 249 | xstart_3 = xend_2 250 | ystart_3 = yend_2 251 | xend_3 = xstart_3 + np.sin(s[0] + s[2] + s[4]) * self.visual_length 252 | yend_3 = ystart_3 - np.cos(s[0] + s[2] + s[4]) * self.visual_length 253 | 254 | # Draw pole 2. 255 | draw.line((xstart_3, ystart_3, xend_3, yend_3), width=self.line_width, fill=255) 256 | 257 | return np.expand_dims(np.asarray(im) / 255.0, axis=-1) 258 | 259 | def is_fail(self, s): 260 | """Indicates whether the state results in failure.""" 261 | # Unused. 262 | del s 263 | return False 264 | 265 | def is_goal(self, s): 266 | """Inidicates whether the state achieves the goal.""" 267 | angle_1 = s[StateIndex.THETA_1] 268 | angle_2 = s[StateIndex.THETA_2] 269 | angle_3 = s[StateIndex.THETA_3] 270 | if ( 271 | self.goal_range[0] < angle_1 < self.goal_range[1] 272 | and self.goal_range[0] < angle_1 + angle_2 < self.goal_range[1] 273 | and self.goal_range[0] < angle_1 + angle_2 + angle_3 < self.goal_range[1] 274 | ): 275 | return True 276 | else: 277 | return False 278 | 279 | def sample_random_action(self): 280 | """Sample a random action from available force.""" 281 | return np.atleast_1d(np.random.uniform(self.action_range[0], self.action_range[1], self.action_dim)) 282 | 283 | def sample_extreme_action(self): 284 | """Sample a random extreme action from available force.""" 285 | return np.atleast_1d(np.random.choice([self.action_range[0], self.action_range[1]], self.action_dim)) 286 | 287 | def sample_random_state(self): 288 | """Sample a random state.""" 289 | angle_1 = np.random.uniform(self.angle_1_range[0], self.angle_1_range[1]) 290 | angle_1_rate = np.random.uniform(self.angular_velocity_range[0], self.angular_velocity_range[1]) 291 | angle_2 = np.random.uniform(self.angle_2_range[0], self.angle_2_range[1]) 292 | angle_2_rate = np.random.uniform(self.angular_velocity_range[0], self.angular_velocity_range[1]) 293 | angle_3 = np.random.uniform(self.angle_3_range[0], self.angle_3_range[1]) 294 | angle_3_rate = np.random.uniform(self.angular_velocity_range[0], self.angular_velocity_range[1]) 295 | true_state = np.array([angle_1, angle_1_rate, angle_2, angle_2_rate, angle_3, angle_3_rate]) 296 | return true_state 297 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.distributions.independent import Independent 4 | from torch.distributions.normal import Normal 5 | 6 | 7 | torch.set_default_dtype(torch.float64) 8 | 9 | 10 | def MultivariateNormalDiag(loc, scale_diag): 11 | if loc.dim() < 1: 12 | raise ValueError("loc must be at least one-dimensional.") 13 | return Independent(Normal(loc, scale_diag), 1) 14 | 15 | 16 | class Encoder(nn.Module): 17 | # deterministic encoder q(z | x) 18 | def __init__(self, net, x_dim, z_dim): 19 | super(Encoder, self).__init__() 20 | self.net = net 21 | self.x_dim = x_dim 22 | self.z_dim = z_dim 23 | 24 | def forward(self, x): 25 | return self.net(x) 26 | 27 | 28 | class Dynamics(nn.Module): 29 | # stochastic transition model: P(z^_t+1 | z_t, u_t) 30 | def __init__(self, net_hidden, net_mean, net_logstd, net_A, net_B, z_dim, u_dim, armotized): 31 | super(Dynamics, self).__init__() 32 | self.net_hidden = net_hidden 33 | self.net_mean = net_mean 34 | self.net_logstd = net_logstd 35 | self.net_A = net_A 36 | self.net_B = net_B 37 | self.z_dim = z_dim 38 | self.u_dim = u_dim 39 | self.armotized = armotized 40 | 41 | def forward(self, z_t, u_t): 42 | z_u_t = torch.cat((z_t, u_t), dim=-1) 43 | hidden_neurons = self.net_hidden(z_u_t) 44 | mean = self.net_mean(hidden_neurons) + z_t # skip connection 45 | logstd = self.net_logstd(hidden_neurons) 46 | if self.armotized: 47 | A = self.net_A(hidden_neurons) 48 | B = self.net_B(hidden_neurons) 49 | else: 50 | A, B = None, None 51 | return MultivariateNormalDiag(mean, torch.exp(logstd)), A, B 52 | 53 | 54 | class PlanarEncoder(Encoder): 55 | def __init__(self, x_dim=1600, z_dim=2): 56 | net = nn.Sequential(nn.Linear(x_dim, 300), nn.ReLU(), nn.Linear(300, 300), nn.ReLU(), nn.Linear(300, z_dim)) 57 | super(PlanarEncoder, self).__init__(net, x_dim, z_dim) 58 | 59 | 60 | class PlanarDynamics(Dynamics): 61 | def __init__(self, armotized, z_dim=2, u_dim=2): 62 | net_hidden = nn.Sequential(nn.Linear(z_dim + u_dim, 20), nn.ReLU(), nn.Linear(20, 20), nn.ReLU()) 63 | net_mean = nn.Linear(20, z_dim) 64 | net_logstd = nn.Linear(20, z_dim) 65 | if armotized: 66 | net_A = nn.Linear(20, z_dim ** 2) 67 | net_B = nn.Linear(20, u_dim * z_dim) 68 | else: 69 | net_A, net_B = None, None 70 | super(PlanarDynamics, self).__init__(net_hidden, net_mean, net_logstd, net_A, net_B, z_dim, u_dim, armotized) 71 | 72 | 73 | class PendulumEncoder(Encoder): 74 | def __init__(self, x_dim=4608, z_dim=3): 75 | net = nn.Sequential(nn.Linear(x_dim, 500), nn.ReLU(), nn.Linear(500, 500), nn.ReLU(), nn.Linear(500, z_dim)) 76 | super(PendulumEncoder, self).__init__(net, x_dim, z_dim) 77 | 78 | 79 | class PendulumDynamics(Dynamics): 80 | def __init__(self, armotized, z_dim=3, u_dim=1): 81 | net_hidden = nn.Sequential(nn.Linear(z_dim + u_dim, 30), nn.ReLU(), nn.Linear(30, 30), nn.ReLU()) 82 | net_mean = nn.Linear(30, z_dim) 83 | net_logstd = nn.Linear(30, z_dim) 84 | if armotized: 85 | net_A = nn.Linear(30, z_dim * z_dim) 86 | net_B = nn.Linear(30, u_dim * z_dim) 87 | else: 88 | net_A, net_B = None, None 89 | super(PendulumDynamics, self).__init__(net_hidden, net_mean, net_logstd, net_A, net_B, z_dim, u_dim, armotized) 90 | 91 | 92 | class Flatten(nn.Module): 93 | def __init__(self): 94 | super(Flatten, self).__init__() 95 | 96 | def forward(self, x): 97 | return x.view(x.size(0), -1) 98 | 99 | 100 | class View(nn.Module): 101 | def __init__(self, shape): 102 | super(View, self).__init__() 103 | self.shape = shape 104 | 105 | def forward(self, x): 106 | return x.view(*self.shape) 107 | 108 | 109 | class CartPoleEncoder(Encoder): 110 | def __init__(self, x_dim=(2, 80, 80), z_dim=8): 111 | x_channels = x_dim[0] 112 | net = nn.Sequential( 113 | nn.Conv2d(in_channels=x_channels, out_channels=32, kernel_size=5, stride=1, padding=2), 114 | nn.ReLU(), 115 | nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=2, padding=2), 116 | nn.ReLU(), 117 | nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=2, padding=2), 118 | nn.ReLU(), 119 | nn.Conv2d(in_channels=32, out_channels=10, kernel_size=5, stride=2, padding=2), 120 | nn.ReLU(), 121 | Flatten(), 122 | nn.Linear(10 * 10 * 10, 200), 123 | nn.ReLU(), 124 | nn.Linear(200, z_dim), 125 | ) 126 | super(CartPoleEncoder, self).__init__(net, x_dim, z_dim) 127 | 128 | 129 | class CartPoleDynamics(Dynamics): 130 | def __init__(self, armotized, z_dim=8, u_dim=1): 131 | net_hidden = nn.Sequential(nn.Linear(z_dim + u_dim, 40), nn.ReLU(), nn.Linear(40, 40), nn.ReLU()) 132 | net_mean = nn.Linear(40, z_dim) 133 | net_logstd = nn.Linear(40, z_dim) 134 | if armotized: 135 | net_A = nn.Linear(40, z_dim * z_dim) 136 | net_B = nn.Linear(40, u_dim * z_dim) 137 | else: 138 | net_A, net_B = None, None 139 | super(CartPoleDynamics, self).__init__(net_hidden, net_mean, net_logstd, net_A, net_B, z_dim, u_dim, armotized) 140 | 141 | 142 | class ThreePoleEncoder(Encoder): 143 | def __init__(self, x_dim=(2, 80, 80), z_dim=8): 144 | x_channels = x_dim[0] 145 | net = nn.Sequential( 146 | nn.Conv2d(in_channels=x_channels, out_channels=32, kernel_size=5, stride=1, padding=2), 147 | nn.ReLU(), 148 | nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=2, padding=2), 149 | nn.ReLU(), 150 | nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=2, padding=2), 151 | nn.ReLU(), 152 | nn.Conv2d(in_channels=32, out_channels=10, kernel_size=5, stride=2, padding=2), 153 | nn.ReLU(), 154 | Flatten(), 155 | nn.Linear(10 * 10 * 10, 200), 156 | nn.ReLU(), 157 | nn.Linear(200, z_dim), 158 | ) 159 | super(ThreePoleEncoder, self).__init__(net, x_dim, z_dim) 160 | 161 | 162 | class ThreePoleDynamics(Dynamics): 163 | def __init__(self, armotized, z_dim=8, u_dim=3): 164 | net_hidden = nn.Sequential(nn.Linear(z_dim + u_dim, 40), nn.ReLU(), nn.Linear(40, 40), nn.ReLU()) 165 | net_mean = nn.Linear(40, z_dim) 166 | net_logstd = nn.Linear(40, z_dim) 167 | if armotized: 168 | net_A = nn.Linear(40, z_dim * z_dim) 169 | net_B = nn.Linear(40, u_dim * z_dim) 170 | else: 171 | net_A, net_B = None, None 172 | super(ThreePoleDynamics, self).__init__( 173 | net_hidden, net_mean, net_logstd, net_A, net_B, z_dim, u_dim, armotized 174 | ) 175 | 176 | 177 | CONFIG = { 178 | "planar": (PlanarEncoder, PlanarDynamics), 179 | "pendulum": (PendulumEncoder, PendulumDynamics), 180 | "pendulum_gym": (PendulumEncoder, PendulumDynamics), 181 | "cartpole": (CartPoleEncoder, CartPoleDynamics), 182 | "threepole": (ThreePoleEncoder, ThreePoleDynamics), 183 | } 184 | 185 | 186 | def load_config(name): 187 | return CONFIG[name] 188 | 189 | 190 | __all__ = ["load_config"] 191 | -------------------------------------------------------------------------------- /pc3.yml: -------------------------------------------------------------------------------- 1 | name: pc3 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - anaconda 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _tflow_select=2.1.0=gpu 10 | - absl-py=0.7.1=py36_0 11 | - asn1crypto=0.24.0=py36_0 12 | - astor=0.8.0=py36_0 13 | - attrs=19.3.0=py_0 14 | - backcall=0.1.0=py_0 15 | - blas=1.0=mkl 16 | - bleach=3.1.0=py_0 17 | - c-ares=1.15.0=h7b6447c_1001 18 | - ca-certificates=2019.9.11=hecc5488_0 19 | - certifi=2019.9.11=py36_0 20 | - cffi=1.12.3=py36h2e261b9_0 21 | - chardet=3.0.4=py36_1003 22 | - colour=0.1.5=py_0 23 | - cryptography=2.7=py36h1ba5d50_0 24 | - cudatoolkit=10.0.130=0 25 | - cudnn=7.6.0=cuda10.0_0 26 | - cupti=10.0.130=0 27 | - cycler=0.10.0=py_1 28 | - dbus=1.13.6=h746ee38_0 29 | - decorator=4.4.0=py_0 30 | - defusedxml=0.6.0=py_0 31 | - entrypoints=0.3=py36_1000 32 | - expat=2.2.6=he6710b0_0 33 | - fontconfig=2.13.1=he4413a7_1000 34 | - freetype=2.9.1=h8a8886c_1 35 | - gast=0.2.2=py36_0 36 | - gettext=0.19.8.1=hc5be6a0_1002 37 | - glib=2.56.2=had28632_1001 38 | - google-pasta=0.1.7=py_0 39 | - grpcio=1.16.1=py36hf8bcb03_1 40 | - gst-plugins-base=1.14.0=hbbd80ab_1 41 | - gstreamer=1.14.0=hb453b48_1 42 | - h5py=2.9.0=py36h7918eee_0 43 | - hdf5=1.10.4=hb1b8bf9_0 44 | - icu=58.2=hf484d3e_1000 45 | - idna=2.8=py36_0 46 | - imageio=2.6.1=py36_0 47 | - importlib_metadata=0.23=py36_0 48 | - intel-openmp=2019.4=243 49 | - ipykernel=5.1.3=py36h5ca1d4c_0 50 | - ipython=7.8.0=py36h5ca1d4c_0 51 | - ipython_genutils=0.2.0=py_1 52 | - jedi=0.15.1=py36_0 53 | - jinja2=2.10.3=py_0 54 | - jpeg=9b=h024ee3a_2 55 | - jsonschema=3.1.1=py36_0 56 | - jupyter_client=5.3.3=py36_1 57 | - jupyter_core=4.4.0=py_0 58 | - keras-applications=1.0.8=py_0 59 | - keras-preprocessing=1.1.0=py_1 60 | - kiwisolver=1.1.0=py36hc9558a2_0 61 | - libedit=3.1.20181209=hc058e9b_0 62 | - libffi=3.2.1=hd88cf55_4 63 | - libgcc-ng=9.1.0=hdf63c60_0 64 | - libgfortran-ng=7.3.0=hdf63c60_0 65 | - libiconv=1.15=h516909a_1005 66 | - libpng=1.6.37=hbc83047_0 67 | - libprotobuf=3.8.0=hd408876_0 68 | - libsodium=1.0.17=h516909a_0 69 | - libstdcxx-ng=9.1.0=hdf63c60_0 70 | - libtiff=4.0.10=h2733197_2 71 | - libuuid=2.32.1=h14c3975_1000 72 | - libxcb=1.13=h14c3975_1002 73 | - libxml2=2.9.9=h13577e0_2 74 | - markdown=3.1.1=py36_0 75 | - markupsafe=1.1.1=py36h14c3975_0 76 | - matplotlib=3.1.1=py36h5429711_0 77 | - mistune=0.8.4=py36h14c3975_1000 78 | - mkl=2019.4=243 79 | - mkl-service=2.3.0=py36he904b0f_0 80 | - mkl_fft=1.0.14=py36ha843d7b_0 81 | - mkl_random=1.0.2=py36hd81dba3_0 82 | - more-itertools=7.2.0=py_0 83 | - nbconvert=5.6.0=py36_1 84 | - nbformat=4.4.0=py_1 85 | - ncurses=6.1=he6710b0_1 86 | - ninja=1.9.0=py36hfd86e86_0 87 | - notebook=6.0.1=py36_0 88 | - numpy=1.16.5=py36h7e9f1db_0 89 | - numpy-base=1.16.5=py36hde5b4d6_0 90 | - olefile=0.46=py36_0 91 | - openssl=1.1.1=h7b6447c_0 92 | - pandoc=2.7.3=0 93 | - pandocfilters=1.4.2=py_1 94 | - parso=0.5.1=py_0 95 | - pcre=8.43=he6710b0_0 96 | - pexpect=4.7.0=py36_0 97 | - pickleshare=0.7.5=py36_1000 98 | - pillow=6.1.0=py36h34e0f95_0 99 | - pip=19.2.2=py36_0 100 | - pixman-cos6-x86_64=0.32.8=h7062e45_0 101 | - prometheus_client=0.7.1=py_0 102 | - prompt_toolkit=2.0.10=py_0 103 | - protobuf=3.8.0=py36he6710b0_0 104 | - pthread-stubs=0.4=h14c3975_1001 105 | - ptyprocess=0.6.0=py_1001 106 | - pycparser=2.19=py36_0 107 | - pyglet=1.3.2=py36_1000 108 | - pygments=2.4.2=py_0 109 | - pyopengl=3.1.1a1=py36_0 110 | - pyopenssl=19.0.0=py36_0 111 | - pyparsing=2.4.2=py_0 112 | - pyqt=5.9.2=py36hcca6a23_4 113 | - pyrsistent=0.15.4=py36h516909a_0 114 | - pysocks=1.7.1=py36_0 115 | - python=3.6.8=h0371630_0 116 | - python-dateutil=2.8.0=py_0 117 | - pytorch=1.2.0=py3.6_cuda10.0.130_cudnn7.6.2_0 118 | - pytz=2019.2=py_0 119 | - pyzmq=18.1.0=py36h1768529_0 120 | - qt=5.9.7=h5867ecd_1 121 | - readline=7.0=h7b6447c_5 122 | - requests=2.22.0=py36_0 123 | - scipy=1.3.1=py36h7c811a0_0 124 | - send2trash=1.5.0=py_0 125 | - setuptools=41.2.0=py36_0 126 | - sip=4.19.8=py36hf484d3e_1000 127 | - six=1.12.0=py36_0 128 | - sqlite=3.29.0=h7b6447c_0 129 | - tensorboard=1.14.0=py36hf484d3e_0 130 | - tensorboardx=1.8=py_0 131 | - tensorflow=1.14.0=gpu_py36h57aa796_0 132 | - tensorflow-base=1.14.0=gpu_py36h8d69cac_0 133 | - tensorflow-estimator=1.14.0=py_0 134 | - tensorflow-gpu=1.14.0=h0d30ee6_0 135 | - termcolor=1.1.0=py36_1 136 | - terminado=0.8.2=py36_0 137 | - testpath=0.4.2=py_1001 138 | - tk=8.6.8=hbc83047_0 139 | - torchvision=0.4.0=py36_cu100 140 | - tornado=6.0.3=py36h516909a_0 141 | - tqdm=4.36.1=py_0 142 | - traitlets=4.3.3=py36_0 143 | - urllib3=1.24.2=py36_0 144 | - wcwidth=0.1.7=py_1 145 | - webencodings=0.5.1=py_1 146 | - werkzeug=0.15.5=py_0 147 | - wheel=0.33.6=py36_0 148 | - wrapt=1.11.2=py36h7b6447c_0 149 | - xorg-libxau=1.0.9=h14c3975_0 150 | - xorg-libxdmcp=1.1.3=h516909a_0 151 | - xorg-x11-server-common-cos6-x86_64=1.17.4=he6f580c_0 152 | - xorg-x11-server-xvfb-cos6-x86_64=1.17.4=h5c27f9d_0 153 | - xz=5.2.4=h14c3975_4 154 | - zeromq=4.3.2=he1b5a44_2 155 | - zipp=0.6.0=py_0 156 | - zlib=1.2.11=h7b6447c_3 157 | - zstd=1.3.7=h0b5b093_0 158 | - pip: 159 | - future==0.17.1 -------------------------------------------------------------------------------- /pc3_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from networks import load_config 3 | from torch import nn 4 | 5 | 6 | torch.set_default_dtype(torch.float64) 7 | # torch.manual_seed(0) 8 | 9 | 10 | class PC3(nn.Module): 11 | def __init__(self, armotized, x_dim, z_dim, u_dim, env): 12 | super(PC3, self).__init__() 13 | enc, dyn = load_config(env) 14 | 15 | self.x_dim = x_dim 16 | self.z_dim = z_dim 17 | self.u_dim = u_dim 18 | self.armotized = armotized 19 | 20 | self.encoder = enc(x_dim, z_dim) 21 | self.dynamics = dyn(armotized, z_dim, u_dim) 22 | 23 | def encode(self, x): 24 | return self.encoder(x) 25 | 26 | def transition(self, z, u): 27 | return self.dynamics(z, u) 28 | 29 | def reparam(self, mean, std): 30 | epsilon = torch.randn_like(std) 31 | return mean + torch.mul(epsilon, std) 32 | 33 | def forward(self, x, u, x_next): 34 | # NCE loss and 35 | # consistency loss: in deterministic case = -log p(z' | z, u) 36 | z_enc = self.encode(x) # deterministic p(z | x) 37 | z_next_trans_dist, _, _ = self.transition(z_enc, u) # P(z^_t+1 | z_t, u _t) 38 | z_next_enc = self.encode(x_next) # deterministic Q(z^_t+1 | x_t+1) 39 | 40 | return z_enc, z_next_trans_dist, z_next_enc 41 | -------------------------------------------------------------------------------- /retrain_dynamics.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import math 4 | import os 5 | import random 6 | import time 7 | from os import path 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn.init as init 12 | import torch.optim as optim 13 | from datasets import CartPoleDataset, PendulumDataset, PlanarDataset, ThreePoleDataset 14 | from losses import curvature, nce_past 15 | from pc3_model import PC3 16 | from tensorboardX import SummaryWriter 17 | from torch import nn 18 | from torch.utils.data import DataLoader 19 | 20 | 21 | torch.set_default_dtype(torch.float64) 22 | 23 | device = torch.device("cuda") 24 | 25 | datasets = { 26 | "planar": PlanarDataset, 27 | "pendulum": PendulumDataset, 28 | "cartpole": CartPoleDataset, 29 | "threepole": ThreePoleDataset, 30 | } 31 | dims = { 32 | "planar": (1600, 2, 2), 33 | "pendulum": (4608, 3, 1), 34 | "cartpole": ((2, 80, 80), 8, 1), 35 | "threepole": ((2, 80, 80), 8, 3), 36 | } 37 | 38 | 39 | def seed_torch(seed): 40 | random.seed(seed) 41 | os.environ["PYTHONHASHSEED"] = str(seed) 42 | np.random.seed(seed) 43 | torch.manual_seed(seed) 44 | torch.cuda.manual_seed(seed) 45 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 46 | torch.backends.cudnn.benchmark = False 47 | torch.backends.cudnn.deterministic = True 48 | 49 | 50 | # default initialization for linear layers 51 | def weights_init(m): 52 | if isinstance(m, nn.Linear): 53 | init.kaiming_uniform_(m.weight, a=math.sqrt(5)) 54 | if m.bias is not None: 55 | fan_in, _ = init._calculate_fan_in_and_fan_out(m.weight) 56 | bound = 1 / math.sqrt(fan_in) 57 | init.uniform_(m.bias, -bound, bound) 58 | 59 | 60 | def compute_loss(model, armotized, u, z_enc, z_next_trans_dist, z_next_enc, option, lam, delta=0.1): 61 | """ 62 | option: cpc or consistency: retrain dynamics model using cpc loss or consistency 63 | """ 64 | # nce and consistency loss 65 | # nce_loss = nce_future(z_next_trans_dist, z_next_enc) # sampling future 66 | nce_loss = nce_past(z_next_trans_dist, z_next_enc) # sampling past 67 | 68 | consis_loss = -torch.mean(z_next_trans_dist.log_prob(z_next_enc)) 69 | 70 | # curvature loss 71 | cur_loss = curvature(model, z_enc, u, delta, armotized) 72 | # cur_loss = new_curvature(model, z_enc, u) 73 | 74 | # additional norm loss to center z range to (0,0) 75 | norm_loss = torch.sum(torch.mean(z_enc, dim=0).pow(2)) 76 | 77 | # additional norm loss to avoid collapsing 78 | avg_norm_2 = torch.mean(torch.sum(z_enc.pow(2), dim=1)) 79 | 80 | if option == "cpc": 81 | loss = lam[0] * nce_loss + lam[-1] * cur_loss 82 | elif option == "consistency": 83 | loss = lam[1] * consis_loss + lam[-1] * cur_loss 84 | return nce_loss, consis_loss, cur_loss, norm_loss, avg_norm_2, loss 85 | 86 | 87 | def train(model, option, train_loader, lam, latent_noise, optimizer, armotized, epoch): 88 | avg_nce_loss = 0.0 89 | avg_consis_loss = 0.0 90 | avg_cur_loss = 0.0 91 | avg_norm_loss = 0.0 92 | avg_norm_2_loss = 0.0 93 | avg_loss = 0.0 94 | 95 | num_batches = len(train_loader) 96 | model.train() 97 | 98 | start = time.time() 99 | 100 | for iter, (x, u, x_next) in enumerate(train_loader): 101 | x = x.to(device).double() 102 | u = u.to(device).double() 103 | x_next = x_next.to(device).double() 104 | optimizer.zero_grad() 105 | 106 | z_enc, z_next_trans_dist, z_next_enc = model(x, u, x_next) 107 | noise = torch.randn(size=z_next_enc.size()) * latent_noise 108 | if next(model.encoder.parameters()).is_cuda: 109 | noise = noise.cuda() 110 | z_next_enc += noise 111 | 112 | nce_loss, consis_loss, cur_loss, norm_loss, norm_2, loss = compute_loss( 113 | model, armotized, u, z_enc, z_next_trans_dist, z_next_enc, option, lam=lam 114 | ) 115 | 116 | loss.backward() 117 | optimizer.step() 118 | 119 | avg_nce_loss += nce_loss.item() 120 | avg_consis_loss += consis_loss.item() 121 | avg_cur_loss += cur_loss.item() 122 | avg_norm_loss += norm_loss.item() 123 | avg_norm_2_loss += norm_2.item() 124 | avg_loss += loss.item() 125 | 126 | avg_nce_loss /= num_batches 127 | avg_consis_loss /= num_batches 128 | avg_cur_loss /= num_batches 129 | avg_norm_loss /= num_batches 130 | avg_norm_2_loss /= num_batches 131 | avg_loss /= num_batches 132 | 133 | if (epoch + 1) % 1 == 0: 134 | print("Epoch %d" % (epoch + 1)) 135 | print("NCE loss: %f" % (avg_nce_loss)) 136 | print("Consistency loss: %f" % (avg_consis_loss)) 137 | print("Curvature loss: %f" % (avg_cur_loss)) 138 | print("Normalization loss: %f" % (avg_norm_loss)) 139 | print("Norma 2 loss: %f" % (avg_norm_2_loss)) 140 | print("Training loss: %f" % (avg_loss)) 141 | print("Training time: %f" % (time.time() - start)) 142 | print("--------------------------------------") 143 | 144 | return avg_nce_loss, avg_consis_loss, avg_cur_loss, avg_loss 145 | 146 | 147 | def main(args): 148 | env_name = args.env 149 | assert env_name in ["planar", "pendulum", "cartpole", "threepole"] 150 | option = args.option 151 | assert option in ["cpc", "consistency"] 152 | load_dir = args.load_dir 153 | epoch_load = args.epoch_load 154 | save_dir = args.save_dir 155 | epoches = args.num_iter 156 | iter_save = args.iter_save 157 | 158 | with open(load_dir + "/settings", "r") as f: 159 | settings = json.load(f) 160 | armotized = settings["armotized"] 161 | seed = settings["seed"] 162 | data_size = settings["data_size"] 163 | noise_level = settings["noise"] 164 | batch_size = settings["batch_size"] 165 | lam_nce = settings["lam_nce"] 166 | lam_c = settings["lam_c"] 167 | lam_cur = settings["lam_cur"] 168 | lam = [lam_nce, lam_c, lam_cur] 169 | lr = settings["lr"] 170 | latent_noise = settings["latent_noise"] 171 | weight_decay = settings["decay"] 172 | 173 | seed_torch(seed) 174 | 175 | dataset = datasets[env_name] 176 | data = dataset(sample_size=data_size, noise=noise_level) 177 | data_loader = DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=False, num_workers=4) 178 | 179 | x_dim, z_dim, u_dim = dims[env_name] 180 | model = PC3(armotized=armotized, x_dim=x_dim, z_dim=z_dim, u_dim=u_dim, env=env_name).to(device) 181 | model.load_state_dict(torch.load(load_dir + "/model_" + str(epoch_load))) 182 | 183 | # frozen the encoder 184 | for param in model.encoder.parameters(): 185 | param.requires_grad = False 186 | # re-initialize and train the dynamics only 187 | model.dynamics.net_hidden.apply(weights_init) 188 | model.dynamics.net_mean.apply(weights_init) 189 | model.dynamics.net_logstd.apply(weights_init) 190 | 191 | optimizer = optim.Adam(model.dynamics.parameters(), betas=(0.9, 0.999), eps=1e-8, lr=lr, weight_decay=weight_decay) 192 | 193 | save_path = "logs/" + env_name + "/" + save_dir 194 | if not path.exists(save_path): 195 | os.makedirs(save_path) 196 | 197 | writer = SummaryWriter(save_path) 198 | 199 | result_path = "result/" + env_name + "/" + save_dir 200 | if not path.exists(result_path): 201 | os.makedirs(result_path) 202 | with open(result_path + "/settings", "w") as f: 203 | json.dump(args.__dict__, f, indent=2) 204 | 205 | start = time.time() 206 | for i in range(epoches): 207 | avg_pred_loss, avg_consis_loss, avg_cur_loss, avg_loss = train( 208 | model, option, data_loader, lam, latent_noise, optimizer, armotized, i 209 | ) 210 | # ...log the running loss 211 | writer.add_scalar("NCE loss", avg_pred_loss, i) 212 | writer.add_scalar("consistency loss", avg_consis_loss, i) 213 | writer.add_scalar("curvature loss", avg_cur_loss, i) 214 | writer.add_scalar("training loss", avg_loss, i) 215 | 216 | # save model 217 | if (i + 1) % iter_save == 0: 218 | print("Saving the model.............") 219 | torch.save(model.state_dict(), result_path + "/model_" + str(i + 1)) 220 | with open(result_path + "/loss_" + str(i + 1), "w") as f: 221 | f.write( 222 | "\n".join( 223 | [ 224 | "NCE loss: " + str(avg_pred_loss), 225 | "Consistency loss: " + str(avg_consis_loss), 226 | "Curvature loss: " + str(avg_cur_loss), 227 | "Training loss: " + str(avg_loss), 228 | ] 229 | ) 230 | ) 231 | end = time.time() 232 | print("time: " + str(end - start)) 233 | with open(result_path + "/time", "w") as f: 234 | f.write(str(end - start)) 235 | writer.close() 236 | 237 | 238 | if __name__ == "__main__": 239 | parser = argparse.ArgumentParser(description="retrain the dynamics") 240 | 241 | parser.add_argument("--env", required=True, type=str, help="environment used for training") 242 | parser.add_argument("--option", required=True, type=str, help="option for re-training dynamics") 243 | parser.add_argument("--load_dir", required=True, type=str, help="path to load the trained model") 244 | parser.add_argument("--epoch_load", default=2000, type=int, help="epoch to load") 245 | parser.add_argument("--save_dir", required=True, type=str, help="path to save retrined model") 246 | parser.add_argument("--num_iter", default=2000, type=int, help="number of epoches") 247 | parser.add_argument( 248 | "--iter_save", default=1000, type=int, help="save model and result after this number of iterations" 249 | ) 250 | args = parser.parse_args() 251 | 252 | main(args) 253 | -------------------------------------------------------------------------------- /sample_results/balance.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VinAIResearch/PC3-pytorch/ec20761ebda3d1e6ed286150e462b0b4de47c81f/sample_results/balance.gif -------------------------------------------------------------------------------- /sample_results/cartpole.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VinAIResearch/PC3-pytorch/ec20761ebda3d1e6ed286150e462b0b4de47c81f/sample_results/cartpole.gif -------------------------------------------------------------------------------- /sample_results/pc3_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VinAIResearch/PC3-pytorch/ec20761ebda3d1e6ed286150e462b0b4de47c81f/sample_results/pc3_5.png -------------------------------------------------------------------------------- /sample_results/pc3_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VinAIResearch/PC3-pytorch/ec20761ebda3d1e6ed286150e462b0b4de47c81f/sample_results/pc3_6.png -------------------------------------------------------------------------------- /sample_results/pc3_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VinAIResearch/PC3-pytorch/ec20761ebda3d1e6ed286150e462b0b4de47c81f/sample_results/pc3_7.png -------------------------------------------------------------------------------- /sample_results/pc3_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VinAIResearch/PC3-pytorch/ec20761ebda3d1e6ed286150e462b0b4de47c81f/sample_results/pc3_model.png -------------------------------------------------------------------------------- /sample_results/pc3_planar_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VinAIResearch/PC3-pytorch/ec20761ebda3d1e6ed286150e462b0b4de47c81f/sample_results/pc3_planar_1.png -------------------------------------------------------------------------------- /sample_results/pc3_planar_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VinAIResearch/PC3-pytorch/ec20761ebda3d1e6ed286150e462b0b4de47c81f/sample_results/pc3_planar_2.png -------------------------------------------------------------------------------- /sample_results/pc3_planar_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VinAIResearch/PC3-pytorch/ec20761ebda3d1e6ed286150e462b0b4de47c81f/sample_results/pc3_planar_5.png -------------------------------------------------------------------------------- /sample_results/pcc_pendulum_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VinAIResearch/PC3-pytorch/ec20761ebda3d1e6ed286150e462b0b4de47c81f/sample_results/pcc_pendulum_1.png -------------------------------------------------------------------------------- /sample_results/pcc_pendulum_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VinAIResearch/PC3-pytorch/ec20761ebda3d1e6ed286150e462b0b4de47c81f/sample_results/pcc_pendulum_2.png -------------------------------------------------------------------------------- /sample_results/pcc_pendulum_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VinAIResearch/PC3-pytorch/ec20761ebda3d1e6ed286150e462b0b4de47c81f/sample_results/pcc_pendulum_4.png -------------------------------------------------------------------------------- /sample_results/pcc_planar_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VinAIResearch/PC3-pytorch/ec20761ebda3d1e6ed286150e462b0b4de47c81f/sample_results/pcc_planar_1.png -------------------------------------------------------------------------------- /sample_results/pcc_planar_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VinAIResearch/PC3-pytorch/ec20761ebda3d1e6ed286150e462b0b4de47c81f/sample_results/pcc_planar_10.png -------------------------------------------------------------------------------- /sample_results/pcc_planar_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VinAIResearch/PC3-pytorch/ec20761ebda3d1e6ed286150e462b0b4de47c81f/sample_results/pcc_planar_2.png -------------------------------------------------------------------------------- /sample_results/planar.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VinAIResearch/PC3-pytorch/ec20761ebda3d1e6ed286150e462b0b4de47c81f/sample_results/planar.gif -------------------------------------------------------------------------------- /sample_results/swing.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VinAIResearch/PC3-pytorch/ec20761ebda3d1e6ed286150e462b0b4de47c81f/sample_results/swing.gif -------------------------------------------------------------------------------- /sample_results/table_maps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VinAIResearch/PC3-pytorch/ec20761ebda3d1e6ed286150e462b0b4de47c81f/sample_results/table_maps.png -------------------------------------------------------------------------------- /sample_results/table_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VinAIResearch/PC3-pytorch/ec20761ebda3d1e6ed286150e462b0b4de47c81f/sample_results/table_result.png -------------------------------------------------------------------------------- /sample_results/threepole.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VinAIResearch/PC3-pytorch/ec20761ebda3d1e6ed286150e462b0b4de47c81f/sample_results/threepole.gif -------------------------------------------------------------------------------- /train_pc3.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import random 5 | import time 6 | from os import path 7 | 8 | import numpy as np 9 | import torch 10 | import torch.optim as optim 11 | from datasets import CartPoleDataset, PendulumDataset, PlanarDataset, ThreePoleDataset 12 | from latent_map_planar import draw_latent_map 13 | from losses import curvature, nce_past 14 | from mdp.plane_obstacles_mdp import PlanarObstaclesMDP 15 | from pc3_model import PC3 16 | from tensorboardX import SummaryWriter 17 | from torch.utils.data import DataLoader 18 | 19 | 20 | torch.set_default_dtype(torch.float64) 21 | 22 | device = torch.device("cuda") 23 | datasets = { 24 | "planar": PlanarDataset, 25 | "pendulum": PendulumDataset, 26 | "cartpole": CartPoleDataset, 27 | "threepole": ThreePoleDataset, 28 | } 29 | dims = { 30 | "planar": (1600, 2, 2), 31 | "pendulum": (4608, 3, 1), 32 | "cartpole": ((2, 80, 80), 8, 1), 33 | "threepole": ((2, 80, 80), 8, 3), 34 | } 35 | 36 | 37 | def seed_torch(seed): 38 | random.seed(seed) 39 | os.environ["PYTHONHASHSEED"] = str(seed) 40 | np.random.seed(seed) 41 | torch.manual_seed(seed) 42 | torch.cuda.manual_seed(seed) 43 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 44 | torch.backends.cudnn.benchmark = False 45 | torch.backends.cudnn.deterministic = True 46 | 47 | 48 | def compute_loss(model, armotized, u, z_enc, z_next_trans_dist, z_next_enc, lam, delta=0.1, norm_coeff=0.01): 49 | # nce and consistency loss 50 | nce_loss = nce_past(z_next_trans_dist, z_next_enc) # sampling past 51 | 52 | consis_loss = -torch.mean(z_next_trans_dist.log_prob(z_next_enc)) 53 | 54 | # curvature loss 55 | cur_loss = curvature(model, z_enc, u, delta, armotized) 56 | 57 | # additional norm loss to center z range to (0,0) 58 | center_loss = torch.sum(torch.mean(z_enc, dim=0).pow(2)) 59 | 60 | # print out to monitor the scale of latent maps, not part of the objective 61 | norm_2 = torch.mean(torch.sum(z_enc.pow(2), dim=1)) 62 | 63 | lam_nce, lam_c, lam_cur = lam 64 | return ( 65 | nce_loss, 66 | consis_loss, 67 | cur_loss, 68 | center_loss, 69 | norm_2, 70 | lam_nce * nce_loss + lam_c * consis_loss + lam_cur * cur_loss + norm_coeff * center_loss, 71 | ) 72 | 73 | 74 | def train(model, train_loader, lam, norm_coeff, latent_noise, optimizer, armotized, epoch): 75 | avg_nce_loss = 0.0 76 | avg_consis_loss = 0.0 77 | avg_cur_loss = 0.0 78 | avg_center_loss = 0.0 79 | avg_norm_2_loss = 0.0 80 | avg_loss = 0.0 81 | 82 | num_batches = len(train_loader) 83 | model.train() 84 | 85 | start = time.time() 86 | 87 | for iter, (x, u, x_next) in enumerate(train_loader): 88 | x = x.to(device).double() 89 | u = u.to(device).double() 90 | x_next = x_next.to(device).double() 91 | optimizer.zero_grad() 92 | 93 | z_enc, z_next_trans_dist, z_next_enc = model(x, u, x_next) 94 | noise = torch.randn(size=z_next_enc.size()) * latent_noise 95 | if next(model.encoder.parameters()).is_cuda: 96 | noise = noise.cuda() 97 | z_next_enc += noise 98 | 99 | nce_loss, consis_loss, cur_loss, center_loss, norm_2, loss = compute_loss( 100 | model, armotized, u, z_enc, z_next_trans_dist, z_next_enc, lam=lam, norm_coeff=norm_coeff 101 | ) 102 | 103 | loss.backward() 104 | optimizer.step() 105 | 106 | avg_nce_loss += nce_loss.item() 107 | avg_consis_loss += consis_loss.item() 108 | avg_cur_loss += cur_loss.item() 109 | avg_center_loss += center_loss.item() 110 | avg_norm_2_loss += norm_2.item() 111 | avg_loss += loss.item() 112 | 113 | avg_nce_loss /= num_batches 114 | avg_consis_loss /= num_batches 115 | avg_cur_loss /= num_batches 116 | avg_center_loss /= num_batches 117 | avg_norm_2_loss /= num_batches 118 | avg_loss /= num_batches 119 | 120 | if (epoch + 1) % 1 == 0: 121 | print("Epoch %d" % (epoch + 1)) 122 | print("NCE loss: %f" % (avg_nce_loss)) 123 | print("Consistency loss: %f" % (avg_consis_loss)) 124 | print("Curvature loss: %f" % (avg_cur_loss)) 125 | print("Center loss: %f" % (avg_center_loss)) 126 | print("Map scale: %f" % (avg_norm_2_loss)) 127 | print("Training loss: %f" % (avg_loss)) 128 | print("Training time: %f" % (time.time() - start)) 129 | print("--------------------------------------") 130 | 131 | return avg_nce_loss, avg_consis_loss, avg_cur_loss, avg_loss 132 | 133 | 134 | def main(args): 135 | env_name = args.env 136 | assert env_name in ["planar", "pendulum", "cartpole", "threepole"] 137 | armotized = args.armotized 138 | log_dir = args.log_dir 139 | seed = args.seed 140 | data_size = args.data_size 141 | noise_level = args.noise 142 | batch_size = args.batch_size 143 | lam_nce = args.lam_nce 144 | lam_c = args.lam_c 145 | lam_cur = args.lam_cur 146 | lam = [lam_nce, lam_c, lam_cur] 147 | norm_coeff = args.norm_coeff 148 | lr = args.lr 149 | latent_noise = args.latent_noise 150 | weight_decay = args.decay 151 | epoches = args.num_iter 152 | iter_save = args.iter_save 153 | save_map = args.save_map 154 | 155 | seed_torch(seed) 156 | 157 | dataset = datasets[env_name] 158 | data = dataset(sample_size=data_size, noise=noise_level) 159 | data_loader = DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=False, num_workers=4) 160 | 161 | x_dim, z_dim, u_dim = dims[env_name] 162 | model = PC3(armotized=armotized, x_dim=x_dim, z_dim=z_dim, u_dim=u_dim, env=env_name).to(device) 163 | 164 | if save_map and env_name == "planar": 165 | mdp = PlanarObstaclesMDP(noise=noise_level) 166 | 167 | optimizer = optim.Adam(model.parameters(), betas=(0.9, 0.999), eps=1e-8, lr=lr, weight_decay=weight_decay) 168 | 169 | log_path = "logs/" + env_name + "/" + log_dir 170 | if not path.exists(log_path): 171 | os.makedirs(log_path) 172 | writer = SummaryWriter(log_path) 173 | 174 | result_path = "result/" + env_name + "/" + log_dir 175 | if not path.exists(result_path): 176 | os.makedirs(result_path) 177 | with open(result_path + "/settings", "w") as f: 178 | json.dump(args.__dict__, f, indent=2) 179 | 180 | if save_map and env_name == "planar": 181 | latent_maps = [draw_latent_map(model, mdp)] 182 | 183 | start = time.time() 184 | for i in range(epoches): 185 | avg_pred_loss, avg_consis_loss, avg_cur_loss, avg_loss = train( 186 | model, data_loader, lam, norm_coeff, latent_noise, optimizer, armotized, i 187 | ) 188 | # ...log the running loss 189 | writer.add_scalar("NCE loss", avg_pred_loss, i) 190 | writer.add_scalar("consistency loss", avg_consis_loss, i) 191 | writer.add_scalar("curvature loss", avg_cur_loss, i) 192 | writer.add_scalar("training loss", avg_loss, i) 193 | if save_map and env_name == "planar": 194 | if (i + 1) % 10 == 0: 195 | map_i = draw_latent_map(model, mdp) 196 | latent_maps.append(map_i) 197 | 198 | # save model 199 | if (i + 1) % iter_save == 0: 200 | print("Saving the model.............") 201 | 202 | torch.save(model.state_dict(), result_path + "/model_" + str(i + 1)) 203 | with open(result_path + "/loss_" + str(i + 1), "w") as f: 204 | f.write( 205 | "\n".join( 206 | [ 207 | "NCE loss: " + str(avg_pred_loss), 208 | "Consistency loss: " + str(avg_consis_loss), 209 | "Curvature loss: " + str(avg_cur_loss), 210 | "Training loss: " + str(avg_loss), 211 | ] 212 | ) 213 | ) 214 | end = time.time() 215 | with open(result_path + "/time", "w") as f: 216 | f.write(str(end - start)) 217 | if env_name == "planar" and save_map: 218 | latent_maps[0].save( 219 | result_path + "/latent_map.gif", 220 | format="GIF", 221 | append_images=latent_maps[1:], 222 | save_all=True, 223 | duration=100, 224 | loop=0, 225 | ) 226 | writer.close() 227 | 228 | 229 | def str2bool(v): 230 | if isinstance(v, bool): 231 | return v 232 | if v.lower() in ("yes", "true", "t", "y", "1"): 233 | return True 234 | elif v.lower() in ("no", "false", "f", "n", "0"): 235 | return False 236 | else: 237 | raise argparse.ArgumentTypeError("Boolean value expected.") 238 | 239 | 240 | if __name__ == "__main__": 241 | parser = argparse.ArgumentParser(description="train pcc model") 242 | 243 | parser.add_argument("--env", required=True, type=str, help="environment used for training") 244 | parser.add_argument( 245 | "--armotized", 246 | required=True, 247 | type=str2bool, 248 | nargs="?", 249 | const=True, 250 | default=False, 251 | help="type of dynamics model", 252 | ) 253 | parser.add_argument("--log_dir", required=True, type=str, help="directory to save training log") 254 | parser.add_argument("--seed", required=True, type=int, help="seed number") 255 | parser.add_argument("--data_size", required=True, type=int, help="the bumber of data points used for training") 256 | parser.add_argument("--noise", default=0, type=float, help="the level of noise") 257 | parser.add_argument("--batch_size", default=256, type=int, help="batch size") 258 | parser.add_argument("--lam_nce", default=1.0, type=float, help="weight of prediction loss") 259 | parser.add_argument("--lam_c", default=1.0, type=float, help="weight of consistency loss") 260 | parser.add_argument("--lam_cur", default=7.0, type=float, help="weight of curvature loss") 261 | parser.add_argument("--norm_coeff", default=0.1, type=float, help="coefficient of additional normalization loss") 262 | parser.add_argument("--lr", default=0.0005, type=float, help="learning rate") 263 | parser.add_argument("--latent_noise", default=0.1, type=float, help="level of noise added to the latent code") 264 | parser.add_argument("--decay", default=0.001, type=float, help="L2 regularization") 265 | parser.add_argument("--num_iter", default=2000, type=int, help="number of epoches") 266 | parser.add_argument( 267 | "--iter_save", default=1000, type=int, help="save model and result after this number of iterations" 268 | ) 269 | parser.add_argument("--save_map", default=False, type=str2bool, help="save the latent map during training or not") 270 | args = parser.parse_args() 271 | 272 | main(args) 273 | --------------------------------------------------------------------------------