├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── environments ├── __init__.py ├── chaotic_pendulum.py ├── datasets.py ├── environment.py ├── environment_factory.py ├── gravity.py ├── pendulum.py └── spring.py ├── experiment_params ├── dataset_offline_default.yaml ├── dataset_online_default.yaml ├── default_environments │ ├── chaotic_pendulum.yaml │ ├── damped_spring.yaml │ ├── pendulum.yaml │ ├── spring.yaml │ ├── three_bodies.yaml │ └── two_bodies.yaml └── train_config_default.yaml ├── generate_data.py ├── hamiltonian_generative_network.py ├── networks ├── __init__.py ├── debug_networks.py ├── decoder_net.py ├── encoder_net.py ├── hamiltonian_net.py └── transformer_net.py ├── requirements.txt ├── sample_rollouts.py ├── tests ├── __init__.py ├── check_gradients.py ├── grid_search.py ├── test_conversions.py ├── test_decoder.py ├── test_inference_net.py ├── test_losses.py └── test_networks.py ├── train.py └── utilities ├── __init__.py ├── conversions.py ├── gradient_flow_utils.py ├── hgn_result.py ├── integrator.py ├── loader.py ├── losses.py ├── statistics.py └── training_logger.py /.gitignore: -------------------------------------------------------------------------------- 1 | #datasets 2 | datasets 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # Pycharm 135 | .idea/ 136 | 137 | # Visual studio 138 | .vscode/ 139 | 140 | saved_models/ 141 | datasets/ 142 | 143 | runs/ -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | For collaborations reach us at [ai.campus.ai@gmail.com](mailto:ai.campus.ai@gmail.com) 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 CampusAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hamiltonian-Generative-Networks 2 | [![DOI](https://zenodo.org/badge/295400716.svg)](https://zenodo.org/badge/latestdoi/295400716) 3 | 4 | Re-implementation of Hamiltonian Generative Networks [paper](https://arxiv.org/abs/1909.13789). 5 | You can find the re-implementation publication details [here](https://rescience.github.io/bibliography/Balsells_Rodas_2021.html), and article [here](https://zenodo.org/record/4835278#.YQUhI3UzaEC). 6 | 7 | 8 | ## Setup 9 | 10 | 1. Install (CPU or GPU) [PyTorch](https://pytorch.org/). Tested on 1.6.0 11 | 2. Install other project dependencies: 12 | `pip install -r requirements.txt` 13 | 14 | ## Modules 15 | 16 | - **[Environments](environments/)**: You can find the implementation of the physical environments used to get the ground truth data. 17 | 18 | - **[Networks](networks/)**: Contains the definitions of the main ANNs used: Encoder, Transformer, Hamiltonian, and Decoder. 19 | 20 | - **[Utilities](utilities/)**: Holds helper classes such as the HGN integrator and a HGN output bundler class. 21 | 22 | - **[Experiment Params](experiment_params/)**: Contains .yaml files with the meta-parameters used in different experiments. 23 | 24 | - **[hamiltonian_generative_network.py](hamiltonian_generative_network.py)** script contains the definition of the HGN architecture. 25 | 26 | ## How to train 27 | The [train.py](train.py) script takes care of performing the training. 28 | To start a training, run 29 | ```commandline 30 | python train.py --train-config --dataset-config 31 | ``` 32 | 33 | ``` 34 | optional arguments: 35 | -h, --help show this help message and exit 36 | --train-config TRAIN_CONFIG 37 | Path to the training configuration yaml file. 38 | --dataset-config DATASET_CONFIG 39 | Path to the dataset configuration yaml file. 40 | --name NAME If specified, this name will be used instead of 41 | experiment_id of the yaml file. 42 | --epochs EPOCHS The number of training epochs. If not specified, 43 | optimization.epochs of the training configuration will 44 | be used. 45 | --env ENV The environment to use (for online training only). 46 | Possible values are 'pendulum', 'spring', 47 | 'two_bodies', 'three_bodies', corresponding to 48 | environment configurations in 49 | experiment_params/default_environments/. If not 50 | specified, the environment specified in the given 51 | --dataset-config will be used. 52 | --dataset-path DATASET_PATH 53 | Path to a stored dataset to use for training. For 54 | offline training only. In this case no dataset 55 | configuration file will be loaded. 56 | --params PARAMS [PARAMS ...] 57 | Override one or more parameters in the config. The 58 | format of an argument is param_name=param_value. 59 | Nested parameters are accessible by using a dot, i.e. 60 | --param dataset.img_size=32. IMPORTANT: lists must be 61 | enclosed in double quotes, i.e. --param 62 | environment.mass:"[0.5, 0.5]". 63 | --resume [RESUME] NOT IMPLEMENTED YET. Resume the training from a saved 64 | model. If a path is provided, the training will be 65 | resumed from the given checkpoint. Otherwise, the last 66 | checkpoint will be taken from 67 | saved_models/. 68 | ``` 69 | The `experiment_params/` folder contains default dataset and training configuration files. 70 | Training can be done in on-line or off-line mode. 71 | 72 | - In **on-line mode** data is generated during training, eliminating the need for a 73 | heavy dataset. A dataset configuration file must be provided in the `--dataset-config` 74 | argument. This file must define the `environment:` and `dataset:` sections 75 | (see [experiment_params/dataset_online_default.yaml](experiment_params/dataset_online_default.yaml)) 76 | . The `--env` argument may be used to override the environment defined in the config file 77 | with one of the default environments in `experiment_params/default_environments/`. 78 | 79 | - In **off-line mode** the training is performed from a stored dataset (see the section below 80 | on how to generate datasets). A dataset config specifying the train and test dataset paths 81 | in the `train_data:` and `test_data:` sections can be given to `--dataset-config` (see 82 | [experiment_params/dataset_offline_default.yaml](experiment_params/dataset_offline_default.yaml)) 83 | . Otherwise, the path to an existing dataset root folder (the one containing the 84 | `parameters.yaml` file ) must be provided to the `--dataset-path` argument. 85 | ## Generating and saving datasets 86 | A dataset can be generated starting from a `yaml` parameter file that specifies all its parameters 87 | in the `environment` and `dataset` sections. To create a dataset, run 88 | ```commandline 89 | python generate_data.py 90 | ``` 91 | which will create the dataset in a folder with the given name (see args below) and will 92 | write a `parameters.yaml` file within it, that can be directly used for off-line training 93 | on the created dataset. 94 | 95 | ``` 96 | optional arguments: 97 | -h, --help show this help message and exit 98 | --name NAME The dataset name. 99 | --dataset-config DATASET_CONFIG 100 | YAML file from which to read the dataset parameters. 101 | If not specified, 102 | experiment_params/dataset_online_default.yaml will be 103 | used. 104 | --ntrain NTRAIN Number of training sample to generate. 105 | --ntest NTEST Number of test samples to generate. 106 | --env ENV The default environment specifications to use. Can be 107 | 'pendulum', 'spring', 'two_bodies', 'three_bodies', 108 | 'chaotic_pendulum'. If this argument is specified, a 109 | default environment section will be loaded from the 110 | correspondent yaml file in 111 | experiment_params/default_environments/ 112 | --datasets-root DATASETS_ROOT 113 | Root of the datasets folder in which the dataset will 114 | be stored. If not specified, datasets/ will be used as 115 | default. 116 | --params PARAMS [PARAMS ...] 117 | Override one or more parameters in the config. The 118 | format of an argument is param_name=param_value. 119 | Nested parameters are accessible by using a dot, i.e. 120 | --param dataset.img_size=32. IMPORTANT: lists must be 121 | enclosed in double quotes, i.e. --param 122 | environment.mass:"[0.5, 0.5]". 123 | ``` 124 | -------------------------------------------------------------------------------- /environments/__init__.py: -------------------------------------------------------------------------------- 1 | # init file 2 | -------------------------------------------------------------------------------- /environments/chaotic_pendulum.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | from environment import Environment, visualize_rollout 5 | 6 | 7 | class ChaoticPendulum(Environment): 8 | """Chaotic Pendulum System: 2 Objects 9 | 10 | Hamiltonian system is: 11 | 12 | H = (1/2*m*L^2)* (p_1^2 + 2*p_2^2 - 2*p_1*p_2* \ 13 | cos(q_1 - q_2)) / (1 + sin^2(q_1 - q_2)) 14 | + mgL*(3 - 2*cos q_1 - cos q_2) 15 | 16 | """ 17 | 18 | WORLD_SIZE = 2.5 19 | 20 | def __init__(self, mass, length, g, q=None, p=None): 21 | """Constructor for pendulum system 22 | 23 | Args: 24 | mass (float): Pendulum mass (kg) 25 | length (float): Pendulum length (m) 26 | g (float): Gravity of the environment (m/s^2) 27 | q ([float], optional): Generalized position in 1-D space: Phase (rad). Defaults to None 28 | p ([float], optional): Generalized momentum in 1-D space: Angular momentum (kg*m^2/s). Defaults to None 29 | """ 30 | self.mass = mass 31 | self.length = length 32 | self.g = g 33 | super().__init__(q=q, p=p) 34 | 35 | def set(self, q, p): 36 | """Sets initial conditions for pendulum 37 | 38 | Args: 39 | q ([float]): Generalized position in 1-D space: Phase (rad) 40 | p ([float]): Generalized momentum in 1-D space: Angular momentum (kg*m^2/s) 41 | 42 | Raises: 43 | ValueError: If p and q are not in 1-D space 44 | """ 45 | if q is None or p is None: 46 | return 47 | if len(q) != 2 or len(p) != 2: 48 | raise ValueError( 49 | "q and p must be 2 objects in 1-D space: Angular momentum and Phase." 50 | ) 51 | self.q = q 52 | self.p = p 53 | 54 | def get_world_size(self): 55 | """Return world size for correctly render the environment. 56 | """ 57 | return self.WORLD_SIZE 58 | 59 | def get_max_noise_std(self): 60 | """Return maximum noise std that keeps the environment stable.""" 61 | return 0.05 62 | 63 | def get_default_radius_bounds(self): 64 | """Returns: 65 | radius_bounds (tuple): (min, max) radius bounds for the environment. 66 | """ 67 | return (0.5, 1.3) 68 | 69 | def _dynamics(self, t, states): 70 | """Defines system dynamics 71 | 72 | Args: 73 | t (float): Time parameter of the dynamic equations. 74 | states ([float]) Phase states at time t 75 | 76 | Returns: 77 | equations ([float]): Movement equations of the physical system 78 | """ 79 | states_resh = states.reshape(2, 2) 80 | dyn = np.zeros_like(states_resh) 81 | 82 | # dq_1 and dq_2 83 | quot = self.mass*(self.length**2) * \ 84 | (1 + (np.sin(states_resh[0, 0] - states_resh[0, 1])**2)) 85 | dyn[0, 0] = states_resh[1, 0] - states_resh[1, 1] * \ 86 | np.cos(states_resh[0, 0] - states_resh[0, 1]) 87 | dyn[0, 1] = states_resh[1, 1] - states_resh[1, 0] * \ 88 | np.cos(states_resh[0, 0] - states_resh[0, 1]) 89 | dyn[0, :] /= quot 90 | # dp_1 and dp_2 91 | dyn[1, :] -= 2 * self.mass * self.g * self.length * np.sin( 92 | states_resh[0, :]) 93 | cst = 1 / (2 * self.mass * (self.length**2)) 94 | term1 = states_resh[1, 0]**2 + states_resh[1, 1]**2 + \ 95 | 2*states_resh[1, 0]*states_resh[1, 1] * \ 96 | np.cos(states_resh[0, 0] - states_resh[0, 1]) 97 | term2 = (1 + (np.sin(states_resh[0, 0] - states_resh[0, 1])**2)) 98 | 99 | dterm1_dq_1 = 2*states_resh[1, 0]*states_resh[1, 1] * \ 100 | np.sin(states_resh[0, 0] - states_resh[0, 1]) 101 | dterm1_dq_2 = -dterm1_dq_1 102 | 103 | dterm2_dq_1 = 2 * np.cos(states_resh[0, 0] - states_resh[0, 1]) 104 | dterm2_dq_2 = -dterm2_dq_1 105 | 106 | dyn[1, 0] -= cst * (dterm1_dq_1 * term2 - term1 * 107 | dterm2_dq_1) / (term2 ** 2) 108 | dyn[1, 1] -= cst * (dterm1_dq_2 * term2 - term1 * 109 | dterm2_dq_2) / (term2 ** 2) 110 | 111 | return dyn.reshape(-1) 112 | 113 | def _draw(self, res=32, color=True): 114 | """Returns array of the environment evolution 115 | 116 | Args: 117 | res (int): Image resolution (images are square). 118 | color (bool): True if RGB, false if grayscale. 119 | 120 | Returns: 121 | vid (np.ndarray): Rendered rollout as a sequence of images 122 | """ 123 | q = self._rollout.reshape(2, 2, -1)[0, :, :] 124 | length = q.shape[-1] 125 | vid = np.zeros((length, res, res, 3), dtype='float') 126 | ball_colors = self._default_ball_colors 127 | space_res = 2.*self.get_world_size()/res 128 | for t in range(length): 129 | coords_1 = self._world_to_pixels( 130 | self.length * np.sin(q[0, t]), self.length * np.cos(q[0, t]), res) 131 | coords_2 = self._world_to_pixels( 132 | self.length * np.sin(q[0, t]) + self.length * np.sin(q[1, t]), 133 | self.length * np.cos(q[0, t]) + self.length * np.cos(q[1, t]), 134 | res) 135 | vid[t] = cv2.circle(vid[t], coords_1, int( 136 | self.length/(space_res*3)), ball_colors[0], -1) 137 | vid[t] = cv2.circle(vid[t], coords_2, int( 138 | self.length/(space_res*3)), ball_colors[1], -1) 139 | vid[t] = cv2.blur(cv2.blur(vid[t], (2, 2)), (2, 2)) 140 | vid += self._default_background_color 141 | vid[vid > 1.] = 1. 142 | if not color: 143 | vid = np.expand_dims(np.max(vid, axis=-1), -1) 144 | return vid 145 | 146 | def _sample_init_conditions(self, radius): 147 | """Samples random initial conditions for the environment 148 | 149 | Args: 150 | radius (float): Radius of the sampling process 151 | """ 152 | states_q = np.random.rand(2) * 2. - 1 153 | states_q = (states_q / np.sqrt((states_q**2).sum())) * radius 154 | states_p = np.random.rand(2) * 2. - 1 155 | states_p = (states_p / np.sqrt((states_p**2).sum())) * radius 156 | self.set(states_q, states_p) 157 | 158 | 159 | # Sample code for sampling rollouts 160 | if __name__ == "__main__": 161 | 162 | pd = ChaoticPendulum(mass=1., length=1, g=3) 163 | rolls = pd.sample_random_rollouts(number_of_frames=300, 164 | delta_time=.125, 165 | number_of_rollouts=1, 166 | img_size=64, 167 | noise_level=0., 168 | radius_bound=(0.5, 1.3), 169 | seed=None) 170 | idx = np.random.randint(rolls.shape[0]) 171 | visualize_rollout(rolls[idx]) 172 | -------------------------------------------------------------------------------- /environments/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import Dataset 7 | from torchvision import transforms 8 | 9 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 10 | from utilities import conversions 11 | 12 | 13 | class EnvironmentSampler(Dataset): 14 | """Dataset for rollout sampling. 15 | 16 | Given an environment and sampling conditions, the dataset samples rollouts as pytorch tensors. 17 | """ 18 | 19 | def __init__(self, 20 | environment, 21 | dataset_len, 22 | number_of_frames, 23 | delta_time, 24 | number_of_rollouts, 25 | img_size, 26 | color, 27 | noise_level, 28 | radius_bound, 29 | seed, 30 | dtype=torch.float): 31 | """Instantiate the EnvironmentSampler. 32 | 33 | Args: 34 | environment (Environment): Instance belonging to Environment abstract base class. 35 | dataset_len (int): Length of the dataset. 36 | number_of_frames (int): Total duration of video (in frames). 37 | delta_time (float): Frame interval of generated data (in seconds). 38 | number_of_rollouts (int): Number of rollouts to generate. 39 | img_size (int): Size of the frames (in pixels). 40 | color (bool): Whether to have colored or grayscale frames. 41 | noise_level (float): Value in [0, 1] to tune the noise added to the environment 42 | trajectory. 43 | radius_bound (float, float): Radius lower and upper bound of the phase state sampling. 44 | Init phase states will be sampled from a circle (q, p) of radius 45 | r ~ U(radius_bound[0], radius_bound[1]) https://arxiv.org/pdf/1909.13789.pdf (Sec 4) 46 | Optionally, it can be a string 'auto'. In that case, the value returned by 47 | environment.get_default_radius_bounds() will be returned. 48 | seed (int): Seed for reproducibility. 49 | dtype (torch.type): Type of the sampled tensors. 50 | """ 51 | self.environment = environment 52 | self.dataset_len = dataset_len 53 | self.number_of_frames = number_of_frames 54 | self.delta_time = delta_time 55 | self.number_of_rollouts = number_of_rollouts 56 | self.img_size = img_size 57 | self.color = color 58 | self.noise_level = noise_level 59 | self.radius_bound = radius_bound 60 | self.seed = seed 61 | self.dtype = dtype 62 | 63 | def __len__(self): 64 | """Get dataset length 65 | 66 | Returns: 67 | length (int): Length of the dataset. 68 | """ 69 | return self.dataset_len 70 | 71 | def __getitem__(self, i): 72 | """Iterator for rollout sampling. 73 | Samples a rollout and converts it to a Pytorch tensor. 74 | 75 | Args: 76 | i (int): Index of the dataset sample (ignored since we sample random data). 77 | Returns: 78 | (Torch.tensor): Tensor of shape (batch_len, seq_len, channels, height, width) with the 79 | sampled rollouts. 80 | """ 81 | rolls = self.environment.sample_random_rollouts( 82 | number_of_frames=self.number_of_frames, 83 | delta_time=self.delta_time, 84 | number_of_rollouts=self.number_of_rollouts, 85 | img_size=self.img_size, 86 | color=self.color, 87 | noise_level=self.noise_level, 88 | radius_bound=self.radius_bound, 89 | seed=self.seed) 90 | rolls = torch.from_numpy(rolls).type(self.dtype) 91 | return conversions.to_channels_first(rolls) 92 | 93 | 94 | class EnvironmentLoader(Dataset): 95 | def __init__(self, root_dir): 96 | self.root_dir = root_dir 97 | self.file_list = os.listdir(root_dir) 98 | 99 | def __len__(self): 100 | return len(self.file_list) 101 | 102 | def __getitem__(self, i): 103 | rolls = np.load(os.path.join( 104 | self.root_dir, self.file_list[i]))['arr_0'] 105 | return rolls.transpose((0, 3, 1, 2)) 106 | 107 | 108 | # Sample code for DataLoader call 109 | if __name__ == "__main__": 110 | import time 111 | from .pendulum import Pendulum 112 | 113 | pd = Pendulum(mass=.5, length=1, g=3) 114 | trainDS = EnvironmentSampler(environment=pd, 115 | dataset_len=100, 116 | number_of_frames=100, 117 | delta_time=.1, 118 | number_of_rollouts=4, 119 | img_size=64, 120 | noise_level=0., 121 | radius_bound=(1.3, 2.3), 122 | seed=23) 123 | # Dataloader instance test, batch_mode disabled 124 | train = torch.utils.data.DataLoader(trainDS, 125 | shuffle=False, 126 | batch_size=None) 127 | start = time.time() 128 | sample = next(iter(train)) 129 | end = time.time() 130 | 131 | print(sample.size(), "Sampled in " + str(end - start) + " s") 132 | 133 | # Dataloader instance test, batch_mode enabled 134 | train = torch.utils.data.DataLoader(trainDS, 135 | shuffle=False, 136 | batch_size=4, 137 | num_workers=1) 138 | start = time.time() 139 | sample = next(iter(train)) 140 | end = time.time() 141 | 142 | print(sample.size(), "Sampled in " + str(end - start) + " s") 143 | 144 | trainDS = EnvironmentLoader('../datasets/pendulum_data/train') 145 | 146 | train = torch.utils.data.DataLoader(trainDS, 147 | shuffle=False, 148 | batch_size=10, 149 | num_workers=4) 150 | start = time.time() 151 | sample = next(iter(train)) 152 | end = time.time() 153 | print(sample.size(), "Sampled in " + str(end - start) + " s") 154 | -------------------------------------------------------------------------------- /environments/environment.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import os 3 | 4 | import cv2 5 | from matplotlib import pyplot as plt, animation 6 | import numpy as np 7 | from scipy.integrate import solve_ivp 8 | 9 | 10 | class Environment(ABC): 11 | def __init__(self, q=None, p=None): 12 | """Instantiate new environment with the provided position and momentum 13 | 14 | Args: 15 | q ([float], optional): generalized position in n-d space 16 | p ([float], optional): generalized momentum in n-d space 17 | """ 18 | self._default_background_color = [81./255, 88./255, 93./255] 19 | self._default_ball_colors = [ 20 | (173./255, 146./255, 0.), (173./255, 0., 0.), (0., 146./255, 0.)] 21 | self._rollout = None 22 | self.q = None 23 | self.p = None 24 | self.set(q=q, p=p) 25 | 26 | @abstractmethod 27 | def set(self, q, p): 28 | """Sets initial conditions for physical system 29 | 30 | Args: 31 | q ([float]): generalized position in n-d space 32 | p ([float]): generalized momentum in n-d space 33 | 34 | Raises: 35 | NotImplementedError: Class instantiation has no implementation 36 | """ 37 | raise NotImplementedError 38 | 39 | @abstractmethod 40 | def _dynamics(self, t, states): 41 | """Defines system dynamics 42 | 43 | Args: 44 | t (float): Time parameter of the dynamic equations. 45 | states ([float]) Phase states at time t 46 | 47 | Raises: 48 | NotImplementedError: Class instantiation has no implementation 49 | """ 50 | raise NotImplementedError 51 | 52 | @abstractmethod 53 | def _draw(self, img_size, color): 54 | """Returns array of the environment evolution. 55 | 56 | Args: 57 | img_size (int): Size of the frames (in pixels). 58 | color (bool): Whether to have colored or grayscale frames. 59 | 60 | Raises: 61 | NotImplementedError: Class instantiation has no implementation 62 | """ 63 | raise NotImplementedError 64 | 65 | @abstractmethod 66 | def get_world_size(self): 67 | """Returns the world size for the environment. 68 | """ 69 | raise NotImplementedError 70 | 71 | @abstractmethod 72 | def get_max_noise_std(self): 73 | """Returns the maximum noise standard deviation that maintains a stable environment. 74 | """ 75 | raise NotImplementedError 76 | 77 | @abstractmethod 78 | def get_default_radius_bounds(self): 79 | """Returns a tuple (min, max) with the default radius bounds for the environment. 80 | """ 81 | raise NotImplementedError 82 | 83 | @abstractmethod 84 | def _sample_init_conditions(self, radius_bound): 85 | """Samples random initial conditions for the environment 86 | 87 | Args: 88 | radius_bound (float, float): Radius lower and upper bound of the phase state sampling. 89 | Optionally, it can be a string 'auto'. In that case, the value returned by 90 | get_default_radius_bounds() will be returned. 91 | 92 | Raises: 93 | NotImplementedError: Class instantiation has no implementation 94 | """ 95 | raise NotImplementedError 96 | 97 | def _world_to_pixels(self, x, y, res): 98 | """Maps coordinates from world space to pixel space 99 | 100 | Args: 101 | x (float): x coordinate of the world space. 102 | y (float): y coordinate of the world space. 103 | res (int): Image resolution in pixels (images are square). 104 | 105 | Returns: 106 | (int, int): Tuple of coordinates in pixel space. 107 | """ 108 | pix_x = int(res*(x + self.get_world_size())/(2*self.get_world_size())) 109 | pix_y = int(res*(y + self.get_world_size())/(2*self.get_world_size())) 110 | 111 | return (pix_x, pix_y) 112 | 113 | def _evolution(self, total_time=10, delta_time=0.1): 114 | """Performs rollout of the physical system given some initial conditions. 115 | Sets rollout phase states to self.rollout 116 | 117 | Args: 118 | total_time (float): Total duration of the rollout (in seconds) 119 | delta_time (float): Sample interval in the rollout (in seconds) 120 | 121 | Raises: 122 | AssertError: If p or q are None 123 | """ 124 | if isinstance(self.q, np.ndarray): 125 | assert self.q.all() != None 126 | assert self.p.all() != None 127 | else: 128 | assert self.q != None 129 | assert self.p != None 130 | 131 | t_eval = np.linspace(0, total_time, 132 | round(total_time / delta_time) + 1)[:-1] 133 | t_span = [0, total_time] 134 | y0 = np.array([np.array(self.q), np.array(self.p)]).reshape(-1) 135 | self._rollout = solve_ivp(self._dynamics, t_span, y0, t_eval=t_eval).y 136 | 137 | def sample_random_rollouts(self, 138 | number_of_frames=100, 139 | delta_time=0.1, 140 | number_of_rollouts=16, 141 | img_size=32, 142 | color=True, 143 | noise_level=0.1, 144 | radius_bound=(1.3, 2.3), 145 | seed=None): 146 | """Samples random rollouts for a given environment 147 | 148 | Args: 149 | number_of_frames (int): Total duration of video (in frames). 150 | delta_time (float): Frame interval of generated data (in seconds). 151 | number_of_rollouts (int): Number of rollouts to generate. 152 | img_size (int): Size of the frames (in pixels). 153 | color (bool): Whether to have colored or grayscale frames. 154 | noise_level (float): Level of noise, in [0, 1]. 0 means no noise, 1 means max noise. 155 | Maximum noise is defined in each environment. 156 | radius_bound (float, float): Radius lower and upper bound of the phase state sampling. 157 | Init phase states will be sampled from a circle (q, p) of radius 158 | r ~ U(radius_bound[0], radius_bound[1]) https://arxiv.org/pdf/1909.13789.pdf (Sec. 4) 159 | Optionally, it can be a string 'auto'. In that case, the value returned by 160 | get_default_radius_bounds() will be returned. 161 | seed (int): Seed for reproducibility. 162 | Raises: 163 | AssertError: If radius_bound[0] > radius_bound[1] 164 | Returns: 165 | (ndarray): Array of shape (Batch, Nframes, Height, Width, Channels). 166 | Contains sampled rollouts 167 | """ 168 | if radius_bound == 'auto': 169 | radius_bound = self.get_default_radius_bounds() 170 | radius_lb, radius_ub = radius_bound 171 | assert radius_lb <= radius_ub 172 | if seed is not None: 173 | np.random.seed(seed) 174 | total_time = number_of_frames * delta_time 175 | batch_sample = [] 176 | for i in range(number_of_rollouts): 177 | self._sample_init_conditions(radius_bound) 178 | self._evolution(total_time, delta_time) 179 | if noise_level > 0.: 180 | self._rollout += np.random.randn( 181 | *self._rollout.shape) * noise_level * self.get_max_noise_std() 182 | batch_sample.append(self._draw(img_size, color)) 183 | 184 | return np.array(batch_sample) 185 | 186 | 187 | def visualize_rollout(rollout, interval=50, show_step=False): 188 | """Visualization for a single sample rollout of a physical system. 189 | 190 | Args: 191 | rollout (numpy.ndarray): Numpy array containing the sequence of images. It's shape must be 192 | (seq_len, height, width, channels). 193 | interval (int): Delay between frames (in millisec). 194 | show_step (bool): Whether to draw the step number in the image 195 | """ 196 | fig = plt.figure() 197 | img = [] 198 | for i, im in enumerate(rollout): 199 | if show_step: 200 | black_img = np.zeros(list(im.shape)) 201 | cv2.putText( 202 | black_img, str(i), (0, 30), fontScale=0.22, color=(255, 255, 255), thickness=1, 203 | fontFace=cv2.LINE_AA) 204 | res_img = (im + black_img / 255.) / 2 205 | else: 206 | res_img = im 207 | img.append([plt.imshow(res_img, animated=True)]) 208 | ani = animation.ArtistAnimation(fig, 209 | img, 210 | interval=interval, 211 | blit=True, 212 | repeat_delay=100) 213 | plt.show() 214 | -------------------------------------------------------------------------------- /environments/environment_factory.py: -------------------------------------------------------------------------------- 1 | """Environment factory class. Given a valid environment name and its constructor args, returns an instantiation of it 2 | """ 3 | import os 4 | import sys 5 | 6 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 7 | 8 | from environment import Environment 9 | from pendulum import Pendulum 10 | from spring import Spring 11 | from gravity import NObjectGravity 12 | from chaotic_pendulum import ChaoticPendulum 13 | 14 | 15 | class EnvFactory(): 16 | """Return a new Environment""" 17 | 18 | # Map the name of the name of the Environment concrete class by retrieving all its subclasses 19 | _name_to_env = {cl.__name__: cl for cl in Environment.__subclasses__()} 20 | 21 | @staticmethod 22 | def get_environment(name, **kwargs): 23 | """Return an environment object based on the environment identifier. 24 | 25 | Args: 26 | name (string); name of the class of the concrete Environment. 27 | **kwargs: args supplied to the constructor of the object of class name. 28 | 29 | Raises: 30 | (NameError): if the given environment type is not supported. 31 | 32 | Returns: 33 | (Environment): concrete instantiation of the Environment. 34 | """ 35 | try: 36 | return EnvFactory._name_to_env[name](**kwargs) 37 | except KeyError: 38 | msg = "%s is not a supported type by Environment." % (name) 39 | msg += "Available types are: " + "".join("%s " % eef for eef in EnvFactory._name_to_env.keys()) 40 | raise NameError(msg) 41 | 42 | 43 | if __name__ == "__main__": 44 | # EnvFactory test 45 | env = EnvFactory.get_environment("Pendulum", mass=0.5, length=1, g=10) 46 | print(type(env)) 47 | 48 | from matplotlib import pyplot as plt, animation 49 | import numpy as np 50 | rolls = env.sample_random_rollouts(number_of_frames=100, 51 | delta_time=0.1, 52 | number_of_rollouts=16, 53 | img_size=32, 54 | color=False, 55 | noise_level=0., 56 | seed=23) 57 | fig = plt.figure() 58 | img = [] 59 | idx = np.random.randint(rolls.shape[0]) 60 | for im in rolls[idx]: 61 | img.append([plt.imshow(im, animated=True)]) 62 | ani = animation.ArtistAnimation(fig, 63 | img, 64 | interval=50, 65 | blit=True, 66 | repeat_delay=1000) 67 | plt.show() -------------------------------------------------------------------------------- /environments/gravity.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import cv2 4 | import numpy as np 5 | 6 | from environment import Environment, visualize_rollout 7 | 8 | 9 | class NObjectGravity(Environment): 10 | 11 | """N Object Gravity Atraction System 12 | 13 | Equations of movement are: 14 | 15 | m_i*q'' = G * sum_{i!=j} ( m_i*m_j*(q_j - q_i)/ abs(q_j - q_i)^3 ) 16 | 17 | """ 18 | 19 | WORLD_SIZE = 3. 20 | 21 | def __init__(self, mass, g, orbit_noise=.01, q=None, p=None): 22 | """Contructor for spring system 23 | 24 | Args: 25 | mass (list): List of floats corresponding to object masses (kg). 26 | g (float): Constant for the intensity of gravitational field (m^3/kg*s^2) 27 | orbit_noise (float, optional): Noise for object orbits when sampling initial conditions 28 | q (ndarray, optional): Object generalized positions in 2-D space: Positions (m). Defaults to None 29 | p (ndarray, optional): Object generalized momentums in 2-D space : Linear momentums (kg*m/s). Defaults to None 30 | Raises: 31 | NotImplementedError: If more than 7 objects are considered 32 | """ 33 | self.mass = mass 34 | self.colors = ['r', 'y', 'g', 'b', 'c', 'p', 'w'] 35 | self.n_objects = len(mass) 36 | self.g = g 37 | self.orbit_noise = orbit_noise 38 | if self.n_objects > 3: 39 | raise NotImplementedError( 40 | 'Gravity interaction for ' + str(self.n_objects) + ' bodies is not implemented.') 41 | super().__init__(q=q, p=p) 42 | 43 | def set(self, q, p): 44 | """Sets initial conditions for pendulum 45 | 46 | Args: 47 | q (ndarray): Object generalized positions in 2-D space: Positions (m) 48 | p (ndarray): Object generalized momentums in 2-D space : Linear momentums (kg*m/s) 49 | 50 | Raises: 51 | ValueError: If q and p are not in 2-D space or do not refer to all the objects in space 52 | """ 53 | if q is None or p is None: 54 | return 55 | if q.shape[0] != self.n_objects or p.shape[0] != self.n_objects: 56 | raise ValueError( 57 | "q and p do not refer to the same number of objects in the system.") 58 | if q.shape[-1] != 2 or p.shape[-1] != 2: 59 | raise ValueError( 60 | "q and p must be in 2-D space: Position and Linear momentum.") 61 | self.q = q.copy() 62 | self.p = p.copy() 63 | 64 | def get_world_size(self): 65 | """Return world size for correctly render the environment. 66 | """ 67 | return self.WORLD_SIZE 68 | 69 | def get_max_noise_std(self): 70 | """Return maximum noise std that keeps the environment stable.""" 71 | if self.n_objects == 2: 72 | return 0.05 73 | elif self.n_objects == 3: 74 | return 0.2 75 | else: 76 | return 0. 77 | 78 | def get_default_radius_bounds(self): 79 | """Returns: 80 | radius_bounds (tuple): (min, max) radius bounds for the environment. 81 | """ 82 | if self.n_objects == 2: 83 | return (0.5, 1.5) 84 | elif self.n_objects == 3: 85 | return (0.9, 1.2) 86 | else: 87 | warnings.warn( 88 | 'Gravity for n > 3 objects can have undefined behavior.') 89 | return (0.3, 0.5) 90 | 91 | def _dynamics(self, t, states): 92 | """Defines system dynamics 93 | 94 | Args: 95 | t (float): Time parameter of the dynamic equations. 96 | states (numpy.ndarray): 1-D array that contains the information of the phase 97 | state, in the format of np.array([q,p]).reshape(-1). 98 | 99 | Returns: 100 | equations (numpy.ndarray): Numpy array with derivatives of q and p w.r.t. time 101 | """ 102 | # Convert states to n_object arrays of q and p 103 | states_resh = states.reshape(2, self.n_objects, 2) 104 | dyn = np.zeros_like(states_resh) 105 | states_q = states_resh[0, :, :] 106 | states_p = states_resh[1, :, :] 107 | dyn[0, :, :] = states_p/(np.array(self.mass)[:, np.newaxis]) 108 | 109 | # Distance calculation 110 | object_distance = np.zeros((self.n_objects, self.n_objects)) 111 | for i in range(self.n_objects): 112 | for j in range(i, self.n_objects): 113 | object_distance[i, j] = np.linalg.norm( 114 | states_q[i] - states_q[j]) 115 | object_distance[j, i] = object_distance[i, j] 116 | object_distance = np.power(object_distance, 3)/self.g 117 | 118 | for d in range(2): 119 | for i in range(self.n_objects): 120 | mom_term = 0 121 | for j in range(self.n_objects): 122 | if i != j: 123 | mom_term += self.mass[j]*(states_q[j, d] - 124 | states_q[i, d])/object_distance[i, j] 125 | dyn[1, i, d] += mom_term*self.mass[i] 126 | return dyn.reshape(-1) 127 | 128 | def _draw(self, res=32, color=True): 129 | """Returns array of the environment evolution 130 | 131 | Args: 132 | res (int): Image resolution (images are square). 133 | color (bool): True if RGB, false if grayscale. 134 | 135 | Returns: 136 | vid (np.ndarray): Numpy array of shape (seq_len, height, width, channels) 137 | containing the rendered rollout as a sequence of images. 138 | """ 139 | q = self._rollout.reshape(2, self.n_objects, 2, -1)[0] 140 | length = q.shape[-1] 141 | vid = np.zeros((length, res, res, 3), dtype='float') 142 | ball_colors = self._default_ball_colors 143 | space_res = 2.*self.get_world_size()/res 144 | if self.n_objects == 2: 145 | factor = 0.55 146 | else: 147 | factor = 0.25 148 | for t in range(length): 149 | for n in range(self.n_objects): 150 | brush = self.colors[n] 151 | if brush == 'y': 152 | vid[t] = cv2.circle(vid[t], 153 | self._world_to_pixels( 154 | q[n, 0, t], q[n, 1, t], res), 155 | int(self.mass[n]*factor/space_res), ball_colors[0], -1) 156 | elif brush == 'r': 157 | vid[t] = cv2.circle(vid[t], 158 | self._world_to_pixels( 159 | q[n, 0, t], q[n, 1, t], res), 160 | int(self.mass[n]*factor/space_res), ball_colors[1], -1) 161 | elif brush == 'g': 162 | vid[t] = cv2.circle(vid[t], 163 | self._world_to_pixels( 164 | q[n, 0, t], q[n, 1, t], res), 165 | int(self.mass[n]*factor/space_res), ball_colors[2], -1) 166 | vid[t] = cv2.blur(cv2.blur(vid[t], (2, 2)), (2, 2)) 167 | vid += self._default_background_color 168 | vid[vid > 1.] = 1. 169 | if not color: 170 | vid = np.expand_dims(np.max(vid, axis=-1), -1) 171 | return vid 172 | 173 | def _sample_init_conditions(self, radius_bound): 174 | """Samples random initial conditions for the environment 175 | Args: 176 | radius_bound (float, float): Radius lower and upper bound of the phase state sampling. 177 | Optionally, it can be a string 'auto'. In that case, the value returned by 178 | get_default_radius_bounds() will be returned. 179 | """ 180 | radius_lb, radius_ub = radius_bound 181 | radius = np.random.rand()*(radius_ub - radius_lb) + radius_lb 182 | states = np.zeros((2, self.n_objects, 2)) 183 | # first position 184 | pos = np.random.rand(2)*2. - 1 185 | pos = (pos/np.sqrt((pos**2).sum()))*radius 186 | 187 | # velocity that yields a circular orbit 188 | vel = self.__rotate2d(pos, theta=np.pi/2) 189 | if np.random.randn() < .5: 190 | vel = -vel 191 | if self.n_objects == 2: 192 | factor = 2 193 | vel /= (factor*radius**1.5) 194 | 195 | else: 196 | factor = np.sqrt(np.sin(np.pi/3)/(2*np.cos(np.pi/6)**2)) 197 | vel *= factor/(radius**1.5) 198 | 199 | states[0, 0, :] = pos 200 | states[1, 0, :] = vel 201 | 202 | rot_angle = 2*np.pi/self.n_objects 203 | for i in range(1, self.n_objects): 204 | states[0, i, :] = self.__rotate2d( 205 | states[0, i - 1, :], theta=rot_angle) 206 | states[1, i, :] = self.__rotate2d( 207 | states[1, i - 1, :], theta=rot_angle) 208 | for i in range(self.n_objects): 209 | states[1, i, :] *= 1 + \ 210 | self.orbit_noise*(2*np.random.rand(2) - 1) 211 | self.set(states[0], states[1]) 212 | 213 | def __rotate2d(self, p, theta): 214 | c, s = np.cos(theta), np.sin(theta) 215 | Rot = np.array([[c, -s], [s, c]]) 216 | return np.dot(Rot, p.reshape(2, 1)).squeeze() 217 | 218 | 219 | # Sample code for sampling rollouts 220 | if __name__ == "__main__": 221 | 222 | og = NObjectGravity(mass=[1., 1.], 223 | g=1., orbit_noise=0.05) 224 | rolls = og.sample_random_rollouts(number_of_frames=30, 225 | delta_time=0.125, 226 | number_of_rollouts=1, 227 | img_size=32, 228 | noise_level=0., 229 | radius_bound=(.5, 1.5), 230 | seed=None) 231 | idx = np.random.randint(rolls.shape[0]) 232 | visualize_rollout(rolls[idx]) 233 | -------------------------------------------------------------------------------- /environments/pendulum.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | from environment import Environment, visualize_rollout 5 | 6 | 7 | class Pendulum(Environment): 8 | """Pendulum System 9 | 10 | Equations of movement are: 11 | 12 | theta'' = -(g/l)*sin(theta) 13 | 14 | """ 15 | 16 | WORLD_SIZE = 2. 17 | 18 | def __init__(self, mass, length, g, q=None, p=None): 19 | """Constructor for pendulum system 20 | 21 | Args: 22 | mass (float): Pendulum mass (kg) 23 | length (float): Pendulum length (m) 24 | g (float): Gravity of the environment (m/s^2) 25 | q ([float], optional): Generalized position in 1-D space: Phase (rad). Defaults to None 26 | p ([float], optional): Generalized momentum in 1-D space: Angular momentum (kg*m^2/s). 27 | Defaults to None 28 | """ 29 | self.mass = mass 30 | self.length = length 31 | self.g = g 32 | super().__init__(q=q, p=p) 33 | 34 | def set(self, q, p): 35 | """Sets initial conditions for pendulum 36 | 37 | Args: 38 | q ([float]): Generalized position in 1-D space: Phase (rad) 39 | p ([float]): Generalized momentum in 1-D space: Angular momentum (kg*m^2/s) 40 | 41 | Raises: 42 | ValueError: If p and q are not in 1-D space 43 | """ 44 | if q is None or p is None: 45 | return 46 | if len(q) != 1 or len(p) != 1: 47 | raise ValueError( 48 | "q and p must be in 1-D space: Angular momentum and Phase.") 49 | self.q = q 50 | self.p = p 51 | 52 | def get_world_size(self): 53 | """Return world size for correctly render the environment. 54 | """ 55 | return self.WORLD_SIZE 56 | 57 | def get_max_noise_std(self): 58 | """Return maximum noise std that keeps the environment stable.""" 59 | return 0.1 60 | 61 | def get_default_radius_bounds(self): 62 | """Returns: 63 | radius_bounds (tuple): (min, max) radius bounds for the environment. 64 | """ 65 | return (1.3, 2.3) 66 | 67 | def _dynamics(self, t, states): 68 | """Defines system dynamics 69 | 70 | Args: 71 | t (float): Time parameter of the dynamic equations. 72 | states ([float]) Phase states at time t 73 | 74 | Returns: 75 | equations ([float]): Movement equations of the physical system 76 | """ 77 | return [(states[1] / (self.mass * self.length * self.length)), 78 | -self.g * self.mass * self.length * np.sin(states[0])] 79 | 80 | def _draw(self, res=32, color=True): 81 | """Returns array of the environment evolution 82 | 83 | Args: 84 | res (int): Image resolution (images are square). 85 | color (bool): True if RGB, false if grayscale. 86 | 87 | Returns: 88 | vid (np.ndarray): Rendered rollout as a sequence of images 89 | """ 90 | q = self._rollout[0, :] 91 | length = len(q) 92 | vid = np.zeros((length, res, res, 3), dtype='float') 93 | ball_color = self._default_ball_colors[0] 94 | space_res = 2.*self.get_world_size()/res 95 | for t in range(length): 96 | vid[t] = cv2.circle(vid[t], self._world_to_pixels(self.length*np.sin(q[t]), 97 | self.length*np.cos(q[t]), res), 98 | int(self.mass/space_res), ball_color, -1) 99 | vid[t] = cv2.blur(cv2.blur(vid[t], (2, 2)), (2, 2)) 100 | vid += self._default_background_color 101 | vid[vid > 1.] = 1. 102 | if not color: 103 | vid = np.expand_dims(np.max(vid, axis=-1), -1) 104 | return vid 105 | 106 | def _sample_init_conditions(self, radius_bound): 107 | """Samples random initial conditions for the environment 108 | 109 | Args: 110 | radius_bound (float, float): Radius lower and upper bound of the phase state sampling. 111 | Optionally, it can be a string 'auto'. In that case, the value returned by 112 | get_default_radius_bounds() will be returned. 113 | """ 114 | radius_lb, radius_ub = radius_bound 115 | radius = np.random.rand()*(radius_ub - radius_lb) + radius_lb 116 | states = np.random.rand(2) * 2. - 1 117 | states = (states / np.sqrt((states**2).sum())) * radius 118 | self.set([states[0]], [states[1]]) 119 | 120 | 121 | # Sample code for sampling rollouts 122 | if __name__ == "__main__": 123 | 124 | pd = Pendulum(mass=.5, length=1, g=3) 125 | rolls = pd.sample_random_rollouts(number_of_frames=100, 126 | delta_time=0.1, 127 | number_of_rollouts=16, 128 | img_size=32, 129 | noise_level=0., 130 | radius_bound=(1.3, 2.3), 131 | color=True, 132 | seed=23) 133 | idx = np.random.randint(rolls.shape[0]) 134 | visualize_rollout(rolls[idx]) 135 | -------------------------------------------------------------------------------- /environments/spring.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | from environment import Environment, visualize_rollout 5 | 6 | 7 | class Spring(Environment): 8 | """Damped spring System 9 | 10 | Equations of movement are: 11 | 12 | x'' = -2*c*sqrt(k/m)*x' -(k/m)*x 13 | 14 | """ 15 | 16 | WORLD_SIZE = 2. 17 | 18 | def __init__(self, mass, elastic_cst, damping_ratio=0., q=None, p=None): 19 | """Constructor for spring system 20 | 21 | Args: 22 | mass (float): Spring mass (kg) 23 | elastic_cst (float): Spring elastic constant (kg/s^2) 24 | damping_ratio (float): Damping ratio of the oscillator 25 | if damping_ratio > 1: Oscillator is overdamped 26 | if damping_ratio = 1: Oscillator is critically damped 27 | if damping_ratio < 1: Oscillator is underdamped 28 | q ([float], optional): Generalized position in 1-D space: Position (m). Defaults to None 29 | p ([float], optional): Generalized momentum in 1-D space: Linear momentum (kg*m/s). Defaults to None 30 | """ 31 | self.mass = mass 32 | self.elastic_cst = elastic_cst 33 | self.damping_ratio = damping_ratio 34 | super().__init__(q=q, p=p) 35 | 36 | def set(self, q, p): 37 | """Sets initial conditions for spring system 38 | 39 | Args: 40 | q ([float]): Generalized position in 1-D space: Position (m) 41 | p ([float]): Generalized momentum in 1-D space: Linear momentum (kg*m/s) 42 | 43 | Raises: 44 | ValueError: If p and q are not in 1-D space 45 | """ 46 | if q is None or p is None: 47 | return 48 | if len(q) != 1 or len(p) != 1: 49 | raise ValueError( 50 | "q and p must be in 1-D space: Angular momentum and Phase.") 51 | self.q = q 52 | self.p = p 53 | 54 | def get_world_size(self): 55 | """Return world size for correctly render the environment. 56 | """ 57 | return self.WORLD_SIZE 58 | 59 | def get_max_noise_std(self): 60 | """Return maximum noise std that keeps the environment stable.""" 61 | return 0.1 62 | 63 | def get_default_radius_bounds(self): 64 | """Returns: 65 | radius_bounds (tuple): (min, max) radius bounds for the environment. 66 | """ 67 | return (0.1, 1.0) 68 | 69 | def _dynamics(self, t, states): 70 | """Defines system dynamics 71 | 72 | Args: 73 | t (float): Time parameter of the dynamic equations. 74 | states ([float]) Phase states at time t 75 | 76 | Returns: 77 | equations ([float]): Movement equations of the physical system 78 | """ 79 | # angular freq of the undamped oscillator 80 | w0 = np.sqrt(self.elastic_cst/self.mass) 81 | # dynamics of the damped oscillator 82 | return [states[1] / self.mass, -2*self.damping_ratio*w0*states[1] - self.elastic_cst*states[0]] 83 | 84 | def _draw(self, res=32, color=True): 85 | """Returns array of the environment evolution 86 | 87 | Args: 88 | res (int): Image resolution (images are square). 89 | color (bool): True if RGB, false if grayscale. 90 | 91 | Returns: 92 | vid (np.ndarray): Rendered rollout as a sequence of images 93 | """ 94 | q = self._rollout[0, :] 95 | length = len(q) 96 | vid = np.zeros((length, res, res, 3), dtype='float') 97 | ball_color = self._default_ball_colors[0] 98 | space_res = 2.*self.get_world_size()/res 99 | for t in range(length): 100 | vid[t] = cv2.circle(vid[t], self._world_to_pixels(0, q[t], res), 101 | int(self.mass/space_res), ball_color, -1) 102 | vid[t] = cv2.blur(cv2.blur(vid[t], (2, 2)), (2, 2)) 103 | vid += self._default_background_color 104 | vid[vid > 1.] = 1. 105 | if not color: 106 | vid = np.expand_dims(np.max(vid, axis=-1), -1) 107 | return vid 108 | 109 | def _sample_init_conditions(self, radius_bound): 110 | """Samples random initial conditions for the environment 111 | 112 | Args: 113 | radius_bound (float, float): Radius lower and upper bound of the phase state sampling. 114 | Optionally, it can be a string 'auto'. In that case, the value returned by 115 | get_default_radius_bounds() will be returned. 116 | """ 117 | radius_lb, radius_ub = radius_bound 118 | radius = np.random.rand()*(radius_ub - radius_lb) + radius_lb 119 | states = np.random.rand(2) * 2. - 1 120 | states = (states / np.sqrt((states**2).sum())) * radius 121 | self.set([states[0]], [states[1]]) 122 | 123 | 124 | # Sample code for sampling rollouts 125 | if __name__ == "__main__": 126 | 127 | sp = Spring(mass=.5, elastic_cst=2, damping_ratio=0.) 128 | rolls = sp.sample_random_rollouts(number_of_frames=100, 129 | delta_time=0.1, 130 | number_of_rollouts=16, 131 | img_size=32, 132 | noise_level=0., 133 | radius_bound=(.5, 1.4), 134 | color=True, 135 | seed=None) 136 | idx = np.random.randint(rolls.shape[0]) 137 | visualize_rollout(rolls[idx]) 138 | -------------------------------------------------------------------------------- /experiment_params/dataset_offline_default.yaml: -------------------------------------------------------------------------------- 1 | # Define data paths 2 | dataset: 3 | train_data: "datasets/pendulum_data/train" # Should be of 50k samples 4 | test_data: "datasets/pendulum_data/test" # Should be of 10k samples 5 | -------------------------------------------------------------------------------- /experiment_params/dataset_online_default.yaml: -------------------------------------------------------------------------------- 1 | # Define environment 2 | environment: 3 | # The following parameters must correspond in name and type 4 | # to the environment __init__() arguments 5 | name: "Pendulum" 6 | mass: 0.5 7 | length: 1 8 | g: 3 9 | 10 | # Define data characteristics 11 | dataset: 12 | img_size: 32 13 | radius_bound: 'auto' 14 | num_train_samples: 50000 # Total number of rollouts used when training on-line. 15 | num_test_samples: 10000 # Number of test samples. 16 | rollout: 17 | seq_length: 34 18 | delta_time: 0.125 19 | n_channels: 3 20 | noise_level: 1 # Level of environment noise. 0 means no noise, 1 means max noise. 21 | # Maximum values are defined in each environment. 22 | -------------------------------------------------------------------------------- /experiment_params/default_environments/chaotic_pendulum.yaml: -------------------------------------------------------------------------------- 1 | environment: 2 | name: "ChaoticPendulum" 3 | mass: 1.0 4 | length: 1.0 5 | g: 3.0 -------------------------------------------------------------------------------- /experiment_params/default_environments/damped_spring.yaml: -------------------------------------------------------------------------------- 1 | environment: 2 | name: "Spring" 3 | mass: 0.5 4 | elastic_cst: 2.0 5 | damping_ratio: 0.3 6 | -------------------------------------------------------------------------------- /experiment_params/default_environments/pendulum.yaml: -------------------------------------------------------------------------------- 1 | environment: 2 | name: "Pendulum" 3 | mass: 0.5 4 | length: 1 5 | g: 3 -------------------------------------------------------------------------------- /experiment_params/default_environments/spring.yaml: -------------------------------------------------------------------------------- 1 | environment: 2 | name: "Spring" 3 | mass: 0.5 4 | elastic_cst: 2.0 -------------------------------------------------------------------------------- /experiment_params/default_environments/three_bodies.yaml: -------------------------------------------------------------------------------- 1 | environment: 2 | name: "NObjectGravity" 3 | mass: [1., 1., 1.] 4 | g: 1.0 5 | orbit_noise: 0.1 6 | -------------------------------------------------------------------------------- /experiment_params/default_environments/two_bodies.yaml: -------------------------------------------------------------------------------- 1 | environment: 2 | name: "NObjectGravity" 3 | mass: [1., 1.] 4 | g: 1.0 5 | orbit_noise: 0.1 6 | -------------------------------------------------------------------------------- /experiment_params/train_config_default.yaml: -------------------------------------------------------------------------------- 1 | experiment_id: "default_offline" 2 | model_save_dir: "saved_models" 3 | 4 | device: 'cuda:0' # Will use this device if available, otherwise wil use cpu 5 | 6 | # Define networks architectures 7 | networks: 8 | variational: True 9 | dtype : "float" 10 | encoder: 11 | hidden_conv_layers: 6 12 | n_filters: [32, 64, 64, 64, 64, 64, 64] # first + hidden 13 | kernel_sizes: [3, 3, 3, 3, 3, 3, 3, 3] # first + hidden + last 14 | strides: [1, 1, 1, 1, 1, 1, 1, 1] # first + hidden + last 15 | out_channels: 48 16 | transformer: 17 | hidden_conv_layers: 1 18 | n_filters: [64, 64] # first + hidden 19 | kernel_sizes: [3, 3, 3] # first + hidden + last 20 | strides: [2, 2, 2] # first + hidden + last 21 | out_channels: 16 # Channels of q, and p splitted 22 | hamiltonian: 23 | hidden_conv_layers: 3 24 | in_shape: [16, 4, 4] # Should be coherent with transformer output 25 | n_filters: [32, 64, 64, 64] # first + hidden 26 | kernel_sizes: [3, 2, 2, 2, 2] # first + hidden + last 27 | strides: [1, 2, 1, 1, 1] # first + hidden + last 28 | paddings: [1, 0, [0, 1, 0, 1], [0, 1, 0, 1], 0] # first + hidden + last 29 | decoder: 30 | n_residual_blocks: 3 31 | n_filters: [64, 64, 64] 32 | kernel_sizes: [3, 3, 3, 3] 33 | 34 | # Define HGN Integrator 35 | integrator: 36 | method: "Leapfrog" 37 | 38 | # Define optimization 39 | optimization: 40 | epochs: 5 41 | batch_size: 16 42 | input_frames: 5 # Number of frames to feed to the encoder while training 43 | # Learning rates 44 | encoder_lr: 1.5e-4 45 | transformer_lr: 1.5e-4 46 | hnn_lr: 1.5e-4 47 | decoder_lr: 1.5e-4 48 | 49 | geco: 50 | alpha: 0.99 # decay of the moving average 51 | tol: 3.3e-2 # per pixel error tolerance. keep in mind this gets squared 52 | initial_lagrange_multiplier: 1.0 # this is 1/beta 53 | lagrange_multiplier_param: 0.1 # adjust update on langrange multiplier 54 | # To train in a beta-vae fashion use the following parameters: 55 | # alpha: 0.0 56 | # tol: 0.0 57 | # initial_lagrange_multiplier: 1 / beta 58 | # lagrange_multiplier_param = 1.0 -------------------------------------------------------------------------------- /generate_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import ast 3 | import copy 4 | import shutil 5 | import sys 6 | import os 7 | import yaml 8 | 9 | import numpy as np 10 | from tqdm import tqdm 11 | 12 | from environments import environment_factory 13 | 14 | 15 | def generate_and_save(root_path, environment, n_samples, n_frames, delta_time, img_size, 16 | radius_bound, noise_level, color, start_seed, train=True): 17 | path = os.path.join(root_path, 'train' if train else 'test') 18 | if not os.path.exists(path): 19 | os.makedirs(path) 20 | for i in tqdm(range(n_samples)): 21 | rolls = environment.sample_random_rollouts( 22 | number_of_frames=n_frames, 23 | delta_time=delta_time, 24 | number_of_rollouts=1, 25 | img_size=img_size, 26 | noise_level=noise_level, 27 | radius_bound=radius_bound, 28 | color=color, 29 | seed=i + start_seed 30 | )[0] 31 | filename = "{0:05d}".format(i) 32 | np.savez(os.path.join(path, filename), rolls) 33 | return path 34 | 35 | 36 | def _read_config(config_file): 37 | with open(config_file, 'r') as file: 38 | config = yaml.load(file, Loader=yaml.FullLoader) 39 | return config 40 | 41 | 42 | def _prepare_out_config(config, train_path, test_path): 43 | out_config = copy.deepcopy(config) 44 | out_config['dataset']['train_data'] = train_path 45 | out_config['dataset']['test_data'] = test_path 46 | return out_config 47 | 48 | 49 | def _overwrite_config_with_cmd_arguments(config, args): 50 | # This function overwrites parameters in the given dictionary 51 | # with the correspondent command line arguments. 52 | if args.ntrain is not None: 53 | config['dataset']['num_train_samples'] = args.ntrain[0] 54 | if args.ntest is not None: 55 | config['dataset']['num_test_samples'] = args.ntest[0] 56 | if args.env is not None: 57 | env_params = _read_config(DEFAULT_ENVIRONMENTS_PATH + args.env[0] + '.yaml') 58 | config['environment'] = env_params['environment'] 59 | if args.params is not None: 60 | for p in args.params: 61 | key, value = p.split('=') 62 | ptr = config 63 | keys = key.split('.') 64 | for i, k in enumerate(keys): 65 | if i == len(keys) - 1: 66 | ptr[k] = ast.literal_eval(value) 67 | else: 68 | ptr = ptr[k] 69 | 70 | 71 | if __name__ == '__main__': 72 | DEFAULT_DATASETS_ROOT = 'datasets/' 73 | DEFAULT_DATASET_CONFIG_FILE = 'experiment_params/dataset_online_default.yaml' 74 | DEFAULT_TRAIN_CONFIG_FILE = 'experiment_params/train_config_default.yaml' 75 | DEFAULT_ENVIRONMENTS_PATH = 'experiment_params/default_environments/' 76 | 77 | parser = argparse.ArgumentParser() 78 | parser.add_argument( 79 | '--name', action='store', nargs=1, required=True, help='The dataset name.' 80 | ) 81 | parser.add_argument( 82 | '--dataset-config', action='store', nargs=1, type=str, required=False, 83 | help=f'YAML file from which to read the dataset parameters. If not specified, ' 84 | f'{DEFAULT_DATASET_CONFIG_FILE} will be used.' 85 | ) 86 | parser.add_argument( 87 | '--ntrain', action='store', nargs=1, required=False, type=int, 88 | help='Number of training sample to generate.' 89 | ) 90 | parser.add_argument( 91 | '--ntest', action='store', nargs=1, required=False, type=int, 92 | help='Number of test samples to generate.' 93 | ) 94 | parser.add_argument( 95 | '--env', action='store', nargs=1, required=False, type=str, 96 | help=f'The default environment specifications to use. Can be \'pendulum\', \'spring\', ' 97 | f'\'two_bodies\', \'three_bodies\', \'chaotic_pendulum\'. If this argument is ' 98 | f'specified, a default environment section will be loaded from the correspondent yaml ' 99 | f'file in {DEFAULT_ENVIRONMENTS_PATH}' 100 | ) 101 | parser.add_argument( 102 | '--datasets-root', action='store', nargs=1, required=False, type=str, 103 | help=f'Root of the datasets folder in which the dataset will be stored. If not specified, ' 104 | f'{DEFAULT_DATASETS_ROOT} will be used as default.' 105 | ) 106 | parser.add_argument( 107 | '--params', action='store', nargs='+', required=False, 108 | help='Override one or more parameters in the config. The format of an argument is ' 109 | 'param_name=param_value. Nested parameters are accessible by using a dot, ' 110 | 'i.e. --param dataset.img_size=32. IMPORTANT: lists must be enclosed in double ' 111 | 'quotes, i.e. --param environment.mass:"[0.5, 0.5]".' 112 | ) 113 | _args = parser.parse_args() 114 | 115 | # Read yaml file with parameters definition 116 | _dataset_config_file = _args.dataset_config[0] if _args.dataset_config is not None else \ 117 | DEFAULT_DATASET_CONFIG_FILE 118 | _dataset_config = _read_config(_dataset_config_file) 119 | 120 | # Overwrite dictionary from command line args to ensure they will be used 121 | _overwrite_config_with_cmd_arguments(_dataset_config, _args) 122 | 123 | # Extract environment parameters 124 | EXP_NAME = _args.name[0] 125 | N_TRAIN_SAMPLES = _dataset_config['dataset']['num_train_samples'] 126 | N_TEST_SAMPLES = _dataset_config['dataset']['num_test_samples'] 127 | IMG_SIZE = _dataset_config['dataset']['img_size'] 128 | RADIUS_BOUND = _dataset_config['dataset']['radius_bound'] 129 | NOISE_LEVEL = _dataset_config['dataset']['rollout']['noise_level'] 130 | N_FRAMES = _dataset_config['dataset']['rollout']['seq_length'] 131 | DELTA_TIME = _dataset_config['dataset']['rollout']['delta_time'] 132 | N_CHANNELS = _dataset_config['dataset']['rollout']['n_channels'] 133 | 134 | # Get dataset output path 135 | dataset_root = DEFAULT_DATASETS_ROOT if _args.datasets_root is None else _args.datasets_root[0] 136 | dataset_root = os.path.join(dataset_root, EXP_NAME) 137 | 138 | # Get the environment object from dictionary parameters 139 | environment = environment_factory.EnvFactory.get_environment(**_dataset_config['environment']) 140 | 141 | # Ask user confirmation 142 | print(f'The dataset will be generated with the following configuration:') 143 | print(f'PATH: {dataset_root}') 144 | print(f'dataset: {_dataset_config["dataset"]}') 145 | print(f'environment: {_dataset_config["environment"]}') 146 | print('\nProceed? (y/n):') 147 | if input() != 'y': 148 | print('Aborting') 149 | exit() 150 | 151 | # Generate train samples 152 | _train_path = generate_and_save( 153 | root_path=dataset_root, environment=environment, 154 | n_samples=N_TRAIN_SAMPLES, n_frames=N_FRAMES, delta_time=DELTA_TIME, img_size=IMG_SIZE, 155 | radius_bound=RADIUS_BOUND, noise_level=NOISE_LEVEL, color=N_CHANNELS == 3, 156 | start_seed=0, train=True 157 | ) 158 | 159 | # Generate test samples 160 | _test_path = None 161 | if N_TEST_SAMPLES > 0: 162 | _test_path = generate_and_save( 163 | root_path=dataset_root, environment=environment, 164 | n_samples=N_TEST_SAMPLES, n_frames=N_FRAMES, delta_time=DELTA_TIME, 165 | img_size=IMG_SIZE, radius_bound=RADIUS_BOUND, noise_level=NOISE_LEVEL, 166 | color=N_CHANNELS == 3, start_seed=N_TRAIN_SAMPLES, train=False 167 | ) 168 | 169 | # Convert parameters to offline train parameters and write them in the dataset 170 | _out_config = _prepare_out_config(_dataset_config, _train_path, _test_path) 171 | yaml_content = yaml.dump(_out_config, default_flow_style=True) 172 | config_out_path = os.path.join(dataset_root, 'parameters.yaml') 173 | with open(config_out_path, 'x') as f: 174 | f.write(yaml_content) 175 | f.close() 176 | 177 | print(f'A parameter file ready to be trained on was generated at {config_out_path}') 178 | -------------------------------------------------------------------------------- /hamiltonian_generative_network.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | 4 | import torch 5 | 6 | from utilities import conversions 7 | from utilities.integrator import Integrator 8 | from utilities.hgn_result import HgnResult 9 | 10 | 11 | class HGN: 12 | """Hamiltonian Generative Network model. 13 | 14 | This class models the HGN and implements its training and evaluation. 15 | """ 16 | ENCODER_FILENAME = "encoder.pt" 17 | TRANSFORMER_FILENAME = "transformer.pt" 18 | HAMILTONIAN_FILENAME = "hamiltonian.pt" 19 | DECODER_FILENAME = "decoder.pt" 20 | 21 | def __init__(self, 22 | encoder, 23 | transformer, 24 | hnn, 25 | decoder, 26 | integrator, 27 | device, 28 | dtype, 29 | seq_len, 30 | channels): 31 | """Instantiate a Hamiltonian Generative Network. 32 | 33 | Args: 34 | encoder (networks.encoder_net.EncoderNet): Encoder neural network. 35 | transformer (networks.transformer_net.TransformerNet): Transformer neural network. 36 | hnn (networks.hamiltonian_net.HamiltonianNet): Hamiltonian neural network. 37 | decoder (networks.decoder_net.DecoderNet): Decoder neural network. 38 | integrator (Integrator): HGN integrator. 39 | device (str): String with the device to use. E.g. 'cuda:0', 'cpu'. 40 | dtype (torch.dtype): Data type used for the networks. 41 | seq_len (int): Number of frames in each rollout. 42 | channels (int, optional): Number of channels of the images. Defaults to 3. 43 | """ 44 | # Parameters 45 | self.seq_len = seq_len 46 | self.channels = channels 47 | self.device = device 48 | self.dtype = dtype 49 | 50 | # Modules 51 | self.encoder = encoder 52 | self.transformer = transformer 53 | self.hnn = hnn 54 | self.decoder = decoder 55 | self.integrator = integrator 56 | 57 | def forward(self, rollout_batch, n_steps=None, variational=True): 58 | """Get the prediction of the HGN for a given rollout_batch of n_steps. 59 | 60 | Args: 61 | rollout_batch (torch.Tensor): Minibatch of rollouts as a Tensor of shape 62 | (batch_size, seq_len, channels, height, width). 63 | n_steps (integer, optional): Number of guessed steps, if None will match seq_len. 64 | Defaults to None. 65 | variational (bool): Whether to sample from the encoder distribution or take the mean. 66 | 67 | Returns: 68 | (utilities.HgnResult): An HgnResult object containing data of the forward pass over the 69 | given minibatch. 70 | """ 71 | n_steps = self.seq_len if n_steps is None else n_steps 72 | 73 | # Instantiate prediction object 74 | prediction_shape = list(rollout_batch.shape) 75 | prediction_shape[1] = n_steps + 1 # Count the first one 76 | prediction = HgnResult(batch_shape=torch.Size(prediction_shape), 77 | device=self.device) 78 | prediction.set_input(rollout_batch) 79 | 80 | # Concat along channel dimension 81 | rollout_batch = conversions.concat_rgb(rollout_batch) 82 | 83 | # Latent distribution 84 | z, z_mean, z_logvar = self.encoder(rollout_batch, sample=variational) 85 | prediction.set_z(z_sample=z, z_mean=z_mean, z_logvar=z_logvar) 86 | 87 | # Initial state 88 | q, p = self.transformer(z) 89 | prediction.append_state(q=q, p=p) 90 | 91 | # Initial state reconstruction 92 | x_reconstructed = self.decoder(q) 93 | prediction.append_reconstruction(x_reconstructed) 94 | 95 | # Estimate predictions 96 | for _ in range(n_steps): 97 | # Compute next state 98 | q, p = self.integrator.step(q=q, p=p, hnn=self.hnn) 99 | prediction.append_state(q=q, p=p) 100 | prediction.append_energy(self.integrator.energy) # This is the energy of previous timestep 101 | 102 | # Compute state reconstruction 103 | x_reconstructed = self.decoder(q) 104 | prediction.append_reconstruction(x_reconstructed) 105 | 106 | # We need to add the energy of the system at the last time-step 107 | with torch.no_grad(): 108 | last_energy = self.hnn(q=q, p=p).detach().cpu().numpy() 109 | prediction.append_energy(last_energy) # This is the energy of previous timestep 110 | return prediction 111 | 112 | def load(self, directory): 113 | """Load networks' parameters 114 | 115 | Args: 116 | directory (string): Path to the saved models 117 | """ 118 | self.encoder = torch.load(os.path.join(directory, 119 | self.ENCODER_FILENAME), 120 | map_location=self.device) 121 | self.transformer = torch.load(os.path.join(directory, 122 | self.TRANSFORMER_FILENAME), 123 | map_location=self.device) 124 | self.hnn = torch.load(os.path.join(directory, 125 | self.HAMILTONIAN_FILENAME), 126 | map_location=self.device) 127 | self.decoder = torch.load(os.path.join(directory, 128 | self.DECODER_FILENAME), 129 | map_location=self.device) 130 | 131 | def save(self, directory): 132 | """Save networks' parameters 133 | 134 | Args: 135 | directory (string): Path where to save the models, if does not exist it, is created 136 | """ 137 | pathlib.Path(directory).mkdir(parents=True, exist_ok=True) 138 | torch.save(self.encoder, os.path.join(directory, 139 | self.ENCODER_FILENAME)) 140 | torch.save(self.transformer, 141 | os.path.join(directory, self.TRANSFORMER_FILENAME)) 142 | torch.save(self.hnn, os.path.join(directory, 143 | self.HAMILTONIAN_FILENAME)) 144 | torch.save(self.decoder, os.path.join(directory, 145 | self.DECODER_FILENAME)) 146 | 147 | def debug_mode(self): 148 | """Set the network to debug mode, i.e. allow intermediate gradients to be retrieved. 149 | """ 150 | for module in [self.encoder, self.transformer, self.decoder, self.hnn]: 151 | for name, layer in module.named_parameters(): 152 | layer.retain_grad() 153 | 154 | def get_random_sample(self, n_steps, img_shape=(32, 32)): 155 | """Sample a rollout from the HGN 156 | 157 | Args: 158 | n_steps (int): Length of the sampled rollout 159 | img_shape (tuple(int, int), optional): Size of the images, should match the trained ones. Defaults to (32, 32). 160 | 161 | Returns: 162 | (utilities.HgnResult): An HgnResult object containing data of the forward pass over the 163 | given minibatch. 164 | """ 165 | # Sample from a normal distribution the latent representation of the rollout 166 | latent_shape = (1, self.encoder.out_mean.out_channels, img_shape[0], 167 | img_shape[1]) 168 | latent_representation = torch.randn(latent_shape).to(self.device) 169 | 170 | # Instantiate prediction object 171 | prediction_shape = (1, n_steps, self.channels, img_shape[0], 172 | img_shape[1]) 173 | prediction = HgnResult(batch_shape=torch.Size(prediction_shape), 174 | device=self.device) 175 | 176 | prediction.set_z(z_sample=latent_representation) 177 | 178 | # Initial state 179 | q, p = self.transformer(latent_representation) 180 | prediction.append_state(q=q, p=p) 181 | 182 | # Initial state reconstruction 183 | x_reconstructed = self.decoder(q) 184 | prediction.append_reconstruction(x_reconstructed) 185 | 186 | # Estimate predictions 187 | for _ in range(n_steps - 1): 188 | # Compute next state 189 | q, p = self.integrator.step(q=q, p=p, hnn=self.hnn) 190 | prediction.append_state(q=q, p=p) 191 | 192 | # Compute state reconstruction 193 | x_reconstructed = self.decoder(q) 194 | prediction.append_reconstruction(x_reconstructed) 195 | return prediction 196 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CampusAI/Hamiltonian-Generative-Networks/702d3ff3aec40eba20e17c5a1612b5b0b1e2f831/networks/__init__.py -------------------------------------------------------------------------------- /networks/debug_networks.py: -------------------------------------------------------------------------------- 1 | """This module contains simple, easy to debug networks that have the same interface as the other 2 | networks in the package. 3 | """ 4 | 5 | import sys 6 | import os 7 | 8 | import torch 9 | from torch import nn 10 | 11 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 12 | 13 | from networks import decoder_net 14 | from networks import encoder_net 15 | from networks import hamiltonian_net 16 | from networks import transformer_net 17 | 18 | 19 | class EncoderNet(encoder_net.EncoderNet): 20 | """Network that encodes the input value into a two-dimensional latent encoding. 21 | """ 22 | def __init__(self, phi=None, dtype=torch.float): 23 | """Create the encoder network. 24 | 25 | Args: 26 | phi (torch.Tensor): 2-dimensional Tensor with weights. 27 | dtype (torch.dtype): Type of the weights. 28 | """ 29 | nn.Module.__init__(self) 30 | self.phi = phi if phi is not None else nn.Parameter( 31 | torch.tensor([1., 2.], requires_grad=True, dtype=dtype)) 32 | 33 | def forward(self, x): 34 | """Compute the encoding of x. 35 | 36 | Args: 37 | x (torch.Tensor) must be a N x k x shape where N is the batch size, k is the length 38 | of the sequence, and shape is the shape of each frame in the sequence. Here we 39 | assume k=2 and shape = 1. The encoding is given by 40 | [phi[0] * x[0], phi[1] * (x[1] - x[0])] 41 | Returns: 42 | A tuple (encoding, mean, var), where encoding is a N x 2 Tensor, N is the batch 43 | size. mean and var are the mean and variance tensors, returned for compatibility. 44 | """ 45 | q = self.phi[0] * x[:, 0, :] 46 | p = self.phi[1] * (x[:, 1, :] - x[:, 0, :]) 47 | encoding = torch.stack((q, p), dim=1) 48 | return encoding, torch.zeros_like(encoding), torch.ones_like(encoding) 49 | 50 | 51 | class TransformerNet(transformer_net.TransformerNet): 52 | """Transforms the given encoding into abstract phase space q and p. 53 | """ 54 | def __init__(self, w=None, dtype=torch.float): 55 | """Create the transformer net by setting the weighs. 56 | 57 | Args: 58 | w (torch.Tensor): Tensor of weights. 59 | dtype (torch.dtype): Type of the weights. 60 | """ 61 | nn.Module.__init__(self) 62 | self.w = w if w is not None else nn.Parameter( 63 | torch.tensor([1., 1.], requires_grad=True, dtype=dtype)) 64 | 65 | def forward(self, x): 66 | """Transform the two dimensional input tensor x into q, p as q = w_0 * x_0, p = w_1 * x_1 67 | 68 | Args: 69 | x (torch.Tensor): Two dimensional latent encoding. 70 | 71 | Returns: 72 | A tuple two Tensors q and p. 73 | """ 74 | q = self.w[0] * x[:, 0, :] 75 | p = self.w[1] * x[:, 1, :] 76 | encoding = torch.stack((q, p), dim=1) 77 | return self.to_phase_space(encoding) 78 | 79 | 80 | class HamiltonianNet(hamiltonian_net.HamiltonianNet): 81 | """Computes the hamiltonian from the given q and p. 82 | """ 83 | def __init__(self, gamma=None, dtype=torch.float): 84 | """Create the hamiltonian net. 85 | 86 | Args: 87 | gamma (torch.Tensor): A two dimensional tensor with the weights. 88 | dtype (torch.type): Type of the weights. 89 | """ 90 | nn.Module.__init__(self) 91 | self.gamma = gamma if gamma is not None else nn.Parameter( 92 | torch.tensor([3., 4.], requires_grad=True, dtype=dtype)) 93 | 94 | def forward(self, q, p): 95 | return self.gamma[0] * q + self.gamma[1] * p**2 96 | 97 | 98 | class DecoderNet(decoder_net.DecoderNet): 99 | """Decoder net debug implementation, where the input q is decoded as q * theta 100 | """ 101 | def __init__(self, theta=None, dtype=torch.float): 102 | """Create the debug decoder network. 103 | 104 | Args: 105 | theta (torch.Tensor): 1 dimensional Tensor containing theta. 106 | dtype (torch.dtype): Type of theta. 107 | """ 108 | nn.Module.__init__(self) 109 | self.theta = theta if theta is not None else nn.Parameter( 110 | torch.tensor([2.], dtype=dtype)) 111 | 112 | def forward(self, q): 113 | """Returns q * theta. 114 | 115 | Args: 116 | q (torch.tensor): A one-dimensional Tensor. 117 | 118 | Returns: 119 | A one dimensional Tensor. 120 | """ 121 | return self.theta * q 122 | 123 | 124 | if __name__ == '__main__': 125 | enc = EncoderNet() 126 | transf = TransformerNet() 127 | ham = HamiltonianNet() 128 | dec = DecoderNet() -------------------------------------------------------------------------------- /networks/decoder_net.py: -------------------------------------------------------------------------------- 1 | """This module contains the implementation of a decoder network, that applies 3 residual blocks 2 | to the input abstract position q. In the paper q is a (16, 4, 4) tensor that can be seen as a 4x4 3 | image with 16 channels, but here other sizes may be used. 4 | """ 5 | 6 | import torch 7 | from torch import nn 8 | 9 | 10 | class ResidualBlock(nn.Module): 11 | """A residual block that up-samples the input image by a factor of 2. 12 | """ 13 | def __init__(self, in_channels, n_filters=64, kernel_size=3, dtype=torch.float): 14 | """Instantiate the residual block, composed by a 2x up-sampling and two convolutional 15 | layers. 16 | 17 | Args: 18 | in_channels (int): Number of input channels. 19 | n_filters (int): Number of filters, and thus output channels. 20 | kernel_size (int): Size of the convolutional kernels. 21 | dtype (torch.dtype): Type to be used in tensors. 22 | """ 23 | super().__init__() 24 | self.channels = in_channels 25 | self.n_filters = n_filters 26 | padding = int(kernel_size / 2) 27 | self.conv1 = nn.Conv2d( 28 | in_channels=in_channels, 29 | out_channels=n_filters, 30 | kernel_size=kernel_size, 31 | padding=padding, 32 | ) 33 | self.conv2 = nn.Conv2d( 34 | in_channels=n_filters, 35 | out_channels=n_filters, 36 | kernel_size=kernel_size, 37 | padding=padding, 38 | ) 39 | if in_channels != n_filters: 40 | self.dim_match_conv = nn.Conv2d(in_channels=in_channels, 41 | out_channels=n_filters, 42 | kernel_size=1, 43 | padding=0) 44 | self.leaky_relu = nn.LeakyReLU() 45 | self.upsample = nn.UpsamplingNearest2d(scale_factor=2) 46 | self.sigmoid = nn.Sigmoid() 47 | self.type(dtype) 48 | 49 | def forward(self, x): 50 | """Apply 2x up-sampling, followed by two convolutional layers with leaky relu. A sigmoid 51 | activation is applied at the end. 52 | 53 | TODO: Should we use batch normalization? It is often common in residual blocks. 54 | TODO: Here we apply a convolutional layer to the input up-sampled tensor if its number 55 | of channels does not match the convolutional layer channels. Is this the correct way? 56 | 57 | Args: 58 | x (torch.Tensor): Input image of shape (N, C, H, W) where N is the batch size and C 59 | is the number of in_channels. 60 | 61 | Returns: 62 | A torch.Tensor with the up-sampled images, of shape (N, n_filters, H, W). 63 | """ 64 | x = self.upsample(x) 65 | residual = self.dim_match_conv( 66 | x) if self.channels != self.n_filters else x 67 | x = self.leaky_relu(self.conv1(x)) 68 | x = self.leaky_relu(self.conv2(x)) 69 | x = self.sigmoid(x + residual) 70 | return x 71 | 72 | 73 | class DecoderNet(nn.Module): 74 | """The Decoder network, that takes a latent encoding of shape (in_channels, H, W) 75 | and produces the output image by applying 3 ResidualBlock modules and a final 1x1 convolution. 76 | Each residual block up-scales the image by 2, and the convolution produces the desired number 77 | of output channels, thus the output shape is (out_channels, H*2^3, W*2^3). 78 | """ 79 | 80 | DEFAULT_PARAMS = { 81 | 'n_residual_blocks': 3, 82 | 'n_filters': [64, 64, 64], 83 | 'kernel_sizes': [3, 3, 3, 3], 84 | } 85 | 86 | def __init__(self, 87 | in_channels, 88 | out_channels=3, 89 | n_residual_blocks=None, 90 | n_filters=None, 91 | kernel_sizes=None, 92 | dtype=torch.float): 93 | """Create the decoder network composed of the given number of residual blocks. 94 | 95 | Args: 96 | in_channels (int): Number of input encodings channels. 97 | out_channels (int): Number output image channels (1 for grayscale, 3 for RGB). 98 | n_residual_blocks (int): Number of residual blocks in the network. 99 | n_filters (list): List where the i-th element is the number of filters for 100 | convolutional layers for the i-th residual block, excluding the output block. 101 | Therefore, n_filters must be of length n_residual_blocks - 1 102 | kernel_sizes(list): List where the i-th element is the kernel size of convolutional 103 | layers for the i-th residual block. 104 | """ 105 | super().__init__() 106 | if all(var is None for var in (n_residual_blocks, n_filters, kernel_sizes)): 107 | n_residual_blocks = DecoderNet.DEFAULT_PARAMS['n_residual_blocks'] 108 | n_filters = DecoderNet.DEFAULT_PARAMS['n_filters'] 109 | kernel_sizes = DecoderNet.DEFAULT_PARAMS['kernel_sizes'] 110 | elif all(var is not None for var in (n_residual_blocks, n_filters, kernel_sizes)): 111 | assert len(kernel_sizes) == n_residual_blocks + 1, \ 112 | 'kernel_sizes and upsample must be of length n_residual_blocks + 1 ('\ 113 | + str(n_residual_blocks + 1) + ' in this case).' 114 | assert len(n_filters) == n_residual_blocks, 'n_filters must be of length ' \ 115 | 'n_residual_blocks (' + str(n_residual_blocks) + ' in this case).' 116 | else: 117 | raise ValueError( 118 | 'Args n_residual_blocks, n_filters, kernel_size, upsample ' 119 | 'can only be either all None, or all defined by the user.') 120 | filters = [in_channels] + n_filters 121 | self.residual_blocks = nn.ModuleList([ 122 | ResidualBlock( 123 | in_channels=int(filters[i]), 124 | n_filters=int(filters[i + 1]), 125 | kernel_size=int(kernel_sizes[i]), 126 | dtype=dtype 127 | ) for i in range(n_residual_blocks) 128 | ]) 129 | self.out_conv = nn.Conv2d( 130 | in_channels=filters[-1], 131 | out_channels=out_channels, 132 | kernel_size=kernel_sizes[-1], 133 | padding=int(kernel_sizes[-1] / 2) # To not resize the image 134 | ) 135 | self.sigmoid = nn.Sigmoid() 136 | self.type(dtype) 137 | 138 | def forward(self, x): 139 | """Apply the three residual blocks and the final convolutional layer. 140 | 141 | Args: 142 | x (torch.Tensor): Tensor of shape (N, in_channels, H, W) where N is the batch size. 143 | 144 | Returns: 145 | Tensor of shape (out_channels, H * 2^3, W * 2^3) with the reconstructed image. 146 | """ 147 | for layer in self.residual_blocks: 148 | x = layer(x) 149 | x = self.sigmoid(self.out_conv(x)) 150 | return x 151 | -------------------------------------------------------------------------------- /networks/encoder_net.py: -------------------------------------------------------------------------------- 1 | """This module contains the implementation of the Encoder step of the Hamiltonian Generative Networks 2 | paper. The encoder maps the input sequence of frames into a latent distribution and samples a Tensor z 3 | from it using the re-parametrization trick. 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class EncoderNet(nn.Module): 11 | """Implementation of the encoder network, that encodes the input frames sequence into a 12 | distribution over the latent space and samples with the common reparametrization trick. 13 | 14 | The network expects the images to be concatenated along channel dimension. This means that if 15 | a batch of sequences has shape (batch_size, seq_len, channels, height, width) the network 16 | will accept an input of shape (batch_size, seq_len * channels, height, width). 17 | """ 18 | 19 | DEFAULT_PARAMS = { 20 | 'hidden_conv_layers': 6, 21 | 'n_filters': [32, 64, 64, 64, 64, 64, 64], 22 | 'kernel_sizes': [3, 3, 3, 3, 3, 3, 3, 3], 23 | 'strides': [1, 1, 1, 1, 1, 1, 1, 1], 24 | } 25 | 26 | def __init__(self, 27 | seq_len, 28 | in_channels, 29 | out_channels, 30 | hidden_conv_layers=None, 31 | n_filters=None, 32 | kernel_sizes=None, 33 | strides=None, 34 | act_func=nn.ReLU(), 35 | dtype=torch.float): 36 | """Instantiate the convolutional layers that compose the input network with the 37 | appropriate shapes. 38 | 39 | If K is the total number of layers, then hidden_conv_layers = K - 2. The length of 40 | n_filters must be K - 1, and that of kernel_sizes and strides must be K. If all 41 | them are None, EncoderNet.DEFAULT_PARAMS will be used. 42 | 43 | Args: 44 | seq_len (int): Number of frames that compose a sequence. 45 | in_channels (int): Number of channels of images in the input sequence. 46 | out_channels (int): Number of channels of the output latent encoding. 47 | hidden_conv_layers (int): Number of hidden convolutional layers (excluding the input 48 | and the two output layers for mean and variance). 49 | n_filters (list): List with number of filters for each of the hidden layers. 50 | kernel_sizes (list): List with kernel sizes for each convolutional layer. 51 | strides (list): List with strides for each convolutional layer. 52 | act_func (torch.nn.Module): The activation function to apply after each layer. 53 | dtype (torch.dtype): Type of the weights. 54 | """ 55 | super().__init__() 56 | if all(var is None for var in (hidden_conv_layers, n_filters, 57 | kernel_sizes, strides)): 58 | hidden_conv_layers = EncoderNet.DEFAULT_PARAMS[ 59 | 'hidden_conv_layers'] 60 | n_filters = EncoderNet.DEFAULT_PARAMS['n_filters'] 61 | kernel_sizes = EncoderNet.DEFAULT_PARAMS['kernel_sizes'] 62 | strides = EncoderNet.DEFAULT_PARAMS['strides'] 63 | elif all(var is not None for var in (hidden_conv_layers, n_filters, 64 | kernel_sizes, strides)): 65 | # If no Nones, check consistency 66 | assert len(n_filters) == hidden_conv_layers + 1,\ 67 | 'n_filters must be a list of length hidden_conv_layers + 1 ' \ 68 | '(' + str(hidden_conv_layers + 1) + ' in this case).' 69 | assert len(kernel_sizes) == hidden_conv_layers + 2 and \ 70 | len(strides) == hidden_conv_layers + 2, \ 71 | 'kernel_sizes and strides must be lists with values foreach layer in the ' \ 72 | 'network (' + str(hidden_conv_layers + 2) + ' in this case).' 73 | else: 74 | raise ValueError( 75 | 'Args hidden_conv_layers, n_filters, kernel_sizes, and strides' 76 | 'can only be either all None, or all defined by the user.') 77 | paddings = [int(k / 2) for k in kernel_sizes] 78 | self.input_conv = nn.Conv2d(in_channels=seq_len * in_channels, 79 | out_channels=n_filters[0], 80 | kernel_size=kernel_sizes[0], 81 | padding=paddings[0], 82 | stride=strides[0]) 83 | self.hidden_layers = nn.ModuleList(modules=[ 84 | nn.Conv2d(in_channels=n_filters[i], 85 | out_channels=n_filters[i + 1], 86 | kernel_size=kernel_sizes[i + 1], 87 | padding=paddings[i + 1], 88 | stride=strides[i + 1]) for i in range(hidden_conv_layers) 89 | ]) 90 | self.out_mean = nn.Conv2d(in_channels=n_filters[-1], 91 | out_channels=out_channels, 92 | kernel_size=kernel_sizes[-1], 93 | padding=paddings[-1], 94 | stride=strides[-1]) 95 | self.out_logvar = nn.Conv2d(in_channels=n_filters[-1], 96 | out_channels=out_channels, 97 | kernel_size=kernel_sizes[-1], 98 | padding=paddings[-1], 99 | stride=strides[-1]) 100 | self.activation = act_func 101 | self.type(dtype) 102 | 103 | def forward(self, x, sample=True): 104 | """Compute the encoding of the given sequence of images. 105 | 106 | Args: 107 | x (torch.Tensor): A (batch_size, seq_len * channels, height, width) tensor containing 108 | the sequence of frames. 109 | sample (bool): Whether to sample from the encoding distribution or returning the mean. 110 | 111 | Returns: 112 | A tuple (z, mu, log_var), which are all N x 48 x H x W tensors. z is the latent encoding 113 | for the given input sequence, while mu and log_var are distribution parameters. 114 | """ 115 | x = self.activation(self.input_conv(x)) 116 | for layer in self.hidden_layers: 117 | x = self.activation(layer(x)) 118 | mean = self.out_mean(x) 119 | if not sample: 120 | return mean, None, None # Return None to ensure that they're not used in loss 121 | log_var = self.out_logvar(x) 122 | stddev = torch.exp(0.5 * log_var) 123 | epsilon = torch.randn_like(mean) 124 | z = mean + stddev * epsilon 125 | return z, mean, log_var -------------------------------------------------------------------------------- /networks/hamiltonian_net.py: -------------------------------------------------------------------------------- 1 | """This module contains an implementation of the Hamiltonian network described in the paper. 2 | The Hamiltonian network takes the abstract positions and momenta, q and p, and computes a scalar 3 | value that is interpreted as the Hamiltonian. 4 | """ 5 | 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | 10 | 11 | class HamiltonianNet(nn.Module): 12 | """The Hamiltonian network, composed of 6 convolutional layers and a final linear layer. 13 | """ 14 | 15 | DEFAULT_PARAMS = { 16 | 'hidden_conv_layers': 6, 17 | 'n_filters': [32, 64, 64, 64, 64, 64, 64, 64], 18 | 'kernel_sizes': [3, 3, 3, 3, 3, 3, 3, 3], 19 | 'strides': [1, 1, 1, 1, 1, 1, 1, 1], 20 | } 21 | 22 | def __init__(self, 23 | in_shape, 24 | hidden_conv_layers=None, 25 | n_filters=None, 26 | kernel_sizes=None, 27 | strides=None, 28 | paddings=None, 29 | act_func=nn.Softplus(), 30 | dtype=torch.float): 31 | """Create the layers of the Hamiltonian network. 32 | 33 | If K is the total number of convolutional layers, then hidden_conv_layers = K - 2. 34 | The length of n_filters, kernel_sizes, and strides must be K. If all of them are None, 35 | HamiltonianNet.DEFAULT_PARAMS will be used. 36 | 37 | Args: 38 | in_shape (tuple): Shape of input elements (channels, height, width). 39 | hidden_conv_layers (int): Number of hidden convolutional layers (excluding the input and 40 | the two output layers for mean and variance). 41 | n_filters (list): List with number of filters for each of the hidden layers. 42 | kernel_sizes (list): List with kernel sizes for each convolutional layer. 43 | strides (list): List with strides for each convolutional layer. 44 | act_func (torch.nn.Module): The activation function to apply after each layer. 45 | dtype (torch.dtype): Type of the weights. 46 | """ 47 | super().__init__() 48 | if all(var is None for var in (hidden_conv_layers, n_filters, 49 | kernel_sizes, strides, paddings)): 50 | hidden_conv_layers = HamiltonianNet.DEFAULT_PARAMS[ 51 | 'hidden_conv_layers'] 52 | n_filters = HamiltonianNet.DEFAULT_PARAMS['n_filters'] 53 | kernel_sizes = HamiltonianNet.DEFAULT_PARAMS['kernel_sizes'] 54 | strides = HamiltonianNet.DEFAULT_PARAMS['strides'] 55 | paddings = HamiltonianNet.DEFAULT_PARAMS['paddings'] 56 | elif all(var is not None for var in (hidden_conv_layers, n_filters, 57 | kernel_sizes, strides, paddings)): 58 | # If no Nones, check consistency 59 | assert len(n_filters) == hidden_conv_layers + 1,\ 60 | 'n_filters must be a list of length hidden_conv_layers + 2 ' \ 61 | '(' + str(hidden_conv_layers + 2) + ' in this case).' 62 | assert len(kernel_sizes) == hidden_conv_layers + 2 and \ 63 | len(strides) == len(kernel_sizes) and \ 64 | len(paddings) == len(kernel_sizes), \ 65 | 'kernel_sizes and strides must be lists with values foreach layer in the ' \ 66 | 'network (' + str(hidden_conv_layers + 2) + ' in this case).' 67 | else: 68 | raise ValueError( 69 | 'Args hidden_conv_layers, n_filters, kernel_sizes, and strides' 70 | 'can only be either all None, or all defined by the user.' 71 | ) 72 | self.paddings = paddings 73 | conv_paddings = [0 if isinstance(p, list) else p for p in paddings] 74 | in_channels = in_shape[0] * 2 75 | self.in_conv = nn.Conv2d(in_channels=in_channels, 76 | out_channels=n_filters[0], 77 | kernel_size=kernel_sizes[0], 78 | padding=conv_paddings[0], 79 | stride=strides[0]) 80 | self.hidden_layers = nn.ModuleList(modules=[ 81 | nn.Conv2d(in_channels=n_filters[i], 82 | out_channels=n_filters[i + 1], 83 | kernel_size=kernel_sizes[i + 1], 84 | padding=conv_paddings[i + 1], 85 | stride=strides[i + 1]) for i in range(hidden_conv_layers) 86 | ]) 87 | self.out_conv = nn.Conv2d(in_channels=n_filters[-1], 88 | out_channels=1, 89 | kernel_size=2, 90 | padding=0) 91 | self.activation = act_func 92 | self.type(dtype) 93 | 94 | def forward(self, q, p): 95 | """Forward pass that returns the Hamiltonian for the given q and p inputs. 96 | 97 | q and p must be two (batch_size, channels, height, width) tensors. 98 | 99 | Args: 100 | q (torch.Tensor): The tensor corresponding to the position in abstract space. 101 | p (torch.Tensor): The tensor corresponding to the momentum in abstract space. 102 | 103 | Returns: 104 | A (batch_size, 1) shaped tensor with the energy for each input in the batch. 105 | """ 106 | x = torch.cat( 107 | (q, p), 108 | dim=1) # Concatenate q and p to obtain a N x 2C x H x W tensor 109 | if isinstance(self.paddings[0], list): 110 | x = F.pad(x, self.paddings[0]) 111 | x = self.activation(self.in_conv(x)) 112 | for i, layer in enumerate(self.hidden_layers): 113 | if isinstance(self.paddings[i + 1], list): 114 | x = F.pad(x, self.paddings[i + 1]) 115 | x = self.activation(layer(x)) 116 | x = self.activation(self.out_conv(x)) 117 | x = x.squeeze(dim=1).squeeze(dim=1) 118 | return x 119 | 120 | 121 | if __name__ == '__main__': 122 | hamiltonian_net = HamiltonianNet(in_shape=(16, 4, 4)) 123 | q, p = torch.randn((2, 128, 16, 4, 4)) 124 | h = hamiltonian_net(q, p) 125 | -------------------------------------------------------------------------------- /networks/transformer_net.py: -------------------------------------------------------------------------------- 1 | """This module contains the implementation of the Transformer step of the Hamiltonian Generative Networks paper. 2 | The Transformer network takes the latent Tensor z and maps it into the abstract phase space (q, p). 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class TransformerNet(nn.Module): 10 | """Implementation of the encoder-transformer network, that maps the latent space into the 11 | phase space. 12 | """ 13 | 14 | DEFAULT_PARAMS = { 15 | 'hidden_conv_layers': 2, 16 | 'n_filters': [64, 64, 64], 17 | 'kernel_sizes': [3, 3, 3, 3], 18 | 'strides': [2, 2, 2, 1], 19 | } 20 | 21 | def __init__(self, 22 | in_channels, 23 | out_channels, 24 | hidden_conv_layers=None, 25 | n_filters=None, 26 | kernel_sizes=None, 27 | strides=None, 28 | act_func=torch.nn.ReLU(), 29 | dtype=torch.float): 30 | """Instantiate the convolutional layers with the given attributes or using the default 31 | parameters. 32 | 33 | If K is the total number of layers, then hidden_conv_layers = K - 2. The length of 34 | n_filters must be K - 1, and that of kernel_sizes and strides must be K. If all 35 | them are None, TransformerNet.DEFAULT_PARAMS will be used. 36 | 37 | 38 | Args: 39 | in_channels (int): Number of input in_channels. 40 | out_channels (int): Number of in_channels of q and p 41 | hidden_conv_layers (int): Number of hidden convolutional layers (excluding the input 42 | and the two output layers for mean and variance). 43 | n_filters (list): List with number of filters for each of the hidden layers. 44 | kernel_sizes (list): List with kernel sizes for each convolutional layer. 45 | strides (list): List with strides for each convolutional layer. 46 | act_func (torch.nn.Module): The activation function to apply after each layer. 47 | dtype (torch.dtype): Type of the weights. 48 | """ 49 | super().__init__() 50 | if all(var is None for var in (hidden_conv_layers, n_filters, 51 | kernel_sizes, strides)): 52 | hidden_conv_layers = TransformerNet.DEFAULT_PARAMS[ 53 | 'hidden_conv_layers'] 54 | n_filters = TransformerNet.DEFAULT_PARAMS['n_filters'] 55 | kernel_sizes = TransformerNet.DEFAULT_PARAMS['kernel_sizes'] 56 | strides = TransformerNet.DEFAULT_PARAMS['strides'] 57 | elif all(var is not None for var in (hidden_conv_layers, n_filters, 58 | kernel_sizes, strides)): 59 | # If no Nones, check consistency 60 | assert len(n_filters) == hidden_conv_layers + 1,\ 61 | 'n_filters must be of length hidden_conv_layers + 1 ' \ 62 | '(' + str(hidden_conv_layers + 1) + ' in this case).' 63 | assert len(kernel_sizes) == hidden_conv_layers + 2 \ 64 | and len(strides) == hidden_conv_layers + 2, \ 65 | 'kernel_sizes and strides must be lists with values foreach layer in the ' \ 66 | 'network (' + str(hidden_conv_layers + 2) + ' in this case).' 67 | else: 68 | raise ValueError( 69 | 'Args hidden_conv_layers, n_filters, kernel_sizes, and strides' 70 | 'can only be either all None, or all defined by the user.') 71 | 72 | paddings = [int(k / 2) for k in kernel_sizes] 73 | self.in_conv = nn.Conv2d(in_channels=in_channels, 74 | out_channels=n_filters[0], 75 | kernel_size=kernel_sizes[0], 76 | padding=paddings[0], 77 | stride=strides[0]) 78 | self.hidden_layers = nn.ModuleList(modules=[ 79 | nn.Conv2d(in_channels=n_filters[i], 80 | out_channels=n_filters[i + 1], 81 | kernel_size=kernel_sizes[i + 1], 82 | padding=paddings[i + 1], 83 | stride=strides[i + 1]) for i in range(hidden_conv_layers) 84 | ]) 85 | self.out_conv = nn.Conv2d(in_channels=n_filters[-1], 86 | out_channels=out_channels * 2, 87 | kernel_size=kernel_sizes[-1], 88 | padding=paddings[-1], 89 | stride=strides[-1]) 90 | self.activation = act_func 91 | self.type(dtype) 92 | 93 | def forward(self, x): 94 | """Transforms the given encoding into two tensors q, p. 95 | 96 | Args: 97 | x (torch.Tensor): A Tensor of shape (batch_size, channels, H, W). 98 | 99 | Returns: 100 | Two Tensors q, p corresponding to vectors of abstract positions and momenta. 101 | """ 102 | x = self.activation(self.in_conv(x)) 103 | for layer in self.hidden_layers: 104 | x = self.activation(layer(x)) 105 | x = self.activation(self.out_conv(x)) 106 | q, p = self.to_phase_space(x) 107 | return q, p 108 | 109 | @staticmethod 110 | def to_phase_space(encoding): 111 | """Takes the encoder-transformer output and returns the q and p tensors. 112 | 113 | Args: 114 | encoding (torch.Tensor): A tensor of shape (batch_size, channels, ...). 115 | 116 | Returns: 117 | Two tensors of shape (batch_size, channels/2, ...) resulting from splitting the given 118 | tensor along the second dimension. 119 | """ 120 | assert encoding.shape[1] % 2 == 0,\ 121 | 'The number of in_channels is odd. Cannot split properly.' 122 | half_len = int(encoding.shape[1] / 2) 123 | q = encoding[:, :half_len] 124 | p = encoding[:, half_len:] 125 | return q, p 126 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python==4.4.0.44 2 | scipy==1.4.1 3 | torch==1.6.0 4 | matplotlib==3.1.2 5 | torchvision==0.7.0 6 | numpy==1.18.1 7 | tqdm==4.42.0 8 | pytest==6.1.1 9 | PyYAML==5.3.1 10 | moviepy 11 | tensorboard==2.1.1 -------------------------------------------------------------------------------- /sample_rollouts.py: -------------------------------------------------------------------------------- 1 | """This script shows how to sample new rollouts from a variational-trained HGN. 2 | """ 3 | from hamiltonian_generative_network import HGN 4 | from utilities.integrator import Integrator 5 | 6 | if __name__=="__main__": 7 | model_to_load = "saved_models/two_bodies_default" 8 | 9 | integrator = Integrator(delta_t=0.125, method="Leapfrog") 10 | hgn = HGN(integrator=integrator) # If going to load, no need to specify networks 11 | hgn.load(model_to_load) 12 | 13 | # Sample a rollout of n_steps 14 | prediction = hgn.get_random_sample(n_steps=50, img_shape=(32, 32)) 15 | prediction.visualize() 16 | 17 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CampusAI/Hamiltonian-Generative-Networks/702d3ff3aec40eba20e17c5a1612b5b0b1e2f831/tests/__init__.py -------------------------------------------------------------------------------- /tests/check_gradients.py: -------------------------------------------------------------------------------- 1 | """Script to train the Hamiltonian Generative Network 2 | """ 3 | import os 4 | import sys 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | import torch 9 | 10 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 11 | from hamiltonian_generative_network import HGN 12 | import environments.test_env as test_env 13 | import networks.debug_networks as debug_networks 14 | import utilities.integrator as integrator 15 | 16 | epsilon = 1e-6 17 | 18 | if __name__ == "__main__": 19 | rollouts = torch.tensor([[[43.23], [22.12], [3.], [4.]]], requires_grad=True).double() 20 | 21 | # Instantiate networks 22 | encoder = debug_networks.EncoderNet(seq_len=rollouts.shape[1]) 23 | transformer = debug_networks.TransformerNet() 24 | hnn = debug_networks.HamiltonianNet() 25 | decoder = debug_networks.DecoderNet() 26 | 27 | # Define HGN integrator 28 | hgn_integrator = integrator.Integrator(delta_t=0.1, method="Euler") 29 | 30 | # Define optimization module 31 | optim_params = [ 32 | {'params': encoder.parameters(),}, 33 | {'params': transformer.parameters(),}, 34 | {'params': hnn.parameters(),}, 35 | {'params': decoder.parameters(),}, 36 | ] 37 | optimizer = torch.optim.SGD(optim_params, lr = 0.01, momentum=0.9) 38 | loss = torch.nn.MSELoss() 39 | 40 | # Instantiate Hamiltonian Generative Network 41 | hgn = HGN(encoder=encoder, 42 | transformer=transformer, 43 | hnn=hnn, 44 | decoder=decoder, 45 | integrator=hgn_integrator, 46 | loss=loss, 47 | optimizer=optimizer, 48 | device="cpu", 49 | dtype=torch.double, 50 | seq_len=rollouts.shape[1], 51 | channels=1) 52 | 53 | base_error = hgn.fit(rollouts) 54 | print(base_error) 55 | 56 | # print(torch.autograd.gradcheck(hgn.fit, rollouts)) 57 | networks = [encoder, transformer, hnn, decoder] 58 | 59 | # Automatic gradients 60 | base_gradients = [] 61 | print("Automatic gradients:") 62 | for network in networks: 63 | for param in network.parameters(): 64 | base_gradients.append(param.grad.numpy()) 65 | base_gradients = np.array(base_gradients).flatten() 66 | print(base_gradients) 67 | 68 | # Numeric gradients 69 | print("\nNumeric gradients:") 70 | num_grads = [] 71 | for network in networks: 72 | net_grads = [] 73 | for param in network.parameters(): 74 | param_copy = param.data.clone() 75 | for indx, _ in np.ndenumerate(param_copy): 76 | param_copy[indx] += epsilon 77 | param.data = param_copy 78 | error = hgn.fit(rollouts) 79 | param_copy[indx] -= epsilon 80 | param.data = param_copy 81 | print(hgn.fit(rollouts)) 82 | 83 | estimated_grad = (error - base_error)/epsilon 84 | net_grads.append(estimated_grad.detach().numpy()) 85 | num_grads.append(np.array(net_grads)) 86 | num_grads = np.array(num_grads) 87 | print(num_grads) 88 | 89 | print("\nRelative errors:") 90 | errors = 100*np.abs((base_gradients - num_grads)/num_grads) 91 | print(errors) 92 | 93 | # print("Average:", np.average(errors.flatten())) 94 | 95 | -------------------------------------------------------------------------------- /tests/grid_search.py: -------------------------------------------------------------------------------- 1 | """ Iterates through all the yaml files of the given directory and trains a model 2 | """ 3 | import os 4 | import sys 5 | 6 | import yaml 7 | 8 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 9 | from train import train 10 | 11 | if __name__=="__main__": 12 | base_params_file = "experiment_params/overfit_test.yaml" 13 | train_params_dir = "experiment_params/overfit_grid_search/" 14 | 15 | for file in os.listdir(train_params_dir): 16 | if file.endswith(".yaml"): 17 | with open(base_params_file, 'r') as f: 18 | params = yaml.load(f, Loader=yaml.FullLoader) 19 | 20 | with open(os.path.join(train_params_dir, file), 'r') as f: 21 | params_to_update = yaml.load(f, Loader=yaml.FullLoader) 22 | 23 | params.update(params_to_update) 24 | train(params) 25 | -------------------------------------------------------------------------------- /tests/test_conversions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from utilities import conversions 5 | 6 | 7 | def test_to_channels_last(): 8 | batch_size, seq_len, height, width, channels = 5, 5, 32, 32, 3 9 | tensor = torch.randn((batch_size, seq_len, channels, height, width)) 10 | converted = conversions.to_channels_last(tensor) 11 | 12 | for b in range(batch_size): 13 | for s in range(seq_len): 14 | for h in range(height): 15 | for w in range(width): 16 | for c in range(channels): 17 | assert tensor[b, s, c, h, w] == converted[b, s, h, w, c] 18 | 19 | 20 | def test_to_channels_first(): 21 | batch_size, seq_len, height, width, channels = 5, 5, 32, 32, 3 22 | tensor = torch.randn((batch_size, seq_len, height, width, channels)) 23 | converted = conversions.to_channels_first(tensor) 24 | 25 | for b in range(batch_size): 26 | for s in range(seq_len): 27 | for h in range(height): 28 | for w in range(width): 29 | for c in range(channels): 30 | assert tensor[b, s, h, w, c] == converted[b, s, c, h, w] 31 | 32 | 33 | def test_concat_rgb(): 34 | """Test that concat_rgb correctly reshapes the tensor by concatenating sequences along 35 | channel dimensions 36 | """ 37 | batch_len = 2 38 | seq_len = 5 39 | channels = 3 40 | img_size = 32 41 | batch = torch.randn((batch_len, seq_len, channels, img_size, img_size)) 42 | concatenated = conversions.concat_rgb(batch) 43 | 44 | for concat_seq, batch_seq in zip(concatenated, batch): 45 | expected = torch.empty((seq_len * channels, 32, 32)) 46 | for i in range(32): 47 | for j in range(32): 48 | rgb = torch.empty(channels * seq_len) 49 | for s in range(seq_len): 50 | rgb[s * channels + 0] = batch_seq[s, 0, i, j] 51 | rgb[s * channels + 1] = batch_seq[s, 1, i, j] 52 | rgb[s * channels + 2] = batch_seq[s, 2, i, j] 53 | expected[:, i, j] = rgb 54 | assert torch.equal(expected, concat_seq) 55 | 56 | 57 | def test_batch_to_sequence(): 58 | batch_size, seq_len, height, width, channels = 15, 10, 32, 32, 3 59 | batch = np.random.normal(size=(batch_size, seq_len, height, width, 60 | channels)) 61 | sequence = conversions.batch_to_sequence(batch) 62 | 63 | for b in range(batch_size): 64 | for s in range(seq_len): 65 | assert np.array_equal(batch[b], 66 | sequence[b * seq_len:(b + 1) * seq_len]) 67 | -------------------------------------------------------------------------------- /tests/test_decoder.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from networks import decoder_net 5 | 6 | 7 | def test_decoder_shape_default_params(): 8 | batch_size = 10 9 | # With default params 10 | decoder = decoder_net.DecoderNet(in_channels=16, out_channels=3) 11 | inp = torch.randn((batch_size, 16, 4, 4)) 12 | out = decoder(inp) 13 | expected_size = torch.Size([batch_size, 3, 32, 32]) 14 | actual_size = out.size() 15 | assert expected_size == actual_size 16 | 17 | # With default params, different in-out channels 18 | decoder = decoder_net.DecoderNet(in_channels=32, out_channels=10) 19 | inp = torch.randn((batch_size, 32, 4, 4)) 20 | out = decoder(inp) 21 | expected_size = torch.Size([batch_size, 10, 32, 32]) 22 | actual_size = out.size() 23 | assert expected_size == actual_size 24 | 25 | # With custom params 26 | decoder = decoder_net.DecoderNet(in_channels=16, out_channels=3, n_residual_blocks=4, 27 | n_filters=[32, 64, 128, 256], 28 | kernel_sizes=[3, 3, 3, 5, 5]) 29 | inp = torch.randn((batch_size, 16, 4, 4)) 30 | out = decoder(inp) 31 | expected_size = torch.Size([batch_size, 3, 64, 64]) 32 | actual_size = out.size() 33 | assert expected_size == actual_size 34 | 35 | # With custom params, different in-out shape 36 | decoder = decoder_net.DecoderNet(in_channels=8, out_channels=1, n_residual_blocks=4, 37 | n_filters=[32, 64, 128, 256], 38 | kernel_sizes=[3, 3, 3, 5, 5]) 39 | inp = torch.randn((batch_size, 8, 4, 4)) 40 | out = decoder(inp) 41 | expected_size = torch.Size([batch_size, 1, 64, 64]) 42 | actual_size = out.size() 43 | assert expected_size == actual_size 44 | 45 | 46 | def test_decoder_raises_exception(): 47 | """Test the encoder correctly raises exceptions for wrong custom params 48 | """ 49 | with pytest.raises(ValueError): 50 | decoder = decoder_net.DecoderNet(in_channels=8, out_channels=1, n_residual_blocks=4, 51 | n_filters=[32, 64, 128, 256]) 52 | with pytest.raises(ValueError): 53 | decoder = decoder_net.DecoderNet(in_channels=8, out_channels=1, n_residual_blocks=4, 54 | kernel_sizes=[3, 3, 3, 5, 5]) 55 | with pytest.raises(ValueError): 56 | decoder = decoder_net.DecoderNet(in_channels=8, out_channels=1, 57 | n_filters=[32, 64, 128, 256], 58 | kernel_sizes=[3, 3, 3, 5, 5]) 59 | with pytest.raises(AssertionError): # Wrong n_filters and kernel_sizes 60 | decoder = decoder_net.DecoderNet(in_channels=8, out_channels=1, n_residual_blocks=3, 61 | n_filters=[32, 64, 128, 256], 62 | kernel_sizes=[3, 3, 3, 5, 5]) 63 | with pytest.raises(AssertionError): # Wrong n_filters 64 | decoder = decoder_net.DecoderNet(in_channels=8, out_channels=1, n_residual_blocks=4, 65 | n_filters=[32, 64, 128], 66 | kernel_sizes=[3, 3, 3, 5, 5]) 67 | with pytest.raises(AssertionError): # Wrong kernel_sizes 68 | decoder = decoder_net.DecoderNet(in_channels=8, out_channels=1, n_residual_blocks=4, 69 | n_filters=[32, 64, 128, 256], 70 | kernel_sizes=[3, 3, 3]) 71 | with pytest.raises(AssertionError): # Wrong n_filters for no residual blocks 72 | decoder = decoder_net.DecoderNet(in_channels=8, out_channels=1, n_residual_blocks=0, 73 | n_filters=[3], 74 | kernel_sizes=[5]) -------------------------------------------------------------------------------- /tests/test_inference_net.py: -------------------------------------------------------------------------------- 1 | """This module provides tests for the network architectures. Requires pytest (pip install pytest). 2 | 3 | Run by command line: 4 | pytest tests/networks.py 5 | """ 6 | import os 7 | import sys 8 | 9 | import torch 10 | import pytest 11 | 12 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 13 | from networks import decoder_net 14 | from networks import encoder_net 15 | from networks import hamiltonian_net 16 | from networks import transformer_net 17 | from utilities import conversions 18 | 19 | 20 | def test_to_phase_space(): 21 | batch_size = 10 22 | channels = 48 23 | img_size = 32 24 | batch = torch.randn((batch_size, channels, img_size, img_size)) 25 | q, p = transformer_net.TransformerNet.to_phase_space(batch) 26 | 27 | for k in range(batch_size): 28 | for i in range(int(channels / 2)): 29 | for h in range(img_size): 30 | for w in range(img_size): 31 | assert q[k, i, h, w] == batch[k, i, h, w] 32 | assert p[k, i, h, w] == batch[k, 33 | int(channels / 2) + i, h, w] 34 | 35 | 36 | def test_encoder_out_shape(): 37 | seq_len = 10 38 | img_size = 32 39 | in_channels = 3 40 | out_channels = 48 41 | hidden_conv_layers = 4 42 | n_filters = [32, 48, 64, 80, 96] # Must be hidden_conv_layers + 1 43 | kernel_sizes = [3, 5, 7, 7, 5, 3] # Must be hidden_conv_layers + 2 44 | strides = [1, 2, 1, 2, 1, 1] # Must be hidden_conv_layers + 2 45 | encoder = encoder_net.EncoderNet(seq_len=seq_len, 46 | in_channels=in_channels, 47 | out_channels=out_channels, 48 | hidden_conv_layers=hidden_conv_layers, 49 | n_filters=n_filters, 50 | kernel_sizes=kernel_sizes, 51 | strides=strides) 52 | 53 | expected_out_size = torch.Size( 54 | [128, 48, int(img_size / 4), 55 | int(img_size / 4)]) 56 | 57 | inputs = torch.randn((128, seq_len, in_channels, img_size, img_size)) 58 | inputs = conversions.concat_rgb(inputs) 59 | z, mu, var = encoder(inputs) 60 | 61 | assert z.size() == expected_out_size 62 | 63 | 64 | def test_encoder_raises_exception(): 65 | """Test that the encoder correctly raises exceptions if called with wrong params. 66 | """ 67 | seq_len = 10 68 | img_size = 32 69 | in_channels = 3 70 | out_channels = 48 71 | hidden_conv_layers = 4 72 | # Args to be tested 73 | n_filters = [32, 48, 64, 80, 96] # Must be hidden_conv_layers + 1 74 | kernel_sizes = [3, 5, 7, 7, 5, 3] # Must be hidden_conv_layers + 2 75 | strides = [1, 2, 1, 2, 1, 1] # Must be hidden_conv_layers + 2 76 | with pytest.raises(AssertionError): 77 | encoder = encoder_net.EncoderNet( 78 | seq_len=seq_len, 79 | in_channels=in_channels, 80 | out_channels=out_channels, 81 | hidden_conv_layers=hidden_conv_layers, 82 | n_filters=n_filters[:-1], # n_filters is shorter than it should be 83 | kernel_sizes=kernel_sizes, 84 | strides=strides) 85 | with pytest.raises(AssertionError): 86 | encoder = encoder_net.EncoderNet( 87 | seq_len=seq_len, 88 | in_channels=in_channels, 89 | out_channels=out_channels, 90 | hidden_conv_layers=hidden_conv_layers, 91 | n_filters=n_filters, 92 | kernel_sizes= 93 | kernel_sizes[:-1], # kernel_sizes is shorter than it should be 94 | strides=strides) 95 | with pytest.raises(AssertionError): 96 | encoder = encoder_net.EncoderNet( 97 | seq_len=seq_len, 98 | in_channels=in_channels, 99 | out_channels=out_channels, 100 | hidden_conv_layers=hidden_conv_layers, 101 | n_filters=n_filters, 102 | kernel_sizes=kernel_sizes, 103 | strides=strides[:-1] # strides is shorter than it should be 104 | ) 105 | with pytest.raises(AssertionError): 106 | encoder = encoder_net.EncoderNet( 107 | seq_len=seq_len, 108 | in_channels=in_channels, 109 | out_channels=out_channels, 110 | hidden_conv_layers=hidden_conv_layers, 111 | n_filters=n_filters + 112 | [64], # n_filters is longer than it should be 113 | kernel_sizes=kernel_sizes, 114 | strides=strides) 115 | with pytest.raises(AssertionError): 116 | encoder = encoder_net.EncoderNet( 117 | seq_len=seq_len, 118 | in_channels=in_channels, 119 | out_channels=out_channels, 120 | hidden_conv_layers=hidden_conv_layers, 121 | n_filters=n_filters, 122 | kernel_sizes=kernel_sizes + 123 | [5], # kernel_sizes is longer than it should be 124 | strides=strides) 125 | with pytest.raises(AssertionError): 126 | encoder = encoder_net.EncoderNet( 127 | seq_len=seq_len, 128 | in_channels=in_channels, 129 | out_channels=out_channels, 130 | hidden_conv_layers=hidden_conv_layers, 131 | n_filters=n_filters, 132 | kernel_sizes=kernel_sizes, 133 | strides=strides + [2] # strides is longer than it should be 134 | ) 135 | # Test not all arguments are given 136 | with pytest.raises(ValueError): 137 | encoder = encoder_net.EncoderNet( 138 | seq_len=seq_len, 139 | in_channels=in_channels, 140 | out_channels=out_channels, 141 | hidden_conv_layers=hidden_conv_layers, 142 | n_filters=n_filters, 143 | kernel_sizes=kernel_sizes, 144 | # Missing strides 145 | ) 146 | # Test that correctly works if no args are given 147 | encoder = encoder_net.EncoderNet(seq_len=seq_len, 148 | in_channels=in_channels, 149 | out_channels=out_channels) 150 | 151 | 152 | def test_transformer_shape(): 153 | img_size = 32 154 | in_channels = 48 155 | out_channels = 16 # For q and p separately 156 | hidden_conv_layers = 4 157 | n_filters = [32, 48, 64, 80, 96] # Must be hidden_conv_layers + 1 158 | kernel_sizes = [3, 5, 7, 7, 5, 3] # Must be hidden_conv_layers + 2 159 | strides = [1, 1, 1, 1, 1, 1] # Must be hidden_conv_layers + 2 160 | transformer = transformer_net.TransformerNet( 161 | in_channels=in_channels, 162 | out_channels=out_channels, 163 | hidden_conv_layers=hidden_conv_layers, 164 | n_filters=n_filters, 165 | kernel_sizes=kernel_sizes, 166 | strides=strides) 167 | 168 | expected_out_size = torch.Size([128, 16, img_size, img_size]) 169 | 170 | inputs = torch.randn((128, in_channels, img_size, img_size)) 171 | q, p = transformer(inputs) 172 | 173 | assert q.size() == expected_out_size 174 | assert p.size() == expected_out_size 175 | 176 | 177 | def test_transformer_raises_exception(): 178 | """Test that the transformer correctly raises exceptions if called with wrong params. 179 | """ 180 | img_size = 32 181 | in_channels = 3 182 | out_channels = 48 183 | hidden_conv_layers = 4 184 | # Args to be tested 185 | n_filters = [32, 48, 64, 80, 96] # Must be hidden_conv_layers + 1 186 | kernel_sizes = [3, 5, 7, 7, 5, 3] # Must be hidden_conv_layers + 2 187 | strides = [1, 2, 1, 2, 1, 1] # Must be hidden_conv_layers + 2 188 | with pytest.raises(AssertionError): 189 | transformer = transformer_net.TransformerNet( 190 | in_channels=in_channels, 191 | out_channels=out_channels, 192 | hidden_conv_layers=hidden_conv_layers, 193 | n_filters=n_filters[:-1], # n_filters is shorter than it should be 194 | kernel_sizes=kernel_sizes, 195 | strides=strides) 196 | with pytest.raises(AssertionError): 197 | transformer = transformer_net.TransformerNet( 198 | in_channels=in_channels, 199 | out_channels=out_channels, 200 | hidden_conv_layers=hidden_conv_layers, 201 | n_filters=n_filters, 202 | kernel_sizes= 203 | kernel_sizes[:-1], # kernel_sizes is shorter than it should be 204 | strides=strides) 205 | with pytest.raises(AssertionError): 206 | transformer = transformer_net.TransformerNet( 207 | in_channels=in_channels, 208 | out_channels=out_channels, 209 | hidden_conv_layers=hidden_conv_layers, 210 | n_filters=n_filters, 211 | kernel_sizes=kernel_sizes, 212 | strides=strides[:-1] # strides is shorter than it should be 213 | ) 214 | with pytest.raises(AssertionError): 215 | transformer = transformer_net.TransformerNet( 216 | in_channels=in_channels, 217 | out_channels=out_channels, 218 | hidden_conv_layers=hidden_conv_layers, 219 | n_filters=n_filters + 220 | [64], # n_filters is longer than it should be 221 | kernel_sizes=kernel_sizes, 222 | strides=strides) 223 | with pytest.raises(AssertionError): 224 | transformer = transformer_net.TransformerNet( 225 | in_channels=in_channels, 226 | out_channels=out_channels, 227 | hidden_conv_layers=hidden_conv_layers, 228 | n_filters=n_filters, 229 | kernel_sizes=kernel_sizes + 230 | [5], # kernel_sizes is longer than it should be 231 | strides=strides) 232 | with pytest.raises(AssertionError): 233 | transformer = transformer_net.TransformerNet( 234 | in_channels=in_channels, 235 | out_channels=out_channels, 236 | hidden_conv_layers=hidden_conv_layers, 237 | n_filters=n_filters, 238 | kernel_sizes=kernel_sizes, 239 | strides=strides + [2] # strides is longer than it should be 240 | ) 241 | # Test not all arguments are given 242 | with pytest.raises(ValueError): 243 | transformer = transformer_net.TransformerNet( 244 | in_channels=in_channels, 245 | out_channels=out_channels, 246 | hidden_conv_layers=hidden_conv_layers, 247 | n_filters=n_filters, 248 | kernel_sizes=kernel_sizes, 249 | # Missing strides 250 | ) 251 | # Test that correctly works if no args are given 252 | transformer = transformer_net.TransformerNet(in_channels=in_channels, 253 | out_channels=out_channels) 254 | -------------------------------------------------------------------------------- /tests/test_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | import os 4 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 5 | from utilities.losses import kld_loss 6 | 7 | def test_kld_loss(): 8 | batch_sizes = [1, 10, 100] 9 | latent_size = 8 10 | for batch_size in [10]: #batch_sizes: 11 | mu = torch.randn((batch_size, latent_size)) 12 | logvar = torch.randn((batch_size, latent_size)) 13 | 14 | kld = kld_loss(mu, logvar) 15 | assert kld.dim() == 0 16 | 17 | mu = torch.randn((batch_size, latent_size, latent_size, latent_size)) 18 | logvar = torch.randn((batch_size, latent_size, latent_size, latent_size)) 19 | 20 | kld = kld_loss(mu, logvar) 21 | assert kld.dim() == 0 22 | 23 | test_kld_loss() -------------------------------------------------------------------------------- /tests/test_networks.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | import torch 6 | 7 | sys.path.append(str(Path('.').absolute().parent)) 8 | from networks import encoder_net 9 | from networks import hamiltonian_net 10 | from networks import transformer_net 11 | 12 | if __name__ == '__main__': 13 | SEQUENCE_LENGTH = 10 14 | BATCH_SIZE = 64 15 | enc_net = encoder_net.EncoderNet(seq_len=SEQUENCE_LENGTH, out_channels=48) 16 | 17 | rand_images = np.random.randint(0, 18 | 255, 19 | size=(BATCH_SIZE, SEQUENCE_LENGTH, 32, 32)) 20 | rand_images_ts = torch.tensor(rand_images).float() 21 | 22 | z, mean, stdev = enc_net(rand_images_ts) 23 | 24 | trans_net = transformer_net.TransformerNet(in_channels=48, out_channels=32) 25 | q, p = trans_net(z) 26 | 27 | ham_net = hamiltonian_net.HamiltonianNet(in_channels=32) 28 | 29 | h = ham_net(q, p) 30 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """Script to train the Hamiltonian Generative Network 2 | """ 3 | import ast 4 | import argparse 5 | import copy 6 | import pprint 7 | import os 8 | import warnings 9 | import yaml 10 | 11 | import numpy as np 12 | import torch 13 | import tqdm 14 | 15 | from utilities.integrator import Integrator 16 | from utilities.training_logger import TrainingLogger 17 | from utilities import loader 18 | from utilities.loader import load_hgn, get_online_dataloaders, get_offline_dataloaders 19 | from utilities.losses import reconstruction_loss, kld_loss, geco_constraint 20 | from utilities.statistics import mean_confidence_interval 21 | 22 | def _avoid_overwriting(experiment_id): 23 | # This function throws an error if the given experiment data already exists in runs/ 24 | logdir = os.path.join('runs', experiment_id) 25 | if os.path.exists(logdir): 26 | assert len(os.listdir(logdir)) == 0,\ 27 | f'Experiment id {experiment_id} already exists in runs/. Remove it, change the name ' \ 28 | f'in the yaml file.' 29 | 30 | 31 | class HgnTrainer: 32 | 33 | def __init__(self, params, resume=False): 34 | """Instantiate and train the Hamiltonian Generative Network. 35 | 36 | Args: 37 | params (dict): Experiment parameters (see experiment_params folder). 38 | """ 39 | 40 | self.params = params 41 | self.resume = resume 42 | 43 | if not resume: # Fail if experiment_id already exist in runs/ 44 | _avoid_overwriting(params["experiment_id"]) 45 | 46 | # Set device 47 | self.device = params["device"] 48 | if "cuda" in self.device and not torch.cuda.is_available(): 49 | warnings.warn( 50 | "Warning! Set to train in GPU but cuda is not available. Device is set to CPU.") 51 | self.device = "cpu" 52 | 53 | # Get dtype, will raise a 'module 'torch' has no attribute' if there is a typo 54 | self.dtype = torch.__getattribute__(params["networks"]["dtype"]) 55 | 56 | # Load hgn from parameters to deice 57 | self.hgn = load_hgn(params=self.params, 58 | device=self.device, 59 | dtype=self.dtype) 60 | if 'load_path' in self.params: 61 | self.load_and_reset(self.params, self.device, self.dtype) 62 | 63 | # Either generate data on-the-fly or load the data from disk 64 | if "train_data" in self.params["dataset"]: 65 | print("Training with OFFLINE data...") 66 | self.train_data_loader, self.test_data_loader = get_offline_dataloaders(self.params) 67 | else: 68 | print("Training with ONLINE data...") 69 | self.train_data_loader, self.test_data_loader = get_online_dataloaders(self.params) 70 | 71 | # Initialize training logger 72 | self.training_logger = TrainingLogger( 73 | hyper_params=self.params, 74 | loss_freq=100, 75 | rollout_freq=1000, 76 | model_freq=10000 77 | ) 78 | 79 | # Initialize tensorboard writer 80 | self.model_save_file = os.path.join( 81 | self.params["model_save_dir"], 82 | self.params["experiment_id"] 83 | ) 84 | 85 | # Define optimization modules 86 | optim_params = [ 87 | { 88 | 'params': self.hgn.encoder.parameters(), 89 | 'lr': params["optimization"]["encoder_lr"] 90 | }, 91 | { 92 | 'params': self.hgn.transformer.parameters(), 93 | 'lr': params["optimization"]["transformer_lr"] 94 | }, 95 | { 96 | 'params': self.hgn.hnn.parameters(), 97 | 'lr': params["optimization"]["hnn_lr"] 98 | }, 99 | { 100 | 'params': self.hgn.decoder.parameters(), 101 | 'lr': params["optimization"]["decoder_lr"] 102 | }, 103 | ] 104 | self.optimizer = torch.optim.Adam(optim_params) 105 | 106 | def load_and_reset(self, params, device, dtype): 107 | """Load the HGN from the path specified in params['load_path'] and reset the networks in 108 | params['reset']. 109 | 110 | Args: 111 | params (dict): Dictionary with all the necessary parameters to load the networks. 112 | device (str): 'gpu:N' or 'cpu' 113 | dtype (torch.dtype): Data type to be used in computations. 114 | """ 115 | self.hgn.load(params['load_path']) 116 | if 'reset' in params: 117 | if isinstance(params['reset'], list): 118 | for net in params['reset']: 119 | assert net in ['encoder', 'decoder', 'hamiltonian', 'transformer'] 120 | else: 121 | assert params['reset'] in ['encoder', 'decoder', 'hamiltonian', 'transformer'] 122 | if 'encoder' in params['reset']: 123 | self.hgn.encoder = loader.instantiate_encoder(params, device, dtype) 124 | if 'decoder' in params['reset']: 125 | self.hgn.decoder = loader.instantiate_decoder(params, device, dtype) 126 | if 'transformer' in params['reset']: 127 | self.hgn.transformer = loader.instantiate_transformer(params, device, dtype) 128 | if 'hamiltonian' in params['reset']: 129 | self.hgn.hnn = loader.instantiate_hamiltonian(params, device, dtype) 130 | 131 | def training_step(self, rollouts): 132 | """Perform a training step with the given rollouts batch. 133 | 134 | Args: 135 | rollouts (torch.Tensor): Tensor of shape (batch_size, seq_len, channels, height, width) 136 | corresponding to a batch of sampled rollouts. 137 | 138 | Returns: 139 | A dictionary of losses and the model's prediction of the rollout. The reconstruction loss and 140 | KL divergence are floats and prediction is the HGNResult object with data of the forward pass. 141 | """ 142 | self.optimizer.zero_grad() 143 | 144 | rollout_len = rollouts.shape[1] 145 | input_frames = self.params['optimization']['input_frames'] 146 | assert(input_frames <= rollout_len) # optimization.use_steps must be smaller (or equal) to rollout.sequence_length 147 | roll = rollouts[:, :input_frames] 148 | 149 | hgn_output = self.hgn.forward(rollout_batch=roll, n_steps=rollout_len - input_frames) 150 | target = rollouts[:, input_frames-1:] # Fit first input_frames and try to predict the last + the next (rollout_len - input_frames) 151 | prediction = hgn_output.reconstructed_rollout 152 | 153 | if self.params["networks"]["variational"]: 154 | tol = self.params["geco"]["tol"] 155 | alpha = self.params["geco"]["alpha"] 156 | lagrange_mult_param = self.params["geco"]["lagrange_multiplier_param"] 157 | 158 | C, rec_loss = geco_constraint(target, prediction, tol) # C has gradient 159 | 160 | # Compute moving average of constraint C (without gradient) 161 | if self.C_ma is None: 162 | self.C_ma = C.detach() 163 | else: 164 | self.C_ma = alpha * self.C_ma + (1 - alpha) * C.detach() 165 | C_curr = C.detach().item() # keep track for logging 166 | C = C + (self.C_ma - C.detach()) # Move C without affecting its gradient 167 | 168 | # Compute KL divergence 169 | mu = hgn_output.z_mean 170 | logvar = hgn_output.z_logvar 171 | kld = kld_loss(mu=mu, logvar=logvar) 172 | 173 | # normalize by number of frames, channels and pixels per frame 174 | kld_normalizer = prediction.flatten(1).size(1) 175 | kld = kld / kld_normalizer 176 | 177 | # Compute losses 178 | train_loss = kld + self.langrange_multiplier * C 179 | 180 | # clamping the langrange multiplier to avoid inf values 181 | self.langrange_multiplier = self.langrange_multiplier * torch.exp( 182 | lagrange_mult_param * C.detach()) 183 | self.langrange_multiplier = torch.clamp(self.langrange_multiplier, 1e-10, 1e10) 184 | 185 | losses = { 186 | 'loss/train': train_loss.item(), 187 | 'loss/kld': kld.item(), 188 | 'loss/C': C_curr, 189 | 'loss/C_ma': self.C_ma.item(), 190 | 'loss/rec': rec_loss.item(), 191 | 'other/langrange_mult': self.langrange_multiplier.item() 192 | } 193 | 194 | else: # not variational 195 | # Compute frame reconstruction error 196 | train_loss = reconstruction_loss( 197 | target=target, 198 | prediction=prediction) 199 | losses = {'loss/train': train_loss.item()} 200 | 201 | train_loss.backward() 202 | self.optimizer.step() 203 | 204 | return losses, hgn_output 205 | 206 | def fit(self): 207 | """The trainer fits an HGN. 208 | 209 | Returns: 210 | (HGN) An HGN model that has been fitted to the data 211 | """ 212 | 213 | # Initial values for geco algorithm 214 | if self.params["networks"]["variational"]: 215 | self.langrange_multiplier = self.params["geco"]["initial_lagrange_multiplier"] 216 | self.C_ma = None 217 | 218 | # TRAIN 219 | for ep in range(self.params["optimization"]["epochs"]): 220 | print("Epoch %s / %s" % (str(ep + 1), str(self.params["optimization"]["epochs"]))) 221 | pbar = tqdm.tqdm(self.train_data_loader) 222 | for batch_idx, rollout_batch in enumerate(pbar): 223 | # Move to device and change dtype 224 | rollout_batch = rollout_batch.to(self.device).type(self.dtype) 225 | 226 | # Do an optimization step 227 | losses, prediction = self.training_step(rollouts=rollout_batch) 228 | 229 | # Log progress 230 | self.training_logger.step(losses=losses, 231 | rollout_batch=rollout_batch, 232 | prediction=prediction, 233 | model=self.hgn) 234 | 235 | # Progress-bar msg 236 | msg = ", ".join([ 237 | f"{k}: {v:.2e}" for k, v in losses.items() if v is not None 238 | ]) 239 | pbar.set_description(msg) 240 | # Save model 241 | self.hgn.save(self.model_save_file) 242 | 243 | self.test() 244 | return self.hgn 245 | 246 | def compute_reconst_kld_errors(self, dataloader): 247 | """Computes reconstruction error and KL divergence. 248 | 249 | Args: 250 | dataloader (torch.utils.data.DataLoader): DataLoader to retrieve errors from. 251 | 252 | Returns: 253 | (reconst_error_mean, reconst_error_h), (kld_mean, kld_h): Tuples where the mean and 95% 254 | conficence interval is shown. 255 | """ 256 | first = True 257 | pbar = tqdm.tqdm(dataloader) 258 | 259 | for _, rollout_batch in enumerate(pbar): 260 | # Move to device and change dtype 261 | rollout_batch = rollout_batch.to(self.device).type(self.dtype) 262 | rollout_len = rollout_batch.shape[1] 263 | input_frames = self.params['optimization']['input_frames'] 264 | assert(input_frames <= rollout_len) # optimization.use_steps must be smaller (or equal) to rollout.sequence_length 265 | roll = rollout_batch[:, :input_frames] 266 | hgn_output = self.hgn.forward(rollout_batch=roll, n_steps=rollout_len - input_frames) 267 | target = rollout_batch[:, input_frames-1:] # Fit first input_frames and try to predict the last + the next (rollout_len - input_frames) 268 | prediction = hgn_output.reconstructed_rollout 269 | error = reconstruction_loss( 270 | target=target, 271 | prediction=prediction, mean_reduction=False).detach().cpu( 272 | ).numpy() 273 | if self.params["networks"]["variational"]: 274 | kld = kld_loss(mu=hgn_output.z_mean, logvar=hgn_output.z_logvar, mean_reduction=False).detach().cpu( 275 | ).numpy() 276 | # normalize by number of frames, channels and pixels per frame 277 | kld_normalizer = prediction.flatten(1).size(1) 278 | kld = kld / kld_normalizer 279 | if first: 280 | first = False 281 | set_errors = error 282 | if self.params["networks"]["variational"]: 283 | set_klds = kld 284 | else: 285 | set_errors = np.concatenate((set_errors, error)) 286 | if self.params["networks"]["variational"]: 287 | set_klds = np.concatenate((set_klds, kld)) 288 | err_mean, err_h = mean_confidence_interval(set_errors) 289 | if self.params["networks"]["variational"]: 290 | kld_mean, kld_h = mean_confidence_interval(set_klds) 291 | return (err_mean, err_h), (kld_mean, kld_h) 292 | else: 293 | return (err_mean, err_h), None 294 | 295 | def test(self): 296 | """Test after the training is finished and logs result to tensorboard. 297 | """ 298 | print("Calculating final training error...") 299 | (err_mean, err_h), kld = self.compute_reconst_kld_errors(self.train_data_loader) 300 | self.training_logger.log_error("Train reconstruction error", err_mean, err_h) 301 | if kld is not None: 302 | kld_mean, kld_h = kld 303 | self.training_logger.log_error("Train KL divergence", kld_mean, kld_h) 304 | 305 | print("Calculating final test error...") 306 | (err_mean, err_h), kld = self.compute_reconst_kld_errors(self.test_data_loader) 307 | self.training_logger.log_error("Test reconstruction error", err_mean, err_h) 308 | if kld is not None: 309 | kld_mean, kld_h = kld 310 | self.training_logger.log_error("Test KL divergence", kld_mean, kld_h) 311 | 312 | def _overwrite_config_with_cmd_arguments(config, args): 313 | if args.name is not None: 314 | config['experiment_id'] = args.name[0] 315 | if args.epochs is not None: 316 | config['optimization']['epochs'] = args.epochs[0] 317 | if args.dataset_path is not None: 318 | # Read the parameters.yaml file in the given dataset path 319 | dataset_config = _read_config(os.path.join(_args.dataset_path[0], 'parameters.yaml')) 320 | for key, value in dataset_config.items(): 321 | config[key] = value 322 | if args.env is not None: 323 | if 'train_data' in config['dataset']: 324 | raise ValueError( 325 | f'--env was given but configuration is set for offline training: ' 326 | f'train_data={config["dataset"]["train_data"]}' 327 | ) 328 | env_params = _read_config(DEFAULT_ENVIRONMENTS_PATH + args.env[0] + '.yaml') 329 | config['environment'] = env_params['environment'] 330 | if args.params is not None: 331 | for p in args.params: 332 | key, value = p.split('=') 333 | ptr = config 334 | keys = key.split('.') 335 | for i, k in enumerate(keys): 336 | if i == len(keys) - 1: 337 | ptr[k] = ast.literal_eval(value) 338 | else: 339 | ptr = ptr[k] 340 | if args.load is not None: 341 | config['load_path'] = args.load[0] 342 | if args.reset is not None: 343 | config['reset'] = args.reset 344 | 345 | 346 | def _read_config(config_file): 347 | with open(config_file, 'r') as f: 348 | config = yaml.load(f, Loader=yaml.FullLoader) 349 | return config 350 | 351 | 352 | def _merge_configs(train_config, dataset_config): 353 | config = copy.deepcopy(train_config) 354 | for key, value in dataset_config.items(): 355 | config[key] = value 356 | # If the config specifies a dataset path, we take the rollout from the configuration file 357 | # in the given dataset 358 | if 'dataset' in config and 'train_data' in config['dataset']: 359 | dataset_config = _read_config( # Read parameters.yaml in root of given dataset 360 | os.path.join(os.path.dirname(config['dataset']['train_data']), 'parameters.yaml')) 361 | config['dataset']['rollout'] = dataset_config['dataset']['rollout'] 362 | return config 363 | 364 | 365 | def _ask_confirmation(config): 366 | printer = pprint.PrettyPrinter(indent=4) 367 | print(f'The training will be run with the following configuration:') 368 | printed_config = copy.deepcopy(_config) 369 | printed_config.pop('networks') 370 | printer.pprint(printed_config) 371 | print('Proceed? (y/n):') 372 | if input() != 'y': 373 | print('Abort.') 374 | exit() 375 | 376 | 377 | if __name__ == "__main__": 378 | 379 | DEFAULT_TRAIN_CONFIG_FILE = "experiment_params/train_config_default.yaml" 380 | DEFAULT_DATASET_CONFIG_FILE = "experiment_params/dataset_online_default.yaml" 381 | DEFAULT_ENVIRONMENTS_PATH = "experiment_params/default_environments/" 382 | DEFAULT_SAVE_MODELS_DIR = "saved_models/" 383 | 384 | parser = argparse.ArgumentParser() 385 | parser.add_argument( 386 | '--train-config', action='store', nargs=1, type=str, required=True, 387 | help=f'Path to the training configuration yaml file.' 388 | ) 389 | parser.add_argument( 390 | '--dataset-config', action='store', nargs=1, type=str, required=False, 391 | help=f'Path to the dataset configuration yaml file.' 392 | ) 393 | parser.add_argument( 394 | '--name', action='store', nargs=1, required=False, 395 | help='If specified, this name will be used instead of experiment_id of the yaml file.' 396 | ) 397 | parser.add_argument( 398 | '--epochs', action='store', nargs=1, type=int, required=False, 399 | help='The number of training epochs. If not specified, optimization.epochs of the ' 400 | 'training configuration will be used.' 401 | ) 402 | parser.add_argument( 403 | '--env', action='store', nargs=1, type=str, required=False, 404 | help='The environment to use (for online training only). Possible values are ' 405 | '\'pendulum\', \'spring\', \'two_bodies\', \'three_bodies\', corresponding to ' 406 | 'environment configurations in experiment_params/default_environments/. If not ' 407 | 'specified, the environment specified in the given --dataset-config will be used.' 408 | ) 409 | parser.add_argument( 410 | '--dataset-path', action='store', nargs=1, type=str, required=False, 411 | help='Path to a stored dataset to use for training. For offline training only. In this ' 412 | 'case no dataset configuration file will be loaded.' 413 | ) 414 | parser.add_argument( 415 | '--params', action='store', nargs='+', required=False, 416 | help='Override one or more parameters in the config. The format of an argument is ' 417 | 'param_name=param_value. Nested parameters are accessible by using a dot, ' 418 | 'i.e. --param dataset.img_size=32. IMPORTANT: lists must be enclosed in double ' 419 | 'quotes, i.e. --param environment.mass:"[0.5, 0.5]".' 420 | ) 421 | parser.add_argument( 422 | '-y', '-y', action='store_true', default=False, required=False, 423 | help='Whether to skip asking for user confirmation before starting the training.' 424 | ) 425 | parser.add_argument( 426 | '--resume', action='store', required=False, nargs='?', default=None, 427 | help='NOT IMPLEMENTED YET. Resume the training from a saved model. If a path is provided, ' 428 | 'the training will be resumed from the given checkpoint. Otherwise, the last ' 429 | 'checkpoint will be taken from saved_models/.' 430 | ) 431 | parser.add_argument( 432 | '--load', action='store', type=str, required=False, nargs=1, 433 | help='Path from which to load the HGN.' 434 | ) 435 | parser.add_argument( 436 | '--reset', action='store', nargs='+', required=False, 437 | help='Use only in combimation with --load, tells the trainer to reinstantiate the given ' 438 | 'networks. Values: \'encoder\', \'transformer\', \'decoder\', \'hamiltonian\'.' 439 | ) 440 | _args = parser.parse_args() 441 | 442 | # Read configurations 443 | _train_config = _read_config(_args.train_config[0]) 444 | if _args.dataset_path is None: # Will use the dataset config file (or default if not given) 445 | _dataset_config_file = DEFAULT_DATASET_CONFIG_FILE if _args.dataset_config is None else \ 446 | _args.dataset_config[0] 447 | _dataset_config = _read_config(_dataset_config_file) 448 | _config = _merge_configs(_train_config, _dataset_config) 449 | else: # Will use the dataset given in the command line arguments 450 | assert _args.dataset_config is None, 'Both --dataset-path and --dataset-config were given.' 451 | _config = _train_config 452 | 453 | # Overwrite configuration with command line arguments 454 | _overwrite_config_with_cmd_arguments(_config, _args) 455 | 456 | # Show configuration and ask user for confirmation 457 | if not _args.y: 458 | _ask_confirmation(_config) 459 | 460 | # Train HGN network 461 | trainer = HgnTrainer(_config) 462 | hgn = trainer.fit() 463 | -------------------------------------------------------------------------------- /utilities/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CampusAI/Hamiltonian-Generative-Networks/702d3ff3aec40eba20e17c5a1612b5b0b1e2f831/utilities/__init__.py -------------------------------------------------------------------------------- /utilities/conversions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def to_channels_last(tensor): 6 | """Convert a tensor from shape (batch_size, seq_len, channels, height, width) to 7 | shape (batch_size, seq_len, height, width, channels). 8 | 9 | Args: 10 | tensor (torch.Tensor): Tensor to be converted. 11 | 12 | Returns: 13 | A view of the given tensor with shape (batch_size, seq_len, height, channels, width). 14 | Any change applied to this out tensor will be applied to the input tensor. 15 | """ 16 | return tensor.permute(0, 1, 3, 4, 2) 17 | 18 | 19 | def to_channels_first(tensor): 20 | """Convert a tensor from shape (batch_size, seq_len, height, width, channels) to 21 | shape (batch_size, seq_len, channels, height, width). 22 | 23 | Args: 24 | tensor (torch.Tensor): Tensor to be converted. 25 | 26 | Returns: 27 | A view of the given tensor with shape (batch_size, seq_len, channels, height, width). 28 | Any change applied to this out tensor will be applied to the input tensor. 29 | """ 30 | return tensor.permute(0, 1, 4, 2, 3) 31 | 32 | 33 | def concat_rgb(batch): 34 | """Concatenate the images along channel dimension. 35 | 36 | Args: 37 | batch (torch.Tensor): A Tensor with shape (batch_size, seq_len, channels, height, width) 38 | containing the images of the sequence. 39 | 40 | Returns: 41 | A Tensor with shape (batch_size, seq_len * channels, height, width) with the images 42 | concatenated along the channel dimension. 43 | """ 44 | batch_size, seq_len, channels, h, w = batch.size() 45 | return batch.reshape((batch_size, seq_len * channels, h, w)) 46 | 47 | 48 | def batch_to_sequence(batch): 49 | """Convert a batch of sequence of images into a single sequence composed by the concatenation 50 | of sequences in the batch. 51 | 52 | Args: 53 | batch (numpy.ndarray): Numpy array of sequences of images, must have shape 54 | (batch_size, seq_len, height, width, channels). 55 | 56 | Returns: 57 | A numpy array of shape (batch_size * seq_len, height, width, channels) with the 58 | concatenation of the given batch of sequences. 59 | """ 60 | return np.concatenate(batch, axis=0) 61 | -------------------------------------------------------------------------------- /utilities/gradient_flow_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import yaml 4 | 5 | from matplotlib import pyplot as plt 6 | from matplotlib.lines import Line2D 7 | import numpy as np 8 | import torch 9 | 10 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 11 | from networks import encoder_net 12 | from networks import transformer_net 13 | import train 14 | 15 | 16 | torch.autograd.set_detect_anomaly(True) 17 | 18 | 19 | GRADIENTS = {} # Each element is [counter, grad] 20 | 21 | 22 | def backward_hook(module, grad_input, grad_output): 23 | """Hook called in the backward pass of modules, saving the gradients in the GRADIENTS dict. 24 | 25 | Args: 26 | module (torch.nn.Module): The module for which this backward pass is called. Must provide a 27 | 'name' attribute, that will be used as key in the GRADIENTS dict. 28 | grad_input (tuple): Tuple (dL/dx, dL/dw, dL/db) 29 | grad_output (tuple): 1-tuple (dL/do) i.e. the gradient of the loss w.r.t. the layer output. 30 | """ 31 | if module.name == 'Transformer_out': 32 | q, p = transformer_net.TransformerNet.to_phase_space(grad_output[0]) 33 | set_gradient('Transformer_out_q', q.detach().cpu().numpy()) 34 | set_gradient('Transformer_out_p', p.detach().cpu().numpy()) 35 | else: 36 | set_gradient(module.name, grad_output[0].detach().cpu().numpy()) 37 | return None 38 | 39 | 40 | def set_gradient(name, gradient): 41 | if name in GRADIENTS: 42 | GRADIENTS[name][0] += 1 43 | GRADIENTS[name][1] += gradient 44 | else: 45 | GRADIENTS[name] = [1, gradient] 46 | 47 | 48 | def register_hooks(hgn): 49 | """Set a name to all the interesting layers of the hamiltonian generative networks and register 50 | hook. 51 | 52 | Args: 53 | hgn (hamiltonian_generative_network.HGN): The HGN to analyse. 54 | """ 55 | # Setting name variable to be used in hook 56 | hgn.encoder.input_conv.name = 'Encoder_in' 57 | hgn.encoder.out_mean.name = 'Encoder_out_mean' 58 | hgn.encoder.out_logvar.name = 'Encoder_out_logvar' 59 | hgn.transformer.in_conv.name = 'Transformer_in' 60 | hgn.transformer.out_conv.name = 'Transformer_out' 61 | hgn.hnn.in_conv.name = 'Hamiltonian_in' 62 | hgn.hnn.linear.name = 'Hamiltonian_out' 63 | hgn.decoder.residual_blocks[0].name = 'Decoder_in' 64 | hgn.decoder.out_conv.name = 'Decoder_out' 65 | 66 | # Registering hooks 67 | hgn.encoder.input_conv.register_backward_hook(backward_hook) 68 | hgn.encoder.out_mean.register_backward_hook(backward_hook) 69 | hgn.encoder.out_logvar.register_backward_hook(backward_hook) 70 | hgn.transformer.in_conv.register_backward_hook(backward_hook) 71 | hgn.transformer.out_conv.register_backward_hook(backward_hook) 72 | hgn.hnn.in_conv.register_backward_hook(backward_hook) 73 | hgn.hnn.linear.register_backward_hook(backward_hook) 74 | hgn.decoder.residual_blocks[0].register_backward_hook(backward_hook) 75 | hgn.decoder.out_conv.register_backward_hook(backward_hook) 76 | 77 | 78 | def get_grads(hgn, batch_size, dtype): 79 | """Plot the gradients of each input-output layer of the hamiltonian generative network model. 80 | 81 | Args: 82 | hgn (hamiltonian_generative_network.HGN): The HGN to analyze. 83 | batch_size (int): Batch size used when testing gradients 84 | dtype (torch.dtype): Type to be used in tensor operations. 85 | 86 | """ 87 | register_hooks(hgn) 88 | rand_in = torch.rand((batch_size, hgn.seq_len, hgn.channels, 32, 32)).type(dtype) 89 | hgn.fit(rand_in) 90 | 91 | names = GRADIENTS.keys() 92 | max_grads = [np.abs((GRADIENTS[k][1] / GRADIENTS[k][0])).max() for k in names] 93 | mean_grads = [np.abs((GRADIENTS[k][1] / GRADIENTS[k][0])).mean() for k in names] 94 | 95 | return names, max_grads, mean_grads 96 | 97 | 98 | def plot_grads(names, max_grads, mean_grads): 99 | plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.3, lw=1, color="c") 100 | plt.bar(np.arange(len(max_grads)), mean_grads, alpha=0.3, lw=1, color="b") 101 | plt.hlines(0, 0, len(mean_grads) + 1, lw=2, color="k") 102 | plt.xticks(range(0, len(mean_grads), 1), names, rotation="vertical") 103 | plt.xlim(left=0, right=len(mean_grads)) 104 | plt.ylim(bottom=-0.000001, top=0.0001) # zoom in on the lower gradient regions 105 | plt.xlabel("Layers") 106 | plt.ylabel("average gradient") 107 | plt.title("Gradient flow") 108 | plt.grid(True) 109 | plt.legend([Line2D([0], [0], color="c", lw=4), 110 | Line2D([0], [0], color="b", lw=4), 111 | Line2D([0], [0], color="k", lw=4)], 112 | ['max-gradient', 'mean-gradient', 'zero-gradient']) 113 | plt.show() 114 | 115 | 116 | if __name__ == '__main__': 117 | params_file = "experiment_params/default.yaml" 118 | with open(params_file, 'r') as f: 119 | params = yaml.load(f, Loader=yaml.FullLoader) 120 | device = params["device"] if torch.cuda.is_available() else "cpu" 121 | 122 | hgn = train.load_hgn(params, device=device, dtype=torch.float) 123 | 124 | names, max_grads, mean_grads = get_grads( 125 | hgn, batch_size=params['optimization']['batch_size'], dtype=torch.float) 126 | 127 | print('-------------------BACKWARD CALL COUNTS------------------------------------------------') 128 | for k, v in GRADIENTS.items(): 129 | print(f'{k:20} backward called {v[0]:10} times') 130 | print('-------------------------GRADIENTS-----------------------------------------------------') 131 | for name, max_grad, mean_grad in zip(names, max_grads, mean_grads): 132 | print(f'{name:20} max_grad: {max_grad:25} mean_grad: {mean_grad:25}') 133 | print('---------------------------------------------------------------------------------------') 134 | -------------------------------------------------------------------------------- /utilities/hgn_result.py: -------------------------------------------------------------------------------- 1 | from utilities import conversions 2 | from environments.environment import visualize_rollout 3 | import os 4 | import sys 5 | 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import torch 9 | 10 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 11 | 12 | 13 | class HgnResult(): 14 | """Class to bundle HGN guessed output information. 15 | """ 16 | 17 | def __init__(self, batch_shape, device): 18 | """Instantiate the HgnResult that will contain all the information of the forward pass 19 | over a single batch of rollouts. 20 | 21 | Args: 22 | batch_shape (torch.Size): Shape of a batch of reconstructed rollouts, returned by 23 | batch.shape. 24 | device (str): String with the device to use. E.g. 'cuda:0', 'cpu'. 25 | """ 26 | self.input = None 27 | self.z_mean = None 28 | self.z_logvar = None 29 | self.z_sample = None 30 | self.q_s = [] 31 | self.p_s = [] 32 | self.energies = [] # Estimated energy of the system by the Hamiltonian network 33 | self.reconstructed_rollout = torch.empty(batch_shape).to(device) 34 | self.reconstruction_ptr = 0 35 | 36 | def set_input(self, rollout): 37 | """Store ground truth of system evolution. 38 | 39 | Args: 40 | rollout (torch.Tensor): Tensor of shape (batch_size, seq_len, channels, height, width) 41 | containing the ground truth rollouts of a batch. 42 | """ 43 | self.input = rollout 44 | 45 | def set_z(self, z_sample, z_mean=None, z_logvar=None): 46 | """Store latent encodings and correspondent distribution parameters. 47 | 48 | Args: 49 | z_sample (torch.Tensor): Batch of latent encodings. 50 | z_mean (torch.Tensor, optional): Batch of means of the latent distribution. 51 | z_logvar (torch.Tensor, optional): Batch of log variances of the latent distributions. 52 | """ 53 | self.z_mean = z_mean 54 | self.z_logvar = z_logvar 55 | self.z_sample = z_sample 56 | 57 | def append_state(self, q, p): 58 | """Append the guessed position (q) and momentum (p) to guessed list . 59 | 60 | Args: 61 | q (torch.Tensor): Tensor with the abstract position. 62 | p (torch.Tensor): Tensor with the abstract momentum. 63 | """ 64 | self.q_s.append(q) 65 | self.p_s.append(p) 66 | 67 | def append_reconstruction(self, reconstruction): 68 | """Append guessed reconstruction. 69 | 70 | Args: 71 | reconstruction (torch.Tensor): Tensor of shape (seq_len, channels, height, width). 72 | containing the reconstructed rollout. 73 | """ 74 | assert self.reconstruction_ptr < self.reconstructed_rollout.shape[1],\ 75 | 'Trying to add rollout number ' + str(self.reconstruction_ptr) + ' when batch has ' +\ 76 | str(self.reconstructed_rollout.shape[0]) 77 | self.reconstructed_rollout[:, self.reconstruction_ptr] = reconstruction 78 | self.reconstruction_ptr += 1 79 | 80 | def append_energy(self, energy): 81 | """Append the guessed system energy to energy list. 82 | 83 | Args: 84 | energy (torch.Tensor): Energy of each trajectory in the batch. 85 | """ 86 | self.energies.append(energy) 87 | 88 | def get_energy(self): 89 | """Get the average energy of that rollout and the average of each trajectory std. 90 | 91 | Returns: 92 | (tuple(float, float)): (average_energy, average_std_energy) average_std_energy is computed as follows: 93 | For each trajectory in the rollout, compute the std of the energy and average across trajectories. 94 | """ 95 | energies = np.array(self.energies) 96 | energy_std = np.std(energies, axis=0) 97 | return np.mean(energies), np.mean(energy_std) 98 | 99 | def visualize(self, interval=50, show_step=False): 100 | """Visualize the predicted rollout. 101 | """ 102 | rollout_batch = conversions.to_channels_last( 103 | self.reconstructed_rollout).detach().cpu().numpy() 104 | sequence = conversions.batch_to_sequence(rollout_batch) 105 | visualize_rollout(sequence, interval=interval, show_step=show_step) 106 | -------------------------------------------------------------------------------- /utilities/integrator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Integrator: 5 | """HGN integrator class: Implements different integration methods for Hamiltonian differential equations. 6 | """ 7 | METHODS = ["Euler", "RK4", "Leapfrog", "Yoshida"] 8 | 9 | def __init__(self, delta_t, method="Euler"): 10 | """Initialize HGN integrator. 11 | 12 | Args: 13 | delta_t (float): Time difference between integration steps. 14 | method (str, optional): Integration method, must be "Euler", "RK4", "Leapfrog" or "Yoshida". Defaults to "Euler". 15 | 16 | Raises: 17 | KeyError: If the integration method passed is invalid. 18 | """ 19 | if method not in self.METHODS: 20 | msg = "%s is not a supported method. " % (method) 21 | msg += "Available methods are: " + "".join("%s " % m 22 | for m in self.METHODS) 23 | raise KeyError(msg) 24 | 25 | self.delta_t = delta_t 26 | self.method = method 27 | 28 | def _get_grads(self, q, p, hnn, remember_energy=False): 29 | """Apply the Hamiltonian equations to the Hamiltonian network to get dq_dt, dp_dt. 30 | 31 | Args: 32 | q (torch.Tensor): Latent-space position tensor. 33 | p (torch.Tensor): Latent-space momentum tensor. 34 | hnn (HamiltonianNet): Hamiltonian Neural Network. 35 | remember_energy (bool): Whether to store the computed energy in self.energy. 36 | 37 | Returns: 38 | tuple(torch.Tensor, torch.Tensor): Position and momentum time derivatives: dq_dt, dp_dt. 39 | """ 40 | # Compute energy of the system 41 | energy = hnn(q=q, p=p) 42 | 43 | # dq_dt = dH/dp 44 | dq_dt = torch.autograd.grad(energy, 45 | p, 46 | create_graph=True, 47 | retain_graph=True, 48 | grad_outputs=torch.ones_like(energy))[0] 49 | 50 | # dp_dt = -dH/dq 51 | dp_dt = -torch.autograd.grad(energy, 52 | q, 53 | create_graph=True, 54 | retain_graph=True, 55 | grad_outputs=torch.ones_like(energy))[0] 56 | 57 | if remember_energy: 58 | self.energy = energy.detach().cpu().numpy() 59 | 60 | return dq_dt, dp_dt 61 | 62 | def _euler_step(self, q, p, hnn): 63 | """Compute next latent-space position and momentum using Euler integration method. 64 | 65 | Args: 66 | q (torch.Tensor): Latent-space position tensor. 67 | p (torch.Tensor): Latent-space momentum tensor. 68 | hnn (HamiltonianNet): Hamiltonian Neural Network. 69 | 70 | Returns: 71 | tuple(torch.Tensor, torch.Tensor): Next time-step position, momentum and energy: q_next, p_next. 72 | """ 73 | dq_dt, dp_dt = self._get_grads(q, p, hnn, remember_energy=True) 74 | 75 | # Euler integration 76 | q_next = q + self.delta_t * dq_dt 77 | p_next = p + self.delta_t * dp_dt 78 | return q_next, p_next 79 | 80 | def _rk_step(self, q, p, hnn): 81 | """Compute next latent-space position and momentum using Runge-Kutta 4 integration method. 82 | 83 | Args: 84 | q (torch.Tensor): Latent-space position tensor. 85 | p (torch.Tensor): Latent-space momentum tensor. 86 | hnn (HamiltonianNet): Hamiltonian Neural Network. 87 | 88 | Returns: 89 | tuple(torch.Tensor, torch.Tensor): Next time-step position and momentum: q_next, p_next. 90 | """ 91 | # k1 92 | k1_q, k1_p = self._get_grads(q, p, hnn, remember_energy=True) 93 | 94 | # k2 95 | q_2 = q + self.delta_t * k1_q / 2 # x = x_t + dt * k1 / 2 96 | p_2 = p + self.delta_t * k1_p / 2 # x = x_t + dt * k1 / 2 97 | k2_q, k2_p = self._get_grads(q_2, p_2, hnn) 98 | 99 | # k3 100 | q_3 = q + self.delta_t * k2_q / 2 # x = x_t + dt * k2 / 2 101 | p_3 = p + self.delta_t * k2_p / 2 # x = x_t + dt * k2 / 2 102 | k3_q, k3_p = self._get_grads(q_3, p_3, hnn) 103 | 104 | # k4 105 | q_3 = q + self.delta_t * k3_q / 2 # x = x_t + dt * k3 106 | p_3 = p + self.delta_t * k3_p / 2 # x = x_t + dt * k3 107 | k4_q, k4_p = self._get_grads(q_3, p_3, hnn) 108 | 109 | # Runge-Kutta 4 integration 110 | q_next = q + self.delta_t * ((k1_q / 6) + (k2_q / 3) + (k3_q / 3) + 111 | (k4_q / 6)) 112 | p_next = p + self.delta_t * ((k1_p / 6) + (k2_p / 3) + (k3_p / 3) + 113 | (k4_p / 6)) 114 | return q_next, p_next 115 | 116 | def _lf_step(self, q, p, hnn): 117 | """Compute next latent-space position and momentum using LeapFrog integration method. 118 | 119 | Args: 120 | q (torch.Tensor): Latent-space position tensor. 121 | p (torch.Tensor): Latent-space momentum tensor. 122 | hnn (HamiltonianNet): Hamiltonian Neural Network. 123 | 124 | Returns: 125 | tuple(torch.Tensor, torch.Tensor): Next time-step position and momentum: q_next, p_next. 126 | """ 127 | # get acceleration 128 | _, dp_dt = self._get_grads(q, p, hnn, remember_energy=True) 129 | # leapfrog step 130 | p_next_half = p + dp_dt * (self.delta_t) / 2 131 | q_next = q + p_next_half * self.delta_t 132 | # momentum synchronization 133 | _, dp_next_dt = self._get_grads(q_next, p_next_half, hnn) 134 | p_next = p_next_half + dp_next_dt * (self.delta_t) / 2 135 | return q_next, p_next 136 | 137 | def _ys_step(self, q, p, hnn): 138 | """Compute next latent-space position and momentum using 4th order Yoshida integration method. 139 | 140 | Args: 141 | q (torch.Tensor): Latent-space position tensor. 142 | p (torch.Tensor): Latent-space momentum tensor. 143 | hnn (HamiltonianNet): Hamiltonian Neural Network. 144 | 145 | Returns: 146 | tuple(torch.Tensor, torch.Tensor): Next time-step position and momentum: q_next, p_next. 147 | """ 148 | # yoshida coeficients c_n and d_m 149 | w_1 = 1./(2 - 2**(1./3)) 150 | w_0 = -(2**(1./3))*w_1 151 | c_1 = c_4 = w_1/2. 152 | c_2 = c_3 = (w_0 + w_1)/2. 153 | d_1 = d_3 = w_1 154 | d_2 = w_0 155 | 156 | # first order 157 | q_1 = q + c_1*p*self.delta_t 158 | _, a_1 = self._get_grads(q_1, p, hnn, remember_energy=True) 159 | p_1 = p + d_1*a_1*self.delta_t 160 | # second order 161 | q_2 = q_1 + c_2*p_1*self.delta_t 162 | _, a_2 = self._get_grads(q_2, p, hnn, remember_energy=False) 163 | p_2 = p_1 + d_2*a_2*self.delta_t 164 | # third order 165 | q_3 = q_2 + c_3*p_2*self.delta_t 166 | _, a_3 = self._get_grads(q_3, p, hnn, remember_energy=False) 167 | p_3 = p_2 + d_3*a_3*self.delta_t 168 | # fourth order 169 | q_4 = q_3 + c_4*p_3*self.delta_t 170 | 171 | return q_4, p_3 172 | 173 | def step(self, q, p, hnn): 174 | """Compute next latent-space position and momentum. 175 | 176 | Args: 177 | q (torch.Tensor): Latent-space position tensor. 178 | p (torch.Tensor): Latent-space momentum tensor. 179 | hnn (HamiltonianNet): Hamiltonian Neural Network. 180 | 181 | Raises: 182 | NotImplementedError: If the integration method requested is not implemented. 183 | 184 | Returns: 185 | tuple(torch.Tensor, torch.Tensor): Next time-step position and momentum: q_next, p_next. 186 | """ 187 | if self.method == "Euler": 188 | return self._euler_step(q, p, hnn) 189 | if self.method == "RK4": 190 | return self._rk_step(q, p, hnn) 191 | if self.method == "Leapfrog": 192 | return self._lf_step(q, p, hnn) 193 | if self.method == "Yoshida": 194 | return self._ys_step(q, p, hnn) 195 | raise NotImplementedError 196 | -------------------------------------------------------------------------------- /utilities/loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import torch 5 | 6 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 7 | 8 | from environments.datasets import EnvironmentSampler, EnvironmentLoader 9 | from environments.environment_factory import EnvFactory 10 | from hamiltonian_generative_network import HGN 11 | from networks.decoder_net import DecoderNet 12 | from networks.encoder_net import EncoderNet 13 | from networks.hamiltonian_net import HamiltonianNet 14 | from networks.transformer_net import TransformerNet 15 | from utilities.integrator import Integrator 16 | 17 | 18 | def instantiate_encoder(params, device, dtype): 19 | encoder = EncoderNet(seq_len=params["dataset"]["rollout"]["seq_length"], 20 | in_channels=params["dataset"]["rollout"]["n_channels"], 21 | **params["networks"]["encoder"], 22 | dtype=dtype).to(device) 23 | return encoder 24 | 25 | 26 | def instantiate_transformer(params, device, dtype): 27 | transformer = TransformerNet( 28 | in_channels=params["networks"]["encoder"]["out_channels"], 29 | **params["networks"]["transformer"], 30 | dtype=dtype).to(device) 31 | return transformer 32 | 33 | 34 | def instantiate_hamiltonian(params, device, dtype): 35 | hnn = HamiltonianNet(**params["networks"]["hamiltonian"], 36 | dtype=dtype).to(device) 37 | return hnn 38 | 39 | 40 | def instantiate_decoder(params, device, dtype): 41 | decoder = DecoderNet( 42 | in_channels=params["networks"]["transformer"]["out_channels"], 43 | out_channels=params["dataset"]["rollout"]["n_channels"], 44 | **params["networks"]["decoder"], 45 | dtype=dtype).to(device) 46 | return decoder 47 | 48 | 49 | def load_hgn(params, device, dtype): 50 | """Return the Hamiltonian Generative Network created from the given parameters. 51 | 52 | Args: 53 | params (dict): Experiment parameters (see experiment_params folder). 54 | device (str): String with the device to use. E.g. 'cuda:0', 'cpu'. 55 | dtype (torch.dtype): Data type to be used by the networks. 56 | """ 57 | # Define networks 58 | encoder = EncoderNet(seq_len=params["optimization"]["input_frames"], 59 | in_channels=params["dataset"]["rollout"]["n_channels"], 60 | **params["networks"]["encoder"], 61 | dtype=dtype).to(device) 62 | transformer = TransformerNet( 63 | in_channels=params["networks"]["encoder"]["out_channels"], 64 | **params["networks"]["transformer"], 65 | dtype=dtype).to(device) 66 | hnn = HamiltonianNet(**params["networks"]["hamiltonian"], 67 | dtype=dtype).to(device) 68 | decoder = DecoderNet( 69 | in_channels=params["networks"]["transformer"]["out_channels"], 70 | out_channels=params["dataset"]["rollout"]["n_channels"], 71 | **params["networks"]["decoder"], 72 | dtype=dtype).to(device) 73 | 74 | # Define HGN integrator 75 | integrator = Integrator(delta_t=params["dataset"]["rollout"]["delta_time"], 76 | method=params["integrator"]["method"]) 77 | 78 | # Instantiate Hamiltonian Generative Network 79 | hgn = HGN(encoder=encoder, 80 | transformer=transformer, 81 | hnn=hnn, 82 | decoder=decoder, 83 | integrator=integrator, 84 | device=device, 85 | dtype=dtype, 86 | seq_len=params["dataset"]["rollout"]["seq_length"], 87 | channels=params["dataset"]["rollout"]["n_channels"]) 88 | return hgn 89 | 90 | 91 | def get_online_dataloaders(params): 92 | """Get train and test online environment dataloaders for the given params 93 | 94 | Args: 95 | params (dict): Experiment parameters (see experiment_params folder). 96 | 97 | Returns: 98 | tuple(torch.utils.data.DataLoader, torch.utils.data.DataLoader): Train and test dataloader 99 | """ 100 | # Pick environment 101 | env = EnvFactory.get_environment(**params["environment"]) 102 | 103 | # Train 104 | trainDS = EnvironmentSampler( 105 | environment=env, 106 | dataset_len=params["dataset"]["num_train_samples"], 107 | number_of_frames=params["dataset"]["rollout"]["seq_length"], 108 | delta_time=params["dataset"]["rollout"]["delta_time"], 109 | number_of_rollouts=params["optimization"]["batch_size"], 110 | img_size=params["dataset"]["img_size"], 111 | color=params["dataset"]["rollout"]["n_channels"] == 3, 112 | radius_bound=params["dataset"]["radius_bound"], 113 | noise_level=params["dataset"]["rollout"]["noise_level"], 114 | seed=None) 115 | train_data_loader = torch.utils.data.DataLoader(trainDS, 116 | shuffle=False, 117 | batch_size=None) 118 | # Test 119 | testDS = EnvironmentSampler( 120 | environment=env, 121 | dataset_len=params["dataset"]["num_test_samples"], 122 | number_of_frames=params["dataset"]["rollout"]["seq_length"], 123 | delta_time=params["dataset"]["rollout"]["delta_time"], 124 | number_of_rollouts=params["optimization"]["batch_size"], 125 | img_size=params["dataset"]["img_size"], 126 | color=params["dataset"]["rollout"]["n_channels"] == 3, 127 | radius_bound=params["dataset"]["radius_bound"], 128 | noise_level=params["dataset"]["rollout"]["noise_level"], 129 | seed=None) 130 | test_data_loader = torch.utils.data.DataLoader(testDS, 131 | shuffle=False, 132 | batch_size=None) 133 | return train_data_loader, test_data_loader 134 | 135 | 136 | def get_offline_dataloaders(params): 137 | """Get train and test online environment dataloaders for the given params 138 | 139 | Args: 140 | params (dict): Experiment parameters (see experiment_params folder). 141 | 142 | Returns: 143 | tuple(torch.utils.data.DataLoader, torch.utils.data.DataLoader): Train and test dataloader 144 | """ 145 | # Train 146 | trainDS = EnvironmentLoader(params["dataset"]["train_data"]) 147 | train_data_loader = torch.utils.data.DataLoader( 148 | trainDS, shuffle=True, batch_size=params["optimization"]["batch_size"]) 149 | 150 | # Test 151 | test_DS = EnvironmentLoader(params["dataset"]["test_data"]) 152 | test_data_loader = torch.utils.data.DataLoader( 153 | test_DS, shuffle=True, batch_size=params["optimization"]["batch_size"]) 154 | 155 | return train_data_loader, test_data_loader 156 | -------------------------------------------------------------------------------- /utilities/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def reconstruction_loss(prediction, target, mean_reduction=True): 5 | """Computes the MSE loss between the target and the predictions. 6 | 7 | Args: 8 | prediction (Tensor) The prediction of the model 9 | target (Tensor): The target batch 10 | mean_reduction (bool): Whether to perform mean reduction across batch (default is true) 11 | Returns: 12 | (Tensor): MSE loss 13 | """ 14 | reduction = 'mean' if mean_reduction else 'none' 15 | mse = torch.nn.MSELoss(reduction=reduction) 16 | if mean_reduction: 17 | return mse(input=prediction, target=target) 18 | else: 19 | return mse(input=prediction, target=target).flatten(1).mean(-1) 20 | 21 | 22 | def kld_loss(mu, logvar, mean_reduction=True): 23 | """ First it computes the KLD over each datapoint in the batch as a sum over all latent dims. 24 | It returns the mean KLD over the batch size. 25 | The KLD is computed in comparison to a multivariate Gaussian with zero mean and identity covariance. 26 | 27 | Args: 28 | mu (torch.Tensor): the part of the latent vector that corresponds to the mean 29 | logvar (torch.Tensor): the log of the variance (sigma squared) 30 | mean_reduction (bool): Whether to perform mean reduction across batch (default is true) 31 | 32 | Returns: 33 | (torch.Tensor): KL divergence. 34 | """ 35 | mu = mu.flatten(1) 36 | logvar = logvar.flatten(1) 37 | kld_per_sample = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim = 1) 38 | if mean_reduction: 39 | return torch.mean(kld_per_sample, dim = 0) 40 | else: 41 | return kld_per_sample 42 | 43 | 44 | def geco_constraint(target, prediction, tol): 45 | """Computes the constraint for the geco algorithm. 46 | 47 | Args: 48 | target (toch.Tensor): the rollout target. 49 | prediction (torch.Tensor): the prediction of the model. 50 | tol (float): the tolerance we accept between target and prediction. 51 | 52 | Returns: 53 | (tuple(torch.Tensor, torch.Tensor)): the constraing value as MSE minus the tolerance, and MSE. 54 | """ 55 | rec_loss = reconstruction_loss(prediction=prediction, target=target) 56 | return rec_loss - tol**2, rec_loss 57 | -------------------------------------------------------------------------------- /utilities/statistics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.stats 3 | 4 | 5 | def mean_confidence_interval(data, confidence=0.95): 6 | """Calculates mean and confidence interval from samples such that they lie within m +/- h 7 | with the given confidence. 8 | 9 | Args: 10 | data (np.array): Sample to calculate the confidence interval. 11 | confidence (float): Confidence of the interval (betwen 0 and 1). 12 | """ 13 | n = len(data) 14 | m, se = np.mean(data), scipy.stats.sem(data) 15 | h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1) 16 | return m, h 17 | -------------------------------------------------------------------------------- /utilities/training_logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torch.utils.tensorboard import SummaryWriter 4 | 5 | 6 | class TrainingLogger: 7 | def __init__(self, 8 | hyper_params, 9 | loss_freq=100, 10 | rollout_freq=1000, 11 | model_freq=1000): 12 | """Instantiate a TrainingLogger. 13 | 14 | Args: 15 | hyper_params (dict): Parameters used to train the model (for reproducibility). 16 | loss_freq (int, optional): Frequency at which the loss values are updated in 17 | TensorBoard. Defaults to 100. 18 | rollout_freq (int, optional): Frequency at which videos are updated in TensorBoard. 19 | Defaults to 1000. 20 | model_freq (int, optional): Frequency at which a checkpoint of the model is saved. 21 | Defaults to 1000. 22 | """ 23 | self.writer = SummaryWriter( 24 | log_dir=os.path.join("runs", hyper_params["experiment_id"])) 25 | self.writer.add_text('data/hyperparams', str(hyper_params), 0) 26 | self.hparams = hyper_params 27 | self.iteration = 0 28 | self.loss_freq = loss_freq 29 | self.rollout_freq = rollout_freq 30 | self.model_freq = model_freq 31 | 32 | def step(self, losses, rollout_batch, prediction, model): 33 | """Perform a logging step: update inner iteration counter and log info if needed. 34 | 35 | Args: 36 | losses (tuple): Tuple of two floats, corresponding to reconstruction loss and KLD. 37 | rollout_batch (torch.Tensor): Batch of ground truth rollouts, as a Tensor of shape 38 | (batch_size, seq_len, channels, height, width). 39 | prediction (utilities.hgn_result.HgnResult): The HgnResult object containing data of 40 | the models forward pass on the rollout_batch. 41 | """ 42 | if self.iteration % self.loss_freq == 0: 43 | for loss_name, loss_value in losses.items(): 44 | if loss_value is not None: 45 | self.writer.add_scalar(f'{loss_name}', loss_value, self.iteration) 46 | enery_mean, energy_std = prediction.get_energy() 47 | self.writer.add_scalar(f'energy/mean', enery_mean, self.iteration) 48 | self.writer.add_scalar(f'energy/std', energy_std, self.iteration) 49 | 50 | if self.iteration % self.rollout_freq == 0: 51 | self.writer.add_video('data/input', 52 | rollout_batch.detach().cpu(), self.iteration) 53 | self.writer.add_video( 54 | 'data/reconstruction', 55 | prediction.reconstructed_rollout.detach().cpu(), 56 | self.iteration) 57 | 58 | # Sample from HGN and add to tensorboard 59 | random_sample = model.get_random_sample(n_steps=50, img_shape=(32, 32)) 60 | self.writer.add_video( 61 | 'data/sample', 62 | random_sample.reconstructed_rollout.detach().cpu(), 63 | self.iteration) 64 | 65 | if self.iteration % self.model_freq == 0: 66 | save_dir = os.path.join( 67 | self.hparams["model_save_dir"], self.hparams["experiment_id"] + 68 | "_checkpoint_" + str(self.iteration)) 69 | model.save(save_dir) 70 | self.iteration += 1 71 | 72 | def log_text(self, label, msg): 73 | """Add text to tensorboard 74 | Args: 75 | label (str): Label to identify in tensorboard display 76 | msg (str, float): Message to display (can be a numericsl value) 77 | """ 78 | self.writer.add_text('data/' + label, str(msg), 0) 79 | 80 | def log_error(self, label, mean, dist): 81 | """Add text to tensorboard 82 | Args: 83 | mean (float): Mean of the error interval to display. 84 | dist (float): distance of the error corresponding to the confidence. 85 | """ 86 | self.log_text(label, "{:.8f} +/- {:.8f}".format(mean, dist)) 87 | --------------------------------------------------------------------------------