├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── __init__.py ├── baseline_models ├── __init__.py ├── latent_ode_lib │ ├── README.md │ ├── __init__.py │ ├── base_models.py │ ├── create_latent_ode_model.py │ ├── diffeq_solver.py │ ├── encoder_decoder.py │ ├── latent_ode.py │ ├── likelihood_eval.py │ ├── ode_func.py │ ├── ode_rnn.py │ ├── parse_datasets.py │ ├── plotting.py │ ├── rnn_baselines.py │ └── utils.py └── original_latent_ode.py ├── config.py ├── envs ├── __init__.py └── oderl │ ├── .gitignore │ ├── LICENSE │ ├── README.md │ ├── __init__.py │ ├── ctrl │ ├── __init__.py │ ├── ctrl.py │ ├── dataset.py │ ├── dynamics.py │ ├── policy.py │ └── utils.py │ ├── env_simulator.py │ ├── envs │ ├── __init__.py │ ├── base_env.py │ ├── ctacrobot.py │ ├── ctcartpole.py │ └── ctpendulum.py │ ├── runner.py │ └── utils │ ├── __init__.py │ ├── benn.py │ ├── bnn.py │ ├── dropout_bnn.py │ ├── enn.py │ ├── ibnn.py │ └── utils.py ├── mppi_dataset_collector.py ├── mppi_optim.yaml ├── mppi_with_model.py ├── oracle.py ├── overlay.py ├── planners ├── __init__.py └── mppi_delay.py ├── process_results ├── __init__.py ├── files │ └── .gitkeep ├── plot_util.py └── process_logs.py ├── pyproject.toml ├── requirements-dev.txt ├── requirements.txt ├── run_exp_multi.py ├── saved_models └── .gitkeep ├── setup.cfg ├── setup └── install.sh ├── train_utils.py ├── w_latent_ode.py └── w_nl.py /.gitignore: -------------------------------------------------------------------------------- 1 | replay_buffer_env*.pt 2 | *.npy 3 | *.pdf 4 | *.mp4 5 | wandb 6 | *.png 7 | results/ 8 | *.bak 9 | *.dat 10 | *.dir 11 | logs/ 12 | *.db 13 | .vscode 14 | 15 | # Byte-compiled / optimized / DLL files 16 | __pycache__/ 17 | *.py[cod] 18 | *$py.class 19 | 20 | # C extensions 21 | *.so 22 | 23 | # Distribution / packaging 24 | .Python 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | wheels/ 37 | pip-wheel-metadata/ 38 | share/python-wheels/ 39 | *.egg-info/ 40 | .installed.cfg 41 | *.egg 42 | MANIFEST 43 | 44 | # PyInstaller 45 | # Usually these files are written by a python script from a template 46 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 47 | *.manifest 48 | *.spec 49 | 50 | # Installer logs 51 | pip-log.txt 52 | pip-delete-this-directory.txt 53 | 54 | # Unit test / coverage reports 55 | htmlcov/ 56 | .tox/ 57 | .nox/ 58 | .coverage 59 | .coverage.* 60 | .cache 61 | nosetests.xml 62 | coverage.xml 63 | *.cover 64 | *.py,cover 65 | .hypothesis/ 66 | .pytest_cache/ 67 | 68 | # Translations 69 | *.mo 70 | *.pot 71 | 72 | # Django stuff: 73 | *.log 74 | local_settings.py 75 | db.sqlite3 76 | db.sqlite3-journal 77 | 78 | # Flask stuff: 79 | instance/ 80 | .webassets-cache 81 | 82 | # Scrapy stuff: 83 | .scrapy 84 | 85 | # Sphinx documentation 86 | docs/_build/ 87 | 88 | # PyBuilder 89 | target/ 90 | 91 | # Jupyter Notebook 92 | .ipynb_checkpoints 93 | 94 | # IPython 95 | profile_default/ 96 | ipython_config.py 97 | 98 | # pyenv 99 | .python-version 100 | 101 | # pipenv 102 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 103 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 104 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 105 | # install all needed dependencies. 106 | #Pipfile.lock 107 | 108 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 109 | __pypackages__/ 110 | 111 | # Celery stuff 112 | celerybeat-schedule 113 | celerybeat.pid 114 | 115 | # SageMath parsed files 116 | *.sage.py 117 | 118 | # Environments 119 | .env 120 | .venv 121 | env/ 122 | venv/ 123 | ENV/ 124 | env.bak/ 125 | venv.bak/ 126 | 127 | # Spyder project settings 128 | .spyderproject 129 | .spyproject 130 | 131 | # Rope project settings 132 | .ropeproject 133 | 134 | # mkdocs documentation 135 | /site 136 | 137 | # mypy 138 | .mypy_cache/ 139 | .dmypy.json 140 | dmypy.json 141 | 142 | # Pyre type checker 143 | .pyre/ 144 | 145 | # Custom@ 146 | saved_models/*.pt 147 | saved_models/*.zip 148 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.4.0 4 | hooks: 5 | - id: trailing-whitespace 6 | args: ['--markdown-linebreak-ext=md'] 7 | - id: check-added-large-files 8 | args: ['--maxkb=1000'] 9 | - id: check-ast 10 | - id: check-case-conflict 11 | - id: check-merge-conflict 12 | - id: check-toml 13 | - id: check-yaml 14 | - id: check-executables-have-shebangs 15 | - id: debug-statements 16 | - id: end-of-file-fixer 17 | - id: requirements-txt-fixer 18 | - id: mixed-line-ending 19 | args: ['--fix=auto'] 20 | 21 | - repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks 22 | rev: v2.8.0 23 | hooks: 24 | - id: pretty-format-toml 25 | args: [--autofix] 26 | 27 | - repo: https://github.com/pycqa/isort 28 | rev: 5.12.0 29 | hooks: 30 | - id: isort 31 | 32 | - repo: https://github.com/psf/black 33 | rev: 23.1.0 34 | hooks: 35 | - id: black-jupyter 36 | language_version: python3 37 | 38 | - repo: https://github.com/PyCQA/flake8 39 | rev: 6.0.0 40 | hooks: 41 | - id: flake8 42 | 43 | - repo: https://github.com/PyCQA/bandit 44 | rev: 1.7.5 45 | hooks: 46 | - id: bandit 47 | args: ["-c", "pyproject.toml", "-q", "-lll"] 48 | additional_dependencies: ["bandit[toml]"] 49 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Samuel Holt, Hao Sun 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Laplace Control for Continuous-time Delayed Systems (Code) 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-2206.04843-b31b1b.svg)](https://arxiv.org/abs/2302.12604) 4 | [![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/licenses/MIT) 5 | [![code style](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 6 | 7 | This repository is the official implementation of [Neural Laplace Control for Continuous-time Delayed Systems](https://arxiv.org/abs/2302.12604). 8 | 9 | 1. Run/Follow steps in [install.sh](setup/install.sh) 10 | 2. Replicate experimental results by running and configuring [run_exp_multi.py](run_exp_multi.py). 11 | ```sh 12 | python run_exp_multi.py 13 | ``` 14 | 3. Process the output log file using [process_logs.py](process_results/process_logs.py) by updating the `LOG_PATH` variable to point to the recently generated log file. 15 | ```sh 16 | python process_results/process_logs.py 17 | ``` 18 | 19 | #### Retraining 20 | To retrain all models from scratch (much slower), set the following variables to `True` in [run_exp_multi.py](run_exp_multi.py) before running it: 21 | ```python 22 | RETRAIN = True 23 | FORCE_RETRAIN = True 24 | ``` 25 | 26 | #### Large files: 27 | To obtain large files like saved models for this work, please download these from Google Drive [here](https://drive.google.com/drive/folders/1j8IijW5iVrxD7hSstBfFpmojAkArP5CU?usp=sharing) and place them into corresponding directories. 28 | 29 | 30 | ## Resources & Other Great Tools 📝 31 | * 💻 [Neural Laplace](https://github.com/samholt/NeuralLaplace): Neural Laplace: Differentiable Laplace Reconstructions for modelling any time observation with O(1) complexity. 32 | 33 | ### Acknowledgements & Citing `Neural Laplace Control` ✏️ 34 | 35 | If you use `Neural Laplace Control` in your research, please cite it as follows: 36 | 37 | ``` 38 | @inproceedings{holt2023neural, 39 | title={Neural Laplace Control for Continuous-time Delayed Systems}, 40 | author={Holt, Samuel and H{\"u}y{\"u}k, Alihan and Qian, Zhaozhi and Sun, Hao and van der Schaar, Mihaela}, 41 | booktitle={International Conference on Artificial Intelligence and Statistics}, 42 | pages={1747--1778}, 43 | year={2023}, 44 | organization={PMLR} 45 | } 46 | ``` 47 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/samholt/NeuralLaplaceControl/e827944748538b7356ff6bda04aaac34a0da41d2/__init__.py -------------------------------------------------------------------------------- /baseline_models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/samholt/NeuralLaplaceControl/e827944748538b7356ff6bda04aaac34a0da41d2/baseline_models/__init__.py -------------------------------------------------------------------------------- /baseline_models/latent_ode_lib/README.md: -------------------------------------------------------------------------------- 1 | Code in this folder references [Latent ODEs for Irregularly-Sampled Time Series](https://github.com/YuliaRubanova/latent_ode), used here as a baseline 2 | -------------------------------------------------------------------------------- /baseline_models/latent_ode_lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/samholt/NeuralLaplaceControl/e827944748538b7356ff6bda04aaac34a0da41d2/baseline_models/latent_ode_lib/__init__.py -------------------------------------------------------------------------------- /baseline_models/latent_ode_lib/create_latent_ode_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Latent ODEs for Irregularly-Sampled Time Series 3 | Author: Yulia Rubanova 4 | """ 5 | 6 | import torch.nn as nn 7 | 8 | from .diffeq_solver import DiffeqSolver 9 | from .encoder_decoder import Decoder, Encoder_z0_ODE_RNN, Encoder_z0_RNN 10 | from .latent_ode import LatentODE 11 | from .ode_func import ODEFunc, ODEFunc_w_Poisson 12 | from .utils import create_net 13 | 14 | ##################################################################################################### 15 | 16 | 17 | def create_LatentODE_model_direct( 18 | input_dim, 19 | z0_prior, 20 | obsrv_std, 21 | device, 22 | classif_per_tp=False, 23 | n_labels=1, 24 | latents=2, 25 | units=100, 26 | poisson=False, 27 | gen_layers=1, 28 | rec_dims=20, 29 | rec_layers=1, 30 | z0_encoder="odernn", 31 | gru_units=100, 32 | classif=False, 33 | linear_classif=False, 34 | dataset="", 35 | ): 36 | dim = latents 37 | if poisson: 38 | lambda_net = create_net(dim, input_dim, n_layers=1, n_units=units, nonlinear=nn.Tanh) 39 | 40 | # ODE function produces the gradient for latent state and for poisson rate 41 | ode_func_net = create_net(dim * 2, latents * 2, n_layers=gen_layers, n_units=units, nonlinear=nn.Tanh) 42 | 43 | gen_ode_func = ( 44 | ODEFunc_w_Poisson( 45 | input_dim=input_dim, 46 | latent_dim=latents * 2, 47 | ode_func_net=ode_func_net, 48 | lambda_net=lambda_net, 49 | device=device, 50 | ) 51 | .to(device) 52 | .double() 53 | ) 54 | else: 55 | dim = latents 56 | ode_func_net = create_net(dim, latents, n_layers=gen_layers, n_units=units, nonlinear=nn.Tanh) 57 | 58 | gen_ode_func = ( 59 | ODEFunc( 60 | input_dim=input_dim, 61 | latent_dim=latents, 62 | ode_func_net=ode_func_net, 63 | device=device, 64 | ) 65 | .to(device) 66 | .double() 67 | ) 68 | 69 | z0_diffeq_solver = None 70 | n_rec_dims = rec_dims 71 | enc_input_dim = int(input_dim) * 2 # we concatenate the mask 72 | gen_data_dim = input_dim 73 | 74 | z0_dim = latents 75 | if poisson: 76 | z0_dim += latents # predict the initial poisson rate 77 | 78 | if z0_encoder == "odernn": 79 | ode_func_net = create_net( 80 | n_rec_dims, 81 | n_rec_dims, 82 | n_layers=rec_layers, 83 | n_units=units, 84 | nonlinear=nn.Tanh, 85 | ) 86 | 87 | rec_ode_func = ( 88 | ODEFunc( 89 | input_dim=enc_input_dim, 90 | latent_dim=n_rec_dims, 91 | ode_func_net=ode_func_net, 92 | device=device, 93 | ) 94 | .to(device) 95 | .double() 96 | ) 97 | 98 | z0_diffeq_solver = DiffeqSolver( 99 | enc_input_dim, 100 | rec_ode_func, 101 | "euler", 102 | latents, 103 | odeint_rtol=1e-3, 104 | odeint_atol=1e-4, 105 | device=device, 106 | ) 107 | 108 | encoder_z0 = ( 109 | Encoder_z0_ODE_RNN( 110 | n_rec_dims, 111 | enc_input_dim, 112 | z0_diffeq_solver, 113 | z0_dim=z0_dim, 114 | n_gru_units=gru_units, 115 | device=device, 116 | ) 117 | .to(device) 118 | .double() 119 | ) 120 | 121 | elif z0_encoder == "rnn": 122 | encoder_z0 = ( 123 | Encoder_z0_RNN(z0_dim, enc_input_dim, lstm_output_size=n_rec_dims, device=device).to(device).double() 124 | ) 125 | else: 126 | raise Exception("Unknown encoder for Latent ODE model: " + z0_encoder) # pylint: disable=broad-exception-raised 127 | 128 | decoder = Decoder(latents, gen_data_dim).to(device).double() 129 | 130 | diffeq_solver = DiffeqSolver( 131 | gen_data_dim, 132 | gen_ode_func, 133 | "dopri5", 134 | latents, 135 | odeint_rtol=1e-3, 136 | odeint_atol=1e-4, 137 | device=device, 138 | ) 139 | 140 | model = ( 141 | LatentODE( 142 | input_dim=gen_data_dim, 143 | latent_dim=latents, 144 | encoder_z0=encoder_z0, 145 | decoder=decoder, 146 | diffeq_solver=diffeq_solver, 147 | z0_prior=z0_prior, 148 | device=device, 149 | obsrv_std=obsrv_std, 150 | use_poisson_proc=poisson, 151 | use_binary_classif=classif, 152 | linear_classifier=linear_classif, 153 | classif_per_tp=classif_per_tp, 154 | n_labels=n_labels, 155 | train_classif_w_reconstr=(dataset == "physionet"), 156 | ) 157 | .to(device) 158 | .double() 159 | ) 160 | 161 | return model 162 | 163 | 164 | def create_LatentODE_model(args, input_dim, z0_prior, obsrv_std, device, classif_per_tp=False, n_labels=1): 165 | dim = args.latents 166 | if args.poisson: 167 | lambda_net = create_net(dim, input_dim, n_layers=1, n_units=args.units, nonlinear=nn.Tanh) 168 | 169 | # ODE function produces the gradient for latent state and for poisson rate 170 | ode_func_net = create_net( 171 | dim * 2, 172 | args.latents * 2, 173 | n_layers=args.gen_layers, 174 | n_units=args.units, 175 | nonlinear=nn.Tanh, 176 | ) 177 | 178 | gen_ode_func = ( 179 | ODEFunc_w_Poisson( 180 | input_dim=input_dim, 181 | latent_dim=args.latents * 2, 182 | ode_func_net=ode_func_net, 183 | lambda_net=lambda_net, 184 | device=device, 185 | ) 186 | .to(device) 187 | .double() 188 | ) 189 | else: 190 | dim = args.latents 191 | ode_func_net = create_net( 192 | dim, 193 | args.latents, 194 | n_layers=args.gen_layers, 195 | n_units=args.units, 196 | nonlinear=nn.Tanh, 197 | ) 198 | 199 | gen_ode_func = ( 200 | ODEFunc( 201 | input_dim=input_dim, 202 | latent_dim=args.latents, 203 | ode_func_net=ode_func_net, 204 | device=device, 205 | ) 206 | .to(device) 207 | .double() 208 | ) 209 | 210 | z0_diffeq_solver = None 211 | n_rec_dims = args.rec_dims 212 | enc_input_dim = int(input_dim) * 2 # we concatenate the mask 213 | gen_data_dim = input_dim 214 | 215 | z0_dim = args.latents 216 | if args.poisson: 217 | z0_dim += args.latents # predict the initial poisson rate 218 | 219 | if args.z0_encoder == "odernn": 220 | ode_func_net = create_net( 221 | n_rec_dims, 222 | n_rec_dims, 223 | n_layers=args.rec_layers, 224 | n_units=args.units, 225 | nonlinear=nn.Tanh, 226 | ) 227 | 228 | rec_ode_func = ( 229 | ODEFunc( 230 | input_dim=enc_input_dim, 231 | latent_dim=n_rec_dims, 232 | ode_func_net=ode_func_net, 233 | device=device, 234 | ) 235 | .to(device) 236 | .double() 237 | ) 238 | 239 | z0_diffeq_solver = DiffeqSolver( 240 | enc_input_dim, 241 | rec_ode_func, 242 | "euler", 243 | args.latents, 244 | odeint_rtol=1e-3, 245 | odeint_atol=1e-4, 246 | device=device, 247 | ) 248 | 249 | encoder_z0 = ( 250 | Encoder_z0_ODE_RNN( 251 | n_rec_dims, 252 | enc_input_dim, 253 | z0_diffeq_solver, 254 | z0_dim=z0_dim, 255 | n_gru_units=args.gru_units, 256 | device=device, 257 | ) 258 | .to(device) 259 | .double() 260 | ) 261 | 262 | elif args.z0_encoder == "rnn": 263 | encoder_z0 = ( 264 | Encoder_z0_RNN(z0_dim, enc_input_dim, lstm_output_size=n_rec_dims, device=device).to(device).double() 265 | ) 266 | else: 267 | raise Exception( # pylint: disable=broad-exception-raised 268 | "Unknown encoder for Latent ODE model: " + args.z0_encoder 269 | ) 270 | 271 | decoder = Decoder(args.latents, gen_data_dim).to(device).double() 272 | 273 | diffeq_solver = DiffeqSolver( 274 | gen_data_dim, 275 | gen_ode_func, 276 | "dopri5", 277 | args.latents, 278 | odeint_rtol=1e-3, 279 | odeint_atol=1e-4, 280 | device=device, 281 | ) 282 | 283 | model = ( 284 | LatentODE( 285 | input_dim=gen_data_dim, 286 | latent_dim=args.latents, 287 | encoder_z0=encoder_z0, 288 | decoder=decoder, 289 | diffeq_solver=diffeq_solver, 290 | z0_prior=z0_prior, 291 | device=device, 292 | obsrv_std=obsrv_std, 293 | use_poisson_proc=args.poisson, 294 | use_binary_classif=args.classif, 295 | linear_classifier=args.linear_classif, 296 | classif_per_tp=classif_per_tp, 297 | n_labels=n_labels, 298 | train_classif_w_reconstr=(args.dataset == "physionet"), 299 | ) 300 | .to(device) 301 | .double() 302 | ) 303 | 304 | return model 305 | -------------------------------------------------------------------------------- /baseline_models/latent_ode_lib/diffeq_solver.py: -------------------------------------------------------------------------------- 1 | """ 2 | Latent ODEs for Irregularly-Sampled Time Series 3 | Author: Yulia Rubanova 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torchdiffeq import odeint as odeint 9 | 10 | # ^ git clone https://github.com/rtqichen/torchdiffeq.git 11 | 12 | ##################################################################################################### 13 | 14 | 15 | class DiffeqSolver(nn.Module): 16 | def __init__( 17 | self, 18 | input_dim, 19 | ode_func, 20 | method, 21 | latents, 22 | odeint_rtol=1e-4, 23 | odeint_atol=1e-5, 24 | device=torch.device("cpu"), 25 | ): 26 | super(DiffeqSolver, self).__init__() 27 | 28 | self.ode_method = method 29 | self.latents = latents 30 | self.device = device 31 | self.ode_func = ode_func 32 | 33 | self.odeint_rtol = odeint_rtol 34 | self.odeint_atol = odeint_atol 35 | 36 | def forward(self, first_point, time_steps_to_predict, backwards=False): 37 | """ 38 | # Decode the trajectory through ODE Solver 39 | """ 40 | n_traj_samples, n_traj = first_point.size()[0], first_point.size()[1] 41 | n_dims = first_point.size()[-1] # pylint: disable=unused-variable # noqa: F841 42 | 43 | pred_y = odeint( 44 | self.ode_func, 45 | first_point, 46 | time_steps_to_predict, 47 | rtol=self.odeint_rtol, 48 | atol=self.odeint_atol, 49 | method=self.ode_method, 50 | ) 51 | pred_y = pred_y.permute(1, 2, 0, 3) # pyright: ignore 52 | 53 | assert torch.mean(pred_y[:, :, 0, :] - first_point) < 0.001 54 | assert pred_y.size()[0] == n_traj_samples 55 | assert pred_y.size()[1] == n_traj 56 | 57 | return pred_y 58 | 59 | def sample_traj_from_prior(self, starting_point_enc, time_steps_to_predict, n_traj_samples=1): 60 | """ 61 | # Decode the trajectory through ODE Solver using samples from the prior 62 | 63 | time_steps_to_predict: time steps at which we want to sample the new trajectory 64 | """ 65 | func = self.ode_func.sample_next_point_from_prior 66 | 67 | pred_y = odeint( 68 | func, 69 | starting_point_enc, 70 | time_steps_to_predict, 71 | rtol=self.odeint_rtol, 72 | atol=self.odeint_atol, 73 | method=self.ode_method, 74 | ) 75 | # shape: [n_traj_samples, n_traj, n_tp, n_dim] 76 | pred_y = pred_y.permute(1, 2, 0, 3) # pyright: ignore 77 | return pred_y 78 | -------------------------------------------------------------------------------- /baseline_models/latent_ode_lib/encoder_decoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Latent ODEs for Irregularly-Sampled Time Series 3 | Author: Yulia Rubanova 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn.modules.rnn import GRU 9 | 10 | from .utils import ( 11 | check_mask, 12 | get_device, 13 | init_network_weights, 14 | linspace_vector, 15 | reverse, 16 | split_last_dim, 17 | ) 18 | 19 | 20 | # GRU description: 21 | # http://www.wildml.com/2015/10/recurrent-neural-network-tutorial-part-4-implementing-a-grulstm-rnn-with-python-and-theano/ 22 | class GRU_unit(nn.Module): 23 | def __init__( 24 | self, 25 | latent_dim, 26 | input_dim, 27 | update_gate=None, 28 | reset_gate=None, 29 | new_state_net=None, 30 | n_units=100, 31 | device=torch.device("cpu"), 32 | ): 33 | super(GRU_unit, self).__init__() 34 | 35 | if update_gate is None: 36 | self.update_gate = nn.Sequential( 37 | nn.Linear(latent_dim * 2 + input_dim, n_units), 38 | nn.Tanh(), 39 | nn.Linear(n_units, latent_dim), 40 | nn.Sigmoid(), 41 | ) 42 | init_network_weights(self.update_gate) 43 | else: 44 | self.update_gate = update_gate 45 | 46 | if reset_gate is None: 47 | self.reset_gate = nn.Sequential( 48 | nn.Linear(latent_dim * 2 + input_dim, n_units), 49 | nn.Tanh(), 50 | nn.Linear(n_units, latent_dim), 51 | nn.Sigmoid(), 52 | ) 53 | init_network_weights(self.reset_gate) 54 | else: 55 | self.reset_gate = reset_gate 56 | 57 | if new_state_net is None: 58 | self.new_state_net = nn.Sequential( 59 | nn.Linear(latent_dim * 2 + input_dim, n_units), 60 | nn.Tanh(), 61 | nn.Linear(n_units, latent_dim * 2), 62 | ) 63 | init_network_weights(self.new_state_net) 64 | else: 65 | self.new_state_net = new_state_net 66 | 67 | def forward(self, y_mean, y_std, x, masked_update=True): 68 | y_concat = torch.cat([y_mean, y_std, x], -1) 69 | 70 | update_gate = self.update_gate(y_concat) 71 | reset_gate = self.reset_gate(y_concat) 72 | concat = torch.cat([y_mean * reset_gate, y_std * reset_gate, x], -1) 73 | 74 | new_state, new_state_std = split_last_dim(self.new_state_net(concat)) 75 | new_state_std = new_state_std.abs() 76 | 77 | new_y = (1 - update_gate) * new_state + update_gate * y_mean 78 | new_y_std = (1 - update_gate) * new_state_std + update_gate * y_std 79 | 80 | assert not torch.isnan(new_y).any() 81 | 82 | if masked_update: 83 | # IMPORTANT: assumes that x contains both data and mask 84 | # update only the hidden states for hidden state only if at least one feature is present for the current time point 85 | n_data_dims = x.size(-1) // 2 86 | mask = x[:, :, n_data_dims:] 87 | check_mask(x[:, :, :n_data_dims], mask) 88 | 89 | mask = (torch.sum(mask, -1, keepdim=True) > 0).float() 90 | 91 | assert not torch.isnan(mask).any() 92 | 93 | new_y = mask * new_y + (1 - mask) * y_mean 94 | new_y_std = mask * new_y_std + (1 - mask) * y_std 95 | 96 | if torch.isnan(new_y).any(): 97 | print("new_y is nan!") 98 | print(mask) 99 | print(y_mean) 100 | exit() 101 | 102 | new_y_std = new_y_std.abs() 103 | return new_y, new_y_std 104 | 105 | 106 | class Encoder_z0_RNN(nn.Module): 107 | def __init__( 108 | self, 109 | latent_dim, 110 | input_dim, 111 | lstm_output_size=20, 112 | use_delta_t=True, 113 | device=torch.device("cpu"), 114 | ): 115 | super(Encoder_z0_RNN, self).__init__() 116 | 117 | self.gru_rnn_output_size = lstm_output_size 118 | self.latent_dim = latent_dim 119 | self.input_dim = input_dim 120 | self.device = device 121 | self.use_delta_t = use_delta_t 122 | 123 | self.hiddens_to_z0 = nn.Sequential( 124 | nn.Linear(self.gru_rnn_output_size, 50), 125 | nn.Tanh(), 126 | nn.Linear(50, latent_dim * 2), 127 | ) 128 | 129 | init_network_weights(self.hiddens_to_z0) 130 | 131 | input_dim = self.input_dim 132 | 133 | if use_delta_t: 134 | self.input_dim += 1 135 | self.gru_rnn = GRU(self.input_dim, self.gru_rnn_output_size).to(device).double() 136 | 137 | def forward(self, data, time_steps, run_backwards=True): 138 | # IMPORTANT: assumes that 'data' already has mask concatenated to it 139 | 140 | # data shape: [n_traj, n_tp, n_dims] 141 | # shape required for rnn: (seq_len, batch, input_size) 142 | # t0: not used here 143 | n_traj = data.size(0) 144 | 145 | assert not torch.isnan(data).any() 146 | assert not torch.isnan(time_steps).any() 147 | 148 | data = data.permute(1, 0, 2) 149 | 150 | if run_backwards: 151 | # Look at data in the reverse order: from later points to the first 152 | data = reverse(data) 153 | 154 | if self.use_delta_t: 155 | delta_t = time_steps[1:] - time_steps[:-1] 156 | if run_backwards: 157 | # we are going backwards in time with 158 | delta_t = reverse(delta_t) 159 | # append zero delta t in the end 160 | delta_t = torch.cat((delta_t, torch.zeros(1).to(self.device))) 161 | delta_t = delta_t.unsqueeze(1).repeat((1, n_traj)).unsqueeze(-1) 162 | data = torch.cat((delta_t, data), -1) 163 | 164 | outputs, _ = self.gru_rnn(data) 165 | 166 | # LSTM output shape: (seq_len, batch, num_directions * hidden_size) 167 | last_output = outputs[-1] 168 | 169 | self.extra_info = {"rnn_outputs": outputs, "time_points": time_steps} 170 | 171 | mean, std = split_last_dim(self.hiddens_to_z0(last_output)) 172 | std = std.abs() 173 | 174 | assert not torch.isnan(mean).any() 175 | assert not torch.isnan(std).any() 176 | 177 | return mean.unsqueeze(0), std.unsqueeze(0) 178 | 179 | 180 | class Encoder_z0_ODE_RNN(nn.Module): 181 | # Derive z0 by running ode backwards. 182 | # For every y_i we have two versions: encoded from data and derived from ODE by running it backwards from t_i+1 to t_i 183 | # Compute a weighted sum of y_i from data and y_i from ode. Use weighted y_i as an initial value for ODE runing from t_i to t_i-1 184 | # Continue until we get to z0 185 | def __init__( 186 | self, 187 | latent_dim, 188 | input_dim, 189 | z0_diffeq_solver=None, 190 | z0_dim=None, 191 | GRU_update=None, 192 | n_gru_units=100, 193 | device=torch.device("cpu"), 194 | ): 195 | super(Encoder_z0_ODE_RNN, self).__init__() 196 | 197 | if z0_dim is None: 198 | self.z0_dim = latent_dim 199 | else: 200 | self.z0_dim = z0_dim 201 | 202 | if GRU_update is None: 203 | self.GRU_update = GRU_unit(latent_dim, input_dim, n_units=n_gru_units, device=device).to(device).double() 204 | else: 205 | self.GRU_update = GRU_update 206 | 207 | self.z0_diffeq_solver = z0_diffeq_solver 208 | self.latent_dim = latent_dim 209 | self.input_dim = input_dim 210 | self.device = device 211 | self.extra_info = None 212 | 213 | self.transform_z0 = nn.Sequential( 214 | nn.Linear(latent_dim * 2, 100), 215 | nn.Tanh(), 216 | nn.Linear(100, self.z0_dim * 2), 217 | ) 218 | init_network_weights(self.transform_z0) 219 | 220 | def forward(self, data, time_steps, run_backwards=True, save_info=False): 221 | # data, time_steps -- observations and their time stamps 222 | # IMPORTANT: assumes that 'data' already has mask concatenated to it 223 | assert not torch.isnan(data).any() 224 | assert not torch.isnan(time_steps).any() 225 | 226 | n_traj, n_tp, n_dims = data.size() # pylint: disable=unused-variable 227 | if len(time_steps) == 1: 228 | prev_y = torch.zeros((1, n_traj, self.latent_dim)).to(self.device) 229 | prev_std = torch.zeros((1, n_traj, self.latent_dim)).to(self.device) 230 | 231 | xi = data[:, 0, :].unsqueeze(0) 232 | 233 | last_yi, last_yi_std = self.GRU_update(prev_y, prev_std, xi) 234 | extra_info = None 235 | else: 236 | last_yi, last_yi_std, _, extra_info = self.run_odernn( 237 | data, time_steps, run_backwards=run_backwards, save_info=save_info 238 | ) 239 | 240 | means_z0 = last_yi.reshape(1, n_traj, self.latent_dim) 241 | std_z0 = last_yi_std.reshape(1, n_traj, self.latent_dim) 242 | 243 | mean_z0, std_z0 = split_last_dim(self.transform_z0(torch.cat((means_z0, std_z0), -1))) 244 | std_z0 = std_z0.abs() 245 | if save_info: 246 | self.extra_info = extra_info 247 | 248 | return mean_z0, std_z0 249 | 250 | def run_odernn(self, data, time_steps, run_backwards=True, save_info=False): 251 | # IMPORTANT: assumes that 'data' already has mask concatenated to it 252 | 253 | n_traj, n_tp, n_dims = data.size() # pylint: disable=unused-variable 254 | extra_info = [] 255 | 256 | device = get_device(data) 257 | 258 | prev_y = torch.zeros((1, n_traj, self.latent_dim)).to(device).double() 259 | prev_std = torch.zeros((1, n_traj, self.latent_dim)).to(device).double() 260 | 261 | prev_t, t_i = time_steps[-1] + 0.01, time_steps[-1] 262 | 263 | interval_length = time_steps[-1] - time_steps[0] 264 | minimum_step = interval_length / 50 265 | 266 | # print("minimum step: {}".format(minimum_step)) 267 | 268 | assert not torch.isnan(data).any() 269 | assert not torch.isnan(time_steps).any() 270 | 271 | latent_ys = [] 272 | # Run ODE backwards and combine the y(t) estimates using gating 273 | time_points_iter = range(0, len(time_steps)) 274 | if run_backwards: 275 | time_points_iter = reversed(time_points_iter) 276 | 277 | for i in time_points_iter: 278 | if (prev_t - t_i) < minimum_step: 279 | time_points = torch.stack((prev_t, t_i)) 280 | inc = self.z0_diffeq_solver.ode_func(prev_t, prev_y) * (t_i - prev_t) # pyright: ignore 281 | 282 | assert not torch.isnan(inc).any() 283 | 284 | ode_sol = prev_y + inc 285 | ode_sol = torch.stack((prev_y, ode_sol), 2).to(device).double() 286 | 287 | assert not torch.isnan(ode_sol).any() 288 | else: 289 | n_intermediate_tp = max(2, ((prev_t - t_i) / minimum_step).int()) 290 | 291 | time_points = linspace_vector(prev_t, t_i, n_intermediate_tp) 292 | ode_sol = self.z0_diffeq_solver(prev_y, time_points) # pyright: ignore 293 | 294 | assert not torch.isnan(ode_sol).any() 295 | 296 | if torch.mean(ode_sol[:, :, 0, :] - prev_y) >= 0.001: 297 | print("Error: first point of the ODE is not equal to initial value") 298 | print(torch.mean(ode_sol[:, :, 0, :] - prev_y)) 299 | exit() 300 | # assert(torch.mean(ode_sol[:, :, 0, :] - prev_y) < 0.001) 301 | 302 | yi_ode = ode_sol[:, :, -1, :] 303 | xi = data[:, i, :].unsqueeze(0) 304 | 305 | yi, yi_std = self.GRU_update(yi_ode, prev_std, xi) 306 | 307 | prev_y, prev_std = yi, yi_std 308 | prev_t, t_i = time_steps[i], time_steps[i - 1] 309 | 310 | latent_ys.append(yi) 311 | 312 | if save_info: 313 | d = { 314 | "yi_ode": yi_ode.detach(), # "yi_from_data": yi_from_data, 315 | "yi": yi.detach(), 316 | "yi_std": yi_std.detach(), 317 | "time_points": time_points.detach(), 318 | "ode_sol": ode_sol.detach(), 319 | } 320 | extra_info.append(d) 321 | 322 | latent_ys = torch.stack(latent_ys, 1) 323 | 324 | assert not torch.isnan(yi).any() # pyright: ignore 325 | assert not torch.isnan(yi_std).any() # pyright: ignore 326 | 327 | return yi, yi_std, latent_ys, extra_info # pyright: ignore 328 | 329 | 330 | class Decoder(nn.Module): 331 | def __init__(self, latent_dim, input_dim): 332 | super(Decoder, self).__init__() 333 | # decode data from latent space where we are solving an ODE back to the data space 334 | 335 | decoder = nn.Sequential( 336 | nn.Linear(latent_dim, input_dim), 337 | ) 338 | 339 | init_network_weights(decoder) 340 | self.decoder = decoder 341 | 342 | def forward(self, data): 343 | return self.decoder(data) 344 | -------------------------------------------------------------------------------- /baseline_models/latent_ode_lib/latent_ode.py: -------------------------------------------------------------------------------- 1 | """ 2 | Latent ODEs for Irregularly-Sampled Time Series 3 | Author: Yulia Rubanova 4 | """ 5 | 6 | import torch 7 | 8 | from .base_models import VAE_Baseline 9 | from .encoder_decoder import Encoder_z0_ODE_RNN, Encoder_z0_RNN 10 | from .utils import get_device, sample_standard_gaussian 11 | 12 | 13 | class LatentODE(VAE_Baseline): # pylint: disable=abstract-method 14 | def __init__( 15 | self, 16 | input_dim, 17 | latent_dim, 18 | encoder_z0, 19 | decoder, 20 | diffeq_solver, 21 | z0_prior, 22 | device, 23 | obsrv_std=None, 24 | use_binary_classif=False, 25 | use_poisson_proc=False, 26 | linear_classifier=False, 27 | classif_per_tp=False, 28 | n_labels=1, 29 | train_classif_w_reconstr=False, 30 | ): 31 | super(LatentODE, self).__init__( 32 | input_dim=input_dim, 33 | latent_dim=latent_dim, 34 | z0_prior=z0_prior, 35 | device=device, 36 | obsrv_std=obsrv_std, # pyright: ignore 37 | use_binary_classif=use_binary_classif, 38 | classif_per_tp=classif_per_tp, 39 | linear_classifier=linear_classifier, 40 | use_poisson_proc=use_poisson_proc, 41 | n_labels=n_labels, 42 | train_classif_w_reconstr=train_classif_w_reconstr, 43 | ) 44 | 45 | self.encoder_z0 = encoder_z0 46 | self.diffeq_solver = diffeq_solver 47 | self.decoder = decoder 48 | self.use_poisson_proc = use_poisson_proc 49 | 50 | def get_reconstruction( 51 | self, 52 | time_steps_to_predict_i, 53 | truth, 54 | truth_time_steps_i, 55 | mask=None, 56 | n_traj_samples=1, 57 | run_backwards=True, 58 | mode=None, 59 | ): 60 | time_steps_to_predict = torch.flatten(time_steps_to_predict_i) 61 | truth_time_steps = torch.flatten(truth_time_steps_i) 62 | 63 | if isinstance(self.encoder_z0, Encoder_z0_ODE_RNN) or isinstance(self.encoder_z0, Encoder_z0_RNN): 64 | truth_w_mask = truth 65 | if mask is not None: 66 | truth_w_mask = torch.cat((truth, mask), -1) 67 | first_point_mu, first_point_std = self.encoder_z0( 68 | truth_w_mask, truth_time_steps, run_backwards=run_backwards 69 | ) 70 | 71 | means_z0 = first_point_mu.repeat(n_traj_samples, 1, 1) 72 | sigma_z0 = first_point_std.repeat(n_traj_samples, 1, 1) 73 | first_point_enc = sample_standard_gaussian(means_z0, sigma_z0) 74 | 75 | else: 76 | raise Exception( # pylint: disable=broad-exception-raised 77 | f"Unknown encoder type {type(self.encoder_z0).__name__}" 78 | ) 79 | 80 | first_point_std = first_point_std.abs() 81 | assert torch.sum(first_point_std < 0) == 0.0 82 | 83 | if self.use_poisson_proc: 84 | n_traj_samples, n_traj, n_dims = first_point_enc.size() # pylint: disable=unused-variable 85 | # append a vector of zeros to compute the integral of lambda 86 | zeros = torch.zeros([n_traj_samples, n_traj, self.input_dim]).to(get_device(truth)) 87 | first_point_enc_aug = torch.cat((first_point_enc, zeros), -1) 88 | else: 89 | first_point_enc_aug = first_point_enc 90 | 91 | assert not torch.isnan(time_steps_to_predict).any() 92 | assert not torch.isnan(first_point_enc).any() 93 | assert not torch.isnan(first_point_enc_aug).any() 94 | 95 | # Shape of sol_y [n_traj_samples, n_samples, n_timepoints, n_latents] 96 | sol_y = self.diffeq_solver(first_point_enc_aug, time_steps_to_predict) 97 | 98 | if self.use_poisson_proc: 99 | ( 100 | sol_y, 101 | log_lambda_y, 102 | int_lambda, 103 | _, 104 | ) = self.diffeq_solver.ode_func.extract_poisson_rate(sol_y) 105 | 106 | assert torch.sum(int_lambda[:, :, 0, :]) == 0.0 107 | assert torch.sum(int_lambda[0, 0, -1, :] <= 0) == 0.0 108 | 109 | pred_x = self.decoder(sol_y) 110 | 111 | all_extra_info = { 112 | "first_point": (first_point_mu, first_point_std, first_point_enc), 113 | "latent_traj": sol_y.detach(), 114 | } 115 | 116 | if self.use_poisson_proc: 117 | # integral of lambda from the last step of ODE Solver 118 | all_extra_info["int_lambda"] = int_lambda[:, :, -1, :] # pyright: ignore 119 | all_extra_info["log_lambda_y"] = log_lambda_y # pyright: ignore 120 | 121 | if self.use_binary_classif: 122 | if self.classif_per_tp: 123 | all_extra_info["label_predictions"] = self.classifier(sol_y) 124 | else: 125 | all_extra_info["label_predictions"] = self.classifier(first_point_enc).squeeze(-1) 126 | 127 | return pred_x, all_extra_info 128 | 129 | def sample_traj_from_prior(self, time_steps_to_predict, n_traj_samples=1): 130 | # input_dim = starting_point.size()[-1] 131 | # starting_point = starting_point.view(1,1,input_dim) 132 | 133 | # Sample z0 from prior 134 | starting_point_enc = self.z0_prior.sample([n_traj_samples, 1, self.latent_dim]).squeeze(-1) 135 | 136 | starting_point_enc_aug = starting_point_enc 137 | if self.use_poisson_proc: 138 | n_traj_samples, n_traj, n_dims = starting_point_enc.size() # pylint: disable=unused-variable 139 | # append a vector of zeros to compute the integral of lambda 140 | zeros = torch.zeros(n_traj_samples, n_traj, self.input_dim).to(self.device) 141 | starting_point_enc_aug = torch.cat((starting_point_enc, zeros), -1) 142 | 143 | sol_y = self.diffeq_solver.sample_traj_from_prior( 144 | starting_point_enc_aug, time_steps_to_predict, n_traj_samples=3 145 | ) 146 | 147 | if self.use_poisson_proc: 148 | ( 149 | sol_y, 150 | log_lambda_y, # pylint: disable=unused-variable 151 | int_lambda, # pylint: disable=unused-variable 152 | _, 153 | ) = self.diffeq_solver.ode_func.extract_poisson_rate(sol_y) 154 | 155 | return self.decoder(sol_y) 156 | -------------------------------------------------------------------------------- /baseline_models/latent_ode_lib/likelihood_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Latent ODEs for Irregularly-Sampled Time Series 3 | Author: Yulia Rubanova 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.distributions import Independent 9 | from torch.distributions.normal import Normal 10 | 11 | from .utils import get_device 12 | 13 | 14 | def gaussian_log_likelihood(mu_2d, data_2d, obsrv_std, indices=None): 15 | n_data_points = mu_2d.size()[-1] 16 | 17 | if n_data_points > 0: 18 | gaussian = Independent(Normal(loc=mu_2d, scale=obsrv_std.repeat(n_data_points)), 1) 19 | log_prob = gaussian.log_prob(data_2d) 20 | log_prob = log_prob / n_data_points 21 | else: 22 | log_prob = torch.zeros([1]).to(get_device(data_2d)).squeeze() 23 | return log_prob 24 | 25 | 26 | def poisson_log_likelihood(masked_log_lambdas, masked_data, indices, int_lambdas): 27 | # masked_log_lambdas and masked_data 28 | n_data_points = masked_data.size()[-1] 29 | 30 | if n_data_points > 0: 31 | log_prob = torch.sum(masked_log_lambdas) - int_lambdas[indices] 32 | # log_prob = log_prob / n_data_points 33 | else: 34 | log_prob = torch.zeros([1]).to(get_device(masked_data)).squeeze() 35 | return log_prob 36 | 37 | 38 | def compute_binary_CE_loss(label_predictions, mortality_label): 39 | # print("Computing binary classification loss: compute_CE_loss") 40 | 41 | mortality_label = mortality_label.reshape(-1) 42 | 43 | if len(label_predictions.size()) == 1: 44 | label_predictions = label_predictions.unsqueeze(0) 45 | 46 | n_traj_samples = label_predictions.size(0) 47 | label_predictions = label_predictions.reshape(n_traj_samples, -1) 48 | 49 | idx_not_nan = ~torch.isnan(mortality_label) 50 | if len(idx_not_nan) == 0.0: 51 | print("All are labels are NaNs!") 52 | ce_loss = torch.Tensor(0.0).to(get_device(mortality_label)) 53 | 54 | label_predictions = label_predictions[:, idx_not_nan] 55 | mortality_label = mortality_label[idx_not_nan] 56 | 57 | if torch.sum(mortality_label == 0.0) == 0 or torch.sum(mortality_label == 1.0) == 0: 58 | print("Warning: all examples in a batch belong to the same class -- please increase the batch size.") 59 | 60 | assert not torch.isnan(label_predictions).any() 61 | assert not torch.isnan(mortality_label).any() 62 | 63 | # For each trajectory, we get n_traj_samples samples from z0 -- compute loss on all of them 64 | mortality_label = mortality_label.repeat(n_traj_samples, 1) 65 | ce_loss = nn.BCEWithLogitsLoss()(label_predictions, mortality_label) 66 | 67 | # divide by number of patients in a batch 68 | ce_loss = ce_loss / n_traj_samples 69 | return ce_loss 70 | 71 | 72 | def compute_multiclass_CE_loss(label_predictions, true_label, mask): 73 | # print("Computing multi-class classification loss: compute_multiclass_CE_loss") 74 | 75 | if len(label_predictions.size()) == 3: 76 | label_predictions = label_predictions.unsqueeze(0) 77 | 78 | n_traj_samples, n_traj, n_tp, n_dims = label_predictions.size() 79 | 80 | # assert(not torch.isnan(label_predictions).any()) 81 | # assert(not torch.isnan(true_label).any()) 82 | 83 | # For each trajectory, we get n_traj_samples samples from z0 -- compute loss on all of them 84 | true_label = true_label.repeat(n_traj_samples, 1, 1) 85 | 86 | label_predictions = label_predictions.reshape(n_traj_samples * n_traj * n_tp, n_dims) 87 | true_label = true_label.reshape(n_traj_samples * n_traj * n_tp, n_dims) 88 | 89 | # choose time points with at least one measurement 90 | mask = torch.sum(mask, -1) > 0 91 | 92 | # repeat the mask for each label to mark that the label for this time point is present 93 | pred_mask = mask.repeat(n_dims, 1, 1).permute(1, 2, 0) 94 | 95 | label_mask = mask 96 | pred_mask = pred_mask.repeat(n_traj_samples, 1, 1, 1) 97 | label_mask = label_mask.repeat(n_traj_samples, 1, 1, 1) 98 | 99 | pred_mask = pred_mask.reshape(n_traj_samples * n_traj * n_tp, n_dims) 100 | label_mask = label_mask.reshape(n_traj_samples * n_traj * n_tp, 1) 101 | 102 | if (label_predictions.size(-1) > 1) and (true_label.size(-1) > 1): 103 | assert label_predictions.size(-1) == true_label.size(-1) 104 | # targets are in one-hot encoding -- convert to indices 105 | _, true_label = true_label.max(-1) 106 | 107 | res = [] 108 | for i in range(true_label.size(0)): 109 | pred_masked = torch.masked_select(label_predictions[i], pred_mask[i].bool()) 110 | labels = torch.masked_select(true_label[i], label_mask[i].bool()) 111 | 112 | pred_masked = pred_masked.reshape(-1, n_dims) 113 | 114 | if len(labels) == 0: 115 | continue 116 | 117 | ce_loss = nn.CrossEntropyLoss()(pred_masked, labels.long()) 118 | res.append(ce_loss) 119 | 120 | ce_loss = torch.stack(res, 0).to(get_device(label_predictions)) 121 | ce_loss = torch.mean(ce_loss) 122 | # # divide by number of patients in a batch 123 | # ce_loss = ce_loss / n_traj_samples 124 | return ce_loss 125 | 126 | 127 | def compute_masked_likelihood(mu, data, mask, likelihood_func): 128 | # Compute the likelihood per patient and per attribute so that we don't prioritize patients with more measurements 129 | n_traj_samples, n_traj, n_timepoints, n_dims = data.size() # pylint: disable=unused-variable 130 | 131 | res = [] 132 | for i in range(n_traj_samples): 133 | for k in range(n_traj): 134 | for j in range(n_dims): 135 | data_masked = torch.masked_select(data[i, k, :, j], mask[i, k, :, j].bool()) 136 | 137 | # assert(torch.sum(data_masked == 0.) < 10) 138 | 139 | mu_masked = torch.masked_select(mu[i, k, :, j], mask[i, k, :, j].bool()) 140 | log_prob = likelihood_func(mu_masked, data_masked, indices=(i, k, j)) 141 | res.append(log_prob) 142 | # shape: [n_traj*n_traj_samples, 1] 143 | 144 | res = torch.stack(res, 0).to(get_device(data)) 145 | res = res.reshape((n_traj_samples, n_traj, n_dims)) 146 | # Take mean over the number of dimensions 147 | res = torch.mean(res, -1) # !!!!!!!!!!! changed from sum to mean 148 | res = res.transpose(0, 1) 149 | return res 150 | 151 | 152 | def masked_gaussian_log_density(mu, data, obsrv_std, mask=None): 153 | # these cases are for plotting through plot_estim_density 154 | if len(mu.size()) == 3: 155 | # add additional dimension for gp samples 156 | mu = mu.unsqueeze(0) 157 | 158 | if len(data.size()) == 2: 159 | # add additional dimension for gp samples and time step 160 | data = data.unsqueeze(0).unsqueeze(2) 161 | elif len(data.size()) == 3: 162 | # add additional dimension for gp samples 163 | data = data.unsqueeze(0) 164 | 165 | n_traj_samples, n_traj, n_timepoints, n_dims = mu.size() 166 | 167 | assert data.size()[-1] == n_dims 168 | 169 | # Shape after permutation: [n_traj, n_traj_samples, n_timepoints, n_dims] 170 | if mask is None: 171 | mu_flat = mu.reshape(n_traj_samples * n_traj, n_timepoints * n_dims) 172 | n_traj_samples, n_traj, n_timepoints, n_dims = data.size() 173 | data_flat = data.reshape(n_traj_samples * n_traj, n_timepoints * n_dims) 174 | 175 | res = gaussian_log_likelihood(mu_flat, data_flat, obsrv_std) 176 | res = res.reshape(n_traj_samples, n_traj).transpose(0, 1) 177 | else: 178 | # Compute the likelihood per patient so that we don't prioritize patients with more measurements 179 | def func(mu, data, indices): 180 | return gaussian_log_likelihood(mu, data, obsrv_std=obsrv_std, indices=indices) 181 | 182 | res = compute_masked_likelihood(mu, data, mask, func) 183 | return res 184 | 185 | 186 | def mse(mu, data, indices=None): 187 | n_data_points = mu.size()[-1] 188 | 189 | if n_data_points > 0: 190 | mse = nn.MSELoss()(mu, data) # pylint: disable=redefined-outer-name 191 | else: 192 | mse = torch.zeros([1]).to(get_device(data)).squeeze() 193 | return mse 194 | 195 | 196 | def compute_mse(mu, data, mask=None): 197 | # these cases are for plotting through plot_estim_density 198 | if len(mu.size()) == 3: 199 | # add additional dimension for gp samples 200 | mu = mu.unsqueeze(0) 201 | 202 | if len(data.size()) == 2: 203 | # add additional dimension for gp samples and time step 204 | data = data.unsqueeze(0).unsqueeze(2) 205 | elif len(data.size()) == 3: 206 | # add additional dimension for gp samples 207 | data = data.unsqueeze(0) 208 | 209 | n_traj_samples, n_traj, n_timepoints, n_dims = mu.size() 210 | assert data.size()[-1] == n_dims 211 | 212 | # Shape after permutation: [n_traj, n_traj_samples, n_timepoints, n_dims] 213 | if mask is None: 214 | mu_flat = mu.reshape(n_traj_samples * n_traj, n_timepoints * n_dims) 215 | n_traj_samples, n_traj, n_timepoints, n_dims = data.size() 216 | data_flat = data.reshape(n_traj_samples * n_traj, n_timepoints * n_dims) 217 | res = mse(mu_flat, data_flat) 218 | else: 219 | # Compute the likelihood per patient so that we don't prioritize patients with more measurements 220 | res = compute_masked_likelihood(mu, data, mask, mse) 221 | return res 222 | 223 | 224 | def compute_poisson_proc_likelihood(truth, pred_y, info, mask=None): 225 | # Compute Poisson likelihood 226 | # https://math.stackexchange.com/questions/344487/log-likelihood-of-a-realization-of-a-poisson-process 227 | # Sum log lambdas across all time points 228 | if mask is None: 229 | poisson_log_l = torch.sum(info["log_lambda_y"], 2) - info["int_lambda"] 230 | # Sum over data dims 231 | poisson_log_l = torch.mean(poisson_log_l, -1) 232 | else: 233 | # Compute likelihood of the data under the predictions 234 | truth_repeated = truth.repeat(pred_y.size(0), 1, 1, 1) 235 | mask_repeated = mask.repeat(pred_y.size(0), 1, 1, 1) 236 | 237 | # Compute the likelihood per patient and per attribute so that we don't priorize patients with more measurements 238 | int_lambda = info["int_lambda"] 239 | 240 | def f(log_lam, data, indices): 241 | return poisson_log_likelihood(log_lam, data, indices, int_lambda) 242 | 243 | poisson_log_l = compute_masked_likelihood(info["log_lambda_y"], truth_repeated, mask_repeated, f) 244 | poisson_log_l = poisson_log_l.permute(1, 0) 245 | # Take mean over n_traj 246 | # poisson_log_l = torch.mean(poisson_log_l, 1) 247 | 248 | # poisson_log_l shape: [n_traj_samples, n_traj] 249 | return poisson_log_l 250 | -------------------------------------------------------------------------------- /baseline_models/latent_ode_lib/ode_func.py: -------------------------------------------------------------------------------- 1 | """ 2 | Latent ODEs for Irregularly-Sampled Time Series 3 | Author: Yulia Rubanova 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from .utils import init_network_weights 10 | 11 | ##################################################################################################### 12 | 13 | 14 | class ODEFunc(nn.Module): 15 | def __init__(self, input_dim, latent_dim, ode_func_net, device=torch.device("cpu")): 16 | """ 17 | input_dim: dimensionality of the input 18 | latent_dim: dimensionality used for ODE. Analog of a continous latent state 19 | """ 20 | super(ODEFunc, self).__init__() 21 | 22 | self.input_dim = input_dim 23 | self.device = device 24 | 25 | init_network_weights(ode_func_net) 26 | self.gradient_net = ode_func_net 27 | self.nfe = 0 28 | 29 | def forward(self, t_local, y, backwards=False): 30 | """ 31 | Perform one step in solving ODE. Given current data point y and current time point t_local, 32 | returns gradient dy/dt at this time point 33 | 34 | t_local: current time point 35 | y: value at the current time point 36 | """ 37 | grad = self.get_ode_gradient_nn(t_local, y) 38 | if backwards: 39 | grad = -grad 40 | return grad 41 | 42 | def get_ode_gradient_nn(self, t_local, y): 43 | self.nfe += 1 44 | return self.gradient_net(y) 45 | 46 | def sample_next_point_from_prior(self, t_local, y): 47 | """ 48 | t_local: current time point 49 | y: value at the current time point 50 | """ 51 | return self.get_ode_gradient_nn(t_local, y) 52 | 53 | 54 | ##################################################################################################### 55 | 56 | 57 | class ODEFunc_w_Poisson(ODEFunc): 58 | def __init__( 59 | self, 60 | input_dim, 61 | latent_dim, 62 | ode_func_net, 63 | lambda_net, 64 | device=torch.device("cpu"), 65 | ): 66 | """ 67 | input_dim: dimensionality of the input 68 | latent_dim: dimensionality used for ODE. Analog of a continous latent state 69 | """ 70 | super(ODEFunc_w_Poisson, self).__init__(input_dim, latent_dim, ode_func_net, device) 71 | 72 | self.latent_ode = ODEFunc( 73 | input_dim=input_dim, 74 | latent_dim=latent_dim, 75 | ode_func_net=ode_func_net, 76 | device=device, 77 | ) 78 | 79 | self.latent_dim = latent_dim 80 | self.lambda_net = lambda_net 81 | # The computation of poisson likelihood can become numerically unstable. 82 | # The integral lambda(t) dt can take large values. In fact, it is equal to the expected number of 83 | # events on the interval [0,T] 84 | # Exponent of lambda can also take large values 85 | # So we divide lambda by the constant and then multiply the integral of lambda by the constant 86 | self.const_for_lambda = torch.Tensor([100.0]).to(device).double() 87 | 88 | def extract_poisson_rate(self, augmented, final_result=True): 89 | y, log_lambdas, int_lambda = None, None, None 90 | 91 | assert augmented.size(-1) == self.latent_dim + self.input_dim 92 | latent_lam_dim = self.latent_dim // 2 93 | 94 | if len(augmented.size()) == 3: 95 | int_lambda = augmented[:, :, -self.input_dim :] 96 | y_latent_lam = augmented[:, :, : -self.input_dim] 97 | 98 | log_lambdas = self.lambda_net(y_latent_lam[:, :, -latent_lam_dim:]) 99 | y = y_latent_lam[:, :, :-latent_lam_dim] 100 | 101 | elif len(augmented.size()) == 4: 102 | int_lambda = augmented[:, :, :, -self.input_dim :] 103 | y_latent_lam = augmented[:, :, :, : -self.input_dim] 104 | 105 | log_lambdas = self.lambda_net(y_latent_lam[:, :, :, -latent_lam_dim:]) 106 | y = y_latent_lam[:, :, :, :-latent_lam_dim] 107 | 108 | # Multiply the integral over lambda by a constant 109 | # only when we have finished the integral computation (i.e. this is not a call in get_ode_gradient_nn) 110 | if final_result: 111 | int_lambda = int_lambda * self.const_for_lambda 112 | 113 | # Latents for performing reconstruction (y) have the same size as latent poisson rate (log_lambdas) 114 | assert y.size(-1) == latent_lam_dim # pyright: ignore 115 | 116 | return y, log_lambdas, int_lambda, y_latent_lam # pyright: ignore 117 | 118 | def get_ode_gradient_nn(self, t_local, augmented): # pylint: disable=arguments-renamed 119 | # pylint: disable-next=unused-variable 120 | y, log_lam, int_lambda, y_latent_lam = self.extract_poisson_rate(augmented, final_result=False) 121 | dydt_dldt = self.latent_ode(t_local, y_latent_lam) 122 | 123 | log_lam = log_lam - torch.log(self.const_for_lambda) 124 | return torch.cat((dydt_dldt, torch.exp(log_lam)), -1) 125 | -------------------------------------------------------------------------------- /baseline_models/latent_ode_lib/ode_rnn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Latent ODEs for Irregularly-Sampled Time Series 3 | Author: Yulia Rubanova 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from .base_models import Baseline 10 | from .encoder_decoder import Encoder_z0_ODE_RNN 11 | from .utils import init_network_weights, shift_outputs 12 | 13 | 14 | class ODE_RNN(Baseline): # pylint: disable=abstract-method 15 | def __init__( 16 | self, 17 | input_dim, 18 | latent_dim, 19 | device=torch.device("cpu"), 20 | z0_diffeq_solver=None, 21 | n_gru_units=100, 22 | n_units=100, 23 | concat_mask=False, 24 | obsrv_std=0.1, 25 | use_binary_classif=False, 26 | classif_per_tp=False, 27 | n_labels=1, 28 | train_classif_w_reconstr=False, 29 | ): 30 | Baseline.__init__( 31 | self, 32 | input_dim, 33 | latent_dim, 34 | device=device, 35 | obsrv_std=obsrv_std, 36 | use_binary_classif=use_binary_classif, 37 | classif_per_tp=classif_per_tp, 38 | n_labels=n_labels, 39 | train_classif_w_reconstr=train_classif_w_reconstr, 40 | ) 41 | 42 | ode_rnn_encoder_dim = latent_dim 43 | 44 | self.ode_gru = ( 45 | Encoder_z0_ODE_RNN( 46 | latent_dim=ode_rnn_encoder_dim, 47 | input_dim=(input_dim) * 2, # input and the mask 48 | z0_diffeq_solver=z0_diffeq_solver, 49 | n_gru_units=n_gru_units, 50 | device=device, 51 | ) 52 | .to(device) 53 | .double() 54 | ) 55 | 56 | self.z0_diffeq_solver = z0_diffeq_solver 57 | 58 | self.decoder = nn.Sequential( 59 | nn.Linear(latent_dim, n_units), 60 | nn.Tanh(), 61 | nn.Linear(n_units, input_dim), 62 | ) 63 | 64 | init_network_weights(self.decoder) 65 | 66 | def get_reconstruction( 67 | self, 68 | time_steps_to_predict, 69 | data, 70 | truth_time_steps, 71 | mask=None, 72 | n_traj_samples=None, 73 | mode=None, 74 | ): 75 | if (len(truth_time_steps) != len(time_steps_to_predict)) or ( 76 | torch.sum(time_steps_to_predict - truth_time_steps) != 0 77 | ): 78 | raise Exception("Extrapolation mode not implemented for ODE-RNN") # pylint: disable=broad-exception-raised 79 | 80 | # time_steps_to_predict and truth_time_steps should be the same 81 | assert len(truth_time_steps) == len(time_steps_to_predict) 82 | assert mask is not None 83 | 84 | data_and_mask = data 85 | if mask is not None: 86 | data_and_mask = torch.cat([data, mask], -1) 87 | 88 | _, _, latent_ys, _ = self.ode_gru.run_odernn(data_and_mask, truth_time_steps, run_backwards=False) 89 | 90 | latent_ys = latent_ys.permute(0, 2, 1, 3) 91 | last_hidden = latent_ys[:, :, -1, :] 92 | 93 | # assert(torch.sum(int_lambda[0,0,-1,:] <= 0) == 0.) 94 | 95 | outputs = self.decoder(latent_ys) 96 | # Shift outputs for computing the loss -- we should compare the first output to the second data point, etc. 97 | first_point = data[:, 0, :] 98 | outputs = shift_outputs(outputs, first_point) 99 | 100 | extra_info = {"first_point": (latent_ys[:, :, -1, :], 0.0, latent_ys[:, :, -1, :])} 101 | 102 | if self.use_binary_classif: 103 | if self.classif_per_tp: 104 | extra_info["label_predictions"] = self.classifier(latent_ys) 105 | else: 106 | extra_info["label_predictions"] = self.classifier(last_hidden).squeeze(-1) 107 | 108 | # outputs shape: [n_traj_samples, n_traj, n_tp, n_dims] 109 | return outputs, extra_info 110 | -------------------------------------------------------------------------------- /baseline_models/latent_ode_lib/parse_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Latent ODEs for Irregularly-Sampled Time Series 3 | Author: Yulia Rubanova 4 | """ 5 | 6 | import torch 7 | from torch.distributions import uniform 8 | from torch.utils.data import DataLoader 9 | 10 | from .utils import inf_generator, split_and_subsample_batch, split_train_test 11 | 12 | ##################################################################################################### 13 | 14 | 15 | def sine(trajectories_to_sample, device): 16 | t_end = 20.0 17 | t_nsamples = 200 18 | t_begin = t_end / t_nsamples 19 | t = torch.linspace(t_begin, t_end, t_nsamples).to(device).double() 20 | y = torch.sin(t) 21 | trajectories = y.view(1, -1, 1).repeat(trajectories_to_sample, 1, 1) 22 | return trajectories, t 23 | 24 | 25 | def dde_ramp_loading_time_sol(trajectories_to_sample, device): 26 | t_end = 20.0 27 | t_nsamples = 200 28 | t_begin = t_end / t_nsamples 29 | ti = torch.linspace(t_begin, t_end, t_nsamples).to(device).double() 30 | result = [] 31 | for t in ti: 32 | if t < 5: 33 | result.append(0) 34 | elif 5 <= t < 10: 35 | result.append((1.0 / 4.0) * ((t - 5) - 0.5 * torch.sin(2 * (t - 5)))) 36 | elif 10 <= t: 37 | result.append( 38 | (1.0 / 4.0) * ((t - 5) - (t - 10) - 0.5 * torch.sin(2 * (t - 5)) + 0.5 * torch.sin(2 * (t - 10))) 39 | ) 40 | y = torch.Tensor(result).to(device).double() / 5.0 41 | trajectories = y.view(1, -1, 1).repeat(trajectories_to_sample, 1, 1) 42 | return trajectories, ti 43 | 44 | 45 | def parse_datasets(args, device): 46 | def basic_collate_fn(batch, time_steps, args=args, device=device, data_type="train"): 47 | batch = torch.stack(batch) 48 | data_dict = {"data": batch, "time_steps": time_steps} 49 | 50 | data_dict = split_and_subsample_batch(data_dict, args, data_type=data_type) 51 | return data_dict 52 | 53 | dataset_name = args.dataset 54 | 55 | n_total_tp = args.timepoints + args.extrap 56 | max_t_extrap = args.max_t / args.timepoints * n_total_tp 57 | if dataset_name == "sine" or dataset_name == "dde_ramp_loading_time_sol": 58 | trajectories_to_sample = 1000 59 | if dataset_name == "sine": 60 | trajectories, t = sine(trajectories_to_sample, device) 61 | elif dataset_name == "dde_ramp_loading_time_sol": 62 | trajectories, t = dde_ramp_loading_time_sol(trajectories_to_sample, device) 63 | 64 | # # Normalise 65 | # samples = trajectories.shape[0] 66 | # dim = trajectories.shape[2] 67 | # traj = (trajectories.view(-1, dim) - trajectories.view(-1, 68 | # dim).mean(0)) / trajectories.view(-1, dim).std(0) 69 | # trajectories = torch.reshape(traj, (samples, -1, dim)) 70 | 71 | traj_index = torch.randperm(trajectories.shape[0]) # pyright: ignore 72 | train_split = int(0.8 * trajectories.shape[0]) # pyright: ignore 73 | test_split = int(0.9 * trajectories.shape[0]) # pyright: ignore 74 | train_trajectories = trajectories[traj_index[:train_split], :, :] # pyright: ignore 75 | test_trajectories = trajectories[traj_index[test_split:], :, :] # pyright: ignore 76 | 77 | # test_plot_traj = test_trajectories[0, :, :] 78 | 79 | input_dim = train_trajectories.shape[2] 80 | # output_dim = input_dim 81 | batch_size = 128 82 | 83 | train_dataloader = DataLoader( 84 | train_trajectories, # pyright: ignore 85 | batch_size=batch_size, 86 | shuffle=False, 87 | collate_fn=lambda batch: basic_collate_fn(batch, t, data_type="train"), 88 | ) 89 | test_dataloader = DataLoader( 90 | test_trajectories, # pyright: ignore 91 | batch_size=batch_size, 92 | shuffle=False, 93 | collate_fn=lambda batch: basic_collate_fn(batch, t, data_type="test"), 94 | ) 95 | 96 | data_objects = { 97 | "dataset_obj": "", 98 | "train_dataloader": inf_generator(train_dataloader), 99 | "test_dataloader": inf_generator(test_dataloader), 100 | "input_dim": input_dim, 101 | "n_train_batches": len(train_dataloader), 102 | "n_test_batches": len(test_dataloader), 103 | } 104 | return data_objects 105 | 106 | ########### 1d datasets ########### 107 | 108 | # Sampling args.timepoints time points in the interval [0, args.max_t] 109 | # Sample points for both training sequence and explapolation (test) 110 | distribution = uniform.Uniform(torch.Tensor([0.0]), torch.Tensor([max_t_extrap])) 111 | time_steps_extrap = distribution.sample(torch.Size([n_total_tp - 1]))[:, 0] 112 | time_steps_extrap = torch.cat((torch.Tensor([0.0]), time_steps_extrap)) 113 | time_steps_extrap = torch.sort(time_steps_extrap)[0] 114 | 115 | dataset_obj = None 116 | 117 | if dataset_obj is None: 118 | raise Exception(f"Unknown dataset: {dataset_name}") # pylint: disable=broad-exception-raised 119 | 120 | dataset = dataset_obj.sample_traj(time_steps_extrap, n_samples=args.n, noise_weight=args.noise_weight) 121 | 122 | # Process small datasets 123 | dataset = dataset.to(device).double() 124 | time_steps_extrap = time_steps_extrap.to(device).double() 125 | 126 | train_y, test_y = split_train_test(dataset, train_fraq=0.8) 127 | 128 | # n_samples = len(dataset) 129 | input_dim = dataset.size(-1) 130 | 131 | batch_size = min(args.batch_size, args.n) 132 | train_dataloader = DataLoader( 133 | train_y, 134 | batch_size=batch_size, 135 | shuffle=False, 136 | collate_fn=lambda batch: basic_collate_fn(batch, time_steps_extrap, data_type="train"), 137 | ) 138 | test_dataloader = DataLoader( 139 | test_y, 140 | batch_size=args.n, 141 | shuffle=False, 142 | collate_fn=lambda batch: basic_collate_fn(batch, time_steps_extrap, data_type="test"), 143 | ) 144 | 145 | data_objects = { # "dataset_obj": dataset_obj, 146 | "train_dataloader": inf_generator(train_dataloader), 147 | "test_dataloader": inf_generator(test_dataloader), 148 | "input_dim": input_dim, 149 | "n_train_batches": len(train_dataloader), 150 | "n_test_batches": len(test_dataloader), 151 | } 152 | 153 | return data_objects 154 | -------------------------------------------------------------------------------- /baseline_models/original_latent_ode.py: -------------------------------------------------------------------------------- 1 | """Ref: [Latent ODEs for Irregularly-Sampled Time Series](https://github.com/YuliaRubanova/latent_ode) 2 | """ 3 | 4 | import matplotlib 5 | import matplotlib.pyplot 6 | import torch 7 | from torch import nn 8 | 9 | from .latent_ode_lib.create_latent_ode_model import create_LatentODE_model_direct 10 | from .latent_ode_lib.plotting import Normal 11 | from .latent_ode_lib.utils import compute_loss_all_batches_direct 12 | 13 | matplotlib.use("Agg") 14 | 15 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 16 | 17 | 18 | class GeneralLatentODEOfficial(nn.Module): # pylint: disable=abstract-method 19 | def __init__( 20 | self, 21 | input_dim, 22 | classif_per_tp=False, 23 | n_labels=1, 24 | obsrv_std=0.01, 25 | latents=2, 26 | hidden_units=100, 27 | ): 28 | super(GeneralLatentODEOfficial, self).__init__() 29 | 30 | obsrv_std = torch.Tensor([obsrv_std]).to(DEVICE) 31 | 32 | z0_prior = Normal(torch.Tensor([0.0]).to(DEVICE), torch.Tensor([1.0]).to(DEVICE)) 33 | 34 | self.model = create_LatentODE_model_direct( 35 | input_dim, 36 | z0_prior, 37 | obsrv_std, 38 | DEVICE, 39 | classif_per_tp=classif_per_tp, 40 | n_labels=n_labels, 41 | latents=latents, 42 | units=hidden_units, 43 | gru_units=hidden_units, 44 | ).to(DEVICE) 45 | 46 | self.latents = latents 47 | 48 | def _get_loss(self, dl): 49 | loss = compute_loss_all_batches_direct(self.model, dl, device=DEVICE, classif=0) 50 | return loss["loss"], loss["mse"] 51 | 52 | def training_step(self, batch): 53 | loss = self.model.compute_all_losses(batch) 54 | return loss["loss"] 55 | 56 | def validation_step(self, dlval): 57 | loss, mse = self._get_loss(dlval) 58 | return loss, mse 59 | 60 | def test_step(self, dltest): 61 | loss, mse = self._get_loss(dltest) 62 | return loss, mse 63 | 64 | def predict(self, dl): 65 | predictions = [] 66 | for batch in dl: 67 | pred_y, _ = self.model.get_reconstruction( 68 | batch["tp_to_predict"], 69 | batch["observed_data"], 70 | batch["observed_tp"], 71 | mask=batch["observed_mask"], 72 | n_traj_samples=1, 73 | mode=batch["mode"], 74 | ) 75 | predictions.append(pred_y.squeeze(0)) 76 | return torch.cat(predictions, 0) 77 | 78 | def encode(self, dl): 79 | encodings = [] 80 | for batch in dl: 81 | mask = batch["observed_mask"] 82 | truth_w_mask = batch["observed_data"] 83 | if mask is not None: 84 | truth_w_mask = torch.cat((batch["observed_data"], mask), -1) 85 | # pylint: disable-next=unused-variable 86 | mean, std = self.model.encoder_z0(truth_w_mask, torch.flatten(batch["observed_tp"]), run_backwards=True) 87 | encodings.append(mean.view(-1, self.latents)) 88 | return torch.cat(encodings, 0) 89 | 90 | def _get_and_reset_nfes(self): 91 | """Returns and resets the number of function evaluations for model.""" 92 | iteration_nfes = ( # pyright: ignore 93 | self.model.encoder_z0.z0_diffeq_solver.ode_func.nfe # pyright: ignore 94 | + self.model.diffeq_solver.ode_func.nfe 95 | ) 96 | self.model.encoder_z0.z0_diffeq_solver.ode_func.nfe = 0 # pyright: ignore 97 | self.model.diffeq_solver.ode_func.nfe = 0 98 | return iteration_nfes 99 | -------------------------------------------------------------------------------- /envs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/samholt/NeuralLaplaceControl/e827944748538b7356ff6bda04aaac34a0da41d2/envs/__init__.py -------------------------------------------------------------------------------- /envs/oderl/.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/* 2 | !.vscode/settings.json 3 | !.vscode/tasks.json 4 | !.vscode/launch.json 5 | !.vscode/extensions.json 6 | *.code-workspace 7 | *.torch 8 | *.pth 9 | *.out 10 | *.pt 11 | *.mp4 12 | *.pkl 13 | video/.Rhistory 14 | *.idea 15 | *.ds 16 | *checkpoints 17 | *.txt 18 | *.html 19 | *.tgz 20 | *.npz 21 | *.sh 22 | 23 | *.mat 24 | *.json 25 | *.hkl 26 | *.data-00000-of-00001 27 | *.index 28 | *.meta 29 | *.png 30 | *.DS_Store 31 | *.tar.gz 32 | *.png 33 | *.amc 34 | *.eps 35 | *.jpg 36 | *.pyc 37 | *.ckpt 38 | *.m~ 39 | *.mexmaci64 40 | *.out 41 | *.mat 42 | *.spyderworkspace 43 | *.svn-base 44 | *.pkl 45 | *.npy 46 | *.gz 47 | 48 | # Byte-compiled / optimized / DLL files 49 | __pycache__/ 50 | *.py[cod] 51 | *$py.class 52 | 53 | # C extensions 54 | *.so 55 | 56 | # Distribution / packaging 57 | .Python 58 | build/ 59 | develop-eggs/ 60 | dist/ 61 | downloads/ 62 | eggs/ 63 | .eggs/ 64 | lib/ 65 | lib64/ 66 | parts/ 67 | sdist/ 68 | var/ 69 | wheels/ 70 | pip-wheel-metadata/ 71 | share/python-wheels/ 72 | *.egg-info/ 73 | .installed.cfg 74 | *.egg 75 | MANIFEST 76 | 77 | # PyInstaller 78 | # Usually these files are written by a python script from a template 79 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 80 | *.manifest 81 | *.spec 82 | 83 | # Installer logs 84 | pip-log.txt 85 | pip-delete-this-directory.txt 86 | 87 | # Unit test / coverage reports 88 | htmlcov/ 89 | .tox/ 90 | .nox/ 91 | .coverage 92 | .coverage.* 93 | .cache 94 | nosetests.xml 95 | coverage.xml 96 | *.cover 97 | *.py,cover 98 | .hypothesis/ 99 | .pytest_cache/ 100 | 101 | # Translations 102 | *.mo 103 | *.pot 104 | 105 | # Django stuff: 106 | *.log 107 | local_settings.py 108 | db.sqlite3 109 | db.sqlite3-journal 110 | 111 | # Flask stuff: 112 | instance/ 113 | .webassets-cache 114 | 115 | # Scrapy stuff: 116 | .scrapy 117 | 118 | # Sphinx documentation 119 | docs/_build/ 120 | 121 | # PyBuilder 122 | target/ 123 | 124 | # Jupyter Notebook 125 | .ipynb_checkpoints 126 | 127 | # IPython 128 | profile_default/ 129 | ipython_config.py 130 | 131 | # pyenv 132 | .python-version 133 | 134 | # pipenv 135 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 136 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 137 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 138 | # install all needed dependencies. 139 | #Pipfile.lock 140 | 141 | # celery beat schedule file 142 | celerybeat-schedule 143 | 144 | # SageMath parsed files 145 | *.sage.py 146 | 147 | # Environments 148 | .env 149 | .venv 150 | env/ 151 | venv/ 152 | ENV/ 153 | env.bak/ 154 | venv.bak/ 155 | 156 | # Spyder project settings 157 | .spyderproject 158 | .spyproject 159 | 160 | # Rope project settings 161 | .ropeproject 162 | 163 | # mkdocs documentation 164 | /site 165 | 166 | # mypy 167 | .mypy_cache/ 168 | .dmypy.json 169 | dmypy.json 170 | 171 | # Pyre type checker 172 | .pyre/ 173 | -------------------------------------------------------------------------------- /envs/oderl/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Cagatay Yildiz 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /envs/oderl/README.md: -------------------------------------------------------------------------------- 1 | 2 | # ODE-RL 3 | Experiment code for ICML 2021 paper [Continuous-time Model-based Reinforcement Learning](https://arxiv.org/pdf/2102.04764.pdf). Implemented in `Python 3.7.7` and `torch 1.6.0` (later versions should be OK). Also requires `torchdiffeq`, `TorchDiffEqPack` and `gym`. 4 | 5 | ![ENODE simulation](img1.png) 6 | 7 | ## Quick introduction 8 | - `runner.py` should run off-the-shelf. The file can be used to reproduce our results and it also demonstrates how to 9 | - create a continuous-time RL environment 10 | - initiate our model (with different variational formulations) as well as baselines (PETS & deep PILCO) 11 | - visualize the dynamics fits 12 | - execute the main learning loop (Algorithm-1 in the paper) 13 | - `ctrl` folder has our model implementation as well as helper functions for training. 14 | - `ctrl/ctrl`: creates our model and serves as an interface between the model and training/visualization functions. 15 | - `ctrl/dataset`: contains state-action-reward trajectories and interpolation (for continuous-time action) classes. 16 | - `ctrl/dynamics`: implements the dynamics model and is responsible for forward simulating all models. 17 | - `ctrl/policy`: deterministic policy implementation 18 | - `envs` contains our continuous-time implementation of RL environments. 19 | - `utils` includes the function approximators. 20 | -------------------------------------------------------------------------------- /envs/oderl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/samholt/NeuralLaplaceControl/e827944748538b7356ff6bda04aaac34a0da41d2/envs/oderl/__init__.py -------------------------------------------------------------------------------- /envs/oderl/ctrl/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import Dataset # noqa: F401 2 | from .dynamics import NODE, PETS, DeepPILCO # noqa: F401 3 | from .policy import Policy # noqa: F401 4 | -------------------------------------------------------------------------------- /envs/oderl/ctrl/ctrl.py: -------------------------------------------------------------------------------- 1 | import io 2 | import pickle 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | import envs.oderl.ctrl.dataset as ds 8 | from envs.oderl.utils import BNN 9 | 10 | from .dynamics import NODE, PETS, DeepPILCO 11 | from .policy import Policy 12 | 13 | DEFAULT_PAR_MAP = { 14 | "nl_f": 3, 15 | "nn_f": 200, 16 | "act_f": "elu", 17 | "dropout_f": 0.05, 18 | "n_ens": 10, 19 | "learn_sigma": False, 20 | "nl_g": 2, 21 | "nn_g": 200, 22 | "act_g": "relu", 23 | "nl_V": 2, 24 | "nn_V": 200, 25 | "act_V": "tanh", 26 | } 27 | 28 | 29 | class CTRL(nn.Module): # pylint: disable=abstract-method 30 | def __init__(self, env, dynamics, **kwargs): 31 | super().__init__() 32 | for PAR in DEFAULT_PAR_MAP: 33 | kwargs[PAR] = kwargs.get(PAR) # returns value in DEFAULT_PARS or None 34 | self.kwargs = kwargs 35 | self.env = env 36 | self.n_ens = self.kwargs["n_ens"] 37 | self.learn_sigma = self.kwargs["learn_sigma"] 38 | self.dynamics = dynamics 39 | self.set_solver("dopri5") 40 | self._g = Policy(self.env, nl=kwargs["nl_g"], nn=kwargs["nn_g"], act=kwargs["act_g"]) 41 | self.make_dynamics_model( 42 | nl_f=kwargs["nl_f"], 43 | nn_f=kwargs["nn_f"], 44 | act_f=kwargs["act_f"], 45 | dropout_f=kwargs["dropout_f"], 46 | ) 47 | self.V = BNN( 48 | self.env.n, 49 | 1, 50 | n_hid_layers=kwargs["nl_V"], 51 | act=kwargs["act_V"], 52 | n_hidden=kwargs["nn_V"], 53 | bnn=False, 54 | ) 55 | self.reset_parameters() 56 | 57 | @property 58 | def device(self): 59 | return next(self._g.parameters()).device 60 | 61 | @property 62 | def dtype(self): 63 | return torch.float32 64 | 65 | @property 66 | def sn(self): 67 | return self.logsn.exp() 68 | 69 | @property 70 | def dynamics_parameters(self): 71 | if self.learn_sigma: 72 | return [self.logsn] + list(self._f.parameters()) 73 | else: 74 | return list(self._f.parameters()) 75 | 76 | @property 77 | def is_cont(self): 78 | return "ode" in self.dynamics 79 | 80 | @property 81 | def name(self): 82 | return self.env.name + "-" + self.dynamics 83 | 84 | def make_dynamics_model(self, nl_f=2, nn_f=200, act_f="elu", dropout_f=0.05): 85 | if self.learn_sigma: 86 | self.logsn = torch.nn.Parameter(-torch.ones(self.env.n + self.env.m) * 3.0, requires_grad=True) 87 | else: 88 | self.register_buffer("logsn", -torch.ones(self.env.n + self.env.m) * 3.0) 89 | if self.is_cont: 90 | if dropout_f > 0.0: 91 | print("Dropout is set to 0 since NODE is running") 92 | self._f = NODE(self.env, self.dynamics, self.n_ens, nl=nl_f, nn=nn_f, act=act_f) 93 | elif self.dynamics == "pets": 94 | if dropout_f > 0.0: 95 | print("Dropout is set to 0 since PETS is running") 96 | self._f = PETS(self.env, "epnn", self.n_ens, nl=nl_f, nn=nn_f, act=act_f) 97 | elif self.dynamics == "deep_pilco": 98 | self._f = DeepPILCO( 99 | self.env, 100 | "dbnn", 101 | self.n_ens, 102 | nl=nl_f, 103 | nn=nn_f, 104 | act=act_f, 105 | dropout=dropout_f, 106 | ) 107 | 108 | def set_solver(self, solver): 109 | assert solver in ["euler", "midpoint", "rk4", "dopri5", "rk23", "rk45"] 110 | self.solver = {} 111 | self.solver["method"] = solver 112 | self.solver["step_size"] = self.env.dt / 10 # in case fixed step solvers are used 113 | self.solver["rtol"] = 1e-3 114 | self.solver["atol"] = 1e-6 115 | 116 | def draw_f(self, L=1, noise_vec=None, true_rhs=False): 117 | if self._f.ens_method: 118 | return self._f._f.draw_f() 119 | if noise_vec is None: 120 | noise_vec = self.draw_noise(L) 121 | return self._f._f.draw_f(L, noise_vec) # pyright: ignore # TODO - check if mean is a parameter 122 | 123 | def get_L(self, L=1): 124 | """returns the number of samples from the function 125 | this is needed as ensembles have a fixed number of possible fnc draws. 126 | """ 127 | return self.n_ens if self._f.ens_method else L 128 | 129 | def draw_noise(self, L=1, true_rhs=False): 130 | return self._f._f.draw_noise(L=L) 131 | 132 | def forward_simulate(self, H_ts, s0, g, f=None, L=10, tau=None, compute_rew=False): 133 | """Performs forward simulation for L different vector fields 134 | If H_ts is a float, then we form a uniform time grid for integration [0, dt, 2dt, ..., H_ts]. 135 | If H_ts is a torch vector (possibly nonuniform), H_ts is used as the integration time points 136 | Inputs 137 | H_ts - either a float denoting the integration time (in seconds) or integration time points 138 | s0 - [N,n] initial values 139 | g - policy function that interpolates between actions 140 | Outputs 141 | st - [L,N,T,n] 142 | rt - [L,N,T,n] 143 | at - dict of {t:a}, {[T]:[L,N,m]} 144 | t - [N,T] 145 | """ 146 | L = self.get_L(L) 147 | if f is None: 148 | f = self.draw_f(L, None) 149 | # integration time points is a uniform grid 150 | if isinstance(H_ts, float) or isinstance(H_ts, int): 151 | return self._f.forward_simulate( 152 | solver=self.solver, 153 | H=H_ts, 154 | s0=s0, 155 | f=f, 156 | g=g, 157 | L=L, 158 | tau=tau, 159 | compute_rew=compute_rew, 160 | ) 161 | else: 162 | return self._f.forward_simulate_nonuniform_ts( 163 | solver=self.solver, 164 | ts=H_ts, 165 | s0=s0, 166 | f=f, 167 | g=g, 168 | L=L, 169 | tau=tau, 170 | compute_rew=compute_rew, 171 | ) 172 | 173 | def reset_parameters(self, w=0.1): 174 | self._f.reset_parameters(w) 175 | self._g.reset_parameters(w) 176 | self.V.reset_parameters(w) 177 | nn.init.uniform_(self.logsn, -1.0, -1.0) 178 | 179 | @staticmethod 180 | def load(env, fname, verbose=True): 181 | fname = fname[:-4] if fname.endswith(".pkl") else fname 182 | if verbose: 183 | print("{:s} is loading.".format(fname)) 184 | 185 | class CPU_Unpickler(pickle.Unpickler): 186 | def find_class(self, module, name): 187 | if module == "torch.storage" and name == "_load_from_bytes": 188 | return lambda b: torch.load(io.BytesIO(b), map_location="cpu") 189 | else: 190 | return super().find_class(module, name) 191 | 192 | f = open(fname + ".pkl", "rb") 193 | stuff = CPU_Unpickler(f).load() 194 | dynamics, kwargs, state_dict = ( 195 | stuff["dynamics"], 196 | stuff["kwargs"], 197 | stuff["state_dict"], 198 | ) 199 | ctrl = CTRL(env, dynamics, **kwargs).to(env.device) 200 | ctrl.load_state_dict(state_dict) 201 | ctrl.eval() 202 | if "D_D" in list(stuff.keys()): 203 | D, ts = stuff["D_D"], stuff["D_ts"] 204 | D = ds.Dataset(env, D, ts).to(env.device) 205 | if verbose: 206 | print(D.shape) 207 | print(stuff["dynamics"]) 208 | else: 209 | D = None 210 | return ctrl.to(env.device), D 211 | 212 | def save(self, D=None, fname=None, verbose=False): 213 | if fname is None: 214 | fname = self.name 215 | if verbose: 216 | print("model save name is {:s}".format(fname)) 217 | state_dict = self.state_dict() 218 | save_dict = {} 219 | save_dict["state_dict"] = state_dict 220 | save_dict["kwargs"] = self.kwargs 221 | save_dict["dynamics"] = self.dynamics 222 | if D is not None: 223 | save_dict["D_D"] = D.D 224 | save_dict["D_ts"] = D.ts 225 | pickle.dump(save_dict, open(fname + ".pkl", "wb")) 226 | 227 | def __repr__(self): 228 | text = "Env solver: " + str(self.env.solver) + "\n" 229 | text += "Model solver: " + str(self.solver) + "\n" 230 | text += self._f.__repr__() + "\n" 231 | text += self._g.__repr__() + "\n" 232 | text += self.V.__repr__() + "\n" 233 | return text 234 | -------------------------------------------------------------------------------- /envs/oderl/ctrl/dataset.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from envs.oderl.utils.utils import KernelInterpolation 7 | 8 | 9 | class Dataset: 10 | def __init__(self, env, D, ts): 11 | D = Dataset.__compute_rewards_if_needed(env, D) # N,T,n+m+1 12 | self.env = env 13 | self.n = env.n 14 | self.m = env.m 15 | self.D = D 16 | self.ts = ts 17 | 18 | @property 19 | def device(self): 20 | return self.env.device 21 | 22 | @property 23 | def shape(self): 24 | return self.D.shape 25 | 26 | @property 27 | def dt(self): 28 | return self.env.dt 29 | 30 | @property 31 | def N(self): 32 | return self.shape[0] 33 | 34 | @property 35 | def T(self): 36 | return self.shape[1] 37 | 38 | @property 39 | def s(self): 40 | return self.D[:, :, : self.n] 41 | 42 | @property 43 | def a(self): 44 | return self.D[:, :, self.n : self.n + self.m] 45 | 46 | @property 47 | def sa(self): 48 | return self.D[:, :, : self.n + self.m] 49 | 50 | @property 51 | def r(self): 52 | return self.D[:, :, -1:] 53 | 54 | def clone(self): 55 | return copy.deepcopy(self) 56 | 57 | def add_experience(self, Dnew, ts): 58 | assert len(Dnew.shape) == 3, "New experience must be a 3D torch tensor" # N,T,nm 59 | Dnew = self.__compute_rewards_if_needed(self.env, Dnew) # N,T,n+m+1 60 | Dnew, ts = Dnew.to(self.device), ts.to(self.device) 61 | self.D = torch.cat([self.D, Dnew]) 62 | self.ts = torch.cat([self.ts, ts]) 63 | 64 | def crop_last(self, N=1): 65 | self.D = self.D[:-N] 66 | self.ts = self.ts[:-N] 67 | 68 | @staticmethod 69 | def __compute_rewards_if_needed(env, D): 70 | """returns (s,a,r)""" 71 | assert len(D.shape) == 3, "Dataset must be a 3D torch tensor" # N,T,nm 72 | if D.shape[-1] == env.n + env.m: 73 | [N, T, nm] = D.shape # pylint: disable=unused-variable 74 | with torch.no_grad(): 75 | s_ = D[:, :, : env.n].view([-1, env.n]) 76 | a_ = D[:, :, env.n :].view([-1, env.m]) 77 | rewards = env.diff_reward(s_, a_).view([N, T, 1]) 78 | D = torch.cat([D, rewards], 2) # N,T,n+m+1 79 | return D 80 | 81 | def to(self, device): 82 | self.D = self.D.to(device) 83 | self.ts = self.ts.to(device) 84 | return self 85 | 86 | def extract_data(self, H, cont, nrep=1, idx=None): 87 | """extracts sequences randomly subsequenced from the dataset 88 | H - in second 89 | cont - boolean denoting whether the system is continuous 90 | returns 91 | g - policy or None 92 | st - [N,T,n] 93 | at - [N,T,m] 94 | rt - [N,T,1] 95 | """ 96 | idx = list(np.arange(0, self.N)) if idx is None else list(idx) 97 | T = int(H / self.dt) # convert sec to # data points 98 | idx = [item for sublist in nrep * [idx] for item in sublist] 99 | t0s = torch.tensor(np.random.randint(0, 1 + self.T - T, len(idx)), dtype=torch.int32).to(self.device) 100 | st_at_rt = torch.stack([self.D[seq_idx_, t0 : t0 + T] for t0, seq_idx_ in zip(t0s, idx)]) 101 | st, at, rt = ( 102 | st_at_rt[:, :, : self.n], 103 | st_at_rt[:, :, self.n : self.n + self.m], 104 | st_at_rt[:, :, -1:], 105 | ) 106 | ts = torch.stack([self.ts[seq_idx_, t0 : t0 + T] for t0, seq_idx_ in zip(t0s, idx)]) 107 | g = self.__extract_policy(idx, at=at, cont=cont, ts=ts, T=T) 108 | return g, st, at, rt, ts 109 | 110 | def __extract_policy(self, idx, at, cont, ts, T): 111 | if cont: 112 | return KernelInterpolatePolicy(at, ts) 113 | else: 114 | return DiscreteActions(at, ts) 115 | 116 | 117 | class DiscreteActions: 118 | def __init__(self, at, ts): 119 | if len(at.shape) != 3: 120 | raise ValueError("Actions must be 3D!\n") 121 | self.at = at.to(at.device) # N,T,m 122 | self.ts = ts.to(at.device) # N,T 123 | self.N = self.ts.shape[0] 124 | self.max_idx = self.at.shape[1] - 1 125 | 126 | def __call__(self, s, t): 127 | # t = t.item() if isinstance(t,torch.Tensor) else t 128 | if t[0].item() > self.ts[0, -1].item(): # actions outside the defined range 129 | actions = self.at[:, -1] 130 | else: 131 | before_idx = [(t[i] + 1e-5 > self.ts[i]).sum().item() - 1 for i in range(self.N)] 132 | before_idx = [min(item, self.max_idx) for item in before_idx] 133 | actions = self.at[np.arange(self.N), before_idx] 134 | if actions.isnan().sum() > 0: 135 | raise ValueError("Action interpolation is wrong!") 136 | if s.ndim == 2: 137 | return actions 138 | elif s.ndim == 3: 139 | return torch.stack([actions] * s.shape[0]) 140 | elif s.ndim == 4: 141 | tmp = torch.stack([actions] * s.shape[1]) 142 | return torch.stack([tmp] * s.shape[0]) 143 | 144 | 145 | class KernelInterpolatePolicy: 146 | def __init__(self, at, ts): 147 | [N, T, m] = at.shape # pylint: disable=unused-variable 148 | sfs = 1.0 * torch.ones([N, 1, 1], device=at.device, dtype=torch.float32) 149 | ells = 0.5 * torch.ones([N, 1, 1], device=at.device, dtype=torch.float32) 150 | self.kernel_int = KernelInterpolation(sfs, ells, ts.unsqueeze(-1), at, eps=1e-5) 151 | 152 | def __call__(self, s, t): 153 | actions = self.kernel_int(t.unsqueeze(-1).unsqueeze(-1)) # N,1,n_out 154 | actions = actions.permute(1, 0, 2) # 1,N,n_out 155 | if s.ndim == 2: 156 | return actions 157 | elif s.ndim == 3: 158 | return torch.cat([actions] * s.shape[0]) 159 | elif s.ndim == 4: 160 | tmp = torch.stack([actions] * s.shape[1]) 161 | return torch.stack([tmp] * s.shape[0]) 162 | -------------------------------------------------------------------------------- /envs/oderl/ctrl/dynamics.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | 7 | from envs.oderl.utils import BENN, ENN, EPNN, IBNN, DropoutBNN 8 | from envs.oderl.utils.utils import odesolve 9 | 10 | 11 | class Dynamics(nn.Module, metaclass=ABCMeta): 12 | @abstractmethod 13 | def __init__( 14 | self, 15 | env, 16 | dynamics, 17 | L, 18 | nl=2, 19 | nn=100, # pylint: disable=redefined-outer-name 20 | act="relu", 21 | dropout=0.0, 22 | bnn=False, 23 | ): 24 | super().__init__() 25 | n, m = env.n, env.m 26 | self.qin, self.qout = n + m, n 27 | self.qin = self.qin 28 | self.env = env 29 | self.dynamics = dynamics 30 | self.L = L 31 | self.ens_method = False 32 | 33 | if self.dynamics == "ibnode": 34 | self._f = IBNN(L, self.qin, self.qout, n_hid_layers=nl, n_hidden=nn, act=act) 35 | 36 | elif self.dynamics == "benode": 37 | self._f = BENN(L, self.qin, self.qout, n_hid_layers=nl, n_hidden=nn, act=act) 38 | self.ens_method = True 39 | 40 | elif self.dynamics == "enode": 41 | self._f = ENN(L, self.qin, self.qout, n_hid_layers=nl, n_hidden=nn, act=act) 42 | self.ens_method = True 43 | 44 | elif self.dynamics == "pets": 45 | self._f = EPNN(L, self.qin, self.qout, n_hid_layers=nl, n_hidden=nn, act=act) 46 | self.ens_method = True 47 | 48 | elif self.dynamics == "deep_pilco": 49 | self._f = DropoutBNN( 50 | self.qin, 51 | self.qout, 52 | n_hid_layers=nl, 53 | n_hidden=nn, 54 | act=act, 55 | dropout_rate=dropout, 56 | ) 57 | 58 | self.reset_parameters() 59 | 60 | @property 61 | def device(self): 62 | return self._f.device 63 | 64 | def reset_parameters(self, w=0.1): 65 | self._f.reset_parameters(w) 66 | 67 | def kl(self): 68 | try: 69 | return self._f.kl() # pyright: ignore 70 | except Exception: # pylint: disable=broad-exception-caught 71 | return torch.Tensor(np.zeros(1) * 1.0).to(self.device) 72 | 73 | def ds_dt(self, f, s, a): 74 | return f(torch.cat([s, a], -1)) 75 | 76 | def dv_dt(self, t, s, a, v, tau, compute_rew): 77 | r = self.env.diff_reward(s, a).unsqueeze(2) if compute_rew else v # L,N,1 78 | if tau is not None: 79 | t = t.item() if isinstance(t, torch.Tensor) else t 80 | r *= np.exp(-t / tau) 81 | return r 82 | 83 | def forward_simulate(self, solver, H, s0, f, g, L, tau=None, compute_rew=True): 84 | # starting from t=0 85 | T = int(H / self.env.dt) 86 | ts = self.env.dt * torch.arange(T + 1, dtype=torch.float32, device=self._f.device) 87 | t0s = torch.zeros(s0.shape[0], dtype=torch.float32, device=self._f.device) 88 | st, rt, at = self._forward_simulate(solver, ts, t0s, s0, f, g, L, tau=tau, compute_rew=compute_rew) 89 | return st, rt, at, torch.stack([ts[:-1]] * s0.shape[0]) 90 | 91 | def forward_simulate_nonuniform_ts(self, solver, ts, s0, f, g, L, tau=None, compute_rew=True): 92 | # all t0s are set to be zero 93 | [N, T] = ts.shape 94 | ts_norm = ts - ts[:, 0:1] # all starting from 0 [[.0 .1 .4],[.1 .3 .4]] --> [[.0 .1 .4],[.0 .2 .3]] 95 | ts_ode = ts_norm[:, 1:].reshape(-1).unique() # [.0 .1 .4 .2 .3] 96 | ts_ode = torch.cat([torch.zeros(1, device=ts_ode.device), ts_ode]) # handle numerical issues 97 | ts_ode_sorted, _ = ts_ode.sort() # [.0 .1 .2 .3 .4] 98 | ts_ode_sorted = torch.cat([ts_ode_sorted, ts_ode_sorted[-1:] + 1e-3]) # handle numerical issues 99 | Tidx = [[torch.where(ts_norm[n, t] == ts_ode_sorted)[0].item() for t in range(T)] for n in range(N)] 100 | st, rt, at = self._forward_simulate( 101 | solver, 102 | ts_ode_sorted, 103 | ts[:, 0], 104 | s0, 105 | f, 106 | g, 107 | L, 108 | tau=tau, 109 | compute_rew=compute_rew, 110 | ) 111 | sts = [st[:, i, Tidx[i]] for i in range(N)] 112 | rts = [rt[:, i, Tidx[i]] for i in range(N)] 113 | return torch.stack(sts, 1), torch.stack(rts, 1), at, ts 114 | 115 | @abstractmethod 116 | def _forward_simulate(self, solver, ts, t0s, s0, f, g, L, tau, compute_rew): 117 | """Performs forward simulation for L different vector fields 118 | ts - [T+1], starting from 0 119 | t0s - [N] 120 | Output 121 | st - [L,N,T,n] 122 | rt - [L,N,T,n] 123 | at - [L,N,T-1,m] 124 | t - [T] 125 | """ 126 | raise NotImplementedError 127 | 128 | 129 | class NODE(Dynamics): # pylint: disable=abstract-method 130 | def __init__( 131 | self, 132 | env, 133 | dynamics, 134 | L, 135 | nl=2, 136 | nn=100, # pylint: disable=redefined-outer-name 137 | act="relu", 138 | ): 139 | super().__init__(env, dynamics, L, nl=nl, nn=nn, act=act, dropout=0.0) 140 | self.at = dict() 141 | 142 | def odef(self, t, sv, f, g, t0s=None, tau=None, compute_rew=True): 143 | """Input 144 | t - current time (add t0s to get inputs to the policy!) 145 | sv - state&value - [Nens,N,n+1] 146 | f - time differential 147 | g - action function 148 | """ 149 | t = t if isinstance(t, torch.Tensor) else torch.tensor(t).to(self.device) 150 | s, v = sv[:, :, :-1], sv[:, :, -1:] 151 | a = g(s, t + t0s) # L,N,m 152 | self.at[t + t0s] = a 153 | ds = self.ds_dt(f, s, a) 154 | dv = self.dv_dt(t, s, a, v, tau, compute_rew) 155 | return torch.cat([ds, dv], -1) # Nens,N,n+1 156 | 157 | def _forward_simulate(self, solver, ts, t0s, s0, f, g, L, tau=None, compute_rew=True): 158 | T = len(ts) - 1 159 | # N,n: 160 | [N, n] = s0.shape # pylint: disable=unused-variable 161 | s0 = torch.stack([s0] * L) # Nens,N,n 162 | r0 = torch.zeros(s0.shape[:-1], device=s0.device).unsqueeze(2) # Nens,N,1 163 | s0r0 = torch.cat([s0, r0], -1) # Nens,N,n+1 164 | odef = lambda t, s: self.odef(t, s, f, g, t0s, tau=tau, compute_rew=compute_rew) # noqa: E731 165 | strt = odesolve( 166 | odef, 167 | s0r0, 168 | ts, 169 | solver["step_size"], 170 | solver["method"], 171 | solver["rtol"], 172 | solver["atol"], 173 | ) 174 | # L,N,T,n & L,N,T: 175 | st, rt = strt[:T, ..., :n].permute([1, 2, 0, 3]), strt[:T, ..., -1].permute([1, 2, 0]) # pyright: ignore 176 | return st, rt, self.at 177 | 178 | 179 | class PETS(Dynamics): # pylint: disable=abstract-method 180 | def __init__( 181 | self, 182 | env, 183 | dynamics, 184 | L, 185 | nl=2, 186 | nn=100, # pylint: disable=redefined-outer-name 187 | act="relu", 188 | ): 189 | super().__init__(env, dynamics, L, nl=nl, nn=nn, act=act, dropout=0.0) 190 | self.P = 20 191 | 192 | def _forward_simulate(self, solver, ts, t0s, s0, f, g, L, tau=None, compute_rew=True): 193 | H = len(ts) - 1 194 | [N, n] = s0.shape # N,n 195 | s0 = torch.cat([s0] * self.P) # PN,n 196 | s0 = torch.stack([s0] * L) # Nens,PN,n 197 | V0 = torch.zeros([*s0.shape[:-1], 1], device=s0.device) # Nens,PN,1 198 | st, Vt, at = [s0], [V0], {} 199 | delta_t = ts[1:] - ts[:-1] 200 | for t_, delta_t_ in zip(ts, delta_t): # 0 & 0.1 201 | a = g(st[-1].reshape(L, self.P, N, n), t_ + t0s).reshape(L, self.P * N, -1) # Nens,PN,m 202 | at[t_ + t0s] = a 203 | dV = self.dv_dt(t_, st[-1], a, Vt[-1], tau, compute_rew) 204 | V = Vt[-1] + delta_t_ * dV 205 | Vt.append(V) 206 | ds = self.ds_dt(f, st[-1], a) 207 | s = st[-1] + delta_t_ * ds 208 | st.append(s) 209 | self._f.shuffle() # pyright: ignore 210 | st, Vt = torch.stack(st)[:H].permute(1, 2, 0, 3), torch.stack(Vt)[:H].permute(1, 2, 0, 3).squeeze(-1) 211 | # Nens,PN,T,n & Nens,PN,T & Nens,PN,T-1,m 212 | st = st.reshape([L, self.P, N, H, n]).view(L * self.P, N, H, n) 213 | Vt = Vt.reshape([L, self.P, N, H]).view(L * self.P, N, H) 214 | return st, Vt, at 215 | 216 | 217 | class DeepPILCO(Dynamics): # pylint: disable=abstract-method 218 | def __init__( 219 | self, 220 | env, 221 | dynamics, 222 | L, 223 | nl=2, 224 | nn=100, # pylint: disable=redefined-outer-name 225 | act="relu", 226 | dropout=0.0, 227 | ): 228 | super().__init__(env, "dbnn", L, nl=nl, nn=nn, act=act, dropout=dropout, bnn=True) 229 | 230 | def _forward_simulate(self, solver, ts, t0s, s0, f, g, L, tau=None, compute_rew=True): 231 | H = len(ts) - 1 232 | # N,n: 233 | [N, n] = s0.shape # pylint: disable=unused-variable 234 | s0 = torch.stack([s0] * L) # Nens,N,n 235 | V0 = torch.zeros([*s0.shape[:-1], 1], device=s0.device) # Nens,N,1 236 | st, Vt, at = [s0], [V0], {} 237 | delta_t = ts[1:] - ts[:-1] 238 | for t_, delta_t_ in zip(ts, delta_t): # 0 & 0.1 239 | a = g(st[-1], t_ + t0s) # L,N,m 240 | at[t_ + t0s] = a 241 | dV = self.dv_dt(t_, st[-1], a, Vt[-1], tau, compute_rew) 242 | V = Vt[-1] + delta_t_ * dV 243 | Vt.append(V) 244 | ds = self.ds_dt(f, st[-1], a) 245 | s = st[-1] + delta_t_ * ds 246 | mu, sig = s.mean(0), s.std(0) 247 | s_new = torch.randn_like(s) * sig + mu 248 | st.append(s_new) 249 | st, Vt = torch.stack(st)[:H].permute(1, 2, 0, 3), torch.stack(Vt)[:H].permute( 250 | 1, 2, 0, 3 251 | ) # Nens,N,T,n & Nens,N,T,1 252 | Vt = Vt.squeeze(-1) # Nens,N,T 253 | return st, Vt, at 254 | -------------------------------------------------------------------------------- /envs/oderl/ctrl/policy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from envs.oderl.utils import BNN 5 | 6 | tanh_ = torch.nn.Tanh() 7 | 8 | 9 | def final_activation(env, a): 10 | return tanh_(a) * env.act_rng 11 | 12 | 13 | class Policy(nn.Module): 14 | def __init__(self, env, nl=2, nn=100, act="relu"): # pylint: disable=redefined-outer-name 15 | super().__init__() 16 | self.env = env 17 | self.act = act 18 | self._g = BNN(env.n, env.m, n_hid_layers=nl, act=act, n_hidden=nn, dropout=0.0, bnn=False) 19 | self.reset_parameters() 20 | 21 | def reset_parameters(self, w=0.1): 22 | self._g.reset_parameters(w) 23 | 24 | def forward(self, s, t): 25 | a = self._g(s) 26 | return final_activation(self.env, a) 27 | -------------------------------------------------------------------------------- /envs/oderl/env_simulator.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import imageio 4 | import pyvirtualdisplay 5 | import torch 6 | from tqdm import tqdm 7 | 8 | import ctrl.ctrl as base 9 | import envs 10 | 11 | os.environ["KMP_DUPLICATE_LIB_OK"] = "True" 12 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 'cpu' 13 | torch.set_default_dtype(torch.float32) 14 | 15 | ################## environment and dataset ################## 16 | dt = 0.1 # mean time difference between observations 17 | noise = 0.0 # observation noise std 18 | ts_grid = "fixed" # the distribution for the observation time differences: ['fixed','uniform','exp'] 19 | # ENV_CLS = envs.CTCartpole # [CTPendulum, CTCartpole, CTAcrobot] 20 | # [CTPendulum, CTCartpole, CTAcrobot]: 21 | ENV_CLS = envs.CTAcrobot # pyright: ignore 22 | env = ENV_CLS( 23 | dt=dt, 24 | obs_trans=True, 25 | device=device, 26 | obs_noise=noise, 27 | ts_grid=ts_grid, 28 | solver="euler", 29 | ) 30 | # D = utils.collect_data(env, H=5.0, N=env.N0) 31 | 32 | 33 | ################## model ################## 34 | dynamics = "enode" # ensemble of neural ODEs 35 | # dynamics = 'benode' # batch ensemble of neural ODEs 36 | # dynamics = 'ibnode' # implicit BNN ODEs 37 | # dynamics = 'pets' # PETS 38 | # dynamics = 'deep_pilco' # deep PILCO 39 | n_ens = 5 # ensemble size 40 | nl_f = 3 # number of hidden layers in the differential function 41 | nn_f = 200 # number of hidden neurons in each hidden layer of f 42 | act_f = "elu" # activation of f (should be smooth) 43 | dropout_f = 0.05 # dropout parameter (needed only for deep pilco) 44 | learn_sigma = False # whether to learn the observation noise or keep it fixed 45 | nl_g = 2 # number of hidden layers in the policy function 46 | nn_g = 200 # number of hidden neurons in each hidden layer of g 47 | act_g = "relu" # activation of g 48 | nl_V = 2 # number of hidden layers in the state-value function 49 | nn_V = 200 # number of hidden neurons in each hidden layer of V 50 | act_V = "tanh" # activation of V (should be smooth) 51 | 52 | ctrl = base.CTRL( 53 | env, 54 | dynamics, 55 | n_ens=n_ens, 56 | learn_sigma=learn_sigma, 57 | nl_f=nl_f, 58 | nn_f=nn_f, 59 | act_f=act_f, 60 | dropout_f=dropout_f, 61 | nl_g=nl_g, 62 | nn_g=nn_g, 63 | act_g=act_g, 64 | nl_V=nl_V, 65 | nn_V=nn_V, 66 | act_V=act_V, 67 | ).to(device) 68 | 69 | print("Env dt={:.3f}\nObservation noise={:.3f}\nTime increments={:s}".format(env.dt, env.obs_noise, str(env.ts_grid))) 70 | 71 | display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start() # pyright: ignore 72 | 73 | 74 | j = 0 75 | filename = f"./logs/videos/out_video_{j}.mp4" 76 | fps = int(1 / 0.05) 77 | state = env.reset() 78 | done = False 79 | 80 | 81 | def g(state, t): # pylint: disable=redefined-outer-name 82 | return torch.from_numpy(env.action_space.sample()) 83 | 84 | 85 | with imageio.get_writer(filename, fps=fps) as video: 86 | for iter_ in tqdm(range(100)): 87 | # action = env.action_space.sample() 88 | returns = env.integrate_system(2, g, s0=state, return_states=True) 89 | state = returns[-1][-1] 90 | env.set_state_(state) 91 | # state_next, reward, done, info = env.integrate_system() 92 | video.append_data(env.render(mode="rgb_array", last_act=returns[1][-1])) # pyright: ignore 93 | env.close() 94 | 95 | 96 | # ################## learning ################## 97 | # utils.plot_model(ctrl, D, L=30, H=2.0, rep_buf=10, fname=ctrl.name+'-train.png') 98 | # utils.plot_test( ctrl, D, L=30, H=2.5, N=5, fname=ctrl.name+'-test.png') 99 | 100 | # utils.train_loop(ctrl, D, ctrl.name, 50, L=30, H=2.0) 101 | 102 | # ctrl.save(D=D,fname=ctrl.name) # save the model & dataset 103 | # ctrl_,D_ = base.CTRL.load(env, f'{fname}') # load the model & dataset 104 | -------------------------------------------------------------------------------- /envs/oderl/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_env import BaseEnv # noqa: F401 2 | from .ctacrobot import CTAcrobot # noqa: F401 3 | from .ctcartpole import CTCartpole # noqa: F401 4 | from .ctpendulum import CTPendulum # noqa: F401 5 | -------------------------------------------------------------------------------- /envs/oderl/envs/base_env.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | import gym 4 | import numpy as np 5 | import torch 6 | from gym import spaces 7 | from gym.utils import seeding 8 | from torchdiffeq import odeint 9 | 10 | from envs.oderl.utils.utils import numpy_to_torch 11 | 12 | 13 | class BaseEnv(gym.Env, metaclass=ABCMeta): 14 | @abstractmethod 15 | def __init__( 16 | self, 17 | dt, 18 | n, 19 | m, 20 | act_rng, 21 | obs_trans, 22 | name, 23 | state_actions_names, 24 | device, 25 | solver, 26 | obs_noise, 27 | ts_grid, 28 | ac_rew_const=0.01, 29 | vel_rew_const=0.01, 30 | n_steps=200, 31 | ): 32 | self.dt = dt 33 | self.n = n 34 | self.m = m 35 | self.act_rng = act_rng 36 | self.obs_trans = obs_trans 37 | self.name = name 38 | self.reward_range = [-ac_rew_const * act_rng**2, 1.0] # pyright: ignore 39 | self.state_actions_names = state_actions_names 40 | self.ac_rew_const = ac_rew_const 41 | self.vel_rew_const = vel_rew_const 42 | self.obs_noise = obs_noise 43 | self.ts_grid = ts_grid 44 | # derived 45 | self.viewer = None 46 | self.action_space = spaces.Box(low=-self.act_rng, high=self.act_rng, shape=(self.m,)) 47 | self.seed() 48 | self.ac_lb = numpy_to_torch(self.action_space.low, device=device) # pyright: ignore 49 | self.ac_ub = numpy_to_torch(self.action_space.high, device=device) # pyright: ignore 50 | self.set_solver(method=solver) 51 | 52 | self.n_steps = n_steps 53 | self.time_step = 0 54 | 55 | def set_solver(self, method="euler", rtol=1e-6, atol=1e-9, num_bins=None): 56 | if num_bins is None: 57 | if method == "euler": 58 | num_bins = 1 59 | elif method == "rk4": 60 | num_bins = 50 61 | else: 62 | num_bins = 1 63 | self.solver = { 64 | "method": method, 65 | "rtol": rtol, 66 | "atol": atol, 67 | "step_size": self.dt / num_bins, 68 | } 69 | 70 | def seed(self, seed=None): 71 | self.np_random, seed = seeding.np_random(seed) 72 | return [seed] 73 | 74 | @property 75 | def device(self): 76 | return self.ac_lb.device 77 | 78 | def close(self): 79 | if self.viewer: 80 | self.viewer.close() 81 | self.viewer = None 82 | 83 | def get_obs(self): 84 | if self.obs_trans: 85 | torch_state = torch.tensor(self.state).unsqueeze(0) # pyright: ignore # pylint: disable=no-member 86 | # return list(self.torch_transform_states(torch_state)[0].numpy()) 87 | return self.torch_transform_states(torch_state)[0].numpy() 88 | else: 89 | return self.state # pyright: ignore # pylint: disable=no-member 90 | 91 | def reward(self, obs, a): 92 | return self.np_obs_reward_fn(obs) + self.np_ac_reward_fn(a) # pyright: ignore # pylint: disable=no-member 93 | 94 | def diff_reward(self, s, a): 95 | if not isinstance(s, torch.Tensor) or not isinstance(a, torch.Tensor): 96 | raise NotImplementedError("Differentiable reward only accepts torch.Tensor inputs\n") 97 | return self.diff_obs_reward_(s) + self.diff_ac_reward_(a) 98 | 99 | def build_time_grid(self, T=None, only_one_step=True, device=None): 100 | if device is None: 101 | device = self.device 102 | if only_one_step: 103 | if self.ts_grid == "fixed": 104 | ts = torch.arange(2, device=device) * self.dt 105 | elif self.ts_grid == "uniform" or self.ts_grid == "random": 106 | ts = torch.cat( 107 | ( 108 | torch.tensor([0.0], device=device), 109 | (torch.rand(1, device=device) * 2 * self.dt), 110 | ) 111 | ) 112 | elif self.ts_grid == "exp": 113 | ts = torch.cat( 114 | ( 115 | torch.tensor([0.0], device=device), 116 | torch.distributions.exponential.Exponential(1 / self.dt) 117 | .sample([1]) # pyright: ignore 118 | .to(device), 119 | ) 120 | ) 121 | else: 122 | raise ValueError("Time grid parameter is wrong!") 123 | return ts 124 | else: 125 | if self.ts_grid == "fixed": 126 | ts = torch.arange(T, device=device) * self.dt # pyright: ignore 127 | elif self.ts_grid == "uniform" or self.ts_grid == "random": 128 | ts = (torch.rand(T, device=device) * 2 * self.dt).cumsum(0) # pyright: ignore 129 | elif self.ts_grid == "exp": 130 | ts = torch.distributions.exponential.Exponential(1 / self.dt).sample([T]) # pyright: ignore 131 | ts = ts.cumsum(0).to(device) 132 | else: 133 | raise ValueError("Time grid parameter is wrong!") 134 | return ts 135 | 136 | def integrate_system(self, T, g, s0=None, N=1, return_states=False): 137 | """Returns torch tensors 138 | states - [N,T,n] where s0=[N,n] 139 | actions - [N,T,m] 140 | rewards - [N,T] 141 | ts - [N,T] 142 | """ 143 | with torch.no_grad(): 144 | s0 = ( 145 | torch.stack([numpy_to_torch(self.reset()) for _ in range(N)]).to(self.device) 146 | if s0 is None 147 | else numpy_to_torch(s0) 148 | ) 149 | s0 = self.obs2state(s0) 150 | ts = self.build_time_grid(T) 151 | 152 | def odefnc(t, s): 153 | a = g(self.torch_transform_states(s), t) # 1,m 154 | return self.torch_rhs(s, a) 155 | 156 | st = odeint( 157 | odefnc, 158 | s0, 159 | ts, 160 | rtol=self.solver["rtol"], 161 | atol=self.solver["atol"], 162 | method=self.solver["method"], 163 | ) 164 | at = torch.stack([g(self.torch_transform_states(s_), t_) for s_, t_ in zip(st, ts)]) 165 | rt = self.diff_reward(st, at) # T,N 166 | if len(rt.shape) > 1: 167 | st, at, rt = st.permute(1, 0, 2), at.permute(1, 0, 2), rt.T # pyright: ignore 168 | st_obs = self.torch_transform_states(st) 169 | st_obs += torch.randn_like(st_obs) * self.obs_noise # pyright: ignore 170 | returns = [st_obs, at, rt, torch.stack([ts] * st_obs.shape[0])] 171 | if return_states: 172 | returns.append(st) 173 | return returns 174 | 175 | def batch_integrate_system_double_time(self, is0s, actions, device=None): 176 | """Returns torch tensors 177 | states - [N,T,n] where s0=[N,n] 178 | actions - [N,T,m] 179 | rewards - [N,T] 180 | ts - [N,T] 181 | """ 182 | # from tqdm import tqdm 183 | if device is None: 184 | device = self.device 185 | # from time import time 186 | # t0 = time() 187 | with torch.no_grad(): 188 | # print(f'{time() - t0}') 189 | s0s = self.obs2state(is0s) 190 | ts = self.build_time_grid(device=device, only_one_step=False, T=3) 191 | sb_l = [] 192 | sn_l = [] 193 | # for a in tqdm(actions): 194 | for a in actions: 195 | ab = a.view(1, -1).repeat(s0s.shape[0], 1) 196 | 197 | def odefnc(t, s): 198 | return self.torch_rhs(s, ab) # pylint: disable=cell-var-from-loop 199 | 200 | st = odeint( 201 | odefnc, 202 | s0s, 203 | ts, 204 | rtol=self.solver["rtol"], 205 | atol=self.solver["atol"], 206 | method=self.solver["method"], 207 | ) 208 | sb_l.append(st[-2, :, :]) # pyright: ignore 209 | sn_l.append(st[-1, :, :]) # pyright: ignore 210 | sn = torch.stack(sn_l) 211 | sb = torch.stack(sb_l) 212 | # print(f'OUT {time() - t0}') 213 | sn = sn.view(-1, sn.shape[2]) 214 | sb = sb.view(-1, sb.shape[2]) 215 | s0s = self.torch_transform_states(s0s) 216 | sb = self.torch_transform_states(sb) 217 | # print(f'{time() - t0}') 218 | sn = self.torch_transform_states(sn) 219 | if len(actions.shape) == 1: 220 | a0 = actions.repeat_interleave(s0s.shape[0]).view(-1, 1) 221 | else: 222 | a0 = actions.repeat_interleave(s0s.shape[0]).view(-1, actions.shape[1]) 223 | # a0 = actions.view(1,-1).repeat(s0s.shape[0],1).view(-1,1) 224 | # print(f'{time() - t0}') 225 | s0s_out = s0s.unsqueeze(0).repeat(actions.shape[0], 1, 1).view(-1, s0s.shape[1]) 226 | if self.obs_noise != 0.0: 227 | sn += torch.randn_like(sn) * self.obs_noise 228 | # print(f' FIN {time() - t0}') 229 | return s0s_out, a0, sb, sn, ts[1] 230 | 231 | def batch_integrate_system(self, is0s, actions, device=None): 232 | """Returns torch tensors 233 | states - [N,T,n] where s0=[N,n] 234 | actions - [N,T,m] 235 | rewards - [N,T] 236 | ts - [N,T] 237 | """ 238 | # from tqdm import tqdm 239 | if device is None: 240 | device = self.device 241 | # from time import time 242 | # t0 = time() 243 | with torch.no_grad(): 244 | # print(f'{time() - t0}') 245 | s0s = self.obs2state(is0s) 246 | ts = self.build_time_grid(device=device) 247 | sn_l = [] 248 | # for a in tqdm(actions): 249 | for a in actions: 250 | ab = a.view(1, -1).repeat(s0s.shape[0], 1) 251 | 252 | def odefnc(t, s): 253 | return self.torch_rhs(s, ab) # pylint: disable=cell-var-from-loop 254 | 255 | st = odeint( 256 | odefnc, 257 | s0s, 258 | ts, 259 | rtol=self.solver["rtol"], 260 | atol=self.solver["atol"], 261 | method=self.solver["method"], 262 | ) 263 | sn_l.append(st[-1, :, :]) # pyright: ignore 264 | sn = torch.stack(sn_l) 265 | # print(f'OUT {time() - t0}') 266 | sn = sn.view(-1, sn.shape[2]) 267 | s0s = self.torch_transform_states(s0s) 268 | # print(f'{time() - t0}') 269 | sn = self.torch_transform_states(sn) 270 | if len(actions.shape) == 1: 271 | a0 = actions.repeat_interleave(s0s.shape[0]).view(-1, 1) 272 | else: 273 | a0 = actions.repeat_interleave(s0s.shape[0]).view(-1, actions.shape[1]) 274 | # a0 = actions.view(1,-1).repeat(s0s.shape[0],1).view(-1,1) 275 | # print(f'{time() - t0}') 276 | s0s_out = s0s.unsqueeze(0).repeat(actions.shape[0], 1, 1).view(-1, s0s.shape[1]) 277 | if self.obs_noise != 0.0: 278 | sn += torch.randn_like(sn) * self.obs_noise 279 | # print(f' FIN {time() - t0}') 280 | return s0s_out, a0, sn, ts[-1] 281 | 282 | def torch_transform_states(self, state): 283 | if self.obs_trans: 284 | raise NotImplementedError 285 | else: 286 | return state 287 | 288 | def obs2state(self, state): 289 | if self.obs_trans: 290 | raise NotImplementedError 291 | else: 292 | return state 293 | 294 | def np_terminating_reward(self, state): # [...,n] 295 | return np.zeros(state.shape[:-1]) * 0.0 296 | 297 | def trigonometric2angle(self, costheta, sintheta): 298 | C = (costheta**2 + sintheta**2).detach() 299 | costheta, sintheta = costheta / C, sintheta / C 300 | theta = torch.atan2(sintheta / C, costheta / C) 301 | return theta 302 | 303 | @abstractmethod 304 | def reset(self): 305 | raise NotImplementedError 306 | 307 | @abstractmethod 308 | def torch_rhs(self, state, action): 309 | raise NotImplementedError 310 | 311 | @abstractmethod 312 | def diff_obs_reward_(self, s): 313 | raise NotImplementedError 314 | 315 | @abstractmethod 316 | def diff_ac_reward_(self, a): 317 | raise NotImplementedError 318 | 319 | @abstractmethod 320 | def render(self, mode, **kwargs): # pylint: disable=signature-differs 321 | raise NotImplementedError 322 | -------------------------------------------------------------------------------- /envs/oderl/envs/ctacrobot.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import copy 4 | 5 | import numpy as np 6 | import torch 7 | from numpy import cos, pi, sin 8 | 9 | from .base_env import BaseEnv 10 | 11 | 12 | def angle_normalize(x): 13 | return ((x + np.pi) % (2 * np.pi)) - np.pi # [-3, -1, 0, 1, 2, 3] --> [-1, -1, 0, -1, 0] 14 | 15 | 16 | class CTAcrobot(BaseEnv): 17 | """ 18 | Code modified from https://github.com/openai/gym/blob/master/gym/envs/classic_control/acrobot.py 19 | Acrobot is a 2-link pendulum with only the second joint actuated. 20 | Initially, both links point downwards. The goal is to swing the 21 | end-effector at a height at least the length of one link above the base. 22 | Both links can swing freely and can pass by each other, i.e., they don't 23 | collide when they have the same angle. 24 | **STATE:** 25 | The state consists of the sin() and cos() of the two rotational joint 26 | angles and the joint angular velocities : 27 | [cos(theta1) sin(theta1) cos(theta2) sin(theta2) thetaDot1 thetaDot2]. 28 | For the first link, an angle of 0 corresponds to the link pointing downwards. 29 | The angle of the second link is relative to the angle of the first link. 30 | An angle of 0 corresponds to having the same angle between the two links. 31 | A state of [1, 0, 1, 0, ..., ...] means that both links point downwards. 32 | **ACTIONS:** 33 | Restricted to a range [-4,4] 34 | .. note:: 35 | The dynamics equations were missing some terms in the NIPS paper which 36 | are present in the book. R. Sutton confirmed in personal correspondence 37 | that the experimental results shown in the paper and the book were 38 | generated with the equations shown in the book. 39 | However, there is the option to run the domain with the paper equations 40 | by setting book_or_nips = 'nips' 41 | **REFERENCE:** 42 | .. seealso:: 43 | R. Sutton: Generalization in Reinforcement Learning: 44 | Successful Examples Using Sparse Coarse Coding (NIPS 1996) 45 | .. seealso:: 46 | R. Sutton and A. G. Barto: 47 | Reinforcement learning: An introduction. 48 | Cambridge: MIT press, 1998. 49 | .. warning:: 50 | This version of the domain uses the Runge-Kutta method for integrating 51 | the system dynamics and is more realistic, but also considerably harder 52 | than the original version which employs Euler integration, 53 | see the AcrobotLegacy class. 54 | """ 55 | 56 | metadata = {"render.modes": ["human", "rgb_array"], "video.frames_per_second": 15} # pyright: ignore 57 | 58 | LINK_LENGTH_1 = 1.0 # [m] 59 | LINK_LENGTH_2 = 1.0 # [m] 60 | LINK_MASS_1 = 1.0 #: [kg] mass of link 1 61 | LINK_MASS_2 = 1.0 #: [kg] mass of link 2 62 | LINK_COM_POS_1 = 0.5 #: [m] position of the center of mass of link 1 63 | LINK_COM_POS_2 = 0.5 #: [m] position of the center of mass of link 2 64 | LINK_MOI = 1.0 #: moments of inertia for both links 65 | 66 | def __init__( 67 | self, 68 | dt=0.1, 69 | device="cpu", 70 | obs_trans=True, 71 | obs_noise=0.0, 72 | ts_grid="fixed", 73 | solver="dopri8", 74 | fully_act=True, 75 | friction=False, 76 | ): 77 | self.fully_act = fully_act 78 | self.N0 = 7 79 | self.Nexpseq = 3 80 | name = "acrobot" 81 | if obs_trans: 82 | state_action_names = [ 83 | "cos_theta1", 84 | "sin_theta1", 85 | "cos_theta2", 86 | "sin_theta2", 87 | "velocity1", 88 | "velocity2", 89 | ] 90 | name += "-trig" 91 | else: 92 | state_action_names = ["theta1", "theta2", "velocity1", "velocity2"] 93 | if fully_act: 94 | state_action_names += ["action1", "action2"] 95 | # print('Running fully actuated Acrobot') 96 | else: 97 | state_action_names += ["action"] 98 | super().__init__( 99 | dt, 100 | 4 + 2 * obs_trans, 101 | 1 + fully_act, 102 | 5.0, 103 | obs_trans, 104 | name, 105 | state_action_names, 106 | device, 107 | solver, 108 | obs_noise, 109 | ts_grid, 110 | 1e-4, 111 | 1e-1, 112 | ) 113 | self.reset() 114 | 115 | #################### environment specific ################## 116 | def extract_velocity(self, state): 117 | return state[..., -2:] 118 | 119 | def extract_position(self, state): 120 | return state[..., :-2] 121 | 122 | def merge_velocity_acceleration(self, ds, dv): 123 | return torch.cat(ds, dv, -1) # pyright: ignore 124 | 125 | def torch_transform_states(self, state): 126 | """Input - [N,n] or [L,N,n]""" 127 | if self.obs_trans: 128 | state_ = state.detach().clone() 129 | theta1, theta2, vel1, vel2 = ( 130 | state_[..., 0:1], 131 | state_[..., 1:2], 132 | state_[..., 2:3], 133 | state_[..., 3:4], 134 | ) 135 | return torch.cat([theta1.cos(), theta1.sin(), theta2.cos(), theta2.sin(), vel1, vel2], -1) 136 | else: 137 | return state 138 | 139 | def set_state_(self, state): 140 | assert state.shape[-1] == 4, "Trigonometrically transformed states cannot be set!\n" 141 | self.state = copy.deepcopy(state) 142 | return self.get_obs() 143 | 144 | def df_du(self, state): 145 | raise NotImplementedError() 146 | 147 | #################### override ################## 148 | def reset(self): 149 | self.state = self.np_random.uniform(low=-0.1, high=0.1, size=(4,)) 150 | self.time_step = 0 151 | return self.get_obs() 152 | 153 | def obs2state(self, obs): # pylint: disable=arguments-renamed 154 | if obs.shape[-1] == 4: 155 | return obs 156 | cos_th1, sin_th1, cos_th2, sin_th2, vel1, vel2 = ( 157 | obs[..., 0], 158 | obs[..., 1], 159 | obs[..., 2], 160 | obs[..., 3], 161 | obs[..., 4], 162 | obs[..., 5], 163 | ) 164 | theta1 = self.trigonometric2angle(cos_th1, sin_th1) 165 | theta2 = self.trigonometric2angle(cos_th2, sin_th2) 166 | return torch.stack([theta1, theta2, vel1, vel2], -1) 167 | 168 | def torch_rhs(self, state, action): 169 | """Input 170 | state [N,n] 171 | action [N,m] 172 | """ 173 | sixD = state.shape[-1] == 6 174 | m1 = self.LINK_MASS_1 175 | m2 = self.LINK_MASS_2 176 | l1 = self.LINK_LENGTH_1 177 | lc1 = self.LINK_COM_POS_1 178 | lc2 = self.LINK_COM_POS_2 179 | I1 = self.LINK_MOI 180 | I2 = self.LINK_MOI 181 | g = 9.8 182 | if sixD: 183 | costtheta1 = state[..., 0] 184 | sintheta1 = state[..., 1] 185 | costtheta2 = state[..., 2] 186 | sintheta2 = state[..., 3] 187 | dtheta1 = state[..., 4] 188 | dtheta2 = state[..., 5] 189 | C1 = (costtheta1**2 + sintheta1**2).detach() 190 | costheta1, sintheta1 = costtheta1 / C1, sintheta1 / C1 191 | theta1 = torch.atan2(sintheta1 / C1, costheta1 / C1) 192 | C2 = (costtheta2**2 + sintheta2**2).detach() 193 | costheta2, sintheta2 = costtheta2 / C2, sintheta2 / C2 194 | theta2 = torch.atan2(sintheta2 / C2, costheta2 / C2) 195 | else: 196 | theta1, theta2, dtheta1, dtheta2 = ( 197 | state[..., 0], 198 | state[..., 1], 199 | state[..., 2], 200 | state[..., 3], 201 | ) 202 | d1 = m1 * lc1**2 + m2 * (l1**2 + lc2**2 + 2 * l1 * lc2 * torch.cos(theta2)) + I1 + I2 203 | d2 = m2 * (lc2**2 + l1 * lc2 * torch.cos(theta2)) + I2 204 | phi2 = m2 * lc2 * g * torch.cos(theta1 + theta2 - pi / 2.0) 205 | phi1 = ( 206 | -m2 * l1 * lc2 * dtheta2**2 * torch.sin(theta2) 207 | - 2 * m2 * l1 * lc2 * dtheta2 * dtheta1 * torch.sin(theta2) 208 | + (m1 * lc1 + m2 * l1) * g * torch.cos(theta1 - pi / 2) 209 | + phi2 210 | ) 211 | ddtheta2 = (action[..., 0] + d2 / d1 * phi1 - m2 * l1 * lc2 * dtheta1**2 * torch.sin(theta2) - phi2) / ( 212 | m2 * lc2**2 + I2 - d2**2 / d1 213 | ) 214 | if self.fully_act: 215 | ddtheta1 = -(action[..., 1] + d2 * ddtheta2 + phi1) / d1 216 | else: 217 | ddtheta1 = -(d2 * ddtheta2 + phi1) / d1 218 | if sixD: 219 | return torch.stack( 220 | [ 221 | -sintheta1 * dtheta1 / C1, # pyright: ignore 222 | costheta1 * dtheta1 / C1, # pyright: ignore 223 | -sintheta2 * dtheta2 / C2, # pyright: ignore 224 | costheta2 * dtheta2 / C2, # pyright: ignore 225 | ddtheta1, 226 | ddtheta2, 227 | ], 228 | -1, 229 | ) 230 | else: 231 | return torch.stack([dtheta1, dtheta2, ddtheta1, ddtheta2], -1) 232 | 233 | def diff_obs_reward_(self, state, exp_reward=False): # pylint: disable=arguments-renamed 234 | if state.shape[-1] == 6: 235 | state = self.obs2state(state) 236 | th1, th2, vel1, vel2 = ( 237 | state[..., 0], 238 | state[..., 1], 239 | state[..., 2], 240 | state[..., 3], 241 | ) 242 | velocity_reward = -(vel1**2) - vel2**2 243 | p1 = [-self.LINK_LENGTH_1 * torch.cos(th1), self.LINK_LENGTH_1 * torch.sin(th1)] 244 | p2 = [ 245 | p1[0] - self.LINK_LENGTH_2 * torch.cos(th1 + th2), 246 | p1[1] + self.LINK_LENGTH_2 * torch.sin(th1 + th2), 247 | ] 248 | state_reward = -((p2[0] - self.LINK_LENGTH_1 - self.LINK_LENGTH_2) ** 2) - (p2[1]) ** 2 249 | if exp_reward: 250 | return (state_reward + self.vel_rew_const * velocity_reward).exp() 251 | else: 252 | return state_reward + self.vel_rew_const * velocity_reward 253 | 254 | def diff_ac_reward_(self, action): # pylint: disable=arguments-renamed 255 | return -self.ac_rew_const * torch.sum(action**2, -1) 256 | 257 | def render(self, *args, mode="human", **kwargs): 258 | from gym.envs.classic_control import rendering 259 | 260 | s = self.state 261 | if self.viewer is None: 262 | self.viewer = rendering.Viewer(512, 512) 263 | bound = self.LINK_LENGTH_1 + self.LINK_LENGTH_2 + 0.2 # 2.2 for default 264 | self.viewer.set_bounds(-bound, bound, -bound, bound) 265 | if s is None: 266 | return None 267 | p1 = [-self.LINK_LENGTH_1 * cos(s[0]), self.LINK_LENGTH_1 * sin(s[0])] 268 | p2 = [ 269 | p1[0] - self.LINK_LENGTH_2 * cos(s[0] + s[1]), 270 | p1[1] + self.LINK_LENGTH_2 * sin(s[0] + s[1]), 271 | ] 272 | # print(p1+p2) 273 | xys = np.array([[0, 0], p1, p2])[:, ::-1] 274 | thetas = [s[0] - pi / 2, s[0] + s[1] - pi / 2] 275 | link_lengths = [self.LINK_LENGTH_1, self.LINK_LENGTH_2] 276 | self.viewer.draw_line((-2.2, 1), (2.2, 1)) 277 | for (x, y), th, llen in zip(xys, thetas, link_lengths): 278 | l, r, t, b = 0, llen, 0.1, -0.1 279 | jtransform = rendering.Transform(rotation=th, translation=(x, y)) 280 | link = self.viewer.draw_polygon([(l, b), (l, t), (r, t), (r, b)]) 281 | link.add_attr(jtransform) 282 | link.set_color(0, 0.8, 0.8) 283 | circ = self.viewer.draw_circle(0.1) # pyright: ignore 284 | circ.set_color(0.8, 0.8, 0) 285 | circ.add_attr(jtransform) 286 | return self.viewer.render(return_rgb_array=mode == "rgb_array") 287 | -------------------------------------------------------------------------------- /envs/oderl/envs/ctpendulum.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import copy 4 | from os import path 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from .base_env import BaseEnv 10 | 11 | 12 | def angle_normalize(x): 13 | return ((x + np.pi) % (2 * np.pi)) - np.pi # [-3, -1, 0, 1, 2, 3] --> [-1, -1, 0, -1, 0] 14 | 15 | 16 | class CTPendulum(BaseEnv): # pylint: disable=abstract-method 17 | """The precise equation for reward: 18 | -(theta^2 + 0.1*theta_dt^2 + 0.001*action^2) 19 | Theta is normalized between -pi and pi. Therefore, the lowest reward is -(pi^2 + 0.1*8^2 + 0.001*2^2) = -16.2736044, 20 | and the highest reward is 0. In essence, the goal is to remain at zero angle (vertical), 21 | with the least rotational velocity, and the least effort. 22 | """ 23 | 24 | metadata = {"render.modes": ["human", "rgb_array"], "video.frames_per_second": 30} # pyright: ignore 25 | 26 | def __init__( 27 | self, 28 | dt=0.1, 29 | device="cpu", 30 | obs_trans=True, 31 | obs_noise=0.0, 32 | ts_grid="fixed", 33 | solver="dopri8", 34 | friction=False, 35 | ): 36 | name = "pendulum" 37 | if obs_trans: 38 | state_action_names = ["cos_theta", "sin_theta", "velocity", "action"] 39 | name += "-trig" 40 | else: 41 | state_action_names = ["angle", "velocity", "action"] 42 | # self.reward_range = [-17.0, 0.0] # for visualization 43 | super().__init__( 44 | dt, 45 | 2 + obs_trans, 46 | 1, 47 | 2.0, 48 | obs_trans, 49 | name, 50 | state_action_names, 51 | device, 52 | solver, 53 | obs_noise, 54 | ts_grid, 55 | ) 56 | self.N0 = 3 57 | self.Nexpseq = 0 58 | self.g = 10.0 59 | self.mass = 1.0 60 | self.l = 1.0 # noqa: E741 61 | self.reset() 62 | 63 | #################### environment specific ################## 64 | def extract_velocity(self, state): 65 | return state[..., -1:] 66 | 67 | def extract_position(self, state): 68 | return state[..., :-1] 69 | 70 | def merge_velocity_acceleration(self, ds, dv): 71 | return torch.cat([ds, dv], -1) 72 | 73 | def torch_transform_states(self, state): 74 | """Input - [N,n] or [L,N,n]""" 75 | if self.obs_trans: 76 | theta, theta_dot = state[..., 0:1], state[..., 1:2] 77 | return torch.cat([theta.cos(), theta.sin(), theta_dot], -1) 78 | else: 79 | return state 80 | 81 | def set_state_(self, state): 82 | assert state.shape[-1] == 2, "Trigonometrically transformed states cannot be set!\n" 83 | self.state = copy.deepcopy(state) 84 | return self.get_obs() 85 | 86 | def df_du(self, state): 87 | theta, theta_dot = state[..., 0], state[..., 1] 88 | m, l = self.mass, self.l # noqa: E741 89 | return torch.stack([theta * 0.0, torch.ones_like(theta_dot) * 3.0 / (m * l**2)], -1) 90 | 91 | #################### override ################## 92 | def reset(self): 93 | # low, high = np.array([-np.pi, -3]), np.array([np.pi, 3]) 94 | rand_state = self.np_random.uniform(low=-0.1, high=0.1, size=(2,)) 95 | rand_state[0] += np.pi 96 | self.state = rand_state 97 | self.time_step = 0 98 | return self.get_obs() 99 | # # low, high = np.array([-0.75*np.pi, -1]), np.array([-0.5*np.pi, 1]) 100 | # self.state = self.np_random.uniform(low=low, high=high) 101 | # self.time_step = 0 102 | # return self.get_obs() 103 | 104 | def obs2state(self, obs): # pylint: disable=arguments-renamed 105 | if obs.shape[-1] == 2: 106 | return obs 107 | cos_th, sin_th, vel = obs[..., 0], obs[..., 1], obs[..., 2] 108 | theta = self.trigonometric2angle(cos_th, sin_th) 109 | return torch.stack([theta, vel], -1) 110 | 111 | def torch_rhs(self, state, action): 112 | """Input 113 | state [N,n] 114 | action [N,m] 115 | """ 116 | # assert state.shape[-1]==2, 'Trigonometrically transformed states do not define ODE rhs!\n' 117 | g, m, l = self.g, self.mass, self.l # noqa: E741 118 | if state.shape[-1] == 2: 119 | th, thdot = state[..., 0], state[..., 1] 120 | return torch.stack( 121 | [ 122 | thdot, 123 | (-3 * g / (2 * l) * torch.sin(th + np.pi) + 3.0 / (m * l**2) * action[..., 0]), 124 | ], 125 | -1, 126 | ) 127 | elif state.shape[-1] == 3: 128 | costh, sinth, thdot = state[..., 0], state[..., 1], state[..., 2] 129 | th = self.obs2state(state)[..., 0] 130 | return torch.stack( 131 | [ 132 | -sinth * thdot, 133 | costh * thdot, 134 | (-3 * g / (2 * l) * torch.sin(th + np.pi) + 3.0 / (m * l**2) * action[..., 0]), 135 | ], 136 | -1, 137 | ) 138 | 139 | def diff_obs_reward_(self, state, exp_reward=False): # pylint: disable=arguments-renamed 140 | if state.shape[-1] == 2: 141 | th, thdot = state[..., 0], state[..., 1] 142 | cos_th, sin_th = th.cos(), th.sin() 143 | else: 144 | cos_th, sin_th, thdot = state[..., 0], state[..., 1], state[..., 2] 145 | state_reward = -self.l**2 * ((1 - cos_th) ** 2 + sin_th**2) 146 | velocity_reward = -(thdot**2) 147 | # return state_reward.exp() + self.vel_rew_const*velocity_reward 148 | if exp_reward: 149 | return (state_reward + self.vel_rew_const * velocity_reward).exp() # works superb 150 | else: 151 | return state_reward + self.vel_rew_const * velocity_reward # works superb 152 | # return (state_reward + self.vel_rew_const*velocity_reward).exp() # works superb 153 | 154 | def diff_ac_reward_(self, action): # pylint: disable=arguments-renamed 155 | return -self.ac_rew_const * torch.sum(action**2, -1) 156 | 157 | def render(self, mode="human", **kwargs): 158 | if self.viewer is None: 159 | from gym.envs.classic_control import rendering 160 | 161 | self.viewer = rendering.Viewer(500, 500) 162 | self.viewer.set_bounds(-2.2, 2.2, -2.2, 2.2) 163 | rod = rendering.make_capsule(1, 0.2) 164 | rod.set_color(0.8, 0.3, 0.3) 165 | self.pole_transform = rendering.Transform() 166 | rod.add_attr(self.pole_transform) 167 | self.viewer.add_geom(rod) 168 | axle = rendering.make_circle(0.05) # pyright: ignore 169 | axle.set_color(0, 0, 0) 170 | self.viewer.add_geom(axle) 171 | fname = path.join(path.dirname(__file__), "assets/clockwise.png") 172 | self.img = rendering.Image(fname, 1.0, 1.0) 173 | self.imgtrans = rendering.Transform() 174 | self.img.add_attr(self.imgtrans) 175 | 176 | self.viewer.add_onetime(self.img) 177 | self.pole_transform.set_rotation(self.state[0] + np.pi / 2) 178 | try: 179 | last_act = kwargs["last_act"] 180 | self.imgtrans.scale = (-last_act / 2, np.abs(last_act) / 2) 181 | except Exception: # pylint: disable=broad-exception-caught 182 | pass 183 | return self.viewer.render(return_rgb_array=mode == "rgb_array") 184 | -------------------------------------------------------------------------------- /envs/oderl/runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | import ctrl.ctrl as base 6 | import envs 7 | from ctrl import utils 8 | 9 | os.environ["KMP_DUPLICATE_LIB_OK"] = "True" 10 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 'cpu' 11 | torch.set_default_dtype(torch.float32) 12 | 13 | ################## environment and dataset ################## 14 | dt = 0.1 # mean time difference between observations 15 | noise = 0.0 # observation noise std 16 | ts_grid = "fixed" # the distribution for the observation time differences: ['fixed','uniform','exp'] 17 | # [CTPendulum, CTCartpole, CTAcrobot]: 18 | ENV_CLS = envs.CTCartpole # pyright: ignore 19 | env = ENV_CLS( 20 | dt=dt, 21 | obs_trans=True, 22 | device=device, 23 | obs_noise=noise, 24 | ts_grid=ts_grid, 25 | solver="dopri5", 26 | ) 27 | D = utils.collect_data(env, H=5.0, N=env.N0) 28 | 29 | 30 | ################## model ################## 31 | dynamics = "enode" # ensemble of neural ODEs 32 | # dynamics = 'benode' # batch ensemble of neural ODEs 33 | # dynamics = 'ibnode' # implicit BNN ODEs 34 | # dynamics = 'pets' # PETS 35 | # dynamics = 'deep_pilco' # deep PILCO 36 | n_ens = 5 # ensemble size 37 | nl_f = 3 # number of hidden layers in the differential function 38 | nn_f = 200 # number of hidden neurons in each hidden layer of f 39 | act_f = "elu" # activation of f (should be smooth) 40 | dropout_f = 0.05 # dropout parameter (needed only for deep pilco) 41 | learn_sigma = False # whether to learn the observation noise or keep it fixed 42 | nl_g = 2 # number of hidden layers in the policy function 43 | nn_g = 200 # number of hidden neurons in each hidden layer of g 44 | act_g = "relu" # activation of g 45 | nl_V = 2 # number of hidden layers in the state-value function 46 | nn_V = 200 # number of hidden neurons in each hidden layer of V 47 | act_V = "tanh" # activation of V (should be smooth) 48 | 49 | ctrl = base.CTRL( 50 | env, 51 | dynamics, 52 | n_ens=n_ens, 53 | learn_sigma=learn_sigma, 54 | nl_f=nl_f, 55 | nn_f=nn_f, 56 | act_f=act_f, 57 | dropout_f=dropout_f, 58 | nl_g=nl_g, 59 | nn_g=nn_g, 60 | act_g=act_g, 61 | nl_V=nl_V, 62 | nn_V=nn_V, 63 | act_V=act_V, 64 | ).to(device) 65 | 66 | print("Env dt={:.3f}\nObservation noise={:.3f}\nTime increments={:s}".format(env.dt, env.obs_noise, str(env.ts_grid))) 67 | 68 | 69 | ################## learning ################## 70 | utils.plot_model(ctrl, D, L=30, H=2.0, rep_buf=10, fname=ctrl.name + "-train.png") 71 | utils.plot_test(ctrl, D, L=30, H=2.5, N=5, fname=ctrl.name + "-test.png") 72 | 73 | utils.train_loop(ctrl, D, ctrl.name, 50, L=30, H=2.0) 74 | 75 | # ctrl.save(D=D,fname=ctrl.name) # save the model & dataset 76 | # ctrl_,D_ = base.CTRL.load(env, f'{fname}') # load the model & dataset 77 | -------------------------------------------------------------------------------- /envs/oderl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .benn import BENN # noqa: F401 2 | from .bnn import BNN # noqa: F401 3 | from .dropout_bnn import DropoutBNN # noqa: F401 4 | from .enn import ENN, EPNN # noqa: F401 5 | from .ibnn import IBNN # noqa: F401 6 | from .utils import * # noqa: F401,F403 7 | -------------------------------------------------------------------------------- /envs/oderl/utils/benn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn.parameter import Parameter 5 | 6 | from .utils import get_act 7 | 8 | 9 | class BENN(nn.Module): 10 | def __init__( 11 | self, 12 | n_ens: int, 13 | n_in: int, 14 | n_out: int, 15 | n_hid_layers: int = 2, 16 | n_hidden: int = 250, 17 | act: str = "relu", 18 | requires_grad=True, 19 | bias=True, 20 | layer_norm=False, 21 | skip_con=False, 22 | ): 23 | super().__init__() 24 | layers_dim = [n_in] + n_hid_layers * [n_hidden] + [n_out] 25 | self.n_ens = n_ens 26 | self.skip_con = skip_con 27 | self.act = act 28 | self.bias = bias 29 | self.acts = [] 30 | self.weights, self.biases = nn.ParameterList([]), nn.ParameterList([]) 31 | self.rs, self.ss = nn.ParameterList([]), nn.ParameterList([]) 32 | for i, (n_in, n_out) in enumerate(zip(layers_dim[:-1], layers_dim[1:])): 33 | self.weights.append(Parameter(torch.Tensor(n_in, n_out), requires_grad=requires_grad)) 34 | self.biases.append(None if not bias else Parameter(torch.Tensor(1, n_out), requires_grad=requires_grad)) 35 | self.acts.append(get_act(act) if i < n_hid_layers else get_act("linear")) # no act. in final layer 36 | self.rs.append(Parameter(torch.Tensor(n_ens, 1, n_in), requires_grad=requires_grad)) # Nens,1,n 37 | self.ss.append(Parameter(torch.Tensor(n_ens, 1, n_out), requires_grad=requires_grad)) # Nens,1,n 38 | self.reset_parameters() 39 | 40 | def shuffle(self): 41 | rand_idx = torch.randperm(self.n_ens) 42 | for r, s in zip(self.rs, self.ss): 43 | r.data = r.data[rand_idx] 44 | s.data = s.data[rand_idx] 45 | 46 | @property 47 | def device(self): 48 | return self.weights[0].device 49 | 50 | def __transform_sig(self, sig): # pyright: ignore # pylint: disable=unused-private-member 51 | # return F.softplus(sig) 52 | return sig.exp() + 1e-6 53 | 54 | def reset_parameters(self, gain=1.0): 55 | for i, (weight, bias) in enumerate(zip(self.weights, self.biases)): # pylint: disable=unused-variable 56 | nn.init.xavier_uniform_(weight, gain) 57 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weight) 58 | bound = 1 / np.sqrt(fan_in) 59 | if self.bias: 60 | nn.init.uniform_(bias, -bound, bound) 61 | for r, s in zip(self.rs, self.ss): 62 | nn.init.normal_(r, 1.0, 0.25) 63 | nn.init.normal_(s, 1.0, 0.25) 64 | 65 | def draw_noise(self, **kwargs): 66 | return None 67 | 68 | def draw_f(self, L=1, noise_vec=None): 69 | """Draws L//n_ens samples from each ensemble component 70 | Assigns each x[i] to a different sample in a different component 71 | x - [L,N,n] 72 | output - [L,N,n] 73 | """ 74 | 75 | def f(x): 76 | for r, s, weight, bias, act in zip(self.rs, self.ss, self.weights, self.biases, self.acts): 77 | x_ = (x * r) @ weight + bias 78 | x_ = x_ + x if x.shape == x_.shape and self.skip_con else x_ 79 | x = act(x_ * s) 80 | return x 81 | 82 | return f 83 | 84 | def forward(self, x, L=1): 85 | return self.draw_f()(x) 86 | 87 | def kl(self): 88 | return torch.zeros(1).to(self.device) 89 | 90 | def __repr__(self): 91 | str_ = f"BENN - {self.n_ens} members\n" 92 | for i, (weight, act) in enumerate(zip(self.weights, self.acts)): 93 | str_ += "Layer-{:d}: ".format(i + 1) + "".join(str([*weight.shape][::-1])) + "\t" + str(act) + "\n" 94 | return str_ 95 | -------------------------------------------------------------------------------- /envs/oderl/utils/bnn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.distributions import Normal 6 | from torch.distributions import kl_divergence as kl 7 | from torch.nn.parameter import Parameter 8 | from torch.nn.utils.convert_parameters import parameters_to_vector 9 | 10 | from .utils import get_act 11 | 12 | 13 | class BNN(nn.Module): 14 | def __init__( 15 | self, 16 | n_in: int, 17 | n_out: int, 18 | n_hid_layers: int = 2, 19 | n_hidden: int = 100, 20 | act: str = "relu", 21 | dropout=0.0, 22 | requires_grad=True, 23 | logsig0=-3, 24 | bnn=True, 25 | layer_norm=False, 26 | batch_norm=False, 27 | bias=True, 28 | var_apr="mf", 29 | ): 30 | super().__init__() 31 | layers_dim = [n_in] + n_hid_layers * [n_hidden] + [n_out] 32 | assert not (layer_norm and batch_norm), "Either layer_norm or batch_norm should be True" 33 | self.weight_mus = nn.ParameterList([]) 34 | self.bias_mus = nn.ParameterList([]) 35 | self.norms = nn.ModuleList([]) 36 | self.dropout_rate = dropout 37 | self.dropout = nn.Dropout(dropout) 38 | self.acts = [] 39 | self.act = act 40 | self.bnn = bnn 41 | self.bias = bias 42 | self.var_apr = var_apr 43 | for i, (n_in, n_out) in enumerate(zip(layers_dim[:-1], layers_dim[1:])): 44 | self.weight_mus.append(Parameter(torch.Tensor(n_in, n_out), requires_grad=requires_grad)) 45 | self.bias_mus.append(None if not bias else Parameter(torch.Tensor(1, n_out), requires_grad=requires_grad)) 46 | self.acts.append(get_act(act) if i < n_hid_layers else get_act("linear")) # no act. in final layer 47 | norm = nn.Identity() 48 | if i < n_hid_layers: 49 | if layer_norm: 50 | norm = nn.LayerNorm(n_out) 51 | elif batch_norm: 52 | norm = nn.BatchNorm1d(n_out) 53 | self.norms.append(norm) 54 | if bnn: 55 | self.weight_logsigs = nn.ParameterList([]) 56 | self.bias_logsigs = nn.ParameterList([]) 57 | self.logsig0 = logsig0 58 | for i, (n_in, n_out) in enumerate(zip(layers_dim[:-1], layers_dim[1:])): 59 | self.weight_logsigs.append(Parameter(torch.Tensor(n_in, n_out), requires_grad=requires_grad)) 60 | self.bias_logsigs.append( 61 | None if not bias else Parameter(torch.Tensor(1, n_out), requires_grad=requires_grad) 62 | ) 63 | self.reset_parameters() 64 | 65 | @property 66 | def device(self): 67 | return self.weight_mus[0].device 68 | 69 | def __transform_sig(self, sig): 70 | return torch.log(1 + torch.exp(sig)) 71 | # return sig.exp() 72 | 73 | def reset_parameters(self, gain=1.0): 74 | for i, (weight, bias) in enumerate(zip(self.weight_mus, self.bias_mus)): # pylint: disable=unused-variable 75 | nn.init.xavier_uniform_(weight, gain) 76 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weight) 77 | bound = 1 / np.sqrt(fan_in) 78 | if self.bias: 79 | nn.init.uniform_(bias, -bound, bound) 80 | for norm in self.norms[:-1]: # pyright: ignore 81 | if isinstance(norm, nn.LayerNorm): 82 | norm.reset_parameters() 83 | if self.bnn: 84 | for w, b in zip(self.weight_logsigs, self.bias_logsigs): 85 | nn.init.uniform_(w, self.logsig0 - 1, self.logsig0 + 1) 86 | if self.bias: 87 | nn.init.uniform_(b, self.logsig0 - 1, self.logsig0 + 1) 88 | 89 | def draw_noise(self, L): 90 | P = parameters_to_vector(self.parameters()).numel() // 2 # single noise term needed per (mean,var) pair 91 | noise = torch.randn([L, P], device=self.weight_mus[0].device) 92 | if self.var_apr == "mf": 93 | return noise 94 | elif self.var_apr == "radial": 95 | noise /= noise.norm(dim=1, keepdim=True) 96 | r = torch.randn([L, 1], device=self.weight_mus[0].device) 97 | return noise * r 98 | 99 | def __sample_weights(self, L, noise_vec=None): 100 | if self.bnn: 101 | if noise_vec is None: 102 | noise_vec = self.draw_noise(L) # L,P 103 | weights = [] 104 | i = 0 105 | for weight_mu, weight_sig in zip(self.weight_mus, self.weight_logsigs): 106 | p = weight_mu.numel() 107 | weights.append( 108 | weight_mu 109 | + noise_vec[:, i : i + p].view(L, *weight_mu.shape) # pyright: ignore 110 | * self.__transform_sig(weight_sig) 111 | ) 112 | i += p 113 | if self.bias: 114 | biases = [] 115 | for bias_mu, bias_sig in zip(self.bias_mus, self.bias_logsigs): 116 | p = bias_mu.numel() 117 | biases.append( 118 | bias_mu 119 | + noise_vec[:, i : i + p].view(L, *bias_mu.shape) # pyright: ignore 120 | * self.__transform_sig(bias_sig) 121 | ) 122 | i += p 123 | else: 124 | biases = [ 125 | torch.zeros([L, 1, weight_mu.shape[1]], device=weight_mu.device) * 1.0 126 | for weight_mu, bias_mu in zip(self.weight_mus, self.bias_mus) 127 | ] # list of zeros 128 | else: 129 | raise ValueError("This is a NN, not a BNN!") 130 | return weights, biases 131 | 132 | def draw_f(self, L=1, noise_vec=None): 133 | """ 134 | x=[N,n] & bnn=False ---> out=[N,n] 135 | x=[N,n] & L=1 ---> out=[N,n] 136 | x=[N,n] & L>1 ---> out=[L,N,n] 137 | x=[L,N,n] -------> out=[L,N,n] 138 | """ 139 | if not self.bnn: 140 | 141 | def f(x): # pyright: ignore 142 | for weight, bias, act, norm in zip(self.weight_mus, self.bias_mus, self.acts, self.norms): 143 | x = act(norm(self.dropout(F.linear(x, weight.T, bias)))) 144 | return x 145 | 146 | return f 147 | else: 148 | weights, biases = self.__sample_weights(L, noise_vec) 149 | 150 | def f(x): # pyright: ignore 151 | x2d = x.ndim == 2 152 | if x2d: 153 | x = torch.stack([x] * L) # [L,N,n] 154 | for weight, bias, act, norm in zip(weights, biases, self.acts, self.norms): 155 | x = act(norm(self.dropout(torch.baddbmm(bias, x, weight)))) 156 | return x.squeeze(0) if x2d and L == 1 else x 157 | 158 | return f 159 | 160 | def forward(self, x, L=1): 161 | return self.draw_f(L)(x) 162 | 163 | def kl(self, L=100): 164 | if not self.bnn: 165 | return torch.zeros([1], device=self.device) * 1.0 166 | if self.var_apr == "mf": 167 | mus = [weight_mu.view([-1]) for weight_mu in self.weight_mus] 168 | logsigs = [weight_logsig.view([-1]) for weight_logsig in self.weight_logsigs] 169 | if self.bias: 170 | mus += [bias_mu.view([-1]) for bias_mu in self.bias_mus] 171 | logsigs += [bias_logsigs.view([-1]) for bias_logsigs in self.bias_logsigs] 172 | mus = torch.cat(mus) 173 | sigs = self.__transform_sig(torch.cat(logsigs)) 174 | q = Normal(mus, sigs) 175 | N = Normal(torch.zeros_like(mus), torch.ones_like(mus)) 176 | return kl(q, N) 177 | elif self.var_apr == "radial": 178 | weights, biases = self.__sample_weights(L) 179 | weights = torch.cat([w.view(L, -1) for w in weights], 1) 180 | sigs = torch.cat([weight_sig.view([-1]) for weight_sig in self.weight_logsigs]) 181 | if self.bias: 182 | biases = torch.cat([b.view(L, -1) for b in biases], 1) 183 | weights = torch.cat([weights, biases], 1) 184 | bias_sigs = torch.cat([bias_sig.view([-1]) for bias_sig in self.bias_logsigs]) 185 | sigs = torch.cat([sigs, bias_sigs]) 186 | cross_entr = -(weights**2).mean(0) / 2 - np.log(2 * np.pi) 187 | entr = -self.__transform_sig(sigs).log() 188 | return entr - cross_entr 189 | 190 | def __repr__(self): 191 | str_ = "BNN\n" if self.bnn else "NN\n" 192 | for i, (weight, act) in enumerate(zip(self.weight_mus, self.acts)): 193 | str_ += "Layer-{:d}: ".format(i + 1) + "".join(str([*weight.shape][::-1])) + "\t" + str(act) + "\n" 194 | return str_ 195 | -------------------------------------------------------------------------------- /envs/oderl/utils/dropout_bnn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.distributions import Bernoulli 5 | from torch.nn.parameter import Parameter 6 | 7 | from .bnn import get_act 8 | 9 | 10 | class DropoutBNN(nn.Module): 11 | def __init__( 12 | self, 13 | n_in: int, 14 | n_out: int, 15 | n_hid_layers: int = 2, 16 | act: str = "relu", 17 | dropout_rate=0.0, 18 | n_hidden: int = 100, 19 | bias=True, 20 | requires_grad=True, 21 | layer_norm=False, 22 | ): 23 | super().__init__() 24 | self.layers_dim = [n_in] + n_hid_layers * [n_hidden] + [n_out] 25 | self.weights = nn.ParameterList([]) 26 | self.biases = nn.ParameterList([]) 27 | self.layer_norms = nn.ModuleList([]) 28 | self.dropout_rate = dropout_rate 29 | self.acts = [] 30 | self.bias = bias 31 | self.act = act 32 | for i, (n_in, n_out) in enumerate(zip(self.layers_dim[:-1], self.layers_dim[1:])): 33 | self.weights.append(Parameter(torch.Tensor(n_in, n_out), requires_grad=requires_grad)) 34 | self.biases.append(None if not bias else Parameter(torch.Tensor(n_out), requires_grad=requires_grad)) 35 | self.acts.append(get_act(act) if i < n_hid_layers else get_act("linear")) # no act. in final layer 36 | self.layer_norms.append(nn.LayerNorm(n_out) if layer_norm and i < n_hid_layers else nn.Identity()) 37 | self.reset_parameters() 38 | 39 | def reset_parameters(self, gain=1.0): 40 | for i, (weight, bias) in enumerate(zip(self.weights, self.biases)): # pylint: disable=unused-variable 41 | nn.init.xavier_uniform_(weight, gain) 42 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weight) 43 | bound = 1 / np.sqrt(fan_in) 44 | nn.init.uniform_(bias, -bound, bound) 45 | for norm in self.layer_norms[:-1]: # pyright: ignore 46 | if isinstance(norm, nn.LayerNorm): 47 | norm.reset_parameters() 48 | 49 | def sample_weights(self): 50 | pass 51 | 52 | @property 53 | def device(self): 54 | return self.weights[0].device 55 | 56 | def draw_noise(self, L=1): 57 | dropout_masks = [] 58 | dropout_rate = self.dropout_rate 59 | b = Bernoulli(1 - dropout_rate) 60 | for h in self.layers_dim[1:-1]: 61 | dropout_masks.append(b.sample([L, 1, h]).to(self.device)) # pyright: ignore 62 | dropout_masks.append(torch.ones([L, 1, self.layers_dim[-1]], device=self.device)) 63 | return dropout_masks 64 | 65 | def draw_f(self, L=1, noise_vec=None): 66 | dropout_masks = self.draw_noise(L) if noise_vec is None else noise_vec # list of [L,1,h] 67 | 68 | def f(x): 69 | x2d = x.ndim == 2 70 | if x2d: 71 | x = torch.stack([x] * L) # [L,N,n] 72 | for weight, bias, dropout_mask, act, norm in zip( 73 | self.weights, self.biases, dropout_masks, self.acts, self.layer_norms 74 | ): 75 | x = act(norm(dropout_mask * (x @ weight + bias))) 76 | return x.squeeze(0) if x2d and L == 1 else x 77 | 78 | return f 79 | 80 | def forward(self, x, L=1): 81 | return self.draw_f(L, None)(x) 82 | 83 | def __repr__(self): 84 | str_ = "DBBB\\dropout rate = {:.2f}\n".format(self.dropout_rate) 85 | for i, (weight, act) in enumerate(zip(self.weights, self.acts)): 86 | str_ += "Layer-{:d}: ".format(i + 1) + "".join(str([*weight.shape][::-1])) + "\t" + str(act) + "\n" 87 | return str_ 88 | -------------------------------------------------------------------------------- /envs/oderl/utils/enn.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn.parameter import Parameter 7 | 8 | from .utils import get_act 9 | 10 | 11 | class ENN_BASE(nn.Module, metaclass=ABCMeta): 12 | @abstractmethod 13 | def __init__( 14 | self, 15 | n_ens, 16 | layers_ins, 17 | layers_outs, 18 | n_hid_layers=2, 19 | act="celu", 20 | dropout=0.0, 21 | skip_con=False, 22 | n_hidden=100, 23 | requires_grad=True, 24 | logsig0=-3, 25 | layer_norm=False, 26 | ): 27 | super().__init__() 28 | self.n_ens = n_ens 29 | self.weights = nn.ParameterList([]) 30 | self.biases = nn.ParameterList([]) 31 | self.layer_norms = nn.ModuleList([]) 32 | self.dropout_rate = dropout 33 | self.skip_con = skip_con 34 | self.dropout = nn.Dropout(dropout) 35 | self.acts = [] 36 | for i, (n_in, n_out) in enumerate(zip(layers_ins, layers_outs)): 37 | self.weights.append(Parameter(torch.Tensor(n_ens, n_in, n_out), requires_grad=requires_grad)) 38 | self.biases.append(Parameter(torch.Tensor(n_ens, 1, n_out), requires_grad=requires_grad)) 39 | self.acts.append(get_act(act) if i < n_hid_layers else get_act("linear")) # no act. in final layer 40 | self.layer_norms.append( 41 | nn.LayerNorm(n_out, elementwise_affine=False) if layer_norm and i < n_hid_layers else nn.Identity() 42 | ) 43 | self.reset_parameters() 44 | 45 | @property 46 | def device(self): 47 | return self.weights[0].device 48 | 49 | def reset_parameters(self, gain=1.0): 50 | for i, (weight, bias) in enumerate(zip(self.weights, self.biases)): # pylint: disable=unused-variable 51 | for w, b in zip(weight, bias): 52 | nn.init.xavier_uniform_(w, gain) 53 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(w) 54 | bound = 1 / np.sqrt(fan_in) 55 | nn.init.uniform_(b, -bound, bound) 56 | for norm in self.layer_norms[:-1]: # pyright: ignore 57 | if isinstance(norm, nn.LayerNorm): 58 | norm.reset_parameters() 59 | 60 | def kl(self): 61 | return torch.zeros(1).to(self.device) 62 | 63 | def set_num_particles(self, N): 64 | self.n_ens = N 65 | weights_new = [Parameter(weight[:N]) for weight in self.weights] 66 | biases_new = [Parameter(bias[:N]) for bias in self.biases] 67 | del self.weights 68 | del self.biases 69 | self.weights = nn.ParameterList(weights_new) 70 | self.biases = nn.ParameterList(biases_new) 71 | 72 | def shuffle(self): 73 | rand_idx = torch.randperm(self.n_ens) 74 | for w, b in zip(self.weights, self.biases): 75 | w.data = w.data[rand_idx] 76 | b.data = b.data[rand_idx] 77 | 78 | def name(self): 79 | str_ = "" 80 | for i, (weight, act) in enumerate(zip(self.weights, self.acts)): 81 | str_ += "Layer-{:d}: ".format(i + 1) + "".join(str([*weight.shape][::-1])) + "\t" + str(act) + "\n" 82 | return str_ 83 | 84 | def draw_noise(self, **kwargs): 85 | return None 86 | 87 | def forward(self, x): 88 | return self.draw_f()(x) 89 | 90 | @abstractmethod 91 | def draw_f(self): 92 | raise NotImplementedError 93 | 94 | 95 | class ENN(ENN_BASE): 96 | def __init__( 97 | self, 98 | n_ens, 99 | n_in, 100 | n_out, 101 | n_hid_layers=2, 102 | act="relu", 103 | dropout=0.0, 104 | skip_con=False, 105 | n_hidden=100, 106 | requires_grad=True, 107 | logsig0=-3, 108 | layer_norm=False, 109 | ): 110 | layers_ins = [n_in] + n_hid_layers * [n_hidden] 111 | layers_outs = n_hid_layers * [n_hidden] + [n_out] 112 | super().__init__( 113 | n_ens, 114 | layers_ins, 115 | layers_outs, 116 | n_hid_layers=n_hid_layers, 117 | skip_con=skip_con, 118 | act=act, 119 | dropout=dropout, 120 | n_hidden=n_hidden, 121 | requires_grad=requires_grad, 122 | logsig0=logsig0, 123 | layer_norm=layer_norm, 124 | ) 125 | 126 | def draw_f(self, **kwargs): 127 | """Returns 2D if input is 2D""" 128 | 129 | def f(x): # input/output is [Nens,N,nin] or [N,nin] 130 | x2d = x.ndim == 2 131 | x = torch.stack([x] * self.n_ens) if x2d else x 132 | for W, b, act, norm in zip(self.weights, self.biases, self.acts, self.layer_norms): 133 | x_ = self.dropout(torch.baddbmm(b, x, W)) 134 | x_ = x_ + x if x.shape == x_.shape and self.skip_con else x_ 135 | x = norm(act(x_)) # Nens,1,nout & Nens,N,nin & Nens,nin,nout 136 | return x.mean(0) if x2d else x 137 | 138 | return f 139 | 140 | def __repr__(self): 141 | super_name = super().name() 142 | return f"ENN - {self.n_ens} members\n" + super_name 143 | 144 | 145 | class EPNN(ENN): 146 | def __init__( 147 | self, 148 | n_ens, 149 | n_in, 150 | n_out, 151 | n_hid_layers=2, 152 | act="relu", 153 | dropout=0.0, 154 | skip_con=False, 155 | n_hidden=100, 156 | requires_grad=True, 157 | logsig0=-3, 158 | layer_norm=False, 159 | ): 160 | super().__init__( 161 | n_ens, 162 | n_in, 163 | 2 * n_out, 164 | n_hid_layers=n_hid_layers, 165 | act=act, 166 | dropout=dropout, 167 | skip_con=skip_con, 168 | n_hidden=n_hidden, 169 | requires_grad=requires_grad, 170 | logsig0=logsig0, 171 | layer_norm=layer_norm, 172 | ) 173 | self.n_out = n_out 174 | self.sp = nn.Softplus() 175 | self.max_logsig = nn.Parameter(torch.ones([n_out]), requires_grad=requires_grad) 176 | self.min_logsig = nn.Parameter(-2 * torch.ones([n_out]), requires_grad=requires_grad) 177 | 178 | def get_probs(self, x): 179 | x2d = x.ndim == 2 180 | x = torch.stack([x] * self.n_ens) if x2d else x 181 | for W, b, act, norm in zip(self.weights, self.biases, self.acts, self.layer_norms): 182 | x_ = self.dropout(torch.baddbmm(b, x, W)) 183 | x_ = x_ + x if x.shape == x_.shape and self.skip_con else x_ 184 | x = norm(act(x_)) # Nens,1,2nout & Nens,N,nin & Nens,nin,2nout 185 | x = x.mean(0) if x2d else x # ...,2nout 186 | mean, logvar = x[..., : self.n_out], x[..., self.n_out :] 187 | logvar = self.max_logsig - self.sp(self.max_logsig - logvar) 188 | logvar = self.min_logsig + self.sp(logvar - self.min_logsig) 189 | return mean, logvar.exp() 190 | 191 | def draw_f(self, **kwargs): 192 | """Returns 2D if input is 2D""" 193 | 194 | def f(x): # input/output is [Nens,N,nin] or [N,nin] 195 | mean, sig = self.get_probs(x) 196 | return mean + torch.randn_like(sig) * sig 197 | 198 | return f 199 | 200 | def __repr__(self): 201 | super_name = super().name() 202 | return f"EPNN - {self.n_ens} members\n" + super_name 203 | -------------------------------------------------------------------------------- /envs/oderl/utils/ibnn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.distributions import Normal 5 | from torch.distributions import kl_divergence as kl 6 | from torch.nn.parameter import Parameter 7 | 8 | from .utils import get_act 9 | 10 | 11 | class IBNN(nn.Module): 12 | def __init__( 13 | self, 14 | n_ens: int, 15 | n_in: int, 16 | n_out: int, 17 | n_hid_layers: int = 2, 18 | n_hidden: int = 250, 19 | act: str = "relu", 20 | requires_grad=True, 21 | bias=True, 22 | layer_norm=False, 23 | dropout=0.0, 24 | bnn=True, 25 | skip_con=False, 26 | ): 27 | super().__init__() 28 | print("IBNN: layer_norm, dropout, bnn parameters are discarded") 29 | layers_dim = [n_in] + n_hid_layers * [n_hidden] + [n_out] 30 | self.weights = nn.ParameterList([]) 31 | self.biases = nn.ParameterList([]) 32 | self.acts = [] 33 | self.n_ens = n_ens 34 | self.skip_con = skip_con 35 | self.act = act 36 | self.bias = bias 37 | for i, (n_in, n_out) in enumerate(zip(layers_dim[:-1], layers_dim[1:])): 38 | self.weights.append(Parameter(torch.Tensor(n_in, n_out), requires_grad=requires_grad)) 39 | self.biases.append(None if not bias else Parameter(torch.Tensor(1, n_out), requires_grad=requires_grad)) 40 | self.acts.append(get_act(act) if i < n_hid_layers else get_act("linear")) # no act. in final layer 41 | self.z_mus = nn.ParameterList([]) 42 | self.z_logsigs = nn.ParameterList([]) 43 | for i, n_node in enumerate(layers_dim[:-1]): 44 | self.z_mus.append(Parameter(torch.Tensor(n_ens, 1, n_node), requires_grad=requires_grad)) # Nens,1,n 45 | self.z_logsigs.append(Parameter(torch.Tensor(n_ens, 1, n_node), requires_grad=requires_grad)) 46 | self.reset_parameters() 47 | 48 | def shuffle(self): 49 | rand_idx = torch.randperm(self.n_ens) 50 | for mu, logsig in zip(self.z_mus, self.z_logsigs): 51 | mu.data = mu.data[rand_idx] 52 | logsig.data = logsig.data[rand_idx] 53 | 54 | @property 55 | def device(self): 56 | return self.weights[0].device 57 | 58 | def __transform_sig(self, sig): 59 | # return F.softplus(sig) 60 | return sig.exp() + 1e-6 61 | 62 | def reset_parameters(self, gain=1.0): 63 | for i, (weight, bias) in enumerate(zip(self.weights, self.biases)): # pylint: disable=unused-variable 64 | nn.init.xavier_uniform_(weight, gain) 65 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weight) 66 | bound = 1 / np.sqrt(fan_in) 67 | if self.bias: 68 | nn.init.uniform_(bias, -bound, bound) 69 | for z_mu, z_logsig in zip(self.z_mus, self.z_logsigs): 70 | nn.init.normal_(z_mu, 1.0, 0.25) 71 | # nn.init.normal_(z_logsig, 0.05, 0.02) 72 | nn.init.normal_(z_logsig, -2, 0.01) 73 | 74 | def draw_noise(self, L): 75 | assert L // self.n_ens, f"L={L} must be a multiple of n_ens={self.n_ens}" 76 | return [torch.randn([L, 1, z_mu.shape[-1]], device=self.device) for z_mu in self.z_mus] # L,1,N 77 | 78 | def __draw_multiplicative_factors(self, noise_vec): 79 | zs = [] 80 | for i, noise in enumerate(noise_vec): # for each layer 81 | noise = noise.view([-1, *self.z_mus[i].shape]) # L/Nens,Nens,1,n 82 | sig = self.__transform_sig(self.z_logsigs[i]) 83 | z = self.z_mus[i] + noise * sig # L/Nens,Nens,1,n 84 | zs.append(z.reshape(-1, 1, self.z_mus[i].shape[-1])) # L,1,n 85 | return zs # list of L,1,n 86 | 87 | def draw_f(self, L=1, noise_vec=None): 88 | """Draws L//n_ens samples from each ensemble component 89 | Assigns each x[i] to a different sample in a different component 90 | x - [N,n] or [L,N,n] 91 | output - the same shape as input 92 | """ 93 | # assert L//self.n_ens, f'L={L} must be a multiple of n_ens={self.n_ens}' 94 | noise_vec = noise_vec if noise_vec is not None else self.draw_noise(L) 95 | zs = self.__draw_multiplicative_factors(noise_vec) # list of [L,1,n_hidden] 96 | 97 | def f(x): 98 | x2d = x.ndim == 2 99 | x = torch.stack([x] * L) if x2d else x # L,N,n 100 | for z, weight, bias, act in zip(zs, self.weights, self.biases, self.acts): 101 | x_ = (x * z) @ weight + bias 102 | x_ = x_ + x if x.shape == x_.shape and self.skip_con else x_ 103 | x = act(x_) 104 | return x.mean(0) if x2d else x 105 | 106 | return f 107 | 108 | def forward(self, x, L=1): 109 | return self.draw_f(L)(x) 110 | 111 | def kl(self): 112 | kls = [] 113 | for mu, logsig in zip(self.z_mus, self.z_logsigs): 114 | mu_ = mu.mean([0])[0] # n 115 | sig_ = self.__transform_sig(logsig).pow(2).mean(0)[0].pow(0.5) # n 116 | qhat = Normal(mu_, sig_) 117 | p = Normal(torch.ones_like(mu_), torch.ones_like(sig_)) 118 | kl_ = kl(qhat, p).sum() 119 | kls.append(kl_) 120 | return torch.stack(kls).sum() 121 | 122 | def __repr__(self): 123 | str_ = f"iBNN - {self.n_ens} components\n" 124 | for i, (weight, act) in enumerate(zip(self.weights, self.acts)): 125 | str_ += "Layer-{:d}: ".format(i + 1) + "".join(str([*weight.shape][::-1])) + "\t" + str(act) + "\n" 126 | return str_ 127 | -------------------------------------------------------------------------------- /envs/oderl/utils/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from TorchDiffEqPack.odesolver import odesolve as torchdiffeqpack_odesolve 7 | 8 | 9 | def get_act(act="relu"): 10 | if act == "relu": 11 | return nn.ReLU() 12 | elif act == "elu": 13 | return nn.ELU() 14 | elif act == "celu": 15 | return nn.CELU() 16 | elif act == "leaky_relu": 17 | return nn.LeakyReLU() 18 | elif act == "sigmoid": 19 | return nn.Sigmoid() 20 | elif act == "tanh": 21 | return nn.Tanh() 22 | elif act == "sin": 23 | return torch.sin 24 | elif act == "linear": 25 | return nn.Identity() 26 | elif act == "softplus": 27 | return nn.modules.activation.Softplus() 28 | elif act == "swish": 29 | return lambda x: x * torch.sigmoid(x) 30 | else: 31 | return None 32 | 33 | 34 | def sq_dist(X1, X2, ell=1.0): 35 | X1 = X1 / ell 36 | X1s = torch.sum(X1**2, dim=-1).view([-1, 1]) 37 | X2 = X2 / ell 38 | X2s = torch.sum(X2**2, dim=-1).view([1, -1]) 39 | sq_dist = -2 * torch.mm(X1, X2.t()) + X1s + X2s # pylint: disable=redefined-outer-name 40 | return sq_dist 41 | 42 | 43 | def sq_dist3(X1, X2, ell=1.0): 44 | N = X1.shape[0] 45 | X1 = X1 / ell 46 | X1s = torch.sum(X1**2, dim=-1).view([N, -1, 1]) 47 | X2 = X2 / ell 48 | X2s = torch.sum(X2**2, dim=-1).view([N, 1, -1]) 49 | sq_dist = -2 * X1 @ X2.transpose(-1, -2) + X1s + X2s # pylint: disable=redefined-outer-name 50 | return sq_dist 51 | 52 | 53 | def batch_sq_dist(x, y, ell=1.0): 54 | """ 55 | Modified from https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065/3 56 | Input: x is a bxNxd matrix y is an optional bxMxd matrix 57 | Output: dist is a bxNxM matrix where dist[b,i,j] is the square norm between x[b,i,:] and y[b,j,:] 58 | i.e. dist[i,j] = ||x[b,i,:]-y[b,j,:]||^2 59 | """ 60 | assert x.ndim == 3, "Input1 must be 3D, not {x.shape}" 61 | y = y if y.ndim == 3 else torch.stack([y] * x.shape[0]) 62 | assert y.ndim == 3, "Input2 must be 3D, not {y.shape}" 63 | x, y = x / ell, y / ell 64 | x_norm = (x**2).sum(2).view(x.shape[0], x.shape[1], 1) 65 | y_t = y.permute(0, 2, 1).contiguous() 66 | y_norm = (y**2).sum(2).view(y.shape[0], 1, y.shape[1]) 67 | dist = x_norm + y_norm - 2.0 * torch.bmm(x, y_t) 68 | dist[dist != dist] = 0 # replace nan values with 0 69 | return torch.clamp(dist, 0.0, np.inf) 70 | 71 | 72 | def K(X1, X2, ell=1.0, sf=1.0, eps=1e-5): 73 | dnorm2 = sq_dist(X1, X2, ell) if X1.ndim == 2 else sq_dist3(X1, X2, ell) 74 | K_ = sf**2 * torch.exp(-0.5 * dnorm2) 75 | if X1.shape[-2] == X2.shape[-2]: 76 | return K_ + torch.eye(X1.shape[-2], device=X1.device) * eps 77 | return K_ 78 | 79 | 80 | def torch_to_numpy(a): 81 | if isinstance(a, torch.Tensor): 82 | return a.cpu().detach().numpy() 83 | else: 84 | return a 85 | 86 | 87 | def numpy_to_torch(a, device="cpu", dtype=torch.float32): 88 | if isinstance(a, np.ndarray) or isinstance(a, list): 89 | return torch.tensor(a, dtype=dtype).to(device) 90 | else: 91 | return a 92 | 93 | 94 | def log_sum_exp(value, dim=None, keepdim=False): 95 | """Numerically stable implementation of the operation 96 | 97 | value.exp().sum(dim, keepdim).log() 98 | """ 99 | if dim is not None: 100 | m, _ = torch.max(value, dim=dim, keepdim=True) 101 | value0 = value - m 102 | if keepdim is False: 103 | m = m.squeeze(dim) 104 | return m + torch.log(torch.sum(torch.exp(value0), dim=dim, keepdim=keepdim)) 105 | else: 106 | m = torch.max(value) 107 | sum_exp = torch.sum(torch.exp(value - m)) 108 | if isinstance(sum_exp, torch.tensor): # pyright: ignore 109 | return m + torch.log(sum_exp) 110 | else: 111 | return m + math.log(sum_exp) 112 | 113 | 114 | def flatten_(sequence): 115 | flat = [p.contiguous().view(-1) for p in sequence] 116 | return torch.cat(flat) if len(flat) > 0 else torch.tensor([]) 117 | 118 | 119 | def smooth(x, w=7): 120 | x = np.array(x) 121 | y = np.zeros_like(x) 122 | for i in range(len(y)): 123 | y[i] = x[max(0, i - w) : min(i + w, len(y))].mean() 124 | return y 125 | 126 | 127 | def odesolve(f, z0, ts, step_size, method, rtol, atol): 128 | options = {} 129 | method = "midpoint" if method == "RK2" else method 130 | options.update({"method": method}) 131 | options.update({"step_size": step_size}) 132 | options.update({"t0": ts[0].item()}) 133 | options.update({"t1": ts[-1].item()}) 134 | options.update({"rtol": rtol}) 135 | options.update({"atol": atol}) 136 | options.update({"t_eval": ts.tolist()}) 137 | return torchdiffeqpack_odesolve(f, z0, options) 138 | 139 | 140 | def Klinear(X1, X2, ell=1.0, sf=1.0, eps=1e-5): 141 | dnorm2 = sq_dist(X1, X2, ell) if X1.ndim == 2 else sq_dist3(X1, X2, ell) 142 | K_ = sf**2 * torch.exp(-0.5 * dnorm2) 143 | if X1.shape[-2] == X2.shape[-2]: 144 | return K_ + torch.eye(X1.shape[-2], device=X1.device) * eps 145 | return K_ 146 | 147 | 148 | class KernelInterpolation: 149 | def __init__(self, sf, ell, X, y, eps=1e-5, kernel="exp"): 150 | self.sf = sf 151 | self.ell = ell 152 | self.X = X 153 | self.y = y 154 | self.eps = eps 155 | self.K = K if kernel == "exp" else Klinear 156 | self.KXX_inv_y = torch.linalg.solve(self.K(X, X, ell, sf, eps), y)[0] 157 | # self.KXX_inv_y = y.solve(self.K(X,X,ell,sf,eps))[0] 158 | 159 | def __call__(self, x): 160 | x = x if isinstance(x, torch.Tensor) else torch.tensor(x) 161 | kxX = self.K(x, self.X, self.ell, self.sf, self.eps) 162 | out = kxX @ self.KXX_inv_y # 1,nout 163 | return out 164 | -------------------------------------------------------------------------------- /mppi_optim.yaml: -------------------------------------------------------------------------------- 1 | program: mppi_with_model.py 2 | method: bayes 3 | metric: 4 | goal: maximize 5 | name: total_reward 6 | parameters: 7 | mppi_roll_outs: 8 | # distribution: log_uniform 9 | # # min: 0.0 10 | # # max: 1.0 11 | # # q: 1 12 | values: [1,2,4,8,16,32,64,128,256,512,1024,2048,4096,8192,16384,32768,65536,131072,262144] 13 | mppi_time_steps: 14 | # distribution: q_log_uniform 15 | # min: 1 16 | # max: 6 17 | # q: 1 18 | values: [1,2,4,8,16,32,64,128,256,512,1024,2048,4096,8192,16384,32768,65536,131072,262144] 19 | mppi_lambda: 20 | # distribution: q_log_uniform 21 | # min: -3 22 | # max: 2 23 | # q: 0.0001 24 | values: [0.00001, 0.0001, 0.001, 0.01, 0.1, 1.0, 10.0, 100.0, 1000.0] 25 | mppi_sigma: 26 | # distribution: q_log_uniform 27 | # min: -3 28 | # max: 2 29 | # q: 0.0001 30 | values: [0.00001, 0.0001, 0.001, 0.01, 0.1, 0.5, 0.8, 1.0, 1.5, 2.0, 10.0, 100.0, 1000.0] 31 | 32 | early_terminate: 33 | type: hyperband 34 | s: 2 35 | eta: 3 36 | max_iter: 27 37 | -------------------------------------------------------------------------------- /planners/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/samholt/NeuralLaplaceControl/e827944748538b7356ff6bda04aaac34a0da41d2/planners/__init__.py -------------------------------------------------------------------------------- /process_results/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/samholt/NeuralLaplaceControl/e827944748538b7356ff6bda04aaac34a0da41d2/process_results/__init__.py -------------------------------------------------------------------------------- /process_results/files/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/samholt/NeuralLaplaceControl/e827944748538b7356ff6bda04aaac34a0da41d2/process_results/files/.gitkeep -------------------------------------------------------------------------------- /process_results/plot_util.py: -------------------------------------------------------------------------------- 1 | def get_normalized_policy_values_delay_zero(): 2 | random_policy = { 3 | "oderl-acrobot": -2948.636826752257, 4 | "oderl-cartpole": -14246.301963850627, 5 | "oderl-pendulum": -616.7659306662474, 6 | } 7 | best_policy = { 8 | "oderl-acrobot": -571.1055129432718, 9 | "oderl-cartpole": -139.68956484338668, 10 | "oderl-pendulum": -121.04611233502484, 11 | } 12 | return random_policy, best_policy 13 | 14 | 15 | def get_normalized_policy_values_delay_one(): 16 | random_policy = { 17 | "oderl-acrobot": -2910.5048468493706, 18 | "oderl-cartpole": -9713.192129825948, 19 | "oderl-pendulum": -575.9776055772861, 20 | } 21 | best_policy = { 22 | "oderl-acrobot": -558.764978724654, 23 | "oderl-cartpole": -146.26268198045534, 24 | "oderl-pendulum": -123.43791579297383, 25 | } 26 | return random_policy, best_policy 27 | -------------------------------------------------------------------------------- /process_results/process_logs.py: -------------------------------------------------------------------------------- 1 | import ast 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import pandas as pd 6 | import scipy.stats 7 | import seaborn as sn 8 | import torch 9 | from tqdm import tqdm 10 | 11 | pd.set_option("mode.chained_assignment", None) 12 | # SCALE = 7 13 | SCALE = 13 14 | # HEIGHT_SCALE = 1.5 15 | HEIGHT_SCALE = 0.5 16 | sn.set(rc={"figure.figsize": (SCALE, int(HEIGHT_SCALE * SCALE))}) 17 | # sn.set(font_scale=1.4) 18 | sn.set(font_scale=2.0) 19 | sn.set_style(style="white") 20 | # sn.color_palette("tab10") 21 | sn.color_palette("colorblind") 22 | # plt.style.use('tableau-colorblind10') 23 | 24 | 25 | # LEGEND_Y_CORD = -0.70 # * (HEIGHT_SCALE / 2.0) 26 | LEGEND_Y_CORD = -0.75 # * (HEIGHT_SCALE / 2.0) 27 | SUBPLOT_ADJUST = 1 / HEIGHT_SCALE # -(0.05 + LEGEND_Y_CORD) 28 | LEGEND_X_CORD = 0.45 29 | 30 | # plt.gcf().subplots_adjust(bottom=(1-1/HEIGHT_SCALE), left=0.15, top=0.99) 31 | plt.gcf().subplots_adjust(bottom=0.40, left=0.2, top=0.95) 32 | # LINE_WIDTH = 3 33 | 34 | PLOT_FROM_CACHE = False 35 | PLOT_SAFETY_MARGIN = 1.25 36 | 37 | N = 3 # Significant Figures for Results 38 | DP = 5 39 | 40 | np.random.seed(999) 41 | torch.random.manual_seed(999) 42 | 43 | 44 | def is_float(element) -> bool: 45 | try: 46 | float(element) 47 | return True 48 | except ValueError: 49 | return False 50 | 51 | 52 | def string_to_float_dict(d): 53 | return {k: float(v) if is_float(v) else v for k, v in d.items()} 54 | 55 | 56 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 57 | 58 | X_METRIC = "nevals" 59 | # Y_METRIC='nmse_test' 60 | Y_METRIC = "nmse_train" 61 | # Y_METRIC='r_best' 62 | 63 | 64 | def mean_confidence_interval(data, confidence=0.95): 65 | a = 1.0 * np.array(data) 66 | n = len(a) 67 | m, se = np.mean(a), scipy.stats.sem(a) 68 | h = se * scipy.stats.t.ppf((1 + confidence) / 2.0, n - 1) 69 | return m, 2 * h # m-h, m+h 70 | 71 | 72 | def confidence_interval(prob, n): 73 | return 1.96 * np.sqrt((prob * (1 - prob)) / n) 74 | 75 | 76 | N = 5 # Significant figure 77 | 78 | LOG_PATH = "./process_results/files/main_table_results.txt" 79 | 80 | GENERATE_FIGS = False 81 | 82 | HEADINGS = [ 83 | "val_loss", 84 | "train_loss", 85 | "best_val_loss", 86 | "total_reward", 87 | "delay", 88 | "model_name", 89 | "seed", 90 | "planner", 91 | ] 92 | HEADINGS_NEW = [ 93 | "val_loss", 94 | "train_loss", 95 | "best_val_loss", 96 | "total_reward", 97 | "delay", 98 | "model_name", 99 | "seed", 100 | "planner", 101 | "model_env", 102 | ] 103 | ENVS = ["oderl-pendulum", "oderl-cartpole", "oderl-acrobot"] 104 | env_inx = 0 105 | 106 | name_map = { 107 | # pylint: disable=anomalous-backslash-in-string 108 | "delta_t_rnn+mpc": "$\Delta t-$RNN", # pyright: ignore # noqa: W605 109 | "latent_ode+mpc": "Latent-ODE", 110 | "nl+mpc": "NLC \\textbf{(Ours)}", 111 | "node+mpc": "NODE", 112 | "oracle+mpc": "Oracle", 113 | "random+mpc": "Random", 114 | } 115 | 116 | custom_method_order = { 117 | "delta_t_rnn+mpc": 2, 118 | "latent_ode+mpc": 3, 119 | "nl+mpc": 5, 120 | "node+mpc": 4, 121 | "oracle+mpc": 1, 122 | "random+mpc": 0, 123 | } 124 | 125 | 126 | def name_mapper(name): 127 | return name_map[name] 128 | 129 | 130 | NORMALIZE = True 131 | CHANGE_COLUMN_HEADINGS = True 132 | if __name__ == "__main__": 133 | with open(LOG_PATH) as f: # pylint: disable=unspecified-encoding 134 | lines = f.readlines() 135 | 136 | # datasets = {} 137 | pd_l = [] 138 | df_tmp = [] # Drop last entry if not completed 139 | delay = None 140 | training = False 141 | lines_to_skip = 5 142 | lines_seen = 0 143 | delay = 0 144 | 145 | for line in tqdm(lines): 146 | if "[Model Completed evaluation mppi] {" in line and not training: 147 | result_dict = line.split("[Model Completed evaluation mppi] ")[1].strip() 148 | result_dict = result_dict.replace("nan", "'nan'") 149 | result_dict = ast.literal_eval(result_dict) 150 | pd_l.append(result_dict) 151 | if "[Model Completed evaluation q] {" in line and not training: 152 | result_dict = line.split("[Model Completed evaluation q] ")[1].strip() 153 | result_dict = result_dict.replace("nan", "'nan'") 154 | result_dict = ast.literal_eval(result_dict) 155 | pd_l.append(result_dict) 156 | 157 | dfm = pd.DataFrame(pd_l) 158 | dfm[["total_reward", "delay", "seed"]] = dfm[["total_reward", "delay", "seed"]].apply( 159 | pd.to_numeric, errors="coerce" 160 | ) 161 | # dfm[['val_loss', 'train_loss', 'best_val_loss', 'total_reward', 'delay', 'seed']] = 162 | # dfm[['val_loss', 'train_loss', 'best_val_loss', 'total_reward', 'delay', 'seed']] 163 | # .apply(pd.to_numeric, errors='coerce') 164 | dfm["name"] = dfm["model_name"] + "+" + dfm["planner"] 165 | dfm.drop(columns=["model_name", "planner"], inplace=True) 166 | t = dfm.groupby(["delay", "env_name", "name", "seed"]).agg("mean")["total_reward"] 167 | 168 | delay_results = {} 169 | finals_t = [] 170 | for delay in [d for d in dfm["delay"].unique() if d >= 1]: 171 | b = t.unstack(level=0)[delay] 172 | if NORMALIZE: 173 | # if delay == 1: 174 | # print('') 175 | best_policy = b.unstack(level=-1).mean(1).unstack()["oracle+mpc"] 176 | # best_policy = b.unstack(level=-1).mean(1).unstack().max(1) 177 | random_policy = b.unstack(level=-1).mean(1).unstack()["random+mpc"] 178 | # random_policy = b.unstack(level=-1).mean(1).unstack().min(1) 179 | bi = b.unstack() 180 | delay_l = [] 181 | # for env_name in b.unstack(level=0).columns: 182 | for env_name in ["oderl-cartpole", "oderl-pendulum", "oderl-acrobot"]: 183 | if NORMALIZE: 184 | vals = (b.unstack(level=0)[env_name] - random_policy[env_name]) / ( # pyright: ignore 185 | best_policy[env_name] - random_policy[env_name] # pyright: ignore 186 | ) 187 | vm = vals.unstack().mean(1) * 100.0 188 | vstd = vals.unstack().std(1) * 100.0 189 | vstd[vm < 0] = 0 190 | vm[vm < 0] = 0 191 | else: 192 | vals = b.unstack(level=0)[env_name] 193 | vm = vals.unstack().mean(1) 194 | vstd = vals.unstack().std(1) 195 | # pylint: disable=anomalous-backslash-in-string 196 | res = ( 197 | vm.round(2).astype("string") + "$\pm$" + vstd.round(2).astype("string") # pyright: ignore # noqa: W605 198 | ) 199 | res.name = env_name 200 | delay_l.append(res) 201 | final = pd.concat(delay_l, axis=1).transpose() 202 | final.index = final.index + f"_d={delay}" 203 | # if delay != 0: 204 | finals_t.append(final.transpose()) 205 | # print(f'DELAY: {delay}') 206 | # str_p = final.to_latex(escape=False).replace('\\textbackslash', '\\') 207 | str_p = final.to_latex(escape=False) 208 | str_p = str_p.replace("", "NA") 209 | # print(str_p) 210 | # print('') 211 | delay_results[delay] = final 212 | final_df = pd.concat(finals_t, axis=1) 213 | final_df = final_df[["+mpc" in s for s in final_df.index]] 214 | final_df = final_df.drop("rnn+mpc", errors="ignore") 215 | final_df = final_df.sort_values(by=["name"], key=lambda x: x.map(custom_method_order)) 216 | final_df.index = final_df.index.map(name_mapper) 217 | str_p = final_df.to_latex(escape=False) 218 | str_p = str_p.replace("", "NA") 219 | if CHANGE_COLUMN_HEADINGS: 220 | lines = str_p.split("\n") 221 | lines[0] = r"\begin{tabular}{c|ccc|ccc|ccc}" 222 | lines[2] = ( 223 | r" & \multicolumn{3}{c}{Action Delay~$\tau=\bar{\Delta}$ s} " 224 | r"& \multicolumn{3}{c}{Action Delay~$\tau=2\bar{\Delta}$ s} " 225 | r"& \multicolumn{3}{c}{Action Delay~$\tau=3\bar{\Delta}$ s} \\" 226 | ) 227 | lines[3] = ( 228 | r" Dynamics Model & Cartpole & Pendulum & Acrobot " 229 | r"& Cartpole & Pendulum & Acrobot & Cartpole & Pendulum & Acrobot \\ " 230 | ) 231 | lines.insert(-4, r"\midrule") 232 | str_p = "\n".join(lines) 233 | print(str_p) 234 | print("") 235 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | build-backend = "setuptools.build_meta" 3 | requires = ["setuptools>=46.1.0"] 4 | 5 | [tool.bandit] 6 | exclude_dirs = ["tests"] 7 | 8 | [tool.black] 9 | include = '\.pyi?$' 10 | line-length = 120 11 | target-version = ['py38', 'py39', 'py310'] 12 | 13 | [tool.isort] 14 | known_first_party = """baseline_models,envs,planners,process_results,ctrl,config, 15 | mppi_dataset_collector,mppi_with_model,oracle,overlay,train_utils,w_latent_ode,w_nl 16 | """ 17 | profile = "black" 18 | src_paths = ["src"] 19 | 20 | [tool.pylint] 21 | disable = "R,C,fixme,unused-argument,protected-access,attribute-defined-outside-init,import-error" 22 | generated-members = "torch.*" 23 | 24 | # ignored-modules = "scipy.special" 25 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | # tox-conda 2 | bandit[toml] 3 | black[jupyter] 4 | flake8 5 | isort 6 | pre-commit 7 | pylint 8 | tox 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | control 2 | gym~=0.21.0 3 | imageio 4 | imageio-ffmpeg 5 | matplotlib 6 | pandas 7 | pandas 8 | pyglet~=1.5.27 9 | pyvirtualdisplay 10 | scikit-learn 11 | scipy 12 | seaborn 13 | sklearn 14 | torch 15 | torchdiffeq>=0.2.3 16 | TorchDiffEqPack 17 | torchlaplace 18 | torchvision 19 | tqdm 20 | wandb 21 | -------------------------------------------------------------------------------- /run_exp_multi.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | import traceback 5 | from functools import partial 6 | from pathlib import Path 7 | 8 | import torch 9 | import wandb 10 | from torch import multiprocessing 11 | from tqdm import tqdm 12 | 13 | from config import dotdict, get_config, seed_all 14 | from mppi_with_model import mppi_with_model_evaluate_single_step 15 | from train_utils import train_model 16 | 17 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 18 | 19 | MODELS = ["nl", "oracle", "random", "delta_t_rnn", "node", "latent_ode"] 20 | ENVIRONMENTS = ["oderl-cartpole", "oderl-acrobot", "oderl-pendulum"] 21 | DELAYS = list(range(4)) 22 | RETRAIN = False 23 | FORCE_RETRAIN = False 24 | START_FROM_CHECKPOINT = True 25 | MODEL_TRAIN_SEED = 0 26 | PRINT_SETTINGS = False 27 | 28 | trainable_models = [model_name for model_name in MODELS if not ("random" in model_name or "oracle" in model_name)] 29 | 30 | 31 | def train_model_wrapper(args, **kwargs): 32 | try: 33 | (env_name, delay, model_name) = args 34 | 35 | config = kwargs["config"] # pylint: disable=redefined-outer-name 36 | config = dotdict(config) 37 | kwargs["config"] = config 38 | # pylint: disable-next=logging-fstring-interpolation 39 | logger = create_logger_in_process(config.log_path) # pylint: disable=redefined-outer-name 40 | # pylint: disable-next=logging-fstring-interpolation 41 | logger.info(f"[Now training model] {model_name} \t {env_name} \t {delay}") 42 | seed_all(config.seed_start) 43 | kwargs["delay"] = delay 44 | model, results = train_model(model_name, env_name, **kwargs) # pylint: disable=unused-variable 45 | results["errored"] = False 46 | except Exception as e: # pylint: disable=broad-exception-caught 47 | # pylint: disable-next=logging-fstring-interpolation 48 | logger.exception(f"[Error] {e}") # pyright: ignore 49 | # pylint: disable-next=logging-fstring-interpolation 50 | logger.info( # pyright: ignore 51 | f"[Failed training model] {env_name} {model_name} delay={delay} \t " # pyright: ignore 52 | f"model_seed={MODEL_TRAIN_SEED} \t | error={e}" 53 | ) 54 | traceback.print_exc() 55 | results = {"errored": True} 56 | print("") 57 | results.update({"delay": delay, "model_name": model_name, "env_name": env_name}) # pyright: ignore 58 | # pylint: disable-next=logging-fstring-interpolation 59 | logger.info(f"[Training Result] {model_name} result={results}") # pyright: ignore 60 | return results 61 | 62 | 63 | def mppi_with_model_evaluate_single_step_wrapper(args, **kwargs): 64 | try: 65 | (env_name, delay, model_name, seed) = args 66 | 67 | seed_all(seed) 68 | config = kwargs["config"] # pylint: disable=redefined-outer-name 69 | config = dotdict(config) 70 | kwargs["config"] = config 71 | logger = create_logger_in_process(config.log_path) # pylint: disable=redefined-outer-name 72 | # pylint: disable-next=logging-fstring-interpolation 73 | logger.info(f"[Now evaluating mppi model] {model_name} \t {env_name} \t {delay}") 74 | results = mppi_with_model_evaluate_single_step( 75 | model_name=model_name, 76 | action_delay=delay, 77 | env_name=env_name, 78 | seed=seed, 79 | **kwargs, 80 | ) 81 | results["errored"] = False 82 | except Exception as e: # pylint: disable=broad-exception-caught 83 | # pylint: disable-next=logging-fstring-interpolation 84 | logger.exception(f"[Error] {e}") # pyright: ignore 85 | # pylint: disable-next=logging-fstring-interpolation 86 | logger.info( # pyright: ignore 87 | f"[Failed evaluating mppi model] {env_name} {model_name} delay={delay} \t " # pyright: ignore 88 | f"model_seed={MODEL_TRAIN_SEED} \t | error={e}" 89 | ) 90 | traceback.print_exc() 91 | results = {"errored": True} 92 | print("") 93 | results.update({"delay": delay, "model_name": model_name, "env_name": env_name, "seed": seed}) # pyright: ignore 94 | # pylint: disable-next=logging-fstring-interpolation 95 | logger.info(f"[Evaluate Result] result={results}") # pyright: ignore 96 | return results 97 | 98 | 99 | def main(config, wandb=None): # pylint: disable=redefined-outer-name 100 | model_training_results_l = [] 101 | model_eval_results_l = [] 102 | 103 | pool_outer = multiprocessing.Pool(config.collect_expert_cores_per_env_sampler) 104 | if config.retrain: 105 | train_all_model_inputs = [ 106 | (env_name, delay, model_name) 107 | for env_name in ENVIRONMENTS 108 | for delay in DELAYS 109 | for model_name in trainable_models 110 | ] 111 | # pylint: disable-next=logging-fstring-interpolation 112 | logger.info(f"Going to train for {len(train_all_model_inputs)} tasks") 113 | with multiprocessing.Pool(1) as pool_outer: # 12 114 | multi_wrapper_train_model = partial( 115 | train_model_wrapper, 116 | config=dict(config), 117 | wandb=None, 118 | model_seed=config.model_seed, 119 | retrain=config.retrain, 120 | start_from_checkpoint=config.start_from_checkpoint, 121 | force_retrain=config.force_retrain, 122 | print_settings=config.print_settings, 123 | evaluate_model_when_trained=False, 124 | ) 125 | for i, result in tqdm( # pylint: disable=unused-variable 126 | enumerate(pool_outer.imap_unordered(multi_wrapper_train_model, train_all_model_inputs)), 127 | total=len(train_all_model_inputs), 128 | smoothing=0, 129 | ): 130 | # pylint: disable-next=logging-fstring-interpolation 131 | logger.info(f"[Model Completed training] {result}") 132 | model_training_results_l.append(result) 133 | 134 | # Compute the results - in multiprocessing now 135 | mppi_evaluate_all_model_inputs = [ 136 | (env_name, delay, model_name, seed) 137 | for env_name in ENVIRONMENTS 138 | for delay in DELAYS 139 | for model_name in MODELS 140 | for seed in range(config.seed_start, config.seed_runs + config.seed_start) 141 | ] 142 | # pylint: disable-next=logging-fstring-interpolation 143 | logger.info(f"Evaluating mppi for seed input {len(mppi_evaluate_all_model_inputs)} tasks") 144 | if config.multi_process_results: 145 | pool_outer = multiprocessing.Pool(12) # 12, 8 , 18 146 | multi_wrapper_mppi_evaluate = partial( 147 | mppi_with_model_evaluate_single_step_wrapper, 148 | config=dict(config), 149 | roll_outs=config.mppi_roll_outs, 150 | time_steps=config.mppi_time_steps, 151 | lambda_=config.mppi_lambda, 152 | sigma=config.mppi_sigma, 153 | dt=config.dt, 154 | encode_obs_time=config.encode_obs_time, 155 | save_video=config.save_video, 156 | ) 157 | if config.multi_process_results: 158 | for i, result in tqdm( 159 | enumerate(pool_outer.imap_unordered(multi_wrapper_mppi_evaluate, mppi_evaluate_all_model_inputs)), 160 | total=len(mppi_evaluate_all_model_inputs), 161 | smoothing=0, 162 | ): 163 | # pylint: disable-next=logging-fstring-interpolation 164 | logger.info(f"[Model Completed evaluation mppi] {result}") 165 | model_eval_results_l.append(result) 166 | else: 167 | for i, task_input in tqdm( 168 | enumerate(mppi_evaluate_all_model_inputs), 169 | total=len(mppi_evaluate_all_model_inputs), 170 | smoothing=0, 171 | ): 172 | result = multi_wrapper_mppi_evaluate(task_input) 173 | # pylint: disable-next=logging-fstring-interpolation 174 | logger.info(f"[Model Completed evaluation mppi] {result}") 175 | model_eval_results_l.append(result) 176 | if config.multi_process_results: 177 | pool_outer.close() 178 | 179 | 180 | def generate_log_file_path(file, log_folder="logs"): 181 | file_name = os.path.basename(os.path.realpath(file)).split(".py")[0] 182 | Path(f"./{log_folder}").mkdir(parents=True, exist_ok=True) 183 | path_run_name = "{}-{}".format(file_name, time.strftime("%Y%m%d-%H%M%S")) 184 | return f"{log_folder}/{path_run_name}_log.txt" 185 | 186 | 187 | def create_logger_in_process(log_file_path): 188 | logger = multiprocessing.get_logger() # pylint: disable=redefined-outer-name 189 | if not logger.hasHandlers(): 190 | formatter = logging.Formatter("%(processName)s| %(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s") 191 | stream_handler = logging.StreamHandler() 192 | file_handler = logging.FileHandler(log_file_path) 193 | stream_handler.setFormatter(formatter) 194 | file_handler.setFormatter(formatter) 195 | logger.addHandler(stream_handler) 196 | logger.addHandler(file_handler) 197 | logger.setLevel(logging.INFO) 198 | return logger 199 | 200 | 201 | if __name__ == "__main__": 202 | log_path = generate_log_file_path(__file__) 203 | logger = create_logger_in_process(log_path) 204 | defaults = get_config() 205 | defaults["log_path"] = log_path 206 | if defaults["multi_process_results"]: 207 | torch.multiprocessing.set_start_method("spawn") 208 | defaults["retrain"] = RETRAIN 209 | defaults["force_retrain"] = FORCE_RETRAIN 210 | defaults["start_from_checkpoint"] = START_FROM_CHECKPOINT 211 | defaults["print_settings"] = PRINT_SETTINGS 212 | defaults["model_train_seed"] = MODEL_TRAIN_SEED 213 | defaults["sweep_mode"] = True # Real run settings 214 | defaults["end_training_after_seconds"] = int(1350 * 6.0) 215 | 216 | wandb.init(config=defaults, project=defaults["wandb_project"]) # pyright: ignore 217 | config = wandb.config 218 | seed_all(0) 219 | # pylint: disable-next=logging-fstring-interpolation 220 | logger.info(f"Starting run \t | See log at : {log_path}") 221 | main(config, wandb) 222 | wandb.finish() 223 | logger.info("Run over. Fin.") 224 | # pylint: disable-next=logging-fstring-interpolation 225 | logger.info(f"[Log found at] {log_path}") 226 | -------------------------------------------------------------------------------- /saved_models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/samholt/NeuralLaplaceControl/e827944748538b7356ff6bda04aaac34a0da41d2/saved_models/.gitkeep -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max_line_length = 120 3 | select = C,E,F,W,B,B950 4 | extend_ignore = E203,E501,W503,E266 5 | # ^ Black-compatible 6 | # E203 and W503 have edge cases handled by black 7 | # Additionally updated from: 8 | # https://black.readthedocs.io/en/stable/the_black_code_style/current_style.html#line-length 9 | exclude = 10 | build 11 | dist 12 | .eggs 13 | -------------------------------------------------------------------------------- /setup/install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Create a virtual environment with Python 3.9 (here using conda): 4 | conda create --name nlc python=3.9.7 5 | conda activate nlc 6 | 7 | # Install dependencies: 8 | conda install pytorch torchvision pytorch-cuda=11.6 -c pytorch -c nvidia 9 | pip install -r requirements.txt 10 | 11 | # If you have any issues with ffmpeg, try: 12 | # conda update ffmpeg 13 | # pip install imageio-ffmpeg 14 | 15 | # Optional. For library development, install developement dependencies. 16 | pip install -r requirements-dev.txt 17 | -------------------------------------------------------------------------------- /w_latent_ode.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import matplotlib 4 | import matplotlib.pyplot 5 | import torch 6 | from torch import nn 7 | from torchlaplace import laplace_reconstruct 8 | 9 | from baseline_models.latent_ode_lib.create_latent_ode_model import ( 10 | create_LatentODE_model_direct, 11 | ) 12 | from baseline_models.latent_ode_lib.plotting import Normal 13 | from baseline_models.latent_ode_lib.utils import compute_loss_all_batches_direct 14 | 15 | matplotlib.use("Agg") 16 | 17 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 18 | logger = logging.getLogger() 19 | 20 | 21 | class GeneralLatentODEOfficial(nn.Module): 22 | def __init__( 23 | self, 24 | state_dim, 25 | action_dim, 26 | latent_dim, 27 | hidden_units=64, 28 | state_mean=None, 29 | state_std=None, 30 | action_mean=None, 31 | action_std=None, 32 | normalize=False, 33 | normalize_time=False, 34 | dt=0.05, 35 | classif_per_tp=False, 36 | n_labels=1, 37 | obsrv_std=0.01, 38 | ): 39 | super(GeneralLatentODEOfficial, self).__init__() 40 | input_dim = state_dim + action_dim 41 | action_encoder_latent_dim = 2 42 | latents = state_dim + action_encoder_latent_dim 43 | # latents = 2 44 | self.latents = latents 45 | self.output_dim = state_dim 46 | self.normalize = normalize 47 | self.normalize_time = normalize_time 48 | self.register_buffer("state_mean", torch.tensor(state_mean)) 49 | self.register_buffer("state_std", torch.tensor(state_std)) 50 | self.register_buffer("action_mean", torch.tensor(action_mean)) 51 | self.register_buffer("action_std", torch.tensor(action_std)) 52 | self.register_buffer("dt", torch.tensor(dt)) 53 | 54 | obsrv_std = torch.Tensor([obsrv_std]).to(device) 55 | z0_prior = Normal(torch.Tensor([0.0]).to(device), torch.Tensor([1.0]).to(device)) 56 | self.model = create_LatentODE_model_direct( 57 | input_dim, 58 | z0_prior, 59 | obsrv_std, 60 | device, 61 | classif_per_tp=classif_per_tp, 62 | n_labels=n_labels, 63 | latents=latents, 64 | units=hidden_units, 65 | gru_units=hidden_units, 66 | ).to(device) 67 | self.latents = latents 68 | self.batch_obs_buffer = torch.zeros(1, 4, state_dim).to(device) 69 | 70 | def _get_loss(self, dl): 71 | loss = compute_loss_all_batches_direct(self.model, dl, device=device, classif=0) 72 | return loss["loss"], loss["mse"] 73 | 74 | def train_loss(self, in_batch_obs, in_batch_action, ts_pred): 75 | if self.normalize: 76 | batch_obs = (in_batch_obs - self.state_mean) / self.state_std 77 | batch_action = (in_batch_action - self.action_mean) / self.action_std 78 | else: 79 | batch_obs = in_batch_obs 80 | batch_action = in_batch_action / 3.0 81 | p_action = self.action_encoder(batch_action) # pyright: ignore 82 | sa_in = torch.cat((batch_obs, p_action), axis=1) # pyright: ignore 83 | if len(sa_in.shape) == 2: 84 | sa_in = sa_in.unsqueeze(1) 85 | 86 | p = sa_in.squeeze() 87 | return torch.squeeze( 88 | laplace_reconstruct( 89 | self.laplace_rep_func, 90 | p, 91 | ts_pred, 92 | recon_dim=self.output_dim, 93 | ilt_algorithm=self.ilt_algorithm, # pyright: ignore 94 | ) 95 | ) 96 | 97 | def train_step(self, in_batch_obs, in_batch_action, ts_pred, observed_tp, pred_batch_obs_diff): 98 | if self.normalize: 99 | batch_obs = (in_batch_obs - self.state_mean) / self.state_std 100 | batch_action = (in_batch_action - self.action_mean) / self.action_std 101 | else: 102 | batch_obs = in_batch_obs 103 | batch_action = in_batch_action / 3.0 104 | # if self.normalize_time: 105 | # ts_pred = (ts_pred / (self.dt*8.0)) 106 | batch_size = batch_obs.shape[0] 107 | 108 | if len(batch_action.shape) == 2: 109 | batch_action = batch_action.unsqueeze(1) 110 | 111 | observed_data = torch.cat((batch_obs, in_batch_action), dim=2) 112 | data_to_predict = torch.cat( 113 | ( 114 | pred_batch_obs_diff.view(batch_size, 1, -1), 115 | torch.zeros((batch_size, 1, batch_action.shape[2]), device=device, dtype=torch.double), 116 | ), 117 | dim=2, 118 | ) 119 | 120 | batch = { 121 | "observed_data": observed_data, 122 | "observed_tp": observed_tp, 123 | "data_to_predict": data_to_predict, 124 | "tp_to_predict": ts_pred, 125 | "observed_mask": torch.ones_like(observed_data), 126 | "mask_predicted_data": torch.ones_like(data_to_predict), 127 | "labels": None, 128 | "mode": "extrap", 129 | } 130 | loss = self.model.compute_all_losses(batch) 131 | return loss["loss"] 132 | 133 | def training_step_(self, batch): 134 | loss = self.model.compute_all_losses(batch) 135 | return loss["loss"] 136 | 137 | def validation_step(self, dlval): 138 | loss, mse = self._get_loss(dlval) 139 | return loss, mse 140 | 141 | def test_step(self, dltest): 142 | loss, mse = self._get_loss(dltest) 143 | return loss, mse 144 | 145 | def forward(self, in_batch_obs, in_batch_action, ts_pred): 146 | if self.normalize: 147 | batch_obs = (in_batch_obs - self.state_mean) / self.state_std 148 | batch_action = (in_batch_action - self.action_mean) / self.action_std 149 | else: 150 | batch_obs = in_batch_obs 151 | batch_action = in_batch_action / 3.0 152 | # if self.normalize_time: 153 | # ts_pred = (ts_pred / (self.dt*8.0)) 154 | 155 | if len(in_batch_obs.shape) == 3: 156 | observed_data = torch.cat((batch_obs, batch_action), dim=2) 157 | else: 158 | if len(batch_action.shape) == 2: 159 | batch_action = batch_action.unsqueeze(1) 160 | 161 | if batch_obs.shape[0] == 1: 162 | self.batch_obs_buffer[0,] = torch.roll(self.batch_obs_buffer[0,], -1, dims=0) 163 | self.batch_obs_buffer[:, -1, :] = batch_obs 164 | observed_data = torch.cat((self.batch_obs_buffer, batch_action), dim=2) 165 | else: 166 | if self.batch_obs_buffer.shape[0] != batch_obs.shape[0]: 167 | self.batch_obs_buffer = torch.zeros(batch_obs.shape[0], 4, batch_obs.shape[1]).to(device) 168 | self.batch_obs_buffer = torch.roll(self.batch_obs_buffer, -1, dims=1) 169 | self.batch_obs_buffer[:, -1, :] = batch_obs 170 | observed_data = torch.cat((self.batch_obs_buffer, batch_action), dim=2) 171 | # observed_data = torch.cat((batch_obs.view(batch_size, 1, -1)\ 172 | # .repeat(1, batch_action.shape[1], 1), batch_action),dim=2) 173 | observed_ts = ( 174 | torch.arange(-(in_batch_action.shape[1] - 1), 1, 1, device=device, dtype=torch.double) * self.dt 175 | ).view(1, -1) 176 | 177 | if ts_pred.shape[0] > 1: 178 | if ts_pred.unique().size()[0] == 1: 179 | ts_pred = ts_pred[0].view(1, 1) 180 | else: 181 | raise ValueError("ts_pred must be unique") 182 | 183 | batch = { 184 | "observed_data": observed_data, 185 | "observed_tp": observed_ts, 186 | "data_to_predict": None, 187 | "tp_to_predict": ts_pred, 188 | "observed_mask": torch.ones_like(observed_data), 189 | "mask_predicted_data": None, 190 | "labels": None, 191 | "mode": "extrap", 192 | } 193 | predict = self.predict_(batch) 194 | return predict[:, :, : -in_batch_action.shape[2]].squeeze() 195 | 196 | def predict_(self, batch): 197 | pred_y, _ = self.model.get_reconstruction( 198 | batch["tp_to_predict"], 199 | batch["observed_data"], 200 | batch["observed_tp"], 201 | mask=batch["observed_mask"], 202 | n_traj_samples=1, 203 | mode=batch["mode"], 204 | ) 205 | return pred_y.squeeze(0) 206 | 207 | def encode(self, dl): 208 | encodings = [] 209 | for batch in dl: 210 | mask = batch["observed_mask"] 211 | truth_w_mask = batch["observed_data"] 212 | if mask is not None: 213 | truth_w_mask = torch.cat((batch["observed_data"], mask), -1) 214 | # pylint: disable-next=unused-variable 215 | mean, std = self.model.encoder_z0(truth_w_mask, torch.flatten(batch["observed_tp"]), run_backwards=True) 216 | encodings.append(mean.view(-1, self.latents)) 217 | return torch.cat(encodings, 0) 218 | 219 | def _get_and_reset_nfes(self): 220 | """Returns and resets the number of function evaluations for model.""" 221 | iteration_nfes = ( 222 | self.model.encoder_z0.z0_diffeq_solver.ode_func.nfe # pyright: ignore 223 | + self.model.diffeq_solver.ode_func.nfe 224 | ) 225 | self.model.encoder_z0.z0_diffeq_solver.ode_func.nfe = 0 # pyright: ignore 226 | self.model.diffeq_solver.ode_func.nfe = 0 227 | return iteration_nfes 228 | -------------------------------------------------------------------------------- /w_nl.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from torchlaplace import laplace_reconstruct 7 | 8 | from config import CME_reconstruction_terms 9 | 10 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 11 | logger = logging.getLogger() 12 | 13 | 14 | class ReverseGRUEncoder(nn.Module): 15 | # Encodes observed trajectory into latent vector 16 | def __init__(self, dimension_in, latent_dim, hidden_units, encode_obs_time=True): 17 | super(ReverseGRUEncoder, self).__init__() 18 | self.encode_obs_time = encode_obs_time 19 | if self.encode_obs_time: 20 | dimension_in += 1 21 | self.gru = nn.GRU(dimension_in, hidden_units, 2, batch_first=True) 22 | self.linear_out = nn.Linear(hidden_units, latent_dim) 23 | nn.init.xavier_uniform_(self.linear_out.weight) 24 | 25 | def forward(self, observed_data): 26 | trajs_to_encode = observed_data # (batch_size, t_observed_dim, observed_dim) 27 | reversed_trajs_to_encode = torch.flip(trajs_to_encode, (1,)) 28 | out, _ = self.gru(reversed_trajs_to_encode) 29 | return self.linear_out(out[:, -1, :]) 30 | 31 | 32 | class LaplaceRepresentationFunc(nn.Module): 33 | # SphereSurfaceModel : C^{b+k} -> C^{bxd} - In Riemann Sphere Co ords : 34 | # b dim s reconstruction terms, k is latent encoding dimension, d is output dimension 35 | def __init__(self, s_dim, output_dim, latent_dim, hidden_units=64): 36 | super(LaplaceRepresentationFunc, self).__init__() 37 | self.s_dim = s_dim 38 | self.output_dim = output_dim 39 | self.latent_dim = latent_dim 40 | self.linear_tanh_stack = nn.Sequential( 41 | nn.Linear(s_dim * 2 + latent_dim, hidden_units), 42 | nn.Tanh(), 43 | nn.Linear(hidden_units, hidden_units), 44 | nn.Tanh(), 45 | nn.Linear(hidden_units, (s_dim) * 2 * output_dim), 46 | ) 47 | 48 | for m in self.linear_tanh_stack.modules(): 49 | if isinstance(m, nn.Linear): 50 | nn.init.xavier_uniform_(m.weight) 51 | 52 | phi_max = torch.pi / 2.0 53 | self.phi_scale = phi_max - -torch.pi / 2.0 54 | 55 | def forward(self, i): 56 | out = self.linear_tanh_stack(i.view(-1, self.s_dim * 2 + self.latent_dim)).view( 57 | -1, 2 * self.output_dim, self.s_dim 58 | ) 59 | theta = nn.Tanh()(out[:, : self.output_dim, :]) * torch.pi # From - pi to + pi 60 | phi = ( 61 | nn.Tanh()(out[:, self.output_dim :, :]) * self.phi_scale / 2.0 - torch.pi / 2.0 + self.phi_scale / 2.0 62 | ) # Form -pi / 2 to + pi / 2 63 | return theta, phi 64 | 65 | 66 | class NeuralLaplaceModel(nn.Module): 67 | def __init__( 68 | self, 69 | state_dim, 70 | action_dim, 71 | latent_dim, 72 | hidden_units=64, 73 | s_recon_terms=33, 74 | ilt_algorithm="fourier", 75 | encode_obs_time=False, 76 | state_mean=None, 77 | state_std=None, 78 | action_mean=None, 79 | action_std=None, 80 | normalize=False, 81 | normalize_time=False, 82 | dt=0.05, 83 | ): 84 | super(NeuralLaplaceModel, self).__init__() 85 | self.ilt_algorithm = ilt_algorithm 86 | if ilt_algorithm == "cme": 87 | terms = CME_reconstruction_terms() 88 | s_recon_terms = terms[np.argmin(terms < s_recon_terms) - 2] 89 | action_encoder_latent_dim = 2 90 | laplace_latent_dim = state_dim + action_encoder_latent_dim 91 | self.latent_dim = latent_dim 92 | self.action_encoder = ReverseGRUEncoder( 93 | action_dim, 94 | action_encoder_latent_dim, 95 | hidden_units // 2, 96 | encode_obs_time=encode_obs_time, 97 | ) 98 | self.laplace_rep_func = LaplaceRepresentationFunc( 99 | s_recon_terms, state_dim, laplace_latent_dim, hidden_units=hidden_units 100 | ) 101 | self.encode_obs_time = encode_obs_time 102 | self.output_dim = state_dim 103 | self.normalize = normalize 104 | self.normalize_time = normalize_time 105 | self.s_recon_terms = s_recon_terms 106 | # if self.encode_obs_time: 107 | # state_mean = np.concatenate((state_mean,np.array([1]))) 108 | # state_std = np.concatenate((state_std,np.array([1]))) 109 | # action_mean = np.concatenate((action_mean,np.array([1]))) 110 | # action_std = np.concatenate((action_std,np.array([1]))) 111 | self.register_buffer("state_mean", torch.tensor(state_mean)) 112 | self.register_buffer("state_std", torch.tensor(state_std)) 113 | self.register_buffer("action_mean", torch.tensor(action_mean)) 114 | self.register_buffer("action_std", torch.tensor(action_std)) 115 | self.register_buffer("dt", torch.tensor(dt)) 116 | 117 | def forward(self, in_batch_obs, in_batch_action, ts_pred): 118 | # in_batch_action = in_batch_action[:,:1,:] 119 | if self.normalize: 120 | batch_obs = (in_batch_obs - self.state_mean) / self.state_std 121 | batch_action = (in_batch_action - self.action_mean) / self.action_std 122 | if self.normalize_time: 123 | ts_pred = ts_pred / (self.dt * 8.0) # pyright: ignore 124 | # ts_pred = (((ts_pred - self.dt) / (self.dt*4.0)) + 0.05) 125 | # batch_action = in_batch_action.view(-1, in_batch_action.shape[2]) 126 | # batch_action = batch_action.view(*in_batch_action.shape) 127 | else: 128 | batch_obs = in_batch_obs 129 | batch_action = in_batch_action / 3.0 130 | # p_action = batch_action.view(batch_action.shape[0],batch_action.shape[-1]) 131 | if len(batch_action.shape) == 2: 132 | batch_action = batch_action.unsqueeze(1) 133 | p_action = self.action_encoder(batch_action) 134 | sa_in = torch.cat((batch_obs, p_action), axis=1) # pyright: ignore 135 | p = sa_in 136 | return torch.squeeze( 137 | laplace_reconstruct( 138 | self.laplace_rep_func, 139 | p, 140 | ts_pred, 141 | recon_dim=self.output_dim, 142 | ilt_algorithm=self.ilt_algorithm, 143 | ilt_reconstruction_terms=self.s_recon_terms, 144 | ) 145 | ) 146 | 147 | 148 | def load_replay_buffer(fn): 149 | offline_dataset = np.load(fn, allow_pickle=True).item() 150 | return offline_dataset 151 | --------------------------------------------------------------------------------